cart-elc

Source code for CART-ELC
git clone git://git.laack.co/cart-elc.git
Log | Files | Refs | README | LICENSE

bench_gemm.cpp (11435B)


      1 
      2 // g++-4.4 bench_gemm.cpp -I .. -O2 -DNDEBUG -lrt -fopenmp && OMP_NUM_THREADS=2  ./a.out
      3 // icpc bench_gemm.cpp -I .. -O3 -DNDEBUG -lrt -openmp  && OMP_NUM_THREADS=2  ./a.out
      4 
      5 // Compilation options:
      6 // 
      7 // -DSCALAR=std::complex<double>
      8 // -DSCALARA=double or -DSCALARB=double
      9 // -DHAVE_BLAS
     10 // -DDECOUPLED
     11 //
     12 
     13 #include <iostream>
     14 #include <bench/BenchTimer.h>
     15 #include <Eigen/Core>
     16 
     17 
     18 using namespace std;
     19 using namespace Eigen;
     20 
     21 #ifndef SCALAR
     22 // #define SCALAR std::complex<float>
     23 #define SCALAR float
     24 #endif
     25 
     26 #ifndef SCALARA
     27 #define SCALARA SCALAR
     28 #endif
     29 
     30 #ifndef SCALARB
     31 #define SCALARB SCALAR
     32 #endif
     33 
     34 #ifdef ROWMAJ_A
     35 const int opt_A = RowMajor;
     36 #else
     37 const int opt_A = ColMajor;
     38 #endif
     39 
     40 #ifdef ROWMAJ_B
     41 const int opt_B = RowMajor;
     42 #else
     43 const int opt_B = ColMajor;
     44 #endif
     45 
     46 typedef SCALAR Scalar;
     47 typedef NumTraits<Scalar>::Real RealScalar;
     48 typedef Matrix<SCALARA,Dynamic,Dynamic,opt_A> A;
     49 typedef Matrix<SCALARB,Dynamic,Dynamic,opt_B> B;
     50 typedef Matrix<Scalar,Dynamic,Dynamic> C;
     51 typedef Matrix<RealScalar,Dynamic,Dynamic> M;
     52 
     53 #ifdef HAVE_BLAS
     54 
     55 extern "C" {
     56   #include <Eigen/src/misc/blas.h>
     57 }
     58 
     59 static float fone = 1;
     60 static float fzero = 0;
     61 static double done = 1;
     62 static double szero = 0;
     63 static std::complex<float> cfone = 1;
     64 static std::complex<float> cfzero = 0;
     65 static std::complex<double> cdone = 1;
     66 static std::complex<double> cdzero = 0;
     67 static char notrans = 'N';
     68 static char trans = 'T';  
     69 static char nonunit = 'N';
     70 static char lower = 'L';
     71 static char right = 'R';
     72 static int intone = 1;
     73 
     74 #ifdef ROWMAJ_A
     75 const char transA = trans;
     76 #else
     77 const char transA = notrans;
     78 #endif
     79 
     80 #ifdef ROWMAJ_B
     81 const char transB = trans;
     82 #else
     83 const char transB = notrans;
     84 #endif
     85 
     86 template<typename A,typename B>
     87 void blas_gemm(const A& a, const B& b, MatrixXf& c)
     88 {
     89   int M = c.rows(); int N = c.cols(); int K = a.cols();
     90   int lda = a.outerStride(); int ldb = b.outerStride(); int ldc = c.rows();
     91 
     92   sgemm_(&transA,&transB,&M,&N,&K,&fone,
     93          const_cast<float*>(a.data()),&lda,
     94          const_cast<float*>(b.data()),&ldb,&fone,
     95          c.data(),&ldc);
     96 }
     97 
     98 template<typename A,typename B>
     99 void blas_gemm(const A& a, const B& b, MatrixXd& c)
    100 {
    101   int M = c.rows(); int N = c.cols(); int K = a.cols();
    102   int lda = a.outerStride(); int ldb = b.outerStride(); int ldc = c.rows();
    103 
    104   dgemm_(&transA,&transB,&M,&N,&K,&done,
    105          const_cast<double*>(a.data()),&lda,
    106          const_cast<double*>(b.data()),&ldb,&done,
    107          c.data(),&ldc);
    108 }
    109 
    110 template<typename A,typename B>
    111 void blas_gemm(const A& a, const B& b, MatrixXcf& c)
    112 {
    113   int M = c.rows(); int N = c.cols(); int K = a.cols();
    114   int lda = a.outerStride(); int ldb = b.outerStride(); int ldc = c.rows();
    115 
    116   cgemm_(&transA,&transB,&M,&N,&K,(float*)&cfone,
    117          const_cast<float*>((const float*)a.data()),&lda,
    118          const_cast<float*>((const float*)b.data()),&ldb,(float*)&cfone,
    119          (float*)c.data(),&ldc);
    120 }
    121 
    122 template<typename A,typename B>
    123 void blas_gemm(const A& a, const B& b, MatrixXcd& c)
    124 {
    125   int M = c.rows(); int N = c.cols(); int K = a.cols();
    126   int lda = a.outerStride(); int ldb = b.outerStride(); int ldc = c.rows();
    127 
    128   zgemm_(&transA,&transB,&M,&N,&K,(double*)&cdone,
    129          const_cast<double*>((const double*)a.data()),&lda,
    130          const_cast<double*>((const double*)b.data()),&ldb,(double*)&cdone,
    131          (double*)c.data(),&ldc);
    132 }
    133 
    134 
    135 
    136 #endif
    137 
    138 void matlab_cplx_cplx(const M& ar, const M& ai, const M& br, const M& bi, M& cr, M& ci)
    139 {
    140   cr.noalias() += ar * br;
    141   cr.noalias() -= ai * bi;
    142   ci.noalias() += ar * bi;
    143   ci.noalias() += ai * br;
    144   // [cr ci] += [ar ai] * br + [-ai ar] * bi
    145 }
    146 
    147 void matlab_real_cplx(const M& a, const M& br, const M& bi, M& cr, M& ci)
    148 {
    149   cr.noalias() += a * br;
    150   ci.noalias() += a * bi;
    151 }
    152 
    153 void matlab_cplx_real(const M& ar, const M& ai, const M& b, M& cr, M& ci)
    154 {
    155   cr.noalias() += ar * b;
    156   ci.noalias() += ai * b;
    157 }
    158 
    159 
    160 
    161 template<typename A, typename B, typename C>
    162 EIGEN_DONT_INLINE void gemm(const A& a, const B& b, C& c)
    163 {
    164   c.noalias() += a * b;
    165 }
    166 
    167 int main(int argc, char ** argv)
    168 {
    169   std::ptrdiff_t l1 = internal::queryL1CacheSize();
    170   std::ptrdiff_t l2 = internal::queryTopLevelCacheSize();
    171   std::cout << "L1 cache size     = " << (l1>0 ? l1/1024 : -1) << " KB\n";
    172   std::cout << "L2/L3 cache size  = " << (l2>0 ? l2/1024 : -1) << " KB\n";
    173   typedef internal::gebp_traits<Scalar,Scalar> Traits;
    174   std::cout << "Register blocking = " << Traits::mr << " x " << Traits::nr << "\n";
    175 
    176   int rep = 1;    // number of repetitions per try
    177   int tries = 2;  // number of tries, we keep the best
    178 
    179   int s = 2048;
    180   int m = s;
    181   int n = s;
    182   int p = s;
    183   int cache_size1=-1, cache_size2=l2, cache_size3 = 0;
    184 
    185   bool need_help = false;
    186   for (int i=1; i<argc;)
    187   {
    188     if(argv[i][0]=='-')
    189     {
    190       if(argv[i][1]=='s')
    191       {
    192         ++i;
    193         s = atoi(argv[i++]);
    194         m = n = p = s;
    195         if(argv[i][0]!='-')
    196         {
    197           n = atoi(argv[i++]);
    198           p = atoi(argv[i++]);
    199         }
    200       }
    201       else if(argv[i][1]=='c')
    202       {
    203         ++i;
    204         cache_size1 = atoi(argv[i++]);
    205         if(argv[i][0]!='-')
    206         {
    207           cache_size2 = atoi(argv[i++]);
    208           if(argv[i][0]!='-')
    209             cache_size3 = atoi(argv[i++]);
    210         }
    211       }
    212       else if(argv[i][1]=='t')
    213       {
    214         tries = atoi(argv[++i]);
    215         ++i;
    216       }
    217       else if(argv[i][1]=='p')
    218       {
    219         ++i;
    220         rep = atoi(argv[i++]);
    221       }
    222     }
    223     else
    224     {
    225       need_help = true;
    226       break;
    227     }
    228   }
    229 
    230   if(need_help)
    231   {
    232     std::cout << argv[0] << " -s <matrix sizes> -c <cache sizes> -t <nb tries> -p <nb repeats>\n";
    233     std::cout << "   <matrix sizes> : size\n";
    234     std::cout << "   <matrix sizes> : rows columns depth\n";
    235     return 1;
    236   }
    237 
    238 #if EIGEN_VERSION_AT_LEAST(3,2,90)
    239   if(cache_size1>0)
    240     setCpuCacheSizes(cache_size1,cache_size2,cache_size3);
    241 #endif
    242   
    243   A a(m,p); a.setRandom();
    244   B b(p,n); b.setRandom();
    245   C c(m,n); c.setOnes();
    246   C rc = c;
    247 
    248   std::cout << "Matrix sizes = " << m << "x" << p << " * " << p << "x" << n << "\n";
    249   std::ptrdiff_t mc(m), nc(n), kc(p);
    250   internal::computeProductBlockingSizes<Scalar,Scalar>(kc, mc, nc);
    251   std::cout << "blocking size (mc x kc) = " << mc << " x " << kc << " x " << nc << "\n";
    252 
    253   C r = c;
    254 
    255   // check the parallel product is correct
    256   #if defined EIGEN_HAS_OPENMP
    257   Eigen::initParallel();
    258   int procs = omp_get_max_threads();
    259   if(procs>1)
    260   {
    261     #ifdef HAVE_BLAS
    262     blas_gemm(a,b,r);
    263     #else
    264     omp_set_num_threads(1);
    265     r.noalias() += a * b;
    266     omp_set_num_threads(procs);
    267     #endif
    268     c.noalias() += a * b;
    269     if(!r.isApprox(c)) std::cerr << "Warning, your parallel product is crap!\n\n";
    270   }
    271   #elif defined HAVE_BLAS
    272     blas_gemm(a,b,r);
    273     c.noalias() += a * b;
    274     if(!r.isApprox(c)) {
    275       std::cout << (r  - c).norm()/r.norm() << "\n";
    276       std::cerr << "Warning, your product is crap!\n\n";
    277     }
    278   #else
    279     if(1.*m*n*p<2000.*2000*2000)
    280     {
    281       gemm(a,b,c);
    282       r.noalias() += a.cast<Scalar>() .lazyProduct( b.cast<Scalar>() );
    283       if(!r.isApprox(c)) {
    284         std::cout << (r  - c).norm()/r.norm() << "\n";
    285         std::cerr << "Warning, your product is crap!\n\n";
    286       }
    287     }
    288   #endif
    289 
    290   #ifdef HAVE_BLAS
    291   BenchTimer tblas;
    292   c = rc;
    293   BENCH(tblas, tries, rep, blas_gemm(a,b,c));
    294   std::cout << "blas  cpu         " << tblas.best(CPU_TIMER)/rep  << "s  \t" << (double(m)*n*p*rep*2/tblas.best(CPU_TIMER))*1e-9  <<  " GFLOPS \t(" << tblas.total(CPU_TIMER)  << "s)\n";
    295   std::cout << "blas  real        " << tblas.best(REAL_TIMER)/rep << "s  \t" << (double(m)*n*p*rep*2/tblas.best(REAL_TIMER))*1e-9 <<  " GFLOPS \t(" << tblas.total(REAL_TIMER) << "s)\n";
    296   #endif
    297 
    298   // warm start
    299   if(b.norm()+a.norm()==123.554) std::cout << "\n";
    300 
    301   BenchTimer tmt;
    302   c = rc;
    303   BENCH(tmt, tries, rep, gemm(a,b,c));
    304   std::cout << "eigen cpu         " << tmt.best(CPU_TIMER)/rep  << "s  \t" << (double(m)*n*p*rep*2/tmt.best(CPU_TIMER))*1e-9  <<  " GFLOPS \t(" << tmt.total(CPU_TIMER)  << "s)\n";
    305   std::cout << "eigen real        " << tmt.best(REAL_TIMER)/rep << "s  \t" << (double(m)*n*p*rep*2/tmt.best(REAL_TIMER))*1e-9 <<  " GFLOPS \t(" << tmt.total(REAL_TIMER) << "s)\n";
    306 
    307   #ifdef EIGEN_HAS_OPENMP
    308   if(procs>1)
    309   {
    310     BenchTimer tmono;
    311     omp_set_num_threads(1);
    312     Eigen::setNbThreads(1);
    313     c = rc;
    314     BENCH(tmono, tries, rep, gemm(a,b,c));
    315     std::cout << "eigen mono cpu    " << tmono.best(CPU_TIMER)/rep  << "s  \t" << (double(m)*n*p*rep*2/tmono.best(CPU_TIMER))*1e-9  <<  " GFLOPS \t(" << tmono.total(CPU_TIMER)  << "s)\n";
    316     std::cout << "eigen mono real   " << tmono.best(REAL_TIMER)/rep << "s  \t" << (double(m)*n*p*rep*2/tmono.best(REAL_TIMER))*1e-9 <<  " GFLOPS \t(" << tmono.total(REAL_TIMER) << "s)\n";
    317     std::cout << "mt speed up x" << tmono.best(CPU_TIMER) / tmt.best(REAL_TIMER)  << " => " << (100.0*tmono.best(CPU_TIMER) / tmt.best(REAL_TIMER))/procs << "%\n";
    318   }
    319   #endif
    320   
    321   if(1.*m*n*p<30*30*30)
    322   {
    323     BenchTimer tmt;
    324     c = rc;
    325     BENCH(tmt, tries, rep, c.noalias()+=a.lazyProduct(b));
    326     std::cout << "lazy cpu         " << tmt.best(CPU_TIMER)/rep  << "s  \t" << (double(m)*n*p*rep*2/tmt.best(CPU_TIMER))*1e-9  <<  " GFLOPS \t(" << tmt.total(CPU_TIMER)  << "s)\n";
    327     std::cout << "lazy real        " << tmt.best(REAL_TIMER)/rep << "s  \t" << (double(m)*n*p*rep*2/tmt.best(REAL_TIMER))*1e-9 <<  " GFLOPS \t(" << tmt.total(REAL_TIMER) << "s)\n";
    328   }
    329   
    330   #ifdef DECOUPLED
    331   if((NumTraits<A::Scalar>::IsComplex) && (NumTraits<B::Scalar>::IsComplex))
    332   {
    333     M ar(m,p); ar.setRandom();
    334     M ai(m,p); ai.setRandom();
    335     M br(p,n); br.setRandom();
    336     M bi(p,n); bi.setRandom();
    337     M cr(m,n); cr.setRandom();
    338     M ci(m,n); ci.setRandom();
    339     
    340     BenchTimer t;
    341     BENCH(t, tries, rep, matlab_cplx_cplx(ar,ai,br,bi,cr,ci));
    342     std::cout << "\"matlab\" cpu    " << t.best(CPU_TIMER)/rep  << "s  \t" << (double(m)*n*p*rep*2/t.best(CPU_TIMER))*1e-9  <<  " GFLOPS \t(" << t.total(CPU_TIMER)  << "s)\n";
    343     std::cout << "\"matlab\" real   " << t.best(REAL_TIMER)/rep << "s  \t" << (double(m)*n*p*rep*2/t.best(REAL_TIMER))*1e-9 <<  " GFLOPS \t(" << t.total(REAL_TIMER) << "s)\n";
    344   }
    345   if((!NumTraits<A::Scalar>::IsComplex) && (NumTraits<B::Scalar>::IsComplex))
    346   {
    347     M a(m,p);  a.setRandom();
    348     M br(p,n); br.setRandom();
    349     M bi(p,n); bi.setRandom();
    350     M cr(m,n); cr.setRandom();
    351     M ci(m,n); ci.setRandom();
    352     
    353     BenchTimer t;
    354     BENCH(t, tries, rep, matlab_real_cplx(a,br,bi,cr,ci));
    355     std::cout << "\"matlab\" cpu    " << t.best(CPU_TIMER)/rep  << "s  \t" << (double(m)*n*p*rep*2/t.best(CPU_TIMER))*1e-9  <<  " GFLOPS \t(" << t.total(CPU_TIMER)  << "s)\n";
    356     std::cout << "\"matlab\" real   " << t.best(REAL_TIMER)/rep << "s  \t" << (double(m)*n*p*rep*2/t.best(REAL_TIMER))*1e-9 <<  " GFLOPS \t(" << t.total(REAL_TIMER) << "s)\n";
    357   }
    358   if((NumTraits<A::Scalar>::IsComplex) && (!NumTraits<B::Scalar>::IsComplex))
    359   {
    360     M ar(m,p); ar.setRandom();
    361     M ai(m,p); ai.setRandom();
    362     M b(p,n);  b.setRandom();
    363     M cr(m,n); cr.setRandom();
    364     M ci(m,n); ci.setRandom();
    365     
    366     BenchTimer t;
    367     BENCH(t, tries, rep, matlab_cplx_real(ar,ai,b,cr,ci));
    368     std::cout << "\"matlab\" cpu    " << t.best(CPU_TIMER)/rep  << "s  \t" << (double(m)*n*p*rep*2/t.best(CPU_TIMER))*1e-9  <<  " GFLOPS \t(" << t.total(CPU_TIMER)  << "s)\n";
    369     std::cout << "\"matlab\" real   " << t.best(REAL_TIMER)/rep << "s  \t" << (double(m)*n*p*rep*2/t.best(REAL_TIMER))*1e-9 <<  " GFLOPS \t(" << t.total(REAL_TIMER) << "s)\n";
    370   }
    371   #endif
    372 
    373   return 0;
    374 }
    375