1*89c4ff92SAndroid Build Coastguard Worker //
2*89c4ff92SAndroid Build Coastguard Worker // Copyright © 2017 Arm Ltd. All rights reserved.
3*89c4ff92SAndroid Build Coastguard Worker // SPDX-License-Identifier: MIT
4*89c4ff92SAndroid Build Coastguard Worker //
5*89c4ff92SAndroid Build Coastguard Worker
6*89c4ff92SAndroid Build Coastguard Worker #include "ParserHelper.hpp"
7*89c4ff92SAndroid Build Coastguard Worker
8*89c4ff92SAndroid Build Coastguard Worker #include <armnn/Descriptors.hpp>
9*89c4ff92SAndroid Build Coastguard Worker #include <armnnUtils/Permute.hpp>
10*89c4ff92SAndroid Build Coastguard Worker
11*89c4ff92SAndroid Build Coastguard Worker #include <fmt/format.h>
12*89c4ff92SAndroid Build Coastguard Worker
13*89c4ff92SAndroid Build Coastguard Worker namespace armnnUtils
14*89c4ff92SAndroid Build Coastguard Worker {
15*89c4ff92SAndroid Build Coastguard Worker
16*89c4ff92SAndroid Build Coastguard Worker const armnn::PermutationVector NHWCToArmNN = { 0, 2, 3, 1 };
17*89c4ff92SAndroid Build Coastguard Worker const armnn::PermutationVector ArmNNToNHWC = { 0, 3, 1, 2 };
18*89c4ff92SAndroid Build Coastguard Worker
ProcessConcatInputTensorInfo(armnn::TensorInfo & inputTensorInfo,armnn::OriginsDescriptor & concatDescriptor,const unsigned int & concatAxis,unsigned int inputIndex,unsigned int & mergeDimOrigin)19*89c4ff92SAndroid Build Coastguard Worker void ProcessConcatInputTensorInfo(armnn::TensorInfo& inputTensorInfo,
20*89c4ff92SAndroid Build Coastguard Worker armnn::OriginsDescriptor& concatDescriptor,
21*89c4ff92SAndroid Build Coastguard Worker const unsigned int& concatAxis,
22*89c4ff92SAndroid Build Coastguard Worker unsigned int inputIndex,
23*89c4ff92SAndroid Build Coastguard Worker unsigned int& mergeDimOrigin)
24*89c4ff92SAndroid Build Coastguard Worker {
25*89c4ff92SAndroid Build Coastguard Worker const uint32_t inputRank = concatDescriptor.GetNumDimensions();
26*89c4ff92SAndroid Build Coastguard Worker
27*89c4ff92SAndroid Build Coastguard Worker // double check dimensions of the tensors
28*89c4ff92SAndroid Build Coastguard Worker if (inputTensorInfo.GetNumDimensions() != inputRank)
29*89c4ff92SAndroid Build Coastguard Worker {
30*89c4ff92SAndroid Build Coastguard Worker throw armnn::ParseException(fmt::format(
31*89c4ff92SAndroid Build Coastguard Worker "The number of dimensions: {0} for input tensors of the "
32*89c4ff92SAndroid Build Coastguard Worker "concatenation op should be {1} {2}",
33*89c4ff92SAndroid Build Coastguard Worker inputTensorInfo.GetNumDimensions(),
34*89c4ff92SAndroid Build Coastguard Worker inputRank,
35*89c4ff92SAndroid Build Coastguard Worker CHECK_LOCATION().AsString()));
36*89c4ff92SAndroid Build Coastguard Worker }
37*89c4ff92SAndroid Build Coastguard Worker
38*89c4ff92SAndroid Build Coastguard Worker for (unsigned int j = 0; j < concatAxis; ++j)
39*89c4ff92SAndroid Build Coastguard Worker {
40*89c4ff92SAndroid Build Coastguard Worker concatDescriptor.SetViewOriginCoord(inputIndex, j, 0);
41*89c4ff92SAndroid Build Coastguard Worker }
42*89c4ff92SAndroid Build Coastguard Worker
43*89c4ff92SAndroid Build Coastguard Worker concatDescriptor.SetViewOriginCoord(inputIndex, concatAxis, mergeDimOrigin);
44*89c4ff92SAndroid Build Coastguard Worker mergeDimOrigin += inputTensorInfo.GetShape()[concatAxis];
45*89c4ff92SAndroid Build Coastguard Worker
46*89c4ff92SAndroid Build Coastguard Worker for (unsigned int j = concatAxis + 1; j < inputRank; ++j)
47*89c4ff92SAndroid Build Coastguard Worker {
48*89c4ff92SAndroid Build Coastguard Worker concatDescriptor.SetViewOriginCoord(inputIndex, j, 0);
49*89c4ff92SAndroid Build Coastguard Worker }
50*89c4ff92SAndroid Build Coastguard Worker }
51*89c4ff92SAndroid Build Coastguard Worker
CalculateReducedOutputTensoInfo(const armnn::TensorInfo & inputTensorInfo,const std::set<unsigned int> & axisSet,bool keepDims,armnn::TensorInfo & outputTensorInfo)52*89c4ff92SAndroid Build Coastguard Worker void CalculateReducedOutputTensoInfo(const armnn::TensorInfo& inputTensorInfo,
53*89c4ff92SAndroid Build Coastguard Worker const std::set<unsigned int>& axisSet,
54*89c4ff92SAndroid Build Coastguard Worker bool keepDims,
55*89c4ff92SAndroid Build Coastguard Worker armnn::TensorInfo& outputTensorInfo)
56*89c4ff92SAndroid Build Coastguard Worker {
57*89c4ff92SAndroid Build Coastguard Worker std::vector<unsigned int> outputShapeVector;
58*89c4ff92SAndroid Build Coastguard Worker bool dimensionFound = false;
59*89c4ff92SAndroid Build Coastguard Worker unsigned int size = 1;
60*89c4ff92SAndroid Build Coastguard Worker
61*89c4ff92SAndroid Build Coastguard Worker for (unsigned int i = 0; i < inputTensorInfo.GetNumDimensions(); ++i)
62*89c4ff92SAndroid Build Coastguard Worker {
63*89c4ff92SAndroid Build Coastguard Worker dimensionFound = false;
64*89c4ff92SAndroid Build Coastguard Worker for (unsigned int axis: axisSet)
65*89c4ff92SAndroid Build Coastguard Worker {
66*89c4ff92SAndroid Build Coastguard Worker if (axis == i)
67*89c4ff92SAndroid Build Coastguard Worker {
68*89c4ff92SAndroid Build Coastguard Worker dimensionFound = true;
69*89c4ff92SAndroid Build Coastguard Worker break;
70*89c4ff92SAndroid Build Coastguard Worker }
71*89c4ff92SAndroid Build Coastguard Worker }
72*89c4ff92SAndroid Build Coastguard Worker
73*89c4ff92SAndroid Build Coastguard Worker if (!dimensionFound)
74*89c4ff92SAndroid Build Coastguard Worker {
75*89c4ff92SAndroid Build Coastguard Worker size *= inputTensorInfo.GetShape()[i];
76*89c4ff92SAndroid Build Coastguard Worker
77*89c4ff92SAndroid Build Coastguard Worker if (keepDims)
78*89c4ff92SAndroid Build Coastguard Worker {
79*89c4ff92SAndroid Build Coastguard Worker outputShapeVector.push_back(inputTensorInfo.GetShape()[i]);
80*89c4ff92SAndroid Build Coastguard Worker }
81*89c4ff92SAndroid Build Coastguard Worker }
82*89c4ff92SAndroid Build Coastguard Worker else
83*89c4ff92SAndroid Build Coastguard Worker {
84*89c4ff92SAndroid Build Coastguard Worker if (keepDims)
85*89c4ff92SAndroid Build Coastguard Worker {
86*89c4ff92SAndroid Build Coastguard Worker outputShapeVector.push_back(1);
87*89c4ff92SAndroid Build Coastguard Worker }
88*89c4ff92SAndroid Build Coastguard Worker }
89*89c4ff92SAndroid Build Coastguard Worker }
90*89c4ff92SAndroid Build Coastguard Worker
91*89c4ff92SAndroid Build Coastguard Worker if (keepDims)
92*89c4ff92SAndroid Build Coastguard Worker {
93*89c4ff92SAndroid Build Coastguard Worker armnn::TensorShape outputTensorShape(inputTensorInfo.GetNumDimensions(), &outputShapeVector[0]);
94*89c4ff92SAndroid Build Coastguard Worker outputTensorInfo = armnn::TensorInfo(outputTensorShape, inputTensorInfo.GetDataType());
95*89c4ff92SAndroid Build Coastguard Worker }
96*89c4ff92SAndroid Build Coastguard Worker else
97*89c4ff92SAndroid Build Coastguard Worker {
98*89c4ff92SAndroid Build Coastguard Worker outputTensorInfo = armnn::TensorInfo({size}, inputTensorInfo.GetDataType());
99*89c4ff92SAndroid Build Coastguard Worker }
100*89c4ff92SAndroid Build Coastguard Worker }
101*89c4ff92SAndroid Build Coastguard Worker
102*89c4ff92SAndroid Build Coastguard Worker
CalculateStridedSliceOutputTensorInfo(const armnn::TensorInfo & inputTensorInfo,const armnn::StridedSliceDescriptor & desc,armnn::TensorInfo & outputTensorInfo)103*89c4ff92SAndroid Build Coastguard Worker void CalculateStridedSliceOutputTensorInfo(const armnn::TensorInfo& inputTensorInfo,
104*89c4ff92SAndroid Build Coastguard Worker const armnn::StridedSliceDescriptor& desc,
105*89c4ff92SAndroid Build Coastguard Worker armnn::TensorInfo& outputTensorInfo)
106*89c4ff92SAndroid Build Coastguard Worker {
107*89c4ff92SAndroid Build Coastguard Worker const armnn::TensorShape& inputShape = inputTensorInfo.GetShape();
108*89c4ff92SAndroid Build Coastguard Worker
109*89c4ff92SAndroid Build Coastguard Worker std::vector<unsigned int> outputShapeVector;
110*89c4ff92SAndroid Build Coastguard Worker for (unsigned int i = 0; i < inputTensorInfo.GetNumDimensions(); i++)
111*89c4ff92SAndroid Build Coastguard Worker {
112*89c4ff92SAndroid Build Coastguard Worker if (desc.m_ShrinkAxisMask & (1 << i))
113*89c4ff92SAndroid Build Coastguard Worker {
114*89c4ff92SAndroid Build Coastguard Worker continue;
115*89c4ff92SAndroid Build Coastguard Worker }
116*89c4ff92SAndroid Build Coastguard Worker
117*89c4ff92SAndroid Build Coastguard Worker int stride = desc.m_Stride[i];
118*89c4ff92SAndroid Build Coastguard Worker int start = desc.GetStartForAxis(inputShape, i);
119*89c4ff92SAndroid Build Coastguard Worker int stop = desc.GetStopForAxis(inputShape, i, start);
120*89c4ff92SAndroid Build Coastguard Worker
121*89c4ff92SAndroid Build Coastguard Worker int newSize = stride > 0 ? ((stop - start) + stride - 1) / stride :
122*89c4ff92SAndroid Build Coastguard Worker ((start - stop) - stride - 1) / -stride;
123*89c4ff92SAndroid Build Coastguard Worker
124*89c4ff92SAndroid Build Coastguard Worker newSize = std::max(0, newSize);
125*89c4ff92SAndroid Build Coastguard Worker
126*89c4ff92SAndroid Build Coastguard Worker outputShapeVector.push_back(static_cast<unsigned int>(newSize));
127*89c4ff92SAndroid Build Coastguard Worker }
128*89c4ff92SAndroid Build Coastguard Worker
129*89c4ff92SAndroid Build Coastguard Worker armnn::TensorShape outputTensorShape(inputTensorInfo.GetNumDimensions(), &outputShapeVector[0]);
130*89c4ff92SAndroid Build Coastguard Worker outputTensorInfo = armnn::TensorInfo(armnn::TensorShape(outputTensorShape), inputTensorInfo.GetDataType());
131*89c4ff92SAndroid Build Coastguard Worker }
132*89c4ff92SAndroid Build Coastguard Worker } // namespace armnnUtils
133