1*89c4ff92SAndroid Build Coastguard Worker //
2*89c4ff92SAndroid Build Coastguard Worker // Copyright © 2017,2022 Arm Ltd and Contributors. All rights reserved.
3*89c4ff92SAndroid Build Coastguard Worker // SPDX-License-Identifier: MIT
4*89c4ff92SAndroid Build Coastguard Worker //
5*89c4ff92SAndroid Build Coastguard Worker #pragma once
6*89c4ff92SAndroid Build Coastguard Worker
7*89c4ff92SAndroid Build Coastguard Worker #include <armnnUtils/Permute.hpp>
8*89c4ff92SAndroid Build Coastguard Worker
9*89c4ff92SAndroid Build Coastguard Worker #include <armnnUtils/QuantizeHelper.hpp>
10*89c4ff92SAndroid Build Coastguard Worker #include <ResolveType.hpp>
11*89c4ff92SAndroid Build Coastguard Worker
12*89c4ff92SAndroid Build Coastguard Worker #include <CommonTestUtils.hpp>
13*89c4ff92SAndroid Build Coastguard Worker
14*89c4ff92SAndroid Build Coastguard Worker #include <map>
15*89c4ff92SAndroid Build Coastguard Worker #include <vector>
16*89c4ff92SAndroid Build Coastguard Worker
17*89c4ff92SAndroid Build Coastguard Worker namespace
18*89c4ff92SAndroid Build Coastguard Worker {
19*89c4ff92SAndroid Build Coastguard Worker
CreateTransposeConvolution2dNetwork(const armnn::TransposeConvolution2dDescriptor & descriptor,const armnn::TensorInfo & inputInfo,const armnn::TensorInfo & outputInfo,const armnn::ConstTensor & weights,const armnn::Optional<armnn::ConstTensor> & biases)20*89c4ff92SAndroid Build Coastguard Worker INetworkPtr CreateTransposeConvolution2dNetwork(const armnn::TransposeConvolution2dDescriptor& descriptor,
21*89c4ff92SAndroid Build Coastguard Worker const armnn::TensorInfo& inputInfo,
22*89c4ff92SAndroid Build Coastguard Worker const armnn::TensorInfo& outputInfo,
23*89c4ff92SAndroid Build Coastguard Worker const armnn::ConstTensor& weights,
24*89c4ff92SAndroid Build Coastguard Worker const armnn::Optional<armnn::ConstTensor>& biases)
25*89c4ff92SAndroid Build Coastguard Worker {
26*89c4ff92SAndroid Build Coastguard Worker using namespace armnn;
27*89c4ff92SAndroid Build Coastguard Worker
28*89c4ff92SAndroid Build Coastguard Worker INetworkPtr network(INetwork::Create());
29*89c4ff92SAndroid Build Coastguard Worker IConnectableLayer* input = network->AddInputLayer(0, "input");
30*89c4ff92SAndroid Build Coastguard Worker IConnectableLayer* transposeConvolution2d =
31*89c4ff92SAndroid Build Coastguard Worker network->AddTransposeConvolution2dLayer(descriptor, weights, biases, "transposeConvolution2d");
32*89c4ff92SAndroid Build Coastguard Worker IConnectableLayer* output = network->AddOutputLayer(0, "output");
33*89c4ff92SAndroid Build Coastguard Worker
34*89c4ff92SAndroid Build Coastguard Worker Connect(input, transposeConvolution2d, inputInfo, 0, 0);
35*89c4ff92SAndroid Build Coastguard Worker Connect(transposeConvolution2d, output, outputInfo, 0, 0);
36*89c4ff92SAndroid Build Coastguard Worker
37*89c4ff92SAndroid Build Coastguard Worker return network;
38*89c4ff92SAndroid Build Coastguard Worker }
39*89c4ff92SAndroid Build Coastguard Worker
40*89c4ff92SAndroid Build Coastguard Worker } // anonymous namespace
41*89c4ff92SAndroid Build Coastguard Worker
42*89c4ff92SAndroid Build Coastguard Worker template<armnn::DataType ArmnnType, armnn::DataType ArmnnBType>
TransposeConvolution2dEndToEnd(const std::vector<armnn::BackendId> & backends,armnn::DataLayout dataLayout)43*89c4ff92SAndroid Build Coastguard Worker void TransposeConvolution2dEndToEnd(const std::vector<armnn::BackendId>& backends,
44*89c4ff92SAndroid Build Coastguard Worker armnn::DataLayout dataLayout)
45*89c4ff92SAndroid Build Coastguard Worker {
46*89c4ff92SAndroid Build Coastguard Worker using namespace armnn;
47*89c4ff92SAndroid Build Coastguard Worker using T = ResolveType<ArmnnType>;
48*89c4ff92SAndroid Build Coastguard Worker
49*89c4ff92SAndroid Build Coastguard Worker constexpr unsigned int batches = 1u;
50*89c4ff92SAndroid Build Coastguard Worker constexpr unsigned int channels = 1u;
51*89c4ff92SAndroid Build Coastguard Worker
52*89c4ff92SAndroid Build Coastguard Worker constexpr unsigned int wInput = 3u;
53*89c4ff92SAndroid Build Coastguard Worker constexpr unsigned int hInput = wInput;
54*89c4ff92SAndroid Build Coastguard Worker
55*89c4ff92SAndroid Build Coastguard Worker constexpr unsigned int wOutput = 5u;
56*89c4ff92SAndroid Build Coastguard Worker constexpr unsigned int hOutput = wOutput;
57*89c4ff92SAndroid Build Coastguard Worker
58*89c4ff92SAndroid Build Coastguard Worker constexpr unsigned int wWeights = 3u;
59*89c4ff92SAndroid Build Coastguard Worker constexpr unsigned int hWeights = wWeights;
60*89c4ff92SAndroid Build Coastguard Worker
61*89c4ff92SAndroid Build Coastguard Worker TensorShape inputShape = MakeTensorShape(batches, channels, hInput, wInput, dataLayout);
62*89c4ff92SAndroid Build Coastguard Worker TensorShape outputShape = MakeTensorShape(batches, channels, hOutput, wOutput, dataLayout);
63*89c4ff92SAndroid Build Coastguard Worker TensorShape weightsShape = MakeTensorShape(batches, channels, hWeights, wWeights, dataLayout);
64*89c4ff92SAndroid Build Coastguard Worker
65*89c4ff92SAndroid Build Coastguard Worker const float qScale = IsQuantizedType<T>() ? 0.25f : 1.0f;
66*89c4ff92SAndroid Build Coastguard Worker const int32_t qOffset = IsQuantizedType<T>() ? 50 : 0;
67*89c4ff92SAndroid Build Coastguard Worker
68*89c4ff92SAndroid Build Coastguard Worker TensorInfo inputInfo(inputShape, ArmnnType, qScale, qOffset, true);
69*89c4ff92SAndroid Build Coastguard Worker TensorInfo outputInfo(outputShape, ArmnnType, qScale, qOffset);
70*89c4ff92SAndroid Build Coastguard Worker TensorInfo weightsInfo(weightsShape, ArmnnType, qScale, qOffset, true);
71*89c4ff92SAndroid Build Coastguard Worker TensorInfo biasesInfo({ channels }, ArmnnBType, qScale * qScale, 0, true);
72*89c4ff92SAndroid Build Coastguard Worker
73*89c4ff92SAndroid Build Coastguard Worker std::vector<float> inputData =
74*89c4ff92SAndroid Build Coastguard Worker {
75*89c4ff92SAndroid Build Coastguard Worker 1.f, 1.f, 1.f,
76*89c4ff92SAndroid Build Coastguard Worker 1.f, 1.f, 1.f,
77*89c4ff92SAndroid Build Coastguard Worker 1.f, 1.f, 1.f
78*89c4ff92SAndroid Build Coastguard Worker };
79*89c4ff92SAndroid Build Coastguard Worker
80*89c4ff92SAndroid Build Coastguard Worker std::vector<float> weightsData =
81*89c4ff92SAndroid Build Coastguard Worker {
82*89c4ff92SAndroid Build Coastguard Worker 1.f, 2.f, 3.f,
83*89c4ff92SAndroid Build Coastguard Worker 4.f, 5.f, 6.f,
84*89c4ff92SAndroid Build Coastguard Worker 7.f, 8.f, 9.f
85*89c4ff92SAndroid Build Coastguard Worker };
86*89c4ff92SAndroid Build Coastguard Worker
87*89c4ff92SAndroid Build Coastguard Worker std::vector<float> biasesData = { 1.f };
88*89c4ff92SAndroid Build Coastguard Worker
89*89c4ff92SAndroid Build Coastguard Worker std::vector<float> expectedOutputData =
90*89c4ff92SAndroid Build Coastguard Worker {
91*89c4ff92SAndroid Build Coastguard Worker 6.f, 11.f, 6.f, 11.f, 6.f,
92*89c4ff92SAndroid Build Coastguard Worker 11.f, 21.f, 11.f, 21.f, 11.f,
93*89c4ff92SAndroid Build Coastguard Worker 6.f, 11.f, 6.f, 11.f, 6.f,
94*89c4ff92SAndroid Build Coastguard Worker 11.f, 21.f, 11.f, 21.f, 11.f,
95*89c4ff92SAndroid Build Coastguard Worker 6.f, 11.f, 6.f, 11.f, 6.f
96*89c4ff92SAndroid Build Coastguard Worker };
97*89c4ff92SAndroid Build Coastguard Worker
98*89c4ff92SAndroid Build Coastguard Worker TransposeConvolution2dDescriptor descriptor;
99*89c4ff92SAndroid Build Coastguard Worker descriptor.m_PadLeft = 1;
100*89c4ff92SAndroid Build Coastguard Worker descriptor.m_PadRight = 1;
101*89c4ff92SAndroid Build Coastguard Worker descriptor.m_PadTop = 1;
102*89c4ff92SAndroid Build Coastguard Worker descriptor.m_PadBottom = 1;
103*89c4ff92SAndroid Build Coastguard Worker descriptor.m_StrideX = 2;
104*89c4ff92SAndroid Build Coastguard Worker descriptor.m_StrideY = 2;
105*89c4ff92SAndroid Build Coastguard Worker descriptor.m_BiasEnabled = true;
106*89c4ff92SAndroid Build Coastguard Worker descriptor.m_DataLayout = dataLayout;
107*89c4ff92SAndroid Build Coastguard Worker
108*89c4ff92SAndroid Build Coastguard Worker // swizzle data if needed
109*89c4ff92SAndroid Build Coastguard Worker if (dataLayout == armnn::DataLayout::NHWC)
110*89c4ff92SAndroid Build Coastguard Worker {
111*89c4ff92SAndroid Build Coastguard Worker constexpr size_t dataTypeSize = sizeof(float);
112*89c4ff92SAndroid Build Coastguard Worker const armnn::PermutationVector nchwToNhwc = { 0, 3, 1, 2 };
113*89c4ff92SAndroid Build Coastguard Worker
114*89c4ff92SAndroid Build Coastguard Worker std::vector<float> tmp(inputData.size());
115*89c4ff92SAndroid Build Coastguard Worker armnnUtils::Permute(inputInfo.GetShape(), nchwToNhwc, inputData.data(), tmp.data(), dataTypeSize);
116*89c4ff92SAndroid Build Coastguard Worker inputData = tmp;
117*89c4ff92SAndroid Build Coastguard Worker
118*89c4ff92SAndroid Build Coastguard Worker tmp.resize(weightsData.size());
119*89c4ff92SAndroid Build Coastguard Worker armnnUtils::Permute(weightsInfo.GetShape(), nchwToNhwc, weightsData.data(), tmp.data(), dataTypeSize);
120*89c4ff92SAndroid Build Coastguard Worker weightsData = tmp;
121*89c4ff92SAndroid Build Coastguard Worker
122*89c4ff92SAndroid Build Coastguard Worker tmp.resize(expectedOutputData.size());
123*89c4ff92SAndroid Build Coastguard Worker armnnUtils::Permute(outputInfo.GetShape(), nchwToNhwc, expectedOutputData.data(), tmp.data(), dataTypeSize);
124*89c4ff92SAndroid Build Coastguard Worker expectedOutputData = tmp;
125*89c4ff92SAndroid Build Coastguard Worker }
126*89c4ff92SAndroid Build Coastguard Worker
127*89c4ff92SAndroid Build Coastguard Worker // quantize data
128*89c4ff92SAndroid Build Coastguard Worker std::vector<T> qInputData = armnnUtils::QuantizedVector<T>(inputData, qScale, qOffset);
129*89c4ff92SAndroid Build Coastguard Worker std::vector<T> qWeightsData = armnnUtils::QuantizedVector<T>(weightsData, qScale, qOffset);
130*89c4ff92SAndroid Build Coastguard Worker std::vector<T> qExpectedOutputData = armnnUtils::QuantizedVector<T>(expectedOutputData, qScale, qOffset);
131*89c4ff92SAndroid Build Coastguard Worker
132*89c4ff92SAndroid Build Coastguard Worker using BT = ResolveType<ArmnnBType>;
133*89c4ff92SAndroid Build Coastguard Worker std::vector<BT> qBiasesData = armnnUtils::QuantizedVector<BT>(biasesData, qScale * qScale, 0);
134*89c4ff92SAndroid Build Coastguard Worker
135*89c4ff92SAndroid Build Coastguard Worker ConstTensor weights(weightsInfo, qWeightsData);
136*89c4ff92SAndroid Build Coastguard Worker ConstTensor biases(biasesInfo, qBiasesData);
137*89c4ff92SAndroid Build Coastguard Worker
138*89c4ff92SAndroid Build Coastguard Worker INetworkPtr network = CreateTransposeConvolution2dNetwork(descriptor,
139*89c4ff92SAndroid Build Coastguard Worker inputInfo,
140*89c4ff92SAndroid Build Coastguard Worker outputInfo,
141*89c4ff92SAndroid Build Coastguard Worker weights,
142*89c4ff92SAndroid Build Coastguard Worker Optional<ConstTensor>(biases));
143*89c4ff92SAndroid Build Coastguard Worker
144*89c4ff92SAndroid Build Coastguard Worker
145*89c4ff92SAndroid Build Coastguard Worker EndToEndLayerTestImpl<ArmnnType, ArmnnType>(std::move(network),
146*89c4ff92SAndroid Build Coastguard Worker { { 0, qInputData } },
147*89c4ff92SAndroid Build Coastguard Worker { { 0, qExpectedOutputData } },
148*89c4ff92SAndroid Build Coastguard Worker backends);
149*89c4ff92SAndroid Build Coastguard Worker }
150*89c4ff92SAndroid Build Coastguard Worker
151*89c4ff92SAndroid Build Coastguard Worker template<armnn::DataType ArmnnType, armnn::DataType ArmnnBType>
SimpleTransposeConvolution2dEndToEnd(const std::vector<armnn::BackendId> & backends,armnn::DataLayout dataLayout)152*89c4ff92SAndroid Build Coastguard Worker void SimpleTransposeConvolution2dEndToEnd(const std::vector<armnn::BackendId>& backends,
153*89c4ff92SAndroid Build Coastguard Worker armnn::DataLayout dataLayout)
154*89c4ff92SAndroid Build Coastguard Worker {
155*89c4ff92SAndroid Build Coastguard Worker using namespace armnn;
156*89c4ff92SAndroid Build Coastguard Worker using T = ResolveType<ArmnnType>;
157*89c4ff92SAndroid Build Coastguard Worker
158*89c4ff92SAndroid Build Coastguard Worker const float qScale = IsQuantizedType<T>() ? 0.25f : 1.0f;
159*89c4ff92SAndroid Build Coastguard Worker const int32_t qOffset = IsQuantizedType<T>() ? 50 : 0;
160*89c4ff92SAndroid Build Coastguard Worker
161*89c4ff92SAndroid Build Coastguard Worker TensorInfo inputInfo({1, 2, 2, 1}, ArmnnType, qScale, qOffset, true);
162*89c4ff92SAndroid Build Coastguard Worker TensorInfo outputInfo({1, 3, 3, 1}, ArmnnType, qScale, qOffset);
163*89c4ff92SAndroid Build Coastguard Worker TensorInfo weightsInfo({1, 2, 2, 1}, ArmnnType, qScale, qOffset, true);
164*89c4ff92SAndroid Build Coastguard Worker TensorInfo biasesInfo({ 1 }, ArmnnBType, qScale * qScale, 0, true);
165*89c4ff92SAndroid Build Coastguard Worker
166*89c4ff92SAndroid Build Coastguard Worker std::vector<float> inputData =
167*89c4ff92SAndroid Build Coastguard Worker {
168*89c4ff92SAndroid Build Coastguard Worker 1, 2, 3, 4
169*89c4ff92SAndroid Build Coastguard Worker };
170*89c4ff92SAndroid Build Coastguard Worker
171*89c4ff92SAndroid Build Coastguard Worker std::vector<float> weightsData =
172*89c4ff92SAndroid Build Coastguard Worker {
173*89c4ff92SAndroid Build Coastguard Worker 0, 1, 2, 4
174*89c4ff92SAndroid Build Coastguard Worker };
175*89c4ff92SAndroid Build Coastguard Worker std::vector<float> biasesData = { 0.f };
176*89c4ff92SAndroid Build Coastguard Worker
177*89c4ff92SAndroid Build Coastguard Worker std::vector<float> expectedOutputData =
178*89c4ff92SAndroid Build Coastguard Worker {
179*89c4ff92SAndroid Build Coastguard Worker 0, 1, 2,
180*89c4ff92SAndroid Build Coastguard Worker 2, 11, 12,
181*89c4ff92SAndroid Build Coastguard Worker 6, 20, 16
182*89c4ff92SAndroid Build Coastguard Worker };
183*89c4ff92SAndroid Build Coastguard Worker
184*89c4ff92SAndroid Build Coastguard Worker TransposeConvolution2dDescriptor descriptor;
185*89c4ff92SAndroid Build Coastguard Worker descriptor.m_PadLeft = 0;
186*89c4ff92SAndroid Build Coastguard Worker descriptor.m_PadRight = 0;
187*89c4ff92SAndroid Build Coastguard Worker descriptor.m_PadTop = 0;
188*89c4ff92SAndroid Build Coastguard Worker descriptor.m_PadBottom = 0;
189*89c4ff92SAndroid Build Coastguard Worker descriptor.m_StrideX = 1;
190*89c4ff92SAndroid Build Coastguard Worker descriptor.m_StrideY = 1;
191*89c4ff92SAndroid Build Coastguard Worker descriptor.m_BiasEnabled = true;
192*89c4ff92SAndroid Build Coastguard Worker descriptor.m_DataLayout = dataLayout;
193*89c4ff92SAndroid Build Coastguard Worker descriptor.m_OutputShapeEnabled = true;
194*89c4ff92SAndroid Build Coastguard Worker descriptor.m_OutputShape = { 1, 3, 3, 1 };
195*89c4ff92SAndroid Build Coastguard Worker
196*89c4ff92SAndroid Build Coastguard Worker // quantize data
197*89c4ff92SAndroid Build Coastguard Worker std::vector<T> qInputData = armnnUtils::QuantizedVector<T>(inputData, qScale, qOffset);
198*89c4ff92SAndroid Build Coastguard Worker std::vector<T> qWeightsData = armnnUtils::QuantizedVector<T>(weightsData, qScale, qOffset);
199*89c4ff92SAndroid Build Coastguard Worker std::vector<T> qExpectedOutputData = armnnUtils::QuantizedVector<T>(expectedOutputData, qScale, qOffset);
200*89c4ff92SAndroid Build Coastguard Worker
201*89c4ff92SAndroid Build Coastguard Worker using BT = ResolveType<ArmnnBType>;
202*89c4ff92SAndroid Build Coastguard Worker std::vector<BT> qBiasesData = armnnUtils::QuantizedVector<BT>(biasesData, qScale * qScale, 0);
203*89c4ff92SAndroid Build Coastguard Worker
204*89c4ff92SAndroid Build Coastguard Worker ConstTensor weights(weightsInfo, qWeightsData);
205*89c4ff92SAndroid Build Coastguard Worker ConstTensor biases(biasesInfo, qBiasesData);
206*89c4ff92SAndroid Build Coastguard Worker
207*89c4ff92SAndroid Build Coastguard Worker INetworkPtr network = CreateTransposeConvolution2dNetwork(descriptor,
208*89c4ff92SAndroid Build Coastguard Worker inputInfo,
209*89c4ff92SAndroid Build Coastguard Worker outputInfo,
210*89c4ff92SAndroid Build Coastguard Worker weights,
211*89c4ff92SAndroid Build Coastguard Worker Optional<ConstTensor>(biases));
212*89c4ff92SAndroid Build Coastguard Worker
213*89c4ff92SAndroid Build Coastguard Worker EndToEndLayerTestImpl<ArmnnType, ArmnnType>(std::move(network),
214*89c4ff92SAndroid Build Coastguard Worker { { 0, qInputData } },
215*89c4ff92SAndroid Build Coastguard Worker { { 0, qExpectedOutputData } },
216*89c4ff92SAndroid Build Coastguard Worker backends);
217*89c4ff92SAndroid Build Coastguard Worker }
218