1 /*
2 * Copyright (c) Meta Platforms, Inc. and affiliates.
3 * All rights reserved.
4 *
5 * This source code is licensed under the BSD-style license found in the
6 * LICENSE file in the root directory of this source tree.
7 */
8
9 #include <executorch/kernels/portable/cpu/util/reduce_util.h>
10 #include <executorch/runtime/kernel/kernel_includes.h>
11 #include <algorithm>
12 #include <cinttypes>
13 #include <cmath>
14
15 /**
16 * For an input tensor, use the scale and zero_point arguments to quantize it.
17 */
18 namespace torch {
19 namespace executor {
20 namespace native {
21
22 using Tensor = exec_aten::Tensor;
23 using Scalar = exec_aten::Scalar;
24 using ScalarType = exec_aten::ScalarType;
25
26 namespace {
27
28 /**
29 * Asserts that the parameters are valid.
30 */
check_dequantize_per_tensor_args(const Tensor & input,int64_t quant_min,int64_t quant_max,ScalarType dtype,exec_aten::optional<ScalarType> & out_dtype,Tensor & out)31 void check_dequantize_per_tensor_args(
32 const Tensor& input,
33 int64_t quant_min,
34 int64_t quant_max,
35 ScalarType dtype,
36 exec_aten::optional<ScalarType>& out_dtype,
37 Tensor& out) {
38 ET_CHECK_MSG(
39 input.scalar_type() == ScalarType::Byte ||
40 input.scalar_type() == ScalarType::Char ||
41 input.scalar_type() == ScalarType::Bits16 ||
42 input.scalar_type() == ScalarType::UInt16 ||
43 input.scalar_type() == ScalarType::Short ||
44 input.scalar_type() == ScalarType::Int,
45 "input.scalar_type() %" PRId8 " is not supported:",
46 static_cast<int8_t>(input.scalar_type()));
47
48 ET_CHECK_MSG(
49 input.scalar_type() == dtype,
50 "input.scalar_type() %" PRId8 " is not matching dtype argumenta:",
51 static_cast<int8_t>(input.scalar_type()));
52
53 if (out_dtype.has_value()) {
54 ET_CHECK_MSG(
55 out.scalar_type() == out_dtype.value(),
56 "output_dtype must match the dtype of the out tensor");
57 }
58
59 ET_CHECK_MSG(
60 quant_min <= quant_max,
61 "quant min: %" PRId64 " is greater than quant max: %" PRId64,
62 quant_min,
63 quant_max);
64 }
65
66 } // namespace
67
68 /**
69 * Dequantizes the input tensor according to the formula (input - zero_point) *
70 * scale
71 *
72 * NOTE: quant_min and quant_max are not used in computation, but rather
73 * metadata that is passed around which can be useful for pattern matching. See
74 * https://github.com/pytorch/pytorch/pull/87093#discussion_r1000841181 for more
75 * info.
76 */
dequantize_per_tensor_out(const Tensor & input,double scale,int64_t zero_point,int64_t quant_min,int64_t quant_max,ScalarType dtype,exec_aten::optional<ScalarType> out_dtype,Tensor & out)77 Tensor& dequantize_per_tensor_out(
78 const Tensor& input,
79 double scale,
80 int64_t zero_point,
81 int64_t quant_min,
82 int64_t quant_max,
83 ScalarType dtype,
84 exec_aten::optional<ScalarType> out_dtype,
85 Tensor& out) {
86 torch::executor::Error err = resize_tensor(out, input.sizes());
87 ET_CHECK_MSG(
88 err == torch::executor::Error::Ok,
89 "Failed to resize out Tensor in dequantize_per_tensor_out");
90
91 check_dequantize_per_tensor_args(
92 input, quant_min, quant_max, dtype, out_dtype, out);
93
94 // calculate the dequantized output, cast scale to float to match fbgemm
95 // behavior
96 #define DEQUANTIZE_IMPL(IN_CTYPE, OUT_CTYPE, out_dtype) \
97 case ScalarType::out_dtype: { \
98 /* Hoist these function calls out of our inner loop because they might not \
99 * get inlined without LTO, particularly in ATen mode. */ \
100 auto* out_data_ptr = out.mutable_data_ptr<OUT_CTYPE>(); \
101 const auto* input_data_ptr = input.const_data_ptr<IN_CTYPE>(); \
102 const auto input_numel = input.numel(); \
103 for (size_t i = 0; i < input_numel; i++) { \
104 out_data_ptr[i] = static_cast<OUT_CTYPE>( \
105 (input_data_ptr[i] - static_cast<int32_t>(zero_point)) * \
106 static_cast<float>(scale)); \
107 } \
108 } break;
109 #define CALCULATE_INT_TYPE(IN_CTYPE, in_dtype) \
110 case ScalarType::in_dtype: \
111 switch (out.scalar_type()) { \
112 ET_FORALL_FLOAT_TYPES_WITH(IN_CTYPE, DEQUANTIZE_IMPL); \
113 default: \
114 ET_CHECK_MSG( \
115 false, \
116 "Unhandled output dtype %" PRId8, \
117 static_cast<int8_t>(out.scalar_type())); \
118 } \
119 break;
120
121 switch (input.scalar_type()) {
122 ET_FORALL_INT_TYPES(CALCULATE_INT_TYPE);
123 CALCULATE_INT_TYPE(uint16_t, Bits16);
124 CALCULATE_INT_TYPE(uint16_t, UInt16);
125 default:
126 ET_CHECK_MSG(
127 false,
128 "Unhandled input dtype %" PRId8,
129 static_cast<int8_t>(input.scalar_type()));
130 }
131
132 #undef CALCULATE_FLOAT_TYPE
133 #undef DEQUANTIZE_IMPL
134 return out;
135 }
136
dequantize_per_tensor_tensor_args_out(const Tensor & input,const Tensor & scale,const Tensor & zero_point,int64_t quant_min,int64_t quant_max,ScalarType dtype,exec_aten::optional<ScalarType> out_dtype,Tensor & out)137 Tensor& dequantize_per_tensor_tensor_args_out(
138 const Tensor& input,
139 const Tensor& scale,
140 const Tensor& zero_point,
141 int64_t quant_min,
142 int64_t quant_max,
143 ScalarType dtype,
144 exec_aten::optional<ScalarType> out_dtype,
145 Tensor& out) {
146 ET_CHECK_MSG(
147 scale.scalar_type() == ScalarType::Double,
148 "Expected scale to be Double tensor received: %" PRId8,
149 static_cast<int8_t>(scale.scalar_type()));
150 ET_CHECK_MSG(
151 zero_point.scalar_type() == ScalarType::Long,
152 "Expected scale to be Long tensor received: %" PRId8,
153 static_cast<int8_t>(zero_point.scalar_type()));
154 ET_CHECK_MSG(
155 scale.numel() == 1,
156 "Exepcted scale to only have one element received: %zd",
157 ssize_t(scale.numel()));
158 ET_CHECK_MSG(
159 zero_point.numel() == 1,
160 "Exepcted zero_point to only have one element received: %zd",
161 ssize_t(zero_point.numel()));
162
163 dequantize_per_tensor_out(
164 input,
165 scale.const_data_ptr<double>()[0],
166 zero_point.const_data_ptr<int64_t>()[0],
167 quant_min,
168 quant_max,
169 dtype,
170 out_dtype,
171 out);
172 return out;
173 }
174
get_scale(const Tensor & scale,size_t channel_ix)175 float get_scale(const Tensor& scale, size_t channel_ix) {
176 ET_CHECK_MSG(
177 (scale.scalar_type() == ScalarType::Double) ||
178 (scale.scalar_type() == ScalarType::Float),
179 "scale.scalar_type() %" PRId8 " is not double or float type",
180 static_cast<int8_t>(scale.scalar_type()));
181 if (scale.scalar_type() == ScalarType::Double) {
182 return static_cast<float>(scale.const_data_ptr<double>()[channel_ix]);
183 } else {
184 return scale.const_data_ptr<float>()[channel_ix];
185 }
186 }
187
dequantize_per_channel_out(const Tensor & input,const Tensor & scale,const exec_aten::optional<Tensor> & opt_zero_points,int64_t axis,int64_t quant_min,int64_t quant_max,ScalarType dtype,exec_aten::optional<ScalarType> out_dtype,Tensor & out)188 Tensor& dequantize_per_channel_out(
189 const Tensor& input,
190 const Tensor& scale,
191 const exec_aten::optional<Tensor>& opt_zero_points,
192 int64_t axis,
193 int64_t quant_min,
194 int64_t quant_max,
195 ScalarType dtype,
196 exec_aten::optional<ScalarType> out_dtype,
197 Tensor& out) {
198 // normalize axis
199 ET_CHECK_MSG(
200 tensor_has_dim(input, axis),
201 "axis %zd is not legal it should be -input.dim() <= axis < input.dim() %zd",
202 ssize_t(axis),
203 ssize_t(input.dim()));
204
205 if (axis < 0) {
206 axis += nonzero_dim(input);
207 }
208
209 ET_CHECK_MSG(
210 scale.numel() == input.size(axis),
211 "scale.numel() %zd != input.size(axis) %zd",
212 ssize_t(scale.numel()),
213 ssize_t(input.size(axis)));
214
215 if (opt_zero_points.has_value()) {
216 auto zero_point = opt_zero_points.value();
217 ET_CHECK_MSG(
218 zero_point.scalar_type() == ScalarType::Long,
219 "zero_point.scalar_type() %" PRId8 " is not integer type",
220 static_cast<int8_t>(zero_point.scalar_type()));
221
222 ET_CHECK_MSG(
223 zero_point.numel() == input.size(axis),
224 "zero_point.numel() %zd != input.size(axis) %zd",
225 ssize_t(zero_point.numel()),
226 ssize_t(input.size(axis)));
227 }
228
229 check_dequantize_per_tensor_args(
230 input, quant_min, quant_max, dtype, out_dtype, out);
231
232 // a list contains all dimensions except axis
233 int64_t dims[kTensorDimensionLimit];
234 for (int64_t i = 0; i < input.dim() - 1; i++) {
235 if (i < axis) {
236 dims[i] = i;
237 } else {
238 dims[i] = i + 1;
239 }
240 }
241 const int64_t* zero_point_data;
242 if (opt_zero_points.has_value()) {
243 zero_point_data = opt_zero_points.value().const_data_ptr<int64_t>();
244 } else {
245 zero_point_data = nullptr;
246 }
247
248 exec_aten::optional<exec_aten::ArrayRef<int64_t>> optional_dim_list{
249 exec_aten::ArrayRef<int64_t>{dims, size_t(input.dim() - 1)}};
250
251 // Actual dequantization logic
252 // input, out are the input and output tensors
253 // channel_ix is the index along the axis dimension. 0 <= channel_ix <
254 // input.size(axis).
255 // i.e. if the tensor has shape (N,C,H,W), axis being 1, then channel_ix
256 // will be 0, 1, 2, ... C-1
257 // in_ix is the flat index of the element you are dequantizing.
258 // in other words you are dequantizing in_data[in_ix]
259 #define DEQUANTIZE_IMPL(CTYPE_IN, CTYPE_OUT, out_dtype) \
260 case ScalarType::out_dtype: \
261 if (input.dim() == 1) { \
262 auto* out_data_ptr = out.mutable_data_ptr<CTYPE_OUT>(); \
263 const auto* input_data_ptr = input.const_data_ptr<CTYPE_IN>(); \
264 ET_CHECK_MSG( \
265 axis == 0, "Axis must be 0 for a single dimensional tensors"); \
266 const exec_aten::optional<int64_t> dim; \
267 apply_over_dim( \
268 [input_data_ptr, out_data_ptr, zero_point_data, &scale]( \
269 size_t numel, size_t stride, size_t base_ix) { \
270 for (size_t i = 0; i < numel; i++) { \
271 size_t current_ix = base_ix * stride + i; \
272 float _scale = get_scale(scale, current_ix); \
273 int64_t zero_point = 0; \
274 if (zero_point_data != nullptr) { \
275 zero_point = zero_point_data[current_ix]; \
276 } \
277 out_data_ptr[current_ix] = \
278 static_cast<CTYPE_OUT>( \
279 input_data_ptr[current_ix] - zero_point) * \
280 _scale; \
281 } \
282 }, \
283 input, \
284 dim); \
285 break; \
286 } \
287 for (size_t channel_ix = 0; channel_ix < input.size(axis); ++channel_ix) { \
288 float _scale = get_scale(scale, channel_ix); \
289 int64_t _zero_point = 0; \
290 if (zero_point_data != nullptr) { \
291 _zero_point = zero_point_data[channel_ix]; \
292 } \
293 auto* out_data_ptr = out.mutable_data_ptr<CTYPE_OUT>(); \
294 const auto* input_data_ptr = input.const_data_ptr<CTYPE_IN>(); \
295 apply_over_dim_list( \
296 [input_data_ptr, out_data_ptr, _scale, _zero_point](size_t in_ix) { \
297 out_data_ptr[in_ix] = static_cast<CTYPE_OUT>( \
298 (input_data_ptr[in_ix] - _zero_point) * _scale); \
299 }, \
300 input, \
301 optional_dim_list, \
302 channel_ix); \
303 } \
304 break;
305 #define CALCULATE_FLOAT_TYPE(CTYPE_IN, in_dtype) \
306 case ScalarType::in_dtype: \
307 switch (out.scalar_type()) { \
308 ET_FORALL_FLOAT_TYPES_WITH(CTYPE_IN, DEQUANTIZE_IMPL); \
309 default: \
310 ET_CHECK_MSG( \
311 false, \
312 "Unhandled output dtype %" PRId8, \
313 static_cast<int8_t>(out.scalar_type())); \
314 } \
315 break;
316
317 switch (input.scalar_type()) {
318 ET_FORALL_INT_TYPES(CALCULATE_FLOAT_TYPE);
319 CALCULATE_INT_TYPE(uint16_t, Bits16);
320 CALCULATE_INT_TYPE(uint16_t, UInt16);
321 default:
322 ET_CHECK_MSG(
323 false,
324 "Unhandled input dtype %" PRId8,
325 static_cast<int8_t>(input.scalar_type()));
326 }
327 #undef CALCULATE_FLOAT_TYPE
328 #undef QUANTIZE_IMPL
329
330 return out;
331 }
332
dequantize_per_channel_out(KernelRuntimeContext & context,const Tensor & input,const Tensor & scale,const exec_aten::optional<Tensor> & opt_zero_points,int64_t axis,int64_t quant_min,int64_t quant_max,ScalarType dtype,exec_aten::optional<ScalarType> out_dtype,Tensor & out)333 Tensor& dequantize_per_channel_out(
334 KernelRuntimeContext& context,
335 const Tensor& input,
336 const Tensor& scale,
337 const exec_aten::optional<Tensor>& opt_zero_points,
338 int64_t axis,
339 int64_t quant_min,
340 int64_t quant_max,
341 ScalarType dtype,
342 exec_aten::optional<ScalarType> out_dtype,
343 Tensor& out) {
344 (void)context;
345 torch::executor::Error err = resize_tensor(out, input.sizes());
346 ET_CHECK_MSG(
347 err == torch::executor::Error::Ok,
348 "Failed to resize out Tensor in dequantize_per_channel_out");
349
350 return dequantize_per_channel_out(
351 input,
352 scale,
353 opt_zero_points,
354 axis,
355 quant_min,
356 quant_max,
357 dtype,
358 out_dtype,
359 out);
360 }
361
dequantize_per_tensor_out(KernelRuntimeContext & context,const Tensor & input,double scale,int64_t zero_point,int64_t quant_min,int64_t quant_max,ScalarType dtype,exec_aten::optional<ScalarType> out_dtype,Tensor & out)362 Tensor& dequantize_per_tensor_out(
363 KernelRuntimeContext& context,
364 const Tensor& input,
365 double scale,
366 int64_t zero_point,
367 int64_t quant_min,
368 int64_t quant_max,
369 ScalarType dtype,
370 exec_aten::optional<ScalarType> out_dtype,
371 Tensor& out) {
372 // TODO(larryliu): Add a context arg to the real op function and remove this
373 // wrapper
374 (void)context;
375 return dequantize_per_tensor_out(
376 input, scale, zero_point, quant_min, quant_max, dtype, out_dtype, out);
377 }
378
dequantize_per_tensor_tensor_args_out(KernelRuntimeContext & context,const Tensor & input,const Tensor & scale,const Tensor & zero_point,int64_t quant_min,int64_t quant_max,ScalarType dtype,exec_aten::optional<ScalarType> out_dtype,Tensor & out)379 Tensor& dequantize_per_tensor_tensor_args_out(
380 KernelRuntimeContext& context,
381 const Tensor& input,
382 const Tensor& scale,
383 const Tensor& zero_point,
384 int64_t quant_min,
385 int64_t quant_max,
386 ScalarType dtype,
387 exec_aten::optional<ScalarType> out_dtype,
388 Tensor& out) {
389 // TODO(larryliu): Add a context arg to the real op function and remove this
390 // wrapper
391 (void)context;
392 return dequantize_per_tensor_tensor_args_out(
393 input, scale, zero_point, quant_min, quant_max, dtype, out_dtype, out);
394 }
395
dequantize_per_token_out(const Tensor & input,const Tensor & scale,const Tensor & zero_points,int64_t quant_min,int64_t quant_max,ScalarType dtype,ScalarType out_dtype,Tensor & out)396 Tensor& dequantize_per_token_out(
397 const Tensor& input,
398 const Tensor& scale,
399 const Tensor& zero_points,
400 int64_t quant_min,
401 int64_t quant_max,
402 ScalarType dtype,
403 ScalarType out_dtype,
404 Tensor& out) {
405 // Refactor this into a util
406 size_t num_channels = 1;
407 for (size_t i = 0; i < input.dim() - 1; i++) {
408 num_channels *= input.size(i);
409 }
410 // This unfortunate change is needed because we compile op_quantize for aten
411 // mode as well
412 std::array<exec_aten::SizesType, 2> input_sizes;
413 input_sizes[0] = static_cast<exec_aten::SizesType>(num_channels);
414 input_sizes[1] =
415 static_cast<exec_aten::SizesType>(input.size(input.dim() - 1));
416 #ifdef USE_ATEN_LIB
417 Tensor reshaped_input = at::from_blob(
418 input.mutable_data_ptr(),
419 input_sizes,
420 at::TensorOptions(input.scalar_type()));
421 #else
422 std::array<exec_aten::DimOrderType, 2> input_dim_order{0, 1};
423 std::array<exec_aten::StridesType, 2> input_strides;
424 dim_order_to_stride_nocheck(
425 input_sizes.data(), input_dim_order.data(), 2, input_strides.data());
426 void* input_data = input.mutable_data_ptr();
427 TensorImpl reshaped_input_impl = TensorImpl(
428 input.scalar_type(),
429 2,
430 input_sizes.data(),
431 input_data,
432 input_dim_order.data(),
433 input_strides.data(),
434 TensorShapeDynamism::STATIC);
435 Tensor reshaped_input(&reshaped_input_impl);
436 torch::executor::Error err = resize_tensor(out, input.sizes());
437 ET_CHECK_MSG(
438 err == torch::executor::Error::Ok,
439 "Failed to resize out Tensor in dequantize_per_channel_out");
440 #endif
441
442 return dequantize_per_channel_out(
443 reshaped_input,
444 scale,
445 zero_points,
446 0, /* axis */
447 quant_min,
448 quant_max,
449 dtype,
450 out_dtype,
451 out);
452 }
453
dequantize_per_token_out(RuntimeContext & context,const Tensor & input,const Tensor & scale,const Tensor & zero_points,int64_t quant_min,int64_t quant_max,ScalarType dtype,ScalarType out_dtype,Tensor & out)454 Tensor& dequantize_per_token_out(
455 RuntimeContext& context,
456 const Tensor& input,
457 const Tensor& scale,
458 const Tensor& zero_points,
459 int64_t quant_min,
460 int64_t quant_max,
461 ScalarType dtype,
462 ScalarType out_dtype,
463 Tensor& out) {
464 (void)context;
465 return dequantize_per_token_out(
466 input, scale, zero_points, quant_min, quant_max, dtype, out_dtype, out);
467 }
468
469 } // namespace native
470 } // namespace executor
471 } // namespace torch
472