1 /*
2 * Copyright (c) Facebook, Inc. and its affiliates.
3 * All rights reserved.
4 *
5 * This source code is licensed under the BSD-style license found in the
6 * LICENSE file in the root directory of this source tree.
7 */
8
9 #pragma once
10
11 #include <assert.h>
12 #include <math.h>
13 #include <stddef.h>
14 #include <stdint.h>
15
16 #include <fp16/bitcasts.h>
17
18 #include <qnnpack/params.h>
19 #include <qnnpack/scalar-utils.h>
20
21 static inline union pytorch_qnnp_q31_requantization_params
pytorch_qnnp_compute_scalar_requantization_params(float scale,uint8_t zero_point,uint8_t min,uint8_t max)22 pytorch_qnnp_compute_scalar_requantization_params(
23 float scale,
24 uint8_t zero_point,
25 uint8_t min,
26 uint8_t max) {
27 /* Compute requantization parameters */
28 assert(scale < 1.0f);
29 assert(scale >= 0x1.0p-32f);
30 const uint32_t scale_bits = fp32_to_bits(scale);
31
32 /* Multiplier is in [0x40000000, 0x7FFFFF80] range */
33 const int32_t multiplier = (int32_t)(
34 ((scale_bits & UINT32_C(0x007FFFFF)) | UINT32_C(0x00800000)) << 7);
35 assert(multiplier >= INT32_C(0x40000000));
36 assert(multiplier <= INT32_C(0x7FFFFF80));
37
38 /* Shift is in [0, 31] range */
39 const int32_t shift = 127 + 31 - 32 - (fp32_to_bits(scale) >> 23);
40 assert(shift >= 0);
41 assert(shift < 32);
42
43 union pytorch_qnnp_q31_requantization_params params;
44 const uint32_t remainder_mask = (UINT32_C(1) << shift) - UINT32_C(1);
45 const uint32_t remainder_threshold = remainder_mask >> 1;
46 params.scalar.multiplier = multiplier;
47 params.scalar.remainder_mask = (int32_t)remainder_mask;
48 params.scalar.remainder_threshold = (int32_t)remainder_threshold;
49 params.scalar.shift = (uint32_t)shift;
50 params.scalar.min_less_zero_point =
51 (int32_t)(uint32_t)min - (int32_t)(uint32_t)zero_point;
52 params.scalar.max_less_zero_point =
53 (int32_t)(uint32_t)max - (int32_t)(uint32_t)zero_point;
54 params.scalar.zero_point = (int32_t)(uint32_t)zero_point;
55 return params;
56 }
57
58 static inline union pytorch_qnnp_fp32_requantization_params
pytorch_qnnp_compute_scalar_fp32_requantization_params(float * scales,uint8_t zero_point,uint8_t min,uint8_t max)59 pytorch_qnnp_compute_scalar_fp32_requantization_params(
60 float* scales,
61 uint8_t zero_point,
62 uint8_t min,
63 uint8_t max) {
64
65 union pytorch_qnnp_fp32_requantization_params params;
66 params.scalar.scales = scales;
67 params.scalar.output_zero_point = zero_point;
68 params.scalar.output_max = max;
69 params.scalar.output_min = min;
70 params.scalar.min_less_zero_point = ((float)((int32_t)(uint32_t)min -
71 (int32_t)(uint32_t)zero_point));
72 params.scalar.max_less_zero_point = ((float)((int32_t)(uint32_t)max -
73 (int32_t)(uint32_t)zero_point));
74 params.scalar.magic = 12582912.0f;
75 params.scalar.magic_less_zero_point = (INT32_C(0x4B400000) -
76 (int32_t)(uint32_t)zero_point);
77 return params;
78 }
79
80 static inline union pytorch_qnnp_q31_requantization_params
pytorch_qnnp_compute_requantization_params(float scale,uint8_t zero_point,uint8_t min,uint8_t max)81 pytorch_qnnp_compute_requantization_params(
82 float scale,
83 uint8_t zero_point,
84 uint8_t min,
85 uint8_t max) {
86 /* Compute requantization parameters */
87 const uint32_t scale_bits = fp32_to_bits(scale);
88
89 /* Multiplier is in [0x40000000, 0x7FFFFF80] range */
90 const int32_t multiplier = (int32_t)(
91 ((scale_bits & UINT32_C(0x007FFFFF)) | UINT32_C(0x00800000)) << 7);
92 assert(multiplier >= INT32_C(0x40000000));
93 assert(multiplier <= INT32_C(0x7FFFFF80));
94
95 /* Shift is in [0, 31] range */
96 const int32_t shift = 127 + 31 - 32 - (fp32_to_bits(scale) >> 23);
97 assert(shift >= 0);
98 assert(shift < 32);
99
100 union pytorch_qnnp_q31_requantization_params params;
101 #if CPUINFO_ARCH_X86 || CPUINFO_ARCH_X86_64
102 const uint32_t remainder_mask = (UINT32_C(1) << shift) - UINT32_C(1);
103 const uint32_t remainder_threshold = remainder_mask >> 1;
104 params.sse2.multiplier[0] = multiplier;
105 params.sse2.multiplier[1] = multiplier;
106 params.sse2.multiplier[2] = multiplier;
107 params.sse2.multiplier[3] = multiplier;
108 params.sse2.rounding[0] = UINT64_C(0x40000000);
109 params.sse2.rounding[1] = UINT64_C(0x40000000);
110 params.sse2.remainder_mask[0] = (int32_t)remainder_mask;
111 params.sse2.remainder_mask[1] = (int32_t)remainder_mask;
112 params.sse2.remainder_mask[2] = (int32_t)remainder_mask;
113 params.sse2.remainder_mask[3] = (int32_t)remainder_mask;
114 params.sse2.remainder_threshold[0] = (int32_t)remainder_threshold;
115 params.sse2.remainder_threshold[1] = (int32_t)remainder_threshold;
116 params.sse2.remainder_threshold[2] = (int32_t)remainder_threshold;
117 params.sse2.remainder_threshold[3] = (int32_t)remainder_threshold;
118 params.sse2.shift[0] = (uint64_t)(uint32_t)shift;
119 params.sse2.shift[1] = (uint64_t)(uint32_t)shift;
120 for (uint32_t i = 0; i < 8; i++) {
121 params.sse2.zero_point[i] = (int16_t)(uint16_t)zero_point;
122 }
123 for (uint32_t i = 0; i < 16; i++) {
124 params.sse2.max[i] = max;
125 params.sse2.min[i] = min;
126 }
127 #elif CPUINFO_ARCH_ARM || CPUINFO_ARCH_ARM64
128 params.neon.multiplier = multiplier;
129 params.neon.right_shift = -shift;
130 params.neon.zero_point = (int16_t)(uint16_t)zero_point;
131 params.neon.max = max;
132 params.neon.min = min;
133 #else
134 const uint32_t remainder_mask = (UINT32_C(1) << shift) - UINT32_C(1);
135 const uint32_t remainder_threshold = remainder_mask >> 1;
136 params.scalar.multiplier = multiplier;
137 params.scalar.remainder_mask = (int32_t)remainder_mask;
138 params.scalar.remainder_threshold = (int32_t)remainder_threshold;
139 params.scalar.shift = (uint32_t)shift;
140 params.scalar.min_less_zero_point =
141 (int32_t)(uint32_t)min - (int32_t)(uint32_t)zero_point;
142 params.scalar.max_less_zero_point =
143 (int32_t)(uint32_t)max - (int32_t)(uint32_t)zero_point;
144 params.scalar.zero_point = (int32_t)(uint32_t)zero_point;
145 #endif
146 return params;
147 }
148
149 static inline union pytorch_qnnp_conv_quantization_params
pytorch_qnnp_compute_conv_quantization_params(uint8_t input_zero_point,const uint8_t * kernel_zero_points,const float * requantization_scales,uint8_t output_zero_point,uint8_t output_min,uint8_t output_max)150 pytorch_qnnp_compute_conv_quantization_params(
151 uint8_t input_zero_point,
152 const uint8_t* kernel_zero_points,
153 const float* requantization_scales,
154 uint8_t output_zero_point,
155 uint8_t output_min,
156 uint8_t output_max) {
157
158 union pytorch_qnnp_conv_quantization_params params;
159 #if CPUINFO_ARCH_X86 || CPUINFO_ARCH_X86_64
160 params.sse2.kernel_zero_points = kernel_zero_points;
161 for (uint32_t i = 0; i < 8; i++) {
162 params.sse2.input_zero_point[i] = (int16_t)(uint16_t)input_zero_point;
163 }
164 params.sse2.requantization_scales = requantization_scales;
165 for (uint32_t i = 0; i < 8; i++) {
166 params.sse2.output_zero_point[i] = (int16_t)(uint16_t)output_zero_point;
167 }
168 for (uint32_t i = 0; i < 16; i++) {
169 params.sse2.output_max[i] = output_max;
170 params.sse2.output_min[i] = output_min;
171 }
172 #elif CPUINFO_ARCH_ARM || CPUINFO_ARCH_ARM64
173 params.neon.input_zero_point = (int16_t)(uint16_t)input_zero_point;
174 params.neon.kernel_zero_points = kernel_zero_points;
175 params.neon.requantization_scales = requantization_scales;
176 params.neon.output_zero_point = (int16_t)(uint16_t)output_zero_point;
177 params.neon.output_max = output_max;
178 params.neon.output_min = output_min;
179 params.neon.vfmin = ((float)((int32_t)(uint32_t)output_min -
180 (int32_t)(uint32_t)output_zero_point));
181 params.neon.vfmax = ((float)((int32_t)(uint32_t)output_max -
182 (int32_t)(uint32_t)output_zero_point));
183 params.neon.vfmagic = 12582912.0f;
184 params.neon.vimagic = (INT32_C(0x4B400000) -
185 (int32_t)(uint32_t)output_zero_point);
186 #else
187 params.scalar.input_zero_point = (int32_t)(uint32_t)input_zero_point;
188 params.scalar.kernel_zero_points = kernel_zero_points;
189 params.scalar.requantization_scales = requantization_scales;
190 params.scalar.output_min_less_zero_point =
191 (int32_t)(uint32_t)output_min - (int32_t)(uint32_t)output_zero_point;
192 params.scalar.output_max_less_zero_point =
193 (int32_t)(uint32_t)output_max - (int32_t)(uint32_t)output_zero_point;
194 params.scalar.output_zero_point = (int32_t)(uint32_t)output_zero_point;
195 #endif
196 return params;
197 }
198
199 static inline union pytorch_qnnp_avgpool_quantization_params
pytorch_qnnp_compute_avgpool_quantization_params(int32_t bias,float scale,uint8_t output_zero_point,uint8_t output_min,uint8_t output_max)200 pytorch_qnnp_compute_avgpool_quantization_params(
201 int32_t bias,
202 float scale,
203 uint8_t output_zero_point,
204 uint8_t output_min,
205 uint8_t output_max) {
206 /* Compute requantization parameters */
207 assert(scale >= 0x1.0p-32f);
208 assert(scale < 256.0f);
209
210 union pytorch_qnnp_avgpool_quantization_params params;
211 #if CPUINFO_ARCH_X86 || CPUINFO_ARCH_X86_64
212 params.sse2.bias[0] = bias;
213 params.sse2.bias[1] = bias;
214 params.sse2.bias[2] = bias;
215 params.sse2.bias[3] = bias;
216 params.sse2.scale[0] = scale;
217 params.sse2.scale[1] = scale;
218 params.sse2.scale[2] = scale;
219 params.sse2.scale[3] = scale;
220 for (uint32_t i = 0; i < 8; i++) {
221 params.sse2.output_zero_point[i] = (int16_t)(uint16_t)output_zero_point;
222 }
223 for (uint32_t i = 0; i < 16; i++) {
224 params.sse2.output_max[i] = output_max;
225 params.sse2.output_min[i] = output_min;
226 }
227 #elif CPUINFO_ARCH_ARM || CPUINFO_ARCH_ARM64
228 params.neon.bias = bias;
229 params.neon.scale = scale;
230 params.neon.output_zero_point = (int16_t)(uint16_t)output_zero_point;
231 params.neon.output_max = output_max;
232 params.neon.output_min = output_min;
233 params.neon.vfmin = ((float)((int32_t)(uint32_t)output_min -
234 (int32_t)(uint32_t)output_zero_point));
235 params.neon.vfmax = ((float)((int32_t)(uint32_t)output_max -
236 (int32_t)(uint32_t)output_zero_point));
237 params.neon.vfmagic = 12582912.0f;
238 params.neon.vimagic = (INT32_C(0x4B400000) -
239 (int32_t)(uint32_t)output_zero_point);
240 #else
241 params.scalar.bias = bias;
242 params.scalar.scale = scale;
243 params.scalar.output_zero_point = (int32_t)(uint32_t)output_zero_point;
244 params.scalar.output_max = (int32_t)(uint32_t)output_max;
245 params.scalar.output_min = (int32_t)(uint32_t)output_min;
246 #endif
247 return params;
248 }
249
250 static inline union pytorch_qnnp_avgpool_quantization_params
pytorch_qnnp_compute_scalar_avgpool_quantization_params(int32_t bias,float scale,uint8_t output_zero_point,uint8_t output_min,uint8_t output_max)251 pytorch_qnnp_compute_scalar_avgpool_quantization_params(
252 int32_t bias,
253 float scale,
254 uint8_t output_zero_point,
255 uint8_t output_min,
256 uint8_t output_max) {
257 /* Compute requantization parameters */
258 assert(scale >= 0x1.0p-32f);
259 assert(scale < 256.0f);
260
261 union pytorch_qnnp_avgpool_quantization_params params;
262 params.scalar.bias = bias;
263 params.scalar.scale = scale;
264 params.scalar.output_zero_point = (int32_t)(uint32_t)output_zero_point;
265 params.scalar.output_max = (int32_t)(uint32_t)output_max;
266 params.scalar.output_min = (int32_t)(uint32_t)output_min;
267 return params;
268 }
269
270 static inline union pytorch_qnnp_u8_clamping_params
pytorch_qnnp_compute_u8_clamping_params(uint8_t output_min,uint8_t output_max)271 pytorch_qnnp_compute_u8_clamping_params(
272 uint8_t output_min,
273 uint8_t output_max) {
274 assert(output_min <= output_max);
275
276 union pytorch_qnnp_u8_clamping_params params;
277 #if CPUINFO_ARCH_X86 || CPUINFO_ARCH_X86_64
278 for (uint32_t i = 0; i < 16; i++) {
279 params.sse2.output_max[i] = output_max;
280 params.sse2.output_min[i] = output_min;
281 }
282 #elif CPUINFO_ARCH_ARM || CPUINFO_ARCH_ARM64
283 params.neon.output_max = output_max;
284 params.neon.output_min = output_min;
285 #else
286 params.scalar.output_min = (int32_t)(uint32_t)output_min;
287 params.scalar.output_max = (int32_t)(uint32_t)output_max;
288 #endif
289 return params;
290 }
291
292 static inline union pytorch_qnnp_add_quantization_params
pytorch_qnnp_compute_add_quantization_params(uint8_t a_zero_point,uint8_t b_zero_point,uint8_t output_zero_point,float a_output_scale,float b_output_scale,uint8_t output_min,uint8_t output_max)293 pytorch_qnnp_compute_add_quantization_params(
294 uint8_t a_zero_point,
295 uint8_t b_zero_point,
296 uint8_t output_zero_point,
297 float a_output_scale,
298 float b_output_scale,
299 uint8_t output_min,
300 uint8_t output_max) {
301 assert(a_output_scale >= 0x1.0p-14f);
302 assert(b_output_scale >= 0x1.0p-14f);
303 assert(a_output_scale < 0x1.0p+8f);
304 assert(b_output_scale < 0x1.0p+8f);
305
306 /* Compute requantization parameters */
307 const float max_output_scale =
308 a_output_scale > b_output_scale ? a_output_scale : b_output_scale;
309 assert(max_output_scale >= 0x1.0p-14f);
310 assert(max_output_scale < 0x1.0p+8f);
311 const uint32_t max_scale_bits = fp32_to_bits(max_output_scale);
312 const int32_t max_scale_exponent = (int32_t)(max_scale_bits >> 23) - 127;
313 /* Shift is in [13, 31] range */
314 const uint32_t shift = (uint32_t)(21 - max_scale_exponent);
315 assert(shift < 32);
316 assert(shift >= 13);
317
318 const float scale_multiplier =
319 fp32_from_bits((uint32_t)(21 - max_scale_exponent + 127) << 23);
320
321 /* Multipliers are in [0, 2**22) range, largest multiplier is in [2**21,
322 * 2**22) range */
323 const uint32_t a_multiplier =
324 (uint32_t)(int32_t)lrintf(a_output_scale * scale_multiplier);
325 const uint32_t b_multiplier =
326 (uint32_t)(int32_t)lrintf(b_output_scale * scale_multiplier);
327 assert(
328 (a_multiplier > b_multiplier ? a_multiplier : b_multiplier) >=
329 UINT32_C(0x00200000));
330 assert(a_multiplier < UINT32_C(0x00400000));
331 assert(b_multiplier < UINT32_C(0x00400000));
332
333 union pytorch_qnnp_add_quantization_params params;
334 #if CPUINFO_ARCH_X86 || CPUINFO_ARCH_X86_64
335 const uint32_t remainder_mask = (UINT32_C(1) << shift) - UINT32_C(1);
336 const uint32_t remainder_threshold = remainder_mask >> 1;
337 const int32_t zero_point_product = (int32_t) -
338 (a_multiplier * (uint32_t)a_zero_point +
339 b_multiplier * (uint32_t)b_zero_point);
340 for (uint32_t i = 0; i < 4; i++) {
341 params.sse2.zero_point_product[i] = zero_point_product;
342 }
343 for (uint32_t i = 0; i < 8; i++) {
344 params.sse2.y_zero_point[i] = (int16_t)(uint16_t)output_zero_point;
345 }
346 for (uint32_t i = 0; i < 8; i++) {
347 params.sse2.a_multiplier_lo[i] = (uint16_t)(uint32_t)a_multiplier;
348 params.sse2.a_multiplier_hi[i] = (uint16_t)((uint32_t)a_multiplier >> 16);
349 params.sse2.b_multiplier_lo[i] = (uint16_t)(uint32_t)b_multiplier;
350 params.sse2.b_multiplier_hi[i] = (uint16_t)((uint32_t)b_multiplier >> 16);
351 }
352 params.sse2.a_multiplier = a_multiplier;
353 params.sse2.b_multiplier = b_multiplier;
354 for (uint32_t i = 0; i < 4; i++) {
355 params.sse2.remainder_mask[i] = remainder_mask;
356 params.sse2.remainder_threshold[i] = remainder_threshold;
357 }
358 params.sse2.shift = shift;
359 for (uint32_t i = 0; i < 16; i++) {
360 params.sse2.y_max[i] = output_max;
361 params.sse2.y_min[i] = output_min;
362 }
363 #elif CPUINFO_ARCH_ARM || CPUINFO_ARCH_ARM64
364 params.neon.a_zero_point = a_zero_point;
365 params.neon.b_zero_point = b_zero_point;
366 params.neon.y_zero_point = (int16_t)(uint16_t)output_zero_point;
367 params.neon.a_multiplier = (int32_t)a_multiplier;
368 params.neon.b_multiplier = (int32_t)b_multiplier;
369 params.neon.right_shift = (int32_t)-shift;
370 params.neon.y_max = output_max;
371 params.neon.y_min = output_min;
372 #else
373 const uint32_t remainder_mask = (UINT32_C(1) << shift) - UINT32_C(1);
374 const uint32_t remainder_threshold = remainder_mask >> 1;
375 params.scalar.zero_point_product = (int32_t) -
376 (a_multiplier * (uint32_t)a_zero_point +
377 b_multiplier * (uint32_t)b_zero_point);
378 params.scalar.a_multiplier = a_multiplier;
379 params.scalar.b_multiplier = b_multiplier;
380 params.scalar.remainder_mask = (int32_t)remainder_mask;
381 params.scalar.remainder_threshold = (int32_t)remainder_threshold;
382 params.scalar.shift = shift;
383 params.scalar.y_zero_point = (int32_t)(uint32_t)output_zero_point;
384 params.scalar.y_max = (int32_t)(uint32_t)output_max;
385 params.scalar.y_min = (int32_t)(uint32_t)output_min;
386 #endif
387 return params;
388 }
389
390 static inline union pytorch_qnnp_add_quantization_params
pytorch_qnnp_compute_scalar_add_quantization_params(uint8_t a_zero_point,uint8_t b_zero_point,uint8_t output_zero_point,float a_output_scale,float b_output_scale,uint8_t output_min,uint8_t output_max)391 pytorch_qnnp_compute_scalar_add_quantization_params(
392 uint8_t a_zero_point,
393 uint8_t b_zero_point,
394 uint8_t output_zero_point,
395 float a_output_scale,
396 float b_output_scale,
397 uint8_t output_min,
398 uint8_t output_max) {
399 assert(a_output_scale >= 0x1.0p-10f);
400 assert(b_output_scale >= 0x1.0p-10f);
401 assert(a_output_scale < 0x1.0p+8f);
402 assert(b_output_scale < 0x1.0p+8f);
403
404 /* Compute requantization parameters */
405 const float max_output_scale =
406 a_output_scale > b_output_scale ? a_output_scale : b_output_scale;
407 assert(max_output_scale >= 0x1.0p-10f);
408 assert(max_output_scale < 0x1.0p+8f);
409 const uint32_t max_scale_bits = fp32_to_bits(max_output_scale);
410 const int32_t max_scale_exponent = (int32_t)(max_scale_bits >> 23) - 127;
411 /* Shift is in [13, 31] range */
412 const uint32_t shift = (uint32_t)(21 - max_scale_exponent);
413 assert(shift < 32);
414 assert(shift >= 13);
415
416 /* Multipliers are in [0, 2**22) range, largest multiplier is in [2**21,
417 * 2**22) range */
418 const uint32_t a_multiplier = (uint32_t)(int32_t)lrintf(
419 fp32_from_bits(fp32_to_bits(a_output_scale) + (shift << 23)));
420 const uint32_t b_multiplier = (uint32_t)(int32_t)lrintf(
421 fp32_from_bits(fp32_to_bits(b_output_scale) + (shift << 23)));
422 assert(
423 (a_multiplier > b_multiplier ? a_multiplier : b_multiplier) >=
424 UINT32_C(0x00200000));
425 assert(a_multiplier < UINT32_C(0x00400000));
426 assert(b_multiplier < UINT32_C(0x00400000));
427
428 union pytorch_qnnp_add_quantization_params params;
429 const uint32_t remainder_mask = (UINT32_C(1) << shift) - UINT32_C(1);
430 const uint32_t remainder_threshold = remainder_mask >> 1;
431 params.scalar.zero_point_product = (int32_t) -
432 (a_multiplier * (uint32_t)a_zero_point +
433 b_multiplier * (uint32_t)b_zero_point);
434 params.scalar.a_multiplier = a_multiplier;
435 params.scalar.b_multiplier = b_multiplier;
436 params.scalar.remainder_mask = (int32_t)remainder_mask;
437 params.scalar.remainder_threshold = (int32_t)remainder_threshold;
438 params.scalar.shift = shift;
439 params.scalar.y_zero_point = (int32_t)(uint32_t)output_zero_point;
440 params.scalar.y_max = (int32_t)(uint32_t)output_max;
441 params.scalar.y_min = (int32_t)(uint32_t)output_min;
442 return params;
443 }
444
pytorch_qnnp_q31_requantize(int32_t n,union pytorch_qnnp_q31_requantization_params params)445 static inline uint8_t pytorch_qnnp_q31_requantize(
446 int32_t n,
447 union pytorch_qnnp_q31_requantization_params params) {
448 const int64_t product = (int64_t)n * (int64_t)params.scalar.multiplier;
449 const int32_t q31product =
450 (int32_t)(uint32_t)((uint64_t)(product + INT64_C(0x40000000)) >> 31);
451 const int32_t remainder =
452 (q31product & params.scalar.remainder_mask) - (int32_t)(n < 0);
453 n = asr_s32(q31product, params.scalar.shift) +
454 (int32_t)(remainder > params.scalar.remainder_threshold);
455 if (n < params.scalar.min_less_zero_point) {
456 n = params.scalar.min_less_zero_point;
457 }
458 if (n > params.scalar.max_less_zero_point) {
459 n = params.scalar.max_less_zero_point;
460 }
461
462 return (uint8_t)(n + params.scalar.zero_point);
463 }
464
pytorch_qnnp_fp32_requantize(int32_t n,union pytorch_qnnp_fp32_requantization_params params,int32_t output_channel_index)465 static inline uint8_t pytorch_qnnp_fp32_requantize(
466 int32_t n,
467 union pytorch_qnnp_fp32_requantization_params params,
468 int32_t output_channel_index) {
469
470 const long lmin =
471 (long)((int32_t)(uint32_t)params.scalar.output_min -
472 (int32_t)(uint32_t)params.scalar.output_zero_point);
473 const long lmax =
474 (long)((int32_t)(uint32_t)params.scalar.output_max -
475 (int32_t)(uint32_t)params.scalar.output_zero_point);
476
477 const float n_scaled = (float)n * params.scalar.scales[output_channel_index];
478 const long n_rounded = lrintf(n_scaled);
479 const int32_t n_clamped = (int32_t)(
480 n_rounded < lmin ? lmin : n_rounded > lmax ? lmax : n_rounded);
481 const int32_t n_biased =
482 n_clamped + (int32_t)(uint32_t)params.scalar.output_zero_point;
483
484 return (uint8_t)n_biased;
485 }
486
pytorch_qnnp_fp32_requantize_magic(int32_t n,union pytorch_qnnp_fp32_requantization_params params,int32_t output_channel_index)487 static inline uint8_t pytorch_qnnp_fp32_requantize_magic(
488 int32_t n,
489 union pytorch_qnnp_fp32_requantization_params params,
490 int32_t output_channel_index) {
491
492 const float fmin = params.scalar.min_less_zero_point;
493 const float fmax = params.scalar.max_less_zero_point;
494 const float fmagic = params.scalar.magic;
495 const int32_t imagic = params.scalar.magic_less_zero_point;
496
497 const float n_scaled = (float)n * params.scalar.scales[output_channel_index];
498 const float n_clamped =
499 n_scaled < fmin ? fmin : n_scaled > fmax ? fmax : n_scaled;
500 const int32_t n_biased = (int32_t)fp32_to_bits(n_clamped + fmagic) - imagic;
501
502 return (uint8_t)n_biased;
503 }
504
pytorch_qnnp_avgpool_quantize(int32_t n,union pytorch_qnnp_avgpool_quantization_params params)505 static inline uint8_t pytorch_qnnp_avgpool_quantize(
506 int32_t n,
507 union pytorch_qnnp_avgpool_quantization_params params) {
508
509 const float scaled_n = ((float)n)*params.scalar.scale;
510 int32_t n_rounded = (int32_t)lrintf(scaled_n) + params.scalar.output_zero_point;
511
512 const int32_t lmin =
513 (int32_t)(uint32_t)params.scalar.output_min;
514 const int32_t lmax =
515 (int32_t)(uint32_t)params.scalar.output_max;
516
517 n_rounded = (
518 n_rounded < lmin ? lmin : n_rounded > lmax ? lmax : n_rounded);
519
520 return (uint8_t)n_rounded;
521 }
522
pytorch_qnnp_add_quantize(uint8_t a,uint8_t b,union pytorch_qnnp_add_quantization_params params)523 static inline uint8_t pytorch_qnnp_add_quantize(
524 uint8_t a,
525 uint8_t b,
526 union pytorch_qnnp_add_quantization_params params) {
527 /* Multiply by factors and accumulate products */
528 int32_t acc = params.scalar.zero_point_product +
529 (int32_t)((uint32_t)a * params.scalar.a_multiplier) +
530 (int32_t)((uint32_t)b * params.scalar.b_multiplier);
531
532 /* Shift right and round */
533 const int32_t rem = (acc & params.scalar.remainder_mask) - (int32_t)(acc < 0);
534 acc = asr_s32(acc, params.scalar.shift) +
535 (int32_t)(rem > params.scalar.remainder_threshold);
536
537 /* Clamp and add output zero point */
538 int32_t y = acc + params.scalar.y_zero_point;
539 if (y >= params.scalar.y_max) {
540 y = params.scalar.y_max;
541 }
542 if (y <= params.scalar.y_min) {
543 y = params.scalar.y_min;
544 }
545 return (uint8_t)y;
546 }
547