1 // This file is part of Eigen, a lightweight C++ template library 2 // for linear algebra. 3 // 4 // Copyright (C) 2014 Benoit Steiner <[email protected]> 5 // 6 // This Source Code Form is subject to the terms of the Mozilla 7 // Public License v. 2.0. If a copy of the MPL was not distributed 8 // with this file, You can obtain one at http://mozilla.org/MPL/2.0/. 9 10 #ifndef EIGEN_CXX11_TENSOR_TENSOR_CONTRACTION_MAPPER_H 11 #define EIGEN_CXX11_TENSOR_TENSOR_CONTRACTION_MAPPER_H 12 13 namespace Eigen { 14 15 namespace internal { 16 17 enum { 18 Rhs = 0, 19 Lhs = 1 20 }; 21 22 /* 23 * Implementation of the Eigen blas_data_mapper class for tensors. 24 */ 25 /// The make pointer class is used by sycl in order to build the mapper class on the device. For other platform the default make pointer is used which 26 /// is scalar * for CoeffLoader. 27 template <typename Tensor, bool HasRawAccess, template <class> class MakePointer_ = MakePointer> 28 struct CoeffLoader; 29 30 template <typename Scalar, typename Index, int side, typename Tensor, 31 typename nocontract_t, typename contract_t, int packet_size, 32 bool inner_dim_contiguous, bool inner_dim_reordered, int Alignment, 33 template <class> class MakePointer_ = MakePointer> 34 class BaseTensorContractionMapper; 35 36 template <typename Tensor, bool HasRawAccess, template <class> class MakePointer_> 37 struct CoeffLoader { 38 enum { 39 DirectOffsets = false 40 }; 41 CoeffLoaderCoeffLoader42 EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE CoeffLoader(const Tensor& tensor) : m_tensor(tensor) { } 43 offsetBufferCoeffLoader44 EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE void offsetBuffer(typename Tensor::Index) { 45 eigen_assert(false && "unsupported"); 46 } 47 48 EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE const typename MakePointer_<const typename Tensor::Scalar>::Type dataCoeffLoader49 data() const { 50 eigen_assert(false && "unsupported"); 51 return NULL; 52 } 53 coeffCoeffLoader54 EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE typename Tensor::Scalar coeff(typename Tensor::Index index) const { return m_tensor.coeff(index); } 55 56 template<int LoadMode> EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE packetCoeffLoader57 typename Tensor::PacketReturnType packet(typename Tensor::Index index) const 58 { 59 return m_tensor.template packet<LoadMode>(index); 60 } 61 62 #ifdef EIGEN_USE_SYCL 63 // The placeholder accessors require to be bound to a command group handler for SYCL bindCoeffLoader64 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void bind(cl::sycl::handler &cgh) const { 65 m_tensor.bind(cgh); 66 } 67 #endif 68 69 private: 70 const Tensor m_tensor; 71 }; 72 73 template <typename Tensor, template <class> class MakePointer_> 74 struct CoeffLoader<Tensor, true, MakePointer_> { 75 enum { 76 DirectOffsets = true 77 }; 78 79 EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE CoeffLoader(const Tensor& tensor) : m_data(tensor.data()) {} 80 81 EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE void offsetBuffer(typename Tensor::Index offset) { 82 m_data += offset; 83 } 84 85 EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE const typename MakePointer_<const typename Tensor::Scalar>::Type 86 data() const { 87 return m_data; 88 } 89 90 EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE typename Tensor::Scalar coeff(typename Tensor::Index index) const { return loadConstant(m_data+index); } 91 92 template<int LoadMode> EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE 93 typename Tensor::PacketReturnType packet(typename Tensor::Index index) const 94 { 95 return internal::ploadt_ro<typename Tensor::PacketReturnType, LoadMode>(m_data + index); 96 } 97 98 #ifdef EIGEN_USE_SYCL 99 // The placeholder accessors require to be bound to a command group handler for SYCL 100 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void bind(cl::sycl::handler &cgh) const { 101 m_data.bind(cgh); 102 } 103 #endif 104 private: 105 typedef typename Tensor::Scalar Scalar; 106 107 typename MakePointer_<const Scalar>::Type m_data; 108 }; 109 110 template<typename Scalar, typename Index, int side, 111 typename Tensor, 112 typename nocontract_t, typename contract_t, 113 int packet_size, bool inner_dim_contiguous, int Alignment, template <class> class MakePointer_ = MakePointer> 114 class SimpleTensorContractionMapper { 115 public: 116 EIGEN_DEVICE_FUNC 117 SimpleTensorContractionMapper(const Tensor& tensor, 118 const nocontract_t& nocontract_strides, 119 const nocontract_t& ij_strides, 120 const contract_t& contract_strides, 121 const contract_t& k_strides) : 122 m_tensor(tensor), 123 m_nocontract_strides(nocontract_strides), 124 m_ij_strides(ij_strides), 125 m_contract_strides(contract_strides), 126 m_k_strides(k_strides) { } 127 128 enum { 129 DirectOffsets = CoeffLoader<Tensor, Tensor::RawAccess, MakePointer_>::DirectOffsets 130 }; 131 132 EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE void offsetBuffer(typename Tensor::Index offset) { 133 m_tensor.offsetBuffer(offset); 134 } 135 136 EIGEN_DEVICE_FUNC 137 EIGEN_STRONG_INLINE void prefetch(Index /*i*/) { } 138 139 EIGEN_DEVICE_FUNC 140 EIGEN_STRONG_INLINE Scalar operator()(Index row) const { 141 // column major assumption 142 return operator()(row, 0); 143 } 144 145 EIGEN_DEVICE_FUNC 146 EIGEN_STRONG_INLINE Scalar operator()(Index row, Index col) const { 147 return m_tensor.coeff(computeIndex(row, col)); 148 } 149 150 EIGEN_DEVICE_FUNC 151 EIGEN_STRONG_INLINE Index computeIndex(Index row, Index col) const { 152 const bool left = (side == Lhs); 153 EIGEN_UNUSED_VARIABLE(left); // annoying bug in g++8.1: https://gcc.gnu.org/bugzilla/show_bug.cgi?id=85963 154 Index nocontract_val = left ? row : col; 155 Index linidx = 0; 156 EIGEN_UNROLL_LOOP 157 for (int i = static_cast<int>(array_size<nocontract_t>::value) - 1; i > 0; i--) { 158 const Index idx = nocontract_val / m_ij_strides[i]; 159 linidx += idx * m_nocontract_strides[i]; 160 nocontract_val -= idx * m_ij_strides[i]; 161 } 162 if (array_size<typename Tensor::Dimensions>::value > array_size<contract_t>::value) { 163 if (side == Lhs && inner_dim_contiguous) { 164 eigen_assert(m_nocontract_strides[0] == 1); 165 linidx += nocontract_val; 166 } else { 167 linidx += nocontract_val * m_nocontract_strides[0]; 168 } 169 } 170 171 Index contract_val = left ? col : row; 172 if(array_size<contract_t>::value > 0) { 173 EIGEN_UNROLL_LOOP 174 for (int i = static_cast<int>(array_size<contract_t>::value) - 1; i > 0; i--) { 175 const Index idx = contract_val / m_k_strides[i]; 176 linidx += idx * m_contract_strides[i]; 177 contract_val -= idx * m_k_strides[i]; 178 } 179 180 if (side == Rhs && inner_dim_contiguous) { 181 eigen_assert(m_contract_strides[0] == 1); 182 linidx += contract_val; 183 } else { 184 linidx += contract_val * m_contract_strides[0]; 185 } 186 } 187 188 return linidx; 189 } 190 191 EIGEN_DEVICE_FUNC 192 EIGEN_STRONG_INLINE IndexPair<Index> computeIndexPair(Index row, Index col, const Index distance) const { 193 const bool left = (side == Lhs); 194 EIGEN_UNUSED_VARIABLE(left); // annoying bug in g++8.1: https://gcc.gnu.org/bugzilla/show_bug.cgi?id=85963 195 Index nocontract_val[2] = {left ? row : col, left ? row + distance : col}; 196 Index linidx[2] = {0, 0}; 197 if (array_size<typename Tensor::Dimensions>::value > array_size<contract_t>::value) { 198 EIGEN_UNROLL_LOOP 199 for (int i = static_cast<int>(array_size<nocontract_t>::value) - 1; i > 0; i--) { 200 const Index idx0 = nocontract_val[0] / m_ij_strides[i]; 201 const Index idx1 = nocontract_val[1] / m_ij_strides[i]; 202 linidx[0] += idx0 * m_nocontract_strides[i]; 203 linidx[1] += idx1 * m_nocontract_strides[i]; 204 nocontract_val[0] -= idx0 * m_ij_strides[i]; 205 nocontract_val[1] -= idx1 * m_ij_strides[i]; 206 } 207 if (side == Lhs && inner_dim_contiguous) { 208 eigen_assert(m_nocontract_strides[0] == 1); 209 linidx[0] += nocontract_val[0]; 210 linidx[1] += nocontract_val[1]; 211 } else { 212 linidx[0] += nocontract_val[0] * m_nocontract_strides[0]; 213 linidx[1] += nocontract_val[1] * m_nocontract_strides[0]; 214 } 215 } 216 217 Index contract_val[2] = {left ? col : row, left ? col : row + distance}; 218 if (array_size<contract_t>::value> 0) { 219 EIGEN_UNROLL_LOOP 220 for (int i = static_cast<int>(array_size<contract_t>::value) - 1; i > 0; i--) { 221 const Index idx0 = contract_val[0] / m_k_strides[i]; 222 const Index idx1 = contract_val[1] / m_k_strides[i]; 223 linidx[0] += idx0 * m_contract_strides[i]; 224 linidx[1] += idx1 * m_contract_strides[i]; 225 contract_val[0] -= idx0 * m_k_strides[i]; 226 contract_val[1] -= idx1 * m_k_strides[i]; 227 } 228 229 if (side == Rhs && inner_dim_contiguous) { 230 eigen_assert(m_contract_strides[0] == 1); 231 linidx[0] += contract_val[0]; 232 linidx[1] += contract_val[1]; 233 } else { 234 linidx[0] += contract_val[0] * m_contract_strides[0]; 235 linidx[1] += contract_val[1] * m_contract_strides[0]; 236 } 237 } 238 return IndexPair<Index>(linidx[0], linidx[1]); 239 } 240 241 EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE Index firstAligned(Index size) const { 242 // Only claim alignment when we can compute the actual stride (ie when we're 243 // dealing with the lhs with inner_dim_contiguous. This is because the 244 // matrix-vector product relies on the stride when dealing with aligned inputs. 245 return (Alignment == Aligned) && (side == Lhs) && inner_dim_contiguous ? 0 : size; 246 } 247 EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE Index stride() const { 248 return ((side == Lhs) && inner_dim_contiguous && array_size<contract_t>::value > 0) ? m_contract_strides[0] : 1; 249 } 250 251 #ifdef EIGEN_USE_SYCL 252 // The placeholder accessors require to be bound to a command group handler for SYCL 253 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void bind(cl::sycl::handler &cgh) const { 254 m_tensor.bind(cgh); 255 } 256 #endif 257 258 const CoeffLoader<Tensor, Tensor::RawAccess, MakePointer_>& tensor() const { 259 return m_tensor; 260 } 261 262 const nocontract_t& nocontract_strides() const { 263 return m_nocontract_strides; 264 } 265 const nocontract_t& ij_strides() const { return m_ij_strides; } 266 const contract_t& contract_strides() const { return m_contract_strides; } 267 const contract_t& k_strides() const { return m_k_strides; } 268 269 protected: 270 CoeffLoader<Tensor, Tensor::RawAccess, MakePointer_> m_tensor; 271 const nocontract_t m_nocontract_strides; 272 const nocontract_t m_ij_strides; 273 const contract_t m_contract_strides; 274 const contract_t m_k_strides; 275 }; 276 277 template<typename Scalar, typename Index, int side, 278 typename Tensor, 279 typename nocontract_t, typename contract_t, 280 int packet_size, bool inner_dim_contiguous, 281 bool inner_dim_reordered, int Alignment, template <class> class MakePointer_> 282 class BaseTensorContractionMapper : public SimpleTensorContractionMapper<Scalar, Index, side, Tensor, nocontract_t, contract_t, packet_size, inner_dim_contiguous, Alignment, MakePointer_> 283 { 284 public: 285 typedef SimpleTensorContractionMapper<Scalar, Index, side, Tensor, nocontract_t, contract_t, packet_size, inner_dim_contiguous, Alignment, MakePointer_> ParentMapper; 286 287 EIGEN_DEVICE_FUNC 288 BaseTensorContractionMapper(const Tensor& tensor, 289 const nocontract_t& nocontract_strides, 290 const nocontract_t& ij_strides, 291 const contract_t& contract_strides, 292 const contract_t& k_strides) : 293 ParentMapper(tensor, nocontract_strides, ij_strides, contract_strides, k_strides) { } 294 295 template <typename PacketT,int AlignmentType> 296 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE 297 typename internal::enable_if<internal::unpacket_traits<PacketT>::size==packet_size,PacketT>::type 298 load(Index i, Index j) const 299 { 300 // whole method makes column major assumption 301 302 // don't need to add offsets for now (because operator handles that) 303 // current code assumes packet size must be a multiple of 2 304 EIGEN_STATIC_ASSERT(packet_size % 2 == 0, YOU_MADE_A_PROGRAMMING_MISTAKE); 305 306 if (Tensor::PacketAccess && inner_dim_contiguous && !inner_dim_reordered) { 307 const Index index = this->computeIndex(i, j); 308 eigen_assert(this->computeIndex(i+packet_size-1, j) == index + packet_size-1); 309 return this->m_tensor.template packet<AlignmentType>(index); 310 } 311 312 const IndexPair<Index> indexPair = this->computeIndexPair(i, j, packet_size - 1); 313 const Index first = indexPair.first; 314 const Index lastIdx = indexPair.second; 315 316 // We can always do optimized packet reads from left hand side right now, because 317 // the vertical matrix dimension on the left hand side is never contracting. 318 // On the right hand side we need to check if the contracting dimensions may have 319 // been shuffled first. 320 if (Tensor::PacketAccess && 321 (side == Lhs || internal::array_size<contract_t>::value <= 1 || !inner_dim_reordered) && 322 (lastIdx - first) == (packet_size - 1)) { 323 324 return this->m_tensor.template packet<AlignmentType>(first); 325 } 326 327 EIGEN_ALIGN_MAX Scalar data[packet_size]; 328 329 data[0] = this->m_tensor.coeff(first); 330 EIGEN_UNROLL_LOOP 331 for (Index k = 1; k < packet_size - 1; k += 2) { 332 const IndexPair<Index> internal_pair = this->computeIndexPair(i + k, j, 1); 333 data[k] = this->m_tensor.coeff(internal_pair.first); 334 data[k + 1] = this->m_tensor.coeff(internal_pair.second); 335 } 336 data[packet_size - 1] = this->m_tensor.coeff(lastIdx); 337 338 return pload<PacketT>(data); 339 } 340 341 template <typename PacketT,int AlignmentType> 342 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE 343 typename internal::enable_if<internal::unpacket_traits<PacketT>::size!=packet_size,PacketT>::type 344 load(Index i, Index j) const 345 { 346 const Index requested_packet_size = internal::unpacket_traits<PacketT>::size; 347 EIGEN_ALIGN_MAX Scalar data[requested_packet_size]; 348 349 const IndexPair<Index> indexPair = this->computeIndexPair(i, j, requested_packet_size - 1); 350 const Index first = indexPair.first; 351 const Index lastIdx = indexPair.second; 352 353 data[0] = this->m_tensor.coeff(first); 354 for (Index k = 1; k < requested_packet_size - 1; k += 2) { 355 const IndexPair<Index> internal_pair = this->computeIndexPair(i + k, j, 1); 356 data[k] = this->m_tensor.coeff(internal_pair.first); 357 data[k + 1] = this->m_tensor.coeff(internal_pair.second); 358 } 359 data[requested_packet_size - 1] = this->m_tensor.coeff(lastIdx); 360 361 return pload<PacketT>(data); 362 } 363 364 template <typename PacketT,int AlignmentType> 365 EIGEN_DEVICE_FUNC 366 EIGEN_STRONG_INLINE PacketT loadPacket(Index i, Index j) const { 367 return this->load<PacketT,AlignmentType>(i,j); 368 } 369 }; 370 371 372 template<typename Scalar, typename Index, int side, 373 typename Tensor, 374 typename nocontract_t, typename contract_t, 375 bool inner_dim_contiguous, 376 bool inner_dim_reordered, int Alignment, template <class> class MakePointer_> 377 class BaseTensorContractionMapper<Scalar, Index, side, Tensor, nocontract_t, contract_t, 1, inner_dim_contiguous, inner_dim_reordered, Alignment, MakePointer_> 378 : public SimpleTensorContractionMapper<Scalar, Index, side, Tensor, nocontract_t, contract_t, 1, inner_dim_contiguous, Alignment, MakePointer_> 379 { 380 public: 381 typedef SimpleTensorContractionMapper<Scalar, Index, side, Tensor, nocontract_t, contract_t, 1, inner_dim_contiguous, Alignment, MakePointer_> ParentMapper; 382 383 EIGEN_DEVICE_FUNC 384 BaseTensorContractionMapper(const Tensor& tensor, 385 const nocontract_t& nocontract_strides, 386 const nocontract_t& ij_strides, 387 const contract_t& contract_strides, 388 const contract_t& k_strides) : 389 ParentMapper(tensor, nocontract_strides, ij_strides, contract_strides, k_strides) { } 390 391 template <typename PacketT,int> EIGEN_DEVICE_FUNC 392 EIGEN_STRONG_INLINE PacketT loadPacket(Index i, Index j) const { 393 EIGEN_ALIGN_MAX Scalar data[1]; 394 data[0] = this->m_tensor.coeff(this->computeIndex(i, j)); 395 return pload<PacketT>(data); 396 } 397 template <typename PacketT,int> EIGEN_DEVICE_FUNC 398 EIGEN_STRONG_INLINE PacketT load(Index i, Index j) const { 399 EIGEN_ALIGN_MAX Scalar data[1]; 400 data[0] = this->m_tensor.coeff(this->computeIndex(i, j)); 401 return pload<PacketT>(data); 402 } 403 }; 404 405 406 template<typename Scalar, typename Index, int side, 407 typename Tensor, 408 typename nocontract_t, typename contract_t, 409 int packet_size, 410 bool inner_dim_contiguous, bool inner_dim_reordered, int Alignment, template <class> class MakePointer_=MakePointer> 411 class TensorContractionSubMapper { 412 public: 413 414 typedef BaseTensorContractionMapper<Scalar, Index, side, Tensor, nocontract_t, contract_t, packet_size, inner_dim_contiguous, inner_dim_reordered, Alignment, MakePointer_> ParentMapper; 415 typedef TensorContractionSubMapper<Scalar, Index, side, Tensor, nocontract_t, contract_t, packet_size, inner_dim_contiguous, inner_dim_reordered, Alignment, MakePointer_> Self; 416 typedef Self LinearMapper; 417 418 enum { 419 // We can use direct offsets iff the parent mapper supports then and we can compute the strides. 420 // TODO: we should also enable direct offsets for the Rhs case. 421 UseDirectOffsets = ParentMapper::DirectOffsets && (side == Lhs) && inner_dim_contiguous && (array_size<contract_t>::value > 0) 422 }; 423 424 EIGEN_DEVICE_FUNC TensorContractionSubMapper(const ParentMapper& base_mapper, Index vert_offset, Index horiz_offset) 425 : m_base_mapper(base_mapper), m_vert_offset(vert_offset), m_horiz_offset(horiz_offset) { 426 // Bake the offsets into the buffer used by the base mapper whenever possible. This avoids the need to recompute 427 // this offset every time we attempt to access a coefficient. 428 if (UseDirectOffsets) { 429 Index stride = m_base_mapper.stride(); 430 m_base_mapper.offsetBuffer(vert_offset + horiz_offset * stride); 431 } 432 } 433 434 EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE Scalar operator()(Index i) const { 435 if (UseDirectOffsets) { 436 return m_base_mapper(i, 0); 437 } 438 return m_base_mapper(i + m_vert_offset, m_horiz_offset); 439 } 440 EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE Scalar operator()(Index i, Index j) const { 441 if (UseDirectOffsets) { 442 return m_base_mapper(i, j); 443 } 444 return m_base_mapper(i + m_vert_offset, j + m_horiz_offset); 445 } 446 447 template <typename PacketT> 448 EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE PacketT loadPacket(Index i) const { 449 if (UseDirectOffsets) { 450 return m_base_mapper.template loadPacket<PacketT,Alignment>(i, 0); 451 } 452 return m_base_mapper.template loadPacket<PacketT,Alignment>(i + m_vert_offset, m_horiz_offset); 453 } 454 455 template <typename PacketT> 456 EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE PacketT loadPacket(Index i, Index j) const { 457 if (UseDirectOffsets) { 458 return m_base_mapper.template loadPacket<PacketT,Alignment>(i, j); 459 } 460 return m_base_mapper.template loadPacket<PacketT,Alignment>(i + m_vert_offset, j + m_horiz_offset); 461 } 462 463 template <typename PacketT, int AlignmentType> 464 EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE PacketT loadPacket(Index i, Index j) const { 465 if (UseDirectOffsets) { 466 return m_base_mapper.template load<PacketT,AlignmentType>(i, j); 467 } 468 return m_base_mapper.template loadPacket<PacketT,AlignmentType>(i + m_vert_offset, j + m_horiz_offset); 469 } 470 471 template <typename PacketT> 472 EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE void storePacket(Index i, const PacketT& p) const { 473 if (UseDirectOffsets) { 474 m_base_mapper.storePacket(i, 0, p); 475 } 476 m_base_mapper.storePacket(i + m_vert_offset, m_horiz_offset, p); 477 } 478 479 EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE LinearMapper getLinearMapper(Index i, Index j) const { 480 if (UseDirectOffsets) { 481 return LinearMapper(m_base_mapper, i, j); 482 } 483 return LinearMapper(m_base_mapper, i + m_vert_offset, j + m_horiz_offset); 484 } 485 486 template <typename PacketT, int AlignmentType> 487 EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE PacketT load(Index i) const { 488 EIGEN_STATIC_ASSERT((internal::is_same<PacketT, PacketT>::value), YOU_MADE_A_PROGRAMMING_MISTAKE); 489 const int ActualAlignment = (AlignmentType == Aligned) && (Alignment == Aligned) ? Aligned : Unaligned; 490 if (UseDirectOffsets) { 491 return m_base_mapper.template loadPacket<PacketT,ActualAlignment>(i, 0); 492 } 493 return m_base_mapper.template loadPacket<PacketT,ActualAlignment>(i + m_vert_offset, m_horiz_offset); 494 } 495 496 template <typename PacketT> 497 EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE bool aligned(Index) const { 498 return false; 499 } 500 501 #ifdef EIGEN_USE_SYCL 502 // The placeholder accessors require to be bound to a command group handler for SYCL 503 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void bind(cl::sycl::handler &cgh) const { 504 m_base_mapper.bind(cgh); 505 } 506 #endif 507 508 const ParentMapper& base_mapper() const { return m_base_mapper; } 509 Index vert_offset() const { return m_vert_offset; } 510 Index horiz_offset() const { return m_horiz_offset; } 511 512 private: 513 ParentMapper m_base_mapper; 514 const Index m_vert_offset; 515 const Index m_horiz_offset; 516 }; 517 518 519 template<typename Scalar_, typename Index, int side, 520 typename Tensor, 521 typename nocontract_t, typename contract_t, 522 int packet_size, 523 bool inner_dim_contiguous, bool inner_dim_reordered, int Alignment, template <class> class MakePointer_=MakePointer> 524 class TensorContractionInputMapper 525 : public BaseTensorContractionMapper<Scalar_, Index, side, Tensor, nocontract_t, contract_t, packet_size, inner_dim_contiguous, inner_dim_reordered, Alignment, MakePointer_> { 526 527 public: 528 typedef Scalar_ Scalar; 529 typedef BaseTensorContractionMapper<Scalar, Index, side, Tensor, nocontract_t, contract_t, packet_size, inner_dim_contiguous, inner_dim_reordered, Alignment, MakePointer_> Base; 530 typedef TensorContractionSubMapper<Scalar, Index, side, Tensor, nocontract_t, contract_t, packet_size, inner_dim_contiguous, inner_dim_reordered, Alignment, MakePointer_> SubMapper; 531 typedef SubMapper VectorMapper; 532 533 EIGEN_DEVICE_FUNC TensorContractionInputMapper(const Tensor& tensor, 534 const nocontract_t& nocontract_strides, 535 const nocontract_t& ij_strides, 536 const contract_t& contract_strides, 537 const contract_t& k_strides) 538 : Base(tensor, nocontract_strides, ij_strides, contract_strides, k_strides) { } 539 540 EIGEN_DEVICE_FUNC 541 EIGEN_STRONG_INLINE SubMapper getSubMapper(Index i, Index j) const { 542 return SubMapper(*this, i, j); 543 } 544 545 EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE VectorMapper getVectorMapper(Index i, Index j) const { 546 return VectorMapper(*this, i, j); 547 } 548 549 EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE const CoeffLoader<Tensor, Tensor::RawAccess, MakePointer_>& get_tensor() const { 550 return Base::m_tensor; 551 } 552 }; 553 554 555 template <typename T> struct TensorContractionInputMapperTrait; 556 557 template<typename Scalar_, typename Index_, int side_, 558 typename Tensor_, 559 typename nocontract_t_, typename contract_t_, 560 int packet_size_, 561 bool inner_dim_contiguous_, bool inner_dim_reordered_, int Alignment_, template <class> class MakePointer_> 562 struct TensorContractionInputMapperTrait<TensorContractionInputMapper<Scalar_, Index_, side_, Tensor_, 563 nocontract_t_, contract_t_, packet_size_, inner_dim_contiguous_, 564 inner_dim_reordered_, Alignment_, MakePointer_> > { 565 566 typedef Tensor_ XprType; 567 static const bool inner_dim_contiguous = inner_dim_contiguous_; 568 static const bool inner_dim_reordered = inner_dim_reordered_; 569 }; 570 571 572 } // end namespace internal 573 } // end namespace Eigen 574 575 #endif // EIGEN_CXX11_TENSOR_TENSOR_CONTRACTION_MAPPER_H 576