SparseSparseProductWithPruning.h (8704B)
1 // This file is part of Eigen, a lightweight C++ template library 2 // for linear algebra. 3 // 4 // Copyright (C) 2008-2014 Gael Guennebaud <gael.guennebaud@inria.fr> 5 // 6 // This Source Code Form is subject to the terms of the Mozilla 7 // Public License v. 2.0. If a copy of the MPL was not distributed 8 // with this file, You can obtain one at http://mozilla.org/MPL/2.0/. 9 10 #ifndef EIGEN_SPARSESPARSEPRODUCTWITHPRUNING_H 11 #define EIGEN_SPARSESPARSEPRODUCTWITHPRUNING_H 12 13 namespace Eigen { 14 15 namespace internal { 16 17 18 // perform a pseudo in-place sparse * sparse product assuming all matrices are col major 19 template<typename Lhs, typename Rhs, typename ResultType> 20 static void sparse_sparse_product_with_pruning_impl(const Lhs& lhs, const Rhs& rhs, ResultType& res, const typename ResultType::RealScalar& tolerance) 21 { 22 // return sparse_sparse_product_with_pruning_impl2(lhs,rhs,res); 23 24 typedef typename remove_all<Rhs>::type::Scalar RhsScalar; 25 typedef typename remove_all<ResultType>::type::Scalar ResScalar; 26 typedef typename remove_all<Lhs>::type::StorageIndex StorageIndex; 27 28 // make sure to call innerSize/outerSize since we fake the storage order. 29 Index rows = lhs.innerSize(); 30 Index cols = rhs.outerSize(); 31 //Index size = lhs.outerSize(); 32 eigen_assert(lhs.outerSize() == rhs.innerSize()); 33 34 // allocate a temporary buffer 35 AmbiVector<ResScalar,StorageIndex> tempVector(rows); 36 37 // mimics a resizeByInnerOuter: 38 if(ResultType::IsRowMajor) 39 res.resize(cols, rows); 40 else 41 res.resize(rows, cols); 42 43 evaluator<Lhs> lhsEval(lhs); 44 evaluator<Rhs> rhsEval(rhs); 45 46 // estimate the number of non zero entries 47 // given a rhs column containing Y non zeros, we assume that the respective Y columns 48 // of the lhs differs in average of one non zeros, thus the number of non zeros for 49 // the product of a rhs column with the lhs is X+Y where X is the average number of non zero 50 // per column of the lhs. 51 // Therefore, we have nnz(lhs*rhs) = nnz(lhs) + nnz(rhs) 52 Index estimated_nnz_prod = lhsEval.nonZerosEstimate() + rhsEval.nonZerosEstimate(); 53 54 res.reserve(estimated_nnz_prod); 55 double ratioColRes = double(estimated_nnz_prod)/(double(lhs.rows())*double(rhs.cols())); 56 for (Index j=0; j<cols; ++j) 57 { 58 // FIXME: 59 //double ratioColRes = (double(rhs.innerVector(j).nonZeros()) + double(lhs.nonZeros())/double(lhs.cols()))/double(lhs.rows()); 60 // let's do a more accurate determination of the nnz ratio for the current column j of res 61 tempVector.init(ratioColRes); 62 tempVector.setZero(); 63 for (typename evaluator<Rhs>::InnerIterator rhsIt(rhsEval, j); rhsIt; ++rhsIt) 64 { 65 // FIXME should be written like this: tmp += rhsIt.value() * lhs.col(rhsIt.index()) 66 tempVector.restart(); 67 RhsScalar x = rhsIt.value(); 68 for (typename evaluator<Lhs>::InnerIterator lhsIt(lhsEval, rhsIt.index()); lhsIt; ++lhsIt) 69 { 70 tempVector.coeffRef(lhsIt.index()) += lhsIt.value() * x; 71 } 72 } 73 res.startVec(j); 74 for (typename AmbiVector<ResScalar,StorageIndex>::Iterator it(tempVector,tolerance); it; ++it) 75 res.insertBackByOuterInner(j,it.index()) = it.value(); 76 } 77 res.finalize(); 78 } 79 80 template<typename Lhs, typename Rhs, typename ResultType, 81 int LhsStorageOrder = traits<Lhs>::Flags&RowMajorBit, 82 int RhsStorageOrder = traits<Rhs>::Flags&RowMajorBit, 83 int ResStorageOrder = traits<ResultType>::Flags&RowMajorBit> 84 struct sparse_sparse_product_with_pruning_selector; 85 86 template<typename Lhs, typename Rhs, typename ResultType> 87 struct sparse_sparse_product_with_pruning_selector<Lhs,Rhs,ResultType,ColMajor,ColMajor,ColMajor> 88 { 89 typedef typename ResultType::RealScalar RealScalar; 90 91 static void run(const Lhs& lhs, const Rhs& rhs, ResultType& res, const RealScalar& tolerance) 92 { 93 typename remove_all<ResultType>::type _res(res.rows(), res.cols()); 94 internal::sparse_sparse_product_with_pruning_impl<Lhs,Rhs,ResultType>(lhs, rhs, _res, tolerance); 95 res.swap(_res); 96 } 97 }; 98 99 template<typename Lhs, typename Rhs, typename ResultType> 100 struct sparse_sparse_product_with_pruning_selector<Lhs,Rhs,ResultType,ColMajor,ColMajor,RowMajor> 101 { 102 typedef typename ResultType::RealScalar RealScalar; 103 static void run(const Lhs& lhs, const Rhs& rhs, ResultType& res, const RealScalar& tolerance) 104 { 105 // we need a col-major matrix to hold the result 106 typedef SparseMatrix<typename ResultType::Scalar,ColMajor,typename ResultType::StorageIndex> SparseTemporaryType; 107 SparseTemporaryType _res(res.rows(), res.cols()); 108 internal::sparse_sparse_product_with_pruning_impl<Lhs,Rhs,SparseTemporaryType>(lhs, rhs, _res, tolerance); 109 res = _res; 110 } 111 }; 112 113 template<typename Lhs, typename Rhs, typename ResultType> 114 struct sparse_sparse_product_with_pruning_selector<Lhs,Rhs,ResultType,RowMajor,RowMajor,RowMajor> 115 { 116 typedef typename ResultType::RealScalar RealScalar; 117 static void run(const Lhs& lhs, const Rhs& rhs, ResultType& res, const RealScalar& tolerance) 118 { 119 // let's transpose the product to get a column x column product 120 typename remove_all<ResultType>::type _res(res.rows(), res.cols()); 121 internal::sparse_sparse_product_with_pruning_impl<Rhs,Lhs,ResultType>(rhs, lhs, _res, tolerance); 122 res.swap(_res); 123 } 124 }; 125 126 template<typename Lhs, typename Rhs, typename ResultType> 127 struct sparse_sparse_product_with_pruning_selector<Lhs,Rhs,ResultType,RowMajor,RowMajor,ColMajor> 128 { 129 typedef typename ResultType::RealScalar RealScalar; 130 static void run(const Lhs& lhs, const Rhs& rhs, ResultType& res, const RealScalar& tolerance) 131 { 132 typedef SparseMatrix<typename Lhs::Scalar,ColMajor,typename Lhs::StorageIndex> ColMajorMatrixLhs; 133 typedef SparseMatrix<typename Rhs::Scalar,ColMajor,typename Lhs::StorageIndex> ColMajorMatrixRhs; 134 ColMajorMatrixLhs colLhs(lhs); 135 ColMajorMatrixRhs colRhs(rhs); 136 internal::sparse_sparse_product_with_pruning_impl<ColMajorMatrixLhs,ColMajorMatrixRhs,ResultType>(colLhs, colRhs, res, tolerance); 137 138 // let's transpose the product to get a column x column product 139 // typedef SparseMatrix<typename ResultType::Scalar> SparseTemporaryType; 140 // SparseTemporaryType _res(res.cols(), res.rows()); 141 // sparse_sparse_product_with_pruning_impl<Rhs,Lhs,SparseTemporaryType>(rhs, lhs, _res); 142 // res = _res.transpose(); 143 } 144 }; 145 146 template<typename Lhs, typename Rhs, typename ResultType> 147 struct sparse_sparse_product_with_pruning_selector<Lhs,Rhs,ResultType,ColMajor,RowMajor,RowMajor> 148 { 149 typedef typename ResultType::RealScalar RealScalar; 150 static void run(const Lhs& lhs, const Rhs& rhs, ResultType& res, const RealScalar& tolerance) 151 { 152 typedef SparseMatrix<typename Lhs::Scalar,RowMajor,typename Lhs::StorageIndex> RowMajorMatrixLhs; 153 RowMajorMatrixLhs rowLhs(lhs); 154 sparse_sparse_product_with_pruning_selector<RowMajorMatrixLhs,Rhs,ResultType,RowMajor,RowMajor>(rowLhs,rhs,res,tolerance); 155 } 156 }; 157 158 template<typename Lhs, typename Rhs, typename ResultType> 159 struct sparse_sparse_product_with_pruning_selector<Lhs,Rhs,ResultType,RowMajor,ColMajor,RowMajor> 160 { 161 typedef typename ResultType::RealScalar RealScalar; 162 static void run(const Lhs& lhs, const Rhs& rhs, ResultType& res, const RealScalar& tolerance) 163 { 164 typedef SparseMatrix<typename Rhs::Scalar,RowMajor,typename Lhs::StorageIndex> RowMajorMatrixRhs; 165 RowMajorMatrixRhs rowRhs(rhs); 166 sparse_sparse_product_with_pruning_selector<Lhs,RowMajorMatrixRhs,ResultType,RowMajor,RowMajor,RowMajor>(lhs,rowRhs,res,tolerance); 167 } 168 }; 169 170 template<typename Lhs, typename Rhs, typename ResultType> 171 struct sparse_sparse_product_with_pruning_selector<Lhs,Rhs,ResultType,ColMajor,RowMajor,ColMajor> 172 { 173 typedef typename ResultType::RealScalar RealScalar; 174 static void run(const Lhs& lhs, const Rhs& rhs, ResultType& res, const RealScalar& tolerance) 175 { 176 typedef SparseMatrix<typename Rhs::Scalar,ColMajor,typename Lhs::StorageIndex> ColMajorMatrixRhs; 177 ColMajorMatrixRhs colRhs(rhs); 178 internal::sparse_sparse_product_with_pruning_impl<Lhs,ColMajorMatrixRhs,ResultType>(lhs, colRhs, res, tolerance); 179 } 180 }; 181 182 template<typename Lhs, typename Rhs, typename ResultType> 183 struct sparse_sparse_product_with_pruning_selector<Lhs,Rhs,ResultType,RowMajor,ColMajor,ColMajor> 184 { 185 typedef typename ResultType::RealScalar RealScalar; 186 static void run(const Lhs& lhs, const Rhs& rhs, ResultType& res, const RealScalar& tolerance) 187 { 188 typedef SparseMatrix<typename Lhs::Scalar,ColMajor,typename Lhs::StorageIndex> ColMajorMatrixLhs; 189 ColMajorMatrixLhs colLhs(lhs); 190 internal::sparse_sparse_product_with_pruning_impl<ColMajorMatrixLhs,Rhs,ResultType>(colLhs, rhs, res, tolerance); 191 } 192 }; 193 194 } // end namespace internal 195 196 } // end namespace Eigen 197 198 #endif // EIGEN_SPARSESPARSEPRODUCTWITHPRUNING_H