1*89c4ff92SAndroid Build Coastguard Worker //
2*89c4ff92SAndroid Build Coastguard Worker // Copyright © 2017 Arm Ltd. All rights reserved.
3*89c4ff92SAndroid Build Coastguard Worker // SPDX-License-Identifier: MIT
4*89c4ff92SAndroid Build Coastguard Worker //
5*89c4ff92SAndroid Build Coastguard Worker
6*89c4ff92SAndroid Build Coastguard Worker #include <armnn/Tensor.hpp>
7*89c4ff92SAndroid Build Coastguard Worker
8*89c4ff92SAndroid Build Coastguard Worker #include <armnnUtils/Permute.hpp>
9*89c4ff92SAndroid Build Coastguard Worker
10*89c4ff92SAndroid Build Coastguard Worker #include "Half.hpp"
11*89c4ff92SAndroid Build Coastguard Worker
12*89c4ff92SAndroid Build Coastguard Worker #include <cassert>
13*89c4ff92SAndroid Build Coastguard Worker #include <cstring>
14*89c4ff92SAndroid Build Coastguard Worker
15*89c4ff92SAndroid Build Coastguard Worker namespace
16*89c4ff92SAndroid Build Coastguard Worker {
17*89c4ff92SAndroid Build Coastguard Worker
18*89c4ff92SAndroid Build Coastguard Worker class PermuteLoop
19*89c4ff92SAndroid Build Coastguard Worker {
20*89c4ff92SAndroid Build Coastguard Worker public:
21*89c4ff92SAndroid Build Coastguard Worker using size_type = unsigned int;
22*89c4ff92SAndroid Build Coastguard Worker
PermuteLoop(const armnn::TensorShape & dstShape,const armnn::PermutationVector & mappings)23*89c4ff92SAndroid Build Coastguard Worker PermuteLoop(const armnn::TensorShape& dstShape, const armnn::PermutationVector& mappings)
24*89c4ff92SAndroid Build Coastguard Worker : m_DstShape(dstShape)
25*89c4ff92SAndroid Build Coastguard Worker {
26*89c4ff92SAndroid Build Coastguard Worker assert(dstShape.GetNumDimensions() == mappings.GetSize());
27*89c4ff92SAndroid Build Coastguard Worker
28*89c4ff92SAndroid Build Coastguard Worker const size_type numDims = dstShape.GetNumDimensions();
29*89c4ff92SAndroid Build Coastguard Worker
30*89c4ff92SAndroid Build Coastguard Worker size_type srcStride = 1U;
31*89c4ff92SAndroid Build Coastguard Worker size_type dstStride = 1U;
32*89c4ff92SAndroid Build Coastguard Worker
33*89c4ff92SAndroid Build Coastguard Worker for (size_type i = numDims - 1U, k = 0U; k < numDims; ++k, --i)
34*89c4ff92SAndroid Build Coastguard Worker {
35*89c4ff92SAndroid Build Coastguard Worker m_SrcStrides[mappings[i]] = srcStride;
36*89c4ff92SAndroid Build Coastguard Worker m_DstStrides[i] = dstStride;
37*89c4ff92SAndroid Build Coastguard Worker
38*89c4ff92SAndroid Build Coastguard Worker srcStride *= dstShape[mappings[i]];
39*89c4ff92SAndroid Build Coastguard Worker dstStride *= dstShape[i];
40*89c4ff92SAndroid Build Coastguard Worker }
41*89c4ff92SAndroid Build Coastguard Worker }
42*89c4ff92SAndroid Build Coastguard Worker
Unroll(const void * srcData,void * dstData,size_t dataTypeSize)43*89c4ff92SAndroid Build Coastguard Worker void Unroll(const void* srcData, void* dstData, size_t dataTypeSize)
44*89c4ff92SAndroid Build Coastguard Worker {
45*89c4ff92SAndroid Build Coastguard Worker assert(srcData);
46*89c4ff92SAndroid Build Coastguard Worker assert(dstData);
47*89c4ff92SAndroid Build Coastguard Worker assert(dataTypeSize > 0);
48*89c4ff92SAndroid Build Coastguard Worker
49*89c4ff92SAndroid Build Coastguard Worker const unsigned char* srcDataPtr = reinterpret_cast<const unsigned char*>(srcData);
50*89c4ff92SAndroid Build Coastguard Worker unsigned char* dstDataPtr = reinterpret_cast<unsigned char*>(dstData);
51*89c4ff92SAndroid Build Coastguard Worker
52*89c4ff92SAndroid Build Coastguard Worker const unsigned char* const srcEndPtr = srcDataPtr + m_DstShape.GetNumElements() * dataTypeSize;
53*89c4ff92SAndroid Build Coastguard Worker unsigned char* const dstEndPtr = dstDataPtr + m_DstShape.GetNumElements() * dataTypeSize;
54*89c4ff92SAndroid Build Coastguard Worker
55*89c4ff92SAndroid Build Coastguard Worker Unroll(0, srcDataPtr, dstDataPtr, srcEndPtr, dstEndPtr, dataTypeSize);
56*89c4ff92SAndroid Build Coastguard Worker }
57*89c4ff92SAndroid Build Coastguard Worker
58*89c4ff92SAndroid Build Coastguard Worker private:
Unroll(size_type dimension,const unsigned char * srcData,unsigned char * dstData,const unsigned char * srcEnd,unsigned char * dstEnd,size_t dataTypeSize)59*89c4ff92SAndroid Build Coastguard Worker void Unroll(size_type dimension,
60*89c4ff92SAndroid Build Coastguard Worker const unsigned char* srcData, unsigned char* dstData,
61*89c4ff92SAndroid Build Coastguard Worker const unsigned char* srcEnd, unsigned char* dstEnd,
62*89c4ff92SAndroid Build Coastguard Worker size_t dataTypeSize)
63*89c4ff92SAndroid Build Coastguard Worker {
64*89c4ff92SAndroid Build Coastguard Worker assert(srcData);
65*89c4ff92SAndroid Build Coastguard Worker assert(dstData);
66*89c4ff92SAndroid Build Coastguard Worker assert(srcEnd);
67*89c4ff92SAndroid Build Coastguard Worker assert(dstEnd);
68*89c4ff92SAndroid Build Coastguard Worker assert(srcData < srcEnd);
69*89c4ff92SAndroid Build Coastguard Worker assert(dstData < dstEnd);
70*89c4ff92SAndroid Build Coastguard Worker assert(dataTypeSize > 0);
71*89c4ff92SAndroid Build Coastguard Worker
72*89c4ff92SAndroid Build Coastguard Worker if (dimension >= m_DstShape.GetNumDimensions())
73*89c4ff92SAndroid Build Coastguard Worker {
74*89c4ff92SAndroid Build Coastguard Worker ::memcpy(dstData, srcData, dataTypeSize);
75*89c4ff92SAndroid Build Coastguard Worker }
76*89c4ff92SAndroid Build Coastguard Worker else
77*89c4ff92SAndroid Build Coastguard Worker {
78*89c4ff92SAndroid Build Coastguard Worker for (size_type i = 0; i < m_DstShape[dimension]; i++)
79*89c4ff92SAndroid Build Coastguard Worker {
80*89c4ff92SAndroid Build Coastguard Worker Unroll(dimension + 1, srcData, dstData, srcEnd, dstEnd, dataTypeSize);
81*89c4ff92SAndroid Build Coastguard Worker
82*89c4ff92SAndroid Build Coastguard Worker srcData += m_SrcStrides[dimension] * dataTypeSize;
83*89c4ff92SAndroid Build Coastguard Worker dstData += m_DstStrides[dimension] * dataTypeSize;
84*89c4ff92SAndroid Build Coastguard Worker }
85*89c4ff92SAndroid Build Coastguard Worker }
86*89c4ff92SAndroid Build Coastguard Worker }
87*89c4ff92SAndroid Build Coastguard Worker
88*89c4ff92SAndroid Build Coastguard Worker armnn::TensorShape m_DstShape;
89*89c4ff92SAndroid Build Coastguard Worker std::array<size_type, armnn::MaxNumOfTensorDimensions> m_SrcStrides;
90*89c4ff92SAndroid Build Coastguard Worker std::array<size_type, armnn::MaxNumOfTensorDimensions> m_DstStrides;
91*89c4ff92SAndroid Build Coastguard Worker };
92*89c4ff92SAndroid Build Coastguard Worker
93*89c4ff92SAndroid Build Coastguard Worker } // namespace
94*89c4ff92SAndroid Build Coastguard Worker
95*89c4ff92SAndroid Build Coastguard Worker namespace armnnUtils
96*89c4ff92SAndroid Build Coastguard Worker {
97*89c4ff92SAndroid Build Coastguard Worker
Permuted(const armnn::TensorShape & srcShape,const armnn::PermutationVector & mappings)98*89c4ff92SAndroid Build Coastguard Worker armnn::TensorShape Permuted(const armnn::TensorShape& srcShape,
99*89c4ff92SAndroid Build Coastguard Worker const armnn::PermutationVector& mappings)
100*89c4ff92SAndroid Build Coastguard Worker {
101*89c4ff92SAndroid Build Coastguard Worker assert(srcShape.GetNumDimensions() == mappings.GetSize());
102*89c4ff92SAndroid Build Coastguard Worker
103*89c4ff92SAndroid Build Coastguard Worker const unsigned int numDims = mappings.GetSize();
104*89c4ff92SAndroid Build Coastguard Worker unsigned int outDims[armnn::MaxNumOfTensorDimensions];
105*89c4ff92SAndroid Build Coastguard Worker
106*89c4ff92SAndroid Build Coastguard Worker for (unsigned int i = 0U; i < numDims; ++i)
107*89c4ff92SAndroid Build Coastguard Worker {
108*89c4ff92SAndroid Build Coastguard Worker outDims[mappings[i]] = srcShape[i];
109*89c4ff92SAndroid Build Coastguard Worker }
110*89c4ff92SAndroid Build Coastguard Worker
111*89c4ff92SAndroid Build Coastguard Worker armnn::TensorShape permutedShape(numDims, outDims);
112*89c4ff92SAndroid Build Coastguard Worker return permutedShape;
113*89c4ff92SAndroid Build Coastguard Worker }
114*89c4ff92SAndroid Build Coastguard Worker
Permuted(const armnn::TensorInfo & info,const armnn::PermutationVector & mappings)115*89c4ff92SAndroid Build Coastguard Worker armnn::TensorInfo Permuted(const armnn::TensorInfo& info,
116*89c4ff92SAndroid Build Coastguard Worker const armnn::PermutationVector& mappings)
117*89c4ff92SAndroid Build Coastguard Worker {
118*89c4ff92SAndroid Build Coastguard Worker armnn::TensorInfo outInfo(info);
119*89c4ff92SAndroid Build Coastguard Worker outInfo.SetShape(Permuted(info.GetShape(), mappings));
120*89c4ff92SAndroid Build Coastguard Worker
121*89c4ff92SAndroid Build Coastguard Worker // If TensorInfo has Per-Axis Quantization then it also has a QuantizationDim which needs to
122*89c4ff92SAndroid Build Coastguard Worker // be permuted according to the mapping
123*89c4ff92SAndroid Build Coastguard Worker if (info.GetQuantizationDim().has_value())
124*89c4ff92SAndroid Build Coastguard Worker {
125*89c4ff92SAndroid Build Coastguard Worker outInfo.SetQuantizationDim(mappings[info.GetQuantizationDim().value()]);
126*89c4ff92SAndroid Build Coastguard Worker }
127*89c4ff92SAndroid Build Coastguard Worker
128*89c4ff92SAndroid Build Coastguard Worker return outInfo;
129*89c4ff92SAndroid Build Coastguard Worker }
130*89c4ff92SAndroid Build Coastguard Worker
Permute(const armnn::TensorShape & dstShape,const armnn::PermutationVector & mappings,const void * src,void * dst,size_t dataTypeSize)131*89c4ff92SAndroid Build Coastguard Worker void Permute(const armnn::TensorShape& dstShape, const armnn::PermutationVector& mappings,
132*89c4ff92SAndroid Build Coastguard Worker const void* src, void* dst, size_t dataTypeSize)
133*89c4ff92SAndroid Build Coastguard Worker {
134*89c4ff92SAndroid Build Coastguard Worker PermuteLoop(dstShape, mappings).Unroll(src, dst, dataTypeSize);
135*89c4ff92SAndroid Build Coastguard Worker }
136*89c4ff92SAndroid Build Coastguard Worker
137*89c4ff92SAndroid Build Coastguard Worker } // namespace armnnUtils
138