xref: /aosp_15_r20/external/executorch/kernels/optimized/vec/vec_base.h (revision 523fa7a60841cd1ecfb9cc4201f1ca8b03ed023a)
1 #pragma once
2 
3 #include <cassert>
4 #include <cstdint>
5 #include <cstring>
6 #include <functional>
7 #include <cmath>
8 #include <type_traits>
9 #include <bitset>
10 #include <climits>
11 
12 // These macros helped us unify vec_base.h
13 #ifdef CPU_CAPABILITY_AVX512
14 #if defined(__GNUC__)
15 #define __at_align__ __attribute__((aligned(64)))
16 #elif defined(_WIN32)
17 #define __at_align__ __declspec(align(64))
18 #else
19 #define __at_align__
20 #endif
21 #define VECTOR_WIDTH 64
22 #define int_vector __m512i
23 #else // CPU_CAPABILITY_AVX512
24 #if defined(__GNUC__)
25 #define __at_align__ __attribute__((aligned(32)))
26 #elif defined(_WIN32)
27 #define __at_align__ __declspec(align(32))
28 #else
29 #define __at_align__
30 #endif
31 #define VECTOR_WIDTH 32
32 #define int_vector __m256i
33 #endif // CPU_CAPABILITY_AVX512
34 
35 namespace executorch {
36 namespace vec {
37 
38 // See Note [CPU_CAPABILITY namespace]
39 inline namespace CPU_CAPABILITY {
40 
41 template<size_t n> struct int_of_size;
42 
43 #define DEFINE_INT_OF_SIZE(int_t) \
44 template<> struct int_of_size<sizeof(int_t)> { using type = int_t; }
45 
46 DEFINE_INT_OF_SIZE(int64_t);
47 DEFINE_INT_OF_SIZE(int32_t);
48 DEFINE_INT_OF_SIZE(int16_t);
49 DEFINE_INT_OF_SIZE(int8_t);
50 
51 #undef DEFINE_INT_OF_SIZE
52 
53 template <typename T>
54 using int_same_size_t = typename int_of_size<sizeof(T)>::type;
55 
56 // NOTE: If you specialize on a type, you must define all operations!
57 
58 // emulates Vectorized types
59 #if defined(__s390x__)
60 template <class T, class TEMP=void>
61 #else
62 template <class T>
63 #endif
64 struct Vectorized {
65 private:
66   __at_align__ T values[VECTOR_WIDTH / sizeof(T)];
67 public:
68   using value_type = T;
69   using size_type = int;
70   // Note [constexpr static function to avoid odr-usage compiler bug]
71   // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
72   // Why, you might ask, is size defined to be a static constexpr function,
73   // rather than a more ordinary 'static constexpr int size;' variable?
74   // The problem lies within ODR rules for static constexpr members versus
75   // static constexpr functions.  First, recall that this class (along with all
76   // of its derivations) live in an anonymous namespace: they are intended to be
77   // *completely* inlined at their use-sites, because we need to compile it
78   // multiple times for different instruction sets.
79   //
80   // Because of this constraint, we CANNOT provide a single definition for
81   // any static members in this class; since we want to compile the class
82   // multiple times, there wouldn't actually be any good place to put the
83   // definition.  Now here is the problem: if we ODR-use a static constexpr
84   // member, we are *obligated* to provide a definition.  Without the
85   // definition, you get a compile error like:
86   //
87   //    relocation R_X86_64_PC32 against undefined symbol
88   //    `_ZN2at6vec25612_GLOBAL__N_16VectorizedIdE4sizeE' can not be used when making
89   //    a shared object; recompile with -fPIC
90   //
91   // If this were C++17, we could replace a static constexpr variable with
92   // an inline variable which doesn't require one definition. But we are not
93   // C++17.  So the next best thing is to replace the member with a static
94   // constexpr (and therefore inline) function, which does not require ODR
95   // either.
96   //
97   // Also, technically according to the C++ standard, we don't have to define
98   // a constexpr variable if we never odr-use it.  But it seems that some
99   // versions GCC/Clang have buggy determinations on whether or not an
100   // identifier is odr-used or not, and in any case it's hard to tell if
101   // a variable is odr-used or not.  So best to just cut the problem at the root.
102   static constexpr size_type size_T = sizeof(T);  // Workaround to compile with VS2022.
sizeVectorized103   static constexpr size_type size() {
104     return VECTOR_WIDTH / size_T;
105   }
VectorizedVectorized106   Vectorized() : values{static_cast<T>(0)} {}
VectorizedVectorized107   Vectorized(T val) {
108     for (size_t i = 0; i != size(); i++) {
109       values[i] = val;
110     }
111   }
112   template<typename... Args,
113            typename = std::enable_if_t<(sizeof...(Args) == size())>>
VectorizedVectorized114   Vectorized(Args... vals) : values{vals...}{
115   }
116   // This also implies const T& operator[](int idx) const
117   inline operator const T*() const {
118     return values;
119   }
120   // This also implies T& operator[](int idx)
121   inline operator T*() {
122     return values;
123   }
124   // Return the values as char* for type punning
125   auto as_bytes() const -> const char* {
126     return reinterpret_cast<const char*>(values);
127   }
128   template <int64_t mask_>
blendVectorized129   static Vectorized<T> blend(const Vectorized<T>& a, const Vectorized<T>& b) {
130     int64_t mask = mask_;
131     Vectorized vector;
132     for (size_t i = 0; i < size(); ++i) {
133       if (mask & 0x01) {
134         vector[i] = b[i];
135       } else {
136         vector[i] = a[i];
137       }
138       mask = mask >> 1;
139     }
140     return vector;
141   }
blendvVectorized142   static Vectorized<T> blendv(const Vectorized<T>& a, const Vectorized<T>& b,
143                           const Vectorized<T>& mask) {
144     Vectorized vector;
145     int_same_size_t<T> buffer[size()];
146     mask.store(buffer);
147     for (size_t i = 0; i < size(); ++i) {
148       if (buffer[i] & 0x01)
149        {
150         vector[i] = b[i];
151       } else {
152         vector[i] = a[i];
153       }
154     }
155     return vector;
156   }
157   template<typename step_t>  // step sometimes requires a higher precision type (e.g., T=int, step_t=double)
158   static Vectorized<T> arange(T base = static_cast<T>(0), step_t step = static_cast<step_t>(1)) {
159     Vectorized vector;
160     for (size_t i = 0; i < size(); ++i) {
161       vector.values[i] = base + i * step;
162     }
163     return vector;
164   }
165   static Vectorized<T> set(const Vectorized<T>& a, const Vectorized<T>& b, int64_t count = size()) {
166     Vectorized vector;
167     for (size_t i = 0; i < size(); ++i) {
168       if (i < count) {
169         vector[i] = b[i];
170       } else {
171         vector[i] = a[i];
172       }
173     }
174     return vector;
175   }
loaduVectorized176   static Vectorized<T> loadu(const void* ptr) {
177     Vectorized vector;
178     std::memcpy(vector.values, ptr, VECTOR_WIDTH);
179     return vector;
180   }
loaduVectorized181   static Vectorized<T> loadu(const void* ptr, int64_t count) {
182     Vectorized vector;
183     std::memcpy(vector.values, ptr, count * sizeof(T));
184     return vector;
185   }
186   void store(void* ptr, int count = size()) const {
187     std::memcpy(ptr, values, count * sizeof(T));
188   }
zero_maskVectorized189   int zero_mask() const {
190     // returns an integer mask where all zero elements are translated to 1-bit and others are translated to 0-bit
191     int mask = 0;
192     for (size_t i = 0; i < size(); ++ i) {
193       if (values[i] == static_cast<T>(0)) {
194         mask |= (1 << i);
195       }
196     }
197     return mask;
198   }
isnanVectorized199   Vectorized<T> isnan() const {
200     Vectorized<T> vector;
201     for (size_t i = 0; i != size(); i++) {
202       if (std::isnan(values[i])) {
203         std::memset(static_cast<void*>(vector.values + i), 0xFF, sizeof(T));
204       } else {
205         std::memset(static_cast<void*>(vector.values + i), 0, sizeof(T));
206       }
207     }
208     return vector;
209   }
mapVectorized210   Vectorized<T> map(T (*const f)(T)) const {
211     Vectorized<T> ret;
212     for (size_t i = 0; i != size(); i++) {
213       ret[i] = f(values[i]);
214     }
215     return ret;
216   }
mapVectorized217   Vectorized<T> map(T (*const f)(const T &)) const {
218     Vectorized<T> ret;
219     for (size_t i = 0; i != size(); i++) {
220       ret[i] = f(values[i]);
221     }
222     return ret;
223   }
224   template <typename other_t_abs = T,
225             typename std::enable_if<!std::is_floating_point<other_t_abs>::value, int>::type = 0>
absVectorized226   Vectorized<T> abs() const {
227     // other_t_abs is for SFINAE and clarity. Make sure it is not changed.
228     static_assert(std::is_same<other_t_abs, T>::value, "other_t_abs must be T");
229     return map([](T x) -> T { return x < static_cast<T>(0) ? -x : x; });
230   }
231   template <typename float_t_abs = T,
232             typename std::enable_if<std::is_floating_point<float_t_abs>::value, int>::type = 0>
absVectorized233   Vectorized<T> abs() const {
234     // float_t_abs is for SFINAE and clarity. Make sure it is not changed.
235     static_assert(std::is_same<float_t_abs, T>::value, "float_t_abs must be T");
236     // Specifically deal with floating-point because the generic code above won't handle -0.0 (which should result in
237     // 0.0) properly.
238     return map([](T x) -> T { return std::abs(x); });
239   }
240 
acosVectorized241   Vectorized<T> acos() const {
242     return map(std::acos);
243   }
asinVectorized244   Vectorized<T> asin() const {
245     return map(std::asin);
246   }
atanVectorized247   Vectorized<T> atan() const {
248     return map(std::atan);
249   }
atan2Vectorized250   Vectorized<T> atan2(const Vectorized<T> &exp) const {
251     Vectorized<T> ret;
252     for (size_t i = 0; i < size(); ++i) {
253       ret[i] = std::atan2(values[i], exp[i]);
254     }
255     return ret;
256   }
257   template <
258     typename U = T,
259     typename std::enable_if_t<std::is_floating_point<U>::value, int> = 0>
copysignVectorized260   Vectorized<T> copysign(const Vectorized<T> &sign) const {
261     Vectorized<T> ret;
262     for (size_t i = 0; i < size(); i++) {
263       ret[i] = std::copysign(values[i], sign[i]);
264     }
265     return ret;
266   }
erfVectorized267   Vectorized<T> erf() const {
268     return map(std::erf);
269   }
erfcVectorized270   Vectorized<T> erfc() const {
271     return map(std::erfc);
272   }
expVectorized273   Vectorized<T> exp() const {
274     return map(std::exp);
275   }
exp2Vectorized276   Vectorized<T> exp2() const {
277     return map(std::exp2);
278   }
expm1Vectorized279   Vectorized<T> expm1() const {
280     return map(std::expm1);
281   }
fracVectorized282   Vectorized<T> frac() const {
283     return *this - this->trunc();
284   }
285   template <
286     typename U = T,
287     typename std::enable_if_t<std::is_floating_point<U>::value, int> = 0>
fmodVectorized288   Vectorized<T> fmod(const Vectorized<T>& q) const {
289     // U is for SFINAE purposes only. Make sure it is not changed.
290     static_assert(std::is_same<U, T>::value, "U must be T");
291     Vectorized<T> ret;
292     for (size_t i = 0; i < size(); ++i) {
293       ret[i] = std::fmod(values[i], q[i]);
294     }
295     return ret;
296   }
logVectorized297   Vectorized<T> log() const {
298     return map(std::log);
299   }
log10Vectorized300   Vectorized<T> log10() const {
301     return map(std::log10);
302   }
log1pVectorized303   Vectorized<T> log1p() const {
304     return map(std::log1p);
305   }
log2Vectorized306   Vectorized<T> log2() const {
307     return map(std::log2);
308   }
ceilVectorized309   Vectorized<T> ceil() const {
310     return map(std::ceil);
311   }
cosVectorized312   Vectorized<T> cos() const {
313     return map(std::cos);
314   }
coshVectorized315   Vectorized<T> cosh() const {
316     return map(std::cosh);
317   }
floorVectorized318   Vectorized<T> floor() const {
319     return map(std::floor);
320   }
hypotVectorized321   Vectorized<T> hypot(const Vectorized<T> &b) const {
322     Vectorized<T> ret;
323     for (size_t i = 0; i < size(); ++i) {
324       ret[i] = std::hypot(values[i], b[i]);
325     }
326     return ret;
327   }
negVectorized328   Vectorized<T> neg() const {
329     // NB: the trailing return type is needed because we need to coerce the
330     // return value back to T in the case of unary operator- incuring a
331     // promotion
332     return map([](T x) -> T { return -x; });
333   }
nextafterVectorized334   Vectorized<T> nextafter(const Vectorized<T> &b) const {
335     Vectorized<T> ret;
336     for (size_t i = 0; i < size(); ++i) {
337       ret[i] = std::nextafter(values[i], b[i]);
338     }
339     return ret;
340   }
roundVectorized341   Vectorized<T> round() const {
342     // TODO(T149257433): implement custom round that rounds midway numbers to
343     // the nearest even integer.
344     return map(std::round);
345   }
sinVectorized346   Vectorized<T> sin() const {
347     return map(std::sin);
348   }
sinhVectorized349   Vectorized<T> sinh() const {
350     return map(std::sinh);
351   }
tanVectorized352   Vectorized<T> tan() const {
353     return map(std::tan);
354   }
tanhVectorized355   Vectorized<T> tanh() const {
356     return map(std::tanh);
357   }
truncVectorized358   Vectorized<T> trunc() const {
359     return map(std::trunc);
360   }
lgammaVectorized361   Vectorized<T> lgamma() const {
362     return map(std::lgamma);
363   }
sqrtVectorized364   Vectorized<T> sqrt() const {
365     return map(std::sqrt);
366   }
reciprocalVectorized367   Vectorized<T> reciprocal() const {
368     return map([](T x) { return (T)(1) / x; });
369   }
rsqrtVectorized370   Vectorized<T> rsqrt() const {
371     return map([](T x) { return (T)1 / std::sqrt(x); });
372   }
powVectorized373   Vectorized<T> pow(const Vectorized<T> &exp) const {
374     Vectorized<T> ret;
375     for (size_t i = 0; i < size(); ++i) {
376       ret[i] = std::pow(values[i], exp[i]);
377     }
378     return ret;
379   }
380 private:
381   template <typename Op>
binary_predVectorized382   inline Vectorized<T> binary_pred(const Vectorized<T>& other, Op op) const {
383     // All bits are set to 1 if the pred is true, otherwise 0.
384     Vectorized<T> vector;
385     for (size_t i = 0; i != size(); i++) {
386       if (op(values[i], other.values[i])) {
387         std::memset(static_cast<void*>(vector.values + i), 0xFF, sizeof(T));
388       } else {
389         std::memset(static_cast<void*>(vector.values + i), 0, sizeof(T));
390       }
391     }
392     return vector;
393   }
394 
395 public:
396   Vectorized<T> operator==(const Vectorized<T>& other) const { return binary_pred(other, std::equal_to<T>()); }
397   Vectorized<T> operator!=(const Vectorized<T>& other) const { return binary_pred(other, std::not_equal_to<T>()); }
398   Vectorized<T> operator>=(const Vectorized<T>& other) const { return binary_pred(other, std::greater_equal<T>()); }
399   Vectorized<T> operator<=(const Vectorized<T>& other) const { return binary_pred(other, std::less_equal<T>()); }
400   Vectorized<T> operator>(const Vectorized<T>& other) const { return binary_pred(other, std::greater<T>()); }
401   Vectorized<T> operator<(const Vectorized<T>& other) const { return binary_pred(other, std::less<T>()); }
402 
403 private:
404   template <typename Op>
binary_pred_boolVectorized405   inline Vectorized<T> binary_pred_bool(const Vectorized<T>& other, Op op) const {
406     // 1 if the pred is true, otherwise 0.
407     Vectorized<T> vector;
408     for (size_t i = 0; i != size(); ++ i) {
409       vector[i] = static_cast<T>(op(values[i], other.values[i]));
410     }
411     return vector;
412   }
413 
414 public:
eqVectorized415   Vectorized<T> eq(const Vectorized<T>& other) const { return binary_pred_bool(other, std::equal_to<T>()); }
neVectorized416   Vectorized<T> ne(const Vectorized<T>& other) const { return binary_pred_bool(other, std::not_equal_to<T>()); }
gtVectorized417   Vectorized<T> gt(const Vectorized<T>& other) const { return binary_pred_bool(other, std::greater<T>()); }
geVectorized418   Vectorized<T> ge(const Vectorized<T>& other) const { return binary_pred_bool(other, std::greater_equal<T>()); }
ltVectorized419   Vectorized<T> lt(const Vectorized<T>& other) const { return binary_pred_bool(other, std::less<T>()); }
leVectorized420   Vectorized<T> le(const Vectorized<T>& other) const { return binary_pred_bool(other, std::less_equal<T>()); }
421 };
422 
423 template <class T> Vectorized<T> inline operator+(const Vectorized<T> &a, const Vectorized<T> &b) {
424   Vectorized<T> c;
425   for (size_t i = 0; i != Vectorized<T>::size(); i++) {
426     c[i] = a[i] + b[i];
427   }
428   return c;
429 }
430 
431 template <class T> Vectorized<T> inline operator-(const Vectorized<T> &a, const Vectorized<T> &b) {
432   Vectorized<T> c;
433   for (size_t i = 0; i != Vectorized<T>::size(); i++) {
434     c[i] = a[i] - b[i];
435   }
436   return c;
437 }
438 
439 template <class T> Vectorized<T> inline operator*(const Vectorized<T> &a, const Vectorized<T> &b) {
440   Vectorized<T> c;
441   for (size_t i = 0; i != Vectorized<T>::size(); i++) {
442     c[i] = a[i] * b[i];
443   }
444   return c;
445 }
446 
447 template <class T> Vectorized<T> inline operator/(const Vectorized<T> &a, const Vectorized<T> &b) {
448   Vectorized<T> c;
449   for (size_t i = 0; i != Vectorized<T>::size(); i++) {
450     c[i] = a[i] / b[i];
451   }
452   return c;
453 }
454 
455 template <class T> Vectorized<T> inline operator||(
456     const Vectorized<T> &a, const Vectorized<T> &b) {
457   Vectorized<T> c;
458   for (size_t i = 0; i != Vectorized<T>::size(); i++) {
459     c[i] = a[i] || b[i];
460   }
461   return c;
462 }
463 
464 // Implements the IEEE 754 201X `maximum` operation, which propagates NaN if
465 // either input is a NaN.
466 template <class T>
maximum(const Vectorized<T> & a,const Vectorized<T> & b)467 Vectorized<T> inline maximum(const Vectorized<T> &a, const Vectorized<T> &b) {
468   Vectorized<T> c;
469   for (size_t i = 0; i != Vectorized<T>::size(); i++) {
470     c[i] = (a[i] > b[i]) ? a[i] : b[i];
471     if (std::isnan(a[i])) {
472       // If either input is NaN, propagate a NaN.
473       // NOTE: The case where b[i] was NaN is handled correctly by the naive
474       // ternary operator above.
475       c[i] = a[i];
476     }
477   }
478   return c;
479 }
480 
481 // Implements the IEEE 754 201X `minimum` operation, which propagates NaN if
482 // either input is a NaN.
483 template <class T>
minimum(const Vectorized<T> & a,const Vectorized<T> & b)484 Vectorized<T> inline minimum(const Vectorized<T> &a, const Vectorized<T> &b) {
485   Vectorized<T> c;
486   for (size_t i = 0; i != Vectorized<T>::size(); i++) {
487     c[i] = (a[i] < b[i]) ? a[i] : b[i];
488     if (std::isnan(a[i])) {
489       // If either input is NaN, propagate a NaN.
490       // NOTE: The case where b[i] was NaN is handled correctly by the naive
491       // ternary operator above.
492       c[i] = a[i];
493     }
494   }
495   return c;
496 }
497 
498 template <class T>
clamp(const Vectorized<T> & a,const Vectorized<T> & min_vec,const Vectorized<T> & max_vec)499 Vectorized<T> inline clamp(const Vectorized<T> &a, const Vectorized<T> &min_vec, const Vectorized<T> &max_vec) {
500   Vectorized<T> c;
501   for (size_t i = 0; i != Vectorized<T>::size(); i++) {
502     c[i] = std::min(std::max(a[i], min_vec[i]), max_vec[i]);
503   }
504   return c;
505 }
506 
507 template <class T>
clamp_max(const Vectorized<T> & a,const Vectorized<T> & max_vec)508 Vectorized<T> inline clamp_max(const Vectorized<T> &a, const Vectorized<T> &max_vec) {
509   Vectorized<T> c;
510   for (size_t i = 0; i != Vectorized<T>::size(); i++) {
511     c[i] = a[i] > max_vec[i] ? max_vec[i] : a[i];
512   }
513   return c;
514 }
515 
516 template <class T>
clamp_min(const Vectorized<T> & a,const Vectorized<T> & min_vec)517 Vectorized<T> inline clamp_min(const Vectorized<T> &a, const Vectorized<T> &min_vec) {
518   Vectorized<T> c;
519   for (size_t i = 0; i != Vectorized<T>::size(); i++) {
520     c[i] = a[i] < min_vec[i] ? min_vec[i] : a[i];
521   }
522   return c;
523 }
524 
525 struct Vectorizedi;
526 
527 #if defined(CPU_CAPABILITY_AVX2) || defined(CPU_CAPABILITY_AVX512)
528 template <class T, typename Op>
bitwise_binary_op(const Vectorized<T> & a,const Vectorized<T> & b,Op op)529 static inline Vectorized<T> bitwise_binary_op(const Vectorized<T> &a, const Vectorized<T> &b, Op op) {
530   int_vector buffer;
531 #if defined(CPU_CAPABILITY_AVX2)
532   int_vector a_buffer = _mm256_load_si256(reinterpret_cast<const int_vector*>((const T*)a));
533   int_vector b_buffer = _mm256_load_si256(reinterpret_cast<const int_vector*>((const T*)b));
534 #elif defined(CPU_CAPABILITY_AVX512)
535   int_vector a_buffer = _mm512_load_si512(reinterpret_cast<const int_vector*>((const T*)a));
536   int_vector b_buffer = _mm512_load_si512(reinterpret_cast<const int_vector*>((const T*)b));
537 #endif
538   buffer = op(a_buffer, b_buffer);
539   __at_align__ T results[Vectorized<T>::size()];
540 
541 #if defined(CPU_CAPABILITY_AVX2)
542   _mm256_store_si256(reinterpret_cast<int_vector*>(results), buffer);
543 #elif defined(CPU_CAPABILITY_AVX512)
544   _mm512_store_si512(reinterpret_cast<int_vector*>(results), buffer);
545 #endif
546   return Vectorized<T>::loadu(results);
547 }
548 
549 template<class T, typename std::enable_if_t<!std::is_base_of<Vectorizedi, Vectorized<T>>::value, int> = 0>
550 inline Vectorized<T> operator&(const Vectorized<T>& a, const Vectorized<T>& b) {
551   // We enclose _mm512_and_si512 or _mm256_and_si256 with lambda because it is always_inline
552 #if defined(CPU_CAPABILITY_AVX2)
553   return bitwise_binary_op(a, b, [](int_vector a, int_vector b) { return _mm256_and_si256(a, b); });
554 #elif defined(CPU_CAPABILITY_AVX512)
555   return bitwise_binary_op(a, b, [](int_vector a, int_vector b) { return _mm512_and_si512(a, b); });
556 #endif
557 }
558 template<class T, typename std::enable_if_t<!std::is_base_of<Vectorizedi, Vectorized<T>>::value, int> = 0>
559 inline Vectorized<T> operator|(const Vectorized<T>& a, const Vectorized<T>& b) {
560   // We enclose _mm512_or_si512 or _mm256_or_si256 with lambda because it is always_inline
561 #if defined(CPU_CAPABILITY_AVX2)
562   return bitwise_binary_op(a, b, [](int_vector a, int_vector b) { return _mm256_or_si256(a, b); });
563 #elif defined(CPU_CAPABILITY_AVX512)
564   return bitwise_binary_op(a, b, [](int_vector a, int_vector b) { return _mm512_or_si512(a, b); });
565 #endif
566 }
567 template<class T, typename std::enable_if_t<!std::is_base_of<Vectorizedi, Vectorized<T>>::value, int> = 0>
568 inline Vectorized<T> operator^(const Vectorized<T>& a, const Vectorized<T>& b) {
569   // We enclose _mm512_xor_si512 or _mm256_xor_si256 with lambda because it is always_inline
570 #if defined(CPU_CAPABILITY_AVX2)
571   return bitwise_binary_op(a, b, [](int_vector a, int_vector b) { return _mm256_xor_si256(a, b); });
572 #elif defined(CPU_CAPABILITY_AVX512)
573   return bitwise_binary_op(a, b, [](int_vector a, int_vector b) { return _mm512_xor_si512(a, b); });
574 #endif
575 }
576 
577 #else
578 
579 template <typename T>
580 auto load(char const* data) -> T {
581   T ret;
582   std::memcpy(&ret, data, sizeof(ret));
583   return ret;
584 }
585 
586 template<class T, typename Op>
bitwise_binary_op(const Vectorized<T> & a,const Vectorized<T> & b,Op op)587 static inline Vectorized<T> bitwise_binary_op(const Vectorized<T> &a, const Vectorized<T> &b, Op op) {
588   static constexpr uint32_t element_no = VECTOR_WIDTH / sizeof(intmax_t);
589   __at_align__ intmax_t buffer[element_no];
590   static_assert(VECTOR_WIDTH % sizeof(intmax_t) == 0, "VECTOR_WIDTH not a multiple of sizeof(intmax_t)");
591   static_assert(sizeof(buffer) == sizeof(Vectorized<T>), "sizeof(buffer) must match sizeof(Vectorized<T>)");
592   // We should be using memcpy in order to respect the strict aliasing rule
593   // see: https://github.com/pytorch/pytorch/issues/66119
594   // Using char* is defined in the C11 standard 6.5 Expression paragraph 7
595   // (http://www.open-std.org/jtc1/sc22/wg14/www/docs/n1570.pdf)
596   const auto* a_data = a.as_bytes();
597   const auto* b_data = b.as_bytes();
598   // load each intmax_t chunk and process; increase pointers by sizeof(intmax_t)
599   for (auto& out : buffer) {
600     out = op(load<intmax_t>(a_data), load<intmax_t>(b_data));
601     a_data += sizeof(intmax_t);
602     b_data += sizeof(intmax_t);
603   }
604   assert(a_data == a.as_bytes() + sizeof(a));
605   assert(b_data == b.as_bytes() + sizeof(b));
606   return Vectorized<T>::loadu(buffer);
607 }
608 
609 template<class T, typename std::enable_if_t<!std::is_base_of<Vectorizedi, Vectorized<T>>::value, int> = 0>
610 inline Vectorized<T> operator&(const Vectorized<T>& a, const Vectorized<T>& b) {
611   return bitwise_binary_op(a, b, std::bit_and<intmax_t>());
612 }
613 template<class T, typename std::enable_if_t<!std::is_base_of<Vectorizedi, Vectorized<T>>::value, int> = 0>
614 inline Vectorized<T> operator|(const Vectorized<T>& a, const Vectorized<T>& b) {
615   return bitwise_binary_op(a, b, std::bit_or<intmax_t>());
616 }
617 template<class T, typename std::enable_if_t<!std::is_base_of<Vectorizedi, Vectorized<T>>::value, int> = 0>
618 inline Vectorized<T> operator^(const Vectorized<T>& a, const Vectorized<T>& b) {
619   return bitwise_binary_op(a, b, std::bit_xor<intmax_t>());
620 }
621 
622 #endif // defined(CPU_CAPABILITY_AVX2) || defined(CPU_CAPABILITY_AVX512)
623 
624 template<class T, typename std::enable_if_t<!std::is_base_of<Vectorizedi, Vectorized<T>>::value, int> = 0>
625 inline Vectorized<T> operator~(const Vectorized<T>& a) {
626   Vectorized<T> ones;  // All bits are 1
627   memset((T*) ones, 0xFF, VECTOR_WIDTH);
628   return a ^ ones;
629 }
630 
631 template <class T> Vectorized<T> inline operator<<(const Vectorized<T> &a, const Vectorized<T> &b) {
632   constexpr T max_shift = sizeof(T) * CHAR_BIT;
633   Vectorized<T> c;
634   for (size_t i = 0; i != Vectorized<T>::size(); i++) {
635     T shift = b[i];
636     if ((static_cast<std::make_signed_t<T>>(shift) < 0) || (shift >= max_shift)) {
637       c[i] = 0;
638     } else {
639       c[i] = static_cast<std::make_unsigned_t<T>>(a[i]) << shift;
640     }
641   }
642   return c;
643 }
644 
645 template <class T> Vectorized<T> inline operator>>(const Vectorized<T> &a, const Vectorized<T> &b) {
646   // right shift value to retain sign bit for signed and no bits for unsigned
647   constexpr T max_shift = sizeof(T) * CHAR_BIT - std::is_signed_v<T>;
648   Vectorized<T> c;
649   for (size_t i = 0; i != Vectorized<T>::size(); i++) {
650     T shift = b[i];
651     if ((static_cast<std::make_signed_t<T>>(shift) < 0) || (shift >= max_shift)) {
652       c[i] = a[i] >> max_shift;
653     } else {
654       c[i] = a[i] >> shift;
655     }
656   }
657   return c;
658 }
659 
660 template <typename T>
661 inline Vectorized<T>& operator += (Vectorized<T>& a, const Vectorized<T>& b) {
662   a = a + b;
663   return a;
664 }
665 template <typename T>
666 inline Vectorized<T>& operator -= (Vectorized<T>& a, const Vectorized<T>& b) {
667   a = a - b;
668   return a;
669 }
670 template <typename T>
671 inline Vectorized<T>& operator /= (Vectorized<T>& a, const Vectorized<T>& b) {
672   a = a / b;
673   return a;
674 }
675 template <typename T>
676 inline Vectorized<T>& operator %= (Vectorized<T>& a, const Vectorized<T>& b) {
677   a = a % b;
678   return a;
679 }
680 template <typename T>
681 inline Vectorized<T>& operator *= (Vectorized<T>& a, const Vectorized<T>& b) {
682   a = a * b;
683   return a;
684 }
685 
686 template <typename T>
687 inline Vectorized<T>& operator <<= (Vectorized<T>& a, const Vectorized<T>& b) {
688   a = a << b;
689   return a;
690 }
691 
692 template <typename T>
693 inline Vectorized<T>& operator >>= (Vectorized<T>& a, const Vectorized<T>& b) {
694   a = a >> b;
695   return a;
696 }
697 
698 template <typename T>
fmadd(const Vectorized<T> & a,const Vectorized<T> & b,const Vectorized<T> & c)699 inline Vectorized<T> fmadd(const Vectorized<T>& a, const Vectorized<T>& b, const Vectorized<T>& c) {
700   return a * b + c;
701 }
702 
703 template <typename T>
fmsub(const Vectorized<T> & a,const Vectorized<T> & b,const Vectorized<T> & c)704 inline Vectorized<T> fmsub(const Vectorized<T>& a, const Vectorized<T>& b, const Vectorized<T>& c) {
705   return a * b - c;
706 }
707 
708 template <int64_t scale = 1, typename T = void>
709 std::enable_if_t<scale == 1 || scale == 2 || scale == 4 || scale == 8, Vectorized<T>>
gather(T const * base_addr,const Vectorized<int_same_size_t<T>> & vindex)710 inline gather(T const* base_addr, const Vectorized<int_same_size_t<T>>& vindex) {
711   static constexpr int size = Vectorized<T>::size();
712   int_same_size_t<T> index_arr[size];
713   vindex.store(static_cast<void*>(index_arr));
714   T buffer[size];
715   for (size_t i = 0; i < size; ++i) {
716     buffer[i] = base_addr[index_arr[i] * scale / sizeof(T)];
717   }
718   return Vectorized<T>::loadu(static_cast<void*>(buffer));
719 }
720 
721 template <int64_t scale = 1, typename T = void>
722 std::enable_if_t<scale == 1 || scale == 2 || scale == 4 || scale == 8, Vectorized<T>>
mask_gather(const Vectorized<T> & src,T const * base_addr,const Vectorized<int_same_size_t<T>> & vindex,Vectorized<T> & mask)723 inline mask_gather(const Vectorized<T>& src, T const* base_addr,
724                    const Vectorized<int_same_size_t<T>>& vindex, Vectorized<T>& mask) {
725   static constexpr int size = Vectorized<T>::size();
726   T src_arr[size];
727   int_same_size_t<T> mask_arr[size];  // use int type so we can logical and
728   int_same_size_t<T> index_arr[size];
729   src.store(static_cast<void*>(src_arr));
730   mask.store(static_cast<void*>(mask_arr));
731   vindex.store(static_cast<void*>(index_arr));
732   T buffer[size];
733   for (size_t i = 0; i < size; ++i) {
734     if (mask_arr[i] & 0x01) {  // check highest bit
735       buffer[i] = base_addr[index_arr[i] * scale / sizeof(T)];
736     } else {
737       buffer[i] = src_arr[i];
738     }
739   }
740   mask = Vectorized<T>();  // "zero out" mask
741   return Vectorized<T>::loadu(static_cast<void*>(buffer));
742 }
743 
744 // Cast a given vector to another type without changing the bits representation.
745 // So a Vectorized<double> of 512 bits containing all ones can be cast to a
746 // Vectorized<int64_t> of 512 bits containing all ones (i.e., eight negative 1s).
747 // A Vec<double> of 256 bits containing all ones can be cast to a
748 // Vec<int64_t> of 256 bits containing all ones (i.e., four negative 1s).
749 // There is a struct here because we don't have static_if and I can't
750 // partially specialize a templated function.
751 template<typename dst_t, typename src_t>
752 struct CastImpl {
applyCastImpl753   static inline Vectorized<dst_t> apply(const Vectorized<src_t>& src) {
754     src_t src_arr[Vectorized<src_t>::size()];
755     src.store(static_cast<void*>(src_arr));
756     return Vectorized<dst_t>::loadu(static_cast<const void*>(src_arr));
757   }
758 };
759 
760 template<typename scalar_t>
761 struct CastImpl<scalar_t, scalar_t> {
762   static inline Vectorized<scalar_t> apply(const Vectorized<scalar_t>& src) {
763     return src;
764   }
765 };
766 
767 template<typename dst_t, typename src_t>
768 inline Vectorized<dst_t> cast(const Vectorized<src_t>& src) {
769   return CastImpl<dst_t, src_t>::apply(src);
770 }
771 
772 template <typename T>
773 inline Vectorized<int_same_size_t<T>> convert_to_int_of_same_size(const Vectorized<T>& src) {
774   static constexpr int size = Vectorized<T>::size();
775   T src_arr[size];
776   src.store(static_cast<void*>(src_arr));
777   int_same_size_t<T> buffer[size];
778   for (size_t i = 0; i < size; ++i) {
779     buffer[i] = static_cast<int_same_size_t<T>>(src_arr[i]);
780   }
781   return Vectorized<int_same_size_t<T>>::loadu(static_cast<void*>(buffer));
782 }
783 
784 // Example inputs for AVX512:
785 // a   Vectorized<float>   = {a0, b0, a1, b1, a2, b2, a3, b3, a4, b4, a5, b5, a6, b6, a7, b7}
786 // b   Vectorized<float>   = {a8, b8, a9, b9, a10, b10, a11, b11, a12, b12, a13, b13, a14, b14, a15, b15}
787 // returns:
788 //           Vectorized<float>   = {a0, a1, a2, a3, a4, a5, a6, a7, a8, a9, a10, a11, a12, a13, a14, a15}
789 //           Vectorized<float>   = {b0, b1, b2, b3, b4, b5, b6, b7, b8, b9, b10, b11, b12, b13, b14, b15}
790 // Example inputs for AVX2: a           Vectorized<float>   = {a0, b0, a1, b1, a2, b2, a3, b3}
791 //               b                      Vectorized<float>   = {a4, b4, a5, b5, a6, b6, a7, b7}
792 //       returns:                       Vectorized<float>   = {a0, a1, a2, a3, a4, a5, a6, a7}
793 //                                      Vectorized<float>   = {b0, b1, b2, b3, b4, b5, b6, b7}
794 template <typename T>
795 inline std::enable_if_t<Vectorized<T>::size() % 2 == 0, std::pair<Vectorized<T>, Vectorized<T>>>
796 deinterleave2(const Vectorized<T>& a, const Vectorized<T>& b) {
797   static constexpr int size = Vectorized<T>::size();
798   static constexpr int half_size = size / 2;
799   T a_arr[size];
800   T b_arr[size];
801   T buffer1[size];
802   T buffer2[size];
803   a.store(static_cast<void*>(a_arr));
804   b.store(static_cast<void*>(b_arr));
805   for (size_t i = 0; i < half_size; ++i) {
806     buffer1[i] = a_arr[i * 2];
807     buffer1[half_size + i] = b_arr[i * 2];
808     buffer2[i] = a_arr[i * 2 + 1];
809     buffer2[half_size + i] = b_arr[i * 2 + 1];
810   }
811   return std::make_pair(Vectorized<T>::loadu(static_cast<void*>(buffer1)),
812                         Vectorized<T>::loadu(static_cast<void*>(buffer2)));
813 }
814 
815 // inverse operation of deinterleave2
816 // Example inputs for AVX512:
817 //  a       Vectorized<float>   = {a0, a1, a2, a3, a4, a5, a6, a7, a8, a9, a10, a11, a12, a13, a14, a15}
818 //  b       Vectorized<float>   = {b0, b1, b2, b3, b4, b5, b6, b7, b8, b9, b10, b11, b12, b13, b14, b15}
819 // returns, for AVX512:
820 //          Vectorized<float>   = {a0, b0, a1, b1, a2, b2, a3, b3, a4, b4, a5, b5, a6, b6, a7, b7}
821 //          Vectorized<float>   = {a8, b8, a9, b9, a10, b10, a11, b11, a12, b12, a13, b13, a14, b14, a15, b15}
822 // Example inputs for AVX2 : a           Vectorized<float>   = {a0, a1, a2, a3, a4, a5, a6, a7}
823 //                   b                   Vectorized<float>   = {b0, b1, b2, b3, b4, b5, b6, b7}
824 //       returns:            Vectorized<float>   = {a0, b0, a1, b1, a2, b2, a3, b3}
825 //                           Vectorized<float>   = {a4, b4, a5, b5, a6, b6, a7, b7}
826 template <typename T>
827 inline std::enable_if_t<Vectorized<T>::size() % 2 == 0, std::pair<Vectorized<T>, Vectorized<T>>>
828 interleave2(const Vectorized<T>& a, const Vectorized<T>& b) {
829   static constexpr int size = Vectorized<T>::size();
830   static constexpr int half_size = size / 2;
831   T a_arr[size];
832   T b_arr[size];
833   T buffer1[size];
834   T buffer2[size];
835   a.store(static_cast<void*>(a_arr));
836   b.store(static_cast<void*>(b_arr));
837   for (size_t i = 0; i < half_size; ++i) {
838     buffer1[i * 2] = a_arr[i];
839     buffer1[i * 2 + 1] = b_arr[i];
840     buffer2[i * 2] = a_arr[half_size + i];
841     buffer2[i * 2 + 1] = b_arr[half_size + i];
842   }
843   return std::make_pair(Vectorized<T>::loadu(static_cast<void*>(buffer1)),
844                         Vectorized<T>::loadu(static_cast<void*>(buffer2)));
845 }
846 
847 template <typename src_T, typename dst_T>
848 inline void convert(const src_T *src, dst_T *dst, int64_t n) {
849 #ifndef _MSC_VER
850 # pragma unroll
851 #endif
852   for (int64_t i = 0; i < n; ++i) {
853     (void)i; //Suppress unused variable warning
854     *dst = static_cast<dst_T>(*src);
855     src++;
856     dst++;
857   }
858 }
859 
860 template <typename T>
861 inline Vectorized<T> flip(const Vectorized<T> & data) {
862   static constexpr int size = Vectorized<T>::size();
863   T output[size];
864   T buffer[size];
865   data.store(static_cast<void*>(buffer));
866   for (size_t i = 0; i < size; ++i) {
867     output[i] = buffer[size - i - 1];
868   }
869   return Vectorized<T>::loadu(static_cast<void*>(output));
870 }
871 
872 // Transpose the `src` buffer of type `T` and size (M,N) into the `dst` buffer. `ld_src` is the leading
873 // dimension of `src` and `ld_dst` is the leading dimension of `dst`.
874 template <typename T, int M, int N>
875 inline void transpose_mxn(const T* src, int64_t ld_src, T* dst, int64_t ld_dst) {
876   for (size_t i = 0; i < M; i++) {
877     for (int j = 0; j < N; j++) {
878       dst[j*ld_dst + i] = src[i*ld_src + j];
879     }
880   }
881 }
882 
883 } // namespace CPU_CAPABILITY
884 
885 } // namespace vec
886 } // namespace executorch
887