1 // This file is part of Eigen, a lightweight C++ template library 2 // for linear algebra. 3 // 4 // Copyright (C) 2014 Benoit Steiner <[email protected]> 5 // 6 // This Source Code Form is subject to the terms of the Mozilla 7 // Public License v. 2.0. If a copy of the MPL was not distributed 8 // with this file, You can obtain one at http://mozilla.org/MPL/2.0/. 9 10 #ifndef EIGEN_CXX11_TENSOR_TENSOR_DIMENSIONS_H 11 #define EIGEN_CXX11_TENSOR_TENSOR_DIMENSIONS_H 12 13 14 namespace Eigen { 15 16 /** \internal 17 * 18 * \class TensorDimensions 19 * \ingroup CXX11_Tensor_Module 20 * 21 * \brief Set of classes used to encode and store the dimensions of a Tensor. 22 * 23 * The Sizes class encodes as part of the type the number of dimensions and the 24 * sizes corresponding to each dimension. It uses no storage space since it is 25 * entirely known at compile time. 26 * The DSizes class is its dynamic sibling: the number of dimensions is known 27 * at compile time but the sizes are set during execution. 28 * 29 * \sa Tensor 30 */ 31 32 // Boilerplate code 33 namespace internal { 34 35 template<std::ptrdiff_t n, typename Dimension> struct dget { 36 static const std::ptrdiff_t value = get<n, Dimension>::value; 37 }; 38 39 40 template<typename Index, std::ptrdiff_t NumIndices, std::ptrdiff_t n, bool RowMajor> 41 struct fixed_size_tensor_index_linearization_helper 42 { 43 template <typename Dimensions> EIGEN_DEVICE_FUNC runfixed_size_tensor_index_linearization_helper44 static EIGEN_STRONG_INLINE Index run(array<Index, NumIndices> const& indices, 45 const Dimensions& dimensions) 46 { 47 return array_get<RowMajor ? n - 1 : (NumIndices - n)>(indices) + 48 dget<RowMajor ? n - 1 : (NumIndices - n), Dimensions>::value * 49 fixed_size_tensor_index_linearization_helper<Index, NumIndices, n - 1, RowMajor>::run(indices, dimensions); 50 } 51 }; 52 53 template<typename Index, std::ptrdiff_t NumIndices, bool RowMajor> 54 struct fixed_size_tensor_index_linearization_helper<Index, NumIndices, 0, RowMajor> 55 { 56 template <typename Dimensions> EIGEN_DEVICE_FUNC 57 static EIGEN_STRONG_INLINE Index run(array<Index, NumIndices> const&, const Dimensions&) 58 { 59 return 0; 60 } 61 }; 62 63 template<typename Index, std::ptrdiff_t n> 64 struct fixed_size_tensor_index_extraction_helper 65 { 66 template <typename Dimensions> EIGEN_DEVICE_FUNC 67 static EIGEN_STRONG_INLINE Index run(const Index index, 68 const Dimensions& dimensions) 69 { 70 const Index mult = (index == n-1) ? 1 : 0; 71 return array_get<n-1>(dimensions) * mult + 72 fixed_size_tensor_index_extraction_helper<Index, n - 1>::run(index, dimensions); 73 } 74 }; 75 76 template<typename Index> 77 struct fixed_size_tensor_index_extraction_helper<Index, 0> 78 { 79 template <typename Dimensions> EIGEN_DEVICE_FUNC 80 static EIGEN_STRONG_INLINE Index run(const Index, 81 const Dimensions&) 82 { 83 return 0; 84 } 85 }; 86 87 } // end namespace internal 88 89 90 // Fixed size 91 #ifndef EIGEN_EMULATE_CXX11_META_H 92 template <typename std::ptrdiff_t... Indices> 93 struct Sizes { 94 typedef internal::numeric_list<std::ptrdiff_t, Indices...> Base; 95 const Base t = Base(); 96 static const std::ptrdiff_t total_size = internal::arg_prod(Indices...); 97 static const ptrdiff_t count = Base::count; 98 99 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE std::ptrdiff_t rank() const { 100 return Base::count; 101 } 102 103 static EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE std::ptrdiff_t TotalSize() { 104 return internal::arg_prod(Indices...); 105 } 106 107 EIGEN_DEVICE_FUNC Sizes() { } 108 template <typename DenseIndex> 109 explicit EIGEN_DEVICE_FUNC Sizes(const array<DenseIndex, Base::count>& /*indices*/) { 110 // todo: add assertion 111 } 112 #if EIGEN_HAS_VARIADIC_TEMPLATES 113 template <typename... DenseIndex> EIGEN_DEVICE_FUNC Sizes(DenseIndex...) { } 114 explicit EIGEN_DEVICE_FUNC Sizes(std::initializer_list<std::ptrdiff_t> /*l*/) { 115 // todo: add assertion 116 } 117 #endif 118 119 template <typename T> Sizes& operator = (const T& /*other*/) { 120 // add assertion failure if the size of other is different 121 return *this; 122 } 123 124 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE std::ptrdiff_t operator[] (const std::ptrdiff_t index) const { 125 return internal::fixed_size_tensor_index_extraction_helper<std::ptrdiff_t, Base::count>::run(index, t); 126 } 127 128 template <typename DenseIndex> EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE 129 ptrdiff_t IndexOfColMajor(const array<DenseIndex, Base::count>& indices) const { 130 return internal::fixed_size_tensor_index_linearization_helper<DenseIndex, Base::count, Base::count, false>::run(indices, t); 131 } 132 template <typename DenseIndex> EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE 133 ptrdiff_t IndexOfRowMajor(const array<DenseIndex, Base::count>& indices) const { 134 return internal::fixed_size_tensor_index_linearization_helper<DenseIndex, Base::count, Base::count, true>::run(indices, t); 135 } 136 }; 137 138 namespace internal { 139 template <typename std::ptrdiff_t... Indices> 140 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE std::ptrdiff_t array_prod(const Sizes<Indices...>&) { 141 return Sizes<Indices...>::total_size; 142 } 143 } 144 145 #else 146 147 template <std::ptrdiff_t n> 148 struct non_zero_size { 149 typedef internal::type2val<std::ptrdiff_t, n> type; 150 }; 151 template <> 152 struct non_zero_size<0> { 153 typedef internal::null_type type; 154 }; 155 156 template <std::ptrdiff_t V1=0, std::ptrdiff_t V2=0, std::ptrdiff_t V3=0, std::ptrdiff_t V4=0, std::ptrdiff_t V5=0> struct Sizes { 157 typedef typename internal::make_type_list<typename non_zero_size<V1>::type, typename non_zero_size<V2>::type, typename non_zero_size<V3>::type, typename non_zero_size<V4>::type, typename non_zero_size<V5>::type >::type Base; 158 static const std::ptrdiff_t count = Base::count; 159 static const std::ptrdiff_t total_size = internal::arg_prod<Base>::value; 160 161 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE ptrdiff_t rank() const { 162 return count; 163 } 164 165 static EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE ptrdiff_t TotalSize() { 166 return internal::arg_prod<Base>::value; 167 } 168 169 Sizes() { } 170 template <typename DenseIndex> 171 explicit Sizes(const array<DenseIndex, Base::count>& /*indices*/) { 172 // todo: add assertion 173 } 174 template <typename T> Sizes& operator = (const T& /*other*/) { 175 // add assertion failure if the size of other is different 176 return *this; 177 } 178 179 #if EIGEN_HAS_VARIADIC_TEMPLATES 180 template <typename... DenseIndex> Sizes(DenseIndex... /*indices*/) { } 181 explicit Sizes(std::initializer_list<std::ptrdiff_t>) { 182 // todo: add assertion 183 } 184 #else 185 EIGEN_DEVICE_FUNC explicit Sizes(const DenseIndex) { 186 } 187 EIGEN_DEVICE_FUNC Sizes(const DenseIndex, const DenseIndex) { 188 } 189 EIGEN_DEVICE_FUNC Sizes(const DenseIndex, const DenseIndex, const DenseIndex) { 190 } 191 EIGEN_DEVICE_FUNC Sizes(const DenseIndex, const DenseIndex, const DenseIndex, const DenseIndex) { 192 } 193 EIGEN_DEVICE_FUNC Sizes(const DenseIndex, const DenseIndex, const DenseIndex, const DenseIndex, const DenseIndex) { 194 } 195 #endif 196 197 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Index operator[] (const Index index) const { 198 switch (index) { 199 case 0: 200 return internal::get<0, Base>::value; 201 case 1: 202 return internal::get<1, Base>::value; 203 case 2: 204 return internal::get<2, Base>::value; 205 case 3: 206 return internal::get<3, Base>::value; 207 case 4: 208 return internal::get<4, Base>::value; 209 default: 210 eigen_assert(false && "index overflow"); 211 return static_cast<Index>(-1); 212 } 213 } 214 215 template <typename DenseIndex> EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE 216 ptrdiff_t IndexOfColMajor(const array<DenseIndex, Base::count>& indices) const { 217 return internal::fixed_size_tensor_index_linearization_helper<DenseIndex, Base::count, Base::count, false>::run(indices, *reinterpret_cast<const Base*>(this)); 218 } 219 template <typename DenseIndex> EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE 220 ptrdiff_t IndexOfRowMajor(const array<DenseIndex, Base::count>& indices) const { 221 return internal::fixed_size_tensor_index_linearization_helper<DenseIndex, Base::count, Base::count, true>::run(indices, *reinterpret_cast<const Base*>(this)); 222 } 223 }; 224 225 namespace internal { 226 template <std::ptrdiff_t V1, std::ptrdiff_t V2, std::ptrdiff_t V3, std::ptrdiff_t V4, std::ptrdiff_t V5> 227 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE std::ptrdiff_t array_prod(const Sizes<V1, V2, V3, V4, V5>&) { 228 return Sizes<V1, V2, V3, V4, V5>::total_size; 229 } 230 } 231 232 #endif 233 234 // Boilerplate 235 namespace internal { 236 template<typename Index, std::ptrdiff_t NumIndices, std::ptrdiff_t n, bool RowMajor> 237 struct tensor_index_linearization_helper 238 { 239 static EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE 240 Index run(array<Index, NumIndices> const& indices, array<Index, NumIndices> const& dimensions) 241 { 242 return array_get<RowMajor ? n : (NumIndices - n - 1)>(indices) + 243 array_get<RowMajor ? n : (NumIndices - n - 1)>(dimensions) * 244 tensor_index_linearization_helper<Index, NumIndices, n - 1, RowMajor>::run(indices, dimensions); 245 } 246 }; 247 248 template<typename Index, std::ptrdiff_t NumIndices, bool RowMajor> 249 struct tensor_index_linearization_helper<Index, NumIndices, 0, RowMajor> 250 { 251 static EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE 252 Index run(array<Index, NumIndices> const& indices, array<Index, NumIndices> const&) 253 { 254 return array_get<RowMajor ? 0 : NumIndices - 1>(indices); 255 } 256 }; 257 } // end namespace internal 258 259 260 261 // Dynamic size 262 template <typename DenseIndex, int NumDims> 263 struct DSizes : array<DenseIndex, NumDims> { 264 typedef array<DenseIndex, NumDims> Base; 265 static const int count = NumDims; 266 267 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Index rank() const { 268 return NumDims; 269 } 270 271 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE DenseIndex TotalSize() const { 272 return (NumDims == 0) ? 1 : internal::array_prod(*static_cast<const Base*>(this)); 273 } 274 275 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE DSizes() { 276 for (int i = 0 ; i < NumDims; ++i) { 277 (*this)[i] = 0; 278 } 279 } 280 EIGEN_DEVICE_FUNC explicit DSizes(const array<DenseIndex, NumDims>& a) : Base(a) { } 281 282 EIGEN_DEVICE_FUNC explicit DSizes(const DenseIndex i0) { 283 eigen_assert(NumDims == 1); 284 (*this)[0] = i0; 285 } 286 287 EIGEN_DEVICE_FUNC DSizes(const DimensionList<DenseIndex, NumDims>& a) { 288 for (int i = 0 ; i < NumDims; ++i) { 289 (*this)[i] = a[i]; 290 } 291 } 292 293 // Enable DSizes index type promotion only if we are promoting to the 294 // larger type, e.g. allow to promote dimensions of type int to long. 295 template<typename OtherIndex> 296 EIGEN_DEVICE_FUNC 297 explicit DSizes(const array<OtherIndex, NumDims>& other, 298 // Default template parameters require c++11. 299 typename internal::enable_if< 300 internal::is_same< 301 DenseIndex, 302 typename internal::promote_index_type< 303 DenseIndex, 304 OtherIndex 305 >::type 306 >::value, void*>::type = 0) { 307 for (int i = 0; i < NumDims; ++i) { 308 (*this)[i] = static_cast<DenseIndex>(other[i]); 309 } 310 } 311 312 #ifdef EIGEN_HAS_INDEX_LIST 313 template <typename FirstType, typename... OtherTypes> 314 EIGEN_DEVICE_FUNC 315 explicit DSizes(const Eigen::IndexList<FirstType, OtherTypes...>& dimensions) { 316 for (int i = 0; i < dimensions.count; ++i) { 317 (*this)[i] = dimensions[i]; 318 } 319 } 320 #endif 321 322 #ifndef EIGEN_EMULATE_CXX11_META_H 323 template <typename std::ptrdiff_t... Indices> 324 EIGEN_DEVICE_FUNC DSizes(const Sizes<Indices...>& a) { 325 for (int i = 0 ; i < NumDims; ++i) { 326 (*this)[i] = a[i]; 327 } 328 } 329 #else 330 template <std::ptrdiff_t V1, std::ptrdiff_t V2, std::ptrdiff_t V3, std::ptrdiff_t V4, std::ptrdiff_t V5> 331 EIGEN_DEVICE_FUNC DSizes(const Sizes<V1, V2, V3, V4, V5>& a) { 332 for (int i = 0 ; i < NumDims; ++i) { 333 (*this)[i] = a[i]; 334 } 335 } 336 #endif 337 338 #if EIGEN_HAS_VARIADIC_TEMPLATES 339 template<typename... IndexTypes> EIGEN_DEVICE_FUNC 340 EIGEN_STRONG_INLINE explicit DSizes(DenseIndex firstDimension, DenseIndex secondDimension, IndexTypes... otherDimensions) : Base({{firstDimension, secondDimension, otherDimensions...}}) { 341 EIGEN_STATIC_ASSERT(sizeof...(otherDimensions) + 2 == NumDims, YOU_MADE_A_PROGRAMMING_MISTAKE) 342 } 343 #else 344 EIGEN_DEVICE_FUNC DSizes(const DenseIndex i0, const DenseIndex i1) { 345 eigen_assert(NumDims == 2); 346 (*this)[0] = i0; 347 (*this)[1] = i1; 348 } 349 EIGEN_DEVICE_FUNC DSizes(const DenseIndex i0, const DenseIndex i1, const DenseIndex i2) { 350 eigen_assert(NumDims == 3); 351 (*this)[0] = i0; 352 (*this)[1] = i1; 353 (*this)[2] = i2; 354 } 355 EIGEN_DEVICE_FUNC DSizes(const DenseIndex i0, const DenseIndex i1, const DenseIndex i2, const DenseIndex i3) { 356 eigen_assert(NumDims == 4); 357 (*this)[0] = i0; 358 (*this)[1] = i1; 359 (*this)[2] = i2; 360 (*this)[3] = i3; 361 } 362 EIGEN_DEVICE_FUNC DSizes(const DenseIndex i0, const DenseIndex i1, const DenseIndex i2, const DenseIndex i3, const DenseIndex i4) { 363 eigen_assert(NumDims == 5); 364 (*this)[0] = i0; 365 (*this)[1] = i1; 366 (*this)[2] = i2; 367 (*this)[3] = i3; 368 (*this)[4] = i4; 369 } 370 #endif 371 372 EIGEN_DEVICE_FUNC DSizes& operator = (const array<DenseIndex, NumDims>& other) { 373 *static_cast<Base*>(this) = other; 374 return *this; 375 } 376 377 // A constexpr would be so much better here 378 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE DenseIndex IndexOfColMajor(const array<DenseIndex, NumDims>& indices) const { 379 return internal::tensor_index_linearization_helper<DenseIndex, NumDims, NumDims - 1, false>::run(indices, *static_cast<const Base*>(this)); 380 } 381 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE DenseIndex IndexOfRowMajor(const array<DenseIndex, NumDims>& indices) const { 382 return internal::tensor_index_linearization_helper<DenseIndex, NumDims, NumDims - 1, true>::run(indices, *static_cast<const Base*>(this)); 383 } 384 }; 385 386 template <typename IndexType, int NumDims> 387 std::ostream& operator<<(std::ostream& os, 388 const DSizes<IndexType, NumDims>& dims) { 389 os << "["; 390 for (int i = 0; i < NumDims; ++i) { 391 if (i > 0) os << ", "; 392 os << dims[i]; 393 } 394 os << "]"; 395 return os; 396 } 397 398 // Boilerplate 399 namespace internal { 400 template<typename Index, std::ptrdiff_t NumIndices, std::ptrdiff_t n, bool RowMajor> 401 struct tensor_vsize_index_linearization_helper 402 { 403 static EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE 404 Index run(array<Index, NumIndices> const& indices, std::vector<DenseIndex> const& dimensions) 405 { 406 return array_get<RowMajor ? n : (NumIndices - n - 1)>(indices) + 407 array_get<RowMajor ? n : (NumIndices - n - 1)>(dimensions) * 408 tensor_vsize_index_linearization_helper<Index, NumIndices, n - 1, RowMajor>::run(indices, dimensions); 409 } 410 }; 411 412 template<typename Index, std::ptrdiff_t NumIndices, bool RowMajor> 413 struct tensor_vsize_index_linearization_helper<Index, NumIndices, 0, RowMajor> 414 { 415 static EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE 416 Index run(array<Index, NumIndices> const& indices, std::vector<DenseIndex> const&) 417 { 418 return array_get<RowMajor ? 0 : NumIndices - 1>(indices); 419 } 420 }; 421 } // end namespace internal 422 423 424 namespace internal { 425 426 template <typename DenseIndex, int NumDims> struct array_size<const DSizes<DenseIndex, NumDims> > { 427 static const ptrdiff_t value = NumDims; 428 }; 429 template <typename DenseIndex, int NumDims> struct array_size<DSizes<DenseIndex, NumDims> > { 430 static const ptrdiff_t value = NumDims; 431 }; 432 #ifndef EIGEN_EMULATE_CXX11_META_H 433 template <typename std::ptrdiff_t... Indices> struct array_size<const Sizes<Indices...> > { 434 static const std::ptrdiff_t value = Sizes<Indices...>::count; 435 }; 436 template <typename std::ptrdiff_t... Indices> struct array_size<Sizes<Indices...> > { 437 static const std::ptrdiff_t value = Sizes<Indices...>::count; 438 }; 439 template <std::ptrdiff_t n, typename std::ptrdiff_t... Indices> EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE std::ptrdiff_t array_get(const Sizes<Indices...>&) { 440 return get<n, internal::numeric_list<std::ptrdiff_t, Indices...> >::value; 441 } 442 template <std::ptrdiff_t n> EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE std::ptrdiff_t array_get(const Sizes<>&) { 443 eigen_assert(false && "should never be called"); 444 return -1; 445 } 446 #else 447 template <std::ptrdiff_t V1, std::ptrdiff_t V2, std::ptrdiff_t V3, std::ptrdiff_t V4, std::ptrdiff_t V5> struct array_size<const Sizes<V1,V2,V3,V4,V5> > { 448 static const ptrdiff_t value = Sizes<V1,V2,V3,V4,V5>::count; 449 }; 450 template <std::ptrdiff_t V1, std::ptrdiff_t V2, std::ptrdiff_t V3, std::ptrdiff_t V4, std::ptrdiff_t V5> struct array_size<Sizes<V1,V2,V3,V4,V5> > { 451 static const ptrdiff_t value = Sizes<V1,V2,V3,V4,V5>::count; 452 }; 453 template <std::ptrdiff_t n, std::ptrdiff_t V1, std::ptrdiff_t V2, std::ptrdiff_t V3, std::ptrdiff_t V4, std::ptrdiff_t V5> EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE std::ptrdiff_t array_get(const Sizes<V1,V2,V3,V4,V5>&) { 454 return get<n, typename Sizes<V1,V2,V3,V4,V5>::Base>::value; 455 } 456 457 #endif 458 459 460 template <typename Dims1, typename Dims2, ptrdiff_t n, ptrdiff_t m> 461 struct sizes_match_below_dim { 462 static EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE bool run(Dims1&, Dims2&) { 463 return false; 464 } 465 }; 466 template <typename Dims1, typename Dims2, ptrdiff_t n> 467 struct sizes_match_below_dim<Dims1, Dims2, n, n> { 468 static EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE bool run(Dims1& dims1, Dims2& dims2) { 469 return (array_get<n-1>(dims1) == array_get<n-1>(dims2)) && 470 sizes_match_below_dim<Dims1, Dims2, n-1, n-1>::run(dims1, dims2); 471 } 472 }; 473 template <typename Dims1, typename Dims2> 474 struct sizes_match_below_dim<Dims1, Dims2, 0, 0> { 475 static EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE bool run(Dims1&, Dims2&) { 476 return true; 477 } 478 }; 479 480 } // end namespace internal 481 482 483 template <typename Dims1, typename Dims2> 484 EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE bool dimensions_match(Dims1 dims1, Dims2 dims2) { 485 return internal::sizes_match_below_dim<Dims1, Dims2, internal::array_size<Dims1>::value, internal::array_size<Dims2>::value>::run(dims1, dims2); 486 } 487 488 } // end namespace Eigen 489 490 #endif // EIGEN_CXX11_TENSOR_TENSOR_DIMENSIONS_H 491