xref: /aosp_15_r20/external/eigen/unsupported/Eigen/CXX11/src/Tensor/TensorDimensions.h (revision bf2c37156dfe67e5dfebd6d394bad8b2ab5804d4)
1 // This file is part of Eigen, a lightweight C++ template library
2 // for linear algebra.
3 //
4 // Copyright (C) 2014 Benoit Steiner <[email protected]>
5 //
6 // This Source Code Form is subject to the terms of the Mozilla
7 // Public License v. 2.0. If a copy of the MPL was not distributed
8 // with this file, You can obtain one at http://mozilla.org/MPL/2.0/.
9 
10 #ifndef EIGEN_CXX11_TENSOR_TENSOR_DIMENSIONS_H
11 #define EIGEN_CXX11_TENSOR_TENSOR_DIMENSIONS_H
12 
13 
14 namespace Eigen {
15 
16 /** \internal
17   *
18   * \class TensorDimensions
19   * \ingroup CXX11_Tensor_Module
20   *
21   * \brief Set of classes used to encode and store the dimensions of a Tensor.
22   *
23   * The Sizes class encodes as part of the type the number of dimensions and the
24   * sizes corresponding to each dimension. It uses no storage space since it is
25   * entirely known at compile time.
26   * The DSizes class is its dynamic sibling: the number of dimensions is known
27   * at compile time but the sizes are set during execution.
28   *
29   * \sa Tensor
30   */
31 
32 // Boilerplate code
33 namespace internal {
34 
35 template<std::ptrdiff_t n, typename Dimension> struct dget {
36   static const std::ptrdiff_t value = get<n, Dimension>::value;
37 };
38 
39 
40 template<typename Index, std::ptrdiff_t NumIndices, std::ptrdiff_t n, bool RowMajor>
41 struct fixed_size_tensor_index_linearization_helper
42 {
43   template <typename Dimensions> EIGEN_DEVICE_FUNC
runfixed_size_tensor_index_linearization_helper44   static EIGEN_STRONG_INLINE Index run(array<Index, NumIndices> const& indices,
45                           const Dimensions& dimensions)
46   {
47     return array_get<RowMajor ? n - 1 : (NumIndices - n)>(indices) +
48         dget<RowMajor ? n - 1 : (NumIndices - n), Dimensions>::value *
49         fixed_size_tensor_index_linearization_helper<Index, NumIndices, n - 1, RowMajor>::run(indices, dimensions);
50   }
51 };
52 
53 template<typename Index, std::ptrdiff_t NumIndices, bool RowMajor>
54 struct fixed_size_tensor_index_linearization_helper<Index, NumIndices, 0, RowMajor>
55 {
56   template <typename Dimensions> EIGEN_DEVICE_FUNC
57   static EIGEN_STRONG_INLINE Index run(array<Index, NumIndices> const&, const Dimensions&)
58   {
59     return 0;
60   }
61 };
62 
63 template<typename Index, std::ptrdiff_t n>
64 struct fixed_size_tensor_index_extraction_helper
65 {
66   template <typename Dimensions> EIGEN_DEVICE_FUNC
67   static EIGEN_STRONG_INLINE Index run(const Index index,
68                           const Dimensions& dimensions)
69   {
70     const Index mult = (index == n-1) ? 1 : 0;
71     return array_get<n-1>(dimensions) * mult +
72         fixed_size_tensor_index_extraction_helper<Index, n - 1>::run(index, dimensions);
73   }
74 };
75 
76 template<typename Index>
77 struct fixed_size_tensor_index_extraction_helper<Index, 0>
78 {
79   template <typename Dimensions> EIGEN_DEVICE_FUNC
80   static EIGEN_STRONG_INLINE Index run(const Index,
81                           const Dimensions&)
82   {
83     return 0;
84   }
85   };
86 
87 }  // end namespace internal
88 
89 
90 // Fixed size
91 #ifndef EIGEN_EMULATE_CXX11_META_H
92 template <typename std::ptrdiff_t... Indices>
93 struct Sizes {
94   typedef internal::numeric_list<std::ptrdiff_t, Indices...> Base;
95   const Base t = Base();
96   static const std::ptrdiff_t total_size = internal::arg_prod(Indices...);
97   static const ptrdiff_t count = Base::count;
98 
99   EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE std::ptrdiff_t rank() const {
100     return Base::count;
101   }
102 
103   static EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE std::ptrdiff_t TotalSize() {
104     return internal::arg_prod(Indices...);
105   }
106 
107   EIGEN_DEVICE_FUNC Sizes() { }
108   template <typename DenseIndex>
109   explicit EIGEN_DEVICE_FUNC Sizes(const array<DenseIndex, Base::count>& /*indices*/) {
110     // todo: add assertion
111   }
112 #if EIGEN_HAS_VARIADIC_TEMPLATES
113   template <typename... DenseIndex> EIGEN_DEVICE_FUNC Sizes(DenseIndex...) { }
114   explicit EIGEN_DEVICE_FUNC Sizes(std::initializer_list<std::ptrdiff_t> /*l*/) {
115     // todo: add assertion
116   }
117 #endif
118 
119   template <typename T> Sizes& operator = (const T& /*other*/) {
120     // add assertion failure if the size of other is different
121     return *this;
122   }
123 
124   EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE std::ptrdiff_t operator[] (const std::ptrdiff_t index) const {
125     return internal::fixed_size_tensor_index_extraction_helper<std::ptrdiff_t, Base::count>::run(index, t);
126   }
127 
128   template <typename DenseIndex> EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE
129   ptrdiff_t IndexOfColMajor(const array<DenseIndex, Base::count>& indices) const {
130     return internal::fixed_size_tensor_index_linearization_helper<DenseIndex, Base::count, Base::count, false>::run(indices, t);
131   }
132   template <typename DenseIndex> EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE
133   ptrdiff_t IndexOfRowMajor(const array<DenseIndex, Base::count>& indices) const {
134     return internal::fixed_size_tensor_index_linearization_helper<DenseIndex, Base::count, Base::count, true>::run(indices, t);
135   }
136 };
137 
138 namespace internal {
139 template <typename std::ptrdiff_t... Indices>
140 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE std::ptrdiff_t array_prod(const Sizes<Indices...>&) {
141   return Sizes<Indices...>::total_size;
142 }
143 }
144 
145 #else
146 
147 template <std::ptrdiff_t n>
148 struct non_zero_size {
149   typedef internal::type2val<std::ptrdiff_t, n> type;
150 };
151 template <>
152 struct non_zero_size<0> {
153   typedef internal::null_type type;
154 };
155 
156 template <std::ptrdiff_t V1=0, std::ptrdiff_t V2=0, std::ptrdiff_t V3=0, std::ptrdiff_t V4=0, std::ptrdiff_t V5=0> struct Sizes {
157   typedef typename internal::make_type_list<typename non_zero_size<V1>::type, typename non_zero_size<V2>::type, typename non_zero_size<V3>::type, typename non_zero_size<V4>::type, typename non_zero_size<V5>::type >::type Base;
158   static const std::ptrdiff_t count = Base::count;
159   static const std::ptrdiff_t total_size = internal::arg_prod<Base>::value;
160 
161   EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE ptrdiff_t rank() const {
162     return count;
163   }
164 
165   static EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE ptrdiff_t TotalSize() {
166     return internal::arg_prod<Base>::value;
167   }
168 
169   Sizes() { }
170   template <typename DenseIndex>
171   explicit Sizes(const array<DenseIndex, Base::count>& /*indices*/) {
172     // todo: add assertion
173   }
174   template <typename T> Sizes& operator = (const T& /*other*/) {
175     // add assertion failure if the size of other is different
176     return *this;
177   }
178 
179 #if EIGEN_HAS_VARIADIC_TEMPLATES
180   template <typename... DenseIndex> Sizes(DenseIndex... /*indices*/) { }
181   explicit Sizes(std::initializer_list<std::ptrdiff_t>) {
182     // todo: add assertion
183   }
184 #else
185   EIGEN_DEVICE_FUNC explicit Sizes(const DenseIndex) {
186   }
187   EIGEN_DEVICE_FUNC Sizes(const DenseIndex, const DenseIndex) {
188   }
189   EIGEN_DEVICE_FUNC Sizes(const DenseIndex, const DenseIndex, const DenseIndex) {
190   }
191   EIGEN_DEVICE_FUNC Sizes(const DenseIndex, const DenseIndex, const DenseIndex, const DenseIndex) {
192   }
193   EIGEN_DEVICE_FUNC Sizes(const DenseIndex, const DenseIndex, const DenseIndex, const DenseIndex, const DenseIndex) {
194   }
195 #endif
196 
197   EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Index operator[] (const Index index) const {
198     switch (index) {
199       case 0:
200         return internal::get<0, Base>::value;
201       case 1:
202         return internal::get<1, Base>::value;
203       case 2:
204         return internal::get<2, Base>::value;
205       case 3:
206         return internal::get<3, Base>::value;
207       case 4:
208         return internal::get<4, Base>::value;
209       default:
210         eigen_assert(false && "index overflow");
211         return static_cast<Index>(-1);
212     }
213   }
214 
215   template <typename DenseIndex> EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE
216   ptrdiff_t IndexOfColMajor(const array<DenseIndex, Base::count>& indices) const {
217     return internal::fixed_size_tensor_index_linearization_helper<DenseIndex, Base::count, Base::count, false>::run(indices, *reinterpret_cast<const Base*>(this));
218   }
219   template <typename DenseIndex> EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE
220   ptrdiff_t IndexOfRowMajor(const array<DenseIndex, Base::count>& indices) const {
221     return internal::fixed_size_tensor_index_linearization_helper<DenseIndex, Base::count, Base::count, true>::run(indices, *reinterpret_cast<const Base*>(this));
222   }
223 };
224 
225 namespace internal {
226 template <std::ptrdiff_t V1, std::ptrdiff_t V2, std::ptrdiff_t V3, std::ptrdiff_t V4, std::ptrdiff_t V5>
227 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE std::ptrdiff_t array_prod(const Sizes<V1, V2, V3, V4, V5>&) {
228   return Sizes<V1, V2, V3, V4, V5>::total_size;
229 }
230 }
231 
232 #endif
233 
234 // Boilerplate
235 namespace internal {
236 template<typename Index, std::ptrdiff_t NumIndices, std::ptrdiff_t n, bool RowMajor>
237 struct tensor_index_linearization_helper
238 {
239   static EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE
240   Index run(array<Index, NumIndices> const& indices, array<Index, NumIndices> const& dimensions)
241   {
242     return array_get<RowMajor ? n : (NumIndices - n - 1)>(indices) +
243       array_get<RowMajor ? n : (NumIndices - n - 1)>(dimensions) *
244         tensor_index_linearization_helper<Index, NumIndices, n - 1, RowMajor>::run(indices, dimensions);
245   }
246 };
247 
248 template<typename Index, std::ptrdiff_t NumIndices, bool RowMajor>
249 struct tensor_index_linearization_helper<Index, NumIndices, 0, RowMajor>
250 {
251   static EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE
252   Index run(array<Index, NumIndices> const& indices, array<Index, NumIndices> const&)
253   {
254     return array_get<RowMajor ? 0 : NumIndices - 1>(indices);
255   }
256 };
257 }  // end namespace internal
258 
259 
260 
261 // Dynamic size
262 template <typename DenseIndex, int NumDims>
263 struct DSizes : array<DenseIndex, NumDims> {
264   typedef array<DenseIndex, NumDims> Base;
265   static const int count = NumDims;
266 
267   EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Index rank() const {
268     return NumDims;
269   }
270 
271   EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE DenseIndex TotalSize() const {
272     return (NumDims == 0) ? 1 : internal::array_prod(*static_cast<const Base*>(this));
273   }
274 
275   EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE DSizes() {
276     for (int i = 0 ; i < NumDims; ++i) {
277       (*this)[i] = 0;
278     }
279   }
280   EIGEN_DEVICE_FUNC explicit DSizes(const array<DenseIndex, NumDims>& a) : Base(a) { }
281 
282   EIGEN_DEVICE_FUNC explicit DSizes(const DenseIndex i0) {
283     eigen_assert(NumDims == 1);
284     (*this)[0] = i0;
285   }
286 
287   EIGEN_DEVICE_FUNC DSizes(const DimensionList<DenseIndex, NumDims>& a) {
288     for (int i = 0 ; i < NumDims; ++i) {
289       (*this)[i] = a[i];
290     }
291   }
292 
293   // Enable DSizes index type promotion only if we are promoting to the
294   // larger type, e.g. allow to promote dimensions of type int to long.
295   template<typename OtherIndex>
296   EIGEN_DEVICE_FUNC
297   explicit DSizes(const array<OtherIndex, NumDims>& other,
298                   // Default template parameters require c++11.
299                   typename internal::enable_if<
300                      internal::is_same<
301                          DenseIndex,
302                          typename internal::promote_index_type<
303                              DenseIndex,
304                              OtherIndex
305                          >::type
306                      >::value, void*>::type = 0) {
307     for (int i = 0; i < NumDims; ++i) {
308       (*this)[i] = static_cast<DenseIndex>(other[i]);
309     }
310   }
311 
312 #ifdef EIGEN_HAS_INDEX_LIST
313   template <typename FirstType, typename... OtherTypes>
314   EIGEN_DEVICE_FUNC
315   explicit DSizes(const Eigen::IndexList<FirstType, OtherTypes...>& dimensions) {
316     for (int i = 0; i < dimensions.count; ++i) {
317       (*this)[i] = dimensions[i];
318     }
319   }
320 #endif
321 
322 #ifndef EIGEN_EMULATE_CXX11_META_H
323   template <typename std::ptrdiff_t... Indices>
324   EIGEN_DEVICE_FUNC DSizes(const Sizes<Indices...>& a) {
325     for (int i = 0 ; i < NumDims; ++i) {
326       (*this)[i] = a[i];
327     }
328   }
329 #else
330   template <std::ptrdiff_t V1, std::ptrdiff_t V2, std::ptrdiff_t V3, std::ptrdiff_t V4, std::ptrdiff_t V5>
331   EIGEN_DEVICE_FUNC DSizes(const Sizes<V1, V2, V3, V4, V5>& a) {
332     for (int i = 0 ; i < NumDims; ++i) {
333       (*this)[i] = a[i];
334     }
335   }
336 #endif
337 
338 #if EIGEN_HAS_VARIADIC_TEMPLATES
339   template<typename... IndexTypes> EIGEN_DEVICE_FUNC
340   EIGEN_STRONG_INLINE explicit DSizes(DenseIndex firstDimension, DenseIndex secondDimension, IndexTypes... otherDimensions) : Base({{firstDimension, secondDimension, otherDimensions...}}) {
341     EIGEN_STATIC_ASSERT(sizeof...(otherDimensions) + 2 == NumDims, YOU_MADE_A_PROGRAMMING_MISTAKE)
342   }
343 #else
344   EIGEN_DEVICE_FUNC DSizes(const DenseIndex i0, const DenseIndex i1) {
345     eigen_assert(NumDims == 2);
346     (*this)[0] = i0;
347     (*this)[1] = i1;
348   }
349   EIGEN_DEVICE_FUNC DSizes(const DenseIndex i0, const DenseIndex i1, const DenseIndex i2) {
350     eigen_assert(NumDims == 3);
351     (*this)[0] = i0;
352     (*this)[1] = i1;
353     (*this)[2] = i2;
354   }
355   EIGEN_DEVICE_FUNC DSizes(const DenseIndex i0, const DenseIndex i1, const DenseIndex i2, const DenseIndex i3) {
356     eigen_assert(NumDims == 4);
357     (*this)[0] = i0;
358     (*this)[1] = i1;
359     (*this)[2] = i2;
360     (*this)[3] = i3;
361   }
362   EIGEN_DEVICE_FUNC DSizes(const DenseIndex i0, const DenseIndex i1, const DenseIndex i2, const DenseIndex i3, const DenseIndex i4) {
363     eigen_assert(NumDims == 5);
364     (*this)[0] = i0;
365     (*this)[1] = i1;
366     (*this)[2] = i2;
367     (*this)[3] = i3;
368     (*this)[4] = i4;
369   }
370 #endif
371 
372   EIGEN_DEVICE_FUNC DSizes& operator = (const array<DenseIndex, NumDims>& other) {
373     *static_cast<Base*>(this) = other;
374     return *this;
375   }
376 
377   // A constexpr would be so much better here
378   EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE DenseIndex IndexOfColMajor(const array<DenseIndex, NumDims>& indices) const {
379     return internal::tensor_index_linearization_helper<DenseIndex, NumDims, NumDims - 1, false>::run(indices, *static_cast<const Base*>(this));
380   }
381   EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE DenseIndex IndexOfRowMajor(const array<DenseIndex, NumDims>& indices) const {
382     return internal::tensor_index_linearization_helper<DenseIndex, NumDims, NumDims - 1, true>::run(indices, *static_cast<const Base*>(this));
383   }
384 };
385 
386 template <typename IndexType, int NumDims>
387 std::ostream& operator<<(std::ostream& os,
388                          const DSizes<IndexType, NumDims>& dims) {
389   os << "[";
390   for (int i = 0; i < NumDims; ++i) {
391     if (i > 0) os << ", ";
392     os << dims[i];
393   }
394   os << "]";
395   return os;
396 }
397 
398 // Boilerplate
399 namespace internal {
400 template<typename Index, std::ptrdiff_t NumIndices, std::ptrdiff_t n, bool RowMajor>
401 struct tensor_vsize_index_linearization_helper
402 {
403   static EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE
404   Index run(array<Index, NumIndices> const& indices, std::vector<DenseIndex> const& dimensions)
405   {
406     return array_get<RowMajor ? n : (NumIndices - n - 1)>(indices) +
407       array_get<RowMajor ? n : (NumIndices - n - 1)>(dimensions) *
408         tensor_vsize_index_linearization_helper<Index, NumIndices, n - 1, RowMajor>::run(indices, dimensions);
409   }
410 };
411 
412 template<typename Index, std::ptrdiff_t NumIndices, bool RowMajor>
413 struct tensor_vsize_index_linearization_helper<Index, NumIndices, 0, RowMajor>
414 {
415   static EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE
416   Index run(array<Index, NumIndices> const& indices, std::vector<DenseIndex> const&)
417   {
418     return array_get<RowMajor ? 0 : NumIndices - 1>(indices);
419   }
420 };
421 }  // end namespace internal
422 
423 
424 namespace internal {
425 
426 template <typename DenseIndex, int NumDims> struct array_size<const DSizes<DenseIndex, NumDims> > {
427   static const ptrdiff_t value = NumDims;
428 };
429 template <typename DenseIndex, int NumDims> struct array_size<DSizes<DenseIndex, NumDims> > {
430   static const ptrdiff_t value = NumDims;
431 };
432 #ifndef EIGEN_EMULATE_CXX11_META_H
433 template <typename std::ptrdiff_t... Indices> struct array_size<const Sizes<Indices...> > {
434 static const std::ptrdiff_t value = Sizes<Indices...>::count;
435 };
436 template <typename std::ptrdiff_t... Indices> struct array_size<Sizes<Indices...> > {
437 static const std::ptrdiff_t value = Sizes<Indices...>::count;
438 };
439 template <std::ptrdiff_t n, typename std::ptrdiff_t... Indices> EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE std::ptrdiff_t array_get(const Sizes<Indices...>&) {
440   return get<n, internal::numeric_list<std::ptrdiff_t, Indices...> >::value;
441 }
442 template <std::ptrdiff_t n> EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE std::ptrdiff_t array_get(const Sizes<>&) {
443   eigen_assert(false && "should never be called");
444   return -1;
445 }
446 #else
447 template <std::ptrdiff_t V1, std::ptrdiff_t V2, std::ptrdiff_t V3, std::ptrdiff_t V4, std::ptrdiff_t V5> struct array_size<const Sizes<V1,V2,V3,V4,V5> > {
448   static const ptrdiff_t value = Sizes<V1,V2,V3,V4,V5>::count;
449 };
450 template <std::ptrdiff_t V1, std::ptrdiff_t V2, std::ptrdiff_t V3, std::ptrdiff_t V4, std::ptrdiff_t V5> struct array_size<Sizes<V1,V2,V3,V4,V5> > {
451   static const ptrdiff_t value = Sizes<V1,V2,V3,V4,V5>::count;
452 };
453 template <std::ptrdiff_t n, std::ptrdiff_t V1, std::ptrdiff_t V2, std::ptrdiff_t V3, std::ptrdiff_t V4, std::ptrdiff_t V5> EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE std::ptrdiff_t array_get(const Sizes<V1,V2,V3,V4,V5>&) {
454   return get<n, typename Sizes<V1,V2,V3,V4,V5>::Base>::value;
455 }
456 
457 #endif
458 
459 
460 template <typename Dims1, typename Dims2, ptrdiff_t n, ptrdiff_t m>
461 struct sizes_match_below_dim {
462   static EIGEN_DEVICE_FUNC  EIGEN_STRONG_INLINE bool run(Dims1&, Dims2&) {
463     return false;
464   }
465 };
466 template <typename Dims1, typename Dims2, ptrdiff_t n>
467 struct sizes_match_below_dim<Dims1, Dims2, n, n> {
468   static EIGEN_DEVICE_FUNC  EIGEN_STRONG_INLINE bool run(Dims1& dims1, Dims2& dims2) {
469     return (array_get<n-1>(dims1) == array_get<n-1>(dims2)) &&
470         sizes_match_below_dim<Dims1, Dims2, n-1, n-1>::run(dims1, dims2);
471   }
472 };
473 template <typename Dims1, typename Dims2>
474 struct sizes_match_below_dim<Dims1, Dims2, 0, 0> {
475   static EIGEN_DEVICE_FUNC  EIGEN_STRONG_INLINE bool run(Dims1&, Dims2&) {
476     return true;
477   }
478 };
479 
480 } // end namespace internal
481 
482 
483 template <typename Dims1, typename Dims2>
484 EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE bool dimensions_match(Dims1 dims1, Dims2 dims2) {
485   return internal::sizes_match_below_dim<Dims1, Dims2, internal::array_size<Dims1>::value, internal::array_size<Dims2>::value>::run(dims1, dims2);
486 }
487 
488 } // end namespace Eigen
489 
490 #endif // EIGEN_CXX11_TENSOR_TENSOR_DIMENSIONS_H
491