1 /*
2 * Copyright (c) Meta Platforms, Inc. and affiliates.
3 * All rights reserved.
4 *
5 * This source code is licensed under the BSD-style license found in the
6 * LICENSE file in the root directory of this source tree.
7 */
8
9 #pragma once
10
11 // DO NOT DEFINE STATIC DATA IN THIS HEADER!
12 // See Note [Do not compile initializers with AVX]
13
14 #include <executorch/kernels/optimized/vec/intrinsics.h>
15 #include <executorch/kernels/optimized/vec/vec_base.h>
16
17 #if defined(CPU_CAPABILITY_AVX2) && !defined(_MSC_VER)
18 #include <sleef.h>
19 #endif
20
21 namespace executorch {
22 namespace vec {
23 // See Note [CPU_CAPABILITY namespace]
24 inline namespace CPU_CAPABILITY {
25
26
27 #if defined(CPU_CAPABILITY_AVX2) && !defined(_MSC_VER)
28
29 template <> class Vectorized<double> {
30 private:
31 __m256d values;
32 public:
33 using value_type = double;
34 using size_type = int;
size()35 static constexpr size_type size() {
36 return 4;
37 }
Vectorized()38 Vectorized() {}
Vectorized(__m256d v)39 Vectorized(__m256d v) : values(v) {}
Vectorized(double val)40 Vectorized(double val) {
41 values = _mm256_set1_pd(val);
42 }
Vectorized(double val1,double val2,double val3,double val4)43 Vectorized(double val1, double val2, double val3, double val4) {
44 values = _mm256_setr_pd(val1, val2, val3, val4);
45 }
__m256d()46 operator __m256d() const {
47 return values;
48 }
49 template <int64_t mask>
blend(const Vectorized<double> & a,const Vectorized<double> & b)50 static Vectorized<double> blend(const Vectorized<double>& a, const Vectorized<double>& b) {
51 return _mm256_blend_pd(a.values, b.values, mask);
52 }
blendv(const Vectorized<double> & a,const Vectorized<double> & b,const Vectorized<double> & mask)53 static Vectorized<double> blendv(const Vectorized<double>& a, const Vectorized<double>& b,
54 const Vectorized<double>& mask) {
55 return _mm256_blendv_pd(a.values, b.values, mask.values);
56 }
57 template<typename step_t>
58 static Vectorized<double> arange(double base = 0., step_t step = static_cast<step_t>(1)) {
59 return Vectorized<double>(base, base + step, base + 2 * step, base + 3 * step);
60 }
61 static Vectorized<double> set(const Vectorized<double>& a, const Vectorized<double>& b,
62 int64_t count = size()) {
63 switch (count) {
64 case 0:
65 return a;
66 case 1:
67 return blend<1>(a, b);
68 case 2:
69 return blend<3>(a, b);
70 case 3:
71 return blend<7>(a, b);
72 }
73 return b;
74 }
75 static Vectorized<double> loadu(const void* ptr, int64_t count = size()) {
76 if (count == size())
77 return _mm256_loadu_pd(reinterpret_cast<const double*>(ptr));
78
79
80 __at_align__ double tmp_values[size()];
81 // Ensure uninitialized memory does not change the output value See https://github.com/pytorch/pytorch/issues/32502
82 // for more details. We do not initialize arrays to zero using "={0}" because gcc would compile it to two
83 // instructions while a loop would be compiled to one instruction.
84 for (size_t i = 0; i < size(); ++i) {
85 tmp_values[i] = 0.0;
86 }
87 std::memcpy(
88 tmp_values,
89 reinterpret_cast<const double*>(ptr),
90 count * sizeof(double));
91 return _mm256_load_pd(tmp_values);
92 }
93 void store(void* ptr, int count = size()) const {
94 if (count == size()) {
95 _mm256_storeu_pd(reinterpret_cast<double*>(ptr), values);
96 } else if (count > 0) {
97 double tmp_values[size()];
98 _mm256_storeu_pd(reinterpret_cast<double*>(tmp_values), values);
99 std::memcpy(ptr, tmp_values, count * sizeof(double));
100 }
101 }
102 const double& operator[](int idx) const = delete;
103 double& operator[](int idx) = delete;
zero_mask()104 int zero_mask() const {
105 // returns an integer mask where all zero elements are translated to 1-bit and others are translated to 0-bit
106 __m256d cmp = _mm256_cmp_pd(values, _mm256_set1_pd(0.0), _CMP_EQ_OQ);
107 return _mm256_movemask_pd(cmp);
108 }
isnan()109 Vectorized<double> isnan() const {
110 return _mm256_cmp_pd(values, _mm256_set1_pd(0.0), _CMP_UNORD_Q);
111 }
map(double (* const f)(double))112 Vectorized<double> map(double (*const f)(double)) const {
113 __at_align__ double tmp[size()];
114 store(tmp);
115 for (size_t i = 0; i < size(); ++i) {
116 tmp[i] = f(tmp[i]);
117 }
118 return loadu(tmp);
119 }
abs()120 Vectorized<double> abs() const {
121 auto mask = _mm256_set1_pd(-0.f);
122 return _mm256_andnot_pd(mask, values);
123 }
acos()124 Vectorized<double> acos() const {
125 return Vectorized<double>(Sleef_acosd4_u10(values));
126 }
asin()127 Vectorized<double> asin() const {
128 return Vectorized<double>(Sleef_asind4_u10(values));
129 }
atan()130 Vectorized<double> atan() const {
131 return Vectorized<double>(Sleef_atand4_u10(values));
132 }
atan2(const Vectorized<double> & b)133 Vectorized<double> atan2(const Vectorized<double> &b) const {
134 return Vectorized<double>(Sleef_atan2d4_u10(values, b));
135 }
copysign(const Vectorized<double> & sign)136 Vectorized<double> copysign(const Vectorized<double> &sign) const {
137 return Vectorized<double>(Sleef_copysignd4(values, sign));
138 }
erf()139 Vectorized<double> erf() const {
140 return Vectorized<double>(Sleef_erfd4_u10(values));
141 }
erfc()142 Vectorized<double> erfc() const {
143 return Vectorized<double>(Sleef_erfcd4_u15(values));
144 }
exp()145 Vectorized<double> exp() const {
146 return Vectorized<double>(Sleef_expd4_u10(values));
147 }
exp2()148 Vectorized<double> exp2() const {
149 return Vectorized<double>(Sleef_exp2d4_u10(values));
150 }
expm1()151 Vectorized<double> expm1() const {
152 return Vectorized<double>(Sleef_expm1d4_u10(values));
153 }
fmod(const Vectorized<double> & q)154 Vectorized<double> fmod(const Vectorized<double>& q) const {
155 return Vectorized<double>(Sleef_fmodd4(values, q));
156 }
hypot(const Vectorized<double> & b)157 Vectorized<double> hypot(const Vectorized<double> &b) const {
158 return Vectorized<double>(Sleef_hypotd4_u05(values, b));
159 }
log()160 Vectorized<double> log() const {
161 return Vectorized<double>(Sleef_logd4_u10(values));
162 }
log2()163 Vectorized<double> log2() const {
164 return Vectorized<double>(Sleef_log2d4_u10(values));
165 }
log10()166 Vectorized<double> log10() const {
167 return Vectorized<double>(Sleef_log10d4_u10(values));
168 }
log1p()169 Vectorized<double> log1p() const {
170 return Vectorized<double>(Sleef_log1pd4_u10(values));
171 }
sin()172 Vectorized<double> sin() const {
173 return Vectorized<double>(Sleef_sind4_u10(values));
174 }
sinh()175 Vectorized<double> sinh() const {
176 return Vectorized<double>(Sleef_sinhd4_u10(values));
177 }
cos()178 Vectorized<double> cos() const {
179 return Vectorized<double>(Sleef_cosd4_u10(values));
180 }
cosh()181 Vectorized<double> cosh() const {
182 return Vectorized<double>(Sleef_coshd4_u10(values));
183 }
ceil()184 Vectorized<double> ceil() const {
185 return _mm256_ceil_pd(values);
186 }
floor()187 Vectorized<double> floor() const {
188 return _mm256_floor_pd(values);
189 }
190 Vectorized<double> frac() const;
neg()191 Vectorized<double> neg() const {
192 return _mm256_xor_pd(_mm256_set1_pd(-0.), values);
193 }
nextafter(const Vectorized<double> & b)194 Vectorized<double> nextafter(const Vectorized<double> &b) const {
195 return Vectorized<double>(Sleef_nextafterd4(values, b));
196 }
round()197 Vectorized<double> round() const {
198 return _mm256_round_pd(values, (_MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC));
199 }
tan()200 Vectorized<double> tan() const {
201 return Vectorized<double>(Sleef_tand4_u10(values));
202 }
tanh()203 Vectorized<double> tanh() const {
204 return Vectorized<double>(Sleef_tanhd4_u10(values));
205 }
trunc()206 Vectorized<double> trunc() const {
207 return _mm256_round_pd(values, (_MM_FROUND_TO_ZERO | _MM_FROUND_NO_EXC));
208 }
lgamma()209 Vectorized<double> lgamma() const {
210 return Vectorized<double>(Sleef_lgammad4_u10(values));
211 }
sqrt()212 Vectorized<double> sqrt() const {
213 return _mm256_sqrt_pd(values);
214 }
reciprocal()215 Vectorized<double> reciprocal() const {
216 return _mm256_div_pd(_mm256_set1_pd(1), values);
217 }
rsqrt()218 Vectorized<double> rsqrt() const {
219 return _mm256_div_pd(_mm256_set1_pd(1), _mm256_sqrt_pd(values));
220 }
pow(const Vectorized<double> & b)221 Vectorized<double> pow(const Vectorized<double> &b) const {
222 return Vectorized<double>(Sleef_powd4_u10(values, b));
223 }
224 // Comparison using the _CMP_**_OQ predicate.
225 // `O`: get false if an operand is NaN
226 // `Q`: do not raise if an operand is NaN
227 Vectorized<double> operator==(const Vectorized<double>& other) const {
228 return _mm256_cmp_pd(values, other.values, _CMP_EQ_OQ);
229 }
230
231 Vectorized<double> operator!=(const Vectorized<double>& other) const {
232 return _mm256_cmp_pd(values, other.values, _CMP_NEQ_UQ);
233 }
234
235 Vectorized<double> operator<(const Vectorized<double>& other) const {
236 return _mm256_cmp_pd(values, other.values, _CMP_LT_OQ);
237 }
238
239 Vectorized<double> operator<=(const Vectorized<double>& other) const {
240 return _mm256_cmp_pd(values, other.values, _CMP_LE_OQ);
241 }
242
243 Vectorized<double> operator>(const Vectorized<double>& other) const {
244 return _mm256_cmp_pd(values, other.values, _CMP_GT_OQ);
245 }
246
247 Vectorized<double> operator>=(const Vectorized<double>& other) const {
248 return _mm256_cmp_pd(values, other.values, _CMP_GE_OQ);
249 }
250
251 Vectorized<double> eq(const Vectorized<double>& other) const;
252 Vectorized<double> ne(const Vectorized<double>& other) const;
253 Vectorized<double> lt(const Vectorized<double>& other) const;
254 Vectorized<double> le(const Vectorized<double>& other) const;
255 Vectorized<double> gt(const Vectorized<double>& other) const;
256 Vectorized<double> ge(const Vectorized<double>& other) const;
257 };
258
259 template <>
260 Vectorized<double> inline operator+(const Vectorized<double>& a, const Vectorized<double>& b) {
261 return _mm256_add_pd(a, b);
262 }
263
264 template <>
265 Vectorized<double> inline operator-(const Vectorized<double>& a, const Vectorized<double>& b) {
266 return _mm256_sub_pd(a, b);
267 }
268
269 template <>
270 Vectorized<double> inline operator*(const Vectorized<double>& a, const Vectorized<double>& b) {
271 return _mm256_mul_pd(a, b);
272 }
273
274 template <>
275 Vectorized<double> inline operator/(const Vectorized<double>& a, const Vectorized<double>& b) {
276 return _mm256_div_pd(a, b);
277 }
278
279 // frac. Implement this here so we can use subtraction.
frac()280 inline Vectorized<double> Vectorized<double>::frac() const {
281 return *this - this->trunc();
282 }
283
284 // Implements the IEEE 754 201X `maximum` operation, which propagates NaN if
285 // either input is a NaN.
286 template <>
maximum(const Vectorized<double> & a,const Vectorized<double> & b)287 Vectorized<double> inline maximum(const Vectorized<double>& a, const Vectorized<double>& b) {
288 Vectorized<double> max = _mm256_max_pd(a, b);
289 Vectorized<double> isnan = _mm256_cmp_pd(a, b, _CMP_UNORD_Q);
290 // Exploit the fact that all-ones is a NaN.
291 return _mm256_or_pd(max, isnan);
292 }
293
294 // Implements the IEEE 754 201X `minimum` operation, which propagates NaN if
295 // either input is a NaN.
296 template <>
minimum(const Vectorized<double> & a,const Vectorized<double> & b)297 Vectorized<double> inline minimum(const Vectorized<double>& a, const Vectorized<double>& b) {
298 Vectorized<double> min = _mm256_min_pd(a, b);
299 Vectorized<double> isnan = _mm256_cmp_pd(a, b, _CMP_UNORD_Q);
300 // Exploit the fact that all-ones is a NaN.
301 return _mm256_or_pd(min, isnan);
302 }
303
304 template <>
clamp(const Vectorized<double> & a,const Vectorized<double> & min,const Vectorized<double> & max)305 Vectorized<double> inline clamp(const Vectorized<double>& a, const Vectorized<double>& min, const Vectorized<double>& max) {
306 return _mm256_min_pd(max, _mm256_max_pd(min, a));
307 }
308
309 template <>
clamp_min(const Vectorized<double> & a,const Vectorized<double> & min)310 Vectorized<double> inline clamp_min(const Vectorized<double>& a, const Vectorized<double>& min) {
311 return _mm256_max_pd(min, a);
312 }
313
314 template <>
clamp_max(const Vectorized<double> & a,const Vectorized<double> & max)315 Vectorized<double> inline clamp_max(const Vectorized<double>& a, const Vectorized<double>& max) {
316 return _mm256_min_pd(max, a);
317 }
318
319 template <>
320 Vectorized<double> inline operator&(const Vectorized<double>& a, const Vectorized<double>& b) {
321 return _mm256_and_pd(a, b);
322 }
323
324 template <>
325 Vectorized<double> inline operator|(const Vectorized<double>& a, const Vectorized<double>& b) {
326 return _mm256_or_pd(a, b);
327 }
328
329 template <>
330 Vectorized<double> inline operator^(const Vectorized<double>& a, const Vectorized<double>& b) {
331 return _mm256_xor_pd(a, b);
332 }
333
eq(const Vectorized<double> & other)334 inline Vectorized<double> Vectorized<double>::eq(const Vectorized<double>& other) const {
335 return (*this == other) & Vectorized<double>(1.0);
336 }
337
ne(const Vectorized<double> & other)338 inline Vectorized<double> Vectorized<double>::ne(const Vectorized<double>& other) const {
339 return (*this != other) & Vectorized<double>(1.0);
340 }
341
gt(const Vectorized<double> & other)342 inline Vectorized<double> Vectorized<double>::gt(const Vectorized<double>& other) const {
343 return (*this > other) & Vectorized<double>(1.0);
344 }
345
ge(const Vectorized<double> & other)346 inline Vectorized<double> Vectorized<double>::ge(const Vectorized<double>& other) const {
347 return (*this >= other) & Vectorized<double>(1.0);
348 }
349
lt(const Vectorized<double> & other)350 inline Vectorized<double> Vectorized<double>::lt(const Vectorized<double>& other) const {
351 return (*this < other) & Vectorized<double>(1.0);
352 }
353
le(const Vectorized<double> & other)354 inline Vectorized<double> Vectorized<double>::le(const Vectorized<double>& other) const {
355 return (*this <= other) & Vectorized<double>(1.0);
356 }
357
358 template <>
convert(const double * src,double * dst,int64_t n)359 inline void convert(const double* src, double* dst, int64_t n) {
360 int64_t i;
361 #pragma unroll
362 for (i = 0; i <= (n - Vectorized<double>::size()); i += Vectorized<double>::size()) {
363 _mm256_storeu_pd(dst + i, _mm256_loadu_pd(src + i));
364 }
365 #pragma unroll
366 for (; i < n; i++) {
367 dst[i] = src[i];
368 }
369 }
370
371 #ifdef CPU_CAPABILITY_AVX2
372 template <>
fmadd(const Vectorized<double> & a,const Vectorized<double> & b,const Vectorized<double> & c)373 Vectorized<double> inline fmadd(const Vectorized<double>& a, const Vectorized<double>& b, const Vectorized<double>& c) {
374 return _mm256_fmadd_pd(a, b, c);
375 }
376
377 template <>
fmsub(const Vectorized<double> & a,const Vectorized<double> & b,const Vectorized<double> & c)378 Vectorized<double> inline fmsub(const Vectorized<double>& a, const Vectorized<double>& b, const Vectorized<double>& c) {
379 return _mm256_fmsub_pd(a, b, c);
380 }
381 #endif
382
383 #endif
384
385 }}}
386