1*89c4ff92SAndroid Build Coastguard Worker //
2*89c4ff92SAndroid Build Coastguard Worker // Copyright © 2017 Arm Ltd. All rights reserved.
3*89c4ff92SAndroid Build Coastguard Worker // SPDX-License-Identifier: MIT
4*89c4ff92SAndroid Build Coastguard Worker //
5*89c4ff92SAndroid Build Coastguard Worker
6*89c4ff92SAndroid Build Coastguard Worker #pragma once
7*89c4ff92SAndroid Build Coastguard Worker
8*89c4ff92SAndroid Build Coastguard Worker #include <armnn/backends/ITensorHandle.hpp>
9*89c4ff92SAndroid Build Coastguard Worker #include <armnn/backends/TensorHandle.hpp>
10*89c4ff92SAndroid Build Coastguard Worker #include <armnn/Tensor.hpp>
11*89c4ff92SAndroid Build Coastguard Worker #include <armnn/utility/PolymorphicDowncast.hpp>
12*89c4ff92SAndroid Build Coastguard Worker #include <armnnUtils/Permute.hpp>
13*89c4ff92SAndroid Build Coastguard Worker
14*89c4ff92SAndroid Build Coastguard Worker #include <Half.hpp>
15*89c4ff92SAndroid Build Coastguard Worker #include <Profiling.hpp>
16*89c4ff92SAndroid Build Coastguard Worker
17*89c4ff92SAndroid Build Coastguard Worker
18*89c4ff92SAndroid Build Coastguard Worker namespace armnn
19*89c4ff92SAndroid Build Coastguard Worker {
20*89c4ff92SAndroid Build Coastguard Worker namespace
21*89c4ff92SAndroid Build Coastguard Worker {
22*89c4ff92SAndroid Build Coastguard Worker
23*89c4ff92SAndroid Build Coastguard Worker template <typename ArrayType, typename Arg>
AssignValues(unsigned int num,unsigned int & idx,const ArrayType & array,Arg & arg)24*89c4ff92SAndroid Build Coastguard Worker void AssignValues(unsigned int num, unsigned int& idx, const ArrayType& array, Arg& arg)
25*89c4ff92SAndroid Build Coastguard Worker {
26*89c4ff92SAndroid Build Coastguard Worker if (idx >= num)
27*89c4ff92SAndroid Build Coastguard Worker {
28*89c4ff92SAndroid Build Coastguard Worker return;
29*89c4ff92SAndroid Build Coastguard Worker }
30*89c4ff92SAndroid Build Coastguard Worker
31*89c4ff92SAndroid Build Coastguard Worker arg = array[(num - 1) - idx];
32*89c4ff92SAndroid Build Coastguard Worker idx++;
33*89c4ff92SAndroid Build Coastguard Worker }
34*89c4ff92SAndroid Build Coastguard Worker
35*89c4ff92SAndroid Build Coastguard Worker template <typename T, typename ArrayType, typename... Args>
AssignValues(unsigned int num,unsigned int idx,const ArrayType & array,T & assignee,Args &...args)36*89c4ff92SAndroid Build Coastguard Worker void AssignValues(unsigned int num, unsigned int idx, const ArrayType& array, T& assignee, Args&... args)
37*89c4ff92SAndroid Build Coastguard Worker {
38*89c4ff92SAndroid Build Coastguard Worker AssignValues(num, idx, array, assignee);
39*89c4ff92SAndroid Build Coastguard Worker
40*89c4ff92SAndroid Build Coastguard Worker AssignValues(num, idx, array, args...);
41*89c4ff92SAndroid Build Coastguard Worker }
42*89c4ff92SAndroid Build Coastguard Worker
43*89c4ff92SAndroid Build Coastguard Worker } // anonymous namespace
44*89c4ff92SAndroid Build Coastguard Worker
45*89c4ff92SAndroid Build Coastguard Worker template <typename CopyFunc>
CopyTensorContentsGeneric(const ITensorHandle * srcTensor,ITensorHandle * dstTensor,CopyFunc copy)46*89c4ff92SAndroid Build Coastguard Worker void CopyTensorContentsGeneric(const ITensorHandle* srcTensor, ITensorHandle* dstTensor, CopyFunc copy)
47*89c4ff92SAndroid Build Coastguard Worker {
48*89c4ff92SAndroid Build Coastguard Worker // For ease of understanding, names are assigned to the dimensions
49*89c4ff92SAndroid Build Coastguard Worker // of the tensor as if NHWC, however this routine works with any 5D tensor
50*89c4ff92SAndroid Build Coastguard Worker static_assert(MaxNumOfTensorDimensions == 5, "Please update CopyTensorContents");
51*89c4ff92SAndroid Build Coastguard Worker
52*89c4ff92SAndroid Build Coastguard Worker TensorShape srcStrides = srcTensor->GetStrides();
53*89c4ff92SAndroid Build Coastguard Worker const TensorShape& srcShape = srcTensor->GetShape();
54*89c4ff92SAndroid Build Coastguard Worker const auto srcSize = srcTensor->GetStrides()[0] * srcShape[0];
55*89c4ff92SAndroid Build Coastguard Worker IgnoreUnused(srcSize); // Only used for asserts
56*89c4ff92SAndroid Build Coastguard Worker TensorShape dstStrides = dstTensor->GetStrides();
57*89c4ff92SAndroid Build Coastguard Worker const TensorShape& dstShape = dstTensor->GetShape();
58*89c4ff92SAndroid Build Coastguard Worker const auto dstSize = dstTensor->GetStrides()[0] * dstShape[0];
59*89c4ff92SAndroid Build Coastguard Worker IgnoreUnused(dstSize); // Only used for asserts
60*89c4ff92SAndroid Build Coastguard Worker
61*89c4ff92SAndroid Build Coastguard Worker size_t srcDepth = 1;
62*89c4ff92SAndroid Build Coastguard Worker size_t srcBatches = 1;
63*89c4ff92SAndroid Build Coastguard Worker size_t srcHeight = 1;
64*89c4ff92SAndroid Build Coastguard Worker size_t srcWidth = 1;
65*89c4ff92SAndroid Build Coastguard Worker size_t srcChannels = 1;
66*89c4ff92SAndroid Build Coastguard Worker AssignValues(srcShape.GetNumDimensions(),
67*89c4ff92SAndroid Build Coastguard Worker 0,
68*89c4ff92SAndroid Build Coastguard Worker srcShape,
69*89c4ff92SAndroid Build Coastguard Worker srcChannels,
70*89c4ff92SAndroid Build Coastguard Worker srcWidth,
71*89c4ff92SAndroid Build Coastguard Worker srcHeight,
72*89c4ff92SAndroid Build Coastguard Worker srcBatches,
73*89c4ff92SAndroid Build Coastguard Worker srcDepth);
74*89c4ff92SAndroid Build Coastguard Worker
75*89c4ff92SAndroid Build Coastguard Worker size_t srcDepthStride = 0;
76*89c4ff92SAndroid Build Coastguard Worker size_t srcBatchStride = 0;
77*89c4ff92SAndroid Build Coastguard Worker size_t srcHeightStride = 0;
78*89c4ff92SAndroid Build Coastguard Worker size_t srcWidthStride = 0;
79*89c4ff92SAndroid Build Coastguard Worker size_t srcChannelStride = 0;
80*89c4ff92SAndroid Build Coastguard Worker AssignValues(srcStrides.GetNumDimensions(),
81*89c4ff92SAndroid Build Coastguard Worker 0,
82*89c4ff92SAndroid Build Coastguard Worker srcStrides,
83*89c4ff92SAndroid Build Coastguard Worker srcChannelStride,
84*89c4ff92SAndroid Build Coastguard Worker srcWidthStride,
85*89c4ff92SAndroid Build Coastguard Worker srcHeightStride,
86*89c4ff92SAndroid Build Coastguard Worker srcBatchStride,
87*89c4ff92SAndroid Build Coastguard Worker srcDepthStride);
88*89c4ff92SAndroid Build Coastguard Worker
89*89c4ff92SAndroid Build Coastguard Worker size_t dstDepth = 1;
90*89c4ff92SAndroid Build Coastguard Worker size_t dstBatches = 1;
91*89c4ff92SAndroid Build Coastguard Worker size_t dstHeight = 1;
92*89c4ff92SAndroid Build Coastguard Worker size_t dstWidth = 1;
93*89c4ff92SAndroid Build Coastguard Worker size_t dstChannels = 1;
94*89c4ff92SAndroid Build Coastguard Worker AssignValues(dstShape.GetNumDimensions(),
95*89c4ff92SAndroid Build Coastguard Worker 0,
96*89c4ff92SAndroid Build Coastguard Worker dstShape,
97*89c4ff92SAndroid Build Coastguard Worker dstChannels,
98*89c4ff92SAndroid Build Coastguard Worker dstWidth,
99*89c4ff92SAndroid Build Coastguard Worker dstHeight,
100*89c4ff92SAndroid Build Coastguard Worker dstBatches,
101*89c4ff92SAndroid Build Coastguard Worker dstDepth);
102*89c4ff92SAndroid Build Coastguard Worker
103*89c4ff92SAndroid Build Coastguard Worker size_t dstDepthStride = 0;
104*89c4ff92SAndroid Build Coastguard Worker size_t dstBatchStride = 0;
105*89c4ff92SAndroid Build Coastguard Worker size_t dstHeightStride = 0;
106*89c4ff92SAndroid Build Coastguard Worker size_t dstWidthStride = 0;
107*89c4ff92SAndroid Build Coastguard Worker size_t dstChannelStride = 0;
108*89c4ff92SAndroid Build Coastguard Worker AssignValues(dstStrides.GetNumDimensions(),
109*89c4ff92SAndroid Build Coastguard Worker 0,
110*89c4ff92SAndroid Build Coastguard Worker dstStrides,
111*89c4ff92SAndroid Build Coastguard Worker dstChannelStride,
112*89c4ff92SAndroid Build Coastguard Worker dstWidthStride,
113*89c4ff92SAndroid Build Coastguard Worker dstHeightStride,
114*89c4ff92SAndroid Build Coastguard Worker dstBatchStride,
115*89c4ff92SAndroid Build Coastguard Worker dstDepthStride);
116*89c4ff92SAndroid Build Coastguard Worker
117*89c4ff92SAndroid Build Coastguard Worker const unsigned char* srcDataStart;
118*89c4ff92SAndroid Build Coastguard Worker unsigned char* dstDataStart;
119*89c4ff92SAndroid Build Coastguard Worker {
120*89c4ff92SAndroid Build Coastguard Worker ARMNN_SCOPED_PROFILING_EVENT(Compute::Undefined, "Synchronize buffers");
121*89c4ff92SAndroid Build Coastguard Worker srcDataStart = static_cast<const uint8_t*>(srcTensor->Map());
122*89c4ff92SAndroid Build Coastguard Worker dstDataStart = static_cast<uint8_t*>(dstTensor->Map());
123*89c4ff92SAndroid Build Coastguard Worker }
124*89c4ff92SAndroid Build Coastguard Worker
125*89c4ff92SAndroid Build Coastguard Worker size_t copyLength = std::min(srcChannels * srcChannelStride, dstChannels * dstChannelStride);
126*89c4ff92SAndroid Build Coastguard Worker size_t copyWidth = std::min(srcWidth, dstWidth);
127*89c4ff92SAndroid Build Coastguard Worker size_t copyHeight = std::min(srcHeight, dstHeight);
128*89c4ff92SAndroid Build Coastguard Worker size_t copyBatches = std::min(srcBatches, dstBatches);
129*89c4ff92SAndroid Build Coastguard Worker size_t copyDepth = std::min(srcDepth, dstDepth);
130*89c4ff92SAndroid Build Coastguard Worker
131*89c4ff92SAndroid Build Coastguard Worker // Coalesce inner dimensions where possible
132*89c4ff92SAndroid Build Coastguard Worker // to reduce overheard calling copy() and to
133*89c4ff92SAndroid Build Coastguard Worker // allow for memory bandwidth optimisations
134*89c4ff92SAndroid Build Coastguard Worker if (copyLength == srcWidthStride &&
135*89c4ff92SAndroid Build Coastguard Worker copyLength == dstWidthStride)
136*89c4ff92SAndroid Build Coastguard Worker {
137*89c4ff92SAndroid Build Coastguard Worker // There is no special padding between rows,
138*89c4ff92SAndroid Build Coastguard Worker // and sizes are compatible, so copy whole rows
139*89c4ff92SAndroid Build Coastguard Worker copyLength *= copyWidth;
140*89c4ff92SAndroid Build Coastguard Worker copyWidth = 1;
141*89c4ff92SAndroid Build Coastguard Worker
142*89c4ff92SAndroid Build Coastguard Worker if (copyLength == srcHeightStride &&
143*89c4ff92SAndroid Build Coastguard Worker copyLength == dstHeightStride)
144*89c4ff92SAndroid Build Coastguard Worker {
145*89c4ff92SAndroid Build Coastguard Worker // There is no special padding between batches
146*89c4ff92SAndroid Build Coastguard Worker // and sizes are compatible so copy whole batches
147*89c4ff92SAndroid Build Coastguard Worker copyLength *= copyHeight;
148*89c4ff92SAndroid Build Coastguard Worker copyHeight = 1;
149*89c4ff92SAndroid Build Coastguard Worker }
150*89c4ff92SAndroid Build Coastguard Worker }
151*89c4ff92SAndroid Build Coastguard Worker
152*89c4ff92SAndroid Build Coastguard Worker const unsigned char* srcData = srcDataStart;
153*89c4ff92SAndroid Build Coastguard Worker unsigned char* dstData = dstDataStart;
154*89c4ff92SAndroid Build Coastguard Worker for (unsigned int d = 0; d < copyDepth; ++d)
155*89c4ff92SAndroid Build Coastguard Worker {
156*89c4ff92SAndroid Build Coastguard Worker auto srcPtrDepth = srcData;
157*89c4ff92SAndroid Build Coastguard Worker auto dstPtrDepth = dstData;
158*89c4ff92SAndroid Build Coastguard Worker for (unsigned int b = 0; b < copyBatches; ++b)
159*89c4ff92SAndroid Build Coastguard Worker {
160*89c4ff92SAndroid Build Coastguard Worker auto srcPtrBatch = srcData;
161*89c4ff92SAndroid Build Coastguard Worker auto dstPtrBatch = dstData;
162*89c4ff92SAndroid Build Coastguard Worker for (unsigned int h = 0; h < copyHeight; ++h)
163*89c4ff92SAndroid Build Coastguard Worker {
164*89c4ff92SAndroid Build Coastguard Worker auto srcPtrChannel = srcData;
165*89c4ff92SAndroid Build Coastguard Worker auto dstPtrChannel = dstData;
166*89c4ff92SAndroid Build Coastguard Worker for (unsigned int w = 0; w < copyWidth; ++w)
167*89c4ff92SAndroid Build Coastguard Worker {
168*89c4ff92SAndroid Build Coastguard Worker ARMNN_ASSERT(srcData >= srcDataStart && srcData + copyLength <= srcDataStart + srcSize);
169*89c4ff92SAndroid Build Coastguard Worker ARMNN_ASSERT(dstData >= dstDataStart && dstData + copyLength <= dstDataStart + dstSize);
170*89c4ff92SAndroid Build Coastguard Worker copy(dstData, srcData, copyLength);
171*89c4ff92SAndroid Build Coastguard Worker dstData += dstWidthStride;
172*89c4ff92SAndroid Build Coastguard Worker srcData += srcWidthStride;
173*89c4ff92SAndroid Build Coastguard Worker }
174*89c4ff92SAndroid Build Coastguard Worker dstData += (static_cast<long>(dstHeightStride) - (dstData - dstPtrChannel));
175*89c4ff92SAndroid Build Coastguard Worker srcData += (static_cast<long>(srcHeightStride) - (srcData - srcPtrChannel));
176*89c4ff92SAndroid Build Coastguard Worker }
177*89c4ff92SAndroid Build Coastguard Worker dstData += (static_cast<long>(dstBatchStride) - (dstData - dstPtrBatch));
178*89c4ff92SAndroid Build Coastguard Worker srcData += (static_cast<long>(srcBatchStride) - (srcData - srcPtrBatch));
179*89c4ff92SAndroid Build Coastguard Worker }
180*89c4ff92SAndroid Build Coastguard Worker dstData += (static_cast<long>(dstDepthStride) - (dstData - dstPtrDepth));
181*89c4ff92SAndroid Build Coastguard Worker srcData += (static_cast<long>(srcDepthStride) - (srcData - srcPtrDepth));
182*89c4ff92SAndroid Build Coastguard Worker }
183*89c4ff92SAndroid Build Coastguard Worker
184*89c4ff92SAndroid Build Coastguard Worker srcTensor->Unmap();
185*89c4ff92SAndroid Build Coastguard Worker dstTensor->Unmap();
186*89c4ff92SAndroid Build Coastguard Worker }
187*89c4ff92SAndroid Build Coastguard Worker
188*89c4ff92SAndroid Build Coastguard Worker template <typename SrcTensorHandleType, typename DstTensorHandleType, typename DescriptorType>
GatherTensorHandlePairs(const DescriptorType & descriptor,std::vector<std::pair<SrcTensorHandleType *,DstTensorHandleType * >> & tensorHandlePairs)189*89c4ff92SAndroid Build Coastguard Worker void GatherTensorHandlePairs(const DescriptorType& descriptor,
190*89c4ff92SAndroid Build Coastguard Worker std::vector<std::pair<SrcTensorHandleType*, DstTensorHandleType*>>& tensorHandlePairs)
191*89c4ff92SAndroid Build Coastguard Worker {
192*89c4ff92SAndroid Build Coastguard Worker const unsigned int numInputs = static_cast<unsigned int>(descriptor.m_Inputs.size());
193*89c4ff92SAndroid Build Coastguard Worker tensorHandlePairs.reserve(numInputs);
194*89c4ff92SAndroid Build Coastguard Worker
195*89c4ff92SAndroid Build Coastguard Worker for (unsigned int i = 0; i < numInputs; ++i)
196*89c4ff92SAndroid Build Coastguard Worker {
197*89c4ff92SAndroid Build Coastguard Worker SrcTensorHandleType* const srcTensorHandle =
198*89c4ff92SAndroid Build Coastguard Worker PolymorphicDowncast<SrcTensorHandleType*>(descriptor.m_Inputs[i]);
199*89c4ff92SAndroid Build Coastguard Worker DstTensorHandleType* const dstTensorHandle =
200*89c4ff92SAndroid Build Coastguard Worker PolymorphicDowncast<DstTensorHandleType*>(descriptor.m_Outputs[i]);
201*89c4ff92SAndroid Build Coastguard Worker
202*89c4ff92SAndroid Build Coastguard Worker tensorHandlePairs.emplace_back(srcTensorHandle, dstTensorHandle);
203*89c4ff92SAndroid Build Coastguard Worker }
204*89c4ff92SAndroid Build Coastguard Worker }
205*89c4ff92SAndroid Build Coastguard Worker
206*89c4ff92SAndroid Build Coastguard Worker int32_t ConvertMaskToACLFormat(int32_t mask, int32_t numDim);
207*89c4ff92SAndroid Build Coastguard Worker
208*89c4ff92SAndroid Build Coastguard Worker armnn::ConstTensor PermuteTensor(const ConstTensorHandle* tensor,
209*89c4ff92SAndroid Build Coastguard Worker const PermutationVector& permutationVector,
210*89c4ff92SAndroid Build Coastguard Worker void* permuteBuffer);
211*89c4ff92SAndroid Build Coastguard Worker
212*89c4ff92SAndroid Build Coastguard Worker void ReshapeWeightsForAcl(TensorInfo& weightInfo, DataLayout dataLayout);
213*89c4ff92SAndroid Build Coastguard Worker
214*89c4ff92SAndroid Build Coastguard Worker TensorInfo ConvertWeightTensorInfoFromArmnnToAcl(const TensorInfo& weightInfo, DataLayout dataLayout);
215*89c4ff92SAndroid Build Coastguard Worker
216*89c4ff92SAndroid Build Coastguard Worker /// Weights for depthwise have a datalayout of [1,H,W,O] = [1,H,W,I*M]
217*89c4ff92SAndroid Build Coastguard Worker /// This function coverts a TensorInfo from [1,H,W,I*M] to [1,I*M,H,W] (if NCHW) or keeps it at [1,H,W,I*M] (if NHWC)
218*89c4ff92SAndroid Build Coastguard Worker /// as required by the compute library
219*89c4ff92SAndroid Build Coastguard Worker /// Returns a tuple of converted weights tensor info and depth multiplier
220*89c4ff92SAndroid Build Coastguard Worker std::tuple<TensorInfo, unsigned int> Convert1HWOTensorInfoToAcl(const TensorInfo& weightInfo,
221*89c4ff92SAndroid Build Coastguard Worker const TensorInfo& inputInfo,
222*89c4ff92SAndroid Build Coastguard Worker const DataLayout dataLayout);
223*89c4ff92SAndroid Build Coastguard Worker
224*89c4ff92SAndroid Build Coastguard Worker armnn::ConstTensor ConvertWeightTensorFromArmnnToAcl(const ConstTensorHandle* weightTensor,
225*89c4ff92SAndroid Build Coastguard Worker DataLayout dataLayout,
226*89c4ff92SAndroid Build Coastguard Worker void* permuteBuffer);
227*89c4ff92SAndroid Build Coastguard Worker
228*89c4ff92SAndroid Build Coastguard Worker /// Weights for depthwise have a datalayout of [1,H,W,O] = [1,H,W,I*M]
229*89c4ff92SAndroid Build Coastguard Worker /// This function coverts a ConstCpuTensorHandle from [1,H,W,I*M] to [1,I*M,H,W] (if NCHW) or
230*89c4ff92SAndroid Build Coastguard Worker /// keeps it at [1,H,W,I*M] (if NHWC) as required by the compute library
231*89c4ff92SAndroid Build Coastguard Worker ///
232*89c4ff92SAndroid Build Coastguard Worker /// \param weightTensor - ConstTensorHandle of weights tensor
233*89c4ff92SAndroid Build Coastguard Worker /// \param inputInfo - TensorInfo of input tensor
234*89c4ff92SAndroid Build Coastguard Worker /// \param dataLayout - DataLayout of the input tensor
235*89c4ff92SAndroid Build Coastguard Worker /// \param permuteBuffer - Pointer to memory with the size of tensor. Used for the permutation
236*89c4ff92SAndroid Build Coastguard Worker /// \return tuple of transformed weights-ConstTensor and depthwise multiplier
237*89c4ff92SAndroid Build Coastguard Worker std::tuple<ConstTensor, unsigned int> Convert1HWOTensorToAcl(const ConstTensorHandle* weightTensor,
238*89c4ff92SAndroid Build Coastguard Worker const TensorInfo& inputInfo,
239*89c4ff92SAndroid Build Coastguard Worker const DataLayout dataLayout,
240*89c4ff92SAndroid Build Coastguard Worker void* permuteBuffer);
241*89c4ff92SAndroid Build Coastguard Worker
242*89c4ff92SAndroid Build Coastguard Worker /// Converts a (weights) tensor from [1, H, W, I*M] = [1, H, W, O] to [M, I, H, W]
243*89c4ff92SAndroid Build Coastguard Worker ///
244*89c4ff92SAndroid Build Coastguard Worker /// \param weightTensor - ConstTensorHandle of the weight tensor that should be converted
245*89c4ff92SAndroid Build Coastguard Worker /// \param inputInfo - TensorInfo of the corresponding input tensor
246*89c4ff92SAndroid Build Coastguard Worker /// \param dataLayout - DataLayout of the input tensor e.g. NHWC or NCHW
247*89c4ff92SAndroid Build Coastguard Worker /// \param permuteBuffer - Memory location with the same size as the weight tensor to write converted data to
248*89c4ff92SAndroid Build Coastguard Worker /// \return - A tuple of ConstTensor and unsigned int which is the converted weightTensor and the depthMultiplier
249*89c4ff92SAndroid Build Coastguard Worker std::tuple<ConstTensor, unsigned int> Convert1HWOtoMIHW(const ConstTensorHandle* weightTensor,
250*89c4ff92SAndroid Build Coastguard Worker const TensorInfo& inputInfo,
251*89c4ff92SAndroid Build Coastguard Worker const DataLayout& dataLayout,
252*89c4ff92SAndroid Build Coastguard Worker void* permuteBuffer);
253*89c4ff92SAndroid Build Coastguard Worker
254*89c4ff92SAndroid Build Coastguard Worker /// Calculates the key index values needed for GatherNd: N, ND, K, W, C (N is always 1)
255*89c4ff92SAndroid Build Coastguard Worker ///
256*89c4ff92SAndroid Build Coastguard Worker /// \param inputInfo0 - TensorInfo of the corresponding input tensor: params
257*89c4ff92SAndroid Build Coastguard Worker /// \param inputInfo1 - TensorInfo of the corresponding input tensor: indices
258*89c4ff92SAndroid Build Coastguard Worker /// \return - A map with names and values for N, ND, K, W, C
259*89c4ff92SAndroid Build Coastguard Worker std::map<std::string, unsigned int> CalculateGatherNdKeyIndices(TensorInfo inputInfo0, TensorInfo inputInfo1);
260*89c4ff92SAndroid Build Coastguard Worker
261*89c4ff92SAndroid Build Coastguard Worker /// Generates a permutation vector of size rank that permutes the 2 most right dimensions
262*89c4ff92SAndroid Build Coastguard Worker ///
263*89c4ff92SAndroid Build Coastguard Worker /// \param rank - Tensor rank, i.e. number of dimensions in the tensors
264*89c4ff92SAndroid Build Coastguard Worker /// \return - A permutation vector that permutes the 2 last dimensions
265*89c4ff92SAndroid Build Coastguard Worker armnn::PermutationVector GeneratePermutationVectorOnLastTwoDimensions(unsigned int rank);
266*89c4ff92SAndroid Build Coastguard Worker
267*89c4ff92SAndroid Build Coastguard Worker } //namespace armnn
268