xref: /aosp_15_r20/external/eigen/unsupported/Eigen/CXX11/src/Tensor/TensorConversion.h (revision bf2c37156dfe67e5dfebd6d394bad8b2ab5804d4)
1 // This file is part of Eigen, a lightweight C++ template library
2 // for linear algebra.
3 //
4 // Copyright (C) 2015 Benoit Steiner <[email protected]>
5 //
6 // This Source Code Form is subject to the terms of the Mozilla
7 // Public License v. 2.0. If a copy of the MPL was not distributed
8 // with this file, You can obtain one at http://mozilla.org/MPL/2.0/.
9 
10 #ifndef EIGEN_CXX11_TENSOR_TENSOR_CONVERSION_H
11 #define EIGEN_CXX11_TENSOR_TENSOR_CONVERSION_H
12 
13 namespace Eigen {
14 
15 /** \class TensorConversionOp
16   * \ingroup CXX11_Tensor_Module
17   *
18   * \brief Tensor conversion class. This class makes it possible to vectorize
19   * type casting operations when the number of scalars per packet in the source
20   * and the destination type differ
21   */
22 namespace internal {
23 template<typename TargetType, typename XprType>
24 struct traits<TensorConversionOp<TargetType, XprType> >
25 {
26   // Type promotion to handle the case where the types of the lhs and the rhs are different.
27   typedef TargetType Scalar;
28   typedef typename traits<XprType>::StorageKind StorageKind;
29   typedef typename traits<XprType>::Index Index;
30   typedef typename XprType::Nested Nested;
31   typedef typename remove_reference<Nested>::type _Nested;
32   static const int NumDimensions = traits<XprType>::NumDimensions;
33   static const int Layout = traits<XprType>::Layout;
34   enum { Flags = 0 };
35   typedef typename TypeConversion<Scalar, typename traits<XprType>::PointerType>::type PointerType;
36 };
37 
38 template<typename TargetType, typename XprType>
39 struct eval<TensorConversionOp<TargetType, XprType>, Eigen::Dense>
40 {
41   typedef const TensorConversionOp<TargetType, XprType>& type;
42 };
43 
44 template<typename TargetType, typename XprType>
45 struct nested<TensorConversionOp<TargetType, XprType>, 1, typename eval<TensorConversionOp<TargetType, XprType> >::type>
46 {
47   typedef TensorConversionOp<TargetType, XprType> type;
48 };
49 
50 }  // end namespace internal
51 
52 
53 template <typename TensorEvaluator, typename SrcPacket, typename TgtPacket, int SrcCoeffRatio, int TgtCoeffRatio>
54 struct PacketConverter;
55 
56 template <typename TensorEvaluator, typename SrcPacket, typename TgtPacket>
57 struct PacketConverter<TensorEvaluator, SrcPacket, TgtPacket, 1, 1> {
58   EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE
59   PacketConverter(const TensorEvaluator& impl)
60       : m_impl(impl) {}
61 
62   template<int LoadMode, typename Index>
63   EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE TgtPacket packet(Index index) const {
64     return internal::pcast<SrcPacket, TgtPacket>(m_impl.template packet<LoadMode>(index));
65   }
66 
67  private:
68   const TensorEvaluator& m_impl;
69 };
70 
71 
72 template <typename TensorEvaluator, typename SrcPacket, typename TgtPacket>
73 struct PacketConverter<TensorEvaluator, SrcPacket, TgtPacket, 2, 1> {
74   EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE
75   PacketConverter(const TensorEvaluator& impl)
76       : m_impl(impl) {}
77 
78   template<int LoadMode, typename Index>
79   EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE TgtPacket packet(Index index) const {
80     const int SrcPacketSize = internal::unpacket_traits<SrcPacket>::size;
81 
82     SrcPacket src1 = m_impl.template packet<LoadMode>(index);
83     SrcPacket src2 = m_impl.template packet<LoadMode>(index + SrcPacketSize);
84     TgtPacket result = internal::pcast<SrcPacket, TgtPacket>(src1, src2);
85     return result;
86   }
87 
88  private:
89   const TensorEvaluator& m_impl;
90 };
91 
92 template <typename TensorEvaluator, typename SrcPacket, typename TgtPacket>
93 struct PacketConverter<TensorEvaluator, SrcPacket, TgtPacket, 4, 1> {
94   EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE
95   PacketConverter(const TensorEvaluator& impl)
96       : m_impl(impl) {}
97 
98   template<int LoadMode, typename Index>
99   EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE TgtPacket packet(Index index) const {
100     const int SrcPacketSize = internal::unpacket_traits<SrcPacket>::size;
101 
102     SrcPacket src1 = m_impl.template packet<LoadMode>(index);
103     SrcPacket src2 = m_impl.template packet<LoadMode>(index + SrcPacketSize);
104     SrcPacket src3 = m_impl.template packet<LoadMode>(index + 2 * SrcPacketSize);
105     SrcPacket src4 = m_impl.template packet<LoadMode>(index + 3 * SrcPacketSize);
106     TgtPacket result = internal::pcast<SrcPacket, TgtPacket>(src1, src2, src3, src4);
107     return result;
108   }
109 
110  private:
111   const TensorEvaluator& m_impl;
112 };
113 
114 template <typename TensorEvaluator, typename SrcPacket, typename TgtPacket>
115 struct PacketConverter<TensorEvaluator, SrcPacket, TgtPacket, 8, 1> {
116   EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE
117   PacketConverter(const TensorEvaluator& impl)
118       : m_impl(impl) {}
119 
120   template<int LoadMode, typename Index>
121   EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE TgtPacket packet(Index index) const {
122     const int SrcPacketSize = internal::unpacket_traits<SrcPacket>::size;
123 
124     SrcPacket src1 = m_impl.template packet<LoadMode>(index);
125     SrcPacket src2 = m_impl.template packet<LoadMode>(index + 1 * SrcPacketSize);
126     SrcPacket src3 = m_impl.template packet<LoadMode>(index + 2 * SrcPacketSize);
127     SrcPacket src4 = m_impl.template packet<LoadMode>(index + 3 * SrcPacketSize);
128     SrcPacket src5 = m_impl.template packet<LoadMode>(index + 4 * SrcPacketSize);
129     SrcPacket src6 = m_impl.template packet<LoadMode>(index + 5 * SrcPacketSize);
130     SrcPacket src7 = m_impl.template packet<LoadMode>(index + 6 * SrcPacketSize);
131     SrcPacket src8 = m_impl.template packet<LoadMode>(index + 7 * SrcPacketSize);
132     TgtPacket result = internal::pcast<SrcPacket, TgtPacket>(src1, src2, src3, src4, src5, src6, src7, src8);
133     return result;
134   }
135 
136  private:
137   const TensorEvaluator& m_impl;
138 };
139 
140 template <typename TensorEvaluator, typename SrcPacket, typename TgtPacket, int TgtCoeffRatio>
141 struct PacketConverter<TensorEvaluator, SrcPacket, TgtPacket, 1, TgtCoeffRatio> {
142   EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE
143   PacketConverter(const TensorEvaluator& impl)
144       : m_impl(impl), m_maxIndex(impl.dimensions().TotalSize()) {}
145 
146   template<int LoadMode, typename Index>
147   EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE TgtPacket packet(Index index) const {
148     const int SrcPacketSize = internal::unpacket_traits<SrcPacket>::size;
149     // Only call m_impl.packet() when we have direct access to the underlying data. This
150     // ensures that we don't compute the subexpression twice. We may however load some
151     // coefficients twice, but in practice this doesn't negatively impact performance.
152     if (m_impl.data() && (index + SrcPacketSize < m_maxIndex)) {
153       // Force unaligned memory loads since we can't ensure alignment anymore
154       return internal::pcast<SrcPacket, TgtPacket>(m_impl.template packet<Unaligned>(index));
155     } else {
156       const int TgtPacketSize = internal::unpacket_traits<TgtPacket>::size;
157       typedef typename internal::unpacket_traits<SrcPacket>::type SrcType;
158       typedef typename internal::unpacket_traits<TgtPacket>::type TgtType;
159       internal::scalar_cast_op<SrcType, TgtType> converter;
160       EIGEN_ALIGN_MAX typename internal::unpacket_traits<TgtPacket>::type values[TgtPacketSize];
161       EIGEN_UNROLL_LOOP
162       for (int i = 0; i < TgtPacketSize; ++i) {
163         values[i] = converter(m_impl.coeff(index+i));
164       }
165       TgtPacket rslt = internal::pload<TgtPacket>(values);
166       return rslt;
167     }
168   }
169 
170  private:
171   const TensorEvaluator& m_impl;
172   const typename TensorEvaluator::Index m_maxIndex;
173 };
174 
175 template<typename TargetType, typename XprType>
176 class TensorConversionOp : public TensorBase<TensorConversionOp<TargetType, XprType>, ReadOnlyAccessors>
177 {
178   public:
179     typedef typename internal::traits<TensorConversionOp>::Scalar Scalar;
180     typedef typename internal::traits<TensorConversionOp>::StorageKind StorageKind;
181     typedef typename internal::traits<TensorConversionOp>::Index Index;
182     typedef typename internal::nested<TensorConversionOp>::type Nested;
183     typedef Scalar CoeffReturnType;
184     typedef typename NumTraits<Scalar>::Real RealScalar;
185 
186     EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE TensorConversionOp(const XprType& xpr)
187         : m_xpr(xpr) {}
188 
189     EIGEN_DEVICE_FUNC
190     const typename internal::remove_all<typename XprType::Nested>::type&
191     expression() const { return m_xpr; }
192 
193   protected:
194     typename XprType::Nested m_xpr;
195 };
196 
197 template <bool SameType, typename Eval, typename EvalPointerType> struct ConversionSubExprEval {
198   static EIGEN_STRONG_INLINE bool run(Eval& impl, EvalPointerType) {
199     impl.evalSubExprsIfNeeded(NULL);
200     return true;
201   }
202 };
203 
204 template <typename Eval, typename EvalPointerType> struct ConversionSubExprEval<true, Eval, EvalPointerType> {
205   static EIGEN_STRONG_INLINE bool run(Eval& impl, EvalPointerType data) {
206     return impl.evalSubExprsIfNeeded(data);
207   }
208 };
209 
210 #ifdef EIGEN_USE_THREADS
211 template <bool SameType, typename Eval, typename EvalPointerType,
212           typename EvalSubExprsCallback>
213 struct ConversionSubExprEvalAsync {
214   static EIGEN_STRONG_INLINE void run(Eval& impl, EvalPointerType, EvalSubExprsCallback done) {
215     impl.evalSubExprsIfNeededAsync(nullptr, std::move(done));
216   }
217 };
218 
219 template <typename Eval, typename EvalPointerType,
220           typename EvalSubExprsCallback>
221 struct ConversionSubExprEvalAsync<true, Eval, EvalPointerType,
222                                   EvalSubExprsCallback> {
223   static EIGEN_STRONG_INLINE void run(Eval& impl, EvalPointerType data, EvalSubExprsCallback done) {
224     impl.evalSubExprsIfNeededAsync(data, std::move(done));
225   }
226 };
227 #endif
228 
229 namespace internal {
230 
231 template <typename SrcType, typename TargetType, bool IsSameT>
232 struct CoeffConv {
233   template <typename ArgType, typename Device>
234   static EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE TargetType run(const TensorEvaluator<ArgType, Device>& impl, Index index) {
235     internal::scalar_cast_op<SrcType, TargetType> converter;
236     return converter(impl.coeff(index));
237   }
238 };
239 
240 template <typename SrcType, typename TargetType>
241 struct CoeffConv<SrcType, TargetType, true> {
242   template <typename ArgType, typename Device>
243   static EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE TargetType run(const TensorEvaluator<ArgType, Device>& impl, Index index) {
244     return impl.coeff(index);
245   }
246 };
247 
248 template <typename SrcPacket, typename TargetPacket, int LoadMode, bool ActuallyVectorize, bool IsSameT>
249 struct PacketConv {
250   typedef typename internal::unpacket_traits<SrcPacket>::type SrcType;
251   typedef typename internal::unpacket_traits<TargetPacket>::type TargetType;
252 
253   static const int PacketSize = internal::unpacket_traits<TargetPacket>::size;
254 
255   template <typename ArgType, typename Device>
256   static EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE TargetPacket run(const TensorEvaluator<ArgType, Device>& impl, Index index) {
257     internal::scalar_cast_op<SrcType, TargetType> converter;
258     EIGEN_ALIGN_MAX typename internal::remove_const<TargetType>::type values[PacketSize];
259     EIGEN_UNROLL_LOOP
260     for (int i = 0; i < PacketSize; ++i) {
261       values[i] = converter(impl.coeff(index+i));
262     }
263     TargetPacket rslt = internal::pload<TargetPacket>(values);
264     return rslt;
265   }
266 };
267 
268 template <typename SrcPacket, typename TargetPacket, int LoadMode, bool IsSameT>
269 struct PacketConv<SrcPacket, TargetPacket, LoadMode, true, IsSameT> {
270   typedef typename internal::unpacket_traits<SrcPacket>::type SrcType;
271   typedef typename internal::unpacket_traits<TargetPacket>::type TargetType;
272 
273   template <typename ArgType, typename Device>
274   static EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE TargetPacket run(const TensorEvaluator<ArgType, Device>& impl, Index index) {
275     const int SrcCoeffRatio = internal::type_casting_traits<SrcType, TargetType>::SrcCoeffRatio;
276     const int TgtCoeffRatio = internal::type_casting_traits<SrcType, TargetType>::TgtCoeffRatio;
277     PacketConverter<TensorEvaluator<ArgType, Device>, SrcPacket, TargetPacket,
278                     SrcCoeffRatio, TgtCoeffRatio> converter(impl);
279     return converter.template packet<LoadMode>(index);
280   }
281 };
282 
283 template <typename SrcPacket, typename TargetPacket, int LoadMode>
284 struct PacketConv<SrcPacket, TargetPacket, LoadMode, /*ActuallyVectorize=*/false, /*IsSameT=*/true> {
285   typedef typename internal::unpacket_traits<TargetPacket>::type TargetType;
286   static const int PacketSize = internal::unpacket_traits<TargetPacket>::size;
287 
288   template <typename ArgType, typename Device>
289   static EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE TargetPacket run(const TensorEvaluator<ArgType, Device>& impl, Index index) {
290     EIGEN_ALIGN_MAX typename internal::remove_const<TargetType>::type values[PacketSize];
291     for (int i = 0; i < PacketSize; ++i) values[i] = impl.coeff(index+i);
292     return internal::pload<TargetPacket>(values);
293   }
294 };
295 
296 template <typename SrcPacket, typename TargetPacket, int LoadMode>
297 struct PacketConv<SrcPacket, TargetPacket, LoadMode, /*ActuallyVectorize=*/true, /*IsSameT=*/true> {
298   template <typename ArgType, typename Device>
299   static EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE TargetPacket run(const TensorEvaluator<ArgType, Device>& impl, Index index) {
300     return impl.template packet<LoadMode>(index);
301   }
302 };
303 
304 }  // namespace internal
305 
306 // Eval as rvalue
307 template<typename TargetType, typename ArgType, typename Device>
308 struct TensorEvaluator<const TensorConversionOp<TargetType, ArgType>, Device>
309 {
310   typedef TensorConversionOp<TargetType, ArgType> XprType;
311   typedef typename XprType::Index Index;
312   typedef typename TensorEvaluator<ArgType, Device>::Dimensions Dimensions;
313   typedef TargetType Scalar;
314   typedef TargetType CoeffReturnType;
315   typedef typename internal::remove_all<typename internal::traits<ArgType>::Scalar>::type SrcType;
316   typedef typename PacketType<CoeffReturnType, Device>::type PacketReturnType;
317   typedef typename PacketType<SrcType, Device>::type PacketSourceType;
318   static const int PacketSize = PacketType<CoeffReturnType, Device>::size;
319   static const bool IsSameType = internal::is_same<TargetType, SrcType>::value;
320   typedef StorageMemory<CoeffReturnType, Device> Storage;
321   typedef typename Storage::Type EvaluatorPointerType;
322 
323   enum {
324     IsAligned         = false,
325     PacketAccess      =
326     #ifndef EIGEN_USE_SYCL
327                         true,
328     #else
329                         TensorEvaluator<ArgType, Device>::PacketAccess &
330                         internal::type_casting_traits<SrcType, TargetType>::VectorizedCast,
331     #endif
332     BlockAccess       = TensorEvaluator<ArgType, Device>::BlockAccess,
333     PreferBlockAccess = TensorEvaluator<ArgType, Device>::PreferBlockAccess,
334     Layout            = TensorEvaluator<ArgType, Device>::Layout,
335     RawAccess         = false
336   };
337 
338   static const int NumDims = internal::array_size<Dimensions>::value;
339 
340   //===- Tensor block evaluation strategy (see TensorBlock.h) -------------===//
341   typedef internal::TensorBlockDescriptor<NumDims, Index> TensorBlockDesc;
342   typedef internal::TensorBlockScratchAllocator<Device> TensorBlockScratch;
343 
344   typedef typename TensorEvaluator<const ArgType, Device>::TensorBlock
345       ArgTensorBlock;
346 
347   struct TensorConversionOpBlockFactory {
348     template <typename ArgXprType>
349     struct XprType {
350       typedef TensorConversionOp<TargetType, const ArgXprType> type;
351     };
352 
353     template <typename ArgXprType>
354     typename XprType<ArgXprType>::type expr(const ArgXprType& expr) const {
355       return typename XprType<ArgXprType>::type(expr);
356     }
357   };
358 
359   typedef internal::TensorUnaryExprBlock<TensorConversionOpBlockFactory,
360                                          ArgTensorBlock>
361       TensorBlock;
362   //===--------------------------------------------------------------------===//
363 
364   EIGEN_STRONG_INLINE TensorEvaluator(const XprType& op, const Device& device)
365     : m_impl(op.expression(), device)
366   {
367   }
368 
369   EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const Dimensions& dimensions() const { return m_impl.dimensions(); }
370 
371   EIGEN_STRONG_INLINE bool evalSubExprsIfNeeded(EvaluatorPointerType data)
372   {
373     return ConversionSubExprEval<IsSameType, TensorEvaluator<ArgType, Device>, EvaluatorPointerType>::run(m_impl, data);
374   }
375 
376 #ifdef EIGEN_USE_THREADS
377   template <typename EvalSubExprsCallback>
378   EIGEN_STRONG_INLINE void evalSubExprsIfNeededAsync(
379       EvaluatorPointerType data, EvalSubExprsCallback done) {
380     ConversionSubExprEvalAsync<IsSameType, TensorEvaluator<ArgType, Device>,
381                                EvaluatorPointerType,
382         EvalSubExprsCallback>::run(m_impl, data, std::move(done));
383   }
384 #endif
385 
386   EIGEN_STRONG_INLINE void cleanup()
387   {
388     m_impl.cleanup();
389   }
390 
391   EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE CoeffReturnType coeff(Index index) const
392   {
393     return internal::CoeffConv<SrcType, TargetType, IsSameType>::run(m_impl,index);
394   }
395 
396   template<int LoadMode>
397   EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE PacketReturnType
398   packet(Index index) const {
399     // If we are not going to do the cast, we just need to check that base
400     // TensorEvaluator has packet access. Otherwise we also need to make sure,
401     // that we have an implementation of vectorized cast.
402     const bool Vectorizable =
403         IsSameType
404         ? TensorEvaluator<ArgType, Device>::PacketAccess
405         : int(TensorEvaluator<ArgType, Device>::PacketAccess) &
406           int(internal::type_casting_traits<SrcType, TargetType>::VectorizedCast);
407 
408     return internal::PacketConv<PacketSourceType, PacketReturnType, LoadMode,
409                                 Vectorizable, IsSameType>::run(m_impl, index);
410   }
411 
412   EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE TensorOpCost
413   costPerCoeff(bool vectorized) const {
414     const double cast_cost = TensorOpCost::CastCost<SrcType, TargetType>();
415     if (vectorized) {
416       const double SrcCoeffRatio =
417           internal::type_casting_traits<SrcType, TargetType>::SrcCoeffRatio;
418       const double TgtCoeffRatio =
419           internal::type_casting_traits<SrcType, TargetType>::TgtCoeffRatio;
420       return m_impl.costPerCoeff(vectorized) * (SrcCoeffRatio / PacketSize) +
421           TensorOpCost(0, 0, TgtCoeffRatio * (cast_cost / PacketSize));
422     } else {
423       return m_impl.costPerCoeff(vectorized) + TensorOpCost(0, 0, cast_cost);
424     }
425   }
426 
427   EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE
428   internal::TensorBlockResourceRequirements getResourceRequirements() const {
429     return m_impl.getResourceRequirements();
430   }
431 
432   EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE TensorBlock
433   block(TensorBlockDesc& desc, TensorBlockScratch& scratch,
434           bool /*root_of_expr_ast*/ = false) const {
435     return TensorBlock(m_impl.block(desc, scratch),
436                          TensorConversionOpBlockFactory());
437   }
438 
439   EIGEN_DEVICE_FUNC EvaluatorPointerType data() const { return NULL; }
440 
441   /// required by sycl in order to extract the sycl accessor
442   const TensorEvaluator<ArgType, Device>& impl() const { return m_impl; }
443 #ifdef EIGEN_USE_SYCL
444   // binding placeholder accessors to a command group handler for SYCL
445   EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void bind(cl::sycl::handler &cgh) const {
446     m_impl.bind(cgh);
447   }
448 #endif
449 
450  protected:
451   TensorEvaluator<ArgType, Device> m_impl;
452 };
453 
454 } // end namespace Eigen
455 
456 #endif // EIGEN_CXX11_TENSOR_TENSOR_CONVERSION_H
457