1 /*
2 * Copyright (c) Meta Platforms, Inc. and affiliates.
3 * All rights reserved.
4 *
5 * This source code is licensed under the BSD-style license found in the
6 * LICENSE file in the root directory of this source tree.
7 */
8
9 #include <executorch/kernels/portable/NativeFunctions.h> // Declares the aten operator
10 #include <executorch/kernels/quantized/NativeFunctions.h> // Declares the quantized operator
11 #include <executorch/runtime/core/exec_aten/exec_aten.h>
12 #include <executorch/runtime/core/exec_aten/testing_util/tensor_factory.h>
13 #include <executorch/runtime/core/exec_aten/testing_util/tensor_util.h>
14 #include <executorch/runtime/core/exec_aten/util/scalar_type_util.h>
15 #include <executorch/runtime/platform/runtime.h>
16 #include <executorch/test/utils/DeathTest.h>
17
18 #include <gtest/gtest.h>
19 #include <limits>
20
21 using namespace ::testing;
22 using exec_aten::ArrayRef;
23 using exec_aten::optional;
24 using exec_aten::Scalar;
25 using exec_aten::ScalarType;
26 using exec_aten::Tensor;
27 using executorch::runtime::KernelRuntimeContext;
28 using torch::executor::native::add_out;
29 using torch::executor::native::dequantize_per_tensor_out;
30 using torch::executor::native::quantize_per_tensor_out;
31 using torch::executor::native::quantized_add_out;
32
33 using torch::executor::testing::TensorFactory;
34
35 /// A generic smoke test that works for any dtype that supports ones() and
36 /// zeros().
37 template <exec_aten::ScalarType DTYPE>
test_dtype()38 void test_dtype() {
39 TensorFactory<ScalarType::Float> tf;
40
41 Tensor input1 = tf.full({3, 5}, 3.5);
42 Tensor input2 = tf.full({3, 5}, 3.5);
43 double scale = 0.5;
44
45 int64_t zero_point = 1;
46 int64_t quant_min = 0;
47 int64_t quant_max = 255;
48
49 TensorFactory<DTYPE> tfo;
50 Tensor qinput1 = tfo.zeros({3, 5});
51 Tensor qinput2 = tfo.zeros({3, 5});
52 Tensor qoutput = tfo.zeros({3, 5});
53 // 3.5 / 0.5 + 1 = 8
54 quantize_per_tensor_out(
55 input1,
56 scale,
57 zero_point,
58 quant_min,
59 quant_max,
60 ScalarType::Byte,
61 qinput1);
62
63 quantize_per_tensor_out(
64 input2,
65 scale,
66 zero_point,
67 quant_min,
68 quant_max,
69 ScalarType::Byte,
70 qinput2);
71
72 quantized_add_out(
73 qinput1,
74 scale,
75 zero_point,
76 quant_min,
77 quant_max,
78 qinput2,
79 scale,
80 zero_point,
81 quant_min,
82 quant_max,
83 scale,
84 zero_point,
85 quant_min,
86 quant_max,
87 qoutput);
88
89 // can lossessly dq here so retrive the full 3.5 in operation
90 // (3.5 + 3.5) / 0.5 + 1 = 15
91 Tensor expected = tfo.full({3, 5}, 15.0);
92
93 EXPECT_TENSOR_EQ(qoutput, expected);
94 }
95
TEST(OpQuantizeAddTest,AllDtypesSupported)96 TEST(OpQuantizeAddTest, AllDtypesSupported) {
97 test_dtype<ScalarType::Byte>();
98 }
99
TEST(OpQuantizeAddTest,DifferentQParams)100 TEST(OpQuantizeAddTest, DifferentQParams) {
101 TensorFactory<ScalarType::Float> tf;
102
103 Tensor input1 = tf.full({3, 5}, 3.5);
104 Tensor input2 = tf.full({3, 5}, 3.5);
105 double a_scale = 0.5;
106 int64_t a_zero_point = 1;
107
108 double b_scale = 0.25;
109 int64_t b_zero_point = 2;
110
111 double out_scale = 0.1;
112 int64_t out_zero_point = 5;
113
114 int64_t quant_min = 0;
115 int64_t quant_max = 255;
116
117 TensorFactory<ScalarType::Byte> tfo;
118 Tensor qinput1 = tfo.zeros({3, 5});
119 Tensor qinput2 = tfo.zeros({3, 5});
120 Tensor qoutput = tfo.zeros({3, 5});
121 // 3.5 / 0.5 + 1 = 8
122 quantize_per_tensor_out(
123 input1,
124 a_scale,
125 a_zero_point,
126 quant_min,
127 quant_max,
128 ScalarType::Byte,
129 qinput1);
130
131 // 3.5 / 0.25 + 2 = 16
132 quantize_per_tensor_out(
133 input2,
134 b_scale,
135 b_zero_point,
136 quant_min,
137 quant_max,
138 ScalarType::Byte,
139 qinput2);
140
141 quantized_add_out(
142 qinput1,
143 a_scale,
144 a_zero_point,
145 quant_min,
146 quant_max,
147 qinput2,
148 b_scale,
149 b_zero_point,
150 quant_min,
151 quant_max,
152 out_scale,
153 out_zero_point,
154 quant_min,
155 quant_max,
156 qoutput);
157
158 // can lossessly dq here so retrive the full 3.5 in operation
159 // (3.5 + 3.5) / 0.1 + 5 = 75
160 Tensor expected = tfo.full({3, 5}, 75.0);
161
162 EXPECT_TENSOR_EQ(qoutput, expected);
163 }
164
165 // Q -> DQ -> FP ADD -> Q -> DQ should be == to Q -> QADD -> DQ
TEST(OpQuantizeAddTest,ConsitencyWithReferencePattern)166 TEST(OpQuantizeAddTest, ConsitencyWithReferencePattern) {
167 TensorFactory<ScalarType::Float> tf;
168
169 Tensor input1 = tf.full({3, 5}, 3.5);
170 Tensor input2 = tf.full({3, 5}, 3.5);
171 Tensor dq_input1 = tf.zeros({3, 5});
172 Tensor dq_input2 = tf.zeros({3, 5});
173 Tensor reference_op_output = tf.zeros({3, 5});
174 Tensor reference_pattern_output = tf.zeros({3, 5});
175 Tensor fp_output = tf.zeros({3, 5});
176
177 double a_scale = 0.5;
178 int64_t a_zero_point = 1;
179
180 double b_scale = 0.25;
181 int64_t b_zero_point = 2;
182
183 double out_scale = 0.1;
184 int64_t out_zero_point = 5;
185
186 int64_t quant_min = 0;
187 int64_t quant_max = 255;
188
189 TensorFactory<ScalarType::Byte> tfo;
190 Tensor qinput1 = tfo.zeros({3, 5});
191 Tensor qinput2 = tfo.zeros({3, 5});
192 Tensor qoutput = tfo.zeros({3, 5});
193
194 optional<ScalarType> out_dtype = optional<ScalarType>();
195
196 KernelRuntimeContext context{};
197 // q -> qadd -> dq
198 // 3.5 / 0.5 + 1 = 8
199 quantize_per_tensor_out(
200 input1,
201 a_scale,
202 a_zero_point,
203 quant_min,
204 quant_max,
205 ScalarType::Byte,
206 qinput1);
207
208 // 3.5 / 0.25 + 2 = 16
209 quantize_per_tensor_out(
210 input2,
211 b_scale,
212 b_zero_point,
213 quant_min,
214 quant_max,
215 ScalarType::Byte,
216 qinput2);
217
218 quantized_add_out(
219 qinput1,
220 a_scale,
221 a_zero_point,
222 quant_min,
223 quant_max,
224 qinput2,
225 b_scale,
226 b_zero_point,
227 quant_min,
228 quant_max,
229 out_scale,
230 out_zero_point,
231 quant_min,
232 quant_max,
233 qoutput);
234 dequantize_per_tensor_out(
235 qoutput,
236 out_scale,
237 out_zero_point,
238 quant_min,
239 quant_max,
240 ScalarType::Byte,
241 out_dtype,
242 reference_op_output);
243
244 // now get results for q -> dq -> fp add -> q -> dq
245 dequantize_per_tensor_out(
246 qinput1,
247 a_scale,
248 a_zero_point,
249 quant_min,
250 quant_max,
251 ScalarType::Byte,
252 out_dtype,
253 dq_input1);
254
255 dequantize_per_tensor_out(
256 qinput2,
257 b_scale,
258 b_zero_point,
259 quant_min,
260 quant_max,
261 ScalarType::Byte,
262 out_dtype,
263 dq_input2);
264
265 add_out(context, dq_input1, dq_input2, 1.0, fp_output);
266 // reuse 'qoutput' tensor as an intermediate
267 quantize_per_tensor_out(
268 fp_output,
269 out_scale,
270 out_zero_point,
271 quant_min,
272 quant_max,
273 ScalarType::Byte,
274 qoutput);
275
276 dequantize_per_tensor_out(
277 qoutput,
278 out_scale,
279 out_zero_point,
280 quant_min,
281 quant_max,
282 ScalarType::Byte,
283 out_dtype,
284 reference_pattern_output);
285
286 Tensor expected = tf.full({3, 5}, 7.0);
287
288 // Pattern and op results should both be equal to expected and each other,
289 // check all cases explicitly instead of relying on transitivity
290 EXPECT_TENSOR_EQ(reference_op_output, expected);
291 EXPECT_TENSOR_EQ(reference_pattern_output, expected);
292 EXPECT_TENSOR_EQ(reference_op_output, reference_pattern_output);
293 }
294
TEST(OpQuantizeAddTest,InvalidMinMaxDies)295 TEST(OpQuantizeAddTest, InvalidMinMaxDies) {
296 TensorFactory<ScalarType::Float> tf;
297
298 Tensor input1 = tf.full({3, 5}, 3.5);
299 Tensor input2 = tf.full({3, 5}, 3.5);
300 double scale = 0.5;
301 int64_t zero_point = 1;
302
303 int64_t quant_min = 0;
304 int64_t quant_max = 255;
305 int64_t out_quant_min = -1;
306 int64_t out_quant_max = 256;
307
308 TensorFactory<ScalarType::Byte> tfo;
309 Tensor qinput1 = tfo.zeros({3, 5});
310 Tensor qinput2 = tfo.zeros({3, 5});
311 Tensor qoutput = tfo.zeros({3, 5});
312 // 3.5 / 0.5 + 1 = 8
313 quantize_per_tensor_out(
314 input1,
315 scale,
316 zero_point,
317 quant_min,
318 quant_max,
319 ScalarType::Byte,
320 qinput1);
321
322 // 3.5 / 0.25 + 2 = 16
323 quantize_per_tensor_out(
324 input2,
325 scale,
326 zero_point,
327 quant_min,
328 quant_max,
329 ScalarType::Byte,
330 qinput2);
331
332 ET_EXPECT_DEATH(
333 quantized_add_out(
334 qinput1,
335 scale,
336 zero_point,
337 quant_min,
338 quant_max,
339 qinput2,
340 scale,
341 zero_point,
342 quant_min,
343 quant_max,
344 scale,
345 zero_point,
346 out_quant_min,
347 out_quant_max,
348 qoutput),
349 "");
350 }
351
TEST(OpQuantizeAddTest,TopOfRangeTest)352 TEST(OpQuantizeAddTest, TopOfRangeTest) {
353 TensorFactory<ScalarType::Float> tf;
354
355 Tensor input1 = tf.full({3, 5}, 255);
356 Tensor input2 = tf.full({3, 5}, 255);
357 double a_scale = 1;
358 int64_t a_zero_point = 0;
359
360 double b_scale = 1;
361 int64_t b_zero_point = 0;
362
363 double out_scale = 1;
364 int64_t out_zero_point = 0;
365
366 int64_t quant_min = 0;
367 int64_t quant_max = 255;
368
369 TensorFactory<ScalarType::Byte> tfo;
370 Tensor qinput1 = tfo.zeros({3, 5});
371 Tensor qinput2 = tfo.zeros({3, 5});
372 Tensor qoutput = tfo.zeros({3, 5});
373
374 quantize_per_tensor_out(
375 input1,
376 a_scale,
377 a_zero_point,
378 quant_min,
379 quant_max,
380 ScalarType::Byte,
381 qinput1);
382
383 // 3.5 / 0.25 + 2 = 16
384 quantize_per_tensor_out(
385 input2,
386 b_scale,
387 b_zero_point,
388 quant_min,
389 quant_max,
390 ScalarType::Byte,
391 qinput2);
392
393 quantized_add_out(
394 qinput1,
395 a_scale,
396 a_zero_point,
397 quant_min,
398 quant_max,
399 qinput2,
400 b_scale,
401 b_zero_point,
402 quant_min,
403 quant_max,
404 out_scale,
405 out_zero_point,
406 quant_min,
407 quant_max,
408 qoutput);
409
410 Tensor expected = tfo.full({3, 5}, 255);
411
412 EXPECT_TENSOR_EQ(qoutput, expected);
413 }
414