1 #pragma once
2
3 // DO NOT DEFINE STATIC DATA IN THIS HEADER!
4 // See Note [Do not compile initializers with AVX]
5 //
6 // Note [Do not compile initializers with AVX]
7 // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
8 // If you define a static initializer in this file, the initialization will use
9 // AVX instructions because these object files are compiled with AVX enabled.
10 // We need to avoid non-trivial global data in these architecture specific files
11 // because there's no way to guard the global initializers with CPU capability
12 // detection.
13 //
14 // See https://github.com/pytorch/pytorch/issues/37577 for an instance
15 // of this bug in the past.
16
17 #include <array>
18 #include <algorithm>
19 #include <cassert>
20 #include <cstring>
21 #include <functional>
22 #include <cmath>
23 #include <type_traits>
24 #include <climits>
25
26 #include <ATen/cpu/vec/intrinsics.h>
27 #include <ATen/native/Math.h>
28 #include <ATen/NumericUtils.h>
29 #include <c10/util/Half.h>
30 #include <c10/util/BFloat16.h>
31 #include <c10/util/BFloat16-math.h>
32 #include <c10/util/copysign.h>
33 #include <ATen/native/cpu/zmath.h>
34 #include <c10/util/TypeCast.h>
35 #include <c10/macros/Macros.h>
36 #include <c10/util/irange.h>
37 #include <c10/util/Load.h>
38
39 #if defined(__GNUC__)
40 #define __FORCE_INLINE __attribute__((always_inline)) inline
41 #elif defined(_MSC_VER)
42 #define __FORCE_INLINE __forceinline
43 #endif
44
45 #if defined(_MSC_FULL_VER)
46 /*
47 https://learn.microsoft.com/en-us/cpp/overview/compiler-versions?view=msvc-170
48 Use _MSC_FULL_VER to identify current compiler is msvc,
49 Windows llvm will not have this defination.
50 */
51 #define __msvc_cl__
52 #endif
53
54 // These macros helped us unify vec_base.h
55 #ifdef CPU_CAPABILITY_AVX512
56 #if defined(__GNUC__)
57 #define __at_align__ __attribute__((aligned(64)))
58 #elif defined(_WIN32)
59 #define __at_align__ __declspec(align(64))
60 #else
61 #define __at_align__
62 #endif
63 #define VECTOR_WIDTH 64
64 #define int_vector __m512i
65 #else // CPU_CAPABILITY_AVX512
66 #if defined(__GNUC__)
67 #define __at_align__ __attribute__((aligned(32)))
68 #elif defined(_WIN32)
69 #define __at_align__ __declspec(align(32))
70 #else
71 #define __at_align__
72 #endif
73 #define VECTOR_WIDTH 32
74 #define int_vector __m256i
75 #endif // CPU_CAPABILITY_AVX512
76
77 namespace at::vec {
78 // See Note [CPU_CAPABILITY namespace]
79 inline namespace CPU_CAPABILITY {
80 // at::Half and at::BFloat16 should be treated as floating point
81 template <typename T>
82 struct is_floating_point:
83 std::integral_constant<bool,
84 std::is_floating_point_v<T> ||
85 std::is_same_v<T, at::Half> ||
86 std::is_same_v<T, at::BFloat16>> {
87 };
88
89 template<typename T>
90 constexpr bool is_floating_point_v = is_floating_point<T>::value;
91
92 template <typename T>
93 struct is_reduced_floating_point:
94 std::integral_constant<bool,
95 std::is_same_v<T, at::Half> ||
96 std::is_same_v<T, at::BFloat16>> {
97 };
98
99 template <typename T>
100 constexpr bool is_reduced_floating_point_v = is_reduced_floating_point<T>::value;
101
102 template <typename T>
103 struct is_8bit_integer:
104 std::integral_constant<bool,
105 std::is_same_v<T, unsigned char> ||
106 std::is_same_v<T, signed char>> {
107 };
108
109 template <typename T>
110 constexpr bool is_8bit_integer_v = is_8bit_integer<T>::value;
111
112 template<size_t n> struct int_of_size;
113
114 #define DEFINE_INT_OF_SIZE(int_t) \
115 template<> struct int_of_size<sizeof(int_t)> { using type = int_t; }
116
117 DEFINE_INT_OF_SIZE(int64_t);
118 DEFINE_INT_OF_SIZE(int32_t);
119 DEFINE_INT_OF_SIZE(int16_t);
120 DEFINE_INT_OF_SIZE(int8_t);
121
122 #undef DEFINE_INT_OF_SIZE
123
124 template <typename T>
125 using int_same_size_t = typename int_of_size<sizeof(T)>::type;
126
127 // NOTE: If you specialize on a type, you must define all operations!
128
129 // emulates Vectorized types
130 #if defined(__s390x__)
131 template <class T, class TEMP=void>
132 #else
133 template <class T>
134 #endif
135 struct Vectorized {
136 private:
137 __at_align__ T values[VECTOR_WIDTH / sizeof(T)];
138 public:
139 using value_type = T;
140 using size_type = int;
141 // Note [constexpr static function to avoid odr-usage compiler bug]
142 // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
143 // Why, you might ask, is size defined to be a static constexpr function,
144 // rather than a more ordinary 'static constexpr int size;' variable?
145 // The problem lies within ODR rules for static constexpr members versus
146 // static constexpr functions. First, recall that this class (along with all
147 // of its derivations) live in an anonymous namespace: they are intended to be
148 // *completely* inlined at their use-sites, because we need to compile it
149 // multiple times for different instruction sets.
150 //
151 // Because of this constraint, we CANNOT provide a single definition for
152 // any static members in this class; since we want to compile the class
153 // multiple times, there wouldn't actually be any good place to put the
154 // definition. Now here is the problem: if we ODR-use a static constexpr
155 // member, we are *obligated* to provide a definition. Without the
156 // definition, you get a compile error like:
157 //
158 // relocation R_X86_64_PC32 against undefined symbol
159 // `_ZN2at6vec25612_GLOBAL__N_16VectorizedIdE4sizeE' can not be used when making
160 // a shared object; recompile with -fPIC
161 //
162 // If this were C++17, we could replace a static constexpr variable with
163 // an inline variable which doesn't require one definition. But we are not
164 // C++17. So the next best thing is to replace the member with a static
165 // constexpr (and therefore inline) function, which does not require ODR
166 // either.
167 //
168 // Also, technically according to the C++ standard, we don't have to define
169 // a constexpr variable if we never odr-use it. But it seems that some
170 // versions GCC/Clang have buggy determinations on whether or not an
171 // identifier is odr-used or not, and in any case it's hard to tell if
172 // a variable is odr-used or not. So best to just cut the problem at the root.
sizeVectorized173 static constexpr size_type size() {
174 return VECTOR_WIDTH / sizeof(T);
175 }
VectorizedVectorized176 Vectorized() : values{static_cast<T>(0)} {}
VectorizedVectorized177 Vectorized(T val) {
178 for (int i = 0; i != size(); i++) {
179 values[i] = val;
180 }
181 }
182 template<typename... Args,
183 typename = std::enable_if_t<(sizeof...(Args) == size())>>
VectorizedVectorized184 Vectorized(Args... vals) : values{vals...}{
185 }
186 // This also implies const T& operator[](int idx) const
187 inline operator const T*() const {
188 return values;
189 }
190 // This also implies T& operator[](int idx)
191 inline operator T*() {
192 return values;
193 }
194 // Return the values as char* for type punning
195 auto as_bytes() const -> const char* {
196 return reinterpret_cast<const char*>(values);
197 }
198 template <int64_t mask_>
blendVectorized199 static Vectorized<T> blend(const Vectorized<T>& a, const Vectorized<T>& b) {
200 int64_t mask = mask_;
201 Vectorized vector;
202 for (const auto i : c10::irange(size())) {
203 if (mask & 0x01) {
204 vector[i] = b[i];
205 } else {
206 vector[i] = a[i];
207 }
208 mask = mask >> 1;
209 }
210 return vector;
211 }
blendvVectorized212 static Vectorized<T> blendv(const Vectorized<T>& a, const Vectorized<T>& b,
213 const Vectorized<T>& mask) {
214 Vectorized vector;
215 int_same_size_t<T> buffer[size()];
216 mask.store(buffer);
217 for (const auto i : c10::irange(size())) {
218 if (buffer[i] & 0x01)
219 {
220 vector[i] = b[i];
221 } else {
222 vector[i] = a[i];
223 }
224 }
225 return vector;
226 }
227 template<typename step_t> // step sometimes requires a higher precision type (e.g., T=int, step_t=double)
228 static Vectorized<T> arange(T base = static_cast<T>(0), step_t step = static_cast<step_t>(1)) {
229 Vectorized vector;
230 for (const auto i : c10::irange(size())) {
231 vector.values[i] = base + i * step;
232 }
233 return vector;
234 }
235 static Vectorized<T> set(const Vectorized<T>& a, const Vectorized<T>& b, int64_t count = size()) {
236 Vectorized vector;
237 for (const auto i : c10::irange(size())) {
238 if (i < count) {
239 vector[i] = b[i];
240 } else {
241 vector[i] = a[i];
242 }
243 }
244 return vector;
245 }
loaduVectorized246 static Vectorized<T> loadu(const void* ptr) {
247 Vectorized vector;
248 std::memcpy(vector.values, ptr, VECTOR_WIDTH);
249 return vector;
250 }
loaduVectorized251 static Vectorized<T> loadu(const void* ptr, int64_t count) {
252 Vectorized vector;
253 std::memcpy(vector.values, ptr, count * sizeof(T));
254 return vector;
255 }
loadu_one_fourthVectorized256 static Vectorized<T> loadu_one_fourth(const void* ptr) {
257 static_assert(std::is_same_v<T, signed char> || std::is_same_v<T, unsigned char>, "For byte types only");
258 return Vectorized::loadu(ptr, 8);
259 }
260
261 void store(void* ptr, int count = size()) const {
262 std::memcpy(ptr, values, count * sizeof(T));
263 }
zero_maskVectorized264 int zero_mask() const {
265 // returns an integer mask where all zero elements are translated to 1-bit and others are translated to 0-bit
266 int mask = 0;
267 for (int i = 0; i < size(); ++ i) {
268 if (values[i] == static_cast<T>(0)) {
269 mask |= (1 << i);
270 }
271 }
272 return mask;
273 }
isnanVectorized274 Vectorized<T> isnan() const {
275 Vectorized<T> vector;
276 for (int64_t i = 0; i != size(); i++) {
277 if (_isnan(values[i])) {
278 std::memset(static_cast<void*>(vector.values + i), 0xFF, sizeof(T));
279 } else {
280 std::memset(static_cast<void*>(vector.values + i), 0, sizeof(T));
281 }
282 }
283 return vector;
284 }
has_inf_nanVectorized285 bool has_inf_nan() const {
286 for (int64_t i = 0; i != size(); i++) {
287 if(_isnan(values[i]) || _isinf(values[i])) {
288 return true;
289 }
290 }
291 return false;
292 }
mapVectorized293 Vectorized<T> map(T (*const f)(T)) const {
294 Vectorized<T> ret;
295 for (int64_t i = 0; i != size(); i++) {
296 ret[i] = f(values[i]);
297 }
298 return ret;
299 }
mapVectorized300 Vectorized<T> map(T (*const f)(const T &)) const {
301 Vectorized<T> ret;
302 for (int64_t i = 0; i != size(); i++) {
303 ret[i] = f(values[i]);
304 }
305 return ret;
306 }
307 template <typename other_t_abs = T,
308 typename std::enable_if_t<!is_floating_point_v<other_t_abs> && !c10::is_complex<other_t_abs>::value, int> = 0>
absVectorized309 Vectorized<T> abs() const {
310 // other_t_abs is for SFINAE and clarity. Make sure it is not changed.
311 static_assert(std::is_same_v<other_t_abs, T>, "other_t_abs must be T");
312 return map([](T x) -> T { return x < static_cast<T>(0) ? -x : x; });
313 }
314 template <typename float_t_abs = T,
315 typename std::enable_if_t<is_floating_point_v<float_t_abs>, int> = 0>
absVectorized316 Vectorized<T> abs() const {
317 // float_t_abs is for SFINAE and clarity. Make sure it is not changed.
318 static_assert(std::is_same_v<float_t_abs, T>, "float_t_abs must be T");
319 // Specifically deal with floating-point because the generic code above won't handle -0.0 (which should result in
320 // 0.0) properly.
321 return map([](T x) -> T { return std::abs(x); });
322 }
323 template <typename complex_t_abs = T,
324 typename std::enable_if_t<c10::is_complex<complex_t_abs>::value, int> = 0>
absVectorized325 Vectorized<T> abs() const {
326 // complex_t_abs is for SFINAE and clarity. Make sure it is not changed.
327 static_assert(std::is_same_v<complex_t_abs, T>, "complex_t_abs must be T");
328 // Specifically map() does not perform the type conversion needed by abs.
329 return map([](T x) { return static_cast<T>(std::abs(x)); });
330 }
331
332 template <typename other_t_sgn = T,
333 typename std::enable_if_t<c10::is_complex<other_t_sgn>::value, int> = 0>
sgnVectorized334 Vectorized<T> sgn() const {
335 return map(at::native::sgn_impl);
336 }
337
338 template <typename other_t_angle = T,
339 typename std::enable_if_t<!c10::is_complex<other_t_angle>::value, int> = 0>
angleVectorized340 Vectorized<T> angle() const {
341 // other_t_angle is for SFINAE and clarity. Make sure it is not changed.
342 static_assert(std::is_same_v<other_t_angle, T>, "other_t_angle must be T");
343 return map(at::native::angle_impl<T>); // compiler is unable to resolve the overload without <T>
344 }
345 template <typename complex_t_angle = T,
346 typename std::enable_if_t<c10::is_complex<complex_t_angle>::value, int> = 0>
angleVectorized347 Vectorized<T> angle() const {
348 // complex_t_angle is for SFINAE and clarity. Make sure it is not changed.
349 static_assert(std::is_same_v<complex_t_angle, T>, "complex_t_angle must be T");
350 return map([](T x) { return static_cast<T>(std::arg(x)); });
351 }
352 template <typename other_t_real = T,
353 typename std::enable_if_t<!c10::is_complex<other_t_real>::value, int> = 0>
realVectorized354 Vectorized<T> real() const {
355 // other_t_real is for SFINAE and clarity. Make sure it is not changed.
356 static_assert(std::is_same_v<other_t_real, T>, "other_t_real must be T");
357 return *this;
358 }
359 template <typename complex_t_real = T,
360 typename std::enable_if_t<c10::is_complex<complex_t_real>::value, int> = 0>
realVectorized361 Vectorized<T> real() const {
362 // complex_t_real is for SFINAE and clarity. Make sure it is not changed.
363 static_assert(std::is_same_v<complex_t_real, T>, "complex_t_real must be T");
364 return map([](T x) { return static_cast<T>(x.real()); });
365 }
366 template <typename other_t_imag = T,
367 typename std::enable_if_t<!c10::is_complex<other_t_imag>::value, int> = 0>
imagVectorized368 Vectorized<T> imag() const {
369 // other_t_imag is for SFINAE and clarity. Make sure it is not changed.
370 static_assert(std::is_same_v<other_t_imag, T>, "other_t_imag must be T");
371 return Vectorized(0);
372 }
373 template <typename complex_t_imag = T,
374 typename std::enable_if_t<c10::is_complex<complex_t_imag>::value, int> = 0>
imagVectorized375 Vectorized<T> imag() const {
376 // complex_t_imag is for SFINAE and clarity. Make sure it is not changed.
377 static_assert(std::is_same_v<complex_t_imag, T>, "complex_t_imag must be T");
378 return map([](T x) { return static_cast<T>(x.imag()); });
379 }
380 template <typename other_t_conj = T,
381 typename std::enable_if_t<!c10::is_complex<other_t_conj>::value, int> = 0>
conjVectorized382 Vectorized<T> conj() const {
383 // other_t_conj is for SFINAE and clarity. Make sure it is not changed.
384 static_assert(std::is_same_v<other_t_conj, T>, "other_t_conj must be T");
385 return *this;
386 }
387 template <typename complex_t_conj = T,
388 typename std::enable_if_t<c10::is_complex<complex_t_conj>::value, int> = 0>
conjVectorized389 Vectorized<T> conj() const {
390 // complex_t_conj is for SFINAE and clarity. Make sure it is not changed.
391 static_assert(std::is_same_v<complex_t_conj, T>, "complex_t_conj must be T");
392 return map([](T x) { return static_cast<T>(std::conj(x)); });
393 }
acosVectorized394 Vectorized<T> acos() const {
395 return map(std::acos);
396 }
acoshVectorized397 Vectorized<T> acosh() const {
398 return map(std::acosh);
399 }
asinVectorized400 Vectorized<T> asin() const {
401 return map(std::asin);
402 }
atanVectorized403 Vectorized<T> atan() const {
404 return map(std::atan);
405 }
atanhVectorized406 Vectorized<T> atanh() const {
407 return map(std::atanh);
408 }
atan2Vectorized409 Vectorized<T> atan2(const Vectorized<T> &exp) const {
410 Vectorized<T> ret;
411 for (const auto i : c10::irange(size())) {
412 ret[i] = std::atan2(values[i], exp[i]);
413 }
414 return ret;
415 }
416 template <
417 typename U = T,
418 typename std::enable_if_t<is_floating_point_v<U>, int> = 0>
copysignVectorized419 Vectorized<T> copysign(const Vectorized<T> &sign) const {
420 Vectorized<T> ret;
421 for (size_type i = 0; i < size(); i++) {
422 ret[i] = c10::copysign(values[i], sign[i]);
423 }
424 return ret;
425 }
erfVectorized426 Vectorized<T> erf() const {
427 return map(std::erf);
428 }
erfcVectorized429 Vectorized<T> erfc() const {
430 return map(std::erfc);
431 }
erfinvVectorized432 Vectorized<T> erfinv() const {
433 return map(calc_erfinv);
434 }
expVectorized435 Vectorized<T> exp() const {
436 return map(std::exp);
437 }
exp2Vectorized438 Vectorized<T> exp2() const {
439 return map(exp2_impl);
440 }
expm1Vectorized441 Vectorized<T> expm1() const {
442 return map(std::expm1);
443 }
exp_u20Vectorized444 Vectorized<T> exp_u20() const {
445 return map(std::exp);
446 }
fracVectorized447 Vectorized<T> frac() const {
448 return *this - this->trunc();
449 }
450 template <
451 typename U = T,
452 typename std::enable_if_t<is_floating_point_v<U>, int> = 0>
fmodVectorized453 Vectorized<T> fmod(const Vectorized<T>& q) const {
454 // U is for SFINAE purposes only. Make sure it is not changed.
455 static_assert(std::is_same_v<U, T>, "U must be T");
456 Vectorized<T> ret;
457 for (const auto i : c10::irange(size())) {
458 ret[i] = std::fmod(values[i], q[i]);
459 }
460 return ret;
461 }
logVectorized462 Vectorized<T> log() const {
463 return map(std::log);
464 }
log10Vectorized465 Vectorized<T> log10() const {
466 return map(std::log10);
467 }
log1pVectorized468 Vectorized<T> log1p() const {
469 return map(std::log1p);
470 }
471 template <typename other_t_log2 = T,
472 typename std::enable_if_t<!c10::is_complex<other_t_log2>::value, int> = 0>
log2Vectorized473 Vectorized<T> log2() const {
474 // other_t_log2 is for SFINAE and clarity. Make sure it is not changed.
475 static_assert(std::is_same_v<other_t_log2, T>, "other_t_log2 must be T");
476 return map(std::log2);
477 }
478 template <typename complex_t_log2 = T,
479 typename std::enable_if_t<c10::is_complex<complex_t_log2>::value, int> = 0>
log2Vectorized480 Vectorized<T> log2() const {
481 // complex_t_log2 is for SFINAE and clarity. Make sure it is not changed.
482 static_assert(std::is_same_v<complex_t_log2, T>, "complex_t_log2 must be T");
483 const T log_2 = T(std::log(2.0));
484 return Vectorized(map(std::log))/Vectorized(log_2);
485 }
ceilVectorized486 Vectorized<T> ceil() const {
487 return map(at::native::ceil_impl);
488 }
cosVectorized489 Vectorized<T> cos() const {
490 return map(std::cos);
491 }
coshVectorized492 Vectorized<T> cosh() const {
493 return map(std::cosh);
494 }
floorVectorized495 Vectorized<T> floor() const {
496 return map(at::native::floor_impl);
497 }
hypotVectorized498 Vectorized<T> hypot(const Vectorized<T> &b) const {
499 Vectorized<T> ret;
500 for (const auto i : c10::irange(size())) {
501 ret[i] = std::hypot(values[i], b[i]);
502 }
503 return ret;
504 }
i0Vectorized505 Vectorized<T> i0() const {
506 return map(calc_i0);
507 }
i0eVectorized508 Vectorized<T> i0e() const {
509 return map(calc_i0e);
510 }
digammaVectorized511 Vectorized<T> digamma() const {
512 return map(calc_digamma);
513 }
igammaVectorized514 Vectorized<T> igamma(const Vectorized<T> &x) const {
515 Vectorized<T> ret;
516 for (const auto i : c10::irange(size())) {
517 ret[i] = calc_igamma(values[i], x[i]);
518 }
519 return ret;
520 }
igammacVectorized521 Vectorized<T> igammac(const Vectorized<T> &x) const {
522 Vectorized<T> ret;
523 for (const auto i : c10::irange(size())) {
524 ret[i] = calc_igammac(values[i], x[i]);
525 }
526 return ret;
527 }
negVectorized528 Vectorized<T> neg() const {
529 // NB: the trailing return type is needed because we need to coerce the
530 // return value back to T in the case of unary operator- incuring a
531 // promotion
532 return map([](T x) -> T { return -x; });
533 }
nextafterVectorized534 Vectorized<T> nextafter(const Vectorized<T> &b) const {
535 Vectorized<T> ret;
536 for (const auto i : c10::irange(size())) {
537 ret[i] = std::nextafter(values[i], b[i]);
538 }
539 return ret;
540 }
roundVectorized541 Vectorized<T> round() const {
542 // We do not use std::round because we would like to round midway numbers to the nearest even integer.
543 return map(at::native::round_impl);
544 }
sinVectorized545 Vectorized<T> sin() const {
546 return map(std::sin);
547 }
sinhVectorized548 Vectorized<T> sinh() const {
549 return map(std::sinh);
550 }
tanVectorized551 Vectorized<T> tan() const {
552 return map(std::tan);
553 }
tanhVectorized554 Vectorized<T> tanh() const {
555 return map(std::tanh);
556 }
truncVectorized557 Vectorized<T> trunc() const {
558 return map(at::native::trunc_impl);
559 }
lgammaVectorized560 Vectorized<T> lgamma() const {
561 return map(std::lgamma);
562 }
sqrtVectorized563 Vectorized<T> sqrt() const {
564 return map(std::sqrt);
565 }
reciprocalVectorized566 Vectorized<T> reciprocal() const {
567 return map([](T x) { return (T)(1) / x; });
568 }
rsqrtVectorized569 Vectorized<T> rsqrt() const {
570 return map([](T x) { return (T)1 / std::sqrt(x); });
571 }
powVectorized572 Vectorized<T> pow(const Vectorized<T> &exp) const {
573 Vectorized<T> ret;
574 for (const auto i : c10::irange(size())) {
575 ret[i] = std::pow(values[i], exp[i]);
576 }
577 return ret;
578 }
579 private:
580 template <typename Op>
binary_predVectorized581 inline Vectorized<T> binary_pred(const Vectorized<T>& other, Op op) const {
582 // All bits are set to 1 if the pred is true, otherwise 0.
583 Vectorized<T> vector;
584 for (int64_t i = 0; i != size(); i++) {
585 if (op(values[i], other.values[i])) {
586 std::memset(static_cast<void*>(vector.values + i), 0xFF, sizeof(T));
587 } else {
588 std::memset(static_cast<void*>(vector.values + i), 0, sizeof(T));
589 }
590 }
591 return vector;
592 }
593
594 public:
595 Vectorized<T> operator==(const Vectorized<T>& other) const { return binary_pred(other, std::equal_to<T>()); }
596 Vectorized<T> operator!=(const Vectorized<T>& other) const { return binary_pred(other, std::not_equal_to<T>()); }
597 Vectorized<T> operator>=(const Vectorized<T>& other) const { return binary_pred(other, std::greater_equal<T>()); }
598 Vectorized<T> operator<=(const Vectorized<T>& other) const { return binary_pred(other, std::less_equal<T>()); }
599 Vectorized<T> operator>(const Vectorized<T>& other) const { return binary_pred(other, std::greater<T>()); }
600 Vectorized<T> operator<(const Vectorized<T>& other) const { return binary_pred(other, std::less<T>()); }
601
602 private:
603 template <typename Op>
binary_pred_boolVectorized604 inline Vectorized<T> binary_pred_bool(const Vectorized<T>& other, Op op) const {
605 // 1 if the pred is true, otherwise 0.
606 Vectorized<T> vector;
607 for (int i = 0; i != size(); ++ i) {
608 vector[i] = static_cast<T>(op(values[i], other.values[i]));
609 }
610 return vector;
611 }
612
613 public:
eqVectorized614 Vectorized<T> eq(const Vectorized<T>& other) const { return binary_pred_bool(other, std::equal_to<T>()); }
neVectorized615 Vectorized<T> ne(const Vectorized<T>& other) const { return binary_pred_bool(other, std::not_equal_to<T>()); }
gtVectorized616 Vectorized<T> gt(const Vectorized<T>& other) const { return binary_pred_bool(other, std::greater<T>()); }
geVectorized617 Vectorized<T> ge(const Vectorized<T>& other) const { return binary_pred_bool(other, std::greater_equal<T>()); }
ltVectorized618 Vectorized<T> lt(const Vectorized<T>& other) const { return binary_pred_bool(other, std::less<T>()); }
leVectorized619 Vectorized<T> le(const Vectorized<T>& other) const { return binary_pred_bool(other, std::less_equal<T>()); }
620 };
621
622 template <class T> Vectorized<T> inline operator+(const Vectorized<T> &a, const Vectorized<T> &b) {
623 Vectorized<T> c;
624 for (int i = 0; i != Vectorized<T>::size(); i++) {
625 c[i] = a[i] + b[i];
626 }
627 return c;
628 }
629
630 template <class T> Vectorized<T> inline operator-(const Vectorized<T> &a, const Vectorized<T> &b) {
631 Vectorized<T> c;
632 for (int i = 0; i != Vectorized<T>::size(); i++) {
633 c[i] = a[i] - b[i];
634 }
635 return c;
636 }
637
638 template <class T> Vectorized<T> inline operator*(const Vectorized<T> &a, const Vectorized<T> &b) {
639 Vectorized<T> c;
640 for (int i = 0; i != Vectorized<T>::size(); i++) {
641 c[i] = a[i] * b[i];
642 }
643 return c;
644 }
645
646 template <class T> Vectorized<T> inline operator/(const Vectorized<T> &a, const Vectorized<T> &b) __ubsan_ignore_float_divide_by_zero__ {
647 Vectorized<T> c;
648 for (int i = 0; i != Vectorized<T>::size(); i++) {
649 c[i] = a[i] / b[i];
650 }
651 return c;
652 }
653
654 template <class T,
655 typename std::enable_if_t<!is_floating_point_v<T>, int> = 0>
656 Vectorized<T> inline operator%(const Vectorized<T> &a, const Vectorized<T> &b) __ubsan_ignore_float_divide_by_zero__ {
657 return a - a / b * b;
658 }
659
660 template <class T> Vectorized<T> inline operator||(
661 const Vectorized<T> &a, const Vectorized<T> &b) {
662 Vectorized<T> c;
663 for (int i = 0; i != Vectorized<T>::size(); i++) {
664 c[i] = a[i] || b[i];
665 }
666 return c;
667 }
668
669 // Implements the IEEE 754 201X `maximum` operation, which propagates NaN if
670 // either input is a NaN.
671 template <class T,
672 typename std::enable_if_t<!c10::is_complex<T>::value, int> = 0>
maximum(const Vectorized<T> & a,const Vectorized<T> & b)673 Vectorized<T> inline maximum(const Vectorized<T> &a, const Vectorized<T> &b) {
674 Vectorized<T> c;
675 for (int i = 0; i != Vectorized<T>::size(); i++) {
676 c[i] = (a[i] > b[i]) ? a[i] : b[i];
677 if (_isnan(a[i])) {
678 // If either input is NaN, propagate a NaN.
679 // NOTE: The case where b[i] was NaN is handled correctly by the naive
680 // ternary operator above.
681 c[i] = a[i];
682 }
683 }
684 return c;
685 }
686
687 template <class T,
688 typename std::enable_if_t<c10::is_complex<T>::value, int> = 0>
maximum(const Vectorized<T> & a,const Vectorized<T> & b)689 Vectorized<T> inline maximum(const Vectorized<T> &a, const Vectorized<T> &b) {
690 Vectorized<T> c;
691 for (int i = 0; i != Vectorized<T>::size(); i++) {
692 c[i] = (std::abs(a[i]) > std::abs(b[i])) ? a[i] : b[i];
693 if (_isnan(a[i])) {
694 // If either input is NaN, propagate a NaN.
695 // NOTE: The case where b[i] was NaN is handled correctly by the naive
696 // ternary operator above.
697 c[i] = a[i];
698 }
699 }
700 return c;
701 }
702
703 // Implements the IEEE 754 201X `minimum` operation, which propagates NaN if
704 // either input is a NaN.
705 template <class T,
706 typename std::enable_if_t<!c10::is_complex<T>::value, int> = 0>
minimum(const Vectorized<T> & a,const Vectorized<T> & b)707 Vectorized<T> inline minimum(const Vectorized<T> &a, const Vectorized<T> &b) {
708 Vectorized<T> c;
709 for (int i = 0; i != Vectorized<T>::size(); i++) {
710 c[i] = (a[i] < b[i]) ? a[i] : b[i];
711 if (_isnan(a[i])) {
712 // If either input is NaN, propagate a NaN.
713 // NOTE: The case where b[i] was NaN is handled correctly by the naive
714 // ternary operator above.
715 c[i] = a[i];
716 }
717 }
718 return c;
719 }
720
721 template <class T,
722 typename std::enable_if_t<c10::is_complex<T>::value, int> = 0>
minimum(const Vectorized<T> & a,const Vectorized<T> & b)723 Vectorized<T> inline minimum(const Vectorized<T> &a, const Vectorized<T> &b) {
724 Vectorized<T> c;
725 for (int i = 0; i != Vectorized<T>::size(); i++) {
726 c[i] = (std::abs(a[i]) < std::abs(b[i])) ? a[i] : b[i];
727 if (_isnan(a[i])) {
728 // If either input is NaN, propagate a NaN.
729 // NOTE: The case where b[i] was NaN is handled correctly by the naive
730 // ternary operator above.
731 c[i] = a[i];
732 }
733 }
734 return c;
735 }
736
737 template <class T,
738 typename std::enable_if_t<!c10::is_complex<T>::value, int> = 0>
clamp(const Vectorized<T> & a,const Vectorized<T> & min_vec,const Vectorized<T> & max_vec)739 Vectorized<T> inline clamp(const Vectorized<T> &a, const Vectorized<T> &min_vec, const Vectorized<T> &max_vec) {
740 Vectorized<T> c;
741 for (int i = 0; i != Vectorized<T>::size(); i++) {
742 c[i] = std::min(std::max(a[i], min_vec[i]), max_vec[i]);
743 }
744 return c;
745 }
746
747 template <class T,
748 typename std::enable_if_t<!c10::is_complex<T>::value, int> = 0>
clamp_max(const Vectorized<T> & a,const Vectorized<T> & max_vec)749 Vectorized<T> inline clamp_max(const Vectorized<T> &a, const Vectorized<T> &max_vec) {
750 Vectorized<T> c;
751 for (int i = 0; i != Vectorized<T>::size(); i++) {
752 c[i] = a[i] > max_vec[i] ? max_vec[i] : a[i];
753 }
754 return c;
755 }
756
757 template <class T,
758 typename std::enable_if_t<!c10::is_complex<T>::value, int> = 0>
clamp_min(const Vectorized<T> & a,const Vectorized<T> & min_vec)759 Vectorized<T> inline clamp_min(const Vectorized<T> &a, const Vectorized<T> &min_vec) {
760 Vectorized<T> c;
761 for (int i = 0; i != Vectorized<T>::size(); i++) {
762 c[i] = a[i] < min_vec[i] ? min_vec[i] : a[i];
763 }
764 return c;
765 }
766
767 struct Vectorizedi;
768
769 #if defined(CPU_CAPABILITY_AVX2) || defined(CPU_CAPABILITY_AVX512)
770 template <class T, typename Op>
bitwise_binary_op(const Vectorized<T> & a,const Vectorized<T> & b,Op op)771 static inline Vectorized<T> bitwise_binary_op(const Vectorized<T> &a, const Vectorized<T> &b, Op op) {
772 int_vector buffer;
773 #if defined(CPU_CAPABILITY_AVX2)
774 int_vector a_buffer = _mm256_load_si256(reinterpret_cast<const int_vector*>((const T*)a));
775 int_vector b_buffer = _mm256_load_si256(reinterpret_cast<const int_vector*>((const T*)b));
776 #elif defined(CPU_CAPABILITY_AVX512)
777 int_vector a_buffer = _mm512_load_si512(reinterpret_cast<const int_vector*>((const T*)a));
778 int_vector b_buffer = _mm512_load_si512(reinterpret_cast<const int_vector*>((const T*)b));
779 #endif
780 buffer = op(a_buffer, b_buffer);
781 __at_align__ T results[Vectorized<T>::size()];
782
783 #if defined(CPU_CAPABILITY_AVX2)
784 _mm256_store_si256(reinterpret_cast<int_vector*>(results), buffer);
785 #elif defined(CPU_CAPABILITY_AVX512)
786 _mm512_store_si512(reinterpret_cast<int_vector*>(results), buffer);
787 #endif
788 return Vectorized<T>::loadu(results);
789 }
790
791 template<class T, typename std::enable_if_t<!std::is_base_of<Vectorizedi, Vectorized<T>>::value, int> = 0>
792 inline Vectorized<T> operator&(const Vectorized<T>& a, const Vectorized<T>& b) {
793 // We enclose _mm512_and_si512 or _mm256_and_si256 with lambda because it is always_inline
794 #if defined(CPU_CAPABILITY_AVX2)
795 return bitwise_binary_op(a, b, [](int_vector a, int_vector b) { return _mm256_and_si256(a, b); });
796 #elif defined(CPU_CAPABILITY_AVX512)
797 return bitwise_binary_op(a, b, [](int_vector a, int_vector b) { return _mm512_and_si512(a, b); });
798 #endif
799 }
800 template<class T, typename std::enable_if_t<!std::is_base_of<Vectorizedi, Vectorized<T>>::value, int> = 0>
801 inline Vectorized<T> operator|(const Vectorized<T>& a, const Vectorized<T>& b) {
802 // We enclose _mm512_or_si512 or _mm256_or_si256 with lambda because it is always_inline
803 #if defined(CPU_CAPABILITY_AVX2)
804 return bitwise_binary_op(a, b, [](int_vector a, int_vector b) { return _mm256_or_si256(a, b); });
805 #elif defined(CPU_CAPABILITY_AVX512)
806 return bitwise_binary_op(a, b, [](int_vector a, int_vector b) { return _mm512_or_si512(a, b); });
807 #endif
808 }
809 template<class T, typename std::enable_if_t<!std::is_base_of<Vectorizedi, Vectorized<T>>::value, int> = 0>
810 inline Vectorized<T> operator^(const Vectorized<T>& a, const Vectorized<T>& b) {
811 // We enclose _mm512_xor_si512 or _mm256_xor_si256 with lambda because it is always_inline
812 #if defined(CPU_CAPABILITY_AVX2)
813 return bitwise_binary_op(a, b, [](int_vector a, int_vector b) { return _mm256_xor_si256(a, b); });
814 #elif defined(CPU_CAPABILITY_AVX512)
815 return bitwise_binary_op(a, b, [](int_vector a, int_vector b) { return _mm512_xor_si512(a, b); });
816 #endif
817 }
818
819 #else
820
821 template <typename T>
822 auto load(char const* data) -> T {
823 T ret;
824 std::memcpy(&ret, data, sizeof(ret));
825 return ret;
826 }
827
828 template<class T, typename Op>
bitwise_binary_op(const Vectorized<T> & a,const Vectorized<T> & b,Op op)829 static inline Vectorized<T> bitwise_binary_op(const Vectorized<T> &a, const Vectorized<T> &b, Op op) {
830 static constexpr uint32_t element_no = VECTOR_WIDTH / sizeof(intmax_t);
831 __at_align__ intmax_t buffer[element_no];
832 static_assert(VECTOR_WIDTH % sizeof(intmax_t) == 0, "VECTOR_WIDTH not a multiple of sizeof(intmax_t)");
833 static_assert(sizeof(buffer) == sizeof(Vectorized<T>), "sizeof(buffer) must match sizeof(Vectorized<T>)");
834 // We should be using memcpy in order to respect the strict aliasing rule
835 // see: https://github.com/pytorch/pytorch/issues/66119
836 // Using char* is defined in the C11 standard 6.5 Expression paragraph 7
837 // (http://www.open-std.org/jtc1/sc22/wg14/www/docs/n1570.pdf)
838 const auto* a_data = a.as_bytes();
839 const auto* b_data = b.as_bytes();
840 // load each intmax_t chunk and process; increase pointers by sizeof(intmax_t)
841 for (auto& out : buffer) {
842 out = op(load<intmax_t>(a_data), load<intmax_t>(b_data));
843 a_data += sizeof(intmax_t);
844 b_data += sizeof(intmax_t);
845 }
846 assert(a_data == a.as_bytes() + sizeof(a));
847 assert(b_data == b.as_bytes() + sizeof(b));
848 return Vectorized<T>::loadu(buffer);
849 }
850
851 template<class T, typename std::enable_if_t<!std::is_base_of_v<Vectorizedi, Vectorized<T>>, int> = 0>
852 inline Vectorized<T> operator&(const Vectorized<T>& a, const Vectorized<T>& b) {
853 return bitwise_binary_op(a, b, std::bit_and<intmax_t>());
854 }
855 template<class T, typename std::enable_if_t<!std::is_base_of_v<Vectorizedi, Vectorized<T>>, int> = 0>
856 inline Vectorized<T> operator|(const Vectorized<T>& a, const Vectorized<T>& b) {
857 return bitwise_binary_op(a, b, std::bit_or<intmax_t>());
858 }
859 template<class T, typename std::enable_if_t<!std::is_base_of_v<Vectorizedi, Vectorized<T>>, int> = 0>
860 inline Vectorized<T> operator^(const Vectorized<T>& a, const Vectorized<T>& b) {
861 return bitwise_binary_op(a, b, std::bit_xor<intmax_t>());
862 }
863
864 #endif // defined(CPU_CAPABILITY_AVX2) || defined(CPU_CAPABILITY_AVX512)
865
866 template<class T, typename std::enable_if_t<!std::is_base_of_v<Vectorizedi, Vectorized<T>>, int> = 0>
867 inline Vectorized<T> operator~(const Vectorized<T>& a) {
868 using int_t = int_same_size_t<T>;
869 Vectorized<T> ones(c10::bit_cast<T>((int_t)(~(int_t)0))); // All bits are 1
870 return a ^ ones;
871 }
872
873 template <class T> Vectorized<T> inline operator<<(const Vectorized<T> &a, const Vectorized<T> &b) {
874 constexpr T max_shift = sizeof(T) * CHAR_BIT;
875 Vectorized<T> c;
876 for (int i = 0; i != Vectorized<T>::size(); i++) {
877 T shift = b[i];
878 if ((static_cast<std::make_signed_t<T>>(shift) < 0) || (shift >= max_shift)) {
879 c[i] = 0;
880 } else {
881 c[i] = static_cast<std::make_unsigned_t<T>>(a[i]) << shift;
882 }
883 }
884 return c;
885 }
886
887 template <class T> Vectorized<T> inline operator>>(const Vectorized<T> &a, const Vectorized<T> &b) {
888 // right shift value to retain sign bit for signed and no bits for unsigned
889 constexpr T max_shift = sizeof(T) * CHAR_BIT - std::is_signed_v<T>;
890 Vectorized<T> c;
891 for (int i = 0; i != Vectorized<T>::size(); i++) {
892 T shift = b[i];
893 if ((static_cast<std::make_signed_t<T>>(shift) < 0) || (shift >= max_shift)) {
894 c[i] = a[i] >> max_shift;
895 } else {
896 c[i] = a[i] >> shift;
897 }
898 }
899 return c;
900 }
901
902 template <typename T>
903 inline Vectorized<T>& operator += (Vectorized<T>& a, const Vectorized<T>& b) {
904 a = a + b;
905 return a;
906 }
907 template <typename T>
908 inline Vectorized<T>& operator -= (Vectorized<T>& a, const Vectorized<T>& b) {
909 a = a - b;
910 return a;
911 }
912 template <typename T>
913 inline Vectorized<T>& operator /= (Vectorized<T>& a, const Vectorized<T>& b) {
914 a = a / b;
915 return a;
916 }
917 template <typename T>
918 inline Vectorized<T>& operator %= (Vectorized<T>& a, const Vectorized<T>& b) {
919 a = a % b;
920 return a;
921 }
922 template <typename T>
923 inline Vectorized<T>& operator *= (Vectorized<T>& a, const Vectorized<T>& b) {
924 a = a * b;
925 return a;
926 }
927
928 template <typename T>
929 inline Vectorized<T>& operator <<= (Vectorized<T>& a, const Vectorized<T>& b) {
930 a = a << b;
931 return a;
932 }
933
934 template <typename T>
935 inline Vectorized<T>& operator >>= (Vectorized<T>& a, const Vectorized<T>& b) {
936 a = a >> b;
937 return a;
938 }
939
940 template <typename T>
fmadd(const Vectorized<T> & a,const Vectorized<T> & b,const Vectorized<T> & c)941 inline Vectorized<T> fmadd(const Vectorized<T>& a, const Vectorized<T>& b, const Vectorized<T>& c) {
942 return a * b + c;
943 }
944
945 template <typename T>
fmsub(const Vectorized<T> & a,const Vectorized<T> & b,const Vectorized<T> & c)946 inline Vectorized<T> fmsub(const Vectorized<T>& a, const Vectorized<T>& b, const Vectorized<T>& c) {
947 return a * b - c;
948 }
949
950 template <typename T>
951 Vectorized<T> inline operator&&(
952 const Vectorized<T>& a,
953 const Vectorized<T>& b) {
954 Vectorized<T> ret;
955 for (int i = 0; i != Vectorized<T>::size(); i++) {
956 ret[i] = a[i] && b[i];
957 }
958 return ret;
959 }
960
961 template <int64_t scale = 1, typename T = void>
962 std::enable_if_t<scale == 1 || scale == 2 || scale == 4 || scale == 8, Vectorized<T>>
gather(T const * base_addr,const Vectorized<int_same_size_t<T>> & vindex)963 inline gather(T const* base_addr, const Vectorized<int_same_size_t<T>>& vindex) {
964 static constexpr int size = Vectorized<T>::size();
965 int_same_size_t<T> index_arr[size];
966 vindex.store(static_cast<void*>(index_arr));
967 T buffer[size];
968 for (const auto i : c10::irange(size)) {
969 buffer[i] = base_addr[index_arr[i] * scale / sizeof(T)];
970 }
971 return Vectorized<T>::loadu(static_cast<void*>(buffer));
972 }
973
974 template <int64_t scale = 1, typename T = void>
975 std::enable_if_t<scale == 1 || scale == 2 || scale == 4 || scale == 8, Vectorized<T>>
mask_gather(const Vectorized<T> & src,T const * base_addr,const Vectorized<int_same_size_t<T>> & vindex,Vectorized<T> & mask)976 inline mask_gather(const Vectorized<T>& src, T const* base_addr,
977 const Vectorized<int_same_size_t<T>>& vindex, Vectorized<T>& mask) {
978 static constexpr int size = Vectorized<T>::size();
979 T src_arr[size];
980 int_same_size_t<T> mask_arr[size]; // use int type so we can logical and
981 int_same_size_t<T> index_arr[size];
982 src.store(static_cast<void*>(src_arr));
983 mask.store(static_cast<void*>(mask_arr));
984 vindex.store(static_cast<void*>(index_arr));
985 T buffer[size];
986 for (const auto i : c10::irange(size)) {
987 if (mask_arr[i] & 0x01) { // check highest bit
988 buffer[i] = base_addr[index_arr[i] * scale / sizeof(T)];
989 } else {
990 buffer[i] = src_arr[i];
991 }
992 }
993 mask = Vectorized<T>(); // "zero out" mask
994 return Vectorized<T>::loadu(static_cast<void*>(buffer));
995 }
996
997 // Cast a given vector to another type without changing the bits representation.
998 // So a Vectorized<double> of 512 bits containing all ones can be cast to a
999 // Vectorized<int64_t> of 512 bits containing all ones (i.e., eight negative 1s).
1000 // A Vec<double> of 256 bits containing all ones can be cast to a
1001 // Vec<int64_t> of 256 bits containing all ones (i.e., four negative 1s).
1002 // There is a struct here because we don't have static_if and I can't
1003 // partially specialize a templated function.
1004 template<typename dst_t, typename src_t>
1005 struct CastImpl {
applyCastImpl1006 static inline Vectorized<dst_t> apply(const Vectorized<src_t>& src) {
1007 src_t src_arr[Vectorized<src_t>::size()];
1008 src.store(static_cast<void*>(src_arr));
1009 return Vectorized<dst_t>::loadu(static_cast<const void*>(src_arr));
1010 }
1011 };
1012
1013 template<typename scalar_t>
1014 struct CastImpl<scalar_t, scalar_t> {
1015 static inline Vectorized<scalar_t> apply(const Vectorized<scalar_t>& src) {
1016 return src;
1017 }
1018 };
1019
1020 template<typename dst_t, typename src_t>
1021 inline Vectorized<dst_t> cast(const Vectorized<src_t>& src) {
1022 return CastImpl<dst_t, src_t>::apply(src);
1023 }
1024
1025 template <typename T, typename IntType = int_same_size_t<T>>
1026 inline Vectorized<IntType> convert_to_int_of_same_size(const Vectorized<T>& src) {
1027 static_assert(sizeof(T) == sizeof(IntType));
1028 static constexpr int size = Vectorized<T>::size();
1029
1030 std::array<T, size> src_arr;
1031 src.store(static_cast<void*>(src_arr.data()));
1032 std::array<IntType, size> buffer;
1033 std::transform(src_arr.cbegin(), src_arr.cend(), buffer.begin(),
1034 [](const T& x) { return static_cast<IntType>(x); });
1035 return Vectorized<IntType>::loadu(static_cast<const void*>(buffer.data()));
1036 }
1037
1038 template <typename T, typename IntType = int_same_size_t<T>>
1039 inline Vectorized<T> convert_to_fp_of_same_size(const Vectorized<IntType>& src) {
1040 static_assert(sizeof(T) == sizeof(IntType));
1041 static constexpr int size = Vectorized<T>::size();
1042
1043 std::array<IntType, size> src_arr;
1044 src.store(static_cast<void*>(src_arr.data()));
1045 std::array<T, size> buffer;
1046 std::transform(src_arr.cbegin(), src_arr.cend(), buffer.begin(),
1047 [](const IntType& x) { return static_cast<T>(x); });
1048 return Vectorized<T>::loadu(static_cast<const void*>(buffer.data()));
1049 }
1050
1051 // Example inputs for AVX512:
1052 // a Vectorized<float> = {a0, b0, a1, b1, a2, b2, a3, b3, a4, b4, a5, b5, a6, b6, a7, b7}
1053 // b Vectorized<float> = {a8, b8, a9, b9, a10, b10, a11, b11, a12, b12, a13, b13, a14, b14, a15, b15}
1054 // returns:
1055 // Vectorized<float> = {a0, a1, a2, a3, a4, a5, a6, a7, a8, a9, a10, a11, a12, a13, a14, a15}
1056 // Vectorized<float> = {b0, b1, b2, b3, b4, b5, b6, b7, b8, b9, b10, b11, b12, b13, b14, b15}
1057 // Example inputs for AVX2: a Vectorized<float> = {a0, b0, a1, b1, a2, b2, a3, b3}
1058 // b Vectorized<float> = {a4, b4, a5, b5, a6, b6, a7, b7}
1059 // returns: Vectorized<float> = {a0, a1, a2, a3, a4, a5, a6, a7}
1060 // Vectorized<float> = {b0, b1, b2, b3, b4, b5, b6, b7}
1061 template <typename T>
1062 inline std::enable_if_t<Vectorized<T>::size() % 2 == 0, std::pair<Vectorized<T>, Vectorized<T>>>
1063 deinterleave2(const Vectorized<T>& a, const Vectorized<T>& b) {
1064 static constexpr int size = Vectorized<T>::size();
1065 static constexpr int half_size = size / 2;
1066 T a_arr[size];
1067 T b_arr[size];
1068 T buffer1[size];
1069 T buffer2[size];
1070 a.store(static_cast<void*>(a_arr));
1071 b.store(static_cast<void*>(b_arr));
1072 for (const auto i : c10::irange(half_size)) {
1073 buffer1[i] = a_arr[i * 2];
1074 buffer1[half_size + i] = b_arr[i * 2];
1075 buffer2[i] = a_arr[i * 2 + 1];
1076 buffer2[half_size + i] = b_arr[i * 2 + 1];
1077 }
1078 return std::make_pair(Vectorized<T>::loadu(static_cast<void*>(buffer1)),
1079 Vectorized<T>::loadu(static_cast<void*>(buffer2)));
1080 }
1081
1082 // inverse operation of deinterleave2
1083 // Example inputs for AVX512:
1084 // a Vectorized<float> = {a0, a1, a2, a3, a4, a5, a6, a7, a8, a9, a10, a11, a12, a13, a14, a15}
1085 // b Vectorized<float> = {b0, b1, b2, b3, b4, b5, b6, b7, b8, b9, b10, b11, b12, b13, b14, b15}
1086 // returns, for AVX512:
1087 // Vectorized<float> = {a0, b0, a1, b1, a2, b2, a3, b3, a4, b4, a5, b5, a6, b6, a7, b7}
1088 // Vectorized<float> = {a8, b8, a9, b9, a10, b10, a11, b11, a12, b12, a13, b13, a14, b14, a15, b15}
1089 // Example inputs for AVX2 : a Vectorized<float> = {a0, a1, a2, a3, a4, a5, a6, a7}
1090 // b Vectorized<float> = {b0, b1, b2, b3, b4, b5, b6, b7}
1091 // returns: Vectorized<float> = {a0, b0, a1, b1, a2, b2, a3, b3}
1092 // Vectorized<float> = {a4, b4, a5, b5, a6, b6, a7, b7}
1093 template <typename T>
1094 inline std::enable_if_t<Vectorized<T>::size() % 2 == 0, std::pair<Vectorized<T>, Vectorized<T>>>
1095 interleave2(const Vectorized<T>& a, const Vectorized<T>& b) {
1096 static constexpr int size = Vectorized<T>::size();
1097 static constexpr int half_size = size / 2;
1098 T a_arr[size];
1099 T b_arr[size];
1100 T buffer1[size];
1101 T buffer2[size];
1102 a.store(static_cast<void*>(a_arr));
1103 b.store(static_cast<void*>(b_arr));
1104 for (const auto i : c10::irange(half_size)) {
1105 buffer1[i * 2] = a_arr[i];
1106 buffer1[i * 2 + 1] = b_arr[i];
1107 buffer2[i * 2] = a_arr[half_size + i];
1108 buffer2[i * 2 + 1] = b_arr[half_size + i];
1109 }
1110 return std::make_pair(Vectorized<T>::loadu(static_cast<void*>(buffer1)),
1111 Vectorized<T>::loadu(static_cast<void*>(buffer2)));
1112 }
1113
1114 template <typename src_T, typename dst_T>
1115 inline void convert(const src_T *src, dst_T *dst, int64_t n) {
1116 #ifndef _MSC_VER
1117 # pragma unroll
1118 #endif
1119 for (C10_UNUSED const auto i : c10::irange(n)) {
1120 *dst = c10::convert<dst_T>(c10::load(src));
1121 src++;
1122 dst++;
1123 }
1124 }
1125
1126 template <typename T>
1127 inline Vectorized<T> flip(const Vectorized<T> & data) {
1128 static constexpr int size = Vectorized<T>::size();
1129 T output[size];
1130 T buffer[size];
1131 data.store(static_cast<void*>(buffer));
1132 for (const auto i : c10::irange(size)) {
1133 output[i] = buffer[size - i - 1];
1134 }
1135 return Vectorized<T>::loadu(static_cast<void*>(output));
1136 }
1137
1138 // Transpose the `src` buffer of type `T` and size (M,N) into the `dst` buffer. `ld_src` is the leading
1139 // dimension of `src` and `ld_dst` is the leading dimension of `dst`.
1140 template <typename T>
1141 inline void transpose_mxn(const T* src, int64_t ld_src, T* dst, int64_t ld_dst, int M, int N) {
1142 for (int i = 0; i < M; i++) {
1143 for (int j = 0; j < N; j++) {
1144 dst[j*ld_dst + i] = src[i*ld_src + j];
1145 }
1146 }
1147 }
1148
1149 template <typename T, int M, int N>
1150 inline void transpose_mxn(const T* src, int64_t ld_src, T* dst, int64_t ld_dst) {
1151 transpose_mxn<T>(src, ld_src, dst, ld_dst, M, N);
1152 }
1153
1154 }} // namespace at::vec::CPU_CAPABILITY
1155
1156 // additional headers for more operations that depend on vec_base
1157 #include <ATen/cpu/vec/vec_n.h>
1158 #include <ATen/cpu/vec/vec_mask.h>
1159 #include <ATen/cpu/vec/vec_convert.h>
1160