TensorContractionGpu.h (63402B)
1 // This file is part of Eigen, a lightweight C++ template library 2 // for linear algebra. 3 // 4 // Copyright (C) 2014-2015 Benoit Steiner <benoit.steiner.goog@gmail.com> 5 // Copyright (C) 2015 Navdeep Jaitly <ndjaitly@google.com> 6 // Copyright (C) 2014 Eric Martin <eric@ericmart.in> 7 // 8 // This Source Code Form is subject to the terms of the Mozilla 9 // Public License v. 2.0. If a copy of the MPL was not distributed 10 // with this file, You can obtain one at http://mozilla.org/MPL/2.0/. 11 12 #ifndef EIGEN_CXX11_TENSOR_TENSOR_CONTRACTION_GPU_H 13 #define EIGEN_CXX11_TENSOR_TENSOR_CONTRACTION_GPU_H 14 15 #if defined(EIGEN_USE_GPU) && defined(EIGEN_GPUCC) 16 17 namespace Eigen { 18 19 template<typename Scalar, typename Index, typename LhsMapper, 20 typename RhsMapper, typename OutputMapper, bool needs_edge_check> 21 __device__ EIGEN_STRONG_INLINE void 22 EigenContractionKernelInternal(const LhsMapper lhs, const RhsMapper rhs, 23 const OutputMapper output, Scalar* lhs_shmem, Scalar* rhs_shmem, 24 const Index m_size, const Index n_size, const Index k_size) { 25 26 const Index m_block_idx = blockIdx.x; 27 const Index n_block_idx = blockIdx.y; 28 29 const Index base_m = 64 * m_block_idx; 30 const Index base_n = 64 * n_block_idx; 31 32 // declare and initialize 64 registers for output 8x8 block 33 34 // prefetch registers 35 Scalar lhs_pf0; 36 Scalar lhs_pf1; 37 Scalar lhs_pf2; 38 Scalar lhs_pf3; 39 Scalar lhs_pf4; 40 Scalar lhs_pf5; 41 Scalar lhs_pf6; 42 Scalar lhs_pf7; 43 44 Scalar rhs_pf0; 45 Scalar rhs_pf1; 46 Scalar rhs_pf2; 47 Scalar rhs_pf3; 48 Scalar rhs_pf4; 49 Scalar rhs_pf5; 50 Scalar rhs_pf6; 51 Scalar rhs_pf7; 52 53 // shared memory is formatted 54 // (contract idx in block, nocontract idx in block, block idx) 55 // where block idx is column major. This transposition limits the number of 56 // bank conflicts when reading the LHS. The core idea is that since the contracting 57 // index is shared by both sides, then the contracting index should be in threadIdx.x. 58 59 // On the LHS, we pad each row inside of each block with an extra element. This makes 60 // each block 8 rows of 9 elements, which is 72 elements. This gives no bank conflicts 61 // on writes and very few 2-way conflicts on reads. There is an 8x8 grid of these blocks. 62 63 // On the RHS we just add 8 padding elements to the end of each block. This gives no bank 64 // conflicts on writes and also none on reads. 65 66 // storage indices 67 const Index lhs_store_idx_base = threadIdx.y * 72 + threadIdx.x * 9 + threadIdx.z; 68 const Index rhs_store_idx_base = threadIdx.y * 72 + threadIdx.z * 8 + threadIdx.x; 69 70 const Index lhs_store_idx_0 = lhs_store_idx_base + 576 * 0; 71 const Index lhs_store_idx_1 = lhs_store_idx_base + 576 * 1; 72 const Index lhs_store_idx_2 = lhs_store_idx_base + 576 * 2; 73 const Index lhs_store_idx_3 = lhs_store_idx_base + 576 * 3; 74 const Index lhs_store_idx_4 = lhs_store_idx_base + 576 * 4; 75 const Index lhs_store_idx_5 = lhs_store_idx_base + 576 * 5; 76 const Index lhs_store_idx_6 = lhs_store_idx_base + 576 * 6; 77 const Index lhs_store_idx_7 = lhs_store_idx_base + 576 * 7; 78 79 const Index rhs_store_idx_0 = rhs_store_idx_base + 576 * 0; 80 const Index rhs_store_idx_1 = rhs_store_idx_base + 576 * 1; 81 const Index rhs_store_idx_2 = rhs_store_idx_base + 576 * 2; 82 const Index rhs_store_idx_3 = rhs_store_idx_base + 576 * 3; 83 const Index rhs_store_idx_4 = rhs_store_idx_base + 576 * 4; 84 const Index rhs_store_idx_5 = rhs_store_idx_base + 576 * 5; 85 const Index rhs_store_idx_6 = rhs_store_idx_base + 576 * 6; 86 const Index rhs_store_idx_7 = rhs_store_idx_base + 576 * 7; 87 88 // in the loading code, the following variables are important: 89 // threadIdx.x: the vertical position in an 8x8 block 90 // threadIdx.y: the vertical index of the 8x8 block in the grid 91 // threadIdx.z: the horizontal position in an 8x8 block 92 // k: the horizontal index of the 8x8 block in the grid 93 // 94 // The k parameter is implicit (it was the loop counter for a loop that went 95 // from 0 to <8, but now that loop is unrolled in the below code. 96 97 const Index load_idx_vert = threadIdx.x + 8 * threadIdx.y; 98 const Index lhs_vert = base_m + load_idx_vert; 99 100 #define prefetchIntoRegisters(base_k) \ 101 { \ 102 lhs_pf0 = conv(0); \ 103 lhs_pf1 = conv(0); \ 104 lhs_pf2 = conv(0); \ 105 lhs_pf3 = conv(0); \ 106 lhs_pf4 = conv(0); \ 107 lhs_pf5 = conv(0); \ 108 lhs_pf6 = conv(0); \ 109 lhs_pf7 = conv(0); \ 110 \ 111 rhs_pf0 = conv(0); \ 112 rhs_pf1 = conv(0); \ 113 rhs_pf2 = conv(0); \ 114 rhs_pf3 = conv(0); \ 115 rhs_pf4 = conv(0); \ 116 rhs_pf5 = conv(0); \ 117 rhs_pf6 = conv(0); \ 118 rhs_pf7 = conv(0); \ 119 \ 120 if (!needs_edge_check || lhs_vert < m_size) { \ 121 const Index lhs_horiz_0 = base_k + threadIdx.z + 0 * 8; \ 122 const Index lhs_horiz_1 = base_k + threadIdx.z + 1 * 8; \ 123 const Index lhs_horiz_2 = base_k + threadIdx.z + 2 * 8; \ 124 const Index lhs_horiz_3 = base_k + threadIdx.z + 3 * 8; \ 125 const Index lhs_horiz_4 = base_k + threadIdx.z + 4 * 8; \ 126 const Index lhs_horiz_5 = base_k + threadIdx.z + 5 * 8; \ 127 const Index lhs_horiz_6 = base_k + threadIdx.z + 6 * 8; \ 128 const Index lhs_horiz_7 = base_k + threadIdx.z + 7 * 8; \ 129 \ 130 if (!needs_edge_check || lhs_horiz_7 < k_size) { \ 131 lhs_pf0 = lhs(lhs_vert, lhs_horiz_0); \ 132 lhs_pf1 = lhs(lhs_vert, lhs_horiz_1); \ 133 lhs_pf2 = lhs(lhs_vert, lhs_horiz_2); \ 134 lhs_pf3 = lhs(lhs_vert, lhs_horiz_3); \ 135 lhs_pf4 = lhs(lhs_vert, lhs_horiz_4); \ 136 lhs_pf5 = lhs(lhs_vert, lhs_horiz_5); \ 137 lhs_pf6 = lhs(lhs_vert, lhs_horiz_6); \ 138 lhs_pf7 = lhs(lhs_vert, lhs_horiz_7); \ 139 } else if (lhs_horiz_6 < k_size) { \ 140 lhs_pf0 = lhs(lhs_vert, lhs_horiz_0); \ 141 lhs_pf1 = lhs(lhs_vert, lhs_horiz_1); \ 142 lhs_pf2 = lhs(lhs_vert, lhs_horiz_2); \ 143 lhs_pf3 = lhs(lhs_vert, lhs_horiz_3); \ 144 lhs_pf4 = lhs(lhs_vert, lhs_horiz_4); \ 145 lhs_pf5 = lhs(lhs_vert, lhs_horiz_5); \ 146 lhs_pf6 = lhs(lhs_vert, lhs_horiz_6); \ 147 } else if (lhs_horiz_5 < k_size) { \ 148 lhs_pf0 = lhs(lhs_vert, lhs_horiz_0); \ 149 lhs_pf1 = lhs(lhs_vert, lhs_horiz_1); \ 150 lhs_pf2 = lhs(lhs_vert, lhs_horiz_2); \ 151 lhs_pf3 = lhs(lhs_vert, lhs_horiz_3); \ 152 lhs_pf4 = lhs(lhs_vert, lhs_horiz_4); \ 153 lhs_pf5 = lhs(lhs_vert, lhs_horiz_5); \ 154 } else if (lhs_horiz_4 < k_size) { \ 155 lhs_pf0 = lhs(lhs_vert, lhs_horiz_0); \ 156 lhs_pf1 = lhs(lhs_vert, lhs_horiz_1); \ 157 lhs_pf2 = lhs(lhs_vert, lhs_horiz_2); \ 158 lhs_pf3 = lhs(lhs_vert, lhs_horiz_3); \ 159 lhs_pf4 = lhs(lhs_vert, lhs_horiz_4); \ 160 } else if (lhs_horiz_3 < k_size) { \ 161 lhs_pf0 = lhs(lhs_vert, lhs_horiz_0); \ 162 lhs_pf1 = lhs(lhs_vert, lhs_horiz_1); \ 163 lhs_pf2 = lhs(lhs_vert, lhs_horiz_2); \ 164 lhs_pf3 = lhs(lhs_vert, lhs_horiz_3); \ 165 } else if (lhs_horiz_2 < k_size) { \ 166 lhs_pf0 = lhs(lhs_vert, lhs_horiz_0); \ 167 lhs_pf1 = lhs(lhs_vert, lhs_horiz_1); \ 168 lhs_pf2 = lhs(lhs_vert, lhs_horiz_2); \ 169 } else if (lhs_horiz_1 < k_size) { \ 170 lhs_pf0 = lhs(lhs_vert, lhs_horiz_0); \ 171 lhs_pf1 = lhs(lhs_vert, lhs_horiz_1); \ 172 } else if (lhs_horiz_0 < k_size) { \ 173 lhs_pf0 = lhs(lhs_vert, lhs_horiz_0); \ 174 } \ 175 } \ 176 \ 177 const Index rhs_vert = base_k + load_idx_vert; \ 178 if (!needs_edge_check || rhs_vert < k_size) { \ 179 const Index rhs_horiz_0 = base_n + threadIdx.z + 0 * 8; \ 180 const Index rhs_horiz_1 = base_n + threadIdx.z + 1 * 8; \ 181 const Index rhs_horiz_2 = base_n + threadIdx.z + 2 * 8; \ 182 const Index rhs_horiz_3 = base_n + threadIdx.z + 3 * 8; \ 183 const Index rhs_horiz_4 = base_n + threadIdx.z + 4 * 8; \ 184 const Index rhs_horiz_5 = base_n + threadIdx.z + 5 * 8; \ 185 const Index rhs_horiz_6 = base_n + threadIdx.z + 6 * 8; \ 186 const Index rhs_horiz_7 = base_n + threadIdx.z + 7 * 8; \ 187 \ 188 if (rhs_horiz_7 < n_size) { \ 189 rhs_pf0 = rhs(rhs_vert, rhs_horiz_0); \ 190 rhs_pf1 = rhs(rhs_vert, rhs_horiz_1); \ 191 rhs_pf2 = rhs(rhs_vert, rhs_horiz_2); \ 192 rhs_pf3 = rhs(rhs_vert, rhs_horiz_3); \ 193 rhs_pf4 = rhs(rhs_vert, rhs_horiz_4); \ 194 rhs_pf5 = rhs(rhs_vert, rhs_horiz_5); \ 195 rhs_pf6 = rhs(rhs_vert, rhs_horiz_6); \ 196 rhs_pf7 = rhs(rhs_vert, rhs_horiz_7); \ 197 } else if (rhs_horiz_6 < n_size) { \ 198 rhs_pf0 = rhs(rhs_vert, rhs_horiz_0); \ 199 rhs_pf1 = rhs(rhs_vert, rhs_horiz_1); \ 200 rhs_pf2 = rhs(rhs_vert, rhs_horiz_2); \ 201 rhs_pf3 = rhs(rhs_vert, rhs_horiz_3); \ 202 rhs_pf4 = rhs(rhs_vert, rhs_horiz_4); \ 203 rhs_pf5 = rhs(rhs_vert, rhs_horiz_5); \ 204 rhs_pf6 = rhs(rhs_vert, rhs_horiz_6); \ 205 } else if (rhs_horiz_5 < n_size) { \ 206 rhs_pf0 = rhs(rhs_vert, rhs_horiz_0); \ 207 rhs_pf1 = rhs(rhs_vert, rhs_horiz_1); \ 208 rhs_pf2 = rhs(rhs_vert, rhs_horiz_2); \ 209 rhs_pf3 = rhs(rhs_vert, rhs_horiz_3); \ 210 rhs_pf4 = rhs(rhs_vert, rhs_horiz_4); \ 211 rhs_pf5 = rhs(rhs_vert, rhs_horiz_5); \ 212 } else if (rhs_horiz_4 < n_size) { \ 213 rhs_pf0 = rhs(rhs_vert, rhs_horiz_0); \ 214 rhs_pf1 = rhs(rhs_vert, rhs_horiz_1); \ 215 rhs_pf2 = rhs(rhs_vert, rhs_horiz_2); \ 216 rhs_pf3 = rhs(rhs_vert, rhs_horiz_3); \ 217 rhs_pf4 = rhs(rhs_vert, rhs_horiz_4); \ 218 } else if (rhs_horiz_3 < n_size) { \ 219 rhs_pf0 = rhs(rhs_vert, rhs_horiz_0); \ 220 rhs_pf1 = rhs(rhs_vert, rhs_horiz_1); \ 221 rhs_pf2 = rhs(rhs_vert, rhs_horiz_2); \ 222 rhs_pf3 = rhs(rhs_vert, rhs_horiz_3); \ 223 } else if (rhs_horiz_2 < n_size) { \ 224 rhs_pf0 = rhs(rhs_vert, rhs_horiz_0); \ 225 rhs_pf1 = rhs(rhs_vert, rhs_horiz_1); \ 226 rhs_pf2 = rhs(rhs_vert, rhs_horiz_2); \ 227 } else if (rhs_horiz_1 < n_size) { \ 228 rhs_pf0 = rhs(rhs_vert, rhs_horiz_0); \ 229 rhs_pf1 = rhs(rhs_vert, rhs_horiz_1); \ 230 } else if (rhs_horiz_0 < n_size) { \ 231 rhs_pf0 = rhs(rhs_vert, rhs_horiz_0); \ 232 } \ 233 } \ 234 } \ 235 236 #define writeRegToShmem(_) \ 237 lhs_shmem[lhs_store_idx_0] = lhs_pf0; \ 238 rhs_shmem[rhs_store_idx_0] = rhs_pf0; \ 239 \ 240 lhs_shmem[lhs_store_idx_1] = lhs_pf1; \ 241 rhs_shmem[rhs_store_idx_1] = rhs_pf1; \ 242 \ 243 lhs_shmem[lhs_store_idx_2] = lhs_pf2; \ 244 rhs_shmem[rhs_store_idx_2] = rhs_pf2; \ 245 \ 246 lhs_shmem[lhs_store_idx_3] = lhs_pf3; \ 247 rhs_shmem[rhs_store_idx_3] = rhs_pf3; \ 248 \ 249 lhs_shmem[lhs_store_idx_4] = lhs_pf4; \ 250 rhs_shmem[rhs_store_idx_4] = rhs_pf4; \ 251 \ 252 lhs_shmem[lhs_store_idx_5] = lhs_pf5; \ 253 rhs_shmem[rhs_store_idx_5] = rhs_pf5; \ 254 \ 255 lhs_shmem[lhs_store_idx_6] = lhs_pf6; \ 256 rhs_shmem[rhs_store_idx_6] = rhs_pf6; \ 257 \ 258 lhs_shmem[lhs_store_idx_7] = lhs_pf7; \ 259 rhs_shmem[rhs_store_idx_7] = rhs_pf7; \ 260 261 // declare and initialize result array 262 #define res(i, j) _res_##i##j 263 #define initResultRow(i) \ 264 Scalar res(i, 0) = conv(0); \ 265 Scalar res(i, 1) = conv(0); \ 266 Scalar res(i, 2) = conv(0); \ 267 Scalar res(i, 3) = conv(0); \ 268 Scalar res(i, 4) = conv(0); \ 269 Scalar res(i, 5) = conv(0); \ 270 Scalar res(i, 6) = conv(0); \ 271 Scalar res(i, 7) = conv(0); \ 272 273 internal::scalar_cast_op<int, Scalar> conv; 274 initResultRow(0); 275 initResultRow(1); 276 initResultRow(2); 277 initResultRow(3); 278 initResultRow(4); 279 initResultRow(5); 280 initResultRow(6); 281 initResultRow(7); 282 #undef initResultRow 283 284 for (Index base_k = 0; base_k < k_size; base_k += 64) { 285 // wait for previous iteration to finish with shmem. Despite common sense, 286 // the code is a bit faster with this here then at bottom of loop 287 __syncthreads(); 288 289 prefetchIntoRegisters(base_k); 290 writeRegToShmem(); 291 292 #undef prefetchIntoRegisters 293 #undef writeRegToShmem 294 295 // wait for shared mem packing to be done before starting computation 296 __syncthreads(); 297 298 // compute 8x8 matrix product by outer product. This involves packing one column 299 // of LHS and one row of RHS into registers (takes 16 registers). 300 301 #define lcol(i) _lcol##i 302 Scalar lcol(0); 303 Scalar lcol(1); 304 Scalar lcol(2); 305 Scalar lcol(3); 306 Scalar lcol(4); 307 Scalar lcol(5); 308 Scalar lcol(6); 309 Scalar lcol(7); 310 311 #define rrow(j) _rrow##j 312 Scalar rrow(0); 313 Scalar rrow(1); 314 Scalar rrow(2); 315 Scalar rrow(3); 316 Scalar rrow(4); 317 Scalar rrow(5); 318 Scalar rrow(6); 319 Scalar rrow(7); 320 321 // Now x corresponds to k, y to m, and z to n 322 const Scalar* lhs_block = &lhs_shmem[threadIdx.x + 9 * threadIdx.y]; 323 const Scalar* rhs_block = &rhs_shmem[threadIdx.x + 8 * threadIdx.z]; 324 325 #define lhs_element(i, j) lhs_block[72 * ((i) + 8 * (j))] 326 #define rhs_element(i, j) rhs_block[72 * ((i) + 8 * (j))] 327 328 #define loadData(i, j) \ 329 lcol(0) = lhs_element(0, j); \ 330 rrow(0) = rhs_element(i, 0); \ 331 lcol(1) = lhs_element(1, j); \ 332 rrow(1) = rhs_element(i, 1); \ 333 lcol(2) = lhs_element(2, j); \ 334 rrow(2) = rhs_element(i, 2); \ 335 lcol(3) = lhs_element(3, j); \ 336 rrow(3) = rhs_element(i, 3); \ 337 lcol(4) = lhs_element(4, j); \ 338 rrow(4) = rhs_element(i, 4); \ 339 lcol(5) = lhs_element(5, j); \ 340 rrow(5) = rhs_element(i, 5); \ 341 lcol(6) = lhs_element(6, j); \ 342 rrow(6) = rhs_element(i, 6); \ 343 lcol(7) = lhs_element(7, j); \ 344 rrow(7) = rhs_element(i, 7); \ 345 346 #define computeCol(j) \ 347 res(0, j) += lcol(0) * rrow(j); \ 348 res(1, j) += lcol(1) * rrow(j); \ 349 res(2, j) += lcol(2) * rrow(j); \ 350 res(3, j) += lcol(3) * rrow(j); \ 351 res(4, j) += lcol(4) * rrow(j); \ 352 res(5, j) += lcol(5) * rrow(j); \ 353 res(6, j) += lcol(6) * rrow(j); \ 354 res(7, j) += lcol(7) * rrow(j); \ 355 356 #define computePass(i) \ 357 loadData(i, i); \ 358 \ 359 computeCol(0); \ 360 computeCol(1); \ 361 computeCol(2); \ 362 computeCol(3); \ 363 computeCol(4); \ 364 computeCol(5); \ 365 computeCol(6); \ 366 computeCol(7); \ 367 368 computePass(0); 369 computePass(1); 370 computePass(2); 371 computePass(3); 372 computePass(4); 373 computePass(5); 374 computePass(6); 375 computePass(7); 376 377 #undef lcol 378 #undef rrow 379 #undef lhs_element 380 #undef rhs_element 381 #undef loadData 382 #undef computeCol 383 #undef computePass 384 } // end loop over k 385 386 // we've now iterated over all of the large (ie width 64) k blocks and 387 // accumulated results in registers. At this point thread (x, y, z) contains 388 // the sum across all big k blocks of the product of little k block of index (x, y) 389 // with block of index (y, z). To compute the final output, we need to reduce 390 // the 8 threads over y by summation. 391 #if defined(EIGEN_HIPCC) || (defined(EIGEN_CUDA_SDK_VER) && EIGEN_CUDA_SDK_VER < 90000) 392 #define shuffleInc(i, j, mask) res(i, j) += __shfl_xor(res(i, j), mask) 393 #else 394 #define shuffleInc(i, j, mask) res(i, j) += __shfl_xor_sync(0xFFFFFFFF, res(i, j), mask) 395 #endif 396 397 #define reduceRow(i, mask) \ 398 shuffleInc(i, 0, mask); \ 399 shuffleInc(i, 1, mask); \ 400 shuffleInc(i, 2, mask); \ 401 shuffleInc(i, 3, mask); \ 402 shuffleInc(i, 4, mask); \ 403 shuffleInc(i, 5, mask); \ 404 shuffleInc(i, 6, mask); \ 405 shuffleInc(i, 7, mask); \ 406 407 #define reduceMatrix(mask) \ 408 reduceRow(0, mask); \ 409 reduceRow(1, mask); \ 410 reduceRow(2, mask); \ 411 reduceRow(3, mask); \ 412 reduceRow(4, mask); \ 413 reduceRow(5, mask); \ 414 reduceRow(6, mask); \ 415 reduceRow(7, mask); \ 416 417 // actually perform the reduction, now each thread of index (_, y, z) 418 // contains the correct values in its registers that belong in the output 419 // block 420 reduceMatrix(1); 421 reduceMatrix(2); 422 reduceMatrix(4); 423 424 #undef shuffleInc 425 #undef reduceRow 426 #undef reduceMatrix 427 428 // now we need to copy the 64 values into main memory. We can't split work 429 // among threads because all variables are in registers. There's 2 ways 430 // to do this: 431 // (1) have 1 thread do 64 writes from registers into global memory 432 // (2) have 1 thread do 64 writes into shared memory, and then 8 threads 433 // each do 8 writes into global memory. We can just overwrite the shared 434 // memory from the problem we just solved. 435 // (2) is slightly faster than (1) due to less branching and more ILP 436 437 // TODO: won't yield much gain, but could just use currently unused shared mem 438 // and then we won't have to sync 439 // wait for shared mem to be out of use 440 __syncthreads(); 441 442 #define writeResultShmem(i, j) \ 443 lhs_shmem[i + 8 * threadIdx.y + 64 * threadIdx.z + 512 * j] = res(i, j); \ 444 445 #define writeRow(i) \ 446 writeResultShmem(i, 0); \ 447 writeResultShmem(i, 1); \ 448 writeResultShmem(i, 2); \ 449 writeResultShmem(i, 3); \ 450 writeResultShmem(i, 4); \ 451 writeResultShmem(i, 5); \ 452 writeResultShmem(i, 6); \ 453 writeResultShmem(i, 7); \ 454 455 if (threadIdx.x == 0) { 456 writeRow(0); 457 writeRow(1); 458 writeRow(2); 459 writeRow(3); 460 writeRow(4); 461 writeRow(5); 462 writeRow(6); 463 writeRow(7); 464 } 465 #undef writeResultShmem 466 #undef writeRow 467 468 const int max_i_write = numext::mini((int)((m_size - base_m - threadIdx.y + 7) / 8), 8); 469 const int max_j_write = numext::mini((int)((n_size - base_n - threadIdx.z + 7) / 8), 8); 470 471 if (threadIdx.x < max_i_write) { 472 if (max_j_write == 8) { 473 // TODO: can i trade bank conflicts for coalesced writes? 474 Scalar val0 = lhs_shmem[threadIdx.x + 8 * threadIdx.y + 64 * threadIdx.z + 512 * 0]; 475 Scalar val1 = lhs_shmem[threadIdx.x + 8 * threadIdx.y + 64 * threadIdx.z + 512 * 1]; 476 Scalar val2 = lhs_shmem[threadIdx.x + 8 * threadIdx.y + 64 * threadIdx.z + 512 * 2]; 477 Scalar val3 = lhs_shmem[threadIdx.x + 8 * threadIdx.y + 64 * threadIdx.z + 512 * 3]; 478 Scalar val4 = lhs_shmem[threadIdx.x + 8 * threadIdx.y + 64 * threadIdx.z + 512 * 4]; 479 Scalar val5 = lhs_shmem[threadIdx.x + 8 * threadIdx.y + 64 * threadIdx.z + 512 * 5]; 480 Scalar val6 = lhs_shmem[threadIdx.x + 8 * threadIdx.y + 64 * threadIdx.z + 512 * 6]; 481 Scalar val7 = lhs_shmem[threadIdx.x + 8 * threadIdx.y + 64 * threadIdx.z + 512 * 7]; 482 483 output(base_m + threadIdx.y + 8 * threadIdx.x, base_n + threadIdx.z + 8 * 0) = val0; 484 output(base_m + threadIdx.y + 8 * threadIdx.x, base_n + threadIdx.z + 8 * 1) = val1; 485 output(base_m + threadIdx.y + 8 * threadIdx.x, base_n + threadIdx.z + 8 * 2) = val2; 486 output(base_m + threadIdx.y + 8 * threadIdx.x, base_n + threadIdx.z + 8 * 3) = val3; 487 output(base_m + threadIdx.y + 8 * threadIdx.x, base_n + threadIdx.z + 8 * 4) = val4; 488 output(base_m + threadIdx.y + 8 * threadIdx.x, base_n + threadIdx.z + 8 * 5) = val5; 489 output(base_m + threadIdx.y + 8 * threadIdx.x, base_n + threadIdx.z + 8 * 6) = val6; 490 output(base_m + threadIdx.y + 8 * threadIdx.x, base_n + threadIdx.z + 8 * 7) = val7; 491 } else { 492 #pragma unroll 7 493 for (int j = 0; j < max_j_write; j++) { 494 Scalar val = lhs_shmem[threadIdx.x + 8 * threadIdx.y + 64 * threadIdx.z + 512 * j]; 495 output(base_m + threadIdx.y + 8 * threadIdx.x, base_n + threadIdx.z + 8 * j) = val; 496 } 497 } 498 } 499 #undef res 500 } 501 502 503 template<typename Scalar, typename Index, typename LhsMapper, 504 typename RhsMapper, typename OutputMapper> 505 __global__ void 506 #if defined(EIGEN_HIPCC) 507 __launch_bounds__(512, 1) 508 #else 509 __launch_bounds__(512) 510 #endif 511 EigenContractionKernel(const LhsMapper lhs, const RhsMapper rhs, 512 const OutputMapper output, 513 const Index m_size, const Index n_size, const Index k_size) { 514 __shared__ Scalar lhs_shmem[72 * 64]; 515 __shared__ Scalar rhs_shmem[72 * 64]; 516 517 const Index m_block_idx = blockIdx.x; 518 const Index n_block_idx = blockIdx.y; 519 520 const Index base_m = 64 * m_block_idx; 521 const Index base_n = 64 * n_block_idx; 522 523 if (base_m + 63 < m_size && base_n + 63 < n_size) { 524 EigenContractionKernelInternal<Scalar, Index, LhsMapper, RhsMapper, OutputMapper, false>(lhs, rhs, output, lhs_shmem, rhs_shmem, m_size, n_size, k_size); 525 } else { 526 EigenContractionKernelInternal<Scalar, Index, LhsMapper, RhsMapper, OutputMapper, true>(lhs, rhs, output, lhs_shmem, rhs_shmem, m_size, n_size, k_size); 527 } 528 } 529 530 531 template<typename Index, typename LhsMapper, 532 typename RhsMapper, typename OutputMapper, bool CHECK_LHS_BOUNDARY, 533 bool CHECK_RHS_BOUNDARY> 534 __device__ __forceinline__ void 535 EigenFloatContractionKernelInternal16x16(const LhsMapper lhs, const RhsMapper rhs, 536 const OutputMapper output, float2 lhs_shmem2[][16], 537 float2 rhs_shmem2[][8], const Index m_size, 538 const Index n_size, const Index k_size, 539 const Index base_m, const Index base_n) { 540 541 // prefetch registers 542 float4 lhs_pf0, rhs_pf0; 543 544 float4 results[4]; 545 for (int i=0; i < 4; i++) { 546 results[i].x = results[i].y = results[i].z = results[i].w = 0; 547 } 548 549 #define prefetch_lhs(reg, row, col) \ 550 if (!CHECK_LHS_BOUNDARY) { \ 551 if (col < k_size) { \ 552 reg =lhs.template loadPacket<float4,Unaligned>(row, col); \ 553 } \ 554 } else { \ 555 if (col < k_size) { \ 556 if (row + 3 < m_size) { \ 557 reg =lhs.template loadPacket<float4,Unaligned>(row, col); \ 558 } else if (row + 2 < m_size) { \ 559 reg.x =lhs(row + 0, col); \ 560 reg.y =lhs(row + 1, col); \ 561 reg.z =lhs(row + 2, col); \ 562 } else if (row + 1 < m_size) { \ 563 reg.x =lhs(row + 0, col); \ 564 reg.y =lhs(row + 1, col); \ 565 } else if (row < m_size) { \ 566 reg.x =lhs(row + 0, col); \ 567 } \ 568 } \ 569 } \ 570 571 Index lhs_vert = base_m+threadIdx.x*4; 572 573 for (Index k = 0; k < k_size; k += 16) { 574 575 lhs_pf0 = internal::pset1<float4>(0); 576 rhs_pf0 = internal::pset1<float4>(0); 577 578 Index lhs_horiz = threadIdx.y+k; 579 prefetch_lhs(lhs_pf0, lhs_vert, lhs_horiz) 580 581 Index rhs_vert = k+(threadIdx.x%4)*4; 582 Index rhs_horiz0 = (threadIdx.x>>2)+threadIdx.y*4+base_n; 583 584 if (!CHECK_RHS_BOUNDARY) { 585 if ((rhs_vert + 3) < k_size) { 586 // just CHECK_RHS_BOUNDARY 587 rhs_pf0 = rhs.template loadPacket<float4,Unaligned>(rhs_vert, rhs_horiz0); 588 } else if (rhs_vert + 2 < k_size) { 589 // just CHECK_RHS_BOUNDARY 590 rhs_pf0.x = rhs(rhs_vert, rhs_horiz0); 591 rhs_pf0.y = rhs(rhs_vert + 1, rhs_horiz0); 592 rhs_pf0.z = rhs(rhs_vert + 2, rhs_horiz0); 593 } else if (rhs_vert + 1 < k_size) { 594 rhs_pf0.x = rhs(rhs_vert, rhs_horiz0); 595 rhs_pf0.y = rhs(rhs_vert + 1, rhs_horiz0); 596 } else if (rhs_vert < k_size) { 597 rhs_pf0.x = rhs(rhs_vert, rhs_horiz0); 598 } 599 } else { 600 if (rhs_horiz0 < n_size) { 601 if ((rhs_vert + 3) < k_size) { 602 rhs_pf0 = rhs.template loadPacket<float4,Unaligned>(rhs_vert, rhs_horiz0); 603 } else if ((rhs_vert + 2) < k_size) { 604 rhs_pf0.x = rhs(rhs_vert, rhs_horiz0); 605 rhs_pf0.y = rhs(rhs_vert + 1, rhs_horiz0); 606 rhs_pf0.z = rhs(rhs_vert + 2, rhs_horiz0); 607 } else if ((rhs_vert + 1) < k_size) { 608 rhs_pf0.x = rhs(rhs_vert, rhs_horiz0); 609 rhs_pf0.y = rhs(rhs_vert + 1, rhs_horiz0); 610 } else if (rhs_vert < k_size) { 611 rhs_pf0.x = rhs(rhs_vert, rhs_horiz0); 612 } 613 } 614 } 615 float x1, x2 ; 616 // the following can be a bitwise operation..... some day. 617 if((threadIdx.x%8) < 4) { 618 x1 = rhs_pf0.y; 619 x2 = rhs_pf0.w; 620 } else { 621 x1 = rhs_pf0.x; 622 x2 = rhs_pf0.z; 623 } 624 #if defined(EIGEN_HIPCC) || (defined(EIGEN_CUDA_SDK_VER) && EIGEN_CUDA_SDK_VER < 90000) 625 x1 = __shfl_xor(x1, 4); 626 x2 = __shfl_xor(x2, 4); 627 #else 628 x1 = __shfl_xor_sync(0xFFFFFFFF, x1, 4); 629 x2 = __shfl_xor_sync(0xFFFFFFFF, x2, 4); 630 #endif 631 if((threadIdx.x%8) < 4) { 632 rhs_pf0.y = x1; 633 rhs_pf0.w = x2; 634 } else { 635 rhs_pf0.x = x1; 636 rhs_pf0.z = x2; 637 } 638 639 // We have 64 features. 640 // Row 0 -> times (0, 4, 8, 12, 1, 5, 9, 13) for features 0, 1. 641 // Row 1 -> times (0, 4, 8, 12, 1, 5, 9, 13) for features 2, 3. 642 // ... 643 // Row 31 -> times (0, 4, 8, 12, 1, 5, 9, 13) for features 62, 63 644 // Row 32 -> times (2, 6, 10, 14, 3, 7, 11, 15) for features 0, 1 645 // ... 646 rhs_shmem2[(threadIdx.x>>3)+ threadIdx.y*2][threadIdx.x%8] = make_float2(rhs_pf0.x, rhs_pf0.y); 647 rhs_shmem2[(threadIdx.x>>3)+ threadIdx.y*2+32][threadIdx.x%8] = make_float2(rhs_pf0.z, rhs_pf0.w); 648 649 // Row 0 (time 0) -> features (0, 1), (4, 5), .. (28, 29), (32, 33), .. (60, 61) 650 // Row 1 (time 1) -> features (0, 1), (4, 5), .. (28, 29), (32, 33), .. (60, 61) 651 // ... 652 // Row 15 (time 15) -> features (0, 1), (4, 5), .. (28, 29), (32, 33), .. (60, 61) 653 // Row 16 (time 0) -> features (2, 3), (6, 7), .. (30, 31), (34, 35), .. (62, 63) 654 // ... 655 656 lhs_shmem2[threadIdx.y][threadIdx.x] = make_float2(lhs_pf0.x, lhs_pf0.y); 657 lhs_shmem2[threadIdx.y+16][threadIdx.x] = make_float2(lhs_pf0.z, lhs_pf0.w); 658 659 660 #define add_vals(fl1, fl2, fr1, fr2)\ 661 results[0].x += fl1.x * fr1.x;\ 662 results[0].y += fl1.y * fr1.x;\ 663 results[0].z += fl2.x * fr1.x;\ 664 results[0].w += fl2.y * fr1.x;\ 665 \ 666 results[1].x += fl1.x * fr1.y;\ 667 results[1].y += fl1.y * fr1.y;\ 668 results[1].z += fl2.x * fr1.y;\ 669 results[1].w += fl2.y * fr1.y;\ 670 \ 671 results[2].x += fl1.x * fr2.x;\ 672 results[2].y += fl1.y * fr2.x;\ 673 results[2].z += fl2.x * fr2.x;\ 674 results[2].w += fl2.y * fr2.x;\ 675 \ 676 results[3].x += fl1.x * fr2.y;\ 677 results[3].y += fl1.y * fr2.y;\ 678 results[3].z += fl2.x * fr2.y;\ 679 results[3].w += fl2.y * fr2.y;\ 680 681 __syncthreads(); 682 683 // Do the multiplies. 684 #pragma unroll 685 for (int koff = 0; koff < 16; koff ++) { 686 // 32 x threads. 687 float2 fl1 = lhs_shmem2[koff][threadIdx.x]; 688 float2 fl2 = lhs_shmem2[koff + 16][threadIdx.x]; 689 690 int start_feature = threadIdx.y * 4; 691 float2 fr1 = rhs_shmem2[(start_feature>>1) + 32*((koff%4)/2)][koff/4 + (koff%2)*4]; 692 float2 fr2 = rhs_shmem2[(start_feature>>1) + 1 + 32*((koff%4)/2)][koff/4 + (koff%2)*4]; 693 694 add_vals(fl1, fl2, fr1, fr2) 695 } 696 __syncthreads(); 697 } 698 699 #undef prefetch_lhs 700 #undef add_vals 701 702 Index horiz_base = threadIdx.y*4+base_n; 703 if (!CHECK_LHS_BOUNDARY && !CHECK_RHS_BOUNDARY) { 704 for (int i = 0; i < 4; i++) { 705 output(lhs_vert, horiz_base + i) = results[i].x; 706 output(lhs_vert + 1, horiz_base + i) = results[i].y; 707 output(lhs_vert + 2, horiz_base + i) = results[i].z; 708 output(lhs_vert + 3, horiz_base + i) = results[i].w; 709 } 710 } else if (!CHECK_RHS_BOUNDARY) { 711 // CHECK LHS 712 if (lhs_vert + 3 < m_size) { 713 for (int i = 0; i < 4; i++) { 714 output(lhs_vert, horiz_base + i) = results[i].x; 715 output(lhs_vert + 1, horiz_base + i) = results[i].y; 716 output(lhs_vert + 2, horiz_base + i) = results[i].z; 717 output(lhs_vert + 3, horiz_base + i) = results[i].w; 718 } 719 } else if (lhs_vert + 2 < m_size) { 720 for (int i = 0; i < 4; i++) { 721 output(lhs_vert, horiz_base + i) = results[i].x; 722 output(lhs_vert + 1, horiz_base + i) = results[i].y; 723 output(lhs_vert + 2, horiz_base + i) = results[i].z; 724 } 725 } else if (lhs_vert + 1 < m_size) { 726 for (int i = 0; i < 4; i++) { 727 output(lhs_vert, horiz_base + i) = results[i].x; 728 output(lhs_vert + 1, horiz_base + i) = results[i].y; 729 } 730 } else if (lhs_vert < m_size) { 731 for (int i = 0; i < 4; i++) { 732 output(lhs_vert, horiz_base + i) = results[i].x; 733 } 734 } 735 } else if (!CHECK_LHS_BOUNDARY) { 736 // CHECK RHS 737 /* 738 int ncols_rem = fminf(n_size- horiz_base, 4); 739 for (int i = 0; i < ncols_rem; i++) { 740 output(lhs_vert, horiz_base + i) = results[i].x; 741 output(lhs_vert + 1, horiz_base + i) = results[i].y; 742 output(lhs_vert + 2, horiz_base + i) = results[i].z; 743 output(lhs_vert + 3, horiz_base + i) = results[i].w; 744 }*/ 745 for (int i = 0; i < 4; i++) { 746 if (horiz_base+i < n_size) { 747 output(lhs_vert, horiz_base + i) = results[i].x; 748 output(lhs_vert + 1, horiz_base + i) = results[i].y; 749 output(lhs_vert + 2, horiz_base + i) = results[i].z; 750 output(lhs_vert + 3, horiz_base + i) = results[i].w; 751 } 752 } 753 } else { 754 // CHECK both boundaries. 755 for (int i = 0; i < 4; i++) { 756 if (horiz_base+i < n_size) { 757 if (lhs_vert < m_size) 758 output(lhs_vert, horiz_base + i) = results[i].x; 759 if (lhs_vert + 1 < m_size) 760 output(lhs_vert + 1, horiz_base + i) = results[i].y; 761 if (lhs_vert + 2 < m_size) 762 output(lhs_vert + 2, horiz_base + i) = results[i].z; 763 if (lhs_vert + 3 < m_size) 764 output(lhs_vert + 3, horiz_base + i) = results[i].w; 765 } 766 } 767 } 768 } 769 770 771 template<typename Index, typename LhsMapper, 772 typename RhsMapper, typename OutputMapper, bool CHECK_LHS_BOUNDARY, 773 bool CHECK_RHS_BOUNDARY> 774 __device__ __forceinline__ void 775 EigenFloatContractionKernelInternal(const LhsMapper lhs, const RhsMapper rhs, 776 const OutputMapper output, float2 lhs_shmem2[][32], 777 float2 rhs_shmem2[][8], const Index m_size, 778 const Index n_size, const Index k_size, 779 const Index base_m, const Index base_n) { 780 781 // prefetch registers 782 float4 lhs_pf0, lhs_pf1, lhs_pf2, lhs_pf3; 783 float4 rhs_pf0, rhs_pf1; 784 785 float4 results[8]; 786 for (int i=0; i < 8; i++) { 787 results[i].x = results[i].y = results[i].z = results[i].w = 0; 788 } 789 790 Index lhs_vert = base_m+threadIdx.x*4+(threadIdx.y%4)*32; 791 for (Index k = 0; k < k_size; k += 32) { 792 lhs_pf0 = internal::pset1<float4>(0); 793 lhs_pf1 = internal::pset1<float4>(0); 794 lhs_pf2 = internal::pset1<float4>(0); 795 lhs_pf3 = internal::pset1<float4>(0); 796 797 rhs_pf0 = internal::pset1<float4>(0); 798 rhs_pf1 = internal::pset1<float4>(0); 799 800 if (!CHECK_LHS_BOUNDARY) { 801 if ((threadIdx.y/4+k+24) < k_size) { 802 lhs_pf0 =lhs.template loadPacket<float4,Unaligned>(lhs_vert, (threadIdx.y/4+k)); 803 lhs_pf1 =lhs.template loadPacket<float4,Unaligned>(lhs_vert, (threadIdx.y/4+k+8)); 804 lhs_pf2 =lhs.template loadPacket<float4,Unaligned>(lhs_vert, (threadIdx.y/4+k+16)); 805 lhs_pf3 =lhs.template loadPacket<float4,Unaligned>(lhs_vert, (threadIdx.y/4+k+24)); 806 } else if ((threadIdx.y/4+k+16) < k_size) { 807 lhs_pf0 =lhs.template loadPacket<float4,Unaligned>(lhs_vert, (threadIdx.y/4+k)); 808 lhs_pf1 =lhs.template loadPacket<float4,Unaligned>(lhs_vert, (threadIdx.y/4+k+8)); 809 lhs_pf2 =lhs.template loadPacket<float4,Unaligned>(lhs_vert, (threadIdx.y/4+k+16)); 810 } else if ((threadIdx.y/4+k+8) < k_size) { 811 lhs_pf0 =lhs.template loadPacket<float4,Unaligned>(lhs_vert, (threadIdx.y/4+k)); 812 lhs_pf1 =lhs.template loadPacket<float4,Unaligned>(lhs_vert, (threadIdx.y/4+k+8)); 813 } else if ((threadIdx.y/4+k) < k_size) { 814 lhs_pf0 =lhs.template loadPacket<float4,Unaligned>(lhs_vert, (threadIdx.y/4+k)); 815 } 816 } else { 817 // just CHECK_LHS_BOUNDARY 818 if (lhs_vert + 3 < m_size) { 819 if ((threadIdx.y/4+k+24) < k_size) { 820 lhs_pf0 =lhs.template loadPacket<float4,Unaligned>(lhs_vert, (threadIdx.y/4+k)); 821 lhs_pf1 =lhs.template loadPacket<float4,Unaligned>(lhs_vert, (threadIdx.y/4+k+8)); 822 lhs_pf2 =lhs.template loadPacket<float4,Unaligned>(lhs_vert, (threadIdx.y/4+k+16)); 823 lhs_pf3 =lhs.template loadPacket<float4,Unaligned>(lhs_vert, (threadIdx.y/4+k+24)); 824 } else if ((threadIdx.y/4+k+16) < k_size) { 825 lhs_pf0 =lhs.template loadPacket<float4,Unaligned>(lhs_vert, (threadIdx.y/4+k)); 826 lhs_pf1 =lhs.template loadPacket<float4,Unaligned>(lhs_vert, (threadIdx.y/4+k+8)); 827 lhs_pf2 =lhs.template loadPacket<float4,Unaligned>(lhs_vert, (threadIdx.y/4+k+16)); 828 } else if ((threadIdx.y/4+k+8) < k_size) { 829 lhs_pf0 =lhs.template loadPacket<float4,Unaligned>(lhs_vert, (threadIdx.y/4+k)); 830 lhs_pf1 =lhs.template loadPacket<float4,Unaligned>(lhs_vert, (threadIdx.y/4+k+8)); 831 } else if ((threadIdx.y/4+k) < k_size) { 832 lhs_pf0 =lhs.template loadPacket<float4,Unaligned>(lhs_vert, (threadIdx.y/4+k)); 833 } 834 } else if (lhs_vert + 2 < m_size) { 835 if ((threadIdx.y/4+k+24) < k_size) { 836 lhs_pf0.x =lhs(lhs_vert + 0, (threadIdx.y/4+k)); 837 lhs_pf0.y =lhs(lhs_vert + 1, (threadIdx.y/4+k)); 838 lhs_pf0.z =lhs(lhs_vert + 2, (threadIdx.y/4+k)); 839 lhs_pf1.x =lhs(lhs_vert + 0, (threadIdx.y/4+k+8)); 840 lhs_pf1.y =lhs(lhs_vert + 1, (threadIdx.y/4+k+8)); 841 lhs_pf1.z =lhs(lhs_vert + 2, (threadIdx.y/4+k+8)); 842 lhs_pf2.x =lhs(lhs_vert + 0, (threadIdx.y/4+k+16)); 843 lhs_pf2.y =lhs(lhs_vert + 1, (threadIdx.y/4+k+16)); 844 lhs_pf2.z =lhs(lhs_vert + 2, (threadIdx.y/4+k+16)); 845 lhs_pf3.x =lhs(lhs_vert + 0, (threadIdx.y/4+k+24)); 846 lhs_pf3.y =lhs(lhs_vert + 1, (threadIdx.y/4+k+24)); 847 lhs_pf3.z =lhs(lhs_vert + 2, (threadIdx.y/4+k+24)); 848 } else if ((threadIdx.y/4+k+16) < k_size) { 849 lhs_pf0.x =lhs(lhs_vert + 0, (threadIdx.y/4+k)); 850 lhs_pf0.y =lhs(lhs_vert + 1, (threadIdx.y/4+k)); 851 lhs_pf0.z =lhs(lhs_vert + 2, (threadIdx.y/4+k)); 852 lhs_pf1.x =lhs(lhs_vert + 0, (threadIdx.y/4+k+8)); 853 lhs_pf1.y =lhs(lhs_vert + 1, (threadIdx.y/4+k+8)); 854 lhs_pf1.z =lhs(lhs_vert + 2, (threadIdx.y/4+k+8)); 855 lhs_pf2.x =lhs(lhs_vert + 0, (threadIdx.y/4+k+16)); 856 lhs_pf2.y =lhs(lhs_vert + 1, (threadIdx.y/4+k+16)); 857 lhs_pf2.z =lhs(lhs_vert + 2, (threadIdx.y/4+k+16)); 858 } else if ((threadIdx.y/4+k+8) < k_size) { 859 lhs_pf0.x =lhs(lhs_vert + 0, (threadIdx.y/4+k)); 860 lhs_pf0.y =lhs(lhs_vert + 1, (threadIdx.y/4+k)); 861 lhs_pf0.z =lhs(lhs_vert + 2, (threadIdx.y/4+k)); 862 lhs_pf1.x =lhs(lhs_vert + 0, (threadIdx.y/4+k+8)); 863 lhs_pf1.y =lhs(lhs_vert + 1, (threadIdx.y/4+k+8)); 864 lhs_pf1.z =lhs(lhs_vert + 2, (threadIdx.y/4+k+8)); 865 } else if ((threadIdx.y/4+k) < k_size) { 866 lhs_pf0.x =lhs(lhs_vert + 0, (threadIdx.y/4+k)); 867 lhs_pf0.y =lhs(lhs_vert + 1, (threadIdx.y/4+k)); 868 lhs_pf0.z =lhs(lhs_vert + 2, (threadIdx.y/4+k)); 869 } 870 } else if (lhs_vert + 1 < m_size) { 871 if ((threadIdx.y/4+k+24) < k_size) { 872 lhs_pf0.x =lhs(lhs_vert + 0, (threadIdx.y/4+k)); 873 lhs_pf0.y =lhs(lhs_vert + 1, (threadIdx.y/4+k)); 874 lhs_pf1.x =lhs(lhs_vert + 0, (threadIdx.y/4+k+8)); 875 lhs_pf1.y =lhs(lhs_vert + 1, (threadIdx.y/4+k+8)); 876 lhs_pf2.x =lhs(lhs_vert + 0, (threadIdx.y/4+k+16)); 877 lhs_pf2.y =lhs(lhs_vert + 1, (threadIdx.y/4+k+16)); 878 lhs_pf3.x =lhs(lhs_vert + 0, (threadIdx.y/4+k+24)); 879 lhs_pf3.y =lhs(lhs_vert + 1, (threadIdx.y/4+k+24)); 880 } else if ((threadIdx.y/4+k+16) < k_size) { 881 lhs_pf0.x =lhs(lhs_vert + 0, (threadIdx.y/4+k)); 882 lhs_pf0.y =lhs(lhs_vert + 1, (threadIdx.y/4+k)); 883 lhs_pf1.x =lhs(lhs_vert + 0, (threadIdx.y/4+k+8)); 884 lhs_pf1.y =lhs(lhs_vert + 1, (threadIdx.y/4+k+8)); 885 lhs_pf2.x =lhs(lhs_vert + 0, (threadIdx.y/4+k+16)); 886 lhs_pf2.y =lhs(lhs_vert + 1, (threadIdx.y/4+k+16)); 887 } else if ((threadIdx.y/4+k+8) < k_size) { 888 lhs_pf0.x =lhs(lhs_vert + 0, (threadIdx.y/4+k)); 889 lhs_pf0.y =lhs(lhs_vert + 1, (threadIdx.y/4+k)); 890 lhs_pf1.x =lhs(lhs_vert + 0, (threadIdx.y/4+k+8)); 891 lhs_pf1.y =lhs(lhs_vert + 1, (threadIdx.y/4+k+8)); 892 } else if ((threadIdx.y/4+k) < k_size) { 893 lhs_pf0.x =lhs(lhs_vert + 0, (threadIdx.y/4+k)); 894 lhs_pf0.y =lhs(lhs_vert + 1, (threadIdx.y/4+k)); 895 } 896 } else if (lhs_vert < m_size) { 897 if ((threadIdx.y/4+k+24) < k_size) { 898 lhs_pf0.x =lhs(lhs_vert + 0, (threadIdx.y/4+k)); 899 lhs_pf1.x =lhs(lhs_vert + 0, (threadIdx.y/4+k+8)); 900 lhs_pf2.x =lhs(lhs_vert + 0, (threadIdx.y/4+k+16)); 901 lhs_pf3.x =lhs(lhs_vert + 0, (threadIdx.y/4+k+24)); 902 } else if ((threadIdx.y/4+k+16) < k_size) { 903 lhs_pf0.x =lhs(lhs_vert + 0, (threadIdx.y/4+k)); 904 lhs_pf1.x =lhs(lhs_vert + 0, (threadIdx.y/4+k+8)); 905 lhs_pf2.x =lhs(lhs_vert + 0, (threadIdx.y/4+k+16)); 906 } else if ((threadIdx.y/4+k+8) < k_size) { 907 lhs_pf0.x =lhs(lhs_vert + 0, (threadIdx.y/4+k)); 908 lhs_pf1.x =lhs(lhs_vert + 0, (threadIdx.y/4+k+8)); 909 } else if ((threadIdx.y/4+k) < k_size) { 910 lhs_pf0.x =lhs(lhs_vert + 0, (threadIdx.y/4+k)); 911 } 912 } 913 } 914 __syncthreads(); 915 Index rhs_vert = k+threadIdx.x*4; 916 Index rhs_horiz0 = threadIdx.y*2+base_n; 917 Index rhs_horiz1 = threadIdx.y*2+1+base_n; 918 if (!CHECK_RHS_BOUNDARY) { 919 if ((rhs_vert + 3) < k_size) { 920 // just CHECK_RHS_BOUNDARY 921 rhs_pf0 = rhs.template loadPacket<float4,Unaligned>(rhs_vert, rhs_horiz0); 922 rhs_pf1 = rhs.template loadPacket<float4,Unaligned>(rhs_vert, rhs_horiz1); 923 } else if (rhs_vert + 2 < k_size) { 924 // just CHECK_RHS_BOUNDARY 925 rhs_pf0.x = rhs(rhs_vert, rhs_horiz0); 926 rhs_pf0.y = rhs(rhs_vert + 1, rhs_horiz0); 927 rhs_pf0.z = rhs(rhs_vert + 2, rhs_horiz0); 928 rhs_pf1.x = rhs(rhs_vert, rhs_horiz1); 929 rhs_pf1.y = rhs(rhs_vert + 1, rhs_horiz1); 930 rhs_pf1.z = rhs(rhs_vert + 2, rhs_horiz1); 931 } else if (rhs_vert + 1 < k_size) { 932 rhs_pf0.x = rhs(rhs_vert, rhs_horiz0); 933 rhs_pf0.y = rhs(rhs_vert + 1, rhs_horiz0); 934 rhs_pf1.x = rhs(rhs_vert, rhs_horiz1); 935 rhs_pf1.y = rhs(rhs_vert + 1, rhs_horiz1); 936 } else if (rhs_vert < k_size) { 937 rhs_pf0.x = rhs(rhs_vert, rhs_horiz0); 938 rhs_pf1.x = rhs(rhs_vert, rhs_horiz1); 939 } 940 } else { 941 if (rhs_horiz1 < n_size) { 942 if ((rhs_vert + 3) < k_size) { 943 // just CHECK_RHS_BOUNDARY 944 rhs_pf0 = rhs.template loadPacket<float4,Unaligned>(rhs_vert, rhs_horiz0); 945 rhs_pf1 = rhs.template loadPacket<float4,Unaligned>(rhs_vert, rhs_horiz1); 946 } else if (rhs_vert + 2 < k_size) { 947 // just CHECK_RHS_BOUNDARY 948 rhs_pf0.x = rhs(rhs_vert, rhs_horiz0); 949 rhs_pf0.y = rhs(rhs_vert + 1, rhs_horiz0); 950 rhs_pf0.z = rhs(rhs_vert + 2, rhs_horiz0); 951 rhs_pf1.x = rhs(rhs_vert, rhs_horiz1); 952 rhs_pf1.y = rhs(rhs_vert + 1, rhs_horiz1); 953 rhs_pf1.z = rhs(rhs_vert + 2, rhs_horiz1); 954 } else if (k+threadIdx.x*4 + 1 < k_size) { 955 rhs_pf0.x = rhs(rhs_vert, rhs_horiz0); 956 rhs_pf0.y = rhs(rhs_vert + 1, rhs_horiz0); 957 rhs_pf1.x = rhs(rhs_vert, rhs_horiz1); 958 rhs_pf1.y = rhs(rhs_vert + 1, rhs_horiz1); 959 } else if (k+threadIdx.x*4 < k_size) { 960 rhs_pf0.x = rhs(rhs_vert, rhs_horiz0); 961 rhs_pf1.x = rhs(rhs_vert, rhs_horiz1); 962 } 963 } else if (rhs_horiz0 < n_size) { 964 if ((rhs_vert + 3) < k_size) { 965 // just CHECK_RHS_BOUNDARY 966 rhs_pf0 = rhs.template loadPacket<float4,Unaligned>(rhs_vert, rhs_horiz0); 967 } else if ((rhs_vert + 2) < k_size) { 968 // just CHECK_RHS_BOUNDARY 969 rhs_pf0.x = rhs(rhs_vert, rhs_horiz0); 970 rhs_pf0.y = rhs(rhs_vert + 1, rhs_horiz0); 971 rhs_pf0.z = rhs(rhs_vert + 2, rhs_horiz0); 972 } else if ((rhs_vert + 1) < k_size) { 973 rhs_pf0.x = rhs(rhs_vert, rhs_horiz0); 974 rhs_pf0.y = rhs(rhs_vert + 1, rhs_horiz0); 975 } else if (rhs_vert < k_size) { 976 rhs_pf0.x = rhs(rhs_vert, rhs_horiz0); 977 } 978 } 979 } 980 __syncthreads(); 981 // Loaded. Do computation 982 // Row 0 -> times (0, 4, 8, .. 28) for features 0, 1. 983 // Row 1 -> times (0, 4, 8, .. 28) for features 2, 3. 984 // .. 985 // Row 31 -> times (0, 4, 8, .. 28) for features 62, 63 986 rhs_shmem2[threadIdx.y][threadIdx.x] = make_float2(rhs_pf0.x, rhs_pf1.x); 987 // Row 32 -> times (1, 5, 9, .. 29) for features 0, 1. 988 // Row 33 -> times (1, 5, 9, .. 29) for features 2, 3. 989 // .. 990 rhs_shmem2[threadIdx.y+32][threadIdx.x] = make_float2(rhs_pf0.y, rhs_pf1.y); 991 // Row 64 -> times (2, 6, 10, .. 30) for features 0, 1. 992 // Row 65 -> times (2, 6, 10, .. 30) for features 2, 3. 993 rhs_shmem2[threadIdx.y+64][threadIdx.x] = make_float2(rhs_pf0.z, rhs_pf1.z); 994 // Row 96 -> times (3, 7, 11, .. 31) for features 0, 1. 995 // Row 97 -> times (3, 7, 11, .. 31) for features 2, 3. 996 rhs_shmem2[threadIdx.y+96][threadIdx.x] = make_float2(rhs_pf0.w, rhs_pf1.w); 997 998 // LHS. 999 // Row 0 (time 0) -> features (0, 1), (4, 5), .. (28, 29), (32, 33), .. (60, 61) .. (124, 125) 1000 // Row 1 (time 1) -> features (0, 1), (4, 5), .. (28, 29), (32, 33), .. (60, 61) .. (124, 125) 1001 // ... 1002 // Row 8 (time 0) -> features (2, 3), (6, 7), .. (30, 31), (34, 35), .. (62, 63) .. (126, 127) 1003 // Row 15 (time 7) -> features (2, 3), (6, 7), .. (30, 31), (34, 35), .. (62, 63) .. (126, 127) 1004 1005 1006 #define add_vals(a_feat1, a_feat2, f1, f2, f3, f4)\ 1007 results[0].x += a_feat1.x * f1.x;\ 1008 results[1].x += a_feat1.x * f1.y;\ 1009 results[2].x += a_feat1.x * f2.x;\ 1010 results[3].x += a_feat1.x * f2.y;\ 1011 results[4].x += a_feat1.x * f3.x;\ 1012 results[5].x += a_feat1.x * f3.y;\ 1013 results[6].x += a_feat1.x * f4.x;\ 1014 results[7].x += a_feat1.x * f4.y;\ 1015 \ 1016 results[0].y += a_feat1.y * f1.x;\ 1017 results[1].y += a_feat1.y * f1.y;\ 1018 results[2].y += a_feat1.y * f2.x;\ 1019 results[3].y += a_feat1.y * f2.y;\ 1020 results[4].y += a_feat1.y * f3.x;\ 1021 results[5].y += a_feat1.y * f3.y;\ 1022 results[6].y += a_feat1.y * f4.x;\ 1023 results[7].y += a_feat1.y * f4.y;\ 1024 \ 1025 results[0].z += a_feat2.x * f1.x;\ 1026 results[1].z += a_feat2.x * f1.y;\ 1027 results[2].z += a_feat2.x * f2.x;\ 1028 results[3].z += a_feat2.x * f2.y;\ 1029 results[4].z += a_feat2.x * f3.x;\ 1030 results[5].z += a_feat2.x * f3.y;\ 1031 results[6].z += a_feat2.x * f4.x;\ 1032 results[7].z += a_feat2.x * f4.y;\ 1033 \ 1034 results[0].w += a_feat2.y * f1.x;\ 1035 results[1].w += a_feat2.y * f1.y;\ 1036 results[2].w += a_feat2.y * f2.x;\ 1037 results[3].w += a_feat2.y * f2.y;\ 1038 results[4].w += a_feat2.y * f3.x;\ 1039 results[5].w += a_feat2.y * f3.y;\ 1040 results[6].w += a_feat2.y * f4.x;\ 1041 results[7].w += a_feat2.y * f4.y;\ 1042 1043 lhs_shmem2[threadIdx.y/4][threadIdx.x+(threadIdx.y%4)*8] = make_float2(lhs_pf0.x, lhs_pf0.y); 1044 lhs_shmem2[threadIdx.y/4+8][threadIdx.x+(threadIdx.y%4)*8] = make_float2(lhs_pf1.x, lhs_pf1.y); 1045 lhs_shmem2[threadIdx.y/4+16][threadIdx.x+(threadIdx.y%4)*8] = make_float2(lhs_pf2.x, lhs_pf2.y); 1046 lhs_shmem2[threadIdx.y/4+24][threadIdx.x+(threadIdx.y%4)*8] = make_float2(lhs_pf3.x, lhs_pf3.y); 1047 1048 lhs_shmem2[threadIdx.y/4 + 32][threadIdx.x+(threadIdx.y%4)*8] = make_float2(lhs_pf0.z, lhs_pf0.w); 1049 lhs_shmem2[threadIdx.y/4 + 40][threadIdx.x+(threadIdx.y%4)*8] = make_float2(lhs_pf1.z, lhs_pf1.w); 1050 lhs_shmem2[threadIdx.y/4 + 48][threadIdx.x+(threadIdx.y%4)*8] = make_float2(lhs_pf2.z, lhs_pf2.w); 1051 lhs_shmem2[threadIdx.y/4 + 56][threadIdx.x+(threadIdx.y%4)*8] = make_float2(lhs_pf3.z, lhs_pf3.w); 1052 1053 __syncthreads(); 1054 1055 // Do the multiplies. 1056 #pragma unroll 1057 for (int koff = 0; koff < 32; koff ++) { 1058 float2 a3 = lhs_shmem2[koff][threadIdx.x + (threadIdx.y % 4) * 8]; 1059 float2 a4 = lhs_shmem2[koff + 32][threadIdx.x + (threadIdx.y % 4) * 8]; 1060 1061 // first feature is at (threadIdx.y/4) * 8 last is at start + 8. 1062 int start_feature = (threadIdx.y / 4) * 8; 1063 1064 float2 br1 = rhs_shmem2[start_feature/2 + (koff % 4) * 32][koff/4]; 1065 float2 br2 = rhs_shmem2[start_feature/2 + 1 + (koff % 4) * 32][koff/4]; 1066 float2 br3 = rhs_shmem2[start_feature/2 + 2 + (koff % 4) * 32][koff/4]; 1067 float2 br4 = rhs_shmem2[start_feature/2 + 3 + (koff % 4) * 32][koff/4]; 1068 1069 add_vals(a3, a4, br1, br2, br3, br4) 1070 } 1071 __syncthreads(); 1072 } // end loop over k 1073 1074 __syncthreads(); 1075 Index horiz_base = (threadIdx.y/4)*8+base_n; 1076 if (!CHECK_LHS_BOUNDARY && !CHECK_RHS_BOUNDARY) { 1077 for (int i = 0; i < 8; i++) { 1078 output(lhs_vert, horiz_base + i) = results[i].x; 1079 output(lhs_vert + 1, horiz_base + i) = results[i].y; 1080 output(lhs_vert + 2, horiz_base + i) = results[i].z; 1081 output(lhs_vert + 3, horiz_base + i) = results[i].w; 1082 } 1083 } else if (!CHECK_RHS_BOUNDARY) { 1084 if (lhs_vert + 3 < m_size) { 1085 for (int i = 0; i < 8; i++) { 1086 output(lhs_vert, horiz_base + i) = results[i].x; 1087 output(lhs_vert + 1, horiz_base + i) = results[i].y; 1088 output(lhs_vert + 2, horiz_base + i) = results[i].z; 1089 output(lhs_vert + 3, horiz_base + i) = results[i].w; 1090 } 1091 } else if (lhs_vert + 2 < m_size) { 1092 for (int i = 0; i < 8; i++) { 1093 output(lhs_vert, horiz_base + i) = results[i].x; 1094 output(lhs_vert + 1, horiz_base + i) = results[i].y; 1095 output(lhs_vert + 2, horiz_base + i) = results[i].z; 1096 } 1097 } else if (lhs_vert + 1 < m_size) { 1098 for (int i = 0; i < 8; i++) { 1099 output(lhs_vert, horiz_base + i) = results[i].x; 1100 output(lhs_vert + 1, horiz_base + i) = results[i].y; 1101 } 1102 } else if (lhs_vert < m_size) { 1103 for (int i = 0; i < 8; i++) { 1104 output(lhs_vert, horiz_base + i) = results[i].x; 1105 } 1106 } 1107 } else if (!CHECK_LHS_BOUNDARY) { 1108 // CHECK BOUNDARY_B 1109 for (int i = 0; i < 8; i++) { 1110 if (horiz_base + i < n_size) { 1111 output(lhs_vert, horiz_base + i) = results[i].x; 1112 output(lhs_vert + 1, horiz_base + i) = results[i].y; 1113 output(lhs_vert + 2, horiz_base + i) = results[i].z; 1114 output(lhs_vert + 3, horiz_base + i) = results[i].w; 1115 } 1116 } 1117 } else { 1118 // CHECK both boundaries. 1119 for (int i = 0; i < 8; i++) { 1120 if (horiz_base + i < n_size) { 1121 if (lhs_vert < m_size) 1122 output(lhs_vert, horiz_base + i) = results[i].x; 1123 if (lhs_vert + 1 < m_size) 1124 output(lhs_vert + 1, horiz_base + i) = results[i].y; 1125 if (lhs_vert + 2 < m_size) 1126 output(lhs_vert + 2, horiz_base + i) = results[i].z; 1127 if (lhs_vert + 3 < m_size) 1128 output(lhs_vert + 3, horiz_base + i) = results[i].w; 1129 } 1130 } 1131 } 1132 } 1133 1134 1135 template<typename Index, typename LhsMapper, 1136 typename RhsMapper, typename OutputMapper> 1137 __global__ void 1138 #if defined(EIGEN_HIPCC) 1139 __launch_bounds__(256, 1) 1140 #else 1141 __launch_bounds__(256) 1142 #endif 1143 EigenFloatContractionKernel(const LhsMapper lhs, const RhsMapper rhs, 1144 const OutputMapper output, 1145 const Index m_size, const Index n_size, const Index k_size) { 1146 __shared__ float2 lhs_shmem[64*32]; 1147 __shared__ float2 rhs_shmem[128*8]; 1148 1149 typedef float2 LHS_MEM[64][32]; 1150 typedef float2 RHS_MEM[128][8]; 1151 1152 const Index m_block_idx = blockIdx.x; 1153 const Index n_block_idx = blockIdx.y; 1154 1155 const Index base_m = 128 * m_block_idx; 1156 const Index base_n = 64 * n_block_idx; 1157 1158 bool check_rhs = (base_n + 63) >= n_size; 1159 bool check_lhs128 = (base_m + 127) >= m_size; 1160 1161 if (!check_rhs) { 1162 if (!check_lhs128) { 1163 // >= 128 rows left 1164 EigenFloatContractionKernelInternal<Index, LhsMapper, RhsMapper, OutputMapper, false, false>( 1165 lhs, rhs, output, *((LHS_MEM *) lhs_shmem), *((RHS_MEM *) rhs_shmem), m_size, n_size, k_size, base_m, base_n); 1166 } else { 1167 EigenFloatContractionKernelInternal<Index, LhsMapper, RhsMapper, OutputMapper, true, false>( 1168 lhs, rhs, output, *((LHS_MEM *) lhs_shmem), *((RHS_MEM *) rhs_shmem), m_size, n_size, k_size, base_m, base_n); 1169 } 1170 } else { 1171 if (!check_lhs128) { 1172 // >= 128 rows left 1173 EigenFloatContractionKernelInternal<Index, LhsMapper, RhsMapper, OutputMapper, false, true>( 1174 lhs, rhs, output, *((LHS_MEM *) lhs_shmem), *((RHS_MEM *) rhs_shmem), m_size, n_size, k_size, base_m, base_n); 1175 } else { 1176 EigenFloatContractionKernelInternal<Index, LhsMapper, RhsMapper, OutputMapper, true, true>( 1177 lhs, rhs, output, *((LHS_MEM *) lhs_shmem), *((RHS_MEM *) rhs_shmem), m_size, n_size, k_size, base_m, base_n); 1178 } 1179 } 1180 } 1181 1182 template<typename Index, typename LhsMapper, 1183 typename RhsMapper, typename OutputMapper> 1184 __global__ void 1185 #if defined(EIGEN_HIPCC) 1186 __launch_bounds__(256, 1) 1187 #else 1188 __launch_bounds__(256) 1189 #endif 1190 EigenFloatContractionKernel16x16(const LhsMapper lhs, const RhsMapper rhs, 1191 const OutputMapper output, 1192 const Index m_size, const Index n_size, const Index k_size) { 1193 __shared__ float2 lhs_shmem[32][16]; 1194 __shared__ float2 rhs_shmem[64][8]; 1195 1196 const Index m_block_idx = blockIdx.x; 1197 const Index n_block_idx = blockIdx.y; 1198 1199 const Index base_m = 64 * m_block_idx; 1200 const Index base_n = 64 * n_block_idx; 1201 1202 if (base_m + 63 < m_size) { 1203 if (base_n + 63 < n_size) { 1204 EigenFloatContractionKernelInternal16x16<Index, LhsMapper, RhsMapper, OutputMapper, false, false>(lhs, rhs, output, lhs_shmem, rhs_shmem, m_size, n_size, k_size, base_m, base_n); 1205 } else { 1206 EigenFloatContractionKernelInternal16x16<Index, LhsMapper, RhsMapper, OutputMapper, false, true>(lhs, rhs, output, lhs_shmem, rhs_shmem, m_size, n_size, k_size, base_m, base_n); 1207 } 1208 } else { 1209 if (base_n + 63 < n_size) { 1210 EigenFloatContractionKernelInternal16x16<Index, LhsMapper, RhsMapper, OutputMapper, true, false>(lhs, rhs, output, lhs_shmem, rhs_shmem, m_size, n_size, k_size, base_m, base_n); 1211 } else { 1212 EigenFloatContractionKernelInternal16x16<Index, LhsMapper, RhsMapper, OutputMapper, true, true>(lhs, rhs, output, lhs_shmem, rhs_shmem, m_size, n_size, k_size, base_m, base_n); 1213 } 1214 } 1215 } 1216 1217 1218 template<typename Indices, typename LeftArgType, typename RightArgType, typename OutputKernelType> 1219 struct TensorEvaluator<const TensorContractionOp<Indices, LeftArgType, RightArgType, OutputKernelType>, GpuDevice> : 1220 public TensorContractionEvaluatorBase<TensorEvaluator<const TensorContractionOp<Indices, LeftArgType, RightArgType, OutputKernelType>, GpuDevice> > { 1221 1222 typedef GpuDevice Device; 1223 1224 typedef TensorEvaluator<const TensorContractionOp<Indices, LeftArgType, RightArgType, OutputKernelType>, Device> Self; 1225 typedef TensorContractionEvaluatorBase<Self> Base; 1226 1227 typedef TensorContractionOp<Indices, LeftArgType, RightArgType, OutputKernelType> XprType; 1228 typedef typename internal::remove_const<typename XprType::Scalar>::type Scalar; 1229 typedef typename XprType::Index Index; 1230 typedef typename XprType::CoeffReturnType CoeffReturnType; 1231 typedef typename PacketType<CoeffReturnType, GpuDevice>::type PacketReturnType; 1232 1233 enum { 1234 Layout = TensorEvaluator<LeftArgType, Device>::Layout, 1235 }; 1236 1237 // Most of the code is assuming that both input tensors are ColMajor. If the 1238 // inputs are RowMajor, we will "cheat" by swapping the LHS and RHS: 1239 // If we want to compute A * B = C, where A is LHS and B is RHS, the code 1240 // will pretend B is LHS and A is RHS. 1241 typedef typename internal::conditional< 1242 static_cast<int>(Layout) == static_cast<int>(ColMajor), LeftArgType, RightArgType>::type EvalLeftArgType; 1243 typedef typename internal::conditional< 1244 static_cast<int>(Layout) == static_cast<int>(ColMajor), RightArgType, LeftArgType>::type EvalRightArgType; 1245 1246 static const int LDims = 1247 internal::array_size<typename TensorEvaluator<EvalLeftArgType, Device>::Dimensions>::value; 1248 static const int RDims = 1249 internal::array_size<typename TensorEvaluator<EvalRightArgType, Device>::Dimensions>::value; 1250 static const int ContractDims = internal::array_size<Indices>::value; 1251 1252 typedef array<Index, LDims> left_dim_mapper_t; 1253 typedef array<Index, RDims> right_dim_mapper_t; 1254 1255 typedef array<Index, ContractDims> contract_t; 1256 typedef array<Index, LDims - ContractDims> left_nocontract_t; 1257 typedef array<Index, RDims - ContractDims> right_nocontract_t; 1258 1259 static const int NumDims = LDims + RDims - 2 * ContractDims; 1260 1261 typedef DSizes<Index, NumDims> Dimensions; 1262 1263 // typedefs needed in evalTo 1264 typedef typename internal::remove_const<typename EvalLeftArgType::Scalar>::type LhsScalar; 1265 typedef typename internal::remove_const<typename EvalRightArgType::Scalar>::type RhsScalar; 1266 1267 typedef TensorEvaluator<EvalLeftArgType, Device> LeftEvaluator; 1268 typedef TensorEvaluator<EvalRightArgType, Device> RightEvaluator; 1269 1270 typedef typename LeftEvaluator::Dimensions LeftDimensions; 1271 typedef typename RightEvaluator::Dimensions RightDimensions; 1272 1273 TensorEvaluator(const XprType& op, const Device& device) : 1274 Base(op, device) 1275 { 1276 EIGEN_STATIC_ASSERT( (internal::is_same<OutputKernelType, const NoOpOutputKernel>::value), 1277 GPU_TENSOR_CONTRACTION_DOES_NOT_SUPPORT_OUTPUT_KERNELS); 1278 } 1279 1280 // We need to redefine this method to make nvcc happy 1281 EIGEN_STRONG_INLINE bool evalSubExprsIfNeeded(Scalar* data) { 1282 this->m_leftImpl.evalSubExprsIfNeeded(NULL); 1283 this->m_rightImpl.evalSubExprsIfNeeded(NULL); 1284 if (data) { 1285 evalTo(data); 1286 return false; 1287 } else { 1288 this->m_result = static_cast<Scalar *>(this->m_device.allocate(this->dimensions().TotalSize() * sizeof(Scalar))); 1289 evalTo(this->m_result); 1290 return true; 1291 } 1292 } 1293 1294 void evalTo(Scalar* buffer) const { 1295 if (this->m_lhs_inner_dim_contiguous) { 1296 if (this->m_rhs_inner_dim_contiguous) { 1297 if (this->m_rhs_inner_dim_reordered) { 1298 evalTyped<true, true, true, Unaligned>(buffer); 1299 } 1300 else { 1301 evalTyped<true, true, false, Unaligned>(buffer); 1302 } 1303 } 1304 else { 1305 if (this->m_rhs_inner_dim_reordered) { 1306 evalTyped<true, false, true, Unaligned>(buffer); 1307 } 1308 else { 1309 evalTyped<true, false, false, Unaligned>(buffer); 1310 } 1311 } 1312 } 1313 else { 1314 if (this->m_rhs_inner_dim_contiguous) { 1315 if (this->m_rhs_inner_dim_reordered) { 1316 evalTyped<false, true, true, Unaligned>(buffer); 1317 } 1318 else { 1319 evalTyped<false, true, false, Unaligned>(buffer); 1320 } 1321 } 1322 else { 1323 if (this->m_rhs_inner_dim_reordered) { 1324 evalTyped<false, false, true, Unaligned>(buffer); 1325 } 1326 else { 1327 evalTyped<false, false, false, Unaligned>(buffer); 1328 } 1329 } 1330 } 1331 } 1332 1333 template <typename LhsScalar, typename RhsScalar, typename Index, typename LhsMapper, typename RhsMapper, typename OutputMapper> struct LaunchKernels { 1334 static void Run(const LhsMapper& lhs, const RhsMapper& rhs, const OutputMapper& output, Index m, Index n, Index k, const GpuDevice& device) { 1335 const Index m_blocks = (m + 63) / 64; 1336 const Index n_blocks = (n + 63) / 64; 1337 const dim3 num_blocks(m_blocks, n_blocks, 1); 1338 const dim3 block_size(8, 8, 8); 1339 LAUNCH_GPU_KERNEL((EigenContractionKernel<Scalar, Index, LhsMapper, RhsMapper, OutputMapper>), num_blocks, block_size, 0, device, lhs, rhs, output, m, n, k); 1340 } 1341 }; 1342 1343 template <typename Index, typename LhsMapper, typename RhsMapper, typename OutputMapper> struct LaunchKernels<float, float, Index, LhsMapper, RhsMapper, OutputMapper> { 1344 static void Run(const LhsMapper& lhs, const RhsMapper& rhs, const OutputMapper& output, Index m, Index n, Index k, const GpuDevice& device) { 1345 if (m < 768 || n < 768) { 1346 const Index m_blocks = (m + 63) / 64; 1347 const Index n_blocks = (n + 63) / 64; 1348 const dim3 num_blocks(m_blocks, n_blocks, 1); 1349 const dim3 block_size(16, 16, 1); 1350 LAUNCH_GPU_KERNEL((EigenFloatContractionKernel16x16<Index, LhsMapper, RhsMapper, OutputMapper>), num_blocks, block_size, 0, device, lhs, rhs, output, m, n, k); 1351 } else { 1352 const Index m_blocks = (m + 127) / 128; 1353 const Index n_blocks = (n + 63) / 64; 1354 const dim3 num_blocks(m_blocks, n_blocks, 1); 1355 const dim3 block_size(8, 32, 1); 1356 LAUNCH_GPU_KERNEL((EigenFloatContractionKernel<Index, LhsMapper, RhsMapper, OutputMapper>), num_blocks, block_size, 0, device, lhs, rhs, output, m, n, k); 1357 } 1358 } 1359 }; 1360 1361 template <bool lhs_inner_dim_contiguous, bool rhs_inner_dim_contiguous, bool rhs_inner_dim_reordered, int Alignment> 1362 void evalTyped(Scalar* buffer) const { 1363 // columns in left side, rows in right side 1364 const Index k = this->m_k_size; 1365 EIGEN_UNUSED_VARIABLE(k) 1366 1367 // rows in left side 1368 const Index m = this->m_i_size; 1369 1370 // columns in right side 1371 const Index n = this->m_j_size; 1372 1373 // zero out the result buffer (which must be of size at least m * n * sizeof(Scalar) 1374 this->m_device.memset(buffer, 0, m * n * sizeof(Scalar)); 1375 1376 typedef internal::TensorContractionInputMapper<LhsScalar, Index, internal::Lhs, 1377 LeftEvaluator, left_nocontract_t, 1378 contract_t, 4, 1379 lhs_inner_dim_contiguous, 1380 false, Unaligned> LhsMapper; 1381 1382 typedef internal::TensorContractionInputMapper<RhsScalar, Index, internal::Rhs, 1383 RightEvaluator, right_nocontract_t, 1384 contract_t, 4, 1385 rhs_inner_dim_contiguous, 1386 rhs_inner_dim_reordered, Unaligned> RhsMapper; 1387 1388 typedef internal::blas_data_mapper<Scalar, Index, ColMajor> OutputMapper; 1389 1390 1391 // initialize data mappers 1392 LhsMapper lhs(this->m_leftImpl, this->m_left_nocontract_strides, this->m_i_strides, 1393 this->m_left_contracting_strides, this->m_k_strides); 1394 1395 RhsMapper rhs(this->m_rightImpl, this->m_right_nocontract_strides, this->m_j_strides, 1396 this->m_right_contracting_strides, this->m_k_strides); 1397 1398 OutputMapper output(buffer, m); 1399 1400 #if defined(EIGEN_USE_HIP) 1401 setGpuSharedMemConfig(hipSharedMemBankSizeEightByte); 1402 #else 1403 setGpuSharedMemConfig(cudaSharedMemBankSizeEightByte); 1404 #endif 1405 1406 LaunchKernels<LhsScalar, RhsScalar, Index, LhsMapper, RhsMapper, OutputMapper>::Run(lhs, rhs, output, m, n, k, this->m_device); 1407 } 1408 }; 1409 1410 } // end namespace Eigen 1411 1412 #endif // EIGEN_USE_GPU and EIGEN_GPUCC 1413 #endif // EIGEN_CXX11_TENSOR_TENSOR_CONTRACTION_GPU_H