1*89c4ff92SAndroid Build Coastguard Worker //
2*89c4ff92SAndroid Build Coastguard Worker // Copyright © 2017-2023 Arm Ltd and Contributors. All rights reserved.
3*89c4ff92SAndroid Build Coastguard Worker // SPDX-License-Identifier: MIT
4*89c4ff92SAndroid Build Coastguard Worker //
5*89c4ff92SAndroid Build Coastguard Worker #pragma once
6*89c4ff92SAndroid Build Coastguard Worker
7*89c4ff92SAndroid Build Coastguard Worker #include <armnn/Tensor.hpp>
8*89c4ff92SAndroid Build Coastguard Worker #include <armnn/DescriptorsFwd.hpp>
9*89c4ff92SAndroid Build Coastguard Worker
10*89c4ff92SAndroid Build Coastguard Worker #include <armnn/utility/NumericCast.hpp>
11*89c4ff92SAndroid Build Coastguard Worker
12*89c4ff92SAndroid Build Coastguard Worker #include <arm_compute/core/ITensor.h>
13*89c4ff92SAndroid Build Coastguard Worker #include <arm_compute/core/TensorInfo.h>
14*89c4ff92SAndroid Build Coastguard Worker #include <arm_compute/core/Types.h>
15*89c4ff92SAndroid Build Coastguard Worker
16*89c4ff92SAndroid Build Coastguard Worker #include <Half.hpp>
17*89c4ff92SAndroid Build Coastguard Worker
18*89c4ff92SAndroid Build Coastguard Worker namespace armnn
19*89c4ff92SAndroid Build Coastguard Worker {
20*89c4ff92SAndroid Build Coastguard Worker class ITensorHandle;
21*89c4ff92SAndroid Build Coastguard Worker
22*89c4ff92SAndroid Build Coastguard Worker namespace armcomputetensorutils
23*89c4ff92SAndroid Build Coastguard Worker {
24*89c4ff92SAndroid Build Coastguard Worker
25*89c4ff92SAndroid Build Coastguard Worker /// Utility function to map an armnn::DataType to corresponding arm_compute::DataType.
26*89c4ff92SAndroid Build Coastguard Worker arm_compute::DataType GetArmComputeDataType(armnn::DataType dataType, bool multiScales);
27*89c4ff92SAndroid Build Coastguard Worker
28*89c4ff92SAndroid Build Coastguard Worker /// Utility function to map an arm_compute::DataType to corresponding armnn::DataType.
29*89c4ff92SAndroid Build Coastguard Worker armnn::DataType GetArmNNDataType(arm_compute::DataType datatype);
30*89c4ff92SAndroid Build Coastguard Worker
31*89c4ff92SAndroid Build Coastguard Worker /// Utility function used to set up an arm_compute::Coordinates from a vector of ArmNN Axes for reduction functions
32*89c4ff92SAndroid Build Coastguard Worker arm_compute::Coordinates BuildArmComputeReductionCoordinates(size_t inputDimensions,
33*89c4ff92SAndroid Build Coastguard Worker unsigned int originalInputRank,
34*89c4ff92SAndroid Build Coastguard Worker const std::vector<unsigned int>& armnnAxes);
35*89c4ff92SAndroid Build Coastguard Worker
36*89c4ff92SAndroid Build Coastguard Worker /// Utility function used to setup an arm_compute::TensorShape object from an armnn::TensorShape.
37*89c4ff92SAndroid Build Coastguard Worker arm_compute::TensorShape BuildArmComputeTensorShape(const armnn::TensorShape& tensorShape);
38*89c4ff92SAndroid Build Coastguard Worker
39*89c4ff92SAndroid Build Coastguard Worker /// Utility function used to setup an arm_compute::TensorShape object from an armnn::TensorShape. This will
40*89c4ff92SAndroid Build Coastguard Worker /// attempt to reduce the number of leading 1s until the dimension length is equal to the dimensions passed in.
41*89c4ff92SAndroid Build Coastguard Worker arm_compute::TensorShape BuildArmComputeTensorShape(const armnn::TensorShape& tensorShape, unsigned int dimensions);
42*89c4ff92SAndroid Build Coastguard Worker
43*89c4ff92SAndroid Build Coastguard Worker /// Utility function used to setup an arm_compute::ITensorInfo object whose dimensions are based on the given
44*89c4ff92SAndroid Build Coastguard Worker /// armnn::ITensorInfo.
45*89c4ff92SAndroid Build Coastguard Worker arm_compute::TensorInfo BuildArmComputeTensorInfo(const armnn::TensorInfo& tensorInfo);
46*89c4ff92SAndroid Build Coastguard Worker
47*89c4ff92SAndroid Build Coastguard Worker /// Utility function used to setup an arm_compute::ITensorInfo object whose dimensions are based on the given
48*89c4ff92SAndroid Build Coastguard Worker /// armnn::ITensorInfo. This will attempt to reduce the number of leading 1s until the dimension length is equal
49*89c4ff92SAndroid Build Coastguard Worker /// to the dimensions passed in.
50*89c4ff92SAndroid Build Coastguard Worker arm_compute::TensorInfo BuildArmComputeTensorInfo(const armnn::TensorInfo& tensorInfo, unsigned int dimensions);
51*89c4ff92SAndroid Build Coastguard Worker
52*89c4ff92SAndroid Build Coastguard Worker /// Utility function used to setup an arm_compute::ITensorInfo object whose dimensions are based on the given
53*89c4ff92SAndroid Build Coastguard Worker /// armnn::ITensorInfo. This will attempt to reduce the number of leading 1s until the dimension length is equal
54*89c4ff92SAndroid Build Coastguard Worker /// to the dimensions passed in.
55*89c4ff92SAndroid Build Coastguard Worker arm_compute::TensorInfo BuildArmComputeTensorInfo(const armnn::TensorInfo& tensorInfo,
56*89c4ff92SAndroid Build Coastguard Worker armnn::DataLayout dataLayout,
57*89c4ff92SAndroid Build Coastguard Worker unsigned int dimensions);
58*89c4ff92SAndroid Build Coastguard Worker
59*89c4ff92SAndroid Build Coastguard Worker /// Utility function used to setup an arm_compute::ITensorInfo object whose dimensions are based on the given
60*89c4ff92SAndroid Build Coastguard Worker /// armnn::ITensorInfo.
61*89c4ff92SAndroid Build Coastguard Worker /// armnn::DataLayout.
62*89c4ff92SAndroid Build Coastguard Worker arm_compute::TensorInfo BuildArmComputeTensorInfo(const armnn::TensorInfo& tensorInfo,
63*89c4ff92SAndroid Build Coastguard Worker armnn::DataLayout dataLayout);
64*89c4ff92SAndroid Build Coastguard Worker
65*89c4ff92SAndroid Build Coastguard Worker /// Utility function used to setup an arm_compute::ITensorInfo object whose dimensions are based on the given
66*89c4ff92SAndroid Build Coastguard Worker /// armnn::ITensorInfo. This will attempt to reduce the number of leading 1s until the dimension length is equal
67*89c4ff92SAndroid Build Coastguard Worker /// to the dimensions passed in.
68*89c4ff92SAndroid Build Coastguard Worker arm_compute::TensorInfo BuildArmComputeTensorInfo(const armnn::TensorInfo& tensorInfo,
69*89c4ff92SAndroid Build Coastguard Worker armnn::DataLayout dataLayout, unsigned int dimensions);
70*89c4ff92SAndroid Build Coastguard Worker
71*89c4ff92SAndroid Build Coastguard Worker /// Utility function used to convert armnn::DataLayout to arm_compute::DataLayout
72*89c4ff92SAndroid Build Coastguard Worker /// armnn::DataLayout.
73*89c4ff92SAndroid Build Coastguard Worker arm_compute::DataLayout ConvertDataLayout(armnn::DataLayout dataLayout);
74*89c4ff92SAndroid Build Coastguard Worker
75*89c4ff92SAndroid Build Coastguard Worker /// Utility function used to setup an arm_compute::PoolingLayerInfo object from given
76*89c4ff92SAndroid Build Coastguard Worker /// armnn::Pooling2dDescriptor
77*89c4ff92SAndroid Build Coastguard Worker /// bool fpMixedPrecision
78*89c4ff92SAndroid Build Coastguard Worker arm_compute::PoolingLayerInfo BuildArmComputePoolingLayerInfo(const Pooling2dDescriptor& descriptor,
79*89c4ff92SAndroid Build Coastguard Worker bool fpMixedPrecision = false);
80*89c4ff92SAndroid Build Coastguard Worker
81*89c4ff92SAndroid Build Coastguard Worker /// Utility function used to setup an arm_compute::Pooling3dLayerInfo object from given
82*89c4ff92SAndroid Build Coastguard Worker /// armnn::Pooling3dDescriptor
83*89c4ff92SAndroid Build Coastguard Worker /// bool fpMixedPrecision
84*89c4ff92SAndroid Build Coastguard Worker arm_compute::Pooling3dLayerInfo BuildArmComputePooling3dLayerInfo(const Pooling3dDescriptor& descriptor,
85*89c4ff92SAndroid Build Coastguard Worker bool fpMixedPrecision = false);
86*89c4ff92SAndroid Build Coastguard Worker
87*89c4ff92SAndroid Build Coastguard Worker /// Utility function to setup an arm_compute::NormalizationLayerInfo object from an armnn::NormalizationDescriptor.
88*89c4ff92SAndroid Build Coastguard Worker arm_compute::NormalizationLayerInfo BuildArmComputeNormalizationLayerInfo(const NormalizationDescriptor& desc);
89*89c4ff92SAndroid Build Coastguard Worker
90*89c4ff92SAndroid Build Coastguard Worker /// Utility function used to setup an arm_compute::PermutationVector object from an armnn::PermutationVector.
91*89c4ff92SAndroid Build Coastguard Worker /// \param perm PermutationVector used in Arm NN Permute layer
92*89c4ff92SAndroid Build Coastguard Worker /// \return PermutationVector used in ACL Transpose layer
93*89c4ff92SAndroid Build Coastguard Worker arm_compute::PermutationVector BuildArmComputePermutationVector(const armnn::PermutationVector& perm);
94*89c4ff92SAndroid Build Coastguard Worker
95*89c4ff92SAndroid Build Coastguard Worker /// Utility function used to setup an arm_compute::PermutationVector object from an armnn::PermutationVector.
96*89c4ff92SAndroid Build Coastguard Worker /// \param perm PermutationVector used in Arm NN Transpose layer
97*89c4ff92SAndroid Build Coastguard Worker /// \return PermutationVector used in ACL Transpose layer
98*89c4ff92SAndroid Build Coastguard Worker arm_compute::PermutationVector BuildArmComputeTransposeVector(const armnn::PermutationVector& perm);
99*89c4ff92SAndroid Build Coastguard Worker
100*89c4ff92SAndroid Build Coastguard Worker /// Utility function used to setup an arm_compute::Size2D object from width and height values.
101*89c4ff92SAndroid Build Coastguard Worker arm_compute::Size2D BuildArmComputeSize2D(const unsigned int width, const unsigned int height);
102*89c4ff92SAndroid Build Coastguard Worker
103*89c4ff92SAndroid Build Coastguard Worker /// Gets the appropriate PixelValue for the TensorInfo DataType
104*89c4ff92SAndroid Build Coastguard Worker arm_compute::PixelValue GetPixelValue(const arm_compute::ITensorInfo* tensorInfo, float value);
105*89c4ff92SAndroid Build Coastguard Worker
106*89c4ff92SAndroid Build Coastguard Worker /// Computes the depth multiplier parameter for the Depthwise Conv2d ACL workload.
107*89c4ff92SAndroid Build Coastguard Worker unsigned int ComputeDepthwiseConv2dDepthMultiplier(armnn::DataLayout layout,
108*89c4ff92SAndroid Build Coastguard Worker const arm_compute::TensorShape& weightsShape,
109*89c4ff92SAndroid Build Coastguard Worker const arm_compute::TensorShape& inputShape);
110*89c4ff92SAndroid Build Coastguard Worker
111*89c4ff92SAndroid Build Coastguard Worker /// Utility function used to setup an arm_compute::PadStrideInfo object from an ArmNN layer descriptor.
112*89c4ff92SAndroid Build Coastguard Worker template <typename Descriptor>
BuildArmComputePadStrideInfo(const Descriptor & descriptor)113*89c4ff92SAndroid Build Coastguard Worker arm_compute::PadStrideInfo BuildArmComputePadStrideInfo(const Descriptor &descriptor)
114*89c4ff92SAndroid Build Coastguard Worker {
115*89c4ff92SAndroid Build Coastguard Worker return arm_compute::PadStrideInfo(descriptor.m_StrideX,
116*89c4ff92SAndroid Build Coastguard Worker descriptor.m_StrideY,
117*89c4ff92SAndroid Build Coastguard Worker descriptor.m_PadLeft,
118*89c4ff92SAndroid Build Coastguard Worker descriptor.m_PadRight,
119*89c4ff92SAndroid Build Coastguard Worker descriptor.m_PadTop,
120*89c4ff92SAndroid Build Coastguard Worker descriptor.m_PadBottom,
121*89c4ff92SAndroid Build Coastguard Worker arm_compute::DimensionRoundingType::FLOOR);
122*89c4ff92SAndroid Build Coastguard Worker }
123*89c4ff92SAndroid Build Coastguard Worker
124*89c4ff92SAndroid Build Coastguard Worker /// Sets up the given ArmCompute tensor's dimensions based on the given ArmNN tensor.
125*89c4ff92SAndroid Build Coastguard Worker template <typename Tensor>
BuildArmComputeTensor(Tensor & tensor,const armnn::TensorInfo & tensorInfo)126*89c4ff92SAndroid Build Coastguard Worker void BuildArmComputeTensor(Tensor& tensor, const armnn::TensorInfo& tensorInfo)
127*89c4ff92SAndroid Build Coastguard Worker {
128*89c4ff92SAndroid Build Coastguard Worker tensor.allocator()->init(BuildArmComputeTensorInfo(tensorInfo));
129*89c4ff92SAndroid Build Coastguard Worker }
130*89c4ff92SAndroid Build Coastguard Worker
131*89c4ff92SAndroid Build Coastguard Worker /// Sets up the given ArmCompute tensor's dimensions based on the given ArmNN tensor.
132*89c4ff92SAndroid Build Coastguard Worker template <typename Tensor>
BuildArmComputeTensor(Tensor & tensor,const armnn::TensorInfo & tensorInfo,DataLayout dataLayout)133*89c4ff92SAndroid Build Coastguard Worker void BuildArmComputeTensor(Tensor& tensor, const armnn::TensorInfo& tensorInfo, DataLayout dataLayout)
134*89c4ff92SAndroid Build Coastguard Worker {
135*89c4ff92SAndroid Build Coastguard Worker tensor.allocator()->init(BuildArmComputeTensorInfo(tensorInfo, dataLayout));
136*89c4ff92SAndroid Build Coastguard Worker }
137*89c4ff92SAndroid Build Coastguard Worker
138*89c4ff92SAndroid Build Coastguard Worker template <typename Tensor>
InitialiseArmComputeTensorEmpty(Tensor & tensor)139*89c4ff92SAndroid Build Coastguard Worker void InitialiseArmComputeTensorEmpty(Tensor& tensor)
140*89c4ff92SAndroid Build Coastguard Worker {
141*89c4ff92SAndroid Build Coastguard Worker tensor.allocator()->allocate();
142*89c4ff92SAndroid Build Coastguard Worker }
143*89c4ff92SAndroid Build Coastguard Worker
144*89c4ff92SAndroid Build Coastguard Worker /// Utility function to free unused tensors after a workload is configured and prepared
145*89c4ff92SAndroid Build Coastguard Worker template <typename Tensor>
FreeTensorIfUnused(std::unique_ptr<Tensor> & tensor)146*89c4ff92SAndroid Build Coastguard Worker void FreeTensorIfUnused(std::unique_ptr<Tensor>& tensor)
147*89c4ff92SAndroid Build Coastguard Worker {
148*89c4ff92SAndroid Build Coastguard Worker if (tensor && !tensor->is_used())
149*89c4ff92SAndroid Build Coastguard Worker {
150*89c4ff92SAndroid Build Coastguard Worker tensor.reset(nullptr);
151*89c4ff92SAndroid Build Coastguard Worker }
152*89c4ff92SAndroid Build Coastguard Worker }
153*89c4ff92SAndroid Build Coastguard Worker
154*89c4ff92SAndroid Build Coastguard Worker // Helper function to obtain byte offset into tensor data
GetTensorOffset(const arm_compute::ITensorInfo & info,uint32_t depthIndex,uint32_t batchIndex,uint32_t channelIndex,uint32_t y,uint32_t x)155*89c4ff92SAndroid Build Coastguard Worker inline size_t GetTensorOffset(const arm_compute::ITensorInfo& info,
156*89c4ff92SAndroid Build Coastguard Worker uint32_t depthIndex,
157*89c4ff92SAndroid Build Coastguard Worker uint32_t batchIndex,
158*89c4ff92SAndroid Build Coastguard Worker uint32_t channelIndex,
159*89c4ff92SAndroid Build Coastguard Worker uint32_t y,
160*89c4ff92SAndroid Build Coastguard Worker uint32_t x)
161*89c4ff92SAndroid Build Coastguard Worker {
162*89c4ff92SAndroid Build Coastguard Worker arm_compute::Coordinates coords;
163*89c4ff92SAndroid Build Coastguard Worker coords.set(4, static_cast<int>(depthIndex));
164*89c4ff92SAndroid Build Coastguard Worker coords.set(3, static_cast<int>(batchIndex));
165*89c4ff92SAndroid Build Coastguard Worker coords.set(2, static_cast<int>(channelIndex));
166*89c4ff92SAndroid Build Coastguard Worker coords.set(1, static_cast<int>(y));
167*89c4ff92SAndroid Build Coastguard Worker coords.set(0, static_cast<int>(x));
168*89c4ff92SAndroid Build Coastguard Worker return armnn::numeric_cast<size_t>(info.offset_element_in_bytes(coords));
169*89c4ff92SAndroid Build Coastguard Worker }
170*89c4ff92SAndroid Build Coastguard Worker
171*89c4ff92SAndroid Build Coastguard Worker // Helper function to obtain element offset into data buffer representing tensor data (assuming no strides).
GetLinearBufferOffset(const arm_compute::ITensorInfo & info,uint32_t depthIndex,uint32_t batchIndex,uint32_t channelIndex,uint32_t y,uint32_t x)172*89c4ff92SAndroid Build Coastguard Worker inline size_t GetLinearBufferOffset(const arm_compute::ITensorInfo& info,
173*89c4ff92SAndroid Build Coastguard Worker uint32_t depthIndex,
174*89c4ff92SAndroid Build Coastguard Worker uint32_t batchIndex,
175*89c4ff92SAndroid Build Coastguard Worker uint32_t channelIndex,
176*89c4ff92SAndroid Build Coastguard Worker uint32_t y,
177*89c4ff92SAndroid Build Coastguard Worker uint32_t x)
178*89c4ff92SAndroid Build Coastguard Worker {
179*89c4ff92SAndroid Build Coastguard Worker const arm_compute::TensorShape& shape = info.tensor_shape();
180*89c4ff92SAndroid Build Coastguard Worker uint32_t width = static_cast<uint32_t>(shape[0]);
181*89c4ff92SAndroid Build Coastguard Worker uint32_t height = static_cast<uint32_t>(shape[1]);
182*89c4ff92SAndroid Build Coastguard Worker uint32_t numChannels = static_cast<uint32_t>(shape[2]);
183*89c4ff92SAndroid Build Coastguard Worker uint32_t numBatches = static_cast<uint32_t>(shape[3]);
184*89c4ff92SAndroid Build Coastguard Worker return (((depthIndex * numBatches + batchIndex) * numChannels + channelIndex) * height + y) * width + x;
185*89c4ff92SAndroid Build Coastguard Worker }
186*89c4ff92SAndroid Build Coastguard Worker
187*89c4ff92SAndroid Build Coastguard Worker template <typename T>
CopyArmComputeITensorData(const arm_compute::ITensor & srcTensor,T * dstData)188*89c4ff92SAndroid Build Coastguard Worker void CopyArmComputeITensorData(const arm_compute::ITensor& srcTensor, T* dstData)
189*89c4ff92SAndroid Build Coastguard Worker {
190*89c4ff92SAndroid Build Coastguard Worker // If MaxNumOfTensorDimensions is increased, this loop will need fixing.
191*89c4ff92SAndroid Build Coastguard Worker static_assert(MaxNumOfTensorDimensions == 5, "Please update CopyArmComputeITensorData");
192*89c4ff92SAndroid Build Coastguard Worker {
193*89c4ff92SAndroid Build Coastguard Worker const arm_compute::ITensorInfo& info = *srcTensor.info();
194*89c4ff92SAndroid Build Coastguard Worker const arm_compute::TensorShape& shape = info.tensor_shape();
195*89c4ff92SAndroid Build Coastguard Worker const uint8_t* const bufferPtr = srcTensor.buffer();
196*89c4ff92SAndroid Build Coastguard Worker uint32_t width = static_cast<uint32_t>(shape[0]);
197*89c4ff92SAndroid Build Coastguard Worker uint32_t height = static_cast<uint32_t>(shape[1]);
198*89c4ff92SAndroid Build Coastguard Worker uint32_t numChannels = static_cast<uint32_t>(shape[2]);
199*89c4ff92SAndroid Build Coastguard Worker uint32_t numBatches = static_cast<uint32_t>(shape[3]);
200*89c4ff92SAndroid Build Coastguard Worker uint32_t depth = static_cast<uint32_t>(shape[4]);
201*89c4ff92SAndroid Build Coastguard Worker
202*89c4ff92SAndroid Build Coastguard Worker for (unsigned int depthIndex = 0; depthIndex < depth; ++depthIndex)
203*89c4ff92SAndroid Build Coastguard Worker {
204*89c4ff92SAndroid Build Coastguard Worker for (unsigned int batchIndex = 0; batchIndex < numBatches; ++batchIndex)
205*89c4ff92SAndroid Build Coastguard Worker {
206*89c4ff92SAndroid Build Coastguard Worker for (unsigned int channelIndex = 0; channelIndex < numChannels; ++channelIndex)
207*89c4ff92SAndroid Build Coastguard Worker {
208*89c4ff92SAndroid Build Coastguard Worker for (unsigned int y = 0; y < height; ++y)
209*89c4ff92SAndroid Build Coastguard Worker {
210*89c4ff92SAndroid Build Coastguard Worker // Copies one row from arm_compute tensor buffer to linear memory buffer.
211*89c4ff92SAndroid Build Coastguard Worker // A row is the largest contiguous region we can copy, as the tensor data may be using strides.
212*89c4ff92SAndroid Build Coastguard Worker memcpy(
213*89c4ff92SAndroid Build Coastguard Worker dstData + GetLinearBufferOffset(info, depthIndex, batchIndex, channelIndex, y, 0),
214*89c4ff92SAndroid Build Coastguard Worker bufferPtr + GetTensorOffset(info, depthIndex, batchIndex, channelIndex, y, 0),
215*89c4ff92SAndroid Build Coastguard Worker width * sizeof(T));
216*89c4ff92SAndroid Build Coastguard Worker }
217*89c4ff92SAndroid Build Coastguard Worker }
218*89c4ff92SAndroid Build Coastguard Worker }
219*89c4ff92SAndroid Build Coastguard Worker }
220*89c4ff92SAndroid Build Coastguard Worker }
221*89c4ff92SAndroid Build Coastguard Worker }
222*89c4ff92SAndroid Build Coastguard Worker
223*89c4ff92SAndroid Build Coastguard Worker template <typename T>
CopyArmComputeITensorData(const T * srcData,arm_compute::ITensor & dstTensor)224*89c4ff92SAndroid Build Coastguard Worker void CopyArmComputeITensorData(const T* srcData, arm_compute::ITensor& dstTensor)
225*89c4ff92SAndroid Build Coastguard Worker {
226*89c4ff92SAndroid Build Coastguard Worker // If MaxNumOfTensorDimensions is increased, this loop will need fixing.
227*89c4ff92SAndroid Build Coastguard Worker static_assert(MaxNumOfTensorDimensions == 5, "Please update CopyArmComputeITensorData");
228*89c4ff92SAndroid Build Coastguard Worker {
229*89c4ff92SAndroid Build Coastguard Worker const arm_compute::ITensorInfo& info = *dstTensor.info();
230*89c4ff92SAndroid Build Coastguard Worker const arm_compute::TensorShape& shape = info.tensor_shape();
231*89c4ff92SAndroid Build Coastguard Worker uint8_t* const bufferPtr = dstTensor.buffer();
232*89c4ff92SAndroid Build Coastguard Worker uint32_t width = static_cast<uint32_t>(shape[0]);
233*89c4ff92SAndroid Build Coastguard Worker uint32_t height = static_cast<uint32_t>(shape[1]);
234*89c4ff92SAndroid Build Coastguard Worker uint32_t numChannels = static_cast<uint32_t>(shape[2]);
235*89c4ff92SAndroid Build Coastguard Worker uint32_t numBatches = static_cast<uint32_t>(shape[3]);
236*89c4ff92SAndroid Build Coastguard Worker uint32_t depth = static_cast<uint32_t>(shape[4]);
237*89c4ff92SAndroid Build Coastguard Worker
238*89c4ff92SAndroid Build Coastguard Worker for (unsigned int depthIndex = 0; depthIndex < depth; ++depthIndex)
239*89c4ff92SAndroid Build Coastguard Worker {
240*89c4ff92SAndroid Build Coastguard Worker for (unsigned int batchIndex = 0; batchIndex < numBatches; ++batchIndex)
241*89c4ff92SAndroid Build Coastguard Worker {
242*89c4ff92SAndroid Build Coastguard Worker for (unsigned int channelIndex = 0; channelIndex < numChannels; ++channelIndex)
243*89c4ff92SAndroid Build Coastguard Worker {
244*89c4ff92SAndroid Build Coastguard Worker for (unsigned int y = 0; y < height; ++y)
245*89c4ff92SAndroid Build Coastguard Worker {
246*89c4ff92SAndroid Build Coastguard Worker // Copies one row from linear memory buffer to arm_compute tensor buffer.
247*89c4ff92SAndroid Build Coastguard Worker // A row is the largest contiguous region we can copy, as the tensor data may be using strides.
248*89c4ff92SAndroid Build Coastguard Worker memcpy(
249*89c4ff92SAndroid Build Coastguard Worker bufferPtr + GetTensorOffset(info, depthIndex, batchIndex, channelIndex, y, 0),
250*89c4ff92SAndroid Build Coastguard Worker srcData + GetLinearBufferOffset(info, depthIndex, batchIndex, channelIndex, y, 0),
251*89c4ff92SAndroid Build Coastguard Worker width * sizeof(T));
252*89c4ff92SAndroid Build Coastguard Worker }
253*89c4ff92SAndroid Build Coastguard Worker }
254*89c4ff92SAndroid Build Coastguard Worker }
255*89c4ff92SAndroid Build Coastguard Worker }
256*89c4ff92SAndroid Build Coastguard Worker }
257*89c4ff92SAndroid Build Coastguard Worker }
258*89c4ff92SAndroid Build Coastguard Worker
259*89c4ff92SAndroid Build Coastguard Worker /// Construct a TensorShape object from an ArmCompute object based on arm_compute::Dimensions.
260*89c4ff92SAndroid Build Coastguard Worker /// \tparam ArmComputeType Any type that implements the Dimensions interface
261*89c4ff92SAndroid Build Coastguard Worker /// \tparam T Shape value type
262*89c4ff92SAndroid Build Coastguard Worker /// \param shapelike An ArmCompute object that implements the Dimensions interface
263*89c4ff92SAndroid Build Coastguard Worker /// \param initial A default value to initialise the shape with
264*89c4ff92SAndroid Build Coastguard Worker /// \return A TensorShape object filled from the Acl shapelike object.
265*89c4ff92SAndroid Build Coastguard Worker template<typename ArmComputeType, typename T>
GetTensorShape(const ArmComputeType & shapelike,T initial)266*89c4ff92SAndroid Build Coastguard Worker TensorShape GetTensorShape(const ArmComputeType& shapelike, T initial)
267*89c4ff92SAndroid Build Coastguard Worker {
268*89c4ff92SAndroid Build Coastguard Worker std::vector<unsigned int> s(MaxNumOfTensorDimensions, initial);
269*89c4ff92SAndroid Build Coastguard Worker for (unsigned int i=0; i < shapelike.num_dimensions(); ++i)
270*89c4ff92SAndroid Build Coastguard Worker {
271*89c4ff92SAndroid Build Coastguard Worker s[(shapelike.num_dimensions()-1)-i] = armnn::numeric_cast<unsigned int>(shapelike[i]);
272*89c4ff92SAndroid Build Coastguard Worker }
273*89c4ff92SAndroid Build Coastguard Worker return TensorShape(armnn::numeric_cast<unsigned int>(shapelike.num_dimensions()), s.data());
274*89c4ff92SAndroid Build Coastguard Worker };
275*89c4ff92SAndroid Build Coastguard Worker
276*89c4ff92SAndroid Build Coastguard Worker /// Get the strides from an ACL strides object
GetStrides(const arm_compute::Strides & strides)277*89c4ff92SAndroid Build Coastguard Worker inline TensorShape GetStrides(const arm_compute::Strides& strides)
278*89c4ff92SAndroid Build Coastguard Worker {
279*89c4ff92SAndroid Build Coastguard Worker return GetTensorShape(strides, 0U);
280*89c4ff92SAndroid Build Coastguard Worker }
281*89c4ff92SAndroid Build Coastguard Worker
282*89c4ff92SAndroid Build Coastguard Worker /// Get the shape from an ACL shape object
GetShape(const arm_compute::TensorShape & shape)283*89c4ff92SAndroid Build Coastguard Worker inline TensorShape GetShape(const arm_compute::TensorShape& shape)
284*89c4ff92SAndroid Build Coastguard Worker {
285*89c4ff92SAndroid Build Coastguard Worker return GetTensorShape(shape, 1U);
286*89c4ff92SAndroid Build Coastguard Worker }
287*89c4ff92SAndroid Build Coastguard Worker
288*89c4ff92SAndroid Build Coastguard Worker } // namespace armcomputetensorutils
289*89c4ff92SAndroid Build Coastguard Worker } // namespace armnn
290