1*89c4ff92SAndroid Build Coastguard Worker //
2*89c4ff92SAndroid Build Coastguard Worker // Copyright © 2017,2022 Arm Ltd and Contributors. All rights reserved.
3*89c4ff92SAndroid Build Coastguard Worker // SPDX-License-Identifier: MIT
4*89c4ff92SAndroid Build Coastguard Worker //
5*89c4ff92SAndroid Build Coastguard Worker #pragma once
6*89c4ff92SAndroid Build Coastguard Worker 
7*89c4ff92SAndroid Build Coastguard Worker #include <armnnUtils/Permute.hpp>
8*89c4ff92SAndroid Build Coastguard Worker 
9*89c4ff92SAndroid Build Coastguard Worker #include <armnnUtils/QuantizeHelper.hpp>
10*89c4ff92SAndroid Build Coastguard Worker #include <ResolveType.hpp>
11*89c4ff92SAndroid Build Coastguard Worker 
12*89c4ff92SAndroid Build Coastguard Worker #include <CommonTestUtils.hpp>
13*89c4ff92SAndroid Build Coastguard Worker 
14*89c4ff92SAndroid Build Coastguard Worker #include <map>
15*89c4ff92SAndroid Build Coastguard Worker #include <vector>
16*89c4ff92SAndroid Build Coastguard Worker 
17*89c4ff92SAndroid Build Coastguard Worker namespace
18*89c4ff92SAndroid Build Coastguard Worker {
19*89c4ff92SAndroid Build Coastguard Worker 
CreateTransposeConvolution2dNetwork(const armnn::TransposeConvolution2dDescriptor & descriptor,const armnn::TensorInfo & inputInfo,const armnn::TensorInfo & outputInfo,const armnn::ConstTensor & weights,const armnn::Optional<armnn::ConstTensor> & biases)20*89c4ff92SAndroid Build Coastguard Worker INetworkPtr CreateTransposeConvolution2dNetwork(const armnn::TransposeConvolution2dDescriptor& descriptor,
21*89c4ff92SAndroid Build Coastguard Worker                                                 const armnn::TensorInfo& inputInfo,
22*89c4ff92SAndroid Build Coastguard Worker                                                 const armnn::TensorInfo& outputInfo,
23*89c4ff92SAndroid Build Coastguard Worker                                                 const armnn::ConstTensor& weights,
24*89c4ff92SAndroid Build Coastguard Worker                                                 const armnn::Optional<armnn::ConstTensor>& biases)
25*89c4ff92SAndroid Build Coastguard Worker {
26*89c4ff92SAndroid Build Coastguard Worker     using namespace armnn;
27*89c4ff92SAndroid Build Coastguard Worker 
28*89c4ff92SAndroid Build Coastguard Worker     INetworkPtr network(INetwork::Create());
29*89c4ff92SAndroid Build Coastguard Worker     IConnectableLayer* input = network->AddInputLayer(0, "input");
30*89c4ff92SAndroid Build Coastguard Worker     IConnectableLayer* transposeConvolution2d =
31*89c4ff92SAndroid Build Coastguard Worker         network->AddTransposeConvolution2dLayer(descriptor, weights, biases, "transposeConvolution2d");
32*89c4ff92SAndroid Build Coastguard Worker     IConnectableLayer* output = network->AddOutputLayer(0, "output");
33*89c4ff92SAndroid Build Coastguard Worker 
34*89c4ff92SAndroid Build Coastguard Worker     Connect(input, transposeConvolution2d, inputInfo, 0, 0);
35*89c4ff92SAndroid Build Coastguard Worker     Connect(transposeConvolution2d, output, outputInfo, 0, 0);
36*89c4ff92SAndroid Build Coastguard Worker 
37*89c4ff92SAndroid Build Coastguard Worker     return network;
38*89c4ff92SAndroid Build Coastguard Worker }
39*89c4ff92SAndroid Build Coastguard Worker 
40*89c4ff92SAndroid Build Coastguard Worker } // anonymous namespace
41*89c4ff92SAndroid Build Coastguard Worker 
42*89c4ff92SAndroid Build Coastguard Worker template<armnn::DataType ArmnnType, armnn::DataType ArmnnBType>
TransposeConvolution2dEndToEnd(const std::vector<armnn::BackendId> & backends,armnn::DataLayout dataLayout)43*89c4ff92SAndroid Build Coastguard Worker void TransposeConvolution2dEndToEnd(const std::vector<armnn::BackendId>& backends,
44*89c4ff92SAndroid Build Coastguard Worker                                     armnn::DataLayout dataLayout)
45*89c4ff92SAndroid Build Coastguard Worker {
46*89c4ff92SAndroid Build Coastguard Worker     using namespace armnn;
47*89c4ff92SAndroid Build Coastguard Worker     using T = ResolveType<ArmnnType>;
48*89c4ff92SAndroid Build Coastguard Worker 
49*89c4ff92SAndroid Build Coastguard Worker     constexpr unsigned int batches  = 1u;
50*89c4ff92SAndroid Build Coastguard Worker     constexpr unsigned int channels = 1u;
51*89c4ff92SAndroid Build Coastguard Worker 
52*89c4ff92SAndroid Build Coastguard Worker     constexpr unsigned int wInput = 3u;
53*89c4ff92SAndroid Build Coastguard Worker     constexpr unsigned int hInput = wInput;
54*89c4ff92SAndroid Build Coastguard Worker 
55*89c4ff92SAndroid Build Coastguard Worker     constexpr unsigned int wOutput = 5u;
56*89c4ff92SAndroid Build Coastguard Worker     constexpr unsigned int hOutput = wOutput;
57*89c4ff92SAndroid Build Coastguard Worker 
58*89c4ff92SAndroid Build Coastguard Worker     constexpr unsigned int wWeights = 3u;
59*89c4ff92SAndroid Build Coastguard Worker     constexpr unsigned int hWeights = wWeights;
60*89c4ff92SAndroid Build Coastguard Worker 
61*89c4ff92SAndroid Build Coastguard Worker     TensorShape inputShape   = MakeTensorShape(batches, channels, hInput, wInput, dataLayout);
62*89c4ff92SAndroid Build Coastguard Worker     TensorShape outputShape  = MakeTensorShape(batches, channels, hOutput, wOutput, dataLayout);
63*89c4ff92SAndroid Build Coastguard Worker     TensorShape weightsShape = MakeTensorShape(batches, channels, hWeights, wWeights, dataLayout);
64*89c4ff92SAndroid Build Coastguard Worker 
65*89c4ff92SAndroid Build Coastguard Worker     const float   qScale  = IsQuantizedType<T>() ? 0.25f : 1.0f;
66*89c4ff92SAndroid Build Coastguard Worker     const int32_t qOffset = IsQuantizedType<T>() ? 50    : 0;
67*89c4ff92SAndroid Build Coastguard Worker 
68*89c4ff92SAndroid Build Coastguard Worker     TensorInfo inputInfo(inputShape, ArmnnType, qScale, qOffset, true);
69*89c4ff92SAndroid Build Coastguard Worker     TensorInfo outputInfo(outputShape, ArmnnType, qScale, qOffset);
70*89c4ff92SAndroid Build Coastguard Worker     TensorInfo weightsInfo(weightsShape, ArmnnType, qScale, qOffset, true);
71*89c4ff92SAndroid Build Coastguard Worker     TensorInfo biasesInfo({ channels }, ArmnnBType, qScale * qScale, 0, true);
72*89c4ff92SAndroid Build Coastguard Worker 
73*89c4ff92SAndroid Build Coastguard Worker     std::vector<float> inputData =
74*89c4ff92SAndroid Build Coastguard Worker     {
75*89c4ff92SAndroid Build Coastguard Worker        1.f, 1.f, 1.f,
76*89c4ff92SAndroid Build Coastguard Worker        1.f, 1.f, 1.f,
77*89c4ff92SAndroid Build Coastguard Worker        1.f, 1.f, 1.f
78*89c4ff92SAndroid Build Coastguard Worker     };
79*89c4ff92SAndroid Build Coastguard Worker 
80*89c4ff92SAndroid Build Coastguard Worker     std::vector<float> weightsData =
81*89c4ff92SAndroid Build Coastguard Worker     {
82*89c4ff92SAndroid Build Coastguard Worker         1.f, 2.f, 3.f,
83*89c4ff92SAndroid Build Coastguard Worker         4.f, 5.f, 6.f,
84*89c4ff92SAndroid Build Coastguard Worker         7.f, 8.f, 9.f
85*89c4ff92SAndroid Build Coastguard Worker     };
86*89c4ff92SAndroid Build Coastguard Worker 
87*89c4ff92SAndroid Build Coastguard Worker     std::vector<float> biasesData = { 1.f };
88*89c4ff92SAndroid Build Coastguard Worker 
89*89c4ff92SAndroid Build Coastguard Worker     std::vector<float> expectedOutputData =
90*89c4ff92SAndroid Build Coastguard Worker     {
91*89c4ff92SAndroid Build Coastguard Worker          6.f, 11.f,  6.f, 11.f,  6.f,
92*89c4ff92SAndroid Build Coastguard Worker         11.f, 21.f, 11.f, 21.f, 11.f,
93*89c4ff92SAndroid Build Coastguard Worker          6.f, 11.f,  6.f, 11.f,  6.f,
94*89c4ff92SAndroid Build Coastguard Worker         11.f, 21.f, 11.f, 21.f, 11.f,
95*89c4ff92SAndroid Build Coastguard Worker          6.f, 11.f,  6.f, 11.f,  6.f
96*89c4ff92SAndroid Build Coastguard Worker     };
97*89c4ff92SAndroid Build Coastguard Worker 
98*89c4ff92SAndroid Build Coastguard Worker     TransposeConvolution2dDescriptor descriptor;
99*89c4ff92SAndroid Build Coastguard Worker     descriptor.m_PadLeft     = 1;
100*89c4ff92SAndroid Build Coastguard Worker     descriptor.m_PadRight    = 1;
101*89c4ff92SAndroid Build Coastguard Worker     descriptor.m_PadTop      = 1;
102*89c4ff92SAndroid Build Coastguard Worker     descriptor.m_PadBottom   = 1;
103*89c4ff92SAndroid Build Coastguard Worker     descriptor.m_StrideX     = 2;
104*89c4ff92SAndroid Build Coastguard Worker     descriptor.m_StrideY     = 2;
105*89c4ff92SAndroid Build Coastguard Worker     descriptor.m_BiasEnabled = true;
106*89c4ff92SAndroid Build Coastguard Worker     descriptor.m_DataLayout  = dataLayout;
107*89c4ff92SAndroid Build Coastguard Worker 
108*89c4ff92SAndroid Build Coastguard Worker     // swizzle data if needed
109*89c4ff92SAndroid Build Coastguard Worker     if (dataLayout == armnn::DataLayout::NHWC)
110*89c4ff92SAndroid Build Coastguard Worker     {
111*89c4ff92SAndroid Build Coastguard Worker         constexpr size_t dataTypeSize = sizeof(float);
112*89c4ff92SAndroid Build Coastguard Worker         const armnn::PermutationVector nchwToNhwc = { 0, 3, 1, 2 };
113*89c4ff92SAndroid Build Coastguard Worker 
114*89c4ff92SAndroid Build Coastguard Worker         std::vector<float> tmp(inputData.size());
115*89c4ff92SAndroid Build Coastguard Worker         armnnUtils::Permute(inputInfo.GetShape(), nchwToNhwc, inputData.data(), tmp.data(), dataTypeSize);
116*89c4ff92SAndroid Build Coastguard Worker         inputData = tmp;
117*89c4ff92SAndroid Build Coastguard Worker 
118*89c4ff92SAndroid Build Coastguard Worker         tmp.resize(weightsData.size());
119*89c4ff92SAndroid Build Coastguard Worker         armnnUtils::Permute(weightsInfo.GetShape(), nchwToNhwc, weightsData.data(), tmp.data(), dataTypeSize);
120*89c4ff92SAndroid Build Coastguard Worker         weightsData = tmp;
121*89c4ff92SAndroid Build Coastguard Worker 
122*89c4ff92SAndroid Build Coastguard Worker         tmp.resize(expectedOutputData.size());
123*89c4ff92SAndroid Build Coastguard Worker         armnnUtils::Permute(outputInfo.GetShape(), nchwToNhwc, expectedOutputData.data(), tmp.data(), dataTypeSize);
124*89c4ff92SAndroid Build Coastguard Worker         expectedOutputData = tmp;
125*89c4ff92SAndroid Build Coastguard Worker     }
126*89c4ff92SAndroid Build Coastguard Worker 
127*89c4ff92SAndroid Build Coastguard Worker     // quantize data
128*89c4ff92SAndroid Build Coastguard Worker     std::vector<T> qInputData          = armnnUtils::QuantizedVector<T>(inputData, qScale, qOffset);
129*89c4ff92SAndroid Build Coastguard Worker     std::vector<T> qWeightsData        = armnnUtils::QuantizedVector<T>(weightsData, qScale, qOffset);
130*89c4ff92SAndroid Build Coastguard Worker     std::vector<T> qExpectedOutputData = armnnUtils::QuantizedVector<T>(expectedOutputData, qScale, qOffset);
131*89c4ff92SAndroid Build Coastguard Worker 
132*89c4ff92SAndroid Build Coastguard Worker     using BT = ResolveType<ArmnnBType>;
133*89c4ff92SAndroid Build Coastguard Worker     std::vector<BT> qBiasesData = armnnUtils::QuantizedVector<BT>(biasesData, qScale * qScale, 0);
134*89c4ff92SAndroid Build Coastguard Worker 
135*89c4ff92SAndroid Build Coastguard Worker     ConstTensor weights(weightsInfo, qWeightsData);
136*89c4ff92SAndroid Build Coastguard Worker     ConstTensor biases(biasesInfo, qBiasesData);
137*89c4ff92SAndroid Build Coastguard Worker 
138*89c4ff92SAndroid Build Coastguard Worker     INetworkPtr network = CreateTransposeConvolution2dNetwork(descriptor,
139*89c4ff92SAndroid Build Coastguard Worker                                                               inputInfo,
140*89c4ff92SAndroid Build Coastguard Worker                                                               outputInfo,
141*89c4ff92SAndroid Build Coastguard Worker                                                               weights,
142*89c4ff92SAndroid Build Coastguard Worker                                                               Optional<ConstTensor>(biases));
143*89c4ff92SAndroid Build Coastguard Worker 
144*89c4ff92SAndroid Build Coastguard Worker 
145*89c4ff92SAndroid Build Coastguard Worker     EndToEndLayerTestImpl<ArmnnType, ArmnnType>(std::move(network),
146*89c4ff92SAndroid Build Coastguard Worker                                                 { { 0, qInputData } },
147*89c4ff92SAndroid Build Coastguard Worker                                                 { { 0, qExpectedOutputData } },
148*89c4ff92SAndroid Build Coastguard Worker                                                 backends);
149*89c4ff92SAndroid Build Coastguard Worker }
150*89c4ff92SAndroid Build Coastguard Worker 
151*89c4ff92SAndroid Build Coastguard Worker template<armnn::DataType ArmnnType, armnn::DataType ArmnnBType>
SimpleTransposeConvolution2dEndToEnd(const std::vector<armnn::BackendId> & backends,armnn::DataLayout dataLayout)152*89c4ff92SAndroid Build Coastguard Worker void SimpleTransposeConvolution2dEndToEnd(const std::vector<armnn::BackendId>& backends,
153*89c4ff92SAndroid Build Coastguard Worker                                           armnn::DataLayout dataLayout)
154*89c4ff92SAndroid Build Coastguard Worker {
155*89c4ff92SAndroid Build Coastguard Worker     using namespace armnn;
156*89c4ff92SAndroid Build Coastguard Worker     using T = ResolveType<ArmnnType>;
157*89c4ff92SAndroid Build Coastguard Worker 
158*89c4ff92SAndroid Build Coastguard Worker     const float   qScale  = IsQuantizedType<T>() ? 0.25f : 1.0f;
159*89c4ff92SAndroid Build Coastguard Worker     const int32_t qOffset = IsQuantizedType<T>() ? 50    : 0;
160*89c4ff92SAndroid Build Coastguard Worker 
161*89c4ff92SAndroid Build Coastguard Worker     TensorInfo inputInfo({1, 2, 2, 1}, ArmnnType, qScale, qOffset, true);
162*89c4ff92SAndroid Build Coastguard Worker     TensorInfo outputInfo({1, 3, 3, 1}, ArmnnType, qScale, qOffset);
163*89c4ff92SAndroid Build Coastguard Worker     TensorInfo weightsInfo({1, 2, 2, 1}, ArmnnType, qScale, qOffset, true);
164*89c4ff92SAndroid Build Coastguard Worker     TensorInfo biasesInfo({ 1 }, ArmnnBType, qScale * qScale, 0, true);
165*89c4ff92SAndroid Build Coastguard Worker 
166*89c4ff92SAndroid Build Coastguard Worker     std::vector<float> inputData =
167*89c4ff92SAndroid Build Coastguard Worker     {
168*89c4ff92SAndroid Build Coastguard Worker         1, 2, 3, 4
169*89c4ff92SAndroid Build Coastguard Worker     };
170*89c4ff92SAndroid Build Coastguard Worker 
171*89c4ff92SAndroid Build Coastguard Worker     std::vector<float> weightsData =
172*89c4ff92SAndroid Build Coastguard Worker     {
173*89c4ff92SAndroid Build Coastguard Worker         0, 1, 2, 4
174*89c4ff92SAndroid Build Coastguard Worker     };
175*89c4ff92SAndroid Build Coastguard Worker     std::vector<float> biasesData = { 0.f };
176*89c4ff92SAndroid Build Coastguard Worker 
177*89c4ff92SAndroid Build Coastguard Worker     std::vector<float> expectedOutputData =
178*89c4ff92SAndroid Build Coastguard Worker     {
179*89c4ff92SAndroid Build Coastguard Worker         0, 1,  2,
180*89c4ff92SAndroid Build Coastguard Worker         2, 11, 12,
181*89c4ff92SAndroid Build Coastguard Worker         6, 20, 16
182*89c4ff92SAndroid Build Coastguard Worker     };
183*89c4ff92SAndroid Build Coastguard Worker 
184*89c4ff92SAndroid Build Coastguard Worker     TransposeConvolution2dDescriptor descriptor;
185*89c4ff92SAndroid Build Coastguard Worker     descriptor.m_PadLeft     = 0;
186*89c4ff92SAndroid Build Coastguard Worker     descriptor.m_PadRight    = 0;
187*89c4ff92SAndroid Build Coastguard Worker     descriptor.m_PadTop      = 0;
188*89c4ff92SAndroid Build Coastguard Worker     descriptor.m_PadBottom   = 0;
189*89c4ff92SAndroid Build Coastguard Worker     descriptor.m_StrideX     = 1;
190*89c4ff92SAndroid Build Coastguard Worker     descriptor.m_StrideY     = 1;
191*89c4ff92SAndroid Build Coastguard Worker     descriptor.m_BiasEnabled = true;
192*89c4ff92SAndroid Build Coastguard Worker     descriptor.m_DataLayout  = dataLayout;
193*89c4ff92SAndroid Build Coastguard Worker     descriptor.m_OutputShapeEnabled = true;
194*89c4ff92SAndroid Build Coastguard Worker     descriptor.m_OutputShape = { 1, 3, 3, 1 };
195*89c4ff92SAndroid Build Coastguard Worker 
196*89c4ff92SAndroid Build Coastguard Worker     // quantize data
197*89c4ff92SAndroid Build Coastguard Worker     std::vector<T> qInputData          = armnnUtils::QuantizedVector<T>(inputData, qScale, qOffset);
198*89c4ff92SAndroid Build Coastguard Worker     std::vector<T> qWeightsData        = armnnUtils::QuantizedVector<T>(weightsData, qScale, qOffset);
199*89c4ff92SAndroid Build Coastguard Worker     std::vector<T> qExpectedOutputData = armnnUtils::QuantizedVector<T>(expectedOutputData, qScale, qOffset);
200*89c4ff92SAndroid Build Coastguard Worker 
201*89c4ff92SAndroid Build Coastguard Worker     using BT = ResolveType<ArmnnBType>;
202*89c4ff92SAndroid Build Coastguard Worker     std::vector<BT> qBiasesData = armnnUtils::QuantizedVector<BT>(biasesData, qScale * qScale, 0);
203*89c4ff92SAndroid Build Coastguard Worker 
204*89c4ff92SAndroid Build Coastguard Worker     ConstTensor weights(weightsInfo, qWeightsData);
205*89c4ff92SAndroid Build Coastguard Worker     ConstTensor biases(biasesInfo, qBiasesData);
206*89c4ff92SAndroid Build Coastguard Worker 
207*89c4ff92SAndroid Build Coastguard Worker     INetworkPtr network = CreateTransposeConvolution2dNetwork(descriptor,
208*89c4ff92SAndroid Build Coastguard Worker                                                               inputInfo,
209*89c4ff92SAndroid Build Coastguard Worker                                                               outputInfo,
210*89c4ff92SAndroid Build Coastguard Worker                                                               weights,
211*89c4ff92SAndroid Build Coastguard Worker                                                               Optional<ConstTensor>(biases));
212*89c4ff92SAndroid Build Coastguard Worker 
213*89c4ff92SAndroid Build Coastguard Worker     EndToEndLayerTestImpl<ArmnnType, ArmnnType>(std::move(network),
214*89c4ff92SAndroid Build Coastguard Worker                                                 { { 0, qInputData } },
215*89c4ff92SAndroid Build Coastguard Worker                                                 { { 0, qExpectedOutputData } },
216*89c4ff92SAndroid Build Coastguard Worker                                                 backends);
217*89c4ff92SAndroid Build Coastguard Worker }
218