cart-elc

Source code for CART-ELC
git clone git://git.laack.co/cart-elc.git
Log | Files | Refs | README | LICENSE

TriangularSolver.h (9657B)


      1 // This file is part of Eigen, a lightweight C++ template library
      2 // for linear algebra.
      3 //
      4 // Copyright (C) 2008 Gael Guennebaud <gael.guennebaud@inria.fr>
      5 //
      6 // This Source Code Form is subject to the terms of the Mozilla
      7 // Public License v. 2.0. If a copy of the MPL was not distributed
      8 // with this file, You can obtain one at http://mozilla.org/MPL/2.0/.
      9 
     10 #ifndef EIGEN_SPARSETRIANGULARSOLVER_H
     11 #define EIGEN_SPARSETRIANGULARSOLVER_H
     12 
     13 namespace Eigen { 
     14 
     15 namespace internal {
     16 
     17 template<typename Lhs, typename Rhs, int Mode,
     18   int UpLo = (Mode & Lower)
     19            ? Lower
     20            : (Mode & Upper)
     21            ? Upper
     22            : -1,
     23   int StorageOrder = int(traits<Lhs>::Flags) & RowMajorBit>
     24 struct sparse_solve_triangular_selector;
     25 
     26 // forward substitution, row-major
     27 template<typename Lhs, typename Rhs, int Mode>
     28 struct sparse_solve_triangular_selector<Lhs,Rhs,Mode,Lower,RowMajor>
     29 {
     30   typedef typename Rhs::Scalar Scalar;
     31   typedef evaluator<Lhs> LhsEval;
     32   typedef typename evaluator<Lhs>::InnerIterator LhsIterator;
     33   static void run(const Lhs& lhs, Rhs& other)
     34   {
     35     LhsEval lhsEval(lhs);
     36     for(Index col=0 ; col<other.cols() ; ++col)
     37     {
     38       for(Index i=0; i<lhs.rows(); ++i)
     39       {
     40         Scalar tmp = other.coeff(i,col);
     41         Scalar lastVal(0);
     42         Index lastIndex = 0;
     43         for(LhsIterator it(lhsEval, i); it; ++it)
     44         {
     45           lastVal = it.value();
     46           lastIndex = it.index();
     47           if(lastIndex==i)
     48             break;
     49           tmp -= lastVal * other.coeff(lastIndex,col);
     50         }
     51         if (Mode & UnitDiag)
     52           other.coeffRef(i,col) = tmp;
     53         else
     54         {
     55           eigen_assert(lastIndex==i);
     56           other.coeffRef(i,col) = tmp/lastVal;
     57         }
     58       }
     59     }
     60   }
     61 };
     62 
     63 // backward substitution, row-major
     64 template<typename Lhs, typename Rhs, int Mode>
     65 struct sparse_solve_triangular_selector<Lhs,Rhs,Mode,Upper,RowMajor>
     66 {
     67   typedef typename Rhs::Scalar Scalar;
     68   typedef evaluator<Lhs> LhsEval;
     69   typedef typename evaluator<Lhs>::InnerIterator LhsIterator;
     70   static void run(const Lhs& lhs, Rhs& other)
     71   {
     72     LhsEval lhsEval(lhs);
     73     for(Index col=0 ; col<other.cols() ; ++col)
     74     {
     75       for(Index i=lhs.rows()-1 ; i>=0 ; --i)
     76       {
     77         Scalar tmp = other.coeff(i,col);
     78         Scalar l_ii(0);
     79         LhsIterator it(lhsEval, i);
     80         while(it && it.index()<i)
     81           ++it;
     82         if(!(Mode & UnitDiag))
     83         {
     84           eigen_assert(it && it.index()==i);
     85           l_ii = it.value();
     86           ++it;
     87         }
     88         else if (it && it.index() == i)
     89           ++it;
     90         for(; it; ++it)
     91         {
     92           tmp -= it.value() * other.coeff(it.index(),col);
     93         }
     94 
     95         if (Mode & UnitDiag)  other.coeffRef(i,col) = tmp;
     96         else                  other.coeffRef(i,col) = tmp/l_ii;
     97       }
     98     }
     99   }
    100 };
    101 
    102 // forward substitution, col-major
    103 template<typename Lhs, typename Rhs, int Mode>
    104 struct sparse_solve_triangular_selector<Lhs,Rhs,Mode,Lower,ColMajor>
    105 {
    106   typedef typename Rhs::Scalar Scalar;
    107   typedef evaluator<Lhs> LhsEval;
    108   typedef typename evaluator<Lhs>::InnerIterator LhsIterator;
    109   static void run(const Lhs& lhs, Rhs& other)
    110   {
    111     LhsEval lhsEval(lhs);
    112     for(Index col=0 ; col<other.cols() ; ++col)
    113     {
    114       for(Index i=0; i<lhs.cols(); ++i)
    115       {
    116         Scalar& tmp = other.coeffRef(i,col);
    117         if (tmp!=Scalar(0)) // optimization when other is actually sparse
    118         {
    119           LhsIterator it(lhsEval, i);
    120           while(it && it.index()<i)
    121             ++it;
    122           if(!(Mode & UnitDiag))
    123           {
    124             eigen_assert(it && it.index()==i);
    125             tmp /= it.value();
    126           }
    127           if (it && it.index()==i)
    128             ++it;
    129           for(; it; ++it)
    130             other.coeffRef(it.index(), col) -= tmp * it.value();
    131         }
    132       }
    133     }
    134   }
    135 };
    136 
    137 // backward substitution, col-major
    138 template<typename Lhs, typename Rhs, int Mode>
    139 struct sparse_solve_triangular_selector<Lhs,Rhs,Mode,Upper,ColMajor>
    140 {
    141   typedef typename Rhs::Scalar Scalar;
    142   typedef evaluator<Lhs> LhsEval;
    143   typedef typename evaluator<Lhs>::InnerIterator LhsIterator;
    144   static void run(const Lhs& lhs, Rhs& other)
    145   {
    146     LhsEval lhsEval(lhs);
    147     for(Index col=0 ; col<other.cols() ; ++col)
    148     {
    149       for(Index i=lhs.cols()-1; i>=0; --i)
    150       {
    151         Scalar& tmp = other.coeffRef(i,col);
    152         if (tmp!=Scalar(0)) // optimization when other is actually sparse
    153         {
    154           if(!(Mode & UnitDiag))
    155           {
    156             // TODO replace this by a binary search. make sure the binary search is safe for partially sorted elements
    157             LhsIterator it(lhsEval, i);
    158             while(it && it.index()!=i)
    159               ++it;
    160             eigen_assert(it && it.index()==i);
    161             other.coeffRef(i,col) /= it.value();
    162           }
    163           LhsIterator it(lhsEval, i);
    164           for(; it && it.index()<i; ++it)
    165             other.coeffRef(it.index(), col) -= tmp * it.value();
    166         }
    167       }
    168     }
    169   }
    170 };
    171 
    172 } // end namespace internal
    173 
    174 #ifndef EIGEN_PARSED_BY_DOXYGEN
    175 
    176 template<typename ExpressionType,unsigned int Mode>
    177 template<typename OtherDerived>
    178 void TriangularViewImpl<ExpressionType,Mode,Sparse>::solveInPlace(MatrixBase<OtherDerived>& other) const
    179 {
    180   eigen_assert(derived().cols() == derived().rows() && derived().cols() == other.rows());
    181   eigen_assert((!(Mode & ZeroDiag)) && bool(Mode & (Upper|Lower)));
    182 
    183   enum { copy = internal::traits<OtherDerived>::Flags & RowMajorBit };
    184 
    185   typedef typename internal::conditional<copy,
    186     typename internal::plain_matrix_type_column_major<OtherDerived>::type, OtherDerived&>::type OtherCopy;
    187   OtherCopy otherCopy(other.derived());
    188 
    189   internal::sparse_solve_triangular_selector<ExpressionType, typename internal::remove_reference<OtherCopy>::type, Mode>::run(derived().nestedExpression(), otherCopy);
    190 
    191   if (copy)
    192     other = otherCopy;
    193 }
    194 #endif
    195 
    196 // pure sparse path
    197 
    198 namespace internal {
    199 
    200 template<typename Lhs, typename Rhs, int Mode,
    201   int UpLo = (Mode & Lower)
    202            ? Lower
    203            : (Mode & Upper)
    204            ? Upper
    205            : -1,
    206   int StorageOrder = int(Lhs::Flags) & (RowMajorBit)>
    207 struct sparse_solve_triangular_sparse_selector;
    208 
    209 // forward substitution, col-major
    210 template<typename Lhs, typename Rhs, int Mode, int UpLo>
    211 struct sparse_solve_triangular_sparse_selector<Lhs,Rhs,Mode,UpLo,ColMajor>
    212 {
    213   typedef typename Rhs::Scalar Scalar;
    214   typedef typename promote_index_type<typename traits<Lhs>::StorageIndex,
    215                                       typename traits<Rhs>::StorageIndex>::type StorageIndex;
    216   static void run(const Lhs& lhs, Rhs& other)
    217   {
    218     const bool IsLower = (UpLo==Lower);
    219     AmbiVector<Scalar,StorageIndex> tempVector(other.rows()*2);
    220     tempVector.setBounds(0,other.rows());
    221 
    222     Rhs res(other.rows(), other.cols());
    223     res.reserve(other.nonZeros());
    224 
    225     for(Index col=0 ; col<other.cols() ; ++col)
    226     {
    227       // FIXME estimate number of non zeros
    228       tempVector.init(.99/*float(other.col(col).nonZeros())/float(other.rows())*/);
    229       tempVector.setZero();
    230       tempVector.restart();
    231       for (typename Rhs::InnerIterator rhsIt(other, col); rhsIt; ++rhsIt)
    232       {
    233         tempVector.coeffRef(rhsIt.index()) = rhsIt.value();
    234       }
    235 
    236       for(Index i=IsLower?0:lhs.cols()-1;
    237           IsLower?i<lhs.cols():i>=0;
    238           i+=IsLower?1:-1)
    239       {
    240         tempVector.restart();
    241         Scalar& ci = tempVector.coeffRef(i);
    242         if (ci!=Scalar(0))
    243         {
    244           // find
    245           typename Lhs::InnerIterator it(lhs, i);
    246           if(!(Mode & UnitDiag))
    247           {
    248             if (IsLower)
    249             {
    250               eigen_assert(it.index()==i);
    251               ci /= it.value();
    252             }
    253             else
    254               ci /= lhs.coeff(i,i);
    255           }
    256           tempVector.restart();
    257           if (IsLower)
    258           {
    259             if (it.index()==i)
    260               ++it;
    261             for(; it; ++it)
    262               tempVector.coeffRef(it.index()) -= ci * it.value();
    263           }
    264           else
    265           {
    266             for(; it && it.index()<i; ++it)
    267               tempVector.coeffRef(it.index()) -= ci * it.value();
    268           }
    269         }
    270       }
    271 
    272 
    273       Index count = 0;
    274       // FIXME compute a reference value to filter zeros
    275       for (typename AmbiVector<Scalar,StorageIndex>::Iterator it(tempVector/*,1e-12*/); it; ++it)
    276       {
    277         ++ count;
    278 //         std::cerr << "fill " << it.index() << ", " << col << "\n";
    279 //         std::cout << it.value() << "  ";
    280         // FIXME use insertBack
    281         res.insert(it.index(), col) = it.value();
    282       }
    283 //       std::cout << "tempVector.nonZeros() == " << int(count) << " / " << (other.rows()) << "\n";
    284     }
    285     res.finalize();
    286     other = res.markAsRValue();
    287   }
    288 };
    289 
    290 } // end namespace internal
    291 
    292 #ifndef EIGEN_PARSED_BY_DOXYGEN
    293 template<typename ExpressionType,unsigned int Mode>
    294 template<typename OtherDerived>
    295 void TriangularViewImpl<ExpressionType,Mode,Sparse>::solveInPlace(SparseMatrixBase<OtherDerived>& other) const
    296 {
    297   eigen_assert(derived().cols() == derived().rows() && derived().cols() == other.rows());
    298   eigen_assert( (!(Mode & ZeroDiag)) && bool(Mode & (Upper|Lower)));
    299 
    300 //   enum { copy = internal::traits<OtherDerived>::Flags & RowMajorBit };
    301 
    302 //   typedef typename internal::conditional<copy,
    303 //     typename internal::plain_matrix_type_column_major<OtherDerived>::type, OtherDerived&>::type OtherCopy;
    304 //   OtherCopy otherCopy(other.derived());
    305 
    306   internal::sparse_solve_triangular_sparse_selector<ExpressionType, OtherDerived, Mode>::run(derived().nestedExpression(), other.derived());
    307 
    308 //   if (copy)
    309 //     other = otherCopy;
    310 }
    311 #endif
    312 
    313 } // end namespace Eigen
    314 
    315 #endif // EIGEN_SPARSETRIANGULARSOLVER_H