xref: /aosp_15_r20/external/executorch/kernels/quantized/cpu/op_choose_qparams.cpp (revision 523fa7a60841cd1ecfb9cc4201f1ca8b03ed023a)
1 /*
2  * Copyright (c) Meta Platforms, Inc. and affiliates.
3  * All rights reserved.
4  *
5  * This source code is licensed under the BSD-style license found in the
6  * LICENSE file in the root directory of this source tree.
7  */
8 
9 #include <executorch/kernels/portable/cpu/vec_ops.h>
10 #include <executorch/runtime/kernel/kernel_includes.h>
11 #include <algorithm>
12 #include <cinttypes>
13 #include <cmath>
14 #include <tuple>
15 /**
16  * For an input tensor, use the scale and zero_point arguments to quantize it.
17  */
18 namespace torch {
19 namespace executor {
20 namespace native {
21 
22 using Tensor = exec_aten::Tensor;
23 using Scalar = exec_aten::Scalar;
24 using ScalarType = exec_aten::ScalarType;
25 
26 namespace {
27 
28 constexpr float SMALL_SCALE_THRESHOLD = 6.1e-5f;
29 
30 /**
31  * Asserts that the parameters are valid.
32  */
check_quantize_per_tensor_args(const Tensor & input,int64_t qmin,int64_t qmax,ScalarType dtype,Tensor & scale_out,Tensor & zero_point_out,bool is_per_token=false)33 void check_quantize_per_tensor_args(
34     const Tensor& input,
35     int64_t qmin,
36     int64_t qmax,
37     ScalarType dtype,
38     Tensor& scale_out,
39     Tensor& zero_point_out,
40     bool is_per_token = false) {
41   (void)dtype;
42   ET_CHECK_MSG(
43       qmin < qmax,
44       "qmin should be less than qmax, but received min: %" PRId64
45       ", max %" PRId64,
46       qmin,
47       qmax);
48   ET_CHECK_MSG(
49       input.scalar_type() == ScalarType::Float,
50       "Expected input to be Float tensor received: %" PRId8,
51       static_cast<int8_t>(input.scalar_type()));
52   ET_CHECK_MSG(
53       scale_out.scalar_type() == ScalarType::Double,
54       "Expected scale to be Double tensor received: %" PRId8,
55       static_cast<int8_t>(scale_out.scalar_type()));
56   ET_CHECK_MSG(
57       zero_point_out.scalar_type() == ScalarType::Long,
58       "Expected scale to be Long tensor received: %" PRId8,
59       static_cast<int8_t>(zero_point_out.scalar_type()));
60 
61   if (is_per_token) {
62     for (auto i = 0; i < input.dim() - 1; i++) {
63       ET_CHECK_MSG(
64           scale_out.size(i) == input.size(i),
65           "Exepcted scale to have the same number of elements at dimentions %d got %zd",
66           i,
67           scale_out.size(i));
68       ET_CHECK_MSG(
69           zero_point_out.size(i) == input.size(i),
70           "Exepcted zero pont to have the same number of elements at dimentions %d got %zd",
71           i,
72           zero_point_out.size(i));
73     }
74     ET_CHECK_MSG(
75         scale_out.size(input.dim() - 1) == 1,
76         "Exepcted scale to have only one element at dimentions %zd but got %zd",
77         input.dim() - 1,
78         scale_out.size(input.dim() - 1));
79     ET_CHECK_MSG(
80         zero_point_out.size(input.dim() - 1) == 1,
81         "Exepcted zero point to have only one element at dimentions %zd but got %zd",
82         input.dim() - 1,
83         zero_point_out.size(input.dim() - 1));
84   } else {
85     ET_CHECK_MSG(
86         scale_out.numel() == 1,
87         "Exepcted scale to only have one element received: %zd",
88         ssize_t(scale_out.numel()));
89     ET_CHECK_MSG(
90         zero_point_out.numel() == 1,
91         "Exepcted zero_point to only have one element received: %zd",
92         ssize_t(zero_point_out.numel()));
93   }
94 }
95 
calculate_scale_and_zero_point(float min,float max,int32_t qmin,int32_t qmax,double & scale,int32_t & zero_point)96 void calculate_scale_and_zero_point(
97     float min,
98     float max,
99     int32_t qmin,
100     int32_t qmax,
101     double& scale,
102     int32_t& zero_point) {
103   // We extend the [min, max] interval to ensure that it contains 0.
104   // Otherwise, we would not meet the requirement that 0 be an exactly
105   // representable value.
106   min = std::min(min, 0.f);
107   max = std::max(max, 0.f);
108 
109   // Use double precision for intermediate computation but use single precision
110   // in final number to reflect the actual number used during quantization.
111   scale = (static_cast<double>(max) - min) / (qmax - qmin);
112   // If scale is 0 or too small so its reciprocal is infinity, we arbitrary
113   // adjust the scale to 0.1 . We want to avoid scale's reciprocal being
114   // infinity because some of fbgemm code pre-computes scale's reciprocal to do
115   // multiplication instead of division in the time critical part of code.
116   if (float(scale) == 0.0f || std::isinf(1.0f / float(scale))) {
117     scale = 0.1;
118   }
119   ET_CHECK_MSG(scale > 0, "quantization scale should be > 0");
120 
121   // Cut off small scale
122   if (scale < SMALL_SCALE_THRESHOLD) {
123     float org_scale = scale;
124     scale = SMALL_SCALE_THRESHOLD;
125     // Adjust the min and max based on the new scale
126     if (min == 0.0f) {
127       max = SMALL_SCALE_THRESHOLD * (qmax - qmin);
128     } else if (max == 0.0f) {
129       min = -SMALL_SCALE_THRESHOLD * (qmax - qmin);
130     } else {
131       float amplifier = SMALL_SCALE_THRESHOLD / org_scale;
132       min *= amplifier;
133       max *= amplifier;
134     }
135   }
136 
137   // Zero-point computation.
138   // First the initial floating-point computation. The zero-point can be
139   // determined from solving an affine equation for any known pair
140   // (real value, corresponding quantized value).
141   // We know two such pairs: (rmin, qmin) and (rmax, qmax).
142   // The arithmetic error on the zero point computed from either pair
143   // will be roughly machine_epsilon * (sum of absolute values of terms)
144   // so we want to use the variant that adds the smaller terms.
145   double zero_point_from_min = qmin - min / static_cast<double>(scale);
146   double zero_point_from_max = qmax - max / static_cast<double>(scale);
147   double zero_point_from_min_error =
148       std::abs(qmin) - std::abs(min / static_cast<double>(scale));
149   double zero_point_from_max_error =
150       std::abs(qmax) - std::abs(max / static_cast<double>(scale));
151   double initial_zero_point =
152       zero_point_from_min_error < zero_point_from_max_error
153       ? zero_point_from_min
154       : zero_point_from_max;
155 
156   // Now we need to nudge the zero point to be an integer
157   // (our zero points are integer, and this is motivated by the requirement
158   // to be able to represent the real value "0" exactly as a quantized value,
159   // which is required in multiple places, for example in Im2col with zero
160   // padding).
161   int32_t nudged_zero_point = 0;
162   if (initial_zero_point < qmin) {
163     nudged_zero_point = qmin;
164   } else if (initial_zero_point > qmax) {
165     nudged_zero_point = qmax;
166   } else {
167     nudged_zero_point = nearbyint(static_cast<float>(initial_zero_point));
168   }
169   zero_point = nudged_zero_point;
170   return;
171 }
172 
choose_qparams(const Tensor & input,int32_t qmin,int32_t qmax,Tensor & scale_out,Tensor & zero_point_out)173 void choose_qparams(
174     const Tensor& input,
175     int32_t qmin,
176     int32_t qmax,
177     Tensor& scale_out,
178     Tensor& zero_point_out) {
179   const float* x_fp32 = input.const_data_ptr<float>();
180   // Compute x_min, x_max and q_params (scale, zero_point)
181   float min = torch::executor::vec_minf(x_fp32, input.numel());
182   float max = torch::executor::vec_maxf(x_fp32, input.numel());
183 
184   double scale;
185   int32_t zero_point;
186   calculate_scale_and_zero_point(min, max, qmin, qmax, scale, zero_point);
187 
188   scale_out.mutable_data_ptr<double>()[0] = scale;
189   zero_point_out.mutable_data_ptr<int64_t>()[0] = zero_point;
190 }
191 
choose_qparams_per_token(const Tensor & input,int32_t qmin,int32_t qmax,Tensor & scale_out,Tensor & zero_point_out)192 void choose_qparams_per_token(
193     const Tensor& input,
194     int32_t qmin,
195     int32_t qmax,
196     Tensor& scale_out,
197     Tensor& zero_point_out) {
198   const float* x_fp32 = input.const_data_ptr<float>();
199   // Compute x_min, x_max and q_params (scale, zero_point)
200   auto num_tokens = 1;
201   for (auto i = 0; i < input.dim() - 1; i++) {
202     num_tokens *= input.size(i);
203   }
204   auto token_dim_size = input.size(input.dim() - 1);
205   for (auto i = 0; i < num_tokens; i++) {
206     // vec_minf uses std::min_element. Check if it actually
207     // gets vectorized.
208     float min = torch::executor::vec_minf(x_fp32, token_dim_size);
209     float max = torch::executor::vec_maxf(x_fp32, token_dim_size);
210     double scale;
211     int32_t zero_point;
212     calculate_scale_and_zero_point(min, max, qmin, qmax, scale, zero_point);
213     scale_out.mutable_data_ptr<double>()[i] = scale;
214     zero_point_out.mutable_data_ptr<int64_t>()[i] = zero_point;
215     x_fp32 += token_dim_size;
216   }
217 }
218 } // namespace
219 
choose_qparams_tensor_out(const Tensor & input,int64_t quant_min,int64_t quant_max,ET_UNUSED double eps,ScalarType dtype,Tensor & scale_out,Tensor & zero_point_out)220 std::tuple<Tensor&, Tensor&> choose_qparams_tensor_out(
221     const Tensor& input,
222     int64_t quant_min,
223     int64_t quant_max,
224     ET_UNUSED double eps,
225     ScalarType dtype,
226     Tensor& scale_out,
227     Tensor& zero_point_out) {
228   check_quantize_per_tensor_args(
229       input, quant_min, quant_max, dtype, scale_out, zero_point_out);
230 
231   choose_qparams(input, quant_min, quant_max, scale_out, zero_point_out);
232   return {scale_out, zero_point_out};
233 }
234 
choose_qparams_tensor_out(KernelRuntimeContext & context,const Tensor & input,int64_t quant_min,int64_t quant_max,double eps,ScalarType dtype,Tensor & scale_out,Tensor & zero_point_out)235 ::std::tuple<Tensor&, Tensor&> choose_qparams_tensor_out(
236     KernelRuntimeContext& context,
237     const Tensor& input,
238     int64_t quant_min,
239     int64_t quant_max,
240     double eps,
241     ScalarType dtype,
242     Tensor& scale_out,
243     Tensor& zero_point_out) {
244   // TODO(larryliu): Add a context arg to the real op function and remove this
245   // wrapper
246   (void)context;
247   return choose_qparams_tensor_out(
248       input, quant_min, quant_max, eps, dtype, scale_out, zero_point_out);
249 }
250 
choose_qparams_per_token_asymmetric_out(const Tensor & input,ScalarType dtype,Tensor & scale_out,Tensor & zero_point_out)251 std::tuple<Tensor&, Tensor&> choose_qparams_per_token_asymmetric_out(
252     const Tensor& input,
253     ScalarType dtype,
254     Tensor& scale_out,
255     Tensor& zero_point_out) {
256   int64_t quant_min = -128;
257   int64_t quant_max = 127;
258   exec_aten::SizesType output_sizes[kTensorDimensionLimit];
259   for (ssize_t i = 0; i < input.dim() - 1; i++) {
260     output_sizes[i] = input.size(i);
261   }
262   output_sizes[input.dim() - 1] = 1;
263   size_t output_dim = input.dim();
264   torch::executor::Error err =
265       resize_tensor(scale_out, {output_sizes, output_dim});
266   ET_CHECK_MSG(
267       err == torch::executor::Error::Ok,
268       "Failed to resize scale_out Tensor in choose_qparams");
269   err = resize_tensor(zero_point_out, {output_sizes, output_dim});
270   ET_CHECK_MSG(
271       err == torch::executor::Error::Ok,
272       "Failed to resize zero_point_out Tensor in choose_qparams");
273 
274   check_quantize_per_tensor_args(
275       input,
276       quant_min,
277       quant_max,
278       dtype,
279       scale_out,
280       zero_point_out,
281       true /* is_per_token*/);
282 
283   choose_qparams_per_token(
284       input, quant_min, quant_max, scale_out, zero_point_out);
285   return {scale_out, zero_point_out};
286 }
287 
choose_qparams_per_token_asymmetric_out(RuntimeContext & context,const Tensor & input,ScalarType dtype,Tensor & scale_out,Tensor & zero_point_out)288 ::std::tuple<Tensor&, Tensor&> choose_qparams_per_token_asymmetric_out(
289     RuntimeContext& context,
290     const Tensor& input,
291     ScalarType dtype,
292     Tensor& scale_out,
293     Tensor& zero_point_out) {
294   (void)context;
295   return choose_qparams_per_token_asymmetric_out(
296       input, dtype, scale_out, zero_point_out);
297 }
298 
299 } // namespace native
300 } // namespace executor
301 } // namespace torch
302