xref: /aosp_15_r20/external/pytorch/aten/src/ATen/native/quantized/cpu/qnnpack/src/qnnpack/requantization.h (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 /*
2  * Copyright (c) Facebook, Inc. and its affiliates.
3  * All rights reserved.
4  *
5  * This source code is licensed under the BSD-style license found in the
6  * LICENSE file in the root directory of this source tree.
7  */
8 
9 #pragma once
10 
11 #include <assert.h>
12 #include <math.h>
13 #include <stddef.h>
14 #include <stdint.h>
15 
16 #include <fp16/bitcasts.h>
17 
18 #include <qnnpack/params.h>
19 #include <qnnpack/scalar-utils.h>
20 
21 static inline union pytorch_qnnp_q31_requantization_params
pytorch_qnnp_compute_scalar_requantization_params(float scale,uint8_t zero_point,uint8_t min,uint8_t max)22 pytorch_qnnp_compute_scalar_requantization_params(
23     float scale,
24     uint8_t zero_point,
25     uint8_t min,
26     uint8_t max) {
27   /* Compute requantization parameters */
28   assert(scale < 1.0f);
29   assert(scale >= 0x1.0p-32f);
30   const uint32_t scale_bits = fp32_to_bits(scale);
31 
32   /* Multiplier is in [0x40000000, 0x7FFFFF80] range */
33   const int32_t multiplier = (int32_t)(
34       ((scale_bits & UINT32_C(0x007FFFFF)) | UINT32_C(0x00800000)) << 7);
35   assert(multiplier >= INT32_C(0x40000000));
36   assert(multiplier <= INT32_C(0x7FFFFF80));
37 
38   /* Shift is in [0, 31] range */
39   const int32_t shift = 127 + 31 - 32 - (fp32_to_bits(scale) >> 23);
40   assert(shift >= 0);
41   assert(shift < 32);
42 
43   union pytorch_qnnp_q31_requantization_params params;
44   const uint32_t remainder_mask = (UINT32_C(1) << shift) - UINT32_C(1);
45   const uint32_t remainder_threshold = remainder_mask >> 1;
46   params.scalar.multiplier = multiplier;
47   params.scalar.remainder_mask = (int32_t)remainder_mask;
48   params.scalar.remainder_threshold = (int32_t)remainder_threshold;
49   params.scalar.shift = (uint32_t)shift;
50   params.scalar.min_less_zero_point =
51       (int32_t)(uint32_t)min - (int32_t)(uint32_t)zero_point;
52   params.scalar.max_less_zero_point =
53       (int32_t)(uint32_t)max - (int32_t)(uint32_t)zero_point;
54   params.scalar.zero_point = (int32_t)(uint32_t)zero_point;
55   return params;
56 }
57 
58 static inline union pytorch_qnnp_fp32_requantization_params
pytorch_qnnp_compute_scalar_fp32_requantization_params(float * scales,uint8_t zero_point,uint8_t min,uint8_t max)59 pytorch_qnnp_compute_scalar_fp32_requantization_params(
60     float* scales,
61     uint8_t zero_point,
62     uint8_t min,
63     uint8_t max) {
64 
65   union pytorch_qnnp_fp32_requantization_params params;
66   params.scalar.scales = scales;
67   params.scalar.output_zero_point = zero_point;
68   params.scalar.output_max = max;
69   params.scalar.output_min = min;
70   params.scalar.min_less_zero_point = ((float)((int32_t)(uint32_t)min -
71       (int32_t)(uint32_t)zero_point));
72   params.scalar.max_less_zero_point = ((float)((int32_t)(uint32_t)max -
73       (int32_t)(uint32_t)zero_point));
74   params.scalar.magic = 12582912.0f;
75   params.scalar.magic_less_zero_point = (INT32_C(0x4B400000) -
76       (int32_t)(uint32_t)zero_point);
77   return params;
78 }
79 
80 static inline union pytorch_qnnp_q31_requantization_params
pytorch_qnnp_compute_requantization_params(float scale,uint8_t zero_point,uint8_t min,uint8_t max)81 pytorch_qnnp_compute_requantization_params(
82     float scale,
83     uint8_t zero_point,
84     uint8_t min,
85     uint8_t max) {
86   /* Compute requantization parameters */
87   const uint32_t scale_bits = fp32_to_bits(scale);
88 
89   /* Multiplier is in [0x40000000, 0x7FFFFF80] range */
90   const int32_t multiplier = (int32_t)(
91       ((scale_bits & UINT32_C(0x007FFFFF)) | UINT32_C(0x00800000)) << 7);
92   assert(multiplier >= INT32_C(0x40000000));
93   assert(multiplier <= INT32_C(0x7FFFFF80));
94 
95   /* Shift is in [0, 31] range */
96   const int32_t shift = 127 + 31 - 32 - (fp32_to_bits(scale) >> 23);
97   assert(shift >= 0);
98   assert(shift < 32);
99 
100   union pytorch_qnnp_q31_requantization_params params;
101 #if CPUINFO_ARCH_X86 || CPUINFO_ARCH_X86_64
102   const uint32_t remainder_mask = (UINT32_C(1) << shift) - UINT32_C(1);
103   const uint32_t remainder_threshold = remainder_mask >> 1;
104   params.sse2.multiplier[0] = multiplier;
105   params.sse2.multiplier[1] = multiplier;
106   params.sse2.multiplier[2] = multiplier;
107   params.sse2.multiplier[3] = multiplier;
108   params.sse2.rounding[0] = UINT64_C(0x40000000);
109   params.sse2.rounding[1] = UINT64_C(0x40000000);
110   params.sse2.remainder_mask[0] = (int32_t)remainder_mask;
111   params.sse2.remainder_mask[1] = (int32_t)remainder_mask;
112   params.sse2.remainder_mask[2] = (int32_t)remainder_mask;
113   params.sse2.remainder_mask[3] = (int32_t)remainder_mask;
114   params.sse2.remainder_threshold[0] = (int32_t)remainder_threshold;
115   params.sse2.remainder_threshold[1] = (int32_t)remainder_threshold;
116   params.sse2.remainder_threshold[2] = (int32_t)remainder_threshold;
117   params.sse2.remainder_threshold[3] = (int32_t)remainder_threshold;
118   params.sse2.shift[0] = (uint64_t)(uint32_t)shift;
119   params.sse2.shift[1] = (uint64_t)(uint32_t)shift;
120   for (uint32_t i = 0; i < 8; i++) {
121     params.sse2.zero_point[i] = (int16_t)(uint16_t)zero_point;
122   }
123   for (uint32_t i = 0; i < 16; i++) {
124     params.sse2.max[i] = max;
125     params.sse2.min[i] = min;
126   }
127 #elif CPUINFO_ARCH_ARM || CPUINFO_ARCH_ARM64
128   params.neon.multiplier = multiplier;
129   params.neon.right_shift = -shift;
130   params.neon.zero_point = (int16_t)(uint16_t)zero_point;
131   params.neon.max = max;
132   params.neon.min = min;
133 #else
134   const uint32_t remainder_mask = (UINT32_C(1) << shift) - UINT32_C(1);
135   const uint32_t remainder_threshold = remainder_mask >> 1;
136   params.scalar.multiplier = multiplier;
137   params.scalar.remainder_mask = (int32_t)remainder_mask;
138   params.scalar.remainder_threshold = (int32_t)remainder_threshold;
139   params.scalar.shift = (uint32_t)shift;
140   params.scalar.min_less_zero_point =
141       (int32_t)(uint32_t)min - (int32_t)(uint32_t)zero_point;
142   params.scalar.max_less_zero_point =
143       (int32_t)(uint32_t)max - (int32_t)(uint32_t)zero_point;
144   params.scalar.zero_point = (int32_t)(uint32_t)zero_point;
145 #endif
146   return params;
147 }
148 
149 static inline union pytorch_qnnp_conv_quantization_params
pytorch_qnnp_compute_conv_quantization_params(uint8_t input_zero_point,const uint8_t * kernel_zero_points,const float * requantization_scales,uint8_t output_zero_point,uint8_t output_min,uint8_t output_max)150 pytorch_qnnp_compute_conv_quantization_params(
151     uint8_t input_zero_point,
152     const uint8_t* kernel_zero_points,
153     const float* requantization_scales,
154     uint8_t output_zero_point,
155     uint8_t output_min,
156     uint8_t output_max) {
157 
158   union pytorch_qnnp_conv_quantization_params params;
159 #if CPUINFO_ARCH_X86 || CPUINFO_ARCH_X86_64
160   params.sse2.kernel_zero_points = kernel_zero_points;
161   for (uint32_t i = 0; i < 8; i++) {
162     params.sse2.input_zero_point[i] = (int16_t)(uint16_t)input_zero_point;
163   }
164   params.sse2.requantization_scales = requantization_scales;
165   for (uint32_t i = 0; i < 8; i++) {
166     params.sse2.output_zero_point[i] = (int16_t)(uint16_t)output_zero_point;
167   }
168   for (uint32_t i = 0; i < 16; i++) {
169     params.sse2.output_max[i] = output_max;
170     params.sse2.output_min[i] = output_min;
171   }
172 #elif CPUINFO_ARCH_ARM || CPUINFO_ARCH_ARM64
173   params.neon.input_zero_point = (int16_t)(uint16_t)input_zero_point;
174   params.neon.kernel_zero_points = kernel_zero_points;
175   params.neon.requantization_scales = requantization_scales;
176   params.neon.output_zero_point = (int16_t)(uint16_t)output_zero_point;
177   params.neon.output_max = output_max;
178   params.neon.output_min = output_min;
179   params.neon.vfmin = ((float)((int32_t)(uint32_t)output_min -
180       (int32_t)(uint32_t)output_zero_point));
181   params.neon.vfmax = ((float)((int32_t)(uint32_t)output_max -
182       (int32_t)(uint32_t)output_zero_point));
183   params.neon.vfmagic = 12582912.0f;
184   params.neon.vimagic = (INT32_C(0x4B400000) -
185       (int32_t)(uint32_t)output_zero_point);
186 #else
187   params.scalar.input_zero_point = (int32_t)(uint32_t)input_zero_point;
188   params.scalar.kernel_zero_points = kernel_zero_points;
189   params.scalar.requantization_scales = requantization_scales;
190   params.scalar.output_min_less_zero_point =
191       (int32_t)(uint32_t)output_min - (int32_t)(uint32_t)output_zero_point;
192   params.scalar.output_max_less_zero_point =
193       (int32_t)(uint32_t)output_max - (int32_t)(uint32_t)output_zero_point;
194   params.scalar.output_zero_point = (int32_t)(uint32_t)output_zero_point;
195 #endif
196   return params;
197 }
198 
199 static inline union pytorch_qnnp_avgpool_quantization_params
pytorch_qnnp_compute_avgpool_quantization_params(int32_t bias,float scale,uint8_t output_zero_point,uint8_t output_min,uint8_t output_max)200 pytorch_qnnp_compute_avgpool_quantization_params(
201     int32_t bias,
202     float scale,
203     uint8_t output_zero_point,
204     uint8_t output_min,
205     uint8_t output_max) {
206   /* Compute requantization parameters */
207   assert(scale >= 0x1.0p-32f);
208   assert(scale < 256.0f);
209 
210   union pytorch_qnnp_avgpool_quantization_params params;
211 #if CPUINFO_ARCH_X86 || CPUINFO_ARCH_X86_64
212   params.sse2.bias[0] = bias;
213   params.sse2.bias[1] = bias;
214   params.sse2.bias[2] = bias;
215   params.sse2.bias[3] = bias;
216   params.sse2.scale[0] = scale;
217   params.sse2.scale[1] = scale;
218   params.sse2.scale[2] = scale;
219   params.sse2.scale[3] = scale;
220   for (uint32_t i = 0; i < 8; i++) {
221     params.sse2.output_zero_point[i] = (int16_t)(uint16_t)output_zero_point;
222   }
223   for (uint32_t i = 0; i < 16; i++) {
224     params.sse2.output_max[i] = output_max;
225     params.sse2.output_min[i] = output_min;
226   }
227 #elif CPUINFO_ARCH_ARM || CPUINFO_ARCH_ARM64
228   params.neon.bias = bias;
229   params.neon.scale = scale;
230   params.neon.output_zero_point = (int16_t)(uint16_t)output_zero_point;
231   params.neon.output_max = output_max;
232   params.neon.output_min = output_min;
233   params.neon.vfmin = ((float)((int32_t)(uint32_t)output_min -
234       (int32_t)(uint32_t)output_zero_point));
235   params.neon.vfmax = ((float)((int32_t)(uint32_t)output_max -
236       (int32_t)(uint32_t)output_zero_point));
237   params.neon.vfmagic = 12582912.0f;
238   params.neon.vimagic = (INT32_C(0x4B400000) -
239       (int32_t)(uint32_t)output_zero_point);
240 #else
241   params.scalar.bias = bias;
242   params.scalar.scale = scale;
243   params.scalar.output_zero_point = (int32_t)(uint32_t)output_zero_point;
244   params.scalar.output_max = (int32_t)(uint32_t)output_max;
245   params.scalar.output_min = (int32_t)(uint32_t)output_min;
246 #endif
247   return params;
248 }
249 
250 static inline union pytorch_qnnp_avgpool_quantization_params
pytorch_qnnp_compute_scalar_avgpool_quantization_params(int32_t bias,float scale,uint8_t output_zero_point,uint8_t output_min,uint8_t output_max)251 pytorch_qnnp_compute_scalar_avgpool_quantization_params(
252     int32_t bias,
253     float scale,
254     uint8_t output_zero_point,
255     uint8_t output_min,
256     uint8_t output_max) {
257   /* Compute requantization parameters */
258   assert(scale >= 0x1.0p-32f);
259   assert(scale < 256.0f);
260 
261   union pytorch_qnnp_avgpool_quantization_params params;
262   params.scalar.bias = bias;
263   params.scalar.scale = scale;
264   params.scalar.output_zero_point = (int32_t)(uint32_t)output_zero_point;
265   params.scalar.output_max = (int32_t)(uint32_t)output_max;
266   params.scalar.output_min = (int32_t)(uint32_t)output_min;
267   return params;
268 }
269 
270 static inline union pytorch_qnnp_u8_clamping_params
pytorch_qnnp_compute_u8_clamping_params(uint8_t output_min,uint8_t output_max)271 pytorch_qnnp_compute_u8_clamping_params(
272     uint8_t output_min,
273     uint8_t output_max) {
274   assert(output_min <= output_max);
275 
276   union pytorch_qnnp_u8_clamping_params params;
277 #if CPUINFO_ARCH_X86 || CPUINFO_ARCH_X86_64
278   for (uint32_t i = 0; i < 16; i++) {
279     params.sse2.output_max[i] = output_max;
280     params.sse2.output_min[i] = output_min;
281   }
282 #elif CPUINFO_ARCH_ARM || CPUINFO_ARCH_ARM64
283   params.neon.output_max = output_max;
284   params.neon.output_min = output_min;
285 #else
286   params.scalar.output_min = (int32_t)(uint32_t)output_min;
287   params.scalar.output_max = (int32_t)(uint32_t)output_max;
288 #endif
289   return params;
290 }
291 
292 static inline union pytorch_qnnp_add_quantization_params
pytorch_qnnp_compute_add_quantization_params(uint8_t a_zero_point,uint8_t b_zero_point,uint8_t output_zero_point,float a_output_scale,float b_output_scale,uint8_t output_min,uint8_t output_max)293 pytorch_qnnp_compute_add_quantization_params(
294     uint8_t a_zero_point,
295     uint8_t b_zero_point,
296     uint8_t output_zero_point,
297     float a_output_scale,
298     float b_output_scale,
299     uint8_t output_min,
300     uint8_t output_max) {
301   assert(a_output_scale >= 0x1.0p-14f);
302   assert(b_output_scale >= 0x1.0p-14f);
303   assert(a_output_scale < 0x1.0p+8f);
304   assert(b_output_scale < 0x1.0p+8f);
305 
306   /* Compute requantization parameters */
307   const float max_output_scale =
308       a_output_scale > b_output_scale ? a_output_scale : b_output_scale;
309   assert(max_output_scale >= 0x1.0p-14f);
310   assert(max_output_scale < 0x1.0p+8f);
311   const uint32_t max_scale_bits = fp32_to_bits(max_output_scale);
312   const int32_t max_scale_exponent = (int32_t)(max_scale_bits >> 23) - 127;
313   /* Shift is in [13, 31] range */
314   const uint32_t shift = (uint32_t)(21 - max_scale_exponent);
315   assert(shift < 32);
316   assert(shift >= 13);
317 
318   const float scale_multiplier =
319       fp32_from_bits((uint32_t)(21 - max_scale_exponent + 127) << 23);
320 
321   /* Multipliers are in [0, 2**22) range, largest multiplier is in [2**21,
322    * 2**22) range */
323   const uint32_t a_multiplier =
324       (uint32_t)(int32_t)lrintf(a_output_scale * scale_multiplier);
325   const uint32_t b_multiplier =
326       (uint32_t)(int32_t)lrintf(b_output_scale * scale_multiplier);
327   assert(
328       (a_multiplier > b_multiplier ? a_multiplier : b_multiplier) >=
329       UINT32_C(0x00200000));
330   assert(a_multiplier < UINT32_C(0x00400000));
331   assert(b_multiplier < UINT32_C(0x00400000));
332 
333   union pytorch_qnnp_add_quantization_params params;
334 #if CPUINFO_ARCH_X86 || CPUINFO_ARCH_X86_64
335   const uint32_t remainder_mask = (UINT32_C(1) << shift) - UINT32_C(1);
336   const uint32_t remainder_threshold = remainder_mask >> 1;
337   const int32_t zero_point_product = (int32_t) -
338       (a_multiplier * (uint32_t)a_zero_point +
339        b_multiplier * (uint32_t)b_zero_point);
340   for (uint32_t i = 0; i < 4; i++) {
341     params.sse2.zero_point_product[i] = zero_point_product;
342   }
343   for (uint32_t i = 0; i < 8; i++) {
344     params.sse2.y_zero_point[i] = (int16_t)(uint16_t)output_zero_point;
345   }
346   for (uint32_t i = 0; i < 8; i++) {
347     params.sse2.a_multiplier_lo[i] = (uint16_t)(uint32_t)a_multiplier;
348     params.sse2.a_multiplier_hi[i] = (uint16_t)((uint32_t)a_multiplier >> 16);
349     params.sse2.b_multiplier_lo[i] = (uint16_t)(uint32_t)b_multiplier;
350     params.sse2.b_multiplier_hi[i] = (uint16_t)((uint32_t)b_multiplier >> 16);
351   }
352   params.sse2.a_multiplier = a_multiplier;
353   params.sse2.b_multiplier = b_multiplier;
354   for (uint32_t i = 0; i < 4; i++) {
355     params.sse2.remainder_mask[i] = remainder_mask;
356     params.sse2.remainder_threshold[i] = remainder_threshold;
357   }
358   params.sse2.shift = shift;
359   for (uint32_t i = 0; i < 16; i++) {
360     params.sse2.y_max[i] = output_max;
361     params.sse2.y_min[i] = output_min;
362   }
363 #elif CPUINFO_ARCH_ARM || CPUINFO_ARCH_ARM64
364   params.neon.a_zero_point = a_zero_point;
365   params.neon.b_zero_point = b_zero_point;
366   params.neon.y_zero_point = (int16_t)(uint16_t)output_zero_point;
367   params.neon.a_multiplier = (int32_t)a_multiplier;
368   params.neon.b_multiplier = (int32_t)b_multiplier;
369   params.neon.right_shift = (int32_t)-shift;
370   params.neon.y_max = output_max;
371   params.neon.y_min = output_min;
372 #else
373   const uint32_t remainder_mask = (UINT32_C(1) << shift) - UINT32_C(1);
374   const uint32_t remainder_threshold = remainder_mask >> 1;
375   params.scalar.zero_point_product = (int32_t) -
376       (a_multiplier * (uint32_t)a_zero_point +
377        b_multiplier * (uint32_t)b_zero_point);
378   params.scalar.a_multiplier = a_multiplier;
379   params.scalar.b_multiplier = b_multiplier;
380   params.scalar.remainder_mask = (int32_t)remainder_mask;
381   params.scalar.remainder_threshold = (int32_t)remainder_threshold;
382   params.scalar.shift = shift;
383   params.scalar.y_zero_point = (int32_t)(uint32_t)output_zero_point;
384   params.scalar.y_max = (int32_t)(uint32_t)output_max;
385   params.scalar.y_min = (int32_t)(uint32_t)output_min;
386 #endif
387   return params;
388 }
389 
390 static inline union pytorch_qnnp_add_quantization_params
pytorch_qnnp_compute_scalar_add_quantization_params(uint8_t a_zero_point,uint8_t b_zero_point,uint8_t output_zero_point,float a_output_scale,float b_output_scale,uint8_t output_min,uint8_t output_max)391 pytorch_qnnp_compute_scalar_add_quantization_params(
392     uint8_t a_zero_point,
393     uint8_t b_zero_point,
394     uint8_t output_zero_point,
395     float a_output_scale,
396     float b_output_scale,
397     uint8_t output_min,
398     uint8_t output_max) {
399   assert(a_output_scale >= 0x1.0p-10f);
400   assert(b_output_scale >= 0x1.0p-10f);
401   assert(a_output_scale < 0x1.0p+8f);
402   assert(b_output_scale < 0x1.0p+8f);
403 
404   /* Compute requantization parameters */
405   const float max_output_scale =
406       a_output_scale > b_output_scale ? a_output_scale : b_output_scale;
407   assert(max_output_scale >= 0x1.0p-10f);
408   assert(max_output_scale < 0x1.0p+8f);
409   const uint32_t max_scale_bits = fp32_to_bits(max_output_scale);
410   const int32_t max_scale_exponent = (int32_t)(max_scale_bits >> 23) - 127;
411   /* Shift is in [13, 31] range */
412   const uint32_t shift = (uint32_t)(21 - max_scale_exponent);
413   assert(shift < 32);
414   assert(shift >= 13);
415 
416   /* Multipliers are in [0, 2**22) range, largest multiplier is in [2**21,
417    * 2**22) range */
418   const uint32_t a_multiplier = (uint32_t)(int32_t)lrintf(
419       fp32_from_bits(fp32_to_bits(a_output_scale) + (shift << 23)));
420   const uint32_t b_multiplier = (uint32_t)(int32_t)lrintf(
421       fp32_from_bits(fp32_to_bits(b_output_scale) + (shift << 23)));
422   assert(
423       (a_multiplier > b_multiplier ? a_multiplier : b_multiplier) >=
424       UINT32_C(0x00200000));
425   assert(a_multiplier < UINT32_C(0x00400000));
426   assert(b_multiplier < UINT32_C(0x00400000));
427 
428   union pytorch_qnnp_add_quantization_params params;
429   const uint32_t remainder_mask = (UINT32_C(1) << shift) - UINT32_C(1);
430   const uint32_t remainder_threshold = remainder_mask >> 1;
431   params.scalar.zero_point_product = (int32_t) -
432       (a_multiplier * (uint32_t)a_zero_point +
433        b_multiplier * (uint32_t)b_zero_point);
434   params.scalar.a_multiplier = a_multiplier;
435   params.scalar.b_multiplier = b_multiplier;
436   params.scalar.remainder_mask = (int32_t)remainder_mask;
437   params.scalar.remainder_threshold = (int32_t)remainder_threshold;
438   params.scalar.shift = shift;
439   params.scalar.y_zero_point = (int32_t)(uint32_t)output_zero_point;
440   params.scalar.y_max = (int32_t)(uint32_t)output_max;
441   params.scalar.y_min = (int32_t)(uint32_t)output_min;
442   return params;
443 }
444 
pytorch_qnnp_q31_requantize(int32_t n,union pytorch_qnnp_q31_requantization_params params)445 static inline uint8_t pytorch_qnnp_q31_requantize(
446     int32_t n,
447     union pytorch_qnnp_q31_requantization_params params) {
448   const int64_t product = (int64_t)n * (int64_t)params.scalar.multiplier;
449   const int32_t q31product =
450       (int32_t)(uint32_t)((uint64_t)(product + INT64_C(0x40000000)) >> 31);
451   const int32_t remainder =
452       (q31product & params.scalar.remainder_mask) - (int32_t)(n < 0);
453   n = asr_s32(q31product, params.scalar.shift) +
454       (int32_t)(remainder > params.scalar.remainder_threshold);
455   if (n < params.scalar.min_less_zero_point) {
456     n = params.scalar.min_less_zero_point;
457   }
458   if (n > params.scalar.max_less_zero_point) {
459     n = params.scalar.max_less_zero_point;
460   }
461 
462   return (uint8_t)(n + params.scalar.zero_point);
463 }
464 
pytorch_qnnp_fp32_requantize(int32_t n,union pytorch_qnnp_fp32_requantization_params params,int32_t output_channel_index)465 static inline uint8_t pytorch_qnnp_fp32_requantize(
466     int32_t n,
467     union pytorch_qnnp_fp32_requantization_params params,
468     int32_t output_channel_index) {
469 
470   const long lmin =
471       (long)((int32_t)(uint32_t)params.scalar.output_min -
472           (int32_t)(uint32_t)params.scalar.output_zero_point);
473   const long lmax =
474       (long)((int32_t)(uint32_t)params.scalar.output_max -
475           (int32_t)(uint32_t)params.scalar.output_zero_point);
476 
477   const float n_scaled = (float)n * params.scalar.scales[output_channel_index];
478   const long n_rounded = lrintf(n_scaled);
479   const int32_t n_clamped = (int32_t)(
480       n_rounded < lmin ? lmin : n_rounded > lmax ? lmax : n_rounded);
481   const int32_t n_biased =
482       n_clamped + (int32_t)(uint32_t)params.scalar.output_zero_point;
483 
484   return (uint8_t)n_biased;
485 }
486 
pytorch_qnnp_fp32_requantize_magic(int32_t n,union pytorch_qnnp_fp32_requantization_params params,int32_t output_channel_index)487 static inline uint8_t pytorch_qnnp_fp32_requantize_magic(
488     int32_t n,
489     union pytorch_qnnp_fp32_requantization_params params,
490     int32_t output_channel_index) {
491 
492   const float fmin = params.scalar.min_less_zero_point;
493   const float fmax = params.scalar.max_less_zero_point;
494   const float fmagic = params.scalar.magic;
495   const int32_t imagic = params.scalar.magic_less_zero_point;
496 
497   const float n_scaled = (float)n * params.scalar.scales[output_channel_index];
498   const float n_clamped =
499       n_scaled < fmin ? fmin : n_scaled > fmax ? fmax : n_scaled;
500   const int32_t n_biased = (int32_t)fp32_to_bits(n_clamped + fmagic) - imagic;
501 
502   return (uint8_t)n_biased;
503 }
504 
pytorch_qnnp_avgpool_quantize(int32_t n,union pytorch_qnnp_avgpool_quantization_params params)505 static inline uint8_t pytorch_qnnp_avgpool_quantize(
506     int32_t n,
507     union pytorch_qnnp_avgpool_quantization_params params) {
508 
509   const float scaled_n = ((float)n)*params.scalar.scale;
510   int32_t n_rounded = (int32_t)lrintf(scaled_n) + params.scalar.output_zero_point;
511 
512   const int32_t lmin =
513       (int32_t)(uint32_t)params.scalar.output_min;
514   const int32_t lmax =
515       (int32_t)(uint32_t)params.scalar.output_max;
516 
517   n_rounded = (
518       n_rounded < lmin ? lmin : n_rounded > lmax ? lmax : n_rounded);
519 
520   return (uint8_t)n_rounded;
521 }
522 
pytorch_qnnp_add_quantize(uint8_t a,uint8_t b,union pytorch_qnnp_add_quantization_params params)523 static inline uint8_t pytorch_qnnp_add_quantize(
524     uint8_t a,
525     uint8_t b,
526     union pytorch_qnnp_add_quantization_params params) {
527   /* Multiply by factors and accumulate products */
528   int32_t acc = params.scalar.zero_point_product +
529       (int32_t)((uint32_t)a * params.scalar.a_multiplier) +
530       (int32_t)((uint32_t)b * params.scalar.b_multiplier);
531 
532   /* Shift right and round */
533   const int32_t rem = (acc & params.scalar.remainder_mask) - (int32_t)(acc < 0);
534   acc = asr_s32(acc, params.scalar.shift) +
535       (int32_t)(rem > params.scalar.remainder_threshold);
536 
537   /* Clamp and add output zero point */
538   int32_t y = acc + params.scalar.y_zero_point;
539   if (y >= params.scalar.y_max) {
540     y = params.scalar.y_max;
541   }
542   if (y <= params.scalar.y_min) {
543     y = params.scalar.y_min;
544   }
545   return (uint8_t)y;
546 }
547