1*89c4ff92SAndroid Build Coastguard Worker //
2*89c4ff92SAndroid Build Coastguard Worker // Copyright © 2017-2023 Arm Ltd. All rights reserved.
3*89c4ff92SAndroid Build Coastguard Worker // SPDX-License-Identifier: MIT
4*89c4ff92SAndroid Build Coastguard Worker //
5*89c4ff92SAndroid Build Coastguard Worker
6*89c4ff92SAndroid Build Coastguard Worker #include <armnnUtils/TensorUtils.hpp>
7*89c4ff92SAndroid Build Coastguard Worker
8*89c4ff92SAndroid Build Coastguard Worker #include <armnn/backends/ITensorHandle.hpp>
9*89c4ff92SAndroid Build Coastguard Worker #include <armnn/utility/Assert.hpp>
10*89c4ff92SAndroid Build Coastguard Worker #include <armnn/utility/NumericCast.hpp>
11*89c4ff92SAndroid Build Coastguard Worker
12*89c4ff92SAndroid Build Coastguard Worker #include <fmt/format.h>
13*89c4ff92SAndroid Build Coastguard Worker
14*89c4ff92SAndroid Build Coastguard Worker using namespace armnn;
15*89c4ff92SAndroid Build Coastguard Worker
16*89c4ff92SAndroid Build Coastguard Worker namespace armnnUtils
17*89c4ff92SAndroid Build Coastguard Worker {
18*89c4ff92SAndroid Build Coastguard Worker
GetTensorShape(unsigned int numberOfBatches,unsigned int numberOfChannels,unsigned int height,unsigned int width,const DataLayout dataLayout)19*89c4ff92SAndroid Build Coastguard Worker TensorShape GetTensorShape(unsigned int numberOfBatches,
20*89c4ff92SAndroid Build Coastguard Worker unsigned int numberOfChannels,
21*89c4ff92SAndroid Build Coastguard Worker unsigned int height,
22*89c4ff92SAndroid Build Coastguard Worker unsigned int width,
23*89c4ff92SAndroid Build Coastguard Worker const DataLayout dataLayout)
24*89c4ff92SAndroid Build Coastguard Worker {
25*89c4ff92SAndroid Build Coastguard Worker switch (dataLayout)
26*89c4ff92SAndroid Build Coastguard Worker {
27*89c4ff92SAndroid Build Coastguard Worker case DataLayout::NCHW:
28*89c4ff92SAndroid Build Coastguard Worker return TensorShape({numberOfBatches, numberOfChannels, height, width});
29*89c4ff92SAndroid Build Coastguard Worker case DataLayout::NHWC:
30*89c4ff92SAndroid Build Coastguard Worker return TensorShape({numberOfBatches, height, width, numberOfChannels});
31*89c4ff92SAndroid Build Coastguard Worker default:
32*89c4ff92SAndroid Build Coastguard Worker throw InvalidArgumentException("Unknown data layout ["
33*89c4ff92SAndroid Build Coastguard Worker + std::to_string(static_cast<int>(dataLayout)) +
34*89c4ff92SAndroid Build Coastguard Worker "]", CHECK_LOCATION());
35*89c4ff92SAndroid Build Coastguard Worker }
36*89c4ff92SAndroid Build Coastguard Worker }
37*89c4ff92SAndroid Build Coastguard Worker
GetTensorInfo(unsigned int numberOfBatches,unsigned int numberOfChannels,unsigned int height,unsigned int width,const DataLayout dataLayout,const DataType dataType)38*89c4ff92SAndroid Build Coastguard Worker TensorInfo GetTensorInfo(unsigned int numberOfBatches,
39*89c4ff92SAndroid Build Coastguard Worker unsigned int numberOfChannels,
40*89c4ff92SAndroid Build Coastguard Worker unsigned int height,
41*89c4ff92SAndroid Build Coastguard Worker unsigned int width,
42*89c4ff92SAndroid Build Coastguard Worker const DataLayout dataLayout,
43*89c4ff92SAndroid Build Coastguard Worker const DataType dataType)
44*89c4ff92SAndroid Build Coastguard Worker {
45*89c4ff92SAndroid Build Coastguard Worker switch (dataLayout)
46*89c4ff92SAndroid Build Coastguard Worker {
47*89c4ff92SAndroid Build Coastguard Worker case DataLayout::NCHW:
48*89c4ff92SAndroid Build Coastguard Worker return TensorInfo({numberOfBatches, numberOfChannels, height, width}, dataType);
49*89c4ff92SAndroid Build Coastguard Worker case DataLayout::NHWC:
50*89c4ff92SAndroid Build Coastguard Worker return TensorInfo({numberOfBatches, height, width, numberOfChannels}, dataType);
51*89c4ff92SAndroid Build Coastguard Worker default:
52*89c4ff92SAndroid Build Coastguard Worker throw InvalidArgumentException("Unknown data layout ["
53*89c4ff92SAndroid Build Coastguard Worker + std::to_string(static_cast<int>(dataLayout)) +
54*89c4ff92SAndroid Build Coastguard Worker "]", CHECK_LOCATION());
55*89c4ff92SAndroid Build Coastguard Worker }
56*89c4ff92SAndroid Build Coastguard Worker }
57*89c4ff92SAndroid Build Coastguard Worker
GetTensorInfo(unsigned int numberOfBatches,unsigned int numberOfChannels,unsigned int depth,unsigned int height,unsigned int width,const DataLayout dataLayout,const DataType dataType)58*89c4ff92SAndroid Build Coastguard Worker TensorInfo GetTensorInfo(unsigned int numberOfBatches,
59*89c4ff92SAndroid Build Coastguard Worker unsigned int numberOfChannels,
60*89c4ff92SAndroid Build Coastguard Worker unsigned int depth,
61*89c4ff92SAndroid Build Coastguard Worker unsigned int height,
62*89c4ff92SAndroid Build Coastguard Worker unsigned int width,
63*89c4ff92SAndroid Build Coastguard Worker const DataLayout dataLayout,
64*89c4ff92SAndroid Build Coastguard Worker const DataType dataType)
65*89c4ff92SAndroid Build Coastguard Worker {
66*89c4ff92SAndroid Build Coastguard Worker switch (dataLayout)
67*89c4ff92SAndroid Build Coastguard Worker {
68*89c4ff92SAndroid Build Coastguard Worker case DataLayout::NDHWC:
69*89c4ff92SAndroid Build Coastguard Worker return TensorInfo({numberOfBatches, depth, height, width, numberOfChannels}, dataType);
70*89c4ff92SAndroid Build Coastguard Worker case DataLayout::NCDHW:
71*89c4ff92SAndroid Build Coastguard Worker return TensorInfo({numberOfBatches, numberOfChannels, depth, height, width}, dataType);
72*89c4ff92SAndroid Build Coastguard Worker default:
73*89c4ff92SAndroid Build Coastguard Worker throw InvalidArgumentException("Unknown data layout ["
74*89c4ff92SAndroid Build Coastguard Worker + std::to_string(static_cast<int>(dataLayout)) +
75*89c4ff92SAndroid Build Coastguard Worker "]", CHECK_LOCATION());
76*89c4ff92SAndroid Build Coastguard Worker }
77*89c4ff92SAndroid Build Coastguard Worker }
78*89c4ff92SAndroid Build Coastguard Worker
FindMinMax(ITensorHandle * tensorHandle)79*89c4ff92SAndroid Build Coastguard Worker std::pair<float, float> FindMinMax(ITensorHandle* tensorHandle)
80*89c4ff92SAndroid Build Coastguard Worker {
81*89c4ff92SAndroid Build Coastguard Worker auto tensor_data = static_cast<const float *>(tensorHandle->Map(true));
82*89c4ff92SAndroid Build Coastguard Worker auto tensor_size = tensorHandle->GetShape().GetNumElements();
83*89c4ff92SAndroid Build Coastguard Worker
84*89c4ff92SAndroid Build Coastguard Worker // Set min/max initially to first value in tensor
85*89c4ff92SAndroid Build Coastguard Worker float min = tensor_data[0];
86*89c4ff92SAndroid Build Coastguard Worker float max = tensor_data[0];
87*89c4ff92SAndroid Build Coastguard Worker
88*89c4ff92SAndroid Build Coastguard Worker // Loop over rest of tensor and update min/max if necessary
89*89c4ff92SAndroid Build Coastguard Worker for (unsigned int val = 1; val < tensor_size; val++)
90*89c4ff92SAndroid Build Coastguard Worker {
91*89c4ff92SAndroid Build Coastguard Worker if (tensor_data[val] < min)
92*89c4ff92SAndroid Build Coastguard Worker {
93*89c4ff92SAndroid Build Coastguard Worker min = tensor_data[val];
94*89c4ff92SAndroid Build Coastguard Worker }
95*89c4ff92SAndroid Build Coastguard Worker else if (tensor_data[val] > max)
96*89c4ff92SAndroid Build Coastguard Worker {
97*89c4ff92SAndroid Build Coastguard Worker max = tensor_data[val];
98*89c4ff92SAndroid Build Coastguard Worker }
99*89c4ff92SAndroid Build Coastguard Worker }
100*89c4ff92SAndroid Build Coastguard Worker
101*89c4ff92SAndroid Build Coastguard Worker tensorHandle->Unmap();
102*89c4ff92SAndroid Build Coastguard Worker
103*89c4ff92SAndroid Build Coastguard Worker return std::make_pair(min, max);
104*89c4ff92SAndroid Build Coastguard Worker }
105*89c4ff92SAndroid Build Coastguard Worker
ReduceDims(const TensorShape & tensorShape,unsigned int dimensions)106*89c4ff92SAndroid Build Coastguard Worker TensorShape ReduceDims(const TensorShape& tensorShape, unsigned int dimensions)
107*89c4ff92SAndroid Build Coastguard Worker {
108*89c4ff92SAndroid Build Coastguard Worker if (tensorShape.GetNumDimensions() <= dimensions)
109*89c4ff92SAndroid Build Coastguard Worker {
110*89c4ff92SAndroid Build Coastguard Worker return tensorShape;
111*89c4ff92SAndroid Build Coastguard Worker }
112*89c4ff92SAndroid Build Coastguard Worker std::vector<unsigned int> newShape;
113*89c4ff92SAndroid Build Coastguard Worker
114*89c4ff92SAndroid Build Coastguard Worker unsigned int dimsToSkip = tensorShape.GetNumDimensions() - dimensions;
115*89c4ff92SAndroid Build Coastguard Worker unsigned int dimsSkipped = 0;
116*89c4ff92SAndroid Build Coastguard Worker bool insertRemainder = false;
117*89c4ff92SAndroid Build Coastguard Worker
118*89c4ff92SAndroid Build Coastguard Worker for (unsigned int i = 0; i < tensorShape.GetNumDimensions(); ++i)
119*89c4ff92SAndroid Build Coastguard Worker {
120*89c4ff92SAndroid Build Coastguard Worker if (tensorShape[i] == 1 && dimsSkipped < dimsToSkip && !insertRemainder)
121*89c4ff92SAndroid Build Coastguard Worker {
122*89c4ff92SAndroid Build Coastguard Worker ++dimsSkipped;
123*89c4ff92SAndroid Build Coastguard Worker continue;
124*89c4ff92SAndroid Build Coastguard Worker }
125*89c4ff92SAndroid Build Coastguard Worker newShape.push_back(tensorShape[i]);
126*89c4ff92SAndroid Build Coastguard Worker // Once we insert the first dimension we can't skip any more
127*89c4ff92SAndroid Build Coastguard Worker insertRemainder = true;
128*89c4ff92SAndroid Build Coastguard Worker }
129*89c4ff92SAndroid Build Coastguard Worker return TensorShape(static_cast<unsigned int>(newShape.size()), newShape.data());
130*89c4ff92SAndroid Build Coastguard Worker }
131*89c4ff92SAndroid Build Coastguard Worker
ReduceDims(const TensorInfo & tensorInfo,unsigned int dimensions)132*89c4ff92SAndroid Build Coastguard Worker TensorInfo ReduceDims(const TensorInfo& tensorInfo, unsigned int dimensions)
133*89c4ff92SAndroid Build Coastguard Worker {
134*89c4ff92SAndroid Build Coastguard Worker TensorInfo strippedTensor(tensorInfo);
135*89c4ff92SAndroid Build Coastguard Worker TensorShape strippedShape = ReduceDims(tensorInfo.GetShape(), dimensions);
136*89c4ff92SAndroid Build Coastguard Worker strippedTensor.SetShape(strippedShape);
137*89c4ff92SAndroid Build Coastguard Worker return strippedTensor;
138*89c4ff92SAndroid Build Coastguard Worker }
139*89c4ff92SAndroid Build Coastguard Worker
ExpandDims(const TensorShape & tensorShape,int axis)140*89c4ff92SAndroid Build Coastguard Worker TensorShape ExpandDims(const TensorShape& tensorShape, int axis)
141*89c4ff92SAndroid Build Coastguard Worker {
142*89c4ff92SAndroid Build Coastguard Worker unsigned int outputDim = tensorShape.GetNumDimensions() + 1;
143*89c4ff92SAndroid Build Coastguard Worker
144*89c4ff92SAndroid Build Coastguard Worker if (axis < -armnn::numeric_cast<int>(outputDim) || axis > armnn::numeric_cast<int>(tensorShape.GetNumDimensions()))
145*89c4ff92SAndroid Build Coastguard Worker {
146*89c4ff92SAndroid Build Coastguard Worker throw InvalidArgumentException(fmt::format("Invalid expansion axis {} for {}D input tensor. {}",
147*89c4ff92SAndroid Build Coastguard Worker axis,
148*89c4ff92SAndroid Build Coastguard Worker tensorShape.GetNumDimensions(),
149*89c4ff92SAndroid Build Coastguard Worker CHECK_LOCATION().AsString()));
150*89c4ff92SAndroid Build Coastguard Worker }
151*89c4ff92SAndroid Build Coastguard Worker
152*89c4ff92SAndroid Build Coastguard Worker if (axis < 0)
153*89c4ff92SAndroid Build Coastguard Worker {
154*89c4ff92SAndroid Build Coastguard Worker axis = armnn::numeric_cast<int>(outputDim) + axis;
155*89c4ff92SAndroid Build Coastguard Worker }
156*89c4ff92SAndroid Build Coastguard Worker
157*89c4ff92SAndroid Build Coastguard Worker std::vector<unsigned int> outputShape;
158*89c4ff92SAndroid Build Coastguard Worker outputShape.reserve(tensorShape.GetNumDimensions());
159*89c4ff92SAndroid Build Coastguard Worker for (unsigned int i = 0; i < tensorShape.GetNumDimensions(); ++i)
160*89c4ff92SAndroid Build Coastguard Worker {
161*89c4ff92SAndroid Build Coastguard Worker outputShape.push_back(tensorShape[i]);
162*89c4ff92SAndroid Build Coastguard Worker }
163*89c4ff92SAndroid Build Coastguard Worker outputShape.insert(outputShape.begin() + axis, 1);
164*89c4ff92SAndroid Build Coastguard Worker
165*89c4ff92SAndroid Build Coastguard Worker return { outputDim, outputShape.data() };
166*89c4ff92SAndroid Build Coastguard Worker }
167*89c4ff92SAndroid Build Coastguard Worker
ExpandDimsToRank(const TensorShape & tensorShape,unsigned int rank)168*89c4ff92SAndroid Build Coastguard Worker TensorShape ExpandDimsToRank(const TensorShape& tensorShape, unsigned int rank)
169*89c4ff92SAndroid Build Coastguard Worker {
170*89c4ff92SAndroid Build Coastguard Worker // Can't expand if rank is smaller than current shape
171*89c4ff92SAndroid Build Coastguard Worker if (tensorShape.GetNumDimensions() >= rank)
172*89c4ff92SAndroid Build Coastguard Worker {
173*89c4ff92SAndroid Build Coastguard Worker return tensorShape;
174*89c4ff92SAndroid Build Coastguard Worker }
175*89c4ff92SAndroid Build Coastguard Worker
176*89c4ff92SAndroid Build Coastguard Worker std::vector<unsigned int> newShape;
177*89c4ff92SAndroid Build Coastguard Worker
178*89c4ff92SAndroid Build Coastguard Worker // First add 1s to the beginning of the tensorInfo to fill in the space
179*89c4ff92SAndroid Build Coastguard Worker for (unsigned int i = 0; i < rank - tensorShape.GetNumDimensions(); ++i)
180*89c4ff92SAndroid Build Coastguard Worker {
181*89c4ff92SAndroid Build Coastguard Worker newShape.push_back(1);
182*89c4ff92SAndroid Build Coastguard Worker }
183*89c4ff92SAndroid Build Coastguard Worker
184*89c4ff92SAndroid Build Coastguard Worker // Then iterate through the original shape and append it to the new shape with the added 1s
185*89c4ff92SAndroid Build Coastguard Worker for (unsigned int i = 0; i < tensorShape.GetNumDimensions(); ++i)
186*89c4ff92SAndroid Build Coastguard Worker {
187*89c4ff92SAndroid Build Coastguard Worker newShape.push_back(tensorShape[i]);
188*89c4ff92SAndroid Build Coastguard Worker }
189*89c4ff92SAndroid Build Coastguard Worker
190*89c4ff92SAndroid Build Coastguard Worker return TensorShape(static_cast<unsigned int>(newShape.size()), newShape.data());
191*89c4ff92SAndroid Build Coastguard Worker }
192*89c4ff92SAndroid Build Coastguard Worker
SqueezeDims(const TensorShape & tensorShape)193*89c4ff92SAndroid Build Coastguard Worker std::vector<unsigned int> SqueezeDims(const TensorShape& tensorShape)
194*89c4ff92SAndroid Build Coastguard Worker {
195*89c4ff92SAndroid Build Coastguard Worker std::vector<unsigned int> squeezedDims;
196*89c4ff92SAndroid Build Coastguard Worker
197*89c4ff92SAndroid Build Coastguard Worker for (unsigned int i = 0; i < tensorShape.GetNumDimensions(); ++i)
198*89c4ff92SAndroid Build Coastguard Worker {
199*89c4ff92SAndroid Build Coastguard Worker if (tensorShape[i] != 1)
200*89c4ff92SAndroid Build Coastguard Worker {
201*89c4ff92SAndroid Build Coastguard Worker squeezedDims.push_back(tensorShape[i]);
202*89c4ff92SAndroid Build Coastguard Worker }
203*89c4ff92SAndroid Build Coastguard Worker }
204*89c4ff92SAndroid Build Coastguard Worker return squeezedDims;
205*89c4ff92SAndroid Build Coastguard Worker }
206*89c4ff92SAndroid Build Coastguard Worker
GetNumElementsBetween(const TensorShape & shape,const unsigned int firstAxisInclusive,const unsigned int lastAxisExclusive)207*89c4ff92SAndroid Build Coastguard Worker unsigned int GetNumElementsBetween(const TensorShape& shape,
208*89c4ff92SAndroid Build Coastguard Worker const unsigned int firstAxisInclusive,
209*89c4ff92SAndroid Build Coastguard Worker const unsigned int lastAxisExclusive)
210*89c4ff92SAndroid Build Coastguard Worker {
211*89c4ff92SAndroid Build Coastguard Worker ARMNN_ASSERT(firstAxisInclusive <= lastAxisExclusive);
212*89c4ff92SAndroid Build Coastguard Worker ARMNN_ASSERT(lastAxisExclusive <= shape.GetNumDimensions());
213*89c4ff92SAndroid Build Coastguard Worker unsigned int count = 1;
214*89c4ff92SAndroid Build Coastguard Worker for (unsigned int i = firstAxisInclusive; i < lastAxisExclusive; i++)
215*89c4ff92SAndroid Build Coastguard Worker {
216*89c4ff92SAndroid Build Coastguard Worker count *= shape[i];
217*89c4ff92SAndroid Build Coastguard Worker }
218*89c4ff92SAndroid Build Coastguard Worker return count;
219*89c4ff92SAndroid Build Coastguard Worker }
220*89c4ff92SAndroid Build Coastguard Worker
GetUnsignedAxis(const unsigned int inputDimension,const int axis)221*89c4ff92SAndroid Build Coastguard Worker unsigned int GetUnsignedAxis(const unsigned int inputDimension, const int axis)
222*89c4ff92SAndroid Build Coastguard Worker {
223*89c4ff92SAndroid Build Coastguard Worker ARMNN_ASSERT_MSG(axis < armnn::numeric_cast<int>(inputDimension),
224*89c4ff92SAndroid Build Coastguard Worker "Required axis index greater than number of dimensions.");
225*89c4ff92SAndroid Build Coastguard Worker ARMNN_ASSERT_MSG(axis >= -armnn::numeric_cast<int>(inputDimension),
226*89c4ff92SAndroid Build Coastguard Worker "Required axis index lower than negative of the number of dimensions");
227*89c4ff92SAndroid Build Coastguard Worker
228*89c4ff92SAndroid Build Coastguard Worker unsigned int uAxis = axis < 0 ?
229*89c4ff92SAndroid Build Coastguard Worker inputDimension - armnn::numeric_cast<unsigned int>(abs(axis))
230*89c4ff92SAndroid Build Coastguard Worker : armnn::numeric_cast<unsigned int>(axis);
231*89c4ff92SAndroid Build Coastguard Worker return uAxis;
232*89c4ff92SAndroid Build Coastguard Worker }
233*89c4ff92SAndroid Build Coastguard Worker
GetNumElementsAfter(const armnn::TensorShape & shape,unsigned int axis)234*89c4ff92SAndroid Build Coastguard Worker unsigned int GetNumElementsAfter(const armnn::TensorShape& shape, unsigned int axis)
235*89c4ff92SAndroid Build Coastguard Worker {
236*89c4ff92SAndroid Build Coastguard Worker unsigned int numDim = shape.GetNumDimensions();
237*89c4ff92SAndroid Build Coastguard Worker ARMNN_ASSERT(axis <= numDim - 1);
238*89c4ff92SAndroid Build Coastguard Worker unsigned int count = 1;
239*89c4ff92SAndroid Build Coastguard Worker for (unsigned int i = axis+1; i < numDim; i++)
240*89c4ff92SAndroid Build Coastguard Worker {
241*89c4ff92SAndroid Build Coastguard Worker count *= shape[i];
242*89c4ff92SAndroid Build Coastguard Worker }
243*89c4ff92SAndroid Build Coastguard Worker return count;
244*89c4ff92SAndroid Build Coastguard Worker }
245*89c4ff92SAndroid Build Coastguard Worker
GetPerAxisParams(const armnn::TensorInfo & info)246*89c4ff92SAndroid Build Coastguard Worker std::pair<unsigned int, std::vector<float>> GetPerAxisParams(const armnn::TensorInfo& info)
247*89c4ff92SAndroid Build Coastguard Worker {
248*89c4ff92SAndroid Build Coastguard Worker const std::vector<float>& scales = info.GetQuantizationScales();
249*89c4ff92SAndroid Build Coastguard Worker armnn::Optional<unsigned int> quantizationDim = info.GetQuantizationDim();
250*89c4ff92SAndroid Build Coastguard Worker if (!info.HasPerAxisQuantization())
251*89c4ff92SAndroid Build Coastguard Worker {
252*89c4ff92SAndroid Build Coastguard Worker throw armnn::InvalidArgumentException(
253*89c4ff92SAndroid Build Coastguard Worker std::string("Per-axis quantization params not set for tensor of type ") +
254*89c4ff92SAndroid Build Coastguard Worker armnn::GetDataTypeName(info.GetDataType()), CHECK_LOCATION());
255*89c4ff92SAndroid Build Coastguard Worker }
256*89c4ff92SAndroid Build Coastguard Worker unsigned int axisFactor = GetNumElementsAfter(info.GetShape(), quantizationDim.value()) ;
257*89c4ff92SAndroid Build Coastguard Worker
258*89c4ff92SAndroid Build Coastguard Worker return { axisFactor, scales };
259*89c4ff92SAndroid Build Coastguard Worker }
260*89c4ff92SAndroid Build Coastguard Worker
261*89c4ff92SAndroid Build Coastguard Worker template<typename PrimitiveType>
CheckSizes(const std::vector<PrimitiveType> & data,const armnn::TensorInfo & tensorInfo,unsigned int size=1)262*89c4ff92SAndroid Build Coastguard Worker void CheckSizes(const std::vector<PrimitiveType>& data, const armnn::TensorInfo& tensorInfo, unsigned int size = 1)
263*89c4ff92SAndroid Build Coastguard Worker {
264*89c4ff92SAndroid Build Coastguard Worker if (data.size() / size != tensorInfo.GetNumElements())
265*89c4ff92SAndroid Build Coastguard Worker {
266*89c4ff92SAndroid Build Coastguard Worker throw InvalidArgumentException(
267*89c4ff92SAndroid Build Coastguard Worker fmt::format("The data does not contain the expected number of elements {} != {}. {}",
268*89c4ff92SAndroid Build Coastguard Worker data.size(), tensorInfo.GetNumElements(), CHECK_LOCATION().AsString()));
269*89c4ff92SAndroid Build Coastguard Worker }
270*89c4ff92SAndroid Build Coastguard Worker }
271*89c4ff92SAndroid Build Coastguard Worker
272*89c4ff92SAndroid Build Coastguard Worker template<typename PrimitiveType>
ToFloatArray(const std::vector<PrimitiveType> & data,const armnn::TensorInfo & tensorInfo)273*89c4ff92SAndroid Build Coastguard Worker std::unique_ptr<float[]> ToFloatArray(const std::vector<PrimitiveType>& data, const armnn::TensorInfo& tensorInfo)
274*89c4ff92SAndroid Build Coastguard Worker {
275*89c4ff92SAndroid Build Coastguard Worker CheckSizes(data, tensorInfo);
276*89c4ff92SAndroid Build Coastguard Worker
277*89c4ff92SAndroid Build Coastguard Worker std::unique_ptr<float[]> returnBuffer(new float[tensorInfo.GetNumElements()]);
278*89c4ff92SAndroid Build Coastguard Worker
279*89c4ff92SAndroid Build Coastguard Worker if (tensorInfo.HasPerAxisQuantization())
280*89c4ff92SAndroid Build Coastguard Worker {
281*89c4ff92SAndroid Build Coastguard Worker unsigned int axis = tensorInfo.GetQuantizationDim().value();
282*89c4ff92SAndroid Build Coastguard Worker auto axisDimensionality = tensorInfo.GetShape()[axis];
283*89c4ff92SAndroid Build Coastguard Worker auto axisFactor = armnnUtils::GetNumElementsAfter(tensorInfo.GetShape(), axis);
284*89c4ff92SAndroid Build Coastguard Worker
285*89c4ff92SAndroid Build Coastguard Worker for (unsigned int i = 0; i < tensorInfo.GetNumElements(); ++i)
286*89c4ff92SAndroid Build Coastguard Worker {
287*89c4ff92SAndroid Build Coastguard Worker unsigned int axisIndex;
288*89c4ff92SAndroid Build Coastguard Worker
289*89c4ff92SAndroid Build Coastguard Worker if (i < axisFactor)
290*89c4ff92SAndroid Build Coastguard Worker {
291*89c4ff92SAndroid Build Coastguard Worker axisIndex = 0;
292*89c4ff92SAndroid Build Coastguard Worker }
293*89c4ff92SAndroid Build Coastguard Worker else
294*89c4ff92SAndroid Build Coastguard Worker {
295*89c4ff92SAndroid Build Coastguard Worker axisIndex = (i / axisFactor) % axisDimensionality;
296*89c4ff92SAndroid Build Coastguard Worker }
297*89c4ff92SAndroid Build Coastguard Worker returnBuffer[i] = Dequantize<PrimitiveType>(data[i],
298*89c4ff92SAndroid Build Coastguard Worker tensorInfo.GetQuantizationScales()[axisIndex],
299*89c4ff92SAndroid Build Coastguard Worker tensorInfo.GetQuantizationOffset());
300*89c4ff92SAndroid Build Coastguard Worker }
301*89c4ff92SAndroid Build Coastguard Worker }
302*89c4ff92SAndroid Build Coastguard Worker else
303*89c4ff92SAndroid Build Coastguard Worker {
304*89c4ff92SAndroid Build Coastguard Worker for (unsigned int i = 0; i < tensorInfo.GetNumElements(); ++i)
305*89c4ff92SAndroid Build Coastguard Worker {
306*89c4ff92SAndroid Build Coastguard Worker returnBuffer[i] = Dequantize<PrimitiveType>(data[i],
307*89c4ff92SAndroid Build Coastguard Worker tensorInfo.GetQuantizationScale(),
308*89c4ff92SAndroid Build Coastguard Worker tensorInfo.GetQuantizationOffset());
309*89c4ff92SAndroid Build Coastguard Worker }
310*89c4ff92SAndroid Build Coastguard Worker }
311*89c4ff92SAndroid Build Coastguard Worker return returnBuffer;
312*89c4ff92SAndroid Build Coastguard Worker }
313*89c4ff92SAndroid Build Coastguard Worker
ToFloatArray(const std::vector<uint8_t> & data,const armnn::TensorInfo & tensorInfo)314*89c4ff92SAndroid Build Coastguard Worker std::unique_ptr<float[]> ToFloatArray(const std::vector<uint8_t>& data, const armnn::TensorInfo& tensorInfo)
315*89c4ff92SAndroid Build Coastguard Worker {
316*89c4ff92SAndroid Build Coastguard Worker if (tensorInfo.GetDataType() == DataType::QAsymmS8 || tensorInfo.GetDataType() == DataType::QSymmS8)
317*89c4ff92SAndroid Build Coastguard Worker {
318*89c4ff92SAndroid Build Coastguard Worker CheckSizes(data, tensorInfo);
319*89c4ff92SAndroid Build Coastguard Worker std::vector<int8_t> buffer(tensorInfo.GetNumElements());
320*89c4ff92SAndroid Build Coastguard Worker ::memcpy(buffer.data(), data.data(), data.size());
321*89c4ff92SAndroid Build Coastguard Worker return ToFloatArray<int8_t>(buffer, tensorInfo);
322*89c4ff92SAndroid Build Coastguard Worker }
323*89c4ff92SAndroid Build Coastguard Worker else if (tensorInfo.GetDataType() == DataType::QAsymmU8)
324*89c4ff92SAndroid Build Coastguard Worker {
325*89c4ff92SAndroid Build Coastguard Worker CheckSizes(data, tensorInfo);
326*89c4ff92SAndroid Build Coastguard Worker return ToFloatArray<uint8_t>(data, tensorInfo);
327*89c4ff92SAndroid Build Coastguard Worker }
328*89c4ff92SAndroid Build Coastguard Worker else if (tensorInfo.GetDataType() == DataType::Signed32)
329*89c4ff92SAndroid Build Coastguard Worker {
330*89c4ff92SAndroid Build Coastguard Worker CheckSizes(data, tensorInfo, 4);
331*89c4ff92SAndroid Build Coastguard Worker std::vector<int32_t> buffer(tensorInfo.GetNumElements());
332*89c4ff92SAndroid Build Coastguard Worker ::memcpy(buffer.data(), data.data(), data.size());
333*89c4ff92SAndroid Build Coastguard Worker return ToFloatArray<int32_t>(buffer, tensorInfo);
334*89c4ff92SAndroid Build Coastguard Worker }
335*89c4ff92SAndroid Build Coastguard Worker else if (tensorInfo.GetDataType() == DataType::Signed64)
336*89c4ff92SAndroid Build Coastguard Worker {
337*89c4ff92SAndroid Build Coastguard Worker CheckSizes(data, tensorInfo, 8);
338*89c4ff92SAndroid Build Coastguard Worker std::vector<int64_t> buffer(tensorInfo.GetNumElements());
339*89c4ff92SAndroid Build Coastguard Worker ::memcpy(buffer.data(), data.data(), data.size());
340*89c4ff92SAndroid Build Coastguard Worker return ToFloatArray<int64_t>(buffer, tensorInfo);
341*89c4ff92SAndroid Build Coastguard Worker }
342*89c4ff92SAndroid Build Coastguard Worker throw InvalidArgumentException(
343*89c4ff92SAndroid Build Coastguard Worker fmt::format("Unsupported datatype {}. {}",
344*89c4ff92SAndroid Build Coastguard Worker GetDataTypeName(tensorInfo.GetDataType()),
345*89c4ff92SAndroid Build Coastguard Worker CHECK_LOCATION().AsString()));
346*89c4ff92SAndroid Build Coastguard Worker }
347*89c4ff92SAndroid Build Coastguard Worker
348*89c4ff92SAndroid Build Coastguard Worker } // namespace armnnUtils
349