xref: /aosp_15_r20/external/executorch/kernels/quantized/test/op_add_test.cpp (revision 523fa7a60841cd1ecfb9cc4201f1ca8b03ed023a)
1 /*
2  * Copyright (c) Meta Platforms, Inc. and affiliates.
3  * All rights reserved.
4  *
5  * This source code is licensed under the BSD-style license found in the
6  * LICENSE file in the root directory of this source tree.
7  */
8 
9 #include <executorch/kernels/portable/NativeFunctions.h> // Declares the aten operator
10 #include <executorch/kernels/quantized/NativeFunctions.h> // Declares the quantized operator
11 #include <executorch/runtime/core/exec_aten/exec_aten.h>
12 #include <executorch/runtime/core/exec_aten/testing_util/tensor_factory.h>
13 #include <executorch/runtime/core/exec_aten/testing_util/tensor_util.h>
14 #include <executorch/runtime/core/exec_aten/util/scalar_type_util.h>
15 #include <executorch/runtime/platform/runtime.h>
16 #include <executorch/test/utils/DeathTest.h>
17 
18 #include <gtest/gtest.h>
19 #include <limits>
20 
21 using namespace ::testing;
22 using exec_aten::ArrayRef;
23 using exec_aten::optional;
24 using exec_aten::Scalar;
25 using exec_aten::ScalarType;
26 using exec_aten::Tensor;
27 using executorch::runtime::KernelRuntimeContext;
28 using torch::executor::native::add_out;
29 using torch::executor::native::dequantize_per_tensor_out;
30 using torch::executor::native::quantize_per_tensor_out;
31 using torch::executor::native::quantized_add_out;
32 
33 using torch::executor::testing::TensorFactory;
34 
35 /// A generic smoke test that works for any dtype that supports ones() and
36 /// zeros().
37 template <exec_aten::ScalarType DTYPE>
test_dtype()38 void test_dtype() {
39   TensorFactory<ScalarType::Float> tf;
40 
41   Tensor input1 = tf.full({3, 5}, 3.5);
42   Tensor input2 = tf.full({3, 5}, 3.5);
43   double scale = 0.5;
44 
45   int64_t zero_point = 1;
46   int64_t quant_min = 0;
47   int64_t quant_max = 255;
48 
49   TensorFactory<DTYPE> tfo;
50   Tensor qinput1 = tfo.zeros({3, 5});
51   Tensor qinput2 = tfo.zeros({3, 5});
52   Tensor qoutput = tfo.zeros({3, 5});
53   // 3.5 / 0.5 + 1 = 8
54   quantize_per_tensor_out(
55       input1,
56       scale,
57       zero_point,
58       quant_min,
59       quant_max,
60       ScalarType::Byte,
61       qinput1);
62 
63   quantize_per_tensor_out(
64       input2,
65       scale,
66       zero_point,
67       quant_min,
68       quant_max,
69       ScalarType::Byte,
70       qinput2);
71 
72   quantized_add_out(
73       qinput1,
74       scale,
75       zero_point,
76       quant_min,
77       quant_max,
78       qinput2,
79       scale,
80       zero_point,
81       quant_min,
82       quant_max,
83       scale,
84       zero_point,
85       quant_min,
86       quant_max,
87       qoutput);
88 
89   // can lossessly dq here so retrive the full 3.5 in operation
90   // (3.5 + 3.5) / 0.5 + 1 = 15
91   Tensor expected = tfo.full({3, 5}, 15.0);
92 
93   EXPECT_TENSOR_EQ(qoutput, expected);
94 }
95 
TEST(OpQuantizeAddTest,AllDtypesSupported)96 TEST(OpQuantizeAddTest, AllDtypesSupported) {
97   test_dtype<ScalarType::Byte>();
98 }
99 
TEST(OpQuantizeAddTest,DifferentQParams)100 TEST(OpQuantizeAddTest, DifferentQParams) {
101   TensorFactory<ScalarType::Float> tf;
102 
103   Tensor input1 = tf.full({3, 5}, 3.5);
104   Tensor input2 = tf.full({3, 5}, 3.5);
105   double a_scale = 0.5;
106   int64_t a_zero_point = 1;
107 
108   double b_scale = 0.25;
109   int64_t b_zero_point = 2;
110 
111   double out_scale = 0.1;
112   int64_t out_zero_point = 5;
113 
114   int64_t quant_min = 0;
115   int64_t quant_max = 255;
116 
117   TensorFactory<ScalarType::Byte> tfo;
118   Tensor qinput1 = tfo.zeros({3, 5});
119   Tensor qinput2 = tfo.zeros({3, 5});
120   Tensor qoutput = tfo.zeros({3, 5});
121   // 3.5 / 0.5 + 1 = 8
122   quantize_per_tensor_out(
123       input1,
124       a_scale,
125       a_zero_point,
126       quant_min,
127       quant_max,
128       ScalarType::Byte,
129       qinput1);
130 
131   // 3.5 / 0.25 + 2 = 16
132   quantize_per_tensor_out(
133       input2,
134       b_scale,
135       b_zero_point,
136       quant_min,
137       quant_max,
138       ScalarType::Byte,
139       qinput2);
140 
141   quantized_add_out(
142       qinput1,
143       a_scale,
144       a_zero_point,
145       quant_min,
146       quant_max,
147       qinput2,
148       b_scale,
149       b_zero_point,
150       quant_min,
151       quant_max,
152       out_scale,
153       out_zero_point,
154       quant_min,
155       quant_max,
156       qoutput);
157 
158   // can lossessly dq here so retrive the full 3.5 in operation
159   // (3.5 + 3.5) / 0.1 + 5 = 75
160   Tensor expected = tfo.full({3, 5}, 75.0);
161 
162   EXPECT_TENSOR_EQ(qoutput, expected);
163 }
164 
165 // Q -> DQ -> FP ADD -> Q -> DQ should be == to Q -> QADD -> DQ
TEST(OpQuantizeAddTest,ConsitencyWithReferencePattern)166 TEST(OpQuantizeAddTest, ConsitencyWithReferencePattern) {
167   TensorFactory<ScalarType::Float> tf;
168 
169   Tensor input1 = tf.full({3, 5}, 3.5);
170   Tensor input2 = tf.full({3, 5}, 3.5);
171   Tensor dq_input1 = tf.zeros({3, 5});
172   Tensor dq_input2 = tf.zeros({3, 5});
173   Tensor reference_op_output = tf.zeros({3, 5});
174   Tensor reference_pattern_output = tf.zeros({3, 5});
175   Tensor fp_output = tf.zeros({3, 5});
176 
177   double a_scale = 0.5;
178   int64_t a_zero_point = 1;
179 
180   double b_scale = 0.25;
181   int64_t b_zero_point = 2;
182 
183   double out_scale = 0.1;
184   int64_t out_zero_point = 5;
185 
186   int64_t quant_min = 0;
187   int64_t quant_max = 255;
188 
189   TensorFactory<ScalarType::Byte> tfo;
190   Tensor qinput1 = tfo.zeros({3, 5});
191   Tensor qinput2 = tfo.zeros({3, 5});
192   Tensor qoutput = tfo.zeros({3, 5});
193 
194   optional<ScalarType> out_dtype = optional<ScalarType>();
195 
196   KernelRuntimeContext context{};
197   // q -> qadd -> dq
198   // 3.5 / 0.5 + 1 = 8
199   quantize_per_tensor_out(
200       input1,
201       a_scale,
202       a_zero_point,
203       quant_min,
204       quant_max,
205       ScalarType::Byte,
206       qinput1);
207 
208   // 3.5 / 0.25 + 2 = 16
209   quantize_per_tensor_out(
210       input2,
211       b_scale,
212       b_zero_point,
213       quant_min,
214       quant_max,
215       ScalarType::Byte,
216       qinput2);
217 
218   quantized_add_out(
219       qinput1,
220       a_scale,
221       a_zero_point,
222       quant_min,
223       quant_max,
224       qinput2,
225       b_scale,
226       b_zero_point,
227       quant_min,
228       quant_max,
229       out_scale,
230       out_zero_point,
231       quant_min,
232       quant_max,
233       qoutput);
234   dequantize_per_tensor_out(
235       qoutput,
236       out_scale,
237       out_zero_point,
238       quant_min,
239       quant_max,
240       ScalarType::Byte,
241       out_dtype,
242       reference_op_output);
243 
244   // now get results for q -> dq -> fp add -> q -> dq
245   dequantize_per_tensor_out(
246       qinput1,
247       a_scale,
248       a_zero_point,
249       quant_min,
250       quant_max,
251       ScalarType::Byte,
252       out_dtype,
253       dq_input1);
254 
255   dequantize_per_tensor_out(
256       qinput2,
257       b_scale,
258       b_zero_point,
259       quant_min,
260       quant_max,
261       ScalarType::Byte,
262       out_dtype,
263       dq_input2);
264 
265   add_out(context, dq_input1, dq_input2, 1.0, fp_output);
266   // reuse 'qoutput' tensor as an intermediate
267   quantize_per_tensor_out(
268       fp_output,
269       out_scale,
270       out_zero_point,
271       quant_min,
272       quant_max,
273       ScalarType::Byte,
274       qoutput);
275 
276   dequantize_per_tensor_out(
277       qoutput,
278       out_scale,
279       out_zero_point,
280       quant_min,
281       quant_max,
282       ScalarType::Byte,
283       out_dtype,
284       reference_pattern_output);
285 
286   Tensor expected = tf.full({3, 5}, 7.0);
287 
288   // Pattern and op results should both be equal to expected and each other,
289   // check all cases explicitly instead of relying on transitivity
290   EXPECT_TENSOR_EQ(reference_op_output, expected);
291   EXPECT_TENSOR_EQ(reference_pattern_output, expected);
292   EXPECT_TENSOR_EQ(reference_op_output, reference_pattern_output);
293 }
294 
TEST(OpQuantizeAddTest,InvalidMinMaxDies)295 TEST(OpQuantizeAddTest, InvalidMinMaxDies) {
296   TensorFactory<ScalarType::Float> tf;
297 
298   Tensor input1 = tf.full({3, 5}, 3.5);
299   Tensor input2 = tf.full({3, 5}, 3.5);
300   double scale = 0.5;
301   int64_t zero_point = 1;
302 
303   int64_t quant_min = 0;
304   int64_t quant_max = 255;
305   int64_t out_quant_min = -1;
306   int64_t out_quant_max = 256;
307 
308   TensorFactory<ScalarType::Byte> tfo;
309   Tensor qinput1 = tfo.zeros({3, 5});
310   Tensor qinput2 = tfo.zeros({3, 5});
311   Tensor qoutput = tfo.zeros({3, 5});
312   // 3.5 / 0.5 + 1 = 8
313   quantize_per_tensor_out(
314       input1,
315       scale,
316       zero_point,
317       quant_min,
318       quant_max,
319       ScalarType::Byte,
320       qinput1);
321 
322   // 3.5 / 0.25 + 2 = 16
323   quantize_per_tensor_out(
324       input2,
325       scale,
326       zero_point,
327       quant_min,
328       quant_max,
329       ScalarType::Byte,
330       qinput2);
331 
332   ET_EXPECT_DEATH(
333       quantized_add_out(
334           qinput1,
335           scale,
336           zero_point,
337           quant_min,
338           quant_max,
339           qinput2,
340           scale,
341           zero_point,
342           quant_min,
343           quant_max,
344           scale,
345           zero_point,
346           out_quant_min,
347           out_quant_max,
348           qoutput),
349       "");
350 }
351 
TEST(OpQuantizeAddTest,TopOfRangeTest)352 TEST(OpQuantizeAddTest, TopOfRangeTest) {
353   TensorFactory<ScalarType::Float> tf;
354 
355   Tensor input1 = tf.full({3, 5}, 255);
356   Tensor input2 = tf.full({3, 5}, 255);
357   double a_scale = 1;
358   int64_t a_zero_point = 0;
359 
360   double b_scale = 1;
361   int64_t b_zero_point = 0;
362 
363   double out_scale = 1;
364   int64_t out_zero_point = 0;
365 
366   int64_t quant_min = 0;
367   int64_t quant_max = 255;
368 
369   TensorFactory<ScalarType::Byte> tfo;
370   Tensor qinput1 = tfo.zeros({3, 5});
371   Tensor qinput2 = tfo.zeros({3, 5});
372   Tensor qoutput = tfo.zeros({3, 5});
373 
374   quantize_per_tensor_out(
375       input1,
376       a_scale,
377       a_zero_point,
378       quant_min,
379       quant_max,
380       ScalarType::Byte,
381       qinput1);
382 
383   // 3.5 / 0.25 + 2 = 16
384   quantize_per_tensor_out(
385       input2,
386       b_scale,
387       b_zero_point,
388       quant_min,
389       quant_max,
390       ScalarType::Byte,
391       qinput2);
392 
393   quantized_add_out(
394       qinput1,
395       a_scale,
396       a_zero_point,
397       quant_min,
398       quant_max,
399       qinput2,
400       b_scale,
401       b_zero_point,
402       quant_min,
403       quant_max,
404       out_scale,
405       out_zero_point,
406       quant_min,
407       quant_max,
408       qoutput);
409 
410   Tensor expected = tfo.full({3, 5}, 255);
411 
412   EXPECT_TENSOR_EQ(qoutput, expected);
413 }
414