xref: /aosp_15_r20/external/armnn/delegate/test/FullyConnectedTest.cpp (revision 89c4ff92f2867872bb9e2354d150bf0c8c502810)
1*89c4ff92SAndroid Build Coastguard Worker //
2*89c4ff92SAndroid Build Coastguard Worker // Copyright © 2020-2021,2023 Arm Ltd and Contributors. All rights reserved.
3*89c4ff92SAndroid Build Coastguard Worker // SPDX-License-Identifier: MIT
4*89c4ff92SAndroid Build Coastguard Worker //
5*89c4ff92SAndroid Build Coastguard Worker 
6*89c4ff92SAndroid Build Coastguard Worker #include "FullyConnectedTestHelper.hpp"
7*89c4ff92SAndroid Build Coastguard Worker 
8*89c4ff92SAndroid Build Coastguard Worker namespace
9*89c4ff92SAndroid Build Coastguard Worker {
10*89c4ff92SAndroid Build Coastguard Worker 
FullyConnectedFp32Test(std::vector<armnn::BackendId> & backends,bool constantWeights=true)11*89c4ff92SAndroid Build Coastguard Worker void FullyConnectedFp32Test(std::vector<armnn::BackendId>& backends, bool constantWeights = true)
12*89c4ff92SAndroid Build Coastguard Worker {
13*89c4ff92SAndroid Build Coastguard Worker     std::vector<int32_t> inputTensorShape   { 1, 4, 1, 1 };
14*89c4ff92SAndroid Build Coastguard Worker     std::vector<int32_t> weightsTensorShape { 1, 4 };
15*89c4ff92SAndroid Build Coastguard Worker     std::vector<int32_t> biasTensorShape    { 1 };
16*89c4ff92SAndroid Build Coastguard Worker     std::vector<int32_t> outputTensorShape  { 1, 1 };
17*89c4ff92SAndroid Build Coastguard Worker 
18*89c4ff92SAndroid Build Coastguard Worker     std::vector<float> inputValues = { 10, 20, 30, 40 };
19*89c4ff92SAndroid Build Coastguard Worker     std::vector<float> weightsData = { 2, 3, 4, 5 };
20*89c4ff92SAndroid Build Coastguard Worker 
21*89c4ff92SAndroid Build Coastguard Worker     std::vector<float> expectedOutputValues = { (400 + 10) };
22*89c4ff92SAndroid Build Coastguard Worker 
23*89c4ff92SAndroid Build Coastguard Worker     // bias is set std::vector<float> biasData = { 10 } in the model
24*89c4ff92SAndroid Build Coastguard Worker     FullyConnectedTest<float>(backends,
25*89c4ff92SAndroid Build Coastguard Worker                               ::tflite::TensorType_FLOAT32,
26*89c4ff92SAndroid Build Coastguard Worker                               tflite::ActivationFunctionType_NONE,
27*89c4ff92SAndroid Build Coastguard Worker                               inputTensorShape,
28*89c4ff92SAndroid Build Coastguard Worker                               weightsTensorShape,
29*89c4ff92SAndroid Build Coastguard Worker                               biasTensorShape,
30*89c4ff92SAndroid Build Coastguard Worker                               outputTensorShape,
31*89c4ff92SAndroid Build Coastguard Worker                               inputValues,
32*89c4ff92SAndroid Build Coastguard Worker                               expectedOutputValues,
33*89c4ff92SAndroid Build Coastguard Worker                               weightsData,
34*89c4ff92SAndroid Build Coastguard Worker                               constantWeights);
35*89c4ff92SAndroid Build Coastguard Worker }
36*89c4ff92SAndroid Build Coastguard Worker 
FullyConnectedActivationTest(std::vector<armnn::BackendId> & backends,bool constantWeights=true)37*89c4ff92SAndroid Build Coastguard Worker void FullyConnectedActivationTest(std::vector<armnn::BackendId>& backends, bool constantWeights = true)
38*89c4ff92SAndroid Build Coastguard Worker {
39*89c4ff92SAndroid Build Coastguard Worker     std::vector<int32_t> inputTensorShape   { 1, 4, 1, 1 };
40*89c4ff92SAndroid Build Coastguard Worker     std::vector<int32_t> weightsTensorShape { 1, 4 };
41*89c4ff92SAndroid Build Coastguard Worker     std::vector<int32_t> biasTensorShape    { 1 };
42*89c4ff92SAndroid Build Coastguard Worker     std::vector<int32_t> outputTensorShape  { 1, 1 };
43*89c4ff92SAndroid Build Coastguard Worker 
44*89c4ff92SAndroid Build Coastguard Worker     std::vector<float> inputValues = { -10, 20, 30, 40 };
45*89c4ff92SAndroid Build Coastguard Worker     std::vector<float> weightsData = { 2, 3, 4, -5 };
46*89c4ff92SAndroid Build Coastguard Worker 
47*89c4ff92SAndroid Build Coastguard Worker     std::vector<float> expectedOutputValues = { 0 };
48*89c4ff92SAndroid Build Coastguard Worker 
49*89c4ff92SAndroid Build Coastguard Worker     // bias is set std::vector<float> biasData = { 10 } in the model
50*89c4ff92SAndroid Build Coastguard Worker     FullyConnectedTest<float>(backends,
51*89c4ff92SAndroid Build Coastguard Worker                               ::tflite::TensorType_FLOAT32,
52*89c4ff92SAndroid Build Coastguard Worker                               tflite::ActivationFunctionType_RELU,
53*89c4ff92SAndroid Build Coastguard Worker                               inputTensorShape,
54*89c4ff92SAndroid Build Coastguard Worker                               weightsTensorShape,
55*89c4ff92SAndroid Build Coastguard Worker                               biasTensorShape,
56*89c4ff92SAndroid Build Coastguard Worker                               outputTensorShape,
57*89c4ff92SAndroid Build Coastguard Worker                               inputValues,
58*89c4ff92SAndroid Build Coastguard Worker                               expectedOutputValues,
59*89c4ff92SAndroid Build Coastguard Worker                               weightsData,
60*89c4ff92SAndroid Build Coastguard Worker                               constantWeights);
61*89c4ff92SAndroid Build Coastguard Worker }
62*89c4ff92SAndroid Build Coastguard Worker 
FullyConnectedInt8Test(std::vector<armnn::BackendId> & backends,bool constantWeights=true)63*89c4ff92SAndroid Build Coastguard Worker void FullyConnectedInt8Test(std::vector<armnn::BackendId>& backends, bool constantWeights = true)
64*89c4ff92SAndroid Build Coastguard Worker {
65*89c4ff92SAndroid Build Coastguard Worker     std::vector<int32_t> inputTensorShape   { 1, 4, 2, 1 };
66*89c4ff92SAndroid Build Coastguard Worker     std::vector<int32_t> weightsTensorShape { 1, 4 };
67*89c4ff92SAndroid Build Coastguard Worker     std::vector<int32_t> biasTensorShape    { 1 };
68*89c4ff92SAndroid Build Coastguard Worker     std::vector<int32_t> outputTensorShape  { 2, 1 };
69*89c4ff92SAndroid Build Coastguard Worker 
70*89c4ff92SAndroid Build Coastguard Worker     std::vector<int8_t> inputValues = { 1, 2, 3, 4, 5, 10, 15, 20 };
71*89c4ff92SAndroid Build Coastguard Worker     std::vector<int8_t> weightsData = { 2, 3, 4, 5 };
72*89c4ff92SAndroid Build Coastguard Worker 
73*89c4ff92SAndroid Build Coastguard Worker     std::vector<int8_t> expectedOutputValues = { 25, 105 };  // (40 + 10) / 2, (200 + 10) / 2
74*89c4ff92SAndroid Build Coastguard Worker 
75*89c4ff92SAndroid Build Coastguard Worker     // bias is set std::vector<int32_t> biasData = { 10 } in the model
76*89c4ff92SAndroid Build Coastguard Worker     // input and weights quantization scale 1.0f and offset 0 in the model
77*89c4ff92SAndroid Build Coastguard Worker     // output quantization scale 2.0f and offset 0 in the model
78*89c4ff92SAndroid Build Coastguard Worker     FullyConnectedTest<int8_t>(backends,
79*89c4ff92SAndroid Build Coastguard Worker                                 ::tflite::TensorType_INT8,
80*89c4ff92SAndroid Build Coastguard Worker                                 tflite::ActivationFunctionType_NONE,
81*89c4ff92SAndroid Build Coastguard Worker                                 inputTensorShape,
82*89c4ff92SAndroid Build Coastguard Worker                                 weightsTensorShape,
83*89c4ff92SAndroid Build Coastguard Worker                                 biasTensorShape,
84*89c4ff92SAndroid Build Coastguard Worker                                 outputTensorShape,
85*89c4ff92SAndroid Build Coastguard Worker                                 inputValues,
86*89c4ff92SAndroid Build Coastguard Worker                                 expectedOutputValues,
87*89c4ff92SAndroid Build Coastguard Worker                                 weightsData,
88*89c4ff92SAndroid Build Coastguard Worker                                 constantWeights);
89*89c4ff92SAndroid Build Coastguard Worker }
90*89c4ff92SAndroid Build Coastguard Worker 
91*89c4ff92SAndroid Build Coastguard Worker TEST_SUITE("FullyConnected_GpuAccTests")
92*89c4ff92SAndroid Build Coastguard Worker {
93*89c4ff92SAndroid Build Coastguard Worker 
94*89c4ff92SAndroid Build Coastguard Worker TEST_CASE ("FullyConnected_FP32_GpuAcc_Test")
95*89c4ff92SAndroid Build Coastguard Worker {
96*89c4ff92SAndroid Build Coastguard Worker     std::vector<armnn::BackendId> backends = { armnn::Compute::GpuAcc };
97*89c4ff92SAndroid Build Coastguard Worker     FullyConnectedFp32Test(backends);
98*89c4ff92SAndroid Build Coastguard Worker }
99*89c4ff92SAndroid Build Coastguard Worker 
100*89c4ff92SAndroid Build Coastguard Worker TEST_CASE ("FullyConnected_Int8_GpuAcc_Test")
101*89c4ff92SAndroid Build Coastguard Worker {
102*89c4ff92SAndroid Build Coastguard Worker     std::vector<armnn::BackendId> backends = { armnn::Compute::GpuAcc };
103*89c4ff92SAndroid Build Coastguard Worker     FullyConnectedInt8Test(backends);
104*89c4ff92SAndroid Build Coastguard Worker }
105*89c4ff92SAndroid Build Coastguard Worker 
106*89c4ff92SAndroid Build Coastguard Worker TEST_CASE ("FullyConnected_Activation_GpuAcc_Test")
107*89c4ff92SAndroid Build Coastguard Worker {
108*89c4ff92SAndroid Build Coastguard Worker     std::vector<armnn::BackendId> backends = { armnn::Compute::GpuAcc };
109*89c4ff92SAndroid Build Coastguard Worker     FullyConnectedActivationTest(backends);
110*89c4ff92SAndroid Build Coastguard Worker }
111*89c4ff92SAndroid Build Coastguard Worker 
112*89c4ff92SAndroid Build Coastguard Worker } // End of TEST_SUITE("FullyConnected_GpuAccTests")
113*89c4ff92SAndroid Build Coastguard Worker 
114*89c4ff92SAndroid Build Coastguard Worker TEST_SUITE("FullyConnected_CpuAccTests")
115*89c4ff92SAndroid Build Coastguard Worker {
116*89c4ff92SAndroid Build Coastguard Worker 
117*89c4ff92SAndroid Build Coastguard Worker TEST_CASE ("FullyConnected_FP32_CpuAcc_Test")
118*89c4ff92SAndroid Build Coastguard Worker {
119*89c4ff92SAndroid Build Coastguard Worker     std::vector<armnn::BackendId> backends = { armnn::Compute::CpuAcc };
120*89c4ff92SAndroid Build Coastguard Worker     FullyConnectedFp32Test(backends);
121*89c4ff92SAndroid Build Coastguard Worker }
122*89c4ff92SAndroid Build Coastguard Worker 
123*89c4ff92SAndroid Build Coastguard Worker TEST_CASE ("FullyConnected_Int8_CpuAcc_Test")
124*89c4ff92SAndroid Build Coastguard Worker {
125*89c4ff92SAndroid Build Coastguard Worker     std::vector<armnn::BackendId> backends = { armnn::Compute::CpuAcc };
126*89c4ff92SAndroid Build Coastguard Worker     FullyConnectedInt8Test(backends);
127*89c4ff92SAndroid Build Coastguard Worker }
128*89c4ff92SAndroid Build Coastguard Worker 
129*89c4ff92SAndroid Build Coastguard Worker TEST_CASE ("FullyConnected_Activation_CpuAcc_Test")
130*89c4ff92SAndroid Build Coastguard Worker {
131*89c4ff92SAndroid Build Coastguard Worker     std::vector<armnn::BackendId> backends = { armnn::Compute::CpuAcc };
132*89c4ff92SAndroid Build Coastguard Worker     FullyConnectedActivationTest(backends);
133*89c4ff92SAndroid Build Coastguard Worker }
134*89c4ff92SAndroid Build Coastguard Worker 
135*89c4ff92SAndroid Build Coastguard Worker } // End of TEST_SUITE("FullyConnected_CpuAccTests")
136*89c4ff92SAndroid Build Coastguard Worker 
137*89c4ff92SAndroid Build Coastguard Worker TEST_SUITE("FullyConnected_CpuRefTests")
138*89c4ff92SAndroid Build Coastguard Worker {
139*89c4ff92SAndroid Build Coastguard Worker 
140*89c4ff92SAndroid Build Coastguard Worker TEST_CASE ("FullyConnected_FP32_CpuRef_Test")
141*89c4ff92SAndroid Build Coastguard Worker {
142*89c4ff92SAndroid Build Coastguard Worker     std::vector<armnn::BackendId> backends = { armnn::Compute::CpuRef };
143*89c4ff92SAndroid Build Coastguard Worker     FullyConnectedFp32Test(backends);
144*89c4ff92SAndroid Build Coastguard Worker }
145*89c4ff92SAndroid Build Coastguard Worker 
146*89c4ff92SAndroid Build Coastguard Worker TEST_CASE ("FullyConnected_Int8_CpuRef_Test")
147*89c4ff92SAndroid Build Coastguard Worker {
148*89c4ff92SAndroid Build Coastguard Worker     std::vector<armnn::BackendId> backends = { armnn::Compute::CpuRef };
149*89c4ff92SAndroid Build Coastguard Worker     FullyConnectedInt8Test(backends);
150*89c4ff92SAndroid Build Coastguard Worker }
151*89c4ff92SAndroid Build Coastguard Worker 
152*89c4ff92SAndroid Build Coastguard Worker TEST_CASE ("FullyConnected_Activation_CpuRef_Test")
153*89c4ff92SAndroid Build Coastguard Worker {
154*89c4ff92SAndroid Build Coastguard Worker     std::vector<armnn::BackendId> backends = { armnn::Compute::CpuRef };
155*89c4ff92SAndroid Build Coastguard Worker     FullyConnectedActivationTest(backends);
156*89c4ff92SAndroid Build Coastguard Worker }
157*89c4ff92SAndroid Build Coastguard Worker 
158*89c4ff92SAndroid Build Coastguard Worker TEST_CASE ("FullyConnected_Weights_As_Inputs_FP32_CpuRef_Test")
159*89c4ff92SAndroid Build Coastguard Worker {
160*89c4ff92SAndroid Build Coastguard Worker     std::vector<armnn::BackendId> backends = { armnn::Compute::CpuRef };
161*89c4ff92SAndroid Build Coastguard Worker     FullyConnectedFp32Test(backends, false);
162*89c4ff92SAndroid Build Coastguard Worker }
163*89c4ff92SAndroid Build Coastguard Worker 
164*89c4ff92SAndroid Build Coastguard Worker TEST_CASE ("FullyConnected_Weights_As_Inputs_Int8_CpuRef_Test")
165*89c4ff92SAndroid Build Coastguard Worker {
166*89c4ff92SAndroid Build Coastguard Worker     std::vector<armnn::BackendId> backends = { armnn::Compute::CpuRef };
167*89c4ff92SAndroid Build Coastguard Worker     FullyConnectedInt8Test(backends, false);
168*89c4ff92SAndroid Build Coastguard Worker }
169*89c4ff92SAndroid Build Coastguard Worker 
170*89c4ff92SAndroid Build Coastguard Worker TEST_CASE ("FullyConnected_Weights_As_Inputs_Activation_CpuRef_Test")
171*89c4ff92SAndroid Build Coastguard Worker {
172*89c4ff92SAndroid Build Coastguard Worker     std::vector<armnn::BackendId> backends = { armnn::Compute::CpuRef };
173*89c4ff92SAndroid Build Coastguard Worker     FullyConnectedActivationTest(backends, false);
174*89c4ff92SAndroid Build Coastguard Worker }
175*89c4ff92SAndroid Build Coastguard Worker 
176*89c4ff92SAndroid Build Coastguard Worker } // End of TEST_SUITE("FullyConnected_CpuRefTests")
177*89c4ff92SAndroid Build Coastguard Worker 
178*89c4ff92SAndroid Build Coastguard Worker } // anonymous namespace