xref: /aosp_15_r20/external/armnn/src/backends/backendsCommon/WorkloadUtils.hpp (revision 89c4ff92f2867872bb9e2354d150bf0c8c502810)
1*89c4ff92SAndroid Build Coastguard Worker //
2*89c4ff92SAndroid Build Coastguard Worker // Copyright © 2017 Arm Ltd. All rights reserved.
3*89c4ff92SAndroid Build Coastguard Worker // SPDX-License-Identifier: MIT
4*89c4ff92SAndroid Build Coastguard Worker //
5*89c4ff92SAndroid Build Coastguard Worker 
6*89c4ff92SAndroid Build Coastguard Worker #pragma once
7*89c4ff92SAndroid Build Coastguard Worker 
8*89c4ff92SAndroid Build Coastguard Worker #include <armnn/backends/ITensorHandle.hpp>
9*89c4ff92SAndroid Build Coastguard Worker #include <armnn/backends/TensorHandle.hpp>
10*89c4ff92SAndroid Build Coastguard Worker #include <armnn/Tensor.hpp>
11*89c4ff92SAndroid Build Coastguard Worker #include <armnn/utility/PolymorphicDowncast.hpp>
12*89c4ff92SAndroid Build Coastguard Worker #include <armnnUtils/Permute.hpp>
13*89c4ff92SAndroid Build Coastguard Worker 
14*89c4ff92SAndroid Build Coastguard Worker #include <Half.hpp>
15*89c4ff92SAndroid Build Coastguard Worker #include <Profiling.hpp>
16*89c4ff92SAndroid Build Coastguard Worker 
17*89c4ff92SAndroid Build Coastguard Worker 
18*89c4ff92SAndroid Build Coastguard Worker namespace armnn
19*89c4ff92SAndroid Build Coastguard Worker {
20*89c4ff92SAndroid Build Coastguard Worker namespace
21*89c4ff92SAndroid Build Coastguard Worker {
22*89c4ff92SAndroid Build Coastguard Worker 
23*89c4ff92SAndroid Build Coastguard Worker template <typename ArrayType, typename Arg>
AssignValues(unsigned int num,unsigned int & idx,const ArrayType & array,Arg & arg)24*89c4ff92SAndroid Build Coastguard Worker void AssignValues(unsigned int num, unsigned int& idx, const ArrayType& array, Arg& arg)
25*89c4ff92SAndroid Build Coastguard Worker {
26*89c4ff92SAndroid Build Coastguard Worker     if (idx >= num)
27*89c4ff92SAndroid Build Coastguard Worker     {
28*89c4ff92SAndroid Build Coastguard Worker         return;
29*89c4ff92SAndroid Build Coastguard Worker     }
30*89c4ff92SAndroid Build Coastguard Worker 
31*89c4ff92SAndroid Build Coastguard Worker     arg = array[(num - 1) - idx];
32*89c4ff92SAndroid Build Coastguard Worker     idx++;
33*89c4ff92SAndroid Build Coastguard Worker }
34*89c4ff92SAndroid Build Coastguard Worker 
35*89c4ff92SAndroid Build Coastguard Worker template <typename T, typename ArrayType, typename... Args>
AssignValues(unsigned int num,unsigned int idx,const ArrayType & array,T & assignee,Args &...args)36*89c4ff92SAndroid Build Coastguard Worker void AssignValues(unsigned int num, unsigned int idx, const ArrayType& array, T& assignee, Args&... args)
37*89c4ff92SAndroid Build Coastguard Worker {
38*89c4ff92SAndroid Build Coastguard Worker     AssignValues(num, idx, array, assignee);
39*89c4ff92SAndroid Build Coastguard Worker 
40*89c4ff92SAndroid Build Coastguard Worker     AssignValues(num, idx, array, args...);
41*89c4ff92SAndroid Build Coastguard Worker }
42*89c4ff92SAndroid Build Coastguard Worker 
43*89c4ff92SAndroid Build Coastguard Worker }    // anonymous namespace
44*89c4ff92SAndroid Build Coastguard Worker 
45*89c4ff92SAndroid Build Coastguard Worker template <typename CopyFunc>
CopyTensorContentsGeneric(const ITensorHandle * srcTensor,ITensorHandle * dstTensor,CopyFunc copy)46*89c4ff92SAndroid Build Coastguard Worker void CopyTensorContentsGeneric(const ITensorHandle* srcTensor, ITensorHandle* dstTensor, CopyFunc copy)
47*89c4ff92SAndroid Build Coastguard Worker {
48*89c4ff92SAndroid Build Coastguard Worker     // For ease of understanding, names are assigned to the dimensions
49*89c4ff92SAndroid Build Coastguard Worker     // of the tensor as if NHWC, however this routine works with any 5D tensor
50*89c4ff92SAndroid Build Coastguard Worker     static_assert(MaxNumOfTensorDimensions == 5, "Please update CopyTensorContents");
51*89c4ff92SAndroid Build Coastguard Worker 
52*89c4ff92SAndroid Build Coastguard Worker     TensorShape srcStrides      = srcTensor->GetStrides();
53*89c4ff92SAndroid Build Coastguard Worker     const TensorShape& srcShape = srcTensor->GetShape();
54*89c4ff92SAndroid Build Coastguard Worker     const auto srcSize          = srcTensor->GetStrides()[0] * srcShape[0];
55*89c4ff92SAndroid Build Coastguard Worker     IgnoreUnused(srcSize);  // Only used for asserts
56*89c4ff92SAndroid Build Coastguard Worker     TensorShape dstStrides      = dstTensor->GetStrides();
57*89c4ff92SAndroid Build Coastguard Worker     const TensorShape& dstShape = dstTensor->GetShape();
58*89c4ff92SAndroid Build Coastguard Worker     const auto dstSize          = dstTensor->GetStrides()[0] * dstShape[0];
59*89c4ff92SAndroid Build Coastguard Worker     IgnoreUnused(dstSize);  // Only used for asserts
60*89c4ff92SAndroid Build Coastguard Worker 
61*89c4ff92SAndroid Build Coastguard Worker     size_t srcDepth    = 1;
62*89c4ff92SAndroid Build Coastguard Worker     size_t srcBatches  = 1;
63*89c4ff92SAndroid Build Coastguard Worker     size_t srcHeight   = 1;
64*89c4ff92SAndroid Build Coastguard Worker     size_t srcWidth    = 1;
65*89c4ff92SAndroid Build Coastguard Worker     size_t srcChannels = 1;
66*89c4ff92SAndroid Build Coastguard Worker     AssignValues(srcShape.GetNumDimensions(),
67*89c4ff92SAndroid Build Coastguard Worker                  0,
68*89c4ff92SAndroid Build Coastguard Worker                  srcShape,
69*89c4ff92SAndroid Build Coastguard Worker                  srcChannels,
70*89c4ff92SAndroid Build Coastguard Worker                  srcWidth,
71*89c4ff92SAndroid Build Coastguard Worker                  srcHeight,
72*89c4ff92SAndroid Build Coastguard Worker                  srcBatches,
73*89c4ff92SAndroid Build Coastguard Worker                  srcDepth);
74*89c4ff92SAndroid Build Coastguard Worker 
75*89c4ff92SAndroid Build Coastguard Worker     size_t srcDepthStride   = 0;
76*89c4ff92SAndroid Build Coastguard Worker     size_t srcBatchStride   = 0;
77*89c4ff92SAndroid Build Coastguard Worker     size_t srcHeightStride  = 0;
78*89c4ff92SAndroid Build Coastguard Worker     size_t srcWidthStride   = 0;
79*89c4ff92SAndroid Build Coastguard Worker     size_t srcChannelStride = 0;
80*89c4ff92SAndroid Build Coastguard Worker     AssignValues(srcStrides.GetNumDimensions(),
81*89c4ff92SAndroid Build Coastguard Worker                  0,
82*89c4ff92SAndroid Build Coastguard Worker                  srcStrides,
83*89c4ff92SAndroid Build Coastguard Worker                  srcChannelStride,
84*89c4ff92SAndroid Build Coastguard Worker                  srcWidthStride,
85*89c4ff92SAndroid Build Coastguard Worker                  srcHeightStride,
86*89c4ff92SAndroid Build Coastguard Worker                  srcBatchStride,
87*89c4ff92SAndroid Build Coastguard Worker                  srcDepthStride);
88*89c4ff92SAndroid Build Coastguard Worker 
89*89c4ff92SAndroid Build Coastguard Worker     size_t dstDepth    = 1;
90*89c4ff92SAndroid Build Coastguard Worker     size_t dstBatches  = 1;
91*89c4ff92SAndroid Build Coastguard Worker     size_t dstHeight   = 1;
92*89c4ff92SAndroid Build Coastguard Worker     size_t dstWidth    = 1;
93*89c4ff92SAndroid Build Coastguard Worker     size_t dstChannels = 1;
94*89c4ff92SAndroid Build Coastguard Worker     AssignValues(dstShape.GetNumDimensions(),
95*89c4ff92SAndroid Build Coastguard Worker                  0,
96*89c4ff92SAndroid Build Coastguard Worker                  dstShape,
97*89c4ff92SAndroid Build Coastguard Worker                  dstChannels,
98*89c4ff92SAndroid Build Coastguard Worker                  dstWidth,
99*89c4ff92SAndroid Build Coastguard Worker                  dstHeight,
100*89c4ff92SAndroid Build Coastguard Worker                  dstBatches,
101*89c4ff92SAndroid Build Coastguard Worker                  dstDepth);
102*89c4ff92SAndroid Build Coastguard Worker 
103*89c4ff92SAndroid Build Coastguard Worker     size_t dstDepthStride   = 0;
104*89c4ff92SAndroid Build Coastguard Worker     size_t dstBatchStride   = 0;
105*89c4ff92SAndroid Build Coastguard Worker     size_t dstHeightStride  = 0;
106*89c4ff92SAndroid Build Coastguard Worker     size_t dstWidthStride   = 0;
107*89c4ff92SAndroid Build Coastguard Worker     size_t dstChannelStride = 0;
108*89c4ff92SAndroid Build Coastguard Worker     AssignValues(dstStrides.GetNumDimensions(),
109*89c4ff92SAndroid Build Coastguard Worker                  0,
110*89c4ff92SAndroid Build Coastguard Worker                  dstStrides,
111*89c4ff92SAndroid Build Coastguard Worker                  dstChannelStride,
112*89c4ff92SAndroid Build Coastguard Worker                  dstWidthStride,
113*89c4ff92SAndroid Build Coastguard Worker                  dstHeightStride,
114*89c4ff92SAndroid Build Coastguard Worker                  dstBatchStride,
115*89c4ff92SAndroid Build Coastguard Worker                  dstDepthStride);
116*89c4ff92SAndroid Build Coastguard Worker 
117*89c4ff92SAndroid Build Coastguard Worker     const unsigned char* srcDataStart;
118*89c4ff92SAndroid Build Coastguard Worker     unsigned char* dstDataStart;
119*89c4ff92SAndroid Build Coastguard Worker     {
120*89c4ff92SAndroid Build Coastguard Worker         ARMNN_SCOPED_PROFILING_EVENT(Compute::Undefined, "Synchronize buffers");
121*89c4ff92SAndroid Build Coastguard Worker         srcDataStart = static_cast<const uint8_t*>(srcTensor->Map());
122*89c4ff92SAndroid Build Coastguard Worker         dstDataStart = static_cast<uint8_t*>(dstTensor->Map());
123*89c4ff92SAndroid Build Coastguard Worker     }
124*89c4ff92SAndroid Build Coastguard Worker 
125*89c4ff92SAndroid Build Coastguard Worker     size_t copyLength  = std::min(srcChannels * srcChannelStride, dstChannels * dstChannelStride);
126*89c4ff92SAndroid Build Coastguard Worker     size_t copyWidth   = std::min(srcWidth, dstWidth);
127*89c4ff92SAndroid Build Coastguard Worker     size_t copyHeight  = std::min(srcHeight, dstHeight);
128*89c4ff92SAndroid Build Coastguard Worker     size_t copyBatches = std::min(srcBatches, dstBatches);
129*89c4ff92SAndroid Build Coastguard Worker     size_t copyDepth   = std::min(srcDepth, dstDepth);
130*89c4ff92SAndroid Build Coastguard Worker 
131*89c4ff92SAndroid Build Coastguard Worker     // Coalesce inner dimensions where possible
132*89c4ff92SAndroid Build Coastguard Worker     // to reduce overheard calling copy() and to
133*89c4ff92SAndroid Build Coastguard Worker     // allow for memory bandwidth optimisations
134*89c4ff92SAndroid Build Coastguard Worker     if (copyLength == srcWidthStride &&
135*89c4ff92SAndroid Build Coastguard Worker         copyLength == dstWidthStride)
136*89c4ff92SAndroid Build Coastguard Worker     {
137*89c4ff92SAndroid Build Coastguard Worker         // There is no special padding between rows,
138*89c4ff92SAndroid Build Coastguard Worker         // and sizes are compatible, so copy whole rows
139*89c4ff92SAndroid Build Coastguard Worker         copyLength *= copyWidth;
140*89c4ff92SAndroid Build Coastguard Worker         copyWidth = 1;
141*89c4ff92SAndroid Build Coastguard Worker 
142*89c4ff92SAndroid Build Coastguard Worker         if (copyLength == srcHeightStride &&
143*89c4ff92SAndroid Build Coastguard Worker             copyLength == dstHeightStride)
144*89c4ff92SAndroid Build Coastguard Worker         {
145*89c4ff92SAndroid Build Coastguard Worker             // There is no special padding between batches
146*89c4ff92SAndroid Build Coastguard Worker             // and sizes are compatible so copy whole batches
147*89c4ff92SAndroid Build Coastguard Worker             copyLength *= copyHeight;
148*89c4ff92SAndroid Build Coastguard Worker             copyHeight = 1;
149*89c4ff92SAndroid Build Coastguard Worker         }
150*89c4ff92SAndroid Build Coastguard Worker     }
151*89c4ff92SAndroid Build Coastguard Worker 
152*89c4ff92SAndroid Build Coastguard Worker     const unsigned char* srcData = srcDataStart;
153*89c4ff92SAndroid Build Coastguard Worker     unsigned char* dstData = dstDataStart;
154*89c4ff92SAndroid Build Coastguard Worker     for (unsigned int d = 0; d < copyDepth; ++d)
155*89c4ff92SAndroid Build Coastguard Worker     {
156*89c4ff92SAndroid Build Coastguard Worker         auto srcPtrDepth = srcData;
157*89c4ff92SAndroid Build Coastguard Worker         auto dstPtrDepth = dstData;
158*89c4ff92SAndroid Build Coastguard Worker         for (unsigned int b = 0; b < copyBatches; ++b)
159*89c4ff92SAndroid Build Coastguard Worker         {
160*89c4ff92SAndroid Build Coastguard Worker             auto srcPtrBatch = srcData;
161*89c4ff92SAndroid Build Coastguard Worker             auto dstPtrBatch = dstData;
162*89c4ff92SAndroid Build Coastguard Worker             for (unsigned int h = 0; h < copyHeight; ++h)
163*89c4ff92SAndroid Build Coastguard Worker             {
164*89c4ff92SAndroid Build Coastguard Worker                 auto srcPtrChannel = srcData;
165*89c4ff92SAndroid Build Coastguard Worker                 auto dstPtrChannel = dstData;
166*89c4ff92SAndroid Build Coastguard Worker                 for (unsigned int w = 0; w < copyWidth; ++w)
167*89c4ff92SAndroid Build Coastguard Worker                 {
168*89c4ff92SAndroid Build Coastguard Worker                     ARMNN_ASSERT(srcData >= srcDataStart && srcData + copyLength <= srcDataStart + srcSize);
169*89c4ff92SAndroid Build Coastguard Worker                     ARMNN_ASSERT(dstData >= dstDataStart && dstData + copyLength <= dstDataStart + dstSize);
170*89c4ff92SAndroid Build Coastguard Worker                     copy(dstData, srcData, copyLength);
171*89c4ff92SAndroid Build Coastguard Worker                     dstData += dstWidthStride;
172*89c4ff92SAndroid Build Coastguard Worker                     srcData += srcWidthStride;
173*89c4ff92SAndroid Build Coastguard Worker                 }
174*89c4ff92SAndroid Build Coastguard Worker                 dstData += (static_cast<long>(dstHeightStride) - (dstData - dstPtrChannel));
175*89c4ff92SAndroid Build Coastguard Worker                 srcData += (static_cast<long>(srcHeightStride) - (srcData - srcPtrChannel));
176*89c4ff92SAndroid Build Coastguard Worker             }
177*89c4ff92SAndroid Build Coastguard Worker             dstData += (static_cast<long>(dstBatchStride) - (dstData - dstPtrBatch));
178*89c4ff92SAndroid Build Coastguard Worker             srcData += (static_cast<long>(srcBatchStride) - (srcData - srcPtrBatch));
179*89c4ff92SAndroid Build Coastguard Worker         }
180*89c4ff92SAndroid Build Coastguard Worker         dstData += (static_cast<long>(dstDepthStride) - (dstData - dstPtrDepth));
181*89c4ff92SAndroid Build Coastguard Worker         srcData += (static_cast<long>(srcDepthStride) - (srcData - srcPtrDepth));
182*89c4ff92SAndroid Build Coastguard Worker     }
183*89c4ff92SAndroid Build Coastguard Worker 
184*89c4ff92SAndroid Build Coastguard Worker     srcTensor->Unmap();
185*89c4ff92SAndroid Build Coastguard Worker     dstTensor->Unmap();
186*89c4ff92SAndroid Build Coastguard Worker }
187*89c4ff92SAndroid Build Coastguard Worker 
188*89c4ff92SAndroid Build Coastguard Worker template <typename SrcTensorHandleType, typename DstTensorHandleType, typename DescriptorType>
GatherTensorHandlePairs(const DescriptorType & descriptor,std::vector<std::pair<SrcTensorHandleType *,DstTensorHandleType * >> & tensorHandlePairs)189*89c4ff92SAndroid Build Coastguard Worker void GatherTensorHandlePairs(const DescriptorType& descriptor,
190*89c4ff92SAndroid Build Coastguard Worker                              std::vector<std::pair<SrcTensorHandleType*, DstTensorHandleType*>>& tensorHandlePairs)
191*89c4ff92SAndroid Build Coastguard Worker {
192*89c4ff92SAndroid Build Coastguard Worker     const unsigned int numInputs = static_cast<unsigned int>(descriptor.m_Inputs.size());
193*89c4ff92SAndroid Build Coastguard Worker     tensorHandlePairs.reserve(numInputs);
194*89c4ff92SAndroid Build Coastguard Worker 
195*89c4ff92SAndroid Build Coastguard Worker     for (unsigned int i = 0; i < numInputs; ++i)
196*89c4ff92SAndroid Build Coastguard Worker     {
197*89c4ff92SAndroid Build Coastguard Worker         SrcTensorHandleType* const srcTensorHandle =
198*89c4ff92SAndroid Build Coastguard Worker             PolymorphicDowncast<SrcTensorHandleType*>(descriptor.m_Inputs[i]);
199*89c4ff92SAndroid Build Coastguard Worker         DstTensorHandleType* const dstTensorHandle =
200*89c4ff92SAndroid Build Coastguard Worker             PolymorphicDowncast<DstTensorHandleType*>(descriptor.m_Outputs[i]);
201*89c4ff92SAndroid Build Coastguard Worker 
202*89c4ff92SAndroid Build Coastguard Worker         tensorHandlePairs.emplace_back(srcTensorHandle, dstTensorHandle);
203*89c4ff92SAndroid Build Coastguard Worker     }
204*89c4ff92SAndroid Build Coastguard Worker }
205*89c4ff92SAndroid Build Coastguard Worker 
206*89c4ff92SAndroid Build Coastguard Worker int32_t ConvertMaskToACLFormat(int32_t mask, int32_t numDim);
207*89c4ff92SAndroid Build Coastguard Worker 
208*89c4ff92SAndroid Build Coastguard Worker armnn::ConstTensor PermuteTensor(const ConstTensorHandle* tensor,
209*89c4ff92SAndroid Build Coastguard Worker                                  const PermutationVector& permutationVector,
210*89c4ff92SAndroid Build Coastguard Worker                                  void* permuteBuffer);
211*89c4ff92SAndroid Build Coastguard Worker 
212*89c4ff92SAndroid Build Coastguard Worker void ReshapeWeightsForAcl(TensorInfo& weightInfo, DataLayout dataLayout);
213*89c4ff92SAndroid Build Coastguard Worker 
214*89c4ff92SAndroid Build Coastguard Worker TensorInfo ConvertWeightTensorInfoFromArmnnToAcl(const TensorInfo& weightInfo, DataLayout dataLayout);
215*89c4ff92SAndroid Build Coastguard Worker 
216*89c4ff92SAndroid Build Coastguard Worker /// Weights for depthwise have a datalayout of [1,H,W,O] = [1,H,W,I*M]
217*89c4ff92SAndroid Build Coastguard Worker /// This function coverts a TensorInfo from [1,H,W,I*M] to [1,I*M,H,W] (if NCHW) or keeps it at [1,H,W,I*M] (if NHWC)
218*89c4ff92SAndroid Build Coastguard Worker /// as required by the compute library
219*89c4ff92SAndroid Build Coastguard Worker /// Returns a tuple of converted weights tensor info and depth multiplier
220*89c4ff92SAndroid Build Coastguard Worker std::tuple<TensorInfo, unsigned int> Convert1HWOTensorInfoToAcl(const TensorInfo& weightInfo,
221*89c4ff92SAndroid Build Coastguard Worker                                                                 const TensorInfo& inputInfo,
222*89c4ff92SAndroid Build Coastguard Worker                                                                 const DataLayout dataLayout);
223*89c4ff92SAndroid Build Coastguard Worker 
224*89c4ff92SAndroid Build Coastguard Worker armnn::ConstTensor ConvertWeightTensorFromArmnnToAcl(const ConstTensorHandle* weightTensor,
225*89c4ff92SAndroid Build Coastguard Worker                                                      DataLayout dataLayout,
226*89c4ff92SAndroid Build Coastguard Worker                                                      void* permuteBuffer);
227*89c4ff92SAndroid Build Coastguard Worker 
228*89c4ff92SAndroid Build Coastguard Worker /// Weights for depthwise have a datalayout of [1,H,W,O] = [1,H,W,I*M]
229*89c4ff92SAndroid Build Coastguard Worker /// This function coverts a ConstCpuTensorHandle from [1,H,W,I*M] to [1,I*M,H,W] (if NCHW) or
230*89c4ff92SAndroid Build Coastguard Worker /// keeps it at [1,H,W,I*M] (if NHWC) as required by the compute library
231*89c4ff92SAndroid Build Coastguard Worker ///
232*89c4ff92SAndroid Build Coastguard Worker /// \param weightTensor - ConstTensorHandle of weights tensor
233*89c4ff92SAndroid Build Coastguard Worker /// \param inputInfo - TensorInfo of input tensor
234*89c4ff92SAndroid Build Coastguard Worker /// \param dataLayout - DataLayout of the input tensor
235*89c4ff92SAndroid Build Coastguard Worker /// \param permuteBuffer - Pointer to memory with the size of tensor. Used for the permutation
236*89c4ff92SAndroid Build Coastguard Worker /// \return tuple of transformed weights-ConstTensor and depthwise multiplier
237*89c4ff92SAndroid Build Coastguard Worker std::tuple<ConstTensor, unsigned int> Convert1HWOTensorToAcl(const ConstTensorHandle* weightTensor,
238*89c4ff92SAndroid Build Coastguard Worker                                                              const TensorInfo& inputInfo,
239*89c4ff92SAndroid Build Coastguard Worker                                                              const DataLayout dataLayout,
240*89c4ff92SAndroid Build Coastguard Worker                                                              void* permuteBuffer);
241*89c4ff92SAndroid Build Coastguard Worker 
242*89c4ff92SAndroid Build Coastguard Worker /// Converts a (weights) tensor from [1, H, W, I*M] = [1, H, W, O] to [M, I, H, W]
243*89c4ff92SAndroid Build Coastguard Worker ///
244*89c4ff92SAndroid Build Coastguard Worker /// \param weightTensor - ConstTensorHandle of the weight tensor that should be converted
245*89c4ff92SAndroid Build Coastguard Worker /// \param inputInfo - TensorInfo of the corresponding input tensor
246*89c4ff92SAndroid Build Coastguard Worker /// \param dataLayout - DataLayout of the input tensor e.g. NHWC or NCHW
247*89c4ff92SAndroid Build Coastguard Worker /// \param permuteBuffer - Memory location with the same size as the weight tensor to write converted data to
248*89c4ff92SAndroid Build Coastguard Worker /// \return - A tuple of ConstTensor and unsigned int which is the converted weightTensor and the depthMultiplier
249*89c4ff92SAndroid Build Coastguard Worker std::tuple<ConstTensor, unsigned int> Convert1HWOtoMIHW(const ConstTensorHandle* weightTensor,
250*89c4ff92SAndroid Build Coastguard Worker                                                         const TensorInfo& inputInfo,
251*89c4ff92SAndroid Build Coastguard Worker                                                         const DataLayout& dataLayout,
252*89c4ff92SAndroid Build Coastguard Worker                                                         void* permuteBuffer);
253*89c4ff92SAndroid Build Coastguard Worker 
254*89c4ff92SAndroid Build Coastguard Worker /// Calculates the key index values needed for GatherNd: N, ND, K, W, C (N is always 1)
255*89c4ff92SAndroid Build Coastguard Worker ///
256*89c4ff92SAndroid Build Coastguard Worker /// \param inputInfo0 - TensorInfo of the corresponding input tensor: params
257*89c4ff92SAndroid Build Coastguard Worker /// \param inputInfo1 - TensorInfo of the corresponding input tensor: indices
258*89c4ff92SAndroid Build Coastguard Worker /// \return - A map with names and values for  N, ND, K, W, C
259*89c4ff92SAndroid Build Coastguard Worker std::map<std::string, unsigned int> CalculateGatherNdKeyIndices(TensorInfo inputInfo0, TensorInfo inputInfo1);
260*89c4ff92SAndroid Build Coastguard Worker 
261*89c4ff92SAndroid Build Coastguard Worker /// Generates a permutation vector of size rank that permutes the 2 most right dimensions
262*89c4ff92SAndroid Build Coastguard Worker ///
263*89c4ff92SAndroid Build Coastguard Worker /// \param rank - Tensor rank, i.e. number of dimensions in the tensors
264*89c4ff92SAndroid Build Coastguard Worker /// \return - A permutation vector that permutes the 2 last dimensions
265*89c4ff92SAndroid Build Coastguard Worker armnn::PermutationVector GeneratePermutationVectorOnLastTwoDimensions(unsigned int rank);
266*89c4ff92SAndroid Build Coastguard Worker 
267*89c4ff92SAndroid Build Coastguard Worker }  //namespace armnn
268