xref: /aosp_15_r20/external/pytorch/aten/src/ATen/cpu/vec/vec256/vec256_float_neon.h (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #pragma once
2 
3 // DO NOT DEFINE STATIC DATA IN THIS HEADER!
4 // See Note [Do not compile initializers with AVX]
5 
6 #include <ATen/cpu/vec/intrinsics.h>
7 #include <ATen/cpu/vec/vec_base.h>
8 #include <c10/util/irange.h>
9 
10 #if defined(__aarch64__) && defined(AT_BUILD_ARM_VEC256_WITH_SLEEF)
11 #include <sleef.h>
12 #endif
13 
14 // Sleef offers vectorized versions of some transcedentals
15 // such as sin, cos, tan etc..
16 // However for now opting for STL, since we are not building
17 // with Sleef for mobile yet.
18 
19 namespace at::vec {
20 // See Note [CPU_CAPABILITY namespace]
21 inline namespace CPU_CAPABILITY {
22 
23 // Right now contains only aarch64 implementation.
24 // Due to follow two reasons aarch32 is not currently supported.
25 // 1. Due to difference in ISA been aarch32 and aarch64, intrinsics
26 //    that work for aarch64 dont work for aarch32.
27 // 2. Android NDK r21 has problems with compiling aarch32.
28 //    Clang seg faults.
29 //    https://github.com/android/ndk/issues/1248
30 //    https://bugs.llvm.org/show_bug.cgi?id=45824
31 // Most likely we will do aarch32 support with inline asm.
32 #if defined(__aarch64__)
33 
34 #ifdef __BIG_ENDIAN__
35 #error "Big endian is not supported."
36 #endif
37 
38 #if defined(AT_BUILD_ARM_VEC256_WITH_SLEEF)
39 #define USE_SLEEF(sleef_code, non_sleef_code) sleef_code
40 #else
41 #define USE_SLEEF(sleef_code, non_sleef_code) non_sleef_code
42 #endif
43 
44 template<int index, bool mask_val>
45 struct BlendRegs {
46   static float32x4_t impl(
47     const float32x4_t& a, const float32x4_t& b, float32x4_t& res);
48 };
49 
50 template<int index>
51 struct BlendRegs<index, true>{
52   static float32x4_t impl(
53       const float32x4_t& a, const float32x4_t& b, float32x4_t& res) {
54     return vsetq_lane_f32(vgetq_lane_f32(b, index), res, index);
55   }
56 };
57 
58 template<int index>
59 struct BlendRegs<index, false>{
60   static float32x4_t impl(
61       const float32x4_t& a, const float32x4_t& b, float32x4_t& res) {
62     return vsetq_lane_f32(vgetq_lane_f32(a, index), res, index);
63   }
64 };
65 
66 template <> class Vectorized<float> {
67 private:
68   float32x4x2_t values;
69 public:
70   using value_type = float;
71   using size_type = int;
72   static constexpr size_type size() {
73     return 8;
74   }
75   Vectorized() {}
76   Vectorized(float32x4x2_t v) : values(v) {}
77   Vectorized(float val) : values{vdupq_n_f32(val), vdupq_n_f32(val) } {}
78   Vectorized(float val0, float val1, float val2, float val3,
79          float val4, float val5, float val6, float val7) :
80          values{val0, val1, val2, val3, val4, val5, val6, val7} {}
81   Vectorized(float32x4_t val0, float32x4_t val1) : values{val0, val1} {}
82   operator float32x4x2_t() const {
83     return values;
84   }
85   template <int64_t mask>
86   static Vectorized<float> blend(const Vectorized<float>& a, const Vectorized<float>& b) {
87     Vectorized<float> vec;
88     // 0.
89     vec.values.val[0] =
90       BlendRegs<0, (mask & 0x01)!=0>::impl(
91           a.values.val[0], b.values.val[0], vec.values.val[0]);
92     vec.values.val[0] =
93       BlendRegs<1, (mask & 0x02)!=0>::impl(
94           a.values.val[0], b.values.val[0], vec.values.val[0]);
95     vec.values.val[0] =
96       BlendRegs<2, (mask & 0x04)!=0>::impl(
97           a.values.val[0], b.values.val[0], vec.values.val[0]);
98     vec.values.val[0] =
99       BlendRegs<3, (mask & 0x08)!=0>::impl(
100           a.values.val[0], b.values.val[0], vec.values.val[0]);
101     // 1.
102     vec.values.val[1] =
103       BlendRegs<0, (mask & 0x10)!=0>::impl(
104           a.values.val[1], b.values.val[1], vec.values.val[1]);
105     vec.values.val[1] =
106       BlendRegs<1, (mask & 0x20)!=0>::impl(
107           a.values.val[1], b.values.val[1], vec.values.val[1]);
108     vec.values.val[1] =
109       BlendRegs<2, (mask & 0x40)!=0>::impl(
110           a.values.val[1], b.values.val[1], vec.values.val[1]);
111     vec.values.val[1] =
112       BlendRegs<3, (mask & 0x80)!=0>::impl(
113           a.values.val[1], b.values.val[1], vec.values.val[1]);
114     return vec;
115   }
116   static Vectorized<float> blendv(const Vectorized<float>& a, const Vectorized<float>& b,
117                               const Vectorized<float>& mask) {
118     // TODO
119     // NB: This requires that each value, i.e., each uint value,
120     // of the mask either all be zeros or all be 1s.
121     // We perhaps need some kind of an assert?
122     // But that will affect performance.
123     Vectorized<float> vec(mask.values);
124     vec.values.val[0] = vbslq_f32(
125         vreinterpretq_u32_f32(vec.values.val[0]),
126         b.values.val[0],
127         a.values.val[0]);
128     vec.values.val[1] = vbslq_f32(
129         vreinterpretq_u32_f32(vec.values.val[1]),
130         b.values.val[1],
131         a.values.val[1]);
132     return vec;
133   }
134   template<typename step_t>
135   static Vectorized<float> arange(float base = 0.f, step_t step = static_cast<step_t>(1)) {
136     const Vectorized<float> base_vec(base);
137     const Vectorized<float> step_vec(step);
138     const Vectorized<float> step_sizes(0, 1, 2, 3, 4, 5, 6, 7);
139     return fmadd(step_sizes, step_vec, base_vec);
140   }
141   static Vectorized<float> set(const Vectorized<float>& a, const Vectorized<float>& b,
142                            int64_t count = size()) {
143     switch (count) {
144       case 0:
145         return a;
146       case 1:
147         {
148           Vectorized<float> vec;
149           static uint32x4_t mask_low = {0xFFFFFFFF, 0x0, 0x0, 0x0};
150           vec.values.val[0] = vreinterpretq_f32_u32(mask_low);
151           vec.values.val[1] = a.values.val[1];
152           vec.values.val[0] = vbslq_f32(
153               vreinterpretq_u32_f32(vec.values.val[0]),
154               b.values.val[0],
155               a.values.val[0]);
156           return vec;
157         }
158       case 2:
159         {
160           Vectorized<float> vec;
161           static uint32x4_t mask_low = {0xFFFFFFFF, 0xFFFFFFFF, 0x0, 0x0};
162           vec.values.val[0] = vreinterpretq_f32_u32(mask_low);
163           vec.values.val[1] = a.values.val[1];
164           vec.values.val[0] = vbslq_f32(
165               vreinterpretq_u32_f32(vec.values.val[0]),
166               b.values.val[0],
167               a.values.val[0]);
168           return vec;
169         }
170       case 3:
171         {
172           Vectorized<float> vec;
173           static uint32x4_t mask_low = {0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0x0};
174           vec.values.val[0] = vreinterpretq_f32_u32(mask_low);
175           vec.values.val[1] = a.values.val[1];
176           vec.values.val[0] = vbslq_f32(
177               vreinterpretq_u32_f32(vec.values.val[0]),
178               b.values.val[0],
179               a.values.val[0]);
180           return vec;
181         }
182       case 4:
183         return Vectorized<float>(b.values.val[0], a.values.val[1]);
184       case 5:
185         {
186           Vectorized<float> vec;
187           static uint32x4_t mask_high = {0xFFFFFFFF, 0x0, 0x0, 0x0};
188           vec.values.val[0] = b.values.val[0];
189           vec.values.val[1] = vreinterpretq_f32_u32(mask_high);
190           vec.values.val[1] = vbslq_f32(
191               vreinterpretq_u32_f32(vec.values.val[1]),
192               b.values.val[1],
193               a.values.val[1]);
194           return vec;
195         }
196       case 6:
197         {
198           Vectorized<float> vec;
199           static uint32x4_t mask_high = {0xFFFFFFFF, 0xFFFFFFFF, 0x0, 0x0};
200           vec.values.val[0] = b.values.val[0];
201           vec.values.val[1] = vreinterpretq_f32_u32(mask_high);
202           vec.values.val[1] = vbslq_f32(
203               vreinterpretq_u32_f32(vec.values.val[1]),
204               b.values.val[1],
205               a.values.val[1]);
206           return vec;
207         }
208       case 7:
209         {
210           Vectorized<float> vec;
211           static uint32x4_t mask_high = {0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0x0};
212           vec.values.val[0] = b.values.val[0];
213           vec.values.val[1] = vreinterpretq_f32_u32(mask_high);
214           vec.values.val[1] = vbslq_f32(
215               vreinterpretq_u32_f32(vec.values.val[1]),
216               b.values.val[1],
217               a.values.val[1]);
218           return vec;
219         }
220     }
221     return b;
222   }
223   static Vectorized<float> loadu(const void* ptr, int64_t count = size()) {
224     if (count == size()) {
225       return vld1q_f32_x2(reinterpret_cast<const float*>(ptr));
226     }
227     else if (count == (size() >> 1)) {
228       Vectorized<float> res;
229       res.values.val[0] = vld1q_f32(reinterpret_cast<const float*>(ptr));
230       res.values.val[1] = vdupq_n_f32(0.f);
231       return res;
232     }
233     else {
234       __at_align__ float tmp_values[size()];
235       for (const auto i : c10::irange(size())) {
236         tmp_values[i] = 0.0;
237       }
238       std::memcpy(
239           tmp_values,
240           reinterpret_cast<const float*>(ptr),
241           count * sizeof(float));
242       return vld1q_f32_x2(reinterpret_cast<const float*>(tmp_values));
243     }
244   }
245   void store(void* ptr, int64_t count = size()) const {
246     if (count == size()) {
247       vst1q_f32_x2(reinterpret_cast<float*>(ptr), values);
248     }
249     else if (count == (size() >> 1)) {
250       vst1q_f32(reinterpret_cast<float*>(ptr), values.val[0]);
251     }
252     else {
253       float tmp_values[size()];
254       vst1q_f32_x2(reinterpret_cast<float*>(tmp_values), values);
255       std::memcpy(ptr, tmp_values, count * sizeof(float));
256     }
257   }
258   inline const float32x4_t& get_low() const {
259     return values.val[0];
260   }
261   inline float32x4_t& get_low() {
262     return values.val[0];
263   }
264   inline const float32x4_t& get_high() const {
265     return values.val[1];
266   }
267   inline float32x4_t& get_high() {
268     return values.val[1];
269   }
270   // Very slow implementation of indexing.
271   // Only required because vec256_qint refers to this.
272   // Once we specialize that implementation for ARM
273   // this should be removed. TODO (kimishpatel)
274   float operator[](int idx) const {
275     __at_align__ float tmp[size()];
276     store(tmp);
277     return tmp[idx];
278   }
279   float operator[](int idx) {
280     __at_align__ float tmp[size()];
281     store(tmp);
282     return tmp[idx];
283   }
284   // For boolean version where we want to if any 1/all zero
285   // etc. can be done faster in a different way.
286   int zero_mask() const {
287     __at_align__ float tmp[size()];
288     store(tmp);
289     int mask = 0;
290     for (int i = 0; i < size(); ++ i) {
291       if (tmp[i] == 0.f) {
292         mask |= (1 << i);
293       }
294     }
295     return mask;
296   }
297   Vectorized<float> isnan() const {
298     __at_align__ float tmp[size()];
299     __at_align__ float res[size()];
300     store(tmp);
301     for (const auto i : c10::irange(size())) {
302       if (_isnan(tmp[i])) {
303         std::memset(static_cast<void*>(&res[i]), 0xFF, sizeof(float));
304       } else {
305         std::memset(static_cast<void*>(&res[i]), 0, sizeof(float));
306       }
307     }
308     return loadu(res);
309   };
310   bool has_inf_nan() const {
311     __at_align__ float tmp[size()];
312     store(tmp);
313     for (const auto i : c10::irange(size())) {
314       if(_isnan(tmp[i]) || _isinf(tmp[i])) {
315         return true;
316       }
317     }
318     return false;
319   }
320   Vectorized<float> map(float (*const f)(float)) const {
321     __at_align__ float tmp[size()];
322     store(tmp);
323     for (const auto i : c10::irange(size())) {
324       tmp[i] = f(tmp[i]);
325     }
326     return loadu(tmp);
327   }
328   Vectorized<float> abs() const {
329     return Vectorized<float>(vabsq_f32(values.val[0]), vabsq_f32(values.val[1]));
330   }
331   Vectorized<float> angle() const {
332     auto zero = Vectorized<float>(0);
333     auto pi = Vectorized<float>(c10::pi<float>);
334     auto tmp = blendv(zero, pi, *this < zero);
335     return blendv(tmp, *this, isnan());
336   }
337   Vectorized<float> real() const {
338     return *this;
339   }
340   Vectorized<float> imag() const {
341     return Vectorized<float>(0.f);
342   }
343   Vectorized<float> conj() const {
344     return *this;
345   }
346   Vectorized<float> acos() const {
347     return USE_SLEEF(
348       Vectorized<float>(Sleef_acosf4_u10(values.val[0]), Sleef_acosf4_u10(values.val[1])),
349       map(std::acos)
350     );
351   }
352   Vectorized<float> acosh() const {
353     return USE_SLEEF(
354       Vectorized<float>(Sleef_acoshf4_u10(values.val[0]), Sleef_acoshf4_u10(values.val[1])),
355       map(std::acosh)
356     );
357   }
358   Vectorized<float> asin() const {
359     return USE_SLEEF(
360       Vectorized<float>(Sleef_asinf4_u10(values.val[0]), Sleef_asinf4_u10(values.val[1])),
361       map(std::asin)
362     );
363   }
364   Vectorized<float> atan() const {
365     return USE_SLEEF(
366       Vectorized<float>(Sleef_atanf4_u10(values.val[0]), Sleef_atanf4_u10(values.val[1])),
367       map(std::atan)
368     );
369   }
370   Vectorized<float> atanh() const {
371     return USE_SLEEF(
372       Vectorized<float>(Sleef_atanhf4_u10(values.val[0]), Sleef_atanhf4_u10(values.val[1])),
373       map(std::atanh)
374     );
375   }
376   Vectorized<float> atan2(const Vectorized<float> &exp) const {
377     USE_SLEEF(
378       {
379         return Vectorized<float>(Sleef_atan2f4_u10(values.val[0], exp.values.val[0]),
380                                  Sleef_atan2f4_u10(values.val[1], exp.values.val[1]));
381       },
382       {
383         __at_align__ float tmp[size()];
384         __at_align__ float tmp_exp[size()];
385         store(tmp);
386         exp.store(tmp_exp);
387         for (const auto i : c10::irange(size())) {
388           tmp[i] = std::atan2(tmp[i], tmp_exp[i]);
389         }
390         return loadu(tmp);
391       }
392     )
393   }
394   Vectorized<float> copysign(const Vectorized<float> &sign) const {
395     USE_SLEEF(
396       {
397         return Vectorized<float>(Sleef_copysignf4(values.val[0], sign.values.val[0]),
398                                  Sleef_copysignf4(values.val[1], sign.values.val[1]));
399       },
400       {
401         __at_align__ float tmp[size()];
402         __at_align__ float tmp_sign[size()];
403         store(tmp);
404         sign.store(tmp_sign);
405         for (size_type i = 0; i < size(); i++) {
406           tmp[i] = std::copysign(tmp[i], tmp_sign[i]);
407         }
408         return loadu(tmp);
409       }
410     )
411   }
412   Vectorized<float> erf() const;
413   Vectorized<float> erfc() const {
414     return USE_SLEEF(
415       Vectorized<float>(Sleef_erfcf4_u15(values.val[0]), Sleef_erfcf4_u15(values.val[1])),
416       map(std::erfc)
417     );
418   }
419   Vectorized<float> erfinv() const {
420     return map(calc_erfinv);
421   }
422   Vectorized<float> exp() const {
423     return USE_SLEEF(
424       Vectorized<float>(Sleef_expf4_u10(values.val[0]), Sleef_expf4_u10(values.val[1])),
425       map(std::exp)
426     );
427   }
428   Vectorized<float> exp2() const {
429     return USE_SLEEF(
430         Vectorized<float>(Sleef_exp2f4_u10(values.val[0]), Sleef_exp2f4_u10(values.val[1])),
431         map(std::exp2)
432       );
433   }
434   Vectorized<float> expm1() const {
435     return USE_SLEEF(
436       Vectorized<float>(Sleef_expm1f4_u10(values.val[0]), Sleef_expm1f4_u10(values.val[1])),
437       map(std::expm1)
438     );
439   }
440   Vectorized<float> exp_u20() const {
441     return exp();
442   }
443   Vectorized<float> fmod(const Vectorized<float>& q) const {
444     USE_SLEEF(
445       {
446         return Vectorized<float>(Sleef_fmodf4(values.val[0], q.values.val[0]),
447                                  Sleef_fmodf4(values.val[1], q.values.val[1]));
448       },
449       {
450         __at_align__ float tmp[size()];
451         __at_align__ float tmp_q[size()];
452         store(tmp);
453         q.store(tmp_q);
454         for (const auto i : c10::irange(size())) {
455           tmp[i] = std::fmod(tmp[i], tmp_q[i]);
456         }
457         return loadu(tmp);
458       }
459     )
460   }
461   Vectorized<float> hypot(const Vectorized<float> &b) const {
462     USE_SLEEF(
463       {
464         return Vectorized<float>(Sleef_hypotf4_u05(values.val[0], b.values.val[0]),
465                                  Sleef_hypotf4_u05(values.val[1], b.values.val[1]));
466       },
467       {
468         __at_align__ float tmp[size()];
469         __at_align__ float tmp_b[size()];
470         store(tmp);
471         b.store(tmp_b);
472         for (const auto i : c10::irange(size())) {
473           tmp[i] = std::hypot(tmp[i], tmp_b[i]);
474         }
475         return loadu(tmp);
476       }
477     )
478   }
479   Vectorized<float> i0() const {
480     return map(calc_i0);
481   }
482   Vectorized<float> i0e() const {
483     return map(calc_i0e);
484   }
485   Vectorized<float> digamma() const {
486     return map(calc_digamma);
487   }
488   Vectorized<float> igamma(const Vectorized<float> &x) const {
489     __at_align__ float tmp[size()];
490     __at_align__ float tmp_x[size()];
491     store(tmp);
492     x.store(tmp_x);
493     for (const auto i : c10::irange(size())) {
494       tmp[i] = calc_igamma(tmp[i], tmp_x[i]);
495     }
496     return loadu(tmp);
497   }
498   Vectorized<float> igammac(const Vectorized<float> &x) const {
499     __at_align__ float tmp[size()];
500     __at_align__ float tmp_x[size()];
501     store(tmp);
502     x.store(tmp_x);
503     for (const auto i : c10::irange(size())) {
504       tmp[i] = calc_igammac(tmp[i], tmp_x[i]);
505     }
506     return loadu(tmp);
507   }
508   Vectorized<float> log() const {
509     return USE_SLEEF(
510       Vectorized<float>(Sleef_logf4_u10(values.val[0]), Sleef_logf4_u10(values.val[1])),
511       map(std::log)
512     );
513   }
514   Vectorized<float> log10() const {
515     return USE_SLEEF(
516       Vectorized<float>(Sleef_log10f4_u10(values.val[0]), Sleef_log10f4_u10(values.val[1])),
517       map(std::log10)
518     );
519   }
520   Vectorized<float> log1p() const {
521     return USE_SLEEF(
522       Vectorized<float>(Sleef_log1pf4_u10(values.val[0]), Sleef_log1pf4_u10(values.val[1])),
523       map(std::log1p)
524     );
525   }
526   Vectorized<float> log2() const {
527     return USE_SLEEF(
528       Vectorized<float>(Sleef_log2f4_u10(values.val[0]), Sleef_log2f4_u10(values.val[1])),
529       map(std::log2)
530     );
531   }
532   Vectorized<float> nextafter(const Vectorized<float> &b) const {
533     USE_SLEEF(
534       {
535         return Vectorized<float>(Sleef_nextafterf4(values.val[0], b.values.val[0]),
536                                  Sleef_nextafterf4(values.val[1], b.values.val[1]));
537       },
538       {
539         __at_align__ float tmp[size()];
540         __at_align__ float tmp_b[size()];
541         store(tmp);
542         b.store(tmp_b);
543         for (const auto i : c10::irange(size())) {
544           tmp[i] = std::nextafter(tmp[i], tmp_b[i]);
545         }
546         return loadu(tmp);
547       }
548     )
549   }
550   Vectorized<float> frac() const;
551   Vectorized<float> sin() const {
552     return USE_SLEEF(
553       Vectorized<float>(Sleef_sinf4_u10(values.val[0]), Sleef_sinf4_u10(values.val[1])),
554       map(std::sin)
555     );
556   }
557   Vectorized<float> sinh() const {
558     return USE_SLEEF(
559       Vectorized<float>(Sleef_sinhf4_u10(values.val[0]), Sleef_sinhf4_u10(values.val[1])),
560       map(std::sinh)
561     );
562   }
563   Vectorized<float> cos() const {
564     return USE_SLEEF(
565       Vectorized<float>(Sleef_cosf4_u10(values.val[0]), Sleef_cosf4_u10(values.val[1])),
566       map(std::cos)
567     );
568   }
569   Vectorized<float> cosh() const {
570     return USE_SLEEF(
571       Vectorized<float>(Sleef_coshf4_u10(values.val[0]), Sleef_coshf4_u10(values.val[1])),
572       map(std::cosh)
573     );
574   }
575   Vectorized<float> ceil() const {
576     return map(at::native::ceil_impl);
577   }
578   Vectorized<float> floor() const {
579     return map(at::native::floor_impl);
580   }
581   Vectorized<float> neg() const {
582     return Vectorized<float>(
583         vnegq_f32(values.val[0]),
584         vnegq_f32(values.val[1]));
585   }
586   Vectorized<float> round() const {
587     // We do not use std::round because we would like to round midway numbers to the nearest even integer.
588     return map(at::native::round_impl);
589   }
590   Vectorized<float> tan() const {
591     return USE_SLEEF(
592       Vectorized<float>(Sleef_tanf4_u10(values.val[0]), Sleef_tanf4_u10(values.val[1])),
593       map(std::tan)
594     );
595   }
596   Vectorized<float> tanh() const {
597     return USE_SLEEF(
598       Vectorized<float>(Sleef_tanhf4_u10(values.val[0]), Sleef_tanhf4_u10(values.val[1])),
599       map(std::tanh)
600     );
601   }
602   Vectorized<float> trunc() const {
603     float32x4_t r0 = vrndq_f32(values.val[0]);
604     float32x4_t r1 = vrndq_f32(values.val[1]);
605     return Vectorized<float>(r0, r1);
606   }
607   Vectorized<float> lgamma() const {
608     return USE_SLEEF(
609       Vectorized<float>(Sleef_lgammaf4_u10(values.val[0]), Sleef_lgammaf4_u10(values.val[1])),
610       map(std::lgamma)
611     );
612   }
613   Vectorized<float> sqrt() const {
614     return Vectorized<float>(
615         vsqrtq_f32(values.val[0]),
616         vsqrtq_f32(values.val[1]));
617   }
618   Vectorized<float> reciprocal() const {
619     auto r0 = vdivq_f32(vdupq_n_f32(1.0f), values.val[0]);
620     auto r1 = vdivq_f32(vdupq_n_f32(1.0f), values.val[1]);
621     return Vectorized<float>(r0, r1);
622   }
623   Vectorized<float> rsqrt() const {
624     return this->sqrt().reciprocal();
625   }
626   Vectorized<float> pow(const Vectorized<float> &exp) const {
627     USE_SLEEF(
628       {
629         return Vectorized<float>(Sleef_powf4_u10(values.val[0], exp.values.val[0]),
630                                  Sleef_powf4_u10(values.val[1], exp.values.val[1]));
631       },
632       {
633         __at_align__ float tmp[size()];
634         __at_align__ float tmp_exp[size()];
635         store(tmp);
636         exp.store(tmp_exp);
637         for (const auto i : c10::irange(size())) {
638           tmp[i] = std::pow(tmp[i], tmp_exp[i]);
639         }
640         return loadu(tmp);
641       }
642     )
643   }
644   Vectorized<float> operator==(const Vectorized<float>& other) const {
645     float32x4_t r0 =
646       vreinterpretq_f32_u32(vceqq_f32(values.val[0], other.values.val[0]));
647     float32x4_t r1 =
648       vreinterpretq_f32_u32(vceqq_f32(values.val[1], other.values.val[1]));
649     return Vectorized<float>(r0, r1);
650   }
651 
652   Vectorized<float> operator!=(const Vectorized<float>& other) const {
653     float32x4_t r0 = vreinterpretq_f32_u32(
654         vmvnq_u32(vceqq_f32(values.val[0], other.values.val[0])));
655     float32x4_t r1 = vreinterpretq_f32_u32(
656         vmvnq_u32(vceqq_f32(values.val[1], other.values.val[1])));
657     return Vectorized<float>(r0, r1);
658   }
659 
660   Vectorized<float> operator<(const Vectorized<float>& other) const {
661     float32x4_t r0 =
662       vreinterpretq_f32_u32(vcltq_f32(values.val[0], other.values.val[0]));
663     float32x4_t r1 =
664       vreinterpretq_f32_u32(vcltq_f32(values.val[1], other.values.val[1]));
665     return Vectorized<float>(r0, r1);
666   }
667 
668   Vectorized<float> operator<=(const Vectorized<float>& other) const {
669     float32x4_t r0 =
670       vreinterpretq_f32_u32(vcleq_f32(values.val[0], other.values.val[0]));
671     float32x4_t r1 =
672       vreinterpretq_f32_u32(vcleq_f32(values.val[1], other.values.val[1]));
673     return Vectorized<float>(r0, r1);
674   }
675 
676   Vectorized<float> operator>(const Vectorized<float>& other) const {
677     float32x4_t r0 =
678       vreinterpretq_f32_u32(vcgtq_f32(values.val[0], other.values.val[0]));
679     float32x4_t r1 =
680       vreinterpretq_f32_u32(vcgtq_f32(values.val[1], other.values.val[1]));
681     return Vectorized<float>(r0, r1);
682   }
683 
684   Vectorized<float> operator>=(const Vectorized<float>& other) const {
685     float32x4_t r0 =
686       vreinterpretq_f32_u32(vcgeq_f32(values.val[0], other.values.val[0]));
687     float32x4_t r1 =
688       vreinterpretq_f32_u32(vcgeq_f32(values.val[1], other.values.val[1]));
689     return Vectorized<float>(r0, r1);
690   }
691 
692   Vectorized<float> eq(const Vectorized<float>& other) const;
693   Vectorized<float> ne(const Vectorized<float>& other) const;
694   Vectorized<float> gt(const Vectorized<float>& other) const;
695   Vectorized<float> ge(const Vectorized<float>& other) const;
696   Vectorized<float> lt(const Vectorized<float>& other) const;
697   Vectorized<float> le(const Vectorized<float>& other) const;
698 };
699 
700 template <>
701 Vectorized<float> inline operator+(const Vectorized<float>& a, const Vectorized<float>& b) {
702   float32x4_t r0 = vaddq_f32(a.get_low(), b.get_low());
703   float32x4_t r1 = vaddq_f32(a.get_high(), b.get_high());
704   return Vectorized<float>(r0, r1);
705 }
706 
707 template <>
708 Vectorized<float> inline operator-(const Vectorized<float>& a, const Vectorized<float>& b) {
709   float32x4_t r0 = vsubq_f32(a.get_low(), b.get_low());
710   float32x4_t r1 = vsubq_f32(a.get_high(), b.get_high());
711   return Vectorized<float>(r0, r1);
712 }
713 
714 template <>
715 Vectorized<float> inline operator*(const Vectorized<float>& a, const Vectorized<float>& b) {
716   float32x4_t r0 = vmulq_f32(a.get_low(), b.get_low());
717   float32x4_t r1 = vmulq_f32(a.get_high(), b.get_high());
718   return Vectorized<float>(r0, r1);
719 }
720 
721 template <>
722 Vectorized<float> inline operator/(const Vectorized<float>& a, const Vectorized<float>& b) {
723   float32x4_t r0 = vdivq_f32(a.get_low(), b.get_low());
724   float32x4_t r1 = vdivq_f32(a.get_high(), b.get_high());
725   return Vectorized<float>(r0, r1);
726 }
727 
728 // frac. Implement this here so we can use subtraction
729 inline Vectorized<float> Vectorized<float>::frac() const {
730   return *this - this->trunc();
731 }
732 
733 //Added sleef Implementation for Maximum
734 Vectorized<float> inline maximum(const Vectorized<float>& a, const Vectorized<float>& b)  {
735   if(!a.has_inf_nan() && !b.has_inf_nan()){
736     return USE_SLEEF(
737       Vectorized<float>(Sleef_fmaxf4(a.get_low(), b.get_low()),Sleef_fmaxf4(a.get_high(), b.get_high())),
738       Vectorized<float>(vmaxq_f32(a.get_low(), b.get_low()),vmaxq_f32(a.get_high(), b.get_high())));
739   }
740   else{
741     return Vectorized<float>(vmaxq_f32(a.get_low(), b.get_low()),vmaxq_f32(a.get_high(), b.get_high()));
742   }
743   }
744 
745 // Implements the IEEE 754 201X `minimum` operation, which propagates NaN if
746 // either input is a NaN.
747 template <>
748 Vectorized<float> inline minimum(const Vectorized<float>& a, const Vectorized<float>& b) {
749   float32x4_t r0 = vminq_f32(a.get_low(), b.get_low());
750   float32x4_t r1 = vminq_f32(a.get_high(), b.get_high());
751   return Vectorized<float>(r0, r1);
752 }
753 
754 template <>
755 Vectorized<float> inline clamp(const Vectorized<float>& a, const Vectorized<float>& min, const Vectorized<float>& max) {
756   return minimum(max, maximum(min, a));
757 }
758 
759 template <>
760 Vectorized<float> inline clamp_max(const Vectorized<float>& a, const Vectorized<float>& max) {
761   return minimum(max, a);
762 }
763 
764 template <>
765 Vectorized<float> inline clamp_min(const Vectorized<float>& a, const Vectorized<float>& min) {
766   return maximum(min, a);
767 }
768 
769 template <>
770 Vectorized<float> inline operator&(const Vectorized<float>& a, const Vectorized<float>& b) {
771   float32x4_t r0 = vreinterpretq_f32_u32(vandq_u32(
772       vreinterpretq_u32_f32(a.get_low()),
773       vreinterpretq_u32_f32(b.get_low())));
774   float32x4_t r1 = vreinterpretq_f32_u32(vandq_u32(
775       vreinterpretq_u32_f32(a.get_high()),
776       vreinterpretq_u32_f32(b.get_high())));
777   return Vectorized<float>(r0, r1);
778 }
779 
780 template <>
781 Vectorized<float> inline operator|(const Vectorized<float>& a, const Vectorized<float>& b) {
782   float32x4_t r0 = vreinterpretq_f32_u32(vorrq_u32(
783       vreinterpretq_u32_f32(a.get_low()),
784       vreinterpretq_u32_f32(b.get_low())));
785   float32x4_t r1 = vreinterpretq_f32_u32(vorrq_u32(
786       vreinterpretq_u32_f32(a.get_high()),
787       vreinterpretq_u32_f32(b.get_high())));
788   return Vectorized<float>(r0, r1);
789 }
790 
791 template <>
792 Vectorized<float> inline operator^(const Vectorized<float>& a, const Vectorized<float>& b) {
793   float32x4_t r0 = vreinterpretq_f32_u32(veorq_u32(
794       vreinterpretq_u32_f32(a.get_low()),
795       vreinterpretq_u32_f32(b.get_low())));
796   float32x4_t r1 = vreinterpretq_f32_u32(veorq_u32(
797       vreinterpretq_u32_f32(a.get_high()),
798       vreinterpretq_u32_f32(b.get_high())));
799   return Vectorized<float>(r0, r1);
800 }
801 
802 inline Vectorized<float> Vectorized<float>::eq(const Vectorized<float>& other) const {
803   return (*this == other) & Vectorized<float>(1.0f);
804 }
805 
806 inline Vectorized<float> Vectorized<float>::ne(const Vectorized<float>& other) const {
807   return (*this != other) & Vectorized<float>(1.0f);
808 }
809 
810 inline Vectorized<float> Vectorized<float>::gt(const Vectorized<float>& other) const {
811   return (*this > other) & Vectorized<float>(1.0f);
812 }
813 
814 inline Vectorized<float> Vectorized<float>::ge(const Vectorized<float>& other) const {
815   return (*this >= other) & Vectorized<float>(1.0f);
816 }
817 
818 inline Vectorized<float> Vectorized<float>::lt(const Vectorized<float>& other) const {
819   return (*this < other) & Vectorized<float>(1.0f);
820 }
821 
822 inline Vectorized<float> Vectorized<float>::le(const Vectorized<float>& other) const {
823   return (*this <= other) & Vectorized<float>(1.0f);
824 }
825 
826 template <>
827 inline void convert(const float* src, int32_t* dst, int64_t n) {
828   int64_t i;
829 #ifndef __msvc_cl__
830 #pragma unroll
831 #endif
832   for (i = 0; i <= (n - Vectorized<float>::size()); i += Vectorized<float>::size()) {
833     vst1q_s32(dst + i, vcvtq_s32_f32(vld1q_f32(src + i)));
834     vst1q_s32(dst + i + 4, vcvtq_s32_f32(vld1q_f32(src + i + 4)));
835   }
836 #ifndef __msvc_cl__
837 #pragma unroll
838 #endif
839   for (; i < n; i++) {
840     dst[i] = static_cast<int32_t>(src[i]);
841   }
842 }
843 
844 template <>
845 inline void convert(const int32_t* src, float* dst, int64_t n) {
846   int64_t i;
847 #ifndef __msvc_cl__
848 #pragma unroll
849 #endif
850   for (i = 0; i <= (n - Vectorized<float>::size()); i += Vectorized<float>::size()) {
851     vst1q_f32(dst + i, vcvtq_f32_s32(vld1q_s32(src + i)));
852     vst1q_f32(dst + i + 4, vcvtq_f32_s32(vld1q_s32(src + i + 4)));
853   }
854 #ifndef __msvc_cl__
855 #pragma unroll
856 #endif
857   for (; i < n; i++) {
858     dst[i] = static_cast<float>(src[i]);
859   }
860 }
861 
862 template <>
863 Vectorized<float> inline fmadd(const Vectorized<float>& a, const Vectorized<float>& b, const Vectorized<float>& c) {
864   float32x4_t r0 = vfmaq_f32(c.get_low(), a.get_low(), b.get_low());
865   float32x4_t r1 = vfmaq_f32(c.get_high(), a.get_high(), b.get_high());
866   return Vectorized<float>(r0, r1);
867 }
868 
869 template <>
870 Vectorized<float> inline fmsub(const Vectorized<float>& a, const Vectorized<float>& b, const Vectorized<float>& c) {
871   float32x4_t r0 = vfmsq_f32(c.get_low(), a.get_low(), b.get_low());
872   float32x4_t r1 = vfmsq_f32(c.get_high(), a.get_high(), b.get_high());
873   return Vectorized<float>(r0, r1);
874 }
875 
876 inline Vectorized<float> Vectorized<float>::erf() const{
877     // constants
878     const Vectorized<float> neg_zero_vec(-0.f);
879     const Vectorized<float> one_vec(1.0f);
880     const Vectorized<float> p(0.3275911f);
881     const Vectorized<float> p1(0.254829592f);
882     const Vectorized<float> p2(-0.284496736f);
883     const Vectorized<float> p3(1.421413741f);
884     const Vectorized<float> p4(-1.453152027f);
885     const Vectorized<float> p5(1.061405429f);
886     // sign(x)
887     auto sign_mask = neg_zero_vec & *this;
888     auto abs_vec = this->abs();
889     // t = 1 / (p * abs(x) + 1)
890     auto tmp0 = fmadd(p, abs_vec, one_vec);
891     auto t = one_vec / tmp0;
892     // r = p5 * t ^ 4 + p4 * t ^ 3 + p3 * t ^ 2 + p2 * t + p1
893     auto tmp1 = fmadd(p5, t, p4);
894     auto tmp2 = fmadd(tmp1, t, p3);
895     auto tmp3 = fmadd(tmp2, t, p2);
896     auto r = fmadd(tmp3, t, p1);
897     // - exp(- x * x)
898     auto pow_2 = (*this) * (*this);
899     auto neg_pow_2 = pow_2 ^ neg_zero_vec;
900     auto tmp4 = neg_pow_2.map(std::exp); // This can be swapped for a faster implementation of exp.
901     auto tmp5 = tmp4 ^ neg_zero_vec;
902     // erf(x) = sign(x) * (1 - r * t * exp(- x * x))
903     auto tmp6 = t * tmp5;
904     auto tmp7 = fmadd(tmp6, r, one_vec);
905     return tmp7 ^ sign_mask;
906 }
907 #endif /* defined(aarch64) */
908 
909 }} // namespace at::vec::CPU_CAPABILITY
910