xref: /aosp_15_r20/external/executorch/kernels/optimized/vec/vec256/vec256_double.h (revision 523fa7a60841cd1ecfb9cc4201f1ca8b03ed023a)
1 /*
2  * Copyright (c) Meta Platforms, Inc. and affiliates.
3  * All rights reserved.
4  *
5  * This source code is licensed under the BSD-style license found in the
6  * LICENSE file in the root directory of this source tree.
7  */
8 
9 #pragma once
10 
11 // DO NOT DEFINE STATIC DATA IN THIS HEADER!
12 // See Note [Do not compile initializers with AVX]
13 
14 #include <executorch/kernels/optimized/vec/intrinsics.h>
15 #include <executorch/kernels/optimized/vec/vec_base.h>
16 
17 #if defined(CPU_CAPABILITY_AVX2) && !defined(_MSC_VER)
18 #include <sleef.h>
19 #endif
20 
21 namespace executorch {
22 namespace vec {
23 // See Note [CPU_CAPABILITY namespace]
24 inline namespace CPU_CAPABILITY {
25 
26 
27 #if defined(CPU_CAPABILITY_AVX2) && !defined(_MSC_VER)
28 
29 template <> class Vectorized<double> {
30 private:
31   __m256d values;
32 public:
33   using value_type = double;
34   using size_type = int;
size()35   static constexpr size_type size() {
36     return 4;
37   }
Vectorized()38   Vectorized() {}
Vectorized(__m256d v)39   Vectorized(__m256d v) : values(v) {}
Vectorized(double val)40   Vectorized(double val) {
41     values = _mm256_set1_pd(val);
42   }
Vectorized(double val1,double val2,double val3,double val4)43   Vectorized(double val1, double val2, double val3, double val4) {
44     values = _mm256_setr_pd(val1, val2, val3, val4);
45   }
__m256d()46   operator __m256d() const {
47     return values;
48   }
49   template <int64_t mask>
blend(const Vectorized<double> & a,const Vectorized<double> & b)50   static Vectorized<double> blend(const Vectorized<double>& a, const Vectorized<double>& b) {
51     return _mm256_blend_pd(a.values, b.values, mask);
52   }
blendv(const Vectorized<double> & a,const Vectorized<double> & b,const Vectorized<double> & mask)53   static Vectorized<double> blendv(const Vectorized<double>& a, const Vectorized<double>& b,
54                                const Vectorized<double>& mask) {
55     return _mm256_blendv_pd(a.values, b.values, mask.values);
56   }
57   template<typename step_t>
58   static Vectorized<double> arange(double base = 0., step_t step = static_cast<step_t>(1)) {
59     return Vectorized<double>(base, base + step, base + 2 * step, base + 3 * step);
60   }
61   static Vectorized<double> set(const Vectorized<double>& a, const Vectorized<double>& b,
62                             int64_t count = size()) {
63     switch (count) {
64       case 0:
65         return a;
66       case 1:
67         return blend<1>(a, b);
68       case 2:
69         return blend<3>(a, b);
70       case 3:
71         return blend<7>(a, b);
72     }
73     return b;
74   }
75   static Vectorized<double> loadu(const void* ptr, int64_t count = size()) {
76     if (count == size())
77       return _mm256_loadu_pd(reinterpret_cast<const double*>(ptr));
78 
79 
80     __at_align__ double tmp_values[size()];
81     // Ensure uninitialized memory does not change the output value See https://github.com/pytorch/pytorch/issues/32502
82     // for more details. We do not initialize arrays to zero using "={0}" because gcc would compile it to two
83     // instructions while a loop would be compiled to one instruction.
84     for (size_t i = 0; i < size(); ++i) {
85       tmp_values[i] = 0.0;
86     }
87     std::memcpy(
88         tmp_values,
89         reinterpret_cast<const double*>(ptr),
90         count * sizeof(double));
91     return _mm256_load_pd(tmp_values);
92   }
93   void store(void* ptr, int count = size()) const {
94     if (count == size()) {
95       _mm256_storeu_pd(reinterpret_cast<double*>(ptr), values);
96     } else if (count > 0) {
97       double tmp_values[size()];
98       _mm256_storeu_pd(reinterpret_cast<double*>(tmp_values), values);
99       std::memcpy(ptr, tmp_values, count * sizeof(double));
100     }
101   }
102   const double& operator[](int idx) const  = delete;
103   double& operator[](int idx) = delete;
zero_mask()104   int zero_mask() const {
105     // returns an integer mask where all zero elements are translated to 1-bit and others are translated to 0-bit
106     __m256d cmp = _mm256_cmp_pd(values, _mm256_set1_pd(0.0), _CMP_EQ_OQ);
107     return _mm256_movemask_pd(cmp);
108   }
isnan()109   Vectorized<double> isnan() const {
110     return _mm256_cmp_pd(values, _mm256_set1_pd(0.0), _CMP_UNORD_Q);
111   }
map(double (* const f)(double))112   Vectorized<double> map(double (*const f)(double)) const {
113     __at_align__ double tmp[size()];
114     store(tmp);
115     for (size_t i = 0; i < size(); ++i) {
116       tmp[i] = f(tmp[i]);
117     }
118     return loadu(tmp);
119   }
abs()120   Vectorized<double> abs() const {
121     auto mask = _mm256_set1_pd(-0.f);
122     return _mm256_andnot_pd(mask, values);
123   }
acos()124   Vectorized<double> acos() const {
125     return Vectorized<double>(Sleef_acosd4_u10(values));
126   }
asin()127   Vectorized<double> asin() const {
128     return Vectorized<double>(Sleef_asind4_u10(values));
129   }
atan()130   Vectorized<double> atan() const {
131     return Vectorized<double>(Sleef_atand4_u10(values));
132   }
atan2(const Vectorized<double> & b)133   Vectorized<double> atan2(const Vectorized<double> &b) const {
134     return Vectorized<double>(Sleef_atan2d4_u10(values, b));
135   }
copysign(const Vectorized<double> & sign)136   Vectorized<double> copysign(const Vectorized<double> &sign) const {
137     return Vectorized<double>(Sleef_copysignd4(values, sign));
138   }
erf()139   Vectorized<double> erf() const {
140     return Vectorized<double>(Sleef_erfd4_u10(values));
141   }
erfc()142   Vectorized<double> erfc() const {
143     return Vectorized<double>(Sleef_erfcd4_u15(values));
144   }
exp()145   Vectorized<double> exp() const {
146     return Vectorized<double>(Sleef_expd4_u10(values));
147   }
exp2()148   Vectorized<double> exp2() const {
149     return Vectorized<double>(Sleef_exp2d4_u10(values));
150   }
expm1()151   Vectorized<double> expm1() const {
152     return Vectorized<double>(Sleef_expm1d4_u10(values));
153   }
fmod(const Vectorized<double> & q)154   Vectorized<double> fmod(const Vectorized<double>& q) const {
155     return Vectorized<double>(Sleef_fmodd4(values, q));
156   }
hypot(const Vectorized<double> & b)157   Vectorized<double> hypot(const Vectorized<double> &b) const {
158     return Vectorized<double>(Sleef_hypotd4_u05(values, b));
159   }
log()160   Vectorized<double> log() const {
161     return Vectorized<double>(Sleef_logd4_u10(values));
162   }
log2()163   Vectorized<double> log2() const {
164     return Vectorized<double>(Sleef_log2d4_u10(values));
165   }
log10()166   Vectorized<double> log10() const {
167     return Vectorized<double>(Sleef_log10d4_u10(values));
168   }
log1p()169   Vectorized<double> log1p() const {
170     return Vectorized<double>(Sleef_log1pd4_u10(values));
171   }
sin()172   Vectorized<double> sin() const {
173     return Vectorized<double>(Sleef_sind4_u10(values));
174   }
sinh()175   Vectorized<double> sinh() const {
176     return Vectorized<double>(Sleef_sinhd4_u10(values));
177   }
cos()178   Vectorized<double> cos() const {
179     return Vectorized<double>(Sleef_cosd4_u10(values));
180   }
cosh()181   Vectorized<double> cosh() const {
182     return Vectorized<double>(Sleef_coshd4_u10(values));
183   }
ceil()184   Vectorized<double> ceil() const {
185     return _mm256_ceil_pd(values);
186   }
floor()187   Vectorized<double> floor() const {
188     return _mm256_floor_pd(values);
189   }
190   Vectorized<double> frac() const;
neg()191   Vectorized<double> neg() const {
192     return _mm256_xor_pd(_mm256_set1_pd(-0.), values);
193   }
nextafter(const Vectorized<double> & b)194   Vectorized<double> nextafter(const Vectorized<double> &b) const {
195     return Vectorized<double>(Sleef_nextafterd4(values, b));
196   }
round()197   Vectorized<double> round() const {
198     return _mm256_round_pd(values, (_MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC));
199   }
tan()200   Vectorized<double> tan() const {
201     return Vectorized<double>(Sleef_tand4_u10(values));
202   }
tanh()203   Vectorized<double> tanh() const {
204     return Vectorized<double>(Sleef_tanhd4_u10(values));
205   }
trunc()206   Vectorized<double> trunc() const {
207     return _mm256_round_pd(values, (_MM_FROUND_TO_ZERO | _MM_FROUND_NO_EXC));
208   }
lgamma()209   Vectorized<double> lgamma() const {
210     return Vectorized<double>(Sleef_lgammad4_u10(values));
211   }
sqrt()212   Vectorized<double> sqrt() const {
213     return _mm256_sqrt_pd(values);
214   }
reciprocal()215   Vectorized<double> reciprocal() const {
216     return _mm256_div_pd(_mm256_set1_pd(1), values);
217   }
rsqrt()218   Vectorized<double> rsqrt() const {
219     return _mm256_div_pd(_mm256_set1_pd(1), _mm256_sqrt_pd(values));
220   }
pow(const Vectorized<double> & b)221   Vectorized<double> pow(const Vectorized<double> &b) const {
222     return Vectorized<double>(Sleef_powd4_u10(values, b));
223   }
224   // Comparison using the _CMP_**_OQ predicate.
225   //   `O`: get false if an operand is NaN
226   //   `Q`: do not raise if an operand is NaN
227   Vectorized<double> operator==(const Vectorized<double>& other) const {
228     return _mm256_cmp_pd(values, other.values, _CMP_EQ_OQ);
229   }
230 
231   Vectorized<double> operator!=(const Vectorized<double>& other) const {
232     return _mm256_cmp_pd(values, other.values, _CMP_NEQ_UQ);
233   }
234 
235   Vectorized<double> operator<(const Vectorized<double>& other) const {
236     return _mm256_cmp_pd(values, other.values, _CMP_LT_OQ);
237   }
238 
239   Vectorized<double> operator<=(const Vectorized<double>& other) const {
240     return _mm256_cmp_pd(values, other.values, _CMP_LE_OQ);
241   }
242 
243   Vectorized<double> operator>(const Vectorized<double>& other) const {
244     return _mm256_cmp_pd(values, other.values, _CMP_GT_OQ);
245   }
246 
247   Vectorized<double> operator>=(const Vectorized<double>& other) const {
248     return _mm256_cmp_pd(values, other.values, _CMP_GE_OQ);
249   }
250 
251   Vectorized<double> eq(const Vectorized<double>& other) const;
252   Vectorized<double> ne(const Vectorized<double>& other) const;
253   Vectorized<double> lt(const Vectorized<double>& other) const;
254   Vectorized<double> le(const Vectorized<double>& other) const;
255   Vectorized<double> gt(const Vectorized<double>& other) const;
256   Vectorized<double> ge(const Vectorized<double>& other) const;
257 };
258 
259 template <>
260 Vectorized<double> inline operator+(const Vectorized<double>& a, const Vectorized<double>& b) {
261   return _mm256_add_pd(a, b);
262 }
263 
264 template <>
265 Vectorized<double> inline operator-(const Vectorized<double>& a, const Vectorized<double>& b) {
266   return _mm256_sub_pd(a, b);
267 }
268 
269 template <>
270 Vectorized<double> inline operator*(const Vectorized<double>& a, const Vectorized<double>& b) {
271   return _mm256_mul_pd(a, b);
272 }
273 
274 template <>
275 Vectorized<double> inline operator/(const Vectorized<double>& a, const Vectorized<double>& b) {
276   return _mm256_div_pd(a, b);
277 }
278 
279 // frac. Implement this here so we can use subtraction.
frac()280 inline Vectorized<double> Vectorized<double>::frac() const {
281   return *this - this->trunc();
282 }
283 
284 // Implements the IEEE 754 201X `maximum` operation, which propagates NaN if
285 // either input is a NaN.
286 template <>
maximum(const Vectorized<double> & a,const Vectorized<double> & b)287 Vectorized<double> inline maximum(const Vectorized<double>& a, const Vectorized<double>& b) {
288   Vectorized<double> max = _mm256_max_pd(a, b);
289   Vectorized<double> isnan = _mm256_cmp_pd(a, b, _CMP_UNORD_Q);
290   // Exploit the fact that all-ones is a NaN.
291   return _mm256_or_pd(max, isnan);
292 }
293 
294 // Implements the IEEE 754 201X `minimum` operation, which propagates NaN if
295 // either input is a NaN.
296 template <>
minimum(const Vectorized<double> & a,const Vectorized<double> & b)297 Vectorized<double> inline minimum(const Vectorized<double>& a, const Vectorized<double>& b) {
298   Vectorized<double> min = _mm256_min_pd(a, b);
299   Vectorized<double> isnan = _mm256_cmp_pd(a, b, _CMP_UNORD_Q);
300   // Exploit the fact that all-ones is a NaN.
301   return _mm256_or_pd(min, isnan);
302 }
303 
304 template <>
clamp(const Vectorized<double> & a,const Vectorized<double> & min,const Vectorized<double> & max)305 Vectorized<double> inline clamp(const Vectorized<double>& a, const Vectorized<double>& min, const Vectorized<double>& max) {
306   return _mm256_min_pd(max, _mm256_max_pd(min, a));
307 }
308 
309 template <>
clamp_min(const Vectorized<double> & a,const Vectorized<double> & min)310 Vectorized<double> inline clamp_min(const Vectorized<double>& a, const Vectorized<double>& min) {
311   return _mm256_max_pd(min, a);
312 }
313 
314 template <>
clamp_max(const Vectorized<double> & a,const Vectorized<double> & max)315 Vectorized<double> inline clamp_max(const Vectorized<double>& a, const Vectorized<double>& max) {
316   return _mm256_min_pd(max, a);
317 }
318 
319 template <>
320 Vectorized<double> inline operator&(const Vectorized<double>& a, const Vectorized<double>& b) {
321   return _mm256_and_pd(a, b);
322 }
323 
324 template <>
325 Vectorized<double> inline operator|(const Vectorized<double>& a, const Vectorized<double>& b) {
326   return _mm256_or_pd(a, b);
327 }
328 
329 template <>
330 Vectorized<double> inline operator^(const Vectorized<double>& a, const Vectorized<double>& b) {
331   return _mm256_xor_pd(a, b);
332 }
333 
eq(const Vectorized<double> & other)334 inline Vectorized<double> Vectorized<double>::eq(const Vectorized<double>& other) const {
335   return (*this == other) & Vectorized<double>(1.0);
336 }
337 
ne(const Vectorized<double> & other)338 inline Vectorized<double> Vectorized<double>::ne(const Vectorized<double>& other) const {
339   return (*this != other) & Vectorized<double>(1.0);
340 }
341 
gt(const Vectorized<double> & other)342 inline Vectorized<double> Vectorized<double>::gt(const Vectorized<double>& other) const {
343   return (*this > other) & Vectorized<double>(1.0);
344 }
345 
ge(const Vectorized<double> & other)346 inline Vectorized<double> Vectorized<double>::ge(const Vectorized<double>& other) const {
347   return (*this >= other) & Vectorized<double>(1.0);
348 }
349 
lt(const Vectorized<double> & other)350 inline Vectorized<double> Vectorized<double>::lt(const Vectorized<double>& other) const {
351   return (*this < other) & Vectorized<double>(1.0);
352 }
353 
le(const Vectorized<double> & other)354 inline Vectorized<double> Vectorized<double>::le(const Vectorized<double>& other) const {
355   return (*this <= other) & Vectorized<double>(1.0);
356 }
357 
358 template <>
convert(const double * src,double * dst,int64_t n)359 inline void convert(const double* src, double* dst, int64_t n) {
360   int64_t i;
361 #pragma unroll
362   for (i = 0; i <= (n - Vectorized<double>::size()); i += Vectorized<double>::size()) {
363     _mm256_storeu_pd(dst + i, _mm256_loadu_pd(src + i));
364   }
365 #pragma unroll
366   for (; i < n; i++) {
367     dst[i] = src[i];
368   }
369 }
370 
371 #ifdef CPU_CAPABILITY_AVX2
372 template <>
fmadd(const Vectorized<double> & a,const Vectorized<double> & b,const Vectorized<double> & c)373 Vectorized<double> inline fmadd(const Vectorized<double>& a, const Vectorized<double>& b, const Vectorized<double>& c) {
374   return _mm256_fmadd_pd(a, b, c);
375 }
376 
377 template <>
fmsub(const Vectorized<double> & a,const Vectorized<double> & b,const Vectorized<double> & c)378 Vectorized<double> inline fmsub(const Vectorized<double>& a, const Vectorized<double>& b, const Vectorized<double>& c) {
379   return _mm256_fmsub_pd(a, b, c);
380 }
381 #endif
382 
383 #endif
384 
385 }}}
386