xref: /aosp_15_r20/external/tensorflow/tensorflow/lite/delegates/hexagon/builders/tests/matmul_test.cc (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include <gtest/gtest.h>
16 #include "tensorflow/lite/delegates/hexagon/builders/tests/hexagon_delegate_op_model.h"
17 #include "tensorflow/lite/kernels/internal/types.h"
18 #include "tensorflow/lite/schema/schema_generated.h"
19 
20 namespace tflite {
21 using testing::ElementsAreArray;
22 
23 class FullyConnectedOpModel : public SingleOpModelWithHexagon {
24  public:
FullyConnectedOpModel(int units,int batches,const TensorData & input,const TensorData & output,bool optional_bias,bool const_weights,ActivationFunctionType activation_function=ActivationFunctionType_NONE)25   FullyConnectedOpModel(
26       int units, int batches, const TensorData& input, const TensorData& output,
27       bool optional_bias, bool const_weights,
28       ActivationFunctionType activation_function = ActivationFunctionType_NONE)
29       : batches_(batches), units_(units) {
30     int total_input_size = 1;
31     for (size_t i = 0; i < input.shape.size(); ++i) {
32       total_input_size *= input.shape[i];
33     }
34     input_size_ = total_input_size / batches_;
35 
36     input_ = AddInput(input);
37     weights_ =
38         AddInput({input.type, {units_, input_size_}, input.min, input.max});
39 
40     if (optional_bias) {
41       bias_ = AddNullInput();
42     } else {
43       auto bias_scale = GetScale(input_) * GetScale(weights_);
44       TensorData bias{TensorType_INT32, {units_}, 0, 0, bias_scale};
45       bias_ = AddInput(bias);
46     }
47 
48     output_ = AddOutput(output);
49 
50     SetBuiltinOp(
51         BuiltinOperator_FULLY_CONNECTED, BuiltinOptions_FullyConnectedOptions,
52         CreateFullyConnectedOptions(builder_, activation_function,
53                                     FullyConnectedOptionsWeightsFormat_DEFAULT,
54                                     /*keep_num_dims=*/false)
55             .Union());
56     BuildInterpreter({GetShape(input_), GetShape(weights_)});
57 
58     // Weights & bias tensors need to be constant.
59     // We don't use AddConstInput to allow setting filter values later.
60     if (const_weights) {
61       auto* weights_tensor = interpreter_->tensor(weights_);
62       weights_tensor->allocation_type = kTfLiteMmapRo;
63     }
64     if (!optional_bias) {
65       auto* bias_tensor = interpreter_->tensor(bias_);
66       bias_tensor->allocation_type = kTfLiteMmapRo;
67     }
68   }
69 
SetBias(const std::vector<float> & data)70   void SetBias(const std::vector<float>& data) {
71     QuantizeAndPopulate<int>(bias_, data);
72   }
73 
74   template <typename T>
SetWeights(const std::vector<float> & data)75   void SetWeights(const std::vector<float>& data) {
76     QuantizeAndPopulate<T>(weights_, data);
77   }
78 
79   template <typename T>
SetInput(const std::vector<float> & data)80   void SetInput(const std::vector<float>& data) {
81     QuantizeAndPopulate<T>(input_, data);
82   }
83 
84   template <typename T>
GetOutput()85   std::vector<T> GetOutput() {
86     return ExtractVector<T>(output_);
87   }
88 
89   template <typename T>
GetDequantizedOutput()90   std::vector<float> GetDequantizedOutput() {
91     return Dequantize<T>(ExtractVector<T>(output_), GetScale(output_),
92                          GetZeroPoint(output_));
93   }
94 
95  protected:
96   int input_;
97   int weights_;
98   int bias_;
99   int output_;
100 
101   int batches_;
102   int units_;
103   int input_size_;
104 };
105 
106 class QuantizedFullyConnectedOpTest
107     : public ::testing::TestWithParam<ActivationFunctionType> {};
108 
TEST_P(QuantizedFullyConnectedOpTest,TestQuantizedInt8)109 TEST_P(QuantizedFullyConnectedOpTest, TestQuantizedInt8) {
110   FullyConnectedOpModel m(/*units=*/3, /*batches*/ 2,
111                           /*input=*/{TensorType_INT8, {2, 10}, -63.5, 64},
112                           /*output=*/{TensorType_INT8, {}, -127, 128},
113                           /*optional_bias*/ false, /*const_weight*/ false,
114                           GetParam());
115 
116   m.SetWeights<int8_t>({
117       1, 2, 3, 4, 5, 6, 7, 8, 9, 10,  // u = 0
118       1, 2, 3, 4, 5, 6, 7, 8, 9, 10,  // u = 1
119       1, 2, 3, 4, 5, 6, 7, 8, 9, 10,  // u = 2
120   });
121   m.SetBias({1, 2, 3});
122 
123   m.SetInput<int8_t>({
124       1, 2, 3, 4, 5, 6, 7, 8,  -9, -10,  // b = 0
125       1, 2, 3, 4, 5, 6, 7, -8, 9,  -10,  // b = 1
126   });
127 
128   ASSERT_EQ(m.Invoke(), kTfLiteOk);
129   auto reference_output = m.GetDequantizedOutput<int8_t>();
130 
131   m.ApplyDelegateAndInvoke();
132 
133   EXPECT_THAT(m.GetDequantizedOutput<int8_t>(),
134               ElementsAreArray(ArrayFloatNear(reference_output)));
135 }
136 
TEST_P(QuantizedFullyConnectedOpTest,TestQuantizedUint8)137 TEST_P(QuantizedFullyConnectedOpTest, TestQuantizedUint8) {
138   FullyConnectedOpModel m(
139       /*units=*/3, /*batches*/ 2,
140       /*input=*/{TensorType_UINT8, {2, 10}, -63.5, 64},
141       /*output=*/{TensorType_UINT8, {}, -127, 128}, /*optional_bias*/ false,
142       /*const_weight*/ false, GetParam());
143 
144   m.SetWeights<uint8_t>({
145       1, 2, 3, 4, 5, 6, 7, 8, 9, 10,  // u = 0
146       1, 2, 3, 4, 5, 6, 7, 8, 9, 10,  // u = 1
147       1, 2, 3, 4, 5, 6, 7, 8, 9, 10,  // u = 2
148   });
149   m.SetBias({1, 2, 3});
150 
151   m.SetInput<uint8_t>({
152       1, 2, 3, 4, 5, 6, 7, 8,  -9, -10,  // b = 0
153       1, 2, 3, 4, 5, 6, 7, -8, 9,  -10,  // b = 1
154   });
155 
156   ASSERT_EQ(m.Invoke(), kTfLiteOk);
157   auto reference_output = m.GetDequantizedOutput<uint8_t>();
158 
159   m.ApplyDelegateAndInvoke();
160 
161   EXPECT_THAT(m.GetDequantizedOutput<uint8_t>(),
162               ElementsAreArray(ArrayFloatNear(reference_output)));
163 }
164 
TEST_P(QuantizedFullyConnectedOpTest,TestQuantizedUint8_NoBias)165 TEST_P(QuantizedFullyConnectedOpTest, TestQuantizedUint8_NoBias) {
166   FullyConnectedOpModel m(
167       /*units=*/3, /*batches*/ 2,
168       /*input=*/{TensorType_UINT8, {2, 10}, -63.5, 64},
169       /*output=*/{TensorType_UINT8, {}, -127, 128}, /*optional_bias*/ true,
170       /*const_weight*/ false, GetParam());
171 
172   m.SetWeights<uint8_t>({
173       1, 2, 3, 4, 5, 6, 7, 8, 9, 10,  // u = 0
174       1, 2, 3, 4, 5, 6, 7, 8, 9, 10,  // u = 1
175       1, 2, 3, 4, 5, 6, 7, 8, 9, 10,  // u = 2
176   });
177 
178   m.SetInput<uint8_t>({
179       1, 2, 3, 4, 5, 6, 7, 8,  -9, -10,  // b = 0
180       1, 2, 3, 4, 5, 6, 7, -8, 9,  -10,  // b = 1
181   });
182 
183   ASSERT_EQ(m.Invoke(), kTfLiteOk);
184   auto reference_output = m.GetDequantizedOutput<uint8_t>();
185 
186   m.ApplyDelegateAndInvoke();
187 
188   EXPECT_THAT(m.GetDequantizedOutput<uint8_t>(),
189               ElementsAreArray(ArrayFloatNear(reference_output)));
190 }
191 
TEST_P(QuantizedFullyConnectedOpTest,TestQuantizedInt8_NoBias)192 TEST_P(QuantizedFullyConnectedOpTest, TestQuantizedInt8_NoBias) {
193   FullyConnectedOpModel m(/*units=*/3, /*batches*/ 2,
194                           /*input=*/{TensorType_INT8, {2, 10}, -63.5, 64},
195                           /*output=*/{TensorType_INT8, {}, -127, 128},
196                           /*optional_bias*/ true, /*const_weight*/ false,
197                           GetParam());
198 
199   m.SetWeights<int8_t>({
200       1, 2, 3, 4, 5, 6, 7, 8, 9, 10,  // u = 0
201       1, 2, 3, 4, 5, 6, 7, 8, 9, 10,  // u = 1
202       1, 2, 3, 4, 5, 6, 7, 8, 9, 10,  // u = 2
203   });
204 
205   m.SetInput<int8_t>({
206       1, 2, 3, 4, 5, 6, 7, 8,  -9, -10,  // b = 0
207       1, 2, 3, 4, 5, 6, 7, -8, 9,  -10,  // b = 1
208   });
209 
210   ASSERT_EQ(m.Invoke(), kTfLiteOk);
211   auto reference_output = m.GetDequantizedOutput<int8_t>();
212 
213   m.ApplyDelegateAndInvoke();
214 
215   EXPECT_THAT(m.GetDequantizedOutput<int8_t>(),
216               ElementsAreArray(ArrayFloatNear(reference_output)));
217 }
218 
TEST_P(QuantizedFullyConnectedOpTest,TestQuantizedInt8_NonConstWeights)219 TEST_P(QuantizedFullyConnectedOpTest, TestQuantizedInt8_NonConstWeights) {
220   FullyConnectedOpModel m(/*units=*/3, /*batches*/ 2,
221                           /*input=*/{TensorType_INT8, {2, 10}, -63.5, 64},
222                           /*output=*/{TensorType_INT8, {}, -127, 128},
223                           /*optional_bias=*/false, /*const_weights=*/false,
224                           GetParam());
225 
226   m.SetWeights<int8_t>({
227       1, 2, 3, 4, 5, 6, 7, 8, 9, 10,  // u = 0
228       1, 2, 3, 4, 5, 6, 7, 8, 9, 10,  // u = 1
229       1, 2, 3, 4, 5, 6, 7, 8, 9, 10,  // u = 2
230   });
231   m.SetBias({1, 2, 3});
232 
233   m.SetInput<int8_t>({
234       1, 2, 3, 4, 5, 6, 7, 8,  -9, -10,  // b = 0
235       1, 2, 3, 4, 5, 6, 7, -8, 9,  -10,  // b = 1
236   });
237 
238   ASSERT_EQ(m.Invoke(), kTfLiteOk);
239   auto reference_output = m.GetDequantizedOutput<int8_t>();
240 
241   m.ApplyDelegateAndInvoke();
242 
243   EXPECT_THAT(m.GetDequantizedOutput<int8_t>(),
244               ElementsAreArray(ArrayFloatNear(reference_output)));
245 }
246 
TEST_P(QuantizedFullyConnectedOpTest,TestQuantizedUint8_NonConstWeights)247 TEST_P(QuantizedFullyConnectedOpTest, TestQuantizedUint8_NonConstWeights) {
248   FullyConnectedOpModel m(
249       /*units=*/3, /*batches*/ 2,
250       /*input=*/{TensorType_UINT8, {2, 10}, -63.5, 64},
251       /*output=*/{TensorType_UINT8, {}, -127, 128}, /*optional_bias=*/false,
252       /*const_weights=*/false, GetParam());
253 
254   m.SetWeights<uint8_t>({
255       1, 2, 3, 4, 5, 6, 7, 8, 9, 10,  // u = 0
256       1, 2, 3, 4, 5, 6, 7, 8, 9, 10,  // u = 1
257       1, 2, 3, 4, 5, 6, 7, 8, 9, 10,  // u = 2
258   });
259   m.SetBias({1, 2, 3});
260 
261   m.SetInput<uint8_t>({
262       1, 2, 3, 4, 5, 6, 7, 8,  -9, -10,  // b = 0
263       1, 2, 3, 4, 5, 6, 7, -8, 9,  -10,  // b = 1
264   });
265 
266   ASSERT_EQ(m.Invoke(), kTfLiteOk);
267   auto reference_output = m.GetDequantizedOutput<uint8_t>();
268 
269   m.ApplyDelegateAndInvoke();
270 
271   EXPECT_THAT(m.GetDequantizedOutput<uint8_t>(),
272               ElementsAreArray(ArrayFloatNear(reference_output)));
273 }
274 
275 INSTANTIATE_TEST_SUITE_P(QuantizedFullyConnectedOpTest,
276                          QuantizedFullyConnectedOpTest,
277                          testing::Values(ActivationFunctionType_NONE,
278                                          ActivationFunctionType_RELU));
279 
TEST(QuantizedFullyConnected,TestQuantizedUint8_NonConstWeights_Relu6)280 TEST(QuantizedFullyConnected, TestQuantizedUint8_NonConstWeights_Relu6) {
281   // We rely on output min/max set to values that guarantees the activation
282   // function results.
283   // So setting output min/max (0, 6) should be equivalent to relu6
284   FullyConnectedOpModel m(
285       /*units=*/3, /*batches*/ 2,
286       /*input=*/{TensorType_UINT8, {2, 10}, -63.5, 64},
287       /*output=*/{TensorType_UINT8, {}, 0, 6}, /*optional_bias=*/false,
288       /*const_weights=*/false, ActivationFunctionType_RELU6);
289 
290   m.SetWeights<uint8_t>({
291       1, 2, 3, 4, 5, 6, 7, 8, 9, 10,  // u = 0
292       1, 2, 3, 4, 5, 6, 7, 8, 9, 10,  // u = 1
293       1, 2, 3, 4, 5, 6, 7, 8, 9, 10,  // u = 2
294   });
295   m.SetBias({1, 2, 3});
296 
297   m.SetInput<uint8_t>({
298       1, 2, 3, 4, 5, 6, 7, 8,  -9, -10,  // b = 0
299       1, 2, 3, 4, 5, 6, 7, -8, 9,  -10,  // b = 1
300   });
301 
302   ASSERT_EQ(m.Invoke(), kTfLiteOk);
303   auto reference_output = m.GetDequantizedOutput<uint8_t>();
304 
305   m.ApplyDelegateAndInvoke();
306 
307   EXPECT_THAT(m.GetDequantizedOutput<uint8_t>(),
308               ElementsAreArray(ArrayFloatNear(reference_output)));
309 }
310 
311 }  // namespace tflite
312