xref: /aosp_15_r20/external/eigen/unsupported/Eigen/CXX11/src/Tensor/TensorContractionSycl.h (revision bf2c37156dfe67e5dfebd6d394bad8b2ab5804d4)
1 // This file is part of Eigen, a lightweight C++ template library for linear algebra.
2 //
3 // Mehdi Goli    Codeplay Software Ltd.
4 // Ralph Potter  Codeplay Software Ltd.
5 // Luke Iwanski  Codeplay Software Ltd.
6 // Contact: <[email protected]>
7 //
8 // This Source Code Form is subject to the terms of the Mozilla Public License v. 2.0. If a copy of the MPL was not
9 // distributed with this file, You can obtain one at http://mozilla.org/MPL/2.0/.
10 
11 /*****************************************************************
12  * TensorContractionSycl.h
13  *
14  * \brief:
15  *  TensorContractionSycl.h, provides various tensor contraction kernel for SYCL backend
16  *
17  *****************************************************************/
18 
19 #ifndef EIGEN_CXX11_TENSOR_TENSOR_CONTRACTION_SYCL_H
20 #define EIGEN_CXX11_TENSOR_TENSOR_CONTRACTION_SYCL_H
21 
22 namespace Eigen {
23 
24 namespace TensorSycl {
25 namespace internal {
26 
27 #ifndef EIGEN_SYCL_DISABLE_GEMV
28 /*!
29  * \brief TVPanelSize, a template class used for setting the panel size required for launching General TensorVector
30  * contraction kernel on various hardware devices.
31  *
32  * \tparam Scalar: determines the element type of the tensor/vector
33  *
34  * \tparam StorageIndex  determines the Index type.
35  *
36  * \tparam NCWindow: determines the number of non-contracting element to be process by each work-group
37  *
38  * \tparam CFactor: determines the number of contracting element to be process by each thread
39  *
40  * \tparam NCFactor: determines the number of non-contracting element to be process by each thread
41  */
42 template <typename Scalar, typename StorageIndex, StorageIndex NCWindow, StorageIndex CFactor, StorageIndex NCFactor>
43 struct TVPanelSize {
44   // LocalThreadSizeC: determines total number of thread per workgroup for the contracting dimension
45   static EIGEN_CONSTEXPR StorageIndex LocalThreadSizeC = EIGEN_SYCL_LOCAL_THREAD_DIM0;
46   // LocalThreadSizeNC: determines total number of thread per workgroup for the non-contracting dimension
47   static EIGEN_CONSTEXPR StorageIndex LocalThreadSizeNC = EIGEN_SYCL_LOCAL_THREAD_DIM1;
48   // TileSizeDimNC: determines the tile size for the non-contracting dimension
49   static EIGEN_CONSTEXPR StorageIndex TileSizeDimNC = NCWindow / NCFactor;
50   // TileSizeDimC: determines the tile size for the contracting dimension
51   static EIGEN_CONSTEXPR StorageIndex TileSizeDimC = CFactor * LocalThreadSizeNC * LocalThreadSizeC;
52   // WorkLoadPerThreadNC : determines workload per thread for loading the non-contracting dimension
53   static EIGEN_CONSTEXPR StorageIndex WorkLoadPerThreadNC = TileSizeDimNC / LocalThreadSizeNC;
54   // WorkLoadPerThreadC: determines workload per thread for loading the non-contracting dimension
55   static EIGEN_CONSTEXPR StorageIndex WorkLoadPerThreadC = TileSizeDimC / LocalThreadSizeC;
56   // BC : determines if supporting bank conflict is required
57   static EIGEN_CONSTEXPR bool BC = false;
58 };
59 #endif
60 
61 /*!
62  * \brief TTPanelSize, a template class used for setting the panel size required for launching General Tensor Tensor
63  contraction kernel on various hardware devices.
64  *
65  * \tparam Scalar: determines the element type of the tensor
66  *
67  * \tparam StorageIndex: determines the Index type.
68  *
69  * \tparam REG_SIZE_M: determines workload per thread for loading the M dimension This can be varied based on the
70  available register on a chosen device(can be controlled by EIGEN_SYCL_REG_M macro).
71  *
72  * \tparam REG_SIZE_N: determines workload per thread for loading the N dimension This can be varied based on the
73  available register on a chosen device(can be controlled by EIGEN_SYCL_REG_N macro).
74  *
75  * \tparam TSDK: determines Tile size for dimension K. The packet size is assumed to be considered
76  */
77 
78 template <typename Scalar, typename StorageIndex, StorageIndex REG_SIZE_M, StorageIndex REG_SIZE_N, StorageIndex TSDK>
79 struct TTPanelSize {
80   // TileSizeDimK: determines Tile size for dimension K. The packet size is assumed to be considered
81   static EIGEN_CONSTEXPR StorageIndex TileSizeDimK = TSDK;
82   // WorkLoadPerThreadM : determines workload per thread for loading the M dimension This can be varied based on the
83   // available register on a chosen device(can be controlled by EIGEN_SYCL_REG_M macro//
84 #ifndef EIGEN_SYCL_REG_M
85   static EIGEN_CONSTEXPR StorageIndex WorkLoadPerThreadM = REG_SIZE_M;
86 #else
87   static EIGEN_CONSTEXPR StorageIndex WorkLoadPerThreadM = EIGEN_SYCL_REG_M;
88 #endif
89 // WorkLoadPerThreadN : determines workload per thread for loading the N dimension This can be varied based on the
90 // available register on a chosen device(can be controlled by EIGEN_SYCL_REG_N macro
91 #ifndef EIGEN_SYCL_REG_N
92   static EIGEN_CONSTEXPR StorageIndex WorkLoadPerThreadN = REG_SIZE_N;
93 #else
94   static EIGEN_CONSTEXPR StorageIndex WorkLoadPerThreadN = EIGEN_SYCL_REG_N;
95 #endif
96   // LocalThreadSizeM: determines total number of thread per workgroup for the m dimension
97   static EIGEN_CONSTEXPR StorageIndex LocalThreadSizeM = EIGEN_SYCL_LOCAL_THREAD_DIM0;
98   // LocalThreadSizeN: determines total number of thread per workgroup for the n dimension
99   static EIGEN_CONSTEXPR StorageIndex LocalThreadSizeN = EIGEN_SYCL_LOCAL_THREAD_DIM1;
100   // TileSizeDimM: determines the tile size for the m dimension
101   static EIGEN_CONSTEXPR StorageIndex TileSizeDimM = LocalThreadSizeM * WorkLoadPerThreadM;
102   // TileSizeDimN: determines the tile size for the n dimension
103   static EIGEN_CONSTEXPR StorageIndex TileSizeDimN = LocalThreadSizeN * WorkLoadPerThreadN;
104   // LoadPerThreadLhs: determines workload per thread for loading Lhs Tensor. This must be divisable by packetsize
105   static EIGEN_CONSTEXPR StorageIndex LoadPerThreadLhs =
106       ((TileSizeDimK * WorkLoadPerThreadM * WorkLoadPerThreadN) / (TileSizeDimN));
107   // LoadPerThreadRhs: determines workload per thread for loading Rhs Tensor. This must be divisable by packetsize
108   static EIGEN_CONSTEXPR StorageIndex LoadPerThreadRhs =
109       ((TileSizeDimK * WorkLoadPerThreadM * WorkLoadPerThreadN) / (TileSizeDimM));
110   // BC : determines if supporting bank conflict is required
111   static EIGEN_CONSTEXPR bool BC = true;
112   // DoubleBuffer: determines if double buffering technique should be used (This can be disabled by
113   // EIGEN_SYCL_DISABLE_DOUBLE_BUFFER macro when the device doesnot have sufficient  local memory)
114   static EIGEN_CONSTEXPR bool DoubleBuffer =
115 #ifdef EIGEN_SYCL_DISABLE_DOUBLE_BUFFER
116       false;
117 #else
118       true;
119 #endif
120 };
121 
122 /* !
123  * \brief contraction_type: an enum class representing the Tensor Contraction implementation algorithm. This is used to
124  * specialize the contraction algorithm based on device support for dedicated local memory.
125  */
126 enum class contraction_type { local, no_local };
127 /* !
128  * \brief data_source an enum class determining the location of the data in a memory hierarchy (global, local, private).
129  */
130 enum class data_source { global_mem, local_mem, private_mem };
131 
132 /*!
133  * \brief read, a template function used for loading the data from global
134  memory. This function is used to guarantee coalesced and vectorized load whenever possible
135  *
136  * \tparam PacketLoad: determines if the each element of this tensor block should be loaded in a packet mode
137  *
138  * \param is_coalesced_layout: determines whether or not the Tensor data in a memory can be access coalesced and
139  vectorized when possible. Coalesced memory access is a key factor in Kernel performance. When a tensor is 2d and the
140  contracting dimension is 1, it is always possible to accessed tensor data coalesced and vectorized. This is the case
141  when RHS(right hand side) Tensor is transposed or when LHS(left hand side) Tensor is not transposed.
142  *
143  * \tparam PacketType:  determines the type of packet
144  *
145  * \tparam TensorMapper: determines the input tensor mapper type
146  *
147  * \tparam StorageIndex: determines the Index type
148 
149  * \param tensorMapper: is the input tensor
150  *
151  * \param NCIndex: is the non-contracting dim index
152  *
153  * \param CIndex is the contracting dim index
154  *
155  * \param ld: is the leading dimension of the flattened tensor
156  */
157 template <bool PacketLoad, bool is_coalesced_layout, bool, typename PacketType, typename TensorMapper,
158           typename StorageIndex>
read(const TensorMapper & tensorMapper,const StorageIndex & NCIndex,const StorageIndex & CIndex,const StorageIndex & ld)159 static EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE typename ::Eigen::internal::enable_if<PacketLoad, PacketType>::type read(
160     const TensorMapper &tensorMapper, const StorageIndex &NCIndex, const StorageIndex &CIndex, const StorageIndex &ld) {
161   const StorageIndex row = (is_coalesced_layout) ? NCIndex : CIndex;
162   const StorageIndex col = (is_coalesced_layout) ? CIndex : NCIndex;
163   return tensorMapper.get_tensor().template packet<Unaligned>(row + (col * ld));
164 }
165 
166 /*!
167  * \brief read, special overload of read function, when the read access is not vectorized
168  *
169  * \tparam PacketLoad: determines if the each element of this tensor block should be loaded in a packet mode
170  *
171  * \param is_coalesced_layout: determines whether or not the Tensor data in a memory can be access coalesced and
172   vectorized when possible. Coalesced memory access is a key factor in Kernel performance. When a tensor is 2d and the
173   contracting dimension is 1, it is always possible to accessed tensor data coalesced and vectorized. This is the case
174   when RHS(right hand side) Tensor is transposed or when LHS(left hand side) Tensor is not transposed.
175  *
176  * \tparam PacketType: determines the type of packet
177  *
178  * \tparam TensorMapper: determines the input tensor mapper type
179  *
180  * \tparam StorageIndex: determines the Index type
181 
182  * \param tensorMapper: is the input tensor
183  *
184  * \param NCIndex: is the non-contracting dim index
185  *
186  * \param CIndex: is the contracting dim index
187  */
188 template <bool PacketLoad, bool, bool IsRhs, typename PacketType, typename TensorMapper, typename StorageIndex>
read(const TensorMapper & tensorMapper,const StorageIndex & NCIndex,const StorageIndex & CIndex,const StorageIndex &)189 static EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE typename ::Eigen::internal::enable_if<!PacketLoad, PacketType>::type read(
190     const TensorMapper &tensorMapper, const StorageIndex &NCIndex, const StorageIndex &CIndex, const StorageIndex &) {
191   const StorageIndex row = (IsRhs) ? CIndex : NCIndex;
192   const StorageIndex col = (IsRhs) ? NCIndex : CIndex;
193   return tensorMapper(row, col);
194 }
195 
196 /*!
197  * \brief write, a template function used for storing the data to local memory. This function is used to guarantee
198  * coalesced and vectorized store whenever possible.
199  *
200  * \tparam StorageIndex: determines the Index type
201  *
202  * \param ld is the leading dimension of the local memory. ld is a compile time value for the local memory
203  *
204  * \tparam data_source: an enum value representing if the location of the data in a memory hierarchy.
205  *
206  * \tparam PacketType:  determines the type of packet
207  *
208  * \tparam DataScalar: determines the output data type
209  *
210  * \param packet_data: the data to be written in the local memory
211  *
212  * \param ptr: a pointer to the local memory
213  *
214  * \param CIndex is the contracting dim index
215  */
216 
217 template <typename StorageIndex, StorageIndex ld, data_source dt, typename PacketType, typename DataScalar>
218 static EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE
219     typename ::Eigen::internal::enable_if<dt != data_source::global_mem, void>::type
write(PacketType & packet_data,DataScalar ptr)220     write(PacketType &packet_data, DataScalar ptr) {
221   EIGEN_CONSTEXPR int PacketSize = Eigen::internal::unpacket_traits<PacketType>::size;
222   EIGEN_UNROLL_LOOP
223   for (int i = 0; i < PacketSize; i++) {
224     *ptr = PacketWrapper<PacketType, PacketSize>::scalarize(i, packet_data);
225     ptr += ld;
226   }
227 }
228 
229 /*!
230  * \brief Overloading the write function for storing the data to global memory, when vectorization enabled This function
231  * is used to guarantee coalesced and vectorized store whenever possible.
232  *
233  * \tparam data_source: an enum value representing if the location of the data in a memory hierarchy.
234  *
235  * \tparam PacketType:  determines the type of packet
236  *
237  * \tparam DataScalar: determines the output data type
238  *
239  * \param packet_data: the data to be written in the local memory
240  *
241  * \param ptr: a pointer to the local memory
242  */
243 
244 template <data_source dt, typename PacketType, typename DataScalar>
245 static EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE typename ::Eigen::internal::enable_if<
246     Eigen::internal::unpacket_traits<PacketType>::size != 1 && dt == data_source::global_mem, void>::type
write(PacketType & packet_data,DataScalar * ptr)247 write(PacketType &packet_data, DataScalar *ptr) {
248   ::Eigen::internal::pstoreu<DataScalar, PacketType>(ptr, packet_data);
249 }
250 
251 /*!
252  * \brief Overloading the write function for storing the data to global memory, when vectorization is disabled.
253  *
254  * \tparam data_source: an enum value representing if the location of the data in a memory hierarchy.
255  *
256  * \tparam PacketType:  determines the type of packet
257  *
258  * \tparam DataScalar: determines the output data type
259  *
260  * \param packet_data: the data to be written in the local memory
261  *
262  * \param ptr: a pointer to the local memory
263  */
264 template <data_source dt, typename PacketType, typename DataScalar>
265 static EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE typename ::Eigen::internal::enable_if<
266     Eigen::internal::unpacket_traits<PacketType>::size == 1 && dt == data_source::global_mem, void>::type
write(PacketType & packet_data,DataScalar * ptr)267 write(PacketType &packet_data, DataScalar *ptr) {
268   *ptr = packet_data;
269 }
270 
271 /*!
272  * \brief check_boundary: is used to check the edge condition for non-internal blocks.
273  *
274  * \tparam is_internal: determines if the block is internal
275  */
276 template <bool is_internal>
check_boundary(bool)277 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE bool check_boundary(bool) {
278   return true;
279 }
280 
281 /*!
282  * \brief check_boundary: specialization of the check_boundary for non-internal blocks.
283  *
284  * \param cond: true when the data is in range. Otherwise false
285  */
286 template <>
287 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE bool check_boundary<false>(bool cond) {
288   return cond;
289 }
290 
291 /*!
292  * \brief BlockProperties is a template class that provides different characteristic of a block of each Tensor processed
293  * by each workgroup.
294  *
295  * \tparam is_transposed: iff true, determines whether or not the block of the Tensor is transposed
296  *
297  * \tparam packet_load_: determines if the each element of this tensor block should be loaded in a packet mode
298  *
299  * \tparam PacketType:  determines the type of packet
300  *
301  * \tparam OutType: determines the type of each element for this block of tensor. If packet load is true, it will be
302  * packetType; Otherwise it will be scalar Type
303  *
304  * \param elements_per_access determines the size of each element based on OutType
305  *
306  * \param is_coalesced_layout  determines whether or not the Tensor data in a memory can be access coalesced and
307  * vectorized when possible. Coalesced memory access is a key factor in Kernel performance. When a tensor is 2d and the
308  * contracting dimension is 1, it is always possible to accessed tensor data coalesced and vectorized. This is the case
309  * when RHS(right hand side) Tensor is transposed or when LHS(left hand side) Tensor is not transposed.
310  *
311  * \param nc_stride determines the stride of non-contracting dimension to access the next adjustment element within the
312  * Tensor Block for each workgroup
313  *
314  * \param c_stride  determines the stride of contracting dimension to access the next adjustment element within the
315  * Tensor Block for each workgroup
316  */
317 template <bool is_transposed, bool is_rhs_, bool packet_load_, typename PacketType>
318 struct BlockProperties {
319   static EIGEN_CONSTEXPR bool packet_load = packet_load_;
320   typedef typename Eigen::internal::unpacket_traits<PacketType>::type OutScalar;
321   static EIGEN_CONSTEXPR bool is_rhs = is_rhs_;
322   typedef typename Eigen::internal::conditional<packet_load, PacketType, OutScalar>::type OutType;
323   static EIGEN_CONSTEXPR int elements_per_access = Eigen::internal::unpacket_traits<OutType>::size;
324   static EIGEN_CONSTEXPR bool is_coalesced_layout = !(is_transposed ^ is_rhs);
325   static EIGEN_CONSTEXPR int nc_stride = (is_coalesced_layout ? elements_per_access : 1);
326   static EIGEN_CONSTEXPR int c_stride = (is_coalesced_layout ? 1 : elements_per_access);
327 };
328 
329 /*!
330  * \brief ThreadProperties is a template class that provides each thread's properties within a workgroup.  Please see
331  * the sycl-1.2.1 specification (https://www.khronos.org/registry/SYCL/specs/sycl-1.2.1.pdf) for the workgroup,
332  * work-items
333  *
334  * \tparam StorageIndex: determines the StorageIndex Type
335  *
336  * \param linearLocalThreadId: determines the linearized location of a thread within a work-group
337  *
338  * \param kGroupId: determines the logical group id in a k dimension of the flattened tensor. It will be > 1 when
339  * tall/skinny algorithm is used
340  *
341  * \param mGroupOffset: determines the logical start position of all thread within a workgroup for the m dimension of
342  * the flattened tensor.
343  *
344  * \param kGroupOffset determines the logical start position of all thread within a workgroup for the k dimension of the
345  * flattened tensor. It will be > 1 when tall/skinny algorithm is used.
346  *
347  * \param mLocalOffset: determines the logical start position of each thread within a workgroup for the m dimension of a
348  * flattened tensor. The position determines the distance of each thread within the workgroup from each other
349  * independent from their global position.
350  *
351  * \param nLocalOffset: determines the logical start position of each thread within a workgroup for the n dimension of a
352  * flattened tensor. The position determines the distance of each thread within the workgroup from each other
353  * independent from their global position.
354  *
355  * \param mGlobalOffset: determines the logical start position of each thread a thread for the m dimension on a
356  * flattened tensor
357  *
358  * \param nGlobalOffset: determines the logical start position of each thread a thread for the n dimension on a
359  * flattened tensor
360  *
361  * \param kSize : determine the number of the k elements of the flattened Tensor to be processed by each thread for the
362  * given tensor block. This is !=K dimension of Flattened Tensor when Tall/Skinny matrix is used.
363  *
364  * \param is_internal : this will determined if the thread within the work-group computes an internal block of tensor or
365  * the edge blocks. When it is internal, there is no need to check the boundaries and all the if stantement can be
366  * resolve by compiler.
367  */
368 template <typename StorageIndex>
369 struct ThreadProperties {
370   const StorageIndex linearLocalThreadId;
371   const StorageIndex kGroupId;
372   const StorageIndex mGroupOffset;
373   const StorageIndex nGroupOffset;
374   const StorageIndex kGroupOffset;
375   const StorageIndex mLocalOffset;
376   const StorageIndex nLocalOffset;
377   const StorageIndex mGlobalOffset;
378   const StorageIndex nGlobalOffset;
379   StorageIndex kSize;
380   const bool is_internal;
381   // this is used to adjust the last block
ThreadPropertiesThreadProperties382   EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE ThreadProperties(
383       const StorageIndex linearLocalThreadId_, const StorageIndex kGroupId_, const StorageIndex mGroupOffset_,
384       const StorageIndex nGroupOffset_, const StorageIndex kGroupOffset_, const StorageIndex mLocalOffset_,
385       const StorageIndex nLocalOffset_, const StorageIndex mGlobalOffset_, const StorageIndex nGlobalOffset_,
386       StorageIndex kSize_, const bool is_internal_)
387       : linearLocalThreadId(linearLocalThreadId_),
388         kGroupId(kGroupId_),
389         mGroupOffset(mGroupOffset_),
390         nGroupOffset(nGroupOffset_),
391         kGroupOffset(kGroupOffset_),
392         mLocalOffset(mLocalOffset_),
393         nLocalOffset(nLocalOffset_),
394         mGlobalOffset(mGlobalOffset_),
395         nGlobalOffset(nGlobalOffset_),
396         kSize(kSize_),
397         is_internal(is_internal_) {}
398 };
399 
400 /*!
401  * \brief TensorContractionKernel is a template class that provides Tensor -Tensor contraction operation.
402  *
403  * \tparam OutScalar: determines the output scalar type
404  *
405  * \tparam LhsScalar: determines the left-hand-side scalar type
406  *
407  * \tparam RhsScalar: determines the right-hand-side scalar type
408  *
409  * \tparam OutAccessor: determines the sycl accessor type for out put (please see the sycl-1.2.1 specification
410  (https://www.khronos.org/registry/SYCL/specs/sycl-1.2.1.pdf) for accessor definition)
411  *
412  * \tparam LhsMapper determines the tensor contraction mapper type for left-hand-side matrix
413  *
414  * \tparam RhsMapper determines the tensor contraction mapper type for right-hand-side matrix
415  *
416  * \tparam StorageIndex: determines the StorageIndex Type
417  *
418  * \tparam Properties: determines the Contraction Panel properties
419  *
420  * \tparam TripleDim: determines the M, K, N dimensions for the flatten tensors in order to treat them as a matrix
421  *
422  * \tparam Vectorizable: determines whether or not the vectorization is enabled for the Eigen expression.
423  *
424  * \tparam input_mapper_properties : determine if the input tensors are matrix. If they are matrix, special memory
425  access is used to guarantee that always the memory access are coalesced.
426  *
427  * \tptaram IsFinal : determine if this is the final kernel. If so, the result will be written in a final output.
428  Otherwise, the result of contraction will be written iin a temporary buffer. This is the case when Tall/Skinny
429  contraction is used. So in this case, a final reduction step is required to compute final output.
430 
431  * \tparam contraction_tp: it is an enum value representing whether the local memroy/no local memory implementation of
432  the algorithm to be used
433  *
434  * \param scratch: local memory containing tiles of LHS and RHS tensors for each work-group
435  *
436  * \param lhs: determines the left-hand-side flattened tensor (tensor mapper)
437  *
438  * \param rhs: determines the right-hand-side flattened tensor (tensor mapper)
439  *
440  * \param out_res: determines the output tensor containing the contraction result
441  *
442  * \param groupSizeM: a logical number determining the number of work-group for m dimension
443  *
444  * \param groupSizeN: a logical number determining the number of work-group for n dimension
445  *
446  * \param numTiles: determines total number of tiles on the k dimension
447  *
448  * \param TripleDim: determines the M, K, N dimensions for the flatten tensors in order to treat them as a matrix
449  */
450 template <typename OutScalar, typename LhsScalar, typename RhsScalar, typename OutAccessor, typename LhsMapper,
451           typename RhsMapper, typename StorageIndex, typename Properties, typename TripleDim, bool Vectorizable,
452           typename input_mapper_properties, bool IsFinal, contraction_type contraction_tp>
453 class TensorContractionKernel {
454  public:
455   typedef typename Eigen::TensorSycl::internal::Vectorise<OutScalar, Eigen::SyclDevice, Vectorizable>::PacketReturnType
456       PacketReturnType;
457   static EIGEN_CONSTEXPR int PacketSize =
458       Eigen::TensorSycl::internal::Vectorise<OutScalar, Eigen::SyclDevice, Vectorizable>::PacketSize;
459   static EIGEN_CONSTEXPR bool is_lhs_transposed =
460       !::Eigen::internal::TensorContractionInputMapperTrait<LhsMapper>::inner_dim_contiguous;
461   static EIGEN_CONSTEXPR bool is_rhs_transposed =
462       !::Eigen::internal::TensorContractionInputMapperTrait<RhsMapper>::inner_dim_contiguous;
463 
464   typedef BlockProperties<is_lhs_transposed, false, input_mapper_properties::is_lhs_matrix && Vectorizable,
465                           PacketReturnType>
466       LHSBlockProperties;
467 
468   typedef BlockProperties<is_rhs_transposed, true, input_mapper_properties::is_rhs_matrix && Vectorizable,
469                           PacketReturnType>
470       RHSBlockProperties;
471 
472   static EIGEN_CONSTEXPR StorageIndex NStride =
473       contraction_tp == contraction_type::local ? Properties::WorkLoadPerThreadN : RHSBlockProperties::nc_stride;
474 
475   typedef cl::sycl::accessor<OutScalar, 1, cl::sycl::access::mode::read_write, cl::sycl::access::target::local> Scratch;
476   typedef cl::sycl::multi_ptr<OutScalar, cl::sycl::access::address_space::local_space> local_ptr;
477   typedef OutScalar * /*cl::sycl::multi_ptr<OutScalar, cl::sycl::access::address_space::private_space>*/ private_ptr;
478   typedef
479       typename ::Eigen::internal::conditional<contraction_tp == contraction_type::local, local_ptr, private_ptr>::type
480           tile_ptr;
481   static EIGEN_CONSTEXPR StorageIndex LSDL = contraction_tp == contraction_type::local
482                                                  ? Properties::TileSizeDimM + Properties::BC
483                                                  : Properties::WorkLoadPerThreadM;
484   static EIGEN_CONSTEXPR StorageIndex LSDR = contraction_tp == contraction_type::local
485                                                  ? Properties::TileSizeDimN + Properties::BC
486                                                  : Properties::WorkLoadPerThreadN;
487   static EIGEN_CONSTEXPR StorageIndex LocalOffset = Properties::LocalThreadSizeM * Properties::LocalThreadSizeN;
488 
489   /**
490    * \brief MemHolder this is a place holder struct for creating memory hierarchy in SYCL. Inside SYCL kernel it is not
491    * allowed to have dynamic memory allocation. While the local memory is created outside of the kernel and passed to
492    * the kernel as an accessor, the private memory can only allowed to be allocated statically. Since we are abstracting
493    * the TiledMemory for both local and private memory, the MemHolder structs is used as a helper to abstract out
494    * different type of memory needed when local/no_local memory computation is called.
495    *
496    * \tparam contraction_type: it is an enum value representing whether the local memroy/no local memory implementation
497    of the algorithm to be used
498    * \tparam the private memory size
499    * \param ptr the tile memory pointer type
500    */
501   template <contraction_type, StorageIndex>
502   struct MemHolder {
503     tile_ptr ptr;
MemHolderMemHolder504     EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE MemHolder(local_ptr block_start_ptr) : ptr(block_start_ptr) {}
505   };
506   /**
507    * \brief specialization of memHolder class when no local memory kernel is used.
508    */
509   template <StorageIndex MemSize>
510   struct MemHolder<contraction_type::no_local, MemSize> {
511     OutScalar ptr[MemSize] = {OutScalar{0}};
512   };
513   /**
514    * \brief TiledMemory: contains required memory pointer for loading  each tile of the TensorContraction panel from
515    * global memory to local/private memory when local/no_local algorithm used.
516    *
517    * \param lhs_scratch_extract : determines the LHS tile memory. It is either private or local memory based on the
518    * selected contraction_type.
519    *
520    * \param rhs_scratch_extract : determines the RHS tile memory. It is either private or local memory based on the
521    * selected contraction_type.
522    *
523    * \param lhs_extract_index: determins the position of each thread on a local memory for lhs input. When private
524    * memory is used this is set to zero as this is not applicable in case of private memory.
525    *
526    * \param rhs_extract_index: determins the position of each thread on a local memory for rhs input. When private
527    * memory is used this is set to zero as this is not applicable in case of private memory.
528    *
529    * \param lhs_scratch_compute : determines the  location to load for computation for lhs_local memory. This is the
530    * same as lhs_scratch_extract for private memory.
531    *
532    * \param rhs_scratch_compute : determines the  location to load for computation for rhs_local memory. This is the
533    * same as rhs_scratch_extract for private memory.
534    */
535   struct TiledMemory {
536     MemHolder<contraction_tp, Properties::WorkLoadPerThreadM * Properties::TileSizeDimK> lhs_scratch_extract;
537     MemHolder<contraction_tp, Properties::WorkLoadPerThreadN * Properties::TileSizeDimK> rhs_scratch_extract;
538     tile_ptr lhs_scratch_ptr_compute;
539     tile_ptr rhs_scratch_ptr_compute;
540     const std::pair<StorageIndex, StorageIndex> lhs_extract_index;
541     const std::pair<StorageIndex, StorageIndex> rhs_extract_index;
542     template <contraction_type tp = contraction_tp>
543     EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE
544     TiledMemory(const ThreadProperties<StorageIndex> &, local_ptr,
545                 typename ::Eigen::internal::enable_if<tp == contraction_type::no_local>::type * = 0)
546         : lhs_scratch_extract{},
547           rhs_scratch_extract{},
548           lhs_scratch_ptr_compute(lhs_scratch_extract.ptr),
549           rhs_scratch_ptr_compute(rhs_scratch_extract.ptr),
550           lhs_extract_index(std::pair<StorageIndex, StorageIndex>(StorageIndex{0}, StorageIndex{0})),
551           rhs_extract_index(std::pair<StorageIndex, StorageIndex>(StorageIndex{0}, StorageIndex{0})) {}
552 
553     template <contraction_type tp = contraction_tp>
554     EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE
555     TiledMemory(const ThreadProperties<StorageIndex> &thread_properties, local_ptr block_start_ptr,
556                 typename ::Eigen::internal::enable_if<tp == contraction_type::local>::type * = 0)
557         : lhs_scratch_extract{block_start_ptr},
558           rhs_scratch_extract{lhs_scratch_extract.ptr +
559                               ((Properties::DoubleBuffer + 1) * LSDL * Properties::TileSizeDimK)},
560           lhs_scratch_ptr_compute(lhs_scratch_extract.ptr + thread_properties.mLocalOffset),
561           rhs_scratch_ptr_compute(rhs_scratch_extract.ptr + thread_properties.nLocalOffset),
562           lhs_extract_index(
563               local_id_extract<LHSBlockProperties, Properties::TileSizeDimM>(thread_properties.linearLocalThreadId)),
564           rhs_extract_index(
565               local_id_extract<RHSBlockProperties, Properties::TileSizeDimN>(thread_properties.linearLocalThreadId)) {}
566   };
567 
568   Scratch scratch;
569   const LhsMapper lhs;
570   const RhsMapper rhs;
571   OutAccessor out_res;
572   const StorageIndex groupSizeM;
573   const StorageIndex groupSizeN;
574   const StorageIndex numTiles;
575   const TripleDim triple_dim;
576 
577   EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE TensorContractionKernel(Scratch scratch_, const LhsMapper lhs_,
578                                                                 const RhsMapper rhs_, OutAccessor out_res_,
579                                                                 const StorageIndex groupSizeM_,
580                                                                 const StorageIndex groupSizeN_,
581                                                                 const StorageIndex numTiles_,
582                                                                 const TripleDim triple_dim_)
583       : scratch(scratch_),
584         lhs(lhs_),
585         rhs(rhs_),
586         out_res(out_res_),
587         groupSizeM(groupSizeM_),
588         groupSizeN(groupSizeN_),
589         numTiles(numTiles_),
590         triple_dim(triple_dim_) {}
591 
592   EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE TensorContractionKernel(Scratch scratch_, const LhsMapper lhs_,
593                                                                 const RhsMapper rhs_, OutAccessor out_res_,
594                                                                 const StorageIndex groupSizeM_,
595                                                                 const StorageIndex numTiles_,
596                                                                 const TripleDim triple_dim_)
597       : TensorContractionKernel(scratch_, lhs_, rhs_, out_res_, groupSizeM_, 1, numTiles_, triple_dim_) {}
598 
599   EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void operator()(cl::sycl::nd_item<1> itemID) {
600     const StorageIndex linearLocalThreadId = itemID.get_local_id(0);
601     const StorageIndex nLocalThreadId = linearLocalThreadId / Properties::LocalThreadSizeM;
602     const StorageIndex mLocalThreadId = linearLocalThreadId % Properties::LocalThreadSizeM;
603     const StorageIndex mGroupId = itemID.get_group(0) % groupSizeM;
604     const StorageIndex tmp = itemID.get_group(0) / groupSizeM;
605     const StorageIndex nGroupId = IsFinal ? tmp : tmp % groupSizeN;
606     const StorageIndex kGroupId = IsFinal ? 0 : tmp / groupSizeN;
607     const StorageIndex mGroupOffset = mGroupId * Properties::TileSizeDimM;
608     const StorageIndex nGroupOffset = nGroupId * Properties::TileSizeDimN;
609     const StorageIndex mLocalOffset = PacketSize * mLocalThreadId;
610     const StorageIndex nLocalOffset = NStride * nLocalThreadId;
611     const StorageIndex mGlobalOffset = mGroupOffset + mLocalOffset;
612     const StorageIndex nGlobalOffset = nGroupOffset + nLocalOffset;
613 
614     const StorageIndex kSizePerWG = IsFinal ? triple_dim.K : numTiles * Properties::TileSizeDimK;
615     StorageIndex kGroupOffset = kGroupId * kSizePerWG;
616     const bool is_internal = triple_dim.M - mGroupOffset >= Properties::TileSizeDimM &&
617                              triple_dim.N - nGroupOffset >= Properties::TileSizeDimN &&
618                              triple_dim.K - kGroupOffset >= kSizePerWG;
619     // this is used to adjust the last block
620     StorageIndex kSize = IsFinal ? triple_dim.K : std::min(kSizePerWG, triple_dim.K - kGroupOffset);
621     // This is used to find out the lats K offset so that kGroupOffset -kSize can compute the coffset for loading to
622     // tile
623     kGroupOffset += kSize;
624 
625     auto thread_properties =
626         ThreadProperties<StorageIndex>(linearLocalThreadId, kGroupId, mGroupOffset, nGroupOffset, kGroupOffset,
627                                        mLocalOffset, nLocalOffset, mGlobalOffset, nGlobalOffset, kSize, is_internal);
628 
629     auto out_ptr = out_res.get_pointer() + (IsFinal ? 0 : thread_properties.kGroupId * triple_dim.M * triple_dim.N);
630 
631     (thread_properties.is_internal) ? compute_panel<true>(itemID, thread_properties, out_ptr)
632                                     : compute_panel<false>(itemID, thread_properties, out_ptr);
633   }
634   // The compute block computes the contraction operation private block for each thread and store the resutl in the
635   // privateRes memory of Each computation the compute block function is independent of local and no local concepts as
636   // it only compute the block on each thread's private memory space
637   EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void compute_block_per_tile(OutScalar *lhs_block_ptr, OutScalar *rhs_block_ptr,
638                                                                     PacketReturnType *privateRes) {
639     StorageIndex idx = 0;
640     EIGEN_CONSTEXPR StorageIndex lhs_stride =
641         contraction_tp == contraction_type::local ? (PacketSize * Properties::LocalThreadSizeM) : 1;
642     EIGEN_UNROLL_LOOP
643     for (StorageIndex wLPTN = 0; wLPTN < Properties::WorkLoadPerThreadN; wLPTN++) {
644       auto rhsPacket = PacketReturnType{*(rhs_block_ptr + wLPTN)};
645       StorageIndex lhs_index = 0;
646       EIGEN_UNROLL_LOOP
647       for (StorageIndex wLPTM = 0; wLPTM < Properties::WorkLoadPerThreadM / PacketSize; wLPTM++) {
648         PacketReturnType lhsPack{};
649         Eigen::TensorSycl::internal::PacketWrapper<PacketReturnType, PacketSize>::set_packet(lhsPack,
650                                                                                              lhs_block_ptr + lhs_index);
651         privateRes[idx] = ::Eigen::internal::pmadd(lhsPack, rhsPacket, privateRes[idx]);
652 
653         lhs_index += lhs_stride;
654         idx++;
655       }
656     }
657   }
658   // The store function write the computed contraction operation in the private memory of each thread to the global
659   // memory. The store function is independent of local and no local concepts s that it can be abstract out in the base
660   // class.
661   template <bool is_internal_block, StorageIndex PrivateNStride, typename OutPtr>
662   EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void store(OutPtr *out_ptr, PacketReturnType *privateRes,
663                                                    StorageIndex mGlobalOffset, StorageIndex nGlobalOffset) {
664     auto chk_bound = [&](const StorageIndex &mIndex, const StorageIndex &nIndex) EIGEN_DEVICE_FUNC {
665       return (mIndex + PacketSize - 1 < triple_dim.M && nGlobalOffset + nIndex < triple_dim.N);
666     };
667     // when local memory is not used M and N are both accessed in a coalesced way. However, when local memory is
668     // available the k*N is transposed in the local to N*K therefore, each blocks operates on blockId*
669     // WorkLoadPerThreadN slice of N
670     EIGEN_CONSTEXPR StorageIndex GlobalNStride =
671         contraction_tp == contraction_type::local ? 1 : Properties::LocalThreadSizeN;
672     EIGEN_UNROLL_LOOP
673     for (StorageIndex wLPTN = 0; wLPTN < Properties::WorkLoadPerThreadN / PrivateNStride; wLPTN++) {
674       // output leading dimension
675       StorageIndex outputLD = 0;
676       // When local memory is used the PrivateNstride is always 1 because the coalesed access on N is loaded into Local
677       // memory and extracting from local to global is the same as no transposed version. However, when local memory is
678       // not used and RHS is transposed we packetize the load for RHS.
679       EIGEN_UNROLL_LOOP
680       for (StorageIndex nId = 0; nId < PrivateNStride; nId++) {
681         StorageIndex globalRow = mGlobalOffset;
682         EIGEN_UNROLL_LOOP
683         for (StorageIndex wLPTM = 0; wLPTM < Properties::WorkLoadPerThreadM / PacketSize; wLPTM++) {
684           PacketReturnType privetOut = privateRes[wLPTM];
685           if (check_boundary<is_internal_block>(chk_bound(globalRow, nId))) {
686             // Store the final results in C. The C matrix has always M as a first StorageIndex and N as a second
687             // StorageIndex Therefore it is always coalesced layout
688             write<data_source::global_mem>(privetOut, out_ptr + outputLD + globalRow);
689           } else {
690             EIGEN_UNROLL_LOOP
691             for (StorageIndex mId = 0; mId < PacketSize; mId++) {
692               StorageIndex mOffset = globalRow + mId;
693               if (mOffset < triple_dim.M && (nGlobalOffset + nId < triple_dim.N)) {
694                 out_ptr[mOffset + outputLD] =
695                     Eigen::TensorSycl::internal::PacketWrapper<PacketReturnType, PacketSize>::scalarize(mId, privetOut);
696               }
697             }
698           }
699           globalRow += (PacketSize * Properties::LocalThreadSizeM);
700         }
701         outputLD += triple_dim.M;
702         privateRes += Properties::WorkLoadPerThreadM / PacketSize;
703       }
704       out_ptr += (GlobalNStride * outputLD);
705 
706       nGlobalOffset += (PrivateNStride * GlobalNStride);
707     }
708   }
709   // when no local memory is used the following extract_block will be enabled
710   template <typename InputBlockProperties, bool is_internal_block, typename Input, typename PrivateReg,
711             contraction_type contract_tp = contraction_tp>
712   EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE
713       typename ::Eigen::internal::enable_if<contract_tp == contraction_type::no_local>::type
714       extract_block(const Input &inpt, PrivateReg private_ptr, const std::pair<StorageIndex, StorageIndex> &,
715                     const StorageIndex &ncOffset, const StorageIndex cOffset) {
716     EIGEN_CONSTEXPR StorageIndex LocalThreadSizeNC =
717         InputBlockProperties::is_rhs ? Properties::LocalThreadSizeN : Properties::LocalThreadSizeM;
718     EIGEN_CONSTEXPR StorageIndex WorkLoadPerThreadNC =
719         InputBlockProperties::is_rhs ? Properties::WorkLoadPerThreadN : Properties::WorkLoadPerThreadM;
720     const StorageIndex &NC = InputBlockProperties::is_rhs ? triple_dim.N : triple_dim.M;
721 
722     auto chk_bound = [&](const StorageIndex &CIndex, const StorageIndex &NCIndex) EIGEN_DEVICE_FUNC {
723       return ((CIndex + InputBlockProperties::c_stride - 1 < triple_dim.K) &&
724               (NCIndex + InputBlockProperties::nc_stride - 1 < NC));
725     };
726     const StorageIndex ld = InputBlockProperties::is_coalesced_layout ? NC : triple_dim.K;
727     StorageIndex cIndex = cOffset;
728 
729     EIGEN_UNROLL_LOOP
730     for (StorageIndex cId = 0; cId < Properties::TileSizeDimK / InputBlockProperties::c_stride; cId++) {
731       StorageIndex ncIndex = ncOffset;
732       EIGEN_UNROLL_LOOP
733       for (StorageIndex ncId = 0; ncId < WorkLoadPerThreadNC / InputBlockProperties::nc_stride; ncId++) {
734         if (check_boundary<is_internal_block>(chk_bound(cIndex, ncIndex))) {
735           auto val =
736               read<InputBlockProperties::packet_load, InputBlockProperties::is_coalesced_layout,
737                    InputBlockProperties::is_rhs, typename InputBlockProperties::OutType>(inpt, ncIndex, cIndex, ld);
738 
739           write<StorageIndex, (InputBlockProperties::is_coalesced_layout ? 1 : WorkLoadPerThreadNC),
740                 data_source::private_mem>(val, private_ptr);
741         } else {
742           EIGEN_UNROLL_LOOP
743           for (StorageIndex i = 0; i < InputBlockProperties::elements_per_access; i++) {
744             const StorageIndex ncInd = ncIndex + (InputBlockProperties::is_coalesced_layout ? i : 0);
745             const StorageIndex cInd = cIndex + (InputBlockProperties::is_coalesced_layout ? 0 : i);
746             OutScalar val =
747                 (ncInd < NC && cInd < triple_dim.K)
748                     ? read<false, InputBlockProperties::is_coalesced_layout, InputBlockProperties::is_rhs, OutScalar>(
749                           inpt, ncInd, cInd, ld)
750                     : OutScalar(0);
751             write<StorageIndex, (InputBlockProperties::is_coalesced_layout ? 1 : WorkLoadPerThreadNC),
752                   data_source::private_mem>(
753                 val, private_ptr + (InputBlockProperties::is_coalesced_layout ? i : 0) +
754                          ((InputBlockProperties::is_coalesced_layout ? 0 : i) * WorkLoadPerThreadNC));
755           }
756         }
757 
758         // if it is lhs we have to load it packetised when the packet size is > 1, because the output is coalesced. So
759         // even if M is not accessed in a coalesced mode, we have to load packet_size number of m per thread.
760         ncIndex = (!InputBlockProperties::is_rhs && InputBlockProperties::nc_stride == 1 && PacketSize != 1)
761                       ? ncOffset + (ncId + 1) % PacketSize + ((ncId + 1) / PacketSize) * LocalThreadSizeNC
762                       : (ncIndex + InputBlockProperties::nc_stride * LocalThreadSizeNC);
763         private_ptr += InputBlockProperties::nc_stride;
764       }
765       // the previous for loop ( private_ptr += (ncId * nc_stride)) has already moved ptr with one WorkLoadPerThreadNC
766       private_ptr += (InputBlockProperties::c_stride - 1) * WorkLoadPerThreadNC;
767       cIndex += InputBlockProperties::c_stride;
768     }
769   }
770   template <typename InputBlockProperties, StorageIndex TileSizeDimNC>
771   static EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE std::pair<StorageIndex, StorageIndex> local_id_extract(
772       const StorageIndex &linearLocalThreadId) {
773     const StorageIndex localThreadNC =
774         (InputBlockProperties::is_coalesced_layout)
775             ? linearLocalThreadId % (TileSizeDimNC / InputBlockProperties::nc_stride)
776             : linearLocalThreadId / (Properties::TileSizeDimK / InputBlockProperties::c_stride);
777     const StorageIndex localThreadC =
778         (InputBlockProperties::is_coalesced_layout)
779             ? linearLocalThreadId / (TileSizeDimNC / InputBlockProperties::nc_stride)
780             : linearLocalThreadId % (Properties::TileSizeDimK / InputBlockProperties::c_stride);
781     return std::pair<StorageIndex, StorageIndex>(localThreadNC, localThreadC);
782   }
783 
784   template <bool db = Properties::DoubleBuffer, contraction_type ctp = contraction_tp>
785   static EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE
786       typename ::Eigen::internal::enable_if<db && ctp == contraction_type::local>::type
787       sync_mem(const cl::sycl::nd_item<1> &, bool &db_offset) noexcept {
788     db_offset = !db_offset;
789   }
790 
791   template <bool db = Properties::DoubleBuffer, contraction_type ctp = contraction_tp>
792   static EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE
793       typename ::Eigen::internal::enable_if<!db && ctp == contraction_type::local>::type
794       sync_mem(const cl::sycl::nd_item<1> &itemID, bool &) noexcept {
795     itemID.barrier(cl::sycl::access::fence_space::local_space);
796   }
797 
798   template <contraction_type ctp = contraction_tp>
799   static EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE
800       typename ::Eigen::internal::enable_if<ctp == contraction_type::no_local>::type
801       sync_mem(const cl::sycl::nd_item<1> &, bool &) noexcept {
802     return;
803   }
804 
805   template <bool need_sync, contraction_type ctp = contraction_tp>
806   static EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE
807       typename ::Eigen::internal::enable_if<need_sync && ctp == contraction_type::no_local>::type
808       sync_thread(const cl::sycl::nd_item<1> &
809 #ifdef EIGEN_SYCL_ARM_GPU_CACHE_OPTIMISATION
810                       itemID
811 #endif
812                   ) noexcept {
813 #ifdef EIGEN_SYCL_ARM_GPU_CACHE_OPTIMISATION
814     itemID.barrier(cl::sycl::access::fence_spacce::local_space);
815 #else
816     return;
817 #endif
818   }
819   template <bool need_sync, contraction_type ctp = contraction_tp>
820   static EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE
821       typename ::Eigen::internal::enable_if<need_sync && ctp == contraction_type::local>::type
822       sync_thread(const cl::sycl::nd_item<1> &itemID) {
823     itemID.barrier(cl::sycl::access::fence_space::local_space);
824   }
825   template <bool need_sync>
826   static EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE typename ::Eigen::internal::enable_if<!need_sync>::type sync_thread(
827       const cl::sycl::nd_item<1> &) {
828     return;
829   }
830 
831   template <bool is_internal_block>
832   EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void compute_tile_per_panel(const cl::sycl::nd_item<1> &itemID,
833                                                                     ThreadProperties<StorageIndex> &thread_properties,
834                                                                     TiledMemory &tiled_input_block,
835                                                                     PacketReturnType *privateRes, bool &db_offset) {
836     // Tiling the Rhs block from global to local memory
837     extract_block<RHSBlockProperties, is_internal_block>(
838         rhs, tiled_input_block.rhs_scratch_extract.ptr + (db_offset * Properties::TileSizeDimK * LSDR),
839         tiled_input_block.rhs_extract_index,
840         contraction_tp == contraction_type::local ? thread_properties.nGroupOffset : thread_properties.nGlobalOffset,
841         thread_properties.kGroupOffset - thread_properties.kSize);
842 
843     sync_thread<contraction_tp == contraction_type::no_local>(itemID);
844 
845     // Tiling the Lhs block from global to local memory
846     extract_block<LHSBlockProperties, is_internal_block>(
847         lhs, tiled_input_block.lhs_scratch_extract.ptr + (db_offset * LSDL * Properties::TileSizeDimK),
848         tiled_input_block.lhs_extract_index,
849         contraction_tp == contraction_type::local ? thread_properties.mGroupOffset : thread_properties.mGlobalOffset,
850         thread_properties.kGroupOffset - thread_properties.kSize);
851 
852     // itemID.barrier(cl::sycl::access::fence_space::local_space);
853     sync_thread<contraction_tp == contraction_type::local>(itemID);
854     // switch to compute mede
855     StorageIndex lhs_offset = (db_offset * LSDL * Properties::TileSizeDimK);
856     StorageIndex rhs_offset = (db_offset * Properties::TileSizeDimK * LSDR);
857     // Loop over the values of a single tile
858     for (StorageIndex k = 0; k < Properties::TileSizeDimK; k++) {
859       compute_block_per_tile(tiled_input_block.lhs_scratch_ptr_compute + lhs_offset,
860                              tiled_input_block.rhs_scratch_ptr_compute + rhs_offset, privateRes);
861       lhs_offset += LSDL;
862       rhs_offset += LSDR;
863     }
864     // computing the K index for the next tile
865     thread_properties.kSize -= Properties::TileSizeDimK;
866     sync_mem(itemID, db_offset);
867   }
868 
869   // when local memory is available the following compute_panel will be enabled
870   template <bool is_internal_block, typename OutPtr>
871   EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void compute_panel(const cl::sycl::nd_item<1> &itemID,
872                                                            ThreadProperties<StorageIndex> &thread_properties,
873                                                            OutPtr out_ptr) {
874     auto tiled_input_block = TiledMemory{thread_properties, scratch.get_pointer()};
875     // Allocate register space
876     PacketReturnType privateRes[Properties::WorkLoadPerThreadM * Properties::WorkLoadPerThreadN / PacketSize] = {
877         PacketReturnType{0}};
878     bool db_offset = 0;
879 
880     while (thread_properties.kSize >= Properties::TileSizeDimK) {
881       compute_tile_per_panel<is_internal_block>(itemID, thread_properties, tiled_input_block, privateRes, db_offset);
882     }
883     if (thread_properties.kSize > 0) {
884       compute_tile_per_panel<false>(itemID, thread_properties, tiled_input_block, privateRes, db_offset);
885     }
886 
887     // Storing the final results in the output
888     store<is_internal_block,
889           contraction_tp == contraction_type::local ? static_cast<StorageIndex>(1) : RHSBlockProperties::nc_stride>(
890         out_ptr + thread_properties.nGlobalOffset * triple_dim.M, privateRes, thread_properties.mGlobalOffset,
891         thread_properties.nGlobalOffset);
892   }
893   // When local memory is available the following extract_block will be enabled
894   template <typename InputBlockProperties, bool is_internal_block, typename Input, typename Local,
895             contraction_type contract_tp = contraction_tp>
896   EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE
897       typename ::Eigen::internal::enable_if<contract_tp == contraction_type::local>::type
898       extract_block(const Input &inpt, Local local_ptr, const std::pair<StorageIndex, StorageIndex>& local_index,
899                     const StorageIndex &ncOffset, const StorageIndex cOffset) {
900     EIGEN_CONSTEXPR StorageIndex TileSizeDimNC =
901         InputBlockProperties::is_rhs ? Properties::TileSizeDimN : Properties::TileSizeDimM;
902     EIGEN_CONSTEXPR StorageIndex LoadPerThread =
903         InputBlockProperties::is_rhs ? Properties::LoadPerThreadRhs : Properties::LoadPerThreadLhs;
904     EIGEN_CONSTEXPR StorageIndex LSD = InputBlockProperties::is_rhs ? LSDR : LSDL;
905     static_assert(((LocalOffset % (TileSizeDimNC / InputBlockProperties::nc_stride) == 0) &&
906                    (LocalOffset % (Properties::TileSizeDimK / InputBlockProperties::c_stride) == 0)),
907                   " LocalOffset must be divisable by stride");
908     const StorageIndex &NC = InputBlockProperties::is_rhs ? triple_dim.N : triple_dim.M;
909     StorageIndex localThreadNC = local_index.first;
910     StorageIndex localThreadC = local_index.second;
911     auto chk_bound = [&](const StorageIndex &CIndex, const StorageIndex &NCIndex) EIGEN_DEVICE_FUNC {
912       return ((CIndex + InputBlockProperties::c_stride - 1 < triple_dim.K) &&
913               (NCIndex + InputBlockProperties::nc_stride - 1 < NC));
914     };
915     EIGEN_UNROLL_LOOP
916     for (StorageIndex lPT = 0; lPT < LoadPerThread / InputBlockProperties::elements_per_access; lPT++) {
917       const StorageIndex CIndex = cOffset + (InputBlockProperties::c_stride * localThreadC);
918       const StorageIndex NCIndex = ncOffset + (InputBlockProperties::nc_stride * localThreadNC);
919       const StorageIndex ld = InputBlockProperties::is_coalesced_layout ? NC : triple_dim.K;
920       if (check_boundary<is_internal_block>(chk_bound(CIndex, NCIndex))) {
921         auto val =
922             read<InputBlockProperties::packet_load, InputBlockProperties::is_coalesced_layout,
923                  InputBlockProperties::is_rhs, typename InputBlockProperties::OutType>(inpt, NCIndex, CIndex, ld);
924         write<StorageIndex, (InputBlockProperties::is_coalesced_layout ? 1 : LSD), data_source::local_mem>(
925             val, local_ptr + (InputBlockProperties::nc_stride * localThreadNC) +
926                      (InputBlockProperties::c_stride * localThreadC * LSD));
927       } else {
928         EIGEN_UNROLL_LOOP
929         for (StorageIndex i = 0; i < InputBlockProperties::elements_per_access; i++) {
930           const StorageIndex nCInd = NCIndex + (InputBlockProperties::is_coalesced_layout ? i : 0);
931           const StorageIndex cInd = CIndex + (InputBlockProperties::is_coalesced_layout ? 0 : i);
932           OutScalar val =
933               (nCInd < NC && cInd < triple_dim.K)
934                   ? read<false, InputBlockProperties::is_coalesced_layout, InputBlockProperties::is_rhs, OutScalar>(
935                         inpt, nCInd, cInd, ld)
936                   : OutScalar(0);
937 
938           write<StorageIndex, (InputBlockProperties::is_coalesced_layout ? 1 : LSD), data_source::local_mem>(
939               val, local_ptr + (InputBlockProperties::nc_stride * localThreadNC) +
940                        (InputBlockProperties::is_coalesced_layout ? i : 0) +
941                        ((InputBlockProperties::c_stride * localThreadC +
942                          (InputBlockProperties::is_coalesced_layout ? 0 : i)) *
943                         LSD));
944         }
945       }
946       localThreadNC += (InputBlockProperties::is_coalesced_layout)
947                            ? LocalOffset % (TileSizeDimNC / InputBlockProperties::nc_stride)
948                            : LocalOffset / (Properties::TileSizeDimK / InputBlockProperties::c_stride);
949       localThreadC += (InputBlockProperties::is_coalesced_layout)
950                           ? LocalOffset / (TileSizeDimNC / InputBlockProperties::nc_stride)
951                           : LocalOffset % (Properties::TileSizeDimK / InputBlockProperties::c_stride);
952     }
953   }
954 };
955 
956 #ifndef EIGEN_SYCL_DISABLE_GEMV
957 
958 /*!
959  * \brief GeneralVectorTensor is a template class that provides Tensor -vector contraction operation, which is a special
960  * case of Tensor Tensor contraction.
961  *
962  * \tparam OutScalar: determines the output scalar type
963  *
964  * \tparam OutAccessor: determines the sycl accessor type for out put (please see the sycl-1.2.1 specification
965  * (https://www.khronos.org/registry/SYCL/specs/sycl-1.2.1.pdf) for accessor definition)
966  *
967  * \tparam VectorMapper: determines the tensor contraction mapper for the vector input (can be lhs or rhs)
968  *
969  * \tparam TensorMapper: determines the tensor contraction mapper for the tensor input (can be lhs or rhs)
970  *
971  * \tparam StorageIndex: determines the StorageIndex Type
972  *
973  * \tparam Properties: determines the Contraction Panel properties
974  *
975  * \tparam KFactor: determines the number of elements in K dimension in a Tile
976  *
977  * \tparam Vectorizable: determines whether or not the vectorization is enabled for the Eigen expression.
978  *
979  * \tparam is_lhs_vec: determines whether lhs is a vector or rhs is a vector
980  *
981  * \tparam IsFinal: determine if this is the final kernel. If so, the result will be written in a final output.
982  * Otherwise, the result of contraction will be written iin a temporary buffer.
983  *
984  * \param scratch: determines the local memory containing the vector block for each work-group
985  *
986  * \param vec: determines the vector input (tensor mapper)
987  *
988  * \param mat: determines the tensor input (tensor mapper)
989  *
990  * \param out_res: determines the output vector containing the contraction result
991  *
992  * \param nonContractGroupSize: a logical number determining the number of work-group for non-contracting dimension
993  *
994  * \param nonContractDim: determines the size of non contracting dimension for the flattened tensor
995  *
996  * \param contractDim: determines the size of non contracting dimension for the flattened tensor
997  *
998  */
999 template <typename OutScalar, typename OutAccessor, typename VectorMapper, typename TensorMapper, typename StorageIndex,
1000           typename Properties, StorageIndex KFactor, bool Vectorizable, bool is_lhs_vec, bool IsFinal>
1001 struct GeneralVectorTensor {
1002   typedef typename Eigen::TensorSycl::internal::Vectorise<OutScalar, Eigen::SyclDevice, Vectorizable>::PacketReturnType
1003       PacketReturnType;
1004   static EIGEN_CONSTEXPR int PacketSize =
1005       Eigen::TensorSycl::internal::Vectorise<OutScalar, Eigen::SyclDevice, Vectorizable>::PacketSize;
1006   typedef cl::sycl::accessor<OutScalar, 1, cl::sycl::access::mode::read_write, cl::sycl::access::target::local> Scratch;
1007 
1008   static EIGEN_CONSTEXPR StorageIndex OutScratchOffset =
1009       KFactor * Properties::LocalThreadSizeC * Properties::LocalThreadSizeNC;
1010 
1011   // Since the access layout for a vector can always be coalesced, when LHS is a vector, we pass false and false to make
1012   // sure that the !^ is true When RHS is a vector, we pass true and true to make sure that the !^ is true.
1013   typedef BlockProperties<is_lhs_vec ? false : true, is_lhs_vec ? false : true, Vectorizable, PacketReturnType>
1014       VecBlockProperties;
1015 
1016   Scratch scratch;
1017   const VectorMapper vec;
1018   const TensorMapper mat;
1019   OutAccessor out_res;
1020   const StorageIndex nonContractGroupSize;
1021   const StorageIndex nonContractDim;
1022   const StorageIndex contractDim;
1023 
1024   EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE GeneralVectorTensor(Scratch scratch_, const VectorMapper vec_,
1025                                                             const TensorMapper mat_, OutAccessor out_res_,
1026                                                             const StorageIndex nonContractGroupSize_,
1027                                                             const StorageIndex nonContractDim_,
1028                                                             const StorageIndex contractDim_)
1029       : scratch(scratch_),
1030         vec(vec_),
1031         mat(mat_),
1032         out_res(out_res_),
1033         nonContractGroupSize(nonContractGroupSize_),
1034         nonContractDim(nonContractDim_),
1035         contractDim(contractDim_) {}
1036 
1037   EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void operator()(cl::sycl::nd_item<1> itemID) {
1038     auto scratch_ptr = scratch.get_pointer();
1039     const StorageIndex linearLocalThreadId = itemID.get_local_id(0);
1040     StorageIndex nonContractId = is_lhs_vec ? linearLocalThreadId / Properties::LocalThreadSizeC
1041                                             : linearLocalThreadId % Properties::LocalThreadSizeNC;
1042     StorageIndex contractId = is_lhs_vec ? linearLocalThreadId % Properties::LocalThreadSizeC
1043                                          : linearLocalThreadId / Properties::LocalThreadSizeNC;
1044     const StorageIndex cGroupSize = itemID.get_group_range(0) / nonContractGroupSize;
1045     const StorageIndex nonContractGroupId =
1046         is_lhs_vec ? itemID.get_group(0) / cGroupSize : itemID.get_group(0) % nonContractGroupSize;
1047     const StorageIndex contractGroupId =
1048         is_lhs_vec ? itemID.get_group(0) % cGroupSize : itemID.get_group(0) / nonContractGroupSize;
1049     auto out_ptr = out_res.get_pointer() + (IsFinal ? 0 : contractGroupId * nonContractDim);
1050 
1051     const StorageIndex nonContractGroupOffset = nonContractGroupId * Properties::TileSizeDimNC;
1052     const StorageIndex contractGroupOffset = contractGroupId * Properties::TileSizeDimC;
1053     auto outScratchIndex = nonContractId + contractId * Properties::LocalThreadSizeNC;
1054     const StorageIndex globalNonContractDimOffset = nonContractGroupOffset + nonContractId;
1055     const StorageIndex globalContractDimOffset = contractGroupOffset + contractId;
1056     auto local_output = scratch_ptr + OutScratchOffset;
1057     const bool is_internal = nonContractDim - nonContractGroupOffset >= Properties::TileSizeDimNC &&
1058                              contractDim - contractGroupOffset >= Properties::TileSizeDimC;
1059     is_internal
1060         ? compute_panel<true>(itemID, vec, mat, local_output, out_ptr,
1061 #ifdef EIGEN_SYCL_LOCAL_MEM_UNSET_OR_ON
1062                               scratch_ptr, contractGroupOffset,
1063 #endif
1064                               nonContractGroupOffset, linearLocalThreadId, contractDim, nonContractDim, contractId,
1065                               nonContractId, globalContractDimOffset, globalNonContractDimOffset, outScratchIndex)
1066         : compute_panel<false>(itemID, vec, mat, local_output, out_ptr,
1067 #ifdef EIGEN_SYCL_LOCAL_MEM_UNSET_OR_ON
1068                                scratch_ptr, contractGroupOffset,
1069 #endif
1070                                nonContractGroupOffset, linearLocalThreadId, contractDim, nonContractDim, contractId,
1071                                nonContractId, globalContractDimOffset, globalNonContractDimOffset, outScratchIndex);
1072   }
1073   template <bool is_internal_block, typename OutPtr>
1074   static EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void compute_panel(
1075       const cl::sycl::nd_item<1> &itemID, const VectorMapper &vec, const TensorMapper &mat, OutScalar *local_output,
1076       OutPtr out_ptr,
1077 #ifdef EIGEN_SYCL_LOCAL_MEM_UNSET_OR_ON
1078       OutScalar *scratch_ptr, const StorageIndex contractGroupOffset,
1079 #endif
1080       const StorageIndex nonContractGroupOffset, const StorageIndex linearLocalThreadId, StorageIndex contractDim,
1081       StorageIndex nonContractDim, StorageIndex contractId, StorageIndex nonContractId,
1082       StorageIndex globalContractDimOffset, StorageIndex globalNonContractDimOffset, StorageIndex outScratchIndex) {
1083     OutScalar outScalar[Properties::WorkLoadPerThreadNC] = {OutScalar(0)};
1084     // Reading the vector
1085 #ifdef EIGEN_SYCL_LOCAL_MEM_UNSET_OR_ON
1086     const StorageIndex vectorOffset = contractGroupOffset + linearLocalThreadId;
1087     extract_block<VecBlockProperties, is_internal_block, KFactor,
1088                   Properties::LocalThreadSizeNC * Properties::LocalThreadSizeC>(vec, scratch_ptr, linearLocalThreadId,
1089                                                                                 vectorOffset, contractDim);
1090 
1091     itemID.barrier(cl::sycl::access::fence_space::local_space);
1092     auto in_scratch_ptr = scratch_ptr + contractId;
1093 #endif
1094 
1095     StorageIndex privateOffsetC = 0;
1096     EIGEN_UNROLL_LOOP
1097     for (StorageIndex i = 0; i < Properties::WorkLoadPerThreadC; i++) {
1098       StorageIndex privateOffsetNC = 0;
1099       bool contract_conds = ((globalContractDimOffset + privateOffsetC) < contractDim);
1100 #ifdef EIGEN_SYCL_LOCAL_MEM_UNSET_OR_ON
1101       auto vecScalar = *in_scratch_ptr;
1102 #else
1103       auto vecScalar = (check_boundary<is_internal_block>(contract_conds))
1104                            ? vec(is_lhs_vec ? StorageIndex(0) : globalContractDimOffset + privateOffsetC,
1105                                  is_lhs_vec ? globalContractDimOffset + privateOffsetC : StorageIndex(0))
1106                            : OutScalar(0);
1107 #endif
1108       EIGEN_UNROLL_LOOP
1109       for (StorageIndex j = 0; j < Properties::WorkLoadPerThreadNC; j++) {
1110         auto matScalar = (check_boundary<is_internal_block>(
1111                              contract_conds && ((globalNonContractDimOffset + privateOffsetNC) < nonContractDim)))
1112                              ? mat(is_lhs_vec ? globalContractDimOffset + privateOffsetC
1113                                               : globalNonContractDimOffset + privateOffsetNC,
1114                                    is_lhs_vec ? globalNonContractDimOffset + privateOffsetNC
1115                                               : globalContractDimOffset + privateOffsetC)
1116                              : OutScalar(0);
1117 
1118         outScalar[j] = cl::sycl::mad(matScalar, vecScalar, outScalar[j]);
1119         privateOffsetNC += Properties::LocalThreadSizeNC;
1120       }
1121       privateOffsetC += Properties::LocalThreadSizeC;
1122 #ifdef EIGEN_SYCL_LOCAL_MEM_UNSET_OR_ON
1123       in_scratch_ptr += Properties::LocalThreadSizeC;
1124 #endif
1125     }
1126 
1127     auto out_scratch_ptr = local_output + outScratchIndex;
1128     // Each block of 16*16 element in shared memory should reduce to 16*1
1129     EIGEN_UNROLL_LOOP
1130     for (StorageIndex j = 0; j < Properties::WorkLoadPerThreadNC; j++) {
1131       *out_scratch_ptr = outScalar[j];
1132 
1133       out_scratch_ptr += (Properties::LocalThreadSizeNC * Properties::LocalThreadSizeC);
1134     }
1135     if (is_lhs_vec) {
1136       nonContractId = linearLocalThreadId % Properties::LocalThreadSizeNC;
1137       contractId = linearLocalThreadId / Properties::LocalThreadSizeNC;
1138       outScratchIndex = nonContractId + contractId * Properties::LocalThreadSizeNC;
1139     }
1140 
1141     out_scratch_ptr = local_output + outScratchIndex;
1142     EIGEN_UNROLL_LOOP
1143     for (StorageIndex j = 0; j < Properties::WorkLoadPerThreadNC; j++) {
1144       EIGEN_UNROLL_LOOP
1145       for (StorageIndex offset = Properties::LocalThreadSizeC >> 1; offset > 0; offset >>= 1) {
1146         itemID.barrier(cl::sycl::access::fence_space::local_space);
1147         if (contractId < offset) {
1148           StorageIndex myNeigbourId = (Properties::LocalThreadSizeNC * offset);
1149           *out_scratch_ptr += out_scratch_ptr[myNeigbourId];
1150         }
1151       }
1152       // moving to next 16 by 16 block
1153       out_scratch_ptr += (Properties::LocalThreadSizeNC * Properties::LocalThreadSizeC);
1154     }
1155 
1156     if (contractId == 0) {
1157       out_scratch_ptr = local_output + nonContractId;
1158       StorageIndex global_final_offset = nonContractGroupOffset + nonContractId;
1159       out_ptr += global_final_offset;
1160       EIGEN_UNROLL_LOOP
1161       for (StorageIndex j = 0; j < Properties::WorkLoadPerThreadNC; j++) {
1162         if (check_boundary<is_internal_block>(global_final_offset < nonContractDim)) {
1163           auto res = *out_scratch_ptr;
1164 
1165           *out_ptr = res;
1166           out_ptr += Properties::LocalThreadSizeNC;
1167         }
1168         // moving to next 16 by 16 block to ge the next 16 reduced elements
1169         out_scratch_ptr += (Properties::LocalThreadSizeNC * Properties::LocalThreadSizeC);
1170         if (!(is_internal_block)) global_final_offset += Properties::LocalThreadSizeNC;
1171       }
1172     }
1173   }
1174 
1175   template <typename InputBlockProperties, bool is_internal_block, int CFactor, int GroupSize, typename Input,
1176             typename Local>
1177   static EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void extract_block(const Input &inpt, Local *local_ptr,
1178                                                                   const StorageIndex &linearLocalThreadId,
1179                                                                   const StorageIndex &cOffset, const StorageIndex &C) {
1180     local_ptr += InputBlockProperties::c_stride * linearLocalThreadId;
1181     StorageIndex cIndex = cOffset;
1182     for (StorageIndex cId = 0; cId < CFactor / InputBlockProperties::c_stride; cId++) {
1183       if (check_boundary<is_internal_block>(cIndex + InputBlockProperties::c_stride - 1 < C)) {
1184         auto val = read<InputBlockProperties::packet_load, InputBlockProperties::is_coalesced_layout,
1185                         InputBlockProperties::is_rhs, typename InputBlockProperties::OutType>(inpt, StorageIndex(0),
1186                                                                                               cIndex, StorageIndex(1));
1187         write<StorageIndex, 1, data_source::local_mem>(val, local_ptr);
1188       } else {
1189         EIGEN_UNROLL_LOOP
1190         for (StorageIndex i = 0; i < InputBlockProperties::elements_per_access; i++) {
1191           OutScalar val =
1192               (cIndex + i < C)
1193                   ? read<false, InputBlockProperties::is_coalesced_layout, InputBlockProperties::is_rhs, OutScalar>(
1194                         inpt, StorageIndex(0), cIndex + i, StorageIndex(1))
1195                   : OutScalar(0);
1196           write<StorageIndex, 1, data_source::local_mem>(val, local_ptr + i);
1197         }
1198       }
1199       local_ptr += InputBlockProperties::c_stride * GroupSize;
1200       cIndex += InputBlockProperties::c_stride * GroupSize;
1201     }
1202   }
1203 };
1204 #endif
1205 
1206 #ifndef EIGEN_SYCL_DISABLE_SCALAR
1207 
1208 /*!
1209  * \brief GeneralScalarContraction is a template class that provides the scalar value of Tensor -Tensor contraction
1210  * operation, when all the dimensions are contracting dimensions. This Kernel reduces two tensors to an scalar
1211  *
1212  * \tparam OutScalar: determines the output scalar type
1213  *
1214  * \tparam LhsScalar: determines the left-hand-side scalar type
1215  *
1216  * \tparam RhsScalar: determines the right-hand-side scalar type
1217  *
1218  * \tparam OutAccessor: determines the sycl accessor type for out put (please see the sycl-1.2.1 specification
1219  * (https://www.khronos.org/registry/SYCL/specs/sycl-1.2.1.pdf) for accessor definition)
1220  *
1221  * \tparam LhsMapper: determines the tensor contraction mapper type for left-hand-side matrix
1222  *
1223  * \tparam RhsMapper: determines the tensor contraction mapper type for right-hand-side matrix
1224  *
1225  * \tparam StorageIndex: determines the StorageIndex Type
1226  *
1227  * \tparam Vectorizable: determines whether or not the vectorization is enabled for the Eigen expression.
1228  *
1229  * \param scratch: local memory containing tiles of LHS and RHS tensors for each work-group
1230  *
1231  * \param lhs: determines the left-hand-side flattened tensor (tensor mapper)
1232  *
1233  * \param rhs: determines the right-hand-side flattened tensor (tensor mapper)
1234  *
1235  * \param out_res: determines the output tensor containing the contraction result
1236  *
1237  * \param rng: determins the total input data size
1238  */
1239 template <typename OutScalar, typename LhsScalar, typename RhsScalar, typename OutAccessor, typename LhsMapper,
1240           typename RhsMapper, typename StorageIndex, bool Vectorizable>
1241 struct GeneralScalarContraction {
1242   typedef cl::sycl::accessor<OutScalar, 1, cl::sycl::access::mode::read_write, cl::sycl::access::target::local> Scratch;
1243   Scratch scratch;
1244   const LhsMapper lhs;
1245   const RhsMapper rhs;
1246   OutAccessor out_res;
1247   const StorageIndex rng;
1248 
1249   EIGEN_DEVICE_FUNC
1250   GeneralScalarContraction(Scratch scratch_, const LhsMapper lhs_, const RhsMapper rhs_, OutAccessor out_res_,
1251                            const StorageIndex rng_)
1252       : scratch(scratch_), lhs(lhs_), rhs(rhs_), out_res(out_res_), rng(rng_) {}
1253 
1254   EIGEN_DEVICE_FUNC void operator()(cl::sycl::nd_item<1> itemID) {
1255     auto out_ptr = out_res.get_pointer();
1256     auto scratch_ptr = scratch.get_pointer().get();
1257 
1258     StorageIndex globalid = itemID.get_global_id(0);
1259     StorageIndex localid = itemID.get_local_id(0);
1260     OutScalar accumulator = OutScalar(0);
1261     for (StorageIndex i = globalid; i < rng; i += itemID.get_global_range(0)) {
1262       accumulator = cl::sycl::mad(lhs(0, i), rhs(i, 0), accumulator);
1263     }
1264     auto out_scratch_ptr = scratch_ptr + localid;
1265     *out_scratch_ptr = accumulator;
1266     for (StorageIndex offset = itemID.get_local_range(0) >> 1; offset > 0; offset >>= 1) {
1267       itemID.barrier(cl::sycl::access::fence_space::local_space);
1268       if (localid < offset) {
1269         *out_scratch_ptr = (accumulator += out_scratch_ptr[offset]);
1270       }
1271     }
1272     if (localid == 0) {
1273       out_ptr[itemID.get_group(0)] = accumulator;
1274     }
1275   }
1276 };
1277 #endif
1278 
1279 }  // namespace internal
1280 }  // namespace TensorSycl
1281 
1282 template <typename Indices, typename LeftArgType, typename RightArgType, typename OutputKernelType>
1283 struct TensorEvaluator<const TensorContractionOp<Indices, LeftArgType, RightArgType, OutputKernelType>,
1284                        Eigen::SyclDevice>
1285     : public TensorContractionEvaluatorBase<TensorEvaluator<
1286           const TensorContractionOp<Indices, LeftArgType, RightArgType, OutputKernelType>, Eigen::SyclDevice>> {
1287   static_assert(std::is_same<OutputKernelType, const NoOpOutputKernel>::value,
1288                 "SYCL tensor contraction does not support output kernels.");
1289 
1290   typedef Eigen::SyclDevice Device;
1291 
1292   typedef TensorEvaluator<const TensorContractionOp<Indices, LeftArgType, RightArgType, OutputKernelType>, Device> Self;
1293   typedef TensorContractionEvaluatorBase<Self> Base;
1294   typedef TensorContractionOp<Indices, LeftArgType, RightArgType, OutputKernelType> XprType;
1295   typedef typename internal::remove_const<typename XprType::Scalar>::type Scalar;
1296   typedef typename XprType::Index StorageIndex;
1297   typedef typename XprType::CoeffReturnType CoeffReturnType;
1298   typedef typename PacketType<CoeffReturnType, Device>::type PacketReturnType;
1299   typedef typename Base::Storage Storage;
1300   typedef typename Base::EvaluatorPointerType EvaluatorPointerType;
1301   struct TripleDim {
1302     const StorageIndex M;
1303     const StorageIndex N;
1304     const StorageIndex K;
1305     TripleDim(const StorageIndex M_, const StorageIndex N_, const StorageIndex K_) : M(M_), N(N_), K(K_) {}
1306   };
1307   enum {
1308     Layout = TensorEvaluator<LeftArgType, Device>::Layout,
1309     PacketAccess = (PacketType<CoeffReturnType, Device>::size > 1),
1310     BlockAccess = false,
1311   };
1312 
1313   static EIGEN_CONSTEXPR int LDims = Base::LDims;
1314   static EIGEN_CONSTEXPR int RDims = Base::RDims;
1315   static EIGEN_CONSTEXPR int ContractDims = Base::ContractDims;
1316 
1317   typedef array<StorageIndex, LDims> left_dim_mapper_t;
1318   typedef array<StorageIndex, RDims> right_dim_mapper_t;
1319 
1320   typedef array<StorageIndex, ContractDims> contract_t;
1321   typedef array<StorageIndex, LDims - ContractDims> left_nocontract_t;
1322   typedef array<StorageIndex, RDims - ContractDims> right_nocontract_t;
1323 
1324   static const int NumDims = LDims + RDims - 2 * ContractDims;
1325 
1326   typedef DSizes<StorageIndex, NumDims> Dimensions;
1327 
1328   typedef TensorEvaluator<typename Base::EvalLeftArgType, Device> LeftEvaluator;
1329   typedef TensorEvaluator<typename Base::EvalRightArgType, Device> RightEvaluator;
1330   typedef typename Eigen::internal::remove_const<typename LeftEvaluator::CoeffReturnType>::type LhsScalar;
1331   typedef typename Eigen::internal::remove_const<typename RightEvaluator::CoeffReturnType>::type RhsScalar;
1332 
1333   typedef typename LeftEvaluator::Dimensions LeftDimensions;
1334   typedef typename RightEvaluator::Dimensions RightDimensions;
1335 
1336   template <bool lhs_inner_dim_contiguous, bool rhs_inner_dim_contiguous, bool rhs_inner_dim_reordered>
1337   struct input_mapper_propertis {
1338     static EIGEN_CONSTEXPR bool is_lhs_matrix = (LDims == 2 && ContractDims == 1) || lhs_inner_dim_contiguous;
1339     static EIGEN_CONSTEXPR bool is_rhs_matrix =
1340         (RDims == 2 && ContractDims == 1) || (rhs_inner_dim_contiguous && !rhs_inner_dim_reordered);
1341   };
1342 
1343   TensorEvaluator(const XprType &op, const Device &device) : Base(op, device) {}
1344 
1345   // We need to redefine this method to make nvcc happy
1346   EIGEN_STRONG_INLINE bool evalSubExprsIfNeeded(typename Base::EvaluatorPointerType data) {
1347     this->m_leftImpl.evalSubExprsIfNeeded(NULL);
1348     this->m_rightImpl.evalSubExprsIfNeeded(NULL);
1349     if (!data) {
1350       this->m_result = this->m_device.get(
1351           static_cast<Scalar *>(this->m_device.allocate_temp(this->dimensions().TotalSize() * sizeof(Scalar))));
1352       data = this->m_result;
1353     }
1354     evalToSycl(data);
1355     return (this->m_result != NULL);
1356   }
1357   const Eigen::SyclDevice &device() const { return this->m_device; }
1358   void evalToSycl(typename Base::EvaluatorPointerType buffer) const {
1359     if (this->m_lhs_inner_dim_contiguous) {
1360       if (this->m_rhs_inner_dim_contiguous) {
1361         if (this->m_rhs_inner_dim_reordered) {
1362           evalTyped<true, true, true, Unaligned>(buffer);
1363         } else {
1364           evalTyped<true, true, false, Unaligned>(buffer);
1365         }
1366       } else {
1367         if (this->m_rhs_inner_dim_reordered) {
1368           evalTyped<true, false, true, Unaligned>(buffer);
1369         } else {
1370           evalTyped<true, false, false, Unaligned>(buffer);
1371         }
1372       }
1373     } else {
1374       if (this->m_rhs_inner_dim_contiguous) {
1375         if (this->m_rhs_inner_dim_reordered) {
1376           evalTyped<false, true, true, Unaligned>(buffer);
1377         } else {
1378           evalTyped<false, true, false, Unaligned>(buffer);
1379         }
1380       } else {
1381         if (this->m_rhs_inner_dim_reordered) {
1382           evalTyped<false, false, true, Unaligned>(buffer);
1383         } else {
1384           evalTyped<false, false, false, Unaligned>(buffer);
1385         }
1386       }
1387     }
1388   }
1389 
1390   template <bool lhs_inner_dim_contiguous, bool rhs_inner_dim_contiguous, bool rhs_inner_dim_reordered, int Alignment>
1391   void evalTyped(typename Base::EvaluatorPointerType buffer) const {
1392     const auto triple_dim = TripleDim{this->m_i_size, this->m_j_size, this->m_k_size};
1393     typedef internal::TensorContractionInputMapper<
1394         LhsScalar, StorageIndex, internal::Lhs, LeftEvaluator, left_nocontract_t, contract_t,
1395         PacketType<CoeffReturnType, Device>::size, lhs_inner_dim_contiguous, false, Unaligned, MakeSYCLPointer>
1396         LhsMapper;
1397 
1398     typedef internal::TensorContractionInputMapper<RhsScalar, StorageIndex, internal::Rhs, RightEvaluator,
1399                                                    right_nocontract_t, contract_t,
1400                                                    PacketType<CoeffReturnType, Device>::size, rhs_inner_dim_contiguous,
1401                                                    rhs_inner_dim_reordered, Unaligned, MakeSYCLPointer>
1402         RhsMapper;
1403 
1404     // initialize data mappers
1405     LhsMapper lhs(this->m_leftImpl, this->m_left_nocontract_strides, this->m_i_strides,
1406                   this->m_left_contracting_strides, this->m_k_strides);
1407 
1408     RhsMapper rhs(this->m_rightImpl, this->m_right_nocontract_strides, this->m_j_strides,
1409                   this->m_right_contracting_strides, this->m_k_strides);
1410 
1411 #ifndef EIGEN_SYCL_DISABLE_SCALAR
1412     if (triple_dim.M == 1 && triple_dim.N == 1) {
1413       launchSC(buffer, lhs, rhs, triple_dim.K);
1414     } else
1415 #endif
1416 #ifndef EIGEN_SYCL_DISABLE_GEMV
1417         if (triple_dim.M != 1 && triple_dim.N == 1) {
1418       LaunchVT<false>(buffer, rhs, lhs, triple_dim.M, triple_dim.K);
1419     } else if (triple_dim.M == 1 && triple_dim.N != 1) {
1420       LaunchVT<true>(buffer, lhs, rhs, triple_dim.N, triple_dim.K);
1421     } else  // This is equivalent of if (m!=1 && n!=1)
1422 #endif
1423     {
1424       typedef input_mapper_propertis<lhs_inner_dim_contiguous, rhs_inner_dim_contiguous, rhs_inner_dim_reordered>
1425           inpt_mapper_properties;
1426 #ifndef EIGEN_SYCL_DISABLE_SKINNY
1427       bool skinny = false;
1428       auto platform_name = this->device().getPlatformName();
1429       // This is based on empirical calculation for AMD r9-nano and Fiji
1430       if (platform_name.find("AMD") == 0) {
1431         skinny = (triple_dim.M < triple_dim.K || triple_dim.N < triple_dim.K) &&
1432                  ((triple_dim.M < 1024 && triple_dim.N < 1024) ||
1433                   (uint64_t(triple_dim.M * triple_dim.N) < uint64_t(triple_dim.K)));
1434       } else {
1435         skinny = (((std::max(triple_dim.K, triple_dim.N) / std::min(triple_dim.K, triple_dim.N)) > 100) ||
1436                   ((std::max(triple_dim.K, triple_dim.M) / std::min(triple_dim.K, triple_dim.M)) > 100) ||
1437                   ((std::max(triple_dim.N, triple_dim.M) / std::min(triple_dim.N, triple_dim.M)) > 100));
1438       }
1439       if (skinny)
1440         adjustTT<true, inpt_mapper_properties>(buffer, lhs, rhs, triple_dim);
1441       else
1442 #endif  // EIGEN_SYCL_DISABLE_SKINNY
1443         adjustTT<false, inpt_mapper_properties>(buffer, lhs, rhs, triple_dim);
1444     }
1445   }
1446 
1447   template <bool skinny, typename input_mapper_properties, typename LhsMapper, typename RhsMapper>
1448   void EIGEN_ALWAYS_INLINE adjustTT(EvaluatorPointerType buffer, const LhsMapper &lhs, const RhsMapper &rhs,
1449                                     const TripleDim &triple_dim) const {
1450 #ifdef EIGEN_SYCL_LOCAL_MEM_UNSET_OR_ON
1451     if (device().has_local_memory()) {
1452       typedef TensorSycl::internal::TTPanelSize<CoeffReturnType, StorageIndex, 4, 4, 16> PanelParameters;
1453       launchTT<TensorSycl::internal::contraction_type::local, skinny, input_mapper_properties, PanelParameters>(
1454           buffer, lhs, rhs, triple_dim);
1455     }
1456 #endif
1457 #ifdef EIGEN_SYCL_LOCAL_MEM_UNSET_OR_OFF
1458     if (!(device().has_local_memory())) {
1459       typedef TensorSycl::internal::TTPanelSize<CoeffReturnType, StorageIndex, 4, 4, 4> PanelParameters;
1460       launchTT<TensorSycl::internal::contraction_type::no_local, skinny, input_mapper_properties, PanelParameters>(
1461           buffer, lhs, rhs, triple_dim);
1462     }
1463 #endif
1464   }
1465 
1466   template <TensorSycl::internal::contraction_type ct, bool skinny, typename input_mapper_properties,
1467             typename Properties, typename LhsMapper, typename RhsMapper>
1468   void launchTT(EvaluatorPointerType buffer, const LhsMapper &lhs, const RhsMapper &rhs,
1469                 const TripleDim &triple_dim) const {
1470     const StorageIndex roundUpM = Eigen::TensorSycl::internal::roundUp(triple_dim.M, Properties::TileSizeDimM);
1471     const StorageIndex roundUpN = Eigen::TensorSycl::internal::roundUp(triple_dim.N, Properties::TileSizeDimN);
1472     const StorageIndex groupSizeM = roundUpM / Properties::TileSizeDimM;
1473     const StorageIndex groupSizeN = roundUpN / Properties::TileSizeDimN;
1474 
1475     const StorageIndex roundUpK = Eigen::TensorSycl::internal::roundUp(triple_dim.K, Properties::TileSizeDimK);
1476     StorageIndex totalTilesK = roundUpK / Properties::TileSizeDimK;
1477     StorageIndex groupSizeK =
1478         skinny
1479             ? std::max(std::min(totalTilesK,
1480                                 (StorageIndex)(device().getPowerOfTwo(device().getNumSyclMultiProcessors(), true) * 4) /
1481                                     (groupSizeM * groupSizeN)),
1482                        StorageIndex(1))
1483             : StorageIndex(1);
1484 
1485     const StorageIndex numTilesPerGroup = Eigen::TensorSycl::internal::roundUp(totalTilesK, groupSizeK) / groupSizeK;
1486 
1487     const StorageIndex totalGroupSize = groupSizeM * groupSizeN * groupSizeK;
1488 
1489     const StorageIndex localRange = Properties::LocalThreadSizeM * Properties::LocalThreadSizeN;
1490     const StorageIndex globalRange = totalGroupSize * localRange;
1491 
1492     const StorageIndex scratchSize = (ct == TensorSycl::internal::contraction_type::local)
1493                                          ? ((Properties::DoubleBuffer + 1) *
1494                                             (Properties::TileSizeDimM + Properties::BC) * (Properties::TileSizeDimK)) +
1495                                                ((Properties::DoubleBuffer + 1) * (Properties::TileSizeDimK) *
1496                                                 (Properties::TileSizeDimN + Properties::BC))
1497                                          : StorageIndex(1);
1498 
1499     auto thread_range = cl::sycl::nd_range<1>(cl::sycl::range<1>(globalRange), cl::sycl::range<1>(localRange));
1500     if (groupSizeK == 1) {
1501       typedef TensorSycl::internal::TensorContractionKernel<CoeffReturnType, LhsScalar, RhsScalar, EvaluatorPointerType,
1502                                                             LhsMapper, RhsMapper, StorageIndex, Properties, TripleDim,
1503                                                             PacketAccess, input_mapper_properties, true, ct>
1504           ContractKernelName;
1505       device().template binary_kernel_launcher<CoeffReturnType, ContractKernelName>(
1506           lhs, rhs, buffer, thread_range, scratchSize, groupSizeM, groupSizeN, numTilesPerGroup, triple_dim);
1507     } else {
1508       typedef TensorSycl::internal::TensorContractionKernel<CoeffReturnType, LhsScalar, RhsScalar, EvaluatorPointerType,
1509                                                             LhsMapper, RhsMapper, StorageIndex, Properties, TripleDim,
1510                                                             PacketAccess, input_mapper_properties, false, ct>
1511           ContractKernelName;
1512       CoeffReturnType *temp_pointer = static_cast<CoeffReturnType *>(
1513           device().allocate_temp(triple_dim.M * triple_dim.N * groupSizeK * sizeof(CoeffReturnType)));
1514       EvaluatorPointerType tmp_global_accessor = device().get(temp_pointer);
1515 
1516       device().template binary_kernel_launcher<CoeffReturnType, ContractKernelName>(
1517           lhs, rhs, tmp_global_accessor, thread_range, scratchSize, groupSizeM, groupSizeN, numTilesPerGroup,
1518           triple_dim);
1519 
1520       typedef Eigen::internal::SumReducer<CoeffReturnType> Op;
1521       auto op = Op();
1522       typedef TensorSycl::internal::SecondStepPartialReduction<CoeffReturnType, StorageIndex, EvaluatorPointerType,
1523                                                                EvaluatorPointerType, Op>
1524           ReductionKernel;
1525 
1526       device().template unary_kernel_launcher<CoeffReturnType, ReductionKernel>(
1527           tmp_global_accessor, buffer,
1528           cl::sycl::nd_range<1>(cl::sycl::range<1>(StorageIndex(
1529                                     Eigen::TensorSycl::internal::roundUp(triple_dim.M * triple_dim.N, localRange))),
1530                                 cl::sycl::range<1>(localRange)),
1531           StorageIndex(1), op, StorageIndex(triple_dim.M * triple_dim.N), groupSizeK);
1532 
1533       device().deallocate_temp(temp_pointer);
1534     }
1535   }
1536 
1537 #ifndef EIGEN_SYCL_DISABLE_GEMV
1538   template <bool is_lhs_vec, typename VectorMapper, typename TensorMapper, typename StorageIndex>
1539   void EIGEN_ALWAYS_INLINE LaunchVT(EvaluatorPointerType buffer, const VectorMapper &vec, const TensorMapper &mat,
1540                                     StorageIndex NC, StorageIndex C) const {
1541     const StorageIndex nonContractDim = NC;
1542     EIGEN_CONSTEXPR StorageIndex NCFactor = 1;
1543     EIGEN_CONSTEXPR StorageIndex CFactor = 1;
1544     EIGEN_CONSTEXPR StorageIndex NCWindow = 16;
1545     typedef Eigen::TensorSycl::internal::TVPanelSize<CoeffReturnType, StorageIndex, NCWindow, CFactor, NCFactor>
1546         Properties;
1547     const StorageIndex roundUpC = Eigen::TensorSycl::internal::roundUp(C, Properties::TileSizeDimC);
1548     const StorageIndex cNumGroups = roundUpC / (Properties::LocalThreadSizeC * Properties::WorkLoadPerThreadC);
1549     const StorageIndex roundUpNC = Eigen::TensorSycl::internal::roundUp(nonContractDim, Properties::TileSizeDimNC);
1550     const StorageIndex nCNumGroups = roundUpNC / (Properties::LocalThreadSizeNC * Properties::WorkLoadPerThreadNC);
1551     const StorageIndex globalRange =
1552         (roundUpNC / (Properties::WorkLoadPerThreadNC)) * (roundUpC / (Properties::WorkLoadPerThreadC));
1553     const StorageIndex localRange = Properties::LocalThreadSizeNC * Properties::LocalThreadSizeC;
1554     const StorageIndex scratchSize =
1555         (Properties::WorkLoadPerThreadNC + CFactor) * Properties::LocalThreadSizeC * Properties::LocalThreadSizeNC;
1556     auto thread_range = cl::sycl::nd_range<1>(cl::sycl::range<1>(globalRange), cl::sycl::range<1>(localRange));
1557     if (cNumGroups > 1) {
1558       typedef Eigen::TensorSycl::internal::GeneralVectorTensor<CoeffReturnType, EvaluatorPointerType, VectorMapper,
1559                                                                TensorMapper, StorageIndex, Properties, CFactor, false,
1560                                                                is_lhs_vec, false>
1561           ContractKernelName;
1562       CoeffReturnType *temp_pointer =
1563           static_cast<CoeffReturnType *>(device().allocate_temp(nonContractDim * cNumGroups * sizeof(CoeffReturnType)));
1564       EvaluatorPointerType tmp_global_accessor = device().get(temp_pointer);
1565 
1566       device().template binary_kernel_launcher<CoeffReturnType, ContractKernelName>(
1567           vec, mat, tmp_global_accessor, thread_range, scratchSize, nCNumGroups, nonContractDim, C);
1568 
1569       typedef Eigen::internal::SumReducer<CoeffReturnType> Op;
1570       typedef TensorSycl::internal::SecondStepPartialReduction<CoeffReturnType, StorageIndex, EvaluatorPointerType,
1571                                                                EvaluatorPointerType, Op>
1572           ReductionKernel;
1573 
1574       device().template unary_kernel_launcher<CoeffReturnType, ReductionKernel>(
1575           tmp_global_accessor, buffer,
1576           cl::sycl::nd_range<1>(cl::sycl::range<1>(Eigen::TensorSycl::internal::roundUp(nonContractDim, localRange)),
1577                                 cl::sycl::range<1>(localRange)),
1578           StorageIndex(1), Op(), nonContractDim, cNumGroups);
1579 
1580       device().deallocate_temp(temp_pointer);
1581     } else {
1582       typedef Eigen::TensorSycl::internal::GeneralVectorTensor<CoeffReturnType, EvaluatorPointerType, VectorMapper,
1583                                                                TensorMapper, StorageIndex, Properties, CFactor, false,
1584                                                                is_lhs_vec, true>
1585           ContractKernelName;
1586       device().template binary_kernel_launcher<CoeffReturnType, ContractKernelName>(
1587           vec, mat, buffer, thread_range, scratchSize, nCNumGroups, nonContractDim, C);
1588     }
1589   }
1590 #endif
1591 
1592 #ifndef EIGEN_SYCL_DISABLE_SCALAR
1593   template <typename LhsMapper, typename RhsMapper>
1594   EIGEN_ALWAYS_INLINE void launchSC(EvaluatorPointerType buffer, const LhsMapper &lhs, const RhsMapper &rhs,
1595                                     StorageIndex K) const {
1596     EIGEN_STATIC_ASSERT(!((EIGEN_SYCL_LOCAL_THREAD_DIM0 * EIGEN_SYCL_LOCAL_THREAD_DIM1) &
1597                           (EIGEN_SYCL_LOCAL_THREAD_DIM0 * EIGEN_SYCL_LOCAL_THREAD_DIM1 - 1)),
1598                         "The Local thread size must be a power of 2 for the reduction "
1599                         "operation");
1600     EIGEN_CONSTEXPR StorageIndex local_range = EIGEN_SYCL_LOCAL_THREAD_DIM0 * EIGEN_SYCL_LOCAL_THREAD_DIM1;
1601 
1602     // Here we force the code not to be more than 2-step reduction: Our empirical research shows that if each thread
1603     // reduces at least 512 elementss individually, we get better performance.
1604     const StorageIndex num_work_group = ((K + (512 * local_range - 1)) / (512 * local_range) > 1 ? local_range : 1);
1605     const StorageIndex global_range = num_work_group * local_range;
1606 
1607     typedef Eigen::TensorSycl::internal::GeneralScalarContraction<
1608         CoeffReturnType, LhsScalar, RhsScalar, EvaluatorPointerType, LhsMapper, RhsMapper, StorageIndex, false>
1609         ContractKernelName;
1610     auto thread_range = cl::sycl::nd_range<1>(cl::sycl::range<1>(global_range), cl::sycl::range<1>(local_range));
1611     if (num_work_group > 1) {
1612       CoeffReturnType *temp_pointer =
1613           static_cast<CoeffReturnType *>(device().allocate_temp(num_work_group * sizeof(CoeffReturnType)));
1614       EvaluatorPointerType tmp_global_accessor = device().get(temp_pointer);
1615       device().template binary_kernel_launcher<CoeffReturnType, ContractKernelName>(lhs, rhs, tmp_global_accessor,
1616                                                                                     thread_range, local_range, K);
1617       typedef Eigen::internal::SumReducer<CoeffReturnType> Op;
1618       typedef TensorSycl::internal::SecondStepFullReducer<CoeffReturnType, Op, EvaluatorPointerType,
1619                                                           EvaluatorPointerType, StorageIndex, local_range>
1620           GenericRKernel;
1621       device().template unary_kernel_launcher<CoeffReturnType, GenericRKernel>(
1622           tmp_global_accessor, buffer,
1623           cl::sycl::nd_range<1>(cl::sycl::range<1>(local_range), cl::sycl::range<1>(local_range)), local_range, Op());
1624 
1625       device().deallocate_temp(temp_pointer);
1626     } else {
1627       device().template binary_kernel_launcher<CoeffReturnType, ContractKernelName>(lhs, rhs, buffer, thread_range,
1628                                                                                     local_range, K);
1629     }
1630   }
1631 #endif
1632 
1633   EIGEN_STRONG_INLINE void cleanup() {
1634     this->m_leftImpl.cleanup();
1635     this->m_rightImpl.cleanup();
1636 
1637     if (this->m_result) {
1638       this->m_device.deallocate_temp(this->m_result);
1639       this->m_result = NULL;
1640     }
1641   }
1642   // The placeholder accessors must bound to a command group handler for SYCL
1643   EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void bind(cl::sycl::handler &cgh) const {
1644     this->m_leftImpl.bind(cgh);
1645     this->m_rightImpl.bind(cgh);
1646     this->m_result.bind(cgh);
1647   }
1648 };
1649 }  // namespace Eigen
1650 #endif  // EIGEN_CXX11_TENSOR_TENSOR_CONTRACTION_SYCL_H
1651