xref: /aosp_15_r20/external/eigen/unsupported/Eigen/CXX11/src/Tensor/TensorConcatenation.h (revision bf2c37156dfe67e5dfebd6d394bad8b2ab5804d4)
1 // This file is part of Eigen, a lightweight C++ template library
2 // for linear algebra.
3 //
4 // Copyright (C) 2014 Benoit Steiner <[email protected]>
5 //
6 // This Source Code Form is subject to the terms of the Mozilla
7 // Public License v. 2.0. If a copy of the MPL was not distributed
8 // with this file, You can obtain one at http://mozilla.org/MPL/2.0/.
9 
10 #ifndef EIGEN_CXX11_TENSOR_TENSOR_CONCATENATION_H
11 #define EIGEN_CXX11_TENSOR_TENSOR_CONCATENATION_H
12 
13 namespace Eigen {
14 
15 /** \class TensorConcatenationOp
16   * \ingroup CXX11_Tensor_Module
17   *
18   * \brief Tensor concatenation class.
19   *
20   *
21   */
22 namespace internal {
23 template<typename Axis, typename LhsXprType, typename RhsXprType>
24 struct traits<TensorConcatenationOp<Axis, LhsXprType, RhsXprType> >
25 {
26   // Type promotion to handle the case where the types of the lhs and the rhs are different.
27   typedef typename promote_storage_type<typename LhsXprType::Scalar,
28                                         typename RhsXprType::Scalar>::ret Scalar;
29   typedef typename promote_storage_type<typename traits<LhsXprType>::StorageKind,
30                                         typename traits<RhsXprType>::StorageKind>::ret StorageKind;
31   typedef typename promote_index_type<typename traits<LhsXprType>::Index,
32                                       typename traits<RhsXprType>::Index>::type Index;
33   typedef typename LhsXprType::Nested LhsNested;
34   typedef typename RhsXprType::Nested RhsNested;
35   typedef typename remove_reference<LhsNested>::type _LhsNested;
36   typedef typename remove_reference<RhsNested>::type _RhsNested;
37   static const int NumDimensions = traits<LhsXprType>::NumDimensions;
38   static const int Layout = traits<LhsXprType>::Layout;
39   enum { Flags = 0 };
40   typedef typename conditional<Pointer_type_promotion<typename LhsXprType::Scalar, Scalar>::val,
41                                typename traits<LhsXprType>::PointerType, typename traits<RhsXprType>::PointerType>::type PointerType;
42 };
43 
44 template<typename Axis, typename LhsXprType, typename RhsXprType>
45 struct eval<TensorConcatenationOp<Axis, LhsXprType, RhsXprType>, Eigen::Dense>
46 {
47   typedef const TensorConcatenationOp<Axis, LhsXprType, RhsXprType>& type;
48 };
49 
50 template<typename Axis, typename LhsXprType, typename RhsXprType>
51 struct nested<TensorConcatenationOp<Axis, LhsXprType, RhsXprType>, 1, typename eval<TensorConcatenationOp<Axis, LhsXprType, RhsXprType> >::type>
52 {
53   typedef TensorConcatenationOp<Axis, LhsXprType, RhsXprType> type;
54 };
55 
56 }  // end namespace internal
57 
58 
59 template<typename Axis, typename LhsXprType, typename RhsXprType>
60 class TensorConcatenationOp : public TensorBase<TensorConcatenationOp<Axis, LhsXprType, RhsXprType>, WriteAccessors>
61 {
62   public:
63     typedef TensorBase<TensorConcatenationOp<Axis, LhsXprType, RhsXprType>, WriteAccessors> Base;
64     typedef typename internal::traits<TensorConcatenationOp>::Scalar Scalar;
65     typedef typename internal::traits<TensorConcatenationOp>::StorageKind StorageKind;
66     typedef typename internal::traits<TensorConcatenationOp>::Index Index;
67     typedef typename internal::nested<TensorConcatenationOp>::type Nested;
68     typedef typename internal::promote_storage_type<typename LhsXprType::CoeffReturnType,
69                                                     typename RhsXprType::CoeffReturnType>::ret CoeffReturnType;
70     typedef typename NumTraits<Scalar>::Real RealScalar;
71 
72     EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE TensorConcatenationOp(const LhsXprType& lhs, const RhsXprType& rhs, Axis axis)
73         : m_lhs_xpr(lhs), m_rhs_xpr(rhs), m_axis(axis) {}
74 
75     EIGEN_DEVICE_FUNC
76     const typename internal::remove_all<typename LhsXprType::Nested>::type&
77     lhsExpression() const { return m_lhs_xpr; }
78 
79     EIGEN_DEVICE_FUNC
80     const typename internal::remove_all<typename RhsXprType::Nested>::type&
81     rhsExpression() const { return m_rhs_xpr; }
82 
83     EIGEN_DEVICE_FUNC const Axis& axis() const { return m_axis; }
84 
85     EIGEN_TENSOR_INHERIT_ASSIGNMENT_OPERATORS(TensorConcatenationOp)
86   protected:
87     typename LhsXprType::Nested m_lhs_xpr;
88     typename RhsXprType::Nested m_rhs_xpr;
89     const Axis m_axis;
90 };
91 
92 
93 // Eval as rvalue
94 template<typename Axis, typename LeftArgType, typename RightArgType, typename Device>
95 struct TensorEvaluator<const TensorConcatenationOp<Axis, LeftArgType, RightArgType>, Device>
96 {
97   typedef TensorConcatenationOp<Axis, LeftArgType, RightArgType> XprType;
98   typedef typename XprType::Index Index;
99   static const int NumDims = internal::array_size<typename TensorEvaluator<LeftArgType, Device>::Dimensions>::value;
100   static const int RightNumDims = internal::array_size<typename TensorEvaluator<RightArgType, Device>::Dimensions>::value;
101   typedef DSizes<Index, NumDims> Dimensions;
102   typedef typename XprType::Scalar Scalar;
103   typedef typename XprType::CoeffReturnType CoeffReturnType;
104   typedef typename PacketType<CoeffReturnType, Device>::type PacketReturnType;
105   typedef StorageMemory<CoeffReturnType, Device> Storage;
106   typedef typename Storage::Type EvaluatorPointerType;
107   enum {
108     IsAligned         = false,
109     PacketAccess      = TensorEvaluator<LeftArgType, Device>::PacketAccess &&
110                         TensorEvaluator<RightArgType, Device>::PacketAccess,
111     BlockAccess       = false,
112     PreferBlockAccess = TensorEvaluator<LeftArgType, Device>::PreferBlockAccess ||
113                         TensorEvaluator<RightArgType, Device>::PreferBlockAccess,
114     Layout            = TensorEvaluator<LeftArgType, Device>::Layout,
115     RawAccess         = false
116   };
117 
118   //===- Tensor block evaluation strategy (see TensorBlock.h) -------------===//
119   typedef internal::TensorBlockNotImplemented TensorBlock;
120   //===--------------------------------------------------------------------===//
121 
122   EIGEN_STRONG_INLINE TensorEvaluator(const XprType& op, const Device& device)
123     : m_leftImpl(op.lhsExpression(), device), m_rightImpl(op.rhsExpression(), device), m_axis(op.axis())
124   {
125     EIGEN_STATIC_ASSERT((static_cast<int>(TensorEvaluator<LeftArgType, Device>::Layout) == static_cast<int>(TensorEvaluator<RightArgType, Device>::Layout) || NumDims == 1), YOU_MADE_A_PROGRAMMING_MISTAKE);
126     EIGEN_STATIC_ASSERT((NumDims == RightNumDims), YOU_MADE_A_PROGRAMMING_MISTAKE);
127     EIGEN_STATIC_ASSERT((NumDims > 0), YOU_MADE_A_PROGRAMMING_MISTAKE);
128 
129     eigen_assert(0 <= m_axis && m_axis < NumDims);
130     const Dimensions& lhs_dims = m_leftImpl.dimensions();
131     const Dimensions& rhs_dims = m_rightImpl.dimensions();
132     {
133       int i = 0;
134       for (; i < m_axis; ++i) {
135         eigen_assert(lhs_dims[i] > 0);
136         eigen_assert(lhs_dims[i] == rhs_dims[i]);
137         m_dimensions[i] = lhs_dims[i];
138       }
139       eigen_assert(lhs_dims[i] > 0);  // Now i == m_axis.
140       eigen_assert(rhs_dims[i] > 0);
141       m_dimensions[i] = lhs_dims[i] + rhs_dims[i];
142       for (++i; i < NumDims; ++i) {
143         eigen_assert(lhs_dims[i] > 0);
144         eigen_assert(lhs_dims[i] == rhs_dims[i]);
145         m_dimensions[i] = lhs_dims[i];
146       }
147     }
148 
149     if (static_cast<int>(Layout) == static_cast<int>(ColMajor)) {
150       m_leftStrides[0] = 1;
151       m_rightStrides[0] = 1;
152       m_outputStrides[0] = 1;
153 
154       for (int j = 1; j < NumDims; ++j) {
155         m_leftStrides[j] = m_leftStrides[j-1] * lhs_dims[j-1];
156         m_rightStrides[j] = m_rightStrides[j-1] * rhs_dims[j-1];
157         m_outputStrides[j] = m_outputStrides[j-1] * m_dimensions[j-1];
158       }
159     } else {
160       m_leftStrides[NumDims - 1] = 1;
161       m_rightStrides[NumDims - 1] = 1;
162       m_outputStrides[NumDims - 1] = 1;
163 
164       for (int j = NumDims - 2; j >= 0; --j) {
165         m_leftStrides[j] = m_leftStrides[j+1] * lhs_dims[j+1];
166         m_rightStrides[j] = m_rightStrides[j+1] * rhs_dims[j+1];
167         m_outputStrides[j] = m_outputStrides[j+1] * m_dimensions[j+1];
168       }
169     }
170   }
171 
172   EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const Dimensions& dimensions() const { return m_dimensions; }
173 
174   // TODO(phli): Add short-circuit memcpy evaluation if underlying data are linear?
175   EIGEN_STRONG_INLINE bool evalSubExprsIfNeeded(EvaluatorPointerType)
176   {
177     m_leftImpl.evalSubExprsIfNeeded(NULL);
178     m_rightImpl.evalSubExprsIfNeeded(NULL);
179     return true;
180   }
181 
182   EIGEN_STRONG_INLINE void cleanup()
183   {
184     m_leftImpl.cleanup();
185     m_rightImpl.cleanup();
186   }
187 
188   // TODO(phli): attempt to speed this up. The integer divisions and modulo are slow.
189   // See CL/76180724 comments for more ideas.
190   EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE CoeffReturnType coeff(Index index) const
191   {
192     // Collect dimension-wise indices (subs).
193     array<Index, NumDims> subs;
194     if (static_cast<int>(Layout) == static_cast<int>(ColMajor)) {
195       for (int i = NumDims - 1; i > 0; --i) {
196         subs[i] = index / m_outputStrides[i];
197         index -= subs[i] * m_outputStrides[i];
198       }
199       subs[0] = index;
200     } else {
201       for (int i = 0; i < NumDims - 1; ++i) {
202         subs[i] = index / m_outputStrides[i];
203         index -= subs[i] * m_outputStrides[i];
204       }
205       subs[NumDims - 1] = index;
206     }
207 
208     const Dimensions& left_dims = m_leftImpl.dimensions();
209     if (subs[m_axis] < left_dims[m_axis]) {
210       Index left_index;
211       if (static_cast<int>(Layout) == static_cast<int>(ColMajor)) {
212         left_index = subs[0];
213         EIGEN_UNROLL_LOOP
214         for (int i = 1; i < NumDims; ++i) {
215           left_index += (subs[i] % left_dims[i]) * m_leftStrides[i];
216         }
217       } else {
218         left_index = subs[NumDims - 1];
219         EIGEN_UNROLL_LOOP
220         for (int i = NumDims - 2; i >= 0; --i) {
221           left_index += (subs[i] % left_dims[i]) * m_leftStrides[i];
222         }
223       }
224       return m_leftImpl.coeff(left_index);
225     } else {
226       subs[m_axis] -= left_dims[m_axis];
227       const Dimensions& right_dims = m_rightImpl.dimensions();
228       Index right_index;
229       if (static_cast<int>(Layout) == static_cast<int>(ColMajor)) {
230         right_index = subs[0];
231         EIGEN_UNROLL_LOOP
232         for (int i = 1; i < NumDims; ++i) {
233           right_index += (subs[i] % right_dims[i]) * m_rightStrides[i];
234         }
235       } else {
236         right_index = subs[NumDims - 1];
237         EIGEN_UNROLL_LOOP
238         for (int i = NumDims - 2; i >= 0; --i) {
239           right_index += (subs[i] % right_dims[i]) * m_rightStrides[i];
240         }
241       }
242       return m_rightImpl.coeff(right_index);
243     }
244   }
245 
246   // TODO(phli): Add a real vectorization.
247   template<int LoadMode>
248   EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE PacketReturnType packet(Index index) const
249   {
250     const int packetSize = PacketType<CoeffReturnType, Device>::size;
251     EIGEN_STATIC_ASSERT((packetSize > 1), YOU_MADE_A_PROGRAMMING_MISTAKE)
252     eigen_assert(index + packetSize - 1 < dimensions().TotalSize());
253 
254     EIGEN_ALIGN_MAX CoeffReturnType values[packetSize];
255     EIGEN_UNROLL_LOOP
256     for (int i = 0; i < packetSize; ++i) {
257       values[i] = coeff(index+i);
258     }
259     PacketReturnType rslt = internal::pload<PacketReturnType>(values);
260     return rslt;
261   }
262 
263   EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE TensorOpCost
264   costPerCoeff(bool vectorized) const {
265     const double compute_cost = NumDims * (2 * TensorOpCost::AddCost<Index>() +
266                                            2 * TensorOpCost::MulCost<Index>() +
267                                            TensorOpCost::DivCost<Index>() +
268                                            TensorOpCost::ModCost<Index>());
269     const double lhs_size = m_leftImpl.dimensions().TotalSize();
270     const double rhs_size = m_rightImpl.dimensions().TotalSize();
271     return (lhs_size / (lhs_size + rhs_size)) *
272                m_leftImpl.costPerCoeff(vectorized) +
273            (rhs_size / (lhs_size + rhs_size)) *
274                m_rightImpl.costPerCoeff(vectorized) +
275            TensorOpCost(0, 0, compute_cost);
276   }
277 
278   EIGEN_DEVICE_FUNC EvaluatorPointerType data() const { return NULL; }
279 
280   #ifdef EIGEN_USE_SYCL
281   // binding placeholder accessors to a command group handler for SYCL
282   EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void bind(cl::sycl::handler &cgh) const {
283     m_leftImpl.bind(cgh);
284     m_rightImpl.bind(cgh);
285   }
286   #endif
287 
288   protected:
289     Dimensions m_dimensions;
290     array<Index, NumDims> m_outputStrides;
291     array<Index, NumDims> m_leftStrides;
292     array<Index, NumDims> m_rightStrides;
293     TensorEvaluator<LeftArgType, Device> m_leftImpl;
294     TensorEvaluator<RightArgType, Device> m_rightImpl;
295     const Axis m_axis;
296 };
297 
298 // Eval as lvalue
299 template<typename Axis, typename LeftArgType, typename RightArgType, typename Device>
300   struct TensorEvaluator<TensorConcatenationOp<Axis, LeftArgType, RightArgType>, Device>
301   : public TensorEvaluator<const TensorConcatenationOp<Axis, LeftArgType, RightArgType>, Device>
302 {
303   typedef TensorEvaluator<const TensorConcatenationOp<Axis, LeftArgType, RightArgType>, Device> Base;
304   typedef TensorConcatenationOp<Axis, LeftArgType, RightArgType> XprType;
305   typedef typename Base::Dimensions Dimensions;
306   enum {
307     IsAligned         = false,
308     PacketAccess      = TensorEvaluator<LeftArgType, Device>::PacketAccess &&
309                         TensorEvaluator<RightArgType, Device>::PacketAccess,
310     BlockAccess       = false,
311     PreferBlockAccess = TensorEvaluator<LeftArgType, Device>::PreferBlockAccess ||
312                         TensorEvaluator<RightArgType, Device>::PreferBlockAccess,
313     Layout            = TensorEvaluator<LeftArgType, Device>::Layout,
314     RawAccess         = false
315   };
316 
317   //===- Tensor block evaluation strategy (see TensorBlock.h) -------------===//
318   typedef internal::TensorBlockNotImplemented TensorBlock;
319   //===--------------------------------------------------------------------===//
320 
321   EIGEN_STRONG_INLINE TensorEvaluator(XprType& op, const Device& device)
322     : Base(op, device)
323   {
324     EIGEN_STATIC_ASSERT((static_cast<int>(Layout) == static_cast<int>(ColMajor)), YOU_MADE_A_PROGRAMMING_MISTAKE);
325   }
326 
327   typedef typename XprType::Index Index;
328   typedef typename XprType::Scalar Scalar;
329   typedef typename XprType::CoeffReturnType CoeffReturnType;
330   typedef typename PacketType<CoeffReturnType, Device>::type PacketReturnType;
331 
332   EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE CoeffReturnType& coeffRef(Index index)
333   {
334     // Collect dimension-wise indices (subs).
335     array<Index, Base::NumDims> subs;
336     for (int i = Base::NumDims - 1; i > 0; --i) {
337       subs[i] = index / this->m_outputStrides[i];
338       index -= subs[i] * this->m_outputStrides[i];
339     }
340     subs[0] = index;
341 
342     const Dimensions& left_dims = this->m_leftImpl.dimensions();
343     if (subs[this->m_axis] < left_dims[this->m_axis]) {
344       Index left_index = subs[0];
345       for (int i = 1; i < Base::NumDims; ++i) {
346         left_index += (subs[i] % left_dims[i]) * this->m_leftStrides[i];
347       }
348       return this->m_leftImpl.coeffRef(left_index);
349     } else {
350       subs[this->m_axis] -= left_dims[this->m_axis];
351       const Dimensions& right_dims = this->m_rightImpl.dimensions();
352       Index right_index = subs[0];
353       for (int i = 1; i < Base::NumDims; ++i) {
354         right_index += (subs[i] % right_dims[i]) * this->m_rightStrides[i];
355       }
356       return this->m_rightImpl.coeffRef(right_index);
357     }
358   }
359 
360   template <int StoreMode> EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE
361   void writePacket(Index index, const PacketReturnType& x)
362   {
363     const int packetSize = PacketType<CoeffReturnType, Device>::size;
364     EIGEN_STATIC_ASSERT((packetSize > 1), YOU_MADE_A_PROGRAMMING_MISTAKE)
365     eigen_assert(index + packetSize - 1 < this->dimensions().TotalSize());
366 
367     EIGEN_ALIGN_MAX CoeffReturnType values[packetSize];
368     internal::pstore<CoeffReturnType, PacketReturnType>(values, x);
369     for (int i = 0; i < packetSize; ++i) {
370       coeffRef(index+i) = values[i];
371     }
372   }
373 };
374 
375 } // end namespace Eigen
376 
377 #endif // EIGEN_CXX11_TENSOR_TENSOR_CONCATENATION_H
378