cart-elc

Source code for CART-ELC
git clone git://git.laack.co/cart-elc.git
Log | Files | Refs | README | LICENSE

SparseSparseProductWithPruning.h (8704B)


      1 // This file is part of Eigen, a lightweight C++ template library
      2 // for linear algebra.
      3 //
      4 // Copyright (C) 2008-2014 Gael Guennebaud <gael.guennebaud@inria.fr>
      5 //
      6 // This Source Code Form is subject to the terms of the Mozilla
      7 // Public License v. 2.0. If a copy of the MPL was not distributed
      8 // with this file, You can obtain one at http://mozilla.org/MPL/2.0/.
      9 
     10 #ifndef EIGEN_SPARSESPARSEPRODUCTWITHPRUNING_H
     11 #define EIGEN_SPARSESPARSEPRODUCTWITHPRUNING_H
     12 
     13 namespace Eigen { 
     14 
     15 namespace internal {
     16 
     17 
     18 // perform a pseudo in-place sparse * sparse product assuming all matrices are col major
     19 template<typename Lhs, typename Rhs, typename ResultType>
     20 static void sparse_sparse_product_with_pruning_impl(const Lhs& lhs, const Rhs& rhs, ResultType& res, const typename ResultType::RealScalar& tolerance)
     21 {
     22   // return sparse_sparse_product_with_pruning_impl2(lhs,rhs,res);
     23 
     24   typedef typename remove_all<Rhs>::type::Scalar RhsScalar;
     25   typedef typename remove_all<ResultType>::type::Scalar ResScalar;
     26   typedef typename remove_all<Lhs>::type::StorageIndex StorageIndex;
     27 
     28   // make sure to call innerSize/outerSize since we fake the storage order.
     29   Index rows = lhs.innerSize();
     30   Index cols = rhs.outerSize();
     31   //Index size = lhs.outerSize();
     32   eigen_assert(lhs.outerSize() == rhs.innerSize());
     33 
     34   // allocate a temporary buffer
     35   AmbiVector<ResScalar,StorageIndex> tempVector(rows);
     36 
     37   // mimics a resizeByInnerOuter:
     38   if(ResultType::IsRowMajor)
     39     res.resize(cols, rows);
     40   else
     41     res.resize(rows, cols);
     42   
     43   evaluator<Lhs> lhsEval(lhs);
     44   evaluator<Rhs> rhsEval(rhs);
     45   
     46   // estimate the number of non zero entries
     47   // given a rhs column containing Y non zeros, we assume that the respective Y columns
     48   // of the lhs differs in average of one non zeros, thus the number of non zeros for
     49   // the product of a rhs column with the lhs is X+Y where X is the average number of non zero
     50   // per column of the lhs.
     51   // Therefore, we have nnz(lhs*rhs) = nnz(lhs) + nnz(rhs)
     52   Index estimated_nnz_prod = lhsEval.nonZerosEstimate() + rhsEval.nonZerosEstimate();
     53 
     54   res.reserve(estimated_nnz_prod);
     55   double ratioColRes = double(estimated_nnz_prod)/(double(lhs.rows())*double(rhs.cols()));
     56   for (Index j=0; j<cols; ++j)
     57   {
     58     // FIXME:
     59     //double ratioColRes = (double(rhs.innerVector(j).nonZeros()) + double(lhs.nonZeros())/double(lhs.cols()))/double(lhs.rows());
     60     // let's do a more accurate determination of the nnz ratio for the current column j of res
     61     tempVector.init(ratioColRes);
     62     tempVector.setZero();
     63     for (typename evaluator<Rhs>::InnerIterator rhsIt(rhsEval, j); rhsIt; ++rhsIt)
     64     {
     65       // FIXME should be written like this: tmp += rhsIt.value() * lhs.col(rhsIt.index())
     66       tempVector.restart();
     67       RhsScalar x = rhsIt.value();
     68       for (typename evaluator<Lhs>::InnerIterator lhsIt(lhsEval, rhsIt.index()); lhsIt; ++lhsIt)
     69       {
     70         tempVector.coeffRef(lhsIt.index()) += lhsIt.value() * x;
     71       }
     72     }
     73     res.startVec(j);
     74     for (typename AmbiVector<ResScalar,StorageIndex>::Iterator it(tempVector,tolerance); it; ++it)
     75       res.insertBackByOuterInner(j,it.index()) = it.value();
     76   }
     77   res.finalize();
     78 }
     79 
     80 template<typename Lhs, typename Rhs, typename ResultType,
     81   int LhsStorageOrder = traits<Lhs>::Flags&RowMajorBit,
     82   int RhsStorageOrder = traits<Rhs>::Flags&RowMajorBit,
     83   int ResStorageOrder = traits<ResultType>::Flags&RowMajorBit>
     84 struct sparse_sparse_product_with_pruning_selector;
     85 
     86 template<typename Lhs, typename Rhs, typename ResultType>
     87 struct sparse_sparse_product_with_pruning_selector<Lhs,Rhs,ResultType,ColMajor,ColMajor,ColMajor>
     88 {
     89   typedef typename ResultType::RealScalar RealScalar;
     90 
     91   static void run(const Lhs& lhs, const Rhs& rhs, ResultType& res, const RealScalar& tolerance)
     92   {
     93     typename remove_all<ResultType>::type _res(res.rows(), res.cols());
     94     internal::sparse_sparse_product_with_pruning_impl<Lhs,Rhs,ResultType>(lhs, rhs, _res, tolerance);
     95     res.swap(_res);
     96   }
     97 };
     98 
     99 template<typename Lhs, typename Rhs, typename ResultType>
    100 struct sparse_sparse_product_with_pruning_selector<Lhs,Rhs,ResultType,ColMajor,ColMajor,RowMajor>
    101 {
    102   typedef typename ResultType::RealScalar RealScalar;
    103   static void run(const Lhs& lhs, const Rhs& rhs, ResultType& res, const RealScalar& tolerance)
    104   {
    105     // we need a col-major matrix to hold the result
    106     typedef SparseMatrix<typename ResultType::Scalar,ColMajor,typename ResultType::StorageIndex> SparseTemporaryType;
    107     SparseTemporaryType _res(res.rows(), res.cols());
    108     internal::sparse_sparse_product_with_pruning_impl<Lhs,Rhs,SparseTemporaryType>(lhs, rhs, _res, tolerance);
    109     res = _res;
    110   }
    111 };
    112 
    113 template<typename Lhs, typename Rhs, typename ResultType>
    114 struct sparse_sparse_product_with_pruning_selector<Lhs,Rhs,ResultType,RowMajor,RowMajor,RowMajor>
    115 {
    116   typedef typename ResultType::RealScalar RealScalar;
    117   static void run(const Lhs& lhs, const Rhs& rhs, ResultType& res, const RealScalar& tolerance)
    118   {
    119     // let's transpose the product to get a column x column product
    120     typename remove_all<ResultType>::type _res(res.rows(), res.cols());
    121     internal::sparse_sparse_product_with_pruning_impl<Rhs,Lhs,ResultType>(rhs, lhs, _res, tolerance);
    122     res.swap(_res);
    123   }
    124 };
    125 
    126 template<typename Lhs, typename Rhs, typename ResultType>
    127 struct sparse_sparse_product_with_pruning_selector<Lhs,Rhs,ResultType,RowMajor,RowMajor,ColMajor>
    128 {
    129   typedef typename ResultType::RealScalar RealScalar;
    130   static void run(const Lhs& lhs, const Rhs& rhs, ResultType& res, const RealScalar& tolerance)
    131   {
    132     typedef SparseMatrix<typename Lhs::Scalar,ColMajor,typename Lhs::StorageIndex> ColMajorMatrixLhs;
    133     typedef SparseMatrix<typename Rhs::Scalar,ColMajor,typename Lhs::StorageIndex> ColMajorMatrixRhs;
    134     ColMajorMatrixLhs colLhs(lhs);
    135     ColMajorMatrixRhs colRhs(rhs);
    136     internal::sparse_sparse_product_with_pruning_impl<ColMajorMatrixLhs,ColMajorMatrixRhs,ResultType>(colLhs, colRhs, res, tolerance);
    137 
    138     // let's transpose the product to get a column x column product
    139 //     typedef SparseMatrix<typename ResultType::Scalar> SparseTemporaryType;
    140 //     SparseTemporaryType _res(res.cols(), res.rows());
    141 //     sparse_sparse_product_with_pruning_impl<Rhs,Lhs,SparseTemporaryType>(rhs, lhs, _res);
    142 //     res = _res.transpose();
    143   }
    144 };
    145 
    146 template<typename Lhs, typename Rhs, typename ResultType>
    147 struct sparse_sparse_product_with_pruning_selector<Lhs,Rhs,ResultType,ColMajor,RowMajor,RowMajor>
    148 {
    149   typedef typename ResultType::RealScalar RealScalar;
    150   static void run(const Lhs& lhs, const Rhs& rhs, ResultType& res, const RealScalar& tolerance)
    151   {
    152     typedef SparseMatrix<typename Lhs::Scalar,RowMajor,typename Lhs::StorageIndex> RowMajorMatrixLhs;
    153     RowMajorMatrixLhs rowLhs(lhs);
    154     sparse_sparse_product_with_pruning_selector<RowMajorMatrixLhs,Rhs,ResultType,RowMajor,RowMajor>(rowLhs,rhs,res,tolerance);
    155   }
    156 };
    157 
    158 template<typename Lhs, typename Rhs, typename ResultType>
    159 struct sparse_sparse_product_with_pruning_selector<Lhs,Rhs,ResultType,RowMajor,ColMajor,RowMajor>
    160 {
    161   typedef typename ResultType::RealScalar RealScalar;
    162   static void run(const Lhs& lhs, const Rhs& rhs, ResultType& res, const RealScalar& tolerance)
    163   {
    164     typedef SparseMatrix<typename Rhs::Scalar,RowMajor,typename Lhs::StorageIndex> RowMajorMatrixRhs;
    165     RowMajorMatrixRhs rowRhs(rhs);
    166     sparse_sparse_product_with_pruning_selector<Lhs,RowMajorMatrixRhs,ResultType,RowMajor,RowMajor,RowMajor>(lhs,rowRhs,res,tolerance);
    167   }
    168 };
    169 
    170 template<typename Lhs, typename Rhs, typename ResultType>
    171 struct sparse_sparse_product_with_pruning_selector<Lhs,Rhs,ResultType,ColMajor,RowMajor,ColMajor>
    172 {
    173   typedef typename ResultType::RealScalar RealScalar;
    174   static void run(const Lhs& lhs, const Rhs& rhs, ResultType& res, const RealScalar& tolerance)
    175   {
    176     typedef SparseMatrix<typename Rhs::Scalar,ColMajor,typename Lhs::StorageIndex> ColMajorMatrixRhs;
    177     ColMajorMatrixRhs colRhs(rhs);
    178     internal::sparse_sparse_product_with_pruning_impl<Lhs,ColMajorMatrixRhs,ResultType>(lhs, colRhs, res, tolerance);
    179   }
    180 };
    181 
    182 template<typename Lhs, typename Rhs, typename ResultType>
    183 struct sparse_sparse_product_with_pruning_selector<Lhs,Rhs,ResultType,RowMajor,ColMajor,ColMajor>
    184 {
    185   typedef typename ResultType::RealScalar RealScalar;
    186   static void run(const Lhs& lhs, const Rhs& rhs, ResultType& res, const RealScalar& tolerance)
    187   {
    188     typedef SparseMatrix<typename Lhs::Scalar,ColMajor,typename Lhs::StorageIndex> ColMajorMatrixLhs;
    189     ColMajorMatrixLhs colLhs(lhs);
    190     internal::sparse_sparse_product_with_pruning_impl<ColMajorMatrixLhs,Rhs,ResultType>(colLhs, rhs, res, tolerance);
    191   }
    192 };
    193 
    194 } // end namespace internal
    195 
    196 } // end namespace Eigen
    197 
    198 #endif // EIGEN_SPARSESPARSEPRODUCTWITHPRUNING_H