xref: /aosp_15_r20/external/armnn/src/backends/aclCommon/ArmComputeTensorUtils.cpp (revision 89c4ff92f2867872bb9e2354d150bf0c8c502810)
1*89c4ff92SAndroid Build Coastguard Worker //
2*89c4ff92SAndroid Build Coastguard Worker // Copyright © 2017-2023 Arm Ltd and Contributors. All rights reserved.
3*89c4ff92SAndroid Build Coastguard Worker // SPDX-License-Identifier: MIT
4*89c4ff92SAndroid Build Coastguard Worker //
5*89c4ff92SAndroid Build Coastguard Worker #include <aclCommon/ArmComputeTensorUtils.hpp>
6*89c4ff92SAndroid Build Coastguard Worker #include <aclCommon/ArmComputeUtils.hpp>
7*89c4ff92SAndroid Build Coastguard Worker 
8*89c4ff92SAndroid Build Coastguard Worker #include "armnn/Exceptions.hpp"
9*89c4ff92SAndroid Build Coastguard Worker #include "ArmComputeUtils.hpp"
10*89c4ff92SAndroid Build Coastguard Worker #include <armnn/Descriptors.hpp>
11*89c4ff92SAndroid Build Coastguard Worker 
12*89c4ff92SAndroid Build Coastguard Worker #include <fmt/format.h>
13*89c4ff92SAndroid Build Coastguard Worker 
14*89c4ff92SAndroid Build Coastguard Worker namespace armnn
15*89c4ff92SAndroid Build Coastguard Worker {
16*89c4ff92SAndroid Build Coastguard Worker namespace armcomputetensorutils
17*89c4ff92SAndroid Build Coastguard Worker {
18*89c4ff92SAndroid Build Coastguard Worker 
GetArmComputeDataType(armnn::DataType dataType,bool multiScales)19*89c4ff92SAndroid Build Coastguard Worker arm_compute::DataType GetArmComputeDataType(armnn::DataType dataType, bool multiScales)
20*89c4ff92SAndroid Build Coastguard Worker {
21*89c4ff92SAndroid Build Coastguard Worker     switch(dataType)
22*89c4ff92SAndroid Build Coastguard Worker     {
23*89c4ff92SAndroid Build Coastguard Worker         case armnn::DataType::BFloat16:
24*89c4ff92SAndroid Build Coastguard Worker             return arm_compute::DataType::BFLOAT16;
25*89c4ff92SAndroid Build Coastguard Worker         case armnn::DataType::Boolean:
26*89c4ff92SAndroid Build Coastguard Worker             return arm_compute::DataType::U8;
27*89c4ff92SAndroid Build Coastguard Worker         case armnn::DataType::Float16:
28*89c4ff92SAndroid Build Coastguard Worker             return arm_compute::DataType::F16;
29*89c4ff92SAndroid Build Coastguard Worker         case armnn::DataType::Float32:
30*89c4ff92SAndroid Build Coastguard Worker             return arm_compute::DataType::F32;
31*89c4ff92SAndroid Build Coastguard Worker         case armnn::DataType::QAsymmS8:
32*89c4ff92SAndroid Build Coastguard Worker             return arm_compute::DataType::QASYMM8_SIGNED;
33*89c4ff92SAndroid Build Coastguard Worker         case armnn::DataType::QAsymmU8:
34*89c4ff92SAndroid Build Coastguard Worker             return arm_compute::DataType::QASYMM8;
35*89c4ff92SAndroid Build Coastguard Worker         case armnn::DataType::QSymmS16:
36*89c4ff92SAndroid Build Coastguard Worker             return arm_compute::DataType::QSYMM16;
37*89c4ff92SAndroid Build Coastguard Worker         case armnn::DataType::Signed64:
38*89c4ff92SAndroid Build Coastguard Worker             return arm_compute::DataType::S64;
39*89c4ff92SAndroid Build Coastguard Worker         case armnn::DataType::QSymmS8:
40*89c4ff92SAndroid Build Coastguard Worker         {
41*89c4ff92SAndroid Build Coastguard Worker             return multiScales ? arm_compute::DataType::QSYMM8_PER_CHANNEL : arm_compute::DataType::QSYMM8;
42*89c4ff92SAndroid Build Coastguard Worker         }
43*89c4ff92SAndroid Build Coastguard Worker         case armnn::DataType::Signed32:
44*89c4ff92SAndroid Build Coastguard Worker             return arm_compute::DataType::S32;
45*89c4ff92SAndroid Build Coastguard Worker         default:
46*89c4ff92SAndroid Build Coastguard Worker             ARMNN_ASSERT_MSG(false, "Unknown data type");
47*89c4ff92SAndroid Build Coastguard Worker             return arm_compute::DataType::UNKNOWN;
48*89c4ff92SAndroid Build Coastguard Worker     }
49*89c4ff92SAndroid Build Coastguard Worker }
50*89c4ff92SAndroid Build Coastguard Worker 
GetArmNNDataType(arm_compute::DataType dataType)51*89c4ff92SAndroid Build Coastguard Worker armnn::DataType GetArmNNDataType(arm_compute::DataType dataType)
52*89c4ff92SAndroid Build Coastguard Worker {
53*89c4ff92SAndroid Build Coastguard Worker     switch(dataType)
54*89c4ff92SAndroid Build Coastguard Worker     {
55*89c4ff92SAndroid Build Coastguard Worker         case arm_compute::DataType::BFLOAT16:
56*89c4ff92SAndroid Build Coastguard Worker             return armnn::DataType::BFloat16;
57*89c4ff92SAndroid Build Coastguard Worker         case arm_compute::DataType::U8:
58*89c4ff92SAndroid Build Coastguard Worker             return armnn::DataType::Boolean;
59*89c4ff92SAndroid Build Coastguard Worker         case arm_compute::DataType::F16:
60*89c4ff92SAndroid Build Coastguard Worker             return armnn::DataType::Float16;
61*89c4ff92SAndroid Build Coastguard Worker         case arm_compute::DataType::F32:
62*89c4ff92SAndroid Build Coastguard Worker             return armnn::DataType::Float32;
63*89c4ff92SAndroid Build Coastguard Worker         case arm_compute::DataType::QASYMM8_SIGNED:
64*89c4ff92SAndroid Build Coastguard Worker             return armnn::DataType::QAsymmS8;
65*89c4ff92SAndroid Build Coastguard Worker         case arm_compute::DataType::QASYMM8:
66*89c4ff92SAndroid Build Coastguard Worker             return armnn::DataType::QAsymmU8;
67*89c4ff92SAndroid Build Coastguard Worker         case arm_compute::DataType::QSYMM16:
68*89c4ff92SAndroid Build Coastguard Worker             return armnn::DataType::QSymmS16;
69*89c4ff92SAndroid Build Coastguard Worker         case arm_compute::DataType::S64:
70*89c4ff92SAndroid Build Coastguard Worker             return armnn::DataType::Signed64;
71*89c4ff92SAndroid Build Coastguard Worker         case arm_compute::DataType::QSYMM8_PER_CHANNEL:
72*89c4ff92SAndroid Build Coastguard Worker             return armnn::DataType::QSymmS8;
73*89c4ff92SAndroid Build Coastguard Worker         case arm_compute::DataType::QSYMM8:
74*89c4ff92SAndroid Build Coastguard Worker             return armnn::DataType::QSymmS8;
75*89c4ff92SAndroid Build Coastguard Worker         case arm_compute::DataType::S32:
76*89c4ff92SAndroid Build Coastguard Worker             return armnn::DataType::Signed32;
77*89c4ff92SAndroid Build Coastguard Worker         default:
78*89c4ff92SAndroid Build Coastguard Worker             ARMNN_ASSERT_MSG(false, "Unknown data type");
79*89c4ff92SAndroid Build Coastguard Worker             return armnn::DataType::Float32;
80*89c4ff92SAndroid Build Coastguard Worker     }
81*89c4ff92SAndroid Build Coastguard Worker }
82*89c4ff92SAndroid Build Coastguard Worker 
BuildArmComputeReductionCoordinates(size_t inputDimensions,unsigned int originalInputRank,const std::vector<unsigned int> & armnnAxes)83*89c4ff92SAndroid Build Coastguard Worker arm_compute::Coordinates BuildArmComputeReductionCoordinates(size_t inputDimensions,
84*89c4ff92SAndroid Build Coastguard Worker                                                              unsigned int originalInputRank,
85*89c4ff92SAndroid Build Coastguard Worker                                                              const std::vector<unsigned int>& armnnAxes)
86*89c4ff92SAndroid Build Coastguard Worker {
87*89c4ff92SAndroid Build Coastguard Worker     arm_compute::Coordinates outAclCoords;
88*89c4ff92SAndroid Build Coastguard Worker 
89*89c4ff92SAndroid Build Coastguard Worker     if (armnnAxes.empty())
90*89c4ff92SAndroid Build Coastguard Worker     {
91*89c4ff92SAndroid Build Coastguard Worker         // If no reduction axes were provided, then the input must be reduced along all dimensions.
92*89c4ff92SAndroid Build Coastguard Worker         // Since Compute Library does not accept an empty vector as the reduction dimensions, we then
93*89c4ff92SAndroid Build Coastguard Worker         // manually create a vector including all the input dimensions (in reversed order) as:
94*89c4ff92SAndroid Build Coastguard Worker         //
95*89c4ff92SAndroid Build Coastguard Worker         // { inputDimensions - 1, inputDimensions - 2, ..., 1, 0 }
96*89c4ff92SAndroid Build Coastguard Worker         //
97*89c4ff92SAndroid Build Coastguard Worker         outAclCoords.set_num_dimensions(inputDimensions);
98*89c4ff92SAndroid Build Coastguard Worker         std::generate(outAclCoords.begin(), outAclCoords.end(), [d = inputDimensions - 1] () mutable { return d--; });
99*89c4ff92SAndroid Build Coastguard Worker     }
100*89c4ff92SAndroid Build Coastguard Worker     else
101*89c4ff92SAndroid Build Coastguard Worker     {
102*89c4ff92SAndroid Build Coastguard Worker         // Create a vector of reduction dimensions (in reversed order) with the given reduction axes.
103*89c4ff92SAndroid Build Coastguard Worker         //
104*89c4ff92SAndroid Build Coastguard Worker         // Adjust the given reduction axes according to the original rank of the input tensor (before ACL applied any
105*89c4ff92SAndroid Build Coastguard Worker         // dimension correction).
106*89c4ff92SAndroid Build Coastguard Worker         // For example, if the input tensor originally had 4 dimensions, and one of the reduction axes was 2, then the
107*89c4ff92SAndroid Build Coastguard Worker         // new value for that reduction axis should be 1.
108*89c4ff92SAndroid Build Coastguard Worker         //
109*89c4ff92SAndroid Build Coastguard Worker         // Example:
110*89c4ff92SAndroid Build Coastguard Worker         // ArmNN input shape = { 1, 1, 3, 2 } -> ACL input shape = { 2, 3 }
111*89c4ff92SAndroid Build Coastguard Worker         // ArmNN reduction axis = { 2 }       -> ACL reduction axis = { 1 }
112*89c4ff92SAndroid Build Coastguard Worker         // ArmNN reduction axis = { 3 }       -> ACL reduction axis = { 0 }
113*89c4ff92SAndroid Build Coastguard Worker         //
114*89c4ff92SAndroid Build Coastguard Worker         // The transformation: ACL reduction axis index = original rank - ArmNN reduction axis index - 1
115*89c4ff92SAndroid Build Coastguard Worker         //
116*89c4ff92SAndroid Build Coastguard Worker         outAclCoords.set_num_dimensions(armnnAxes.size());
117*89c4ff92SAndroid Build Coastguard Worker         std::transform(armnnAxes.begin(), armnnAxes.end(),
118*89c4ff92SAndroid Build Coastguard Worker                        outAclCoords.begin(),
119*89c4ff92SAndroid Build Coastguard Worker                        [originalInputRank](unsigned int i){ return originalInputRank - i - 1; });
120*89c4ff92SAndroid Build Coastguard Worker     }
121*89c4ff92SAndroid Build Coastguard Worker 
122*89c4ff92SAndroid Build Coastguard Worker     return outAclCoords;
123*89c4ff92SAndroid Build Coastguard Worker }
124*89c4ff92SAndroid Build Coastguard Worker 
BuildArmComputeTensorShape(const armnn::TensorShape & tensorShape)125*89c4ff92SAndroid Build Coastguard Worker arm_compute::TensorShape BuildArmComputeTensorShape(const armnn::TensorShape& tensorShape)
126*89c4ff92SAndroid Build Coastguard Worker {
127*89c4ff92SAndroid Build Coastguard Worker     arm_compute::TensorShape shape;
128*89c4ff92SAndroid Build Coastguard Worker 
129*89c4ff92SAndroid Build Coastguard Worker     // armnn tensors are (batch, channels, height, width).
130*89c4ff92SAndroid Build Coastguard Worker     // arm_compute tensors are (width, height, channels, batch).
131*89c4ff92SAndroid Build Coastguard Worker     for (unsigned int i = 0; i < tensorShape.GetNumDimensions(); i++)
132*89c4ff92SAndroid Build Coastguard Worker     {
133*89c4ff92SAndroid Build Coastguard Worker         // Note that our dimensions are stored in the opposite order to ACL's.
134*89c4ff92SAndroid Build Coastguard Worker         shape.set(tensorShape.GetNumDimensions() - i - 1, tensorShape[i], false);
135*89c4ff92SAndroid Build Coastguard Worker 
136*89c4ff92SAndroid Build Coastguard Worker         // TensorShape::set() flattens leading ones, so that batch size 1 cannot happen.
137*89c4ff92SAndroid Build Coastguard Worker         // arm_compute tensors expect this.
138*89c4ff92SAndroid Build Coastguard Worker     }
139*89c4ff92SAndroid Build Coastguard Worker 
140*89c4ff92SAndroid Build Coastguard Worker     // prevent arm_compute issue where tensor is flattened to nothing
141*89c4ff92SAndroid Build Coastguard Worker     if (shape.num_dimensions() == 0)
142*89c4ff92SAndroid Build Coastguard Worker     {
143*89c4ff92SAndroid Build Coastguard Worker         shape.set_num_dimensions(1);
144*89c4ff92SAndroid Build Coastguard Worker     }
145*89c4ff92SAndroid Build Coastguard Worker 
146*89c4ff92SAndroid Build Coastguard Worker     return shape;
147*89c4ff92SAndroid Build Coastguard Worker }
148*89c4ff92SAndroid Build Coastguard Worker 
ReduceDimsForACL(const armnn::TensorShape tensorShape,unsigned int dimensions)149*89c4ff92SAndroid Build Coastguard Worker std::vector<unsigned int> ReduceDimsForACL(const armnn::TensorShape tensorShape, unsigned int dimensions)
150*89c4ff92SAndroid Build Coastguard Worker {
151*89c4ff92SAndroid Build Coastguard Worker     std::vector<unsigned int> newShape;
152*89c4ff92SAndroid Build Coastguard Worker 
153*89c4ff92SAndroid Build Coastguard Worker     unsigned int dimsToSkip = 0;
154*89c4ff92SAndroid Build Coastguard Worker 
155*89c4ff92SAndroid Build Coastguard Worker     if (tensorShape.GetNumDimensions() > dimensions)
156*89c4ff92SAndroid Build Coastguard Worker     {
157*89c4ff92SAndroid Build Coastguard Worker         dimsToSkip = tensorShape.GetNumDimensions() - dimensions;
158*89c4ff92SAndroid Build Coastguard Worker     }
159*89c4ff92SAndroid Build Coastguard Worker     unsigned int dimsSkipped = 0;
160*89c4ff92SAndroid Build Coastguard Worker     bool insertRemainder = false;
161*89c4ff92SAndroid Build Coastguard Worker 
162*89c4ff92SAndroid Build Coastguard Worker     for (unsigned int i = 0; i < tensorShape.GetNumDimensions(); ++i)
163*89c4ff92SAndroid Build Coastguard Worker     {
164*89c4ff92SAndroid Build Coastguard Worker         if (tensorShape[i] == 1 && dimsSkipped < dimsToSkip && !insertRemainder)
165*89c4ff92SAndroid Build Coastguard Worker         {
166*89c4ff92SAndroid Build Coastguard Worker             ++dimsSkipped;
167*89c4ff92SAndroid Build Coastguard Worker             continue;
168*89c4ff92SAndroid Build Coastguard Worker         }
169*89c4ff92SAndroid Build Coastguard Worker         newShape.insert(newShape.begin(), tensorShape[i]);
170*89c4ff92SAndroid Build Coastguard Worker         // Once we insert the first dimension we can't skip any more
171*89c4ff92SAndroid Build Coastguard Worker         insertRemainder = true;
172*89c4ff92SAndroid Build Coastguard Worker     }
173*89c4ff92SAndroid Build Coastguard Worker     return newShape;
174*89c4ff92SAndroid Build Coastguard Worker }
175*89c4ff92SAndroid Build Coastguard Worker 
BuildArmComputeTensorShape(const armnn::TensorShape & tensorShape,unsigned int dimensions)176*89c4ff92SAndroid Build Coastguard Worker arm_compute::TensorShape BuildArmComputeTensorShape(const armnn::TensorShape& tensorShape, unsigned int dimensions)
177*89c4ff92SAndroid Build Coastguard Worker {
178*89c4ff92SAndroid Build Coastguard Worker     arm_compute::TensorShape shape;
179*89c4ff92SAndroid Build Coastguard Worker     std::vector<unsigned int> strippedShape = ReduceDimsForACL(tensorShape, dimensions);
180*89c4ff92SAndroid Build Coastguard Worker 
181*89c4ff92SAndroid Build Coastguard Worker     for (unsigned int i = 0; i < strippedShape.size(); i++)
182*89c4ff92SAndroid Build Coastguard Worker     {
183*89c4ff92SAndroid Build Coastguard Worker         shape.set(i, strippedShape[i], false);
184*89c4ff92SAndroid Build Coastguard Worker     }
185*89c4ff92SAndroid Build Coastguard Worker 
186*89c4ff92SAndroid Build Coastguard Worker     // prevent arm_compute issue where tensor is flattened to nothing
187*89c4ff92SAndroid Build Coastguard Worker     if (shape.num_dimensions() == 0)
188*89c4ff92SAndroid Build Coastguard Worker     {
189*89c4ff92SAndroid Build Coastguard Worker         shape.set_num_dimensions(1);
190*89c4ff92SAndroid Build Coastguard Worker     }
191*89c4ff92SAndroid Build Coastguard Worker     return shape;
192*89c4ff92SAndroid Build Coastguard Worker }
193*89c4ff92SAndroid Build Coastguard Worker 
194*89c4ff92SAndroid Build Coastguard Worker // Utility function used to build a TensorInfo object, that can be used to initialise
195*89c4ff92SAndroid Build Coastguard Worker // ARM Compute Tensor and CLTensor allocators.
196*89c4ff92SAndroid Build Coastguard Worker // Note: this utility ignores the value of armnn::TensorInfo.IsConstant(). ACL tensors
197*89c4ff92SAndroid Build Coastguard Worker // default to constant but Arm NN ones default to non constant. In the cases where
198*89c4ff92SAndroid Build Coastguard Worker // we expect ACL to treat a tensor as constant that value must be set after this
199*89c4ff92SAndroid Build Coastguard Worker // utility has been called.
BuildArmComputeTensorInfo(const armnn::TensorInfo & tensorInfo)200*89c4ff92SAndroid Build Coastguard Worker arm_compute::TensorInfo BuildArmComputeTensorInfo(const armnn::TensorInfo& tensorInfo)
201*89c4ff92SAndroid Build Coastguard Worker {
202*89c4ff92SAndroid Build Coastguard Worker     bool multiScales = tensorInfo.HasMultipleQuantizationScales();
203*89c4ff92SAndroid Build Coastguard Worker     const arm_compute::TensorShape aclTensorShape = BuildArmComputeTensorShape(tensorInfo.GetShape());
204*89c4ff92SAndroid Build Coastguard Worker     const arm_compute::DataType aclDataType       = GetArmComputeDataType(tensorInfo.GetDataType(), multiScales);
205*89c4ff92SAndroid Build Coastguard Worker 
206*89c4ff92SAndroid Build Coastguard Worker     const arm_compute::QuantizationInfo aclQuantizationInfo = multiScales ?
207*89c4ff92SAndroid Build Coastguard Worker         arm_compute::QuantizationInfo(tensorInfo.GetQuantizationScales()) :
208*89c4ff92SAndroid Build Coastguard Worker         arm_compute::QuantizationInfo(tensorInfo.GetQuantizationScale(), tensorInfo.GetQuantizationOffset());
209*89c4ff92SAndroid Build Coastguard Worker 
210*89c4ff92SAndroid Build Coastguard Worker     return arm_compute::TensorInfo(aclTensorShape, 1, aclDataType, aclQuantizationInfo);
211*89c4ff92SAndroid Build Coastguard Worker }
212*89c4ff92SAndroid Build Coastguard Worker 
BuildArmComputeTensorInfo(const armnn::TensorInfo & tensorInfo,armnn::DataLayout dataLayout)213*89c4ff92SAndroid Build Coastguard Worker arm_compute::TensorInfo BuildArmComputeTensorInfo(const armnn::TensorInfo& tensorInfo,
214*89c4ff92SAndroid Build Coastguard Worker                                                   armnn::DataLayout dataLayout)
215*89c4ff92SAndroid Build Coastguard Worker {
216*89c4ff92SAndroid Build Coastguard Worker     arm_compute::TensorInfo aclTensorInfo = BuildArmComputeTensorInfo(tensorInfo);
217*89c4ff92SAndroid Build Coastguard Worker     aclTensorInfo.set_data_layout(ConvertDataLayout(dataLayout));
218*89c4ff92SAndroid Build Coastguard Worker 
219*89c4ff92SAndroid Build Coastguard Worker     return aclTensorInfo;
220*89c4ff92SAndroid Build Coastguard Worker }
221*89c4ff92SAndroid Build Coastguard Worker 
BuildArmComputeTensorInfo(const armnn::TensorInfo & tensorInfo,unsigned int dimensions)222*89c4ff92SAndroid Build Coastguard Worker arm_compute::TensorInfo BuildArmComputeTensorInfo(const armnn::TensorInfo& tensorInfo, unsigned int dimensions)
223*89c4ff92SAndroid Build Coastguard Worker {
224*89c4ff92SAndroid Build Coastguard Worker     bool multiScales = tensorInfo.HasMultipleQuantizationScales();
225*89c4ff92SAndroid Build Coastguard Worker     const arm_compute::TensorShape aclTensorShape = BuildArmComputeTensorShape(tensorInfo.GetShape(), dimensions);
226*89c4ff92SAndroid Build Coastguard Worker     const arm_compute::DataType aclDataType       = GetArmComputeDataType(tensorInfo.GetDataType(), multiScales);
227*89c4ff92SAndroid Build Coastguard Worker 
228*89c4ff92SAndroid Build Coastguard Worker     const arm_compute::QuantizationInfo aclQuantizationInfo = multiScales ?
229*89c4ff92SAndroid Build Coastguard Worker               arm_compute::QuantizationInfo(tensorInfo.GetQuantizationScales()) :
230*89c4ff92SAndroid Build Coastguard Worker               arm_compute::QuantizationInfo(tensorInfo.GetQuantizationScale(), tensorInfo.GetQuantizationOffset());
231*89c4ff92SAndroid Build Coastguard Worker 
232*89c4ff92SAndroid Build Coastguard Worker     return arm_compute::TensorInfo(aclTensorShape, 1, aclDataType, aclQuantizationInfo);
233*89c4ff92SAndroid Build Coastguard Worker }
BuildArmComputeTensorInfo(const armnn::TensorInfo & tensorInfo,armnn::DataLayout dataLayout,unsigned int dimensions)234*89c4ff92SAndroid Build Coastguard Worker arm_compute::TensorInfo BuildArmComputeTensorInfo(const armnn::TensorInfo& tensorInfo,
235*89c4ff92SAndroid Build Coastguard Worker                                                   armnn::DataLayout dataLayout, unsigned int dimensions)
236*89c4ff92SAndroid Build Coastguard Worker {
237*89c4ff92SAndroid Build Coastguard Worker     arm_compute::TensorInfo aclTensorInfo = BuildArmComputeTensorInfo(tensorInfo, dimensions);
238*89c4ff92SAndroid Build Coastguard Worker     aclTensorInfo.set_data_layout(ConvertDataLayout(dataLayout));
239*89c4ff92SAndroid Build Coastguard Worker 
240*89c4ff92SAndroid Build Coastguard Worker     return aclTensorInfo;
241*89c4ff92SAndroid Build Coastguard Worker }
242*89c4ff92SAndroid Build Coastguard Worker 
243*89c4ff92SAndroid Build Coastguard Worker 
ConvertDataLayout(armnn::DataLayout dataLayout)244*89c4ff92SAndroid Build Coastguard Worker arm_compute::DataLayout ConvertDataLayout(armnn::DataLayout dataLayout)
245*89c4ff92SAndroid Build Coastguard Worker {
246*89c4ff92SAndroid Build Coastguard Worker     switch(dataLayout)
247*89c4ff92SAndroid Build Coastguard Worker     {
248*89c4ff92SAndroid Build Coastguard Worker         case armnn::DataLayout::NHWC : return arm_compute::DataLayout::NHWC;
249*89c4ff92SAndroid Build Coastguard Worker 
250*89c4ff92SAndroid Build Coastguard Worker         case armnn::DataLayout::NCHW : return arm_compute::DataLayout::NCHW;
251*89c4ff92SAndroid Build Coastguard Worker 
252*89c4ff92SAndroid Build Coastguard Worker         case armnn::DataLayout::NDHWC : return arm_compute::DataLayout::NDHWC;
253*89c4ff92SAndroid Build Coastguard Worker 
254*89c4ff92SAndroid Build Coastguard Worker         case armnn::DataLayout::NCDHW : return arm_compute::DataLayout::NCDHW;
255*89c4ff92SAndroid Build Coastguard Worker 
256*89c4ff92SAndroid Build Coastguard Worker         default: throw InvalidArgumentException("Unknown armnn::DataLayout: [" +
257*89c4ff92SAndroid Build Coastguard Worker                                                 std::to_string(static_cast<int>(dataLayout)) + "]");
258*89c4ff92SAndroid Build Coastguard Worker     }
259*89c4ff92SAndroid Build Coastguard Worker }
260*89c4ff92SAndroid Build Coastguard Worker 
BuildArmComputePoolingLayerInfo(const Pooling2dDescriptor & descriptor,bool fpMixedPrecision)261*89c4ff92SAndroid Build Coastguard Worker arm_compute::PoolingLayerInfo BuildArmComputePoolingLayerInfo(const Pooling2dDescriptor& descriptor,
262*89c4ff92SAndroid Build Coastguard Worker                                                               bool fpMixedPrecision)
263*89c4ff92SAndroid Build Coastguard Worker {
264*89c4ff92SAndroid Build Coastguard Worker     // Resolve ARM Compute layer parameters.
265*89c4ff92SAndroid Build Coastguard Worker     const arm_compute::PoolingType poolingType = ConvertPoolingAlgorithmToAclPoolingType(descriptor.m_PoolType);
266*89c4ff92SAndroid Build Coastguard Worker 
267*89c4ff92SAndroid Build Coastguard Worker     const arm_compute::DataLayout dataLayout = ConvertDataLayout(descriptor.m_DataLayout);
268*89c4ff92SAndroid Build Coastguard Worker 
269*89c4ff92SAndroid Build Coastguard Worker     bool isGlobalPooling = (descriptor.m_StrideX==0 && descriptor.m_StrideY==0);
270*89c4ff92SAndroid Build Coastguard Worker     //use specific constructor if global pooling
271*89c4ff92SAndroid Build Coastguard Worker     if(isGlobalPooling)
272*89c4ff92SAndroid Build Coastguard Worker     {
273*89c4ff92SAndroid Build Coastguard Worker         return arm_compute::PoolingLayerInfo(poolingType, dataLayout);
274*89c4ff92SAndroid Build Coastguard Worker     }
275*89c4ff92SAndroid Build Coastguard Worker 
276*89c4ff92SAndroid Build Coastguard Worker     const arm_compute::DimensionRoundingType rounding = ConvertOutputShapeRoundingToAclDimensionRoundingType(
277*89c4ff92SAndroid Build Coastguard Worker                                                                                     descriptor.m_OutputShapeRounding);
278*89c4ff92SAndroid Build Coastguard Worker     const arm_compute::PadStrideInfo padStrideInfo(descriptor.m_StrideX,
279*89c4ff92SAndroid Build Coastguard Worker                                       descriptor.m_StrideY,
280*89c4ff92SAndroid Build Coastguard Worker                                       descriptor.m_PadLeft,
281*89c4ff92SAndroid Build Coastguard Worker                                       descriptor.m_PadRight,
282*89c4ff92SAndroid Build Coastguard Worker                                       descriptor.m_PadTop,
283*89c4ff92SAndroid Build Coastguard Worker                                       descriptor.m_PadBottom,
284*89c4ff92SAndroid Build Coastguard Worker                                       rounding);
285*89c4ff92SAndroid Build Coastguard Worker 
286*89c4ff92SAndroid Build Coastguard Worker     const bool excludePadding = (descriptor.m_PaddingMethod == PaddingMethod::Exclude);
287*89c4ff92SAndroid Build Coastguard Worker 
288*89c4ff92SAndroid Build Coastguard Worker     const arm_compute::Size2D poolSize(descriptor.m_PoolWidth, descriptor.m_PoolHeight);
289*89c4ff92SAndroid Build Coastguard Worker 
290*89c4ff92SAndroid Build Coastguard Worker     return arm_compute::PoolingLayerInfo(poolingType, poolSize, dataLayout, padStrideInfo, excludePadding,
291*89c4ff92SAndroid Build Coastguard Worker                                          fpMixedPrecision);
292*89c4ff92SAndroid Build Coastguard Worker }
293*89c4ff92SAndroid Build Coastguard Worker 
BuildArmComputePooling3dLayerInfo(const Pooling3dDescriptor & descriptor,bool fpMixedPrecision)294*89c4ff92SAndroid Build Coastguard Worker arm_compute::Pooling3dLayerInfo BuildArmComputePooling3dLayerInfo(const Pooling3dDescriptor& descriptor,
295*89c4ff92SAndroid Build Coastguard Worker                                                                   bool fpMixedPrecision)
296*89c4ff92SAndroid Build Coastguard Worker {
297*89c4ff92SAndroid Build Coastguard Worker     const arm_compute::PoolingType poolingType = ConvertPoolingAlgorithmToAclPoolingType(descriptor.m_PoolType);
298*89c4ff92SAndroid Build Coastguard Worker 
299*89c4ff92SAndroid Build Coastguard Worker     bool isGlobalPooling = (descriptor.m_StrideX==0 && descriptor.m_StrideY==0 && descriptor.m_StrideZ==0);
300*89c4ff92SAndroid Build Coastguard Worker     //use specific constructor if global pooling
301*89c4ff92SAndroid Build Coastguard Worker     if(isGlobalPooling)
302*89c4ff92SAndroid Build Coastguard Worker     {
303*89c4ff92SAndroid Build Coastguard Worker         return arm_compute::Pooling3dLayerInfo(poolingType);
304*89c4ff92SAndroid Build Coastguard Worker     }
305*89c4ff92SAndroid Build Coastguard Worker 
306*89c4ff92SAndroid Build Coastguard Worker     const arm_compute::Size3D poolSize(descriptor.m_PoolWidth, descriptor.m_PoolHeight, descriptor.m_PoolDepth);
307*89c4ff92SAndroid Build Coastguard Worker 
308*89c4ff92SAndroid Build Coastguard Worker     const arm_compute::Size3D stride(descriptor.m_StrideX,
309*89c4ff92SAndroid Build Coastguard Worker                         descriptor.m_StrideY,
310*89c4ff92SAndroid Build Coastguard Worker                         descriptor.m_StrideZ);
311*89c4ff92SAndroid Build Coastguard Worker 
312*89c4ff92SAndroid Build Coastguard Worker     const arm_compute::Padding3D padding(descriptor.m_PadLeft,
313*89c4ff92SAndroid Build Coastguard Worker                             descriptor.m_PadRight,
314*89c4ff92SAndroid Build Coastguard Worker                             descriptor.m_PadTop,
315*89c4ff92SAndroid Build Coastguard Worker                             descriptor.m_PadBottom,
316*89c4ff92SAndroid Build Coastguard Worker                             descriptor.m_PadFront,
317*89c4ff92SAndroid Build Coastguard Worker                             descriptor.m_PadBack);
318*89c4ff92SAndroid Build Coastguard Worker 
319*89c4ff92SAndroid Build Coastguard Worker     const bool excludePadding = (descriptor.m_PaddingMethod == PaddingMethod::Exclude);
320*89c4ff92SAndroid Build Coastguard Worker 
321*89c4ff92SAndroid Build Coastguard Worker     const arm_compute::DimensionRoundingType rounding = ConvertOutputShapeRoundingToAclDimensionRoundingType(
322*89c4ff92SAndroid Build Coastguard Worker             descriptor.m_OutputShapeRounding);
323*89c4ff92SAndroid Build Coastguard Worker 
324*89c4ff92SAndroid Build Coastguard Worker     return arm_compute::Pooling3dLayerInfo(poolingType,
325*89c4ff92SAndroid Build Coastguard Worker                                            poolSize,
326*89c4ff92SAndroid Build Coastguard Worker                                            stride,
327*89c4ff92SAndroid Build Coastguard Worker                                            padding,
328*89c4ff92SAndroid Build Coastguard Worker                                            excludePadding,
329*89c4ff92SAndroid Build Coastguard Worker                                            fpMixedPrecision,
330*89c4ff92SAndroid Build Coastguard Worker                                            rounding);
331*89c4ff92SAndroid Build Coastguard Worker }
332*89c4ff92SAndroid Build Coastguard Worker 
BuildArmComputeNormalizationLayerInfo(const NormalizationDescriptor & descriptor)333*89c4ff92SAndroid Build Coastguard Worker arm_compute::NormalizationLayerInfo BuildArmComputeNormalizationLayerInfo(const NormalizationDescriptor& descriptor)
334*89c4ff92SAndroid Build Coastguard Worker {
335*89c4ff92SAndroid Build Coastguard Worker     const arm_compute::NormType normType =
336*89c4ff92SAndroid Build Coastguard Worker         ConvertNormalizationAlgorithmChannelToAclNormType(descriptor.m_NormChannelType);
337*89c4ff92SAndroid Build Coastguard Worker     return arm_compute::NormalizationLayerInfo(normType,
338*89c4ff92SAndroid Build Coastguard Worker                                                descriptor.m_NormSize,
339*89c4ff92SAndroid Build Coastguard Worker                                                descriptor.m_Alpha,
340*89c4ff92SAndroid Build Coastguard Worker                                                descriptor.m_Beta,
341*89c4ff92SAndroid Build Coastguard Worker                                                descriptor.m_K,
342*89c4ff92SAndroid Build Coastguard Worker                                                false);
343*89c4ff92SAndroid Build Coastguard Worker }
344*89c4ff92SAndroid Build Coastguard Worker 
BuildArmComputePermutationVector(const armnn::PermutationVector & perm)345*89c4ff92SAndroid Build Coastguard Worker arm_compute::PermutationVector BuildArmComputePermutationVector(const armnn::PermutationVector& perm)
346*89c4ff92SAndroid Build Coastguard Worker {
347*89c4ff92SAndroid Build Coastguard Worker     arm_compute::PermutationVector aclPerm;
348*89c4ff92SAndroid Build Coastguard Worker 
349*89c4ff92SAndroid Build Coastguard Worker     unsigned int start = 0;
350*89c4ff92SAndroid Build Coastguard Worker     while ((start < perm.GetSize()) && (start == perm[start]))
351*89c4ff92SAndroid Build Coastguard Worker     {
352*89c4ff92SAndroid Build Coastguard Worker         ++start;
353*89c4ff92SAndroid Build Coastguard Worker     }
354*89c4ff92SAndroid Build Coastguard Worker 
355*89c4ff92SAndroid Build Coastguard Worker     for (unsigned int i = start; i < perm.GetSize(); ++i)
356*89c4ff92SAndroid Build Coastguard Worker     {
357*89c4ff92SAndroid Build Coastguard Worker         aclPerm.set(i - start, perm[i] - start);
358*89c4ff92SAndroid Build Coastguard Worker     }
359*89c4ff92SAndroid Build Coastguard Worker     return aclPerm;
360*89c4ff92SAndroid Build Coastguard Worker }
361*89c4ff92SAndroid Build Coastguard Worker 
BuildArmComputeTransposeVector(const armnn::PermutationVector & perm)362*89c4ff92SAndroid Build Coastguard Worker arm_compute::PermutationVector BuildArmComputeTransposeVector(const armnn::PermutationVector& perm)
363*89c4ff92SAndroid Build Coastguard Worker {
364*89c4ff92SAndroid Build Coastguard Worker     // As ArmNN indexes are left to right and ACL indexes are right to left,
365*89c4ff92SAndroid Build Coastguard Worker     // the permutation vector has to be reversed and then translated into ACL axis.
366*89c4ff92SAndroid Build Coastguard Worker     // i.e. {1, 0, 2, 3} --> {3, 2, 0, 1} --> {0, 1, 3, 2}
367*89c4ff92SAndroid Build Coastguard Worker 
368*89c4ff92SAndroid Build Coastguard Worker     // Below an example of how the ArmNN and ACL index format work:
369*89c4ff92SAndroid Build Coastguard Worker     // ArmNN Format:
370*89c4ff92SAndroid Build Coastguard Worker     // Input Shape        {1, 10, 20, 30}
371*89c4ff92SAndroid Build Coastguard Worker     // Permutation Vector {1,  0,  2,  3}
372*89c4ff92SAndroid Build Coastguard Worker     // Output Shape       {10, 1, 20, 30}
373*89c4ff92SAndroid Build Coastguard Worker     // dim "1" of input goes into index 0 of the output ([ 10, X, X, X])
374*89c4ff92SAndroid Build Coastguard Worker     // dim "0" of input goes into index 1 of the output ([ 10, 1, X, X ])
375*89c4ff92SAndroid Build Coastguard Worker     // dim "2" of input goes into index 2 of the output ([ 10, 1, 20, X ])
376*89c4ff92SAndroid Build Coastguard Worker     // dim "3" of input goes into index 3 of the output ([ 10, 1, 20, 30 ])
377*89c4ff92SAndroid Build Coastguard Worker     // ACL Format:
378*89c4ff92SAndroid Build Coastguard Worker     // Input Shape        {30, 20, 10, 1}
379*89c4ff92SAndroid Build Coastguard Worker     // Permutation Vector {0,  1,  3,  2}
380*89c4ff92SAndroid Build Coastguard Worker     // Output Shape       {30, 20, 1, 10}
381*89c4ff92SAndroid Build Coastguard Worker     // dim "0" of input goes into index 0 of the output ([ 30,  X, X, X])
382*89c4ff92SAndroid Build Coastguard Worker     // dim "1" of input goes into index 1 of the output ([ 30, 20, X, X ])
383*89c4ff92SAndroid Build Coastguard Worker     // dim "3" of input goes into index 2 of the output ([ 30, 20, 1, X ])
384*89c4ff92SAndroid Build Coastguard Worker     // dim "2" of input goes into index 3 of the output ([ 30, 20, 1, 10 ])
385*89c4ff92SAndroid Build Coastguard Worker 
386*89c4ff92SAndroid Build Coastguard Worker     arm_compute::PermutationVector aclPerm;
387*89c4ff92SAndroid Build Coastguard Worker     auto rank = perm.GetSize();
388*89c4ff92SAndroid Build Coastguard Worker 
389*89c4ff92SAndroid Build Coastguard Worker     // Reverse the order. i.e. {1, 0, 2, 3} --> {3, 2, 0, 1}
390*89c4ff92SAndroid Build Coastguard Worker     std::vector<unsigned int> reversedPerm;
391*89c4ff92SAndroid Build Coastguard Worker     reversedPerm.reserve(rank);
392*89c4ff92SAndroid Build Coastguard Worker     for (unsigned int i = rank; i > 0; --i)
393*89c4ff92SAndroid Build Coastguard Worker     {
394*89c4ff92SAndroid Build Coastguard Worker         reversedPerm.push_back(perm[i-1]);
395*89c4ff92SAndroid Build Coastguard Worker     }
396*89c4ff92SAndroid Build Coastguard Worker 
397*89c4ff92SAndroid Build Coastguard Worker     // Translate from Arm NN axis to ACL axis. i.e. {3, 2, 0, 1} --> {0, 1, 3, 2}
398*89c4ff92SAndroid Build Coastguard Worker     for (unsigned int i = 0; i < rank; ++i)
399*89c4ff92SAndroid Build Coastguard Worker     {
400*89c4ff92SAndroid Build Coastguard Worker         auto aclAxis = rank - 1 - reversedPerm[i];
401*89c4ff92SAndroid Build Coastguard Worker         aclPerm.set(i, aclAxis);
402*89c4ff92SAndroid Build Coastguard Worker     }
403*89c4ff92SAndroid Build Coastguard Worker     return aclPerm;
404*89c4ff92SAndroid Build Coastguard Worker }
405*89c4ff92SAndroid Build Coastguard Worker 
BuildArmComputeSize2D(const unsigned int width,const unsigned int height)406*89c4ff92SAndroid Build Coastguard Worker arm_compute::Size2D BuildArmComputeSize2D(const unsigned int width, const unsigned int height)
407*89c4ff92SAndroid Build Coastguard Worker {
408*89c4ff92SAndroid Build Coastguard Worker     return arm_compute::Size2D(width, height);
409*89c4ff92SAndroid Build Coastguard Worker }
410*89c4ff92SAndroid Build Coastguard Worker 
GetPixelValue(const arm_compute::ITensorInfo * tensorInfo,float value)411*89c4ff92SAndroid Build Coastguard Worker arm_compute::PixelValue GetPixelValue(const arm_compute::ITensorInfo* tensorInfo, float value)
412*89c4ff92SAndroid Build Coastguard Worker {
413*89c4ff92SAndroid Build Coastguard Worker     switch (tensorInfo->data_type())
414*89c4ff92SAndroid Build Coastguard Worker     {
415*89c4ff92SAndroid Build Coastguard Worker         case arm_compute::DataType::F16:
416*89c4ff92SAndroid Build Coastguard Worker         {
417*89c4ff92SAndroid Build Coastguard Worker             arm_compute::PixelValue pixelValue = arm_compute::PixelValue(static_cast<Half>(value));
418*89c4ff92SAndroid Build Coastguard Worker             if (isinf(pixelValue.get<Half>())) {
419*89c4ff92SAndroid Build Coastguard Worker                 throw InvalidArgumentException("Under/Overflow converting float value [" + std::to_string(value) +
420*89c4ff92SAndroid Build Coastguard Worker                     "] to fp16: [" + std::to_string(pixelValue.get<Half>()) + "]");
421*89c4ff92SAndroid Build Coastguard Worker             }
422*89c4ff92SAndroid Build Coastguard Worker             return pixelValue;
423*89c4ff92SAndroid Build Coastguard Worker         }
424*89c4ff92SAndroid Build Coastguard Worker         case arm_compute::DataType::F32:
425*89c4ff92SAndroid Build Coastguard Worker             return arm_compute::PixelValue(value);
426*89c4ff92SAndroid Build Coastguard Worker         case arm_compute::DataType::QASYMM8:
427*89c4ff92SAndroid Build Coastguard Worker             return arm_compute::PixelValue(static_cast<uint8_t>(value));
428*89c4ff92SAndroid Build Coastguard Worker         case arm_compute::DataType::QSYMM16:
429*89c4ff92SAndroid Build Coastguard Worker             return arm_compute::PixelValue(static_cast<int16_t>(value));
430*89c4ff92SAndroid Build Coastguard Worker         case arm_compute::DataType::QSYMM8:
431*89c4ff92SAndroid Build Coastguard Worker         case arm_compute::DataType::QASYMM8_SIGNED:
432*89c4ff92SAndroid Build Coastguard Worker         case arm_compute::DataType::QSYMM8_PER_CHANNEL:
433*89c4ff92SAndroid Build Coastguard Worker             return arm_compute::PixelValue(static_cast<int8_t>(value));
434*89c4ff92SAndroid Build Coastguard Worker         case arm_compute::DataType::S32:
435*89c4ff92SAndroid Build Coastguard Worker             return arm_compute::PixelValue(static_cast<int32_t>(value));
436*89c4ff92SAndroid Build Coastguard Worker         default:
437*89c4ff92SAndroid Build Coastguard Worker             throw InvalidArgumentException("Unsupported DataType: [" +
438*89c4ff92SAndroid Build Coastguard Worker                                            std::to_string(static_cast<int>(tensorInfo->data_type())) + "]");
439*89c4ff92SAndroid Build Coastguard Worker     }
440*89c4ff92SAndroid Build Coastguard Worker }
441*89c4ff92SAndroid Build Coastguard Worker 
ComputeDepthwiseConv2dDepthMultiplier(armnn::DataLayout layout,const arm_compute::TensorShape & weightsShape,const arm_compute::TensorShape & inputShape)442*89c4ff92SAndroid Build Coastguard Worker unsigned int ComputeDepthwiseConv2dDepthMultiplier(armnn::DataLayout layout,
443*89c4ff92SAndroid Build Coastguard Worker                                                    const arm_compute::TensorShape& weightsShape,
444*89c4ff92SAndroid Build Coastguard Worker                                                    const arm_compute::TensorShape& inputShape)
445*89c4ff92SAndroid Build Coastguard Worker {
446*89c4ff92SAndroid Build Coastguard Worker     unsigned int depthMultiplier;
447*89c4ff92SAndroid Build Coastguard Worker     if (layout == armnn::DataLayout::NHWC)
448*89c4ff92SAndroid Build Coastguard Worker     {
449*89c4ff92SAndroid Build Coastguard Worker         depthMultiplier = static_cast<uint32_t>(weightsShape[0]) / static_cast<uint32_t>(inputShape[0]);
450*89c4ff92SAndroid Build Coastguard Worker     }
451*89c4ff92SAndroid Build Coastguard Worker     else if (layout == armnn::DataLayout::NCHW)
452*89c4ff92SAndroid Build Coastguard Worker     {
453*89c4ff92SAndroid Build Coastguard Worker         depthMultiplier = static_cast<uint32_t>(weightsShape[2]) / static_cast<uint32_t>(inputShape[2]);
454*89c4ff92SAndroid Build Coastguard Worker     }
455*89c4ff92SAndroid Build Coastguard Worker     else
456*89c4ff92SAndroid Build Coastguard Worker     {
457*89c4ff92SAndroid Build Coastguard Worker         throw InvalidArgumentException(fmt::format("Unknown data layout for tensor conversion: {}",
458*89c4ff92SAndroid Build Coastguard Worker                                                    GetDataLayoutName(layout)));
459*89c4ff92SAndroid Build Coastguard Worker     }
460*89c4ff92SAndroid Build Coastguard Worker     return depthMultiplier;
461*89c4ff92SAndroid Build Coastguard Worker }
462*89c4ff92SAndroid Build Coastguard Worker 
463*89c4ff92SAndroid Build Coastguard Worker } // namespace armcomputetensorutils
464*89c4ff92SAndroid Build Coastguard Worker } // namespace armnn
465