1*89c4ff92SAndroid Build Coastguard Worker //
2*89c4ff92SAndroid Build Coastguard Worker // Copyright © 2019 Arm Ltd and Contributors. All rights reserved.
3*89c4ff92SAndroid Build Coastguard Worker // SPDX-License-Identifier: MIT
4*89c4ff92SAndroid Build Coastguard Worker //
5*89c4ff92SAndroid Build Coastguard Worker
6*89c4ff92SAndroid Build Coastguard Worker #pragma once
7*89c4ff92SAndroid Build Coastguard Worker
8*89c4ff92SAndroid Build Coastguard Worker #include <armnn/Tensor.hpp>
9*89c4ff92SAndroid Build Coastguard Worker #include <armnn/Types.hpp>
10*89c4ff92SAndroid Build Coastguard Worker
11*89c4ff92SAndroid Build Coastguard Worker #include <armnnUtils/Permute.hpp>
12*89c4ff92SAndroid Build Coastguard Worker
13*89c4ff92SAndroid Build Coastguard Worker template<typename T>
PermuteTensorNchwToNhwc(armnn::TensorInfo & tensorInfo,std::vector<T> & tensorData)14*89c4ff92SAndroid Build Coastguard Worker void PermuteTensorNchwToNhwc(armnn::TensorInfo& tensorInfo, std::vector<T>& tensorData)
15*89c4ff92SAndroid Build Coastguard Worker {
16*89c4ff92SAndroid Build Coastguard Worker const armnn::PermutationVector nchwToNhwc = { 0, 3, 1, 2 };
17*89c4ff92SAndroid Build Coastguard Worker
18*89c4ff92SAndroid Build Coastguard Worker tensorInfo = armnnUtils::Permuted(tensorInfo, nchwToNhwc);
19*89c4ff92SAndroid Build Coastguard Worker
20*89c4ff92SAndroid Build Coastguard Worker std::vector<T> tmp(tensorData.size());
21*89c4ff92SAndroid Build Coastguard Worker armnnUtils::Permute(tensorInfo.GetShape(), nchwToNhwc, tensorData.data(), tmp.data(), sizeof(T));
22*89c4ff92SAndroid Build Coastguard Worker tensorData = tmp;
23*89c4ff92SAndroid Build Coastguard Worker }
24*89c4ff92SAndroid Build Coastguard Worker
25*89c4ff92SAndroid Build Coastguard Worker template<typename T>
PermuteTensorNhwcToNchw(armnn::TensorInfo & tensorInfo,std::vector<T> & tensorData)26*89c4ff92SAndroid Build Coastguard Worker void PermuteTensorNhwcToNchw(armnn::TensorInfo& tensorInfo, std::vector<T>& tensorData)
27*89c4ff92SAndroid Build Coastguard Worker {
28*89c4ff92SAndroid Build Coastguard Worker const armnn::PermutationVector nhwcToNchw = { 0, 2, 3, 1 };
29*89c4ff92SAndroid Build Coastguard Worker
30*89c4ff92SAndroid Build Coastguard Worker tensorInfo = armnnUtils::Permuted(tensorInfo, nhwcToNchw);
31*89c4ff92SAndroid Build Coastguard Worker
32*89c4ff92SAndroid Build Coastguard Worker std::vector<T> tmp(tensorData.size());
33*89c4ff92SAndroid Build Coastguard Worker armnnUtils::Permute(tensorInfo.GetShape(), nhwcToNchw, tensorData.data(), tmp.data(), sizeof(T));
34*89c4ff92SAndroid Build Coastguard Worker
35*89c4ff92SAndroid Build Coastguard Worker tensorData = tmp;
36*89c4ff92SAndroid Build Coastguard Worker }
37*89c4ff92SAndroid Build Coastguard Worker
38*89c4ff92SAndroid Build Coastguard Worker template<typename T>
PermuteTensorNdhwcToNcdhw(armnn::TensorInfo & tensorInfo,std::vector<T> & tensorData)39*89c4ff92SAndroid Build Coastguard Worker void PermuteTensorNdhwcToNcdhw(armnn::TensorInfo& tensorInfo, std::vector<T>& tensorData)
40*89c4ff92SAndroid Build Coastguard Worker {
41*89c4ff92SAndroid Build Coastguard Worker const armnn::PermutationVector ndhwcToNcdhw = { 0, 2, 3, 4, 1 };
42*89c4ff92SAndroid Build Coastguard Worker
43*89c4ff92SAndroid Build Coastguard Worker tensorInfo = armnnUtils::Permuted(tensorInfo, ndhwcToNcdhw);
44*89c4ff92SAndroid Build Coastguard Worker
45*89c4ff92SAndroid Build Coastguard Worker std::vector<T> tmp(tensorData.size());
46*89c4ff92SAndroid Build Coastguard Worker armnnUtils::Permute(tensorInfo.GetShape(), ndhwcToNcdhw, tensorData.data(), tmp.data(), sizeof(T));
47*89c4ff92SAndroid Build Coastguard Worker tensorData = tmp;
48*89c4ff92SAndroid Build Coastguard Worker }
49*89c4ff92SAndroid Build Coastguard Worker
50*89c4ff92SAndroid Build Coastguard Worker template<typename T>
PermuteTensorNcdhwToNdhwc(armnn::TensorInfo & tensorInfo,std::vector<T> & tensorData)51*89c4ff92SAndroid Build Coastguard Worker void PermuteTensorNcdhwToNdhwc(armnn::TensorInfo& tensorInfo, std::vector<T>& tensorData)
52*89c4ff92SAndroid Build Coastguard Worker {
53*89c4ff92SAndroid Build Coastguard Worker const armnn::PermutationVector ncdhwToNdhwc = { 0, 4, 1, 2, 3 };
54*89c4ff92SAndroid Build Coastguard Worker
55*89c4ff92SAndroid Build Coastguard Worker tensorInfo = armnnUtils::Permuted(tensorInfo, ncdhwToNdhwc);
56*89c4ff92SAndroid Build Coastguard Worker
57*89c4ff92SAndroid Build Coastguard Worker std::vector<T> tmp(tensorData.size());
58*89c4ff92SAndroid Build Coastguard Worker armnnUtils::Permute(tensorInfo.GetShape(), ncdhwToNdhwc, tensorData.data(), tmp.data(), sizeof(T));
59*89c4ff92SAndroid Build Coastguard Worker tensorData = tmp;
60*89c4ff92SAndroid Build Coastguard Worker }
61