TensorContractionSycl.h (89042B)
1 // This file is part of Eigen, a lightweight C++ template library for linear algebra. 2 // 3 // Mehdi Goli Codeplay Software Ltd. 4 // Ralph Potter Codeplay Software Ltd. 5 // Luke Iwanski Codeplay Software Ltd. 6 // Contact: <eigen@codeplay.com> 7 // 8 // This Source Code Form is subject to the terms of the Mozilla Public License v. 2.0. If a copy of the MPL was not 9 // distributed with this file, You can obtain one at http://mozilla.org/MPL/2.0/. 10 11 /***************************************************************** 12 * TensorContractionSycl.h 13 * 14 * \brief: 15 * TensorContractionSycl.h, provides various tensor contraction kernel for SYCL backend 16 * 17 *****************************************************************/ 18 19 #ifndef EIGEN_CXX11_TENSOR_TENSOR_CONTRACTION_SYCL_H 20 #define EIGEN_CXX11_TENSOR_TENSOR_CONTRACTION_SYCL_H 21 22 namespace Eigen { 23 24 namespace TensorSycl { 25 namespace internal { 26 27 #ifndef EIGEN_SYCL_DISABLE_GEMV 28 /*! 29 * \brief TVPanelSize, a template class used for setting the panel size required for launching General TensorVector 30 * contraction kernel on various hardware devices. 31 * 32 * \tparam Scalar: determines the element type of the tensor/vector 33 * 34 * \tparam StorageIndex determines the Index type. 35 * 36 * \tparam NCWindow: determines the number of non-contracting element to be process by each work-group 37 * 38 * \tparam CFactor: determines the number of contracting element to be process by each thread 39 * 40 * \tparam NCFactor: determines the number of non-contracting element to be process by each thread 41 */ 42 template <typename Scalar, typename StorageIndex, StorageIndex NCWindow, StorageIndex CFactor, StorageIndex NCFactor> 43 struct TVPanelSize { 44 // LocalThreadSizeC: determines total number of thread per workgroup for the contracting dimension 45 static EIGEN_CONSTEXPR StorageIndex LocalThreadSizeC = EIGEN_SYCL_LOCAL_THREAD_DIM0; 46 // LocalThreadSizeNC: determines total number of thread per workgroup for the non-contracting dimension 47 static EIGEN_CONSTEXPR StorageIndex LocalThreadSizeNC = EIGEN_SYCL_LOCAL_THREAD_DIM1; 48 // TileSizeDimNC: determines the tile size for the non-contracting dimension 49 static EIGEN_CONSTEXPR StorageIndex TileSizeDimNC = NCWindow / NCFactor; 50 // TileSizeDimC: determines the tile size for the contracting dimension 51 static EIGEN_CONSTEXPR StorageIndex TileSizeDimC = CFactor * LocalThreadSizeNC * LocalThreadSizeC; 52 // WorkLoadPerThreadNC : determines workload per thread for loading the non-contracting dimension 53 static EIGEN_CONSTEXPR StorageIndex WorkLoadPerThreadNC = TileSizeDimNC / LocalThreadSizeNC; 54 // WorkLoadPerThreadC: determines workload per thread for loading the non-contracting dimension 55 static EIGEN_CONSTEXPR StorageIndex WorkLoadPerThreadC = TileSizeDimC / LocalThreadSizeC; 56 // BC : determines if supporting bank conflict is required 57 static EIGEN_CONSTEXPR bool BC = false; 58 }; 59 #endif 60 61 /*! 62 * \brief TTPanelSize, a template class used for setting the panel size required for launching General Tensor Tensor 63 contraction kernel on various hardware devices. 64 * 65 * \tparam Scalar: determines the element type of the tensor 66 * 67 * \tparam StorageIndex: determines the Index type. 68 * 69 * \tparam REG_SIZE_M: determines workload per thread for loading the M dimension This can be varied based on the 70 available register on a chosen device(can be controlled by EIGEN_SYCL_REG_M macro). 71 * 72 * \tparam REG_SIZE_N: determines workload per thread for loading the N dimension This can be varied based on the 73 available register on a chosen device(can be controlled by EIGEN_SYCL_REG_N macro). 74 * 75 * \tparam TSDK: determines Tile size for dimension K. The packet size is assumed to be considered 76 */ 77 78 template <typename Scalar, typename StorageIndex, StorageIndex REG_SIZE_M, StorageIndex REG_SIZE_N, StorageIndex TSDK> 79 struct TTPanelSize { 80 // TileSizeDimK: determines Tile size for dimension K. The packet size is assumed to be considered 81 static EIGEN_CONSTEXPR StorageIndex TileSizeDimK = TSDK; 82 // WorkLoadPerThreadM : determines workload per thread for loading the M dimension This can be varied based on the 83 // available register on a chosen device(can be controlled by EIGEN_SYCL_REG_M macro// 84 #ifndef EIGEN_SYCL_REG_M 85 static EIGEN_CONSTEXPR StorageIndex WorkLoadPerThreadM = REG_SIZE_M; 86 #else 87 static EIGEN_CONSTEXPR StorageIndex WorkLoadPerThreadM = EIGEN_SYCL_REG_M; 88 #endif 89 // WorkLoadPerThreadN : determines workload per thread for loading the N dimension This can be varied based on the 90 // available register on a chosen device(can be controlled by EIGEN_SYCL_REG_N macro 91 #ifndef EIGEN_SYCL_REG_N 92 static EIGEN_CONSTEXPR StorageIndex WorkLoadPerThreadN = REG_SIZE_N; 93 #else 94 static EIGEN_CONSTEXPR StorageIndex WorkLoadPerThreadN = EIGEN_SYCL_REG_N; 95 #endif 96 // LocalThreadSizeM: determines total number of thread per workgroup for the m dimension 97 static EIGEN_CONSTEXPR StorageIndex LocalThreadSizeM = EIGEN_SYCL_LOCAL_THREAD_DIM0; 98 // LocalThreadSizeN: determines total number of thread per workgroup for the n dimension 99 static EIGEN_CONSTEXPR StorageIndex LocalThreadSizeN = EIGEN_SYCL_LOCAL_THREAD_DIM1; 100 // TileSizeDimM: determines the tile size for the m dimension 101 static EIGEN_CONSTEXPR StorageIndex TileSizeDimM = LocalThreadSizeM * WorkLoadPerThreadM; 102 // TileSizeDimN: determines the tile size for the n dimension 103 static EIGEN_CONSTEXPR StorageIndex TileSizeDimN = LocalThreadSizeN * WorkLoadPerThreadN; 104 // LoadPerThreadLhs: determines workload per thread for loading Lhs Tensor. This must be divisable by packetsize 105 static EIGEN_CONSTEXPR StorageIndex LoadPerThreadLhs = 106 ((TileSizeDimK * WorkLoadPerThreadM * WorkLoadPerThreadN) / (TileSizeDimN)); 107 // LoadPerThreadRhs: determines workload per thread for loading Rhs Tensor. This must be divisable by packetsize 108 static EIGEN_CONSTEXPR StorageIndex LoadPerThreadRhs = 109 ((TileSizeDimK * WorkLoadPerThreadM * WorkLoadPerThreadN) / (TileSizeDimM)); 110 // BC : determines if supporting bank conflict is required 111 static EIGEN_CONSTEXPR bool BC = true; 112 // DoubleBuffer: determines if double buffering technique should be used (This can be disabled by 113 // EIGEN_SYCL_DISABLE_DOUBLE_BUFFER macro when the device doesnot have sufficient local memory) 114 static EIGEN_CONSTEXPR bool DoubleBuffer = 115 #ifdef EIGEN_SYCL_DISABLE_DOUBLE_BUFFER 116 false; 117 #else 118 true; 119 #endif 120 }; 121 122 /* ! 123 * \brief contraction_type: an enum class representing the Tensor Contraction implementation algorithm. This is used to 124 * specialize the contraction algorithm based on device support for dedicated local memory. 125 */ 126 enum class contraction_type { local, no_local }; 127 /* ! 128 * \brief data_source an enum class determining the location of the data in a memory hierarchy (global, local, private). 129 */ 130 enum class data_source { global_mem, local_mem, private_mem }; 131 132 /*! 133 * \brief read, a template function used for loading the data from global 134 memory. This function is used to guarantee coalesced and vectorized load whenever possible 135 * 136 * \tparam PacketLoad: determines if the each element of this tensor block should be loaded in a packet mode 137 * 138 * \param is_coalesced_layout: determines whether or not the Tensor data in a memory can be access coalesced and 139 vectorized when possible. Coalesced memory access is a key factor in Kernel performance. When a tensor is 2d and the 140 contracting dimension is 1, it is always possible to accessed tensor data coalesced and vectorized. This is the case 141 when RHS(right hand side) Tensor is transposed or when LHS(left hand side) Tensor is not transposed. 142 * 143 * \tparam PacketType: determines the type of packet 144 * 145 * \tparam TensorMapper: determines the input tensor mapper type 146 * 147 * \tparam StorageIndex: determines the Index type 148 149 * \param tensorMapper: is the input tensor 150 * 151 * \param NCIndex: is the non-contracting dim index 152 * 153 * \param CIndex is the contracting dim index 154 * 155 * \param ld: is the leading dimension of the flattened tensor 156 */ 157 template <bool PacketLoad, bool is_coalesced_layout, bool, typename PacketType, typename TensorMapper, 158 typename StorageIndex> 159 static EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE typename ::Eigen::internal::enable_if<PacketLoad, PacketType>::type read( 160 const TensorMapper &tensorMapper, const StorageIndex &NCIndex, const StorageIndex &CIndex, const StorageIndex &ld) { 161 const StorageIndex row = (is_coalesced_layout) ? NCIndex : CIndex; 162 const StorageIndex col = (is_coalesced_layout) ? CIndex : NCIndex; 163 return tensorMapper.get_tensor().template packet<Unaligned>(row + (col * ld)); 164 } 165 166 /*! 167 * \brief read, special overload of read function, when the read access is not vectorized 168 * 169 * \tparam PacketLoad: determines if the each element of this tensor block should be loaded in a packet mode 170 * 171 * \param is_coalesced_layout: determines whether or not the Tensor data in a memory can be access coalesced and 172 vectorized when possible. Coalesced memory access is a key factor in Kernel performance. When a tensor is 2d and the 173 contracting dimension is 1, it is always possible to accessed tensor data coalesced and vectorized. This is the case 174 when RHS(right hand side) Tensor is transposed or when LHS(left hand side) Tensor is not transposed. 175 * 176 * \tparam PacketType: determines the type of packet 177 * 178 * \tparam TensorMapper: determines the input tensor mapper type 179 * 180 * \tparam StorageIndex: determines the Index type 181 182 * \param tensorMapper: is the input tensor 183 * 184 * \param NCIndex: is the non-contracting dim index 185 * 186 * \param CIndex: is the contracting dim index 187 */ 188 template <bool PacketLoad, bool, bool IsRhs, typename PacketType, typename TensorMapper, typename StorageIndex> 189 static EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE typename ::Eigen::internal::enable_if<!PacketLoad, PacketType>::type read( 190 const TensorMapper &tensorMapper, const StorageIndex &NCIndex, const StorageIndex &CIndex, const StorageIndex &) { 191 const StorageIndex row = (IsRhs) ? CIndex : NCIndex; 192 const StorageIndex col = (IsRhs) ? NCIndex : CIndex; 193 return tensorMapper(row, col); 194 } 195 196 /*! 197 * \brief write, a template function used for storing the data to local memory. This function is used to guarantee 198 * coalesced and vectorized store whenever possible. 199 * 200 * \tparam StorageIndex: determines the Index type 201 * 202 * \param ld is the leading dimension of the local memory. ld is a compile time value for the local memory 203 * 204 * \tparam data_source: an enum value representing if the location of the data in a memory hierarchy. 205 * 206 * \tparam PacketType: determines the type of packet 207 * 208 * \tparam DataScalar: determines the output data type 209 * 210 * \param packet_data: the data to be written in the local memory 211 * 212 * \param ptr: a pointer to the local memory 213 * 214 * \param CIndex is the contracting dim index 215 */ 216 217 template <typename StorageIndex, StorageIndex ld, data_source dt, typename PacketType, typename DataScalar> 218 static EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE 219 typename ::Eigen::internal::enable_if<dt != data_source::global_mem, void>::type 220 write(PacketType &packet_data, DataScalar ptr) { 221 EIGEN_CONSTEXPR int PacketSize = Eigen::internal::unpacket_traits<PacketType>::size; 222 EIGEN_UNROLL_LOOP 223 for (int i = 0; i < PacketSize; i++) { 224 *ptr = PacketWrapper<PacketType, PacketSize>::scalarize(i, packet_data); 225 ptr += ld; 226 } 227 } 228 229 /*! 230 * \brief Overloading the write function for storing the data to global memory, when vectorization enabled This function 231 * is used to guarantee coalesced and vectorized store whenever possible. 232 * 233 * \tparam data_source: an enum value representing if the location of the data in a memory hierarchy. 234 * 235 * \tparam PacketType: determines the type of packet 236 * 237 * \tparam DataScalar: determines the output data type 238 * 239 * \param packet_data: the data to be written in the local memory 240 * 241 * \param ptr: a pointer to the local memory 242 */ 243 244 template <data_source dt, typename PacketType, typename DataScalar> 245 static EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE typename ::Eigen::internal::enable_if< 246 Eigen::internal::unpacket_traits<PacketType>::size != 1 && dt == data_source::global_mem, void>::type 247 write(PacketType &packet_data, DataScalar *ptr) { 248 ::Eigen::internal::pstoreu<DataScalar, PacketType>(ptr, packet_data); 249 } 250 251 /*! 252 * \brief Overloading the write function for storing the data to global memory, when vectorization is disabled. 253 * 254 * \tparam data_source: an enum value representing if the location of the data in a memory hierarchy. 255 * 256 * \tparam PacketType: determines the type of packet 257 * 258 * \tparam DataScalar: determines the output data type 259 * 260 * \param packet_data: the data to be written in the local memory 261 * 262 * \param ptr: a pointer to the local memory 263 */ 264 template <data_source dt, typename PacketType, typename DataScalar> 265 static EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE typename ::Eigen::internal::enable_if< 266 Eigen::internal::unpacket_traits<PacketType>::size == 1 && dt == data_source::global_mem, void>::type 267 write(PacketType &packet_data, DataScalar *ptr) { 268 *ptr = packet_data; 269 } 270 271 /*! 272 * \brief check_boundary: is used to check the edge condition for non-internal blocks. 273 * 274 * \tparam is_internal: determines if the block is internal 275 */ 276 template <bool is_internal> 277 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE bool check_boundary(bool) { 278 return true; 279 } 280 281 /*! 282 * \brief check_boundary: specialization of the check_boundary for non-internal blocks. 283 * 284 * \param cond: true when the data is in range. Otherwise false 285 */ 286 template <> 287 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE bool check_boundary<false>(bool cond) { 288 return cond; 289 } 290 291 /*! 292 * \brief BlockProperties is a template class that provides different characteristic of a block of each Tensor processed 293 * by each workgroup. 294 * 295 * \tparam is_transposed: iff true, determines whether or not the block of the Tensor is transposed 296 * 297 * \tparam packet_load_: determines if the each element of this tensor block should be loaded in a packet mode 298 * 299 * \tparam PacketType: determines the type of packet 300 * 301 * \tparam OutType: determines the type of each element for this block of tensor. If packet load is true, it will be 302 * packetType; Otherwise it will be scalar Type 303 * 304 * \param elements_per_access determines the size of each element based on OutType 305 * 306 * \param is_coalesced_layout determines whether or not the Tensor data in a memory can be access coalesced and 307 * vectorized when possible. Coalesced memory access is a key factor in Kernel performance. When a tensor is 2d and the 308 * contracting dimension is 1, it is always possible to accessed tensor data coalesced and vectorized. This is the case 309 * when RHS(right hand side) Tensor is transposed or when LHS(left hand side) Tensor is not transposed. 310 * 311 * \param nc_stride determines the stride of non-contracting dimension to access the next adjustment element within the 312 * Tensor Block for each workgroup 313 * 314 * \param c_stride determines the stride of contracting dimension to access the next adjustment element within the 315 * Tensor Block for each workgroup 316 */ 317 template <bool is_transposed, bool is_rhs_, bool packet_load_, typename PacketType> 318 struct BlockProperties { 319 static EIGEN_CONSTEXPR bool packet_load = packet_load_; 320 typedef typename Eigen::internal::unpacket_traits<PacketType>::type OutScalar; 321 static EIGEN_CONSTEXPR bool is_rhs = is_rhs_; 322 typedef typename Eigen::internal::conditional<packet_load, PacketType, OutScalar>::type OutType; 323 static EIGEN_CONSTEXPR int elements_per_access = Eigen::internal::unpacket_traits<OutType>::size; 324 static EIGEN_CONSTEXPR bool is_coalesced_layout = !(is_transposed ^ is_rhs); 325 static EIGEN_CONSTEXPR int nc_stride = (is_coalesced_layout ? elements_per_access : 1); 326 static EIGEN_CONSTEXPR int c_stride = (is_coalesced_layout ? 1 : elements_per_access); 327 }; 328 329 /*! 330 * \brief ThreadProperties is a template class that provides each thread's properties within a workgroup. Please see 331 * the sycl-1.2.1 specification (https://www.khronos.org/registry/SYCL/specs/sycl-1.2.1.pdf) for the workgroup, 332 * work-items 333 * 334 * \tparam StorageIndex: determines the StorageIndex Type 335 * 336 * \param linearLocalThreadId: determines the linearized location of a thread within a work-group 337 * 338 * \param kGroupId: determines the logical group id in a k dimension of the flattened tensor. It will be > 1 when 339 * tall/skinny algorithm is used 340 * 341 * \param mGroupOffset: determines the logical start position of all thread within a workgroup for the m dimension of 342 * the flattened tensor. 343 * 344 * \param kGroupOffset determines the logical start position of all thread within a workgroup for the k dimension of the 345 * flattened tensor. It will be > 1 when tall/skinny algorithm is used. 346 * 347 * \param mLocalOffset: determines the logical start position of each thread within a workgroup for the m dimension of a 348 * flattened tensor. The position determines the distance of each thread within the workgroup from each other 349 * independent from their global position. 350 * 351 * \param nLocalOffset: determines the logical start position of each thread within a workgroup for the n dimension of a 352 * flattened tensor. The position determines the distance of each thread within the workgroup from each other 353 * independent from their global position. 354 * 355 * \param mGlobalOffset: determines the logical start position of each thread a thread for the m dimension on a 356 * flattened tensor 357 * 358 * \param nGlobalOffset: determines the logical start position of each thread a thread for the n dimension on a 359 * flattened tensor 360 * 361 * \param kSize : determine the number of the k elements of the flattened Tensor to be processed by each thread for the 362 * given tensor block. This is !=K dimension of Flattened Tensor when Tall/Skinny matrix is used. 363 * 364 * \param is_internal : this will determined if the thread within the work-group computes an internal block of tensor or 365 * the edge blocks. When it is internal, there is no need to check the boundaries and all the if stantement can be 366 * resolve by compiler. 367 */ 368 template <typename StorageIndex> 369 struct ThreadProperties { 370 const StorageIndex linearLocalThreadId; 371 const StorageIndex kGroupId; 372 const StorageIndex mGroupOffset; 373 const StorageIndex nGroupOffset; 374 const StorageIndex kGroupOffset; 375 const StorageIndex mLocalOffset; 376 const StorageIndex nLocalOffset; 377 const StorageIndex mGlobalOffset; 378 const StorageIndex nGlobalOffset; 379 StorageIndex kSize; 380 const bool is_internal; 381 // this is used to adjust the last block 382 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE ThreadProperties( 383 const StorageIndex linearLocalThreadId_, const StorageIndex kGroupId_, const StorageIndex mGroupOffset_, 384 const StorageIndex nGroupOffset_, const StorageIndex kGroupOffset_, const StorageIndex mLocalOffset_, 385 const StorageIndex nLocalOffset_, const StorageIndex mGlobalOffset_, const StorageIndex nGlobalOffset_, 386 StorageIndex kSize_, const bool is_internal_) 387 : linearLocalThreadId(linearLocalThreadId_), 388 kGroupId(kGroupId_), 389 mGroupOffset(mGroupOffset_), 390 nGroupOffset(nGroupOffset_), 391 kGroupOffset(kGroupOffset_), 392 mLocalOffset(mLocalOffset_), 393 nLocalOffset(nLocalOffset_), 394 mGlobalOffset(mGlobalOffset_), 395 nGlobalOffset(nGlobalOffset_), 396 kSize(kSize_), 397 is_internal(is_internal_) {} 398 }; 399 400 /*! 401 * \brief TensorContractionKernel is a template class that provides Tensor -Tensor contraction operation. 402 * 403 * \tparam OutScalar: determines the output scalar type 404 * 405 * \tparam LhsScalar: determines the left-hand-side scalar type 406 * 407 * \tparam RhsScalar: determines the right-hand-side scalar type 408 * 409 * \tparam OutAccessor: determines the sycl accessor type for out put (please see the sycl-1.2.1 specification 410 (https://www.khronos.org/registry/SYCL/specs/sycl-1.2.1.pdf) for accessor definition) 411 * 412 * \tparam LhsMapper determines the tensor contraction mapper type for left-hand-side matrix 413 * 414 * \tparam RhsMapper determines the tensor contraction mapper type for right-hand-side matrix 415 * 416 * \tparam StorageIndex: determines the StorageIndex Type 417 * 418 * \tparam Properties: determines the Contraction Panel properties 419 * 420 * \tparam TripleDim: determines the M, K, N dimensions for the flatten tensors in order to treat them as a matrix 421 * 422 * \tparam Vectorizable: determines whether or not the vectorization is enabled for the Eigen expression. 423 * 424 * \tparam input_mapper_properties : determine if the input tensors are matrix. If they are matrix, special memory 425 access is used to guarantee that always the memory access are coalesced. 426 * 427 * \tptaram IsFinal : determine if this is the final kernel. If so, the result will be written in a final output. 428 Otherwise, the result of contraction will be written iin a temporary buffer. This is the case when Tall/Skinny 429 contraction is used. So in this case, a final reduction step is required to compute final output. 430 431 * \tparam contraction_tp: it is an enum value representing whether the local memroy/no local memory implementation of 432 the algorithm to be used 433 * 434 * \param scratch: local memory containing tiles of LHS and RHS tensors for each work-group 435 * 436 * \param lhs: determines the left-hand-side flattened tensor (tensor mapper) 437 * 438 * \param rhs: determines the right-hand-side flattened tensor (tensor mapper) 439 * 440 * \param out_res: determines the output tensor containing the contraction result 441 * 442 * \param groupSizeM: a logical number determining the number of work-group for m dimension 443 * 444 * \param groupSizeN: a logical number determining the number of work-group for n dimension 445 * 446 * \param numTiles: determines total number of tiles on the k dimension 447 * 448 * \param TripleDim: determines the M, K, N dimensions for the flatten tensors in order to treat them as a matrix 449 */ 450 template <typename OutScalar, typename LhsScalar, typename RhsScalar, typename OutAccessor, typename LhsMapper, 451 typename RhsMapper, typename StorageIndex, typename Properties, typename TripleDim, bool Vectorizable, 452 typename input_mapper_properties, bool IsFinal, contraction_type contraction_tp> 453 class TensorContractionKernel { 454 public: 455 typedef typename Eigen::TensorSycl::internal::Vectorise<OutScalar, Eigen::SyclDevice, Vectorizable>::PacketReturnType 456 PacketReturnType; 457 static EIGEN_CONSTEXPR int PacketSize = 458 Eigen::TensorSycl::internal::Vectorise<OutScalar, Eigen::SyclDevice, Vectorizable>::PacketSize; 459 static EIGEN_CONSTEXPR bool is_lhs_transposed = 460 !::Eigen::internal::TensorContractionInputMapperTrait<LhsMapper>::inner_dim_contiguous; 461 static EIGEN_CONSTEXPR bool is_rhs_transposed = 462 !::Eigen::internal::TensorContractionInputMapperTrait<RhsMapper>::inner_dim_contiguous; 463 464 typedef BlockProperties<is_lhs_transposed, false, input_mapper_properties::is_lhs_matrix && Vectorizable, 465 PacketReturnType> 466 LHSBlockProperties; 467 468 typedef BlockProperties<is_rhs_transposed, true, input_mapper_properties::is_rhs_matrix && Vectorizable, 469 PacketReturnType> 470 RHSBlockProperties; 471 472 static EIGEN_CONSTEXPR StorageIndex NStride = 473 contraction_tp == contraction_type::local ? Properties::WorkLoadPerThreadN : RHSBlockProperties::nc_stride; 474 475 typedef cl::sycl::accessor<OutScalar, 1, cl::sycl::access::mode::read_write, cl::sycl::access::target::local> Scratch; 476 typedef cl::sycl::multi_ptr<OutScalar, cl::sycl::access::address_space::local_space> local_ptr; 477 typedef OutScalar * /*cl::sycl::multi_ptr<OutScalar, cl::sycl::access::address_space::private_space>*/ private_ptr; 478 typedef 479 typename ::Eigen::internal::conditional<contraction_tp == contraction_type::local, local_ptr, private_ptr>::type 480 tile_ptr; 481 static EIGEN_CONSTEXPR StorageIndex LSDL = contraction_tp == contraction_type::local 482 ? Properties::TileSizeDimM + Properties::BC 483 : Properties::WorkLoadPerThreadM; 484 static EIGEN_CONSTEXPR StorageIndex LSDR = contraction_tp == contraction_type::local 485 ? Properties::TileSizeDimN + Properties::BC 486 : Properties::WorkLoadPerThreadN; 487 static EIGEN_CONSTEXPR StorageIndex LocalOffset = Properties::LocalThreadSizeM * Properties::LocalThreadSizeN; 488 489 /** 490 * \brief MemHolder this is a place holder struct for creating memory hierarchy in SYCL. Inside SYCL kernel it is not 491 * allowed to have dynamic memory allocation. While the local memory is created outside of the kernel and passed to 492 * the kernel as an accessor, the private memory can only allowed to be allocated statically. Since we are abstracting 493 * the TiledMemory for both local and private memory, the MemHolder structs is used as a helper to abstract out 494 * different type of memory needed when local/no_local memory computation is called. 495 * 496 * \tparam contraction_type: it is an enum value representing whether the local memroy/no local memory implementation 497 of the algorithm to be used 498 * \tparam the private memory size 499 * \param ptr the tile memory pointer type 500 */ 501 template <contraction_type, StorageIndex> 502 struct MemHolder { 503 tile_ptr ptr; 504 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE MemHolder(local_ptr block_start_ptr) : ptr(block_start_ptr) {} 505 }; 506 /** 507 * \brief specialization of memHolder class when no local memory kernel is used. 508 */ 509 template <StorageIndex MemSize> 510 struct MemHolder<contraction_type::no_local, MemSize> { 511 OutScalar ptr[MemSize] = {OutScalar{0}}; 512 }; 513 /** 514 * \brief TiledMemory: contains required memory pointer for loading each tile of the TensorContraction panel from 515 * global memory to local/private memory when local/no_local algorithm used. 516 * 517 * \param lhs_scratch_extract : determines the LHS tile memory. It is either private or local memory based on the 518 * selected contraction_type. 519 * 520 * \param rhs_scratch_extract : determines the RHS tile memory. It is either private or local memory based on the 521 * selected contraction_type. 522 * 523 * \param lhs_extract_index: determins the position of each thread on a local memory for lhs input. When private 524 * memory is used this is set to zero as this is not applicable in case of private memory. 525 * 526 * \param rhs_extract_index: determins the position of each thread on a local memory for rhs input. When private 527 * memory is used this is set to zero as this is not applicable in case of private memory. 528 * 529 * \param lhs_scratch_compute : determines the location to load for computation for lhs_local memory. This is the 530 * same as lhs_scratch_extract for private memory. 531 * 532 * \param rhs_scratch_compute : determines the location to load for computation for rhs_local memory. This is the 533 * same as rhs_scratch_extract for private memory. 534 */ 535 struct TiledMemory { 536 MemHolder<contraction_tp, Properties::WorkLoadPerThreadM * Properties::TileSizeDimK> lhs_scratch_extract; 537 MemHolder<contraction_tp, Properties::WorkLoadPerThreadN * Properties::TileSizeDimK> rhs_scratch_extract; 538 tile_ptr lhs_scratch_ptr_compute; 539 tile_ptr rhs_scratch_ptr_compute; 540 const std::pair<StorageIndex, StorageIndex> lhs_extract_index; 541 const std::pair<StorageIndex, StorageIndex> rhs_extract_index; 542 template <contraction_type tp = contraction_tp> 543 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE 544 TiledMemory(const ThreadProperties<StorageIndex> &, local_ptr, 545 typename ::Eigen::internal::enable_if<tp == contraction_type::no_local>::type * = 0) 546 : lhs_scratch_extract{}, 547 rhs_scratch_extract{}, 548 lhs_scratch_ptr_compute(lhs_scratch_extract.ptr), 549 rhs_scratch_ptr_compute(rhs_scratch_extract.ptr), 550 lhs_extract_index(std::pair<StorageIndex, StorageIndex>(StorageIndex{0}, StorageIndex{0})), 551 rhs_extract_index(std::pair<StorageIndex, StorageIndex>(StorageIndex{0}, StorageIndex{0})) {} 552 553 template <contraction_type tp = contraction_tp> 554 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE 555 TiledMemory(const ThreadProperties<StorageIndex> &thread_properties, local_ptr block_start_ptr, 556 typename ::Eigen::internal::enable_if<tp == contraction_type::local>::type * = 0) 557 : lhs_scratch_extract{block_start_ptr}, 558 rhs_scratch_extract{lhs_scratch_extract.ptr + 559 ((Properties::DoubleBuffer + 1) * LSDL * Properties::TileSizeDimK)}, 560 lhs_scratch_ptr_compute(lhs_scratch_extract.ptr + thread_properties.mLocalOffset), 561 rhs_scratch_ptr_compute(rhs_scratch_extract.ptr + thread_properties.nLocalOffset), 562 lhs_extract_index( 563 local_id_extract<LHSBlockProperties, Properties::TileSizeDimM>(thread_properties.linearLocalThreadId)), 564 rhs_extract_index( 565 local_id_extract<RHSBlockProperties, Properties::TileSizeDimN>(thread_properties.linearLocalThreadId)) {} 566 }; 567 568 Scratch scratch; 569 const LhsMapper lhs; 570 const RhsMapper rhs; 571 OutAccessor out_res; 572 const StorageIndex groupSizeM; 573 const StorageIndex groupSizeN; 574 const StorageIndex numTiles; 575 const TripleDim triple_dim; 576 577 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE TensorContractionKernel(Scratch scratch_, const LhsMapper lhs_, 578 const RhsMapper rhs_, OutAccessor out_res_, 579 const StorageIndex groupSizeM_, 580 const StorageIndex groupSizeN_, 581 const StorageIndex numTiles_, 582 const TripleDim triple_dim_) 583 : scratch(scratch_), 584 lhs(lhs_), 585 rhs(rhs_), 586 out_res(out_res_), 587 groupSizeM(groupSizeM_), 588 groupSizeN(groupSizeN_), 589 numTiles(numTiles_), 590 triple_dim(triple_dim_) {} 591 592 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE TensorContractionKernel(Scratch scratch_, const LhsMapper lhs_, 593 const RhsMapper rhs_, OutAccessor out_res_, 594 const StorageIndex groupSizeM_, 595 const StorageIndex numTiles_, 596 const TripleDim triple_dim_) 597 : TensorContractionKernel(scratch_, lhs_, rhs_, out_res_, groupSizeM_, 1, numTiles_, triple_dim_) {} 598 599 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void operator()(cl::sycl::nd_item<1> itemID) { 600 const StorageIndex linearLocalThreadId = itemID.get_local_id(0); 601 const StorageIndex nLocalThreadId = linearLocalThreadId / Properties::LocalThreadSizeM; 602 const StorageIndex mLocalThreadId = linearLocalThreadId % Properties::LocalThreadSizeM; 603 const StorageIndex mGroupId = itemID.get_group(0) % groupSizeM; 604 const StorageIndex tmp = itemID.get_group(0) / groupSizeM; 605 const StorageIndex nGroupId = IsFinal ? tmp : tmp % groupSizeN; 606 const StorageIndex kGroupId = IsFinal ? 0 : tmp / groupSizeN; 607 const StorageIndex mGroupOffset = mGroupId * Properties::TileSizeDimM; 608 const StorageIndex nGroupOffset = nGroupId * Properties::TileSizeDimN; 609 const StorageIndex mLocalOffset = PacketSize * mLocalThreadId; 610 const StorageIndex nLocalOffset = NStride * nLocalThreadId; 611 const StorageIndex mGlobalOffset = mGroupOffset + mLocalOffset; 612 const StorageIndex nGlobalOffset = nGroupOffset + nLocalOffset; 613 614 const StorageIndex kSizePerWG = IsFinal ? triple_dim.K : numTiles * Properties::TileSizeDimK; 615 StorageIndex kGroupOffset = kGroupId * kSizePerWG; 616 const bool is_internal = triple_dim.M - mGroupOffset >= Properties::TileSizeDimM && 617 triple_dim.N - nGroupOffset >= Properties::TileSizeDimN && 618 triple_dim.K - kGroupOffset >= kSizePerWG; 619 // this is used to adjust the last block 620 StorageIndex kSize = IsFinal ? triple_dim.K : std::min(kSizePerWG, triple_dim.K - kGroupOffset); 621 // This is used to find out the lats K offset so that kGroupOffset -kSize can compute the coffset for loading to 622 // tile 623 kGroupOffset += kSize; 624 625 auto thread_properties = 626 ThreadProperties<StorageIndex>(linearLocalThreadId, kGroupId, mGroupOffset, nGroupOffset, kGroupOffset, 627 mLocalOffset, nLocalOffset, mGlobalOffset, nGlobalOffset, kSize, is_internal); 628 629 auto out_ptr = out_res.get_pointer() + (IsFinal ? 0 : thread_properties.kGroupId * triple_dim.M * triple_dim.N); 630 631 (thread_properties.is_internal) ? compute_panel<true>(itemID, thread_properties, out_ptr) 632 : compute_panel<false>(itemID, thread_properties, out_ptr); 633 } 634 // The compute block computes the contraction operation private block for each thread and store the resutl in the 635 // privateRes memory of Each computation the compute block function is independent of local and no local concepts as 636 // it only compute the block on each thread's private memory space 637 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void compute_block_per_tile(OutScalar *lhs_block_ptr, OutScalar *rhs_block_ptr, 638 PacketReturnType *privateRes) { 639 StorageIndex idx = 0; 640 EIGEN_CONSTEXPR StorageIndex lhs_stride = 641 contraction_tp == contraction_type::local ? (PacketSize * Properties::LocalThreadSizeM) : 1; 642 EIGEN_UNROLL_LOOP 643 for (StorageIndex wLPTN = 0; wLPTN < Properties::WorkLoadPerThreadN; wLPTN++) { 644 auto rhsPacket = PacketReturnType{*(rhs_block_ptr + wLPTN)}; 645 StorageIndex lhs_index = 0; 646 EIGEN_UNROLL_LOOP 647 for (StorageIndex wLPTM = 0; wLPTM < Properties::WorkLoadPerThreadM / PacketSize; wLPTM++) { 648 PacketReturnType lhsPack{}; 649 Eigen::TensorSycl::internal::PacketWrapper<PacketReturnType, PacketSize>::set_packet(lhsPack, 650 lhs_block_ptr + lhs_index); 651 privateRes[idx] = ::Eigen::internal::pmadd(lhsPack, rhsPacket, privateRes[idx]); 652 653 lhs_index += lhs_stride; 654 idx++; 655 } 656 } 657 } 658 // The store function write the computed contraction operation in the private memory of each thread to the global 659 // memory. The store function is independent of local and no local concepts s that it can be abstract out in the base 660 // class. 661 template <bool is_internal_block, StorageIndex PrivateNStride, typename OutPtr> 662 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void store(OutPtr *out_ptr, PacketReturnType *privateRes, 663 StorageIndex mGlobalOffset, StorageIndex nGlobalOffset) { 664 auto chk_bound = [&](const StorageIndex &mIndex, const StorageIndex &nIndex) EIGEN_DEVICE_FUNC { 665 return (mIndex + PacketSize - 1 < triple_dim.M && nGlobalOffset + nIndex < triple_dim.N); 666 }; 667 // when local memory is not used M and N are both accessed in a coalesced way. However, when local memory is 668 // available the k*N is transposed in the local to N*K therefore, each blocks operates on blockId* 669 // WorkLoadPerThreadN slice of N 670 EIGEN_CONSTEXPR StorageIndex GlobalNStride = 671 contraction_tp == contraction_type::local ? 1 : Properties::LocalThreadSizeN; 672 EIGEN_UNROLL_LOOP 673 for (StorageIndex wLPTN = 0; wLPTN < Properties::WorkLoadPerThreadN / PrivateNStride; wLPTN++) { 674 // output leading dimension 675 StorageIndex outputLD = 0; 676 // When local memory is used the PrivateNstride is always 1 because the coalesed access on N is loaded into Local 677 // memory and extracting from local to global is the same as no transposed version. However, when local memory is 678 // not used and RHS is transposed we packetize the load for RHS. 679 EIGEN_UNROLL_LOOP 680 for (StorageIndex nId = 0; nId < PrivateNStride; nId++) { 681 StorageIndex globalRow = mGlobalOffset; 682 EIGEN_UNROLL_LOOP 683 for (StorageIndex wLPTM = 0; wLPTM < Properties::WorkLoadPerThreadM / PacketSize; wLPTM++) { 684 PacketReturnType privetOut = privateRes[wLPTM]; 685 if (check_boundary<is_internal_block>(chk_bound(globalRow, nId))) { 686 // Store the final results in C. The C matrix has always M as a first StorageIndex and N as a second 687 // StorageIndex Therefore it is always coalesced layout 688 write<data_source::global_mem>(privetOut, out_ptr + outputLD + globalRow); 689 } else { 690 EIGEN_UNROLL_LOOP 691 for (StorageIndex mId = 0; mId < PacketSize; mId++) { 692 StorageIndex mOffset = globalRow + mId; 693 if (mOffset < triple_dim.M && (nGlobalOffset + nId < triple_dim.N)) { 694 out_ptr[mOffset + outputLD] = 695 Eigen::TensorSycl::internal::PacketWrapper<PacketReturnType, PacketSize>::scalarize(mId, privetOut); 696 } 697 } 698 } 699 globalRow += (PacketSize * Properties::LocalThreadSizeM); 700 } 701 outputLD += triple_dim.M; 702 privateRes += Properties::WorkLoadPerThreadM / PacketSize; 703 } 704 out_ptr += (GlobalNStride * outputLD); 705 706 nGlobalOffset += (PrivateNStride * GlobalNStride); 707 } 708 } 709 // when no local memory is used the following extract_block will be enabled 710 template <typename InputBlockProperties, bool is_internal_block, typename Input, typename PrivateReg, 711 contraction_type contract_tp = contraction_tp> 712 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE 713 typename ::Eigen::internal::enable_if<contract_tp == contraction_type::no_local>::type 714 extract_block(const Input &inpt, PrivateReg private_ptr, const std::pair<StorageIndex, StorageIndex> &, 715 const StorageIndex &ncOffset, const StorageIndex cOffset) { 716 EIGEN_CONSTEXPR StorageIndex LocalThreadSizeNC = 717 InputBlockProperties::is_rhs ? Properties::LocalThreadSizeN : Properties::LocalThreadSizeM; 718 EIGEN_CONSTEXPR StorageIndex WorkLoadPerThreadNC = 719 InputBlockProperties::is_rhs ? Properties::WorkLoadPerThreadN : Properties::WorkLoadPerThreadM; 720 const StorageIndex &NC = InputBlockProperties::is_rhs ? triple_dim.N : triple_dim.M; 721 722 auto chk_bound = [&](const StorageIndex &CIndex, const StorageIndex &NCIndex) EIGEN_DEVICE_FUNC { 723 return ((CIndex + InputBlockProperties::c_stride - 1 < triple_dim.K) && 724 (NCIndex + InputBlockProperties::nc_stride - 1 < NC)); 725 }; 726 const StorageIndex ld = InputBlockProperties::is_coalesced_layout ? NC : triple_dim.K; 727 StorageIndex cIndex = cOffset; 728 729 EIGEN_UNROLL_LOOP 730 for (StorageIndex cId = 0; cId < Properties::TileSizeDimK / InputBlockProperties::c_stride; cId++) { 731 StorageIndex ncIndex = ncOffset; 732 EIGEN_UNROLL_LOOP 733 for (StorageIndex ncId = 0; ncId < WorkLoadPerThreadNC / InputBlockProperties::nc_stride; ncId++) { 734 if (check_boundary<is_internal_block>(chk_bound(cIndex, ncIndex))) { 735 auto val = 736 read<InputBlockProperties::packet_load, InputBlockProperties::is_coalesced_layout, 737 InputBlockProperties::is_rhs, typename InputBlockProperties::OutType>(inpt, ncIndex, cIndex, ld); 738 739 write<StorageIndex, (InputBlockProperties::is_coalesced_layout ? 1 : WorkLoadPerThreadNC), 740 data_source::private_mem>(val, private_ptr); 741 } else { 742 EIGEN_UNROLL_LOOP 743 for (StorageIndex i = 0; i < InputBlockProperties::elements_per_access; i++) { 744 const StorageIndex ncInd = ncIndex + (InputBlockProperties::is_coalesced_layout ? i : 0); 745 const StorageIndex cInd = cIndex + (InputBlockProperties::is_coalesced_layout ? 0 : i); 746 OutScalar val = 747 (ncInd < NC && cInd < triple_dim.K) 748 ? read<false, InputBlockProperties::is_coalesced_layout, InputBlockProperties::is_rhs, OutScalar>( 749 inpt, ncInd, cInd, ld) 750 : OutScalar(0); 751 write<StorageIndex, (InputBlockProperties::is_coalesced_layout ? 1 : WorkLoadPerThreadNC), 752 data_source::private_mem>( 753 val, private_ptr + (InputBlockProperties::is_coalesced_layout ? i : 0) + 754 ((InputBlockProperties::is_coalesced_layout ? 0 : i) * WorkLoadPerThreadNC)); 755 } 756 } 757 758 // if it is lhs we have to load it packetised when the packet size is > 1, because the output is coalesced. So 759 // even if M is not accessed in a coalesced mode, we have to load packet_size number of m per thread. 760 ncIndex = (!InputBlockProperties::is_rhs && InputBlockProperties::nc_stride == 1 && PacketSize != 1) 761 ? ncOffset + (ncId + 1) % PacketSize + ((ncId + 1) / PacketSize) * LocalThreadSizeNC 762 : (ncIndex + InputBlockProperties::nc_stride * LocalThreadSizeNC); 763 private_ptr += InputBlockProperties::nc_stride; 764 } 765 // the previous for loop ( private_ptr += (ncId * nc_stride)) has already moved ptr with one WorkLoadPerThreadNC 766 private_ptr += (InputBlockProperties::c_stride - 1) * WorkLoadPerThreadNC; 767 cIndex += InputBlockProperties::c_stride; 768 } 769 } 770 template <typename InputBlockProperties, StorageIndex TileSizeDimNC> 771 static EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE std::pair<StorageIndex, StorageIndex> local_id_extract( 772 const StorageIndex &linearLocalThreadId) { 773 const StorageIndex localThreadNC = 774 (InputBlockProperties::is_coalesced_layout) 775 ? linearLocalThreadId % (TileSizeDimNC / InputBlockProperties::nc_stride) 776 : linearLocalThreadId / (Properties::TileSizeDimK / InputBlockProperties::c_stride); 777 const StorageIndex localThreadC = 778 (InputBlockProperties::is_coalesced_layout) 779 ? linearLocalThreadId / (TileSizeDimNC / InputBlockProperties::nc_stride) 780 : linearLocalThreadId % (Properties::TileSizeDimK / InputBlockProperties::c_stride); 781 return std::pair<StorageIndex, StorageIndex>(localThreadNC, localThreadC); 782 } 783 784 template <bool db = Properties::DoubleBuffer, contraction_type ctp = contraction_tp> 785 static EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE 786 typename ::Eigen::internal::enable_if<db && ctp == contraction_type::local>::type 787 sync_mem(const cl::sycl::nd_item<1> &, bool &db_offset) noexcept { 788 db_offset = !db_offset; 789 } 790 791 template <bool db = Properties::DoubleBuffer, contraction_type ctp = contraction_tp> 792 static EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE 793 typename ::Eigen::internal::enable_if<!db && ctp == contraction_type::local>::type 794 sync_mem(const cl::sycl::nd_item<1> &itemID, bool &) noexcept { 795 itemID.barrier(cl::sycl::access::fence_space::local_space); 796 } 797 798 template <contraction_type ctp = contraction_tp> 799 static EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE 800 typename ::Eigen::internal::enable_if<ctp == contraction_type::no_local>::type 801 sync_mem(const cl::sycl::nd_item<1> &, bool &) noexcept { 802 return; 803 } 804 805 template <bool need_sync, contraction_type ctp = contraction_tp> 806 static EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE 807 typename ::Eigen::internal::enable_if<need_sync && ctp == contraction_type::no_local>::type 808 sync_thread(const cl::sycl::nd_item<1> & 809 #ifdef EIGEN_SYCL_ARM_GPU_CACHE_OPTIMISATION 810 itemID 811 #endif 812 ) noexcept { 813 #ifdef EIGEN_SYCL_ARM_GPU_CACHE_OPTIMISATION 814 itemID.barrier(cl::sycl::access::fence_spacce::local_space); 815 #else 816 return; 817 #endif 818 } 819 template <bool need_sync, contraction_type ctp = contraction_tp> 820 static EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE 821 typename ::Eigen::internal::enable_if<need_sync && ctp == contraction_type::local>::type 822 sync_thread(const cl::sycl::nd_item<1> &itemID) { 823 itemID.barrier(cl::sycl::access::fence_space::local_space); 824 } 825 template <bool need_sync> 826 static EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE typename ::Eigen::internal::enable_if<!need_sync>::type sync_thread( 827 const cl::sycl::nd_item<1> &) { 828 return; 829 } 830 831 template <bool is_internal_block> 832 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void compute_tile_per_panel(const cl::sycl::nd_item<1> &itemID, 833 ThreadProperties<StorageIndex> &thread_properties, 834 TiledMemory &tiled_input_block, 835 PacketReturnType *privateRes, bool &db_offset) { 836 // Tiling the Rhs block from global to local memory 837 extract_block<RHSBlockProperties, is_internal_block>( 838 rhs, tiled_input_block.rhs_scratch_extract.ptr + (db_offset * Properties::TileSizeDimK * LSDR), 839 tiled_input_block.rhs_extract_index, 840 contraction_tp == contraction_type::local ? thread_properties.nGroupOffset : thread_properties.nGlobalOffset, 841 thread_properties.kGroupOffset - thread_properties.kSize); 842 843 sync_thread<contraction_tp == contraction_type::no_local>(itemID); 844 845 // Tiling the Lhs block from global to local memory 846 extract_block<LHSBlockProperties, is_internal_block>( 847 lhs, tiled_input_block.lhs_scratch_extract.ptr + (db_offset * LSDL * Properties::TileSizeDimK), 848 tiled_input_block.lhs_extract_index, 849 contraction_tp == contraction_type::local ? thread_properties.mGroupOffset : thread_properties.mGlobalOffset, 850 thread_properties.kGroupOffset - thread_properties.kSize); 851 852 // itemID.barrier(cl::sycl::access::fence_space::local_space); 853 sync_thread<contraction_tp == contraction_type::local>(itemID); 854 // switch to compute mede 855 StorageIndex lhs_offset = (db_offset * LSDL * Properties::TileSizeDimK); 856 StorageIndex rhs_offset = (db_offset * Properties::TileSizeDimK * LSDR); 857 // Loop over the values of a single tile 858 for (StorageIndex k = 0; k < Properties::TileSizeDimK; k++) { 859 compute_block_per_tile(tiled_input_block.lhs_scratch_ptr_compute + lhs_offset, 860 tiled_input_block.rhs_scratch_ptr_compute + rhs_offset, privateRes); 861 lhs_offset += LSDL; 862 rhs_offset += LSDR; 863 } 864 // computing the K index for the next tile 865 thread_properties.kSize -= Properties::TileSizeDimK; 866 sync_mem(itemID, db_offset); 867 } 868 869 // when local memory is available the following compute_panel will be enabled 870 template <bool is_internal_block, typename OutPtr> 871 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void compute_panel(const cl::sycl::nd_item<1> &itemID, 872 ThreadProperties<StorageIndex> &thread_properties, 873 OutPtr out_ptr) { 874 auto tiled_input_block = TiledMemory{thread_properties, scratch.get_pointer()}; 875 // Allocate register space 876 PacketReturnType privateRes[Properties::WorkLoadPerThreadM * Properties::WorkLoadPerThreadN / PacketSize] = { 877 PacketReturnType{0}}; 878 bool db_offset = 0; 879 880 while (thread_properties.kSize >= Properties::TileSizeDimK) { 881 compute_tile_per_panel<is_internal_block>(itemID, thread_properties, tiled_input_block, privateRes, db_offset); 882 } 883 if (thread_properties.kSize > 0) { 884 compute_tile_per_panel<false>(itemID, thread_properties, tiled_input_block, privateRes, db_offset); 885 } 886 887 // Storing the final results in the output 888 store<is_internal_block, 889 contraction_tp == contraction_type::local ? static_cast<StorageIndex>(1) : RHSBlockProperties::nc_stride>( 890 out_ptr + thread_properties.nGlobalOffset * triple_dim.M, privateRes, thread_properties.mGlobalOffset, 891 thread_properties.nGlobalOffset); 892 } 893 // When local memory is available the following extract_block will be enabled 894 template <typename InputBlockProperties, bool is_internal_block, typename Input, typename Local, 895 contraction_type contract_tp = contraction_tp> 896 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE 897 typename ::Eigen::internal::enable_if<contract_tp == contraction_type::local>::type 898 extract_block(const Input &inpt, Local local_ptr, const std::pair<StorageIndex, StorageIndex>& local_index, 899 const StorageIndex &ncOffset, const StorageIndex cOffset) { 900 EIGEN_CONSTEXPR StorageIndex TileSizeDimNC = 901 InputBlockProperties::is_rhs ? Properties::TileSizeDimN : Properties::TileSizeDimM; 902 EIGEN_CONSTEXPR StorageIndex LoadPerThread = 903 InputBlockProperties::is_rhs ? Properties::LoadPerThreadRhs : Properties::LoadPerThreadLhs; 904 EIGEN_CONSTEXPR StorageIndex LSD = InputBlockProperties::is_rhs ? LSDR : LSDL; 905 static_assert(((LocalOffset % (TileSizeDimNC / InputBlockProperties::nc_stride) == 0) && 906 (LocalOffset % (Properties::TileSizeDimK / InputBlockProperties::c_stride) == 0)), 907 " LocalOffset must be divisable by stride"); 908 const StorageIndex &NC = InputBlockProperties::is_rhs ? triple_dim.N : triple_dim.M; 909 StorageIndex localThreadNC = local_index.first; 910 StorageIndex localThreadC = local_index.second; 911 auto chk_bound = [&](const StorageIndex &CIndex, const StorageIndex &NCIndex) EIGEN_DEVICE_FUNC { 912 return ((CIndex + InputBlockProperties::c_stride - 1 < triple_dim.K) && 913 (NCIndex + InputBlockProperties::nc_stride - 1 < NC)); 914 }; 915 EIGEN_UNROLL_LOOP 916 for (StorageIndex lPT = 0; lPT < LoadPerThread / InputBlockProperties::elements_per_access; lPT++) { 917 const StorageIndex CIndex = cOffset + (InputBlockProperties::c_stride * localThreadC); 918 const StorageIndex NCIndex = ncOffset + (InputBlockProperties::nc_stride * localThreadNC); 919 const StorageIndex ld = InputBlockProperties::is_coalesced_layout ? NC : triple_dim.K; 920 if (check_boundary<is_internal_block>(chk_bound(CIndex, NCIndex))) { 921 auto val = 922 read<InputBlockProperties::packet_load, InputBlockProperties::is_coalesced_layout, 923 InputBlockProperties::is_rhs, typename InputBlockProperties::OutType>(inpt, NCIndex, CIndex, ld); 924 write<StorageIndex, (InputBlockProperties::is_coalesced_layout ? 1 : LSD), data_source::local_mem>( 925 val, local_ptr + (InputBlockProperties::nc_stride * localThreadNC) + 926 (InputBlockProperties::c_stride * localThreadC * LSD)); 927 } else { 928 EIGEN_UNROLL_LOOP 929 for (StorageIndex i = 0; i < InputBlockProperties::elements_per_access; i++) { 930 const StorageIndex nCInd = NCIndex + (InputBlockProperties::is_coalesced_layout ? i : 0); 931 const StorageIndex cInd = CIndex + (InputBlockProperties::is_coalesced_layout ? 0 : i); 932 OutScalar val = 933 (nCInd < NC && cInd < triple_dim.K) 934 ? read<false, InputBlockProperties::is_coalesced_layout, InputBlockProperties::is_rhs, OutScalar>( 935 inpt, nCInd, cInd, ld) 936 : OutScalar(0); 937 938 write<StorageIndex, (InputBlockProperties::is_coalesced_layout ? 1 : LSD), data_source::local_mem>( 939 val, local_ptr + (InputBlockProperties::nc_stride * localThreadNC) + 940 (InputBlockProperties::is_coalesced_layout ? i : 0) + 941 ((InputBlockProperties::c_stride * localThreadC + 942 (InputBlockProperties::is_coalesced_layout ? 0 : i)) * 943 LSD)); 944 } 945 } 946 localThreadNC += (InputBlockProperties::is_coalesced_layout) 947 ? LocalOffset % (TileSizeDimNC / InputBlockProperties::nc_stride) 948 : LocalOffset / (Properties::TileSizeDimK / InputBlockProperties::c_stride); 949 localThreadC += (InputBlockProperties::is_coalesced_layout) 950 ? LocalOffset / (TileSizeDimNC / InputBlockProperties::nc_stride) 951 : LocalOffset % (Properties::TileSizeDimK / InputBlockProperties::c_stride); 952 } 953 } 954 }; 955 956 #ifndef EIGEN_SYCL_DISABLE_GEMV 957 958 /*! 959 * \brief GeneralVectorTensor is a template class that provides Tensor -vector contraction operation, which is a special 960 * case of Tensor Tensor contraction. 961 * 962 * \tparam OutScalar: determines the output scalar type 963 * 964 * \tparam OutAccessor: determines the sycl accessor type for out put (please see the sycl-1.2.1 specification 965 * (https://www.khronos.org/registry/SYCL/specs/sycl-1.2.1.pdf) for accessor definition) 966 * 967 * \tparam VectorMapper: determines the tensor contraction mapper for the vector input (can be lhs or rhs) 968 * 969 * \tparam TensorMapper: determines the tensor contraction mapper for the tensor input (can be lhs or rhs) 970 * 971 * \tparam StorageIndex: determines the StorageIndex Type 972 * 973 * \tparam Properties: determines the Contraction Panel properties 974 * 975 * \tparam KFactor: determines the number of elements in K dimension in a Tile 976 * 977 * \tparam Vectorizable: determines whether or not the vectorization is enabled for the Eigen expression. 978 * 979 * \tparam is_lhs_vec: determines whether lhs is a vector or rhs is a vector 980 * 981 * \tparam IsFinal: determine if this is the final kernel. If so, the result will be written in a final output. 982 * Otherwise, the result of contraction will be written iin a temporary buffer. 983 * 984 * \param scratch: determines the local memory containing the vector block for each work-group 985 * 986 * \param vec: determines the vector input (tensor mapper) 987 * 988 * \param mat: determines the tensor input (tensor mapper) 989 * 990 * \param out_res: determines the output vector containing the contraction result 991 * 992 * \param nonContractGroupSize: a logical number determining the number of work-group for non-contracting dimension 993 * 994 * \param nonContractDim: determines the size of non contracting dimension for the flattened tensor 995 * 996 * \param contractDim: determines the size of non contracting dimension for the flattened tensor 997 * 998 */ 999 template <typename OutScalar, typename OutAccessor, typename VectorMapper, typename TensorMapper, typename StorageIndex, 1000 typename Properties, StorageIndex KFactor, bool Vectorizable, bool is_lhs_vec, bool IsFinal> 1001 struct GeneralVectorTensor { 1002 typedef typename Eigen::TensorSycl::internal::Vectorise<OutScalar, Eigen::SyclDevice, Vectorizable>::PacketReturnType 1003 PacketReturnType; 1004 static EIGEN_CONSTEXPR int PacketSize = 1005 Eigen::TensorSycl::internal::Vectorise<OutScalar, Eigen::SyclDevice, Vectorizable>::PacketSize; 1006 typedef cl::sycl::accessor<OutScalar, 1, cl::sycl::access::mode::read_write, cl::sycl::access::target::local> Scratch; 1007 1008 static EIGEN_CONSTEXPR StorageIndex OutScratchOffset = 1009 KFactor * Properties::LocalThreadSizeC * Properties::LocalThreadSizeNC; 1010 1011 // Since the access layout for a vector can always be coalesced, when LHS is a vector, we pass false and false to make 1012 // sure that the !^ is true When RHS is a vector, we pass true and true to make sure that the !^ is true. 1013 typedef BlockProperties<is_lhs_vec ? false : true, is_lhs_vec ? false : true, Vectorizable, PacketReturnType> 1014 VecBlockProperties; 1015 1016 Scratch scratch; 1017 const VectorMapper vec; 1018 const TensorMapper mat; 1019 OutAccessor out_res; 1020 const StorageIndex nonContractGroupSize; 1021 const StorageIndex nonContractDim; 1022 const StorageIndex contractDim; 1023 1024 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE GeneralVectorTensor(Scratch scratch_, const VectorMapper vec_, 1025 const TensorMapper mat_, OutAccessor out_res_, 1026 const StorageIndex nonContractGroupSize_, 1027 const StorageIndex nonContractDim_, 1028 const StorageIndex contractDim_) 1029 : scratch(scratch_), 1030 vec(vec_), 1031 mat(mat_), 1032 out_res(out_res_), 1033 nonContractGroupSize(nonContractGroupSize_), 1034 nonContractDim(nonContractDim_), 1035 contractDim(contractDim_) {} 1036 1037 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void operator()(cl::sycl::nd_item<1> itemID) { 1038 auto scratch_ptr = scratch.get_pointer(); 1039 const StorageIndex linearLocalThreadId = itemID.get_local_id(0); 1040 StorageIndex nonContractId = is_lhs_vec ? linearLocalThreadId / Properties::LocalThreadSizeC 1041 : linearLocalThreadId % Properties::LocalThreadSizeNC; 1042 StorageIndex contractId = is_lhs_vec ? linearLocalThreadId % Properties::LocalThreadSizeC 1043 : linearLocalThreadId / Properties::LocalThreadSizeNC; 1044 const StorageIndex cGroupSize = itemID.get_group_range(0) / nonContractGroupSize; 1045 const StorageIndex nonContractGroupId = 1046 is_lhs_vec ? itemID.get_group(0) / cGroupSize : itemID.get_group(0) % nonContractGroupSize; 1047 const StorageIndex contractGroupId = 1048 is_lhs_vec ? itemID.get_group(0) % cGroupSize : itemID.get_group(0) / nonContractGroupSize; 1049 auto out_ptr = out_res.get_pointer() + (IsFinal ? 0 : contractGroupId * nonContractDim); 1050 1051 const StorageIndex nonContractGroupOffset = nonContractGroupId * Properties::TileSizeDimNC; 1052 const StorageIndex contractGroupOffset = contractGroupId * Properties::TileSizeDimC; 1053 auto outScratchIndex = nonContractId + contractId * Properties::LocalThreadSizeNC; 1054 const StorageIndex globalNonContractDimOffset = nonContractGroupOffset + nonContractId; 1055 const StorageIndex globalContractDimOffset = contractGroupOffset + contractId; 1056 auto local_output = scratch_ptr + OutScratchOffset; 1057 const bool is_internal = nonContractDim - nonContractGroupOffset >= Properties::TileSizeDimNC && 1058 contractDim - contractGroupOffset >= Properties::TileSizeDimC; 1059 is_internal 1060 ? compute_panel<true>(itemID, vec, mat, local_output, out_ptr, 1061 #ifdef EIGEN_SYCL_LOCAL_MEM_UNSET_OR_ON 1062 scratch_ptr, contractGroupOffset, 1063 #endif 1064 nonContractGroupOffset, linearLocalThreadId, contractDim, nonContractDim, contractId, 1065 nonContractId, globalContractDimOffset, globalNonContractDimOffset, outScratchIndex) 1066 : compute_panel<false>(itemID, vec, mat, local_output, out_ptr, 1067 #ifdef EIGEN_SYCL_LOCAL_MEM_UNSET_OR_ON 1068 scratch_ptr, contractGroupOffset, 1069 #endif 1070 nonContractGroupOffset, linearLocalThreadId, contractDim, nonContractDim, contractId, 1071 nonContractId, globalContractDimOffset, globalNonContractDimOffset, outScratchIndex); 1072 } 1073 template <bool is_internal_block, typename OutPtr> 1074 static EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void compute_panel( 1075 const cl::sycl::nd_item<1> &itemID, const VectorMapper &vec, const TensorMapper &mat, OutScalar *local_output, 1076 OutPtr out_ptr, 1077 #ifdef EIGEN_SYCL_LOCAL_MEM_UNSET_OR_ON 1078 OutScalar *scratch_ptr, const StorageIndex contractGroupOffset, 1079 #endif 1080 const StorageIndex nonContractGroupOffset, const StorageIndex linearLocalThreadId, StorageIndex contractDim, 1081 StorageIndex nonContractDim, StorageIndex contractId, StorageIndex nonContractId, 1082 StorageIndex globalContractDimOffset, StorageIndex globalNonContractDimOffset, StorageIndex outScratchIndex) { 1083 OutScalar outScalar[Properties::WorkLoadPerThreadNC] = {OutScalar(0)}; 1084 // Reading the vector 1085 #ifdef EIGEN_SYCL_LOCAL_MEM_UNSET_OR_ON 1086 const StorageIndex vectorOffset = contractGroupOffset + linearLocalThreadId; 1087 extract_block<VecBlockProperties, is_internal_block, KFactor, 1088 Properties::LocalThreadSizeNC * Properties::LocalThreadSizeC>(vec, scratch_ptr, linearLocalThreadId, 1089 vectorOffset, contractDim); 1090 1091 itemID.barrier(cl::sycl::access::fence_space::local_space); 1092 auto in_scratch_ptr = scratch_ptr + contractId; 1093 #endif 1094 1095 StorageIndex privateOffsetC = 0; 1096 EIGEN_UNROLL_LOOP 1097 for (StorageIndex i = 0; i < Properties::WorkLoadPerThreadC; i++) { 1098 StorageIndex privateOffsetNC = 0; 1099 bool contract_conds = ((globalContractDimOffset + privateOffsetC) < contractDim); 1100 #ifdef EIGEN_SYCL_LOCAL_MEM_UNSET_OR_ON 1101 auto vecScalar = *in_scratch_ptr; 1102 #else 1103 auto vecScalar = (check_boundary<is_internal_block>(contract_conds)) 1104 ? vec(is_lhs_vec ? StorageIndex(0) : globalContractDimOffset + privateOffsetC, 1105 is_lhs_vec ? globalContractDimOffset + privateOffsetC : StorageIndex(0)) 1106 : OutScalar(0); 1107 #endif 1108 EIGEN_UNROLL_LOOP 1109 for (StorageIndex j = 0; j < Properties::WorkLoadPerThreadNC; j++) { 1110 auto matScalar = (check_boundary<is_internal_block>( 1111 contract_conds && ((globalNonContractDimOffset + privateOffsetNC) < nonContractDim))) 1112 ? mat(is_lhs_vec ? globalContractDimOffset + privateOffsetC 1113 : globalNonContractDimOffset + privateOffsetNC, 1114 is_lhs_vec ? globalNonContractDimOffset + privateOffsetNC 1115 : globalContractDimOffset + privateOffsetC) 1116 : OutScalar(0); 1117 1118 outScalar[j] = cl::sycl::mad(matScalar, vecScalar, outScalar[j]); 1119 privateOffsetNC += Properties::LocalThreadSizeNC; 1120 } 1121 privateOffsetC += Properties::LocalThreadSizeC; 1122 #ifdef EIGEN_SYCL_LOCAL_MEM_UNSET_OR_ON 1123 in_scratch_ptr += Properties::LocalThreadSizeC; 1124 #endif 1125 } 1126 1127 auto out_scratch_ptr = local_output + outScratchIndex; 1128 // Each block of 16*16 element in shared memory should reduce to 16*1 1129 EIGEN_UNROLL_LOOP 1130 for (StorageIndex j = 0; j < Properties::WorkLoadPerThreadNC; j++) { 1131 *out_scratch_ptr = outScalar[j]; 1132 1133 out_scratch_ptr += (Properties::LocalThreadSizeNC * Properties::LocalThreadSizeC); 1134 } 1135 if (is_lhs_vec) { 1136 nonContractId = linearLocalThreadId % Properties::LocalThreadSizeNC; 1137 contractId = linearLocalThreadId / Properties::LocalThreadSizeNC; 1138 outScratchIndex = nonContractId + contractId * Properties::LocalThreadSizeNC; 1139 } 1140 1141 out_scratch_ptr = local_output + outScratchIndex; 1142 EIGEN_UNROLL_LOOP 1143 for (StorageIndex j = 0; j < Properties::WorkLoadPerThreadNC; j++) { 1144 EIGEN_UNROLL_LOOP 1145 for (StorageIndex offset = Properties::LocalThreadSizeC >> 1; offset > 0; offset >>= 1) { 1146 itemID.barrier(cl::sycl::access::fence_space::local_space); 1147 if (contractId < offset) { 1148 StorageIndex myNeigbourId = (Properties::LocalThreadSizeNC * offset); 1149 *out_scratch_ptr += out_scratch_ptr[myNeigbourId]; 1150 } 1151 } 1152 // moving to next 16 by 16 block 1153 out_scratch_ptr += (Properties::LocalThreadSizeNC * Properties::LocalThreadSizeC); 1154 } 1155 1156 if (contractId == 0) { 1157 out_scratch_ptr = local_output + nonContractId; 1158 StorageIndex global_final_offset = nonContractGroupOffset + nonContractId; 1159 out_ptr += global_final_offset; 1160 EIGEN_UNROLL_LOOP 1161 for (StorageIndex j = 0; j < Properties::WorkLoadPerThreadNC; j++) { 1162 if (check_boundary<is_internal_block>(global_final_offset < nonContractDim)) { 1163 auto res = *out_scratch_ptr; 1164 1165 *out_ptr = res; 1166 out_ptr += Properties::LocalThreadSizeNC; 1167 } 1168 // moving to next 16 by 16 block to ge the next 16 reduced elements 1169 out_scratch_ptr += (Properties::LocalThreadSizeNC * Properties::LocalThreadSizeC); 1170 if (!(is_internal_block)) global_final_offset += Properties::LocalThreadSizeNC; 1171 } 1172 } 1173 } 1174 1175 template <typename InputBlockProperties, bool is_internal_block, int CFactor, int GroupSize, typename Input, 1176 typename Local> 1177 static EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void extract_block(const Input &inpt, Local *local_ptr, 1178 const StorageIndex &linearLocalThreadId, 1179 const StorageIndex &cOffset, const StorageIndex &C) { 1180 local_ptr += InputBlockProperties::c_stride * linearLocalThreadId; 1181 StorageIndex cIndex = cOffset; 1182 for (StorageIndex cId = 0; cId < CFactor / InputBlockProperties::c_stride; cId++) { 1183 if (check_boundary<is_internal_block>(cIndex + InputBlockProperties::c_stride - 1 < C)) { 1184 auto val = read<InputBlockProperties::packet_load, InputBlockProperties::is_coalesced_layout, 1185 InputBlockProperties::is_rhs, typename InputBlockProperties::OutType>(inpt, StorageIndex(0), 1186 cIndex, StorageIndex(1)); 1187 write<StorageIndex, 1, data_source::local_mem>(val, local_ptr); 1188 } else { 1189 EIGEN_UNROLL_LOOP 1190 for (StorageIndex i = 0; i < InputBlockProperties::elements_per_access; i++) { 1191 OutScalar val = 1192 (cIndex + i < C) 1193 ? read<false, InputBlockProperties::is_coalesced_layout, InputBlockProperties::is_rhs, OutScalar>( 1194 inpt, StorageIndex(0), cIndex + i, StorageIndex(1)) 1195 : OutScalar(0); 1196 write<StorageIndex, 1, data_source::local_mem>(val, local_ptr + i); 1197 } 1198 } 1199 local_ptr += InputBlockProperties::c_stride * GroupSize; 1200 cIndex += InputBlockProperties::c_stride * GroupSize; 1201 } 1202 } 1203 }; 1204 #endif 1205 1206 #ifndef EIGEN_SYCL_DISABLE_SCALAR 1207 1208 /*! 1209 * \brief GeneralScalarContraction is a template class that provides the scalar value of Tensor -Tensor contraction 1210 * operation, when all the dimensions are contracting dimensions. This Kernel reduces two tensors to an scalar 1211 * 1212 * \tparam OutScalar: determines the output scalar type 1213 * 1214 * \tparam LhsScalar: determines the left-hand-side scalar type 1215 * 1216 * \tparam RhsScalar: determines the right-hand-side scalar type 1217 * 1218 * \tparam OutAccessor: determines the sycl accessor type for out put (please see the sycl-1.2.1 specification 1219 * (https://www.khronos.org/registry/SYCL/specs/sycl-1.2.1.pdf) for accessor definition) 1220 * 1221 * \tparam LhsMapper: determines the tensor contraction mapper type for left-hand-side matrix 1222 * 1223 * \tparam RhsMapper: determines the tensor contraction mapper type for right-hand-side matrix 1224 * 1225 * \tparam StorageIndex: determines the StorageIndex Type 1226 * 1227 * \tparam Vectorizable: determines whether or not the vectorization is enabled for the Eigen expression. 1228 * 1229 * \param scratch: local memory containing tiles of LHS and RHS tensors for each work-group 1230 * 1231 * \param lhs: determines the left-hand-side flattened tensor (tensor mapper) 1232 * 1233 * \param rhs: determines the right-hand-side flattened tensor (tensor mapper) 1234 * 1235 * \param out_res: determines the output tensor containing the contraction result 1236 * 1237 * \param rng: determins the total input data size 1238 */ 1239 template <typename OutScalar, typename LhsScalar, typename RhsScalar, typename OutAccessor, typename LhsMapper, 1240 typename RhsMapper, typename StorageIndex, bool Vectorizable> 1241 struct GeneralScalarContraction { 1242 typedef cl::sycl::accessor<OutScalar, 1, cl::sycl::access::mode::read_write, cl::sycl::access::target::local> Scratch; 1243 Scratch scratch; 1244 const LhsMapper lhs; 1245 const RhsMapper rhs; 1246 OutAccessor out_res; 1247 const StorageIndex rng; 1248 1249 EIGEN_DEVICE_FUNC 1250 GeneralScalarContraction(Scratch scratch_, const LhsMapper lhs_, const RhsMapper rhs_, OutAccessor out_res_, 1251 const StorageIndex rng_) 1252 : scratch(scratch_), lhs(lhs_), rhs(rhs_), out_res(out_res_), rng(rng_) {} 1253 1254 EIGEN_DEVICE_FUNC void operator()(cl::sycl::nd_item<1> itemID) { 1255 auto out_ptr = out_res.get_pointer(); 1256 auto scratch_ptr = scratch.get_pointer().get(); 1257 1258 StorageIndex globalid = itemID.get_global_id(0); 1259 StorageIndex localid = itemID.get_local_id(0); 1260 OutScalar accumulator = OutScalar(0); 1261 for (StorageIndex i = globalid; i < rng; i += itemID.get_global_range(0)) { 1262 accumulator = cl::sycl::mad(lhs(0, i), rhs(i, 0), accumulator); 1263 } 1264 auto out_scratch_ptr = scratch_ptr + localid; 1265 *out_scratch_ptr = accumulator; 1266 for (StorageIndex offset = itemID.get_local_range(0) >> 1; offset > 0; offset >>= 1) { 1267 itemID.barrier(cl::sycl::access::fence_space::local_space); 1268 if (localid < offset) { 1269 *out_scratch_ptr = (accumulator += out_scratch_ptr[offset]); 1270 } 1271 } 1272 if (localid == 0) { 1273 out_ptr[itemID.get_group(0)] = accumulator; 1274 } 1275 } 1276 }; 1277 #endif 1278 1279 } // namespace internal 1280 } // namespace TensorSycl 1281 1282 template <typename Indices, typename LeftArgType, typename RightArgType, typename OutputKernelType> 1283 struct TensorEvaluator<const TensorContractionOp<Indices, LeftArgType, RightArgType, OutputKernelType>, 1284 Eigen::SyclDevice> 1285 : public TensorContractionEvaluatorBase<TensorEvaluator< 1286 const TensorContractionOp<Indices, LeftArgType, RightArgType, OutputKernelType>, Eigen::SyclDevice>> { 1287 static_assert(std::is_same<OutputKernelType, const NoOpOutputKernel>::value, 1288 "SYCL tensor contraction does not support output kernels."); 1289 1290 typedef Eigen::SyclDevice Device; 1291 1292 typedef TensorEvaluator<const TensorContractionOp<Indices, LeftArgType, RightArgType, OutputKernelType>, Device> Self; 1293 typedef TensorContractionEvaluatorBase<Self> Base; 1294 typedef TensorContractionOp<Indices, LeftArgType, RightArgType, OutputKernelType> XprType; 1295 typedef typename internal::remove_const<typename XprType::Scalar>::type Scalar; 1296 typedef typename XprType::Index StorageIndex; 1297 typedef typename XprType::CoeffReturnType CoeffReturnType; 1298 typedef typename PacketType<CoeffReturnType, Device>::type PacketReturnType; 1299 typedef typename Base::Storage Storage; 1300 typedef typename Base::EvaluatorPointerType EvaluatorPointerType; 1301 struct TripleDim { 1302 const StorageIndex M; 1303 const StorageIndex N; 1304 const StorageIndex K; 1305 TripleDim(const StorageIndex M_, const StorageIndex N_, const StorageIndex K_) : M(M_), N(N_), K(K_) {} 1306 }; 1307 enum { 1308 Layout = TensorEvaluator<LeftArgType, Device>::Layout, 1309 PacketAccess = (PacketType<CoeffReturnType, Device>::size > 1), 1310 BlockAccess = false, 1311 }; 1312 1313 static EIGEN_CONSTEXPR int LDims = Base::LDims; 1314 static EIGEN_CONSTEXPR int RDims = Base::RDims; 1315 static EIGEN_CONSTEXPR int ContractDims = Base::ContractDims; 1316 1317 typedef array<StorageIndex, LDims> left_dim_mapper_t; 1318 typedef array<StorageIndex, RDims> right_dim_mapper_t; 1319 1320 typedef array<StorageIndex, ContractDims> contract_t; 1321 typedef array<StorageIndex, LDims - ContractDims> left_nocontract_t; 1322 typedef array<StorageIndex, RDims - ContractDims> right_nocontract_t; 1323 1324 static const int NumDims = LDims + RDims - 2 * ContractDims; 1325 1326 typedef DSizes<StorageIndex, NumDims> Dimensions; 1327 1328 typedef TensorEvaluator<typename Base::EvalLeftArgType, Device> LeftEvaluator; 1329 typedef TensorEvaluator<typename Base::EvalRightArgType, Device> RightEvaluator; 1330 typedef typename Eigen::internal::remove_const<typename LeftEvaluator::CoeffReturnType>::type LhsScalar; 1331 typedef typename Eigen::internal::remove_const<typename RightEvaluator::CoeffReturnType>::type RhsScalar; 1332 1333 typedef typename LeftEvaluator::Dimensions LeftDimensions; 1334 typedef typename RightEvaluator::Dimensions RightDimensions; 1335 1336 template <bool lhs_inner_dim_contiguous, bool rhs_inner_dim_contiguous, bool rhs_inner_dim_reordered> 1337 struct input_mapper_propertis { 1338 static EIGEN_CONSTEXPR bool is_lhs_matrix = (LDims == 2 && ContractDims == 1) || lhs_inner_dim_contiguous; 1339 static EIGEN_CONSTEXPR bool is_rhs_matrix = 1340 (RDims == 2 && ContractDims == 1) || (rhs_inner_dim_contiguous && !rhs_inner_dim_reordered); 1341 }; 1342 1343 TensorEvaluator(const XprType &op, const Device &device) : Base(op, device) {} 1344 1345 // We need to redefine this method to make nvcc happy 1346 EIGEN_STRONG_INLINE bool evalSubExprsIfNeeded(typename Base::EvaluatorPointerType data) { 1347 this->m_leftImpl.evalSubExprsIfNeeded(NULL); 1348 this->m_rightImpl.evalSubExprsIfNeeded(NULL); 1349 if (!data) { 1350 this->m_result = this->m_device.get( 1351 static_cast<Scalar *>(this->m_device.allocate_temp(this->dimensions().TotalSize() * sizeof(Scalar)))); 1352 data = this->m_result; 1353 } 1354 evalToSycl(data); 1355 return (this->m_result != NULL); 1356 } 1357 const Eigen::SyclDevice &device() const { return this->m_device; } 1358 void evalToSycl(typename Base::EvaluatorPointerType buffer) const { 1359 if (this->m_lhs_inner_dim_contiguous) { 1360 if (this->m_rhs_inner_dim_contiguous) { 1361 if (this->m_rhs_inner_dim_reordered) { 1362 evalTyped<true, true, true, Unaligned>(buffer); 1363 } else { 1364 evalTyped<true, true, false, Unaligned>(buffer); 1365 } 1366 } else { 1367 if (this->m_rhs_inner_dim_reordered) { 1368 evalTyped<true, false, true, Unaligned>(buffer); 1369 } else { 1370 evalTyped<true, false, false, Unaligned>(buffer); 1371 } 1372 } 1373 } else { 1374 if (this->m_rhs_inner_dim_contiguous) { 1375 if (this->m_rhs_inner_dim_reordered) { 1376 evalTyped<false, true, true, Unaligned>(buffer); 1377 } else { 1378 evalTyped<false, true, false, Unaligned>(buffer); 1379 } 1380 } else { 1381 if (this->m_rhs_inner_dim_reordered) { 1382 evalTyped<false, false, true, Unaligned>(buffer); 1383 } else { 1384 evalTyped<false, false, false, Unaligned>(buffer); 1385 } 1386 } 1387 } 1388 } 1389 1390 template <bool lhs_inner_dim_contiguous, bool rhs_inner_dim_contiguous, bool rhs_inner_dim_reordered, int Alignment> 1391 void evalTyped(typename Base::EvaluatorPointerType buffer) const { 1392 const auto triple_dim = TripleDim{this->m_i_size, this->m_j_size, this->m_k_size}; 1393 typedef internal::TensorContractionInputMapper< 1394 LhsScalar, StorageIndex, internal::Lhs, LeftEvaluator, left_nocontract_t, contract_t, 1395 PacketType<CoeffReturnType, Device>::size, lhs_inner_dim_contiguous, false, Unaligned, MakeSYCLPointer> 1396 LhsMapper; 1397 1398 typedef internal::TensorContractionInputMapper<RhsScalar, StorageIndex, internal::Rhs, RightEvaluator, 1399 right_nocontract_t, contract_t, 1400 PacketType<CoeffReturnType, Device>::size, rhs_inner_dim_contiguous, 1401 rhs_inner_dim_reordered, Unaligned, MakeSYCLPointer> 1402 RhsMapper; 1403 1404 // initialize data mappers 1405 LhsMapper lhs(this->m_leftImpl, this->m_left_nocontract_strides, this->m_i_strides, 1406 this->m_left_contracting_strides, this->m_k_strides); 1407 1408 RhsMapper rhs(this->m_rightImpl, this->m_right_nocontract_strides, this->m_j_strides, 1409 this->m_right_contracting_strides, this->m_k_strides); 1410 1411 #ifndef EIGEN_SYCL_DISABLE_SCALAR 1412 if (triple_dim.M == 1 && triple_dim.N == 1) { 1413 launchSC(buffer, lhs, rhs, triple_dim.K); 1414 } else 1415 #endif 1416 #ifndef EIGEN_SYCL_DISABLE_GEMV 1417 if (triple_dim.M != 1 && triple_dim.N == 1) { 1418 LaunchVT<false>(buffer, rhs, lhs, triple_dim.M, triple_dim.K); 1419 } else if (triple_dim.M == 1 && triple_dim.N != 1) { 1420 LaunchVT<true>(buffer, lhs, rhs, triple_dim.N, triple_dim.K); 1421 } else // This is equivalent of if (m!=1 && n!=1) 1422 #endif 1423 { 1424 typedef input_mapper_propertis<lhs_inner_dim_contiguous, rhs_inner_dim_contiguous, rhs_inner_dim_reordered> 1425 inpt_mapper_properties; 1426 #ifndef EIGEN_SYCL_DISABLE_SKINNY 1427 bool skinny = false; 1428 auto platform_name = this->device().getPlatformName(); 1429 // This is based on empirical calculation for AMD r9-nano and Fiji 1430 if (platform_name.find("AMD") == 0) { 1431 skinny = (triple_dim.M < triple_dim.K || triple_dim.N < triple_dim.K) && 1432 ((triple_dim.M < 1024 && triple_dim.N < 1024) || 1433 (uint64_t(triple_dim.M * triple_dim.N) < uint64_t(triple_dim.K))); 1434 } else { 1435 skinny = (((std::max(triple_dim.K, triple_dim.N) / std::min(triple_dim.K, triple_dim.N)) > 100) || 1436 ((std::max(triple_dim.K, triple_dim.M) / std::min(triple_dim.K, triple_dim.M)) > 100) || 1437 ((std::max(triple_dim.N, triple_dim.M) / std::min(triple_dim.N, triple_dim.M)) > 100)); 1438 } 1439 if (skinny) 1440 adjustTT<true, inpt_mapper_properties>(buffer, lhs, rhs, triple_dim); 1441 else 1442 #endif // EIGEN_SYCL_DISABLE_SKINNY 1443 adjustTT<false, inpt_mapper_properties>(buffer, lhs, rhs, triple_dim); 1444 } 1445 } 1446 1447 template <bool skinny, typename input_mapper_properties, typename LhsMapper, typename RhsMapper> 1448 void EIGEN_ALWAYS_INLINE adjustTT(EvaluatorPointerType buffer, const LhsMapper &lhs, const RhsMapper &rhs, 1449 const TripleDim &triple_dim) const { 1450 #ifdef EIGEN_SYCL_LOCAL_MEM_UNSET_OR_ON 1451 if (device().has_local_memory()) { 1452 typedef TensorSycl::internal::TTPanelSize<CoeffReturnType, StorageIndex, 4, 4, 16> PanelParameters; 1453 launchTT<TensorSycl::internal::contraction_type::local, skinny, input_mapper_properties, PanelParameters>( 1454 buffer, lhs, rhs, triple_dim); 1455 } 1456 #endif 1457 #ifdef EIGEN_SYCL_LOCAL_MEM_UNSET_OR_OFF 1458 if (!(device().has_local_memory())) { 1459 typedef TensorSycl::internal::TTPanelSize<CoeffReturnType, StorageIndex, 4, 4, 4> PanelParameters; 1460 launchTT<TensorSycl::internal::contraction_type::no_local, skinny, input_mapper_properties, PanelParameters>( 1461 buffer, lhs, rhs, triple_dim); 1462 } 1463 #endif 1464 } 1465 1466 template <TensorSycl::internal::contraction_type ct, bool skinny, typename input_mapper_properties, 1467 typename Properties, typename LhsMapper, typename RhsMapper> 1468 void launchTT(EvaluatorPointerType buffer, const LhsMapper &lhs, const RhsMapper &rhs, 1469 const TripleDim &triple_dim) const { 1470 const StorageIndex roundUpM = Eigen::TensorSycl::internal::roundUp(triple_dim.M, Properties::TileSizeDimM); 1471 const StorageIndex roundUpN = Eigen::TensorSycl::internal::roundUp(triple_dim.N, Properties::TileSizeDimN); 1472 const StorageIndex groupSizeM = roundUpM / Properties::TileSizeDimM; 1473 const StorageIndex groupSizeN = roundUpN / Properties::TileSizeDimN; 1474 1475 const StorageIndex roundUpK = Eigen::TensorSycl::internal::roundUp(triple_dim.K, Properties::TileSizeDimK); 1476 StorageIndex totalTilesK = roundUpK / Properties::TileSizeDimK; 1477 StorageIndex groupSizeK = 1478 skinny 1479 ? std::max(std::min(totalTilesK, 1480 (StorageIndex)(device().getPowerOfTwo(device().getNumSyclMultiProcessors(), true) * 4) / 1481 (groupSizeM * groupSizeN)), 1482 StorageIndex(1)) 1483 : StorageIndex(1); 1484 1485 const StorageIndex numTilesPerGroup = Eigen::TensorSycl::internal::roundUp(totalTilesK, groupSizeK) / groupSizeK; 1486 1487 const StorageIndex totalGroupSize = groupSizeM * groupSizeN * groupSizeK; 1488 1489 const StorageIndex localRange = Properties::LocalThreadSizeM * Properties::LocalThreadSizeN; 1490 const StorageIndex globalRange = totalGroupSize * localRange; 1491 1492 const StorageIndex scratchSize = (ct == TensorSycl::internal::contraction_type::local) 1493 ? ((Properties::DoubleBuffer + 1) * 1494 (Properties::TileSizeDimM + Properties::BC) * (Properties::TileSizeDimK)) + 1495 ((Properties::DoubleBuffer + 1) * (Properties::TileSizeDimK) * 1496 (Properties::TileSizeDimN + Properties::BC)) 1497 : StorageIndex(1); 1498 1499 auto thread_range = cl::sycl::nd_range<1>(cl::sycl::range<1>(globalRange), cl::sycl::range<1>(localRange)); 1500 if (groupSizeK == 1) { 1501 typedef TensorSycl::internal::TensorContractionKernel<CoeffReturnType, LhsScalar, RhsScalar, EvaluatorPointerType, 1502 LhsMapper, RhsMapper, StorageIndex, Properties, TripleDim, 1503 PacketAccess, input_mapper_properties, true, ct> 1504 ContractKernelName; 1505 device().template binary_kernel_launcher<CoeffReturnType, ContractKernelName>( 1506 lhs, rhs, buffer, thread_range, scratchSize, groupSizeM, groupSizeN, numTilesPerGroup, triple_dim); 1507 } else { 1508 typedef TensorSycl::internal::TensorContractionKernel<CoeffReturnType, LhsScalar, RhsScalar, EvaluatorPointerType, 1509 LhsMapper, RhsMapper, StorageIndex, Properties, TripleDim, 1510 PacketAccess, input_mapper_properties, false, ct> 1511 ContractKernelName; 1512 CoeffReturnType *temp_pointer = static_cast<CoeffReturnType *>( 1513 device().allocate_temp(triple_dim.M * triple_dim.N * groupSizeK * sizeof(CoeffReturnType))); 1514 EvaluatorPointerType tmp_global_accessor = device().get(temp_pointer); 1515 1516 device().template binary_kernel_launcher<CoeffReturnType, ContractKernelName>( 1517 lhs, rhs, tmp_global_accessor, thread_range, scratchSize, groupSizeM, groupSizeN, numTilesPerGroup, 1518 triple_dim); 1519 1520 typedef Eigen::internal::SumReducer<CoeffReturnType> Op; 1521 auto op = Op(); 1522 typedef TensorSycl::internal::SecondStepPartialReduction<CoeffReturnType, StorageIndex, EvaluatorPointerType, 1523 EvaluatorPointerType, Op> 1524 ReductionKernel; 1525 1526 device().template unary_kernel_launcher<CoeffReturnType, ReductionKernel>( 1527 tmp_global_accessor, buffer, 1528 cl::sycl::nd_range<1>(cl::sycl::range<1>(StorageIndex( 1529 Eigen::TensorSycl::internal::roundUp(triple_dim.M * triple_dim.N, localRange))), 1530 cl::sycl::range<1>(localRange)), 1531 StorageIndex(1), op, StorageIndex(triple_dim.M * triple_dim.N), groupSizeK); 1532 1533 device().deallocate_temp(temp_pointer); 1534 } 1535 } 1536 1537 #ifndef EIGEN_SYCL_DISABLE_GEMV 1538 template <bool is_lhs_vec, typename VectorMapper, typename TensorMapper, typename StorageIndex> 1539 void EIGEN_ALWAYS_INLINE LaunchVT(EvaluatorPointerType buffer, const VectorMapper &vec, const TensorMapper &mat, 1540 StorageIndex NC, StorageIndex C) const { 1541 const StorageIndex nonContractDim = NC; 1542 EIGEN_CONSTEXPR StorageIndex NCFactor = 1; 1543 EIGEN_CONSTEXPR StorageIndex CFactor = 1; 1544 EIGEN_CONSTEXPR StorageIndex NCWindow = 16; 1545 typedef Eigen::TensorSycl::internal::TVPanelSize<CoeffReturnType, StorageIndex, NCWindow, CFactor, NCFactor> 1546 Properties; 1547 const StorageIndex roundUpC = Eigen::TensorSycl::internal::roundUp(C, Properties::TileSizeDimC); 1548 const StorageIndex cNumGroups = roundUpC / (Properties::LocalThreadSizeC * Properties::WorkLoadPerThreadC); 1549 const StorageIndex roundUpNC = Eigen::TensorSycl::internal::roundUp(nonContractDim, Properties::TileSizeDimNC); 1550 const StorageIndex nCNumGroups = roundUpNC / (Properties::LocalThreadSizeNC * Properties::WorkLoadPerThreadNC); 1551 const StorageIndex globalRange = 1552 (roundUpNC / (Properties::WorkLoadPerThreadNC)) * (roundUpC / (Properties::WorkLoadPerThreadC)); 1553 const StorageIndex localRange = Properties::LocalThreadSizeNC * Properties::LocalThreadSizeC; 1554 const StorageIndex scratchSize = 1555 (Properties::WorkLoadPerThreadNC + CFactor) * Properties::LocalThreadSizeC * Properties::LocalThreadSizeNC; 1556 auto thread_range = cl::sycl::nd_range<1>(cl::sycl::range<1>(globalRange), cl::sycl::range<1>(localRange)); 1557 if (cNumGroups > 1) { 1558 typedef Eigen::TensorSycl::internal::GeneralVectorTensor<CoeffReturnType, EvaluatorPointerType, VectorMapper, 1559 TensorMapper, StorageIndex, Properties, CFactor, false, 1560 is_lhs_vec, false> 1561 ContractKernelName; 1562 CoeffReturnType *temp_pointer = 1563 static_cast<CoeffReturnType *>(device().allocate_temp(nonContractDim * cNumGroups * sizeof(CoeffReturnType))); 1564 EvaluatorPointerType tmp_global_accessor = device().get(temp_pointer); 1565 1566 device().template binary_kernel_launcher<CoeffReturnType, ContractKernelName>( 1567 vec, mat, tmp_global_accessor, thread_range, scratchSize, nCNumGroups, nonContractDim, C); 1568 1569 typedef Eigen::internal::SumReducer<CoeffReturnType> Op; 1570 typedef TensorSycl::internal::SecondStepPartialReduction<CoeffReturnType, StorageIndex, EvaluatorPointerType, 1571 EvaluatorPointerType, Op> 1572 ReductionKernel; 1573 1574 device().template unary_kernel_launcher<CoeffReturnType, ReductionKernel>( 1575 tmp_global_accessor, buffer, 1576 cl::sycl::nd_range<1>(cl::sycl::range<1>(Eigen::TensorSycl::internal::roundUp(nonContractDim, localRange)), 1577 cl::sycl::range<1>(localRange)), 1578 StorageIndex(1), Op(), nonContractDim, cNumGroups); 1579 1580 device().deallocate_temp(temp_pointer); 1581 } else { 1582 typedef Eigen::TensorSycl::internal::GeneralVectorTensor<CoeffReturnType, EvaluatorPointerType, VectorMapper, 1583 TensorMapper, StorageIndex, Properties, CFactor, false, 1584 is_lhs_vec, true> 1585 ContractKernelName; 1586 device().template binary_kernel_launcher<CoeffReturnType, ContractKernelName>( 1587 vec, mat, buffer, thread_range, scratchSize, nCNumGroups, nonContractDim, C); 1588 } 1589 } 1590 #endif 1591 1592 #ifndef EIGEN_SYCL_DISABLE_SCALAR 1593 template <typename LhsMapper, typename RhsMapper> 1594 EIGEN_ALWAYS_INLINE void launchSC(EvaluatorPointerType buffer, const LhsMapper &lhs, const RhsMapper &rhs, 1595 StorageIndex K) const { 1596 EIGEN_STATIC_ASSERT(!((EIGEN_SYCL_LOCAL_THREAD_DIM0 * EIGEN_SYCL_LOCAL_THREAD_DIM1) & 1597 (EIGEN_SYCL_LOCAL_THREAD_DIM0 * EIGEN_SYCL_LOCAL_THREAD_DIM1 - 1)), 1598 "The Local thread size must be a power of 2 for the reduction " 1599 "operation"); 1600 EIGEN_CONSTEXPR StorageIndex local_range = EIGEN_SYCL_LOCAL_THREAD_DIM0 * EIGEN_SYCL_LOCAL_THREAD_DIM1; 1601 1602 // Here we force the code not to be more than 2-step reduction: Our empirical research shows that if each thread 1603 // reduces at least 512 elementss individually, we get better performance. 1604 const StorageIndex num_work_group = ((K + (512 * local_range - 1)) / (512 * local_range) > 1 ? local_range : 1); 1605 const StorageIndex global_range = num_work_group * local_range; 1606 1607 typedef Eigen::TensorSycl::internal::GeneralScalarContraction< 1608 CoeffReturnType, LhsScalar, RhsScalar, EvaluatorPointerType, LhsMapper, RhsMapper, StorageIndex, false> 1609 ContractKernelName; 1610 auto thread_range = cl::sycl::nd_range<1>(cl::sycl::range<1>(global_range), cl::sycl::range<1>(local_range)); 1611 if (num_work_group > 1) { 1612 CoeffReturnType *temp_pointer = 1613 static_cast<CoeffReturnType *>(device().allocate_temp(num_work_group * sizeof(CoeffReturnType))); 1614 EvaluatorPointerType tmp_global_accessor = device().get(temp_pointer); 1615 device().template binary_kernel_launcher<CoeffReturnType, ContractKernelName>(lhs, rhs, tmp_global_accessor, 1616 thread_range, local_range, K); 1617 typedef Eigen::internal::SumReducer<CoeffReturnType> Op; 1618 typedef TensorSycl::internal::SecondStepFullReducer<CoeffReturnType, Op, EvaluatorPointerType, 1619 EvaluatorPointerType, StorageIndex, local_range> 1620 GenericRKernel; 1621 device().template unary_kernel_launcher<CoeffReturnType, GenericRKernel>( 1622 tmp_global_accessor, buffer, 1623 cl::sycl::nd_range<1>(cl::sycl::range<1>(local_range), cl::sycl::range<1>(local_range)), local_range, Op()); 1624 1625 device().deallocate_temp(temp_pointer); 1626 } else { 1627 device().template binary_kernel_launcher<CoeffReturnType, ContractKernelName>(lhs, rhs, buffer, thread_range, 1628 local_range, K); 1629 } 1630 } 1631 #endif 1632 1633 EIGEN_STRONG_INLINE void cleanup() { 1634 this->m_leftImpl.cleanup(); 1635 this->m_rightImpl.cleanup(); 1636 1637 if (this->m_result) { 1638 this->m_device.deallocate_temp(this->m_result); 1639 this->m_result = NULL; 1640 } 1641 } 1642 // The placeholder accessors must bound to a command group handler for SYCL 1643 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void bind(cl::sycl::handler &cgh) const { 1644 this->m_leftImpl.bind(cgh); 1645 this->m_rightImpl.bind(cgh); 1646 this->m_result.bind(cgh); 1647 } 1648 }; 1649 } // namespace Eigen 1650 #endif // EIGEN_CXX11_TENSOR_TENSOR_CONTRACTION_SYCL_H