1*89c4ff92SAndroid Build Coastguard Worker //
2*89c4ff92SAndroid Build Coastguard Worker // Copyright © 2017 Arm Ltd. All rights reserved.
3*89c4ff92SAndroid Build Coastguard Worker // SPDX-License-Identifier: MIT
4*89c4ff92SAndroid Build Coastguard Worker //
5*89c4ff92SAndroid Build Coastguard Worker
6*89c4ff92SAndroid Build Coastguard Worker #include <backendsCommon/WorkloadUtils.hpp>
7*89c4ff92SAndroid Build Coastguard Worker
8*89c4ff92SAndroid Build Coastguard Worker #include <armnn/Utils.hpp>
9*89c4ff92SAndroid Build Coastguard Worker #include <armnn/utility/NumericCast.hpp>
10*89c4ff92SAndroid Build Coastguard Worker #include <armnnUtils/DataLayoutIndexed.hpp>
11*89c4ff92SAndroid Build Coastguard Worker
12*89c4ff92SAndroid Build Coastguard Worker #include <fmt/format.h>
13*89c4ff92SAndroid Build Coastguard Worker #include <numeric>
14*89c4ff92SAndroid Build Coastguard Worker
15*89c4ff92SAndroid Build Coastguard Worker namespace armnn
16*89c4ff92SAndroid Build Coastguard Worker {
17*89c4ff92SAndroid Build Coastguard Worker
PermuteTensor(const ConstTensorHandle * tensor,const PermutationVector & permutationVector,void * permuteBuffer)18*89c4ff92SAndroid Build Coastguard Worker armnn::ConstTensor PermuteTensor(const ConstTensorHandle* tensor,
19*89c4ff92SAndroid Build Coastguard Worker const PermutationVector& permutationVector, void* permuteBuffer)
20*89c4ff92SAndroid Build Coastguard Worker {
21*89c4ff92SAndroid Build Coastguard Worker ARMNN_ASSERT_MSG(tensor, "Invalid input tensor");
22*89c4ff92SAndroid Build Coastguard Worker ARMNN_ASSERT_MSG(permuteBuffer, "Invalid permute buffer");
23*89c4ff92SAndroid Build Coastguard Worker
24*89c4ff92SAndroid Build Coastguard Worker TensorInfo tensorInfo = tensor->GetTensorInfo();
25*89c4ff92SAndroid Build Coastguard Worker
26*89c4ff92SAndroid Build Coastguard Worker if (permutationVector.GetSize() > 0)
27*89c4ff92SAndroid Build Coastguard Worker {
28*89c4ff92SAndroid Build Coastguard Worker tensorInfo = armnnUtils::Permuted(tensorInfo, permutationVector);
29*89c4ff92SAndroid Build Coastguard Worker armnnUtils::Permute(tensorInfo.GetShape(), permutationVector,
30*89c4ff92SAndroid Build Coastguard Worker tensor->GetConstTensor<void>(), permuteBuffer,
31*89c4ff92SAndroid Build Coastguard Worker GetDataTypeSize(tensorInfo.GetDataType()));
32*89c4ff92SAndroid Build Coastguard Worker }
33*89c4ff92SAndroid Build Coastguard Worker else
34*89c4ff92SAndroid Build Coastguard Worker {
35*89c4ff92SAndroid Build Coastguard Worker ::memcpy(permuteBuffer, tensor->GetConstTensor<void>(), tensorInfo.GetNumBytes());
36*89c4ff92SAndroid Build Coastguard Worker }
37*89c4ff92SAndroid Build Coastguard Worker tensorInfo.SetConstant(true);
38*89c4ff92SAndroid Build Coastguard Worker return ConstTensor(tensorInfo, permuteBuffer);
39*89c4ff92SAndroid Build Coastguard Worker }
40*89c4ff92SAndroid Build Coastguard Worker
ReshapeWeightsForAcl(TensorInfo & weightInfo,DataLayout dataLayout)41*89c4ff92SAndroid Build Coastguard Worker void ReshapeWeightsForAcl(TensorInfo& weightInfo, DataLayout dataLayout)
42*89c4ff92SAndroid Build Coastguard Worker {
43*89c4ff92SAndroid Build Coastguard Worker // Reshape the weights in-place
44*89c4ff92SAndroid Build Coastguard Worker const TensorShape& weightShape = weightInfo.GetShape();
45*89c4ff92SAndroid Build Coastguard Worker switch (dataLayout)
46*89c4ff92SAndroid Build Coastguard Worker {
47*89c4ff92SAndroid Build Coastguard Worker case DataLayout::NHWC:
48*89c4ff92SAndroid Build Coastguard Worker // The data layout is NHWC, reshape from [ H, W, I, M ] to [ 1, H, W, I * M ]
49*89c4ff92SAndroid Build Coastguard Worker weightInfo.SetShape({ 1,
50*89c4ff92SAndroid Build Coastguard Worker weightShape[0],
51*89c4ff92SAndroid Build Coastguard Worker weightShape[1],
52*89c4ff92SAndroid Build Coastguard Worker weightShape[2] * weightShape[3] });
53*89c4ff92SAndroid Build Coastguard Worker weightInfo.SetShape({ 1,
54*89c4ff92SAndroid Build Coastguard Worker weightShape[0] * weightShape[1],
55*89c4ff92SAndroid Build Coastguard Worker weightShape[2],
56*89c4ff92SAndroid Build Coastguard Worker weightShape[3] });
57*89c4ff92SAndroid Build Coastguard Worker break;
58*89c4ff92SAndroid Build Coastguard Worker case DataLayout::NCHW:
59*89c4ff92SAndroid Build Coastguard Worker default:
60*89c4ff92SAndroid Build Coastguard Worker // The data layout is NCHW, reshape from [ M, I, H, W ] to [ 1, I * M, H, W, ]
61*89c4ff92SAndroid Build Coastguard Worker weightInfo.SetShape({ 1, weightShape[0] * weightShape[1], weightShape[2], weightShape[3] });
62*89c4ff92SAndroid Build Coastguard Worker break;
63*89c4ff92SAndroid Build Coastguard Worker }
64*89c4ff92SAndroid Build Coastguard Worker }
65*89c4ff92SAndroid Build Coastguard Worker
66*89c4ff92SAndroid Build Coastguard Worker template <typename DataType>
ReorderWeightChannelsForAcl(const ConstTensor & weightHandle,DataLayout dataLayout,void * permuteBuffer)67*89c4ff92SAndroid Build Coastguard Worker ConstTensor ReorderWeightChannelsForAcl(const ConstTensor& weightHandle, DataLayout dataLayout, void* permuteBuffer)
68*89c4ff92SAndroid Build Coastguard Worker {
69*89c4ff92SAndroid Build Coastguard Worker DataType* weight = static_cast<DataType*>(permuteBuffer);
70*89c4ff92SAndroid Build Coastguard Worker const TensorShape& weightShape = weightHandle.GetShape();
71*89c4ff92SAndroid Build Coastguard Worker unsigned int multiplier;
72*89c4ff92SAndroid Build Coastguard Worker unsigned int height;
73*89c4ff92SAndroid Build Coastguard Worker unsigned int width;
74*89c4ff92SAndroid Build Coastguard Worker unsigned int inputChannels;
75*89c4ff92SAndroid Build Coastguard Worker switch (dataLayout)
76*89c4ff92SAndroid Build Coastguard Worker {
77*89c4ff92SAndroid Build Coastguard Worker case DataLayout::NHWC: //It actually is [ H, W, I, M ]
78*89c4ff92SAndroid Build Coastguard Worker height = weightShape[0];
79*89c4ff92SAndroid Build Coastguard Worker width = weightShape[1];
80*89c4ff92SAndroid Build Coastguard Worker inputChannels = weightShape[2];
81*89c4ff92SAndroid Build Coastguard Worker multiplier = weightShape[3];
82*89c4ff92SAndroid Build Coastguard Worker break;
83*89c4ff92SAndroid Build Coastguard Worker case DataLayout::NCHW: //It actually is [ M, I, H, W ]
84*89c4ff92SAndroid Build Coastguard Worker default:
85*89c4ff92SAndroid Build Coastguard Worker height = weightShape[2];
86*89c4ff92SAndroid Build Coastguard Worker width = weightShape[3];
87*89c4ff92SAndroid Build Coastguard Worker inputChannels = weightShape[1];
88*89c4ff92SAndroid Build Coastguard Worker multiplier = weightShape[0];
89*89c4ff92SAndroid Build Coastguard Worker break;
90*89c4ff92SAndroid Build Coastguard Worker }
91*89c4ff92SAndroid Build Coastguard Worker
92*89c4ff92SAndroid Build Coastguard Worker std::vector<DataType> weightAclOrder(height*width*inputChannels*multiplier);
93*89c4ff92SAndroid Build Coastguard Worker unsigned int destinationWeightsChannel;
94*89c4ff92SAndroid Build Coastguard Worker unsigned int totalChannels = inputChannels * multiplier;
95*89c4ff92SAndroid Build Coastguard Worker unsigned int channelSize = height * width;
96*89c4ff92SAndroid Build Coastguard Worker unsigned int inputChannel = 0;
97*89c4ff92SAndroid Build Coastguard Worker
98*89c4ff92SAndroid Build Coastguard Worker for (unsigned int originWeightsChannel = 0; originWeightsChannel < totalChannels; originWeightsChannel++)
99*89c4ff92SAndroid Build Coastguard Worker {
100*89c4ff92SAndroid Build Coastguard Worker inputChannel = originWeightsChannel % inputChannels;
101*89c4ff92SAndroid Build Coastguard Worker destinationWeightsChannel = (originWeightsChannel - inputChannel) / inputChannels + multiplier * inputChannel;
102*89c4ff92SAndroid Build Coastguard Worker
103*89c4ff92SAndroid Build Coastguard Worker for (unsigned int i = 0; i < channelSize; i++)
104*89c4ff92SAndroid Build Coastguard Worker {
105*89c4ff92SAndroid Build Coastguard Worker weightAclOrder[i + destinationWeightsChannel * channelSize] =
106*89c4ff92SAndroid Build Coastguard Worker weight[i + originWeightsChannel * channelSize];
107*89c4ff92SAndroid Build Coastguard Worker }
108*89c4ff92SAndroid Build Coastguard Worker }
109*89c4ff92SAndroid Build Coastguard Worker
110*89c4ff92SAndroid Build Coastguard Worker ::memcpy(permuteBuffer, weightAclOrder.data(), weightHandle.GetInfo().GetNumBytes());
111*89c4ff92SAndroid Build Coastguard Worker return ConstTensor(weightHandle.GetInfo(), permuteBuffer);
112*89c4ff92SAndroid Build Coastguard Worker }
113*89c4ff92SAndroid Build Coastguard Worker
114*89c4ff92SAndroid Build Coastguard Worker
ConvertWeightTensorInfoFromArmnnToAcl(const TensorInfo & weightInfo,DataLayout dataLayout)115*89c4ff92SAndroid Build Coastguard Worker TensorInfo ConvertWeightTensorInfoFromArmnnToAcl(const TensorInfo& weightInfo, DataLayout dataLayout)
116*89c4ff92SAndroid Build Coastguard Worker {
117*89c4ff92SAndroid Build Coastguard Worker // Convert the weight format from ArmNN's [ M, I, H, W ] (does NOT depend on the data layout) to either
118*89c4ff92SAndroid Build Coastguard Worker // [ 1, H, W, I * M ] (if NHWC) or [ 1, I * M, H, W ] (if NCHW), as required by the compute library
119*89c4ff92SAndroid Build Coastguard Worker
120*89c4ff92SAndroid Build Coastguard Worker // 1. Permute the weights if necessary
121*89c4ff92SAndroid Build Coastguard Worker // If the data layout is NCHW no permutation is necessary, as a reshape to [ 1, I * M, H, W ] can be better done
122*89c4ff92SAndroid Build Coastguard Worker // starting from the current shape of [ M, I, H, W ]
123*89c4ff92SAndroid Build Coastguard Worker TensorInfo weightPermutedInfo(weightInfo);
124*89c4ff92SAndroid Build Coastguard Worker if (dataLayout == DataLayout::NHWC)
125*89c4ff92SAndroid Build Coastguard Worker {
126*89c4ff92SAndroid Build Coastguard Worker // The data layout is NHWC, then permute the weights from [ M, I, H, W ] to [ H, W, I, M ]
127*89c4ff92SAndroid Build Coastguard Worker PermutationVector permutationVector{ 3, 2, 0, 1 };
128*89c4ff92SAndroid Build Coastguard Worker weightPermutedInfo = armnnUtils::Permuted(weightInfo, permutationVector);
129*89c4ff92SAndroid Build Coastguard Worker }
130*89c4ff92SAndroid Build Coastguard Worker
131*89c4ff92SAndroid Build Coastguard Worker // 2. Reshape the weights
132*89c4ff92SAndroid Build Coastguard Worker ReshapeWeightsForAcl(weightPermutedInfo, dataLayout);
133*89c4ff92SAndroid Build Coastguard Worker
134*89c4ff92SAndroid Build Coastguard Worker // 3. Return the permuted weight info
135*89c4ff92SAndroid Build Coastguard Worker return weightPermutedInfo;
136*89c4ff92SAndroid Build Coastguard Worker }
137*89c4ff92SAndroid Build Coastguard Worker
138*89c4ff92SAndroid Build Coastguard Worker
Convert1HWOTensorToAcl(const ConstTensorHandle * weightTensor,const TensorInfo & inputInfo,const DataLayout dataLayout,void * permuteBuffer)139*89c4ff92SAndroid Build Coastguard Worker std::tuple<ConstTensor, unsigned int> Convert1HWOTensorToAcl(const ConstTensorHandle* weightTensor,
140*89c4ff92SAndroid Build Coastguard Worker const TensorInfo& inputInfo,
141*89c4ff92SAndroid Build Coastguard Worker const DataLayout dataLayout,
142*89c4ff92SAndroid Build Coastguard Worker void* permuteBuffer)
143*89c4ff92SAndroid Build Coastguard Worker {
144*89c4ff92SAndroid Build Coastguard Worker TensorInfo weightsInfo = weightTensor->GetTensorInfo();
145*89c4ff92SAndroid Build Coastguard Worker unsigned int depthMultiplier = 1;
146*89c4ff92SAndroid Build Coastguard Worker PermutationVector permutationVector{};
147*89c4ff92SAndroid Build Coastguard Worker if (dataLayout == armnn::DataLayout::NHWC)
148*89c4ff92SAndroid Build Coastguard Worker {
149*89c4ff92SAndroid Build Coastguard Worker // No permutation required. Data layouts are the same.
150*89c4ff92SAndroid Build Coastguard Worker
151*89c4ff92SAndroid Build Coastguard Worker depthMultiplier = weightsInfo.GetShape()[3] / inputInfo.GetShape()[3];
152*89c4ff92SAndroid Build Coastguard Worker }
153*89c4ff92SAndroid Build Coastguard Worker else if (dataLayout == armnn::DataLayout::NCHW)
154*89c4ff92SAndroid Build Coastguard Worker {
155*89c4ff92SAndroid Build Coastguard Worker // [ 1, H, W, I*M] --> [ 1, I * M, H, W ]
156*89c4ff92SAndroid Build Coastguard Worker depthMultiplier = weightsInfo.GetShape()[3] / inputInfo.GetShape()[1];
157*89c4ff92SAndroid Build Coastguard Worker permutationVector = { 0, 2, 3, 1 };
158*89c4ff92SAndroid Build Coastguard Worker }
159*89c4ff92SAndroid Build Coastguard Worker else
160*89c4ff92SAndroid Build Coastguard Worker {
161*89c4ff92SAndroid Build Coastguard Worker throw InvalidArgumentException(fmt::format("Unknown data layout for tensor conversion: {}",
162*89c4ff92SAndroid Build Coastguard Worker GetDataLayoutName(dataLayout)));
163*89c4ff92SAndroid Build Coastguard Worker }
164*89c4ff92SAndroid Build Coastguard Worker
165*89c4ff92SAndroid Build Coastguard Worker ConstTensor weightsPermuted = PermuteTensor(weightTensor, permutationVector, permuteBuffer);
166*89c4ff92SAndroid Build Coastguard Worker
167*89c4ff92SAndroid Build Coastguard Worker return std::make_tuple(weightsPermuted, depthMultiplier);
168*89c4ff92SAndroid Build Coastguard Worker }
169*89c4ff92SAndroid Build Coastguard Worker
Convert1HWOTensorInfoToAcl(const TensorInfo & weightInfo,const TensorInfo & inputInfo,const DataLayout dataLayout)170*89c4ff92SAndroid Build Coastguard Worker std::tuple<TensorInfo, unsigned int> Convert1HWOTensorInfoToAcl(const TensorInfo& weightInfo,
171*89c4ff92SAndroid Build Coastguard Worker const TensorInfo& inputInfo,
172*89c4ff92SAndroid Build Coastguard Worker const DataLayout dataLayout)
173*89c4ff92SAndroid Build Coastguard Worker {
174*89c4ff92SAndroid Build Coastguard Worker unsigned int aclDepthMultiplier = 1;
175*89c4ff92SAndroid Build Coastguard Worker TensorInfo weightsPermuted;
176*89c4ff92SAndroid Build Coastguard Worker if (dataLayout == armnn::DataLayout::NHWC)
177*89c4ff92SAndroid Build Coastguard Worker {
178*89c4ff92SAndroid Build Coastguard Worker // No permutation required. Input and weights data layouts are the same.
179*89c4ff92SAndroid Build Coastguard Worker aclDepthMultiplier = weightInfo.GetShape()[3] / inputInfo.GetShape()[3];
180*89c4ff92SAndroid Build Coastguard Worker weightsPermuted = weightInfo;
181*89c4ff92SAndroid Build Coastguard Worker }
182*89c4ff92SAndroid Build Coastguard Worker
183*89c4ff92SAndroid Build Coastguard Worker else if (dataLayout == armnn::DataLayout::NCHW)
184*89c4ff92SAndroid Build Coastguard Worker {
185*89c4ff92SAndroid Build Coastguard Worker // Weights permutation required. Weights [N,H,W,C] and input [N,C,H,W] data layouts are different.
186*89c4ff92SAndroid Build Coastguard Worker // [ 1, H, W, I*M] --> [ 1, I * M, H, W ]
187*89c4ff92SAndroid Build Coastguard Worker aclDepthMultiplier = weightInfo.GetShape()[3] / inputInfo.GetShape()[1];
188*89c4ff92SAndroid Build Coastguard Worker PermutationVector permutationVector{ 0, 2, 3, 1 };
189*89c4ff92SAndroid Build Coastguard Worker weightsPermuted = armnnUtils::Permuted(weightInfo, permutationVector);
190*89c4ff92SAndroid Build Coastguard Worker }
191*89c4ff92SAndroid Build Coastguard Worker else
192*89c4ff92SAndroid Build Coastguard Worker {
193*89c4ff92SAndroid Build Coastguard Worker throw InvalidArgumentException(fmt::format("Unknown data layout for tensor info conversion: {}",
194*89c4ff92SAndroid Build Coastguard Worker GetDataLayoutName(dataLayout)));
195*89c4ff92SAndroid Build Coastguard Worker }
196*89c4ff92SAndroid Build Coastguard Worker
197*89c4ff92SAndroid Build Coastguard Worker return std::make_tuple(weightsPermuted, aclDepthMultiplier);
198*89c4ff92SAndroid Build Coastguard Worker }
199*89c4ff92SAndroid Build Coastguard Worker
200*89c4ff92SAndroid Build Coastguard Worker
Convert1HWOtoMIHW(const ConstTensorHandle * weightTensor,const TensorInfo & inputInfo,const DataLayout & dataLayout,void * permuteBuffer)201*89c4ff92SAndroid Build Coastguard Worker std::tuple<ConstTensor, unsigned int> Convert1HWOtoMIHW(const ConstTensorHandle* weightTensor,
202*89c4ff92SAndroid Build Coastguard Worker const TensorInfo& inputInfo,
203*89c4ff92SAndroid Build Coastguard Worker const DataLayout& dataLayout,
204*89c4ff92SAndroid Build Coastguard Worker void* permuteBuffer)
205*89c4ff92SAndroid Build Coastguard Worker {
206*89c4ff92SAndroid Build Coastguard Worker TensorInfo weightsInfo = weightTensor->GetTensorInfo();
207*89c4ff92SAndroid Build Coastguard Worker
208*89c4ff92SAndroid Build Coastguard Worker if (weightsInfo.HasPerAxisQuantization())
209*89c4ff92SAndroid Build Coastguard Worker {
210*89c4ff92SAndroid Build Coastguard Worker throw InvalidArgumentException("Can't convert tensor from [1,H,W,Cout] to [M,Cin,H,W] when per channel "
211*89c4ff92SAndroid Build Coastguard Worker "quantization is applied.");
212*89c4ff92SAndroid Build Coastguard Worker }
213*89c4ff92SAndroid Build Coastguard Worker
214*89c4ff92SAndroid Build Coastguard Worker // Reshape weights [ 1, H, W, I*M ] --> [ H, W, I, M ]
215*89c4ff92SAndroid Build Coastguard Worker auto weightsShape = weightsInfo.GetShape();
216*89c4ff92SAndroid Build Coastguard Worker auto channelIndex = armnnUtils::DataLayoutIndexed(dataLayout).GetChannelsIndex();
217*89c4ff92SAndroid Build Coastguard Worker unsigned int depthMultiplier = weightsShape[3] / inputInfo.GetShape()[channelIndex];
218*89c4ff92SAndroid Build Coastguard Worker weightsInfo.SetShape({ weightsShape[1],
219*89c4ff92SAndroid Build Coastguard Worker weightsShape[2],
220*89c4ff92SAndroid Build Coastguard Worker inputInfo.GetShape()[channelIndex],
221*89c4ff92SAndroid Build Coastguard Worker depthMultiplier});
222*89c4ff92SAndroid Build Coastguard Worker
223*89c4ff92SAndroid Build Coastguard Worker // Permute [ H, W, I, M ] --> [ M, I, H, W ]
224*89c4ff92SAndroid Build Coastguard Worker PermutationVector permutationVector = { 2, 3, 1, 0 };
225*89c4ff92SAndroid Build Coastguard Worker ConstTensor weightsPermuted = PermuteTensor(weightTensor, permutationVector, permuteBuffer);
226*89c4ff92SAndroid Build Coastguard Worker
227*89c4ff92SAndroid Build Coastguard Worker return std::make_tuple(weightsPermuted, depthMultiplier);
228*89c4ff92SAndroid Build Coastguard Worker }
229*89c4ff92SAndroid Build Coastguard Worker
ConvertWeightTensorFromArmnnToAcl(const ConstTensorHandle * weightTensor,DataLayout dataLayout,void * permuteBuffer)230*89c4ff92SAndroid Build Coastguard Worker armnn::ConstTensor ConvertWeightTensorFromArmnnToAcl(const ConstTensorHandle* weightTensor,
231*89c4ff92SAndroid Build Coastguard Worker DataLayout dataLayout,
232*89c4ff92SAndroid Build Coastguard Worker void* permuteBuffer)
233*89c4ff92SAndroid Build Coastguard Worker {
234*89c4ff92SAndroid Build Coastguard Worker ARMNN_ASSERT_MSG(weightTensor, "Invalid input tensor");
235*89c4ff92SAndroid Build Coastguard Worker ARMNN_ASSERT_MSG(permuteBuffer, "Invalid permute buffer");
236*89c4ff92SAndroid Build Coastguard Worker
237*89c4ff92SAndroid Build Coastguard Worker auto multiplier = weightTensor->GetTensorInfo().GetShape()[0];
238*89c4ff92SAndroid Build Coastguard Worker auto inputChannels = weightTensor->GetTensorInfo().GetShape()[1];
239*89c4ff92SAndroid Build Coastguard Worker
240*89c4ff92SAndroid Build Coastguard Worker // Convert the weight format from ArmNN's [ M, I, H, W ] (does NOT depend on the data layout) to either
241*89c4ff92SAndroid Build Coastguard Worker // [ 1, H, W, I * M ] (if NHWC) or [ 1, I * M, H, W ] (if NCHW), as required by the compute library
242*89c4ff92SAndroid Build Coastguard Worker
243*89c4ff92SAndroid Build Coastguard Worker // 1. Permute the weights if necessary
244*89c4ff92SAndroid Build Coastguard Worker // If the data layout is NCHW no permutation is necessary, as a reshape to [ 1, I * M, H, W ] can be better done
245*89c4ff92SAndroid Build Coastguard Worker // starting from the current shape of [ M, I, H, W ]
246*89c4ff92SAndroid Build Coastguard Worker // If no permutation is necessary, leave the permutation vector empty
247*89c4ff92SAndroid Build Coastguard Worker PermutationVector permutationVector{};
248*89c4ff92SAndroid Build Coastguard Worker if (dataLayout == DataLayout::NHWC)
249*89c4ff92SAndroid Build Coastguard Worker {
250*89c4ff92SAndroid Build Coastguard Worker // The data layout is NHWC, then permute the weights from [ M, I, H, W ] to [ H, W, I, M ]
251*89c4ff92SAndroid Build Coastguard Worker permutationVector = { 3, 2, 0, 1 };
252*89c4ff92SAndroid Build Coastguard Worker }
253*89c4ff92SAndroid Build Coastguard Worker ConstTensor weightPermuted = PermuteTensor(weightTensor, permutationVector, permuteBuffer);
254*89c4ff92SAndroid Build Coastguard Worker
255*89c4ff92SAndroid Build Coastguard Worker // Shuffle the weights data to obtain the channel order needed used by Acl
256*89c4ff92SAndroid Build Coastguard Worker if (multiplier > 1 && inputChannels > 1 && dataLayout == DataLayout::NCHW)
257*89c4ff92SAndroid Build Coastguard Worker {
258*89c4ff92SAndroid Build Coastguard Worker switch (weightPermuted.GetDataType())
259*89c4ff92SAndroid Build Coastguard Worker {
260*89c4ff92SAndroid Build Coastguard Worker case DataType::Float32:
261*89c4ff92SAndroid Build Coastguard Worker weightPermuted = ReorderWeightChannelsForAcl<float>(weightPermuted, dataLayout, permuteBuffer);
262*89c4ff92SAndroid Build Coastguard Worker break;
263*89c4ff92SAndroid Build Coastguard Worker case DataType::Float16:
264*89c4ff92SAndroid Build Coastguard Worker weightPermuted =
265*89c4ff92SAndroid Build Coastguard Worker ReorderWeightChannelsForAcl<half_float::half>(weightPermuted, dataLayout, permuteBuffer);
266*89c4ff92SAndroid Build Coastguard Worker break;
267*89c4ff92SAndroid Build Coastguard Worker case DataType::QAsymmS8:
268*89c4ff92SAndroid Build Coastguard Worker case DataType::QAsymmU8:
269*89c4ff92SAndroid Build Coastguard Worker weightPermuted = ReorderWeightChannelsForAcl<uint8_t>(weightPermuted, dataLayout, permuteBuffer);
270*89c4ff92SAndroid Build Coastguard Worker break;
271*89c4ff92SAndroid Build Coastguard Worker case DataType::QSymmS8:
272*89c4ff92SAndroid Build Coastguard Worker weightPermuted = ReorderWeightChannelsForAcl<int8_t>(weightPermuted, dataLayout, permuteBuffer);
273*89c4ff92SAndroid Build Coastguard Worker break;
274*89c4ff92SAndroid Build Coastguard Worker default:
275*89c4ff92SAndroid Build Coastguard Worker break;
276*89c4ff92SAndroid Build Coastguard Worker }
277*89c4ff92SAndroid Build Coastguard Worker }
278*89c4ff92SAndroid Build Coastguard Worker
279*89c4ff92SAndroid Build Coastguard Worker // 2. Reshape the weights
280*89c4ff92SAndroid Build Coastguard Worker ReshapeWeightsForAcl(weightPermuted.GetInfo(), dataLayout);
281*89c4ff92SAndroid Build Coastguard Worker
282*89c4ff92SAndroid Build Coastguard Worker // 3. Return both the tensor and the allocated storage to ensure that the data stays alive
283*89c4ff92SAndroid Build Coastguard Worker return weightPermuted;
284*89c4ff92SAndroid Build Coastguard Worker }
285*89c4ff92SAndroid Build Coastguard Worker
ConvertMaskToACLFormat(int32_t mask,int32_t numDim)286*89c4ff92SAndroid Build Coastguard Worker int32_t ConvertMaskToACLFormat(int32_t mask, int32_t numDim)
287*89c4ff92SAndroid Build Coastguard Worker {
288*89c4ff92SAndroid Build Coastguard Worker int32_t reversedMask = 0;
289*89c4ff92SAndroid Build Coastguard Worker for (unsigned int i = 0; i < armnn::numeric_cast<unsigned int>(numDim); ++i)
290*89c4ff92SAndroid Build Coastguard Worker {
291*89c4ff92SAndroid Build Coastguard Worker // Check if bit set in mask for each dimension
292*89c4ff92SAndroid Build Coastguard Worker int32_t bit = (mask & 1 << i) != 0;
293*89c4ff92SAndroid Build Coastguard Worker // Increment the new mask with the bits reversed
294*89c4ff92SAndroid Build Coastguard Worker reversedMask += (bit << std::max(numDim-(armnn::numeric_cast<int>(i)+1), 0));
295*89c4ff92SAndroid Build Coastguard Worker }
296*89c4ff92SAndroid Build Coastguard Worker
297*89c4ff92SAndroid Build Coastguard Worker return reversedMask;
298*89c4ff92SAndroid Build Coastguard Worker }
299*89c4ff92SAndroid Build Coastguard Worker
CalculateGatherNdKeyIndices(TensorInfo inputInfo0,TensorInfo inputInfo1)300*89c4ff92SAndroid Build Coastguard Worker std::map<std::string, unsigned int> CalculateGatherNdKeyIndices(TensorInfo inputInfo0, TensorInfo inputInfo1)
301*89c4ff92SAndroid Build Coastguard Worker {
302*89c4ff92SAndroid Build Coastguard Worker std::vector<unsigned int> paramsShape;
303*89c4ff92SAndroid Build Coastguard Worker for (unsigned int i = 0; i < inputInfo0.GetNumDimensions(); ++i)
304*89c4ff92SAndroid Build Coastguard Worker {
305*89c4ff92SAndroid Build Coastguard Worker paramsShape.push_back(inputInfo0.GetShape()[i]);
306*89c4ff92SAndroid Build Coastguard Worker }
307*89c4ff92SAndroid Build Coastguard Worker
308*89c4ff92SAndroid Build Coastguard Worker std::vector<unsigned int> indicesShape;
309*89c4ff92SAndroid Build Coastguard Worker for (unsigned int i = 0; i < inputInfo1.GetNumDimensions(); ++i)
310*89c4ff92SAndroid Build Coastguard Worker {
311*89c4ff92SAndroid Build Coastguard Worker indicesShape.push_back(inputInfo1.GetShape()[i]);
312*89c4ff92SAndroid Build Coastguard Worker }
313*89c4ff92SAndroid Build Coastguard Worker
314*89c4ff92SAndroid Build Coastguard Worker std::map<std::string, unsigned int> keyIndices;
315*89c4ff92SAndroid Build Coastguard Worker
316*89c4ff92SAndroid Build Coastguard Worker // N: number of batches
317*89c4ff92SAndroid Build Coastguard Worker keyIndices["N"] = 1;
318*89c4ff92SAndroid Build Coastguard Worker
319*89c4ff92SAndroid Build Coastguard Worker // ND: number of dimensions that are sliced from params
320*89c4ff92SAndroid Build Coastguard Worker keyIndices["ND"] = indicesShape.back();
321*89c4ff92SAndroid Build Coastguard Worker
322*89c4ff92SAndroid Build Coastguard Worker // W: number of indices in each batch (all but the last dimension)
323*89c4ff92SAndroid Build Coastguard Worker keyIndices["W"] =
324*89c4ff92SAndroid Build Coastguard Worker static_cast<unsigned int>(std::accumulate(std::begin(indicesShape),
325*89c4ff92SAndroid Build Coastguard Worker std::end(indicesShape) - 1,
326*89c4ff92SAndroid Build Coastguard Worker 1,
327*89c4ff92SAndroid Build Coastguard Worker std::multiplies<>() ));
328*89c4ff92SAndroid Build Coastguard Worker // K: range of each index
329*89c4ff92SAndroid Build Coastguard Worker keyIndices["K"] =
330*89c4ff92SAndroid Build Coastguard Worker static_cast<unsigned int>(std::accumulate(std::begin(paramsShape),
331*89c4ff92SAndroid Build Coastguard Worker std::begin(paramsShape) + static_cast<int>(keyIndices["ND"]),
332*89c4ff92SAndroid Build Coastguard Worker 1,
333*89c4ff92SAndroid Build Coastguard Worker std::multiplies<>() ));
334*89c4ff92SAndroid Build Coastguard Worker // C: number of channels for each index
335*89c4ff92SAndroid Build Coastguard Worker keyIndices["C"] =
336*89c4ff92SAndroid Build Coastguard Worker static_cast<unsigned int>(std::accumulate(std::begin(paramsShape) + static_cast<int>(keyIndices["ND"]),
337*89c4ff92SAndroid Build Coastguard Worker std::end(paramsShape),
338*89c4ff92SAndroid Build Coastguard Worker 1,
339*89c4ff92SAndroid Build Coastguard Worker std::multiplies<>() ));
340*89c4ff92SAndroid Build Coastguard Worker
341*89c4ff92SAndroid Build Coastguard Worker return keyIndices;
342*89c4ff92SAndroid Build Coastguard Worker }
343*89c4ff92SAndroid Build Coastguard Worker
GeneratePermutationVectorOnLastTwoDimensions(unsigned int rank)344*89c4ff92SAndroid Build Coastguard Worker armnn::PermutationVector GeneratePermutationVectorOnLastTwoDimensions(unsigned int rank)
345*89c4ff92SAndroid Build Coastguard Worker {
346*89c4ff92SAndroid Build Coastguard Worker armnn::PermutationVector permutationVector{};
347*89c4ff92SAndroid Build Coastguard Worker switch (rank)
348*89c4ff92SAndroid Build Coastguard Worker {
349*89c4ff92SAndroid Build Coastguard Worker case 2:
350*89c4ff92SAndroid Build Coastguard Worker permutationVector = {1U, 0U};
351*89c4ff92SAndroid Build Coastguard Worker break;
352*89c4ff92SAndroid Build Coastguard Worker case 3:
353*89c4ff92SAndroid Build Coastguard Worker permutationVector = {0U, 2U, 1U};
354*89c4ff92SAndroid Build Coastguard Worker break;
355*89c4ff92SAndroid Build Coastguard Worker case 4:
356*89c4ff92SAndroid Build Coastguard Worker permutationVector = {0U, 1U, 3U, 2U};
357*89c4ff92SAndroid Build Coastguard Worker break;
358*89c4ff92SAndroid Build Coastguard Worker default:
359*89c4ff92SAndroid Build Coastguard Worker throw Exception("Invalid number of dimensions.");
360*89c4ff92SAndroid Build Coastguard Worker }
361*89c4ff92SAndroid Build Coastguard Worker return permutationVector;
362*89c4ff92SAndroid Build Coastguard Worker }
363*89c4ff92SAndroid Build Coastguard Worker
364*89c4ff92SAndroid Build Coastguard Worker } // namespace armnn
365