cart-elc

Source code for CART-ELC
git clone git://git.laack.co/cart-elc.git
Log | Files | Refs | README | LICENSE

TensorContractionGpu.h (63402B)


      1 // This file is part of Eigen, a lightweight C++ template library
      2 // for linear algebra.
      3 //
      4 // Copyright (C) 2014-2015 Benoit Steiner <benoit.steiner.goog@gmail.com>
      5 // Copyright (C) 2015 Navdeep Jaitly <ndjaitly@google.com>
      6 // Copyright (C) 2014 Eric Martin <eric@ericmart.in>
      7 //
      8 // This Source Code Form is subject to the terms of the Mozilla
      9 // Public License v. 2.0. If a copy of the MPL was not distributed
     10 // with this file, You can obtain one at http://mozilla.org/MPL/2.0/.
     11 
     12 #ifndef EIGEN_CXX11_TENSOR_TENSOR_CONTRACTION_GPU_H
     13 #define EIGEN_CXX11_TENSOR_TENSOR_CONTRACTION_GPU_H
     14 
     15 #if defined(EIGEN_USE_GPU) && defined(EIGEN_GPUCC)
     16 
     17 namespace Eigen {
     18 
     19 template<typename Scalar, typename Index, typename LhsMapper,
     20          typename RhsMapper, typename OutputMapper, bool needs_edge_check>
     21 __device__ EIGEN_STRONG_INLINE void
     22 EigenContractionKernelInternal(const LhsMapper lhs, const RhsMapper rhs,
     23                                const OutputMapper output, Scalar* lhs_shmem, Scalar* rhs_shmem,
     24                        const Index m_size, const Index n_size, const Index k_size) {
     25 
     26   const Index m_block_idx = blockIdx.x;
     27   const Index n_block_idx = blockIdx.y;
     28 
     29   const Index base_m = 64 * m_block_idx;
     30   const Index base_n = 64 * n_block_idx;
     31 
     32   // declare and initialize 64 registers for output 8x8 block
     33 
     34   // prefetch registers
     35   Scalar lhs_pf0;
     36   Scalar lhs_pf1;
     37   Scalar lhs_pf2;
     38   Scalar lhs_pf3;
     39   Scalar lhs_pf4;
     40   Scalar lhs_pf5;
     41   Scalar lhs_pf6;
     42   Scalar lhs_pf7;
     43 
     44   Scalar rhs_pf0;
     45   Scalar rhs_pf1;
     46   Scalar rhs_pf2;
     47   Scalar rhs_pf3;
     48   Scalar rhs_pf4;
     49   Scalar rhs_pf5;
     50   Scalar rhs_pf6;
     51   Scalar rhs_pf7;
     52 
     53   // shared memory is formatted
     54   // (contract idx in block, nocontract idx in block, block idx)
     55   // where block idx is column major. This transposition limits the number of
     56   // bank conflicts when reading the LHS. The core idea is that since the contracting
     57   // index is shared by both sides, then the contracting index should be in threadIdx.x.
     58 
     59   // On the LHS, we pad each row inside of each block with an extra element. This makes
     60   // each block 8 rows of 9 elements, which is 72 elements. This gives no bank conflicts
     61   // on writes and very few 2-way conflicts on reads. There is an 8x8 grid of these blocks.
     62 
     63   // On the RHS we just add 8 padding elements to the end of each block. This gives no bank
     64   // conflicts on writes and also none on reads.
     65 
     66   // storage indices
     67   const Index lhs_store_idx_base = threadIdx.y * 72 + threadIdx.x * 9 + threadIdx.z;
     68   const Index rhs_store_idx_base = threadIdx.y * 72 + threadIdx.z * 8 + threadIdx.x;
     69 
     70   const Index lhs_store_idx_0 = lhs_store_idx_base + 576 * 0;
     71   const Index lhs_store_idx_1 = lhs_store_idx_base + 576 * 1;
     72   const Index lhs_store_idx_2 = lhs_store_idx_base + 576 * 2;
     73   const Index lhs_store_idx_3 = lhs_store_idx_base + 576 * 3;
     74   const Index lhs_store_idx_4 = lhs_store_idx_base + 576 * 4;
     75   const Index lhs_store_idx_5 = lhs_store_idx_base + 576 * 5;
     76   const Index lhs_store_idx_6 = lhs_store_idx_base + 576 * 6;
     77   const Index lhs_store_idx_7 = lhs_store_idx_base + 576 * 7;
     78 
     79   const Index rhs_store_idx_0 = rhs_store_idx_base + 576 * 0;
     80   const Index rhs_store_idx_1 = rhs_store_idx_base + 576 * 1;
     81   const Index rhs_store_idx_2 = rhs_store_idx_base + 576 * 2;
     82   const Index rhs_store_idx_3 = rhs_store_idx_base + 576 * 3;
     83   const Index rhs_store_idx_4 = rhs_store_idx_base + 576 * 4;
     84   const Index rhs_store_idx_5 = rhs_store_idx_base + 576 * 5;
     85   const Index rhs_store_idx_6 = rhs_store_idx_base + 576 * 6;
     86   const Index rhs_store_idx_7 = rhs_store_idx_base + 576 * 7;
     87 
     88   // in the loading code, the following variables are important:
     89   // threadIdx.x: the vertical position in an 8x8 block
     90   // threadIdx.y: the vertical index of the 8x8 block in the grid
     91   // threadIdx.z: the horizontal position in an 8x8 block
     92   // k: the horizontal index of the 8x8 block in the grid
     93   //
     94   // The k parameter is implicit (it was the loop counter for a loop that went
     95   // from 0 to <8, but now that loop is unrolled in the below code.
     96 
     97   const Index load_idx_vert = threadIdx.x + 8 * threadIdx.y;
     98   const Index lhs_vert = base_m + load_idx_vert;
     99 
    100 #define prefetchIntoRegisters(base_k)                           \
    101   {                                                             \
    102     lhs_pf0 = conv(0);                                          \
    103     lhs_pf1 = conv(0);                                          \
    104     lhs_pf2 = conv(0);                                          \
    105     lhs_pf3 = conv(0);                                          \
    106     lhs_pf4 = conv(0);                                          \
    107     lhs_pf5 = conv(0);                                          \
    108     lhs_pf6 = conv(0);                                          \
    109     lhs_pf7 = conv(0);                                          \
    110                                                                 \
    111     rhs_pf0 = conv(0);                                          \
    112     rhs_pf1 = conv(0);                                          \
    113     rhs_pf2 = conv(0);                                          \
    114     rhs_pf3 = conv(0);                                          \
    115     rhs_pf4 = conv(0);                                          \
    116     rhs_pf5 = conv(0);                                          \
    117     rhs_pf6 = conv(0);                                          \
    118     rhs_pf7 = conv(0);                                          \
    119                                                                 \
    120     if (!needs_edge_check || lhs_vert < m_size) {               \
    121       const Index lhs_horiz_0 = base_k + threadIdx.z + 0 * 8;   \
    122       const Index lhs_horiz_1 = base_k + threadIdx.z + 1 * 8;   \
    123       const Index lhs_horiz_2 = base_k + threadIdx.z + 2 * 8;   \
    124       const Index lhs_horiz_3 = base_k + threadIdx.z + 3 * 8;   \
    125       const Index lhs_horiz_4 = base_k + threadIdx.z + 4 * 8;   \
    126       const Index lhs_horiz_5 = base_k + threadIdx.z + 5 * 8;   \
    127       const Index lhs_horiz_6 = base_k + threadIdx.z + 6 * 8;   \
    128       const Index lhs_horiz_7 = base_k + threadIdx.z + 7 * 8;   \
    129                                                                 \
    130       if (!needs_edge_check || lhs_horiz_7 < k_size) {          \
    131         lhs_pf0 = lhs(lhs_vert, lhs_horiz_0);                   \
    132         lhs_pf1 = lhs(lhs_vert, lhs_horiz_1);                   \
    133         lhs_pf2 = lhs(lhs_vert, lhs_horiz_2);                   \
    134         lhs_pf3 = lhs(lhs_vert, lhs_horiz_3);                   \
    135         lhs_pf4 = lhs(lhs_vert, lhs_horiz_4);                   \
    136         lhs_pf5 = lhs(lhs_vert, lhs_horiz_5);                   \
    137         lhs_pf6 = lhs(lhs_vert, lhs_horiz_6);                   \
    138         lhs_pf7 = lhs(lhs_vert, lhs_horiz_7);                   \
    139       } else if (lhs_horiz_6 < k_size) {                        \
    140         lhs_pf0 = lhs(lhs_vert, lhs_horiz_0);                   \
    141         lhs_pf1 = lhs(lhs_vert, lhs_horiz_1);                   \
    142         lhs_pf2 = lhs(lhs_vert, lhs_horiz_2);                   \
    143         lhs_pf3 = lhs(lhs_vert, lhs_horiz_3);                   \
    144         lhs_pf4 = lhs(lhs_vert, lhs_horiz_4);                   \
    145         lhs_pf5 = lhs(lhs_vert, lhs_horiz_5);                   \
    146         lhs_pf6 = lhs(lhs_vert, lhs_horiz_6);                   \
    147       } else if (lhs_horiz_5 < k_size) {                        \
    148         lhs_pf0 = lhs(lhs_vert, lhs_horiz_0);                   \
    149         lhs_pf1 = lhs(lhs_vert, lhs_horiz_1);                   \
    150         lhs_pf2 = lhs(lhs_vert, lhs_horiz_2);                   \
    151         lhs_pf3 = lhs(lhs_vert, lhs_horiz_3);                   \
    152         lhs_pf4 = lhs(lhs_vert, lhs_horiz_4);                   \
    153         lhs_pf5 = lhs(lhs_vert, lhs_horiz_5);                   \
    154       } else if (lhs_horiz_4 < k_size) {                        \
    155         lhs_pf0 = lhs(lhs_vert, lhs_horiz_0);                   \
    156         lhs_pf1 = lhs(lhs_vert, lhs_horiz_1);                   \
    157         lhs_pf2 = lhs(lhs_vert, lhs_horiz_2);                   \
    158         lhs_pf3 = lhs(lhs_vert, lhs_horiz_3);                   \
    159         lhs_pf4 = lhs(lhs_vert, lhs_horiz_4);                   \
    160       } else if (lhs_horiz_3 < k_size) {                        \
    161         lhs_pf0 = lhs(lhs_vert, lhs_horiz_0);                   \
    162         lhs_pf1 = lhs(lhs_vert, lhs_horiz_1);                   \
    163         lhs_pf2 = lhs(lhs_vert, lhs_horiz_2);                   \
    164         lhs_pf3 = lhs(lhs_vert, lhs_horiz_3);                   \
    165       } else if (lhs_horiz_2 < k_size) {                        \
    166         lhs_pf0 = lhs(lhs_vert, lhs_horiz_0);                   \
    167         lhs_pf1 = lhs(lhs_vert, lhs_horiz_1);                   \
    168         lhs_pf2 = lhs(lhs_vert, lhs_horiz_2);                   \
    169       } else if (lhs_horiz_1 < k_size) {                        \
    170         lhs_pf0 = lhs(lhs_vert, lhs_horiz_0);                   \
    171         lhs_pf1 = lhs(lhs_vert, lhs_horiz_1);                   \
    172       } else if (lhs_horiz_0 < k_size) {                        \
    173         lhs_pf0 = lhs(lhs_vert, lhs_horiz_0);                   \
    174       }                                                         \
    175     }                                                           \
    176                                                                 \
    177     const Index rhs_vert = base_k + load_idx_vert;              \
    178     if (!needs_edge_check || rhs_vert < k_size) {               \
    179       const Index rhs_horiz_0 = base_n + threadIdx.z + 0 * 8;   \
    180       const Index rhs_horiz_1 = base_n + threadIdx.z + 1 * 8;   \
    181       const Index rhs_horiz_2 = base_n + threadIdx.z + 2 * 8;   \
    182       const Index rhs_horiz_3 = base_n + threadIdx.z + 3 * 8;   \
    183       const Index rhs_horiz_4 = base_n + threadIdx.z + 4 * 8;   \
    184       const Index rhs_horiz_5 = base_n + threadIdx.z + 5 * 8;   \
    185       const Index rhs_horiz_6 = base_n + threadIdx.z + 6 * 8;   \
    186       const Index rhs_horiz_7 = base_n + threadIdx.z + 7 * 8;   \
    187                                                                 \
    188       if (rhs_horiz_7 < n_size) {                               \
    189         rhs_pf0 = rhs(rhs_vert, rhs_horiz_0);                   \
    190         rhs_pf1 = rhs(rhs_vert, rhs_horiz_1);                   \
    191         rhs_pf2 = rhs(rhs_vert, rhs_horiz_2);                   \
    192         rhs_pf3 = rhs(rhs_vert, rhs_horiz_3);                   \
    193         rhs_pf4 = rhs(rhs_vert, rhs_horiz_4);                   \
    194         rhs_pf5 = rhs(rhs_vert, rhs_horiz_5);                   \
    195         rhs_pf6 = rhs(rhs_vert, rhs_horiz_6);                   \
    196         rhs_pf7 = rhs(rhs_vert, rhs_horiz_7);                   \
    197       } else if (rhs_horiz_6 < n_size) {                        \
    198         rhs_pf0 = rhs(rhs_vert, rhs_horiz_0);                   \
    199         rhs_pf1 = rhs(rhs_vert, rhs_horiz_1);                   \
    200         rhs_pf2 = rhs(rhs_vert, rhs_horiz_2);                   \
    201         rhs_pf3 = rhs(rhs_vert, rhs_horiz_3);                   \
    202         rhs_pf4 = rhs(rhs_vert, rhs_horiz_4);                   \
    203         rhs_pf5 = rhs(rhs_vert, rhs_horiz_5);                   \
    204         rhs_pf6 = rhs(rhs_vert, rhs_horiz_6);                   \
    205       } else if (rhs_horiz_5 < n_size) {                        \
    206         rhs_pf0 = rhs(rhs_vert, rhs_horiz_0);                   \
    207         rhs_pf1 = rhs(rhs_vert, rhs_horiz_1);                   \
    208         rhs_pf2 = rhs(rhs_vert, rhs_horiz_2);                   \
    209         rhs_pf3 = rhs(rhs_vert, rhs_horiz_3);                   \
    210         rhs_pf4 = rhs(rhs_vert, rhs_horiz_4);                   \
    211         rhs_pf5 = rhs(rhs_vert, rhs_horiz_5);                   \
    212       } else if (rhs_horiz_4 < n_size) {                        \
    213         rhs_pf0 = rhs(rhs_vert, rhs_horiz_0);                   \
    214         rhs_pf1 = rhs(rhs_vert, rhs_horiz_1);                   \
    215         rhs_pf2 = rhs(rhs_vert, rhs_horiz_2);                   \
    216         rhs_pf3 = rhs(rhs_vert, rhs_horiz_3);                   \
    217         rhs_pf4 = rhs(rhs_vert, rhs_horiz_4);                   \
    218       } else if (rhs_horiz_3 < n_size) {                        \
    219         rhs_pf0 = rhs(rhs_vert, rhs_horiz_0);                   \
    220         rhs_pf1 = rhs(rhs_vert, rhs_horiz_1);                   \
    221         rhs_pf2 = rhs(rhs_vert, rhs_horiz_2);                   \
    222         rhs_pf3 = rhs(rhs_vert, rhs_horiz_3);                   \
    223       } else if (rhs_horiz_2 < n_size) {                        \
    224         rhs_pf0 = rhs(rhs_vert, rhs_horiz_0);                   \
    225         rhs_pf1 = rhs(rhs_vert, rhs_horiz_1);                   \
    226         rhs_pf2 = rhs(rhs_vert, rhs_horiz_2);                   \
    227       } else if (rhs_horiz_1 < n_size) {                        \
    228         rhs_pf0 = rhs(rhs_vert, rhs_horiz_0);                   \
    229         rhs_pf1 = rhs(rhs_vert, rhs_horiz_1);                   \
    230       } else if (rhs_horiz_0 < n_size) {                        \
    231         rhs_pf0 = rhs(rhs_vert, rhs_horiz_0);                   \
    232       }                                                         \
    233     }                                                           \
    234   }                                                             \
    235 
    236 #define writeRegToShmem(_)                      \
    237   lhs_shmem[lhs_store_idx_0] = lhs_pf0;         \
    238   rhs_shmem[rhs_store_idx_0] = rhs_pf0;         \
    239                                                 \
    240   lhs_shmem[lhs_store_idx_1] = lhs_pf1;         \
    241   rhs_shmem[rhs_store_idx_1] = rhs_pf1;         \
    242                                                 \
    243   lhs_shmem[lhs_store_idx_2] = lhs_pf2;         \
    244   rhs_shmem[rhs_store_idx_2] = rhs_pf2;         \
    245                                                 \
    246   lhs_shmem[lhs_store_idx_3] = lhs_pf3;         \
    247   rhs_shmem[rhs_store_idx_3] = rhs_pf3;         \
    248                                                 \
    249   lhs_shmem[lhs_store_idx_4] = lhs_pf4;         \
    250   rhs_shmem[rhs_store_idx_4] = rhs_pf4;         \
    251                                                 \
    252   lhs_shmem[lhs_store_idx_5] = lhs_pf5;         \
    253   rhs_shmem[rhs_store_idx_5] = rhs_pf5;         \
    254                                                 \
    255   lhs_shmem[lhs_store_idx_6] = lhs_pf6;         \
    256   rhs_shmem[rhs_store_idx_6] = rhs_pf6;         \
    257                                                 \
    258   lhs_shmem[lhs_store_idx_7] = lhs_pf7;         \
    259   rhs_shmem[rhs_store_idx_7] = rhs_pf7;         \
    260 
    261   // declare and initialize result array
    262 #define res(i, j) _res_##i##j
    263 #define initResultRow(i)                        \
    264   Scalar res(i, 0) = conv(0);                   \
    265   Scalar res(i, 1) = conv(0);                   \
    266   Scalar res(i, 2) = conv(0);                   \
    267   Scalar res(i, 3) = conv(0);                   \
    268   Scalar res(i, 4) = conv(0);                   \
    269   Scalar res(i, 5) = conv(0);                   \
    270   Scalar res(i, 6) = conv(0);                   \
    271   Scalar res(i, 7) = conv(0);                   \
    272 
    273   internal::scalar_cast_op<int, Scalar> conv;
    274   initResultRow(0);
    275   initResultRow(1);
    276   initResultRow(2);
    277   initResultRow(3);
    278   initResultRow(4);
    279   initResultRow(5);
    280   initResultRow(6);
    281   initResultRow(7);
    282 #undef initResultRow
    283 
    284   for (Index base_k = 0; base_k < k_size; base_k += 64) {
    285     // wait for previous iteration to finish with shmem. Despite common sense,
    286     // the code is a bit faster with this here then at bottom of loop
    287     __syncthreads();
    288 
    289     prefetchIntoRegisters(base_k);
    290     writeRegToShmem();
    291 
    292     #undef prefetchIntoRegisters
    293     #undef writeRegToShmem
    294 
    295     // wait for shared mem packing to be done before starting computation
    296     __syncthreads();
    297 
    298     // compute 8x8 matrix product by outer product. This involves packing one column
    299     // of LHS and one row of RHS into registers (takes 16 registers).
    300 
    301 #define lcol(i) _lcol##i
    302     Scalar lcol(0);
    303     Scalar lcol(1);
    304     Scalar lcol(2);
    305     Scalar lcol(3);
    306     Scalar lcol(4);
    307     Scalar lcol(5);
    308     Scalar lcol(6);
    309     Scalar lcol(7);
    310 
    311 #define rrow(j) _rrow##j
    312     Scalar rrow(0);
    313     Scalar rrow(1);
    314     Scalar rrow(2);
    315     Scalar rrow(3);
    316     Scalar rrow(4);
    317     Scalar rrow(5);
    318     Scalar rrow(6);
    319     Scalar rrow(7);
    320 
    321     // Now x corresponds to k, y to m, and z to n
    322     const Scalar* lhs_block = &lhs_shmem[threadIdx.x + 9 * threadIdx.y];
    323     const Scalar* rhs_block = &rhs_shmem[threadIdx.x + 8 * threadIdx.z];
    324 
    325 #define lhs_element(i, j) lhs_block[72 * ((i) + 8 * (j))]
    326 #define rhs_element(i, j) rhs_block[72 * ((i) + 8 * (j))]
    327 
    328 #define loadData(i, j)                          \
    329     lcol(0) = lhs_element(0, j);               \
    330     rrow(0) = rhs_element(i, 0);               \
    331     lcol(1) = lhs_element(1, j);               \
    332     rrow(1) = rhs_element(i, 1);               \
    333     lcol(2) = lhs_element(2, j);               \
    334     rrow(2) = rhs_element(i, 2);               \
    335     lcol(3) = lhs_element(3, j);               \
    336     rrow(3) = rhs_element(i, 3);               \
    337     lcol(4) = lhs_element(4, j);               \
    338     rrow(4) = rhs_element(i, 4);               \
    339     lcol(5) = lhs_element(5, j);               \
    340     rrow(5) = rhs_element(i, 5);               \
    341     lcol(6) = lhs_element(6, j);               \
    342     rrow(6) = rhs_element(i, 6);               \
    343     lcol(7) = lhs_element(7, j);               \
    344     rrow(7) = rhs_element(i, 7);               \
    345 
    346 #define computeCol(j)                           \
    347     res(0, j) += lcol(0) * rrow(j);             \
    348     res(1, j) += lcol(1) * rrow(j);             \
    349     res(2, j) += lcol(2) * rrow(j);             \
    350     res(3, j) += lcol(3) * rrow(j);             \
    351     res(4, j) += lcol(4) * rrow(j);             \
    352     res(5, j) += lcol(5) * rrow(j);             \
    353     res(6, j) += lcol(6) * rrow(j);             \
    354     res(7, j) += lcol(7) * rrow(j);             \
    355 
    356 #define computePass(i)                          \
    357     loadData(i, i);                             \
    358                                                 \
    359     computeCol(0);                              \
    360     computeCol(1);                              \
    361     computeCol(2);                              \
    362     computeCol(3);                              \
    363     computeCol(4);                              \
    364     computeCol(5);                              \
    365     computeCol(6);                              \
    366     computeCol(7);                              \
    367 
    368     computePass(0);
    369     computePass(1);
    370     computePass(2);
    371     computePass(3);
    372     computePass(4);
    373     computePass(5);
    374     computePass(6);
    375     computePass(7);
    376 
    377 #undef lcol
    378 #undef rrow
    379 #undef lhs_element
    380 #undef rhs_element
    381 #undef loadData
    382 #undef computeCol
    383 #undef computePass
    384   } // end loop over k
    385 
    386   // we've now iterated over all of the large (ie width 64) k blocks and
    387   // accumulated results in registers. At this point thread (x, y, z) contains
    388   // the sum across all big k blocks of the product of little k block of index (x, y)
    389   // with block of index (y, z). To compute the final output, we need to reduce
    390   // the 8 threads over y by summation.
    391 #if defined(EIGEN_HIPCC) || (defined(EIGEN_CUDA_SDK_VER) && EIGEN_CUDA_SDK_VER < 90000)
    392 #define shuffleInc(i, j, mask) res(i, j) += __shfl_xor(res(i, j), mask)
    393 #else
    394 #define shuffleInc(i, j, mask) res(i, j) += __shfl_xor_sync(0xFFFFFFFF, res(i, j), mask)
    395 #endif
    396 
    397 #define reduceRow(i, mask)                      \
    398   shuffleInc(i, 0, mask);                       \
    399   shuffleInc(i, 1, mask);                       \
    400   shuffleInc(i, 2, mask);                       \
    401   shuffleInc(i, 3, mask);                       \
    402   shuffleInc(i, 4, mask);                       \
    403   shuffleInc(i, 5, mask);                       \
    404   shuffleInc(i, 6, mask);                       \
    405   shuffleInc(i, 7, mask);                       \
    406 
    407 #define reduceMatrix(mask)                      \
    408   reduceRow(0, mask);                           \
    409   reduceRow(1, mask);                           \
    410   reduceRow(2, mask);                           \
    411   reduceRow(3, mask);                           \
    412   reduceRow(4, mask);                           \
    413   reduceRow(5, mask);                           \
    414   reduceRow(6, mask);                           \
    415   reduceRow(7, mask);                           \
    416 
    417   // actually perform the reduction, now each thread of index (_, y, z)
    418   // contains the correct values in its registers that belong in the output
    419   // block
    420   reduceMatrix(1);
    421   reduceMatrix(2);
    422   reduceMatrix(4);
    423 
    424 #undef shuffleInc
    425 #undef reduceRow
    426 #undef reduceMatrix
    427 
    428   // now we need to copy the 64 values into main memory. We can't split work
    429   // among threads because all variables are in registers. There's 2 ways
    430   // to do this:
    431   // (1) have 1 thread do 64 writes from registers into global memory
    432   // (2) have 1 thread do 64 writes into shared memory, and then 8 threads
    433   //     each do 8 writes into global memory. We can just overwrite the shared
    434   //     memory from the problem we just solved.
    435   // (2) is slightly faster than (1) due to less branching and more ILP
    436 
    437   // TODO: won't yield much gain, but could just use currently unused shared mem
    438   //       and then we won't have to sync
    439   // wait for shared mem to be out of use
    440   __syncthreads();
    441 
    442 #define writeResultShmem(i, j)                                          \
    443   lhs_shmem[i + 8 * threadIdx.y + 64 * threadIdx.z + 512 * j] = res(i, j); \
    444 
    445 #define writeRow(i)                             \
    446   writeResultShmem(i, 0);                       \
    447   writeResultShmem(i, 1);                       \
    448   writeResultShmem(i, 2);                       \
    449   writeResultShmem(i, 3);                       \
    450   writeResultShmem(i, 4);                       \
    451   writeResultShmem(i, 5);                       \
    452   writeResultShmem(i, 6);                       \
    453   writeResultShmem(i, 7);                       \
    454 
    455   if (threadIdx.x == 0) {
    456     writeRow(0);
    457     writeRow(1);
    458     writeRow(2);
    459     writeRow(3);
    460     writeRow(4);
    461     writeRow(5);
    462     writeRow(6);
    463     writeRow(7);
    464   }
    465 #undef writeResultShmem
    466 #undef writeRow
    467 
    468   const int max_i_write = numext::mini((int)((m_size - base_m - threadIdx.y + 7) / 8), 8);
    469   const int max_j_write = numext::mini((int)((n_size - base_n - threadIdx.z + 7) / 8), 8);
    470 
    471   if (threadIdx.x < max_i_write) {
    472     if (max_j_write == 8) {
    473       // TODO: can i trade bank conflicts for coalesced writes?
    474       Scalar val0 = lhs_shmem[threadIdx.x + 8 * threadIdx.y + 64 * threadIdx.z + 512 * 0];
    475       Scalar val1 = lhs_shmem[threadIdx.x + 8 * threadIdx.y + 64 * threadIdx.z + 512 * 1];
    476       Scalar val2 = lhs_shmem[threadIdx.x + 8 * threadIdx.y + 64 * threadIdx.z + 512 * 2];
    477       Scalar val3 = lhs_shmem[threadIdx.x + 8 * threadIdx.y + 64 * threadIdx.z + 512 * 3];
    478       Scalar val4 = lhs_shmem[threadIdx.x + 8 * threadIdx.y + 64 * threadIdx.z + 512 * 4];
    479       Scalar val5 = lhs_shmem[threadIdx.x + 8 * threadIdx.y + 64 * threadIdx.z + 512 * 5];
    480       Scalar val6 = lhs_shmem[threadIdx.x + 8 * threadIdx.y + 64 * threadIdx.z + 512 * 6];
    481       Scalar val7 = lhs_shmem[threadIdx.x + 8 * threadIdx.y + 64 * threadIdx.z + 512 * 7];
    482 
    483       output(base_m + threadIdx.y + 8 * threadIdx.x, base_n + threadIdx.z + 8 * 0) = val0;
    484       output(base_m + threadIdx.y + 8 * threadIdx.x, base_n + threadIdx.z + 8 * 1) = val1;
    485       output(base_m + threadIdx.y + 8 * threadIdx.x, base_n + threadIdx.z + 8 * 2) = val2;
    486       output(base_m + threadIdx.y + 8 * threadIdx.x, base_n + threadIdx.z + 8 * 3) = val3;
    487       output(base_m + threadIdx.y + 8 * threadIdx.x, base_n + threadIdx.z + 8 * 4) = val4;
    488       output(base_m + threadIdx.y + 8 * threadIdx.x, base_n + threadIdx.z + 8 * 5) = val5;
    489       output(base_m + threadIdx.y + 8 * threadIdx.x, base_n + threadIdx.z + 8 * 6) = val6;
    490       output(base_m + threadIdx.y + 8 * threadIdx.x, base_n + threadIdx.z + 8 * 7) = val7;
    491     } else {
    492 #pragma unroll 7
    493       for (int j = 0; j < max_j_write; j++) {
    494         Scalar val = lhs_shmem[threadIdx.x + 8 * threadIdx.y + 64 * threadIdx.z + 512 * j];
    495         output(base_m + threadIdx.y + 8 * threadIdx.x, base_n + threadIdx.z + 8 * j) = val;
    496       }
    497     }
    498   }
    499 #undef res
    500 }
    501 
    502 
    503 template<typename Scalar, typename Index, typename LhsMapper,
    504          typename RhsMapper, typename OutputMapper>
    505 __global__ void
    506 #if defined(EIGEN_HIPCC)
    507 __launch_bounds__(512, 1)
    508 #else
    509 __launch_bounds__(512)
    510 #endif
    511 EigenContractionKernel(const LhsMapper lhs, const RhsMapper rhs,
    512                        const OutputMapper output,
    513                        const Index m_size, const Index n_size, const Index k_size) {
    514   __shared__ Scalar lhs_shmem[72 * 64];
    515   __shared__ Scalar rhs_shmem[72 * 64];
    516 
    517   const Index m_block_idx = blockIdx.x;
    518   const Index n_block_idx = blockIdx.y;
    519 
    520   const Index base_m = 64 * m_block_idx;
    521   const Index base_n = 64 * n_block_idx;
    522 
    523   if (base_m + 63 < m_size && base_n + 63 < n_size) {
    524     EigenContractionKernelInternal<Scalar, Index, LhsMapper, RhsMapper, OutputMapper, false>(lhs, rhs, output, lhs_shmem, rhs_shmem, m_size, n_size, k_size);
    525   } else {
    526     EigenContractionKernelInternal<Scalar, Index, LhsMapper, RhsMapper, OutputMapper, true>(lhs, rhs, output, lhs_shmem, rhs_shmem, m_size, n_size, k_size);
    527   }
    528 }
    529 
    530 
    531 template<typename Index, typename LhsMapper,
    532          typename RhsMapper, typename OutputMapper, bool CHECK_LHS_BOUNDARY,
    533          bool CHECK_RHS_BOUNDARY>
    534 __device__ __forceinline__ void
    535 EigenFloatContractionKernelInternal16x16(const LhsMapper lhs, const RhsMapper rhs,
    536                        const OutputMapper output, float2 lhs_shmem2[][16],
    537                        float2 rhs_shmem2[][8], const Index m_size,
    538                        const Index n_size, const Index k_size,
    539                        const Index base_m, const Index base_n) {
    540 
    541   // prefetch registers
    542   float4 lhs_pf0, rhs_pf0;
    543 
    544   float4 results[4];
    545   for (int i=0; i < 4; i++) {
    546     results[i].x = results[i].y = results[i].z = results[i].w = 0;
    547   }
    548 
    549 #define prefetch_lhs(reg, row, col)                            \
    550     if (!CHECK_LHS_BOUNDARY) {                                 \
    551       if (col < k_size) {                                      \
    552         reg =lhs.template loadPacket<float4,Unaligned>(row, col);     \
    553       }                                                        \
    554     } else {                                                   \
    555       if (col < k_size) {                                      \
    556         if (row + 3 < m_size) {                                \
    557           reg =lhs.template loadPacket<float4,Unaligned>(row, col);   \
    558         } else if (row + 2 < m_size) {                         \
    559           reg.x =lhs(row + 0, col);                            \
    560           reg.y =lhs(row + 1, col);                            \
    561           reg.z =lhs(row + 2, col);                            \
    562         } else if (row + 1 < m_size) {                         \
    563           reg.x =lhs(row + 0, col);                            \
    564           reg.y =lhs(row + 1, col);                            \
    565         } else if (row  < m_size) {                            \
    566           reg.x =lhs(row + 0, col);                            \
    567         }                                                      \
    568       }                                                        \
    569     }							       \
    570 
    571   Index lhs_vert = base_m+threadIdx.x*4;
    572 
    573   for (Index k = 0; k < k_size; k += 16) {
    574 
    575     lhs_pf0 = internal::pset1<float4>(0);
    576     rhs_pf0 = internal::pset1<float4>(0);
    577 
    578     Index lhs_horiz = threadIdx.y+k;
    579     prefetch_lhs(lhs_pf0, lhs_vert, lhs_horiz)
    580 
    581     Index rhs_vert = k+(threadIdx.x%4)*4;
    582     Index rhs_horiz0 = (threadIdx.x>>2)+threadIdx.y*4+base_n;
    583 
    584     if (!CHECK_RHS_BOUNDARY) {
    585       if ((rhs_vert + 3) < k_size) {
    586         // just CHECK_RHS_BOUNDARY
    587         rhs_pf0 = rhs.template loadPacket<float4,Unaligned>(rhs_vert, rhs_horiz0);
    588       } else if (rhs_vert + 2 < k_size) {
    589         // just CHECK_RHS_BOUNDARY
    590         rhs_pf0.x = rhs(rhs_vert, rhs_horiz0);
    591         rhs_pf0.y = rhs(rhs_vert + 1, rhs_horiz0);
    592         rhs_pf0.z = rhs(rhs_vert + 2, rhs_horiz0);
    593       } else if (rhs_vert + 1 < k_size) {
    594         rhs_pf0.x = rhs(rhs_vert, rhs_horiz0);
    595         rhs_pf0.y = rhs(rhs_vert + 1, rhs_horiz0);
    596       } else if (rhs_vert  < k_size) {
    597         rhs_pf0.x = rhs(rhs_vert, rhs_horiz0);
    598       }
    599     } else {
    600       if (rhs_horiz0 < n_size) {
    601         if ((rhs_vert + 3) < k_size) {
    602           rhs_pf0 = rhs.template loadPacket<float4,Unaligned>(rhs_vert, rhs_horiz0);
    603         } else if ((rhs_vert + 2) < k_size) {
    604           rhs_pf0.x = rhs(rhs_vert, rhs_horiz0);
    605           rhs_pf0.y = rhs(rhs_vert + 1, rhs_horiz0);
    606           rhs_pf0.z = rhs(rhs_vert + 2, rhs_horiz0);
    607         } else if ((rhs_vert + 1) < k_size) {
    608           rhs_pf0.x = rhs(rhs_vert, rhs_horiz0);
    609           rhs_pf0.y = rhs(rhs_vert + 1, rhs_horiz0);
    610         } else if (rhs_vert  < k_size) {
    611           rhs_pf0.x = rhs(rhs_vert, rhs_horiz0);
    612         }
    613       }
    614     }
    615     float x1, x2 ;
    616     // the following can be a bitwise operation..... some day.
    617     if((threadIdx.x%8) < 4) {
    618       x1 = rhs_pf0.y;
    619       x2 = rhs_pf0.w;
    620     } else {
    621       x1 = rhs_pf0.x;
    622       x2 = rhs_pf0.z;
    623     }
    624     #if defined(EIGEN_HIPCC) || (defined(EIGEN_CUDA_SDK_VER) && EIGEN_CUDA_SDK_VER < 90000)
    625     x1 = __shfl_xor(x1, 4);
    626     x2 = __shfl_xor(x2, 4);
    627     #else
    628     x1 = __shfl_xor_sync(0xFFFFFFFF, x1, 4);
    629     x2 = __shfl_xor_sync(0xFFFFFFFF, x2, 4);
    630     #endif
    631     if((threadIdx.x%8) < 4) {
    632       rhs_pf0.y = x1;
    633       rhs_pf0.w = x2;
    634     } else {
    635       rhs_pf0.x = x1;
    636       rhs_pf0.z = x2;
    637     }
    638 
    639     // We have 64 features.
    640     // Row 0 -> times (0, 4, 8, 12, 1, 5, 9, 13) for features 0, 1.
    641     // Row 1 -> times (0, 4, 8, 12, 1, 5, 9, 13) for features 2, 3.
    642     // ...
    643     // Row 31 -> times (0, 4, 8, 12, 1, 5, 9, 13) for features 62, 63
    644     // Row 32 -> times (2, 6, 10, 14, 3, 7, 11, 15) for features 0, 1
    645     // ...
    646     rhs_shmem2[(threadIdx.x>>3)+ threadIdx.y*2][threadIdx.x%8] = make_float2(rhs_pf0.x, rhs_pf0.y);
    647     rhs_shmem2[(threadIdx.x>>3)+ threadIdx.y*2+32][threadIdx.x%8] = make_float2(rhs_pf0.z, rhs_pf0.w);
    648 
    649     // Row 0 (time 0) -> features (0, 1), (4, 5), .. (28, 29), (32, 33), ..  (60, 61)
    650     // Row 1 (time 1) -> features (0, 1), (4, 5), .. (28, 29), (32, 33), ..  (60, 61)
    651     // ...
    652     // Row 15 (time 15) -> features (0, 1), (4, 5), .. (28, 29), (32, 33), ..  (60, 61)
    653     // Row 16 (time 0) -> features (2, 3), (6, 7), .. (30, 31), (34, 35), ..  (62, 63)
    654     // ...
    655 
    656     lhs_shmem2[threadIdx.y][threadIdx.x] = make_float2(lhs_pf0.x, lhs_pf0.y);
    657     lhs_shmem2[threadIdx.y+16][threadIdx.x] = make_float2(lhs_pf0.z, lhs_pf0.w);
    658 
    659 
    660 #define add_vals(fl1, fl2, fr1, fr2)\
    661     results[0].x += fl1.x * fr1.x;\
    662     results[0].y += fl1.y * fr1.x;\
    663     results[0].z += fl2.x * fr1.x;\
    664     results[0].w += fl2.y * fr1.x;\
    665 \
    666     results[1].x += fl1.x * fr1.y;\
    667     results[1].y += fl1.y * fr1.y;\
    668     results[1].z += fl2.x * fr1.y;\
    669     results[1].w += fl2.y * fr1.y;\
    670 \
    671     results[2].x += fl1.x * fr2.x;\
    672     results[2].y += fl1.y * fr2.x;\
    673     results[2].z += fl2.x * fr2.x;\
    674     results[2].w += fl2.y * fr2.x;\
    675 \
    676     results[3].x += fl1.x * fr2.y;\
    677     results[3].y += fl1.y * fr2.y;\
    678     results[3].z += fl2.x * fr2.y;\
    679     results[3].w += fl2.y * fr2.y;\
    680 
    681     __syncthreads();
    682 
    683     // Do the multiplies.
    684     #pragma unroll
    685     for (int koff = 0; koff < 16; koff ++) {
    686       // 32 x threads.
    687       float2 fl1 = lhs_shmem2[koff][threadIdx.x];
    688       float2 fl2 = lhs_shmem2[koff + 16][threadIdx.x];
    689 
    690       int start_feature = threadIdx.y * 4;
    691       float2 fr1 = rhs_shmem2[(start_feature>>1) + 32*((koff%4)/2)][koff/4 + (koff%2)*4];
    692       float2 fr2 = rhs_shmem2[(start_feature>>1) + 1 + 32*((koff%4)/2)][koff/4 + (koff%2)*4];
    693 
    694       add_vals(fl1, fl2, fr1, fr2)
    695     }
    696     __syncthreads();
    697   }
    698 
    699 #undef prefetch_lhs
    700 #undef add_vals
    701 
    702   Index horiz_base = threadIdx.y*4+base_n;
    703   if (!CHECK_LHS_BOUNDARY && !CHECK_RHS_BOUNDARY) {
    704     for (int i = 0; i < 4; i++) {
    705       output(lhs_vert, horiz_base + i) = results[i].x;
    706       output(lhs_vert + 1, horiz_base + i) = results[i].y;
    707       output(lhs_vert + 2, horiz_base + i) = results[i].z;
    708       output(lhs_vert + 3, horiz_base + i) = results[i].w;
    709     }
    710   } else if (!CHECK_RHS_BOUNDARY) {
    711     // CHECK LHS
    712     if (lhs_vert + 3 < m_size) {
    713       for (int i = 0; i < 4; i++) {
    714         output(lhs_vert, horiz_base + i) = results[i].x;
    715         output(lhs_vert + 1, horiz_base + i) = results[i].y;
    716         output(lhs_vert + 2, horiz_base + i) = results[i].z;
    717         output(lhs_vert + 3, horiz_base + i) = results[i].w;
    718       }
    719     } else if (lhs_vert + 2 < m_size) {
    720       for (int i = 0; i < 4; i++) {
    721         output(lhs_vert, horiz_base + i) = results[i].x;
    722         output(lhs_vert + 1, horiz_base + i) = results[i].y;
    723         output(lhs_vert + 2, horiz_base + i) = results[i].z;
    724       }
    725     } else if (lhs_vert + 1 < m_size) {
    726       for (int i = 0; i < 4; i++) {
    727         output(lhs_vert, horiz_base + i) = results[i].x;
    728         output(lhs_vert + 1, horiz_base + i) = results[i].y;
    729       }
    730     } else if (lhs_vert  < m_size) {
    731       for (int i = 0; i < 4; i++) {
    732         output(lhs_vert, horiz_base + i) = results[i].x;
    733       }
    734     }
    735   } else if (!CHECK_LHS_BOUNDARY) {
    736     // CHECK RHS
    737     /*
    738     int ncols_rem = fminf(n_size- horiz_base, 4);
    739     for (int i = 0; i < ncols_rem; i++) {
    740       output(lhs_vert, horiz_base + i) = results[i].x;
    741       output(lhs_vert + 1, horiz_base + i) = results[i].y;
    742       output(lhs_vert + 2, horiz_base + i) = results[i].z;
    743       output(lhs_vert + 3, horiz_base + i) = results[i].w;
    744     }*/
    745     for (int i = 0; i < 4; i++) {
    746       if (horiz_base+i < n_size) {
    747         output(lhs_vert, horiz_base + i) = results[i].x;
    748         output(lhs_vert + 1, horiz_base + i) = results[i].y;
    749         output(lhs_vert + 2, horiz_base + i) = results[i].z;
    750         output(lhs_vert + 3, horiz_base + i) = results[i].w;
    751        }
    752     }
    753   } else {
    754     // CHECK both boundaries.
    755     for (int i = 0; i < 4; i++) {
    756       if (horiz_base+i < n_size) {
    757         if (lhs_vert < m_size)
    758           output(lhs_vert, horiz_base + i) = results[i].x;
    759         if (lhs_vert + 1 < m_size)
    760           output(lhs_vert + 1, horiz_base + i) = results[i].y;
    761         if (lhs_vert + 2 < m_size)
    762           output(lhs_vert + 2, horiz_base + i) = results[i].z;
    763         if (lhs_vert + 3 < m_size)
    764           output(lhs_vert + 3, horiz_base + i) = results[i].w;
    765       }
    766     }
    767   }
    768 }
    769 
    770 
    771 template<typename Index, typename LhsMapper,
    772          typename RhsMapper, typename OutputMapper, bool CHECK_LHS_BOUNDARY,
    773          bool CHECK_RHS_BOUNDARY>
    774 __device__ __forceinline__ void
    775 EigenFloatContractionKernelInternal(const LhsMapper lhs, const RhsMapper rhs,
    776                        const OutputMapper output, float2 lhs_shmem2[][32],
    777                        float2 rhs_shmem2[][8], const Index m_size,
    778                        const Index n_size, const Index k_size,
    779                        const Index base_m, const Index base_n) {
    780 
    781   // prefetch registers
    782   float4 lhs_pf0, lhs_pf1, lhs_pf2, lhs_pf3;
    783   float4 rhs_pf0, rhs_pf1;
    784 
    785   float4 results[8];
    786   for (int i=0; i < 8; i++) {
    787     results[i].x = results[i].y = results[i].z = results[i].w = 0;
    788   }
    789 
    790   Index lhs_vert = base_m+threadIdx.x*4+(threadIdx.y%4)*32;
    791   for (Index k = 0; k < k_size; k += 32) {
    792     lhs_pf0 = internal::pset1<float4>(0);
    793     lhs_pf1 = internal::pset1<float4>(0);
    794     lhs_pf2 = internal::pset1<float4>(0);
    795     lhs_pf3 = internal::pset1<float4>(0);
    796 
    797     rhs_pf0 = internal::pset1<float4>(0);
    798     rhs_pf1 = internal::pset1<float4>(0);
    799 
    800      if (!CHECK_LHS_BOUNDARY) {
    801       if ((threadIdx.y/4+k+24) < k_size) {
    802         lhs_pf0 =lhs.template loadPacket<float4,Unaligned>(lhs_vert, (threadIdx.y/4+k));
    803         lhs_pf1 =lhs.template loadPacket<float4,Unaligned>(lhs_vert, (threadIdx.y/4+k+8));
    804         lhs_pf2 =lhs.template loadPacket<float4,Unaligned>(lhs_vert, (threadIdx.y/4+k+16));
    805         lhs_pf3 =lhs.template loadPacket<float4,Unaligned>(lhs_vert, (threadIdx.y/4+k+24));
    806       } else if ((threadIdx.y/4+k+16) < k_size) {
    807         lhs_pf0 =lhs.template loadPacket<float4,Unaligned>(lhs_vert, (threadIdx.y/4+k));
    808         lhs_pf1 =lhs.template loadPacket<float4,Unaligned>(lhs_vert, (threadIdx.y/4+k+8));
    809         lhs_pf2 =lhs.template loadPacket<float4,Unaligned>(lhs_vert, (threadIdx.y/4+k+16));
    810       } else if ((threadIdx.y/4+k+8) < k_size) {
    811         lhs_pf0 =lhs.template loadPacket<float4,Unaligned>(lhs_vert, (threadIdx.y/4+k));
    812         lhs_pf1 =lhs.template loadPacket<float4,Unaligned>(lhs_vert, (threadIdx.y/4+k+8));
    813       } else if ((threadIdx.y/4+k) < k_size) {
    814         lhs_pf0 =lhs.template loadPacket<float4,Unaligned>(lhs_vert, (threadIdx.y/4+k));
    815       }
    816     } else {
    817       // just CHECK_LHS_BOUNDARY
    818       if (lhs_vert + 3 < m_size) {
    819         if ((threadIdx.y/4+k+24) < k_size) {
    820           lhs_pf0 =lhs.template loadPacket<float4,Unaligned>(lhs_vert, (threadIdx.y/4+k));
    821           lhs_pf1 =lhs.template loadPacket<float4,Unaligned>(lhs_vert, (threadIdx.y/4+k+8));
    822           lhs_pf2 =lhs.template loadPacket<float4,Unaligned>(lhs_vert, (threadIdx.y/4+k+16));
    823           lhs_pf3 =lhs.template loadPacket<float4,Unaligned>(lhs_vert, (threadIdx.y/4+k+24));
    824         } else if ((threadIdx.y/4+k+16) < k_size) {
    825           lhs_pf0 =lhs.template loadPacket<float4,Unaligned>(lhs_vert, (threadIdx.y/4+k));
    826           lhs_pf1 =lhs.template loadPacket<float4,Unaligned>(lhs_vert, (threadIdx.y/4+k+8));
    827           lhs_pf2 =lhs.template loadPacket<float4,Unaligned>(lhs_vert, (threadIdx.y/4+k+16));
    828         } else if ((threadIdx.y/4+k+8) < k_size) {
    829           lhs_pf0 =lhs.template loadPacket<float4,Unaligned>(lhs_vert, (threadIdx.y/4+k));
    830           lhs_pf1 =lhs.template loadPacket<float4,Unaligned>(lhs_vert, (threadIdx.y/4+k+8));
    831         } else if ((threadIdx.y/4+k) < k_size) {
    832           lhs_pf0 =lhs.template loadPacket<float4,Unaligned>(lhs_vert, (threadIdx.y/4+k));
    833         }
    834       } else if (lhs_vert + 2 < m_size) {
    835         if ((threadIdx.y/4+k+24) < k_size) {
    836           lhs_pf0.x =lhs(lhs_vert + 0, (threadIdx.y/4+k));
    837           lhs_pf0.y =lhs(lhs_vert + 1, (threadIdx.y/4+k));
    838           lhs_pf0.z =lhs(lhs_vert + 2, (threadIdx.y/4+k));
    839           lhs_pf1.x =lhs(lhs_vert + 0, (threadIdx.y/4+k+8));
    840           lhs_pf1.y =lhs(lhs_vert + 1, (threadIdx.y/4+k+8));
    841           lhs_pf1.z =lhs(lhs_vert + 2, (threadIdx.y/4+k+8));
    842           lhs_pf2.x =lhs(lhs_vert + 0, (threadIdx.y/4+k+16));
    843           lhs_pf2.y =lhs(lhs_vert + 1, (threadIdx.y/4+k+16));
    844           lhs_pf2.z =lhs(lhs_vert + 2, (threadIdx.y/4+k+16));
    845           lhs_pf3.x =lhs(lhs_vert + 0, (threadIdx.y/4+k+24));
    846           lhs_pf3.y =lhs(lhs_vert + 1, (threadIdx.y/4+k+24));
    847           lhs_pf3.z =lhs(lhs_vert + 2, (threadIdx.y/4+k+24));
    848         } else if ((threadIdx.y/4+k+16) < k_size) {
    849           lhs_pf0.x =lhs(lhs_vert + 0, (threadIdx.y/4+k));
    850           lhs_pf0.y =lhs(lhs_vert + 1, (threadIdx.y/4+k));
    851           lhs_pf0.z =lhs(lhs_vert + 2, (threadIdx.y/4+k));
    852           lhs_pf1.x =lhs(lhs_vert + 0, (threadIdx.y/4+k+8));
    853           lhs_pf1.y =lhs(lhs_vert + 1, (threadIdx.y/4+k+8));
    854           lhs_pf1.z =lhs(lhs_vert + 2, (threadIdx.y/4+k+8));
    855           lhs_pf2.x =lhs(lhs_vert + 0, (threadIdx.y/4+k+16));
    856           lhs_pf2.y =lhs(lhs_vert + 1, (threadIdx.y/4+k+16));
    857           lhs_pf2.z =lhs(lhs_vert + 2, (threadIdx.y/4+k+16));
    858         } else if ((threadIdx.y/4+k+8) < k_size) {
    859           lhs_pf0.x =lhs(lhs_vert + 0, (threadIdx.y/4+k));
    860           lhs_pf0.y =lhs(lhs_vert + 1, (threadIdx.y/4+k));
    861           lhs_pf0.z =lhs(lhs_vert + 2, (threadIdx.y/4+k));
    862           lhs_pf1.x =lhs(lhs_vert + 0, (threadIdx.y/4+k+8));
    863           lhs_pf1.y =lhs(lhs_vert + 1, (threadIdx.y/4+k+8));
    864           lhs_pf1.z =lhs(lhs_vert + 2, (threadIdx.y/4+k+8));
    865         } else if ((threadIdx.y/4+k) < k_size) {
    866           lhs_pf0.x =lhs(lhs_vert + 0, (threadIdx.y/4+k));
    867           lhs_pf0.y =lhs(lhs_vert + 1, (threadIdx.y/4+k));
    868           lhs_pf0.z =lhs(lhs_vert + 2, (threadIdx.y/4+k));
    869         }
    870       } else if (lhs_vert + 1 < m_size) {
    871         if ((threadIdx.y/4+k+24) < k_size) {
    872           lhs_pf0.x =lhs(lhs_vert + 0, (threadIdx.y/4+k));
    873           lhs_pf0.y =lhs(lhs_vert + 1, (threadIdx.y/4+k));
    874           lhs_pf1.x =lhs(lhs_vert + 0, (threadIdx.y/4+k+8));
    875           lhs_pf1.y =lhs(lhs_vert + 1, (threadIdx.y/4+k+8));
    876           lhs_pf2.x =lhs(lhs_vert + 0, (threadIdx.y/4+k+16));
    877           lhs_pf2.y =lhs(lhs_vert + 1, (threadIdx.y/4+k+16));
    878           lhs_pf3.x =lhs(lhs_vert + 0, (threadIdx.y/4+k+24));
    879           lhs_pf3.y =lhs(lhs_vert + 1, (threadIdx.y/4+k+24));
    880         } else if ((threadIdx.y/4+k+16) < k_size) {
    881           lhs_pf0.x =lhs(lhs_vert + 0, (threadIdx.y/4+k));
    882           lhs_pf0.y =lhs(lhs_vert + 1, (threadIdx.y/4+k));
    883           lhs_pf1.x =lhs(lhs_vert + 0, (threadIdx.y/4+k+8));
    884           lhs_pf1.y =lhs(lhs_vert + 1, (threadIdx.y/4+k+8));
    885           lhs_pf2.x =lhs(lhs_vert + 0, (threadIdx.y/4+k+16));
    886           lhs_pf2.y =lhs(lhs_vert + 1, (threadIdx.y/4+k+16));
    887         } else if ((threadIdx.y/4+k+8) < k_size) {
    888           lhs_pf0.x =lhs(lhs_vert + 0, (threadIdx.y/4+k));
    889           lhs_pf0.y =lhs(lhs_vert + 1, (threadIdx.y/4+k));
    890           lhs_pf1.x =lhs(lhs_vert + 0, (threadIdx.y/4+k+8));
    891           lhs_pf1.y =lhs(lhs_vert + 1, (threadIdx.y/4+k+8));
    892         } else if ((threadIdx.y/4+k) < k_size) {
    893           lhs_pf0.x =lhs(lhs_vert + 0, (threadIdx.y/4+k));
    894           lhs_pf0.y =lhs(lhs_vert + 1, (threadIdx.y/4+k));
    895         }
    896       } else if (lhs_vert < m_size) {
    897         if ((threadIdx.y/4+k+24) < k_size) {
    898           lhs_pf0.x =lhs(lhs_vert + 0, (threadIdx.y/4+k));
    899           lhs_pf1.x =lhs(lhs_vert + 0, (threadIdx.y/4+k+8));
    900           lhs_pf2.x =lhs(lhs_vert + 0, (threadIdx.y/4+k+16));
    901           lhs_pf3.x =lhs(lhs_vert + 0, (threadIdx.y/4+k+24));
    902         } else if ((threadIdx.y/4+k+16) < k_size) {
    903           lhs_pf0.x =lhs(lhs_vert + 0, (threadIdx.y/4+k));
    904           lhs_pf1.x =lhs(lhs_vert + 0, (threadIdx.y/4+k+8));
    905           lhs_pf2.x =lhs(lhs_vert + 0, (threadIdx.y/4+k+16));
    906         } else if ((threadIdx.y/4+k+8) < k_size) {
    907           lhs_pf0.x =lhs(lhs_vert + 0, (threadIdx.y/4+k));
    908           lhs_pf1.x =lhs(lhs_vert + 0, (threadIdx.y/4+k+8));
    909         } else if ((threadIdx.y/4+k) < k_size) {
    910           lhs_pf0.x =lhs(lhs_vert + 0, (threadIdx.y/4+k));
    911         }
    912       }
    913     }
    914     __syncthreads();
    915     Index rhs_vert = k+threadIdx.x*4;
    916     Index rhs_horiz0 = threadIdx.y*2+base_n;
    917     Index rhs_horiz1 = threadIdx.y*2+1+base_n;
    918     if (!CHECK_RHS_BOUNDARY) {
    919       if ((rhs_vert + 3) < k_size) {
    920         // just CHECK_RHS_BOUNDARY
    921         rhs_pf0 = rhs.template loadPacket<float4,Unaligned>(rhs_vert, rhs_horiz0);
    922         rhs_pf1 = rhs.template loadPacket<float4,Unaligned>(rhs_vert, rhs_horiz1);
    923       } else if (rhs_vert + 2 < k_size) {
    924         // just CHECK_RHS_BOUNDARY
    925         rhs_pf0.x = rhs(rhs_vert, rhs_horiz0);
    926         rhs_pf0.y = rhs(rhs_vert + 1, rhs_horiz0);
    927         rhs_pf0.z = rhs(rhs_vert + 2, rhs_horiz0);
    928         rhs_pf1.x = rhs(rhs_vert, rhs_horiz1);
    929         rhs_pf1.y = rhs(rhs_vert + 1, rhs_horiz1);
    930         rhs_pf1.z = rhs(rhs_vert + 2, rhs_horiz1);
    931       } else if (rhs_vert + 1 < k_size) {
    932         rhs_pf0.x = rhs(rhs_vert, rhs_horiz0);
    933         rhs_pf0.y = rhs(rhs_vert + 1, rhs_horiz0);
    934         rhs_pf1.x = rhs(rhs_vert, rhs_horiz1);
    935         rhs_pf1.y = rhs(rhs_vert + 1, rhs_horiz1);
    936       } else if (rhs_vert  < k_size) {
    937         rhs_pf0.x = rhs(rhs_vert, rhs_horiz0);
    938         rhs_pf1.x = rhs(rhs_vert, rhs_horiz1);
    939       }
    940     } else {
    941       if (rhs_horiz1 < n_size) {
    942         if ((rhs_vert + 3) < k_size) {
    943           // just CHECK_RHS_BOUNDARY
    944           rhs_pf0 = rhs.template loadPacket<float4,Unaligned>(rhs_vert, rhs_horiz0);
    945           rhs_pf1 = rhs.template loadPacket<float4,Unaligned>(rhs_vert, rhs_horiz1);
    946         } else if (rhs_vert + 2 < k_size) {
    947           // just CHECK_RHS_BOUNDARY
    948           rhs_pf0.x = rhs(rhs_vert, rhs_horiz0);
    949           rhs_pf0.y = rhs(rhs_vert + 1, rhs_horiz0);
    950           rhs_pf0.z = rhs(rhs_vert + 2, rhs_horiz0);
    951           rhs_pf1.x = rhs(rhs_vert, rhs_horiz1);
    952           rhs_pf1.y = rhs(rhs_vert + 1, rhs_horiz1);
    953           rhs_pf1.z = rhs(rhs_vert + 2, rhs_horiz1);
    954         } else if (k+threadIdx.x*4 + 1 < k_size) {
    955           rhs_pf0.x = rhs(rhs_vert, rhs_horiz0);
    956           rhs_pf0.y = rhs(rhs_vert + 1, rhs_horiz0);
    957           rhs_pf1.x = rhs(rhs_vert, rhs_horiz1);
    958           rhs_pf1.y = rhs(rhs_vert + 1, rhs_horiz1);
    959         } else if (k+threadIdx.x*4  < k_size) {
    960           rhs_pf0.x = rhs(rhs_vert, rhs_horiz0);
    961           rhs_pf1.x = rhs(rhs_vert, rhs_horiz1);
    962         }
    963       } else if (rhs_horiz0 < n_size) {
    964         if ((rhs_vert + 3) < k_size) {
    965           // just CHECK_RHS_BOUNDARY
    966           rhs_pf0 = rhs.template loadPacket<float4,Unaligned>(rhs_vert, rhs_horiz0);
    967         } else if ((rhs_vert + 2) < k_size) {
    968           // just CHECK_RHS_BOUNDARY
    969           rhs_pf0.x = rhs(rhs_vert, rhs_horiz0);
    970           rhs_pf0.y = rhs(rhs_vert + 1, rhs_horiz0);
    971           rhs_pf0.z = rhs(rhs_vert + 2, rhs_horiz0);
    972         } else if ((rhs_vert + 1) < k_size) {
    973           rhs_pf0.x = rhs(rhs_vert, rhs_horiz0);
    974           rhs_pf0.y = rhs(rhs_vert + 1, rhs_horiz0);
    975         } else if (rhs_vert  < k_size) {
    976           rhs_pf0.x = rhs(rhs_vert, rhs_horiz0);
    977         }
    978       }
    979     }
    980     __syncthreads();
    981     // Loaded. Do computation
    982     // Row 0 -> times (0, 4, 8, .. 28) for features 0, 1.
    983     // Row 1 -> times (0, 4, 8, .. 28) for features 2, 3.
    984     // ..
    985     // Row 31 -> times (0, 4, 8, .. 28) for features 62, 63
    986     rhs_shmem2[threadIdx.y][threadIdx.x] = make_float2(rhs_pf0.x, rhs_pf1.x);
    987     // Row 32 -> times (1, 5, 9, .. 29) for features 0, 1.
    988     // Row 33 -> times (1, 5, 9, .. 29) for features 2, 3.
    989     // ..
    990     rhs_shmem2[threadIdx.y+32][threadIdx.x] = make_float2(rhs_pf0.y, rhs_pf1.y);
    991     // Row 64 -> times (2, 6, 10, .. 30) for features 0, 1.
    992     // Row 65 -> times (2, 6, 10, .. 30) for features 2, 3.
    993     rhs_shmem2[threadIdx.y+64][threadIdx.x] = make_float2(rhs_pf0.z, rhs_pf1.z);
    994     // Row 96 -> times (3, 7, 11, .. 31) for features 0, 1.
    995     // Row 97 -> times (3, 7, 11, .. 31) for features 2, 3.
    996     rhs_shmem2[threadIdx.y+96][threadIdx.x] = make_float2(rhs_pf0.w, rhs_pf1.w);
    997 
    998     // LHS.
    999     // Row 0 (time 0) -> features (0, 1), (4, 5), .. (28, 29), (32, 33), ..  (60, 61) .. (124, 125)
   1000     // Row 1 (time 1) -> features (0, 1), (4, 5), .. (28, 29), (32, 33), ..  (60, 61) .. (124, 125)
   1001     // ...
   1002     // Row 8 (time 0) -> features (2, 3), (6, 7), .. (30, 31), (34, 35), ..  (62, 63) .. (126, 127)
   1003     // Row 15 (time 7) -> features (2, 3), (6, 7), .. (30, 31), (34, 35), ..  (62, 63) .. (126, 127)
   1004 
   1005 
   1006 #define add_vals(a_feat1, a_feat2, f1, f2, f3, f4)\
   1007       results[0].x += a_feat1.x * f1.x;\
   1008       results[1].x += a_feat1.x * f1.y;\
   1009       results[2].x += a_feat1.x * f2.x;\
   1010       results[3].x += a_feat1.x * f2.y;\
   1011       results[4].x += a_feat1.x * f3.x;\
   1012       results[5].x += a_feat1.x * f3.y;\
   1013       results[6].x += a_feat1.x * f4.x;\
   1014       results[7].x += a_feat1.x * f4.y;\
   1015 \
   1016       results[0].y += a_feat1.y * f1.x;\
   1017       results[1].y += a_feat1.y * f1.y;\
   1018       results[2].y += a_feat1.y * f2.x;\
   1019       results[3].y += a_feat1.y * f2.y;\
   1020       results[4].y += a_feat1.y * f3.x;\
   1021       results[5].y += a_feat1.y * f3.y;\
   1022       results[6].y += a_feat1.y * f4.x;\
   1023       results[7].y += a_feat1.y * f4.y;\
   1024 \
   1025       results[0].z += a_feat2.x * f1.x;\
   1026       results[1].z += a_feat2.x * f1.y;\
   1027       results[2].z += a_feat2.x * f2.x;\
   1028       results[3].z += a_feat2.x * f2.y;\
   1029       results[4].z += a_feat2.x * f3.x;\
   1030       results[5].z += a_feat2.x * f3.y;\
   1031       results[6].z += a_feat2.x * f4.x;\
   1032       results[7].z += a_feat2.x * f4.y;\
   1033 \
   1034       results[0].w += a_feat2.y * f1.x;\
   1035       results[1].w += a_feat2.y * f1.y;\
   1036       results[2].w += a_feat2.y * f2.x;\
   1037       results[3].w += a_feat2.y * f2.y;\
   1038       results[4].w += a_feat2.y * f3.x;\
   1039       results[5].w += a_feat2.y * f3.y;\
   1040       results[6].w += a_feat2.y * f4.x;\
   1041       results[7].w += a_feat2.y * f4.y;\
   1042 
   1043     lhs_shmem2[threadIdx.y/4][threadIdx.x+(threadIdx.y%4)*8] = make_float2(lhs_pf0.x, lhs_pf0.y);
   1044     lhs_shmem2[threadIdx.y/4+8][threadIdx.x+(threadIdx.y%4)*8] = make_float2(lhs_pf1.x, lhs_pf1.y);
   1045     lhs_shmem2[threadIdx.y/4+16][threadIdx.x+(threadIdx.y%4)*8] = make_float2(lhs_pf2.x, lhs_pf2.y);
   1046     lhs_shmem2[threadIdx.y/4+24][threadIdx.x+(threadIdx.y%4)*8] = make_float2(lhs_pf3.x, lhs_pf3.y);
   1047 
   1048     lhs_shmem2[threadIdx.y/4 + 32][threadIdx.x+(threadIdx.y%4)*8] = make_float2(lhs_pf0.z, lhs_pf0.w);
   1049     lhs_shmem2[threadIdx.y/4 + 40][threadIdx.x+(threadIdx.y%4)*8] = make_float2(lhs_pf1.z, lhs_pf1.w);
   1050     lhs_shmem2[threadIdx.y/4 + 48][threadIdx.x+(threadIdx.y%4)*8] = make_float2(lhs_pf2.z, lhs_pf2.w);
   1051     lhs_shmem2[threadIdx.y/4 + 56][threadIdx.x+(threadIdx.y%4)*8] = make_float2(lhs_pf3.z, lhs_pf3.w);
   1052 
   1053     __syncthreads();
   1054 
   1055     // Do the multiplies.
   1056     #pragma unroll
   1057     for (int koff = 0; koff < 32; koff ++) {
   1058       float2 a3 = lhs_shmem2[koff][threadIdx.x + (threadIdx.y % 4) * 8];
   1059       float2 a4 = lhs_shmem2[koff + 32][threadIdx.x + (threadIdx.y % 4) * 8];
   1060 
   1061       // first feature is at (threadIdx.y/4) * 8 last is at start + 8.
   1062       int start_feature = (threadIdx.y / 4) * 8;
   1063 
   1064       float2 br1 = rhs_shmem2[start_feature/2 +     (koff % 4) * 32][koff/4];
   1065       float2 br2 = rhs_shmem2[start_feature/2 + 1 + (koff % 4) * 32][koff/4];
   1066       float2 br3 = rhs_shmem2[start_feature/2 + 2 + (koff % 4) * 32][koff/4];
   1067       float2 br4 = rhs_shmem2[start_feature/2 + 3 + (koff % 4) * 32][koff/4];
   1068 
   1069       add_vals(a3, a4, br1, br2, br3, br4)
   1070     }
   1071     __syncthreads();
   1072   } // end loop over k
   1073 
   1074   __syncthreads();
   1075   Index horiz_base = (threadIdx.y/4)*8+base_n;
   1076   if (!CHECK_LHS_BOUNDARY && !CHECK_RHS_BOUNDARY) {
   1077     for (int i = 0; i < 8; i++) {
   1078       output(lhs_vert, horiz_base + i) = results[i].x;
   1079       output(lhs_vert + 1, horiz_base + i) = results[i].y;
   1080       output(lhs_vert + 2, horiz_base + i) = results[i].z;
   1081       output(lhs_vert + 3, horiz_base + i) = results[i].w;
   1082     }
   1083   } else if (!CHECK_RHS_BOUNDARY) {
   1084     if (lhs_vert + 3 < m_size) {
   1085       for (int i = 0; i < 8; i++) {
   1086         output(lhs_vert, horiz_base + i) = results[i].x;
   1087         output(lhs_vert + 1, horiz_base + i) = results[i].y;
   1088         output(lhs_vert + 2, horiz_base + i) = results[i].z;
   1089         output(lhs_vert + 3, horiz_base + i) = results[i].w;
   1090       }
   1091     } else if (lhs_vert + 2 < m_size) {
   1092       for (int i = 0; i < 8; i++) {
   1093         output(lhs_vert, horiz_base + i) = results[i].x;
   1094         output(lhs_vert + 1, horiz_base + i) = results[i].y;
   1095         output(lhs_vert + 2, horiz_base + i) = results[i].z;
   1096       }
   1097     } else if (lhs_vert + 1 < m_size) {
   1098       for (int i = 0; i < 8; i++) {
   1099         output(lhs_vert, horiz_base + i) = results[i].x;
   1100         output(lhs_vert + 1, horiz_base + i) = results[i].y;
   1101       }
   1102     } else if (lhs_vert  < m_size) {
   1103       for (int i = 0; i < 8; i++) {
   1104         output(lhs_vert, horiz_base + i) = results[i].x;
   1105       }
   1106     }
   1107   } else if (!CHECK_LHS_BOUNDARY) {
   1108     // CHECK BOUNDARY_B
   1109     for (int i = 0; i < 8; i++) {
   1110       if (horiz_base + i < n_size) {
   1111         output(lhs_vert, horiz_base + i) = results[i].x;
   1112         output(lhs_vert + 1, horiz_base + i) = results[i].y;
   1113         output(lhs_vert + 2, horiz_base + i) = results[i].z;
   1114         output(lhs_vert + 3, horiz_base + i) = results[i].w;
   1115       }
   1116     }
   1117   } else {
   1118     // CHECK both boundaries.
   1119     for (int i = 0; i < 8; i++) {
   1120       if (horiz_base + i < n_size) {
   1121         if (lhs_vert < m_size)
   1122           output(lhs_vert, horiz_base + i) = results[i].x;
   1123         if (lhs_vert + 1 < m_size)
   1124           output(lhs_vert + 1, horiz_base + i) = results[i].y;
   1125         if (lhs_vert + 2 < m_size)
   1126           output(lhs_vert + 2, horiz_base + i) = results[i].z;
   1127         if (lhs_vert + 3 < m_size)
   1128           output(lhs_vert + 3, horiz_base + i) = results[i].w;
   1129       }
   1130     }
   1131   }
   1132 }
   1133 
   1134 
   1135 template<typename Index, typename LhsMapper,
   1136          typename RhsMapper, typename OutputMapper>
   1137 __global__ void
   1138 #if defined(EIGEN_HIPCC)
   1139 __launch_bounds__(256, 1)
   1140 #else
   1141 __launch_bounds__(256)
   1142 #endif
   1143 EigenFloatContractionKernel(const LhsMapper lhs, const RhsMapper rhs,
   1144                        const OutputMapper output,
   1145                        const Index m_size, const Index n_size, const Index k_size) {
   1146   __shared__ float2 lhs_shmem[64*32];
   1147   __shared__ float2 rhs_shmem[128*8];
   1148 
   1149   typedef float2 LHS_MEM[64][32];
   1150   typedef float2 RHS_MEM[128][8];
   1151 
   1152   const Index m_block_idx = blockIdx.x;
   1153   const Index n_block_idx = blockIdx.y;
   1154 
   1155   const Index base_m = 128 * m_block_idx;
   1156   const Index base_n = 64 * n_block_idx;
   1157 
   1158   bool check_rhs = (base_n + 63) >= n_size;
   1159   bool check_lhs128 = (base_m + 127) >= m_size;
   1160 
   1161   if (!check_rhs) {
   1162     if (!check_lhs128) {
   1163       // >= 128 rows left
   1164       EigenFloatContractionKernelInternal<Index, LhsMapper, RhsMapper, OutputMapper, false, false>(
   1165                      lhs, rhs, output, *((LHS_MEM *) lhs_shmem), *((RHS_MEM *) rhs_shmem), m_size, n_size, k_size, base_m, base_n);
   1166     } else {
   1167       EigenFloatContractionKernelInternal<Index, LhsMapper, RhsMapper, OutputMapper, true, false>(
   1168                      lhs, rhs, output, *((LHS_MEM *) lhs_shmem), *((RHS_MEM *) rhs_shmem), m_size, n_size, k_size, base_m, base_n);
   1169     }
   1170   } else {
   1171     if (!check_lhs128) {
   1172       // >= 128 rows left
   1173       EigenFloatContractionKernelInternal<Index, LhsMapper, RhsMapper, OutputMapper, false, true>(
   1174                      lhs, rhs, output, *((LHS_MEM *) lhs_shmem), *((RHS_MEM *) rhs_shmem), m_size, n_size, k_size, base_m, base_n);
   1175     } else {
   1176       EigenFloatContractionKernelInternal<Index, LhsMapper, RhsMapper, OutputMapper, true, true>(
   1177                      lhs, rhs, output, *((LHS_MEM *) lhs_shmem), *((RHS_MEM *) rhs_shmem), m_size, n_size, k_size, base_m, base_n);
   1178     }
   1179   }
   1180 }
   1181 
   1182 template<typename Index, typename LhsMapper,
   1183          typename RhsMapper, typename OutputMapper>
   1184 __global__ void
   1185 #if defined(EIGEN_HIPCC)
   1186 __launch_bounds__(256, 1)
   1187 #else
   1188 __launch_bounds__(256)
   1189 #endif
   1190 EigenFloatContractionKernel16x16(const LhsMapper lhs, const RhsMapper rhs,
   1191                        const OutputMapper output,
   1192                        const Index m_size, const Index n_size, const Index k_size) {
   1193   __shared__ float2 lhs_shmem[32][16];
   1194   __shared__ float2 rhs_shmem[64][8];
   1195 
   1196   const Index m_block_idx = blockIdx.x;
   1197   const Index n_block_idx = blockIdx.y;
   1198 
   1199   const Index base_m = 64 * m_block_idx;
   1200   const Index base_n = 64 * n_block_idx;
   1201 
   1202   if (base_m + 63 < m_size) {
   1203     if (base_n + 63 < n_size) {
   1204       EigenFloatContractionKernelInternal16x16<Index, LhsMapper, RhsMapper, OutputMapper, false, false>(lhs, rhs, output, lhs_shmem, rhs_shmem, m_size, n_size, k_size, base_m, base_n);
   1205     } else {
   1206       EigenFloatContractionKernelInternal16x16<Index, LhsMapper, RhsMapper, OutputMapper, false, true>(lhs, rhs, output, lhs_shmem, rhs_shmem, m_size, n_size, k_size, base_m, base_n);
   1207     }
   1208   } else {
   1209     if (base_n + 63 < n_size) {
   1210       EigenFloatContractionKernelInternal16x16<Index, LhsMapper, RhsMapper, OutputMapper, true, false>(lhs, rhs, output, lhs_shmem, rhs_shmem, m_size, n_size, k_size, base_m, base_n);
   1211     } else {
   1212       EigenFloatContractionKernelInternal16x16<Index, LhsMapper, RhsMapper, OutputMapper, true, true>(lhs, rhs, output, lhs_shmem, rhs_shmem, m_size, n_size, k_size, base_m, base_n);
   1213     }
   1214   }
   1215 }
   1216 
   1217 
   1218 template<typename Indices, typename LeftArgType, typename RightArgType, typename OutputKernelType>
   1219 struct TensorEvaluator<const TensorContractionOp<Indices, LeftArgType, RightArgType, OutputKernelType>, GpuDevice> :
   1220     public TensorContractionEvaluatorBase<TensorEvaluator<const TensorContractionOp<Indices, LeftArgType, RightArgType, OutputKernelType>, GpuDevice> > {
   1221 
   1222   typedef GpuDevice Device;
   1223 
   1224   typedef TensorEvaluator<const TensorContractionOp<Indices, LeftArgType, RightArgType, OutputKernelType>, Device> Self;
   1225   typedef TensorContractionEvaluatorBase<Self> Base;
   1226 
   1227   typedef TensorContractionOp<Indices, LeftArgType, RightArgType, OutputKernelType> XprType;
   1228   typedef typename internal::remove_const<typename XprType::Scalar>::type Scalar;
   1229   typedef typename XprType::Index Index;
   1230   typedef typename XprType::CoeffReturnType CoeffReturnType;
   1231   typedef typename PacketType<CoeffReturnType, GpuDevice>::type PacketReturnType;
   1232 
   1233   enum {
   1234     Layout = TensorEvaluator<LeftArgType, Device>::Layout,
   1235   };
   1236 
   1237   // Most of the code is assuming that both input tensors are ColMajor. If the
   1238   // inputs are RowMajor, we will "cheat" by swapping the LHS and RHS:
   1239   // If we want to compute A * B = C, where A is LHS and B is RHS, the code
   1240   // will pretend B is LHS and A is RHS.
   1241   typedef typename internal::conditional<
   1242     static_cast<int>(Layout) == static_cast<int>(ColMajor), LeftArgType, RightArgType>::type EvalLeftArgType;
   1243   typedef typename internal::conditional<
   1244     static_cast<int>(Layout) == static_cast<int>(ColMajor), RightArgType, LeftArgType>::type EvalRightArgType;
   1245 
   1246   static const int LDims =
   1247       internal::array_size<typename TensorEvaluator<EvalLeftArgType, Device>::Dimensions>::value;
   1248   static const int RDims =
   1249       internal::array_size<typename TensorEvaluator<EvalRightArgType, Device>::Dimensions>::value;
   1250   static const int ContractDims = internal::array_size<Indices>::value;
   1251 
   1252   typedef array<Index, LDims> left_dim_mapper_t;
   1253   typedef array<Index, RDims> right_dim_mapper_t;
   1254 
   1255   typedef array<Index, ContractDims> contract_t;
   1256   typedef array<Index, LDims - ContractDims> left_nocontract_t;
   1257   typedef array<Index, RDims - ContractDims> right_nocontract_t;
   1258 
   1259   static const int NumDims = LDims + RDims - 2 * ContractDims;
   1260 
   1261   typedef DSizes<Index, NumDims> Dimensions;
   1262 
   1263   // typedefs needed in evalTo
   1264   typedef typename internal::remove_const<typename EvalLeftArgType::Scalar>::type LhsScalar;
   1265   typedef typename internal::remove_const<typename EvalRightArgType::Scalar>::type RhsScalar;
   1266 
   1267   typedef TensorEvaluator<EvalLeftArgType, Device> LeftEvaluator;
   1268   typedef TensorEvaluator<EvalRightArgType, Device> RightEvaluator;
   1269 
   1270   typedef typename LeftEvaluator::Dimensions LeftDimensions;
   1271   typedef typename RightEvaluator::Dimensions RightDimensions;
   1272 
   1273   TensorEvaluator(const XprType& op, const Device& device) :
   1274       Base(op, device)
   1275   {
   1276     EIGEN_STATIC_ASSERT( (internal::is_same<OutputKernelType, const NoOpOutputKernel>::value),
   1277                           GPU_TENSOR_CONTRACTION_DOES_NOT_SUPPORT_OUTPUT_KERNELS);
   1278   }
   1279 
   1280   // We need to redefine this method to make nvcc happy
   1281   EIGEN_STRONG_INLINE bool evalSubExprsIfNeeded(Scalar* data) {
   1282     this->m_leftImpl.evalSubExprsIfNeeded(NULL);
   1283     this->m_rightImpl.evalSubExprsIfNeeded(NULL);
   1284     if (data) {
   1285       evalTo(data);
   1286       return false;
   1287     } else {
   1288       this->m_result = static_cast<Scalar *>(this->m_device.allocate(this->dimensions().TotalSize() * sizeof(Scalar)));
   1289       evalTo(this->m_result);
   1290       return true;
   1291     }
   1292   }
   1293 
   1294   void evalTo(Scalar* buffer) const {
   1295     if (this->m_lhs_inner_dim_contiguous) {
   1296       if (this->m_rhs_inner_dim_contiguous) {
   1297         if (this->m_rhs_inner_dim_reordered) {
   1298           evalTyped<true, true, true, Unaligned>(buffer);
   1299         }
   1300         else {
   1301           evalTyped<true, true, false, Unaligned>(buffer);
   1302         }
   1303       }
   1304       else {
   1305        if (this->m_rhs_inner_dim_reordered) {
   1306           evalTyped<true, false, true, Unaligned>(buffer);
   1307         }
   1308         else {
   1309           evalTyped<true, false, false, Unaligned>(buffer);
   1310         }
   1311       }
   1312     }
   1313     else {
   1314       if (this->m_rhs_inner_dim_contiguous) {
   1315         if (this->m_rhs_inner_dim_reordered) {
   1316           evalTyped<false, true, true, Unaligned>(buffer);
   1317         }
   1318         else {
   1319           evalTyped<false, true, false, Unaligned>(buffer);
   1320         }
   1321       }
   1322       else {
   1323        if (this->m_rhs_inner_dim_reordered) {
   1324           evalTyped<false, false, true, Unaligned>(buffer);
   1325         }
   1326         else {
   1327           evalTyped<false, false, false, Unaligned>(buffer);
   1328         }
   1329       }
   1330     }
   1331   }
   1332 
   1333   template <typename LhsScalar, typename RhsScalar, typename Index, typename LhsMapper, typename RhsMapper, typename OutputMapper> struct LaunchKernels {
   1334     static void Run(const LhsMapper& lhs, const RhsMapper& rhs, const OutputMapper& output, Index m, Index n, Index k, const GpuDevice& device) {
   1335     const Index m_blocks = (m + 63) / 64;
   1336     const Index n_blocks = (n + 63) / 64;
   1337     const dim3 num_blocks(m_blocks, n_blocks, 1);
   1338     const dim3 block_size(8, 8, 8);
   1339     LAUNCH_GPU_KERNEL((EigenContractionKernel<Scalar, Index, LhsMapper, RhsMapper, OutputMapper>), num_blocks, block_size, 0, device, lhs, rhs, output, m, n, k);
   1340     }
   1341   };
   1342 
   1343   template <typename Index, typename LhsMapper, typename RhsMapper, typename OutputMapper> struct LaunchKernels<float, float, Index, LhsMapper, RhsMapper, OutputMapper> {
   1344     static void Run(const LhsMapper& lhs, const RhsMapper& rhs, const OutputMapper& output, Index m, Index n, Index k, const GpuDevice& device) {
   1345       if (m < 768 || n < 768) {
   1346         const Index m_blocks = (m + 63) / 64;
   1347         const Index n_blocks = (n + 63) / 64;
   1348         const dim3 num_blocks(m_blocks, n_blocks, 1);
   1349         const dim3 block_size(16, 16, 1);
   1350         LAUNCH_GPU_KERNEL((EigenFloatContractionKernel16x16<Index, LhsMapper, RhsMapper, OutputMapper>), num_blocks, block_size, 0, device, lhs, rhs, output, m, n, k);
   1351       } else {
   1352         const Index m_blocks = (m + 127) / 128;
   1353         const Index n_blocks = (n + 63) / 64;
   1354         const dim3 num_blocks(m_blocks, n_blocks, 1);
   1355         const dim3 block_size(8, 32, 1);
   1356         LAUNCH_GPU_KERNEL((EigenFloatContractionKernel<Index, LhsMapper, RhsMapper, OutputMapper>), num_blocks, block_size, 0, device, lhs, rhs, output, m, n, k);
   1357       }
   1358     }
   1359   };
   1360 
   1361   template <bool lhs_inner_dim_contiguous, bool rhs_inner_dim_contiguous, bool rhs_inner_dim_reordered, int Alignment>
   1362   void evalTyped(Scalar* buffer) const {
   1363     // columns in left side, rows in right side
   1364     const Index k = this->m_k_size;
   1365     EIGEN_UNUSED_VARIABLE(k)
   1366 
   1367     // rows in left side
   1368     const Index m = this->m_i_size;
   1369 
   1370     // columns in right side
   1371     const Index n = this->m_j_size;
   1372 
   1373     // zero out the result buffer (which must be of size at least m * n * sizeof(Scalar)
   1374     this->m_device.memset(buffer, 0, m * n * sizeof(Scalar));
   1375 
   1376     typedef internal::TensorContractionInputMapper<LhsScalar, Index, internal::Lhs,
   1377                                                    LeftEvaluator, left_nocontract_t,
   1378                                                    contract_t, 4,
   1379                                                    lhs_inner_dim_contiguous,
   1380                                                    false, Unaligned> LhsMapper;
   1381 
   1382     typedef internal::TensorContractionInputMapper<RhsScalar, Index, internal::Rhs,
   1383                                                    RightEvaluator, right_nocontract_t,
   1384                                                    contract_t, 4,
   1385                                                    rhs_inner_dim_contiguous,
   1386                                                    rhs_inner_dim_reordered, Unaligned> RhsMapper;
   1387 
   1388     typedef internal::blas_data_mapper<Scalar, Index, ColMajor> OutputMapper;
   1389 
   1390 
   1391     // initialize data mappers
   1392     LhsMapper lhs(this->m_leftImpl, this->m_left_nocontract_strides, this->m_i_strides,
   1393                   this->m_left_contracting_strides, this->m_k_strides);
   1394 
   1395     RhsMapper rhs(this->m_rightImpl, this->m_right_nocontract_strides, this->m_j_strides,
   1396                   this->m_right_contracting_strides, this->m_k_strides);
   1397 
   1398     OutputMapper output(buffer, m);
   1399 
   1400 #if defined(EIGEN_USE_HIP)
   1401     setGpuSharedMemConfig(hipSharedMemBankSizeEightByte);
   1402 #else
   1403     setGpuSharedMemConfig(cudaSharedMemBankSizeEightByte);
   1404 #endif
   1405 
   1406     LaunchKernels<LhsScalar, RhsScalar, Index, LhsMapper, RhsMapper, OutputMapper>::Run(lhs, rhs, output,  m, n, k, this->m_device);
   1407   }
   1408 };
   1409 
   1410 } // end namespace Eigen
   1411 
   1412 #endif // EIGEN_USE_GPU and EIGEN_GPUCC
   1413 #endif // EIGEN_CXX11_TENSOR_TENSOR_CONTRACTION_GPU_H