xref: /aosp_15_r20/external/armnn/src/armnnUtils/ParserHelper.cpp (revision 89c4ff92f2867872bb9e2354d150bf0c8c502810)
1*89c4ff92SAndroid Build Coastguard Worker //
2*89c4ff92SAndroid Build Coastguard Worker // Copyright © 2017 Arm Ltd. All rights reserved.
3*89c4ff92SAndroid Build Coastguard Worker // SPDX-License-Identifier: MIT
4*89c4ff92SAndroid Build Coastguard Worker //
5*89c4ff92SAndroid Build Coastguard Worker 
6*89c4ff92SAndroid Build Coastguard Worker #include "ParserHelper.hpp"
7*89c4ff92SAndroid Build Coastguard Worker 
8*89c4ff92SAndroid Build Coastguard Worker #include <armnn/Descriptors.hpp>
9*89c4ff92SAndroid Build Coastguard Worker #include <armnnUtils/Permute.hpp>
10*89c4ff92SAndroid Build Coastguard Worker 
11*89c4ff92SAndroid Build Coastguard Worker #include <fmt/format.h>
12*89c4ff92SAndroid Build Coastguard Worker 
13*89c4ff92SAndroid Build Coastguard Worker namespace armnnUtils
14*89c4ff92SAndroid Build Coastguard Worker {
15*89c4ff92SAndroid Build Coastguard Worker 
16*89c4ff92SAndroid Build Coastguard Worker const armnn::PermutationVector NHWCToArmNN = { 0, 2, 3, 1 };
17*89c4ff92SAndroid Build Coastguard Worker const armnn::PermutationVector ArmNNToNHWC = { 0, 3, 1, 2 };
18*89c4ff92SAndroid Build Coastguard Worker 
ProcessConcatInputTensorInfo(armnn::TensorInfo & inputTensorInfo,armnn::OriginsDescriptor & concatDescriptor,const unsigned int & concatAxis,unsigned int inputIndex,unsigned int & mergeDimOrigin)19*89c4ff92SAndroid Build Coastguard Worker void ProcessConcatInputTensorInfo(armnn::TensorInfo& inputTensorInfo,
20*89c4ff92SAndroid Build Coastguard Worker                                   armnn::OriginsDescriptor& concatDescriptor,
21*89c4ff92SAndroid Build Coastguard Worker                                   const unsigned int& concatAxis,
22*89c4ff92SAndroid Build Coastguard Worker                                   unsigned int inputIndex,
23*89c4ff92SAndroid Build Coastguard Worker                                   unsigned int& mergeDimOrigin)
24*89c4ff92SAndroid Build Coastguard Worker {
25*89c4ff92SAndroid Build Coastguard Worker     const uint32_t inputRank = concatDescriptor.GetNumDimensions();
26*89c4ff92SAndroid Build Coastguard Worker 
27*89c4ff92SAndroid Build Coastguard Worker     // double check dimensions of the tensors
28*89c4ff92SAndroid Build Coastguard Worker     if (inputTensorInfo.GetNumDimensions() != inputRank)
29*89c4ff92SAndroid Build Coastguard Worker     {
30*89c4ff92SAndroid Build Coastguard Worker         throw armnn::ParseException(fmt::format(
31*89c4ff92SAndroid Build Coastguard Worker                                     "The number of dimensions: {0} for input tensors of the "
32*89c4ff92SAndroid Build Coastguard Worker                                     "concatenation op should be {1} {2}",
33*89c4ff92SAndroid Build Coastguard Worker                                     inputTensorInfo.GetNumDimensions(),
34*89c4ff92SAndroid Build Coastguard Worker                                     inputRank,
35*89c4ff92SAndroid Build Coastguard Worker                                     CHECK_LOCATION().AsString()));
36*89c4ff92SAndroid Build Coastguard Worker     }
37*89c4ff92SAndroid Build Coastguard Worker 
38*89c4ff92SAndroid Build Coastguard Worker     for (unsigned int j = 0; j < concatAxis; ++j)
39*89c4ff92SAndroid Build Coastguard Worker     {
40*89c4ff92SAndroid Build Coastguard Worker         concatDescriptor.SetViewOriginCoord(inputIndex, j, 0);
41*89c4ff92SAndroid Build Coastguard Worker     }
42*89c4ff92SAndroid Build Coastguard Worker 
43*89c4ff92SAndroid Build Coastguard Worker     concatDescriptor.SetViewOriginCoord(inputIndex, concatAxis, mergeDimOrigin);
44*89c4ff92SAndroid Build Coastguard Worker     mergeDimOrigin += inputTensorInfo.GetShape()[concatAxis];
45*89c4ff92SAndroid Build Coastguard Worker 
46*89c4ff92SAndroid Build Coastguard Worker     for (unsigned int j = concatAxis + 1; j < inputRank; ++j)
47*89c4ff92SAndroid Build Coastguard Worker     {
48*89c4ff92SAndroid Build Coastguard Worker         concatDescriptor.SetViewOriginCoord(inputIndex, j, 0);
49*89c4ff92SAndroid Build Coastguard Worker     }
50*89c4ff92SAndroid Build Coastguard Worker }
51*89c4ff92SAndroid Build Coastguard Worker 
CalculateReducedOutputTensoInfo(const armnn::TensorInfo & inputTensorInfo,const std::set<unsigned int> & axisSet,bool keepDims,armnn::TensorInfo & outputTensorInfo)52*89c4ff92SAndroid Build Coastguard Worker void CalculateReducedOutputTensoInfo(const armnn::TensorInfo& inputTensorInfo,
53*89c4ff92SAndroid Build Coastguard Worker                                      const std::set<unsigned int>& axisSet,
54*89c4ff92SAndroid Build Coastguard Worker                                      bool keepDims,
55*89c4ff92SAndroid Build Coastguard Worker                                      armnn::TensorInfo& outputTensorInfo)
56*89c4ff92SAndroid Build Coastguard Worker {
57*89c4ff92SAndroid Build Coastguard Worker     std::vector<unsigned int> outputShapeVector;
58*89c4ff92SAndroid Build Coastguard Worker     bool dimensionFound = false;
59*89c4ff92SAndroid Build Coastguard Worker     unsigned int size = 1;
60*89c4ff92SAndroid Build Coastguard Worker 
61*89c4ff92SAndroid Build Coastguard Worker     for (unsigned int i = 0; i < inputTensorInfo.GetNumDimensions(); ++i)
62*89c4ff92SAndroid Build Coastguard Worker     {
63*89c4ff92SAndroid Build Coastguard Worker         dimensionFound = false;
64*89c4ff92SAndroid Build Coastguard Worker         for (unsigned int axis: axisSet)
65*89c4ff92SAndroid Build Coastguard Worker         {
66*89c4ff92SAndroid Build Coastguard Worker             if (axis == i)
67*89c4ff92SAndroid Build Coastguard Worker             {
68*89c4ff92SAndroid Build Coastguard Worker                 dimensionFound = true;
69*89c4ff92SAndroid Build Coastguard Worker                 break;
70*89c4ff92SAndroid Build Coastguard Worker             }
71*89c4ff92SAndroid Build Coastguard Worker         }
72*89c4ff92SAndroid Build Coastguard Worker 
73*89c4ff92SAndroid Build Coastguard Worker         if (!dimensionFound)
74*89c4ff92SAndroid Build Coastguard Worker         {
75*89c4ff92SAndroid Build Coastguard Worker             size *= inputTensorInfo.GetShape()[i];
76*89c4ff92SAndroid Build Coastguard Worker 
77*89c4ff92SAndroid Build Coastguard Worker             if (keepDims)
78*89c4ff92SAndroid Build Coastguard Worker             {
79*89c4ff92SAndroid Build Coastguard Worker                 outputShapeVector.push_back(inputTensorInfo.GetShape()[i]);
80*89c4ff92SAndroid Build Coastguard Worker             }
81*89c4ff92SAndroid Build Coastguard Worker         }
82*89c4ff92SAndroid Build Coastguard Worker         else
83*89c4ff92SAndroid Build Coastguard Worker         {
84*89c4ff92SAndroid Build Coastguard Worker             if (keepDims)
85*89c4ff92SAndroid Build Coastguard Worker             {
86*89c4ff92SAndroid Build Coastguard Worker                 outputShapeVector.push_back(1);
87*89c4ff92SAndroid Build Coastguard Worker             }
88*89c4ff92SAndroid Build Coastguard Worker         }
89*89c4ff92SAndroid Build Coastguard Worker     }
90*89c4ff92SAndroid Build Coastguard Worker 
91*89c4ff92SAndroid Build Coastguard Worker     if (keepDims)
92*89c4ff92SAndroid Build Coastguard Worker     {
93*89c4ff92SAndroid Build Coastguard Worker         armnn::TensorShape outputTensorShape(inputTensorInfo.GetNumDimensions(), &outputShapeVector[0]);
94*89c4ff92SAndroid Build Coastguard Worker         outputTensorInfo = armnn::TensorInfo(outputTensorShape, inputTensorInfo.GetDataType());
95*89c4ff92SAndroid Build Coastguard Worker     }
96*89c4ff92SAndroid Build Coastguard Worker     else
97*89c4ff92SAndroid Build Coastguard Worker     {
98*89c4ff92SAndroid Build Coastguard Worker         outputTensorInfo = armnn::TensorInfo({size}, inputTensorInfo.GetDataType());
99*89c4ff92SAndroid Build Coastguard Worker     }
100*89c4ff92SAndroid Build Coastguard Worker }
101*89c4ff92SAndroid Build Coastguard Worker 
102*89c4ff92SAndroid Build Coastguard Worker 
CalculateStridedSliceOutputTensorInfo(const armnn::TensorInfo & inputTensorInfo,const armnn::StridedSliceDescriptor & desc,armnn::TensorInfo & outputTensorInfo)103*89c4ff92SAndroid Build Coastguard Worker void CalculateStridedSliceOutputTensorInfo(const armnn::TensorInfo& inputTensorInfo,
104*89c4ff92SAndroid Build Coastguard Worker                                            const armnn::StridedSliceDescriptor& desc,
105*89c4ff92SAndroid Build Coastguard Worker                                            armnn::TensorInfo& outputTensorInfo)
106*89c4ff92SAndroid Build Coastguard Worker {
107*89c4ff92SAndroid Build Coastguard Worker     const armnn::TensorShape& inputShape = inputTensorInfo.GetShape();
108*89c4ff92SAndroid Build Coastguard Worker 
109*89c4ff92SAndroid Build Coastguard Worker     std::vector<unsigned int> outputShapeVector;
110*89c4ff92SAndroid Build Coastguard Worker     for (unsigned int i = 0; i < inputTensorInfo.GetNumDimensions(); i++)
111*89c4ff92SAndroid Build Coastguard Worker     {
112*89c4ff92SAndroid Build Coastguard Worker         if (desc.m_ShrinkAxisMask & (1 << i))
113*89c4ff92SAndroid Build Coastguard Worker         {
114*89c4ff92SAndroid Build Coastguard Worker             continue;
115*89c4ff92SAndroid Build Coastguard Worker         }
116*89c4ff92SAndroid Build Coastguard Worker 
117*89c4ff92SAndroid Build Coastguard Worker         int stride = desc.m_Stride[i];
118*89c4ff92SAndroid Build Coastguard Worker         int start = desc.GetStartForAxis(inputShape, i);
119*89c4ff92SAndroid Build Coastguard Worker         int stop = desc.GetStopForAxis(inputShape, i, start);
120*89c4ff92SAndroid Build Coastguard Worker 
121*89c4ff92SAndroid Build Coastguard Worker         int newSize = stride > 0 ? ((stop - start) + stride - 1) / stride :
122*89c4ff92SAndroid Build Coastguard Worker                       ((start - stop) - stride - 1) / -stride;
123*89c4ff92SAndroid Build Coastguard Worker 
124*89c4ff92SAndroid Build Coastguard Worker         newSize = std::max(0, newSize);
125*89c4ff92SAndroid Build Coastguard Worker 
126*89c4ff92SAndroid Build Coastguard Worker         outputShapeVector.push_back(static_cast<unsigned int>(newSize));
127*89c4ff92SAndroid Build Coastguard Worker     }
128*89c4ff92SAndroid Build Coastguard Worker 
129*89c4ff92SAndroid Build Coastguard Worker     armnn::TensorShape outputTensorShape(inputTensorInfo.GetNumDimensions(), &outputShapeVector[0]);
130*89c4ff92SAndroid Build Coastguard Worker     outputTensorInfo = armnn::TensorInfo(armnn::TensorShape(outputTensorShape), inputTensorInfo.GetDataType());
131*89c4ff92SAndroid Build Coastguard Worker }
132*89c4ff92SAndroid Build Coastguard Worker } // namespace armnnUtils
133