xref: /aosp_15_r20/external/pytorch/aten/src/ATen/cpu/vec/vec256/vec256_qint.h (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #pragma once
2 
3 // DO NOT DEFINE STATIC DATA IN THIS HEADER!
4 // See Note [Do not compile initializers with AVX]
5 
6 #include <ATen/cpu/vec/intrinsics.h>
7 #include <ATen/cpu/vec/vec_base.h>
8 #include <ATen/native/quantized/AffineQuantizerBase.h>
9 
10 #include <c10/util/irange.h>
11 #include <c10/util/qint32.h>
12 #include <c10/util/qint8.h>
13 #include <c10/util/quint8.h>
14 
15 #include <array>
16 #include <cmath>
17 
18 // This file defines Vectorized<> for the quantized types.
19 //
20 //
21 // Currently, we simply use these classes as efficient converters between
22 // the quantized types and Vectorized<float>, usually in bandwidth-bound cases
23 // where doing the arithmetic in full-precision is acceptable (e.g.
24 // elementwise operators).
25 //
26 //
27 // Conversions are as follows:
28 //  Vectorized<qint8> -> 4x Vectorized<float>
29 //  Vectorized<quint8> -> 4x Vectorized<float>
30 //  Vectorized<qint32> -> 1x Vectorized<float>
31 //
32 // The size of the returned float vector is specified by the special
33 // constexpr function float_num_vecs. The type of the value returned
34 // from dequantize (and expected as an argument to quantize) is
35 // specified by float_vec_return_type.
36 //
37 // When writing kernels with these vectors, it is expected that floating-
38 // point operations will be carried out in a loop over Vectorized<T>::float_num_vecs
39 // iterations.
40 
41 namespace at::vec {
42 inline namespace CPU_CAPABILITY {
43 
44 #if defined(CPU_CAPABILITY_AVX2)
45 
46 #ifdef _MSC_VER
47 __declspec(align(64)) struct Vectorizedqi {
48  protected:
49   __m256i vals;
50 #else
51 struct Vectorizedqi {
52  protected:
53   __m256i vals __attribute__((aligned(64)));
54 #endif
55 
56  public:
VectorizedqiVectorizedqi57   Vectorizedqi() {}
VectorizedqiVectorizedqi58   Vectorizedqi(__m256i v) : vals(v) {}
__m256iVectorizedqi59   operator __m256i() const {
60     return vals;
61   }
62 };
63 
64 template <typename T>
65 __m256i pack_saturate_and_clamp(
66     __m256i first,
67     __m256i second,
68     T min_val,
69     T max_val);
70 
71 template <>
72 inline __m256i pack_saturate_and_clamp<int32_t>(
73     __m256i /*first*/,
74     __m256i /*second*/,
75     int32_t /*min_val*/,
76     int32_t /*max_val*/) {
77   // This function is for linkage only, will not be used
78   AT_ERROR("pack_saturate_and_clamp<int32_t> is not supported");
79 }
80 
81 template <>
82 inline __m256i pack_saturate_and_clamp<int8_t>(
83     __m256i first,
84     __m256i second,
85     int8_t min_val,
86     int8_t max_val) {
87   __m256i packed_and_sat = _mm256_packs_epi16(first, second);
88   return _mm256_max_epi8(
89       _mm256_set1_epi8(min_val),
90       _mm256_min_epi8(packed_and_sat, _mm256_set1_epi8(max_val)));
91 }
92 
93 template <>
94 inline __m256i pack_saturate_and_clamp<uint8_t>(
95     __m256i first,
96     __m256i second,
97     uint8_t min_val,
98     uint8_t max_val) {
99   __m256i packed_and_sat = _mm256_packus_epi16(first, second);
100   return _mm256_max_epu8(
101       _mm256_set1_epi8(min_val),
102       _mm256_min_epu8(packed_and_sat, _mm256_set1_epi8(max_val)));
103 }
104 
105 template <typename T>
106 typename std::enable_if_t<std::is_same_v<T, uint8_t> || std::is_same_v<T, int8_t>, at::vec::Vectorized<float>>
convert_int8_to_float(at::vec::Vectorized<T> src)107 inline convert_int8_to_float(at::vec::Vectorized<T> src) {
108   // Note: this function only convert inputs number of elements equal to at::vec::Vectorized<float>.size()
109   // Only handle first 8*8 bits
110   __m128i input_128 = _mm256_castsi256_si128(src);
111   // Convert from 8*uint8/int8 to 8*int32
112   __m256i input_256_int32;
113   if constexpr (std::is_same_v<T, uint8_t>)
114     input_256_int32 = _mm256_cvtepu8_epi32(input_128);
115   else
116     input_256_int32 = _mm256_cvtepi8_epi32(input_128);
117   // Convert from 8*int32 to 8*float
118   return _mm256_cvtepi32_ps(input_256_int32);
119 }
120 
121 template <typename T>
122 typename std::enable_if_t<std::is_same_v<T, uint8_t> || std::is_same_v<T, int8_t>, at::vec::Vectorized<T>>
convert_float_to_int8(at::vec::Vectorized<float> src)123 inline convert_float_to_int8(at::vec::Vectorized<float> src) {
124   // Convert from float32 to int32 with truncation
125   __m256i x_values_int32 = _mm256_cvttps_epi32(src);
126 
127   // Convert from int32 to int16 using signed saturation
128   __m256i xy_packed_v = _mm256_packs_epi32(x_values_int32, x_values_int32);
129 
130   constexpr auto min_val = std::numeric_limits<T>::min();
131   constexpr auto max_val = std::numeric_limits<T>::max();
132 
133   // Convert from int16 to uint8/int8 using unsigned saturation
134   __m256i xyzw_clamped_v = pack_saturate_and_clamp<T>(
135       xy_packed_v, xy_packed_v, min_val, max_val);
136   __m256i permute_mask_v =
137     _mm256_set_epi32(0x07, 0x03, 0x06, 0x02, 0x05, 0x01, 0x04, 0x00);
138   return _mm256_permutevar8x32_epi32(xyzw_clamped_v, permute_mask_v);
139 }
140 
141 template <typename T>
QuantizeAvx2(const float * src,T * dst,int len,float inverse_scale,int64_t zero_point)142 __FORCE_INLINE void QuantizeAvx2(
143     const float* src,
144     T* dst,
145     int len,
146     float inverse_scale,
147     int64_t zero_point) {
148   constexpr int VLEN = 8;
149   constexpr auto min_val = std::numeric_limits<T>::min();
150   constexpr auto max_val = std::numeric_limits<T>::max();
151   const __m256i min_v = _mm256_set1_epi32(min_val);
152   const __m256i max_v = _mm256_set1_epi32(max_val);
153   // This is the largest int32 value < int32_max exactly representable in float
154   constexpr int32_t int32_float_max_val =
155       std::numeric_limits<int32_t>::max() - 127;
156   int i = 0;
157   __m256 inverse_scale_v = _mm256_set1_ps(inverse_scale);
158   // clang-format off
159   static const __m256i shuffle_mask_v = _mm256_set_epi8(
160       0xff, 0xff, 0xff, 0xff,
161       0xff, 0xff, 0xff, 0xff,
162       0xff, 0xff, 0xff, 0xff,
163       0x0c, 0x08, 0x04, 0x00,
164       0xff, 0xff, 0xff, 0xff,
165       0xff, 0xff, 0xff, 0xff,
166       0xff, 0xff, 0xff, 0xff,
167       0x0c, 0x08, 0x04, 0x00);
168   // clang-format on
169   __m256i permute_mask_v =
170       _mm256_set_epi32(0x07, 0x03, 0x06, 0x02, 0x05, 0x01, 0x04, 0x00);
171   __m256i permute_mask_l8_v =
172       _mm256_set_epi32(0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x04, 0x00);
173   int len_aligned = len / (VLEN * 4) * (VLEN * 4);
174   for (; i < len_aligned; i += 4 * VLEN) {
175     // x
176     __m256 x_vals = _mm256_load_ps(src + i);
177     __m256 x_transformed_v = _mm256_mul_ps(x_vals, inverse_scale_v);
178     // If the floating point value is greater than int32_max,
179     // _mm256_cvtps_epi32 converts them to -ve. Clip at int32_float_max_val to
180     // Clip at int32_float_max_val to avoid this.
181     x_transformed_v =
182         _mm256_min_ps(x_transformed_v, _mm256_set1_ps(int32_float_max_val));
183     // y
184     __m256 y_vals = _mm256_load_ps(src + i + VLEN);
185     __m256 y_transformed_v = _mm256_mul_ps(y_vals, inverse_scale_v);
186     y_transformed_v =
187         _mm256_min_ps(y_transformed_v, _mm256_set1_ps(int32_float_max_val));
188     // z
189     __m256 z_vals = _mm256_load_ps(src + i + 2 * VLEN);
190     __m256 z_transformed_v = _mm256_mul_ps(z_vals, inverse_scale_v);
191     z_transformed_v =
192         _mm256_min_ps(z_transformed_v, _mm256_set1_ps(int32_float_max_val));
193     // w
194     __m256 w_vals = _mm256_load_ps(src + i + 3 * VLEN);
195     __m256 w_transformed_v = _mm256_mul_ps(w_vals, inverse_scale_v);
196     w_transformed_v =
197         _mm256_min_ps(w_transformed_v, _mm256_set1_ps(int32_float_max_val));
198 
199     __m256i x_rounded_v = _mm256_cvtps_epi32(x_transformed_v);
200     __m256i y_rounded_v = _mm256_cvtps_epi32(y_transformed_v);
201     __m256i z_rounded_v = _mm256_cvtps_epi32(z_transformed_v);
202     __m256i w_rounded_v = _mm256_cvtps_epi32(w_transformed_v);
203 
204     // add zero point
205     x_rounded_v = _mm256_add_epi32(x_rounded_v, _mm256_set1_epi32(zero_point));
206     y_rounded_v = _mm256_add_epi32(y_rounded_v, _mm256_set1_epi32(zero_point));
207     z_rounded_v = _mm256_add_epi32(z_rounded_v, _mm256_set1_epi32(zero_point));
208     w_rounded_v = _mm256_add_epi32(w_rounded_v, _mm256_set1_epi32(zero_point));
209 
210     __m256i xy_packed_v = _mm256_packs_epi32(x_rounded_v, y_rounded_v);
211     __m256i zw_packed_v = _mm256_packs_epi32(z_rounded_v, w_rounded_v);
212     __m256i xyzw_clamped_v =
213         pack_saturate_and_clamp<T>(xy_packed_v, zw_packed_v, min_val, max_val);
214 
215     xyzw_clamped_v =
216         _mm256_permutevar8x32_epi32(xyzw_clamped_v, permute_mask_v);
217     _mm256_storeu_si256(reinterpret_cast<__m256i*>(dst + i), xyzw_clamped_v);
218   }
219 
220   // Additional 8-lane AVX2 version to take advantage when len is smaller
221   // based on fbgemm::QuantizeAvx2 (https://github.com/pytorch/FBGEMM)
222   for (; i < len / VLEN * VLEN; i += VLEN) {
223     __m256 x_vals = _mm256_load_ps(src + i);
224     __m256 x_transformed_v = _mm256_mul_ps(x_vals, inverse_scale_v);
225     x_transformed_v =
226         _mm256_min_ps(x_transformed_v, _mm256_set1_ps(int32_float_max_val));
227     __m256i x_rounded_v = _mm256_cvtps_epi32(x_transformed_v);
228     x_rounded_v = _mm256_add_epi32(x_rounded_v, _mm256_set1_epi32(zero_point));
229     __m256i x_clipped_v =
230         _mm256_max_epi32(min_v, _mm256_min_epi32(max_v, x_rounded_v));
231 
232     x_clipped_v = _mm256_shuffle_epi8(x_clipped_v, shuffle_mask_v);
233     x_clipped_v = _mm256_permutevar8x32_epi32(x_clipped_v, permute_mask_l8_v);
234     _mm_storel_epi64(
235         reinterpret_cast<__m128i*>(dst + i),
236         _mm256_castsi256_si128(x_clipped_v));
237   }
238 
239   for (; i < len; ++i) {
240     float transformed = src[i] * inverse_scale;
241 
242     // Not exactly the same behavior as the vectorized code.
243     // The vectorized code above always rounds to even in halfway cases
244     // (https://software.intel.com/en-us/node/523819), but std::nearbyint
245     // does the same only when the current rounding mode is FE_TONEAREST.
246     // However, in practice, this should not be a problem because most cases
247     // use the default rounding mode FE_TONEAREST.
248     // Note that we cannot implement the same behavior as the vectorized code
249     // using std::round because it does rounding away from zero in halfway
250     // cases.
251     transformed = zero_point + std::nearbyint(transformed);
252     float clipped =
253         std::min(std::max(transformed, float(min_val)), float(max_val));
254     dst[i] = clipped;
255   }
256 }
257 
258 template<>
259 struct Vectorized<c10::qint32> : public Vectorizedqi {
260     using size_type = int;
261     static constexpr size_type size() {
262         return 8;
263     }
264 
265     static constexpr int float_num_vecs() {
266         return 1;
267     }
268 
269     static constexpr int int_num_vecs() {
270         return 1;
271     }
272 
273     using float_vec_return_type = std::array<Vectorized<float>, 1>;
274     using int_vec_return_type = std::array<Vectorized<c10::qint32>, 1>;
275     using value_type = c10::qint32::underlying;
276 
277  public:
278     using Vectorizedqi::Vectorizedqi;
279     Vectorized() {}
280 
281     Vectorized(__m256i vals_) { vals = vals_;}
282 
283     // Broadcast constructor
284     Vectorized(const c10::qint32& val) {
285         value_type uw = val.val_;
286         vals = _mm256_set1_epi32(uw);
287     }
288 
289     void store(void* ptr, int count = size()) const {
290       if (count != size()) {
291         memcpy(ptr, &vals, count * sizeof(value_type));
292       } else {
293         _mm256_storeu_si256((__m256i*)ptr, vals);
294       }
295     }
296 
297     static Vectorized<c10::qint32> loadu(const void* ptr) {
298         return Vectorized<c10::qint32>(ptr);
299     }
300 
301     static Vectorized<c10::qint32> loadu(const void* ptr, int64_t count) {
302         __at_align__ value_type tmp_values[size()];
303         // Ensure uninitialized memory does not change the output value See https://github.com/pytorch/pytorch/issues/32502
304         // for more details. We do not initialize arrays to zero using "={0}" because gcc would compile it to two
305         // instructions while a loop would be compiled to one instruction.
306         for (const auto i : c10::irange(size())) {
307           tmp_values[i] = 0;
308         }
309         std::memcpy(
310             tmp_values, reinterpret_cast<const value_type*>(ptr), count * sizeof(value_type));
311         return _mm256_loadu_si256((const __m256i*)tmp_values);
312     }
313 
314     float_vec_return_type dequantize(
315         Vectorized<float> scale,
316         Vectorized<float> /*zero_point*/,
317         Vectorized<float> scale_zp_premul) const {
318       __m256 float_vals = _mm256_cvtepi32_ps(vals);
319       return {vec::fmadd(scale, Vectorized<float>(float_vals), scale_zp_premul)};
320     }
321 
322     float_vec_return_type dequantize(
323         Vectorized<float> scale,
324         Vectorized<float> zero_point) const {
325       __m256 float_vals = _mm256_cvtepi32_ps(vals);
326       return {(Vectorized<float>(float_vals) - zero_point) * scale};
327     }
328 
329     static Vectorized<c10::qint32> quantize(
330         const float_vec_return_type& rhs,
331         float scale,
332         int32_t zero_point,
333         float /*inverse_scale*/) {
334       Vectorized<c10::qint32> retval;
335       auto rhs_data = (__m256)rhs[0];
336       at::native::quantize_vec<c10::qint32, /*precision=*/32>(
337           scale, zero_point, (float*)&rhs_data, (c10::qint32*)&retval.vals, 8);
338       return retval;
339     }
340 
341     Vectorized<c10::qint32> maximum(Vectorized<c10::qint32> b) const {
342       return _mm256_max_epi32(vals, b.vals);
343     }
344 
345     Vectorized<c10::qint32> minimum(Vectorized<c10::qint32> b) const {
346       return _mm256_min_epi32(vals, b.vals);
347     }
348 
349     Vectorized<c10::qint32> relu(Vectorized<c10::qint32> zero_point) const {
350         return maximum(zero_point);
351     }
352 
353     Vectorized<c10::qint32> relu6(
354         Vectorized<c10::qint32> zero_point,
355         Vectorized<c10::qint32> q_six) {
356       return _mm256_min_epi32(
357           _mm256_max_epi32(vals, zero_point.vals), q_six.vals);
358     }
359 
360     int_vec_return_type widening_subtract(Vectorized<c10::qint32> b) const {
361       return {_mm256_sub_epi32(vals, b)};
362     }
363 
364     static Vectorized<c10::qint32> requantize_from_int(
365         const int_vec_return_type& inp,
366         float multiplier,
367         int32_t zero_point) {
368       __m256 multiplier_v = _mm256_set1_ps(multiplier);
369       __m256i zero_point_v = _mm256_set1_epi32(zero_point);
370 
371       __m256 scaled = _mm256_mul_ps(_mm256_cvtepi32_ps(inp[0]), multiplier_v);
372       __m256i rounded = _mm256_cvtps_epi32(scaled);
373       return _mm256_add_epi32(rounded, zero_point_v);
374     }
375 
376  private:
377     // Load from memory constructor
378     Vectorized(const void* ptr) {
379       vals = _mm256_loadu_si256((const __m256i*)ptr);
380     }
381 };
382 
383 template <>
384 Vectorized<c10::qint32> inline maximum(const Vectorized<c10::qint32>& a, const Vectorized<c10::qint32>& b) {
385   return a.maximum(b);
386 }
387 
388 template <>
389 Vectorized<c10::qint32> inline operator*(
390     const Vectorized<c10::qint32>& a,
391     const Vectorized<c10::qint32>& b) {
392   return _mm256_mullo_epi32(a, b);
393 }
394 
395 template <>
396 Vectorized<c10::qint32> inline operator+(
397     const Vectorized<c10::qint32>& a,
398     const Vectorized<c10::qint32>& b) {
399   return _mm256_add_epi32(a, b);
400 }
401 
402 /*
403  * Convert values from int32 back to int8/uint8
404  */
405 template <typename T>
406 __m256i RequantizeAvx2(
407     const std::array<Vectorized<c10::qint32>, 4>& inp,
408     __m256 multiplier,
409     __m256i zp) {
410   static_assert(
411       std::is_same_v<T, int8_t> || std::is_same_v<T, uint8_t>,
412       "Only int8_t/uint8_t are supported");
413   constexpr auto min_val = std::numeric_limits<T>::min();
414   constexpr auto max_val = std::numeric_limits<T>::max();
415   __m256i permute_mask_v =
416       _mm256_set_epi32(0x07, 0x03, 0x06, 0x02, 0x05, 0x01, 0x04, 0x00);
417   __m256 x_scaled_v = _mm256_mul_ps(_mm256_cvtepi32_ps(inp[0]), multiplier);
418   __m256 y_scaled_v = _mm256_mul_ps(_mm256_cvtepi32_ps(inp[1]), multiplier);
419   __m256 z_scaled_v = _mm256_mul_ps(_mm256_cvtepi32_ps(inp[2]), multiplier);
420   __m256 w_scaled_v = _mm256_mul_ps(_mm256_cvtepi32_ps(inp[3]), multiplier);
421 
422   __m256i x_rounded_v = _mm256_cvtps_epi32(x_scaled_v);
423   __m256i y_rounded_v = _mm256_cvtps_epi32(y_scaled_v);
424   __m256i z_rounded_v = _mm256_cvtps_epi32(z_scaled_v);
425   __m256i w_rounded_v = _mm256_cvtps_epi32(w_scaled_v);
426 
427   /* Add zero point */
428   __m256i x_v = _mm256_add_epi32(x_rounded_v, zp);
429   __m256i y_v = _mm256_add_epi32(y_rounded_v, zp);
430   __m256i z_v = _mm256_add_epi32(z_rounded_v, zp);
431   __m256i w_v = _mm256_add_epi32(w_rounded_v, zp);
432 
433   /* Pack to int16_t and saturate */
434   __m256i xy_packed_v = _mm256_packs_epi32(x_v, y_v);
435   __m256i zw_packed_v = _mm256_packs_epi32(z_v, w_v);
436 
437   __m256i xyzw_clamped_v =
438       pack_saturate_and_clamp<T>(xy_packed_v, zw_packed_v, min_val, max_val);
439 
440   /*
441    * xyzw_clamped_v has results in the following layout so we need to
442    * permute: x0-3 y0-3 z0-3 w0-3 x4-7 y4-7 z4-7 w4-7
443    */
444   xyzw_clamped_v = _mm256_permutevar8x32_epi32(xyzw_clamped_v, permute_mask_v);
445   return xyzw_clamped_v;
446 }
447 
448 template<>
449 struct Vectorized<c10::qint8> : public Vectorizedqi {
450     static constexpr int size() {
451         return 32;
452     }
453 
454     static constexpr int float_num_vecs() {
455         return 4;
456     }
457 
458     static constexpr int int_num_vecs() {
459         return 4;
460     }
461 
462     using float_vec_return_type = std::array<Vectorized<float>, 4>;
463     using int_vec_return_type = std::array<Vectorized<c10::qint32>, 4>;
464     using value_type = typename c10::qint8::underlying;
465 
466  public:
467     using Vectorizedqi::Vectorizedqi;
468 
469     Vectorized() {}
470     Vectorized(__m256i vals_) { vals = vals_;}
471 
472     // Broadcast constructor
473     Vectorized(const c10::qint8& val) {
474         value_type uw = val.val_;
475         vals = _mm256_set1_epi8(uw);
476     }
477 
478     // This is needed because the compiler emits awful code for the default
479     // constructor for moving the enum
480     // NOLINTNEXTLINE(clang-diagnostic-deprecated-copy)
481     C10_CLANG_DIAGNOSTIC_PUSH()
482     #if C10_CLANG_HAS_WARNING("-Wdeprecated-copy")
483     C10_CLANG_DIAGNOSTIC_IGNORE("-Wdeprecated-copy")
484     #endif
485     Vectorized(const Vectorized<c10::qint8>& other) : Vectorizedqi(other.vals) { }
486     C10_CLANG_DIAGNOSTIC_POP()
487 
488     void store(void* ptr, int count = size()) const {
489         if (count != size()) {
490             memcpy(ptr, &vals, count * sizeof(value_type));
491         } else {
492             _mm256_storeu_si256((__m256i*)ptr, vals);
493         }
494     }
495 
496     static Vectorized<c10::qint8> loadu(const void* ptr) {
497         return Vectorized<c10::qint8>(ptr);
498     }
499 
500     static Vectorized<c10::qint8> loadu(const void* ptr, int64_t count) {
501         __at_align__ value_type tmp_values[size()];
502         // Ensure uninitialized memory does not change the output value See https://github.com/pytorch/pytorch/issues/32502
503         // for more details. We do not initialize arrays to zero using "={0}" because gcc would compile it to two
504         // instructions while a loop would be compiled to one instruction.
505         for (const auto i : c10::irange(size())) {
506           tmp_values[i] = 0;
507         }
508         std::memcpy(
509             tmp_values, reinterpret_cast<const value_type*>(ptr), count * sizeof(value_type));
510         return _mm256_loadu_si256((const __m256i*)tmp_values);
511     }
512 
513  private:
514     __m256i cvtepi8_epi32(__m128i epi8_vals) const {
515         return _mm256_cvtepi8_epi32(epi8_vals);
516     }
517 
518  public:
519   float_vec_return_type dequantize(
520       Vectorized<float> scale,
521       Vectorized<float> /*zero_point*/,
522       Vectorized<float> scale_neg_zp_premul) const {
523     __m128i int_val0 = _mm_set1_epi64x(_mm256_extract_epi64(vals, 0));
524     __m128i int_val1 = _mm_set1_epi64x(_mm256_extract_epi64(vals, 1));
525     __m128i int_val2 = _mm_set1_epi64x(_mm256_extract_epi64(vals, 2));
526     __m128i int_val3 = _mm_set1_epi64x(_mm256_extract_epi64(vals, 3));
527 
528     __m256 float_val0 = _mm256_cvtepi32_ps(cvtepi8_epi32(int_val0));
529     __m256 float_val1 = _mm256_cvtepi32_ps(cvtepi8_epi32(int_val1));
530     __m256 float_val2 = _mm256_cvtepi32_ps(cvtepi8_epi32(int_val2));
531     __m256 float_val3 = _mm256_cvtepi32_ps(cvtepi8_epi32(int_val3));
532 
533     auto val0 =
534         vec::fmadd(scale, Vectorized<float>(float_val0), scale_neg_zp_premul);
535     auto val1 =
536         vec::fmadd(scale, Vectorized<float>(float_val1), scale_neg_zp_premul);
537     auto val2 =
538         vec::fmadd(scale, Vectorized<float>(float_val2), scale_neg_zp_premul);
539     auto val3 =
540         vec::fmadd(scale, Vectorized<float>(float_val3), scale_neg_zp_premul);
541     return {val0, val1, val2, val3};
542   }
543 
544   float_vec_return_type dequantize(
545       Vectorized<float> scale,
546       Vectorized<float> zero_point) const {
547     __m128i int_val0 = _mm_set1_epi64x(_mm256_extract_epi64(vals, 0));
548     __m128i int_val1 = _mm_set1_epi64x(_mm256_extract_epi64(vals, 1));
549     __m128i int_val2 = _mm_set1_epi64x(_mm256_extract_epi64(vals, 2));
550     __m128i int_val3 = _mm_set1_epi64x(_mm256_extract_epi64(vals, 3));
551 
552     __m256 float_val0 = _mm256_cvtepi32_ps(cvtepi8_epi32(int_val0));
553     __m256 float_val1 = _mm256_cvtepi32_ps(cvtepi8_epi32(int_val1));
554     __m256 float_val2 = _mm256_cvtepi32_ps(cvtepi8_epi32(int_val2));
555     __m256 float_val3 = _mm256_cvtepi32_ps(cvtepi8_epi32(int_val3));
556 
557     auto val0 = (Vectorized<float>(float_val0) - zero_point) * scale;
558     auto val1 = (Vectorized<float>(float_val1) - zero_point) * scale;
559     auto val2 = (Vectorized<float>(float_val2) - zero_point) * scale;
560     auto val3 = (Vectorized<float>(float_val3) - zero_point) * scale;
561     return {val0, val1, val2, val3};
562   }
563 
564   static Vectorized<c10::qint8> quantize(
565       const float_vec_return_type& rhs,
566       float /*scale*/,
567       int32_t zero_point,
568       float inverse_scale) {
569     auto* rhs_data = (float*)rhs.data();
570     int8_t quantized_values[32];
571     QuantizeAvx2<value_type>(
572         rhs_data, quantized_values, 32, inverse_scale, zero_point);
573     return Vectorized<c10::qint8>::loadu(quantized_values);
574   }
575 
576   Vectorized<c10::qint8> maximum(Vectorized<c10::qint8> b) const {
577       return _mm256_max_epi8(vals, b.vals);
578     }
579 
580   Vectorized<c10::qint8> minimum(Vectorized<c10::qint8> b) const {
581       return _mm256_min_epi8(vals, b.vals);
582     }
583 
584     Vectorized<c10::qint8> relu(Vectorized<c10::qint8> zero_point) const {
585         return maximum(zero_point);
586     }
587 
588     Vectorized<c10::qint8> relu6(
589         Vectorized<c10::qint8> zero_point,
590         Vectorized<c10::qint8> q_six) {
591       return _mm256_min_epi8(
592           _mm256_max_epi8(vals, zero_point.vals), q_six.vals);
593     }
594 
595     int_vec_return_type widening_subtract(Vectorized<c10::qint8> b) const {
596       __m128i int_val0 = _mm_set1_epi64x(_mm256_extract_epi64(vals, 0));
597       __m128i int_val1 = _mm_set1_epi64x(_mm256_extract_epi64(vals, 1));
598       __m128i int_val2 = _mm_set1_epi64x(_mm256_extract_epi64(vals, 2));
599       __m128i int_val3 = _mm_set1_epi64x(_mm256_extract_epi64(vals, 3));
600 
601       __m256i int32_val0 = cvtepi8_epi32(int_val0);
602       __m256i int32_val1 = cvtepi8_epi32(int_val1);
603       __m256i int32_val2 = cvtepi8_epi32(int_val2);
604       __m256i int32_val3 = cvtepi8_epi32(int_val3);
605 
606       __m128i int_b0 = _mm_set1_epi64x(_mm256_extract_epi64(b, 0));
607       __m128i int_b1 = _mm_set1_epi64x(_mm256_extract_epi64(b, 1));
608       __m128i int_b2 = _mm_set1_epi64x(_mm256_extract_epi64(b, 2));
609       __m128i int_b3 = _mm_set1_epi64x(_mm256_extract_epi64(b, 3));
610 
611       __m256i int32_b0 = cvtepi8_epi32(int_b0);
612       __m256i int32_b1 = cvtepi8_epi32(int_b1);
613       __m256i int32_b2 = cvtepi8_epi32(int_b2);
614       __m256i int32_b3 = cvtepi8_epi32(int_b3);
615 
616       __m256i res_0 = _mm256_sub_epi32(int32_val0, int32_b0);
617       __m256i res_1 = _mm256_sub_epi32(int32_val1, int32_b1);
618       __m256i res_2 = _mm256_sub_epi32(int32_val2, int32_b2);
619       __m256i res_3 = _mm256_sub_epi32(int32_val3, int32_b3);
620 
621       return {Vectorized<c10::qint32>(res_0),
622               Vectorized<c10::qint32>(res_1),
623               Vectorized<c10::qint32>(res_2),
624               Vectorized<c10::qint32>(res_3)};
625     }
626 
627     static Vectorized<c10::qint8> requantize_from_int(
628         const int_vec_return_type& inp,
629         float multiplier,
630         int32_t zero_point) {
631       __m256 multiplier_v = _mm256_set1_ps(multiplier);
632       __m256i zero_point_v = _mm256_set1_epi32(zero_point);
633       return RequantizeAvx2<value_type>(inp, multiplier_v, zero_point_v);
634     }
635 
636  private:
637     // Load from memory constructor
638     Vectorized(const void* ptr) {
639         vals = _mm256_loadu_si256((const __m256i*)ptr);
640     }
641 };
642 
643 template <>
644 Vectorized<c10::qint8> inline maximum(const Vectorized<c10::qint8>& a, const Vectorized<c10::qint8>& b) {
645   return a.maximum(b);
646 }
647 
648 template<>
649 struct Vectorized<c10::quint8> : public Vectorizedqi {
650     static constexpr int size() {
651         return 32;
652     }
653 
654     static constexpr int float_num_vecs() {
655         return 4;
656     }
657 
658     static constexpr int int_num_vecs() {
659         return 4;
660     }
661 
662     using float_vec_return_type = std::array<Vectorized<float>, 4>;
663     using int_vec_return_type = std::array<Vectorized<c10::qint32>, 4>;
664     using value_type = typename c10::quint8::underlying;
665 
666  public:
667     using Vectorizedqi::Vectorizedqi;
668     Vectorized() {}
669 
670     Vectorized(__m256i vals_) { vals = vals_;}
671 
672     // Broadcast constructor
673     Vectorized(const c10::quint8& val) {
674         value_type uw = val.val_;
675         vals = _mm256_set1_epi8(uw);
676     }
677 
678     // NOLINTNEXTLINE(clang-diagnostic-deprecated-copy)
679     C10_CLANG_DIAGNOSTIC_PUSH()
680     #if C10_CLANG_HAS_WARNING("-Wdeprecated-copy")
681     C10_CLANG_DIAGNOSTIC_IGNORE("-Wdeprecated-copy")
682     #endif
683     Vectorized(const Vectorized<c10::quint8>& other) : Vectorizedqi(other.vals) { }
684     C10_CLANG_DIAGNOSTIC_POP()
685 
686     void store(void* ptr, int count = size()) const {
687         if (count != size()) {
688             memcpy(ptr, &vals, count * sizeof(value_type));
689         } else {
690             _mm256_storeu_si256((__m256i*)ptr, vals);
691         }
692     }
693 
694     static Vectorized<c10::quint8> loadu(const void* ptr) {
695         return Vectorized<c10::quint8>(ptr);
696     }
697 
698     static Vectorized<c10::quint8> loadu(const void* ptr, int64_t count) {
699         __at_align__ value_type tmp_values[size()];
700         // Ensure uninitialized memory does not change the output value See https://github.com/pytorch/pytorch/issues/32502
701         // for more details. We do not initialize arrays to zero using "={0}" because gcc would compile it to two
702         // instructions while a loop would be compiled to one instruction.
703         for (const auto i : c10::irange(size())) {
704           tmp_values[i] = 0;
705         }
706         std::memcpy(
707             tmp_values, reinterpret_cast<const value_type*>(ptr), count * sizeof(value_type));
708         return _mm256_loadu_si256((const __m256i*)tmp_values);
709     }
710 
711  private:
712     __m256i cvtepu8_epi32(__m128i epu8_vals) const {
713         return _mm256_cvtepu8_epi32(epu8_vals);
714     }
715 
716  public:
717   float_vec_return_type dequantize(
718       Vectorized<float> scale,
719       Vectorized<float> /*zero_point*/,
720       Vectorized<float> scale_zp_premul) const {
721     __m128i int_val0 = _mm_set1_epi64x(_mm256_extract_epi64(vals, 0));
722     __m128i int_val1 = _mm_set1_epi64x(_mm256_extract_epi64(vals, 1));
723     __m128i int_val2 = _mm_set1_epi64x(_mm256_extract_epi64(vals, 2));
724     __m128i int_val3 = _mm_set1_epi64x(_mm256_extract_epi64(vals, 3));
725 
726     __m256 float_val0 = _mm256_cvtepi32_ps(cvtepu8_epi32(int_val0));
727     __m256 float_val1 = _mm256_cvtepi32_ps(cvtepu8_epi32(int_val1));
728     __m256 float_val2 = _mm256_cvtepi32_ps(cvtepu8_epi32(int_val2));
729     __m256 float_val3 = _mm256_cvtepi32_ps(cvtepu8_epi32(int_val3));
730 
731     auto val0 =
732         vec::fmadd(scale, Vectorized<float>(float_val0), scale_zp_premul);
733     auto val1 =
734         vec::fmadd(scale, Vectorized<float>(float_val1), scale_zp_premul);
735     auto val2 =
736         vec::fmadd(scale, Vectorized<float>(float_val2), scale_zp_premul);
737     auto val3 =
738         vec::fmadd(scale, Vectorized<float>(float_val3), scale_zp_premul);
739     return {val0, val1, val2, val3};
740   }
741 
742   float_vec_return_type dequantize(
743       Vectorized<float> scale,
744       Vectorized<float> zero_point) const {
745     __m128i int_val0 = _mm_set1_epi64x(_mm256_extract_epi64(vals, 0));
746     __m128i int_val1 = _mm_set1_epi64x(_mm256_extract_epi64(vals, 1));
747     __m128i int_val2 = _mm_set1_epi64x(_mm256_extract_epi64(vals, 2));
748     __m128i int_val3 = _mm_set1_epi64x(_mm256_extract_epi64(vals, 3));
749 
750     __m256 float_val0 = _mm256_cvtepi32_ps(cvtepu8_epi32(int_val0));
751     __m256 float_val1 = _mm256_cvtepi32_ps(cvtepu8_epi32(int_val1));
752     __m256 float_val2 = _mm256_cvtepi32_ps(cvtepu8_epi32(int_val2));
753     __m256 float_val3 = _mm256_cvtepi32_ps(cvtepu8_epi32(int_val3));
754 
755     auto val0 = (Vectorized<float>(float_val0) - zero_point) * scale;
756     auto val1 = (Vectorized<float>(float_val1) - zero_point) * scale;
757     auto val2 = (Vectorized<float>(float_val2) - zero_point) * scale;
758     auto val3 = (Vectorized<float>(float_val3) - zero_point) * scale;
759     return {val0, val1, val2, val3};
760   }
761 
762   static Vectorized<c10::quint8> quantize(
763       const float_vec_return_type& rhs,
764       float /*scale*/,
765       int32_t zero_point,
766       float inverse_scale) {
767     auto* rhs_data = (float*)rhs.data();
768     uint8_t quantized_values[32];
769     QuantizeAvx2<value_type>(
770         rhs_data, quantized_values, 32, inverse_scale, zero_point);
771     return Vectorized<c10::quint8>::loadu(quantized_values);
772   }
773 
774   Vectorized<c10::quint8> maximum(Vectorized<c10::quint8> b) const {
775       return _mm256_max_epu8(vals, b.vals);
776     }
777 
778   Vectorized<c10::quint8> minimum(Vectorized<c10::quint8> b) const {
779       return _mm256_min_epu8(vals, b.vals);
780     }
781 
782     Vectorized<c10::quint8> relu(Vectorized<c10::quint8> zero_point) const {
783         return maximum(zero_point);
784     }
785 
786     Vectorized<c10::quint8> relu6(
787         Vectorized<c10::quint8> zero_point,
788         Vectorized<c10::quint8> q_six) {
789       return _mm256_min_epu8(
790           _mm256_max_epu8(vals, zero_point.vals), q_six.vals);
791     }
792 
793     int_vec_return_type widening_subtract(Vectorized<c10::quint8> b) const {
794       __m128i int_val0 = _mm_set1_epi64x(_mm256_extract_epi64(vals, 0));
795       __m128i int_val1 = _mm_set1_epi64x(_mm256_extract_epi64(vals, 1));
796       __m128i int_val2 = _mm_set1_epi64x(_mm256_extract_epi64(vals, 2));
797       __m128i int_val3 = _mm_set1_epi64x(_mm256_extract_epi64(vals, 3));
798 
799       __m256i int32_val0 = cvtepu8_epi32(int_val0);
800       __m256i int32_val1 = cvtepu8_epi32(int_val1);
801       __m256i int32_val2 = cvtepu8_epi32(int_val2);
802       __m256i int32_val3 = cvtepu8_epi32(int_val3);
803 
804       __m128i int_b0 = _mm_set1_epi64x(_mm256_extract_epi64(b, 0));
805       __m128i int_b1 = _mm_set1_epi64x(_mm256_extract_epi64(b, 1));
806       __m128i int_b2 = _mm_set1_epi64x(_mm256_extract_epi64(b, 2));
807       __m128i int_b3 = _mm_set1_epi64x(_mm256_extract_epi64(b, 3));
808 
809       __m256i int32_b0 = cvtepu8_epi32(int_b0);
810       __m256i int32_b1 = cvtepu8_epi32(int_b1);
811       __m256i int32_b2 = cvtepu8_epi32(int_b2);
812       __m256i int32_b3 = cvtepu8_epi32(int_b3);
813 
814       __m256i res_0 = _mm256_sub_epi32(int32_val0, int32_b0);
815       __m256i res_1 = _mm256_sub_epi32(int32_val1, int32_b1);
816       __m256i res_2 = _mm256_sub_epi32(int32_val2, int32_b2);
817       __m256i res_3 = _mm256_sub_epi32(int32_val3, int32_b3);
818       return {Vectorized<c10::qint32>(res_0),
819               Vectorized<c10::qint32>(res_1),
820               Vectorized<c10::qint32>(res_2),
821               Vectorized<c10::qint32>(res_3)};
822     }
823 
824     static Vectorized<c10::quint8> requantize_from_int(
825         const int_vec_return_type& inp,
826         float multiplier,
827         int32_t zero_point) {
828       __m256 multiplier_v = _mm256_set1_ps(multiplier);
829       __m256i zero_point_v = _mm256_set1_epi32(zero_point);
830       return RequantizeAvx2<value_type>(inp, multiplier_v, zero_point_v);
831     }
832 
833  private:
834 
835     // Load from memory constructor
836     Vectorized(const void* ptr) {
837         vals = _mm256_loadu_si256((const __m256i*)ptr);
838     }
839 };
840 
841 template <>
842 Vectorized<c10::quint8> inline maximum(const Vectorized<c10::quint8>& a, const Vectorized<c10::quint8>& b) {
843   return a.maximum(b);
844 }
845 
846 #else
847 
848 // NOTE: These are low-performance implementations that we fall back on
849 // if we are not building with AVX2. This may not be an issue, because
850 // currently for quantization we assume the user has at least AVX512
851 // installed, so these can simply act as a reference implementation.
852 //
853 // If in the future we relax this requirement (AVX2+), we should probably
854 // revisit these implementations
855 
856 template <
857     typename T,
858     typename float_vec_return_type_,
859     typename int_vec_return_type_,
860     int size_>
861 struct VectorizedQuantizedConverter {
862   static constexpr int size() {
863     return size_;
864   }
865 
866   static constexpr int float_num_vecs() {
867     return size() / 8;
868   }
869 
870   static constexpr int int_num_vecs() {
871     return size() / 8;
872   }
873 
874   using float_vec_return_type = float_vec_return_type_;
875   using int_vec_return_type = int_vec_return_type_;
876 
877   using value_type = typename T::underlying;
878   std::array<value_type, size_> vals;
879 
880   VectorizedQuantizedConverter(T val) {
881     for (const auto i : c10::irange(size())) {
882       vals[i] = val.val_;
883     }
884   }
885 
886   VectorizedQuantizedConverter(const void* ptr) {
887     memcpy(vals.data(), ptr, sizeof(value_type) * size());
888   }
889 
890   void store(void* ptr, int count = size()) const {
891     memcpy(ptr, vals.data(), count * sizeof(value_type));
892   }
893 
894   float_vec_return_type dequantize(
895       Vectorized<float> scale,
896       Vectorized<float> zero_point,
897       Vectorized<float> /*scale_zp_premul*/) const {
898     float_vec_return_type rv;
899     for (const auto i : c10::irange(float_num_vecs())) {
900       float tmp_vals[8];
901       for (const auto j : c10::irange(8)) {
902         tmp_vals[j] = at::native::dequantize_val<T>(
903             scale[j], zero_point[j], T(vals[8 * i + j]));
904       }
905       rv[i] = Vectorized<float>(tmp_vals[0],
906           tmp_vals[1],
907           tmp_vals[2],
908           tmp_vals[3],
909           tmp_vals[4],
910           tmp_vals[5],
911           tmp_vals[6],
912           tmp_vals[7]);
913     }
914     return rv;
915   }
916 
917   float_vec_return_type dequantize(
918       Vectorized<float> scale,
919       Vectorized<float> zero_point) const {
920     Vectorized<float> scale_zp_premul;
921     return dequantize(scale, zero_point, scale_zp_premul);
922   }
923 
924  protected:
925   VectorizedQuantizedConverter() {}
926 };
927 
928 template <>
929 struct Vectorized<c10::qint32> : public VectorizedQuantizedConverter<
930                                  c10::qint32,
931                                  std::array<Vectorized<float>, 1>,
932                                  std::array<Vectorized<c10::qint32>, 1>,
933                                  8> {
934   Vectorized()
935       : VectorizedQuantizedConverter<
936             c10::qint32,
937             std::array<Vectorized<float>, 1>,
938             std::array<Vectorized<c10::qint32>, 1>,
939             8>() {}
940   Vectorized(c10::qint32 val)
941       : VectorizedQuantizedConverter<
942             c10::qint32,
943             std::array<Vectorized<float>, 1>,
944             std::array<Vectorized<c10::qint32>, 1>,
945             8>(val) {}
946   Vectorized(const void* ptr)
947       : VectorizedQuantizedConverter<
948             c10::qint32,
949             std::array<Vectorized<float>, 1>,
950             std::array<Vectorized<c10::qint32>, 1>,
951             8>(ptr) {}
952 
953   static Vectorized<c10::qint32> loadu(const void* ptr) {
954     return Vectorized<c10::qint32>(ptr);
955   }
956 
957   static Vectorized<c10::qint32> loadu(const void* ptr, int64_t count) {
958     __at_align__ value_type tmp_values[size()];
959     // Ensure uninitialized memory does not change the output value See https://github.com/pytorch/pytorch/issues/32502
960     // for more details. We do not initialize arrays to zero using "={0}" because gcc would compile it to two
961     // instructions while a loop would be compiled to one instruction.
962     for (const auto i : c10::irange(size())) {
963       tmp_values[i] = 0;
964     }
965     std::memcpy(
966         tmp_values, reinterpret_cast<const value_type*>(ptr), count * sizeof(value_type));
967     return Vectorized<c10::qint32>(tmp_values);
968   }
969 
970   static Vectorized<c10::qint32> quantize(
971       const float_vec_return_type& rhs,
972       float scale,
973       int32_t zero_point,
974       float /*inverse_scale*/) {
975     std::array<value_type, size()> qvals;
976     std::array<float, float_num_vecs() * 8> float_vals;
977 
978     for (const auto i : c10::irange(float_num_vecs())) {
979       rhs[i].store(&float_vals[i * 8], 8);
980     }
981 
982     at::native::quantize_vec<c10::qint32, /*precision=*/32>(
983         scale,
984         zero_point,
985         float_vals.data(),
986         (c10::qint32*)qvals.data(),
987         8 * float_num_vecs());
988 
989     return Vectorized<c10::qint32>::loadu(qvals.data());
990   }
991 
992   Vectorized<c10::qint32> maximum(Vectorized<c10::qint32> b) const {
993     Vectorized<c10::qint32> retval;
994     for (const auto i : c10::irange(size())) {
995       retval.vals[i] = std::max<value_type>(vals[i], b.vals[i]);
996     }
997     return retval;
998   }
999 
1000   Vectorized<c10::qint32> minimum(Vectorized<c10::qint32> b) const {
1001     Vectorized<c10::qint32> retval;
1002     for (const auto i : c10::irange(size())) {
1003       retval.vals[i] = std::min<value_type>(vals[i], b.vals[i]);
1004     }
1005     return retval;
1006   }
1007 
1008   Vectorized<c10::qint32> relu(Vectorized<c10::qint32> zero_point) const  {
1009     return maximum(zero_point);
1010   }
1011 
1012 
1013   Vectorized<c10::qint32> relu6(
1014       Vectorized<c10::qint32> zero_point,
1015       Vectorized<c10::qint32> q_six) {
1016     Vectorized<c10::qint32> retval;
1017     for (const auto i : c10::irange(size())) {
1018       retval.vals[i] = std::min<value_type>(
1019           std::max<value_type>(vals[i], zero_point.vals[i]), q_six.vals[i]);
1020     }
1021     return retval;
1022   }
1023 
1024   int_vec_return_type widening_subtract(Vectorized<c10::qint32> b) const {
1025     int_vec_return_type retval;
1026     for (const auto i : c10::irange(size())) {
1027       retval[0].vals[i] = vals[i] - b.vals[i];
1028     }
1029     return retval;
1030   }
1031 
1032   static Vectorized<c10::qint32> requantize_from_int(
1033       const int_vec_return_type& inp,
1034       float multiplier,
1035       int32_t zero_point) {
1036     Vectorized<c10::qint32> retval;
1037     for (const auto i : c10::irange(size())) {
1038       retval.vals[i] =
1039           std::nearbyint(static_cast<float>(inp[0].vals[i]) * multiplier) +
1040           zero_point;
1041     }
1042     return retval;
1043   }
1044 };
1045 
1046 template <>
1047 Vectorized<c10::qint32> inline maximum(const Vectorized<c10::qint32>& a, const Vectorized<c10::qint32>& b) {
1048   return a.maximum(b);
1049 }
1050 
1051 template <>
1052 Vectorized<c10::qint32> inline operator*(
1053     const Vectorized<c10::qint32>& a,
1054     const Vectorized<c10::qint32>& b) {
1055   Vectorized<c10::qint32> retval;
1056   for (const auto i : c10::irange(std::decay_t<decltype(a)>::size())) {
1057     retval.vals[i] = a.vals[i] * b.vals[i];
1058   }
1059   return retval;
1060 }
1061 
1062 template <>
1063 Vectorized<c10::qint32> inline operator+(
1064     const Vectorized<c10::qint32>& a,
1065     const Vectorized<c10::qint32>& b) {
1066   Vectorized<c10::qint32> retval;
1067   for (const auto i : c10::irange(std::decay_t<decltype(a)>::size())) {
1068     retval.vals[i] = a.vals[i] + b.vals[i];
1069   }
1070   return retval;
1071 }
1072 
1073 template <>
1074 struct Vectorized<c10::qint8> : public VectorizedQuantizedConverter<
1075                                 c10::qint8,
1076                                 std::array<Vectorized<float>, 4>,
1077                                 std::array<Vectorized<c10::qint32>, 4>,
1078                                 32> {
1079   Vectorized()
1080       : VectorizedQuantizedConverter<
1081             c10::qint8,
1082             std::array<Vectorized<float>, 4>,
1083             std::array<Vectorized<c10::qint32>, 4>,
1084             32>() {}
1085   Vectorized(c10::qint8 val)
1086       : VectorizedQuantizedConverter<
1087             c10::qint8,
1088             std::array<Vectorized<float>, 4>,
1089             std::array<Vectorized<c10::qint32>, 4>,
1090             32>(val) {}
1091   Vectorized(const void* ptr)
1092       : VectorizedQuantizedConverter<
1093             c10::qint8,
1094             std::array<Vectorized<float>, 4>,
1095             std::array<Vectorized<c10::qint32>, 4>,
1096             32>(ptr) {}
1097 
1098   static Vectorized<c10::qint8> loadu(const void* ptr) {
1099     return Vectorized<c10::qint8>(ptr);
1100   }
1101 
1102   static Vectorized<c10::qint8> loadu(const void* ptr, int64_t count) {
1103     __at_align__ value_type tmp_values[size()];
1104     // Ensure uninitialized memory does not change the output value See https://github.com/pytorch/pytorch/issues/32502
1105     // for more details. We do not initialize arrays to zero using "={0}" because gcc would compile it to two
1106     // instructions while a loop would be compiled to one instruction.
1107     for (const auto i : c10::irange(size())) {
1108       tmp_values[i] = 0;
1109     }
1110     std::memcpy(
1111         tmp_values, reinterpret_cast<const value_type*>(ptr), count * sizeof(value_type));
1112     return Vectorized<c10::qint8>(tmp_values);
1113   }
1114 
1115   static Vectorized<c10::qint8> quantize(
1116       const float_vec_return_type& rhs,
1117       float scale,
1118       int32_t zero_point,
1119       float /*inverse_scale*/) {
1120     std::array<value_type, size()> qvals;
1121     std::array<float, float_num_vecs() * 8> float_vals;
1122 
1123     for (const auto i : c10::irange(float_num_vecs())) {
1124       rhs[i].store(&float_vals[i * 8], 8);
1125     }
1126 
1127     at::native::quantize_vec<c10::qint8>(
1128         scale,
1129         zero_point,
1130         float_vals.data(),
1131         (c10::qint8*)qvals.data(),
1132         8 * float_num_vecs());
1133 
1134     return Vectorized<c10::qint8>::loadu(qvals.data());
1135   }
1136 
1137   Vectorized<c10::qint8> maximum(Vectorized<c10::qint8> b) const {
1138     Vectorized<c10::qint8> retval;
1139     for (const auto i : c10::irange(size())) {
1140       retval.vals[i] = std::max<value_type>(vals[i], b.vals[i]);
1141     }
1142     return retval;
1143   }
1144 
1145   Vectorized<c10::qint8> minimum(Vectorized<c10::qint8> b) const {
1146     Vectorized<c10::qint8> retval;
1147     for (const auto i : c10::irange(size())) {
1148       retval.vals[i] = std::min<value_type>(vals[i], b.vals[i]);
1149     }
1150     return retval;
1151   }
1152 
1153   Vectorized<c10::qint8> relu(Vectorized<c10::qint8> zero_point) const {
1154     return maximum(zero_point);
1155   }
1156 
1157   Vectorized<c10::qint8> relu6(
1158       Vectorized<c10::qint8> zero_point,
1159       Vectorized<c10::qint8> q_six) {
1160     Vectorized<c10::qint8> retval;
1161     for (const auto i : c10::irange(size())) {
1162       retval.vals[i] = std::min<value_type>(
1163           std::max<value_type>(vals[i], zero_point.vals[i]), q_six.vals[i]);
1164     }
1165     return retval;
1166   }
1167 
1168   int_vec_return_type widening_subtract(Vectorized<c10::qint8> b) const {
1169     int_vec_return_type retval;
1170     constexpr int elem_per_int_vec = size() / int_num_vecs();
1171     for (const auto i : c10::irange(int_num_vecs())) {
1172       for (const auto j : c10::irange(elem_per_int_vec)) {
1173         retval[i].vals[j] =
1174             static_cast<int32_t>(vals[i * elem_per_int_vec + j]) -
1175             static_cast<int32_t>(b.vals[i * elem_per_int_vec + j]);
1176       }
1177     }
1178     return retval;
1179   }
1180   static Vectorized<c10::qint8> requantize_from_int(
1181       const int_vec_return_type& inp,
1182       float multiplier,
1183       int32_t zero_point) {
1184     constexpr int elem_per_int_vec = size() / int_num_vecs();
1185     constexpr auto min_val = std::numeric_limits<value_type>::min();
1186     constexpr auto max_val = std::numeric_limits<value_type>::max();
1187     Vectorized<c10::qint8> retval;
1188     for (const auto i : c10::irange(int_num_vecs())) {
1189       for (const auto j : c10::irange(elem_per_int_vec)) {
1190         int32_t rounded =
1191             std::nearbyint(static_cast<float>(inp[i].vals[j]) * multiplier) +
1192             zero_point;
1193         retval.vals[i * elem_per_int_vec + j] =
1194             std::min<int32_t>(std::max<int32_t>(rounded, min_val), max_val);
1195       }
1196     }
1197     return retval;
1198   }
1199 };
1200 
1201 template <>
1202 Vectorized<c10::qint8> inline maximum(const Vectorized<c10::qint8>& a, const Vectorized<c10::qint8>& b) {
1203   return a.maximum(b);
1204 }
1205 
1206 template <>
1207 struct Vectorized<c10::quint8> : public VectorizedQuantizedConverter<
1208                                  c10::quint8,
1209                                  std::array<Vectorized<float>, 4>,
1210                                  std::array<Vectorized<c10::qint32>, 4>,
1211                                  32> {
1212   Vectorized()
1213       : VectorizedQuantizedConverter<
1214             c10::quint8,
1215             std::array<Vectorized<float>, 4>,
1216             std::array<Vectorized<c10::qint32>, 4>,
1217             32>() {}
1218   Vectorized(c10::quint8 val)
1219       : VectorizedQuantizedConverter<
1220             c10::quint8,
1221             std::array<Vectorized<float>, 4>,
1222             std::array<Vectorized<c10::qint32>, 4>,
1223             32>(val) {}
1224   Vectorized(const void* ptr)
1225       : VectorizedQuantizedConverter<
1226             c10::quint8,
1227             std::array<Vectorized<float>, 4>,
1228             std::array<Vectorized<c10::qint32>, 4>,
1229             32>(ptr) {}
1230 
1231   static Vectorized<c10::quint8> loadu(const void* ptr) {
1232     return Vectorized<c10::quint8>(ptr);
1233   }
1234 
1235   static Vectorized<c10::quint8> loadu(const void* ptr, int64_t count) {
1236     __at_align__ value_type tmp_values[size()];
1237     // Ensure uninitialized memory does not change the output value See https://github.com/pytorch/pytorch/issues/32502
1238     // for more details. We do not initialize arrays to zero using "={0}" because gcc would compile it to two
1239     // instructions while a loop would be compiled to one instruction.
1240     for (const auto i : c10::irange(size())) {
1241       tmp_values[i] = 0;
1242     }
1243     std::memcpy(
1244         tmp_values, reinterpret_cast<const value_type*>(ptr), count * sizeof(value_type));
1245     return Vectorized<c10::quint8>(tmp_values);
1246   }
1247 
1248   static Vectorized<c10::quint8> quantize(
1249       const float_vec_return_type& rhs,
1250       float scale,
1251       int32_t zero_point,
1252       float /*inverse_scale*/) {
1253     std::array<value_type, size()> qvals;
1254     std::array<float, float_num_vecs() * 8> float_vals;
1255 
1256     for (const auto i : c10::irange(float_num_vecs())) {
1257       rhs[i].store(&float_vals[i * 8], 8);
1258     }
1259 
1260     at::native::quantize_vec<c10::quint8>(
1261         scale,
1262         zero_point,
1263         float_vals.data(),
1264         (c10::quint8*)qvals.data(),
1265         8 * float_num_vecs());
1266 
1267     return Vectorized<c10::quint8>::loadu(qvals.data());
1268   }
1269 
1270   Vectorized<c10::quint8> maximum(Vectorized<c10::quint8> b) const {
1271     Vectorized<c10::quint8> retval;
1272     for (const auto i : c10::irange(size())) {
1273       retval.vals[i] = std::max<value_type>(vals[i], b.vals[i]);
1274     }
1275     return retval;
1276   }
1277 
1278   Vectorized<c10::quint8> minimum(Vectorized<c10::quint8> b) const {
1279     Vectorized<c10::quint8> retval;
1280     for (const auto i : c10::irange(size())) {
1281       retval.vals[i] = std::min<value_type>(vals[i], b.vals[i]);
1282     }
1283     return retval;
1284   }
1285 
1286   Vectorized<c10::quint8> relu(Vectorized<c10::quint8> zero_point) const {
1287     return maximum(zero_point);
1288   }
1289 
1290 
1291   Vectorized<c10::quint8> relu6(
1292       Vectorized<c10::quint8> zero_point,
1293       Vectorized<c10::quint8> q_six) {
1294     Vectorized<c10::quint8> retval;
1295     for (const auto i : c10::irange(size())) {
1296       retval.vals[i] = std::min<value_type>(
1297           std::max<value_type>(vals[i], zero_point.vals[i]), q_six.vals[i]);
1298     }
1299     return retval;
1300   }
1301 
1302   int_vec_return_type widening_subtract(Vectorized<c10::quint8> b) const {
1303     int_vec_return_type retval;
1304     constexpr int elem_per_int_vec = size() / int_num_vecs();
1305     for (const auto i : c10::irange(int_num_vecs())) {
1306       for (const auto j : c10::irange(elem_per_int_vec)) {
1307         retval[i].vals[j] =
1308             static_cast<int32_t>(vals[i * elem_per_int_vec + j]) -
1309             static_cast<int32_t>(b.vals[i * elem_per_int_vec + j]);
1310       }
1311     }
1312     return retval;
1313   }
1314   static Vectorized<c10::quint8> requantize_from_int(
1315       const int_vec_return_type& inp,
1316       float multiplier,
1317       int32_t zero_point) {
1318     constexpr int elem_per_int_vec = size() / int_num_vecs();
1319     constexpr auto min_val = std::numeric_limits<value_type>::min();
1320     constexpr auto max_val = std::numeric_limits<value_type>::max();
1321     Vectorized<c10::quint8> retval;
1322     for (const auto i : c10::irange(int_num_vecs())) {
1323       for (const auto j : c10::irange(elem_per_int_vec)) {
1324         int32_t rounded =
1325             std::nearbyint(static_cast<float>(inp[i].vals[j]) * multiplier) +
1326             zero_point;
1327         retval.vals[i * elem_per_int_vec + j] =
1328             std::min<int32_t>(std::max<int32_t>(rounded, min_val), max_val);
1329       }
1330     }
1331     return retval;
1332   }
1333 };
1334 
1335 template <>
1336 Vectorized<c10::quint8> inline maximum(const Vectorized<c10::quint8>& a, const Vectorized<c10::quint8>& b) {
1337   return a.maximum(b);
1338 }
1339 
1340 #endif // if defined(CPU_CAPABILITY_AVX2)
1341 
1342 #if defined(CPU_CAPABILITY_NEON)
1343 template <typename T>
1344 typename std::enable_if_t<std::is_same_v<T, int8_t>, at::vec::Vectorized<float>>
1345 inline convert_int8_to_float(at::vec::Vectorized<T> src) {
1346   // Note: this function only convert inputs number of elements equal to at::vec::Vectorized<float>.size()
1347     auto s8x8 = vld1_s8(src.operator const int8_t*());
1348     auto s16x8 = vmovl_s8(s8x8);
1349 
1350     auto s32x4_hi = vmovl_s16(vget_high_s16(s16x8));
1351     auto s32x4_lo = vmovl_s16(vget_low_s16(s16x8));
1352 
1353     return Vectorized<float>(vcvtq_f32_s32(s32x4_lo), vcvtq_f32_s32(s32x4_hi));
1354 }
1355 
1356 template <typename T>
1357 typename std::enable_if_t<std::is_same_v<T, uint8_t>, at::vec::Vectorized<float>>
1358 inline convert_int8_to_float(at::vec::Vectorized<T> src) {
1359   // Note: this function only convert inputs number of elements equal to at::vec::Vectorized<float>.size()
1360     auto u8x8 = vld1_u8(src.operator const uint8_t*());
1361     auto u16x8 = vmovl_u8(u8x8);
1362     auto u32x4_hi = vmovl_u16(vget_high_u16(u16x8));
1363     auto u32x4_lo = vmovl_u16(vget_low_u16(u16x8));
1364 
1365     return Vectorized<float>(vcvtq_f32_u32(u32x4_lo), vcvtq_f32_u32(u32x4_hi));
1366 }
1367 
1368 #endif
1369 }} // namespace at::vec::CPU_CAPABILITY
1370