xref: /aosp_15_r20/external/pytorch/aten/src/ATen/cpu/vec/vec_base.h (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #pragma once
2 
3 // DO NOT DEFINE STATIC DATA IN THIS HEADER!
4 // See Note [Do not compile initializers with AVX]
5 //
6 // Note [Do not compile initializers with AVX]
7 // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
8 // If you define a static initializer in this file, the initialization will use
9 // AVX instructions because these object files are compiled with AVX enabled.
10 // We need to avoid non-trivial global data in these architecture specific files
11 // because there's no way to guard the global initializers with CPU capability
12 // detection.
13 //
14 // See https://github.com/pytorch/pytorch/issues/37577 for an instance
15 // of this bug in the past.
16 
17 #include <array>
18 #include <algorithm>
19 #include <cassert>
20 #include <cstring>
21 #include <functional>
22 #include <cmath>
23 #include <type_traits>
24 #include <climits>
25 
26 #include <ATen/cpu/vec/intrinsics.h>
27 #include <ATen/native/Math.h>
28 #include <ATen/NumericUtils.h>
29 #include <c10/util/Half.h>
30 #include <c10/util/BFloat16.h>
31 #include <c10/util/BFloat16-math.h>
32 #include <c10/util/copysign.h>
33 #include <ATen/native/cpu/zmath.h>
34 #include <c10/util/TypeCast.h>
35 #include <c10/macros/Macros.h>
36 #include <c10/util/irange.h>
37 #include <c10/util/Load.h>
38 
39 #if defined(__GNUC__)
40 #define __FORCE_INLINE __attribute__((always_inline)) inline
41 #elif defined(_MSC_VER)
42 #define __FORCE_INLINE __forceinline
43 #endif
44 
45 #if defined(_MSC_FULL_VER)
46 /*
47 https://learn.microsoft.com/en-us/cpp/overview/compiler-versions?view=msvc-170
48 Use _MSC_FULL_VER to identify current compiler is msvc,
49 Windows llvm will not have this defination.
50 */
51 #define __msvc_cl__
52 #endif
53 
54 // These macros helped us unify vec_base.h
55 #ifdef CPU_CAPABILITY_AVX512
56 #if defined(__GNUC__)
57 #define __at_align__ __attribute__((aligned(64)))
58 #elif defined(_WIN32)
59 #define __at_align__ __declspec(align(64))
60 #else
61 #define __at_align__
62 #endif
63 #define VECTOR_WIDTH 64
64 #define int_vector __m512i
65 #else // CPU_CAPABILITY_AVX512
66 #if defined(__GNUC__)
67 #define __at_align__ __attribute__((aligned(32)))
68 #elif defined(_WIN32)
69 #define __at_align__ __declspec(align(32))
70 #else
71 #define __at_align__
72 #endif
73 #define VECTOR_WIDTH 32
74 #define int_vector __m256i
75 #endif // CPU_CAPABILITY_AVX512
76 
77 namespace at::vec {
78 // See Note [CPU_CAPABILITY namespace]
79 inline namespace CPU_CAPABILITY {
80 // at::Half and at::BFloat16 should be treated as floating point
81 template <typename T>
82 struct is_floating_point:
83     std::integral_constant<bool,
84       std::is_floating_point_v<T> ||
85       std::is_same_v<T, at::Half> ||
86       std::is_same_v<T, at::BFloat16>> {
87 };
88 
89 template<typename T>
90 constexpr bool is_floating_point_v = is_floating_point<T>::value;
91 
92 template <typename T>
93 struct is_reduced_floating_point:
94     std::integral_constant<bool,
95       std::is_same_v<T, at::Half> ||
96       std::is_same_v<T, at::BFloat16>> {
97 };
98 
99 template <typename T>
100 constexpr bool is_reduced_floating_point_v = is_reduced_floating_point<T>::value;
101 
102 template <typename T>
103 struct is_8bit_integer:
104     std::integral_constant<bool,
105       std::is_same_v<T, unsigned char> ||
106       std::is_same_v<T, signed char>> {
107 };
108 
109 template <typename T>
110 constexpr bool is_8bit_integer_v = is_8bit_integer<T>::value;
111 
112 template<size_t n> struct int_of_size;
113 
114 #define DEFINE_INT_OF_SIZE(int_t) \
115 template<> struct int_of_size<sizeof(int_t)> { using type = int_t; }
116 
117 DEFINE_INT_OF_SIZE(int64_t);
118 DEFINE_INT_OF_SIZE(int32_t);
119 DEFINE_INT_OF_SIZE(int16_t);
120 DEFINE_INT_OF_SIZE(int8_t);
121 
122 #undef DEFINE_INT_OF_SIZE
123 
124 template <typename T>
125 using int_same_size_t = typename int_of_size<sizeof(T)>::type;
126 
127 // NOTE: If you specialize on a type, you must define all operations!
128 
129 // emulates Vectorized types
130 #if defined(__s390x__)
131 template <class T, class TEMP=void>
132 #else
133 template <class T>
134 #endif
135 struct Vectorized {
136 private:
137   __at_align__ T values[VECTOR_WIDTH / sizeof(T)];
138 public:
139   using value_type = T;
140   using size_type = int;
141   // Note [constexpr static function to avoid odr-usage compiler bug]
142   // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
143   // Why, you might ask, is size defined to be a static constexpr function,
144   // rather than a more ordinary 'static constexpr int size;' variable?
145   // The problem lies within ODR rules for static constexpr members versus
146   // static constexpr functions.  First, recall that this class (along with all
147   // of its derivations) live in an anonymous namespace: they are intended to be
148   // *completely* inlined at their use-sites, because we need to compile it
149   // multiple times for different instruction sets.
150   //
151   // Because of this constraint, we CANNOT provide a single definition for
152   // any static members in this class; since we want to compile the class
153   // multiple times, there wouldn't actually be any good place to put the
154   // definition.  Now here is the problem: if we ODR-use a static constexpr
155   // member, we are *obligated* to provide a definition.  Without the
156   // definition, you get a compile error like:
157   //
158   //    relocation R_X86_64_PC32 against undefined symbol
159   //    `_ZN2at6vec25612_GLOBAL__N_16VectorizedIdE4sizeE' can not be used when making
160   //    a shared object; recompile with -fPIC
161   //
162   // If this were C++17, we could replace a static constexpr variable with
163   // an inline variable which doesn't require one definition. But we are not
164   // C++17.  So the next best thing is to replace the member with a static
165   // constexpr (and therefore inline) function, which does not require ODR
166   // either.
167   //
168   // Also, technically according to the C++ standard, we don't have to define
169   // a constexpr variable if we never odr-use it.  But it seems that some
170   // versions GCC/Clang have buggy determinations on whether or not an
171   // identifier is odr-used or not, and in any case it's hard to tell if
172   // a variable is odr-used or not.  So best to just cut the problem at the root.
sizeVectorized173   static constexpr size_type size() {
174     return VECTOR_WIDTH / sizeof(T);
175   }
VectorizedVectorized176   Vectorized() : values{static_cast<T>(0)} {}
VectorizedVectorized177   Vectorized(T val) {
178     for (int i = 0; i != size(); i++) {
179       values[i] = val;
180     }
181   }
182   template<typename... Args,
183            typename = std::enable_if_t<(sizeof...(Args) == size())>>
VectorizedVectorized184   Vectorized(Args... vals) : values{vals...}{
185   }
186   // This also implies const T& operator[](int idx) const
187   inline operator const T*() const {
188     return values;
189   }
190   // This also implies T& operator[](int idx)
191   inline operator T*() {
192     return values;
193   }
194   // Return the values as char* for type punning
195   auto as_bytes() const -> const char* {
196     return reinterpret_cast<const char*>(values);
197   }
198   template <int64_t mask_>
blendVectorized199   static Vectorized<T> blend(const Vectorized<T>& a, const Vectorized<T>& b) {
200     int64_t mask = mask_;
201     Vectorized vector;
202     for (const auto i : c10::irange(size())) {
203       if (mask & 0x01) {
204         vector[i] = b[i];
205       } else {
206         vector[i] = a[i];
207       }
208       mask = mask >> 1;
209     }
210     return vector;
211   }
blendvVectorized212   static Vectorized<T> blendv(const Vectorized<T>& a, const Vectorized<T>& b,
213                           const Vectorized<T>& mask) {
214     Vectorized vector;
215     int_same_size_t<T> buffer[size()];
216     mask.store(buffer);
217     for (const auto i : c10::irange(size())) {
218       if (buffer[i] & 0x01)
219        {
220         vector[i] = b[i];
221       } else {
222         vector[i] = a[i];
223       }
224     }
225     return vector;
226   }
227   template<typename step_t>  // step sometimes requires a higher precision type (e.g., T=int, step_t=double)
228   static Vectorized<T> arange(T base = static_cast<T>(0), step_t step = static_cast<step_t>(1)) {
229     Vectorized vector;
230     for (const auto i : c10::irange(size())) {
231       vector.values[i] = base + i * step;
232     }
233     return vector;
234   }
235   static Vectorized<T> set(const Vectorized<T>& a, const Vectorized<T>& b, int64_t count = size()) {
236     Vectorized vector;
237     for (const auto i : c10::irange(size())) {
238       if (i < count) {
239         vector[i] = b[i];
240       } else {
241         vector[i] = a[i];
242       }
243     }
244     return vector;
245   }
loaduVectorized246   static Vectorized<T> loadu(const void* ptr) {
247     Vectorized vector;
248     std::memcpy(vector.values, ptr, VECTOR_WIDTH);
249     return vector;
250   }
loaduVectorized251   static Vectorized<T> loadu(const void* ptr, int64_t count) {
252     Vectorized vector;
253     std::memcpy(vector.values, ptr, count * sizeof(T));
254     return vector;
255   }
loadu_one_fourthVectorized256   static Vectorized<T> loadu_one_fourth(const void* ptr) {
257     static_assert(std::is_same_v<T, signed char> || std::is_same_v<T, unsigned char>, "For byte types only");
258     return Vectorized::loadu(ptr, 8);
259   }
260 
261   void store(void* ptr, int count = size()) const {
262     std::memcpy(ptr, values, count * sizeof(T));
263   }
zero_maskVectorized264   int zero_mask() const {
265     // returns an integer mask where all zero elements are translated to 1-bit and others are translated to 0-bit
266     int mask = 0;
267     for (int i = 0; i < size(); ++ i) {
268       if (values[i] == static_cast<T>(0)) {
269         mask |= (1 << i);
270       }
271     }
272     return mask;
273   }
isnanVectorized274   Vectorized<T> isnan() const {
275     Vectorized<T> vector;
276     for (int64_t i = 0; i != size(); i++) {
277       if (_isnan(values[i])) {
278         std::memset(static_cast<void*>(vector.values + i), 0xFF, sizeof(T));
279       } else {
280         std::memset(static_cast<void*>(vector.values + i), 0, sizeof(T));
281       }
282     }
283     return vector;
284   }
has_inf_nanVectorized285   bool has_inf_nan() const {
286     for (int64_t i = 0; i != size(); i++) {
287       if(_isnan(values[i]) || _isinf(values[i])) {
288         return true;
289       }
290     }
291     return false;
292   }
mapVectorized293   Vectorized<T> map(T (*const f)(T)) const {
294     Vectorized<T> ret;
295     for (int64_t i = 0; i != size(); i++) {
296       ret[i] = f(values[i]);
297     }
298     return ret;
299   }
mapVectorized300   Vectorized<T> map(T (*const f)(const T &)) const {
301     Vectorized<T> ret;
302     for (int64_t i = 0; i != size(); i++) {
303       ret[i] = f(values[i]);
304     }
305     return ret;
306   }
307   template <typename other_t_abs = T,
308             typename std::enable_if_t<!is_floating_point_v<other_t_abs> && !c10::is_complex<other_t_abs>::value, int> = 0>
absVectorized309   Vectorized<T> abs() const {
310     // other_t_abs is for SFINAE and clarity. Make sure it is not changed.
311     static_assert(std::is_same_v<other_t_abs, T>, "other_t_abs must be T");
312     return map([](T x) -> T { return x < static_cast<T>(0) ? -x : x; });
313   }
314   template <typename float_t_abs = T,
315             typename std::enable_if_t<is_floating_point_v<float_t_abs>, int> = 0>
absVectorized316   Vectorized<T> abs() const {
317     // float_t_abs is for SFINAE and clarity. Make sure it is not changed.
318     static_assert(std::is_same_v<float_t_abs, T>, "float_t_abs must be T");
319     // Specifically deal with floating-point because the generic code above won't handle -0.0 (which should result in
320     // 0.0) properly.
321     return map([](T x) -> T { return std::abs(x); });
322   }
323   template <typename complex_t_abs = T,
324             typename std::enable_if_t<c10::is_complex<complex_t_abs>::value, int> = 0>
absVectorized325   Vectorized<T> abs() const {
326     // complex_t_abs is for SFINAE and clarity. Make sure it is not changed.
327     static_assert(std::is_same_v<complex_t_abs, T>, "complex_t_abs must be T");
328     // Specifically map() does not perform the type conversion needed by abs.
329     return map([](T x) { return static_cast<T>(std::abs(x)); });
330   }
331 
332   template <typename other_t_sgn = T,
333             typename std::enable_if_t<c10::is_complex<other_t_sgn>::value, int> = 0>
sgnVectorized334   Vectorized<T> sgn() const {
335     return map(at::native::sgn_impl);
336   }
337 
338   template <typename other_t_angle = T,
339             typename std::enable_if_t<!c10::is_complex<other_t_angle>::value, int> = 0>
angleVectorized340   Vectorized<T> angle() const {
341     // other_t_angle is for SFINAE and clarity. Make sure it is not changed.
342     static_assert(std::is_same_v<other_t_angle, T>, "other_t_angle must be T");
343     return map(at::native::angle_impl<T>);  // compiler is unable to resolve the overload without <T>
344   }
345   template <typename complex_t_angle = T,
346             typename std::enable_if_t<c10::is_complex<complex_t_angle>::value, int> = 0>
angleVectorized347   Vectorized<T> angle() const {
348     // complex_t_angle is for SFINAE and clarity. Make sure it is not changed.
349     static_assert(std::is_same_v<complex_t_angle, T>, "complex_t_angle must be T");
350     return map([](T x) { return static_cast<T>(std::arg(x)); });
351   }
352   template <typename other_t_real = T,
353             typename std::enable_if_t<!c10::is_complex<other_t_real>::value, int> = 0>
realVectorized354   Vectorized<T> real() const {
355     // other_t_real is for SFINAE and clarity. Make sure it is not changed.
356     static_assert(std::is_same_v<other_t_real, T>, "other_t_real must be T");
357     return *this;
358   }
359   template <typename complex_t_real = T,
360             typename std::enable_if_t<c10::is_complex<complex_t_real>::value, int> = 0>
realVectorized361   Vectorized<T> real() const {
362     // complex_t_real is for SFINAE and clarity. Make sure it is not changed.
363     static_assert(std::is_same_v<complex_t_real, T>, "complex_t_real must be T");
364     return map([](T x) { return static_cast<T>(x.real()); });
365   }
366   template <typename other_t_imag = T,
367             typename std::enable_if_t<!c10::is_complex<other_t_imag>::value, int> = 0>
imagVectorized368   Vectorized<T> imag() const {
369     // other_t_imag is for SFINAE and clarity. Make sure it is not changed.
370     static_assert(std::is_same_v<other_t_imag, T>, "other_t_imag must be T");
371     return Vectorized(0);
372   }
373   template <typename complex_t_imag = T,
374             typename std::enable_if_t<c10::is_complex<complex_t_imag>::value, int> = 0>
imagVectorized375   Vectorized<T> imag() const {
376     // complex_t_imag is for SFINAE and clarity. Make sure it is not changed.
377     static_assert(std::is_same_v<complex_t_imag, T>, "complex_t_imag must be T");
378     return map([](T x) { return static_cast<T>(x.imag()); });
379   }
380   template <typename other_t_conj = T,
381             typename std::enable_if_t<!c10::is_complex<other_t_conj>::value, int> = 0>
conjVectorized382   Vectorized<T> conj() const {
383     // other_t_conj is for SFINAE and clarity. Make sure it is not changed.
384     static_assert(std::is_same_v<other_t_conj, T>, "other_t_conj must be T");
385     return *this;
386   }
387   template <typename complex_t_conj = T,
388             typename std::enable_if_t<c10::is_complex<complex_t_conj>::value, int> = 0>
conjVectorized389   Vectorized<T> conj() const {
390     // complex_t_conj is for SFINAE and clarity. Make sure it is not changed.
391     static_assert(std::is_same_v<complex_t_conj, T>, "complex_t_conj must be T");
392     return map([](T x) { return static_cast<T>(std::conj(x)); });
393   }
acosVectorized394   Vectorized<T> acos() const {
395     return map(std::acos);
396   }
acoshVectorized397   Vectorized<T> acosh() const {
398     return map(std::acosh);
399   }
asinVectorized400   Vectorized<T> asin() const {
401     return map(std::asin);
402   }
atanVectorized403   Vectorized<T> atan() const {
404     return map(std::atan);
405   }
atanhVectorized406   Vectorized<T> atanh() const {
407     return map(std::atanh);
408   }
atan2Vectorized409   Vectorized<T> atan2(const Vectorized<T> &exp) const {
410     Vectorized<T> ret;
411     for (const auto i : c10::irange(size())) {
412       ret[i] = std::atan2(values[i], exp[i]);
413     }
414     return ret;
415   }
416   template <
417     typename U = T,
418     typename std::enable_if_t<is_floating_point_v<U>, int> = 0>
copysignVectorized419   Vectorized<T> copysign(const Vectorized<T> &sign) const {
420     Vectorized<T> ret;
421     for (size_type i = 0; i < size(); i++) {
422       ret[i] = c10::copysign(values[i], sign[i]);
423     }
424     return ret;
425   }
erfVectorized426   Vectorized<T> erf() const {
427     return map(std::erf);
428   }
erfcVectorized429   Vectorized<T> erfc() const {
430     return map(std::erfc);
431   }
erfinvVectorized432   Vectorized<T> erfinv() const {
433     return map(calc_erfinv);
434   }
expVectorized435   Vectorized<T> exp() const {
436     return map(std::exp);
437   }
exp2Vectorized438   Vectorized<T> exp2() const {
439     return map(exp2_impl);
440   }
expm1Vectorized441   Vectorized<T> expm1() const {
442     return map(std::expm1);
443   }
exp_u20Vectorized444   Vectorized<T> exp_u20() const {
445     return map(std::exp);
446   }
fracVectorized447   Vectorized<T> frac() const {
448     return *this - this->trunc();
449   }
450   template <
451     typename U = T,
452     typename std::enable_if_t<is_floating_point_v<U>, int> = 0>
fmodVectorized453   Vectorized<T> fmod(const Vectorized<T>& q) const {
454     // U is for SFINAE purposes only. Make sure it is not changed.
455     static_assert(std::is_same_v<U, T>, "U must be T");
456     Vectorized<T> ret;
457     for (const auto i : c10::irange(size())) {
458       ret[i] = std::fmod(values[i], q[i]);
459     }
460     return ret;
461   }
logVectorized462   Vectorized<T> log() const {
463     return map(std::log);
464   }
log10Vectorized465   Vectorized<T> log10() const {
466     return map(std::log10);
467   }
log1pVectorized468   Vectorized<T> log1p() const {
469     return map(std::log1p);
470   }
471   template <typename other_t_log2 = T,
472             typename std::enable_if_t<!c10::is_complex<other_t_log2>::value, int> = 0>
log2Vectorized473   Vectorized<T> log2() const {
474     // other_t_log2 is for SFINAE and clarity. Make sure it is not changed.
475     static_assert(std::is_same_v<other_t_log2, T>, "other_t_log2 must be T");
476     return map(std::log2);
477   }
478   template <typename complex_t_log2 = T,
479             typename std::enable_if_t<c10::is_complex<complex_t_log2>::value, int> = 0>
log2Vectorized480   Vectorized<T> log2() const {
481     // complex_t_log2 is for SFINAE and clarity. Make sure it is not changed.
482     static_assert(std::is_same_v<complex_t_log2, T>, "complex_t_log2 must be T");
483     const T log_2 = T(std::log(2.0));
484     return Vectorized(map(std::log))/Vectorized(log_2);
485   }
ceilVectorized486   Vectorized<T> ceil() const {
487     return map(at::native::ceil_impl);
488   }
cosVectorized489   Vectorized<T> cos() const {
490     return map(std::cos);
491   }
coshVectorized492   Vectorized<T> cosh() const {
493     return map(std::cosh);
494   }
floorVectorized495   Vectorized<T> floor() const {
496     return map(at::native::floor_impl);
497   }
hypotVectorized498   Vectorized<T> hypot(const Vectorized<T> &b) const {
499     Vectorized<T> ret;
500     for (const auto i : c10::irange(size())) {
501       ret[i] = std::hypot(values[i], b[i]);
502     }
503     return ret;
504   }
i0Vectorized505   Vectorized<T> i0() const {
506     return map(calc_i0);
507   }
i0eVectorized508   Vectorized<T> i0e() const {
509     return map(calc_i0e);
510   }
digammaVectorized511   Vectorized<T> digamma() const {
512     return map(calc_digamma);
513   }
igammaVectorized514   Vectorized<T> igamma(const Vectorized<T> &x) const {
515     Vectorized<T> ret;
516     for (const auto i : c10::irange(size())) {
517       ret[i] = calc_igamma(values[i], x[i]);
518     }
519     return ret;
520   }
igammacVectorized521   Vectorized<T> igammac(const Vectorized<T> &x) const {
522     Vectorized<T> ret;
523     for (const auto i : c10::irange(size())) {
524       ret[i] = calc_igammac(values[i], x[i]);
525     }
526     return ret;
527   }
negVectorized528   Vectorized<T> neg() const {
529     // NB: the trailing return type is needed because we need to coerce the
530     // return value back to T in the case of unary operator- incuring a
531     // promotion
532     return map([](T x) -> T { return -x; });
533   }
nextafterVectorized534   Vectorized<T> nextafter(const Vectorized<T> &b) const {
535     Vectorized<T> ret;
536     for (const auto i : c10::irange(size())) {
537       ret[i] = std::nextafter(values[i], b[i]);
538     }
539     return ret;
540   }
roundVectorized541   Vectorized<T> round() const {
542     // We do not use std::round because we would like to round midway numbers to the nearest even integer.
543     return map(at::native::round_impl);
544   }
sinVectorized545   Vectorized<T> sin() const {
546     return map(std::sin);
547   }
sinhVectorized548   Vectorized<T> sinh() const {
549     return map(std::sinh);
550   }
tanVectorized551   Vectorized<T> tan() const {
552     return map(std::tan);
553   }
tanhVectorized554   Vectorized<T> tanh() const {
555     return map(std::tanh);
556   }
truncVectorized557   Vectorized<T> trunc() const {
558     return map(at::native::trunc_impl);
559   }
lgammaVectorized560   Vectorized<T> lgamma() const {
561     return map(std::lgamma);
562   }
sqrtVectorized563   Vectorized<T> sqrt() const {
564     return map(std::sqrt);
565   }
reciprocalVectorized566   Vectorized<T> reciprocal() const {
567     return map([](T x) { return (T)(1) / x; });
568   }
rsqrtVectorized569   Vectorized<T> rsqrt() const {
570     return map([](T x) { return (T)1 / std::sqrt(x); });
571   }
powVectorized572   Vectorized<T> pow(const Vectorized<T> &exp) const {
573     Vectorized<T> ret;
574     for (const auto i : c10::irange(size())) {
575       ret[i] = std::pow(values[i], exp[i]);
576     }
577     return ret;
578   }
579 private:
580   template <typename Op>
binary_predVectorized581   inline Vectorized<T> binary_pred(const Vectorized<T>& other, Op op) const {
582     // All bits are set to 1 if the pred is true, otherwise 0.
583     Vectorized<T> vector;
584     for (int64_t i = 0; i != size(); i++) {
585       if (op(values[i], other.values[i])) {
586         std::memset(static_cast<void*>(vector.values + i), 0xFF, sizeof(T));
587       } else {
588         std::memset(static_cast<void*>(vector.values + i), 0, sizeof(T));
589       }
590     }
591     return vector;
592   }
593 
594 public:
595   Vectorized<T> operator==(const Vectorized<T>& other) const { return binary_pred(other, std::equal_to<T>()); }
596   Vectorized<T> operator!=(const Vectorized<T>& other) const { return binary_pred(other, std::not_equal_to<T>()); }
597   Vectorized<T> operator>=(const Vectorized<T>& other) const { return binary_pred(other, std::greater_equal<T>()); }
598   Vectorized<T> operator<=(const Vectorized<T>& other) const { return binary_pred(other, std::less_equal<T>()); }
599   Vectorized<T> operator>(const Vectorized<T>& other) const { return binary_pred(other, std::greater<T>()); }
600   Vectorized<T> operator<(const Vectorized<T>& other) const { return binary_pred(other, std::less<T>()); }
601 
602 private:
603   template <typename Op>
binary_pred_boolVectorized604   inline Vectorized<T> binary_pred_bool(const Vectorized<T>& other, Op op) const {
605     // 1 if the pred is true, otherwise 0.
606     Vectorized<T> vector;
607     for (int i = 0; i != size(); ++ i) {
608       vector[i] = static_cast<T>(op(values[i], other.values[i]));
609     }
610     return vector;
611   }
612 
613 public:
eqVectorized614   Vectorized<T> eq(const Vectorized<T>& other) const { return binary_pred_bool(other, std::equal_to<T>()); }
neVectorized615   Vectorized<T> ne(const Vectorized<T>& other) const { return binary_pred_bool(other, std::not_equal_to<T>()); }
gtVectorized616   Vectorized<T> gt(const Vectorized<T>& other) const { return binary_pred_bool(other, std::greater<T>()); }
geVectorized617   Vectorized<T> ge(const Vectorized<T>& other) const { return binary_pred_bool(other, std::greater_equal<T>()); }
ltVectorized618   Vectorized<T> lt(const Vectorized<T>& other) const { return binary_pred_bool(other, std::less<T>()); }
leVectorized619   Vectorized<T> le(const Vectorized<T>& other) const { return binary_pred_bool(other, std::less_equal<T>()); }
620 };
621 
622 template <class T> Vectorized<T> inline operator+(const Vectorized<T> &a, const Vectorized<T> &b) {
623   Vectorized<T> c;
624   for (int i = 0; i != Vectorized<T>::size(); i++) {
625     c[i] = a[i] + b[i];
626   }
627   return c;
628 }
629 
630 template <class T> Vectorized<T> inline operator-(const Vectorized<T> &a, const Vectorized<T> &b) {
631   Vectorized<T> c;
632   for (int i = 0; i != Vectorized<T>::size(); i++) {
633     c[i] = a[i] - b[i];
634   }
635   return c;
636 }
637 
638 template <class T> Vectorized<T> inline operator*(const Vectorized<T> &a, const Vectorized<T> &b) {
639   Vectorized<T> c;
640   for (int i = 0; i != Vectorized<T>::size(); i++) {
641     c[i] = a[i] * b[i];
642   }
643   return c;
644 }
645 
646 template <class T> Vectorized<T> inline operator/(const Vectorized<T> &a, const Vectorized<T> &b) __ubsan_ignore_float_divide_by_zero__ {
647   Vectorized<T> c;
648   for (int i = 0; i != Vectorized<T>::size(); i++) {
649     c[i] = a[i] / b[i];
650   }
651   return c;
652 }
653 
654 template <class T,
655           typename std::enable_if_t<!is_floating_point_v<T>, int> = 0>
656 Vectorized<T> inline operator%(const Vectorized<T> &a, const Vectorized<T> &b) __ubsan_ignore_float_divide_by_zero__ {
657   return a - a / b * b;
658 }
659 
660 template <class T> Vectorized<T> inline operator||(
661     const Vectorized<T> &a, const Vectorized<T> &b) {
662   Vectorized<T> c;
663   for (int i = 0; i != Vectorized<T>::size(); i++) {
664     c[i] = a[i] || b[i];
665   }
666   return c;
667 }
668 
669 // Implements the IEEE 754 201X `maximum` operation, which propagates NaN if
670 // either input is a NaN.
671 template <class T,
672           typename std::enable_if_t<!c10::is_complex<T>::value, int> = 0>
maximum(const Vectorized<T> & a,const Vectorized<T> & b)673 Vectorized<T> inline maximum(const Vectorized<T> &a, const Vectorized<T> &b) {
674   Vectorized<T> c;
675   for (int i = 0; i != Vectorized<T>::size(); i++) {
676     c[i] = (a[i] > b[i]) ? a[i] : b[i];
677     if (_isnan(a[i])) {
678       // If either input is NaN, propagate a NaN.
679       // NOTE: The case where b[i] was NaN is handled correctly by the naive
680       // ternary operator above.
681       c[i] = a[i];
682     }
683   }
684   return c;
685 }
686 
687 template <class T,
688           typename std::enable_if_t<c10::is_complex<T>::value, int> = 0>
maximum(const Vectorized<T> & a,const Vectorized<T> & b)689 Vectorized<T> inline maximum(const Vectorized<T> &a, const Vectorized<T> &b) {
690   Vectorized<T> c;
691   for (int i = 0; i != Vectorized<T>::size(); i++) {
692     c[i] = (std::abs(a[i]) > std::abs(b[i])) ? a[i] : b[i];
693     if (_isnan(a[i])) {
694       // If either input is NaN, propagate a NaN.
695       // NOTE: The case where b[i] was NaN is handled correctly by the naive
696       // ternary operator above.
697       c[i] = a[i];
698     }
699   }
700   return c;
701 }
702 
703 // Implements the IEEE 754 201X `minimum` operation, which propagates NaN if
704 // either input is a NaN.
705 template <class T,
706           typename std::enable_if_t<!c10::is_complex<T>::value, int> = 0>
minimum(const Vectorized<T> & a,const Vectorized<T> & b)707 Vectorized<T> inline minimum(const Vectorized<T> &a, const Vectorized<T> &b) {
708   Vectorized<T> c;
709   for (int i = 0; i != Vectorized<T>::size(); i++) {
710     c[i] = (a[i] < b[i]) ? a[i] : b[i];
711     if (_isnan(a[i])) {
712       // If either input is NaN, propagate a NaN.
713       // NOTE: The case where b[i] was NaN is handled correctly by the naive
714       // ternary operator above.
715       c[i] = a[i];
716     }
717   }
718   return c;
719 }
720 
721 template <class T,
722           typename std::enable_if_t<c10::is_complex<T>::value, int> = 0>
minimum(const Vectorized<T> & a,const Vectorized<T> & b)723 Vectorized<T> inline minimum(const Vectorized<T> &a, const Vectorized<T> &b) {
724   Vectorized<T> c;
725   for (int i = 0; i != Vectorized<T>::size(); i++) {
726     c[i] = (std::abs(a[i]) < std::abs(b[i])) ? a[i] : b[i];
727     if (_isnan(a[i])) {
728       // If either input is NaN, propagate a NaN.
729       // NOTE: The case where b[i] was NaN is handled correctly by the naive
730       // ternary operator above.
731       c[i] = a[i];
732     }
733   }
734   return c;
735 }
736 
737 template <class T,
738           typename std::enable_if_t<!c10::is_complex<T>::value, int> = 0>
clamp(const Vectorized<T> & a,const Vectorized<T> & min_vec,const Vectorized<T> & max_vec)739 Vectorized<T> inline clamp(const Vectorized<T> &a, const Vectorized<T> &min_vec, const Vectorized<T> &max_vec) {
740   Vectorized<T> c;
741   for (int i = 0; i != Vectorized<T>::size(); i++) {
742     c[i] = std::min(std::max(a[i], min_vec[i]), max_vec[i]);
743   }
744   return c;
745 }
746 
747 template <class T,
748           typename std::enable_if_t<!c10::is_complex<T>::value, int> = 0>
clamp_max(const Vectorized<T> & a,const Vectorized<T> & max_vec)749 Vectorized<T> inline clamp_max(const Vectorized<T> &a, const Vectorized<T> &max_vec) {
750   Vectorized<T> c;
751   for (int i = 0; i != Vectorized<T>::size(); i++) {
752     c[i] = a[i] > max_vec[i] ? max_vec[i] : a[i];
753   }
754   return c;
755 }
756 
757 template <class T,
758           typename std::enable_if_t<!c10::is_complex<T>::value, int> = 0>
clamp_min(const Vectorized<T> & a,const Vectorized<T> & min_vec)759 Vectorized<T> inline clamp_min(const Vectorized<T> &a, const Vectorized<T> &min_vec) {
760   Vectorized<T> c;
761   for (int i = 0; i != Vectorized<T>::size(); i++) {
762     c[i] = a[i] < min_vec[i] ? min_vec[i] : a[i];
763   }
764   return c;
765 }
766 
767 struct Vectorizedi;
768 
769 #if defined(CPU_CAPABILITY_AVX2) || defined(CPU_CAPABILITY_AVX512)
770 template <class T, typename Op>
bitwise_binary_op(const Vectorized<T> & a,const Vectorized<T> & b,Op op)771 static inline Vectorized<T> bitwise_binary_op(const Vectorized<T> &a, const Vectorized<T> &b, Op op) {
772   int_vector buffer;
773 #if defined(CPU_CAPABILITY_AVX2)
774   int_vector a_buffer = _mm256_load_si256(reinterpret_cast<const int_vector*>((const T*)a));
775   int_vector b_buffer = _mm256_load_si256(reinterpret_cast<const int_vector*>((const T*)b));
776 #elif defined(CPU_CAPABILITY_AVX512)
777   int_vector a_buffer = _mm512_load_si512(reinterpret_cast<const int_vector*>((const T*)a));
778   int_vector b_buffer = _mm512_load_si512(reinterpret_cast<const int_vector*>((const T*)b));
779 #endif
780   buffer = op(a_buffer, b_buffer);
781   __at_align__ T results[Vectorized<T>::size()];
782 
783 #if defined(CPU_CAPABILITY_AVX2)
784   _mm256_store_si256(reinterpret_cast<int_vector*>(results), buffer);
785 #elif defined(CPU_CAPABILITY_AVX512)
786   _mm512_store_si512(reinterpret_cast<int_vector*>(results), buffer);
787 #endif
788   return Vectorized<T>::loadu(results);
789 }
790 
791 template<class T, typename std::enable_if_t<!std::is_base_of<Vectorizedi, Vectorized<T>>::value, int> = 0>
792 inline Vectorized<T> operator&(const Vectorized<T>& a, const Vectorized<T>& b) {
793   // We enclose _mm512_and_si512 or _mm256_and_si256 with lambda because it is always_inline
794 #if defined(CPU_CAPABILITY_AVX2)
795   return bitwise_binary_op(a, b, [](int_vector a, int_vector b) { return _mm256_and_si256(a, b); });
796 #elif defined(CPU_CAPABILITY_AVX512)
797   return bitwise_binary_op(a, b, [](int_vector a, int_vector b) { return _mm512_and_si512(a, b); });
798 #endif
799 }
800 template<class T, typename std::enable_if_t<!std::is_base_of<Vectorizedi, Vectorized<T>>::value, int> = 0>
801 inline Vectorized<T> operator|(const Vectorized<T>& a, const Vectorized<T>& b) {
802   // We enclose _mm512_or_si512 or _mm256_or_si256 with lambda because it is always_inline
803 #if defined(CPU_CAPABILITY_AVX2)
804   return bitwise_binary_op(a, b, [](int_vector a, int_vector b) { return _mm256_or_si256(a, b); });
805 #elif defined(CPU_CAPABILITY_AVX512)
806   return bitwise_binary_op(a, b, [](int_vector a, int_vector b) { return _mm512_or_si512(a, b); });
807 #endif
808 }
809 template<class T, typename std::enable_if_t<!std::is_base_of<Vectorizedi, Vectorized<T>>::value, int> = 0>
810 inline Vectorized<T> operator^(const Vectorized<T>& a, const Vectorized<T>& b) {
811   // We enclose _mm512_xor_si512 or _mm256_xor_si256 with lambda because it is always_inline
812 #if defined(CPU_CAPABILITY_AVX2)
813   return bitwise_binary_op(a, b, [](int_vector a, int_vector b) { return _mm256_xor_si256(a, b); });
814 #elif defined(CPU_CAPABILITY_AVX512)
815   return bitwise_binary_op(a, b, [](int_vector a, int_vector b) { return _mm512_xor_si512(a, b); });
816 #endif
817 }
818 
819 #else
820 
821 template <typename T>
822 auto load(char const* data) -> T {
823   T ret;
824   std::memcpy(&ret, data, sizeof(ret));
825   return ret;
826 }
827 
828 template<class T, typename Op>
bitwise_binary_op(const Vectorized<T> & a,const Vectorized<T> & b,Op op)829 static inline Vectorized<T> bitwise_binary_op(const Vectorized<T> &a, const Vectorized<T> &b, Op op) {
830   static constexpr uint32_t element_no = VECTOR_WIDTH / sizeof(intmax_t);
831   __at_align__ intmax_t buffer[element_no];
832   static_assert(VECTOR_WIDTH % sizeof(intmax_t) == 0, "VECTOR_WIDTH not a multiple of sizeof(intmax_t)");
833   static_assert(sizeof(buffer) == sizeof(Vectorized<T>), "sizeof(buffer) must match sizeof(Vectorized<T>)");
834   // We should be using memcpy in order to respect the strict aliasing rule
835   // see: https://github.com/pytorch/pytorch/issues/66119
836   // Using char* is defined in the C11 standard 6.5 Expression paragraph 7
837   // (http://www.open-std.org/jtc1/sc22/wg14/www/docs/n1570.pdf)
838   const auto* a_data = a.as_bytes();
839   const auto* b_data = b.as_bytes();
840   // load each intmax_t chunk and process; increase pointers by sizeof(intmax_t)
841   for (auto& out : buffer) {
842     out = op(load<intmax_t>(a_data), load<intmax_t>(b_data));
843     a_data += sizeof(intmax_t);
844     b_data += sizeof(intmax_t);
845   }
846   assert(a_data == a.as_bytes() + sizeof(a));
847   assert(b_data == b.as_bytes() + sizeof(b));
848   return Vectorized<T>::loadu(buffer);
849 }
850 
851 template<class T, typename std::enable_if_t<!std::is_base_of_v<Vectorizedi, Vectorized<T>>, int> = 0>
852 inline Vectorized<T> operator&(const Vectorized<T>& a, const Vectorized<T>& b) {
853   return bitwise_binary_op(a, b, std::bit_and<intmax_t>());
854 }
855 template<class T, typename std::enable_if_t<!std::is_base_of_v<Vectorizedi, Vectorized<T>>, int> = 0>
856 inline Vectorized<T> operator|(const Vectorized<T>& a, const Vectorized<T>& b) {
857   return bitwise_binary_op(a, b, std::bit_or<intmax_t>());
858 }
859 template<class T, typename std::enable_if_t<!std::is_base_of_v<Vectorizedi, Vectorized<T>>, int> = 0>
860 inline Vectorized<T> operator^(const Vectorized<T>& a, const Vectorized<T>& b) {
861   return bitwise_binary_op(a, b, std::bit_xor<intmax_t>());
862 }
863 
864 #endif // defined(CPU_CAPABILITY_AVX2) || defined(CPU_CAPABILITY_AVX512)
865 
866 template<class T, typename std::enable_if_t<!std::is_base_of_v<Vectorizedi, Vectorized<T>>, int> = 0>
867 inline Vectorized<T> operator~(const Vectorized<T>& a) {
868   using int_t = int_same_size_t<T>;
869   Vectorized<T> ones(c10::bit_cast<T>((int_t)(~(int_t)0)));  // All bits are 1
870   return a ^ ones;
871 }
872 
873 template <class T> Vectorized<T> inline operator<<(const Vectorized<T> &a, const Vectorized<T> &b) {
874   constexpr T max_shift = sizeof(T) * CHAR_BIT;
875   Vectorized<T> c;
876   for (int i = 0; i != Vectorized<T>::size(); i++) {
877     T shift = b[i];
878     if ((static_cast<std::make_signed_t<T>>(shift) < 0) || (shift >= max_shift)) {
879       c[i] = 0;
880     } else {
881       c[i] = static_cast<std::make_unsigned_t<T>>(a[i]) << shift;
882     }
883   }
884   return c;
885 }
886 
887 template <class T> Vectorized<T> inline operator>>(const Vectorized<T> &a, const Vectorized<T> &b) {
888   // right shift value to retain sign bit for signed and no bits for unsigned
889   constexpr T max_shift = sizeof(T) * CHAR_BIT - std::is_signed_v<T>;
890   Vectorized<T> c;
891   for (int i = 0; i != Vectorized<T>::size(); i++) {
892     T shift = b[i];
893     if ((static_cast<std::make_signed_t<T>>(shift) < 0) || (shift >= max_shift)) {
894       c[i] = a[i] >> max_shift;
895     } else {
896       c[i] = a[i] >> shift;
897     }
898   }
899   return c;
900 }
901 
902 template <typename T>
903 inline Vectorized<T>& operator += (Vectorized<T>& a, const Vectorized<T>& b) {
904   a = a + b;
905   return a;
906 }
907 template <typename T>
908 inline Vectorized<T>& operator -= (Vectorized<T>& a, const Vectorized<T>& b) {
909   a = a - b;
910   return a;
911 }
912 template <typename T>
913 inline Vectorized<T>& operator /= (Vectorized<T>& a, const Vectorized<T>& b) {
914   a = a / b;
915   return a;
916 }
917 template <typename T>
918 inline Vectorized<T>& operator %= (Vectorized<T>& a, const Vectorized<T>& b) {
919   a = a % b;
920   return a;
921 }
922 template <typename T>
923 inline Vectorized<T>& operator *= (Vectorized<T>& a, const Vectorized<T>& b) {
924   a = a * b;
925   return a;
926 }
927 
928 template <typename T>
929 inline Vectorized<T>& operator <<= (Vectorized<T>& a, const Vectorized<T>& b) {
930   a = a << b;
931   return a;
932 }
933 
934 template <typename T>
935 inline Vectorized<T>& operator >>= (Vectorized<T>& a, const Vectorized<T>& b) {
936   a = a >> b;
937   return a;
938 }
939 
940 template <typename T>
fmadd(const Vectorized<T> & a,const Vectorized<T> & b,const Vectorized<T> & c)941 inline Vectorized<T> fmadd(const Vectorized<T>& a, const Vectorized<T>& b, const Vectorized<T>& c) {
942   return a * b + c;
943 }
944 
945 template <typename T>
fmsub(const Vectorized<T> & a,const Vectorized<T> & b,const Vectorized<T> & c)946 inline Vectorized<T> fmsub(const Vectorized<T>& a, const Vectorized<T>& b, const Vectorized<T>& c) {
947   return a * b - c;
948 }
949 
950 template <typename T>
951 Vectorized<T> inline operator&&(
952     const Vectorized<T>& a,
953     const Vectorized<T>& b) {
954   Vectorized<T> ret;
955   for (int i = 0; i != Vectorized<T>::size(); i++) {
956     ret[i] = a[i] && b[i];
957   }
958   return ret;
959 }
960 
961 template <int64_t scale = 1, typename T = void>
962 std::enable_if_t<scale == 1 || scale == 2 || scale == 4 || scale == 8, Vectorized<T>>
gather(T const * base_addr,const Vectorized<int_same_size_t<T>> & vindex)963 inline gather(T const* base_addr, const Vectorized<int_same_size_t<T>>& vindex) {
964   static constexpr int size = Vectorized<T>::size();
965   int_same_size_t<T> index_arr[size];
966   vindex.store(static_cast<void*>(index_arr));
967   T buffer[size];
968   for (const auto i : c10::irange(size)) {
969     buffer[i] = base_addr[index_arr[i] * scale / sizeof(T)];
970   }
971   return Vectorized<T>::loadu(static_cast<void*>(buffer));
972 }
973 
974 template <int64_t scale = 1, typename T = void>
975 std::enable_if_t<scale == 1 || scale == 2 || scale == 4 || scale == 8, Vectorized<T>>
mask_gather(const Vectorized<T> & src,T const * base_addr,const Vectorized<int_same_size_t<T>> & vindex,Vectorized<T> & mask)976 inline mask_gather(const Vectorized<T>& src, T const* base_addr,
977                    const Vectorized<int_same_size_t<T>>& vindex, Vectorized<T>& mask) {
978   static constexpr int size = Vectorized<T>::size();
979   T src_arr[size];
980   int_same_size_t<T> mask_arr[size];  // use int type so we can logical and
981   int_same_size_t<T> index_arr[size];
982   src.store(static_cast<void*>(src_arr));
983   mask.store(static_cast<void*>(mask_arr));
984   vindex.store(static_cast<void*>(index_arr));
985   T buffer[size];
986   for (const auto i : c10::irange(size)) {
987     if (mask_arr[i] & 0x01) {  // check highest bit
988       buffer[i] = base_addr[index_arr[i] * scale / sizeof(T)];
989     } else {
990       buffer[i] = src_arr[i];
991     }
992   }
993   mask = Vectorized<T>();  // "zero out" mask
994   return Vectorized<T>::loadu(static_cast<void*>(buffer));
995 }
996 
997 // Cast a given vector to another type without changing the bits representation.
998 // So a Vectorized<double> of 512 bits containing all ones can be cast to a
999 // Vectorized<int64_t> of 512 bits containing all ones (i.e., eight negative 1s).
1000 // A Vec<double> of 256 bits containing all ones can be cast to a
1001 // Vec<int64_t> of 256 bits containing all ones (i.e., four negative 1s).
1002 // There is a struct here because we don't have static_if and I can't
1003 // partially specialize a templated function.
1004 template<typename dst_t, typename src_t>
1005 struct CastImpl {
applyCastImpl1006   static inline Vectorized<dst_t> apply(const Vectorized<src_t>& src) {
1007     src_t src_arr[Vectorized<src_t>::size()];
1008     src.store(static_cast<void*>(src_arr));
1009     return Vectorized<dst_t>::loadu(static_cast<const void*>(src_arr));
1010   }
1011 };
1012 
1013 template<typename scalar_t>
1014 struct CastImpl<scalar_t, scalar_t> {
1015   static inline Vectorized<scalar_t> apply(const Vectorized<scalar_t>& src) {
1016     return src;
1017   }
1018 };
1019 
1020 template<typename dst_t, typename src_t>
1021 inline Vectorized<dst_t> cast(const Vectorized<src_t>& src) {
1022   return CastImpl<dst_t, src_t>::apply(src);
1023 }
1024 
1025 template <typename T, typename IntType = int_same_size_t<T>>
1026 inline Vectorized<IntType> convert_to_int_of_same_size(const Vectorized<T>& src) {
1027   static_assert(sizeof(T) == sizeof(IntType));
1028   static constexpr int size = Vectorized<T>::size();
1029 
1030   std::array<T, size> src_arr;
1031   src.store(static_cast<void*>(src_arr.data()));
1032   std::array<IntType, size> buffer;
1033   std::transform(src_arr.cbegin(), src_arr.cend(), buffer.begin(),
1034                  [](const T& x) { return static_cast<IntType>(x); });
1035   return Vectorized<IntType>::loadu(static_cast<const void*>(buffer.data()));
1036 }
1037 
1038 template <typename T, typename IntType = int_same_size_t<T>>
1039 inline Vectorized<T> convert_to_fp_of_same_size(const Vectorized<IntType>& src) {
1040   static_assert(sizeof(T) == sizeof(IntType));
1041   static constexpr int size = Vectorized<T>::size();
1042 
1043   std::array<IntType, size> src_arr;
1044   src.store(static_cast<void*>(src_arr.data()));
1045   std::array<T, size> buffer;
1046   std::transform(src_arr.cbegin(), src_arr.cend(), buffer.begin(),
1047                  [](const IntType& x) { return static_cast<T>(x); });
1048   return Vectorized<T>::loadu(static_cast<const void*>(buffer.data()));
1049 }
1050 
1051 // Example inputs for AVX512:
1052 // a   Vectorized<float>   = {a0, b0, a1, b1, a2, b2, a3, b3, a4, b4, a5, b5, a6, b6, a7, b7}
1053 // b   Vectorized<float>   = {a8, b8, a9, b9, a10, b10, a11, b11, a12, b12, a13, b13, a14, b14, a15, b15}
1054 // returns:
1055 //           Vectorized<float>   = {a0, a1, a2, a3, a4, a5, a6, a7, a8, a9, a10, a11, a12, a13, a14, a15}
1056 //           Vectorized<float>   = {b0, b1, b2, b3, b4, b5, b6, b7, b8, b9, b10, b11, b12, b13, b14, b15}
1057 // Example inputs for AVX2: a           Vectorized<float>   = {a0, b0, a1, b1, a2, b2, a3, b3}
1058 //               b                      Vectorized<float>   = {a4, b4, a5, b5, a6, b6, a7, b7}
1059 //       returns:                       Vectorized<float>   = {a0, a1, a2, a3, a4, a5, a6, a7}
1060 //                                      Vectorized<float>   = {b0, b1, b2, b3, b4, b5, b6, b7}
1061 template <typename T>
1062 inline std::enable_if_t<Vectorized<T>::size() % 2 == 0, std::pair<Vectorized<T>, Vectorized<T>>>
1063 deinterleave2(const Vectorized<T>& a, const Vectorized<T>& b) {
1064   static constexpr int size = Vectorized<T>::size();
1065   static constexpr int half_size = size / 2;
1066   T a_arr[size];
1067   T b_arr[size];
1068   T buffer1[size];
1069   T buffer2[size];
1070   a.store(static_cast<void*>(a_arr));
1071   b.store(static_cast<void*>(b_arr));
1072   for (const auto i : c10::irange(half_size)) {
1073     buffer1[i] = a_arr[i * 2];
1074     buffer1[half_size + i] = b_arr[i * 2];
1075     buffer2[i] = a_arr[i * 2 + 1];
1076     buffer2[half_size + i] = b_arr[i * 2 + 1];
1077   }
1078   return std::make_pair(Vectorized<T>::loadu(static_cast<void*>(buffer1)),
1079                         Vectorized<T>::loadu(static_cast<void*>(buffer2)));
1080 }
1081 
1082 // inverse operation of deinterleave2
1083 // Example inputs for AVX512:
1084 //  a       Vectorized<float>   = {a0, a1, a2, a3, a4, a5, a6, a7, a8, a9, a10, a11, a12, a13, a14, a15}
1085 //  b       Vectorized<float>   = {b0, b1, b2, b3, b4, b5, b6, b7, b8, b9, b10, b11, b12, b13, b14, b15}
1086 // returns, for AVX512:
1087 //          Vectorized<float>   = {a0, b0, a1, b1, a2, b2, a3, b3, a4, b4, a5, b5, a6, b6, a7, b7}
1088 //          Vectorized<float>   = {a8, b8, a9, b9, a10, b10, a11, b11, a12, b12, a13, b13, a14, b14, a15, b15}
1089 // Example inputs for AVX2 : a           Vectorized<float>   = {a0, a1, a2, a3, a4, a5, a6, a7}
1090 //                   b                   Vectorized<float>   = {b0, b1, b2, b3, b4, b5, b6, b7}
1091 //       returns:            Vectorized<float>   = {a0, b0, a1, b1, a2, b2, a3, b3}
1092 //                           Vectorized<float>   = {a4, b4, a5, b5, a6, b6, a7, b7}
1093 template <typename T>
1094 inline std::enable_if_t<Vectorized<T>::size() % 2 == 0, std::pair<Vectorized<T>, Vectorized<T>>>
1095 interleave2(const Vectorized<T>& a, const Vectorized<T>& b) {
1096   static constexpr int size = Vectorized<T>::size();
1097   static constexpr int half_size = size / 2;
1098   T a_arr[size];
1099   T b_arr[size];
1100   T buffer1[size];
1101   T buffer2[size];
1102   a.store(static_cast<void*>(a_arr));
1103   b.store(static_cast<void*>(b_arr));
1104   for (const auto i : c10::irange(half_size)) {
1105     buffer1[i * 2] = a_arr[i];
1106     buffer1[i * 2 + 1] = b_arr[i];
1107     buffer2[i * 2] = a_arr[half_size + i];
1108     buffer2[i * 2 + 1] = b_arr[half_size + i];
1109   }
1110   return std::make_pair(Vectorized<T>::loadu(static_cast<void*>(buffer1)),
1111                         Vectorized<T>::loadu(static_cast<void*>(buffer2)));
1112 }
1113 
1114 template <typename src_T, typename dst_T>
1115 inline void convert(const src_T *src, dst_T *dst, int64_t n) {
1116 #ifndef _MSC_VER
1117 # pragma unroll
1118 #endif
1119   for (C10_UNUSED const auto i : c10::irange(n)) {
1120     *dst = c10::convert<dst_T>(c10::load(src));
1121     src++;
1122     dst++;
1123   }
1124 }
1125 
1126 template <typename T>
1127 inline Vectorized<T> flip(const Vectorized<T> & data) {
1128   static constexpr int size = Vectorized<T>::size();
1129   T output[size];
1130   T buffer[size];
1131   data.store(static_cast<void*>(buffer));
1132   for (const auto i : c10::irange(size)) {
1133     output[i] = buffer[size - i - 1];
1134   }
1135   return Vectorized<T>::loadu(static_cast<void*>(output));
1136 }
1137 
1138 // Transpose the `src` buffer of type `T` and size (M,N) into the `dst` buffer. `ld_src` is the leading
1139 // dimension of `src` and `ld_dst` is the leading dimension of `dst`.
1140 template <typename T>
1141 inline void transpose_mxn(const T* src, int64_t ld_src, T* dst, int64_t ld_dst, int M, int N) {
1142   for (int i = 0; i < M; i++) {
1143     for (int j = 0; j < N; j++) {
1144       dst[j*ld_dst + i] = src[i*ld_src + j];
1145     }
1146   }
1147 }
1148 
1149 template <typename T, int M, int N>
1150 inline void transpose_mxn(const T* src, int64_t ld_src, T* dst, int64_t ld_dst) {
1151   transpose_mxn<T>(src, ld_src, dst, ld_dst, M, N);
1152 }
1153 
1154 }} // namespace at::vec::CPU_CAPABILITY
1155 
1156 // additional headers for more operations that depend on vec_base
1157 #include <ATen/cpu/vec/vec_n.h>
1158 #include <ATen/cpu/vec/vec_mask.h>
1159 #include <ATen/cpu/vec/vec_convert.h>
1160