1 #pragma once
2
3 // DO NOT DEFINE STATIC DATA IN THIS HEADER!
4 // See Note [Do not compile initializers with AVX]
5
6 #include <ATen/cpu/vec/intrinsics.h>
7 #include <ATen/cpu/vec/vec_base.h>
8 #include <ATen/native/quantized/AffineQuantizerBase.h>
9
10 #include <c10/util/irange.h>
11 #include <c10/util/qint32.h>
12 #include <c10/util/qint8.h>
13 #include <c10/util/quint8.h>
14
15 #include <array>
16 #include <cmath>
17
18 // This file defines Vectorized<> for the quantized types.
19 //
20 //
21 // Currently, we simply use these classes as efficient converters between
22 // the quantized types and Vectorized<float>, usually in bandwidth-bound cases
23 // where doing the arithmetic in full-precision is acceptable (e.g.
24 // elementwise operators).
25 //
26 //
27 // Conversions are as follows:
28 // Vectorized<qint8> -> 4x Vectorized<float>
29 // Vectorized<quint8> -> 4x Vectorized<float>
30 // Vectorized<qint32> -> 1x Vectorized<float>
31 //
32 // The size of the returned float vector is specified by the special
33 // constexpr function float_num_vecs. The type of the value returned
34 // from dequantize (and expected as an argument to quantize) is
35 // specified by float_vec_return_type.
36 //
37 // When writing kernels with these vectors, it is expected that floating-
38 // point operations will be carried out in a loop over Vectorized<T>::float_num_vecs
39 // iterations.
40
41 namespace at::vec {
42 inline namespace CPU_CAPABILITY {
43
44 #if defined(CPU_CAPABILITY_AVX2)
45
46 #ifdef _MSC_VER
47 __declspec(align(64)) struct Vectorizedqi {
48 protected:
49 __m256i vals;
50 #else
51 struct Vectorizedqi {
52 protected:
53 __m256i vals __attribute__((aligned(64)));
54 #endif
55
56 public:
VectorizedqiVectorizedqi57 Vectorizedqi() {}
VectorizedqiVectorizedqi58 Vectorizedqi(__m256i v) : vals(v) {}
__m256iVectorizedqi59 operator __m256i() const {
60 return vals;
61 }
62 };
63
64 template <typename T>
65 __m256i pack_saturate_and_clamp(
66 __m256i first,
67 __m256i second,
68 T min_val,
69 T max_val);
70
71 template <>
72 inline __m256i pack_saturate_and_clamp<int32_t>(
73 __m256i /*first*/,
74 __m256i /*second*/,
75 int32_t /*min_val*/,
76 int32_t /*max_val*/) {
77 // This function is for linkage only, will not be used
78 AT_ERROR("pack_saturate_and_clamp<int32_t> is not supported");
79 }
80
81 template <>
82 inline __m256i pack_saturate_and_clamp<int8_t>(
83 __m256i first,
84 __m256i second,
85 int8_t min_val,
86 int8_t max_val) {
87 __m256i packed_and_sat = _mm256_packs_epi16(first, second);
88 return _mm256_max_epi8(
89 _mm256_set1_epi8(min_val),
90 _mm256_min_epi8(packed_and_sat, _mm256_set1_epi8(max_val)));
91 }
92
93 template <>
94 inline __m256i pack_saturate_and_clamp<uint8_t>(
95 __m256i first,
96 __m256i second,
97 uint8_t min_val,
98 uint8_t max_val) {
99 __m256i packed_and_sat = _mm256_packus_epi16(first, second);
100 return _mm256_max_epu8(
101 _mm256_set1_epi8(min_val),
102 _mm256_min_epu8(packed_and_sat, _mm256_set1_epi8(max_val)));
103 }
104
105 template <typename T>
106 typename std::enable_if_t<std::is_same_v<T, uint8_t> || std::is_same_v<T, int8_t>, at::vec::Vectorized<float>>
convert_int8_to_float(at::vec::Vectorized<T> src)107 inline convert_int8_to_float(at::vec::Vectorized<T> src) {
108 // Note: this function only convert inputs number of elements equal to at::vec::Vectorized<float>.size()
109 // Only handle first 8*8 bits
110 __m128i input_128 = _mm256_castsi256_si128(src);
111 // Convert from 8*uint8/int8 to 8*int32
112 __m256i input_256_int32;
113 if constexpr (std::is_same_v<T, uint8_t>)
114 input_256_int32 = _mm256_cvtepu8_epi32(input_128);
115 else
116 input_256_int32 = _mm256_cvtepi8_epi32(input_128);
117 // Convert from 8*int32 to 8*float
118 return _mm256_cvtepi32_ps(input_256_int32);
119 }
120
121 template <typename T>
122 typename std::enable_if_t<std::is_same_v<T, uint8_t> || std::is_same_v<T, int8_t>, at::vec::Vectorized<T>>
convert_float_to_int8(at::vec::Vectorized<float> src)123 inline convert_float_to_int8(at::vec::Vectorized<float> src) {
124 // Convert from float32 to int32 with truncation
125 __m256i x_values_int32 = _mm256_cvttps_epi32(src);
126
127 // Convert from int32 to int16 using signed saturation
128 __m256i xy_packed_v = _mm256_packs_epi32(x_values_int32, x_values_int32);
129
130 constexpr auto min_val = std::numeric_limits<T>::min();
131 constexpr auto max_val = std::numeric_limits<T>::max();
132
133 // Convert from int16 to uint8/int8 using unsigned saturation
134 __m256i xyzw_clamped_v = pack_saturate_and_clamp<T>(
135 xy_packed_v, xy_packed_v, min_val, max_val);
136 __m256i permute_mask_v =
137 _mm256_set_epi32(0x07, 0x03, 0x06, 0x02, 0x05, 0x01, 0x04, 0x00);
138 return _mm256_permutevar8x32_epi32(xyzw_clamped_v, permute_mask_v);
139 }
140
141 template <typename T>
QuantizeAvx2(const float * src,T * dst,int len,float inverse_scale,int64_t zero_point)142 __FORCE_INLINE void QuantizeAvx2(
143 const float* src,
144 T* dst,
145 int len,
146 float inverse_scale,
147 int64_t zero_point) {
148 constexpr int VLEN = 8;
149 constexpr auto min_val = std::numeric_limits<T>::min();
150 constexpr auto max_val = std::numeric_limits<T>::max();
151 const __m256i min_v = _mm256_set1_epi32(min_val);
152 const __m256i max_v = _mm256_set1_epi32(max_val);
153 // This is the largest int32 value < int32_max exactly representable in float
154 constexpr int32_t int32_float_max_val =
155 std::numeric_limits<int32_t>::max() - 127;
156 int i = 0;
157 __m256 inverse_scale_v = _mm256_set1_ps(inverse_scale);
158 // clang-format off
159 static const __m256i shuffle_mask_v = _mm256_set_epi8(
160 0xff, 0xff, 0xff, 0xff,
161 0xff, 0xff, 0xff, 0xff,
162 0xff, 0xff, 0xff, 0xff,
163 0x0c, 0x08, 0x04, 0x00,
164 0xff, 0xff, 0xff, 0xff,
165 0xff, 0xff, 0xff, 0xff,
166 0xff, 0xff, 0xff, 0xff,
167 0x0c, 0x08, 0x04, 0x00);
168 // clang-format on
169 __m256i permute_mask_v =
170 _mm256_set_epi32(0x07, 0x03, 0x06, 0x02, 0x05, 0x01, 0x04, 0x00);
171 __m256i permute_mask_l8_v =
172 _mm256_set_epi32(0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x04, 0x00);
173 int len_aligned = len / (VLEN * 4) * (VLEN * 4);
174 for (; i < len_aligned; i += 4 * VLEN) {
175 // x
176 __m256 x_vals = _mm256_load_ps(src + i);
177 __m256 x_transformed_v = _mm256_mul_ps(x_vals, inverse_scale_v);
178 // If the floating point value is greater than int32_max,
179 // _mm256_cvtps_epi32 converts them to -ve. Clip at int32_float_max_val to
180 // Clip at int32_float_max_val to avoid this.
181 x_transformed_v =
182 _mm256_min_ps(x_transformed_v, _mm256_set1_ps(int32_float_max_val));
183 // y
184 __m256 y_vals = _mm256_load_ps(src + i + VLEN);
185 __m256 y_transformed_v = _mm256_mul_ps(y_vals, inverse_scale_v);
186 y_transformed_v =
187 _mm256_min_ps(y_transformed_v, _mm256_set1_ps(int32_float_max_val));
188 // z
189 __m256 z_vals = _mm256_load_ps(src + i + 2 * VLEN);
190 __m256 z_transformed_v = _mm256_mul_ps(z_vals, inverse_scale_v);
191 z_transformed_v =
192 _mm256_min_ps(z_transformed_v, _mm256_set1_ps(int32_float_max_val));
193 // w
194 __m256 w_vals = _mm256_load_ps(src + i + 3 * VLEN);
195 __m256 w_transformed_v = _mm256_mul_ps(w_vals, inverse_scale_v);
196 w_transformed_v =
197 _mm256_min_ps(w_transformed_v, _mm256_set1_ps(int32_float_max_val));
198
199 __m256i x_rounded_v = _mm256_cvtps_epi32(x_transformed_v);
200 __m256i y_rounded_v = _mm256_cvtps_epi32(y_transformed_v);
201 __m256i z_rounded_v = _mm256_cvtps_epi32(z_transformed_v);
202 __m256i w_rounded_v = _mm256_cvtps_epi32(w_transformed_v);
203
204 // add zero point
205 x_rounded_v = _mm256_add_epi32(x_rounded_v, _mm256_set1_epi32(zero_point));
206 y_rounded_v = _mm256_add_epi32(y_rounded_v, _mm256_set1_epi32(zero_point));
207 z_rounded_v = _mm256_add_epi32(z_rounded_v, _mm256_set1_epi32(zero_point));
208 w_rounded_v = _mm256_add_epi32(w_rounded_v, _mm256_set1_epi32(zero_point));
209
210 __m256i xy_packed_v = _mm256_packs_epi32(x_rounded_v, y_rounded_v);
211 __m256i zw_packed_v = _mm256_packs_epi32(z_rounded_v, w_rounded_v);
212 __m256i xyzw_clamped_v =
213 pack_saturate_and_clamp<T>(xy_packed_v, zw_packed_v, min_val, max_val);
214
215 xyzw_clamped_v =
216 _mm256_permutevar8x32_epi32(xyzw_clamped_v, permute_mask_v);
217 _mm256_storeu_si256(reinterpret_cast<__m256i*>(dst + i), xyzw_clamped_v);
218 }
219
220 // Additional 8-lane AVX2 version to take advantage when len is smaller
221 // based on fbgemm::QuantizeAvx2 (https://github.com/pytorch/FBGEMM)
222 for (; i < len / VLEN * VLEN; i += VLEN) {
223 __m256 x_vals = _mm256_load_ps(src + i);
224 __m256 x_transformed_v = _mm256_mul_ps(x_vals, inverse_scale_v);
225 x_transformed_v =
226 _mm256_min_ps(x_transformed_v, _mm256_set1_ps(int32_float_max_val));
227 __m256i x_rounded_v = _mm256_cvtps_epi32(x_transformed_v);
228 x_rounded_v = _mm256_add_epi32(x_rounded_v, _mm256_set1_epi32(zero_point));
229 __m256i x_clipped_v =
230 _mm256_max_epi32(min_v, _mm256_min_epi32(max_v, x_rounded_v));
231
232 x_clipped_v = _mm256_shuffle_epi8(x_clipped_v, shuffle_mask_v);
233 x_clipped_v = _mm256_permutevar8x32_epi32(x_clipped_v, permute_mask_l8_v);
234 _mm_storel_epi64(
235 reinterpret_cast<__m128i*>(dst + i),
236 _mm256_castsi256_si128(x_clipped_v));
237 }
238
239 for (; i < len; ++i) {
240 float transformed = src[i] * inverse_scale;
241
242 // Not exactly the same behavior as the vectorized code.
243 // The vectorized code above always rounds to even in halfway cases
244 // (https://software.intel.com/en-us/node/523819), but std::nearbyint
245 // does the same only when the current rounding mode is FE_TONEAREST.
246 // However, in practice, this should not be a problem because most cases
247 // use the default rounding mode FE_TONEAREST.
248 // Note that we cannot implement the same behavior as the vectorized code
249 // using std::round because it does rounding away from zero in halfway
250 // cases.
251 transformed = zero_point + std::nearbyint(transformed);
252 float clipped =
253 std::min(std::max(transformed, float(min_val)), float(max_val));
254 dst[i] = clipped;
255 }
256 }
257
258 template<>
259 struct Vectorized<c10::qint32> : public Vectorizedqi {
260 using size_type = int;
261 static constexpr size_type size() {
262 return 8;
263 }
264
265 static constexpr int float_num_vecs() {
266 return 1;
267 }
268
269 static constexpr int int_num_vecs() {
270 return 1;
271 }
272
273 using float_vec_return_type = std::array<Vectorized<float>, 1>;
274 using int_vec_return_type = std::array<Vectorized<c10::qint32>, 1>;
275 using value_type = c10::qint32::underlying;
276
277 public:
278 using Vectorizedqi::Vectorizedqi;
279 Vectorized() {}
280
281 Vectorized(__m256i vals_) { vals = vals_;}
282
283 // Broadcast constructor
284 Vectorized(const c10::qint32& val) {
285 value_type uw = val.val_;
286 vals = _mm256_set1_epi32(uw);
287 }
288
289 void store(void* ptr, int count = size()) const {
290 if (count != size()) {
291 memcpy(ptr, &vals, count * sizeof(value_type));
292 } else {
293 _mm256_storeu_si256((__m256i*)ptr, vals);
294 }
295 }
296
297 static Vectorized<c10::qint32> loadu(const void* ptr) {
298 return Vectorized<c10::qint32>(ptr);
299 }
300
301 static Vectorized<c10::qint32> loadu(const void* ptr, int64_t count) {
302 __at_align__ value_type tmp_values[size()];
303 // Ensure uninitialized memory does not change the output value See https://github.com/pytorch/pytorch/issues/32502
304 // for more details. We do not initialize arrays to zero using "={0}" because gcc would compile it to two
305 // instructions while a loop would be compiled to one instruction.
306 for (const auto i : c10::irange(size())) {
307 tmp_values[i] = 0;
308 }
309 std::memcpy(
310 tmp_values, reinterpret_cast<const value_type*>(ptr), count * sizeof(value_type));
311 return _mm256_loadu_si256((const __m256i*)tmp_values);
312 }
313
314 float_vec_return_type dequantize(
315 Vectorized<float> scale,
316 Vectorized<float> /*zero_point*/,
317 Vectorized<float> scale_zp_premul) const {
318 __m256 float_vals = _mm256_cvtepi32_ps(vals);
319 return {vec::fmadd(scale, Vectorized<float>(float_vals), scale_zp_premul)};
320 }
321
322 float_vec_return_type dequantize(
323 Vectorized<float> scale,
324 Vectorized<float> zero_point) const {
325 __m256 float_vals = _mm256_cvtepi32_ps(vals);
326 return {(Vectorized<float>(float_vals) - zero_point) * scale};
327 }
328
329 static Vectorized<c10::qint32> quantize(
330 const float_vec_return_type& rhs,
331 float scale,
332 int32_t zero_point,
333 float /*inverse_scale*/) {
334 Vectorized<c10::qint32> retval;
335 auto rhs_data = (__m256)rhs[0];
336 at::native::quantize_vec<c10::qint32, /*precision=*/32>(
337 scale, zero_point, (float*)&rhs_data, (c10::qint32*)&retval.vals, 8);
338 return retval;
339 }
340
341 Vectorized<c10::qint32> maximum(Vectorized<c10::qint32> b) const {
342 return _mm256_max_epi32(vals, b.vals);
343 }
344
345 Vectorized<c10::qint32> minimum(Vectorized<c10::qint32> b) const {
346 return _mm256_min_epi32(vals, b.vals);
347 }
348
349 Vectorized<c10::qint32> relu(Vectorized<c10::qint32> zero_point) const {
350 return maximum(zero_point);
351 }
352
353 Vectorized<c10::qint32> relu6(
354 Vectorized<c10::qint32> zero_point,
355 Vectorized<c10::qint32> q_six) {
356 return _mm256_min_epi32(
357 _mm256_max_epi32(vals, zero_point.vals), q_six.vals);
358 }
359
360 int_vec_return_type widening_subtract(Vectorized<c10::qint32> b) const {
361 return {_mm256_sub_epi32(vals, b)};
362 }
363
364 static Vectorized<c10::qint32> requantize_from_int(
365 const int_vec_return_type& inp,
366 float multiplier,
367 int32_t zero_point) {
368 __m256 multiplier_v = _mm256_set1_ps(multiplier);
369 __m256i zero_point_v = _mm256_set1_epi32(zero_point);
370
371 __m256 scaled = _mm256_mul_ps(_mm256_cvtepi32_ps(inp[0]), multiplier_v);
372 __m256i rounded = _mm256_cvtps_epi32(scaled);
373 return _mm256_add_epi32(rounded, zero_point_v);
374 }
375
376 private:
377 // Load from memory constructor
378 Vectorized(const void* ptr) {
379 vals = _mm256_loadu_si256((const __m256i*)ptr);
380 }
381 };
382
383 template <>
384 Vectorized<c10::qint32> inline maximum(const Vectorized<c10::qint32>& a, const Vectorized<c10::qint32>& b) {
385 return a.maximum(b);
386 }
387
388 template <>
389 Vectorized<c10::qint32> inline operator*(
390 const Vectorized<c10::qint32>& a,
391 const Vectorized<c10::qint32>& b) {
392 return _mm256_mullo_epi32(a, b);
393 }
394
395 template <>
396 Vectorized<c10::qint32> inline operator+(
397 const Vectorized<c10::qint32>& a,
398 const Vectorized<c10::qint32>& b) {
399 return _mm256_add_epi32(a, b);
400 }
401
402 /*
403 * Convert values from int32 back to int8/uint8
404 */
405 template <typename T>
406 __m256i RequantizeAvx2(
407 const std::array<Vectorized<c10::qint32>, 4>& inp,
408 __m256 multiplier,
409 __m256i zp) {
410 static_assert(
411 std::is_same_v<T, int8_t> || std::is_same_v<T, uint8_t>,
412 "Only int8_t/uint8_t are supported");
413 constexpr auto min_val = std::numeric_limits<T>::min();
414 constexpr auto max_val = std::numeric_limits<T>::max();
415 __m256i permute_mask_v =
416 _mm256_set_epi32(0x07, 0x03, 0x06, 0x02, 0x05, 0x01, 0x04, 0x00);
417 __m256 x_scaled_v = _mm256_mul_ps(_mm256_cvtepi32_ps(inp[0]), multiplier);
418 __m256 y_scaled_v = _mm256_mul_ps(_mm256_cvtepi32_ps(inp[1]), multiplier);
419 __m256 z_scaled_v = _mm256_mul_ps(_mm256_cvtepi32_ps(inp[2]), multiplier);
420 __m256 w_scaled_v = _mm256_mul_ps(_mm256_cvtepi32_ps(inp[3]), multiplier);
421
422 __m256i x_rounded_v = _mm256_cvtps_epi32(x_scaled_v);
423 __m256i y_rounded_v = _mm256_cvtps_epi32(y_scaled_v);
424 __m256i z_rounded_v = _mm256_cvtps_epi32(z_scaled_v);
425 __m256i w_rounded_v = _mm256_cvtps_epi32(w_scaled_v);
426
427 /* Add zero point */
428 __m256i x_v = _mm256_add_epi32(x_rounded_v, zp);
429 __m256i y_v = _mm256_add_epi32(y_rounded_v, zp);
430 __m256i z_v = _mm256_add_epi32(z_rounded_v, zp);
431 __m256i w_v = _mm256_add_epi32(w_rounded_v, zp);
432
433 /* Pack to int16_t and saturate */
434 __m256i xy_packed_v = _mm256_packs_epi32(x_v, y_v);
435 __m256i zw_packed_v = _mm256_packs_epi32(z_v, w_v);
436
437 __m256i xyzw_clamped_v =
438 pack_saturate_and_clamp<T>(xy_packed_v, zw_packed_v, min_val, max_val);
439
440 /*
441 * xyzw_clamped_v has results in the following layout so we need to
442 * permute: x0-3 y0-3 z0-3 w0-3 x4-7 y4-7 z4-7 w4-7
443 */
444 xyzw_clamped_v = _mm256_permutevar8x32_epi32(xyzw_clamped_v, permute_mask_v);
445 return xyzw_clamped_v;
446 }
447
448 template<>
449 struct Vectorized<c10::qint8> : public Vectorizedqi {
450 static constexpr int size() {
451 return 32;
452 }
453
454 static constexpr int float_num_vecs() {
455 return 4;
456 }
457
458 static constexpr int int_num_vecs() {
459 return 4;
460 }
461
462 using float_vec_return_type = std::array<Vectorized<float>, 4>;
463 using int_vec_return_type = std::array<Vectorized<c10::qint32>, 4>;
464 using value_type = typename c10::qint8::underlying;
465
466 public:
467 using Vectorizedqi::Vectorizedqi;
468
469 Vectorized() {}
470 Vectorized(__m256i vals_) { vals = vals_;}
471
472 // Broadcast constructor
473 Vectorized(const c10::qint8& val) {
474 value_type uw = val.val_;
475 vals = _mm256_set1_epi8(uw);
476 }
477
478 // This is needed because the compiler emits awful code for the default
479 // constructor for moving the enum
480 // NOLINTNEXTLINE(clang-diagnostic-deprecated-copy)
481 C10_CLANG_DIAGNOSTIC_PUSH()
482 #if C10_CLANG_HAS_WARNING("-Wdeprecated-copy")
483 C10_CLANG_DIAGNOSTIC_IGNORE("-Wdeprecated-copy")
484 #endif
485 Vectorized(const Vectorized<c10::qint8>& other) : Vectorizedqi(other.vals) { }
486 C10_CLANG_DIAGNOSTIC_POP()
487
488 void store(void* ptr, int count = size()) const {
489 if (count != size()) {
490 memcpy(ptr, &vals, count * sizeof(value_type));
491 } else {
492 _mm256_storeu_si256((__m256i*)ptr, vals);
493 }
494 }
495
496 static Vectorized<c10::qint8> loadu(const void* ptr) {
497 return Vectorized<c10::qint8>(ptr);
498 }
499
500 static Vectorized<c10::qint8> loadu(const void* ptr, int64_t count) {
501 __at_align__ value_type tmp_values[size()];
502 // Ensure uninitialized memory does not change the output value See https://github.com/pytorch/pytorch/issues/32502
503 // for more details. We do not initialize arrays to zero using "={0}" because gcc would compile it to two
504 // instructions while a loop would be compiled to one instruction.
505 for (const auto i : c10::irange(size())) {
506 tmp_values[i] = 0;
507 }
508 std::memcpy(
509 tmp_values, reinterpret_cast<const value_type*>(ptr), count * sizeof(value_type));
510 return _mm256_loadu_si256((const __m256i*)tmp_values);
511 }
512
513 private:
514 __m256i cvtepi8_epi32(__m128i epi8_vals) const {
515 return _mm256_cvtepi8_epi32(epi8_vals);
516 }
517
518 public:
519 float_vec_return_type dequantize(
520 Vectorized<float> scale,
521 Vectorized<float> /*zero_point*/,
522 Vectorized<float> scale_neg_zp_premul) const {
523 __m128i int_val0 = _mm_set1_epi64x(_mm256_extract_epi64(vals, 0));
524 __m128i int_val1 = _mm_set1_epi64x(_mm256_extract_epi64(vals, 1));
525 __m128i int_val2 = _mm_set1_epi64x(_mm256_extract_epi64(vals, 2));
526 __m128i int_val3 = _mm_set1_epi64x(_mm256_extract_epi64(vals, 3));
527
528 __m256 float_val0 = _mm256_cvtepi32_ps(cvtepi8_epi32(int_val0));
529 __m256 float_val1 = _mm256_cvtepi32_ps(cvtepi8_epi32(int_val1));
530 __m256 float_val2 = _mm256_cvtepi32_ps(cvtepi8_epi32(int_val2));
531 __m256 float_val3 = _mm256_cvtepi32_ps(cvtepi8_epi32(int_val3));
532
533 auto val0 =
534 vec::fmadd(scale, Vectorized<float>(float_val0), scale_neg_zp_premul);
535 auto val1 =
536 vec::fmadd(scale, Vectorized<float>(float_val1), scale_neg_zp_premul);
537 auto val2 =
538 vec::fmadd(scale, Vectorized<float>(float_val2), scale_neg_zp_premul);
539 auto val3 =
540 vec::fmadd(scale, Vectorized<float>(float_val3), scale_neg_zp_premul);
541 return {val0, val1, val2, val3};
542 }
543
544 float_vec_return_type dequantize(
545 Vectorized<float> scale,
546 Vectorized<float> zero_point) const {
547 __m128i int_val0 = _mm_set1_epi64x(_mm256_extract_epi64(vals, 0));
548 __m128i int_val1 = _mm_set1_epi64x(_mm256_extract_epi64(vals, 1));
549 __m128i int_val2 = _mm_set1_epi64x(_mm256_extract_epi64(vals, 2));
550 __m128i int_val3 = _mm_set1_epi64x(_mm256_extract_epi64(vals, 3));
551
552 __m256 float_val0 = _mm256_cvtepi32_ps(cvtepi8_epi32(int_val0));
553 __m256 float_val1 = _mm256_cvtepi32_ps(cvtepi8_epi32(int_val1));
554 __m256 float_val2 = _mm256_cvtepi32_ps(cvtepi8_epi32(int_val2));
555 __m256 float_val3 = _mm256_cvtepi32_ps(cvtepi8_epi32(int_val3));
556
557 auto val0 = (Vectorized<float>(float_val0) - zero_point) * scale;
558 auto val1 = (Vectorized<float>(float_val1) - zero_point) * scale;
559 auto val2 = (Vectorized<float>(float_val2) - zero_point) * scale;
560 auto val3 = (Vectorized<float>(float_val3) - zero_point) * scale;
561 return {val0, val1, val2, val3};
562 }
563
564 static Vectorized<c10::qint8> quantize(
565 const float_vec_return_type& rhs,
566 float /*scale*/,
567 int32_t zero_point,
568 float inverse_scale) {
569 auto* rhs_data = (float*)rhs.data();
570 int8_t quantized_values[32];
571 QuantizeAvx2<value_type>(
572 rhs_data, quantized_values, 32, inverse_scale, zero_point);
573 return Vectorized<c10::qint8>::loadu(quantized_values);
574 }
575
576 Vectorized<c10::qint8> maximum(Vectorized<c10::qint8> b) const {
577 return _mm256_max_epi8(vals, b.vals);
578 }
579
580 Vectorized<c10::qint8> minimum(Vectorized<c10::qint8> b) const {
581 return _mm256_min_epi8(vals, b.vals);
582 }
583
584 Vectorized<c10::qint8> relu(Vectorized<c10::qint8> zero_point) const {
585 return maximum(zero_point);
586 }
587
588 Vectorized<c10::qint8> relu6(
589 Vectorized<c10::qint8> zero_point,
590 Vectorized<c10::qint8> q_six) {
591 return _mm256_min_epi8(
592 _mm256_max_epi8(vals, zero_point.vals), q_six.vals);
593 }
594
595 int_vec_return_type widening_subtract(Vectorized<c10::qint8> b) const {
596 __m128i int_val0 = _mm_set1_epi64x(_mm256_extract_epi64(vals, 0));
597 __m128i int_val1 = _mm_set1_epi64x(_mm256_extract_epi64(vals, 1));
598 __m128i int_val2 = _mm_set1_epi64x(_mm256_extract_epi64(vals, 2));
599 __m128i int_val3 = _mm_set1_epi64x(_mm256_extract_epi64(vals, 3));
600
601 __m256i int32_val0 = cvtepi8_epi32(int_val0);
602 __m256i int32_val1 = cvtepi8_epi32(int_val1);
603 __m256i int32_val2 = cvtepi8_epi32(int_val2);
604 __m256i int32_val3 = cvtepi8_epi32(int_val3);
605
606 __m128i int_b0 = _mm_set1_epi64x(_mm256_extract_epi64(b, 0));
607 __m128i int_b1 = _mm_set1_epi64x(_mm256_extract_epi64(b, 1));
608 __m128i int_b2 = _mm_set1_epi64x(_mm256_extract_epi64(b, 2));
609 __m128i int_b3 = _mm_set1_epi64x(_mm256_extract_epi64(b, 3));
610
611 __m256i int32_b0 = cvtepi8_epi32(int_b0);
612 __m256i int32_b1 = cvtepi8_epi32(int_b1);
613 __m256i int32_b2 = cvtepi8_epi32(int_b2);
614 __m256i int32_b3 = cvtepi8_epi32(int_b3);
615
616 __m256i res_0 = _mm256_sub_epi32(int32_val0, int32_b0);
617 __m256i res_1 = _mm256_sub_epi32(int32_val1, int32_b1);
618 __m256i res_2 = _mm256_sub_epi32(int32_val2, int32_b2);
619 __m256i res_3 = _mm256_sub_epi32(int32_val3, int32_b3);
620
621 return {Vectorized<c10::qint32>(res_0),
622 Vectorized<c10::qint32>(res_1),
623 Vectorized<c10::qint32>(res_2),
624 Vectorized<c10::qint32>(res_3)};
625 }
626
627 static Vectorized<c10::qint8> requantize_from_int(
628 const int_vec_return_type& inp,
629 float multiplier,
630 int32_t zero_point) {
631 __m256 multiplier_v = _mm256_set1_ps(multiplier);
632 __m256i zero_point_v = _mm256_set1_epi32(zero_point);
633 return RequantizeAvx2<value_type>(inp, multiplier_v, zero_point_v);
634 }
635
636 private:
637 // Load from memory constructor
638 Vectorized(const void* ptr) {
639 vals = _mm256_loadu_si256((const __m256i*)ptr);
640 }
641 };
642
643 template <>
644 Vectorized<c10::qint8> inline maximum(const Vectorized<c10::qint8>& a, const Vectorized<c10::qint8>& b) {
645 return a.maximum(b);
646 }
647
648 template<>
649 struct Vectorized<c10::quint8> : public Vectorizedqi {
650 static constexpr int size() {
651 return 32;
652 }
653
654 static constexpr int float_num_vecs() {
655 return 4;
656 }
657
658 static constexpr int int_num_vecs() {
659 return 4;
660 }
661
662 using float_vec_return_type = std::array<Vectorized<float>, 4>;
663 using int_vec_return_type = std::array<Vectorized<c10::qint32>, 4>;
664 using value_type = typename c10::quint8::underlying;
665
666 public:
667 using Vectorizedqi::Vectorizedqi;
668 Vectorized() {}
669
670 Vectorized(__m256i vals_) { vals = vals_;}
671
672 // Broadcast constructor
673 Vectorized(const c10::quint8& val) {
674 value_type uw = val.val_;
675 vals = _mm256_set1_epi8(uw);
676 }
677
678 // NOLINTNEXTLINE(clang-diagnostic-deprecated-copy)
679 C10_CLANG_DIAGNOSTIC_PUSH()
680 #if C10_CLANG_HAS_WARNING("-Wdeprecated-copy")
681 C10_CLANG_DIAGNOSTIC_IGNORE("-Wdeprecated-copy")
682 #endif
683 Vectorized(const Vectorized<c10::quint8>& other) : Vectorizedqi(other.vals) { }
684 C10_CLANG_DIAGNOSTIC_POP()
685
686 void store(void* ptr, int count = size()) const {
687 if (count != size()) {
688 memcpy(ptr, &vals, count * sizeof(value_type));
689 } else {
690 _mm256_storeu_si256((__m256i*)ptr, vals);
691 }
692 }
693
694 static Vectorized<c10::quint8> loadu(const void* ptr) {
695 return Vectorized<c10::quint8>(ptr);
696 }
697
698 static Vectorized<c10::quint8> loadu(const void* ptr, int64_t count) {
699 __at_align__ value_type tmp_values[size()];
700 // Ensure uninitialized memory does not change the output value See https://github.com/pytorch/pytorch/issues/32502
701 // for more details. We do not initialize arrays to zero using "={0}" because gcc would compile it to two
702 // instructions while a loop would be compiled to one instruction.
703 for (const auto i : c10::irange(size())) {
704 tmp_values[i] = 0;
705 }
706 std::memcpy(
707 tmp_values, reinterpret_cast<const value_type*>(ptr), count * sizeof(value_type));
708 return _mm256_loadu_si256((const __m256i*)tmp_values);
709 }
710
711 private:
712 __m256i cvtepu8_epi32(__m128i epu8_vals) const {
713 return _mm256_cvtepu8_epi32(epu8_vals);
714 }
715
716 public:
717 float_vec_return_type dequantize(
718 Vectorized<float> scale,
719 Vectorized<float> /*zero_point*/,
720 Vectorized<float> scale_zp_premul) const {
721 __m128i int_val0 = _mm_set1_epi64x(_mm256_extract_epi64(vals, 0));
722 __m128i int_val1 = _mm_set1_epi64x(_mm256_extract_epi64(vals, 1));
723 __m128i int_val2 = _mm_set1_epi64x(_mm256_extract_epi64(vals, 2));
724 __m128i int_val3 = _mm_set1_epi64x(_mm256_extract_epi64(vals, 3));
725
726 __m256 float_val0 = _mm256_cvtepi32_ps(cvtepu8_epi32(int_val0));
727 __m256 float_val1 = _mm256_cvtepi32_ps(cvtepu8_epi32(int_val1));
728 __m256 float_val2 = _mm256_cvtepi32_ps(cvtepu8_epi32(int_val2));
729 __m256 float_val3 = _mm256_cvtepi32_ps(cvtepu8_epi32(int_val3));
730
731 auto val0 =
732 vec::fmadd(scale, Vectorized<float>(float_val0), scale_zp_premul);
733 auto val1 =
734 vec::fmadd(scale, Vectorized<float>(float_val1), scale_zp_premul);
735 auto val2 =
736 vec::fmadd(scale, Vectorized<float>(float_val2), scale_zp_premul);
737 auto val3 =
738 vec::fmadd(scale, Vectorized<float>(float_val3), scale_zp_premul);
739 return {val0, val1, val2, val3};
740 }
741
742 float_vec_return_type dequantize(
743 Vectorized<float> scale,
744 Vectorized<float> zero_point) const {
745 __m128i int_val0 = _mm_set1_epi64x(_mm256_extract_epi64(vals, 0));
746 __m128i int_val1 = _mm_set1_epi64x(_mm256_extract_epi64(vals, 1));
747 __m128i int_val2 = _mm_set1_epi64x(_mm256_extract_epi64(vals, 2));
748 __m128i int_val3 = _mm_set1_epi64x(_mm256_extract_epi64(vals, 3));
749
750 __m256 float_val0 = _mm256_cvtepi32_ps(cvtepu8_epi32(int_val0));
751 __m256 float_val1 = _mm256_cvtepi32_ps(cvtepu8_epi32(int_val1));
752 __m256 float_val2 = _mm256_cvtepi32_ps(cvtepu8_epi32(int_val2));
753 __m256 float_val3 = _mm256_cvtepi32_ps(cvtepu8_epi32(int_val3));
754
755 auto val0 = (Vectorized<float>(float_val0) - zero_point) * scale;
756 auto val1 = (Vectorized<float>(float_val1) - zero_point) * scale;
757 auto val2 = (Vectorized<float>(float_val2) - zero_point) * scale;
758 auto val3 = (Vectorized<float>(float_val3) - zero_point) * scale;
759 return {val0, val1, val2, val3};
760 }
761
762 static Vectorized<c10::quint8> quantize(
763 const float_vec_return_type& rhs,
764 float /*scale*/,
765 int32_t zero_point,
766 float inverse_scale) {
767 auto* rhs_data = (float*)rhs.data();
768 uint8_t quantized_values[32];
769 QuantizeAvx2<value_type>(
770 rhs_data, quantized_values, 32, inverse_scale, zero_point);
771 return Vectorized<c10::quint8>::loadu(quantized_values);
772 }
773
774 Vectorized<c10::quint8> maximum(Vectorized<c10::quint8> b) const {
775 return _mm256_max_epu8(vals, b.vals);
776 }
777
778 Vectorized<c10::quint8> minimum(Vectorized<c10::quint8> b) const {
779 return _mm256_min_epu8(vals, b.vals);
780 }
781
782 Vectorized<c10::quint8> relu(Vectorized<c10::quint8> zero_point) const {
783 return maximum(zero_point);
784 }
785
786 Vectorized<c10::quint8> relu6(
787 Vectorized<c10::quint8> zero_point,
788 Vectorized<c10::quint8> q_six) {
789 return _mm256_min_epu8(
790 _mm256_max_epu8(vals, zero_point.vals), q_six.vals);
791 }
792
793 int_vec_return_type widening_subtract(Vectorized<c10::quint8> b) const {
794 __m128i int_val0 = _mm_set1_epi64x(_mm256_extract_epi64(vals, 0));
795 __m128i int_val1 = _mm_set1_epi64x(_mm256_extract_epi64(vals, 1));
796 __m128i int_val2 = _mm_set1_epi64x(_mm256_extract_epi64(vals, 2));
797 __m128i int_val3 = _mm_set1_epi64x(_mm256_extract_epi64(vals, 3));
798
799 __m256i int32_val0 = cvtepu8_epi32(int_val0);
800 __m256i int32_val1 = cvtepu8_epi32(int_val1);
801 __m256i int32_val2 = cvtepu8_epi32(int_val2);
802 __m256i int32_val3 = cvtepu8_epi32(int_val3);
803
804 __m128i int_b0 = _mm_set1_epi64x(_mm256_extract_epi64(b, 0));
805 __m128i int_b1 = _mm_set1_epi64x(_mm256_extract_epi64(b, 1));
806 __m128i int_b2 = _mm_set1_epi64x(_mm256_extract_epi64(b, 2));
807 __m128i int_b3 = _mm_set1_epi64x(_mm256_extract_epi64(b, 3));
808
809 __m256i int32_b0 = cvtepu8_epi32(int_b0);
810 __m256i int32_b1 = cvtepu8_epi32(int_b1);
811 __m256i int32_b2 = cvtepu8_epi32(int_b2);
812 __m256i int32_b3 = cvtepu8_epi32(int_b3);
813
814 __m256i res_0 = _mm256_sub_epi32(int32_val0, int32_b0);
815 __m256i res_1 = _mm256_sub_epi32(int32_val1, int32_b1);
816 __m256i res_2 = _mm256_sub_epi32(int32_val2, int32_b2);
817 __m256i res_3 = _mm256_sub_epi32(int32_val3, int32_b3);
818 return {Vectorized<c10::qint32>(res_0),
819 Vectorized<c10::qint32>(res_1),
820 Vectorized<c10::qint32>(res_2),
821 Vectorized<c10::qint32>(res_3)};
822 }
823
824 static Vectorized<c10::quint8> requantize_from_int(
825 const int_vec_return_type& inp,
826 float multiplier,
827 int32_t zero_point) {
828 __m256 multiplier_v = _mm256_set1_ps(multiplier);
829 __m256i zero_point_v = _mm256_set1_epi32(zero_point);
830 return RequantizeAvx2<value_type>(inp, multiplier_v, zero_point_v);
831 }
832
833 private:
834
835 // Load from memory constructor
836 Vectorized(const void* ptr) {
837 vals = _mm256_loadu_si256((const __m256i*)ptr);
838 }
839 };
840
841 template <>
842 Vectorized<c10::quint8> inline maximum(const Vectorized<c10::quint8>& a, const Vectorized<c10::quint8>& b) {
843 return a.maximum(b);
844 }
845
846 #else
847
848 // NOTE: These are low-performance implementations that we fall back on
849 // if we are not building with AVX2. This may not be an issue, because
850 // currently for quantization we assume the user has at least AVX512
851 // installed, so these can simply act as a reference implementation.
852 //
853 // If in the future we relax this requirement (AVX2+), we should probably
854 // revisit these implementations
855
856 template <
857 typename T,
858 typename float_vec_return_type_,
859 typename int_vec_return_type_,
860 int size_>
861 struct VectorizedQuantizedConverter {
862 static constexpr int size() {
863 return size_;
864 }
865
866 static constexpr int float_num_vecs() {
867 return size() / 8;
868 }
869
870 static constexpr int int_num_vecs() {
871 return size() / 8;
872 }
873
874 using float_vec_return_type = float_vec_return_type_;
875 using int_vec_return_type = int_vec_return_type_;
876
877 using value_type = typename T::underlying;
878 std::array<value_type, size_> vals;
879
880 VectorizedQuantizedConverter(T val) {
881 for (const auto i : c10::irange(size())) {
882 vals[i] = val.val_;
883 }
884 }
885
886 VectorizedQuantizedConverter(const void* ptr) {
887 memcpy(vals.data(), ptr, sizeof(value_type) * size());
888 }
889
890 void store(void* ptr, int count = size()) const {
891 memcpy(ptr, vals.data(), count * sizeof(value_type));
892 }
893
894 float_vec_return_type dequantize(
895 Vectorized<float> scale,
896 Vectorized<float> zero_point,
897 Vectorized<float> /*scale_zp_premul*/) const {
898 float_vec_return_type rv;
899 for (const auto i : c10::irange(float_num_vecs())) {
900 float tmp_vals[8];
901 for (const auto j : c10::irange(8)) {
902 tmp_vals[j] = at::native::dequantize_val<T>(
903 scale[j], zero_point[j], T(vals[8 * i + j]));
904 }
905 rv[i] = Vectorized<float>(tmp_vals[0],
906 tmp_vals[1],
907 tmp_vals[2],
908 tmp_vals[3],
909 tmp_vals[4],
910 tmp_vals[5],
911 tmp_vals[6],
912 tmp_vals[7]);
913 }
914 return rv;
915 }
916
917 float_vec_return_type dequantize(
918 Vectorized<float> scale,
919 Vectorized<float> zero_point) const {
920 Vectorized<float> scale_zp_premul;
921 return dequantize(scale, zero_point, scale_zp_premul);
922 }
923
924 protected:
925 VectorizedQuantizedConverter() {}
926 };
927
928 template <>
929 struct Vectorized<c10::qint32> : public VectorizedQuantizedConverter<
930 c10::qint32,
931 std::array<Vectorized<float>, 1>,
932 std::array<Vectorized<c10::qint32>, 1>,
933 8> {
934 Vectorized()
935 : VectorizedQuantizedConverter<
936 c10::qint32,
937 std::array<Vectorized<float>, 1>,
938 std::array<Vectorized<c10::qint32>, 1>,
939 8>() {}
940 Vectorized(c10::qint32 val)
941 : VectorizedQuantizedConverter<
942 c10::qint32,
943 std::array<Vectorized<float>, 1>,
944 std::array<Vectorized<c10::qint32>, 1>,
945 8>(val) {}
946 Vectorized(const void* ptr)
947 : VectorizedQuantizedConverter<
948 c10::qint32,
949 std::array<Vectorized<float>, 1>,
950 std::array<Vectorized<c10::qint32>, 1>,
951 8>(ptr) {}
952
953 static Vectorized<c10::qint32> loadu(const void* ptr) {
954 return Vectorized<c10::qint32>(ptr);
955 }
956
957 static Vectorized<c10::qint32> loadu(const void* ptr, int64_t count) {
958 __at_align__ value_type tmp_values[size()];
959 // Ensure uninitialized memory does not change the output value See https://github.com/pytorch/pytorch/issues/32502
960 // for more details. We do not initialize arrays to zero using "={0}" because gcc would compile it to two
961 // instructions while a loop would be compiled to one instruction.
962 for (const auto i : c10::irange(size())) {
963 tmp_values[i] = 0;
964 }
965 std::memcpy(
966 tmp_values, reinterpret_cast<const value_type*>(ptr), count * sizeof(value_type));
967 return Vectorized<c10::qint32>(tmp_values);
968 }
969
970 static Vectorized<c10::qint32> quantize(
971 const float_vec_return_type& rhs,
972 float scale,
973 int32_t zero_point,
974 float /*inverse_scale*/) {
975 std::array<value_type, size()> qvals;
976 std::array<float, float_num_vecs() * 8> float_vals;
977
978 for (const auto i : c10::irange(float_num_vecs())) {
979 rhs[i].store(&float_vals[i * 8], 8);
980 }
981
982 at::native::quantize_vec<c10::qint32, /*precision=*/32>(
983 scale,
984 zero_point,
985 float_vals.data(),
986 (c10::qint32*)qvals.data(),
987 8 * float_num_vecs());
988
989 return Vectorized<c10::qint32>::loadu(qvals.data());
990 }
991
992 Vectorized<c10::qint32> maximum(Vectorized<c10::qint32> b) const {
993 Vectorized<c10::qint32> retval;
994 for (const auto i : c10::irange(size())) {
995 retval.vals[i] = std::max<value_type>(vals[i], b.vals[i]);
996 }
997 return retval;
998 }
999
1000 Vectorized<c10::qint32> minimum(Vectorized<c10::qint32> b) const {
1001 Vectorized<c10::qint32> retval;
1002 for (const auto i : c10::irange(size())) {
1003 retval.vals[i] = std::min<value_type>(vals[i], b.vals[i]);
1004 }
1005 return retval;
1006 }
1007
1008 Vectorized<c10::qint32> relu(Vectorized<c10::qint32> zero_point) const {
1009 return maximum(zero_point);
1010 }
1011
1012
1013 Vectorized<c10::qint32> relu6(
1014 Vectorized<c10::qint32> zero_point,
1015 Vectorized<c10::qint32> q_six) {
1016 Vectorized<c10::qint32> retval;
1017 for (const auto i : c10::irange(size())) {
1018 retval.vals[i] = std::min<value_type>(
1019 std::max<value_type>(vals[i], zero_point.vals[i]), q_six.vals[i]);
1020 }
1021 return retval;
1022 }
1023
1024 int_vec_return_type widening_subtract(Vectorized<c10::qint32> b) const {
1025 int_vec_return_type retval;
1026 for (const auto i : c10::irange(size())) {
1027 retval[0].vals[i] = vals[i] - b.vals[i];
1028 }
1029 return retval;
1030 }
1031
1032 static Vectorized<c10::qint32> requantize_from_int(
1033 const int_vec_return_type& inp,
1034 float multiplier,
1035 int32_t zero_point) {
1036 Vectorized<c10::qint32> retval;
1037 for (const auto i : c10::irange(size())) {
1038 retval.vals[i] =
1039 std::nearbyint(static_cast<float>(inp[0].vals[i]) * multiplier) +
1040 zero_point;
1041 }
1042 return retval;
1043 }
1044 };
1045
1046 template <>
1047 Vectorized<c10::qint32> inline maximum(const Vectorized<c10::qint32>& a, const Vectorized<c10::qint32>& b) {
1048 return a.maximum(b);
1049 }
1050
1051 template <>
1052 Vectorized<c10::qint32> inline operator*(
1053 const Vectorized<c10::qint32>& a,
1054 const Vectorized<c10::qint32>& b) {
1055 Vectorized<c10::qint32> retval;
1056 for (const auto i : c10::irange(std::decay_t<decltype(a)>::size())) {
1057 retval.vals[i] = a.vals[i] * b.vals[i];
1058 }
1059 return retval;
1060 }
1061
1062 template <>
1063 Vectorized<c10::qint32> inline operator+(
1064 const Vectorized<c10::qint32>& a,
1065 const Vectorized<c10::qint32>& b) {
1066 Vectorized<c10::qint32> retval;
1067 for (const auto i : c10::irange(std::decay_t<decltype(a)>::size())) {
1068 retval.vals[i] = a.vals[i] + b.vals[i];
1069 }
1070 return retval;
1071 }
1072
1073 template <>
1074 struct Vectorized<c10::qint8> : public VectorizedQuantizedConverter<
1075 c10::qint8,
1076 std::array<Vectorized<float>, 4>,
1077 std::array<Vectorized<c10::qint32>, 4>,
1078 32> {
1079 Vectorized()
1080 : VectorizedQuantizedConverter<
1081 c10::qint8,
1082 std::array<Vectorized<float>, 4>,
1083 std::array<Vectorized<c10::qint32>, 4>,
1084 32>() {}
1085 Vectorized(c10::qint8 val)
1086 : VectorizedQuantizedConverter<
1087 c10::qint8,
1088 std::array<Vectorized<float>, 4>,
1089 std::array<Vectorized<c10::qint32>, 4>,
1090 32>(val) {}
1091 Vectorized(const void* ptr)
1092 : VectorizedQuantizedConverter<
1093 c10::qint8,
1094 std::array<Vectorized<float>, 4>,
1095 std::array<Vectorized<c10::qint32>, 4>,
1096 32>(ptr) {}
1097
1098 static Vectorized<c10::qint8> loadu(const void* ptr) {
1099 return Vectorized<c10::qint8>(ptr);
1100 }
1101
1102 static Vectorized<c10::qint8> loadu(const void* ptr, int64_t count) {
1103 __at_align__ value_type tmp_values[size()];
1104 // Ensure uninitialized memory does not change the output value See https://github.com/pytorch/pytorch/issues/32502
1105 // for more details. We do not initialize arrays to zero using "={0}" because gcc would compile it to two
1106 // instructions while a loop would be compiled to one instruction.
1107 for (const auto i : c10::irange(size())) {
1108 tmp_values[i] = 0;
1109 }
1110 std::memcpy(
1111 tmp_values, reinterpret_cast<const value_type*>(ptr), count * sizeof(value_type));
1112 return Vectorized<c10::qint8>(tmp_values);
1113 }
1114
1115 static Vectorized<c10::qint8> quantize(
1116 const float_vec_return_type& rhs,
1117 float scale,
1118 int32_t zero_point,
1119 float /*inverse_scale*/) {
1120 std::array<value_type, size()> qvals;
1121 std::array<float, float_num_vecs() * 8> float_vals;
1122
1123 for (const auto i : c10::irange(float_num_vecs())) {
1124 rhs[i].store(&float_vals[i * 8], 8);
1125 }
1126
1127 at::native::quantize_vec<c10::qint8>(
1128 scale,
1129 zero_point,
1130 float_vals.data(),
1131 (c10::qint8*)qvals.data(),
1132 8 * float_num_vecs());
1133
1134 return Vectorized<c10::qint8>::loadu(qvals.data());
1135 }
1136
1137 Vectorized<c10::qint8> maximum(Vectorized<c10::qint8> b) const {
1138 Vectorized<c10::qint8> retval;
1139 for (const auto i : c10::irange(size())) {
1140 retval.vals[i] = std::max<value_type>(vals[i], b.vals[i]);
1141 }
1142 return retval;
1143 }
1144
1145 Vectorized<c10::qint8> minimum(Vectorized<c10::qint8> b) const {
1146 Vectorized<c10::qint8> retval;
1147 for (const auto i : c10::irange(size())) {
1148 retval.vals[i] = std::min<value_type>(vals[i], b.vals[i]);
1149 }
1150 return retval;
1151 }
1152
1153 Vectorized<c10::qint8> relu(Vectorized<c10::qint8> zero_point) const {
1154 return maximum(zero_point);
1155 }
1156
1157 Vectorized<c10::qint8> relu6(
1158 Vectorized<c10::qint8> zero_point,
1159 Vectorized<c10::qint8> q_six) {
1160 Vectorized<c10::qint8> retval;
1161 for (const auto i : c10::irange(size())) {
1162 retval.vals[i] = std::min<value_type>(
1163 std::max<value_type>(vals[i], zero_point.vals[i]), q_six.vals[i]);
1164 }
1165 return retval;
1166 }
1167
1168 int_vec_return_type widening_subtract(Vectorized<c10::qint8> b) const {
1169 int_vec_return_type retval;
1170 constexpr int elem_per_int_vec = size() / int_num_vecs();
1171 for (const auto i : c10::irange(int_num_vecs())) {
1172 for (const auto j : c10::irange(elem_per_int_vec)) {
1173 retval[i].vals[j] =
1174 static_cast<int32_t>(vals[i * elem_per_int_vec + j]) -
1175 static_cast<int32_t>(b.vals[i * elem_per_int_vec + j]);
1176 }
1177 }
1178 return retval;
1179 }
1180 static Vectorized<c10::qint8> requantize_from_int(
1181 const int_vec_return_type& inp,
1182 float multiplier,
1183 int32_t zero_point) {
1184 constexpr int elem_per_int_vec = size() / int_num_vecs();
1185 constexpr auto min_val = std::numeric_limits<value_type>::min();
1186 constexpr auto max_val = std::numeric_limits<value_type>::max();
1187 Vectorized<c10::qint8> retval;
1188 for (const auto i : c10::irange(int_num_vecs())) {
1189 for (const auto j : c10::irange(elem_per_int_vec)) {
1190 int32_t rounded =
1191 std::nearbyint(static_cast<float>(inp[i].vals[j]) * multiplier) +
1192 zero_point;
1193 retval.vals[i * elem_per_int_vec + j] =
1194 std::min<int32_t>(std::max<int32_t>(rounded, min_val), max_val);
1195 }
1196 }
1197 return retval;
1198 }
1199 };
1200
1201 template <>
1202 Vectorized<c10::qint8> inline maximum(const Vectorized<c10::qint8>& a, const Vectorized<c10::qint8>& b) {
1203 return a.maximum(b);
1204 }
1205
1206 template <>
1207 struct Vectorized<c10::quint8> : public VectorizedQuantizedConverter<
1208 c10::quint8,
1209 std::array<Vectorized<float>, 4>,
1210 std::array<Vectorized<c10::qint32>, 4>,
1211 32> {
1212 Vectorized()
1213 : VectorizedQuantizedConverter<
1214 c10::quint8,
1215 std::array<Vectorized<float>, 4>,
1216 std::array<Vectorized<c10::qint32>, 4>,
1217 32>() {}
1218 Vectorized(c10::quint8 val)
1219 : VectorizedQuantizedConverter<
1220 c10::quint8,
1221 std::array<Vectorized<float>, 4>,
1222 std::array<Vectorized<c10::qint32>, 4>,
1223 32>(val) {}
1224 Vectorized(const void* ptr)
1225 : VectorizedQuantizedConverter<
1226 c10::quint8,
1227 std::array<Vectorized<float>, 4>,
1228 std::array<Vectorized<c10::qint32>, 4>,
1229 32>(ptr) {}
1230
1231 static Vectorized<c10::quint8> loadu(const void* ptr) {
1232 return Vectorized<c10::quint8>(ptr);
1233 }
1234
1235 static Vectorized<c10::quint8> loadu(const void* ptr, int64_t count) {
1236 __at_align__ value_type tmp_values[size()];
1237 // Ensure uninitialized memory does not change the output value See https://github.com/pytorch/pytorch/issues/32502
1238 // for more details. We do not initialize arrays to zero using "={0}" because gcc would compile it to two
1239 // instructions while a loop would be compiled to one instruction.
1240 for (const auto i : c10::irange(size())) {
1241 tmp_values[i] = 0;
1242 }
1243 std::memcpy(
1244 tmp_values, reinterpret_cast<const value_type*>(ptr), count * sizeof(value_type));
1245 return Vectorized<c10::quint8>(tmp_values);
1246 }
1247
1248 static Vectorized<c10::quint8> quantize(
1249 const float_vec_return_type& rhs,
1250 float scale,
1251 int32_t zero_point,
1252 float /*inverse_scale*/) {
1253 std::array<value_type, size()> qvals;
1254 std::array<float, float_num_vecs() * 8> float_vals;
1255
1256 for (const auto i : c10::irange(float_num_vecs())) {
1257 rhs[i].store(&float_vals[i * 8], 8);
1258 }
1259
1260 at::native::quantize_vec<c10::quint8>(
1261 scale,
1262 zero_point,
1263 float_vals.data(),
1264 (c10::quint8*)qvals.data(),
1265 8 * float_num_vecs());
1266
1267 return Vectorized<c10::quint8>::loadu(qvals.data());
1268 }
1269
1270 Vectorized<c10::quint8> maximum(Vectorized<c10::quint8> b) const {
1271 Vectorized<c10::quint8> retval;
1272 for (const auto i : c10::irange(size())) {
1273 retval.vals[i] = std::max<value_type>(vals[i], b.vals[i]);
1274 }
1275 return retval;
1276 }
1277
1278 Vectorized<c10::quint8> minimum(Vectorized<c10::quint8> b) const {
1279 Vectorized<c10::quint8> retval;
1280 for (const auto i : c10::irange(size())) {
1281 retval.vals[i] = std::min<value_type>(vals[i], b.vals[i]);
1282 }
1283 return retval;
1284 }
1285
1286 Vectorized<c10::quint8> relu(Vectorized<c10::quint8> zero_point) const {
1287 return maximum(zero_point);
1288 }
1289
1290
1291 Vectorized<c10::quint8> relu6(
1292 Vectorized<c10::quint8> zero_point,
1293 Vectorized<c10::quint8> q_six) {
1294 Vectorized<c10::quint8> retval;
1295 for (const auto i : c10::irange(size())) {
1296 retval.vals[i] = std::min<value_type>(
1297 std::max<value_type>(vals[i], zero_point.vals[i]), q_six.vals[i]);
1298 }
1299 return retval;
1300 }
1301
1302 int_vec_return_type widening_subtract(Vectorized<c10::quint8> b) const {
1303 int_vec_return_type retval;
1304 constexpr int elem_per_int_vec = size() / int_num_vecs();
1305 for (const auto i : c10::irange(int_num_vecs())) {
1306 for (const auto j : c10::irange(elem_per_int_vec)) {
1307 retval[i].vals[j] =
1308 static_cast<int32_t>(vals[i * elem_per_int_vec + j]) -
1309 static_cast<int32_t>(b.vals[i * elem_per_int_vec + j]);
1310 }
1311 }
1312 return retval;
1313 }
1314 static Vectorized<c10::quint8> requantize_from_int(
1315 const int_vec_return_type& inp,
1316 float multiplier,
1317 int32_t zero_point) {
1318 constexpr int elem_per_int_vec = size() / int_num_vecs();
1319 constexpr auto min_val = std::numeric_limits<value_type>::min();
1320 constexpr auto max_val = std::numeric_limits<value_type>::max();
1321 Vectorized<c10::quint8> retval;
1322 for (const auto i : c10::irange(int_num_vecs())) {
1323 for (const auto j : c10::irange(elem_per_int_vec)) {
1324 int32_t rounded =
1325 std::nearbyint(static_cast<float>(inp[i].vals[j]) * multiplier) +
1326 zero_point;
1327 retval.vals[i * elem_per_int_vec + j] =
1328 std::min<int32_t>(std::max<int32_t>(rounded, min_val), max_val);
1329 }
1330 }
1331 return retval;
1332 }
1333 };
1334
1335 template <>
1336 Vectorized<c10::quint8> inline maximum(const Vectorized<c10::quint8>& a, const Vectorized<c10::quint8>& b) {
1337 return a.maximum(b);
1338 }
1339
1340 #endif // if defined(CPU_CAPABILITY_AVX2)
1341
1342 #if defined(CPU_CAPABILITY_NEON)
1343 template <typename T>
1344 typename std::enable_if_t<std::is_same_v<T, int8_t>, at::vec::Vectorized<float>>
1345 inline convert_int8_to_float(at::vec::Vectorized<T> src) {
1346 // Note: this function only convert inputs number of elements equal to at::vec::Vectorized<float>.size()
1347 auto s8x8 = vld1_s8(src.operator const int8_t*());
1348 auto s16x8 = vmovl_s8(s8x8);
1349
1350 auto s32x4_hi = vmovl_s16(vget_high_s16(s16x8));
1351 auto s32x4_lo = vmovl_s16(vget_low_s16(s16x8));
1352
1353 return Vectorized<float>(vcvtq_f32_s32(s32x4_lo), vcvtq_f32_s32(s32x4_hi));
1354 }
1355
1356 template <typename T>
1357 typename std::enable_if_t<std::is_same_v<T, uint8_t>, at::vec::Vectorized<float>>
1358 inline convert_int8_to_float(at::vec::Vectorized<T> src) {
1359 // Note: this function only convert inputs number of elements equal to at::vec::Vectorized<float>.size()
1360 auto u8x8 = vld1_u8(src.operator const uint8_t*());
1361 auto u16x8 = vmovl_u8(u8x8);
1362 auto u32x4_hi = vmovl_u16(vget_high_u16(u16x8));
1363 auto u32x4_lo = vmovl_u16(vget_low_u16(u16x8));
1364
1365 return Vectorized<float>(vcvtq_f32_u32(u32x4_lo), vcvtq_f32_u32(u32x4_hi));
1366 }
1367
1368 #endif
1369 }} // namespace at::vec::CPU_CAPABILITY
1370