1*89c4ff92SAndroid Build Coastguard Worker //
2*89c4ff92SAndroid Build Coastguard Worker // Copyright © 2017-2023 Arm Ltd and Contributors. All rights reserved.
3*89c4ff92SAndroid Build Coastguard Worker // SPDX-License-Identifier: MIT
4*89c4ff92SAndroid Build Coastguard Worker //
5*89c4ff92SAndroid Build Coastguard Worker #include <aclCommon/ArmComputeTensorUtils.hpp>
6*89c4ff92SAndroid Build Coastguard Worker #include <aclCommon/ArmComputeUtils.hpp>
7*89c4ff92SAndroid Build Coastguard Worker
8*89c4ff92SAndroid Build Coastguard Worker #include "armnn/Exceptions.hpp"
9*89c4ff92SAndroid Build Coastguard Worker #include "ArmComputeUtils.hpp"
10*89c4ff92SAndroid Build Coastguard Worker #include <armnn/Descriptors.hpp>
11*89c4ff92SAndroid Build Coastguard Worker
12*89c4ff92SAndroid Build Coastguard Worker #include <fmt/format.h>
13*89c4ff92SAndroid Build Coastguard Worker
14*89c4ff92SAndroid Build Coastguard Worker namespace armnn
15*89c4ff92SAndroid Build Coastguard Worker {
16*89c4ff92SAndroid Build Coastguard Worker namespace armcomputetensorutils
17*89c4ff92SAndroid Build Coastguard Worker {
18*89c4ff92SAndroid Build Coastguard Worker
GetArmComputeDataType(armnn::DataType dataType,bool multiScales)19*89c4ff92SAndroid Build Coastguard Worker arm_compute::DataType GetArmComputeDataType(armnn::DataType dataType, bool multiScales)
20*89c4ff92SAndroid Build Coastguard Worker {
21*89c4ff92SAndroid Build Coastguard Worker switch(dataType)
22*89c4ff92SAndroid Build Coastguard Worker {
23*89c4ff92SAndroid Build Coastguard Worker case armnn::DataType::BFloat16:
24*89c4ff92SAndroid Build Coastguard Worker return arm_compute::DataType::BFLOAT16;
25*89c4ff92SAndroid Build Coastguard Worker case armnn::DataType::Boolean:
26*89c4ff92SAndroid Build Coastguard Worker return arm_compute::DataType::U8;
27*89c4ff92SAndroid Build Coastguard Worker case armnn::DataType::Float16:
28*89c4ff92SAndroid Build Coastguard Worker return arm_compute::DataType::F16;
29*89c4ff92SAndroid Build Coastguard Worker case armnn::DataType::Float32:
30*89c4ff92SAndroid Build Coastguard Worker return arm_compute::DataType::F32;
31*89c4ff92SAndroid Build Coastguard Worker case armnn::DataType::QAsymmS8:
32*89c4ff92SAndroid Build Coastguard Worker return arm_compute::DataType::QASYMM8_SIGNED;
33*89c4ff92SAndroid Build Coastguard Worker case armnn::DataType::QAsymmU8:
34*89c4ff92SAndroid Build Coastguard Worker return arm_compute::DataType::QASYMM8;
35*89c4ff92SAndroid Build Coastguard Worker case armnn::DataType::QSymmS16:
36*89c4ff92SAndroid Build Coastguard Worker return arm_compute::DataType::QSYMM16;
37*89c4ff92SAndroid Build Coastguard Worker case armnn::DataType::Signed64:
38*89c4ff92SAndroid Build Coastguard Worker return arm_compute::DataType::S64;
39*89c4ff92SAndroid Build Coastguard Worker case armnn::DataType::QSymmS8:
40*89c4ff92SAndroid Build Coastguard Worker {
41*89c4ff92SAndroid Build Coastguard Worker return multiScales ? arm_compute::DataType::QSYMM8_PER_CHANNEL : arm_compute::DataType::QSYMM8;
42*89c4ff92SAndroid Build Coastguard Worker }
43*89c4ff92SAndroid Build Coastguard Worker case armnn::DataType::Signed32:
44*89c4ff92SAndroid Build Coastguard Worker return arm_compute::DataType::S32;
45*89c4ff92SAndroid Build Coastguard Worker default:
46*89c4ff92SAndroid Build Coastguard Worker ARMNN_ASSERT_MSG(false, "Unknown data type");
47*89c4ff92SAndroid Build Coastguard Worker return arm_compute::DataType::UNKNOWN;
48*89c4ff92SAndroid Build Coastguard Worker }
49*89c4ff92SAndroid Build Coastguard Worker }
50*89c4ff92SAndroid Build Coastguard Worker
GetArmNNDataType(arm_compute::DataType dataType)51*89c4ff92SAndroid Build Coastguard Worker armnn::DataType GetArmNNDataType(arm_compute::DataType dataType)
52*89c4ff92SAndroid Build Coastguard Worker {
53*89c4ff92SAndroid Build Coastguard Worker switch(dataType)
54*89c4ff92SAndroid Build Coastguard Worker {
55*89c4ff92SAndroid Build Coastguard Worker case arm_compute::DataType::BFLOAT16:
56*89c4ff92SAndroid Build Coastguard Worker return armnn::DataType::BFloat16;
57*89c4ff92SAndroid Build Coastguard Worker case arm_compute::DataType::U8:
58*89c4ff92SAndroid Build Coastguard Worker return armnn::DataType::Boolean;
59*89c4ff92SAndroid Build Coastguard Worker case arm_compute::DataType::F16:
60*89c4ff92SAndroid Build Coastguard Worker return armnn::DataType::Float16;
61*89c4ff92SAndroid Build Coastguard Worker case arm_compute::DataType::F32:
62*89c4ff92SAndroid Build Coastguard Worker return armnn::DataType::Float32;
63*89c4ff92SAndroid Build Coastguard Worker case arm_compute::DataType::QASYMM8_SIGNED:
64*89c4ff92SAndroid Build Coastguard Worker return armnn::DataType::QAsymmS8;
65*89c4ff92SAndroid Build Coastguard Worker case arm_compute::DataType::QASYMM8:
66*89c4ff92SAndroid Build Coastguard Worker return armnn::DataType::QAsymmU8;
67*89c4ff92SAndroid Build Coastguard Worker case arm_compute::DataType::QSYMM16:
68*89c4ff92SAndroid Build Coastguard Worker return armnn::DataType::QSymmS16;
69*89c4ff92SAndroid Build Coastguard Worker case arm_compute::DataType::S64:
70*89c4ff92SAndroid Build Coastguard Worker return armnn::DataType::Signed64;
71*89c4ff92SAndroid Build Coastguard Worker case arm_compute::DataType::QSYMM8_PER_CHANNEL:
72*89c4ff92SAndroid Build Coastguard Worker return armnn::DataType::QSymmS8;
73*89c4ff92SAndroid Build Coastguard Worker case arm_compute::DataType::QSYMM8:
74*89c4ff92SAndroid Build Coastguard Worker return armnn::DataType::QSymmS8;
75*89c4ff92SAndroid Build Coastguard Worker case arm_compute::DataType::S32:
76*89c4ff92SAndroid Build Coastguard Worker return armnn::DataType::Signed32;
77*89c4ff92SAndroid Build Coastguard Worker default:
78*89c4ff92SAndroid Build Coastguard Worker ARMNN_ASSERT_MSG(false, "Unknown data type");
79*89c4ff92SAndroid Build Coastguard Worker return armnn::DataType::Float32;
80*89c4ff92SAndroid Build Coastguard Worker }
81*89c4ff92SAndroid Build Coastguard Worker }
82*89c4ff92SAndroid Build Coastguard Worker
BuildArmComputeReductionCoordinates(size_t inputDimensions,unsigned int originalInputRank,const std::vector<unsigned int> & armnnAxes)83*89c4ff92SAndroid Build Coastguard Worker arm_compute::Coordinates BuildArmComputeReductionCoordinates(size_t inputDimensions,
84*89c4ff92SAndroid Build Coastguard Worker unsigned int originalInputRank,
85*89c4ff92SAndroid Build Coastguard Worker const std::vector<unsigned int>& armnnAxes)
86*89c4ff92SAndroid Build Coastguard Worker {
87*89c4ff92SAndroid Build Coastguard Worker arm_compute::Coordinates outAclCoords;
88*89c4ff92SAndroid Build Coastguard Worker
89*89c4ff92SAndroid Build Coastguard Worker if (armnnAxes.empty())
90*89c4ff92SAndroid Build Coastguard Worker {
91*89c4ff92SAndroid Build Coastguard Worker // If no reduction axes were provided, then the input must be reduced along all dimensions.
92*89c4ff92SAndroid Build Coastguard Worker // Since Compute Library does not accept an empty vector as the reduction dimensions, we then
93*89c4ff92SAndroid Build Coastguard Worker // manually create a vector including all the input dimensions (in reversed order) as:
94*89c4ff92SAndroid Build Coastguard Worker //
95*89c4ff92SAndroid Build Coastguard Worker // { inputDimensions - 1, inputDimensions - 2, ..., 1, 0 }
96*89c4ff92SAndroid Build Coastguard Worker //
97*89c4ff92SAndroid Build Coastguard Worker outAclCoords.set_num_dimensions(inputDimensions);
98*89c4ff92SAndroid Build Coastguard Worker std::generate(outAclCoords.begin(), outAclCoords.end(), [d = inputDimensions - 1] () mutable { return d--; });
99*89c4ff92SAndroid Build Coastguard Worker }
100*89c4ff92SAndroid Build Coastguard Worker else
101*89c4ff92SAndroid Build Coastguard Worker {
102*89c4ff92SAndroid Build Coastguard Worker // Create a vector of reduction dimensions (in reversed order) with the given reduction axes.
103*89c4ff92SAndroid Build Coastguard Worker //
104*89c4ff92SAndroid Build Coastguard Worker // Adjust the given reduction axes according to the original rank of the input tensor (before ACL applied any
105*89c4ff92SAndroid Build Coastguard Worker // dimension correction).
106*89c4ff92SAndroid Build Coastguard Worker // For example, if the input tensor originally had 4 dimensions, and one of the reduction axes was 2, then the
107*89c4ff92SAndroid Build Coastguard Worker // new value for that reduction axis should be 1.
108*89c4ff92SAndroid Build Coastguard Worker //
109*89c4ff92SAndroid Build Coastguard Worker // Example:
110*89c4ff92SAndroid Build Coastguard Worker // ArmNN input shape = { 1, 1, 3, 2 } -> ACL input shape = { 2, 3 }
111*89c4ff92SAndroid Build Coastguard Worker // ArmNN reduction axis = { 2 } -> ACL reduction axis = { 1 }
112*89c4ff92SAndroid Build Coastguard Worker // ArmNN reduction axis = { 3 } -> ACL reduction axis = { 0 }
113*89c4ff92SAndroid Build Coastguard Worker //
114*89c4ff92SAndroid Build Coastguard Worker // The transformation: ACL reduction axis index = original rank - ArmNN reduction axis index - 1
115*89c4ff92SAndroid Build Coastguard Worker //
116*89c4ff92SAndroid Build Coastguard Worker outAclCoords.set_num_dimensions(armnnAxes.size());
117*89c4ff92SAndroid Build Coastguard Worker std::transform(armnnAxes.begin(), armnnAxes.end(),
118*89c4ff92SAndroid Build Coastguard Worker outAclCoords.begin(),
119*89c4ff92SAndroid Build Coastguard Worker [originalInputRank](unsigned int i){ return originalInputRank - i - 1; });
120*89c4ff92SAndroid Build Coastguard Worker }
121*89c4ff92SAndroid Build Coastguard Worker
122*89c4ff92SAndroid Build Coastguard Worker return outAclCoords;
123*89c4ff92SAndroid Build Coastguard Worker }
124*89c4ff92SAndroid Build Coastguard Worker
BuildArmComputeTensorShape(const armnn::TensorShape & tensorShape)125*89c4ff92SAndroid Build Coastguard Worker arm_compute::TensorShape BuildArmComputeTensorShape(const armnn::TensorShape& tensorShape)
126*89c4ff92SAndroid Build Coastguard Worker {
127*89c4ff92SAndroid Build Coastguard Worker arm_compute::TensorShape shape;
128*89c4ff92SAndroid Build Coastguard Worker
129*89c4ff92SAndroid Build Coastguard Worker // armnn tensors are (batch, channels, height, width).
130*89c4ff92SAndroid Build Coastguard Worker // arm_compute tensors are (width, height, channels, batch).
131*89c4ff92SAndroid Build Coastguard Worker for (unsigned int i = 0; i < tensorShape.GetNumDimensions(); i++)
132*89c4ff92SAndroid Build Coastguard Worker {
133*89c4ff92SAndroid Build Coastguard Worker // Note that our dimensions are stored in the opposite order to ACL's.
134*89c4ff92SAndroid Build Coastguard Worker shape.set(tensorShape.GetNumDimensions() - i - 1, tensorShape[i], false);
135*89c4ff92SAndroid Build Coastguard Worker
136*89c4ff92SAndroid Build Coastguard Worker // TensorShape::set() flattens leading ones, so that batch size 1 cannot happen.
137*89c4ff92SAndroid Build Coastguard Worker // arm_compute tensors expect this.
138*89c4ff92SAndroid Build Coastguard Worker }
139*89c4ff92SAndroid Build Coastguard Worker
140*89c4ff92SAndroid Build Coastguard Worker // prevent arm_compute issue where tensor is flattened to nothing
141*89c4ff92SAndroid Build Coastguard Worker if (shape.num_dimensions() == 0)
142*89c4ff92SAndroid Build Coastguard Worker {
143*89c4ff92SAndroid Build Coastguard Worker shape.set_num_dimensions(1);
144*89c4ff92SAndroid Build Coastguard Worker }
145*89c4ff92SAndroid Build Coastguard Worker
146*89c4ff92SAndroid Build Coastguard Worker return shape;
147*89c4ff92SAndroid Build Coastguard Worker }
148*89c4ff92SAndroid Build Coastguard Worker
ReduceDimsForACL(const armnn::TensorShape tensorShape,unsigned int dimensions)149*89c4ff92SAndroid Build Coastguard Worker std::vector<unsigned int> ReduceDimsForACL(const armnn::TensorShape tensorShape, unsigned int dimensions)
150*89c4ff92SAndroid Build Coastguard Worker {
151*89c4ff92SAndroid Build Coastguard Worker std::vector<unsigned int> newShape;
152*89c4ff92SAndroid Build Coastguard Worker
153*89c4ff92SAndroid Build Coastguard Worker unsigned int dimsToSkip = 0;
154*89c4ff92SAndroid Build Coastguard Worker
155*89c4ff92SAndroid Build Coastguard Worker if (tensorShape.GetNumDimensions() > dimensions)
156*89c4ff92SAndroid Build Coastguard Worker {
157*89c4ff92SAndroid Build Coastguard Worker dimsToSkip = tensorShape.GetNumDimensions() - dimensions;
158*89c4ff92SAndroid Build Coastguard Worker }
159*89c4ff92SAndroid Build Coastguard Worker unsigned int dimsSkipped = 0;
160*89c4ff92SAndroid Build Coastguard Worker bool insertRemainder = false;
161*89c4ff92SAndroid Build Coastguard Worker
162*89c4ff92SAndroid Build Coastguard Worker for (unsigned int i = 0; i < tensorShape.GetNumDimensions(); ++i)
163*89c4ff92SAndroid Build Coastguard Worker {
164*89c4ff92SAndroid Build Coastguard Worker if (tensorShape[i] == 1 && dimsSkipped < dimsToSkip && !insertRemainder)
165*89c4ff92SAndroid Build Coastguard Worker {
166*89c4ff92SAndroid Build Coastguard Worker ++dimsSkipped;
167*89c4ff92SAndroid Build Coastguard Worker continue;
168*89c4ff92SAndroid Build Coastguard Worker }
169*89c4ff92SAndroid Build Coastguard Worker newShape.insert(newShape.begin(), tensorShape[i]);
170*89c4ff92SAndroid Build Coastguard Worker // Once we insert the first dimension we can't skip any more
171*89c4ff92SAndroid Build Coastguard Worker insertRemainder = true;
172*89c4ff92SAndroid Build Coastguard Worker }
173*89c4ff92SAndroid Build Coastguard Worker return newShape;
174*89c4ff92SAndroid Build Coastguard Worker }
175*89c4ff92SAndroid Build Coastguard Worker
BuildArmComputeTensorShape(const armnn::TensorShape & tensorShape,unsigned int dimensions)176*89c4ff92SAndroid Build Coastguard Worker arm_compute::TensorShape BuildArmComputeTensorShape(const armnn::TensorShape& tensorShape, unsigned int dimensions)
177*89c4ff92SAndroid Build Coastguard Worker {
178*89c4ff92SAndroid Build Coastguard Worker arm_compute::TensorShape shape;
179*89c4ff92SAndroid Build Coastguard Worker std::vector<unsigned int> strippedShape = ReduceDimsForACL(tensorShape, dimensions);
180*89c4ff92SAndroid Build Coastguard Worker
181*89c4ff92SAndroid Build Coastguard Worker for (unsigned int i = 0; i < strippedShape.size(); i++)
182*89c4ff92SAndroid Build Coastguard Worker {
183*89c4ff92SAndroid Build Coastguard Worker shape.set(i, strippedShape[i], false);
184*89c4ff92SAndroid Build Coastguard Worker }
185*89c4ff92SAndroid Build Coastguard Worker
186*89c4ff92SAndroid Build Coastguard Worker // prevent arm_compute issue where tensor is flattened to nothing
187*89c4ff92SAndroid Build Coastguard Worker if (shape.num_dimensions() == 0)
188*89c4ff92SAndroid Build Coastguard Worker {
189*89c4ff92SAndroid Build Coastguard Worker shape.set_num_dimensions(1);
190*89c4ff92SAndroid Build Coastguard Worker }
191*89c4ff92SAndroid Build Coastguard Worker return shape;
192*89c4ff92SAndroid Build Coastguard Worker }
193*89c4ff92SAndroid Build Coastguard Worker
194*89c4ff92SAndroid Build Coastguard Worker // Utility function used to build a TensorInfo object, that can be used to initialise
195*89c4ff92SAndroid Build Coastguard Worker // ARM Compute Tensor and CLTensor allocators.
196*89c4ff92SAndroid Build Coastguard Worker // Note: this utility ignores the value of armnn::TensorInfo.IsConstant(). ACL tensors
197*89c4ff92SAndroid Build Coastguard Worker // default to constant but Arm NN ones default to non constant. In the cases where
198*89c4ff92SAndroid Build Coastguard Worker // we expect ACL to treat a tensor as constant that value must be set after this
199*89c4ff92SAndroid Build Coastguard Worker // utility has been called.
BuildArmComputeTensorInfo(const armnn::TensorInfo & tensorInfo)200*89c4ff92SAndroid Build Coastguard Worker arm_compute::TensorInfo BuildArmComputeTensorInfo(const armnn::TensorInfo& tensorInfo)
201*89c4ff92SAndroid Build Coastguard Worker {
202*89c4ff92SAndroid Build Coastguard Worker bool multiScales = tensorInfo.HasMultipleQuantizationScales();
203*89c4ff92SAndroid Build Coastguard Worker const arm_compute::TensorShape aclTensorShape = BuildArmComputeTensorShape(tensorInfo.GetShape());
204*89c4ff92SAndroid Build Coastguard Worker const arm_compute::DataType aclDataType = GetArmComputeDataType(tensorInfo.GetDataType(), multiScales);
205*89c4ff92SAndroid Build Coastguard Worker
206*89c4ff92SAndroid Build Coastguard Worker const arm_compute::QuantizationInfo aclQuantizationInfo = multiScales ?
207*89c4ff92SAndroid Build Coastguard Worker arm_compute::QuantizationInfo(tensorInfo.GetQuantizationScales()) :
208*89c4ff92SAndroid Build Coastguard Worker arm_compute::QuantizationInfo(tensorInfo.GetQuantizationScale(), tensorInfo.GetQuantizationOffset());
209*89c4ff92SAndroid Build Coastguard Worker
210*89c4ff92SAndroid Build Coastguard Worker return arm_compute::TensorInfo(aclTensorShape, 1, aclDataType, aclQuantizationInfo);
211*89c4ff92SAndroid Build Coastguard Worker }
212*89c4ff92SAndroid Build Coastguard Worker
BuildArmComputeTensorInfo(const armnn::TensorInfo & tensorInfo,armnn::DataLayout dataLayout)213*89c4ff92SAndroid Build Coastguard Worker arm_compute::TensorInfo BuildArmComputeTensorInfo(const armnn::TensorInfo& tensorInfo,
214*89c4ff92SAndroid Build Coastguard Worker armnn::DataLayout dataLayout)
215*89c4ff92SAndroid Build Coastguard Worker {
216*89c4ff92SAndroid Build Coastguard Worker arm_compute::TensorInfo aclTensorInfo = BuildArmComputeTensorInfo(tensorInfo);
217*89c4ff92SAndroid Build Coastguard Worker aclTensorInfo.set_data_layout(ConvertDataLayout(dataLayout));
218*89c4ff92SAndroid Build Coastguard Worker
219*89c4ff92SAndroid Build Coastguard Worker return aclTensorInfo;
220*89c4ff92SAndroid Build Coastguard Worker }
221*89c4ff92SAndroid Build Coastguard Worker
BuildArmComputeTensorInfo(const armnn::TensorInfo & tensorInfo,unsigned int dimensions)222*89c4ff92SAndroid Build Coastguard Worker arm_compute::TensorInfo BuildArmComputeTensorInfo(const armnn::TensorInfo& tensorInfo, unsigned int dimensions)
223*89c4ff92SAndroid Build Coastguard Worker {
224*89c4ff92SAndroid Build Coastguard Worker bool multiScales = tensorInfo.HasMultipleQuantizationScales();
225*89c4ff92SAndroid Build Coastguard Worker const arm_compute::TensorShape aclTensorShape = BuildArmComputeTensorShape(tensorInfo.GetShape(), dimensions);
226*89c4ff92SAndroid Build Coastguard Worker const arm_compute::DataType aclDataType = GetArmComputeDataType(tensorInfo.GetDataType(), multiScales);
227*89c4ff92SAndroid Build Coastguard Worker
228*89c4ff92SAndroid Build Coastguard Worker const arm_compute::QuantizationInfo aclQuantizationInfo = multiScales ?
229*89c4ff92SAndroid Build Coastguard Worker arm_compute::QuantizationInfo(tensorInfo.GetQuantizationScales()) :
230*89c4ff92SAndroid Build Coastguard Worker arm_compute::QuantizationInfo(tensorInfo.GetQuantizationScale(), tensorInfo.GetQuantizationOffset());
231*89c4ff92SAndroid Build Coastguard Worker
232*89c4ff92SAndroid Build Coastguard Worker return arm_compute::TensorInfo(aclTensorShape, 1, aclDataType, aclQuantizationInfo);
233*89c4ff92SAndroid Build Coastguard Worker }
BuildArmComputeTensorInfo(const armnn::TensorInfo & tensorInfo,armnn::DataLayout dataLayout,unsigned int dimensions)234*89c4ff92SAndroid Build Coastguard Worker arm_compute::TensorInfo BuildArmComputeTensorInfo(const armnn::TensorInfo& tensorInfo,
235*89c4ff92SAndroid Build Coastguard Worker armnn::DataLayout dataLayout, unsigned int dimensions)
236*89c4ff92SAndroid Build Coastguard Worker {
237*89c4ff92SAndroid Build Coastguard Worker arm_compute::TensorInfo aclTensorInfo = BuildArmComputeTensorInfo(tensorInfo, dimensions);
238*89c4ff92SAndroid Build Coastguard Worker aclTensorInfo.set_data_layout(ConvertDataLayout(dataLayout));
239*89c4ff92SAndroid Build Coastguard Worker
240*89c4ff92SAndroid Build Coastguard Worker return aclTensorInfo;
241*89c4ff92SAndroid Build Coastguard Worker }
242*89c4ff92SAndroid Build Coastguard Worker
243*89c4ff92SAndroid Build Coastguard Worker
ConvertDataLayout(armnn::DataLayout dataLayout)244*89c4ff92SAndroid Build Coastguard Worker arm_compute::DataLayout ConvertDataLayout(armnn::DataLayout dataLayout)
245*89c4ff92SAndroid Build Coastguard Worker {
246*89c4ff92SAndroid Build Coastguard Worker switch(dataLayout)
247*89c4ff92SAndroid Build Coastguard Worker {
248*89c4ff92SAndroid Build Coastguard Worker case armnn::DataLayout::NHWC : return arm_compute::DataLayout::NHWC;
249*89c4ff92SAndroid Build Coastguard Worker
250*89c4ff92SAndroid Build Coastguard Worker case armnn::DataLayout::NCHW : return arm_compute::DataLayout::NCHW;
251*89c4ff92SAndroid Build Coastguard Worker
252*89c4ff92SAndroid Build Coastguard Worker case armnn::DataLayout::NDHWC : return arm_compute::DataLayout::NDHWC;
253*89c4ff92SAndroid Build Coastguard Worker
254*89c4ff92SAndroid Build Coastguard Worker case armnn::DataLayout::NCDHW : return arm_compute::DataLayout::NCDHW;
255*89c4ff92SAndroid Build Coastguard Worker
256*89c4ff92SAndroid Build Coastguard Worker default: throw InvalidArgumentException("Unknown armnn::DataLayout: [" +
257*89c4ff92SAndroid Build Coastguard Worker std::to_string(static_cast<int>(dataLayout)) + "]");
258*89c4ff92SAndroid Build Coastguard Worker }
259*89c4ff92SAndroid Build Coastguard Worker }
260*89c4ff92SAndroid Build Coastguard Worker
BuildArmComputePoolingLayerInfo(const Pooling2dDescriptor & descriptor,bool fpMixedPrecision)261*89c4ff92SAndroid Build Coastguard Worker arm_compute::PoolingLayerInfo BuildArmComputePoolingLayerInfo(const Pooling2dDescriptor& descriptor,
262*89c4ff92SAndroid Build Coastguard Worker bool fpMixedPrecision)
263*89c4ff92SAndroid Build Coastguard Worker {
264*89c4ff92SAndroid Build Coastguard Worker // Resolve ARM Compute layer parameters.
265*89c4ff92SAndroid Build Coastguard Worker const arm_compute::PoolingType poolingType = ConvertPoolingAlgorithmToAclPoolingType(descriptor.m_PoolType);
266*89c4ff92SAndroid Build Coastguard Worker
267*89c4ff92SAndroid Build Coastguard Worker const arm_compute::DataLayout dataLayout = ConvertDataLayout(descriptor.m_DataLayout);
268*89c4ff92SAndroid Build Coastguard Worker
269*89c4ff92SAndroid Build Coastguard Worker bool isGlobalPooling = (descriptor.m_StrideX==0 && descriptor.m_StrideY==0);
270*89c4ff92SAndroid Build Coastguard Worker //use specific constructor if global pooling
271*89c4ff92SAndroid Build Coastguard Worker if(isGlobalPooling)
272*89c4ff92SAndroid Build Coastguard Worker {
273*89c4ff92SAndroid Build Coastguard Worker return arm_compute::PoolingLayerInfo(poolingType, dataLayout);
274*89c4ff92SAndroid Build Coastguard Worker }
275*89c4ff92SAndroid Build Coastguard Worker
276*89c4ff92SAndroid Build Coastguard Worker const arm_compute::DimensionRoundingType rounding = ConvertOutputShapeRoundingToAclDimensionRoundingType(
277*89c4ff92SAndroid Build Coastguard Worker descriptor.m_OutputShapeRounding);
278*89c4ff92SAndroid Build Coastguard Worker const arm_compute::PadStrideInfo padStrideInfo(descriptor.m_StrideX,
279*89c4ff92SAndroid Build Coastguard Worker descriptor.m_StrideY,
280*89c4ff92SAndroid Build Coastguard Worker descriptor.m_PadLeft,
281*89c4ff92SAndroid Build Coastguard Worker descriptor.m_PadRight,
282*89c4ff92SAndroid Build Coastguard Worker descriptor.m_PadTop,
283*89c4ff92SAndroid Build Coastguard Worker descriptor.m_PadBottom,
284*89c4ff92SAndroid Build Coastguard Worker rounding);
285*89c4ff92SAndroid Build Coastguard Worker
286*89c4ff92SAndroid Build Coastguard Worker const bool excludePadding = (descriptor.m_PaddingMethod == PaddingMethod::Exclude);
287*89c4ff92SAndroid Build Coastguard Worker
288*89c4ff92SAndroid Build Coastguard Worker const arm_compute::Size2D poolSize(descriptor.m_PoolWidth, descriptor.m_PoolHeight);
289*89c4ff92SAndroid Build Coastguard Worker
290*89c4ff92SAndroid Build Coastguard Worker return arm_compute::PoolingLayerInfo(poolingType, poolSize, dataLayout, padStrideInfo, excludePadding,
291*89c4ff92SAndroid Build Coastguard Worker fpMixedPrecision);
292*89c4ff92SAndroid Build Coastguard Worker }
293*89c4ff92SAndroid Build Coastguard Worker
BuildArmComputePooling3dLayerInfo(const Pooling3dDescriptor & descriptor,bool fpMixedPrecision)294*89c4ff92SAndroid Build Coastguard Worker arm_compute::Pooling3dLayerInfo BuildArmComputePooling3dLayerInfo(const Pooling3dDescriptor& descriptor,
295*89c4ff92SAndroid Build Coastguard Worker bool fpMixedPrecision)
296*89c4ff92SAndroid Build Coastguard Worker {
297*89c4ff92SAndroid Build Coastguard Worker const arm_compute::PoolingType poolingType = ConvertPoolingAlgorithmToAclPoolingType(descriptor.m_PoolType);
298*89c4ff92SAndroid Build Coastguard Worker
299*89c4ff92SAndroid Build Coastguard Worker bool isGlobalPooling = (descriptor.m_StrideX==0 && descriptor.m_StrideY==0 && descriptor.m_StrideZ==0);
300*89c4ff92SAndroid Build Coastguard Worker //use specific constructor if global pooling
301*89c4ff92SAndroid Build Coastguard Worker if(isGlobalPooling)
302*89c4ff92SAndroid Build Coastguard Worker {
303*89c4ff92SAndroid Build Coastguard Worker return arm_compute::Pooling3dLayerInfo(poolingType);
304*89c4ff92SAndroid Build Coastguard Worker }
305*89c4ff92SAndroid Build Coastguard Worker
306*89c4ff92SAndroid Build Coastguard Worker const arm_compute::Size3D poolSize(descriptor.m_PoolWidth, descriptor.m_PoolHeight, descriptor.m_PoolDepth);
307*89c4ff92SAndroid Build Coastguard Worker
308*89c4ff92SAndroid Build Coastguard Worker const arm_compute::Size3D stride(descriptor.m_StrideX,
309*89c4ff92SAndroid Build Coastguard Worker descriptor.m_StrideY,
310*89c4ff92SAndroid Build Coastguard Worker descriptor.m_StrideZ);
311*89c4ff92SAndroid Build Coastguard Worker
312*89c4ff92SAndroid Build Coastguard Worker const arm_compute::Padding3D padding(descriptor.m_PadLeft,
313*89c4ff92SAndroid Build Coastguard Worker descriptor.m_PadRight,
314*89c4ff92SAndroid Build Coastguard Worker descriptor.m_PadTop,
315*89c4ff92SAndroid Build Coastguard Worker descriptor.m_PadBottom,
316*89c4ff92SAndroid Build Coastguard Worker descriptor.m_PadFront,
317*89c4ff92SAndroid Build Coastguard Worker descriptor.m_PadBack);
318*89c4ff92SAndroid Build Coastguard Worker
319*89c4ff92SAndroid Build Coastguard Worker const bool excludePadding = (descriptor.m_PaddingMethod == PaddingMethod::Exclude);
320*89c4ff92SAndroid Build Coastguard Worker
321*89c4ff92SAndroid Build Coastguard Worker const arm_compute::DimensionRoundingType rounding = ConvertOutputShapeRoundingToAclDimensionRoundingType(
322*89c4ff92SAndroid Build Coastguard Worker descriptor.m_OutputShapeRounding);
323*89c4ff92SAndroid Build Coastguard Worker
324*89c4ff92SAndroid Build Coastguard Worker return arm_compute::Pooling3dLayerInfo(poolingType,
325*89c4ff92SAndroid Build Coastguard Worker poolSize,
326*89c4ff92SAndroid Build Coastguard Worker stride,
327*89c4ff92SAndroid Build Coastguard Worker padding,
328*89c4ff92SAndroid Build Coastguard Worker excludePadding,
329*89c4ff92SAndroid Build Coastguard Worker fpMixedPrecision,
330*89c4ff92SAndroid Build Coastguard Worker rounding);
331*89c4ff92SAndroid Build Coastguard Worker }
332*89c4ff92SAndroid Build Coastguard Worker
BuildArmComputeNormalizationLayerInfo(const NormalizationDescriptor & descriptor)333*89c4ff92SAndroid Build Coastguard Worker arm_compute::NormalizationLayerInfo BuildArmComputeNormalizationLayerInfo(const NormalizationDescriptor& descriptor)
334*89c4ff92SAndroid Build Coastguard Worker {
335*89c4ff92SAndroid Build Coastguard Worker const arm_compute::NormType normType =
336*89c4ff92SAndroid Build Coastguard Worker ConvertNormalizationAlgorithmChannelToAclNormType(descriptor.m_NormChannelType);
337*89c4ff92SAndroid Build Coastguard Worker return arm_compute::NormalizationLayerInfo(normType,
338*89c4ff92SAndroid Build Coastguard Worker descriptor.m_NormSize,
339*89c4ff92SAndroid Build Coastguard Worker descriptor.m_Alpha,
340*89c4ff92SAndroid Build Coastguard Worker descriptor.m_Beta,
341*89c4ff92SAndroid Build Coastguard Worker descriptor.m_K,
342*89c4ff92SAndroid Build Coastguard Worker false);
343*89c4ff92SAndroid Build Coastguard Worker }
344*89c4ff92SAndroid Build Coastguard Worker
BuildArmComputePermutationVector(const armnn::PermutationVector & perm)345*89c4ff92SAndroid Build Coastguard Worker arm_compute::PermutationVector BuildArmComputePermutationVector(const armnn::PermutationVector& perm)
346*89c4ff92SAndroid Build Coastguard Worker {
347*89c4ff92SAndroid Build Coastguard Worker arm_compute::PermutationVector aclPerm;
348*89c4ff92SAndroid Build Coastguard Worker
349*89c4ff92SAndroid Build Coastguard Worker unsigned int start = 0;
350*89c4ff92SAndroid Build Coastguard Worker while ((start < perm.GetSize()) && (start == perm[start]))
351*89c4ff92SAndroid Build Coastguard Worker {
352*89c4ff92SAndroid Build Coastguard Worker ++start;
353*89c4ff92SAndroid Build Coastguard Worker }
354*89c4ff92SAndroid Build Coastguard Worker
355*89c4ff92SAndroid Build Coastguard Worker for (unsigned int i = start; i < perm.GetSize(); ++i)
356*89c4ff92SAndroid Build Coastguard Worker {
357*89c4ff92SAndroid Build Coastguard Worker aclPerm.set(i - start, perm[i] - start);
358*89c4ff92SAndroid Build Coastguard Worker }
359*89c4ff92SAndroid Build Coastguard Worker return aclPerm;
360*89c4ff92SAndroid Build Coastguard Worker }
361*89c4ff92SAndroid Build Coastguard Worker
BuildArmComputeTransposeVector(const armnn::PermutationVector & perm)362*89c4ff92SAndroid Build Coastguard Worker arm_compute::PermutationVector BuildArmComputeTransposeVector(const armnn::PermutationVector& perm)
363*89c4ff92SAndroid Build Coastguard Worker {
364*89c4ff92SAndroid Build Coastguard Worker // As ArmNN indexes are left to right and ACL indexes are right to left,
365*89c4ff92SAndroid Build Coastguard Worker // the permutation vector has to be reversed and then translated into ACL axis.
366*89c4ff92SAndroid Build Coastguard Worker // i.e. {1, 0, 2, 3} --> {3, 2, 0, 1} --> {0, 1, 3, 2}
367*89c4ff92SAndroid Build Coastguard Worker
368*89c4ff92SAndroid Build Coastguard Worker // Below an example of how the ArmNN and ACL index format work:
369*89c4ff92SAndroid Build Coastguard Worker // ArmNN Format:
370*89c4ff92SAndroid Build Coastguard Worker // Input Shape {1, 10, 20, 30}
371*89c4ff92SAndroid Build Coastguard Worker // Permutation Vector {1, 0, 2, 3}
372*89c4ff92SAndroid Build Coastguard Worker // Output Shape {10, 1, 20, 30}
373*89c4ff92SAndroid Build Coastguard Worker // dim "1" of input goes into index 0 of the output ([ 10, X, X, X])
374*89c4ff92SAndroid Build Coastguard Worker // dim "0" of input goes into index 1 of the output ([ 10, 1, X, X ])
375*89c4ff92SAndroid Build Coastguard Worker // dim "2" of input goes into index 2 of the output ([ 10, 1, 20, X ])
376*89c4ff92SAndroid Build Coastguard Worker // dim "3" of input goes into index 3 of the output ([ 10, 1, 20, 30 ])
377*89c4ff92SAndroid Build Coastguard Worker // ACL Format:
378*89c4ff92SAndroid Build Coastguard Worker // Input Shape {30, 20, 10, 1}
379*89c4ff92SAndroid Build Coastguard Worker // Permutation Vector {0, 1, 3, 2}
380*89c4ff92SAndroid Build Coastguard Worker // Output Shape {30, 20, 1, 10}
381*89c4ff92SAndroid Build Coastguard Worker // dim "0" of input goes into index 0 of the output ([ 30, X, X, X])
382*89c4ff92SAndroid Build Coastguard Worker // dim "1" of input goes into index 1 of the output ([ 30, 20, X, X ])
383*89c4ff92SAndroid Build Coastguard Worker // dim "3" of input goes into index 2 of the output ([ 30, 20, 1, X ])
384*89c4ff92SAndroid Build Coastguard Worker // dim "2" of input goes into index 3 of the output ([ 30, 20, 1, 10 ])
385*89c4ff92SAndroid Build Coastguard Worker
386*89c4ff92SAndroid Build Coastguard Worker arm_compute::PermutationVector aclPerm;
387*89c4ff92SAndroid Build Coastguard Worker auto rank = perm.GetSize();
388*89c4ff92SAndroid Build Coastguard Worker
389*89c4ff92SAndroid Build Coastguard Worker // Reverse the order. i.e. {1, 0, 2, 3} --> {3, 2, 0, 1}
390*89c4ff92SAndroid Build Coastguard Worker std::vector<unsigned int> reversedPerm;
391*89c4ff92SAndroid Build Coastguard Worker reversedPerm.reserve(rank);
392*89c4ff92SAndroid Build Coastguard Worker for (unsigned int i = rank; i > 0; --i)
393*89c4ff92SAndroid Build Coastguard Worker {
394*89c4ff92SAndroid Build Coastguard Worker reversedPerm.push_back(perm[i-1]);
395*89c4ff92SAndroid Build Coastguard Worker }
396*89c4ff92SAndroid Build Coastguard Worker
397*89c4ff92SAndroid Build Coastguard Worker // Translate from Arm NN axis to ACL axis. i.e. {3, 2, 0, 1} --> {0, 1, 3, 2}
398*89c4ff92SAndroid Build Coastguard Worker for (unsigned int i = 0; i < rank; ++i)
399*89c4ff92SAndroid Build Coastguard Worker {
400*89c4ff92SAndroid Build Coastguard Worker auto aclAxis = rank - 1 - reversedPerm[i];
401*89c4ff92SAndroid Build Coastguard Worker aclPerm.set(i, aclAxis);
402*89c4ff92SAndroid Build Coastguard Worker }
403*89c4ff92SAndroid Build Coastguard Worker return aclPerm;
404*89c4ff92SAndroid Build Coastguard Worker }
405*89c4ff92SAndroid Build Coastguard Worker
BuildArmComputeSize2D(const unsigned int width,const unsigned int height)406*89c4ff92SAndroid Build Coastguard Worker arm_compute::Size2D BuildArmComputeSize2D(const unsigned int width, const unsigned int height)
407*89c4ff92SAndroid Build Coastguard Worker {
408*89c4ff92SAndroid Build Coastguard Worker return arm_compute::Size2D(width, height);
409*89c4ff92SAndroid Build Coastguard Worker }
410*89c4ff92SAndroid Build Coastguard Worker
GetPixelValue(const arm_compute::ITensorInfo * tensorInfo,float value)411*89c4ff92SAndroid Build Coastguard Worker arm_compute::PixelValue GetPixelValue(const arm_compute::ITensorInfo* tensorInfo, float value)
412*89c4ff92SAndroid Build Coastguard Worker {
413*89c4ff92SAndroid Build Coastguard Worker switch (tensorInfo->data_type())
414*89c4ff92SAndroid Build Coastguard Worker {
415*89c4ff92SAndroid Build Coastguard Worker case arm_compute::DataType::F16:
416*89c4ff92SAndroid Build Coastguard Worker {
417*89c4ff92SAndroid Build Coastguard Worker arm_compute::PixelValue pixelValue = arm_compute::PixelValue(static_cast<Half>(value));
418*89c4ff92SAndroid Build Coastguard Worker if (isinf(pixelValue.get<Half>())) {
419*89c4ff92SAndroid Build Coastguard Worker throw InvalidArgumentException("Under/Overflow converting float value [" + std::to_string(value) +
420*89c4ff92SAndroid Build Coastguard Worker "] to fp16: [" + std::to_string(pixelValue.get<Half>()) + "]");
421*89c4ff92SAndroid Build Coastguard Worker }
422*89c4ff92SAndroid Build Coastguard Worker return pixelValue;
423*89c4ff92SAndroid Build Coastguard Worker }
424*89c4ff92SAndroid Build Coastguard Worker case arm_compute::DataType::F32:
425*89c4ff92SAndroid Build Coastguard Worker return arm_compute::PixelValue(value);
426*89c4ff92SAndroid Build Coastguard Worker case arm_compute::DataType::QASYMM8:
427*89c4ff92SAndroid Build Coastguard Worker return arm_compute::PixelValue(static_cast<uint8_t>(value));
428*89c4ff92SAndroid Build Coastguard Worker case arm_compute::DataType::QSYMM16:
429*89c4ff92SAndroid Build Coastguard Worker return arm_compute::PixelValue(static_cast<int16_t>(value));
430*89c4ff92SAndroid Build Coastguard Worker case arm_compute::DataType::QSYMM8:
431*89c4ff92SAndroid Build Coastguard Worker case arm_compute::DataType::QASYMM8_SIGNED:
432*89c4ff92SAndroid Build Coastguard Worker case arm_compute::DataType::QSYMM8_PER_CHANNEL:
433*89c4ff92SAndroid Build Coastguard Worker return arm_compute::PixelValue(static_cast<int8_t>(value));
434*89c4ff92SAndroid Build Coastguard Worker case arm_compute::DataType::S32:
435*89c4ff92SAndroid Build Coastguard Worker return arm_compute::PixelValue(static_cast<int32_t>(value));
436*89c4ff92SAndroid Build Coastguard Worker default:
437*89c4ff92SAndroid Build Coastguard Worker throw InvalidArgumentException("Unsupported DataType: [" +
438*89c4ff92SAndroid Build Coastguard Worker std::to_string(static_cast<int>(tensorInfo->data_type())) + "]");
439*89c4ff92SAndroid Build Coastguard Worker }
440*89c4ff92SAndroid Build Coastguard Worker }
441*89c4ff92SAndroid Build Coastguard Worker
ComputeDepthwiseConv2dDepthMultiplier(armnn::DataLayout layout,const arm_compute::TensorShape & weightsShape,const arm_compute::TensorShape & inputShape)442*89c4ff92SAndroid Build Coastguard Worker unsigned int ComputeDepthwiseConv2dDepthMultiplier(armnn::DataLayout layout,
443*89c4ff92SAndroid Build Coastguard Worker const arm_compute::TensorShape& weightsShape,
444*89c4ff92SAndroid Build Coastguard Worker const arm_compute::TensorShape& inputShape)
445*89c4ff92SAndroid Build Coastguard Worker {
446*89c4ff92SAndroid Build Coastguard Worker unsigned int depthMultiplier;
447*89c4ff92SAndroid Build Coastguard Worker if (layout == armnn::DataLayout::NHWC)
448*89c4ff92SAndroid Build Coastguard Worker {
449*89c4ff92SAndroid Build Coastguard Worker depthMultiplier = static_cast<uint32_t>(weightsShape[0]) / static_cast<uint32_t>(inputShape[0]);
450*89c4ff92SAndroid Build Coastguard Worker }
451*89c4ff92SAndroid Build Coastguard Worker else if (layout == armnn::DataLayout::NCHW)
452*89c4ff92SAndroid Build Coastguard Worker {
453*89c4ff92SAndroid Build Coastguard Worker depthMultiplier = static_cast<uint32_t>(weightsShape[2]) / static_cast<uint32_t>(inputShape[2]);
454*89c4ff92SAndroid Build Coastguard Worker }
455*89c4ff92SAndroid Build Coastguard Worker else
456*89c4ff92SAndroid Build Coastguard Worker {
457*89c4ff92SAndroid Build Coastguard Worker throw InvalidArgumentException(fmt::format("Unknown data layout for tensor conversion: {}",
458*89c4ff92SAndroid Build Coastguard Worker GetDataLayoutName(layout)));
459*89c4ff92SAndroid Build Coastguard Worker }
460*89c4ff92SAndroid Build Coastguard Worker return depthMultiplier;
461*89c4ff92SAndroid Build Coastguard Worker }
462*89c4ff92SAndroid Build Coastguard Worker
463*89c4ff92SAndroid Build Coastguard Worker } // namespace armcomputetensorutils
464*89c4ff92SAndroid Build Coastguard Worker } // namespace armnn
465