xref: /aosp_15_r20/external/armnn/src/armnnUtils/Permute.cpp (revision 89c4ff92f2867872bb9e2354d150bf0c8c502810)
1*89c4ff92SAndroid Build Coastguard Worker //
2*89c4ff92SAndroid Build Coastguard Worker // Copyright © 2017 Arm Ltd. All rights reserved.
3*89c4ff92SAndroid Build Coastguard Worker // SPDX-License-Identifier: MIT
4*89c4ff92SAndroid Build Coastguard Worker //
5*89c4ff92SAndroid Build Coastguard Worker 
6*89c4ff92SAndroid Build Coastguard Worker #include <armnn/Tensor.hpp>
7*89c4ff92SAndroid Build Coastguard Worker 
8*89c4ff92SAndroid Build Coastguard Worker #include <armnnUtils/Permute.hpp>
9*89c4ff92SAndroid Build Coastguard Worker 
10*89c4ff92SAndroid Build Coastguard Worker #include "Half.hpp"
11*89c4ff92SAndroid Build Coastguard Worker 
12*89c4ff92SAndroid Build Coastguard Worker #include <cassert>
13*89c4ff92SAndroid Build Coastguard Worker #include <cstring>
14*89c4ff92SAndroid Build Coastguard Worker 
15*89c4ff92SAndroid Build Coastguard Worker namespace
16*89c4ff92SAndroid Build Coastguard Worker {
17*89c4ff92SAndroid Build Coastguard Worker 
18*89c4ff92SAndroid Build Coastguard Worker class PermuteLoop
19*89c4ff92SAndroid Build Coastguard Worker {
20*89c4ff92SAndroid Build Coastguard Worker public:
21*89c4ff92SAndroid Build Coastguard Worker     using size_type = unsigned int;
22*89c4ff92SAndroid Build Coastguard Worker 
PermuteLoop(const armnn::TensorShape & dstShape,const armnn::PermutationVector & mappings)23*89c4ff92SAndroid Build Coastguard Worker     PermuteLoop(const armnn::TensorShape& dstShape, const armnn::PermutationVector& mappings)
24*89c4ff92SAndroid Build Coastguard Worker         : m_DstShape(dstShape)
25*89c4ff92SAndroid Build Coastguard Worker     {
26*89c4ff92SAndroid Build Coastguard Worker         assert(dstShape.GetNumDimensions() == mappings.GetSize());
27*89c4ff92SAndroid Build Coastguard Worker 
28*89c4ff92SAndroid Build Coastguard Worker         const size_type numDims = dstShape.GetNumDimensions();
29*89c4ff92SAndroid Build Coastguard Worker 
30*89c4ff92SAndroid Build Coastguard Worker         size_type srcStride = 1U;
31*89c4ff92SAndroid Build Coastguard Worker         size_type dstStride = 1U;
32*89c4ff92SAndroid Build Coastguard Worker 
33*89c4ff92SAndroid Build Coastguard Worker         for (size_type i = numDims - 1U, k = 0U; k < numDims; ++k, --i)
34*89c4ff92SAndroid Build Coastguard Worker         {
35*89c4ff92SAndroid Build Coastguard Worker             m_SrcStrides[mappings[i]] = srcStride;
36*89c4ff92SAndroid Build Coastguard Worker             m_DstStrides[i] = dstStride;
37*89c4ff92SAndroid Build Coastguard Worker 
38*89c4ff92SAndroid Build Coastguard Worker             srcStride *= dstShape[mappings[i]];
39*89c4ff92SAndroid Build Coastguard Worker             dstStride *= dstShape[i];
40*89c4ff92SAndroid Build Coastguard Worker         }
41*89c4ff92SAndroid Build Coastguard Worker     }
42*89c4ff92SAndroid Build Coastguard Worker 
Unroll(const void * srcData,void * dstData,size_t dataTypeSize)43*89c4ff92SAndroid Build Coastguard Worker     void Unroll(const void* srcData, void* dstData, size_t dataTypeSize)
44*89c4ff92SAndroid Build Coastguard Worker     {
45*89c4ff92SAndroid Build Coastguard Worker         assert(srcData);
46*89c4ff92SAndroid Build Coastguard Worker         assert(dstData);
47*89c4ff92SAndroid Build Coastguard Worker         assert(dataTypeSize > 0);
48*89c4ff92SAndroid Build Coastguard Worker 
49*89c4ff92SAndroid Build Coastguard Worker         const unsigned char* srcDataPtr = reinterpret_cast<const unsigned char*>(srcData);
50*89c4ff92SAndroid Build Coastguard Worker         unsigned char* dstDataPtr       = reinterpret_cast<unsigned char*>(dstData);
51*89c4ff92SAndroid Build Coastguard Worker 
52*89c4ff92SAndroid Build Coastguard Worker         const unsigned char* const srcEndPtr = srcDataPtr + m_DstShape.GetNumElements() * dataTypeSize;
53*89c4ff92SAndroid Build Coastguard Worker         unsigned char* const       dstEndPtr = dstDataPtr + m_DstShape.GetNumElements() * dataTypeSize;
54*89c4ff92SAndroid Build Coastguard Worker 
55*89c4ff92SAndroid Build Coastguard Worker         Unroll(0, srcDataPtr, dstDataPtr, srcEndPtr, dstEndPtr, dataTypeSize);
56*89c4ff92SAndroid Build Coastguard Worker     }
57*89c4ff92SAndroid Build Coastguard Worker 
58*89c4ff92SAndroid Build Coastguard Worker private:
Unroll(size_type dimension,const unsigned char * srcData,unsigned char * dstData,const unsigned char * srcEnd,unsigned char * dstEnd,size_t dataTypeSize)59*89c4ff92SAndroid Build Coastguard Worker     void Unroll(size_type dimension,
60*89c4ff92SAndroid Build Coastguard Worker                 const unsigned char* srcData, unsigned char* dstData,
61*89c4ff92SAndroid Build Coastguard Worker                 const unsigned char* srcEnd, unsigned char* dstEnd,
62*89c4ff92SAndroid Build Coastguard Worker                 size_t dataTypeSize)
63*89c4ff92SAndroid Build Coastguard Worker     {
64*89c4ff92SAndroid Build Coastguard Worker         assert(srcData);
65*89c4ff92SAndroid Build Coastguard Worker         assert(dstData);
66*89c4ff92SAndroid Build Coastguard Worker         assert(srcEnd);
67*89c4ff92SAndroid Build Coastguard Worker         assert(dstEnd);
68*89c4ff92SAndroid Build Coastguard Worker         assert(srcData < srcEnd);
69*89c4ff92SAndroid Build Coastguard Worker         assert(dstData < dstEnd);
70*89c4ff92SAndroid Build Coastguard Worker         assert(dataTypeSize > 0);
71*89c4ff92SAndroid Build Coastguard Worker 
72*89c4ff92SAndroid Build Coastguard Worker         if (dimension >= m_DstShape.GetNumDimensions())
73*89c4ff92SAndroid Build Coastguard Worker         {
74*89c4ff92SAndroid Build Coastguard Worker             ::memcpy(dstData, srcData, dataTypeSize);
75*89c4ff92SAndroid Build Coastguard Worker         }
76*89c4ff92SAndroid Build Coastguard Worker         else
77*89c4ff92SAndroid Build Coastguard Worker         {
78*89c4ff92SAndroid Build Coastguard Worker             for (size_type i = 0; i < m_DstShape[dimension]; i++)
79*89c4ff92SAndroid Build Coastguard Worker             {
80*89c4ff92SAndroid Build Coastguard Worker                 Unroll(dimension + 1, srcData, dstData, srcEnd, dstEnd, dataTypeSize);
81*89c4ff92SAndroid Build Coastguard Worker 
82*89c4ff92SAndroid Build Coastguard Worker                 srcData += m_SrcStrides[dimension] * dataTypeSize;
83*89c4ff92SAndroid Build Coastguard Worker                 dstData += m_DstStrides[dimension] * dataTypeSize;
84*89c4ff92SAndroid Build Coastguard Worker             }
85*89c4ff92SAndroid Build Coastguard Worker         }
86*89c4ff92SAndroid Build Coastguard Worker     }
87*89c4ff92SAndroid Build Coastguard Worker 
88*89c4ff92SAndroid Build Coastguard Worker     armnn::TensorShape m_DstShape;
89*89c4ff92SAndroid Build Coastguard Worker     std::array<size_type, armnn::MaxNumOfTensorDimensions> m_SrcStrides;
90*89c4ff92SAndroid Build Coastguard Worker     std::array<size_type, armnn::MaxNumOfTensorDimensions> m_DstStrides;
91*89c4ff92SAndroid Build Coastguard Worker };
92*89c4ff92SAndroid Build Coastguard Worker 
93*89c4ff92SAndroid Build Coastguard Worker } // namespace
94*89c4ff92SAndroid Build Coastguard Worker 
95*89c4ff92SAndroid Build Coastguard Worker namespace armnnUtils
96*89c4ff92SAndroid Build Coastguard Worker {
97*89c4ff92SAndroid Build Coastguard Worker 
Permuted(const armnn::TensorShape & srcShape,const armnn::PermutationVector & mappings)98*89c4ff92SAndroid Build Coastguard Worker armnn::TensorShape Permuted(const armnn::TensorShape& srcShape,
99*89c4ff92SAndroid Build Coastguard Worker                             const armnn::PermutationVector& mappings)
100*89c4ff92SAndroid Build Coastguard Worker {
101*89c4ff92SAndroid Build Coastguard Worker     assert(srcShape.GetNumDimensions() == mappings.GetSize());
102*89c4ff92SAndroid Build Coastguard Worker 
103*89c4ff92SAndroid Build Coastguard Worker     const unsigned int numDims = mappings.GetSize();
104*89c4ff92SAndroid Build Coastguard Worker     unsigned int outDims[armnn::MaxNumOfTensorDimensions];
105*89c4ff92SAndroid Build Coastguard Worker 
106*89c4ff92SAndroid Build Coastguard Worker     for (unsigned int i = 0U; i < numDims; ++i)
107*89c4ff92SAndroid Build Coastguard Worker     {
108*89c4ff92SAndroid Build Coastguard Worker         outDims[mappings[i]] = srcShape[i];
109*89c4ff92SAndroid Build Coastguard Worker     }
110*89c4ff92SAndroid Build Coastguard Worker 
111*89c4ff92SAndroid Build Coastguard Worker     armnn::TensorShape permutedShape(numDims, outDims);
112*89c4ff92SAndroid Build Coastguard Worker     return permutedShape;
113*89c4ff92SAndroid Build Coastguard Worker }
114*89c4ff92SAndroid Build Coastguard Worker 
Permuted(const armnn::TensorInfo & info,const armnn::PermutationVector & mappings)115*89c4ff92SAndroid Build Coastguard Worker armnn::TensorInfo Permuted(const armnn::TensorInfo& info,
116*89c4ff92SAndroid Build Coastguard Worker                            const armnn::PermutationVector& mappings)
117*89c4ff92SAndroid Build Coastguard Worker {
118*89c4ff92SAndroid Build Coastguard Worker     armnn::TensorInfo outInfo(info);
119*89c4ff92SAndroid Build Coastguard Worker     outInfo.SetShape(Permuted(info.GetShape(), mappings));
120*89c4ff92SAndroid Build Coastguard Worker 
121*89c4ff92SAndroid Build Coastguard Worker     // If TensorInfo has Per-Axis Quantization then it also has a QuantizationDim which needs to
122*89c4ff92SAndroid Build Coastguard Worker     // be permuted according to the mapping
123*89c4ff92SAndroid Build Coastguard Worker     if (info.GetQuantizationDim().has_value())
124*89c4ff92SAndroid Build Coastguard Worker     {
125*89c4ff92SAndroid Build Coastguard Worker         outInfo.SetQuantizationDim(mappings[info.GetQuantizationDim().value()]);
126*89c4ff92SAndroid Build Coastguard Worker     }
127*89c4ff92SAndroid Build Coastguard Worker 
128*89c4ff92SAndroid Build Coastguard Worker     return outInfo;
129*89c4ff92SAndroid Build Coastguard Worker }
130*89c4ff92SAndroid Build Coastguard Worker 
Permute(const armnn::TensorShape & dstShape,const armnn::PermutationVector & mappings,const void * src,void * dst,size_t dataTypeSize)131*89c4ff92SAndroid Build Coastguard Worker void Permute(const armnn::TensorShape& dstShape, const armnn::PermutationVector& mappings,
132*89c4ff92SAndroid Build Coastguard Worker              const void* src, void* dst, size_t dataTypeSize)
133*89c4ff92SAndroid Build Coastguard Worker {
134*89c4ff92SAndroid Build Coastguard Worker     PermuteLoop(dstShape, mappings).Unroll(src, dst, dataTypeSize);
135*89c4ff92SAndroid Build Coastguard Worker }
136*89c4ff92SAndroid Build Coastguard Worker 
137*89c4ff92SAndroid Build Coastguard Worker } // namespace armnnUtils
138