xref: /aosp_15_r20/external/executorch/kernels/quantized/cpu/op_dequantize.cpp (revision 523fa7a60841cd1ecfb9cc4201f1ca8b03ed023a)
1 /*
2  * Copyright (c) Meta Platforms, Inc. and affiliates.
3  * All rights reserved.
4  *
5  * This source code is licensed under the BSD-style license found in the
6  * LICENSE file in the root directory of this source tree.
7  */
8 
9 #include <executorch/kernels/portable/cpu/util/reduce_util.h>
10 #include <executorch/runtime/kernel/kernel_includes.h>
11 #include <algorithm>
12 #include <cinttypes>
13 #include <cmath>
14 
15 /**
16  * For an input tensor, use the scale and zero_point arguments to quantize it.
17  */
18 namespace torch {
19 namespace executor {
20 namespace native {
21 
22 using Tensor = exec_aten::Tensor;
23 using Scalar = exec_aten::Scalar;
24 using ScalarType = exec_aten::ScalarType;
25 
26 namespace {
27 
28 /**
29  * Asserts that the parameters are valid.
30  */
check_dequantize_per_tensor_args(const Tensor & input,int64_t quant_min,int64_t quant_max,ScalarType dtype,exec_aten::optional<ScalarType> & out_dtype,Tensor & out)31 void check_dequantize_per_tensor_args(
32     const Tensor& input,
33     int64_t quant_min,
34     int64_t quant_max,
35     ScalarType dtype,
36     exec_aten::optional<ScalarType>& out_dtype,
37     Tensor& out) {
38   ET_CHECK_MSG(
39       input.scalar_type() == ScalarType::Byte ||
40           input.scalar_type() == ScalarType::Char ||
41           input.scalar_type() == ScalarType::Bits16 ||
42           input.scalar_type() == ScalarType::UInt16 ||
43           input.scalar_type() == ScalarType::Short ||
44           input.scalar_type() == ScalarType::Int,
45       "input.scalar_type() %" PRId8 " is not supported:",
46       static_cast<int8_t>(input.scalar_type()));
47 
48   ET_CHECK_MSG(
49       input.scalar_type() == dtype,
50       "input.scalar_type() %" PRId8 " is not matching dtype argumenta:",
51       static_cast<int8_t>(input.scalar_type()));
52 
53   if (out_dtype.has_value()) {
54     ET_CHECK_MSG(
55         out.scalar_type() == out_dtype.value(),
56         "output_dtype must match the dtype of the out tensor");
57   }
58 
59   ET_CHECK_MSG(
60       quant_min <= quant_max,
61       "quant min: %" PRId64 " is greater than quant max: %" PRId64,
62       quant_min,
63       quant_max);
64 }
65 
66 } // namespace
67 
68 /**
69  * Dequantizes the input tensor according to the formula (input - zero_point) *
70  * scale
71  *
72  * NOTE: quant_min and quant_max are not used in computation, but rather
73  * metadata that is passed around which can be useful for pattern matching. See
74  * https://github.com/pytorch/pytorch/pull/87093#discussion_r1000841181 for more
75  * info.
76  */
dequantize_per_tensor_out(const Tensor & input,double scale,int64_t zero_point,int64_t quant_min,int64_t quant_max,ScalarType dtype,exec_aten::optional<ScalarType> out_dtype,Tensor & out)77 Tensor& dequantize_per_tensor_out(
78     const Tensor& input,
79     double scale,
80     int64_t zero_point,
81     int64_t quant_min,
82     int64_t quant_max,
83     ScalarType dtype,
84     exec_aten::optional<ScalarType> out_dtype,
85     Tensor& out) {
86   torch::executor::Error err = resize_tensor(out, input.sizes());
87   ET_CHECK_MSG(
88       err == torch::executor::Error::Ok,
89       "Failed to resize out Tensor in dequantize_per_tensor_out");
90 
91   check_dequantize_per_tensor_args(
92       input, quant_min, quant_max, dtype, out_dtype, out);
93 
94   // calculate the dequantized output, cast scale to float to match fbgemm
95   // behavior
96 #define DEQUANTIZE_IMPL(IN_CTYPE, OUT_CTYPE, out_dtype)                        \
97   case ScalarType::out_dtype: {                                                \
98     /* Hoist these function calls out of our inner loop because they might not \
99      * get inlined without LTO, particularly in ATen mode. */                  \
100     auto* out_data_ptr = out.mutable_data_ptr<OUT_CTYPE>();                    \
101     const auto* input_data_ptr = input.const_data_ptr<IN_CTYPE>();             \
102     const auto input_numel = input.numel();                                    \
103     for (size_t i = 0; i < input_numel; i++) {                                 \
104       out_data_ptr[i] = static_cast<OUT_CTYPE>(                                \
105           (input_data_ptr[i] - static_cast<int32_t>(zero_point)) *             \
106           static_cast<float>(scale));                                          \
107     }                                                                          \
108   } break;
109 #define CALCULATE_INT_TYPE(IN_CTYPE, in_dtype)               \
110   case ScalarType::in_dtype:                                 \
111     switch (out.scalar_type()) {                             \
112       ET_FORALL_FLOAT_TYPES_WITH(IN_CTYPE, DEQUANTIZE_IMPL); \
113       default:                                               \
114         ET_CHECK_MSG(                                        \
115             false,                                           \
116             "Unhandled output dtype %" PRId8,                \
117             static_cast<int8_t>(out.scalar_type()));         \
118     }                                                        \
119     break;
120 
121   switch (input.scalar_type()) {
122     ET_FORALL_INT_TYPES(CALCULATE_INT_TYPE);
123     CALCULATE_INT_TYPE(uint16_t, Bits16);
124     CALCULATE_INT_TYPE(uint16_t, UInt16);
125     default:
126       ET_CHECK_MSG(
127           false,
128           "Unhandled input dtype %" PRId8,
129           static_cast<int8_t>(input.scalar_type()));
130   }
131 
132 #undef CALCULATE_FLOAT_TYPE
133 #undef DEQUANTIZE_IMPL
134   return out;
135 }
136 
dequantize_per_tensor_tensor_args_out(const Tensor & input,const Tensor & scale,const Tensor & zero_point,int64_t quant_min,int64_t quant_max,ScalarType dtype,exec_aten::optional<ScalarType> out_dtype,Tensor & out)137 Tensor& dequantize_per_tensor_tensor_args_out(
138     const Tensor& input,
139     const Tensor& scale,
140     const Tensor& zero_point,
141     int64_t quant_min,
142     int64_t quant_max,
143     ScalarType dtype,
144     exec_aten::optional<ScalarType> out_dtype,
145     Tensor& out) {
146   ET_CHECK_MSG(
147       scale.scalar_type() == ScalarType::Double,
148       "Expected scale to be Double tensor received: %" PRId8,
149       static_cast<int8_t>(scale.scalar_type()));
150   ET_CHECK_MSG(
151       zero_point.scalar_type() == ScalarType::Long,
152       "Expected scale to be Long tensor received: %" PRId8,
153       static_cast<int8_t>(zero_point.scalar_type()));
154   ET_CHECK_MSG(
155       scale.numel() == 1,
156       "Exepcted scale to only have one element received: %zd",
157       ssize_t(scale.numel()));
158   ET_CHECK_MSG(
159       zero_point.numel() == 1,
160       "Exepcted zero_point to only have one element received: %zd",
161       ssize_t(zero_point.numel()));
162 
163   dequantize_per_tensor_out(
164       input,
165       scale.const_data_ptr<double>()[0],
166       zero_point.const_data_ptr<int64_t>()[0],
167       quant_min,
168       quant_max,
169       dtype,
170       out_dtype,
171       out);
172   return out;
173 }
174 
get_scale(const Tensor & scale,size_t channel_ix)175 float get_scale(const Tensor& scale, size_t channel_ix) {
176   ET_CHECK_MSG(
177       (scale.scalar_type() == ScalarType::Double) ||
178           (scale.scalar_type() == ScalarType::Float),
179       "scale.scalar_type() %" PRId8 " is not double or float type",
180       static_cast<int8_t>(scale.scalar_type()));
181   if (scale.scalar_type() == ScalarType::Double) {
182     return static_cast<float>(scale.const_data_ptr<double>()[channel_ix]);
183   } else {
184     return scale.const_data_ptr<float>()[channel_ix];
185   }
186 }
187 
dequantize_per_channel_out(const Tensor & input,const Tensor & scale,const exec_aten::optional<Tensor> & opt_zero_points,int64_t axis,int64_t quant_min,int64_t quant_max,ScalarType dtype,exec_aten::optional<ScalarType> out_dtype,Tensor & out)188 Tensor& dequantize_per_channel_out(
189     const Tensor& input,
190     const Tensor& scale,
191     const exec_aten::optional<Tensor>& opt_zero_points,
192     int64_t axis,
193     int64_t quant_min,
194     int64_t quant_max,
195     ScalarType dtype,
196     exec_aten::optional<ScalarType> out_dtype,
197     Tensor& out) {
198   // normalize axis
199   ET_CHECK_MSG(
200       tensor_has_dim(input, axis),
201       "axis %zd is not legal it should be -input.dim() <= axis < input.dim() %zd",
202       ssize_t(axis),
203       ssize_t(input.dim()));
204 
205   if (axis < 0) {
206     axis += nonzero_dim(input);
207   }
208 
209   ET_CHECK_MSG(
210       scale.numel() == input.size(axis),
211       "scale.numel() %zd != input.size(axis) %zd",
212       ssize_t(scale.numel()),
213       ssize_t(input.size(axis)));
214 
215   if (opt_zero_points.has_value()) {
216     auto zero_point = opt_zero_points.value();
217     ET_CHECK_MSG(
218         zero_point.scalar_type() == ScalarType::Long,
219         "zero_point.scalar_type() %" PRId8 " is not integer type",
220         static_cast<int8_t>(zero_point.scalar_type()));
221 
222     ET_CHECK_MSG(
223         zero_point.numel() == input.size(axis),
224         "zero_point.numel() %zd != input.size(axis) %zd",
225         ssize_t(zero_point.numel()),
226         ssize_t(input.size(axis)));
227   }
228 
229   check_dequantize_per_tensor_args(
230       input, quant_min, quant_max, dtype, out_dtype, out);
231 
232   // a list contains all dimensions except axis
233   int64_t dims[kTensorDimensionLimit];
234   for (int64_t i = 0; i < input.dim() - 1; i++) {
235     if (i < axis) {
236       dims[i] = i;
237     } else {
238       dims[i] = i + 1;
239     }
240   }
241   const int64_t* zero_point_data;
242   if (opt_zero_points.has_value()) {
243     zero_point_data = opt_zero_points.value().const_data_ptr<int64_t>();
244   } else {
245     zero_point_data = nullptr;
246   }
247 
248   exec_aten::optional<exec_aten::ArrayRef<int64_t>> optional_dim_list{
249       exec_aten::ArrayRef<int64_t>{dims, size_t(input.dim() - 1)}};
250 
251   // Actual dequantization logic
252   // input, out are the input and output tensors
253   // channel_ix is the index along the axis dimension. 0 <= channel_ix <
254   // input.size(axis).
255   //   i.e. if the tensor has shape (N,C,H,W), axis being 1, then channel_ix
256   //   will be 0, 1, 2, ... C-1
257   // in_ix is the flat index of the element you are dequantizing.
258   //   in other words you are dequantizing in_data[in_ix]
259 #define DEQUANTIZE_IMPL(CTYPE_IN, CTYPE_OUT, out_dtype)                        \
260   case ScalarType::out_dtype:                                                  \
261     if (input.dim() == 1) {                                                    \
262       auto* out_data_ptr = out.mutable_data_ptr<CTYPE_OUT>();                  \
263       const auto* input_data_ptr = input.const_data_ptr<CTYPE_IN>();           \
264       ET_CHECK_MSG(                                                            \
265           axis == 0, "Axis must be 0 for a single dimensional tensors");       \
266       const exec_aten::optional<int64_t> dim;                                  \
267       apply_over_dim(                                                          \
268           [input_data_ptr, out_data_ptr, zero_point_data, &scale](             \
269               size_t numel, size_t stride, size_t base_ix) {                   \
270             for (size_t i = 0; i < numel; i++) {                               \
271               size_t current_ix = base_ix * stride + i;                        \
272               float _scale = get_scale(scale, current_ix);                     \
273               int64_t zero_point = 0;                                          \
274               if (zero_point_data != nullptr) {                                \
275                 zero_point = zero_point_data[current_ix];                      \
276               }                                                                \
277               out_data_ptr[current_ix] =                                       \
278                   static_cast<CTYPE_OUT>(                                      \
279                       input_data_ptr[current_ix] - zero_point) *               \
280                   _scale;                                                      \
281             }                                                                  \
282           },                                                                   \
283           input,                                                               \
284           dim);                                                                \
285       break;                                                                   \
286     }                                                                          \
287     for (size_t channel_ix = 0; channel_ix < input.size(axis); ++channel_ix) { \
288       float _scale = get_scale(scale, channel_ix);                             \
289       int64_t _zero_point = 0;                                                 \
290       if (zero_point_data != nullptr) {                                        \
291         _zero_point = zero_point_data[channel_ix];                             \
292       }                                                                        \
293       auto* out_data_ptr = out.mutable_data_ptr<CTYPE_OUT>();                  \
294       const auto* input_data_ptr = input.const_data_ptr<CTYPE_IN>();           \
295       apply_over_dim_list(                                                     \
296           [input_data_ptr, out_data_ptr, _scale, _zero_point](size_t in_ix) {  \
297             out_data_ptr[in_ix] = static_cast<CTYPE_OUT>(                      \
298                 (input_data_ptr[in_ix] - _zero_point) * _scale);               \
299           },                                                                   \
300           input,                                                               \
301           optional_dim_list,                                                   \
302           channel_ix);                                                         \
303     }                                                                          \
304     break;
305 #define CALCULATE_FLOAT_TYPE(CTYPE_IN, in_dtype)             \
306   case ScalarType::in_dtype:                                 \
307     switch (out.scalar_type()) {                             \
308       ET_FORALL_FLOAT_TYPES_WITH(CTYPE_IN, DEQUANTIZE_IMPL); \
309       default:                                               \
310         ET_CHECK_MSG(                                        \
311             false,                                           \
312             "Unhandled output dtype %" PRId8,                \
313             static_cast<int8_t>(out.scalar_type()));         \
314     }                                                        \
315     break;
316 
317   switch (input.scalar_type()) {
318     ET_FORALL_INT_TYPES(CALCULATE_FLOAT_TYPE);
319     CALCULATE_INT_TYPE(uint16_t, Bits16);
320     CALCULATE_INT_TYPE(uint16_t, UInt16);
321     default:
322       ET_CHECK_MSG(
323           false,
324           "Unhandled input dtype %" PRId8,
325           static_cast<int8_t>(input.scalar_type()));
326   }
327 #undef CALCULATE_FLOAT_TYPE
328 #undef QUANTIZE_IMPL
329 
330   return out;
331 }
332 
dequantize_per_channel_out(KernelRuntimeContext & context,const Tensor & input,const Tensor & scale,const exec_aten::optional<Tensor> & opt_zero_points,int64_t axis,int64_t quant_min,int64_t quant_max,ScalarType dtype,exec_aten::optional<ScalarType> out_dtype,Tensor & out)333 Tensor& dequantize_per_channel_out(
334     KernelRuntimeContext& context,
335     const Tensor& input,
336     const Tensor& scale,
337     const exec_aten::optional<Tensor>& opt_zero_points,
338     int64_t axis,
339     int64_t quant_min,
340     int64_t quant_max,
341     ScalarType dtype,
342     exec_aten::optional<ScalarType> out_dtype,
343     Tensor& out) {
344   (void)context;
345   torch::executor::Error err = resize_tensor(out, input.sizes());
346   ET_CHECK_MSG(
347       err == torch::executor::Error::Ok,
348       "Failed to resize out Tensor in dequantize_per_channel_out");
349 
350   return dequantize_per_channel_out(
351       input,
352       scale,
353       opt_zero_points,
354       axis,
355       quant_min,
356       quant_max,
357       dtype,
358       out_dtype,
359       out);
360 }
361 
dequantize_per_tensor_out(KernelRuntimeContext & context,const Tensor & input,double scale,int64_t zero_point,int64_t quant_min,int64_t quant_max,ScalarType dtype,exec_aten::optional<ScalarType> out_dtype,Tensor & out)362 Tensor& dequantize_per_tensor_out(
363     KernelRuntimeContext& context,
364     const Tensor& input,
365     double scale,
366     int64_t zero_point,
367     int64_t quant_min,
368     int64_t quant_max,
369     ScalarType dtype,
370     exec_aten::optional<ScalarType> out_dtype,
371     Tensor& out) {
372   // TODO(larryliu): Add a context arg to the real op function and remove this
373   // wrapper
374   (void)context;
375   return dequantize_per_tensor_out(
376       input, scale, zero_point, quant_min, quant_max, dtype, out_dtype, out);
377 }
378 
dequantize_per_tensor_tensor_args_out(KernelRuntimeContext & context,const Tensor & input,const Tensor & scale,const Tensor & zero_point,int64_t quant_min,int64_t quant_max,ScalarType dtype,exec_aten::optional<ScalarType> out_dtype,Tensor & out)379 Tensor& dequantize_per_tensor_tensor_args_out(
380     KernelRuntimeContext& context,
381     const Tensor& input,
382     const Tensor& scale,
383     const Tensor& zero_point,
384     int64_t quant_min,
385     int64_t quant_max,
386     ScalarType dtype,
387     exec_aten::optional<ScalarType> out_dtype,
388     Tensor& out) {
389   // TODO(larryliu): Add a context arg to the real op function and remove this
390   // wrapper
391   (void)context;
392   return dequantize_per_tensor_tensor_args_out(
393       input, scale, zero_point, quant_min, quant_max, dtype, out_dtype, out);
394 }
395 
dequantize_per_token_out(const Tensor & input,const Tensor & scale,const Tensor & zero_points,int64_t quant_min,int64_t quant_max,ScalarType dtype,ScalarType out_dtype,Tensor & out)396 Tensor& dequantize_per_token_out(
397     const Tensor& input,
398     const Tensor& scale,
399     const Tensor& zero_points,
400     int64_t quant_min,
401     int64_t quant_max,
402     ScalarType dtype,
403     ScalarType out_dtype,
404     Tensor& out) {
405   // Refactor this into a util
406   size_t num_channels = 1;
407   for (size_t i = 0; i < input.dim() - 1; i++) {
408     num_channels *= input.size(i);
409   }
410   // This unfortunate change is needed because we compile op_quantize for aten
411   // mode as well
412   std::array<exec_aten::SizesType, 2> input_sizes;
413   input_sizes[0] = static_cast<exec_aten::SizesType>(num_channels);
414   input_sizes[1] =
415       static_cast<exec_aten::SizesType>(input.size(input.dim() - 1));
416 #ifdef USE_ATEN_LIB
417   Tensor reshaped_input = at::from_blob(
418       input.mutable_data_ptr(),
419       input_sizes,
420       at::TensorOptions(input.scalar_type()));
421 #else
422   std::array<exec_aten::DimOrderType, 2> input_dim_order{0, 1};
423   std::array<exec_aten::StridesType, 2> input_strides;
424   dim_order_to_stride_nocheck(
425       input_sizes.data(), input_dim_order.data(), 2, input_strides.data());
426   void* input_data = input.mutable_data_ptr();
427   TensorImpl reshaped_input_impl = TensorImpl(
428       input.scalar_type(),
429       2,
430       input_sizes.data(),
431       input_data,
432       input_dim_order.data(),
433       input_strides.data(),
434       TensorShapeDynamism::STATIC);
435   Tensor reshaped_input(&reshaped_input_impl);
436   torch::executor::Error err = resize_tensor(out, input.sizes());
437   ET_CHECK_MSG(
438       err == torch::executor::Error::Ok,
439       "Failed to resize out Tensor in dequantize_per_channel_out");
440 #endif
441 
442   return dequantize_per_channel_out(
443       reshaped_input,
444       scale,
445       zero_points,
446       0, /* axis */
447       quant_min,
448       quant_max,
449       dtype,
450       out_dtype,
451       out);
452 }
453 
dequantize_per_token_out(RuntimeContext & context,const Tensor & input,const Tensor & scale,const Tensor & zero_points,int64_t quant_min,int64_t quant_max,ScalarType dtype,ScalarType out_dtype,Tensor & out)454 Tensor& dequantize_per_token_out(
455     RuntimeContext& context,
456     const Tensor& input,
457     const Tensor& scale,
458     const Tensor& zero_points,
459     int64_t quant_min,
460     int64_t quant_max,
461     ScalarType dtype,
462     ScalarType out_dtype,
463     Tensor& out) {
464   (void)context;
465   return dequantize_per_token_out(
466       input, scale, zero_points, quant_min, quant_max, dtype, out_dtype, out);
467 }
468 
469 } // namespace native
470 } // namespace executor
471 } // namespace torch
472