1*89c4ff92SAndroid Build Coastguard Worker //
2*89c4ff92SAndroid Build Coastguard Worker // Copyright © 2020-2021,2023 Arm Ltd and Contributors. All rights reserved.
3*89c4ff92SAndroid Build Coastguard Worker // SPDX-License-Identifier: MIT
4*89c4ff92SAndroid Build Coastguard Worker //
5*89c4ff92SAndroid Build Coastguard Worker
6*89c4ff92SAndroid Build Coastguard Worker #include "FullyConnectedTestHelper.hpp"
7*89c4ff92SAndroid Build Coastguard Worker
8*89c4ff92SAndroid Build Coastguard Worker namespace
9*89c4ff92SAndroid Build Coastguard Worker {
10*89c4ff92SAndroid Build Coastguard Worker
FullyConnectedFp32Test(std::vector<armnn::BackendId> & backends,bool constantWeights=true)11*89c4ff92SAndroid Build Coastguard Worker void FullyConnectedFp32Test(std::vector<armnn::BackendId>& backends, bool constantWeights = true)
12*89c4ff92SAndroid Build Coastguard Worker {
13*89c4ff92SAndroid Build Coastguard Worker std::vector<int32_t> inputTensorShape { 1, 4, 1, 1 };
14*89c4ff92SAndroid Build Coastguard Worker std::vector<int32_t> weightsTensorShape { 1, 4 };
15*89c4ff92SAndroid Build Coastguard Worker std::vector<int32_t> biasTensorShape { 1 };
16*89c4ff92SAndroid Build Coastguard Worker std::vector<int32_t> outputTensorShape { 1, 1 };
17*89c4ff92SAndroid Build Coastguard Worker
18*89c4ff92SAndroid Build Coastguard Worker std::vector<float> inputValues = { 10, 20, 30, 40 };
19*89c4ff92SAndroid Build Coastguard Worker std::vector<float> weightsData = { 2, 3, 4, 5 };
20*89c4ff92SAndroid Build Coastguard Worker
21*89c4ff92SAndroid Build Coastguard Worker std::vector<float> expectedOutputValues = { (400 + 10) };
22*89c4ff92SAndroid Build Coastguard Worker
23*89c4ff92SAndroid Build Coastguard Worker // bias is set std::vector<float> biasData = { 10 } in the model
24*89c4ff92SAndroid Build Coastguard Worker FullyConnectedTest<float>(backends,
25*89c4ff92SAndroid Build Coastguard Worker ::tflite::TensorType_FLOAT32,
26*89c4ff92SAndroid Build Coastguard Worker tflite::ActivationFunctionType_NONE,
27*89c4ff92SAndroid Build Coastguard Worker inputTensorShape,
28*89c4ff92SAndroid Build Coastguard Worker weightsTensorShape,
29*89c4ff92SAndroid Build Coastguard Worker biasTensorShape,
30*89c4ff92SAndroid Build Coastguard Worker outputTensorShape,
31*89c4ff92SAndroid Build Coastguard Worker inputValues,
32*89c4ff92SAndroid Build Coastguard Worker expectedOutputValues,
33*89c4ff92SAndroid Build Coastguard Worker weightsData,
34*89c4ff92SAndroid Build Coastguard Worker constantWeights);
35*89c4ff92SAndroid Build Coastguard Worker }
36*89c4ff92SAndroid Build Coastguard Worker
FullyConnectedActivationTest(std::vector<armnn::BackendId> & backends,bool constantWeights=true)37*89c4ff92SAndroid Build Coastguard Worker void FullyConnectedActivationTest(std::vector<armnn::BackendId>& backends, bool constantWeights = true)
38*89c4ff92SAndroid Build Coastguard Worker {
39*89c4ff92SAndroid Build Coastguard Worker std::vector<int32_t> inputTensorShape { 1, 4, 1, 1 };
40*89c4ff92SAndroid Build Coastguard Worker std::vector<int32_t> weightsTensorShape { 1, 4 };
41*89c4ff92SAndroid Build Coastguard Worker std::vector<int32_t> biasTensorShape { 1 };
42*89c4ff92SAndroid Build Coastguard Worker std::vector<int32_t> outputTensorShape { 1, 1 };
43*89c4ff92SAndroid Build Coastguard Worker
44*89c4ff92SAndroid Build Coastguard Worker std::vector<float> inputValues = { -10, 20, 30, 40 };
45*89c4ff92SAndroid Build Coastguard Worker std::vector<float> weightsData = { 2, 3, 4, -5 };
46*89c4ff92SAndroid Build Coastguard Worker
47*89c4ff92SAndroid Build Coastguard Worker std::vector<float> expectedOutputValues = { 0 };
48*89c4ff92SAndroid Build Coastguard Worker
49*89c4ff92SAndroid Build Coastguard Worker // bias is set std::vector<float> biasData = { 10 } in the model
50*89c4ff92SAndroid Build Coastguard Worker FullyConnectedTest<float>(backends,
51*89c4ff92SAndroid Build Coastguard Worker ::tflite::TensorType_FLOAT32,
52*89c4ff92SAndroid Build Coastguard Worker tflite::ActivationFunctionType_RELU,
53*89c4ff92SAndroid Build Coastguard Worker inputTensorShape,
54*89c4ff92SAndroid Build Coastguard Worker weightsTensorShape,
55*89c4ff92SAndroid Build Coastguard Worker biasTensorShape,
56*89c4ff92SAndroid Build Coastguard Worker outputTensorShape,
57*89c4ff92SAndroid Build Coastguard Worker inputValues,
58*89c4ff92SAndroid Build Coastguard Worker expectedOutputValues,
59*89c4ff92SAndroid Build Coastguard Worker weightsData,
60*89c4ff92SAndroid Build Coastguard Worker constantWeights);
61*89c4ff92SAndroid Build Coastguard Worker }
62*89c4ff92SAndroid Build Coastguard Worker
FullyConnectedInt8Test(std::vector<armnn::BackendId> & backends,bool constantWeights=true)63*89c4ff92SAndroid Build Coastguard Worker void FullyConnectedInt8Test(std::vector<armnn::BackendId>& backends, bool constantWeights = true)
64*89c4ff92SAndroid Build Coastguard Worker {
65*89c4ff92SAndroid Build Coastguard Worker std::vector<int32_t> inputTensorShape { 1, 4, 2, 1 };
66*89c4ff92SAndroid Build Coastguard Worker std::vector<int32_t> weightsTensorShape { 1, 4 };
67*89c4ff92SAndroid Build Coastguard Worker std::vector<int32_t> biasTensorShape { 1 };
68*89c4ff92SAndroid Build Coastguard Worker std::vector<int32_t> outputTensorShape { 2, 1 };
69*89c4ff92SAndroid Build Coastguard Worker
70*89c4ff92SAndroid Build Coastguard Worker std::vector<int8_t> inputValues = { 1, 2, 3, 4, 5, 10, 15, 20 };
71*89c4ff92SAndroid Build Coastguard Worker std::vector<int8_t> weightsData = { 2, 3, 4, 5 };
72*89c4ff92SAndroid Build Coastguard Worker
73*89c4ff92SAndroid Build Coastguard Worker std::vector<int8_t> expectedOutputValues = { 25, 105 }; // (40 + 10) / 2, (200 + 10) / 2
74*89c4ff92SAndroid Build Coastguard Worker
75*89c4ff92SAndroid Build Coastguard Worker // bias is set std::vector<int32_t> biasData = { 10 } in the model
76*89c4ff92SAndroid Build Coastguard Worker // input and weights quantization scale 1.0f and offset 0 in the model
77*89c4ff92SAndroid Build Coastguard Worker // output quantization scale 2.0f and offset 0 in the model
78*89c4ff92SAndroid Build Coastguard Worker FullyConnectedTest<int8_t>(backends,
79*89c4ff92SAndroid Build Coastguard Worker ::tflite::TensorType_INT8,
80*89c4ff92SAndroid Build Coastguard Worker tflite::ActivationFunctionType_NONE,
81*89c4ff92SAndroid Build Coastguard Worker inputTensorShape,
82*89c4ff92SAndroid Build Coastguard Worker weightsTensorShape,
83*89c4ff92SAndroid Build Coastguard Worker biasTensorShape,
84*89c4ff92SAndroid Build Coastguard Worker outputTensorShape,
85*89c4ff92SAndroid Build Coastguard Worker inputValues,
86*89c4ff92SAndroid Build Coastguard Worker expectedOutputValues,
87*89c4ff92SAndroid Build Coastguard Worker weightsData,
88*89c4ff92SAndroid Build Coastguard Worker constantWeights);
89*89c4ff92SAndroid Build Coastguard Worker }
90*89c4ff92SAndroid Build Coastguard Worker
91*89c4ff92SAndroid Build Coastguard Worker TEST_SUITE("FullyConnected_GpuAccTests")
92*89c4ff92SAndroid Build Coastguard Worker {
93*89c4ff92SAndroid Build Coastguard Worker
94*89c4ff92SAndroid Build Coastguard Worker TEST_CASE ("FullyConnected_FP32_GpuAcc_Test")
95*89c4ff92SAndroid Build Coastguard Worker {
96*89c4ff92SAndroid Build Coastguard Worker std::vector<armnn::BackendId> backends = { armnn::Compute::GpuAcc };
97*89c4ff92SAndroid Build Coastguard Worker FullyConnectedFp32Test(backends);
98*89c4ff92SAndroid Build Coastguard Worker }
99*89c4ff92SAndroid Build Coastguard Worker
100*89c4ff92SAndroid Build Coastguard Worker TEST_CASE ("FullyConnected_Int8_GpuAcc_Test")
101*89c4ff92SAndroid Build Coastguard Worker {
102*89c4ff92SAndroid Build Coastguard Worker std::vector<armnn::BackendId> backends = { armnn::Compute::GpuAcc };
103*89c4ff92SAndroid Build Coastguard Worker FullyConnectedInt8Test(backends);
104*89c4ff92SAndroid Build Coastguard Worker }
105*89c4ff92SAndroid Build Coastguard Worker
106*89c4ff92SAndroid Build Coastguard Worker TEST_CASE ("FullyConnected_Activation_GpuAcc_Test")
107*89c4ff92SAndroid Build Coastguard Worker {
108*89c4ff92SAndroid Build Coastguard Worker std::vector<armnn::BackendId> backends = { armnn::Compute::GpuAcc };
109*89c4ff92SAndroid Build Coastguard Worker FullyConnectedActivationTest(backends);
110*89c4ff92SAndroid Build Coastguard Worker }
111*89c4ff92SAndroid Build Coastguard Worker
112*89c4ff92SAndroid Build Coastguard Worker } // End of TEST_SUITE("FullyConnected_GpuAccTests")
113*89c4ff92SAndroid Build Coastguard Worker
114*89c4ff92SAndroid Build Coastguard Worker TEST_SUITE("FullyConnected_CpuAccTests")
115*89c4ff92SAndroid Build Coastguard Worker {
116*89c4ff92SAndroid Build Coastguard Worker
117*89c4ff92SAndroid Build Coastguard Worker TEST_CASE ("FullyConnected_FP32_CpuAcc_Test")
118*89c4ff92SAndroid Build Coastguard Worker {
119*89c4ff92SAndroid Build Coastguard Worker std::vector<armnn::BackendId> backends = { armnn::Compute::CpuAcc };
120*89c4ff92SAndroid Build Coastguard Worker FullyConnectedFp32Test(backends);
121*89c4ff92SAndroid Build Coastguard Worker }
122*89c4ff92SAndroid Build Coastguard Worker
123*89c4ff92SAndroid Build Coastguard Worker TEST_CASE ("FullyConnected_Int8_CpuAcc_Test")
124*89c4ff92SAndroid Build Coastguard Worker {
125*89c4ff92SAndroid Build Coastguard Worker std::vector<armnn::BackendId> backends = { armnn::Compute::CpuAcc };
126*89c4ff92SAndroid Build Coastguard Worker FullyConnectedInt8Test(backends);
127*89c4ff92SAndroid Build Coastguard Worker }
128*89c4ff92SAndroid Build Coastguard Worker
129*89c4ff92SAndroid Build Coastguard Worker TEST_CASE ("FullyConnected_Activation_CpuAcc_Test")
130*89c4ff92SAndroid Build Coastguard Worker {
131*89c4ff92SAndroid Build Coastguard Worker std::vector<armnn::BackendId> backends = { armnn::Compute::CpuAcc };
132*89c4ff92SAndroid Build Coastguard Worker FullyConnectedActivationTest(backends);
133*89c4ff92SAndroid Build Coastguard Worker }
134*89c4ff92SAndroid Build Coastguard Worker
135*89c4ff92SAndroid Build Coastguard Worker } // End of TEST_SUITE("FullyConnected_CpuAccTests")
136*89c4ff92SAndroid Build Coastguard Worker
137*89c4ff92SAndroid Build Coastguard Worker TEST_SUITE("FullyConnected_CpuRefTests")
138*89c4ff92SAndroid Build Coastguard Worker {
139*89c4ff92SAndroid Build Coastguard Worker
140*89c4ff92SAndroid Build Coastguard Worker TEST_CASE ("FullyConnected_FP32_CpuRef_Test")
141*89c4ff92SAndroid Build Coastguard Worker {
142*89c4ff92SAndroid Build Coastguard Worker std::vector<armnn::BackendId> backends = { armnn::Compute::CpuRef };
143*89c4ff92SAndroid Build Coastguard Worker FullyConnectedFp32Test(backends);
144*89c4ff92SAndroid Build Coastguard Worker }
145*89c4ff92SAndroid Build Coastguard Worker
146*89c4ff92SAndroid Build Coastguard Worker TEST_CASE ("FullyConnected_Int8_CpuRef_Test")
147*89c4ff92SAndroid Build Coastguard Worker {
148*89c4ff92SAndroid Build Coastguard Worker std::vector<armnn::BackendId> backends = { armnn::Compute::CpuRef };
149*89c4ff92SAndroid Build Coastguard Worker FullyConnectedInt8Test(backends);
150*89c4ff92SAndroid Build Coastguard Worker }
151*89c4ff92SAndroid Build Coastguard Worker
152*89c4ff92SAndroid Build Coastguard Worker TEST_CASE ("FullyConnected_Activation_CpuRef_Test")
153*89c4ff92SAndroid Build Coastguard Worker {
154*89c4ff92SAndroid Build Coastguard Worker std::vector<armnn::BackendId> backends = { armnn::Compute::CpuRef };
155*89c4ff92SAndroid Build Coastguard Worker FullyConnectedActivationTest(backends);
156*89c4ff92SAndroid Build Coastguard Worker }
157*89c4ff92SAndroid Build Coastguard Worker
158*89c4ff92SAndroid Build Coastguard Worker TEST_CASE ("FullyConnected_Weights_As_Inputs_FP32_CpuRef_Test")
159*89c4ff92SAndroid Build Coastguard Worker {
160*89c4ff92SAndroid Build Coastguard Worker std::vector<armnn::BackendId> backends = { armnn::Compute::CpuRef };
161*89c4ff92SAndroid Build Coastguard Worker FullyConnectedFp32Test(backends, false);
162*89c4ff92SAndroid Build Coastguard Worker }
163*89c4ff92SAndroid Build Coastguard Worker
164*89c4ff92SAndroid Build Coastguard Worker TEST_CASE ("FullyConnected_Weights_As_Inputs_Int8_CpuRef_Test")
165*89c4ff92SAndroid Build Coastguard Worker {
166*89c4ff92SAndroid Build Coastguard Worker std::vector<armnn::BackendId> backends = { armnn::Compute::CpuRef };
167*89c4ff92SAndroid Build Coastguard Worker FullyConnectedInt8Test(backends, false);
168*89c4ff92SAndroid Build Coastguard Worker }
169*89c4ff92SAndroid Build Coastguard Worker
170*89c4ff92SAndroid Build Coastguard Worker TEST_CASE ("FullyConnected_Weights_As_Inputs_Activation_CpuRef_Test")
171*89c4ff92SAndroid Build Coastguard Worker {
172*89c4ff92SAndroid Build Coastguard Worker std::vector<armnn::BackendId> backends = { armnn::Compute::CpuRef };
173*89c4ff92SAndroid Build Coastguard Worker FullyConnectedActivationTest(backends, false);
174*89c4ff92SAndroid Build Coastguard Worker }
175*89c4ff92SAndroid Build Coastguard Worker
176*89c4ff92SAndroid Build Coastguard Worker } // End of TEST_SUITE("FullyConnected_CpuRefTests")
177*89c4ff92SAndroid Build Coastguard Worker
178*89c4ff92SAndroid Build Coastguard Worker } // anonymous namespace