1 #pragma once
2
3 #include <cassert>
4 #include <cstdint>
5 #include <cstring>
6 #include <functional>
7 #include <cmath>
8 #include <type_traits>
9 #include <bitset>
10 #include <climits>
11
12 // These macros helped us unify vec_base.h
13 #ifdef CPU_CAPABILITY_AVX512
14 #if defined(__GNUC__)
15 #define __at_align__ __attribute__((aligned(64)))
16 #elif defined(_WIN32)
17 #define __at_align__ __declspec(align(64))
18 #else
19 #define __at_align__
20 #endif
21 #define VECTOR_WIDTH 64
22 #define int_vector __m512i
23 #else // CPU_CAPABILITY_AVX512
24 #if defined(__GNUC__)
25 #define __at_align__ __attribute__((aligned(32)))
26 #elif defined(_WIN32)
27 #define __at_align__ __declspec(align(32))
28 #else
29 #define __at_align__
30 #endif
31 #define VECTOR_WIDTH 32
32 #define int_vector __m256i
33 #endif // CPU_CAPABILITY_AVX512
34
35 namespace executorch {
36 namespace vec {
37
38 // See Note [CPU_CAPABILITY namespace]
39 inline namespace CPU_CAPABILITY {
40
41 template<size_t n> struct int_of_size;
42
43 #define DEFINE_INT_OF_SIZE(int_t) \
44 template<> struct int_of_size<sizeof(int_t)> { using type = int_t; }
45
46 DEFINE_INT_OF_SIZE(int64_t);
47 DEFINE_INT_OF_SIZE(int32_t);
48 DEFINE_INT_OF_SIZE(int16_t);
49 DEFINE_INT_OF_SIZE(int8_t);
50
51 #undef DEFINE_INT_OF_SIZE
52
53 template <typename T>
54 using int_same_size_t = typename int_of_size<sizeof(T)>::type;
55
56 // NOTE: If you specialize on a type, you must define all operations!
57
58 // emulates Vectorized types
59 #if defined(__s390x__)
60 template <class T, class TEMP=void>
61 #else
62 template <class T>
63 #endif
64 struct Vectorized {
65 private:
66 __at_align__ T values[VECTOR_WIDTH / sizeof(T)];
67 public:
68 using value_type = T;
69 using size_type = int;
70 // Note [constexpr static function to avoid odr-usage compiler bug]
71 // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
72 // Why, you might ask, is size defined to be a static constexpr function,
73 // rather than a more ordinary 'static constexpr int size;' variable?
74 // The problem lies within ODR rules for static constexpr members versus
75 // static constexpr functions. First, recall that this class (along with all
76 // of its derivations) live in an anonymous namespace: they are intended to be
77 // *completely* inlined at their use-sites, because we need to compile it
78 // multiple times for different instruction sets.
79 //
80 // Because of this constraint, we CANNOT provide a single definition for
81 // any static members in this class; since we want to compile the class
82 // multiple times, there wouldn't actually be any good place to put the
83 // definition. Now here is the problem: if we ODR-use a static constexpr
84 // member, we are *obligated* to provide a definition. Without the
85 // definition, you get a compile error like:
86 //
87 // relocation R_X86_64_PC32 against undefined symbol
88 // `_ZN2at6vec25612_GLOBAL__N_16VectorizedIdE4sizeE' can not be used when making
89 // a shared object; recompile with -fPIC
90 //
91 // If this were C++17, we could replace a static constexpr variable with
92 // an inline variable which doesn't require one definition. But we are not
93 // C++17. So the next best thing is to replace the member with a static
94 // constexpr (and therefore inline) function, which does not require ODR
95 // either.
96 //
97 // Also, technically according to the C++ standard, we don't have to define
98 // a constexpr variable if we never odr-use it. But it seems that some
99 // versions GCC/Clang have buggy determinations on whether or not an
100 // identifier is odr-used or not, and in any case it's hard to tell if
101 // a variable is odr-used or not. So best to just cut the problem at the root.
102 static constexpr size_type size_T = sizeof(T); // Workaround to compile with VS2022.
sizeVectorized103 static constexpr size_type size() {
104 return VECTOR_WIDTH / size_T;
105 }
VectorizedVectorized106 Vectorized() : values{static_cast<T>(0)} {}
VectorizedVectorized107 Vectorized(T val) {
108 for (size_t i = 0; i != size(); i++) {
109 values[i] = val;
110 }
111 }
112 template<typename... Args,
113 typename = std::enable_if_t<(sizeof...(Args) == size())>>
VectorizedVectorized114 Vectorized(Args... vals) : values{vals...}{
115 }
116 // This also implies const T& operator[](int idx) const
117 inline operator const T*() const {
118 return values;
119 }
120 // This also implies T& operator[](int idx)
121 inline operator T*() {
122 return values;
123 }
124 // Return the values as char* for type punning
125 auto as_bytes() const -> const char* {
126 return reinterpret_cast<const char*>(values);
127 }
128 template <int64_t mask_>
blendVectorized129 static Vectorized<T> blend(const Vectorized<T>& a, const Vectorized<T>& b) {
130 int64_t mask = mask_;
131 Vectorized vector;
132 for (size_t i = 0; i < size(); ++i) {
133 if (mask & 0x01) {
134 vector[i] = b[i];
135 } else {
136 vector[i] = a[i];
137 }
138 mask = mask >> 1;
139 }
140 return vector;
141 }
blendvVectorized142 static Vectorized<T> blendv(const Vectorized<T>& a, const Vectorized<T>& b,
143 const Vectorized<T>& mask) {
144 Vectorized vector;
145 int_same_size_t<T> buffer[size()];
146 mask.store(buffer);
147 for (size_t i = 0; i < size(); ++i) {
148 if (buffer[i] & 0x01)
149 {
150 vector[i] = b[i];
151 } else {
152 vector[i] = a[i];
153 }
154 }
155 return vector;
156 }
157 template<typename step_t> // step sometimes requires a higher precision type (e.g., T=int, step_t=double)
158 static Vectorized<T> arange(T base = static_cast<T>(0), step_t step = static_cast<step_t>(1)) {
159 Vectorized vector;
160 for (size_t i = 0; i < size(); ++i) {
161 vector.values[i] = base + i * step;
162 }
163 return vector;
164 }
165 static Vectorized<T> set(const Vectorized<T>& a, const Vectorized<T>& b, int64_t count = size()) {
166 Vectorized vector;
167 for (size_t i = 0; i < size(); ++i) {
168 if (i < count) {
169 vector[i] = b[i];
170 } else {
171 vector[i] = a[i];
172 }
173 }
174 return vector;
175 }
loaduVectorized176 static Vectorized<T> loadu(const void* ptr) {
177 Vectorized vector;
178 std::memcpy(vector.values, ptr, VECTOR_WIDTH);
179 return vector;
180 }
loaduVectorized181 static Vectorized<T> loadu(const void* ptr, int64_t count) {
182 Vectorized vector;
183 std::memcpy(vector.values, ptr, count * sizeof(T));
184 return vector;
185 }
186 void store(void* ptr, int count = size()) const {
187 std::memcpy(ptr, values, count * sizeof(T));
188 }
zero_maskVectorized189 int zero_mask() const {
190 // returns an integer mask where all zero elements are translated to 1-bit and others are translated to 0-bit
191 int mask = 0;
192 for (size_t i = 0; i < size(); ++ i) {
193 if (values[i] == static_cast<T>(0)) {
194 mask |= (1 << i);
195 }
196 }
197 return mask;
198 }
isnanVectorized199 Vectorized<T> isnan() const {
200 Vectorized<T> vector;
201 for (size_t i = 0; i != size(); i++) {
202 if (std::isnan(values[i])) {
203 std::memset(static_cast<void*>(vector.values + i), 0xFF, sizeof(T));
204 } else {
205 std::memset(static_cast<void*>(vector.values + i), 0, sizeof(T));
206 }
207 }
208 return vector;
209 }
mapVectorized210 Vectorized<T> map(T (*const f)(T)) const {
211 Vectorized<T> ret;
212 for (size_t i = 0; i != size(); i++) {
213 ret[i] = f(values[i]);
214 }
215 return ret;
216 }
mapVectorized217 Vectorized<T> map(T (*const f)(const T &)) const {
218 Vectorized<T> ret;
219 for (size_t i = 0; i != size(); i++) {
220 ret[i] = f(values[i]);
221 }
222 return ret;
223 }
224 template <typename other_t_abs = T,
225 typename std::enable_if<!std::is_floating_point<other_t_abs>::value, int>::type = 0>
absVectorized226 Vectorized<T> abs() const {
227 // other_t_abs is for SFINAE and clarity. Make sure it is not changed.
228 static_assert(std::is_same<other_t_abs, T>::value, "other_t_abs must be T");
229 return map([](T x) -> T { return x < static_cast<T>(0) ? -x : x; });
230 }
231 template <typename float_t_abs = T,
232 typename std::enable_if<std::is_floating_point<float_t_abs>::value, int>::type = 0>
absVectorized233 Vectorized<T> abs() const {
234 // float_t_abs is for SFINAE and clarity. Make sure it is not changed.
235 static_assert(std::is_same<float_t_abs, T>::value, "float_t_abs must be T");
236 // Specifically deal with floating-point because the generic code above won't handle -0.0 (which should result in
237 // 0.0) properly.
238 return map([](T x) -> T { return std::abs(x); });
239 }
240
acosVectorized241 Vectorized<T> acos() const {
242 return map(std::acos);
243 }
asinVectorized244 Vectorized<T> asin() const {
245 return map(std::asin);
246 }
atanVectorized247 Vectorized<T> atan() const {
248 return map(std::atan);
249 }
atan2Vectorized250 Vectorized<T> atan2(const Vectorized<T> &exp) const {
251 Vectorized<T> ret;
252 for (size_t i = 0; i < size(); ++i) {
253 ret[i] = std::atan2(values[i], exp[i]);
254 }
255 return ret;
256 }
257 template <
258 typename U = T,
259 typename std::enable_if_t<std::is_floating_point<U>::value, int> = 0>
copysignVectorized260 Vectorized<T> copysign(const Vectorized<T> &sign) const {
261 Vectorized<T> ret;
262 for (size_t i = 0; i < size(); i++) {
263 ret[i] = std::copysign(values[i], sign[i]);
264 }
265 return ret;
266 }
erfVectorized267 Vectorized<T> erf() const {
268 return map(std::erf);
269 }
erfcVectorized270 Vectorized<T> erfc() const {
271 return map(std::erfc);
272 }
expVectorized273 Vectorized<T> exp() const {
274 return map(std::exp);
275 }
exp2Vectorized276 Vectorized<T> exp2() const {
277 return map(std::exp2);
278 }
expm1Vectorized279 Vectorized<T> expm1() const {
280 return map(std::expm1);
281 }
fracVectorized282 Vectorized<T> frac() const {
283 return *this - this->trunc();
284 }
285 template <
286 typename U = T,
287 typename std::enable_if_t<std::is_floating_point<U>::value, int> = 0>
fmodVectorized288 Vectorized<T> fmod(const Vectorized<T>& q) const {
289 // U is for SFINAE purposes only. Make sure it is not changed.
290 static_assert(std::is_same<U, T>::value, "U must be T");
291 Vectorized<T> ret;
292 for (size_t i = 0; i < size(); ++i) {
293 ret[i] = std::fmod(values[i], q[i]);
294 }
295 return ret;
296 }
logVectorized297 Vectorized<T> log() const {
298 return map(std::log);
299 }
log10Vectorized300 Vectorized<T> log10() const {
301 return map(std::log10);
302 }
log1pVectorized303 Vectorized<T> log1p() const {
304 return map(std::log1p);
305 }
log2Vectorized306 Vectorized<T> log2() const {
307 return map(std::log2);
308 }
ceilVectorized309 Vectorized<T> ceil() const {
310 return map(std::ceil);
311 }
cosVectorized312 Vectorized<T> cos() const {
313 return map(std::cos);
314 }
coshVectorized315 Vectorized<T> cosh() const {
316 return map(std::cosh);
317 }
floorVectorized318 Vectorized<T> floor() const {
319 return map(std::floor);
320 }
hypotVectorized321 Vectorized<T> hypot(const Vectorized<T> &b) const {
322 Vectorized<T> ret;
323 for (size_t i = 0; i < size(); ++i) {
324 ret[i] = std::hypot(values[i], b[i]);
325 }
326 return ret;
327 }
negVectorized328 Vectorized<T> neg() const {
329 // NB: the trailing return type is needed because we need to coerce the
330 // return value back to T in the case of unary operator- incuring a
331 // promotion
332 return map([](T x) -> T { return -x; });
333 }
nextafterVectorized334 Vectorized<T> nextafter(const Vectorized<T> &b) const {
335 Vectorized<T> ret;
336 for (size_t i = 0; i < size(); ++i) {
337 ret[i] = std::nextafter(values[i], b[i]);
338 }
339 return ret;
340 }
roundVectorized341 Vectorized<T> round() const {
342 // TODO(T149257433): implement custom round that rounds midway numbers to
343 // the nearest even integer.
344 return map(std::round);
345 }
sinVectorized346 Vectorized<T> sin() const {
347 return map(std::sin);
348 }
sinhVectorized349 Vectorized<T> sinh() const {
350 return map(std::sinh);
351 }
tanVectorized352 Vectorized<T> tan() const {
353 return map(std::tan);
354 }
tanhVectorized355 Vectorized<T> tanh() const {
356 return map(std::tanh);
357 }
truncVectorized358 Vectorized<T> trunc() const {
359 return map(std::trunc);
360 }
lgammaVectorized361 Vectorized<T> lgamma() const {
362 return map(std::lgamma);
363 }
sqrtVectorized364 Vectorized<T> sqrt() const {
365 return map(std::sqrt);
366 }
reciprocalVectorized367 Vectorized<T> reciprocal() const {
368 return map([](T x) { return (T)(1) / x; });
369 }
rsqrtVectorized370 Vectorized<T> rsqrt() const {
371 return map([](T x) { return (T)1 / std::sqrt(x); });
372 }
powVectorized373 Vectorized<T> pow(const Vectorized<T> &exp) const {
374 Vectorized<T> ret;
375 for (size_t i = 0; i < size(); ++i) {
376 ret[i] = std::pow(values[i], exp[i]);
377 }
378 return ret;
379 }
380 private:
381 template <typename Op>
binary_predVectorized382 inline Vectorized<T> binary_pred(const Vectorized<T>& other, Op op) const {
383 // All bits are set to 1 if the pred is true, otherwise 0.
384 Vectorized<T> vector;
385 for (size_t i = 0; i != size(); i++) {
386 if (op(values[i], other.values[i])) {
387 std::memset(static_cast<void*>(vector.values + i), 0xFF, sizeof(T));
388 } else {
389 std::memset(static_cast<void*>(vector.values + i), 0, sizeof(T));
390 }
391 }
392 return vector;
393 }
394
395 public:
396 Vectorized<T> operator==(const Vectorized<T>& other) const { return binary_pred(other, std::equal_to<T>()); }
397 Vectorized<T> operator!=(const Vectorized<T>& other) const { return binary_pred(other, std::not_equal_to<T>()); }
398 Vectorized<T> operator>=(const Vectorized<T>& other) const { return binary_pred(other, std::greater_equal<T>()); }
399 Vectorized<T> operator<=(const Vectorized<T>& other) const { return binary_pred(other, std::less_equal<T>()); }
400 Vectorized<T> operator>(const Vectorized<T>& other) const { return binary_pred(other, std::greater<T>()); }
401 Vectorized<T> operator<(const Vectorized<T>& other) const { return binary_pred(other, std::less<T>()); }
402
403 private:
404 template <typename Op>
binary_pred_boolVectorized405 inline Vectorized<T> binary_pred_bool(const Vectorized<T>& other, Op op) const {
406 // 1 if the pred is true, otherwise 0.
407 Vectorized<T> vector;
408 for (size_t i = 0; i != size(); ++ i) {
409 vector[i] = static_cast<T>(op(values[i], other.values[i]));
410 }
411 return vector;
412 }
413
414 public:
eqVectorized415 Vectorized<T> eq(const Vectorized<T>& other) const { return binary_pred_bool(other, std::equal_to<T>()); }
neVectorized416 Vectorized<T> ne(const Vectorized<T>& other) const { return binary_pred_bool(other, std::not_equal_to<T>()); }
gtVectorized417 Vectorized<T> gt(const Vectorized<T>& other) const { return binary_pred_bool(other, std::greater<T>()); }
geVectorized418 Vectorized<T> ge(const Vectorized<T>& other) const { return binary_pred_bool(other, std::greater_equal<T>()); }
ltVectorized419 Vectorized<T> lt(const Vectorized<T>& other) const { return binary_pred_bool(other, std::less<T>()); }
leVectorized420 Vectorized<T> le(const Vectorized<T>& other) const { return binary_pred_bool(other, std::less_equal<T>()); }
421 };
422
423 template <class T> Vectorized<T> inline operator+(const Vectorized<T> &a, const Vectorized<T> &b) {
424 Vectorized<T> c;
425 for (size_t i = 0; i != Vectorized<T>::size(); i++) {
426 c[i] = a[i] + b[i];
427 }
428 return c;
429 }
430
431 template <class T> Vectorized<T> inline operator-(const Vectorized<T> &a, const Vectorized<T> &b) {
432 Vectorized<T> c;
433 for (size_t i = 0; i != Vectorized<T>::size(); i++) {
434 c[i] = a[i] - b[i];
435 }
436 return c;
437 }
438
439 template <class T> Vectorized<T> inline operator*(const Vectorized<T> &a, const Vectorized<T> &b) {
440 Vectorized<T> c;
441 for (size_t i = 0; i != Vectorized<T>::size(); i++) {
442 c[i] = a[i] * b[i];
443 }
444 return c;
445 }
446
447 template <class T> Vectorized<T> inline operator/(const Vectorized<T> &a, const Vectorized<T> &b) {
448 Vectorized<T> c;
449 for (size_t i = 0; i != Vectorized<T>::size(); i++) {
450 c[i] = a[i] / b[i];
451 }
452 return c;
453 }
454
455 template <class T> Vectorized<T> inline operator||(
456 const Vectorized<T> &a, const Vectorized<T> &b) {
457 Vectorized<T> c;
458 for (size_t i = 0; i != Vectorized<T>::size(); i++) {
459 c[i] = a[i] || b[i];
460 }
461 return c;
462 }
463
464 // Implements the IEEE 754 201X `maximum` operation, which propagates NaN if
465 // either input is a NaN.
466 template <class T>
maximum(const Vectorized<T> & a,const Vectorized<T> & b)467 Vectorized<T> inline maximum(const Vectorized<T> &a, const Vectorized<T> &b) {
468 Vectorized<T> c;
469 for (size_t i = 0; i != Vectorized<T>::size(); i++) {
470 c[i] = (a[i] > b[i]) ? a[i] : b[i];
471 if (std::isnan(a[i])) {
472 // If either input is NaN, propagate a NaN.
473 // NOTE: The case where b[i] was NaN is handled correctly by the naive
474 // ternary operator above.
475 c[i] = a[i];
476 }
477 }
478 return c;
479 }
480
481 // Implements the IEEE 754 201X `minimum` operation, which propagates NaN if
482 // either input is a NaN.
483 template <class T>
minimum(const Vectorized<T> & a,const Vectorized<T> & b)484 Vectorized<T> inline minimum(const Vectorized<T> &a, const Vectorized<T> &b) {
485 Vectorized<T> c;
486 for (size_t i = 0; i != Vectorized<T>::size(); i++) {
487 c[i] = (a[i] < b[i]) ? a[i] : b[i];
488 if (std::isnan(a[i])) {
489 // If either input is NaN, propagate a NaN.
490 // NOTE: The case where b[i] was NaN is handled correctly by the naive
491 // ternary operator above.
492 c[i] = a[i];
493 }
494 }
495 return c;
496 }
497
498 template <class T>
clamp(const Vectorized<T> & a,const Vectorized<T> & min_vec,const Vectorized<T> & max_vec)499 Vectorized<T> inline clamp(const Vectorized<T> &a, const Vectorized<T> &min_vec, const Vectorized<T> &max_vec) {
500 Vectorized<T> c;
501 for (size_t i = 0; i != Vectorized<T>::size(); i++) {
502 c[i] = std::min(std::max(a[i], min_vec[i]), max_vec[i]);
503 }
504 return c;
505 }
506
507 template <class T>
clamp_max(const Vectorized<T> & a,const Vectorized<T> & max_vec)508 Vectorized<T> inline clamp_max(const Vectorized<T> &a, const Vectorized<T> &max_vec) {
509 Vectorized<T> c;
510 for (size_t i = 0; i != Vectorized<T>::size(); i++) {
511 c[i] = a[i] > max_vec[i] ? max_vec[i] : a[i];
512 }
513 return c;
514 }
515
516 template <class T>
clamp_min(const Vectorized<T> & a,const Vectorized<T> & min_vec)517 Vectorized<T> inline clamp_min(const Vectorized<T> &a, const Vectorized<T> &min_vec) {
518 Vectorized<T> c;
519 for (size_t i = 0; i != Vectorized<T>::size(); i++) {
520 c[i] = a[i] < min_vec[i] ? min_vec[i] : a[i];
521 }
522 return c;
523 }
524
525 struct Vectorizedi;
526
527 #if defined(CPU_CAPABILITY_AVX2) || defined(CPU_CAPABILITY_AVX512)
528 template <class T, typename Op>
bitwise_binary_op(const Vectorized<T> & a,const Vectorized<T> & b,Op op)529 static inline Vectorized<T> bitwise_binary_op(const Vectorized<T> &a, const Vectorized<T> &b, Op op) {
530 int_vector buffer;
531 #if defined(CPU_CAPABILITY_AVX2)
532 int_vector a_buffer = _mm256_load_si256(reinterpret_cast<const int_vector*>((const T*)a));
533 int_vector b_buffer = _mm256_load_si256(reinterpret_cast<const int_vector*>((const T*)b));
534 #elif defined(CPU_CAPABILITY_AVX512)
535 int_vector a_buffer = _mm512_load_si512(reinterpret_cast<const int_vector*>((const T*)a));
536 int_vector b_buffer = _mm512_load_si512(reinterpret_cast<const int_vector*>((const T*)b));
537 #endif
538 buffer = op(a_buffer, b_buffer);
539 __at_align__ T results[Vectorized<T>::size()];
540
541 #if defined(CPU_CAPABILITY_AVX2)
542 _mm256_store_si256(reinterpret_cast<int_vector*>(results), buffer);
543 #elif defined(CPU_CAPABILITY_AVX512)
544 _mm512_store_si512(reinterpret_cast<int_vector*>(results), buffer);
545 #endif
546 return Vectorized<T>::loadu(results);
547 }
548
549 template<class T, typename std::enable_if_t<!std::is_base_of<Vectorizedi, Vectorized<T>>::value, int> = 0>
550 inline Vectorized<T> operator&(const Vectorized<T>& a, const Vectorized<T>& b) {
551 // We enclose _mm512_and_si512 or _mm256_and_si256 with lambda because it is always_inline
552 #if defined(CPU_CAPABILITY_AVX2)
553 return bitwise_binary_op(a, b, [](int_vector a, int_vector b) { return _mm256_and_si256(a, b); });
554 #elif defined(CPU_CAPABILITY_AVX512)
555 return bitwise_binary_op(a, b, [](int_vector a, int_vector b) { return _mm512_and_si512(a, b); });
556 #endif
557 }
558 template<class T, typename std::enable_if_t<!std::is_base_of<Vectorizedi, Vectorized<T>>::value, int> = 0>
559 inline Vectorized<T> operator|(const Vectorized<T>& a, const Vectorized<T>& b) {
560 // We enclose _mm512_or_si512 or _mm256_or_si256 with lambda because it is always_inline
561 #if defined(CPU_CAPABILITY_AVX2)
562 return bitwise_binary_op(a, b, [](int_vector a, int_vector b) { return _mm256_or_si256(a, b); });
563 #elif defined(CPU_CAPABILITY_AVX512)
564 return bitwise_binary_op(a, b, [](int_vector a, int_vector b) { return _mm512_or_si512(a, b); });
565 #endif
566 }
567 template<class T, typename std::enable_if_t<!std::is_base_of<Vectorizedi, Vectorized<T>>::value, int> = 0>
568 inline Vectorized<T> operator^(const Vectorized<T>& a, const Vectorized<T>& b) {
569 // We enclose _mm512_xor_si512 or _mm256_xor_si256 with lambda because it is always_inline
570 #if defined(CPU_CAPABILITY_AVX2)
571 return bitwise_binary_op(a, b, [](int_vector a, int_vector b) { return _mm256_xor_si256(a, b); });
572 #elif defined(CPU_CAPABILITY_AVX512)
573 return bitwise_binary_op(a, b, [](int_vector a, int_vector b) { return _mm512_xor_si512(a, b); });
574 #endif
575 }
576
577 #else
578
579 template <typename T>
580 auto load(char const* data) -> T {
581 T ret;
582 std::memcpy(&ret, data, sizeof(ret));
583 return ret;
584 }
585
586 template<class T, typename Op>
bitwise_binary_op(const Vectorized<T> & a,const Vectorized<T> & b,Op op)587 static inline Vectorized<T> bitwise_binary_op(const Vectorized<T> &a, const Vectorized<T> &b, Op op) {
588 static constexpr uint32_t element_no = VECTOR_WIDTH / sizeof(intmax_t);
589 __at_align__ intmax_t buffer[element_no];
590 static_assert(VECTOR_WIDTH % sizeof(intmax_t) == 0, "VECTOR_WIDTH not a multiple of sizeof(intmax_t)");
591 static_assert(sizeof(buffer) == sizeof(Vectorized<T>), "sizeof(buffer) must match sizeof(Vectorized<T>)");
592 // We should be using memcpy in order to respect the strict aliasing rule
593 // see: https://github.com/pytorch/pytorch/issues/66119
594 // Using char* is defined in the C11 standard 6.5 Expression paragraph 7
595 // (http://www.open-std.org/jtc1/sc22/wg14/www/docs/n1570.pdf)
596 const auto* a_data = a.as_bytes();
597 const auto* b_data = b.as_bytes();
598 // load each intmax_t chunk and process; increase pointers by sizeof(intmax_t)
599 for (auto& out : buffer) {
600 out = op(load<intmax_t>(a_data), load<intmax_t>(b_data));
601 a_data += sizeof(intmax_t);
602 b_data += sizeof(intmax_t);
603 }
604 assert(a_data == a.as_bytes() + sizeof(a));
605 assert(b_data == b.as_bytes() + sizeof(b));
606 return Vectorized<T>::loadu(buffer);
607 }
608
609 template<class T, typename std::enable_if_t<!std::is_base_of<Vectorizedi, Vectorized<T>>::value, int> = 0>
610 inline Vectorized<T> operator&(const Vectorized<T>& a, const Vectorized<T>& b) {
611 return bitwise_binary_op(a, b, std::bit_and<intmax_t>());
612 }
613 template<class T, typename std::enable_if_t<!std::is_base_of<Vectorizedi, Vectorized<T>>::value, int> = 0>
614 inline Vectorized<T> operator|(const Vectorized<T>& a, const Vectorized<T>& b) {
615 return bitwise_binary_op(a, b, std::bit_or<intmax_t>());
616 }
617 template<class T, typename std::enable_if_t<!std::is_base_of<Vectorizedi, Vectorized<T>>::value, int> = 0>
618 inline Vectorized<T> operator^(const Vectorized<T>& a, const Vectorized<T>& b) {
619 return bitwise_binary_op(a, b, std::bit_xor<intmax_t>());
620 }
621
622 #endif // defined(CPU_CAPABILITY_AVX2) || defined(CPU_CAPABILITY_AVX512)
623
624 template<class T, typename std::enable_if_t<!std::is_base_of<Vectorizedi, Vectorized<T>>::value, int> = 0>
625 inline Vectorized<T> operator~(const Vectorized<T>& a) {
626 Vectorized<T> ones; // All bits are 1
627 memset((T*) ones, 0xFF, VECTOR_WIDTH);
628 return a ^ ones;
629 }
630
631 template <class T> Vectorized<T> inline operator<<(const Vectorized<T> &a, const Vectorized<T> &b) {
632 constexpr T max_shift = sizeof(T) * CHAR_BIT;
633 Vectorized<T> c;
634 for (size_t i = 0; i != Vectorized<T>::size(); i++) {
635 T shift = b[i];
636 if ((static_cast<std::make_signed_t<T>>(shift) < 0) || (shift >= max_shift)) {
637 c[i] = 0;
638 } else {
639 c[i] = static_cast<std::make_unsigned_t<T>>(a[i]) << shift;
640 }
641 }
642 return c;
643 }
644
645 template <class T> Vectorized<T> inline operator>>(const Vectorized<T> &a, const Vectorized<T> &b) {
646 // right shift value to retain sign bit for signed and no bits for unsigned
647 constexpr T max_shift = sizeof(T) * CHAR_BIT - std::is_signed_v<T>;
648 Vectorized<T> c;
649 for (size_t i = 0; i != Vectorized<T>::size(); i++) {
650 T shift = b[i];
651 if ((static_cast<std::make_signed_t<T>>(shift) < 0) || (shift >= max_shift)) {
652 c[i] = a[i] >> max_shift;
653 } else {
654 c[i] = a[i] >> shift;
655 }
656 }
657 return c;
658 }
659
660 template <typename T>
661 inline Vectorized<T>& operator += (Vectorized<T>& a, const Vectorized<T>& b) {
662 a = a + b;
663 return a;
664 }
665 template <typename T>
666 inline Vectorized<T>& operator -= (Vectorized<T>& a, const Vectorized<T>& b) {
667 a = a - b;
668 return a;
669 }
670 template <typename T>
671 inline Vectorized<T>& operator /= (Vectorized<T>& a, const Vectorized<T>& b) {
672 a = a / b;
673 return a;
674 }
675 template <typename T>
676 inline Vectorized<T>& operator %= (Vectorized<T>& a, const Vectorized<T>& b) {
677 a = a % b;
678 return a;
679 }
680 template <typename T>
681 inline Vectorized<T>& operator *= (Vectorized<T>& a, const Vectorized<T>& b) {
682 a = a * b;
683 return a;
684 }
685
686 template <typename T>
687 inline Vectorized<T>& operator <<= (Vectorized<T>& a, const Vectorized<T>& b) {
688 a = a << b;
689 return a;
690 }
691
692 template <typename T>
693 inline Vectorized<T>& operator >>= (Vectorized<T>& a, const Vectorized<T>& b) {
694 a = a >> b;
695 return a;
696 }
697
698 template <typename T>
fmadd(const Vectorized<T> & a,const Vectorized<T> & b,const Vectorized<T> & c)699 inline Vectorized<T> fmadd(const Vectorized<T>& a, const Vectorized<T>& b, const Vectorized<T>& c) {
700 return a * b + c;
701 }
702
703 template <typename T>
fmsub(const Vectorized<T> & a,const Vectorized<T> & b,const Vectorized<T> & c)704 inline Vectorized<T> fmsub(const Vectorized<T>& a, const Vectorized<T>& b, const Vectorized<T>& c) {
705 return a * b - c;
706 }
707
708 template <int64_t scale = 1, typename T = void>
709 std::enable_if_t<scale == 1 || scale == 2 || scale == 4 || scale == 8, Vectorized<T>>
gather(T const * base_addr,const Vectorized<int_same_size_t<T>> & vindex)710 inline gather(T const* base_addr, const Vectorized<int_same_size_t<T>>& vindex) {
711 static constexpr int size = Vectorized<T>::size();
712 int_same_size_t<T> index_arr[size];
713 vindex.store(static_cast<void*>(index_arr));
714 T buffer[size];
715 for (size_t i = 0; i < size; ++i) {
716 buffer[i] = base_addr[index_arr[i] * scale / sizeof(T)];
717 }
718 return Vectorized<T>::loadu(static_cast<void*>(buffer));
719 }
720
721 template <int64_t scale = 1, typename T = void>
722 std::enable_if_t<scale == 1 || scale == 2 || scale == 4 || scale == 8, Vectorized<T>>
mask_gather(const Vectorized<T> & src,T const * base_addr,const Vectorized<int_same_size_t<T>> & vindex,Vectorized<T> & mask)723 inline mask_gather(const Vectorized<T>& src, T const* base_addr,
724 const Vectorized<int_same_size_t<T>>& vindex, Vectorized<T>& mask) {
725 static constexpr int size = Vectorized<T>::size();
726 T src_arr[size];
727 int_same_size_t<T> mask_arr[size]; // use int type so we can logical and
728 int_same_size_t<T> index_arr[size];
729 src.store(static_cast<void*>(src_arr));
730 mask.store(static_cast<void*>(mask_arr));
731 vindex.store(static_cast<void*>(index_arr));
732 T buffer[size];
733 for (size_t i = 0; i < size; ++i) {
734 if (mask_arr[i] & 0x01) { // check highest bit
735 buffer[i] = base_addr[index_arr[i] * scale / sizeof(T)];
736 } else {
737 buffer[i] = src_arr[i];
738 }
739 }
740 mask = Vectorized<T>(); // "zero out" mask
741 return Vectorized<T>::loadu(static_cast<void*>(buffer));
742 }
743
744 // Cast a given vector to another type without changing the bits representation.
745 // So a Vectorized<double> of 512 bits containing all ones can be cast to a
746 // Vectorized<int64_t> of 512 bits containing all ones (i.e., eight negative 1s).
747 // A Vec<double> of 256 bits containing all ones can be cast to a
748 // Vec<int64_t> of 256 bits containing all ones (i.e., four negative 1s).
749 // There is a struct here because we don't have static_if and I can't
750 // partially specialize a templated function.
751 template<typename dst_t, typename src_t>
752 struct CastImpl {
applyCastImpl753 static inline Vectorized<dst_t> apply(const Vectorized<src_t>& src) {
754 src_t src_arr[Vectorized<src_t>::size()];
755 src.store(static_cast<void*>(src_arr));
756 return Vectorized<dst_t>::loadu(static_cast<const void*>(src_arr));
757 }
758 };
759
760 template<typename scalar_t>
761 struct CastImpl<scalar_t, scalar_t> {
762 static inline Vectorized<scalar_t> apply(const Vectorized<scalar_t>& src) {
763 return src;
764 }
765 };
766
767 template<typename dst_t, typename src_t>
768 inline Vectorized<dst_t> cast(const Vectorized<src_t>& src) {
769 return CastImpl<dst_t, src_t>::apply(src);
770 }
771
772 template <typename T>
773 inline Vectorized<int_same_size_t<T>> convert_to_int_of_same_size(const Vectorized<T>& src) {
774 static constexpr int size = Vectorized<T>::size();
775 T src_arr[size];
776 src.store(static_cast<void*>(src_arr));
777 int_same_size_t<T> buffer[size];
778 for (size_t i = 0; i < size; ++i) {
779 buffer[i] = static_cast<int_same_size_t<T>>(src_arr[i]);
780 }
781 return Vectorized<int_same_size_t<T>>::loadu(static_cast<void*>(buffer));
782 }
783
784 // Example inputs for AVX512:
785 // a Vectorized<float> = {a0, b0, a1, b1, a2, b2, a3, b3, a4, b4, a5, b5, a6, b6, a7, b7}
786 // b Vectorized<float> = {a8, b8, a9, b9, a10, b10, a11, b11, a12, b12, a13, b13, a14, b14, a15, b15}
787 // returns:
788 // Vectorized<float> = {a0, a1, a2, a3, a4, a5, a6, a7, a8, a9, a10, a11, a12, a13, a14, a15}
789 // Vectorized<float> = {b0, b1, b2, b3, b4, b5, b6, b7, b8, b9, b10, b11, b12, b13, b14, b15}
790 // Example inputs for AVX2: a Vectorized<float> = {a0, b0, a1, b1, a2, b2, a3, b3}
791 // b Vectorized<float> = {a4, b4, a5, b5, a6, b6, a7, b7}
792 // returns: Vectorized<float> = {a0, a1, a2, a3, a4, a5, a6, a7}
793 // Vectorized<float> = {b0, b1, b2, b3, b4, b5, b6, b7}
794 template <typename T>
795 inline std::enable_if_t<Vectorized<T>::size() % 2 == 0, std::pair<Vectorized<T>, Vectorized<T>>>
796 deinterleave2(const Vectorized<T>& a, const Vectorized<T>& b) {
797 static constexpr int size = Vectorized<T>::size();
798 static constexpr int half_size = size / 2;
799 T a_arr[size];
800 T b_arr[size];
801 T buffer1[size];
802 T buffer2[size];
803 a.store(static_cast<void*>(a_arr));
804 b.store(static_cast<void*>(b_arr));
805 for (size_t i = 0; i < half_size; ++i) {
806 buffer1[i] = a_arr[i * 2];
807 buffer1[half_size + i] = b_arr[i * 2];
808 buffer2[i] = a_arr[i * 2 + 1];
809 buffer2[half_size + i] = b_arr[i * 2 + 1];
810 }
811 return std::make_pair(Vectorized<T>::loadu(static_cast<void*>(buffer1)),
812 Vectorized<T>::loadu(static_cast<void*>(buffer2)));
813 }
814
815 // inverse operation of deinterleave2
816 // Example inputs for AVX512:
817 // a Vectorized<float> = {a0, a1, a2, a3, a4, a5, a6, a7, a8, a9, a10, a11, a12, a13, a14, a15}
818 // b Vectorized<float> = {b0, b1, b2, b3, b4, b5, b6, b7, b8, b9, b10, b11, b12, b13, b14, b15}
819 // returns, for AVX512:
820 // Vectorized<float> = {a0, b0, a1, b1, a2, b2, a3, b3, a4, b4, a5, b5, a6, b6, a7, b7}
821 // Vectorized<float> = {a8, b8, a9, b9, a10, b10, a11, b11, a12, b12, a13, b13, a14, b14, a15, b15}
822 // Example inputs for AVX2 : a Vectorized<float> = {a0, a1, a2, a3, a4, a5, a6, a7}
823 // b Vectorized<float> = {b0, b1, b2, b3, b4, b5, b6, b7}
824 // returns: Vectorized<float> = {a0, b0, a1, b1, a2, b2, a3, b3}
825 // Vectorized<float> = {a4, b4, a5, b5, a6, b6, a7, b7}
826 template <typename T>
827 inline std::enable_if_t<Vectorized<T>::size() % 2 == 0, std::pair<Vectorized<T>, Vectorized<T>>>
828 interleave2(const Vectorized<T>& a, const Vectorized<T>& b) {
829 static constexpr int size = Vectorized<T>::size();
830 static constexpr int half_size = size / 2;
831 T a_arr[size];
832 T b_arr[size];
833 T buffer1[size];
834 T buffer2[size];
835 a.store(static_cast<void*>(a_arr));
836 b.store(static_cast<void*>(b_arr));
837 for (size_t i = 0; i < half_size; ++i) {
838 buffer1[i * 2] = a_arr[i];
839 buffer1[i * 2 + 1] = b_arr[i];
840 buffer2[i * 2] = a_arr[half_size + i];
841 buffer2[i * 2 + 1] = b_arr[half_size + i];
842 }
843 return std::make_pair(Vectorized<T>::loadu(static_cast<void*>(buffer1)),
844 Vectorized<T>::loadu(static_cast<void*>(buffer2)));
845 }
846
847 template <typename src_T, typename dst_T>
848 inline void convert(const src_T *src, dst_T *dst, int64_t n) {
849 #ifndef _MSC_VER
850 # pragma unroll
851 #endif
852 for (int64_t i = 0; i < n; ++i) {
853 (void)i; //Suppress unused variable warning
854 *dst = static_cast<dst_T>(*src);
855 src++;
856 dst++;
857 }
858 }
859
860 template <typename T>
861 inline Vectorized<T> flip(const Vectorized<T> & data) {
862 static constexpr int size = Vectorized<T>::size();
863 T output[size];
864 T buffer[size];
865 data.store(static_cast<void*>(buffer));
866 for (size_t i = 0; i < size; ++i) {
867 output[i] = buffer[size - i - 1];
868 }
869 return Vectorized<T>::loadu(static_cast<void*>(output));
870 }
871
872 // Transpose the `src` buffer of type `T` and size (M,N) into the `dst` buffer. `ld_src` is the leading
873 // dimension of `src` and `ld_dst` is the leading dimension of `dst`.
874 template <typename T, int M, int N>
875 inline void transpose_mxn(const T* src, int64_t ld_src, T* dst, int64_t ld_dst) {
876 for (size_t i = 0; i < M; i++) {
877 for (int j = 0; j < N; j++) {
878 dst[j*ld_dst + i] = src[i*ld_src + j];
879 }
880 }
881 }
882
883 } // namespace CPU_CAPABILITY
884
885 } // namespace vec
886 } // namespace executorch
887