xref: /aosp_15_r20/external/pytorch/aten/src/ATen/native/quantized/cpu/QuantUtils.h (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #pragma once
2 
3 #include <ATen/core/Tensor.h>
4 #include <ATen/core/List.h>
5 #include <ATen/TensorOperators.h>
6 #include <c10/util/irange.h>
7 #include <algorithm>
8 #include <cmath>
9 
10 #ifndef AT_PER_OPERATOR_HEADERS
11 #include <ATen/Functions.h>
12 #include <ATen/NativeFunctions.h>
13 #else
14 #include <ATen/ops/quantize_per_tensor_native.h>
15 #include <ATen/ops/quantize_per_channel_native.h>
16 #include <ATen/ops/zeros.h>
17 #endif
18 
19 namespace quant_utils {
20 namespace {
RawUint16ToFp16(unsigned short value)21   float RawUint16ToFp16(unsigned short value) {
22     // Convert raw 16 bits half precision floating point number
23     // to single precision floating point number.
24     const unsigned short sign_bits = value >> 15;
25     const unsigned short exponent_bits = value >> 10 & 0x1f;
26     const unsigned short significand_bits = value & 0x3ff;
27 
28     const float sign = sign_bits ? -1 : 1;
29     const float significand =
30         1 + significand_bits * 0.0009765625f; // 0.0009765625f = 0x1p-10 = 2^-10;
31     const float exponent = exponent_bits - 0xf;
32 
33     return sign * std::ldexp(significand, exponent);
34 }
35 
36 template <typename T>
CheckAndSaturate(T max_val,T * element)37 bool CheckAndSaturate(T max_val, T* element) {
38   if (*element > max_val) {
39     *element = max_val;
40     return true;
41   }
42   if (*element < -max_val) {
43     *element = -max_val;
44     return true;
45   }
46   return false;
47 }
48 }
49 using namespace std;
50 // A structure to hold quantization parameters 'scale' and 'zero_point'.
51 // The meaning of these values is as the constants in the quantization equation
52 //
53 //   real_value = scale * (quantized_value - zero_point)
54 //
55 // In other words, 'zero_point' is the quantized value that corresponds
56 // to the real value 0, and 'scale' is the difference of real values
57 // corresponding to consecutive quantized values.
58 struct TensorQuantizationParams {
59   double scale;
60   std::int32_t zero_point;
61   int precision;
62 };
63 
64 // Use fp16_min as the small scale cutoff because we don't want to use scales in
65 // fp16 subnormal range. This is to be consistent with Glow and FakeLowP
66 // implementation for NNPI.
67 constexpr float SMALL_SCALE_THRESHOLD = 6.1e-5f;
68 
69 // Following implementation should be identical to fbgemm::ChooseQuantizationParams
70 inline TensorQuantizationParams ChooseQuantizationParams(
71     float min,
72     float max,
73     int32_t qmin,
74     int32_t qmax,
75     bool preserve_sparsity = false,
76     bool force_scale_power_of_two = false,
77     bool reduce_range = false) {
78   TORCH_CHECK(
79       min <= max,
80       "In ChooseQuantizationParams, min should be less than or equal to max");
81 
82   if (reduce_range) {
83     qmin = qmin/2;
84     qmax = qmax/2;
85   }
86   if (min < 0 && max > 0 && preserve_sparsity) {
87     int symmetric_qmin = -((qmax - qmin) / 2 + 1);
88     int symmetric_qmax = (qmax - qmin) / 2;
89     double max_scale =
90         std::max(fabs(min / symmetric_qmin), fabs(max / symmetric_qmax));
91     min = max_scale * symmetric_qmin;
92     max = max_scale * symmetric_qmax;
93   }
94 
95   // We extend the [min, max] interval to ensure that it contains 0.
96   // Otherwise, we would not meet the requirement that 0 be an exactly
97   // representable value.
98   min = std::min(min, 0.f);
99   max = std::max(max, 0.f);
100 
101   TORCH_CHECK(
102       qmin < qmax,
103       "In ChooseQuantizationParams, qmin should be less than qmax");
104 
105   // Use double precision for intermediate computation but use single precision
106   // in final number to reflect the actual number used during quantization.
107   double scale = (static_cast<double>(max) - min) / (qmax - qmin);
108   // If scale is 0 or too small so its reciprocal is infinity, we arbitrary
109   // adjust the scale to 0.1 . We want to avoid scale's reciprocal being
110   // infinity because some of fbgemm code pre-computes scale's reciprocal to do
111   // multiplication instead of division in the time critical part of code.
112   if (float(scale) == 0.0f || std::isinf(1.0f / float(scale))) {
113     scale = 0.1;
114   }
115   TORCH_CHECK(scale > 0, "quantization scale should be > 0");
116 
117   if (force_scale_power_of_two) {
118     if (scale < 1) {
119       scale = 1.0 / (1 << static_cast<int>(floor(log(1.0 / scale) / log(2))));
120     } else {
121       scale = 1 << static_cast<int>(ceil(log(scale) / log(2)));
122     }
123   }
124 
125   // Cut off small scale
126   if (scale < SMALL_SCALE_THRESHOLD) {
127     float org_scale = scale;
128     scale = SMALL_SCALE_THRESHOLD;
129     // Adjust the min and max based on the new scale
130     if (min == 0.0f) {
131       max = SMALL_SCALE_THRESHOLD * (qmax - qmin);
132     } else if (max == 0.0f) {
133       min = -SMALL_SCALE_THRESHOLD * (qmax - qmin);
134     } else {
135       float amplifier = SMALL_SCALE_THRESHOLD / org_scale;
136       min *= amplifier;
137       max *= amplifier;
138     }
139   }
140 
141   // Zero-point computation.
142   // First the initial floating-point computation. The zero-point can be
143   // determined from solving an affine equation for any known pair
144   // (real value, corresponding quantized value).
145   // We know two such pairs: (rmin, qmin) and (rmax, qmax).
146   // The arithmetic error on the zero point computed from either pair
147   // will be roughly machine_epsilon * (sum of absolute values of terms)
148   // so we want to use the variant that adds the smaller terms.
149   double zero_point_from_min = qmin - min / static_cast<double>(scale);
150   double zero_point_from_max = qmax - max / static_cast<double>(scale);
151   double zero_point_from_min_error =
152       std::abs(qmin) - std::abs(min / static_cast<double>(scale));
153   double zero_point_from_max_error =
154       std::abs(qmax) - std::abs(max / static_cast<double>(scale));
155   double initial_zero_point =
156       zero_point_from_min_error < zero_point_from_max_error
157       ? zero_point_from_min
158       : zero_point_from_max;
159 
160   // for symmetric quantization (preserve_sparsity == true), we force zero_point
161   // to be a middle value between qmin and qmax.
162   // If either min or max is 0, then we just use 0 as zero_point.
163   if (min < 0 && max > 0 && preserve_sparsity) {
164     initial_zero_point = static_cast<double>(qmin + qmax) / 2;
165   }
166 
167   // Now we need to nudge the zero point to be an integer
168   // (our zero points are integer, and this is motivated by the requirement
169   // to be able to represent the real value "0" exactly as a quantized value,
170   // which is required in multiple places, for example in Im2col with zero
171   // padding).
172   int32_t nudged_zero_point = 0;
173   if (initial_zero_point < qmin) {
174     nudged_zero_point = qmin;
175   } else if (initial_zero_point > qmax) {
176     nudged_zero_point = qmax;
177   } else {
178     nudged_zero_point = nearbyint(initial_zero_point);
179   }
180 
181   TensorQuantizationParams result;
182   result.scale = scale;
183   result.zero_point = nudged_zero_point;
184   return result;
185 }
186 
187 // This function helps to convert the Conv1D dimensions usable by the Conv2d op.
188 constexpr int64_t kConv1dSqueezeDim = 0;
MakeArgForConv1d(const torch::List<int64_t> & arg,int64_t base_value)189 static C10_UNUSED torch::List<int64_t> MakeArgForConv1d(const torch::List<int64_t>& arg,
190                                              int64_t base_value) {
191   TORCH_CHECK(!arg.empty(), "Argument must have elements.");
192   torch::List<int64_t> result({arg.get(0), base_value});
193   if (arg.size() == 1) {
194     result[1] = arg.get(0);
195   } else {
196     result[1] = arg.get(1);
197   }
198   result[kConv1dSqueezeDim] = base_value;
199   return result;
200 }
201 
202 // The range for using FP16 quantization of weights requires that the elements
203 // should be in the range of [5.96e-8, 65504]. If it is out of range, then the
204 // number will be saturated to max or min representable values by FP16.
HandleWeightsSaturation(int64_t N,float * weight)205 inline void HandleWeightsSaturation(int64_t N, float* weight) {
206   const float kFp16Max = RawUint16ToFp16(0x7BFF);
207   bool found_out_of_range = false;
208   for (const auto i : c10::irange(N)) {
209     bool saturate = CheckAndSaturate<float>(kFp16Max, weight + i);
210     if (saturate) {
211       found_out_of_range = true;
212     }
213   }
214   if (found_out_of_range) {
215     TORCH_WARN("FOUND weight out of range ");
216   }
217 }
218 
219 // Util function for quantizing bias.
QuantizeBias(bool is_per_channel,const at::Tensor & bias,const at::Tensor & weight_contig,double input_scale)220 inline at::Tensor QuantizeBias(
221     bool is_per_channel,
222     const at::Tensor& bias,
223     const at::Tensor& weight_contig,
224     double input_scale) {
225   at::Tensor qbias;
226   if (is_per_channel) {
227     auto bias_quant_scales =
228         weight_contig.q_per_channel_scales() * input_scale;
229     auto bias_zp = at::zeros(bias_quant_scales.sizes(), c10::kInt);
230     qbias = at::native::quantize_per_channel(
231         bias, bias_quant_scales, bias_zp, 0, c10::kQInt32);
232   } else {
233     qbias = at::native::quantize_per_tensor(
234         bias, weight_contig.q_scale() * input_scale, 0, c10::kQInt32);
235   }
236   return qbias;
237 }
238 
239 } // namespace quant_utils
240