xref: /aosp_15_r20/external/armnn/src/backends/backendsCommon/test/Convolution2dEndToEndTestImpl.hpp (revision 89c4ff92f2867872bb9e2354d150bf0c8c502810)
1 //
2 // Copyright © 2022 Arm Ltd and Contributors. All rights reserved.
3 // SPDX-License-Identifier: MIT
4 //
5 #pragma once
6 
7 #include "EndToEndTestImpl.hpp"
8 #include <armnnUtils/QuantizeHelper.hpp>
9 
10 #include <ResolveType.hpp>
11 
12 #include <CommonTestUtils.hpp>
13 #include <armnnTestUtils/DataLayoutUtils.hpp>
14 
15 #include <map>
16 #include <vector>
17 
18 namespace
19 {
20 
CreateConstConvolution2dNetwork(const armnn::Convolution2dDescriptor & descriptor,const armnn::TensorInfo & inputInfo,const armnn::TensorInfo & weightsInfo,const armnn::TensorInfo & biasInfo,const armnn::TensorInfo & outputInfo,const armnn::ConstTensor & weights,const armnn::ConstTensor & biases,bool biasEnabled)21 armnn::INetworkPtr CreateConstConvolution2dNetwork(const armnn::Convolution2dDescriptor& descriptor,
22                                                    const armnn::TensorInfo& inputInfo,
23                                                    const armnn::TensorInfo& weightsInfo,
24                                                    const armnn::TensorInfo& biasInfo,
25                                                    const armnn::TensorInfo& outputInfo,
26                                                    const armnn::ConstTensor& weights,
27                                                    const armnn::ConstTensor& biases,
28                                                    bool biasEnabled)
29 {
30     using namespace armnn;
31 
32     INetworkPtr network(INetwork::Create());
33     IConnectableLayer* input = network->AddInputLayer(0, "input");
34     IConnectableLayer* weightsLayer = network->AddConstantLayer(weights, "Weights");
35     IConnectableLayer* convolution2d = network->AddConvolution2dLayer(descriptor, "convolution2d");
36     IConnectableLayer* output = network->AddOutputLayer(0, "output");
37 
38     Connect(input, convolution2d, inputInfo, 0, 0);
39     Connect(weightsLayer, convolution2d, weightsInfo, 0, 1);
40 
41     if(biasEnabled)
42     {
43         armnn::IConnectableLayer* biasLayer = network->AddConstantLayer(biases, "Bias");
44         Connect(biasLayer, convolution2d, biasInfo, 0, 2);
45     }
46 
47     Connect(convolution2d, output, outputInfo, 0, 0);
48 
49     return network;
50 }
51 
52 template<armnn::DataType ArmnnType, typename T = armnn::ResolveType<ArmnnType>>
Convolution2dEndToEnd(const std::vector<armnn::BackendId> & backends,armnn::DataLayout dataLayout,bool biasEnabled=true)53 void Convolution2dEndToEnd(const std::vector<armnn::BackendId>& backends,
54                            armnn::DataLayout dataLayout,
55                            bool biasEnabled = true)
56 {
57     using namespace armnn;
58 
59     const float   qScale  = IsQuantizedType<T>() ? 0.25f : 1.0f;
60     const int32_t qOffset = IsQuantizedType<T>() ? 50    : 0;
61 
62     TensorInfo inputInfo({ 1, 5, 5, 1 }, ArmnnType, qScale, qOffset, true);
63     TensorInfo outputInfo({ 1, 3, 3, 1 }, ArmnnType, qScale, qOffset);
64     TensorInfo weightsInfo({ 1, 3, 3, 1 }, ArmnnType, qScale, qOffset, true);
65     TensorInfo biasesInfo({ 1 }, ArmnnType, qScale * qScale, 0, true);
66 
67     std::vector<float> inputData =
68     {
69         1.0f, 5.0f, 2.0f, 3.0f, 5.0f,
70         8.0f, 7.0f, 3.0f, 6.0f, 3.0f,
71         3.0f, 3.0f, 9.0f, 1.0f, 9.0f,
72         4.0f, 1.0f, 8.0f, 1.0f, 3.0f,
73         6.0f, 8.0f, 1.0f, 9.0f, 2.0f
74     };
75 
76     std::vector<float> weightsData =
77     {
78         4.0f, 5.0f, 6.0f,
79         0.0f, 0.0f, 0.0f,
80         3.0f, 2.0f, 1.0f
81     };
82 
83     std::vector<float> biasesData = { 1.0f };
84 
85     float bias = biasEnabled ? biasesData[0] : 0.0f;
86     std::vector<float> expectedOutputData =
87     {
88         65.0f + bias,  76.0f + bias,  91.0f + bias,
89         107.0f + bias, 99.0f + bias,  89.0f + bias,
90         116.0f + bias, 98.0f + bias,  118.0f + bias,
91     };
92 
93     Convolution2dDescriptor descriptor;
94     descriptor.m_PadLeft     = 0;
95     descriptor.m_PadRight    = 0;
96     descriptor.m_PadTop      = 0;
97     descriptor.m_PadBottom   = 0;
98     descriptor.m_StrideX     = 1;
99     descriptor.m_StrideY     = 1;
100     descriptor.m_BiasEnabled = biasEnabled;
101     descriptor.m_DataLayout  = dataLayout;
102 
103     if (dataLayout == DataLayout::NCHW)
104     {
105         PermuteTensorNhwcToNchw(inputInfo, inputData);
106         PermuteTensorNhwcToNchw(weightsInfo, weightsData);
107         PermuteTensorNhwcToNchw(outputInfo, expectedOutputData);
108     }
109 
110     // Quantize data
111     std::vector<T> qInputData          = armnnUtils::QuantizedVector<T>(inputData, qScale, qOffset);
112     std::vector<T> qWeightsData        = armnnUtils::QuantizedVector<T>(weightsData, qScale, qOffset);
113     std::vector<T> qExpectedOutputData = armnnUtils::QuantizedVector<T>(expectedOutputData, qScale, qOffset);
114     std::vector<T> qBiasesData         = armnnUtils::QuantizedVector<T>(biasesData, qScale * qScale, 0);
115 
116     ConstTensor weights(weightsInfo, qWeightsData);
117     ConstTensor biases(biasesInfo, qBiasesData);
118 
119     INetworkPtr network = CreateConstConvolution2dNetwork(descriptor,
120                                                           inputInfo,
121                                                           weightsInfo,
122                                                           biasesInfo,
123                                                           outputInfo,
124                                                           weights,
125                                                           biases,
126                                                           biasEnabled);
127 
128     EndToEndLayerTestImpl<ArmnnType, ArmnnType>(std::move(network),
129                                                 {{ 0, qInputData }},
130                                                 {{ 0, qExpectedOutputData }},
131                                                 backends);
132 }
133 
134 } // anonymous namespace
135