1 #include <cmath>
2 #include <cstring>
3 #include <limits>
4 #include <type_traits>
5 #include <utility>
6 #if defined(__clang__)
7 #include <sleef.h>
8 #elif defined(__GNUC__) || defined(__GNUG__)
9 #include <sleef.h>
10 #include <vecintrin.h>
11 #endif
12 #include <ATen/cpu/vec/intrinsics.h>
13 #include <ATen/cpu/vec/vec_base.h>
14 #include <c10/util/complex.h>
15
16 namespace at {
17 namespace vec {
18
19 // See Note [CPU_CAPABILITY namespace]
20 inline namespace CPU_CAPABILITY {
21
22 template <typename T>
is_zarch_implemented()23 constexpr bool is_zarch_implemented() {
24 return (
25 std::is_same<T, float>::value || std::is_same<T, double>::value ||
26 std::is_same<T, int8_t>::value || std::is_same<T, uint8_t>::value ||
27 std::is_same<T, uint16_t>::value || std::is_same<T, int16_t>::value ||
28 std::is_same<T, int32_t>::value || std::is_same<T, int64_t>::value);
29 }
30
31 template <typename T>
is_zarch_implemented_quant()32 constexpr bool is_zarch_implemented_quant() {
33 return (
34 std::is_same<T, c10::qint32>::value ||
35 std::is_same<T, c10::qint8>::value ||
36 std::is_same<T, c10::quint8>::value);
37 }
38
39 template <typename T>
is_zarch_implemented_complex()40 constexpr bool is_zarch_implemented_complex() {
41 return std::is_same<T, c10::complex<float>>::value ||
42 std::is_same<T, c10::complex<double>>::value;
43 }
44
45 constexpr int offset0 = 0;
46 constexpr int offset16 = 16;
47
48 template <int N>
49 struct VecBinaryType {
50 using type __attribute__((vector_size(16))) = uintmax_t;
51 };
52
53 template <>
54 struct VecBinaryType<8> {
55 using type = __attribute__((vector_size(16))) unsigned long long;
56 };
57
58 template <>
59 struct VecBinaryType<4> {
60 using type = __attribute__((vector_size(16))) unsigned int;
61 };
62
63 template <>
64 struct VecBinaryType<2> {
65 using type = __attribute__((vector_size(16))) unsigned short;
66 };
67
68 template <>
69 struct VecBinaryType<1> {
70 using type = __attribute__((vector_size(16))) unsigned char;
71 };
72
73 template <typename T>
74 struct VecInnerType {
75 using Type __attribute__((vector_size(16))) = T;
76 using BinaryType = typename VecBinaryType<sizeof(T)>::type;
77 using ElementType = T;
78 static constexpr int size = 16 / sizeof(T);
79 };
80
81 // define for int64_t properly for load
82 template <>
83 struct VecInnerType<int64_t> {
84 using Type = __attribute__((vector_size(16))) signed long long;
85 using ElementType = signed long long;
86 using BinaryType = typename VecBinaryType<sizeof(signed long long)>::type;
87 static constexpr int size = 16 / sizeof(signed long long);
88 };
89
90 template <typename T>
91 using ZSimdVect = typename VecInnerType<T>::Type;
92 template <typename T>
93 using ZSimdVectBinary = typename VecInnerType<T>::BinaryType;
94 template <typename T>
95 using ZSimdVectElement = typename VecInnerType<T>::ElementType;
96
97 constexpr int blendChoiceInner(
98 const uint64_t mask,
99 const uint64_t half1 = 0xF,
100 const uint64_t half2 = 0xF0) {
101 uint64_t none = 0;
102 uint64_t both = half1 | half2;
103 // clamp it between 0 and both
104 auto res_mask = mask & both;
105 // return (a._vec0, a._vec1)
106 if (res_mask == none)
107 return 0;
108 // return (b._vec0,b._vec1)
109 else if (res_mask == both)
110 return 1;
111 // return (b._vec0, a._vec1)
112 else if (res_mask == half1)
113 return 2;
114 // return (a._vec0,b._vec1)
115 else if (res_mask == half2)
116 return 3;
117 // return (*_vec0,a._vec1)
118 else if (res_mask > 0 && res_mask < half1)
119 return 4;
120 // return (*_vec0,b._vec1)
121 else if ((res_mask & half2) == half2)
122 return 5;
123 // return (a._vec0,*_vec1)
124 else if ((res_mask & half1) == 0 && res_mask > half1)
125 return 6;
126 // return (b._vec0,*_vec1)
127 else if ((res_mask & half1) == half1 && res_mask > half1)
128 return 7;
129 // return (*_vec0,*_vec1)
130 return 8;
131 }
132
133 // it can be used to emulate blend faster
134 template <int Z>
135 constexpr int blendChoice(const uint64_t mask) {
136 static_assert(Z < 1 || Z > 8, "not implemented");
137 return blendChoiceInner(mask);
138 }
139
140 template <>
141 constexpr int blendChoice<1>(const uint64_t mask) {
142 return blendChoiceInner(mask, 0x0000FFFF, 0xFFFF0000);
143 }
144
145 template <>
146 constexpr int blendChoice<2>(const uint64_t mask) {
147 return blendChoiceInner(mask, 0x00FF, 0xFF00);
148 }
149
150 template <>
151 constexpr int blendChoice<4>(const uint64_t mask) {
152 return blendChoiceInner(mask, 0xF, 0xF0);
153 }
154
155 template <>
156 constexpr int blendChoice<8>(const uint64_t mask) {
157 // clamp it 0 and 0xF
158 return blendChoiceInner(mask, 0x3, 0xC);
159 }
160
161 template <int N>
162 constexpr auto GetMask1(const uint64_t mask) {
163 return typename VecBinaryType<N>::type{};
164 }
165
166 template <int N>
167 constexpr auto GetMask2(const uint64_t mask) {
168 return typename VecBinaryType<N>::type{};
169 }
170
171 template <>
172 constexpr auto GetMask1<1>(const uint64_t mask) {
173 constexpr uint8_t t = (int)0xFF;
174 uint8_t g0 = (mask & 1) * t;
175 uint8_t g1 = ((mask & 2) >> 1) * t;
176 uint8_t g2 = ((mask & 4) >> 2) * t;
177 uint8_t g3 = ((mask & 8) >> 3) * t;
178 uint8_t g4 = ((mask & 16) >> 4) * t;
179 uint8_t g5 = ((mask & 32) >> 5) * t;
180 uint8_t g6 = ((mask & 64) >> 6) * t;
181 uint8_t g7 = ((mask & 128) >> 7) * t;
182 uint8_t g8 = ((mask & 256) >> 8) * t;
183 uint8_t g9 = ((mask & 512) >> 9) * t;
184 uint8_t g10 = ((mask & 1024) >> 10) * t;
185 uint8_t g11 = ((mask & 2048) >> 11) * t;
186 uint8_t g12 = ((mask & 4096) >> 12) * t;
187 uint8_t g13 = ((mask & 8192) >> 13) * t;
188 uint8_t g14 = ((mask & 16384) >> 14) * t;
189 uint8_t g15 = ((mask & 32768) >> 15) * t;
190 return (typename VecBinaryType<1>::type){
191 g0, g1, g2, g3, g4, g5, g6, g7, g8, g9, g10, g11, g12, g13, g14, g15};
192 }
193
194 template <>
195 constexpr auto GetMask2<1>(const uint64_t mask) {
196 uint64_t mask2 = (mask & 0xFFFFFFFF) >> 16;
197 return GetMask1<1>(mask2);
198 }
199
200 template <>
201 constexpr auto GetMask1<2>(const uint64_t mask) {
202 constexpr uint16_t t = (int)0xFFFF;
203 uint16_t g0 = (mask & 1) * t;
204 uint16_t g1 = ((mask & 2) >> 1) * t;
205 uint16_t g2 = ((mask & 4) >> 2) * t;
206 uint16_t g3 = ((mask & 8) >> 3) * t;
207 uint16_t g4 = ((mask & 16) >> 4) * t;
208 uint16_t g5 = ((mask & 32) >> 5) * t;
209 uint16_t g6 = ((mask & 64) >> 6) * t;
210 uint16_t g7 = ((mask & 128) >> 7) * t;
211 return (typename VecBinaryType<2>::type){g0, g1, g2, g3, g4, g5, g6, g7};
212 }
213
214 template <>
215 constexpr auto GetMask2<2>(const uint64_t mask) {
216 uint64_t mask2 = (mask & 0xFFFF) >> 8;
217 return GetMask1<2>(mask2);
218 }
219
220 template <>
221 constexpr auto GetMask1<4>(const uint64_t mask) {
222 uint32_t g0 = (mask & 1) * 0xffffffff;
223 uint32_t g1 = ((mask & 2) >> 1) * 0xffffffff;
224 uint32_t g2 = ((mask & 4) >> 2) * 0xffffffff;
225 uint32_t g3 = ((mask & 8) >> 3) * 0xffffffff;
226 return (typename VecBinaryType<4>::type){g0, g1, g2, g3};
227 }
228
229 template <>
230 constexpr auto GetMask2<4>(const uint64_t mask) {
231 uint64_t mask2 = (mask & 0xFF) >> 4;
232 return GetMask1<4>(mask2);
233 }
234
235 template <>
236 constexpr auto GetMask1<8>(const uint64_t mask) {
237 uint64_t g0 = (mask & 1) * 0xffffffffffffffff;
238 uint64_t g1 = ((mask & 2) >> 1) * 0xffffffffffffffff;
239 return (typename VecBinaryType<8>::type){g0, g1};
240 }
241
242 template <>
243 constexpr auto GetMask2<8>(const uint64_t mask) {
244 uint64_t mask2 = (mask & 0xF) >> 2;
245 return GetMask1<8>(mask2);
246 }
247
248 template <int Z>
249 constexpr int maskForComplex(uint32_t mask) {
250 return 0;
251 }
252
253 template <>
254 constexpr int maskForComplex<8>(uint32_t mask) {
255 mask = mask & 0xF;
256 int complex_mask = 0;
257 if (mask & 1)
258 complex_mask |= 3;
259 if (mask & 2)
260 complex_mask |= (3 << 2);
261 if (mask & 4)
262 complex_mask |= (3 << 4);
263 if (mask & 8)
264 complex_mask |= (3 << 6);
265 return complex_mask;
266 }
267
268 template <>
269 constexpr int maskForComplex<16>(uint32_t mask) {
270 mask = mask & 0x3;
271 int complex_mask = 0;
272 if (mask & 1)
273 complex_mask |= 3;
274 if (mask & 2)
275 complex_mask |= (3 << 2);
276 return complex_mask;
277 }
278
279 template <typename T = c10::complex<float>>
280 constexpr int blend_choice() {
281 return 0xAA;
282 }
283
284 template <>
285 constexpr int blend_choice<c10::complex<double>>() {
286 return 0x0A;
287 }
288
289 constexpr int64_t allbitset(int16_t x) {
290 int64_t onex = 1;
291 return (onex << x) - onex;
292 }
293
294 namespace { /* unnamed namespace */
295
296 ZSimdVect<float> vec_mergee(ZSimdVect<float> x, ZSimdVect<float> y) {
297 constexpr ZSimdVectBinary<uint8_t> mergee_mask{
298 0, 1, 2, 3, 16, 17, 18, 19, 8, 9, 10, 11, 24, 25, 26, 27};
299 return vec_perm(x, y, mergee_mask);
300 }
301
302 ZSimdVect<double> vec_mergee(ZSimdVect<double> x, ZSimdVect<double> y) {
303 return vec_mergeh(x, y);
304 }
305
306 ZSimdVect<float> vec_mergeo(ZSimdVect<float> x, ZSimdVect<float> y) {
307 constexpr ZSimdVectBinary<uint8_t> mergeo_mask{
308 4, 5, 6, 7, 20, 21, 22, 23, 12, 13, 14, 15, 28, 29, 30, 31};
309 return vec_perm(x, y, mergeo_mask);
310 }
311
312 ZSimdVect<double> vec_mergeo(ZSimdVect<double> x, ZSimdVect<double> y) {
313 return vec_mergel(x, y);
314 }
315
316 } /* unnamed namespace */
317
318 //
319 template <typename T>
320 constexpr auto GetBpermZeroMask() {
321 return ZSimdVectBinary<uint8_t>{
322 128,
323 128,
324 128,
325 128,
326 128,
327 128,
328 128,
329 128,
330 128,
331 128,
332 128,
333 128,
334 96,
335 64,
336 32,
337 0};
338 }
339
340 template <>
341 constexpr auto GetBpermZeroMask<double>() {
342 return ZSimdVectBinary<uint8_t>{
343 128,
344 128,
345 128,
346 128,
347 128,
348 128,
349 128,
350 128,
351 128,
352 128,
353 128,
354 128,
355 128,
356 128,
357 64,
358 0};
359 }
360
361 constexpr auto GetSwapMaskFloat() {
362 return ZSimdVectBinary<uint8_t>{
363 4, 5, 6, 7, 0, 1, 2, 3, 12, 13, 14, 15, 8, 9, 10, 11};
364 }
365
366 template <typename T>
367 struct Vectorized<T, std::enable_if_t<is_zarch_implemented<T>()>> {
368 public:
369 using value_type = T;
370 using vtype = ZSimdVect<T>;
371 using vmaskType = ZSimdVectBinary<T>;
372 using size_type = int;
373 // because of gcc inconsistency for int64_t we are obliged to use this, not
374 // value_type
375 using ElementType = ZSimdVectElement<T>;
376 using vinner_data = std::pair<vtype, vtype>;
377
378 private:
379 vtype _vec0;
380 vtype _vec1;
381
382 public:
383 static constexpr size_type size() {
384 return VECTOR_WIDTH / sizeof(ElementType);
385 }
386 Vectorized() {}
387
388 C10_ALWAYS_INLINE Vectorized(vtype v) : _vec0{v}, _vec1{v} {}
389 C10_ALWAYS_INLINE Vectorized(const vinner_data &v) : _vec0{v.first}, _vec1{v.second} {}
390 C10_ALWAYS_INLINE Vectorized(vtype v1, vtype v2) : _vec0{v1}, _vec1{v2} {}
391 C10_ALWAYS_INLINE Vectorized(T s)
392 : _vec0{vec_splats((ElementType)s)}, _vec1{vec_splats((ElementType)s)} {}
393
394 template <typename U, typename DUMMY = void>
395 struct LoaduHelper {
396 static Vectorized<T> C10_ALWAYS_INLINE
397 loadu(const U* ptr, int count = size()) {
398 __at_align__ ElementType tmp_values[size()] = {};
399 std::memcpy(tmp_values, ptr, std::min(count, size()) * sizeof(ElementType));
400
401 return {
402 vec_xl(offset0, &(tmp_values[0])),
403 vec_xl(offset16, &(tmp_values[0]))};
404 }
405 };
406
407 template <typename DUMMY>
408 struct LoaduHelper<ElementType, DUMMY> {
409 static Vectorized<T> C10_ALWAYS_INLINE
410 loadu(const ElementType* ptr, int count = size()) {
411 if (count == size()) {
412 return {
413 vec_xl(offset0, ptr),
414 vec_xl(offset16, ptr)};
415 }
416
417 __at_align__ ElementType tmp_values[size()] = {};
418 std::memcpy(tmp_values, ptr, std::min(count, size()) * sizeof(ElementType));
419
420 return {
421 vec_xl(offset0, &(tmp_values[0])),
422 vec_xl(offset16, &(tmp_values[0]))};
423 }
424 };
425
426 template <typename U>
427 static Vectorized<T> C10_ALWAYS_INLINE
428 loadu(const U* ptr, int count = size()) {
429 return LoaduHelper<U>::loadu(ptr, count);
430 }
431
432 template <typename U>
433 static Vectorized<T> C10_ALWAYS_INLINE
434 loadu_one_fourth(const U* ptr) {
435 // load only first 8 bytes
436 // only intended to be used with uint8_t
437 return loadu(ptr, 8 / sizeof(ElementType));
438 }
439
440 template <typename U, typename DUMMY = void>
441 struct StoreHelper {
442 static void C10_ALWAYS_INLINE store(const Vectorized<T> &vec, U* ptr, int count = size()) {
443 if (count > 0) {
444 __at_align__ ElementType tmp_values[size()];
445 vec_xst(vec._vec0, offset0, &(tmp_values[0]));
446 vec_xst(vec._vec1, offset16, &(tmp_values[0]));
447 std::memcpy(
448 ptr, tmp_values, std::min(count, size()) * sizeof(ElementType));
449 }
450 }
451 };
452
453 template <typename DUMMY>
454 struct StoreHelper<ElementType, DUMMY> {
455 static void C10_ALWAYS_INLINE store(const Vectorized<T> &vec, ElementType* ptr, int count = size()) {
456 if (count == size()) {
457 vec_xst(vec._vec0, offset0, ptr);
458 vec_xst(vec._vec1, offset16, ptr);
459 } else if (count > 0) {
460 __at_align__ ElementType tmp_values[size()];
461 vec_xst(vec._vec0, offset0, &(tmp_values[0]));
462 vec_xst(vec._vec1, offset16, &(tmp_values[0]));
463 std::memcpy(
464 ptr, tmp_values, std::min(count, size()) * sizeof(ElementType));
465 }
466 }
467 };
468
469 template <typename U>
470 void C10_ALWAYS_INLINE store(U* ptr, int count = size()) const {
471 return StoreHelper<U>::store(*this, ptr, count);
472 }
473
474 C10_ALWAYS_INLINE const vtype& vec0() const {
475 return _vec0;
476 }
477
478 C10_ALWAYS_INLINE const vtype& vec1() const {
479 return _vec1;
480 }
481
482 C10_ALWAYS_INLINE vinner_data data() const {
483 return std::make_pair<>(_vec0, _vec1);
484 }
485
486 C10_ALWAYS_INLINE operator vinner_data() const {
487 return data();
488 }
489
490 C10_ALWAYS_INLINE const vmaskType vecb0() const {
491 return (vmaskType)_vec0;
492 }
493 C10_ALWAYS_INLINE const vmaskType vecb1() const {
494 return (vmaskType)_vec1;
495 }
496
497 static Vectorized<T> C10_ALWAYS_INLINE blendv(
498 const Vectorized<T>& a,
499 const Vectorized<T>& b,
500 const Vectorized<T>& mask) {
501 return {
502 vec_sel(a._vec0, b._vec0, mask.vecb0()),
503 vec_sel(a._vec1, b._vec1, mask.vecb1())};
504 }
505
506 template <typename U = T, std::enable_if_t<(sizeof(U) == 8), int> = 0>
507 C10_ALWAYS_INLINE Vectorized(T s1, T s2, T s3, T s4)
508 : _vec0{s1, s2}, _vec1{s3, s4} {}
509
510 template <typename U = T, std::enable_if_t<(sizeof(U) == 4), int> = 0>
511 C10_ALWAYS_INLINE Vectorized(T s1, T s2, T s3, T s4, T s5, T s6, T s7, T s8)
512 : _vec0{s1, s2, s3, s4}, _vec1{s5, s6, s7, s8} {}
513
514 template <typename U = T, std::enable_if_t<(sizeof(U) == 2), int> = 0>
515 C10_ALWAYS_INLINE Vectorized(
516 T s1,
517 T s2,
518 T s3,
519 T s4,
520 T s5,
521 T s6,
522 T s7,
523 T s8,
524 T s9,
525 T s10,
526 T s11,
527 T s12,
528 T s13,
529 T s14,
530 T s15,
531 T s16)
532 : _vec0{s1, s2, s3, s4, s5, s6, s7, s8},
533 _vec1{s9, s10, s11, s12, s13, s14, s15, s16} {}
534
535 template <typename U = T, std::enable_if_t<(sizeof(U) == 1), int> = 0>
536 C10_ALWAYS_INLINE Vectorized(
537 T s1,
538 T s2,
539 T s3,
540 T s4,
541 T s5,
542 T s6,
543 T s7,
544 T s8,
545 T s9,
546 T s10,
547 T s11,
548 T s12,
549 T s13,
550 T s14,
551 T s15,
552 T s16,
553 T s17,
554 T s18,
555 T s19,
556 T s20,
557 T s21,
558 T s22,
559 T s23,
560 T s24,
561 T s25,
562 T s26,
563 T s27,
564 T s28,
565 T s29,
566 T s30,
567 T s31,
568 T s32)
569 : _vec0{s1, s2, s3, s4, s5, s6, s7, s8, s9, s10, s11, s12, s13, s14, s15, s16},
570 _vec1{
571 s17,
572 s18,
573 s19,
574 s20,
575 s21,
576 s22,
577 s23,
578 s24,
579 s25,
580 s26,
581 s27,
582 s28,
583 s29,
584 s30,
585 s31,
586 s32} {}
587
588 template <typename step_t, typename U = T>
589 static std::enable_if_t<sizeof(U) == 8, Vectorized<T>> arange(
590 T base = 0,
591 step_t step = static_cast<step_t>(1)) {
592 return Vectorized<T>(base, base + step, base + 2 * step, base + 3 * step);
593 }
594
595 template <typename step_t, typename U = T>
596 static std::enable_if_t<sizeof(U) == 4, Vectorized<T>> arange(
597 T base = 0,
598 step_t step = static_cast<step_t>(1)) {
599 return Vectorized<T>(
600 base,
601 base + step,
602 base + 2 * step,
603 base + 3 * step,
604 base + 4 * step,
605 base + 5 * step,
606 base + 6 * step,
607 base + 7 * step);
608 }
609
610 template <typename step_t, typename U = T>
611 static std::enable_if_t<sizeof(U) == 2, Vectorized<T>> arange(
612 T base = 0,
613 step_t step = static_cast<step_t>(1)) {
614 return Vectorized<T>(
615 base,
616 base + step,
617 base + 2 * step,
618 base + 3 * step,
619 base + 4 * step,
620 base + 5 * step,
621 base + 6 * step,
622 base + 7 * step,
623 base + 8 * step,
624 base + 9 * step,
625 base + 10 * step,
626 base + 11 * step,
627 base + 12 * step,
628 base + 13 * step,
629 base + 14 * step,
630 base + 15 * step);
631 }
632
633 template <typename step_t, typename U = T>
634 static std::enable_if_t<sizeof(U) == 1, Vectorized<T>> arange(
635 T base = 0,
636 step_t step = static_cast<step_t>(1)) {
637 return Vectorized<T>(
638 base,
639 base + step,
640 base + 2 * step,
641 base + 3 * step,
642 base + 4 * step,
643 base + 5 * step,
644 base + 6 * step,
645 base + 7 * step,
646 base + 8 * step,
647 base + 9 * step,
648 base + 10 * step,
649 base + 11 * step,
650 base + 12 * step,
651 base + 13 * step,
652 base + 14 * step,
653 base + 15 * step,
654 base + 16 * step,
655 base + 17 * step,
656 base + 18 * step,
657 base + 19 * step,
658 base + 20 * step,
659 base + 21 * step,
660 base + 22 * step,
661 base + 23 * step,
662 base + 24 * step,
663 base + 25 * step,
664 base + 26 * step,
665 base + 27 * step,
666 base + 28 * step,
667 base + 29 * step,
668 base + 30 * step,
669 base + 31 * step);
670 }
671
672 // blend section
673 template <int64_t mask>
674 static std::enable_if_t<blendChoice<sizeof(T)>(mask) == 0, Vectorized<T>>
675 C10_ALWAYS_INLINE blend(const Vectorized<T>& a, const Vectorized<T>& b) {
676 return a;
677 }
678
679 template <int64_t mask>
680 static std::enable_if_t<blendChoice<sizeof(T)>(mask) == 1, Vectorized<T>>
681 C10_ALWAYS_INLINE blend(const Vectorized<T>& a, const Vectorized<T>& b) {
682 return b;
683 }
684
685 template <int64_t mask>
686 static std::enable_if_t<blendChoice<sizeof(T)>(mask) == 2, Vectorized<T>>
687 C10_ALWAYS_INLINE blend(const Vectorized<T>& a, const Vectorized<T>& b) {
688 return {b._vec0, a._vec1};
689 }
690
691 template <int64_t mask>
692 static std::enable_if_t<blendChoice<sizeof(T)>(mask) == 3, Vectorized<T>>
693 C10_ALWAYS_INLINE blend(const Vectorized<T>& a, const Vectorized<T>& b) {
694 return {a._vec0, b._vec1};
695 }
696
697 template <int64_t mask>
698 static std::enable_if_t<blendChoice<sizeof(T)>(mask) == 4, Vectorized<T>>
699 C10_ALWAYS_INLINE blend(const Vectorized<T>& a, const Vectorized<T>& b) {
700 const vmaskType mask_1st = GetMask1<sizeof(T)>(mask);
701 return {(vtype)vec_sel(a._vec0, b._vec0, mask_1st), a._vec1};
702 }
703
704 template <int64_t mask>
705 static std::enable_if_t<blendChoice<sizeof(T)>(mask) == 5, Vectorized<T>>
706 C10_ALWAYS_INLINE blend(const Vectorized<T>& a, const Vectorized<T>& b) {
707 const vmaskType mask_1st = GetMask1<sizeof(T)>(mask);
708 return {(vtype)vec_sel(a._vec0, b._vec0, mask_1st), b._vec1};
709 }
710
711 template <int64_t mask>
712 static std::enable_if_t<blendChoice<sizeof(T)>(mask) == 6, Vectorized<T>>
713 C10_ALWAYS_INLINE blend(const Vectorized<T>& a, const Vectorized<T>& b) {
714 const vmaskType mask_2nd = GetMask2<sizeof(T)>(mask);
715 // generated masks
716 return {a._vec0, (vtype)vec_sel(a._vec1, b._vec1, mask_2nd)};
717 }
718
719 template <int64_t mask>
720 static std::enable_if_t<blendChoice<sizeof(T)>(mask) == 7, Vectorized<T>>
721 C10_ALWAYS_INLINE blend(const Vectorized<T>& a, const Vectorized<T>& b) {
722 const vmaskType mask_2nd = GetMask2<sizeof(T)>(mask);
723 // generated masks
724 return {b._vec0, (vtype)vec_sel(a._vec1, b._vec1, mask_2nd)};
725 }
726
727 template <int64_t mask>
728 static std::enable_if_t<blendChoice<sizeof(T)>(mask) == 8, Vectorized<T>>
729 C10_ALWAYS_INLINE blend(const Vectorized<T>& a, const Vectorized<T>& b) {
730 const vmaskType mask_1st = GetMask1<sizeof(T)>(mask);
731 const vmaskType mask_2nd = GetMask2<sizeof(T)>(mask);
732 return {
733 (vtype)vec_sel(a._vec0, b._vec0, mask_1st),
734 (vtype)vec_sel(a._vec1, b._vec1, mask_2nd)};
735 }
736
737 template <int16_t Z, int16_t C>
738 static inline std::enable_if_t<(Z >= C), Vectorized<T>> set_inner(
739 const Vectorized<T>& a,
740 const Vectorized<T>& b,
741 size_t count) {
742 return b;
743 }
744
745 template <int16_t Z, int16_t C>
746 static inline std::enable_if_t<(Z < C), Vectorized<T>> set_inner(
747 const Vectorized<T>& a,
748 const Vectorized<T>& b,
749 size_t count) {
750 if (count == Z)
751 return blend<allbitset(Z)>(a, b);
752 else
753 return set_inner<Z + 1, C>(a, b, count);
754 }
755
756 static Vectorized<T> set(
757 const Vectorized<T>& a,
758 const Vectorized<T>& b,
759 size_t count = size()) {
760 if (count == 0)
761 return a;
762 return set_inner<1, size()>(a, b, count);
763 }
764
765 const ElementType& operator[](int idx) const = delete;
766 ElementType& operator[](int idx) = delete;
767
768 Vectorized<T> _not() const {
769 return {(vtype)vec_nor(vecb0(), vecb0()), (vtype)vec_nor(vecb1(), vecb1())};
770 }
771
772 Vectorized<T> C10_ALWAYS_INLINE eq(const Vectorized<T>& other) const {
773 return (*this == other) & Vectorized<T>((T)1.0);
774 }
775 Vectorized<T> C10_ALWAYS_INLINE ne(const Vectorized<T>& other) const {
776 return (*this != other) & Vectorized<T>((T)1.0);
777 }
778 Vectorized<T> C10_ALWAYS_INLINE gt(const Vectorized<T>& other) const {
779 return (*this > other) & Vectorized<T>((T)1.0);
780 }
781 Vectorized<T> C10_ALWAYS_INLINE ge(const Vectorized<T>& other) const {
782 return (*this >= other) & Vectorized<T>((T)1.0);
783 }
784 Vectorized<T> C10_ALWAYS_INLINE lt(const Vectorized<T>& other) const {
785 return (*this < other) & Vectorized<T>((T)1.0);
786 }
787 Vectorized<T> C10_ALWAYS_INLINE le(const Vectorized<T>& other) const {
788 return (*this <= other) & Vectorized<T>((T)1.0);
789 }
790
791 template <
792 typename U = T,
793 std::enable_if_t<!std::is_unsigned<U>::value, int> = 0>
794 Vectorized<U> C10_ALWAYS_INLINE abs() const {
795 return {vec_abs(_vec0), vec_abs(_vec1)};
796 }
797
798 template <
799 typename U = T,
800 std::enable_if_t<std::is_unsigned<U>::value, int> = 0>
801 Vectorized<U> C10_ALWAYS_INLINE abs() const {
802 return {_vec0, _vec1};
803 }
804
805 Vectorized<T> C10_ALWAYS_INLINE neg() const {
806 return {-_vec0, -_vec1};
807 }
808
809 Vectorized<T> isnan() const {
810 auto x = *this;
811 auto ret = (x == x);
812 return ret._not();
813 }
814
815 bool has_inf_nan() const {
816 for (const auto i : c10::irange(size()/2)) {
817 if(_isnan(_vec0[i]) || _isinf(_vec0[i])) {
818 return true;
819 }
820 }
821 for (const auto i : c10::irange(size()/2)) {
822 if(_isnan(_vec1[i]) || _isinf(_vec1[i])) {
823 return true;
824 }
825 }
826 return false;
827 }
828
829 template <
830 typename U = T,
831 std::enable_if_t<std::is_floating_point<U>::value, int> = 0>
832 Vectorized<U> angle() const {
833 auto tmp = blendv(
834 Vectorized<U>(0), Vectorized<U>(c10::pi<U>), *this < Vectorized<U>(0));
835 return blendv(tmp, *this, isnan());
836 }
837
838 template <
839 typename U = T,
840 std::enable_if_t<!std::is_floating_point<U>::value, int> = 0>
841 Vectorized<U> angle() const {
842 return blendv(
843 Vectorized<U>(0), Vectorized<U>(c10::pi<U>), *this < Vectorized<U>(0));
844 }
845
846 Vectorized<T> real() const {
847 return *this;
848 }
849 Vectorized<T> imag() const {
850 return Vectorized<T>{0};
851 }
852 Vectorized<T> conj() const {
853 return *this;
854 }
855
856 template <
857 typename U = T,
858 std::enable_if_t<std::is_floating_point<U>::value, int> = 0>
859 int zero_mask() const {
860 auto cmp = (*this == Vectorized<U>(0));
861 constexpr auto mask_zero_bits = GetBpermZeroMask<U>();
862 ZSimdVectBinary<uint64_t> result0 =
863 vec_bperm_u128((ZSimdVectBinary<uint8_t>)cmp.vecb0(), mask_zero_bits);
864 ZSimdVectBinary<uint64_t> result1 =
865 vec_bperm_u128((ZSimdVectBinary<uint8_t>)cmp.vecb1(), mask_zero_bits);
866 return (result0[0] | (result1[0] << (size() / 2)));
867 }
868
869 Vectorized<T> C10_ALWAYS_INLINE floor() const {
870 return {vec_floor(_vec0), vec_floor(_vec1)};
871 }
872
873 Vectorized<T> C10_ALWAYS_INLINE ceil() const {
874 return {vec_ceil(_vec0), vec_ceil(_vec1)};
875 }
876
877 Vectorized<T> C10_ALWAYS_INLINE round() const {
878 return {vec_round(_vec0), vec_round(_vec1)};
879 }
880
881 Vectorized<T> C10_ALWAYS_INLINE rint() const {
882 return {vec_rint(_vec0), vec_rint(_vec1)};
883 }
884
885 Vectorized<T> C10_ALWAYS_INLINE trunc() const {
886 return {vec_trunc(_vec0), vec_trunc(_vec1)};
887 }
888
889 Vectorized<T> C10_ALWAYS_INLINE frac() const {
890 return *this - trunc();
891 }
892
893 Vectorized<T> C10_ALWAYS_INLINE sqrt() const {
894 return {vec_sqrt(_vec0), vec_sqrt(_vec1)};
895 }
896 Vectorized<T> C10_ALWAYS_INLINE reciprocal() const {
897 return Vectorized<T>((T)1) / (*this);
898 }
899 Vectorized<T> C10_ALWAYS_INLINE rsqrt() const {
900 return sqrt().reciprocal();
901 }
902
903 template <
904 typename U = T,
905 std::enable_if_t<std::is_same<U, float>::value, int> = 0>
906 inline Vectorized<T> mapOrdinary(float (*const f)(float)) const {
907 float a00 = f(_vec0[0]);
908 float a01 = f(_vec0[1]);
909 float a02 = f(_vec0[2]);
910 float a03 = f(_vec0[3]);
911 float a10 = f(_vec1[0]);
912 float a11 = f(_vec1[1]);
913 float a12 = f(_vec1[2]);
914 float a13 = f(_vec1[3]);
915 return Vectorized<T>{a00, a01, a02, a03, a10, a11, a12, a13};
916 }
917
918 template <
919 typename U = T,
920 std::enable_if_t<std::is_same<U, double>::value, int> = 0>
921 inline Vectorized<T> mapOrdinary(double (*const f)(double)) const {
922 return Vectorized<T>(f(_vec0[0]), f(_vec0[1]), f(_vec1[0]), f(_vec1[1]));
923 }
924
925 template <
926 typename U = T,
927 std::enable_if_t<std::is_same<U, float>::value, int> = 0>
928 inline Vectorized<T> mapOrdinary(
929 float (*const f)(float, float),
930 const Vectorized<T>& b) const {
931 float a00 = f(_vec0[0], b._vec0[0]);
932 float a01 = f(_vec0[1], b._vec0[1]);
933 float a02 = f(_vec0[2], b._vec0[2]);
934 float a03 = f(_vec0[3], b._vec0[3]);
935 float a10 = f(_vec1[0], b._vec1[0]);
936 float a11 = f(_vec1[1], b._vec1[1]);
937 float a12 = f(_vec1[2], b._vec1[2]);
938 float a13 = f(_vec1[3], b._vec1[3]);
939 return Vectorized<T>{a00, a01, a02, a03, a10, a11, a12, a13};
940 }
941
942 template <
943 typename U = T,
944 std::enable_if_t<std::is_same<U, double>::value, int> = 0>
945 inline Vectorized<T> mapOrdinary(
946 double (*const f)(double, double),
947 const Vectorized<T>& b) const {
948 return Vectorized<T>(
949 f(_vec0[0], b._vec0[0]),
950 f(_vec0[1], b._vec0[1]),
951 f(_vec1[0], b._vec1[0]),
952 f(_vec1[1], b._vec1[1]));
953 }
954
955 template <
956 typename FloatOp,
957 typename DoubleOp,
958 typename U = T,
959 std::enable_if_t<std::is_same<U, float>::value, int> = 0>
960 inline Vectorized<T> mapSleef(FloatOp f, DoubleOp d) const {
961 vtype a0 = f(_vec0);
962 vtype a1 = f(_vec1);
963 return Vectorized<T>{a0, a1};
964 }
965
966 template <
967 typename FloatOp,
968 typename DoubleOp,
969 typename U = T,
970 std::enable_if_t<std::is_same<U, double>::value, int> = 0>
971 inline Vectorized<T> mapSleef(FloatOp f, DoubleOp d) const {
972 return Vectorized<T>(d(_vec0), d(_vec1));
973 }
974
975 template <
976 typename FloatOp,
977 typename DoubleOp,
978 typename U = T,
979 std::enable_if_t<std::is_same<U, float>::value, int> = 0>
980 inline Vectorized<T> mapSleef(FloatOp f, DoubleOp d, const Vectorized<T>& b)
981 const {
982 vtype a0 = f(_vec0, b._vec0);
983 vtype a1 = f(_vec1, b._vec1);
984 return Vectorized<T>{a0, a1};
985 }
986
987 template <
988 typename FloatOp,
989 typename DoubleOp,
990 typename U = T,
991 std::enable_if_t<std::is_same<U, double>::value, int> = 0>
992 inline Vectorized<T> mapSleef(FloatOp f, DoubleOp d, const Vectorized<T>& b)
993 const {
994 return Vectorized<T>(d(_vec0, b._vec0), d(_vec1, b._vec1));
995 }
996
997 Vectorized<T> acos() const {
998 return mapSleef(Sleef_acosf4_u10, Sleef_acosd2_u10);
999 }
1000 Vectorized<T> asin() const {
1001 return mapSleef(Sleef_asinf4_u10, Sleef_asind2_u10);
1002 }
1003 Vectorized<T> atan() const {
1004 return mapSleef(Sleef_atanf4_u10, Sleef_atand2_u10);
1005 }
1006 Vectorized<T> atanh() const {
1007 return mapSleef(Sleef_atanhf4_u10, Sleef_atanhd2_u10);
1008 }
1009
1010 Vectorized<T> erf() const {
1011 return mapSleef(Sleef_erff4_u10, Sleef_erfd2_u10);
1012 }
1013 Vectorized<T> erfc() const {
1014 return mapSleef(Sleef_erfcf4_u15, Sleef_erfcd2_u15);
1015 }
1016
1017 Vectorized<T> exp() const {
1018 return mapSleef(Sleef_expf4_u10, Sleef_expd2_u10);
1019 }
1020 Vectorized<T> exp2() const {
1021 return mapSleef(Sleef_exp2f4_u10, Sleef_exp2d2_u10);
1022 }
1023 Vectorized<T> expm1() const {
1024 return mapSleef(Sleef_expm1f4_u10, Sleef_expm1d2_u10);
1025 }
1026 Vectorized<T> exp_u20() const {
1027 return exp();
1028 }
1029
1030 Vectorized<T> log() const {
1031 return mapSleef(Sleef_logf4_u10, Sleef_logd2_u10);
1032 }
1033 Vectorized<T> log2() const {
1034 return mapSleef(Sleef_log2f4_u10, Sleef_log2d2_u10);
1035 }
1036 Vectorized<T> log10() const {
1037 return mapSleef(Sleef_log10f4_u10, Sleef_log10d2_u10);
1038 }
1039 Vectorized<T> log1p() const {
1040 return mapSleef(Sleef_log1pf4_u10, Sleef_log1pd2_u10);
1041 }
1042
1043 Vectorized<T> sin() const {
1044 return mapSleef(Sleef_sinf4_u10, Sleef_sind2_u10);
1045 }
1046 Vectorized<T> sinh() const {
1047 return mapSleef(Sleef_sinhf4_u10, Sleef_sinhd2_u10);
1048 }
1049 Vectorized<T> cos() const {
1050 return mapSleef(Sleef_cosf4_u10, Sleef_cosd2_u10);
1051 }
1052 Vectorized<T> cosh() const {
1053 return mapSleef(Sleef_coshf4_u10, Sleef_coshd2_u10);
1054 }
1055
1056 Vectorized<T> tan() const {
1057 return mapSleef(Sleef_tanf4_u10, Sleef_tand2_u10);
1058 }
1059 Vectorized<T> tanh() const {
1060 return mapSleef(Sleef_tanhf4_u10, Sleef_tanhd2_u10);
1061 }
1062
1063 Vectorized<T> lgamma() const {
1064 return mapSleef(Sleef_lgammaf4_u10, Sleef_lgammad2_u10);
1065 }
1066
1067 Vectorized<T> atan2(const Vectorized<T>& b) const {
1068 return mapSleef(Sleef_atan2f4_u10, Sleef_atan2d2_u10, b);
1069 }
1070 Vectorized<T> copysign(const Vectorized<T>& sign) const {
1071 return mapSleef(Sleef_copysignf4, Sleef_copysignd2, sign);
1072 }
1073 Vectorized<T> fmod(const Vectorized<T>& q) const {
1074 return mapSleef(Sleef_fmodf4, Sleef_fmodd2, q);
1075 }
1076
1077 Vectorized<T> hypot(const Vectorized<T>& b) const {
1078 return mapSleef(Sleef_hypotf4_u05, Sleef_hypotd2_u05, b);
1079 }
1080
1081 Vectorized<T> pow(const Vectorized<T>& b) const {
1082 return mapSleef(Sleef_powf4_u10, Sleef_powd2_u10, b);
1083 }
1084
1085 Vectorized<T> nextafter(const Vectorized<T>& b) const {
1086 return mapSleef(Sleef_nextafterf4, Sleef_nextafterd2, b);
1087 }
1088
1089 Vectorized<T> erfinv() const {
1090 return mapOrdinary(calc_erfinv);
1091 }
1092
1093 Vectorized<T> digamma() const {
1094 return mapOrdinary(calc_digamma);
1095 }
1096
1097 Vectorized<T> igamma(const Vectorized<T>& x) const {
1098 return mapOrdinary(calc_igamma, x);
1099 }
1100
1101 Vectorized<T> igammac(const Vectorized<T>& x) const {
1102 return mapOrdinary(calc_igammac, x);
1103 }
1104
1105 Vectorized<T> i0() const {
1106 return mapOrdinary(calc_i0);
1107 }
1108
1109 Vectorized<T> i0e() const {
1110 return mapOrdinary(calc_i0e);
1111 }
1112
1113 template <
1114 typename U = T,
1115 std::enable_if_t<!std::is_floating_point<U>::value, int> = 0>
1116 Vectorized<T> minimum(const Vectorized<T>& other) const {
1117 return {vec_min(_vec0, other._vec0), vec_min(_vec1, other._vec1)};
1118 }
1119
1120 /* Propagates NaN if either input is a NaN. */
1121 template <
1122 typename U = T,
1123 std::enable_if_t<std::is_floating_point<U>::value, int> = 0>
1124 Vectorized<T> minimum(const Vectorized<T>& other) const {
1125 Vectorized<T> tmp = {vec_min(_vec0, other._vec0), vec_min(_vec1, other._vec1)};
1126 tmp = blendv(tmp, *this, isnan());
1127 return blendv(tmp, other, other.isnan());
1128 }
1129
1130 template <
1131 typename U = T,
1132 std::enable_if_t<!std::is_floating_point<U>::value, int> = 0>
1133 Vectorized<T> maximum(const Vectorized<T>& other) const {
1134 return {vec_max(_vec0, other._vec0), vec_max(_vec1, other._vec1)};
1135 }
1136
1137 /* Propagates NaN if either input is a NaN. */
1138 template <
1139 typename U = T,
1140 std::enable_if_t<std::is_floating_point<U>::value, int> = 0>
1141 Vectorized<T> maximum(const Vectorized<T>& other) const {
1142 Vectorized<T> tmp = {vec_max(_vec0, other._vec0), vec_max(_vec1, other._vec1)};
1143 tmp = blendv(tmp, *this, isnan());
1144 return blendv(tmp, other, other.isnan());
1145 }
1146
1147 template <
1148 typename U = T,
1149 std::enable_if_t<!std::is_floating_point<U>::value, int> = 0>
1150 Vectorized<T> clamp_min(const Vectorized<T>& min) const {
1151 return {vec_max(_vec0, min._vec0), vec_max(_vec1, min._vec1)};
1152 }
1153
1154 /* Keeps NaN if actual value is NaN */
1155 template <
1156 typename U = T,
1157 std::enable_if_t<std::is_floating_point<U>::value, int> = 0>
1158 Vectorized<T> clamp_min(const Vectorized<T>& min) const {
1159 Vectorized<T> tmp = {vec_max(_vec0, min._vec0), vec_max(_vec1, min._vec1)};
1160 return blendv(tmp, *this, isnan());
1161 }
1162
1163 template <
1164 typename U = T,
1165 std::enable_if_t<!std::is_floating_point<U>::value, int> = 0>
1166 Vectorized<T> clamp_max(const Vectorized<T>& max) const {
1167 return {vec_min(_vec0, max._vec0), vec_min(_vec1, max._vec1)};
1168 }
1169
1170 /* Keeps NaN if actual value is NaN */
1171 template <
1172 typename U = T,
1173 std::enable_if_t<std::is_floating_point<U>::value, int> = 0>
1174 Vectorized<T> clamp_max(const Vectorized<T>& max) const {
1175 Vectorized<T> tmp = {vec_min(_vec0, max._vec0), vec_min(_vec1, max._vec1)};
1176 return blendv(tmp, *this, isnan());
1177 }
1178
1179 template <
1180 typename U = T,
1181 std::enable_if_t<std::is_same<U, float>::value, int> = 0>
1182 Vectorized<T> swapped() const {
1183 auto swap_mask = GetSwapMaskFloat();
1184 vtype v0 = vec_perm(_vec0, _vec0, swap_mask);
1185 vtype v1 = vec_perm(_vec1, _vec1, swap_mask);
1186 return {v0, v1};
1187 }
1188
1189 template <
1190 typename U = T,
1191 std::enable_if_t<std::is_same<U, double>::value, int> = 0>
1192 Vectorized<T> swapped() const {
1193 vtype v0 = vec_permi(_vec0, _vec0, 2);
1194 vtype v1 = vec_permi(_vec1, _vec1, 2);
1195 return {v0, v1};
1196 }
1197
1198 template <
1199 typename U = T,
1200 std::enable_if_t<std::is_floating_point<U>::value, int> = 0>
1201 static Vectorized<T> mergee(Vectorized<T>& first, Vectorized<T>& second) {
1202 return {
1203 vec_mergee(first._vec0, second._vec0),
1204 vec_mergee(first._vec1, second._vec1)};
1205 }
1206
1207 template <
1208 typename U = T,
1209 std::enable_if_t<std::is_floating_point<U>::value, int> = 0>
1210 static Vectorized<T> mergeo(Vectorized<T>& first, Vectorized<T>& second) {
1211 return {
1212 vec_mergeo(first._vec0, second._vec0),
1213 vec_mergeo(first._vec1, second._vec1)};
1214 }
1215
1216 static Vectorized<T> horizontal_add_perm(
1217 Vectorized<T>& first,
1218 Vectorized<T>& second) {
1219 // we will simulate it differently with 6 instructions total
1220 // lets permute second so that we can add it getting horizontal sums
1221 auto first_perm = first.swapped(); // 2perm
1222 auto second_perm = second.swapped(); // 2perm
1223 // summ
1224 auto first_ret = first + first_perm; // 2add
1225 auto second_ret = second + second_perm; // 2 add
1226 // now lets choose evens
1227 return mergee(first_ret, second_ret); // 2 mergee's
1228 }
1229
1230 static Vectorized<T> horizontal_sub_perm(
1231 Vectorized<T>& first,
1232 Vectorized<T>& second) {
1233 // we will simulate it differently with 6 instructions total
1234 // lets permute second so that we can add it getting horizontal sums
1235 auto first_perm = first.swapped(); // 2perm
1236 auto second_perm = second.swapped(); // 2perm
1237 // summ
1238 auto first_ret = first - first_perm; // 2sub
1239 auto second_ret = second - second_perm; // 2 sub
1240 // now lets choose evens
1241 return mergee(first_ret, second_ret); // 2 mergee's
1242 }
1243
1244 template <
1245 typename U = T,
1246 std::enable_if_t<std::is_floating_point<U>::value, int> = 0>
1247 Vectorized<T> mergee() const {
1248 return {vec_mergee(_vec0, _vec0), vec_mergee(_vec1, _vec1)};
1249 }
1250
1251 template <
1252 typename U = T,
1253 std::enable_if_t<std::is_floating_point<U>::value, int> = 0>
1254 Vectorized<T> mergeo() const {
1255 return {vec_mergeo(_vec0, _vec0), vec_mergeo(_vec1, _vec1)};
1256 }
1257
1258 template <
1259 typename U = T,
1260 std::enable_if_t<std::is_same<U, uint8_t>::value, int> = 0>
1261 Vectorized<int32_t> to_vec_float_helper() const {
1262 int32_t values[8] = {
1263 _vec0[0],
1264 _vec0[1],
1265 _vec0[2],
1266 _vec0[3],
1267 _vec0[4],
1268 _vec0[5],
1269 _vec0[6],
1270 _vec0[7],
1271 };
1272
1273 return Vectorized<int32_t>{
1274 values[0], values[1], values[2], values[3],
1275 values[4], values[5], values[6], values[7]
1276 };
1277 }
1278
1279 template <
1280 typename U = T,
1281 std::enable_if_t<std::is_same<U, int32_t>::value, int> = 0>
1282 Vectorized<uint8_t> to_vec_uint8_helper() const {
1283 // helper function for float to uint8_t conversion
1284 uint8_t values[8] = {
1285 static_cast<uint8_t>(_vec0[0]),
1286 static_cast<uint8_t>(_vec0[1]),
1287 static_cast<uint8_t>(_vec0[2]),
1288 static_cast<uint8_t>(_vec0[3]),
1289 static_cast<uint8_t>(_vec1[0]),
1290 static_cast<uint8_t>(_vec1[1]),
1291 static_cast<uint8_t>(_vec1[2]),
1292 static_cast<uint8_t>(_vec1[3]),
1293 };
1294
1295 return Vectorized<uint8_t>{
1296 values[0], values[1], values[2], values[3],
1297 values[4], values[5], values[6], values[7],
1298 0, 0, 0, 0,
1299 0, 0, 0, 0,
1300 0, 0, 0, 0,
1301 0, 0, 0, 0,
1302 0, 0, 0, 0,
1303 0, 0, 0, 0,
1304 };
1305 }
1306 };
1307
1308 #define ZVECTOR_OPERATORS(typex) \
1309 template <> \
1310 Vectorized<typex> C10_ALWAYS_INLINE operator+(const Vectorized<typex>& a, const Vectorized<typex>& b) { \
1311 return Vectorized<typex>{a.vec0() + b.vec0(), a.vec1() + b.vec1()}; \
1312 } \
1313 \
1314 template <> \
1315 Vectorized<typex> C10_ALWAYS_INLINE operator-(const Vectorized<typex>& a, const Vectorized<typex>& b) { \
1316 return Vectorized<typex>{a.vec0() - b.vec0(), a.vec1() - b.vec1()}; \
1317 } \
1318 \
1319 template <> \
1320 Vectorized<typex> C10_ALWAYS_INLINE operator*(const Vectorized<typex>& a, const Vectorized<typex>& b) { \
1321 return Vectorized<typex>{a.vec0() * b.vec0(), a.vec1() * b.vec1()}; \
1322 } \
1323 \
1324 template <> \
1325 Vectorized<typex> C10_ALWAYS_INLINE operator/(const Vectorized<typex>& a, const Vectorized<typex>& b) { \
1326 return Vectorized<typex>{a.vec0() / b.vec0(), a.vec1() / b.vec1()}; \
1327 } \
1328 \
1329 template <> \
1330 Vectorized<typex> C10_ALWAYS_INLINE operator&(const Vectorized<typex>& a, const Vectorized<typex>& b) { \
1331 return Vectorized<typex>{ \
1332 (Vectorized<typex>::vtype)(a.vecb0() & b.vecb0()), \
1333 (Vectorized<typex>::vtype)(a.vecb1() & b.vecb1())}; \
1334 } \
1335 \
1336 template <> \
1337 Vectorized<typex> C10_ALWAYS_INLINE operator|(const Vectorized<typex>& a, const Vectorized<typex>& b) { \
1338 return Vectorized<typex>{ \
1339 (Vectorized<typex>::vtype)(a.vecb0() | b.vecb0()), \
1340 (Vectorized<typex>::vtype)(a.vecb1() | b.vecb1())}; \
1341 } \
1342 \
1343 template <> \
1344 Vectorized<typex> C10_ALWAYS_INLINE operator^(const Vectorized<typex>& a, const Vectorized<typex>& b) { \
1345 return Vectorized<typex>{ \
1346 (Vectorized<typex>::vtype)(a.vecb0() ^ b.vecb0()), \
1347 (Vectorized<typex>::vtype)(a.vecb1() ^ b.vecb1())}; \
1348 } \
1349 \
1350 Vectorized<typex> C10_ALWAYS_INLINE operator==(const Vectorized<typex>& a, const Vectorized<typex>& b) { \
1351 return Vectorized<typex>{ \
1352 vec_cmpeq(a.vec0(), b.vec0()), vec_cmpeq(a.vec1(), b.vec1())}; \
1353 } \
1354 \
1355 Vectorized<typex> C10_ALWAYS_INLINE operator!=(const Vectorized<typex>& a, const Vectorized<typex>& b) { \
1356 return Vectorized<typex>{ \
1357 vec_cmpeq(a.vec0(), b.vec0()), vec_cmpeq(a.vec1(), b.vec1())} \
1358 ._not(); \
1359 } \
1360 \
1361 Vectorized<typex> C10_ALWAYS_INLINE operator>(const Vectorized<typex>& a, const Vectorized<typex>& b) { \
1362 return Vectorized<typex>{ \
1363 vec_cmpgt(a.vec0(), b.vec0()), vec_cmpgt(a.vec1(), b.vec1())}; \
1364 } \
1365 \
1366 Vectorized<typex> C10_ALWAYS_INLINE operator>=(const Vectorized<typex>& a, const Vectorized<typex>& b) { \
1367 return Vectorized<typex>{ \
1368 vec_cmpge(a.vec0(), b.vec0()), vec_cmpge(a.vec1(), b.vec1())}; \
1369 } \
1370 \
1371 Vectorized<typex> C10_ALWAYS_INLINE operator<(const Vectorized<typex>& a, const Vectorized<typex>& b) { \
1372 return Vectorized<typex>{ \
1373 vec_cmplt(a.vec0(), b.vec0()), vec_cmplt(a.vec1(), b.vec1())}; \
1374 } \
1375 \
1376 Vectorized<typex> C10_ALWAYS_INLINE operator<=(const Vectorized<typex>& a, const Vectorized<typex>& b) { \
1377 return Vectorized<typex>{ \
1378 vec_cmple(a.vec0(), b.vec0()), vec_cmple(a.vec1(), b.vec1())}; \
1379 }
1380
1381 ZVECTOR_OPERATORS(float)
1382 ZVECTOR_OPERATORS(double)
1383 ZVECTOR_OPERATORS(int8_t)
1384 ZVECTOR_OPERATORS(uint8_t)
1385 ZVECTOR_OPERATORS(uint16_t)
1386 ZVECTOR_OPERATORS(int16_t)
1387 ZVECTOR_OPERATORS(int32_t)
1388 ZVECTOR_OPERATORS(int64_t)
1389
1390 #undef ZVECTOR_OPERATORS
1391
1392 #define ZVECTOR_OPERATORS(typex) \
1393 template <> \
1394 Vectorized<typex> C10_ALWAYS_INLINE operator<<(const Vectorized<typex>& a, const Vectorized<typex>& b) { \
1395 constexpr Vectorized<typex>::ElementType max_shift \
1396 = sizeof(Vectorized<typex>::ElementType) * CHAR_BIT; \
1397 \
1398 Vectorized<typex>::ElementType a_array[Vectorized<typex>::size()]; \
1399 Vectorized<typex>::ElementType b_array[Vectorized<typex>::size()]; \
1400 Vectorized<typex>::ElementType c_array[Vectorized<typex>::size()]; \
1401 \
1402 a.store(a_array); \
1403 b.store(b_array); \
1404 \
1405 for (int i = 0; i != Vectorized<typex>::size(); i++) { \
1406 typex shift = b_array[i]; \
1407 if ((static_cast<std::make_signed_t<typex>>(shift) < 0) || (shift >= max_shift)) { \
1408 c_array[i] = 0; \
1409 } else { \
1410 c_array[i] = static_cast<std::make_unsigned_t<typex>>(a_array[i]) << shift; \
1411 } \
1412 } \
1413 \
1414 return Vectorized<typex>::loadu(c_array); \
1415 } \
1416 \
1417 template <> \
1418 Vectorized<typex> C10_ALWAYS_INLINE operator>>(const Vectorized<typex>& a, const Vectorized<typex>& b) { \
1419 /* right shift value to retain sign bit for signed and no bits for unsigned */ \
1420 constexpr Vectorized<typex>::ElementType max_shift \
1421 = sizeof(typex) * CHAR_BIT - std::is_signed_v<typex>; \
1422 \
1423 Vectorized<typex>::ElementType a_array[Vectorized<typex>::size()]; \
1424 Vectorized<typex>::ElementType b_array[Vectorized<typex>::size()]; \
1425 Vectorized<typex>::ElementType c_array[Vectorized<typex>::size()]; \
1426 \
1427 a.store(a_array); \
1428 b.store(b_array); \
1429 \
1430 for (int i = 0; i != Vectorized<typex>::size(); i++) { \
1431 typex shift = b_array[i]; \
1432 if ((static_cast<std::make_signed_t<typex>>(shift) < 0) || (shift >= max_shift)) { \
1433 c_array[i] = a_array[i] >> max_shift; \
1434 } else { \
1435 c_array[i] = a_array[i] >> shift; \
1436 } \
1437 } \
1438 \
1439 return Vectorized<typex>::loadu(c_array); \
1440 } \
1441 \
1442 template <> \
1443 inline Vectorized<typex> operator~(const Vectorized<typex>& a) { \
1444 return a._not(); \
1445 }
1446
1447 ZVECTOR_OPERATORS(int8_t)
1448 ZVECTOR_OPERATORS(uint8_t)
1449 ZVECTOR_OPERATORS(uint16_t)
1450 ZVECTOR_OPERATORS(int16_t)
1451 ZVECTOR_OPERATORS(int32_t)
1452 ZVECTOR_OPERATORS(int64_t)
1453
1454 #undef ZVECTOR_OPERATORS
1455
1456 #define DEFINE_MAXMIN_FUNCS(operand_type) \
1457 template <> \
1458 Vectorized<operand_type> inline maximum( \
1459 const Vectorized<operand_type>& a, const Vectorized<operand_type>& b) { \
1460 return a.maximum(b); \
1461 } \
1462 template <> \
1463 Vectorized<operand_type> inline minimum( \
1464 const Vectorized<operand_type>& a, const Vectorized<operand_type>& b) { \
1465 return a.minimum(b); \
1466 }
1467
1468 #define DEFINE_CLAMP_MAXMIN_FUNCS(typex) \
1469 DEFINE_MAXMIN_FUNCS(typex) \
1470 template <> \
1471 Vectorized<typex> C10_ALWAYS_INLINE clamp_min( \
1472 const Vectorized<typex>& a, const Vectorized<typex>& min) { \
1473 return a.clamp_min(min); \
1474 } \
1475 template <> \
1476 Vectorized<typex> C10_ALWAYS_INLINE clamp_max( \
1477 const Vectorized<typex>& a, const Vectorized<typex>& max) { \
1478 return a.clamp_max(max); \
1479 } \
1480 template <> \
1481 Vectorized<typex> C10_ALWAYS_INLINE clamp( \
1482 const Vectorized<typex>& a, \
1483 const Vectorized<typex>& min, \
1484 const Vectorized<typex>& max) { \
1485 return clamp_max(clamp_min(a, min), max); \
1486 }
1487
1488 DEFINE_CLAMP_MAXMIN_FUNCS(int8_t)
1489 DEFINE_CLAMP_MAXMIN_FUNCS(uint8_t)
1490 DEFINE_CLAMP_MAXMIN_FUNCS(int16_t)
1491 DEFINE_CLAMP_MAXMIN_FUNCS(int32_t)
1492 DEFINE_CLAMP_MAXMIN_FUNCS(int64_t)
1493 DEFINE_CLAMP_MAXMIN_FUNCS(float)
1494 DEFINE_CLAMP_MAXMIN_FUNCS(double)
1495
1496 namespace { /* unnamed namespace */
1497
1498 #if !defined(vec_float) || __ARCH__ < 13
1499 #warning \
1500 "float->int and int->float conversion is simulated. compile for z15 for improved performance"
1501 inline ZSimdVect<float> vec_int_flt(const ZSimdVect<int> x) {
1502 return ZSimdVect<float>{float(x[0]), float(x[1]), float(x[2]), float(x[3])};
1503 }
1504 inline ZSimdVect<int> vec_flt_int(const ZSimdVect<float> x) {
1505 return ZSimdVect<int>{int(x[0]), int(x[1]), int(x[2]), int(x[3])};
1506 }
1507 #else
1508 #define vec_int_flt vec_float
1509 #define vec_flt_int vec_signed
1510 #endif
1511
1512 Vectorized<float> zvec_convert_to_float(const Vectorized<int32_t>& x) {
1513 return {vec_int_flt(x.vec0()), vec_int_flt(x.vec1())};
1514 }
1515
1516 Vectorized<int32_t> zvec_convert_to_int(const Vectorized<float>& x) {
1517 return {vec_flt_int(x.vec0()), vec_flt_int(x.vec1())};
1518 }
1519
1520 Vectorized<double> zvec_convert_to_float(const Vectorized<int64_t>& x) {
1521 return {vec_double(x.vec0()), vec_double(x.vec1())};
1522 }
1523
1524 Vectorized<int64_t> zvec_convert_to_int(const Vectorized<double>& x) {
1525 return {vec_signed(x.vec0()), vec_signed(x.vec1())};
1526 }
1527
1528 } /* unnamed namespace */
1529
1530 template <typename T, typename V>
1531 Vectorized<V> cast_zvector(const Vectorized<T>& x) {
1532 using cast_type = typename Vectorized<V>::vtype;
1533 return Vectorized<V>{(cast_type)x.vec0(), (cast_type)x.vec1()};
1534 }
1535
1536 template <>
1537 Vectorized<float> C10_ALWAYS_INLINE fmadd(
1538 const Vectorized<float>& a,
1539 const Vectorized<float>& b,
1540 const Vectorized<float>& c) {
1541 return Vectorized<float>{
1542 __builtin_s390_vfmasb(a.vec0(), b.vec0(), c.vec0()),
1543 __builtin_s390_vfmasb(a.vec1(), b.vec1(), c.vec1())};
1544 }
1545 template <>
1546 Vectorized<double> C10_ALWAYS_INLINE fmadd(
1547 const Vectorized<double>& a,
1548 const Vectorized<double>& b,
1549 const Vectorized<double>& c) {
1550 return Vectorized<double>{
1551 __builtin_s390_vfmadb(a.vec0(), b.vec0(), c.vec0()),
1552 __builtin_s390_vfmadb(a.vec1(), b.vec1(), c.vec1())};
1553 }
1554 template <>
1555 Vectorized<int16_t> C10_ALWAYS_INLINE fmadd(
1556 const Vectorized<int16_t>& a,
1557 const Vectorized<int16_t>& b,
1558 const Vectorized<int16_t>& c) {
1559 return Vectorized<int16_t>{
1560 a.vec0() * b.vec0() + c.vec0(), a.vec1() * b.vec1() + c.vec1()};
1561 }
1562 template <>
1563 Vectorized<int32_t> C10_ALWAYS_INLINE fmadd(
1564 const Vectorized<int32_t>& a,
1565 const Vectorized<int32_t>& b,
1566 const Vectorized<int32_t>& c) {
1567 return Vectorized<int32_t>{
1568 a.vec0() * b.vec0() + c.vec0(), a.vec1() * b.vec1() + c.vec1()};
1569 }
1570 template <>
1571 Vectorized<int64_t> C10_ALWAYS_INLINE fmadd(
1572 const Vectorized<int64_t>& a,
1573 const Vectorized<int64_t>& b,
1574 const Vectorized<int64_t>& c) {
1575 return Vectorized<int64_t>{
1576 a.vec0() * b.vec0() + c.vec0(), a.vec1() * b.vec1() + c.vec1()};
1577 }
1578
1579 template <>
1580 Vectorized<int64_t> C10_ALWAYS_INLINE
1581 convert_to_int_of_same_size<double>(const Vectorized<double>& src) {
1582 return zvec_convert_to_int(src);
1583 }
1584
1585 template <>
1586 Vectorized<int32_t> C10_ALWAYS_INLINE
1587 convert_to_int_of_same_size<float>(const Vectorized<float>& src) {
1588 return zvec_convert_to_int(src);
1589 }
1590
1591 template <>
1592 inline void convert(const int32_t* src, float* dst, int64_t n) {
1593 // int32_t and float have same size
1594 int64_t i;
1595 for (i = 0; i <= (n - Vectorized<float>::size());
1596 i += Vectorized<float>::size()) {
1597 const int32_t* src_a = src + i;
1598 float* dst_a = dst + i;
1599 auto input_vec = Vectorized<int32_t>::loadu(src_a);
1600 auto output_vec = zvec_convert_to_float(input_vec);
1601 output_vec.store(dst_a);
1602 }
1603
1604 for (; i < n; i++) {
1605 dst[i] = static_cast<float>(src[i]);
1606 }
1607 }
1608
1609 template <>
1610 inline void convert(const int64_t* src, double* dst, int64_t n) {
1611 int64_t i;
1612 for (i = 0; i <= (n - Vectorized<double>::size());
1613 i += Vectorized<double>::size()) {
1614 const int64_t* src_a = src + i;
1615 double* dst_a = dst + i;
1616 auto input_vec = Vectorized<int64_t>::loadu(src_a);
1617 auto output_vec = zvec_convert_to_float(input_vec);
1618 output_vec.store(dst_a);
1619 }
1620 for (; i < n; i++) {
1621 dst[i] = static_cast<double>(src[i]);
1622 }
1623 }
1624
1625 #define DEFINE_REINTERPRET_CAST_FUNCS(Fst, Cst) \
1626 template <> \
1627 C10_ALWAYS_INLINE Vectorized<Cst> cast<Cst, Fst>( \
1628 const Vectorized<Fst>& src) { \
1629 return cast_zvector<Fst, Cst>(src); \
1630 }
1631
1632 #define DEFINE_REINTERPRET_CAST_TO_ALL_FUNCS(Fst) \
1633 DEFINE_REINTERPRET_CAST_FUNCS(Fst, double) \
1634 DEFINE_REINTERPRET_CAST_FUNCS(Fst, float) \
1635 DEFINE_REINTERPRET_CAST_FUNCS(Fst, int64_t) \
1636 DEFINE_REINTERPRET_CAST_FUNCS(Fst, int32_t) \
1637 DEFINE_REINTERPRET_CAST_FUNCS(Fst, int16_t)
1638
1639 DEFINE_REINTERPRET_CAST_TO_ALL_FUNCS(float)
1640 DEFINE_REINTERPRET_CAST_TO_ALL_FUNCS(double)
1641 DEFINE_REINTERPRET_CAST_TO_ALL_FUNCS(int64_t)
1642 DEFINE_REINTERPRET_CAST_TO_ALL_FUNCS(int32_t)
1643 DEFINE_REINTERPRET_CAST_TO_ALL_FUNCS(int16_t)
1644
1645 #undef DEFINE_REINTERPRET_CAST_FUNCS
1646
1647 template <typename T>
1648 struct unpack_type {
1649 using type = T;
1650 };
1651 template <>
1652 struct unpack_type<int8_t> {
1653 using type = int16_t;
1654 };
1655 template <>
1656 struct unpack_type<uint8_t> {
1657 using type = int16_t;
1658 };
1659 template <>
1660 struct unpack_type<int16_t> {
1661 using type = int32_t;
1662 };
1663
1664 template <typename T>
1665 struct pack_type {
1666 using type = T;
1667 };
1668 template <>
1669 struct pack_type<int16_t> {
1670 using type = int8_t;
1671 };
1672 template <>
1673 struct pack_type<int32_t> {
1674 using type = int16_t;
1675 };
1676
1677 namespace { /* unnamed namespace */
1678
1679 template <typename T, typename V = typename unpack_type<T>::type>
1680 std::pair<Vectorized<V>, Vectorized<V>> unpack(const Vectorized<T>& x) {
1681 auto vec0 = vec_unpackh(x.vec0());
1682 auto vec1 = vec_unpackl(x.vec0());
1683 auto vec2 = vec_unpackh(x.vec1());
1684 auto vec3 = vec_unpackl(x.vec1());
1685 return {Vectorized<V>{vec0, vec1}, Vectorized<V>{vec2, vec3}};
1686 }
1687
1688 template <>
1689 std::pair<Vectorized<int16_t>, Vectorized<int16_t>> unpack<uint8_t, int16_t>(
1690 const Vectorized<uint8_t>& x) {
1691 using typeX = typename Vectorized<uint16_t>::vtype;
1692 typeX vec0 = vec_unpackh(x.vec0());
1693 typeX vec1 = vec_unpackl(x.vec0());
1694 typeX vec2 = vec_unpackh(x.vec1());
1695 typeX vec3 = vec_unpackl(x.vec1());
1696 // auto mask = Vectorized<uint16_t>(0xFF);
1697 // vec0 = vec0 & mask;
1698 // vec1 = vec1 & mask;
1699 // vec2 = vec2 & mask;
1700 // vec3 = vec3 & mask;
1701 return {
1702 cast_zvector<uint16_t, int16_t>(Vectorized<uint16_t>{vec0, vec1}),
1703 cast_zvector<uint16_t, int16_t>(Vectorized<uint16_t>{vec2, vec3})};
1704 }
1705
1706 template <typename T, typename V = typename pack_type<T>::type>
1707 Vectorized<V> pack(const Vectorized<T>& first, const Vectorized<T>& second) {
1708 auto vec0 = vec_packs(first.vec0(), first.vec1());
1709 auto vec1 = vec_packs(second.vec0(), second.vec1());
1710 return Vectorized<V>{vec0, vec1};
1711 }
1712
1713 template <>
1714 Vectorized<uint8_t> pack(
1715 const Vectorized<int16_t>& first,
1716 const Vectorized<int16_t>& second) {
1717 auto vec0 = vec_packsu(first.vec0(), first.vec1());
1718 auto vec1 = vec_packsu(second.vec0(), second.vec1());
1719 return Vectorized<uint8_t>{vec0, vec1};
1720 }
1721
1722 } /* unnamed namespace */
1723
1724 //////////////////////////////////QUANT///////////////////////////////////////////
1725 template <typename T>
1726 struct Vectorized<T, std::enable_if_t<is_zarch_implemented_quant<T>()>> {
1727 public:
1728 using value_type = typename T::underlying;
1729 using vtype = ZSimdVect<value_type>;
1730 using vmaskType = ZSimdVectBinary<value_type>;
1731 using vinner_type = Vectorized<value_type>;
1732 using size_type = int;
1733
1734 static constexpr size_type size() {
1735 return VECTOR_WIDTH / sizeof(value_type);
1736 }
1737
1738 static constexpr size_t float_num_vecs() {
1739 return size() / Vectorized<float>::size();
1740 }
1741 static constexpr int int_num_vecs() {
1742 return float_num_vecs();
1743 }
1744 using float_vec_return_type = std::array<Vectorized<float>, float_num_vecs()>;
1745 using int_vec_return_type =
1746 std::array<Vectorized<c10::qint32>, int_num_vecs()>;
1747
1748 private:
1749 vinner_type _vec;
1750
1751 public:
1752 Vectorized() {}
1753
1754 explicit C10_ALWAYS_INLINE Vectorized(vinner_type v) : _vec{v} {}
1755 Vectorized(const T& val) : _vec(val.val_) {}
1756
1757 C10_ALWAYS_INLINE const vinner_type& vec() const {
1758 return _vec;
1759 }
1760
1761 template <typename U>
1762 static Vectorized<T> C10_ALWAYS_INLINE
1763 loadu(const U* ptr, int count = size()) {
1764 return Vectorized<T>{vinner_type::loadu(ptr, count)};
1765 }
1766
1767 template <typename U>
1768 void C10_ALWAYS_INLINE store(U* ptr, int count = size()) const {
1769 _vec.store(ptr, count);
1770 }
1771
1772 Vectorized<T> relu(Vectorized<T> zero_point) const {
1773 return Vectorized<T>{_vec.maximum(zero_point._vec)};
1774 }
1775
1776 Vectorized<T> relu6(Vectorized<T> zero_point, Vectorized<T> q_six) const {
1777 auto ret_max = _vec.maximum(zero_point._vec);
1778 auto ret_min = ret_max.minimum(q_six._vec);
1779 return Vectorized<T>{ret_min};
1780 }
1781
1782 template <
1783 typename U = T,
1784 std::enable_if_t<Vectorized<U>::float_num_vecs() == 1, int> = 0>
1785 int_vec_return_type widening_subtract(Vectorized<T> b) const {
1786 return {*this - b};
1787 }
1788
1789 template <
1790 typename U = T,
1791 std::enable_if_t<Vectorized<U>::float_num_vecs() == 1, int> = 0>
1792 float_vec_return_type dequantize(
1793 Vectorized<float> scale,
1794 Vectorized<float> zero_point,
1795 Vectorized<float> scale_zp_premul) const {
1796 auto float_val = zvec_convert_to_float(_vec);
1797 return {fmadd(scale, float_val, scale_zp_premul)};
1798 }
1799
1800 template <
1801 typename U = T,
1802 std::enable_if_t<Vectorized<U>::float_num_vecs() == 1, int> = 0>
1803 float_vec_return_type dequantize(
1804 Vectorized<float> scale,
1805 Vectorized<float> zero_point) const {
1806 auto float_val = zvec_convert_to_float(_vec);
1807 return {(float_val - zero_point) * scale};
1808 }
1809
1810 template <
1811 typename U = T,
1812 std::enable_if_t<Vectorized<U>::float_num_vecs() == 1, int> = 0>
1813 static Vectorized<T> quantize(
1814 const float_vec_return_type& rhs,
1815 float scale,
1816 int32_t zero_point,
1817 float inverse_scale) {
1818 Vectorized<float> vecf = rhs[0];
1819 vecf = vecf * Vectorized<float>(inverse_scale);
1820 vecf = vecf.rint() + Vectorized<float>((float)(zero_point));
1821 auto veci = zvec_convert_to_int(vecf);
1822
1823 return Vectorized<T>{veci};
1824 }
1825
1826 template <
1827 typename U = T,
1828 std::enable_if_t<Vectorized<U>::int_num_vecs() == 1, int> = 0>
1829 static Vectorized<T> requantize_from_int(
1830 const int_vec_return_type& inp,
1831 float multiplier,
1832 int32_t zero_point) {
1833 Vectorized<T> vi = inp[0];
1834 auto vecf = zvec_convert_to_float(vi.vec());
1835 vecf = vecf * Vectorized<float>(multiplier);
1836 vecf = vecf.rint();
1837 auto veci = zvec_convert_to_int(vecf) + Vectorized<int>(zero_point);
1838
1839 return Vectorized<T>{veci};
1840 }
1841
1842 template <
1843 typename U = T,
1844 std::enable_if_t<Vectorized<U>::int_num_vecs() == 4, int> = 0>
1845 int_vec_return_type widening_subtract(Vectorized<U> b) const {
1846 auto ret16 = unpack(_vec);
1847 auto ret16B = unpack(b.vec());
1848 auto ret32_0 = unpack(ret16.first);
1849 auto ret32_1 = unpack(ret16.second);
1850 auto ret32B_0 = unpack(ret16B.first);
1851 auto ret32B_1 = unpack(ret16B.second);
1852
1853 return {
1854 Vectorized<c10::qint32>(ret32_0.first - ret32B_0.first),
1855 Vectorized<c10::qint32>(ret32_0.second - ret32B_0.second),
1856 Vectorized<c10::qint32>(ret32_1.first - ret32B_1.first),
1857 Vectorized<c10::qint32>(ret32_1.second - ret32B_1.second)};
1858 }
1859
1860 template <
1861 typename U = T,
1862 std::enable_if_t<Vectorized<U>::float_num_vecs() == 4, int> = 0>
1863 float_vec_return_type C10_ALWAYS_INLINE dequantize(
1864 Vectorized<float> scale,
1865 Vectorized<float> zero_point,
1866 Vectorized<float> scale_zp_premul) const {
1867 // unpacking unsigned as signed
1868 auto ret16 = unpack(_vec);
1869 auto ret32_0 = unpack(ret16.first);
1870 auto ret32_1 = unpack(ret16.second);
1871
1872 auto vecf_0 = zvec_convert_to_float(ret32_0.first);
1873 auto vecf_1 = zvec_convert_to_float(ret32_0.second);
1874
1875 auto vecf_2 = zvec_convert_to_float(ret32_1.first);
1876 auto vecf_3 = zvec_convert_to_float(ret32_1.second);
1877 return {
1878 fmadd(scale, vecf_0, scale_zp_premul),
1879 fmadd(scale, vecf_1, scale_zp_premul),
1880 fmadd(scale, vecf_2, scale_zp_premul),
1881 fmadd(scale, vecf_3, scale_zp_premul)};
1882 }
1883
1884 template <
1885 typename U = T,
1886 std::enable_if_t<Vectorized<U>::float_num_vecs() == 4, int> = 0>
1887 float_vec_return_type dequantize(
1888 Vectorized<float> scale,
1889 Vectorized<float> zero_point) const {
1890 // unpacking unsigned as signed
1891 auto ret16 = unpack(_vec);
1892 auto ret32_0 = unpack(ret16.first);
1893 auto ret32_1 = unpack(ret16.second);
1894
1895 auto vecf_0 = zvec_convert_to_float(ret32_0.first);
1896 auto vecf_1 = zvec_convert_to_float(ret32_0.second);
1897
1898 auto vecf_2 = zvec_convert_to_float(ret32_1.first);
1899 auto vecf_3 = zvec_convert_to_float(ret32_1.second);
1900
1901 return {
1902 (vecf_0 - zero_point) * scale,
1903 (vecf_1 - zero_point) * scale,
1904 (vecf_2 - zero_point) * scale,
1905 (vecf_3 - zero_point) * scale };
1906 }
1907
1908 template <
1909 typename U = T,
1910 std::enable_if_t<Vectorized<U>::float_num_vecs() == 4, int> = 0>
1911 static Vectorized<T> quantize(
1912 const float_vec_return_type& rhs,
1913 float scale,
1914 int32_t zero_point,
1915 float inverse_scale) {
1916 auto vec_inverse = Vectorized<float>(inverse_scale);
1917 auto vec_zero_point = Vectorized<float>((float)zero_point);
1918
1919 auto vecf0 = rhs[0];
1920 auto vecf2 = rhs[1];
1921 auto vecf4 = rhs[2];
1922 auto vecf6 = rhs[3];
1923
1924 vecf0 = vecf0 * vec_inverse;
1925 vecf2 = vecf2 * vec_inverse;
1926 vecf4 = vecf4 * vec_inverse;
1927 vecf6 = vecf6 * vec_inverse;
1928
1929 vecf0 = vecf0.rint() + vec_zero_point;
1930 vecf2 = vecf2.rint() + vec_zero_point;
1931 vecf4 = vecf4.rint() + vec_zero_point;
1932 vecf6 = vecf6.rint() + vec_zero_point;
1933
1934 auto veci0 = zvec_convert_to_int(vecf0);
1935 auto veci2 = zvec_convert_to_int(vecf2);
1936 auto veci4 = zvec_convert_to_int(vecf4);
1937 auto veci6 = zvec_convert_to_int(vecf6);
1938
1939 auto vecshi0 = pack(veci0, veci2);
1940 auto vecshi2 = pack(veci4, veci6);
1941 auto ret = pack<int16_t, typename U::underlying>(vecshi0, vecshi2);
1942
1943 return Vectorized<T>{ret};
1944 }
1945
1946 template <
1947 typename U = T,
1948 std::enable_if_t<Vectorized<U>::int_num_vecs() == 4, int> = 0>
1949 static Vectorized<U> requantize_from_int(
1950 const int_vec_return_type& inp,
1951 float multiplier,
1952 int32_t zero_point) {
1953 Vectorized<float> vec_multiplier = Vectorized<float>(multiplier);
1954 Vectorized<int32_t> vec_zero_point = Vectorized<int32_t>(zero_point);
1955
1956 Vectorized<c10::qint32> vi0 = inp[0];
1957 Vectorized<c10::qint32> vi1 = inp[1];
1958 Vectorized<c10::qint32> vi2 = inp[2];
1959 Vectorized<c10::qint32> vi3 = inp[3];
1960
1961 auto vecf0 = zvec_convert_to_float(vi0.vec());
1962 auto vecf2 = zvec_convert_to_float(vi1.vec());
1963
1964 auto vecf4 = zvec_convert_to_float(vi2.vec());
1965 auto vecf6 = zvec_convert_to_float(vi3.vec());
1966
1967 vecf0 = vecf0 * vec_multiplier;
1968 vecf2 = vecf2 * vec_multiplier;
1969
1970 vecf4 = vecf4 * vec_multiplier;
1971 vecf6 = vecf6 * vec_multiplier;
1972
1973 vecf0 = vecf0.rint();
1974 vecf2 = vecf2.rint();
1975 vecf4 = vecf4.rint();
1976 vecf6 = vecf6.rint();
1977
1978 auto veci0 = zvec_convert_to_int(vecf0);
1979 auto veci2 = zvec_convert_to_int(vecf2);
1980 auto veci4 = zvec_convert_to_int(vecf4);
1981 auto veci6 = zvec_convert_to_int(vecf6);
1982
1983 veci0 = veci0 + vec_zero_point;
1984 veci2 = veci2 + vec_zero_point;
1985
1986 veci4 = veci4 + vec_zero_point;
1987 veci6 = veci6 + vec_zero_point;
1988
1989 auto vecshi0 = pack<int32_t, int16_t>(veci0, veci2);
1990 auto vecshi2 = pack<int32_t, int16_t>(veci4, veci6);
1991
1992 auto ret = pack<int16_t, typename U::underlying>(vecshi0, vecshi2);
1993
1994 return Vectorized<U>{ret};
1995 }
1996
1997 Vectorized<T> C10_ALWAYS_INLINE eq(const Vectorized<T>& other) const {
1998 return Vectorized<T>{_vec.eq(other._vec)};
1999 }
2000 Vectorized<T> C10_ALWAYS_INLINE ne(const Vectorized<T>& other) const {
2001 return Vectorized<T>{_vec.ne(other._vec)};
2002 }
2003 Vectorized<T> C10_ALWAYS_INLINE gt(const Vectorized<T>& other) const {
2004 return Vectorized<T>{_vec.gt(other._vec)};
2005 }
2006 Vectorized<T> C10_ALWAYS_INLINE ge(const Vectorized<T>& other) const {
2007 return Vectorized<T>{_vec.ge(other._vec)};
2008 }
2009 Vectorized<T> C10_ALWAYS_INLINE lt(const Vectorized<T>& other) const {
2010 return Vectorized<T>{_vec.lt(other._vec)};
2011 }
2012 Vectorized<T> C10_ALWAYS_INLINE le(const Vectorized<T>& other) const {
2013 return Vectorized<T>{_vec.le(other._vec)};
2014 }
2015
2016 Vectorized<T> clamp_min(const Vectorized<T>& min) const {
2017 return Vectorized<T>{_vec.clamp_min(min._vec)};
2018 }
2019
2020 Vectorized<T> clamp_max(const Vectorized<T>& max) const {
2021 return Vectorized<T>{_vec.clamp_max(max._vec)};
2022 }
2023
2024 Vectorized<T> minimum(const Vectorized<T>& other) const {
2025 return Vectorized<T>{_vec.minimum(other._vec)};
2026 }
2027
2028 Vectorized<T> maximum(const Vectorized<T>& other) const {
2029 return Vectorized<T>{_vec.maximum(other._vec)};
2030 }
2031 };
2032
2033 #define ZVECTOR_OPERATORS(typex) \
2034 template <> \
2035 Vectorized<typex> C10_ALWAYS_INLINE operator+(const Vectorized<typex>& a, const Vectorized<typex>& b) { \
2036 return Vectorized<typex>{a.vec() + b.vec()}; \
2037 } \
2038 \
2039 template <> \
2040 Vectorized<typex> C10_ALWAYS_INLINE operator-(const Vectorized<typex>& a, const Vectorized<typex>& b) { \
2041 return Vectorized<typex>{a.vec() - b.vec()}; \
2042 } \
2043 \
2044 template <> \
2045 Vectorized<typex> C10_ALWAYS_INLINE operator*(const Vectorized<typex>& a, const Vectorized<typex>& b) { \
2046 return Vectorized<typex>{a.vec() * b.vec()}; \
2047 } \
2048 \
2049 template <> \
2050 Vectorized<typex> C10_ALWAYS_INLINE operator/(const Vectorized<typex>& a, const Vectorized<typex>& b) { \
2051 return Vectorized<typex>{a.vec() / b.vec()}; \
2052 } \
2053 \
2054 template <> \
2055 Vectorized<typex> C10_ALWAYS_INLINE operator&(const Vectorized<typex>& a, const Vectorized<typex>& b) { \
2056 return Vectorized<typex>{a.vec() & b.vec()}; \
2057 } \
2058 \
2059 template <> \
2060 Vectorized<typex> C10_ALWAYS_INLINE operator|(const Vectorized<typex>& a, const Vectorized<typex>& b) { \
2061 return Vectorized<typex>{a.vec() | b.vec()}; \
2062 } \
2063 \
2064 template <> \
2065 Vectorized<typex> C10_ALWAYS_INLINE operator^(const Vectorized<typex>& a, const Vectorized<typex>& b) { \
2066 return Vectorized<typex>{a.vec() ^ b.vec()}; \
2067 } \
2068 \
2069 Vectorized<typex> C10_ALWAYS_INLINE operator==(const Vectorized<typex>& a, const Vectorized<typex>& b) { \
2070 return Vectorized<typex>{a.vec() == b.vec()}; \
2071 } \
2072 \
2073 Vectorized<typex> C10_ALWAYS_INLINE operator!=(const Vectorized<typex>& a, const Vectorized<typex>& b) { \
2074 return Vectorized<typex>{a.vec() != b.vec()}; \
2075 } \
2076 \
2077 Vectorized<typex> C10_ALWAYS_INLINE operator>(const Vectorized<typex>& a, const Vectorized<typex>& b) { \
2078 return Vectorized<typex>{a.vec() > b.vec()}; \
2079 } \
2080 \
2081 Vectorized<typex> C10_ALWAYS_INLINE operator>=(const Vectorized<typex>& a, const Vectorized<typex>& b) { \
2082 return Vectorized<typex>{a.vec() >= b.vec()}; \
2083 } \
2084 \
2085 Vectorized<typex> C10_ALWAYS_INLINE operator<(const Vectorized<typex>& a, const Vectorized<typex>& b) { \
2086 return Vectorized<typex>{a.vec() < b.vec()}; \
2087 } \
2088 \
2089 Vectorized<typex> C10_ALWAYS_INLINE operator<=(const Vectorized<typex>& a, const Vectorized<typex>& b) { \
2090 return Vectorized<typex>{a.vec() <= b.vec()}; \
2091 }
2092
2093 ZVECTOR_OPERATORS(c10::qint32)
2094 ZVECTOR_OPERATORS(c10::qint8)
2095 ZVECTOR_OPERATORS(c10::quint8)
2096
2097 #undef ZVECTOR_OPERATORS
2098
2099 DEFINE_CLAMP_MAXMIN_FUNCS(c10::quint8)
2100 DEFINE_CLAMP_MAXMIN_FUNCS(c10::qint8)
2101 DEFINE_CLAMP_MAXMIN_FUNCS(c10::qint32)
2102
2103 template <typename U = float>
2104 constexpr auto real_mask() {
2105 return (ZSimdVect<U>)ZSimdVectBinary<float>{0xFFFFFFFF, 0, 0xFFFFFFFF, 0};
2106 }
2107
2108 template <>
2109 constexpr auto real_mask<double>() {
2110 return (ZSimdVect<double>)ZSimdVectBinary<double>{0xFFFFFFFFFFFFFFFF, 0};
2111 }
2112
2113 template <typename U = float>
2114 constexpr auto image_mask() {
2115 return (ZSimdVect<U>)ZSimdVectBinary<U>{0, 0xFFFFFFFF, 0, 0xFFFFFFFF};
2116 }
2117
2118 template <>
2119 constexpr auto image_mask<double>() {
2120 return (ZSimdVect<double>)ZSimdVectBinary<double>{0, 0xFFFFFFFFFFFFFFFF};
2121 }
2122
2123 template <typename U = float>
2124 constexpr auto rsign_mask() {
2125 return ZSimdVect<U>{-0.f, 0.f, -0.f, 0.f};
2126 }
2127
2128 template <>
2129 constexpr auto rsign_mask<double>() {
2130 return ZSimdVect<double>{-0.0, 0.f};
2131 }
2132
2133 template <typename U = float>
2134 constexpr auto isign_mask() {
2135 return ZSimdVect<U>{0.0, -0.f, 0.0, -0.f};
2136 }
2137
2138 template <>
2139 constexpr auto isign_mask<double>() {
2140 return ZSimdVect<double>{0.0, -0.0};
2141 }
2142
2143 template <typename U = float>
2144 constexpr auto image_one() {
2145 return ZSimdVect<U>{0, 1.f, 0, 1.f};
2146 }
2147
2148 template <>
2149 constexpr auto image_one<double>() {
2150 return ZSimdVect<double>{0.0, 1.0};
2151 }
2152
2153 template <typename U = float>
2154 constexpr auto pi_half() {
2155 return ZSimdVect<U>{(float)(M_PI / 2.0), 0.f, (float)(M_PI / 2.0), 0.f};
2156 }
2157
2158 template <>
2159 constexpr auto pi_half<double>() {
2160 return ZSimdVect<double>{M_PI / 2.0, 0.0};
2161 }
2162
2163 template <typename U = float>
2164 constexpr auto image_half() {
2165 return ZSimdVect<U>{0, 0.5f, 0, 0.5f};
2166 }
2167
2168 template <>
2169 constexpr auto image_half<double>() {
2170 return ZSimdVect<double>{0.0, 0.5};
2171 }
2172
2173 template <typename U>
2174 constexpr U log2e_inv() {
2175 return static_cast<U>(1.4426950408889634);
2176 }
2177
2178 template <typename U>
2179 constexpr U log10e_inv() {
2180 return static_cast<U>(0.43429448190325176);
2181 }
2182
2183 template <typename T>
2184 struct Vectorized<T, std::enable_if_t<is_zarch_implemented_complex<T>()>> {
2185 public:
2186 using underline_type = decltype(std::declval<T>().imag());
2187 using value_type = T;
2188 using vtype = ZSimdVect<underline_type>;
2189 using vmaskType = ZSimdVectBinary<underline_type>;
2190 using vinner_type = Vectorized<underline_type>;
2191 using size_type = int;
2192 using vinner_data = typename Vectorized<underline_type>::vinner_data;
2193
2194 static constexpr size_type size() {
2195 return VECTOR_WIDTH / sizeof(value_type);
2196 }
2197
2198 private:
2199 vinner_type _vec;
2200
2201 public:
2202 Vectorized() {}
2203
2204 C10_ALWAYS_INLINE Vectorized(const vinner_data &v) : _vec{v.first, v.second} {}
2205
2206 template <typename U = T, std::enable_if_t<(sizeof(U) == 16), int> = 0>
2207 C10_ALWAYS_INLINE Vectorized(T s1, T s2)
2208 : _vec{s1.real(), s1.imag(), s2.real(), s2.imag()} {}
2209
2210 template <typename U = T, std::enable_if_t<(sizeof(U) == 8), int> = 0>
2211 C10_ALWAYS_INLINE Vectorized(T s1, T s2, T s3, T s4)
2212 : _vec{
2213 s1.real(),
2214 s1.imag(),
2215 s2.real(),
2216 s2.imag(),
2217 s3.real(),
2218 s3.imag(),
2219 s4.real(),
2220 s4.imag()} {}
2221
2222 template <typename U = T, std::enable_if_t<(sizeof(U) == 16), int> = 0>
2223 C10_ALWAYS_INLINE Vectorized(T s) : Vectorized<T>(s, s) {}
2224
2225 template <typename U = T, std::enable_if_t<(sizeof(U) == 8), int> = 0>
2226 C10_ALWAYS_INLINE Vectorized(T s) : Vectorized<T>(s, s, s, s) {}
2227
2228 C10_ALWAYS_INLINE operator vinner_type() const {
2229 return _vec;
2230 }
2231
2232 C10_ALWAYS_INLINE const vinner_type& vec() const {
2233 return _vec;
2234 }
2235
2236 C10_ALWAYS_INLINE operator vinner_data() const {
2237 return _vec.data();
2238 }
2239
2240 C10_ALWAYS_INLINE vinner_data data() const {
2241 return _vec.data();
2242 }
2243
2244 template <typename U>
2245 static Vectorized<T> C10_ALWAYS_INLINE
2246 loadu(const U* ptr, int count = size()) {
2247 return Vectorized<T>{vinner_type::loadu(ptr, 2 * count)};
2248 }
2249
2250 template <typename U>
2251 void C10_ALWAYS_INLINE store(U* ptr, int count = size()) const {
2252 return _vec.store(ptr, 2 * count);
2253 }
2254
2255 static Vectorized<T> blendv(
2256 const Vectorized<T>& a,
2257 const Vectorized<T>& b,
2258 const Vectorized<T>& mask) {
2259 // convert std::complex<V> index mask to V index mask: xy -> xxyy
2260 vinner_type vmask = mask.vec();
2261 auto mask_complex = vinner_type(
2262 vec_mergeh(vmask.vec0(), vmask.vec0()),
2263 vec_mergeh(vmask.vec1(), vmask.vec1()));
2264 return Vectorized<T>{vinner_type::blendv(a.vec(), b.vec(), mask_complex)};
2265 }
2266
2267 template <int64_t mask>
2268 static auto C10_ALWAYS_INLINE
2269 blend(const Vectorized<T>& a, const Vectorized<T>& b) {
2270 constexpr int mask_complex = maskForComplex<sizeof(T)>(mask);
2271 return Vectorized<T>{
2272 vinner_type::template blend<mask_complex>(a.vec(), b.vec())};
2273 }
2274
2275 template <typename step_t, typename U = T>
2276 static std::enable_if_t<sizeof(U) == 16, Vectorized<T>> arange(
2277 T base = 0,
2278 step_t step = static_cast<step_t>(1)) {
2279 return Vectorized<T>(base, base + step);
2280 }
2281
2282 template <typename step_t, typename U = T>
2283 static std::enable_if_t<sizeof(U) == 8, Vectorized<T>> arange(
2284 T base = 0,
2285 step_t step = static_cast<step_t>(1)) {
2286 return Vectorized<T>(
2287 base,
2288 base + step,
2289 base + value_type(2) * step,
2290 base + value_type(3) * step);
2291 }
2292
2293 template <int16_t Z, int16_t C>
2294 static inline std::enable_if_t<(Z >= C), Vectorized<T>> set_inner(
2295 const Vectorized<T>& a,
2296 const Vectorized<T>& b,
2297 size_t count) {
2298 return b;
2299 }
2300
2301 template <int16_t Z, int16_t C>
2302 static inline std::enable_if_t<(Z < C), Vectorized<T>> set_inner(
2303 const Vectorized<T>& a,
2304 const Vectorized<T>& b,
2305 size_t count) {
2306 if (count == Z)
2307 return blend<allbitset(Z)>(a, b);
2308 else
2309 return set_inner<Z + 1, C>(a, b, count);
2310 }
2311
2312 static Vectorized<T> set(
2313 const Vectorized<T>& a,
2314 const Vectorized<T>& b,
2315 size_t count = size()) {
2316 if (count == 0)
2317 return a;
2318 return set_inner<1, size()>(a, b, count);
2319 }
2320
2321 const T& operator[](int idx) const = delete;
2322 T& operator[](int idx) = delete;
2323
2324 template <
2325 typename U = T,
2326 std::enable_if_t<std::is_same<U, c10::complex<float>>::value, int> = 0>
2327 Vectorized<T> mapOrdinary(T (*const f)(const T&)) const {
2328 auto v0 = _vec.vec0();
2329 auto v1 = _vec.vec1();
2330 return Vectorized<T>{
2331 f(T(v0[0], v0[1])),
2332 f(T(v0[2], v0[3])),
2333 f(T(v1[0], v1[1])),
2334 f(T(v1[2], v1[3]))};
2335 }
2336
2337 template <
2338 typename U = T,
2339 std::enable_if_t<std::is_same<U, c10::complex<double>>::value, int> = 0>
2340 Vectorized<U> mapOrdinary(T (*const f)(const T&)) const {
2341 auto v0 = _vec.vec0();
2342 auto v1 = _vec.vec1();
2343 return Vectorized<T>{f(T(v0[0], v0[1])), f(T(v1[0], v1[1]))};
2344 }
2345
2346 template <
2347 typename U = T,
2348 std::enable_if_t<std::is_same<U, c10::complex<float>>::value, int> = 0>
2349 Vectorized<T> mapOrdinary(T (*const f)(T)) const {
2350 auto v0 = _vec.vec0();
2351 auto v1 = _vec.vec1();
2352 return Vectorized<T>{
2353 f(T(v0[0], v0[1])),
2354 f(T(v0[2], v0[3])),
2355 f(T(v1[0], v1[1])),
2356 f(T(v1[2], v1[3]))};
2357 }
2358
2359 template <
2360 typename U = T,
2361 std::enable_if_t<std::is_same<U, c10::complex<double>>::value, int> = 0>
2362 Vectorized<T> mapOrdinary(T (*const f)(T)) const {
2363 auto v0 = _vec.vec0();
2364 auto v1 = _vec.vec1();
2365 return Vectorized<T>{f(T(v0[0], v0[1])), f(T(v1[0], v1[1]))};
2366 }
2367
2368 template <
2369 typename U = T,
2370 std::enable_if_t<std::is_same<U, c10::complex<float>>::value, int> = 0>
2371 inline Vectorized<T> mapOrdinary(
2372 T (*const f)(const T&, const T&),
2373 const Vectorized<T>& b) const {
2374 auto v0 = _vec.vec0();
2375 auto v1 = _vec.vec1();
2376 auto bvec = b.vec();
2377 auto b0 = bvec.vec0();
2378 auto b1 = bvec.vec1();
2379 T a00 = f(T(v0[0], v0[1]), T(b0[0], b0[1]));
2380 T a01 = f(T(v0[2], v0[3]), T(b0[2], b0[3]));
2381 T a02 = f(T(v1[0], v1[1]), T(b1[0], b1[1]));
2382 T a03 = f(T(v1[2], v1[3]), T(b1[2], b1[3]));
2383 return Vectorized<T>{a00, a01, a02, a03};
2384 }
2385
2386 template <
2387 typename U = T,
2388 std::enable_if_t<std::is_same<U, c10::complex<double>>::value, int> = 0>
2389 inline Vectorized<T> mapOrdinary(
2390 T (*const f)(const T&, const T&),
2391 const Vectorized<T>& b) const {
2392 auto v0 = _vec.vec0();
2393 auto v1 = _vec.vec1();
2394 auto bvec = b.vec();
2395 auto b0 = bvec.vec0();
2396 auto b1 = bvec.vec1();
2397 U a00 = f(U(v0[0], v0[1]), U(b0[0], b0[1]));
2398 U a01 = f(U(v1[0], v1[1]), U(b1[0], b1[1]));
2399 return Vectorized<T>{a00, a01};
2400 }
2401
2402 template <
2403 typename U = T,
2404 std::enable_if_t<std::is_same<U, c10::complex<float>>::value, int> = 0>
2405 static typename Vectorized<T>::vinner_type real_neg(const typename Vectorized<T>::vinner_type &a)
2406 {
2407 const auto swap_mask = ZSimdVectBinary<uint8_t>{
2408 0, 1, 2, 3, 20, 21, 22, 23, 8, 9, 10, 11, 28, 29, 30, 31};
2409
2410 auto a_neg = a.neg();
2411 vtype v0 = vec_perm(a_neg.vec0(), a.vec0(), swap_mask);
2412 vtype v1 = vec_perm(a_neg.vec1(), a.vec1(), swap_mask);
2413 return {v0, v1};
2414 }
2415
2416 template <
2417 typename U = T,
2418 std::enable_if_t<std::is_same<U, c10::complex<double>>::value, int> = 0>
2419 static typename Vectorized<T>::vinner_type real_neg(const typename Vectorized<T>::vinner_type &a)
2420 {
2421 auto a_neg = a.neg();
2422 auto v0 = vec_permi(a_neg.vec0(), a.vec0(), 1);
2423 auto v1 = vec_permi(a_neg.vec1(), a.vec1(), 1);
2424 return { v0, v1 };
2425 }
2426
2427 Vectorized<T> angle2_() const {
2428 auto b_a = _vec.swapped(); // b a
2429 return Vectorized<T>{_vec.atan2(b_a).swapped()};
2430 }
2431
2432 Vectorized<T> angle() const {
2433 return angle2_().real();
2434 }
2435
2436 Vectorized<T> atan() const {
2437 // atan(x) = i/2 * ln((i + z)/(i - z))
2438 auto ione = Vectorized<T>{vinner_type(image_one<underline_type>())};
2439 auto sum = ione + *this;
2440 auto sub = ione - *this;
2441 auto ln = (sum / sub).log(); // ln((i + z)/(i - z))
2442 return ln *
2443 Vectorized<T>{vinner_type(image_half<underline_type>())}; // i/2*ln()
2444 }
2445
2446 Vectorized<T> atanh() const {
2447 return mapOrdinary(std::atanh);
2448 }
2449
2450 Vectorized<T> asin() const {
2451 // asin(x)
2452 // = -i*ln(iz + sqrt(1 -z^2))
2453 // = -i*ln((ai - b) + sqrt(1 - (a + bi)*(a + bi)))
2454 // = -i*ln((-b + ai) + sqrt(1 - (a**2 - b**2) - 2*abi))
2455 #if 1
2456 vinner_type cnj = conj().vec();
2457 vinner_type b_a = cnj.swapped();
2458 vinner_type ab = cnj * b_a;
2459 vinner_type im = ab + ab;
2460 vinner_type val_2 = _vec * _vec;
2461 vinner_type val_2_swapped = val_2.swapped();
2462 vinner_type re = vinner_type::horizontal_sub_perm(val_2, val_2_swapped);
2463 re = vinner_type(static_cast<underline_type>(1)) - re;
2464 constexpr int blend_mask =
2465 blend_choice<T>(); // 0x0A for complex<double> , 0xAA for complex<float>
2466 vinner_type blendx = vinner_type::template blend<blend_mask>(re, im);
2467 auto root = Vectorized<T>(blendx).sqrt();
2468 auto ln = Vectorized<T>(Vectorized<T>(b_a) + root).log();
2469 return Vectorized<T>(ln.vec().swapped()).conj();
2470 #else
2471 return mapOrdinary(std::asin);
2472 #endif
2473 }
2474
2475 Vectorized<T> acos() const {
2476 // acos(x) = pi/2 - asin(x)
2477 return Vectorized<T>(vinner_type(pi_half<underline_type>())) - asin();
2478 }
2479
2480 Vectorized<T> sin() const {
2481 return mapOrdinary(std::sin);
2482 }
2483 Vectorized<T> sinh() const {
2484 return mapOrdinary(std::sinh);
2485 }
2486 Vectorized<T> cos() const {
2487 return mapOrdinary(std::cos);
2488 }
2489 Vectorized<T> cosh() const {
2490 return mapOrdinary(std::cosh);
2491 }
2492 Vectorized<T> ceil() const {
2493 return Vectorized<T>{_vec.ceil()};
2494 }
2495 Vectorized<T> floor() const {
2496 return Vectorized<T>{_vec.floor()};
2497 }
2498 Vectorized<T> neg() const {
2499 return Vectorized<T>(_vec.neg());
2500 }
2501 Vectorized<T> round() const {
2502 return Vectorized<T>{_vec.round()};
2503 }
2504 Vectorized<T> tan() const {
2505 return mapOrdinary(std::tan);
2506 }
2507 Vectorized<T> tanh() const {
2508 return mapOrdinary(std::tanh);
2509 }
2510 Vectorized<T> trunc() const {
2511 return Vectorized<T>{_vec.trunc()};
2512 }
2513
2514 Vectorized<T> C10_ALWAYS_INLINE eq(const Vectorized<T>& other) const {
2515 auto eq = _vec.eq(other._vec); // compares real and imag individually
2516 // If both real numbers and imag numbers are equal, then the complex numbers are equal
2517 auto real = eq & vinner_type(real_mask<underline_type>());
2518 auto imag = (eq & vinner_type(image_mask<underline_type>())).swapped();
2519 return Vectorized<T>{real & imag};
2520 }
2521 Vectorized<T> C10_ALWAYS_INLINE ne(const Vectorized<T>& other) const {
2522 auto ne = _vec.ne(other._vec); // compares real and imag individually
2523 // If either real numbers or imag numbers are not equal, then the complex numbers are not equal
2524 auto real = ne & vinner_type(real_mask<underline_type>());
2525 auto imag = (ne & vinner_type(image_mask<underline_type>())).swapped();
2526 return Vectorized<T>{real | imag};
2527 }
2528
2529 Vectorized<T> real() const {
2530 return Vectorized<T>(_vec & vinner_type(real_mask<underline_type>()));
2531 }
2532 Vectorized<T> imag_() const {
2533 return Vectorized<T>(_vec & vinner_type(image_mask<underline_type>()));
2534 }
2535 Vectorized<T> imag() const {
2536 return Vectorized<T>{
2537 (_vec & vinner_type(image_mask<underline_type>())).swapped()};
2538 }
2539
2540 Vectorized<T> conj() const {
2541 return Vectorized<T>(_vec ^ vinner_type(isign_mask<underline_type>()));
2542 }
2543
2544 vinner_data abs_2_() const {
2545 auto a = _vec * _vec;
2546 a = a + a.swapped();
2547 return a.mergee().data();
2548 }
2549
2550 static T abs_helper(const T &value)
2551 {
2552 return T(std::abs(value));
2553 }
2554
2555 Vectorized<T> abs() const {
2556 return mapOrdinary(abs_helper);
2557 }
2558
2559 Vectorized<T> exp() const {
2560 return mapOrdinary(std::exp);
2561 }
2562
2563 Vectorized<T> exp2() const {
2564 return mapOrdinary(exp2_impl);
2565 }
2566
2567 Vectorized<T> expm1() const {
2568 return mapOrdinary(std::expm1);
2569 }
2570
2571 Vectorized<T> log() const {
2572 return mapOrdinary(std::log);
2573 }
2574
2575 Vectorized<T> log2() const {
2576 // log2eB_inv
2577 auto ret = log();
2578 return Vectorized<T>{ret._vec * vinner_type(log2e_inv<underline_type>())};
2579 }
2580
2581 Vectorized<T> log10() const {
2582 auto ret = log();
2583 return Vectorized<T>{ret._vec * vinner_type(log10e_inv<underline_type>())};
2584 }
2585
2586 Vectorized<T> log1p() const {
2587 return mapOrdinary(std::log1p);
2588 }
2589
2590 Vectorized<T> sgn() const {
2591 return mapOrdinary(at::native::sgn_impl);
2592 }
2593
2594 Vectorized<T> pow(const Vectorized<T>& exp) const {
2595 return mapOrdinary(std::pow, exp);
2596 }
2597
2598 Vectorized<T> sqrt() const {
2599 return mapOrdinary(std::sqrt);
2600 }
2601
2602 Vectorized<T> reciprocal() const {
2603 // re + im*i = (a + bi) / (c + di)
2604 // re = (ac + bd)/abs_2() = c/abs_2()
2605 // im = (bc - ad)/abs_2() = d/abs_2()
2606 vinner_type c_d = _vec ^ vinner_type(isign_mask<underline_type>());
2607 vinner_type abs = abs_2_();
2608 return Vectorized<T>{c_d / abs};
2609 }
2610
2611 Vectorized<T> rsqrt() const {
2612 return sqrt().reciprocal();
2613 }
2614
2615 Vectorized<T> lt(const Vectorized<T>& other) const {
2616 TORCH_CHECK(false, "not supported for complex numbers");
2617 }
2618
2619 Vectorized<T> le(const Vectorized<T>& other) const {
2620 TORCH_CHECK(false, "not supported for complex numbers");
2621 }
2622
2623 Vectorized<T> gt(const Vectorized<T>& other) const {
2624 TORCH_CHECK(false, "not supported for complex numbers");
2625 }
2626
2627 Vectorized<T> ge(const Vectorized<T>& other) const {
2628 TORCH_CHECK(false, "not supported for complex numbers");
2629 }
2630 };
2631
2632 #define ZVECTOR_OPERATORS(typex) \
2633 template <> \
2634 Vectorized<typex> C10_ALWAYS_INLINE operator+(const Vectorized<typex>& a, const Vectorized<typex>& b) { \
2635 return Vectorized<typex>{a.vec() + b.vec()}; \
2636 } \
2637 \
2638 template <> \
2639 Vectorized<typex> C10_ALWAYS_INLINE operator-(const Vectorized<typex>& a, const Vectorized<typex>& b) { \
2640 return Vectorized<typex>{a.vec() - b.vec()}; \
2641 } \
2642 \
2643 template <> \
2644 Vectorized<typex> inline operator*(const Vectorized<typex>& a, const Vectorized<typex>& b) { \
2645 /* (a + bi) * (c + di) = (ac - bd) + (ad + bc)i */ \
2646 Vectorized<typex>::vinner_type bv = b.vec(); \
2647 \
2648 /* this is more z arch friendly than simulating horizontal from x86 */ \
2649 Vectorized<typex>::vinner_type vi = bv.mergeo(); \
2650 Vectorized<typex>::vinner_type vr = bv.mergee(); \
2651 vi = vi ^ Vectorized<typex>::vinner_type(rsign_mask<Vectorized<typex>::underline_type>()); \
2652 Vectorized<typex>::vinner_type ret = a.vec() * vr; \
2653 Vectorized<typex>::vinner_type vx_swapped = a.vec().swapped(); \
2654 ret = fmadd(vx_swapped, vi, ret); \
2655 \
2656 return Vectorized<typex>{ret}; \
2657 } \
2658 \
2659 template <> \
2660 Vectorized<typex> inline operator/(const Vectorized<typex>& a, const Vectorized<typex>& b) { \
2661 /* Unfortunately, this breaks some tests */ \
2662 /* Implement it like it's done for avx2 */ \
2663 auto fabs_cd = b.vec().abs(); /* |c| |d| */ \
2664 auto fabs_dc = fabs_cd.swapped(); /* |d| |c| */ \
2665 auto scale = Vectorized<typex>::vinner_type {1.0} / maximum(fabs_cd, fabs_dc); /* 1/sc 1/sc */ \
2666 auto a2 = a.vec() * scale; /* a/sc b/sc */ \
2667 auto b2 = b.vec() * scale; /* c/sc d/sc */ \
2668 auto acbd2 = a2 * b2; /* ac/sc^2 bd/sc^2 */ \
2669 \
2670 auto dc2 = b2.swapped(); /* d/sc c/sc */ \
2671 dc2 = Vectorized<typex>::real_neg(dc2); /* -d/|c,d| c/sc */ \
2672 auto adbc2 = a2 * dc2; /* -ad/sc^2 bc/sc^2 */ \
2673 auto sum1 = acbd2 + acbd2.swapped(); /* (ac+bd)/sc^2 (ac+bd)/sc^2 */ \
2674 auto sum2 = adbc2 + adbc2.swapped(); /* (bc-ad)/sc^2 (bc-ad)/sc^2 */ \
2675 auto res2 = Vectorized<typex>::vinner_type::mergee(sum1, sum2); /* (ac+bd)/sc^2 (bc-ad)/sc^2 */ \
2676 \
2677 /* get the denominator */ \
2678 Vectorized<typex>::vinner_type denom2 = Vectorized<typex>{b2}.abs_2_(); /* (c^2+d^2)/sc^2 (c^2+d^2)/sc^2 */ \
2679 res2 = res2 / denom2; \
2680 return Vectorized<typex>{ res2 }; \
2681 } \
2682 \
2683 template <> \
2684 Vectorized<typex> C10_ALWAYS_INLINE operator&(const Vectorized<typex>& a, const Vectorized<typex>& b) { \
2685 return Vectorized<typex>{a.vec() & b.vec()}; \
2686 } \
2687 \
2688 template <> \
2689 Vectorized<typex> C10_ALWAYS_INLINE operator|(const Vectorized<typex>& a, const Vectorized<typex>& b) { \
2690 return Vectorized<typex>{a.vec() | b.vec()}; \
2691 } \
2692 \
2693 template <> \
2694 Vectorized<typex> C10_ALWAYS_INLINE operator^(const Vectorized<typex>& a, const Vectorized<typex>& b) { \
2695 return Vectorized<typex>{a.vec() ^ b.vec()}; \
2696 } \
2697 \
2698 Vectorized<typex> C10_ALWAYS_INLINE operator==(const Vectorized<typex>& a, const Vectorized<typex>& b) { \
2699 return Vectorized<typex>{a.vec() == b.vec()}; \
2700 } \
2701 \
2702 Vectorized<typex> C10_ALWAYS_INLINE operator!=(const Vectorized<typex>& a, const Vectorized<typex>& b) { \
2703 return Vectorized<typex>{a.vec() != b.vec()}; \
2704 } \
2705 \
2706 Vectorized<typex> C10_ALWAYS_INLINE operator<(const Vectorized<typex>& a, const Vectorized<typex>& b) { \
2707 TORCH_CHECK(false, "not supported for complex numbers"); \
2708 } \
2709 \
2710 Vectorized<typex> C10_ALWAYS_INLINE operator<=(const Vectorized<typex>& a, const Vectorized<typex>& b) { \
2711 TORCH_CHECK(false, "not supported for complex numbers"); \
2712 } \
2713 \
2714 Vectorized<typex> C10_ALWAYS_INLINE operator>(const Vectorized<typex>& a, const Vectorized<typex>& b) { \
2715 TORCH_CHECK(false, "not supported for complex numbers"); \
2716 } \
2717 \
2718 Vectorized<typex> C10_ALWAYS_INLINE operator>=(const Vectorized<typex>& a, const Vectorized<typex>& b) { \
2719 TORCH_CHECK(false, "not supported for complex numbers"); \
2720 }
2721
2722 ZVECTOR_OPERATORS(c10::complex<float>)
2723 ZVECTOR_OPERATORS(c10::complex<double>)
2724
2725 #undef ZVECTOR_OPERATORS
2726
2727 template <typename T, std::enable_if_t<(sizeof(T) == 8), int> = 0>
2728 std::pair<Vectorized<T>, Vectorized<T>> inline inner_interleave2(
2729 const Vectorized<T>& a,
2730 const Vectorized<T>& b) {
2731 // inputs:
2732 // a = {a0, a1, a2, a3}
2733 // b = {b0, b1, b2, b3}
2734 using vtype = typename Vectorized<T>::vtype;
2735 vtype ab00 = vec_permi(a.vec0(), b.vec0(), 0);
2736 vtype ab11 = vec_permi(a.vec0(), b.vec0(), 3);
2737 vtype ab2_00 = vec_permi(a.vec1(), b.vec1(), 0);
2738 vtype ab2_11 = vec_permi(a.vec1(), b.vec1(), 3);
2739 // return {a0, b0, a1, b1}
2740 // {a2, b2, a3, b3}
2741 return std::make_pair(
2742 Vectorized<T>{ab00, ab11}, Vectorized<T>{ab2_00, ab2_11});
2743 }
2744
2745 template <typename T, std::enable_if_t<(sizeof(T) == 8), int> = 0>
2746 std::pair<Vectorized<T>, Vectorized<T>> inline inner_deinterleave2(
2747 const Vectorized<T>& a,
2748 const Vectorized<T>& b) {
2749 // inputs:
2750 // a = {a0, b0, a1, b1}
2751 // b = {a2, b2, a3, b3}
2752 using vtype = typename Vectorized<T>::vtype;
2753 vtype aa01 = vec_permi(a.vec0(), a.vec1(), 0);
2754 vtype aa23 = vec_permi(b.vec0(), b.vec1(), 0);
2755
2756 vtype bb_01 = vec_permi(a.vec0(), a.vec1(), 3);
2757 vtype bb_23 = vec_permi(b.vec0(), b.vec1(), 3);
2758
2759 // swap lanes:
2760 // return {a0, a1, a2, a3}
2761 // {b0, b1, b2, b3}
2762 return std::make_pair(Vectorized<T>{aa01, aa23}, Vectorized<T>{bb_01, bb_23});
2763 }
2764
2765 template <typename T, std::enable_if_t<(sizeof(T) == 4), int> = 0>
2766 std::pair<Vectorized<T>, Vectorized<T>> inline inner_interleave2(
2767 const Vectorized<T>& a,
2768 const Vectorized<T>& b) {
2769 // inputs:
2770 // a = {a0, a1, a2, a3,, a4, a5, a6, a7}
2771 // b = {b0, b1, b2, b3,, b4, b5, b6, b7}
2772 using vtype = typename Vectorized<T>::vtype;
2773 vtype ab0011 = vec_mergeh(a.vec0(), b.vec0());
2774 vtype ab2233 = vec_mergel(a.vec0(), b.vec0());
2775
2776 vtype ab2_0011 = vec_mergeh(a.vec1(), b.vec1());
2777 vtype ab2_2233 = vec_mergel(a.vec1(), b.vec1());
2778 // group cols crossing lanes:
2779 // return {a0, b0, a1, b1,, a2, b2, a3, b3}
2780 // {a4, b4, a5, b5,, a6, b6, a7, b7}
2781
2782 return std::make_pair(
2783 Vectorized<T>{ab0011, ab2233}, Vectorized<T>{ab2_0011, ab2_2233});
2784 }
2785
2786 template <typename T, std::enable_if_t<(sizeof(T) == 4), int> = 0>
2787 std::pair<Vectorized<T>, Vectorized<T>> inline inner_deinterleave2(
2788 const Vectorized<T>& a,
2789 const Vectorized<T>& b) {
2790 // inputs:
2791 // a = {a0, b0, a1, b1,, a2, b2, a3, b3}
2792 // b = {a4, b4, a5, b5,, a6, b6, a7, b7}
2793 using vtype = typename Vectorized<T>::vtype;
2794 // {a0,a2,b0,b2} {a1,a3,b1,b3}
2795 vtype a0a2b0b2 = vec_mergeh(a.vec0(), a.vec1());
2796 vtype a1a3b1b3 = vec_mergel(a.vec0(), a.vec1());
2797
2798 vtype aa0123 = vec_mergeh(a0a2b0b2, a1a3b1b3);
2799 vtype bb0123 = vec_mergel(a0a2b0b2, a1a3b1b3);
2800
2801 vtype a0a2b0b2_2 = vec_mergeh(b.vec0(), b.vec1());
2802 vtype a1a3b1b3_2 = vec_mergel(b.vec0(), b.vec1());
2803
2804 vtype aa0123_2 = vec_mergeh(a0a2b0b2_2, a1a3b1b3_2);
2805 vtype bb0123_2 = vec_mergel(a0a2b0b2_2, a1a3b1b3_2);
2806
2807 // it could be done with vec_perm ,too
2808 // swap lanes:
2809 // return {a0, a1, a2, a3,, a4, a5, a6, a7}
2810 // {b0, b1, b2, b3,, b4, b5, b6, b7}
2811
2812 return std::make_pair(
2813 Vectorized<T>{aa0123, aa0123_2}, Vectorized<T>{bb0123, bb0123_2});
2814 }
2815
2816 template <>
2817 std::pair<Vectorized<float>, Vectorized<float>> inline interleave2<float>(
2818 const Vectorized<float>& a,
2819 const Vectorized<float>& b) {
2820 return inner_interleave2<float>(a, b);
2821 }
2822
2823 template <>
2824 std::pair<Vectorized<int32_t>, Vectorized<int32_t>> inline interleave2<int32_t>(
2825 const Vectorized<int32_t>& a,
2826 const Vectorized<int32_t>& b) {
2827 return inner_interleave2<int32_t>(a, b);
2828 }
2829
2830 template <>
2831 std::pair<Vectorized<double>, Vectorized<double>> inline interleave2<double>(
2832 const Vectorized<double>& a,
2833 const Vectorized<double>& b) {
2834 return inner_interleave2<double>(a, b);
2835 }
2836
2837 template <>
2838 std::pair<Vectorized<int64_t>, Vectorized<int64_t>> inline interleave2<int64_t>(
2839 const Vectorized<int64_t>& a,
2840 const Vectorized<int64_t>& b) {
2841 return inner_interleave2<int64_t>(a, b);
2842 }
2843
2844 template <>
2845 std::pair<Vectorized<float>, Vectorized<float>> inline deinterleave2<float>(
2846 const Vectorized<float>& a,
2847 const Vectorized<float>& b) {
2848 return inner_deinterleave2<float>(a, b);
2849 }
2850
2851 template <>
2852 std::pair<Vectorized<int32_t>, Vectorized<int32_t>> inline deinterleave2<
2853 int32_t>(const Vectorized<int32_t>& a, const Vectorized<int32_t>& b) {
2854 return inner_deinterleave2<int32_t>(a, b);
2855 }
2856
2857 template <>
2858 std::pair<Vectorized<double>, Vectorized<double>> inline deinterleave2<double>(
2859 const Vectorized<double>& a,
2860 const Vectorized<double>& b) {
2861 return inner_deinterleave2<double>(a, b);
2862 }
2863
2864 template <>
2865 std::pair<Vectorized<int64_t>, Vectorized<int64_t>> inline deinterleave2<
2866 int64_t>(const Vectorized<int64_t>& a, const Vectorized<int64_t>& b) {
2867 return inner_deinterleave2<int64_t>(a, b);
2868 }
2869
2870 template <typename T>
2871 typename std::enable_if<std::is_same<T, uint8_t>::value, at::vec::Vectorized<float>>::type
2872 inline convert_int8_to_float(const Vectorized<T> &src) {
2873 // Note: this function only convert inputs number of elements equal to at::vec::Vectorized<float>.size()
2874 // Only handle first 64 bits
2875 auto vec_int = src.to_vec_float_helper();
2876
2877 return zvec_convert_to_float(vec_int);
2878 }
2879
2880 template <typename T>
2881 typename std::enable_if<std::is_same<T, uint8_t>::value, at::vec::Vectorized<T>>::type
2882 inline convert_float_to_int8(const Vectorized<float> &src) {
2883 constexpr auto min_val = std::numeric_limits<T>::min();
2884 constexpr auto max_val = std::numeric_limits<T>::max();
2885
2886 auto vec_int = clamp(zvec_convert_to_int(src), Vectorized<int32_t>(min_val), Vectorized<int32_t>(max_val));
2887
2888 return vec_int.to_vec_uint8_helper();
2889 }
2890
2891 #undef DEFINE_CLAMP_MAXMIN_FUNCS
2892 #undef DEFINE_MAXMIN_FUNCS
2893 } // namespace
2894 } // namespace vec
2895 } // namespace at
2896