xref: /aosp_15_r20/external/eigen/unsupported/Eigen/CXX11/src/Tensor/TensorCustomOp.h (revision bf2c37156dfe67e5dfebd6d394bad8b2ab5804d4)
1 // This file is part of Eigen, a lightweight C++ template library
2 // for linear algebra.
3 //
4 // Copyright (C) 2014 Benoit Steiner <[email protected]>
5 //
6 // This Source Code Form is subject to the terms of the Mozilla
7 // Public License v. 2.0. If a copy of the MPL was not distributed
8 // with this file, You can obtain one at http://mozilla.org/MPL/2.0/.
9 
10 #ifndef EIGEN_CXX11_TENSOR_TENSOR_CUSTOM_OP_H
11 #define EIGEN_CXX11_TENSOR_TENSOR_CUSTOM_OP_H
12 
13 namespace Eigen {
14 
15 /** \class TensorCustomUnaryOp
16   * \ingroup CXX11_Tensor_Module
17   *
18   * \brief Tensor custom class.
19   *
20   *
21   */
22 namespace internal {
23 template<typename CustomUnaryFunc, typename XprType>
24 struct traits<TensorCustomUnaryOp<CustomUnaryFunc, XprType> >
25 {
26   typedef typename XprType::Scalar Scalar;
27   typedef typename XprType::StorageKind StorageKind;
28   typedef typename XprType::Index Index;
29   typedef typename XprType::Nested Nested;
30   typedef typename remove_reference<Nested>::type _Nested;
31   static const int NumDimensions = traits<XprType>::NumDimensions;
32   static const int Layout = traits<XprType>::Layout;
33   typedef typename traits<XprType>::PointerType PointerType;
34 };
35 
36 template<typename CustomUnaryFunc, typename XprType>
37 struct eval<TensorCustomUnaryOp<CustomUnaryFunc, XprType>, Eigen::Dense>
38 {
39   typedef const TensorCustomUnaryOp<CustomUnaryFunc, XprType>EIGEN_DEVICE_REF type;
40 };
41 
42 template<typename CustomUnaryFunc, typename XprType>
43 struct nested<TensorCustomUnaryOp<CustomUnaryFunc, XprType> >
44 {
45   typedef TensorCustomUnaryOp<CustomUnaryFunc, XprType> type;
46 };
47 
48 }  // end namespace internal
49 
50 
51 
52 template<typename CustomUnaryFunc, typename XprType>
53 class TensorCustomUnaryOp : public TensorBase<TensorCustomUnaryOp<CustomUnaryFunc, XprType>, ReadOnlyAccessors>
54 {
55   public:
56   typedef typename internal::traits<TensorCustomUnaryOp>::Scalar Scalar;
57   typedef typename Eigen::NumTraits<Scalar>::Real RealScalar;
58   typedef typename XprType::CoeffReturnType CoeffReturnType;
59   typedef typename internal::nested<TensorCustomUnaryOp>::type Nested;
60   typedef typename internal::traits<TensorCustomUnaryOp>::StorageKind StorageKind;
61   typedef typename internal::traits<TensorCustomUnaryOp>::Index Index;
62 
63   EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE TensorCustomUnaryOp(const XprType& expr, const CustomUnaryFunc& func)
64       : m_expr(expr), m_func(func) {}
65 
66   EIGEN_DEVICE_FUNC
67   const CustomUnaryFunc& func() const { return m_func; }
68 
69   EIGEN_DEVICE_FUNC
70   const typename internal::remove_all<typename XprType::Nested>::type&
71   expression() const { return m_expr; }
72 
73   protected:
74     typename XprType::Nested m_expr;
75     const CustomUnaryFunc m_func;
76 };
77 
78 
79 // Eval as rvalue
80 template<typename CustomUnaryFunc, typename XprType, typename Device>
81 struct TensorEvaluator<const TensorCustomUnaryOp<CustomUnaryFunc, XprType>, Device>
82 {
83   typedef TensorCustomUnaryOp<CustomUnaryFunc, XprType> ArgType;
84   typedef typename internal::traits<ArgType>::Index Index;
85   static const int NumDims = internal::traits<ArgType>::NumDimensions;
86   typedef DSizes<Index, NumDims> Dimensions;
87   typedef typename internal::remove_const<typename ArgType::Scalar>::type Scalar;
88   typedef typename internal::remove_const<typename XprType::CoeffReturnType>::type CoeffReturnType;
89   typedef typename PacketType<CoeffReturnType, Device>::type PacketReturnType;
90   static const int PacketSize = PacketType<CoeffReturnType, Device>::size;
91   typedef typename Eigen::internal::traits<XprType>::PointerType TensorPointerType;
92   typedef StorageMemory<CoeffReturnType, Device> Storage;
93   typedef typename Storage::Type EvaluatorPointerType;
94 
95   enum {
96     IsAligned = false,
97     PacketAccess = (PacketType<CoeffReturnType, Device>::size > 1),
98     BlockAccess = false,
99     PreferBlockAccess = false,
100     Layout = TensorEvaluator<XprType, Device>::Layout,
101     CoordAccess = false,  // to be implemented
102     RawAccess = false
103   };
104 
105   //===- Tensor block evaluation strategy (see TensorBlock.h) -------------===//
106   typedef internal::TensorBlockNotImplemented TensorBlock;
107   //===--------------------------------------------------------------------===//
108 
109   EIGEN_STRONG_INLINE TensorEvaluator(const ArgType& op, const Device& device)
110       : m_op(op), m_device(device), m_result(NULL)
111   {
112     m_dimensions = op.func().dimensions(op.expression());
113   }
114 
115   EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const Dimensions& dimensions() const { return m_dimensions; }
116 
117   EIGEN_STRONG_INLINE bool evalSubExprsIfNeeded(EvaluatorPointerType data) {
118     if (data) {
119       evalTo(data);
120       return false;
121     } else {
122       m_result = static_cast<EvaluatorPointerType>(m_device.get( (CoeffReturnType*)
123           m_device.allocate_temp(dimensions().TotalSize() * sizeof(Scalar))));
124       evalTo(m_result);
125       return true;
126     }
127   }
128 
129   EIGEN_STRONG_INLINE void cleanup() {
130     if (m_result) {
131       m_device.deallocate_temp(m_result);
132       m_result = NULL;
133     }
134   }
135 
136   EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE CoeffReturnType coeff(Index index) const {
137     return m_result[index];
138   }
139 
140   template<int LoadMode>
141   EIGEN_DEVICE_FUNC PacketReturnType packet(Index index) const {
142     return internal::ploadt<PacketReturnType, LoadMode>(m_result + index);
143   }
144 
145   EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE TensorOpCost costPerCoeff(bool vectorized) const {
146     // TODO(rmlarsen): Extend CustomOp API to return its cost estimate.
147     return TensorOpCost(sizeof(CoeffReturnType), 0, 0, vectorized, PacketSize);
148   }
149 
150   EIGEN_DEVICE_FUNC EvaluatorPointerType data() const { return m_result; }
151 
152 #ifdef EIGEN_USE_SYCL
153   // binding placeholder accessors to a command group handler for SYCL
154   EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void bind(cl::sycl::handler &cgh) const {
155     m_result.bind(cgh);
156   }
157 #endif
158 
159  protected:
160   void evalTo(EvaluatorPointerType data) {
161     TensorMap<Tensor<CoeffReturnType, NumDims, Layout, Index> > result(m_device.get(data), m_dimensions);
162     m_op.func().eval(m_op.expression(), result, m_device);
163   }
164 
165   Dimensions m_dimensions;
166   const ArgType m_op;
167   const Device EIGEN_DEVICE_REF m_device;
168   EvaluatorPointerType m_result;
169 };
170 
171 
172 
173 /** \class TensorCustomBinaryOp
174   * \ingroup CXX11_Tensor_Module
175   *
176   * \brief Tensor custom class.
177   *
178   *
179   */
180 namespace internal {
181 template<typename CustomBinaryFunc, typename LhsXprType, typename RhsXprType>
182 struct traits<TensorCustomBinaryOp<CustomBinaryFunc, LhsXprType, RhsXprType> >
183 {
184   typedef typename internal::promote_storage_type<typename LhsXprType::Scalar,
185                                                   typename RhsXprType::Scalar>::ret Scalar;
186   typedef typename internal::promote_storage_type<typename LhsXprType::CoeffReturnType,
187                                                   typename RhsXprType::CoeffReturnType>::ret CoeffReturnType;
188   typedef typename promote_storage_type<typename traits<LhsXprType>::StorageKind,
189                                         typename traits<RhsXprType>::StorageKind>::ret StorageKind;
190   typedef typename promote_index_type<typename traits<LhsXprType>::Index,
191                                       typename traits<RhsXprType>::Index>::type Index;
192   typedef typename LhsXprType::Nested LhsNested;
193   typedef typename RhsXprType::Nested RhsNested;
194   typedef typename remove_reference<LhsNested>::type _LhsNested;
195   typedef typename remove_reference<RhsNested>::type _RhsNested;
196   static const int NumDimensions = traits<LhsXprType>::NumDimensions;
197   static const int Layout = traits<LhsXprType>::Layout;
198   typedef typename conditional<Pointer_type_promotion<typename LhsXprType::Scalar, Scalar>::val,
199                                 typename traits<LhsXprType>::PointerType, typename traits<RhsXprType>::PointerType>::type PointerType;
200 };
201 
202 template<typename CustomBinaryFunc, typename LhsXprType, typename RhsXprType>
203 struct eval<TensorCustomBinaryOp<CustomBinaryFunc, LhsXprType, RhsXprType>, Eigen::Dense>
204 {
205   typedef const TensorCustomBinaryOp<CustomBinaryFunc, LhsXprType, RhsXprType>& type;
206 };
207 
208 template<typename CustomBinaryFunc, typename LhsXprType, typename RhsXprType>
209 struct nested<TensorCustomBinaryOp<CustomBinaryFunc, LhsXprType, RhsXprType> >
210 {
211   typedef TensorCustomBinaryOp<CustomBinaryFunc, LhsXprType, RhsXprType> type;
212 };
213 
214 }  // end namespace internal
215 
216 
217 
218 template<typename CustomBinaryFunc, typename LhsXprType, typename RhsXprType>
219 class TensorCustomBinaryOp : public TensorBase<TensorCustomBinaryOp<CustomBinaryFunc, LhsXprType, RhsXprType>, ReadOnlyAccessors>
220 {
221   public:
222   typedef typename internal::traits<TensorCustomBinaryOp>::Scalar Scalar;
223   typedef typename Eigen::NumTraits<Scalar>::Real RealScalar;
224   typedef typename internal::traits<TensorCustomBinaryOp>::CoeffReturnType CoeffReturnType;
225   typedef typename internal::nested<TensorCustomBinaryOp>::type Nested;
226   typedef typename internal::traits<TensorCustomBinaryOp>::StorageKind StorageKind;
227   typedef typename internal::traits<TensorCustomBinaryOp>::Index Index;
228 
229   EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE TensorCustomBinaryOp(const LhsXprType& lhs, const RhsXprType& rhs, const CustomBinaryFunc& func)
230 
231       : m_lhs_xpr(lhs), m_rhs_xpr(rhs), m_func(func) {}
232 
233   EIGEN_DEVICE_FUNC
234   const CustomBinaryFunc& func() const { return m_func; }
235 
236   EIGEN_DEVICE_FUNC
237   const typename internal::remove_all<typename LhsXprType::Nested>::type&
238   lhsExpression() const { return m_lhs_xpr; }
239 
240   EIGEN_DEVICE_FUNC
241   const typename internal::remove_all<typename RhsXprType::Nested>::type&
242   rhsExpression() const { return m_rhs_xpr; }
243 
244   protected:
245     typename LhsXprType::Nested m_lhs_xpr;
246     typename RhsXprType::Nested m_rhs_xpr;
247     const CustomBinaryFunc m_func;
248 };
249 
250 
251 // Eval as rvalue
252 template<typename CustomBinaryFunc, typename LhsXprType, typename RhsXprType, typename Device>
253 struct TensorEvaluator<const TensorCustomBinaryOp<CustomBinaryFunc, LhsXprType, RhsXprType>, Device>
254 {
255   typedef TensorCustomBinaryOp<CustomBinaryFunc, LhsXprType, RhsXprType> XprType;
256   typedef typename internal::traits<XprType>::Index Index;
257   static const int NumDims = internal::traits<XprType>::NumDimensions;
258   typedef DSizes<Index, NumDims> Dimensions;
259   typedef typename XprType::Scalar Scalar;
260   typedef typename internal::remove_const<typename XprType::CoeffReturnType>::type CoeffReturnType;
261   typedef typename PacketType<CoeffReturnType, Device>::type PacketReturnType;
262   static const int PacketSize = PacketType<CoeffReturnType, Device>::size;
263 
264   typedef typename Eigen::internal::traits<XprType>::PointerType TensorPointerType;
265   typedef StorageMemory<CoeffReturnType, Device> Storage;
266   typedef typename Storage::Type EvaluatorPointerType;
267 
268   enum {
269     IsAligned = false,
270     PacketAccess = (PacketType<CoeffReturnType, Device>::size > 1),
271     BlockAccess = false,
272     PreferBlockAccess = false,
273     Layout = TensorEvaluator<LhsXprType, Device>::Layout,
274     CoordAccess = false,  // to be implemented
275     RawAccess = false
276   };
277 
278   //===- Tensor block evaluation strategy (see TensorBlock.h) -------------===//
279   typedef internal::TensorBlockNotImplemented TensorBlock;
280   //===--------------------------------------------------------------------===//
281 
282   EIGEN_STRONG_INLINE TensorEvaluator(const XprType& op, const Device& device)
283       : m_op(op), m_device(device), m_result(NULL)
284   {
285     m_dimensions = op.func().dimensions(op.lhsExpression(), op.rhsExpression());
286   }
287 
288   EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const Dimensions& dimensions() const { return m_dimensions; }
289 
290   EIGEN_STRONG_INLINE bool evalSubExprsIfNeeded(EvaluatorPointerType data) {
291     if (data) {
292       evalTo(data);
293       return false;
294     } else {
295       m_result = static_cast<EvaluatorPointerType>(m_device.get( (CoeffReturnType*)
296         m_device.allocate_temp(dimensions().TotalSize() * sizeof(CoeffReturnType))));
297       evalTo(m_result);
298       return true;
299     }
300   }
301 
302   EIGEN_STRONG_INLINE void cleanup() {
303     if (m_result != NULL) {
304       m_device.deallocate_temp(m_result);
305       m_result = NULL;
306     }
307   }
308 
309   EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE CoeffReturnType coeff(Index index) const {
310     return m_result[index];
311   }
312 
313   template<int LoadMode>
314   EIGEN_DEVICE_FUNC PacketReturnType packet(Index index) const {
315     return internal::ploadt<PacketReturnType, LoadMode>(m_result + index);
316   }
317 
318   EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE TensorOpCost costPerCoeff(bool vectorized) const {
319     // TODO(rmlarsen): Extend CustomOp API to return its cost estimate.
320     return TensorOpCost(sizeof(CoeffReturnType), 0, 0, vectorized, PacketSize);
321   }
322 
323   EIGEN_DEVICE_FUNC EvaluatorPointerType data() const { return m_result; }
324 
325 #ifdef EIGEN_USE_SYCL
326   // binding placeholder accessors to a command group handler for SYCL
327   EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void bind(cl::sycl::handler &cgh) const {
328     m_result.bind(cgh);
329   }
330 #endif
331 
332  protected:
333   void evalTo(EvaluatorPointerType data) {
334     TensorMap<Tensor<CoeffReturnType, NumDims, Layout> > result(m_device.get(data), m_dimensions);
335     m_op.func().eval(m_op.lhsExpression(), m_op.rhsExpression(), result, m_device);
336   }
337 
338   Dimensions m_dimensions;
339   const XprType m_op;
340   const Device EIGEN_DEVICE_REF m_device;
341   EvaluatorPointerType m_result;
342 };
343 
344 
345 } // end namespace Eigen
346 
347 #endif // EIGEN_CXX11_TENSOR_TENSOR_CUSTOM_OP_H
348