xref: /aosp_15_r20/external/eigen/unsupported/Eigen/CXX11/src/Tensor/TensorPatch.h (revision bf2c37156dfe67e5dfebd6d394bad8b2ab5804d4)
1 // This file is part of Eigen, a lightweight C++ template library
2 // for linear algebra.
3 //
4 // Copyright (C) 2014 Benoit Steiner <[email protected]>
5 //
6 // This Source Code Form is subject to the terms of the Mozilla
7 // Public License v. 2.0. If a copy of the MPL was not distributed
8 // with this file, You can obtain one at http://mozilla.org/MPL/2.0/.
9 
10 #ifndef EIGEN_CXX11_TENSOR_TENSOR_PATCH_H
11 #define EIGEN_CXX11_TENSOR_TENSOR_PATCH_H
12 
13 namespace Eigen {
14 
15 /** \class TensorPatch
16   * \ingroup CXX11_Tensor_Module
17   *
18   * \brief Tensor patch class.
19   *
20   *
21   */
22 namespace internal {
23 template<typename PatchDim, typename XprType>
24 struct traits<TensorPatchOp<PatchDim, XprType> > : public traits<XprType>
25 {
26   typedef typename XprType::Scalar Scalar;
27   typedef traits<XprType> XprTraits;
28   typedef typename XprTraits::StorageKind StorageKind;
29   typedef typename XprTraits::Index Index;
30   typedef typename XprType::Nested Nested;
31   typedef typename remove_reference<Nested>::type _Nested;
32   static const int NumDimensions = XprTraits::NumDimensions + 1;
33   static const int Layout = XprTraits::Layout;
34   typedef typename XprTraits::PointerType PointerType;
35 };
36 
37 template<typename PatchDim, typename XprType>
38 struct eval<TensorPatchOp<PatchDim, XprType>, Eigen::Dense>
39 {
40   typedef const TensorPatchOp<PatchDim, XprType>& type;
41 };
42 
43 template<typename PatchDim, typename XprType>
44 struct nested<TensorPatchOp<PatchDim, XprType>, 1, typename eval<TensorPatchOp<PatchDim, XprType> >::type>
45 {
46   typedef TensorPatchOp<PatchDim, XprType> type;
47 };
48 
49 }  // end namespace internal
50 
51 
52 
53 template<typename PatchDim, typename XprType>
54 class TensorPatchOp : public TensorBase<TensorPatchOp<PatchDim, XprType>, ReadOnlyAccessors>
55 {
56   public:
57   typedef typename Eigen::internal::traits<TensorPatchOp>::Scalar Scalar;
58   typedef typename Eigen::NumTraits<Scalar>::Real RealScalar;
59   typedef typename XprType::CoeffReturnType CoeffReturnType;
60   typedef typename Eigen::internal::nested<TensorPatchOp>::type Nested;
61   typedef typename Eigen::internal::traits<TensorPatchOp>::StorageKind StorageKind;
62   typedef typename Eigen::internal::traits<TensorPatchOp>::Index Index;
63 
64   EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE TensorPatchOp(const XprType& expr, const PatchDim& patch_dims)
65       : m_xpr(expr), m_patch_dims(patch_dims) {}
66 
67     EIGEN_DEVICE_FUNC
68     const PatchDim& patch_dims() const { return m_patch_dims; }
69 
70     EIGEN_DEVICE_FUNC
71     const typename internal::remove_all<typename XprType::Nested>::type&
72     expression() const { return m_xpr; }
73 
74   protected:
75     typename XprType::Nested m_xpr;
76     const PatchDim m_patch_dims;
77 };
78 
79 
80 // Eval as rvalue
81 template<typename PatchDim, typename ArgType, typename Device>
82 struct TensorEvaluator<const TensorPatchOp<PatchDim, ArgType>, Device>
83 {
84   typedef TensorPatchOp<PatchDim, ArgType> XprType;
85   typedef typename XprType::Index Index;
86   static const int NumDims = internal::array_size<typename TensorEvaluator<ArgType, Device>::Dimensions>::value + 1;
87   typedef DSizes<Index, NumDims> Dimensions;
88   typedef typename XprType::Scalar Scalar;
89   typedef typename XprType::CoeffReturnType CoeffReturnType;
90   typedef typename PacketType<CoeffReturnType, Device>::type PacketReturnType;
91   static const int PacketSize = PacketType<CoeffReturnType, Device>::size;
92   typedef StorageMemory<CoeffReturnType, Device> Storage;
93   typedef typename Storage::Type EvaluatorPointerType;
94 
95 
96   enum {
97     IsAligned = false,
98     PacketAccess = TensorEvaluator<ArgType, Device>::PacketAccess,
99     BlockAccess = false,
100     PreferBlockAccess = TensorEvaluator<ArgType, Device>::PreferBlockAccess,
101     Layout = TensorEvaluator<ArgType, Device>::Layout,
102     CoordAccess = false,
103     RawAccess = false
104  };
105 
106   //===- Tensor block evaluation strategy (see TensorBlock.h) -------------===//
107   typedef internal::TensorBlockNotImplemented TensorBlock;
108   //===--------------------------------------------------------------------===//
109 
110   EIGEN_STRONG_INLINE TensorEvaluator(const XprType& op, const Device& device)
111       : m_impl(op.expression(), device)
112   {
113     Index num_patches = 1;
114     const typename TensorEvaluator<ArgType, Device>::Dimensions& input_dims = m_impl.dimensions();
115     const PatchDim& patch_dims = op.patch_dims();
116     if (static_cast<int>(Layout) == static_cast<int>(ColMajor)) {
117       for (int i = 0; i < NumDims-1; ++i) {
118         m_dimensions[i] = patch_dims[i];
119         num_patches *= (input_dims[i] - patch_dims[i] + 1);
120       }
121       m_dimensions[NumDims-1] = num_patches;
122 
123       m_inputStrides[0] = 1;
124       m_patchStrides[0] = 1;
125       for (int i = 1; i < NumDims-1; ++i) {
126         m_inputStrides[i] = m_inputStrides[i-1] * input_dims[i-1];
127         m_patchStrides[i] = m_patchStrides[i-1] * (input_dims[i-1] - patch_dims[i-1] + 1);
128       }
129       m_outputStrides[0] = 1;
130       for (int i = 1; i < NumDims; ++i) {
131         m_outputStrides[i] = m_outputStrides[i-1] * m_dimensions[i-1];
132       }
133     } else {
134       for (int i = 0; i < NumDims-1; ++i) {
135         m_dimensions[i+1] = patch_dims[i];
136         num_patches *= (input_dims[i] - patch_dims[i] + 1);
137       }
138       m_dimensions[0] = num_patches;
139 
140       m_inputStrides[NumDims-2] = 1;
141       m_patchStrides[NumDims-2] = 1;
142       for (int i = NumDims-3; i >= 0; --i) {
143         m_inputStrides[i] = m_inputStrides[i+1] * input_dims[i+1];
144         m_patchStrides[i] = m_patchStrides[i+1] * (input_dims[i+1] - patch_dims[i+1] + 1);
145       }
146       m_outputStrides[NumDims-1] = 1;
147       for (int i = NumDims-2; i >= 0; --i) {
148         m_outputStrides[i] = m_outputStrides[i+1] * m_dimensions[i+1];
149       }
150     }
151   }
152 
153   EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const Dimensions& dimensions() const { return m_dimensions; }
154 
155   EIGEN_STRONG_INLINE bool evalSubExprsIfNeeded(EvaluatorPointerType /*data*/) {
156     m_impl.evalSubExprsIfNeeded(NULL);
157     return true;
158   }
159 
160   EIGEN_STRONG_INLINE void cleanup() {
161     m_impl.cleanup();
162   }
163 
164   EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE CoeffReturnType coeff(Index index) const
165   {
166     Index output_stride_index = (static_cast<int>(Layout) == static_cast<int>(ColMajor)) ? NumDims - 1 : 0;
167     // Find the location of the first element of the patch.
168     Index patchIndex = index / m_outputStrides[output_stride_index];
169     // Find the offset of the element wrt the location of the first element.
170     Index patchOffset = index - patchIndex * m_outputStrides[output_stride_index];
171     Index inputIndex = 0;
172     if (static_cast<int>(Layout) == static_cast<int>(ColMajor)) {
173       EIGEN_UNROLL_LOOP
174       for (int i = NumDims - 2; i > 0; --i) {
175         const Index patchIdx = patchIndex / m_patchStrides[i];
176         patchIndex -= patchIdx * m_patchStrides[i];
177         const Index offsetIdx = patchOffset / m_outputStrides[i];
178         patchOffset -= offsetIdx * m_outputStrides[i];
179         inputIndex += (patchIdx + offsetIdx) * m_inputStrides[i];
180       }
181     } else {
182       EIGEN_UNROLL_LOOP
183       for (int i = 0; i < NumDims - 2; ++i) {
184         const Index patchIdx = patchIndex / m_patchStrides[i];
185         patchIndex -= patchIdx * m_patchStrides[i];
186         const Index offsetIdx = patchOffset / m_outputStrides[i+1];
187         patchOffset -= offsetIdx * m_outputStrides[i+1];
188         inputIndex += (patchIdx + offsetIdx) * m_inputStrides[i];
189       }
190     }
191     inputIndex += (patchIndex + patchOffset);
192     return m_impl.coeff(inputIndex);
193   }
194 
195   template<int LoadMode>
196   EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE PacketReturnType packet(Index index) const
197   {
198     EIGEN_STATIC_ASSERT((PacketSize > 1), YOU_MADE_A_PROGRAMMING_MISTAKE)
199     eigen_assert(index+PacketSize-1 < dimensions().TotalSize());
200 
201     Index output_stride_index = (static_cast<int>(Layout) == static_cast<int>(ColMajor)) ? NumDims - 1 : 0;
202     Index indices[2] = {index, index + PacketSize - 1};
203     Index patchIndices[2] = {indices[0] / m_outputStrides[output_stride_index],
204                              indices[1] / m_outputStrides[output_stride_index]};
205     Index patchOffsets[2] = {indices[0] - patchIndices[0] * m_outputStrides[output_stride_index],
206                              indices[1] - patchIndices[1] * m_outputStrides[output_stride_index]};
207 
208     Index inputIndices[2] = {0, 0};
209     if (static_cast<int>(Layout) == static_cast<int>(ColMajor)) {
210       EIGEN_UNROLL_LOOP
211       for (int i = NumDims - 2; i > 0; --i) {
212         const Index patchIdx[2] = {patchIndices[0] / m_patchStrides[i],
213                                    patchIndices[1] / m_patchStrides[i]};
214         patchIndices[0] -= patchIdx[0] * m_patchStrides[i];
215         patchIndices[1] -= patchIdx[1] * m_patchStrides[i];
216 
217         const Index offsetIdx[2] = {patchOffsets[0] / m_outputStrides[i],
218                                     patchOffsets[1] / m_outputStrides[i]};
219         patchOffsets[0] -= offsetIdx[0] * m_outputStrides[i];
220         patchOffsets[1] -= offsetIdx[1] * m_outputStrides[i];
221 
222         inputIndices[0] += (patchIdx[0] + offsetIdx[0]) * m_inputStrides[i];
223         inputIndices[1] += (patchIdx[1] + offsetIdx[1]) * m_inputStrides[i];
224       }
225     } else {
226       EIGEN_UNROLL_LOOP
227       for (int i = 0; i < NumDims - 2; ++i) {
228         const Index patchIdx[2] = {patchIndices[0] / m_patchStrides[i],
229                                    patchIndices[1] / m_patchStrides[i]};
230         patchIndices[0] -= patchIdx[0] * m_patchStrides[i];
231         patchIndices[1] -= patchIdx[1] * m_patchStrides[i];
232 
233         const Index offsetIdx[2] = {patchOffsets[0] / m_outputStrides[i+1],
234                                     patchOffsets[1] / m_outputStrides[i+1]};
235         patchOffsets[0] -= offsetIdx[0] * m_outputStrides[i+1];
236         patchOffsets[1] -= offsetIdx[1] * m_outputStrides[i+1];
237 
238         inputIndices[0] += (patchIdx[0] + offsetIdx[0]) * m_inputStrides[i];
239         inputIndices[1] += (patchIdx[1] + offsetIdx[1]) * m_inputStrides[i];
240       }
241     }
242     inputIndices[0] += (patchIndices[0] + patchOffsets[0]);
243     inputIndices[1] += (patchIndices[1] + patchOffsets[1]);
244 
245     if (inputIndices[1] - inputIndices[0] == PacketSize - 1) {
246       PacketReturnType rslt = m_impl.template packet<Unaligned>(inputIndices[0]);
247       return rslt;
248     }
249     else {
250       EIGEN_ALIGN_MAX CoeffReturnType values[PacketSize];
251       values[0] = m_impl.coeff(inputIndices[0]);
252       values[PacketSize-1] = m_impl.coeff(inputIndices[1]);
253       EIGEN_UNROLL_LOOP
254       for (int i = 1; i < PacketSize-1; ++i) {
255         values[i] = coeff(index+i);
256       }
257       PacketReturnType rslt = internal::pload<PacketReturnType>(values);
258       return rslt;
259     }
260   }
261 
262   EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE TensorOpCost costPerCoeff(bool vectorized) const {
263     const double compute_cost = NumDims * (TensorOpCost::DivCost<Index>() +
264                                            TensorOpCost::MulCost<Index>() +
265                                            2 * TensorOpCost::AddCost<Index>());
266     return m_impl.costPerCoeff(vectorized) +
267            TensorOpCost(0, 0, compute_cost, vectorized, PacketSize);
268   }
269 
270   EIGEN_DEVICE_FUNC EvaluatorPointerType data() const { return NULL; }
271 
272 #ifdef EIGEN_USE_SYCL
273   // binding placeholder accessors to a command group handler for SYCL
274   EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void bind(cl::sycl::handler &cgh) const {
275     m_impl.bind(cgh);
276   }
277 #endif
278 
279  protected:
280   Dimensions m_dimensions;
281   array<Index, NumDims> m_outputStrides;
282   array<Index, NumDims-1> m_inputStrides;
283   array<Index, NumDims-1> m_patchStrides;
284 
285   TensorEvaluator<ArgType, Device> m_impl;
286 
287 };
288 
289 } // end namespace Eigen
290 
291 #endif // EIGEN_CXX11_TENSOR_TENSOR_PATCH_H
292