matrixfree_cg.cpp (4275B)
1 #include <iostream> 2 #include <Eigen/Core> 3 #include <Eigen/Dense> 4 #include <Eigen/IterativeLinearSolvers> 5 #include <unsupported/Eigen/IterativeSolvers> 6 7 class MatrixReplacement; 8 using Eigen::SparseMatrix; 9 10 namespace Eigen { 11 namespace internal { 12 // MatrixReplacement looks-like a SparseMatrix, so let's inherits its traits: 13 template<> 14 struct traits<MatrixReplacement> : public Eigen::internal::traits<Eigen::SparseMatrix<double> > 15 {}; 16 } 17 } 18 19 // Example of a matrix-free wrapper from a user type to Eigen's compatible type 20 // For the sake of simplicity, this example simply wrap a Eigen::SparseMatrix. 21 class MatrixReplacement : public Eigen::EigenBase<MatrixReplacement> { 22 public: 23 // Required typedefs, constants, and method: 24 typedef double Scalar; 25 typedef double RealScalar; 26 typedef int StorageIndex; 27 enum { 28 ColsAtCompileTime = Eigen::Dynamic, 29 MaxColsAtCompileTime = Eigen::Dynamic, 30 IsRowMajor = false 31 }; 32 33 Index rows() const { return mp_mat->rows(); } 34 Index cols() const { return mp_mat->cols(); } 35 36 template<typename Rhs> 37 Eigen::Product<MatrixReplacement,Rhs,Eigen::AliasFreeProduct> operator*(const Eigen::MatrixBase<Rhs>& x) const { 38 return Eigen::Product<MatrixReplacement,Rhs,Eigen::AliasFreeProduct>(*this, x.derived()); 39 } 40 41 // Custom API: 42 MatrixReplacement() : mp_mat(0) {} 43 44 void attachMyMatrix(const SparseMatrix<double> &mat) { 45 mp_mat = &mat; 46 } 47 const SparseMatrix<double> my_matrix() const { return *mp_mat; } 48 49 private: 50 const SparseMatrix<double> *mp_mat; 51 }; 52 53 54 // Implementation of MatrixReplacement * Eigen::DenseVector though a specialization of internal::generic_product_impl: 55 namespace Eigen { 56 namespace internal { 57 58 template<typename Rhs> 59 struct generic_product_impl<MatrixReplacement, Rhs, SparseShape, DenseShape, GemvProduct> // GEMV stands for matrix-vector 60 : generic_product_impl_base<MatrixReplacement,Rhs,generic_product_impl<MatrixReplacement,Rhs> > 61 { 62 typedef typename Product<MatrixReplacement,Rhs>::Scalar Scalar; 63 64 template<typename Dest> 65 static void scaleAndAddTo(Dest& dst, const MatrixReplacement& lhs, const Rhs& rhs, const Scalar& alpha) 66 { 67 // This method should implement "dst += alpha * lhs * rhs" inplace, 68 // however, for iterative solvers, alpha is always equal to 1, so let's not bother about it. 69 assert(alpha==Scalar(1) && "scaling is not implemented"); 70 EIGEN_ONLY_USED_FOR_DEBUG(alpha); 71 72 // Here we could simply call dst.noalias() += lhs.my_matrix() * rhs, 73 // but let's do something fancier (and less efficient): 74 for(Index i=0; i<lhs.cols(); ++i) 75 dst += rhs(i) * lhs.my_matrix().col(i); 76 } 77 }; 78 79 } 80 } 81 82 int main() 83 { 84 int n = 10; 85 Eigen::SparseMatrix<double> S = Eigen::MatrixXd::Random(n,n).sparseView(0.5,1); 86 S = S.transpose()*S; 87 88 MatrixReplacement A; 89 A.attachMyMatrix(S); 90 91 Eigen::VectorXd b(n), x; 92 b.setRandom(); 93 94 // Solve Ax = b using various iterative solver with matrix-free version: 95 { 96 Eigen::ConjugateGradient<MatrixReplacement, Eigen::Lower|Eigen::Upper, Eigen::IdentityPreconditioner> cg; 97 cg.compute(A); 98 x = cg.solve(b); 99 std::cout << "CG: #iterations: " << cg.iterations() << ", estimated error: " << cg.error() << std::endl; 100 } 101 102 { 103 Eigen::BiCGSTAB<MatrixReplacement, Eigen::IdentityPreconditioner> bicg; 104 bicg.compute(A); 105 x = bicg.solve(b); 106 std::cout << "BiCGSTAB: #iterations: " << bicg.iterations() << ", estimated error: " << bicg.error() << std::endl; 107 } 108 109 { 110 Eigen::GMRES<MatrixReplacement, Eigen::IdentityPreconditioner> gmres; 111 gmres.compute(A); 112 x = gmres.solve(b); 113 std::cout << "GMRES: #iterations: " << gmres.iterations() << ", estimated error: " << gmres.error() << std::endl; 114 } 115 116 { 117 Eigen::DGMRES<MatrixReplacement, Eigen::IdentityPreconditioner> gmres; 118 gmres.compute(A); 119 x = gmres.solve(b); 120 std::cout << "DGMRES: #iterations: " << gmres.iterations() << ", estimated error: " << gmres.error() << std::endl; 121 } 122 123 { 124 Eigen::MINRES<MatrixReplacement, Eigen::Lower|Eigen::Upper, Eigen::IdentityPreconditioner> minres; 125 minres.compute(A); 126 x = minres.solve(b); 127 std::cout << "MINRES: #iterations: " << minres.iterations() << ", estimated error: " << minres.error() << std::endl; 128 } 129 }