1 /*
2 * Copyright (c) Meta Platforms, Inc. and affiliates.
3 * All rights reserved.
4 *
5 * This source code is licensed under the BSD-style license found in the
6 * LICENSE file in the root directory of this source tree.
7 */
8
9 #pragma once
10
11 // DO NOT DEFINE STATIC DATA IN THIS HEADER!
12 // See Note [Do not compile initializers with AVX]
13
14 #include <executorch/kernels/optimized/vec/intrinsics.h>
15 #include <executorch/kernels/optimized/vec/vec_base.h>
16
17 #include <iostream>
18
19 namespace executorch {
20 namespace vec {
21 inline namespace CPU_CAPABILITY {
22
23 #ifdef CPU_CAPABILITY_AVX2
24
25 struct Vectorizedi {
26 protected:
27 __m256i values;
28
invertVectorizedi29 static inline __m256i invert(const __m256i& v) {
30 const auto ones = _mm256_set1_epi64x(-1);
31 return _mm256_xor_si256(ones, v);
32 }
33 public:
VectorizediVectorizedi34 Vectorizedi() {}
VectorizediVectorizedi35 Vectorizedi(__m256i v) : values(v) {}
__m256iVectorizedi36 operator __m256i() const {
37 return values;
38 }
39 };
40
41 #else
42
43 struct Vectorizedi {}; // dummy definition to make Vectorizedi always defined
44
45 #endif // CPU_CAPABILITY_AVX2
46
47 #ifdef CPU_CAPABILITY_AVX2
48
49 template <>
50 class Vectorized<int64_t> : public Vectorizedi {
51 private:
52 static const Vectorized<int64_t> ones;
53 public:
54 using value_type = int64_t;
55 using size_type = int;
size()56 static constexpr size_type size() {
57 return 4;
58 }
59 using Vectorizedi::Vectorizedi;
Vectorized()60 Vectorized() {}
Vectorized(int64_t v)61 Vectorized(int64_t v) { values = _mm256_set1_epi64x(v); }
Vectorized(int64_t val1,int64_t val2,int64_t val3,int64_t val4)62 Vectorized(int64_t val1, int64_t val2, int64_t val3, int64_t val4) {
63 values = _mm256_setr_epi64x(val1, val2, val3, val4);
64 }
65 template <int64_t mask>
blend(Vectorized<int64_t> a,Vectorized<int64_t> b)66 static Vectorized<int64_t> blend(Vectorized<int64_t> a, Vectorized<int64_t> b) {
67 __at_align__ int64_t tmp_values[size()];
68 a.store(tmp_values);
69 if (mask & 0x01)
70 tmp_values[0] = _mm256_extract_epi64(b.values, 0);
71 if (mask & 0x02)
72 tmp_values[1] = _mm256_extract_epi64(b.values, 1);
73 if (mask & 0x04)
74 tmp_values[2] = _mm256_extract_epi64(b.values, 2);
75 if (mask & 0x08)
76 tmp_values[3] = _mm256_extract_epi64(b.values, 3);
77 return loadu(tmp_values);
78 }
blendv(const Vectorized<int64_t> & a,const Vectorized<int64_t> & b,const Vectorized<int64_t> & mask)79 static Vectorized<int64_t> blendv(const Vectorized<int64_t>& a, const Vectorized<int64_t>& b,
80 const Vectorized<int64_t>& mask) {
81 return _mm256_blendv_epi8(a.values, b.values, mask.values);
82 }
83 template <typename step_t>
84 static Vectorized<int64_t> arange(int64_t base = 0, step_t step = static_cast<step_t>(1)) {
85 return Vectorized<int64_t>(base, base + step, base + 2 * step, base + 3 * step);
86 }
87 static Vectorized<int64_t>
88 set(Vectorized<int64_t> a, Vectorized<int64_t> b, int64_t count = size()) {
89 switch (count) {
90 case 0:
91 return a;
92 case 1:
93 return blend<1>(a, b);
94 case 2:
95 return blend<3>(a, b);
96 case 3:
97 return blend<7>(a, b);
98 }
99 return b;
100 }
loadu(const void * ptr)101 static Vectorized<int64_t> loadu(const void* ptr) {
102 return _mm256_loadu_si256(reinterpret_cast<const __m256i*>(ptr));
103 }
loadu(const void * ptr,int64_t count)104 static Vectorized<int64_t> loadu(const void* ptr, int64_t count) {
105 __at_align__ int64_t tmp_values[size()];
106 // Ensure uninitialized memory does not change the output value See https://github.com/pytorch/pytorch/issues/32502
107 // for more details. We do not initialize arrays to zero using "={0}" because gcc would compile it to two
108 // instructions while a loop would be compiled to one instruction.
109 for (size_t i = 0; i < size(); ++i) {
110 tmp_values[i] = 0;
111 }
112 std::memcpy(tmp_values, ptr, count * sizeof(int64_t));
113 return loadu(tmp_values);
114 }
115 void store(void* ptr, int count = size()) const {
116 if (count == size()) {
117 // ptr need not to be aligned here. See
118 // https://software.intel.com/content/www/us/en/develop/documentation/cpp-compiler-developer-guide-and-reference/top/compiler-reference/intrinsics/intrinsics-for-intel-advanced-vector-extensions/intrinsics-for-load-and-store-operations-1/mm256-storeu-si256.html
119 _mm256_storeu_si256(reinterpret_cast<__m256i*>(ptr), values);
120 } else if (count > 0) {
121 __at_align__ int64_t tmp_values[size()];
122 _mm256_storeu_si256(reinterpret_cast<__m256i*>(tmp_values), values);
123 std::memcpy(ptr, tmp_values, count * sizeof(int64_t));
124 }
125 }
126 const int64_t& operator[](int idx) const = delete;
127 int64_t& operator[](int idx) = delete;
abs()128 Vectorized<int64_t> abs() const {
129 auto zero = _mm256_set1_epi64x(0);
130 auto is_larger = _mm256_cmpgt_epi64(zero, values);
131 auto inverse = _mm256_xor_si256(values, is_larger);
132 return _mm256_sub_epi64(inverse, is_larger);
133 }
real()134 Vectorized<int64_t> real() const {
135 return *this;
136 }
imag()137 Vectorized<int64_t> imag() const {
138 return _mm256_set1_epi64x(0);
139 }
conj()140 Vectorized<int64_t> conj() const {
141 return *this;
142 }
143 Vectorized<int64_t> neg() const;
144 Vectorized<int64_t> operator==(const Vectorized<int64_t>& other) const {
145 return _mm256_cmpeq_epi64(values, other.values);
146 }
147 Vectorized<int64_t> operator!=(const Vectorized<int64_t>& other) const {
148 return invert(_mm256_cmpeq_epi64(values, other.values));
149 }
150 Vectorized<int64_t> operator<(const Vectorized<int64_t>& other) const {
151 return _mm256_cmpgt_epi64(other.values, values);
152 }
153 Vectorized<int64_t> operator<=(const Vectorized<int64_t>& other) const {
154 return invert(_mm256_cmpgt_epi64(values, other.values));
155 }
156 Vectorized<int64_t> operator>(const Vectorized<int64_t>& other) const {
157 return _mm256_cmpgt_epi64(values, other.values);
158 }
159 Vectorized<int64_t> operator>=(const Vectorized<int64_t>& other) const {
160 return invert(_mm256_cmpgt_epi64(other.values, values));
161 }
162
163 Vectorized<int64_t> eq(const Vectorized<int64_t>& other) const;
164 Vectorized<int64_t> ne(const Vectorized<int64_t>& other) const;
165 Vectorized<int64_t> gt(const Vectorized<int64_t>& other) const;
166 Vectorized<int64_t> ge(const Vectorized<int64_t>& other) const;
167 Vectorized<int64_t> lt(const Vectorized<int64_t>& other) const;
168 Vectorized<int64_t> le(const Vectorized<int64_t>& other) const;
169 };
170
171 template <>
172 class Vectorized<int32_t> : public Vectorizedi {
173 private:
174 static const Vectorized<int32_t> ones;
175 public:
176 using value_type = int32_t;
177 using size_type = int;
size()178 static constexpr int size() {
179 return 8;
180 }
181 using Vectorizedi::Vectorizedi;
Vectorized()182 Vectorized() {}
Vectorized(int32_t v)183 Vectorized(int32_t v) { values = _mm256_set1_epi32(v); }
Vectorized(int32_t val1,int32_t val2,int32_t val3,int32_t val4,int32_t val5,int32_t val6,int32_t val7,int32_t val8)184 Vectorized(int32_t val1, int32_t val2, int32_t val3, int32_t val4,
185 int32_t val5, int32_t val6, int32_t val7, int32_t val8) {
186 values = _mm256_setr_epi32(val1, val2, val3, val4, val5, val6, val7, val8);
187 }
188 template <int64_t mask>
blend(Vectorized<int32_t> a,Vectorized<int32_t> b)189 static Vectorized<int32_t> blend(Vectorized<int32_t> a, Vectorized<int32_t> b) {
190 return _mm256_blend_epi32(a, b, mask);
191 }
blendv(const Vectorized<int32_t> & a,const Vectorized<int32_t> & b,const Vectorized<int32_t> & mask)192 static Vectorized<int32_t> blendv(const Vectorized<int32_t>& a, const Vectorized<int32_t>& b,
193 const Vectorized<int32_t>& mask) {
194 return _mm256_blendv_epi8(a.values, b.values, mask.values);
195 }
196 template <typename step_t>
197 static Vectorized<int32_t> arange(int32_t base = 0, step_t step = static_cast<step_t>(1)) {
198 return Vectorized<int32_t>(
199 base, base + step, base + 2 * step, base + 3 * step,
200 base + 4 * step, base + 5 * step, base + 6 * step, base + 7 * step);
201 }
202 static Vectorized<int32_t>
203 set(Vectorized<int32_t> a, Vectorized<int32_t> b, int32_t count = size()) {
204 switch (count) {
205 case 0:
206 return a;
207 case 1:
208 return blend<1>(a, b);
209 case 2:
210 return blend<3>(a, b);
211 case 3:
212 return blend<7>(a, b);
213 case 4:
214 return blend<15>(a, b);
215 case 5:
216 return blend<31>(a, b);
217 case 6:
218 return blend<63>(a, b);
219 case 7:
220 return blend<127>(a, b);
221 }
222 return b;
223 }
loadu(const void * ptr)224 static Vectorized<int32_t> loadu(const void* ptr) {
225 return _mm256_loadu_si256(reinterpret_cast<const __m256i*>(ptr));
226 }
loadu(const void * ptr,int32_t count)227 static Vectorized<int32_t> loadu(const void* ptr, int32_t count) {
228 __at_align__ int32_t tmp_values[size()];
229 // Ensure uninitialized memory does not change the output value See https://github.com/pytorch/pytorch/issues/32502
230 // for more details. We do not initialize arrays to zero using "={0}" because gcc would compile it to two
231 // instructions while a loop would be compiled to one instruction.
232 for (size_t i = 0; i < size(); ++i) {
233 tmp_values[i] = 0;
234 }
235 std::memcpy(tmp_values, ptr, count * sizeof(int32_t));
236 return loadu(tmp_values);
237 }
238 void store(void* ptr, int count = size()) const {
239 if (count == size()) {
240 // ptr need not to be aligned here. See
241 // https://software.intel.com/content/www/us/en/develop/documentation/cpp-compiler-developer-guide-and-reference/top/compiler-reference/intrinsics/intrinsics-for-intel-advanced-vector-extensions/intrinsics-for-load-and-store-operations-1/mm256-storeu-si256.html
242 _mm256_storeu_si256(reinterpret_cast<__m256i*>(ptr), values);
243 } else if (count > 0) {
244 __at_align__ int32_t tmp_values[size()];
245 _mm256_storeu_si256(reinterpret_cast<__m256i*>(tmp_values), values);
246 std::memcpy(ptr, tmp_values, count * sizeof(int32_t));
247 }
248 }
249 const int32_t& operator[](int idx) const = delete;
250 int32_t& operator[](int idx) = delete;
abs()251 Vectorized<int32_t> abs() const {
252 return _mm256_abs_epi32(values);
253 }
real()254 Vectorized<int32_t> real() const {
255 return *this;
256 }
imag()257 Vectorized<int32_t> imag() const {
258 return _mm256_set1_epi32(0);
259 }
conj()260 Vectorized<int32_t> conj() const {
261 return *this;
262 }
263 Vectorized<int32_t> neg() const;
264 Vectorized<int32_t> operator==(const Vectorized<int32_t>& other) const {
265 return _mm256_cmpeq_epi32(values, other.values);
266 }
267 Vectorized<int32_t> operator!=(const Vectorized<int32_t>& other) const {
268 return invert(_mm256_cmpeq_epi32(values, other.values));
269 }
270 Vectorized<int32_t> operator<(const Vectorized<int32_t>& other) const {
271 return _mm256_cmpgt_epi32(other.values, values);
272 }
273 Vectorized<int32_t> operator<=(const Vectorized<int32_t>& other) const {
274 return invert(_mm256_cmpgt_epi32(values, other.values));
275 }
276 Vectorized<int32_t> operator>(const Vectorized<int32_t>& other) const {
277 return _mm256_cmpgt_epi32(values, other.values);
278 }
279 Vectorized<int32_t> operator>=(const Vectorized<int32_t>& other) const {
280 return invert(_mm256_cmpgt_epi32(other.values, values));
281 }
282 Vectorized<int32_t> eq(const Vectorized<int32_t>& other) const;
283 Vectorized<int32_t> ne(const Vectorized<int32_t>& other) const;
284 Vectorized<int32_t> gt(const Vectorized<int32_t>& other) const;
285 Vectorized<int32_t> ge(const Vectorized<int32_t>& other) const;
286 Vectorized<int32_t> lt(const Vectorized<int32_t>& other) const;
287 Vectorized<int32_t> le(const Vectorized<int32_t>& other) const;
288 };
289
290 template <>
convert(const int32_t * src,float * dst,int64_t n)291 inline void convert(const int32_t *src, float *dst, int64_t n) {
292 int64_t i;
293 // int32_t and float have same size
294 #ifndef _MSC_VER
295 # pragma unroll
296 #endif
297 for (i = 0; i <= (n - Vectorized<int32_t>::size()); i += Vectorized<int32_t>::size()) {
298 auto input_vec = _mm256_loadu_si256(reinterpret_cast<const __m256i*>(src + i));
299 auto output_vec = _mm256_cvtepi32_ps(input_vec);
300 _mm256_storeu_ps(reinterpret_cast<float*>(dst + i), output_vec);
301 }
302 #ifndef _MSC_VER
303 # pragma unroll
304 #endif
305 for (; i < n; i++) {
306 dst[i] = static_cast<float>(src[i]);
307 }
308 }
309
310 template <>
convert(const int32_t * src,double * dst,int64_t n)311 inline void convert(const int32_t *src, double *dst, int64_t n) {
312 int64_t i;
313 // int32_t has half the size of double
314 #ifndef _MSC_VER
315 # pragma unroll
316 #endif
317 for (i = 0; i <= (n - Vectorized<double>::size()); i += Vectorized<double>::size()) {
318 auto input_128_vec = _mm_loadu_si128(reinterpret_cast<const __m128i*>(src + i));
319 auto output_vec = _mm256_cvtepi32_pd(input_128_vec);
320 _mm256_storeu_pd(reinterpret_cast<double*>(dst + i), output_vec);
321 }
322 #ifndef _MSC_VER
323 # pragma unroll
324 #endif
325 for (; i < n; i++) {
326 dst[i] = static_cast<double>(src[i]);
327 }
328 }
329
330 template <>
331 class Vectorized<int16_t> : public Vectorizedi {
332 private:
333 static const Vectorized<int16_t> ones;
334 public:
335 using value_type = int16_t;
336 using size_type = int;
size()337 static constexpr int size() {
338 return 16;
339 }
340 using Vectorizedi::Vectorizedi;
Vectorized()341 Vectorized() {}
Vectorized(int16_t v)342 Vectorized(int16_t v) { values = _mm256_set1_epi16(v); }
Vectorized(int16_t val1,int16_t val2,int16_t val3,int16_t val4,int16_t val5,int16_t val6,int16_t val7,int16_t val8,int16_t val9,int16_t val10,int16_t val11,int16_t val12,int16_t val13,int16_t val14,int16_t val15,int16_t val16)343 Vectorized(int16_t val1, int16_t val2, int16_t val3, int16_t val4,
344 int16_t val5, int16_t val6, int16_t val7, int16_t val8,
345 int16_t val9, int16_t val10, int16_t val11, int16_t val12,
346 int16_t val13, int16_t val14, int16_t val15, int16_t val16) {
347 values = _mm256_setr_epi16(val1, val2, val3, val4, val5, val6, val7, val8,
348 val9, val10, val11, val12, val13, val14, val15, val16);
349 }
350 template <int64_t mask>
blend(Vectorized<int16_t> a,Vectorized<int16_t> b)351 static Vectorized<int16_t> blend(Vectorized<int16_t> a, Vectorized<int16_t> b) {
352 __at_align__ int16_t tmp_values[size()];
353 a.store(tmp_values);
354 if (mask & 0x01)
355 tmp_values[0] = _mm256_extract_epi16(b.values, 0);
356 if (mask & 0x02)
357 tmp_values[1] = _mm256_extract_epi16(b.values, 1);
358 if (mask & 0x04)
359 tmp_values[2] = _mm256_extract_epi16(b.values, 2);
360 if (mask & 0x08)
361 tmp_values[3] = _mm256_extract_epi16(b.values, 3);
362 if (mask & 0x10)
363 tmp_values[4] = _mm256_extract_epi16(b.values, 4);
364 if (mask & 0x20)
365 tmp_values[5] = _mm256_extract_epi16(b.values, 5);
366 if (mask & 0x40)
367 tmp_values[6] = _mm256_extract_epi16(b.values, 6);
368 if (mask & 0x80)
369 tmp_values[7] = _mm256_extract_epi16(b.values, 7);
370 if (mask & 0x100)
371 tmp_values[8] = _mm256_extract_epi16(b.values, 8);
372 if (mask & 0x200)
373 tmp_values[9] = _mm256_extract_epi16(b.values, 9);
374 if (mask & 0x400)
375 tmp_values[10] = _mm256_extract_epi16(b.values, 10);
376 if (mask & 0x800)
377 tmp_values[11] = _mm256_extract_epi16(b.values, 11);
378 if (mask & 0x1000)
379 tmp_values[12] = _mm256_extract_epi16(b.values, 12);
380 if (mask & 0x2000)
381 tmp_values[13] = _mm256_extract_epi16(b.values, 13);
382 if (mask & 0x4000)
383 tmp_values[14] = _mm256_extract_epi16(b.values, 14);
384 if (mask & 0x8000)
385 tmp_values[15] = _mm256_extract_epi16(b.values, 15);
386 return loadu(tmp_values);
387 }
blendv(const Vectorized<int16_t> & a,const Vectorized<int16_t> & b,const Vectorized<int16_t> & mask)388 static Vectorized<int16_t> blendv(const Vectorized<int16_t>& a, const Vectorized<int16_t>& b,
389 const Vectorized<int16_t>& mask) {
390 return _mm256_blendv_epi8(a.values, b.values, mask.values);
391 }
392 template <typename step_t>
393 static Vectorized<int16_t> arange(int16_t base = 0, step_t step = static_cast<step_t>(1)) {
394 return Vectorized<int16_t>(
395 base, base + step, base + 2 * step, base + 3 * step,
396 base + 4 * step, base + 5 * step, base + 6 * step, base + 7 * step,
397 base + 8 * step, base + 9 * step, base + 10 * step, base + 11 * step,
398 base + 12 * step, base + 13 * step, base + 14 * step, base + 15 * step);
399 }
400 static Vectorized<int16_t>
401 set(Vectorized<int16_t> a, Vectorized<int16_t> b, int16_t count = size()) {
402 switch (count) {
403 case 0:
404 return a;
405 case 1:
406 return blend<1>(a, b);
407 case 2:
408 return blend<3>(a, b);
409 case 3:
410 return blend<7>(a, b);
411 case 4:
412 return blend<15>(a, b);
413 case 5:
414 return blend<31>(a, b);
415 case 6:
416 return blend<63>(a, b);
417 case 7:
418 return blend<127>(a, b);
419 case 8:
420 return blend<255>(a, b);
421 case 9:
422 return blend<511>(a, b);
423 case 10:
424 return blend<1023>(a, b);
425 case 11:
426 return blend<2047>(a, b);
427 case 12:
428 return blend<4095>(a, b);
429 case 13:
430 return blend<8191>(a, b);
431 case 14:
432 return blend<16383>(a, b);
433 case 15:
434 return blend<32767>(a, b);
435 }
436 return b;
437 }
loadu(const void * ptr)438 static Vectorized<int16_t> loadu(const void* ptr) {
439 return _mm256_loadu_si256(reinterpret_cast<const __m256i*>(ptr));
440 }
loadu(const void * ptr,int16_t count)441 static Vectorized<int16_t> loadu(const void* ptr, int16_t count) {
442 __at_align__ int16_t tmp_values[size()];
443 // Ensure uninitialized memory does not change the output value See https://github.com/pytorch/pytorch/issues/32502
444 // for more details. We do not initialize arrays to zero using "={0}" because gcc would compile it to two
445 // instructions while a loop would be compiled to one instruction.
446 for (size_t i = 0; i < size(); ++i) {
447 tmp_values[i] = 0;
448 }
449 std::memcpy(tmp_values, ptr, count * sizeof(int16_t));
450 return loadu(tmp_values);
451 }
452 void store(void* ptr, int count = size()) const {
453 if (count == size()) {
454 // ptr need not to be aligned here. See
455 // https://software.intel.com/content/www/us/en/develop/documentation/cpp-compiler-developer-guide-and-reference/top/compiler-reference/intrinsics/intrinsics-for-intel-advanced-vector-extensions/intrinsics-for-load-and-store-operations-1/mm256-storeu-si256.html
456 _mm256_storeu_si256(reinterpret_cast<__m256i*>(ptr), values);
457 } else if (count > 0) {
458 __at_align__ int16_t tmp_values[size()];
459 _mm256_storeu_si256(reinterpret_cast<__m256i*>(tmp_values), values);
460 std::memcpy(ptr, tmp_values, count * sizeof(int16_t));
461 }
462 }
463 const int16_t& operator[](int idx) const = delete;
464 int16_t& operator[](int idx) = delete;
abs()465 Vectorized<int16_t> abs() const {
466 return _mm256_abs_epi16(values);
467 }
real()468 Vectorized<int16_t> real() const {
469 return *this;
470 }
imag()471 Vectorized<int16_t> imag() const {
472 return _mm256_set1_epi16(0);
473 }
conj()474 Vectorized<int16_t> conj() const {
475 return *this;
476 }
477 Vectorized<int16_t> neg() const;
478 Vectorized<int16_t> operator==(const Vectorized<int16_t>& other) const {
479 return _mm256_cmpeq_epi16(values, other.values);
480 }
481 Vectorized<int16_t> operator!=(const Vectorized<int16_t>& other) const {
482 return invert(_mm256_cmpeq_epi16(values, other.values));
483 }
484 Vectorized<int16_t> operator<(const Vectorized<int16_t>& other) const {
485 return _mm256_cmpgt_epi16(other.values, values);
486 }
487 Vectorized<int16_t> operator<=(const Vectorized<int16_t>& other) const {
488 return invert(_mm256_cmpgt_epi16(values, other.values));
489 }
490 Vectorized<int16_t> operator>(const Vectorized<int16_t>& other) const {
491 return _mm256_cmpgt_epi16(values, other.values);
492 }
493 Vectorized<int16_t> operator>=(const Vectorized<int16_t>& other) const {
494 return invert(_mm256_cmpgt_epi16(other.values, values));
495 }
496
497 Vectorized<int16_t> eq(const Vectorized<int16_t>& other) const;
498 Vectorized<int16_t> ne(const Vectorized<int16_t>& other) const;
499 Vectorized<int16_t> gt(const Vectorized<int16_t>& other) const;
500 Vectorized<int16_t> ge(const Vectorized<int16_t>& other) const;
501 Vectorized<int16_t> lt(const Vectorized<int16_t>& other) const;
502 Vectorized<int16_t> le(const Vectorized<int16_t>& other) const;
503 };
504
505 template <typename T>
506 class Vectorized8 : public Vectorizedi {
507 static_assert(
508 std::is_same<T, int8_t>::value || std::is_same<T, uint8_t>::value,
509 "Only int8_t/uint8_t are supported");
510 protected:
511 static const Vectorized<T> ones;
512 public:
513 using value_type = T;
514 using size_type = int;
size()515 static constexpr int size() {
516 return 32;
517 }
518 using Vectorizedi::Vectorizedi;
Vectorized8()519 Vectorized8() {}
Vectorized8(T v)520 Vectorized8(T v) { values = _mm256_set1_epi8(v); }
Vectorized8(T val1,T val2,T val3,T val4,T val5,T val6,T val7,T val8,T val9,T val10,T val11,T val12,T val13,T val14,T val15,T val16,T val17,T val18,T val19,T val20,T val21,T val22,T val23,T val24,T val25,T val26,T val27,T val28,T val29,T val30,T val31,T val32)521 Vectorized8(T val1, T val2, T val3, T val4,
522 T val5, T val6, T val7, T val8,
523 T val9, T val10, T val11, T val12,
524 T val13, T val14, T val15, T val16,
525 T val17, T val18, T val19, T val20,
526 T val21, T val22, T val23, T val24,
527 T val25, T val26, T val27, T val28,
528 T val29, T val30, T val31, T val32) {
529 values = _mm256_setr_epi8(val1, val2, val3, val4, val5, val6, val7, val8,
530 val9, val10, val11, val12, val13, val14, val15, val16,
531 val17, val18, val19, val20, val21, val22, val23, val24,
532 val25, val26, val27, val28, val29, val30, val31, val32);
533 }
534 template <int64_t mask>
blend(Vectorized<T> a,Vectorized<T> b)535 static Vectorized<T> blend(Vectorized<T> a, Vectorized<T> b) {
536 __at_align__ T tmp_values[size()];
537 a.store(tmp_values);
538 if (mask & 0x01)
539 tmp_values[0] = _mm256_extract_epi8(b.values, 0);
540 if (mask & 0x02)
541 tmp_values[1] = _mm256_extract_epi8(b.values, 1);
542 if (mask & 0x04)
543 tmp_values[2] = _mm256_extract_epi8(b.values, 2);
544 if (mask & 0x08)
545 tmp_values[3] = _mm256_extract_epi8(b.values, 3);
546 if (mask & 0x10)
547 tmp_values[4] = _mm256_extract_epi8(b.values, 4);
548 if (mask & 0x20)
549 tmp_values[5] = _mm256_extract_epi8(b.values, 5);
550 if (mask & 0x40)
551 tmp_values[6] = _mm256_extract_epi8(b.values, 6);
552 if (mask & 0x80)
553 tmp_values[7] = _mm256_extract_epi8(b.values, 7);
554 if (mask & 0x100)
555 tmp_values[8] = _mm256_extract_epi8(b.values, 8);
556 if (mask & 0x200)
557 tmp_values[9] = _mm256_extract_epi8(b.values, 9);
558 if (mask & 0x400)
559 tmp_values[10] = _mm256_extract_epi8(b.values, 10);
560 if (mask & 0x800)
561 tmp_values[11] = _mm256_extract_epi8(b.values, 11);
562 if (mask & 0x1000)
563 tmp_values[12] = _mm256_extract_epi8(b.values, 12);
564 if (mask & 0x2000)
565 tmp_values[13] = _mm256_extract_epi8(b.values, 13);
566 if (mask & 0x4000)
567 tmp_values[14] = _mm256_extract_epi8(b.values, 14);
568 if (mask & 0x8000)
569 tmp_values[15] = _mm256_extract_epi8(b.values, 15);
570 if (mask & 0x010000)
571 tmp_values[16] = _mm256_extract_epi8(b.values, 16);
572 if (mask & 0x020000)
573 tmp_values[17] = _mm256_extract_epi8(b.values, 17);
574 if (mask & 0x040000)
575 tmp_values[18] = _mm256_extract_epi8(b.values, 18);
576 if (mask & 0x080000)
577 tmp_values[19] = _mm256_extract_epi8(b.values, 19);
578 if (mask & 0x100000)
579 tmp_values[20] = _mm256_extract_epi8(b.values, 20);
580 if (mask & 0x200000)
581 tmp_values[21] = _mm256_extract_epi8(b.values, 21);
582 if (mask & 0x400000)
583 tmp_values[22] = _mm256_extract_epi8(b.values, 22);
584 if (mask & 0x800000)
585 tmp_values[23] = _mm256_extract_epi8(b.values, 23);
586 if (mask & 0x1000000)
587 tmp_values[24] = _mm256_extract_epi8(b.values, 24);
588 if (mask & 0x2000000)
589 tmp_values[25] = _mm256_extract_epi8(b.values, 25);
590 if (mask & 0x4000000)
591 tmp_values[26] = _mm256_extract_epi8(b.values, 26);
592 if (mask & 0x8000000)
593 tmp_values[27] = _mm256_extract_epi8(b.values, 27);
594 if (mask & 0x10000000)
595 tmp_values[28] = _mm256_extract_epi8(b.values, 28);
596 if (mask & 0x20000000)
597 tmp_values[29] = _mm256_extract_epi8(b.values, 29);
598 if (mask & 0x40000000)
599 tmp_values[30] = _mm256_extract_epi8(b.values, 30);
600 if (mask & 0x80000000)
601 tmp_values[31] = _mm256_extract_epi8(b.values, 31);
602 return loadu(tmp_values);
603 }
blendv(const Vectorized<T> & a,const Vectorized<T> & b,const Vectorized<T> & mask)604 static Vectorized<T> blendv(const Vectorized<T>& a, const Vectorized<T>& b,
605 const Vectorized<T>& mask) {
606 return _mm256_blendv_epi8(a.values, b.values, mask.values);
607 }
608 template <typename step_t>
609 static Vectorized<T> arange(T base = 0, step_t step = static_cast<step_t>(1)) {
610 return Vectorized<T>(
611 base, base + step, base + 2 * step, base + 3 * step,
612 base + 4 * step, base + 5 * step, base + 6 * step, base + 7 * step,
613 base + 8 * step, base + 9 * step, base + 10 * step, base + 11 * step,
614 base + 12 * step, base + 13 * step, base + 14 * step, base + 15 * step,
615 base + 16 * step, base + 17 * step, base + 18 * step, base + 19 * step,
616 base + 20 * step, base + 21 * step, base + 22 * step, base + 23 * step,
617 base + 24 * step, base + 25 * step, base + 26 * step, base + 27 * step,
618 base + 28 * step, base + 29 * step, base + 30 * step, base + 31 * step);
619 }
620 static Vectorized<T>
621 set(Vectorized<T> a, Vectorized<T> b, T count = size()) {
622 switch (count) {
623 case 0:
624 return a;
625 case 1:
626 return blend<0x1>(a, b);
627 case 2:
628 return blend<0x3>(a, b);
629 case 3:
630 return blend<0x7>(a, b);
631 case 4:
632 return blend<0xF>(a, b);
633 case 5:
634 return blend<0x1F>(a, b);
635 case 6:
636 return blend<0x3F>(a, b);
637 case 7:
638 return blend<0x7F>(a, b);
639 case 8:
640 return blend<0xFF>(a, b);
641 case 9:
642 return blend<0x1FF>(a, b);
643 case 10:
644 return blend<0x3FF>(a, b);
645 case 11:
646 return blend<0x7FF>(a, b);
647 case 12:
648 return blend<0xFFF>(a, b);
649 case 13:
650 return blend<0x1FFF>(a, b);
651 case 14:
652 return blend<0x3FFF>(a, b);
653 case 15:
654 return blend<0x7FFF>(a, b);
655 case 16:
656 return blend<0xFFFF>(a, b);
657 case 17:
658 return blend<0x1FFFF>(a, b);
659 case 18:
660 return blend<0x3FFFF>(a, b);
661 case 19:
662 return blend<0x7FFFF>(a, b);
663 case 20:
664 return blend<0xFFFFF>(a, b);
665 case 21:
666 return blend<0x1FFFFF>(a, b);
667 case 22:
668 return blend<0x3FFFFF>(a, b);
669 case 23:
670 return blend<0x7FFFFF>(a, b);
671 case 24:
672 return blend<0xFFFFFF>(a, b);
673 case 25:
674 return blend<0x1FFFFFF>(a, b);
675 case 26:
676 return blend<0x3FFFFFF>(a, b);
677 case 27:
678 return blend<0x7FFFFFF>(a, b);
679 case 28:
680 return blend<0xFFFFFFF>(a, b);
681 case 29:
682 return blend<0x1FFFFFFF>(a, b);
683 case 30:
684 return blend<0x3FFFFFFF>(a, b);
685 case 31:
686 return blend<0x7FFFFFFF>(a, b);
687 }
688 return b;
689 }
loadu(const void * ptr)690 static Vectorized<T> loadu(const void* ptr) {
691 return _mm256_loadu_si256(reinterpret_cast<const __m256i*>(ptr));
692 }
loadu(const void * ptr,T count)693 static Vectorized<T> loadu(const void* ptr, T count) {
694 __at_align__ T tmp_values[size()];
695 // Ensure uninitialized memory does not change the output value See https://github.com/pytorch/pytorch/issues/32502
696 // for more details. We do not initialize arrays to zero using "={0}" because gcc would compile it to two
697 // instructions while a loop would be compiled to one instruction.
698 for (size_t i = 0; i < size(); ++i) {
699 tmp_values[i] = 0;
700 }
701 std::memcpy(tmp_values, ptr, count * sizeof(T));
702 return loadu(tmp_values);
703 }
704 void store(void* ptr, int count = size()) const {
705 if (count == size()) {
706 // ptr need not to be aligned here. See
707 // https://software.intel.com/content/www/us/en/develop/documentation/cpp-compiler-developer-guide-and-reference/top/compiler-reference/intrinsics/intrinsics-for-intel-advanced-vector-extensions/intrinsics-for-load-and-store-operations-1/mm256-storeu-si256.html
708 _mm256_storeu_si256(reinterpret_cast<__m256i*>(ptr), values);
709 } else if (count > 0) {
710 __at_align__ T tmp_values[size()];
711 _mm256_storeu_si256(reinterpret_cast<__m256i*>(tmp_values), values);
712 std::memcpy(ptr, tmp_values, count * sizeof(T));
713 }
714 }
715 const T& operator[](int idx) const = delete;
716 T& operator[](int idx) = delete;
real()717 Vectorized<T> real() const {
718 return *this;
719 }
imag()720 Vectorized<T> imag() const {
721 return _mm256_set1_epi8(0);
722 }
conj()723 Vectorized<T> conj() const {
724 return *this;
725 }
726 };
727
728 template<>
729 class Vectorized<int8_t>: public Vectorized8<int8_t> {
730 public:
731 using Vectorized8::Vectorized8;
732
733 Vectorized<int8_t> neg() const;
734
abs()735 Vectorized<int8_t> abs() const {
736 return _mm256_abs_epi8(values);
737 }
738
739 Vectorized<int8_t> operator==(const Vectorized<int8_t>& other) const {
740 return _mm256_cmpeq_epi8(values, other.values);
741 }
742 Vectorized<int8_t> operator!=(const Vectorized<int8_t>& other) const {
743 return invert(_mm256_cmpeq_epi8(values, other.values));
744 }
745 Vectorized<int8_t> operator<(const Vectorized<int8_t>& other) const {
746 return _mm256_cmpgt_epi8(other.values, values);
747 }
748 Vectorized<int8_t> operator<=(const Vectorized<int8_t>& other) const {
749 return invert(_mm256_cmpgt_epi8(values, other.values));
750 }
751 Vectorized<int8_t> operator>(const Vectorized<int8_t>& other) const {
752 return other < *this;
753 }
754 Vectorized<int8_t> operator>=(const Vectorized<int8_t>& other) const {
755 return other <= *this;
756 }
757
758 Vectorized<int8_t> eq(const Vectorized<int8_t>& other) const;
759 Vectorized<int8_t> ne(const Vectorized<int8_t>& other) const;
760 Vectorized<int8_t> gt(const Vectorized<int8_t>& other) const;
761 Vectorized<int8_t> ge(const Vectorized<int8_t>& other) const;
762 Vectorized<int8_t> lt(const Vectorized<int8_t>& other) const;
763 Vectorized<int8_t> le(const Vectorized<int8_t>& other) const;
764 };
765
766 template<>
767 class Vectorized<uint8_t>: public Vectorized8<uint8_t> {
768 public:
769 using Vectorized8::Vectorized8;
770
771 Vectorized<uint8_t> neg() const;
772
abs()773 Vectorized<uint8_t> abs() const {
774 return *this;
775 }
776
777 Vectorized<uint8_t> operator==(const Vectorized<uint8_t>& other) const {
778 return _mm256_cmpeq_epi8(values, other.values);
779 }
780 Vectorized<uint8_t> operator!=(const Vectorized<uint8_t>& other) const {
781 return invert(_mm256_cmpeq_epi8(values, other.values));
782 }
783 Vectorized<uint8_t> operator<(const Vectorized<uint8_t>& other) const {
784 __m256i max = _mm256_max_epu8(values, other.values);
785 return invert(_mm256_cmpeq_epi8(max, values));
786 }
787 Vectorized<uint8_t> operator<=(const Vectorized<uint8_t>& other) const {
788 __m256i max = _mm256_max_epu8(values, other.values);
789 return _mm256_cmpeq_epi8(max, other.values);
790 }
791 Vectorized<uint8_t> operator>(const Vectorized<uint8_t>& other) const {
792 return other < *this;
793 }
794 Vectorized<uint8_t> operator>=(const Vectorized<uint8_t>& other) const {
795 return other <= *this;
796 }
797
798 Vectorized<uint8_t> eq(const Vectorized<uint8_t>& other) const;
799 Vectorized<uint8_t> ne(const Vectorized<uint8_t>& other) const;
800 Vectorized<uint8_t> gt(const Vectorized<uint8_t>& other) const;
801 Vectorized<uint8_t> ge(const Vectorized<uint8_t>& other) const;
802 Vectorized<uint8_t> lt(const Vectorized<uint8_t>& other) const;
803 Vectorized<uint8_t> le(const Vectorized<uint8_t>& other) const;
804 };
805
806 template <>
807 Vectorized<int64_t> inline operator+(const Vectorized<int64_t>& a, const Vectorized<int64_t>& b) {
808 return _mm256_add_epi64(a, b);
809 }
810
811 template <>
812 Vectorized<int32_t> inline operator+(const Vectorized<int32_t>& a, const Vectorized<int32_t>& b) {
813 return _mm256_add_epi32(a, b);
814 }
815
816 template <>
817 Vectorized<int16_t> inline operator+(const Vectorized<int16_t>& a, const Vectorized<int16_t>& b) {
818 return _mm256_add_epi16(a, b);
819 }
820
821 template <>
822 Vectorized<int8_t> inline operator+(const Vectorized<int8_t>& a, const Vectorized<int8_t>& b) {
823 return _mm256_add_epi8(a, b);
824 }
825
826 template <>
827 Vectorized<uint8_t> inline operator+(const Vectorized<uint8_t>& a, const Vectorized<uint8_t>& b) {
828 return _mm256_add_epi8(a, b);
829 }
830
831 template <>
832 Vectorized<int64_t> inline operator-(const Vectorized<int64_t>& a, const Vectorized<int64_t>& b) {
833 return _mm256_sub_epi64(a, b);
834 }
835
836 template <>
837 Vectorized<int32_t> inline operator-(const Vectorized<int32_t>& a, const Vectorized<int32_t>& b) {
838 return _mm256_sub_epi32(a, b);
839 }
840
841 template <>
842 Vectorized<int16_t> inline operator-(const Vectorized<int16_t>& a, const Vectorized<int16_t>& b) {
843 return _mm256_sub_epi16(a, b);
844 }
845
846 template <>
847 Vectorized<int8_t> inline operator-(const Vectorized<int8_t>& a, const Vectorized<int8_t>& b) {
848 return _mm256_sub_epi8(a, b);
849 }
850
851 template <>
852 Vectorized<uint8_t> inline operator-(const Vectorized<uint8_t>& a, const Vectorized<uint8_t>& b) {
853 return _mm256_sub_epi8(a, b);
854 }
855
856 // Negation. Defined here so we can utilize operator-
neg()857 inline Vectorized<int64_t> Vectorized<int64_t>::neg() const {
858 return Vectorized<int64_t>(0) - *this;
859 }
860
neg()861 inline Vectorized<int32_t> Vectorized<int32_t>::neg() const {
862 return Vectorized<int32_t>(0) - *this;
863 }
864
neg()865 inline Vectorized<int16_t> Vectorized<int16_t>::neg() const {
866 return Vectorized<int16_t>(0) - *this;
867 }
868
neg()869 inline Vectorized<int8_t> Vectorized<int8_t>::neg() const {
870 return Vectorized<int8_t>(0) - *this;
871 }
872
neg()873 inline Vectorized<uint8_t> Vectorized<uint8_t>::neg() const {
874 return Vectorized<uint8_t>(0) - *this;
875 }
876
877 // Emulate operations with no native 64-bit support in avx,
878 // by extracting each element, performing the operation pointwise,
879 // then combining the results into a vector.
880 template <typename op_t>
emulate(const Vectorized<int64_t> & a,const Vectorized<int64_t> & b,const op_t & op)881 Vectorized<int64_t> inline emulate(const Vectorized<int64_t>& a, const Vectorized<int64_t>& b, const op_t& op) {
882 int64_t a0 = _mm256_extract_epi64(a, 0);
883 int64_t a1 = _mm256_extract_epi64(a, 1);
884 int64_t a2 = _mm256_extract_epi64(a, 2);
885 int64_t a3 = _mm256_extract_epi64(a, 3);
886
887 int64_t b0 = _mm256_extract_epi64(b, 0);
888 int64_t b1 = _mm256_extract_epi64(b, 1);
889 int64_t b2 = _mm256_extract_epi64(b, 2);
890 int64_t b3 = _mm256_extract_epi64(b, 3);
891
892 int64_t c0 = op(a0, b0);
893 int64_t c1 = op(a1, b1);
894 int64_t c2 = op(a2, b2);
895 int64_t c3 = op(a3, b3);
896
897 return _mm256_set_epi64x(c3, c2, c1, c0);
898 }
899
900 template <typename op_t>
emulate(const Vectorized<int64_t> & a,const Vectorized<int64_t> & b,const Vectorized<int64_t> & c,const op_t & op)901 Vectorized<int64_t> inline emulate(const Vectorized<int64_t>& a, const Vectorized<int64_t>& b, const Vectorized<int64_t>& c, const op_t& op) {
902 int64_t a0 = _mm256_extract_epi64(a, 0);
903 int64_t a1 = _mm256_extract_epi64(a, 1);
904 int64_t a2 = _mm256_extract_epi64(a, 2);
905 int64_t a3 = _mm256_extract_epi64(a, 3);
906
907 int64_t b0 = _mm256_extract_epi64(b, 0);
908 int64_t b1 = _mm256_extract_epi64(b, 1);
909 int64_t b2 = _mm256_extract_epi64(b, 2);
910 int64_t b3 = _mm256_extract_epi64(b, 3);
911
912 int64_t c0 = _mm256_extract_epi64(c, 0);
913 int64_t c1 = _mm256_extract_epi64(c, 1);
914 int64_t c2 = _mm256_extract_epi64(c, 2);
915 int64_t c3 = _mm256_extract_epi64(c, 3);
916
917 int64_t d0 = op(a0, b0, c0);
918 int64_t d1 = op(a1, b1, c1);
919 int64_t d2 = op(a2, b2, c2);
920 int64_t d3 = op(a3, b3, c3);
921
922 return _mm256_set_epi64x(d3, d2, d1, d0);
923 }
924
925 // AVX2 has no intrinsic for int64_t multiply so it needs to be emulated
926 // This could be implemented more efficiently using epi32 instructions
927 // This is also technically avx compatible, but then we'll need AVX
928 // code for add as well.
929 // Note: intentionally ignores undefined behavior like (-lowest * -1).
930 template <>
931 Vectorized<int64_t> inline operator*(const Vectorized<int64_t>& a, const Vectorized<int64_t>& b) {
932 return emulate(a, b, [](int64_t a_point, int64_t b_point) {return a_point * b_point;});
933 }
934
935 template <>
936 Vectorized<int32_t> inline operator*(const Vectorized<int32_t>& a, const Vectorized<int32_t>& b) {
937 return _mm256_mullo_epi32(a, b);
938 }
939
940 template <>
941 Vectorized<int16_t> inline operator*(const Vectorized<int16_t>& a, const Vectorized<int16_t>& b) {
942 return _mm256_mullo_epi16(a, b);
943 }
944
945 template <typename T, typename Op>
int_elementwise_binary_256(const Vectorized<T> & a,const Vectorized<T> & b,Op op)946 Vectorized<T> inline int_elementwise_binary_256(const Vectorized<T>& a, const Vectorized<T>& b, Op op) {
947 T values_a[Vectorized<T>::size()];
948 T values_b[Vectorized<T>::size()];
949 a.store(values_a);
950 b.store(values_b);
951 for (size_t i = 0; i != Vectorized<T>::size(); i++) {
952 values_a[i] = op(values_a[i], values_b[i]);
953 }
954 return Vectorized<T>::loadu(values_a);
955 }
956
957 template <>
958 Vectorized<int8_t> inline operator*(const Vectorized<int8_t>& a, const Vectorized<int8_t>& b) {
959 // We don't have an instruction for multiplying int8_t
960 return int_elementwise_binary_256(a, b, std::multiplies<int8_t>());
961 }
962
963 template <>
964 Vectorized<uint8_t> inline operator*(const Vectorized<uint8_t>& a, const Vectorized<uint8_t>& b) {
965 // We don't have an instruction for multiplying uint8_t
966 return int_elementwise_binary_256(a, b, std::multiplies<uint8_t>());
967 }
968
969 template <>
minimum(const Vectorized<int64_t> & a,const Vectorized<int64_t> & b)970 Vectorized<int64_t> inline minimum(const Vectorized<int64_t>& a, const Vectorized<int64_t>& b) {
971 return emulate(a, b, [](int64_t a_point, int64_t b_point) {return std::min(a_point, b_point);});
972 }
973
974 template <>
minimum(const Vectorized<int32_t> & a,const Vectorized<int32_t> & b)975 Vectorized<int32_t> inline minimum(const Vectorized<int32_t>& a, const Vectorized<int32_t>& b) {
976 return _mm256_min_epi32(a, b);
977 }
978
979 template <>
minimum(const Vectorized<int16_t> & a,const Vectorized<int16_t> & b)980 Vectorized<int16_t> inline minimum(const Vectorized<int16_t>& a, const Vectorized<int16_t>& b) {
981 return _mm256_min_epi16(a, b);
982 }
983
984 template <>
minimum(const Vectorized<int8_t> & a,const Vectorized<int8_t> & b)985 Vectorized<int8_t> inline minimum(const Vectorized<int8_t>& a, const Vectorized<int8_t>& b) {
986 return _mm256_min_epi8(a, b);
987 }
988
989 template <>
minimum(const Vectorized<uint8_t> & a,const Vectorized<uint8_t> & b)990 Vectorized<uint8_t> inline minimum(const Vectorized<uint8_t>& a, const Vectorized<uint8_t>& b) {
991 return _mm256_min_epu8(a, b);
992 }
993
994 template <>
maximum(const Vectorized<int64_t> & a,const Vectorized<int64_t> & b)995 Vectorized<int64_t> inline maximum(const Vectorized<int64_t>& a, const Vectorized<int64_t>& b) {
996 return emulate(a, b, [](int64_t a_point, int64_t b_point) {return std::max(a_point, b_point);});
997 }
998
999 template <>
maximum(const Vectorized<int32_t> & a,const Vectorized<int32_t> & b)1000 Vectorized<int32_t> inline maximum(const Vectorized<int32_t>& a, const Vectorized<int32_t>& b) {
1001 return _mm256_max_epi32(a, b);
1002 }
1003
1004 template <>
maximum(const Vectorized<int16_t> & a,const Vectorized<int16_t> & b)1005 Vectorized<int16_t> inline maximum(const Vectorized<int16_t>& a, const Vectorized<int16_t>& b) {
1006 return _mm256_max_epi16(a, b);
1007 }
1008
1009 template <>
maximum(const Vectorized<int8_t> & a,const Vectorized<int8_t> & b)1010 Vectorized<int8_t> inline maximum(const Vectorized<int8_t>& a, const Vectorized<int8_t>& b) {
1011 return _mm256_max_epi8(a, b);
1012 }
1013
1014 template <>
maximum(const Vectorized<uint8_t> & a,const Vectorized<uint8_t> & b)1015 Vectorized<uint8_t> inline maximum(const Vectorized<uint8_t>& a, const Vectorized<uint8_t>& b) {
1016 return _mm256_max_epu8(a, b);
1017 }
1018
1019 template <>
clamp(const Vectorized<int64_t> & a,const Vectorized<int64_t> & min_val,const Vectorized<int64_t> & max_val)1020 Vectorized<int64_t> inline clamp(const Vectorized<int64_t>& a, const Vectorized<int64_t>& min_val, const Vectorized<int64_t>& max_val) {
1021 return emulate(a, min_val, max_val, [](int64_t a_point, int64_t min_point, int64_t max_point) {return std::min(max_point, std::max(a_point, min_point));});
1022 }
1023
1024 template <>
clamp(const Vectorized<int32_t> & a,const Vectorized<int32_t> & min_val,const Vectorized<int32_t> & max_val)1025 Vectorized<int32_t> inline clamp(const Vectorized<int32_t>& a, const Vectorized<int32_t>& min_val, const Vectorized<int32_t>& max_val) {
1026 return _mm256_min_epi32(max_val, _mm256_max_epi32(a, min_val));
1027 }
1028
1029 template <>
clamp(const Vectorized<int16_t> & a,const Vectorized<int16_t> & min_val,const Vectorized<int16_t> & max_val)1030 Vectorized<int16_t> inline clamp(const Vectorized<int16_t>& a, const Vectorized<int16_t>& min_val, const Vectorized<int16_t>& max_val) {
1031 return _mm256_min_epi16(max_val, _mm256_max_epi16(a, min_val));
1032 }
1033
1034 template <>
clamp(const Vectorized<int8_t> & a,const Vectorized<int8_t> & min_val,const Vectorized<int8_t> & max_val)1035 Vectorized<int8_t> inline clamp(const Vectorized<int8_t>& a, const Vectorized<int8_t>& min_val, const Vectorized<int8_t>& max_val) {
1036 return _mm256_min_epi8(max_val, _mm256_max_epi8(a, min_val));
1037 }
1038
1039 template <>
clamp(const Vectorized<uint8_t> & a,const Vectorized<uint8_t> & min_val,const Vectorized<uint8_t> & max_val)1040 Vectorized<uint8_t> inline clamp(const Vectorized<uint8_t>& a, const Vectorized<uint8_t>& min_val, const Vectorized<uint8_t>& max_val) {
1041 return _mm256_min_epu8(max_val, _mm256_max_epu8(a, min_val));
1042 }
1043
1044 template <>
clamp_max(const Vectorized<int64_t> & a,const Vectorized<int64_t> & max_val)1045 Vectorized<int64_t> inline clamp_max(const Vectorized<int64_t>& a, const Vectorized<int64_t>& max_val) {
1046 return emulate(a, max_val, [](int64_t a_point, int64_t max_point) {return std::min(max_point, a_point);});
1047 }
1048
1049 template <>
clamp_max(const Vectorized<int32_t> & a,const Vectorized<int32_t> & max_val)1050 Vectorized<int32_t> inline clamp_max(const Vectorized<int32_t>& a, const Vectorized<int32_t>& max_val) {
1051 return _mm256_min_epi32(max_val, a);
1052 }
1053
1054 template <>
clamp_max(const Vectorized<int16_t> & a,const Vectorized<int16_t> & max_val)1055 Vectorized<int16_t> inline clamp_max(const Vectorized<int16_t>& a, const Vectorized<int16_t>& max_val) {
1056 return _mm256_min_epi16(max_val, a);
1057 }
1058
1059 template <>
clamp_max(const Vectorized<int8_t> & a,const Vectorized<int8_t> & max_val)1060 Vectorized<int8_t> inline clamp_max(const Vectorized<int8_t>& a, const Vectorized<int8_t>& max_val) {
1061 return _mm256_min_epi8(max_val, a);
1062 }
1063
1064 template <>
clamp_max(const Vectorized<uint8_t> & a,const Vectorized<uint8_t> & max_val)1065 Vectorized<uint8_t> inline clamp_max(const Vectorized<uint8_t>& a, const Vectorized<uint8_t>& max_val) {
1066 return _mm256_min_epu8(max_val, a);
1067 }
1068
1069 template <>
clamp_min(const Vectorized<int64_t> & a,const Vectorized<int64_t> & min_val)1070 Vectorized<int64_t> inline clamp_min(const Vectorized<int64_t>& a, const Vectorized<int64_t>& min_val) {
1071 return emulate(a, min_val, [](int64_t a_point, int64_t min_point) {return std::max(min_point, a_point);});
1072 }
1073
1074 template <>
clamp_min(const Vectorized<int32_t> & a,const Vectorized<int32_t> & min_val)1075 Vectorized<int32_t> inline clamp_min(const Vectorized<int32_t>& a, const Vectorized<int32_t>& min_val) {
1076 return _mm256_max_epi32(min_val, a);
1077 }
1078
1079 template <>
clamp_min(const Vectorized<int16_t> & a,const Vectorized<int16_t> & min_val)1080 Vectorized<int16_t> inline clamp_min(const Vectorized<int16_t>& a, const Vectorized<int16_t>& min_val) {
1081 return _mm256_max_epi16(min_val, a);
1082 }
1083
1084 template <>
clamp_min(const Vectorized<int8_t> & a,const Vectorized<int8_t> & min_val)1085 Vectorized<int8_t> inline clamp_min(const Vectorized<int8_t>& a, const Vectorized<int8_t>& min_val) {
1086 return _mm256_max_epi8(min_val, a);
1087 }
1088
1089 template <>
clamp_min(const Vectorized<uint8_t> & a,const Vectorized<uint8_t> & min_val)1090 Vectorized<uint8_t> inline clamp_min(const Vectorized<uint8_t>& a, const Vectorized<uint8_t>& min_val) {
1091 return _mm256_max_epu8(min_val, a);
1092 }
1093
1094 template<typename T>
convert_to_int32(const T * ptr)1095 Vectorized<int32_t> inline convert_to_int32(const T* ptr) {
1096 return Vectorized<int32_t>::loadu(ptr);
1097 }
1098
1099 template<>
1100 Vectorized<int32_t> inline convert_to_int32<int8_t>(const int8_t* ptr) {
1101 return _mm256_cvtepi8_epi32(_mm_loadl_epi64(reinterpret_cast<const __m128i*>(ptr)));
1102 }
1103
1104 template<>
1105 Vectorized<int32_t> inline convert_to_int32<uint8_t>(const uint8_t* ptr) {
1106 return _mm256_cvtepu8_epi32(_mm_loadl_epi64(reinterpret_cast<const __m128i*>(ptr)));
1107 }
1108
1109 template <>
1110 Vectorized<int64_t> inline operator/(const Vectorized<int64_t>& a, const Vectorized<int64_t>& b) {
1111 return int_elementwise_binary_256(a, b, std::divides<int64_t>());
1112 }
1113 template <>
1114 Vectorized<int32_t> inline operator/(const Vectorized<int32_t>& a, const Vectorized<int32_t>& b) {
1115 return int_elementwise_binary_256(a, b, std::divides<int32_t>());
1116 }
1117 template <>
1118 Vectorized<int16_t> inline operator/(const Vectorized<int16_t>& a, const Vectorized<int16_t>& b) {
1119 return int_elementwise_binary_256(a, b, std::divides<int16_t>());
1120 }
1121 template <>
1122 Vectorized<int8_t> inline operator/(const Vectorized<int8_t>& a, const Vectorized<int8_t>& b) {
1123 return int_elementwise_binary_256(a, b, std::divides<int8_t>());
1124 }
1125 template <>
1126 Vectorized<uint8_t> inline operator/(const Vectorized<uint8_t>& a, const Vectorized<uint8_t>& b) {
1127 return int_elementwise_binary_256(a, b, std::divides<uint8_t>());
1128 }
1129
1130 template<class T, typename std::enable_if_t<std::is_base_of<Vectorizedi, Vectorized<T>>::value, int> = 0>
1131 inline Vectorized<T> operator&(const Vectorized<T>& a, const Vectorized<T>& b) {
1132 return _mm256_and_si256(a, b);
1133 }
1134 template<class T, typename std::enable_if_t<std::is_base_of<Vectorizedi, Vectorized<T>>::value, int> = 0>
1135 inline Vectorized<T> operator|(const Vectorized<T>& a, const Vectorized<T>& b) {
1136 return _mm256_or_si256(a, b);
1137 }
1138 template<class T, typename std::enable_if_t<std::is_base_of<Vectorizedi, Vectorized<T>>::value, int> = 0>
1139 inline Vectorized<T> operator^(const Vectorized<T>& a, const Vectorized<T>& b) {
1140 return _mm256_xor_si256(a, b);
1141 }
1142 template<class T, typename std::enable_if_t<std::is_base_of<Vectorizedi, Vectorized<T>>::value, int> = 0>
1143 inline Vectorized<T> operator~(const Vectorized<T>& a) {
1144 return _mm256_xor_si256(a, _mm256_set1_epi32(-1));
1145 }
1146
eq(const Vectorized<int64_t> & other)1147 inline Vectorized<int64_t> Vectorized<int64_t>::eq(const Vectorized<int64_t>& other) const {
1148 return (*this == other) & Vectorized<int64_t>(1);
1149 }
1150
ne(const Vectorized<int64_t> & other)1151 inline Vectorized<int64_t> Vectorized<int64_t>::ne(const Vectorized<int64_t>& other) const {
1152 return (*this != other) & Vectorized<int64_t>(1);
1153 }
1154
gt(const Vectorized<int64_t> & other)1155 inline Vectorized<int64_t> Vectorized<int64_t>::gt(const Vectorized<int64_t>& other) const {
1156 return (*this > other) & Vectorized<int64_t>(1);
1157 }
1158
ge(const Vectorized<int64_t> & other)1159 inline Vectorized<int64_t> Vectorized<int64_t>::ge(const Vectorized<int64_t>& other) const {
1160 return (*this >= other) & Vectorized<int64_t>(1);
1161 }
1162
lt(const Vectorized<int64_t> & other)1163 inline Vectorized<int64_t> Vectorized<int64_t>::lt(const Vectorized<int64_t>& other) const {
1164 return (*this < other) & Vectorized<int64_t>(1);
1165 }
1166
le(const Vectorized<int64_t> & other)1167 inline Vectorized<int64_t> Vectorized<int64_t>::le(const Vectorized<int64_t>& other) const {
1168 return (*this <= other) & Vectorized<int64_t>(1);
1169 }
1170
eq(const Vectorized<int32_t> & other)1171 inline Vectorized<int32_t> Vectorized<int32_t>::eq(const Vectorized<int32_t>& other) const {
1172 return (*this == other) & Vectorized<int32_t>(1);
1173 }
1174
ne(const Vectorized<int32_t> & other)1175 inline Vectorized<int32_t> Vectorized<int32_t>::ne(const Vectorized<int32_t>& other) const {
1176 return (*this != other) & Vectorized<int32_t>(1);
1177 }
1178
gt(const Vectorized<int32_t> & other)1179 inline Vectorized<int32_t> Vectorized<int32_t>::gt(const Vectorized<int32_t>& other) const {
1180 return (*this > other) & Vectorized<int32_t>(1);
1181 }
1182
ge(const Vectorized<int32_t> & other)1183 inline Vectorized<int32_t> Vectorized<int32_t>::ge(const Vectorized<int32_t>& other) const {
1184 return (*this >= other) & Vectorized<int32_t>(1);
1185 }
1186
lt(const Vectorized<int32_t> & other)1187 inline Vectorized<int32_t> Vectorized<int32_t>::lt(const Vectorized<int32_t>& other) const {
1188 return (*this < other) & Vectorized<int32_t>(1);
1189 }
1190
le(const Vectorized<int32_t> & other)1191 inline Vectorized<int32_t> Vectorized<int32_t>::le(const Vectorized<int32_t>& other) const {
1192 return (*this <= other) & Vectorized<int32_t>(1);
1193 }
1194
eq(const Vectorized<int16_t> & other)1195 inline Vectorized<int16_t> Vectorized<int16_t>::eq(const Vectorized<int16_t>& other) const {
1196 return (*this == other) & Vectorized<int16_t>(1);
1197 }
1198
ne(const Vectorized<int16_t> & other)1199 inline Vectorized<int16_t> Vectorized<int16_t>::ne(const Vectorized<int16_t>& other) const {
1200 return (*this != other) & Vectorized<int16_t>(1);
1201 }
1202
gt(const Vectorized<int16_t> & other)1203 inline Vectorized<int16_t> Vectorized<int16_t>::gt(const Vectorized<int16_t>& other) const {
1204 return (*this > other) & Vectorized<int16_t>(1);
1205 }
1206
ge(const Vectorized<int16_t> & other)1207 inline Vectorized<int16_t> Vectorized<int16_t>::ge(const Vectorized<int16_t>& other) const {
1208 return (*this >= other) & Vectorized<int16_t>(1);
1209 }
1210
lt(const Vectorized<int16_t> & other)1211 inline Vectorized<int16_t> Vectorized<int16_t>::lt(const Vectorized<int16_t>& other) const {
1212 return (*this < other) & Vectorized<int16_t>(1);
1213 }
1214
le(const Vectorized<int16_t> & other)1215 inline Vectorized<int16_t> Vectorized<int16_t>::le(const Vectorized<int16_t>& other) const {
1216 return (*this <= other) & Vectorized<int16_t>(1);
1217 }
1218
eq(const Vectorized<int8_t> & other)1219 inline Vectorized<int8_t> Vectorized<int8_t>::eq(const Vectorized<int8_t>& other) const {
1220 return (*this == other) & Vectorized<int8_t>(1);
1221 }
1222
ne(const Vectorized<int8_t> & other)1223 inline Vectorized<int8_t> Vectorized<int8_t>::ne(const Vectorized<int8_t>& other) const {
1224 return (*this != other) & Vectorized<int8_t>(1);
1225 }
1226
gt(const Vectorized<int8_t> & other)1227 inline Vectorized<int8_t> Vectorized<int8_t>::gt(const Vectorized<int8_t>& other) const {
1228 return (*this > other) & Vectorized<int8_t>(1);
1229 }
1230
ge(const Vectorized<int8_t> & other)1231 inline Vectorized<int8_t> Vectorized<int8_t>::ge(const Vectorized<int8_t>& other) const {
1232 return (*this >= other) & Vectorized<int8_t>(1);
1233 }
1234
lt(const Vectorized<int8_t> & other)1235 inline Vectorized<int8_t> Vectorized<int8_t>::lt(const Vectorized<int8_t>& other) const {
1236 return (*this < other) & Vectorized<int8_t>(1);
1237 }
1238
le(const Vectorized<int8_t> & other)1239 inline Vectorized<int8_t> Vectorized<int8_t>::le(const Vectorized<int8_t>& other) const {
1240 return (*this <= other) & Vectorized<int8_t>(1);
1241 }
1242
eq(const Vectorized<uint8_t> & other)1243 inline Vectorized<uint8_t> Vectorized<uint8_t>::eq(const Vectorized<uint8_t>& other) const {
1244 return (*this == other) & Vectorized<uint8_t>(1);
1245 }
1246
ne(const Vectorized<uint8_t> & other)1247 inline Vectorized<uint8_t> Vectorized<uint8_t>::ne(const Vectorized<uint8_t>& other) const {
1248 return (*this != other) & Vectorized<uint8_t>(1);
1249 }
1250
gt(const Vectorized<uint8_t> & other)1251 inline Vectorized<uint8_t> Vectorized<uint8_t>::gt(const Vectorized<uint8_t>& other) const {
1252 return (*this > other) & Vectorized<uint8_t>(1);
1253 }
1254
ge(const Vectorized<uint8_t> & other)1255 inline Vectorized<uint8_t> Vectorized<uint8_t>::ge(const Vectorized<uint8_t>& other) const {
1256 return (*this >= other) & Vectorized<uint8_t>(1);
1257 }
1258
lt(const Vectorized<uint8_t> & other)1259 inline Vectorized<uint8_t> Vectorized<uint8_t>::lt(const Vectorized<uint8_t>& other) const {
1260 return (*this < other) & Vectorized<uint8_t>(1);
1261 }
1262
le(const Vectorized<uint8_t> & other)1263 inline Vectorized<uint8_t> Vectorized<uint8_t>::le(const Vectorized<uint8_t>& other) const {
1264 return (*this <= other) & Vectorized<uint8_t>(1);
1265 }
1266
1267 template <bool left_shift>
shift_256_16(const Vectorized<int16_t> & a,const Vectorized<int16_t> & b)1268 Vectorized<int16_t> inline shift_256_16(const Vectorized<int16_t>& a, const Vectorized<int16_t>& b) {
1269 // No vector instruction for shifting int16_t, so emulating it instead.
1270
1271 // Control masks for shuffle operation, treating 256 bits as an
1272 // array of 16-bit elements, and considering pairs of neighboring
1273 // elements. Specifially, a mask named "ctl_M_N" (M,N in [0,1], and
1274 // M!=N) is set so that shuffle will move element with index M from
1275 // input pair into element with index N in output pair, and element
1276 // with index M in output pair will be set to all 0s.
1277 __m256i ctl_0_1 = _mm256_set_epi8(29, 28, 0x80, 0x80, 25, 24, 0x80, 0x80,
1278 21, 20, 0x80, 0x80, 17, 16, 0x80, 0x80,
1279 13, 12, 0x80, 0x80, 9, 8, 0x80, 0x80,
1280 5, 4, 0x80, 0x80, 1, 0, 0x80, 0x80);
1281 __m256i ctl_1_0 = _mm256_set_epi8(0x80, 0x80, 31, 30, 0x80, 0x80, 27, 26,
1282 0x80, 0x80, 23, 22, 0x80, 0x80, 19, 18,
1283 0x80, 0x80, 15, 14, 0x80, 0x80, 11, 10,
1284 0x80, 0x80, 7, 6, 0x80, 0x80, 3, 2);
1285
1286 // Masks for bitwise and operation, treating 256 bits as an array of
1287 // 16-bit elements, and considering them in pairs of neighboring
1288 // elements. A mask named "keep_M" (M in [0,1]) is set so that
1289 // bitwise and will copy element with index M from input pair into
1290 // element with the same index in output pair, while the other
1291 // element in output pair will be set to all 0s.
1292 __m256i keep_0 = _mm256_set1_epi32(0xFFFF);
1293 __m256i keep_1 = _mm256_set1_epi32(0xFFFF0000);
1294
1295 // Take each 16-bit element with idx%2==0 from input array to be
1296 // shifted and extend it to 32 bits so that 0s are added to the
1297 // right. Then, perform shifting on this 32-bit number. Upper 16
1298 // bits will be proper result of shifting original 16-bit number, so
1299 // write them to result array, into the same position from which
1300 // corresponding input element is taken. Also, make sure that
1301 // result array elements with idx%2!=0 are set to all 0s.
1302 //
1303 // Note that number of bits to shift for is extended to 32 bits by
1304 // adding 0s to the left. That means this number is not properly
1305 // sign-extended for negative values. However, number of bits to
1306 // shift is treated as an unsigned integer by respective shift
1307 // intrinsics anyway so if negative then either with or without
1308 // proper sign extension, it will be interpreted as a number greater
1309 // than 32, and the shifting result will be the same.
1310 __m256i a0 = _mm256_shuffle_epi8(a, ctl_0_1);
1311 __m256i b0 = _mm256_and_si256(b, keep_0);
1312 __m256i c0;
1313 if (left_shift)
1314 c0 = _mm256_sllv_epi32(a0, b0);
1315 else
1316 c0 = _mm256_srav_epi32(a0, b0);
1317 c0 = _mm256_shuffle_epi8(c0, ctl_1_0);
1318
1319 // Peform shifting the same way for input array elements with
1320 // idx%2==1.
1321 __m256i a1 = _mm256_and_si256(a, keep_1);
1322 __m256i b1 = _mm256_shuffle_epi8(b, ctl_1_0);
1323 __m256i c1;
1324 if (left_shift)
1325 c1 = _mm256_sllv_epi32(a1, b1);
1326 else
1327 c1 = _mm256_srav_epi32(a1, b1);
1328 c1 = _mm256_and_si256(c1, keep_1);
1329
1330 // Merge partial results into the final result.
1331 __m256i c = _mm256_or_si256(c0, c1);
1332
1333 return c;
1334 }
1335
1336 template <bool left_shift, typename T, typename std::enable_if_t<std::is_same<T, int8_t>::value || std::is_same<T, uint8_t>::value, int> = 0>
shift_256_8(const Vectorized<T> & a,const Vectorized<T> & b)1337 Vectorized<T> inline shift_256_8(const Vectorized<T>& a, const Vectorized<T>& b) {
1338 // No vector instruction for shifting int8_t/uint8_t, so emulating
1339 // it instead.
1340
1341 // Control masks for shuffle operation, treating 256 bits as an
1342 // array of 8-bit elements, and considering quadruples of
1343 // neighboring elements. Specifially, a mask named "ctl_M_N" (M,N
1344 // in [0,1,2,3], and M!=N) is set so that shuffle will move element
1345 // with index M from input quadruple into element with index N in
1346 // output quadruple, and other elements in output quadruple will be
1347 // set to all 0s.
1348 __m256i ctl_0_3 = _mm256_set_epi8(28, 0x80, 0x80, 0x80, 24, 0x80, 0x80, 0x80,
1349 20, 0x80, 0x80, 0x80, 16, 0x80, 0x80, 0x80,
1350 12, 0x80, 0x80, 0x80, 8, 0x80, 0x80, 0x80,
1351 4, 0x80, 0x80, 0x80, 0, 0x80, 0x80, 0x80);
1352 __m256i ctl_1_0 = _mm256_set_epi8(0x80, 0x80, 0x80, 29, 0x80, 0x80, 0x80, 25,
1353 0x80, 0x80, 0x80, 21, 0x80, 0x80, 0x80, 17,
1354 0x80, 0x80, 0x80, 13, 0x80, 0x80, 0x80, 9,
1355 0x80, 0x80, 0x80, 5, 0x80, 0x80, 0x80, 1);
1356 __m256i ctl_1_3 = _mm256_set_epi8(29, 0x80, 0x80, 0x80, 25, 0x80, 0x80, 0x80,
1357 21, 0x80, 0x80, 0x80, 17, 0x80, 0x80, 0x80,
1358 13, 0x80, 0x80, 0x80, 9, 0x80, 0x80, 0x80,
1359 5, 0x80, 0x80, 0x80, 1, 0x80, 0x80, 0x80);
1360 __m256i ctl_2_0 = _mm256_set_epi8(0x80, 0x80, 0x80, 30, 0x80, 0x80, 0x80, 26,
1361 0x80, 0x80, 0x80, 22, 0x80, 0x80, 0x80, 18,
1362 0x80, 0x80, 0x80, 14, 0x80, 0x80, 0x80, 10,
1363 0x80, 0x80, 0x80, 6, 0x80, 0x80, 0x80, 2);
1364 __m256i ctl_2_3 = _mm256_set_epi8(30, 0x80, 0x80, 0x80, 26, 0x80, 0x80, 0x80,
1365 22, 0x80, 0x80, 0x80, 18, 0x80, 0x80, 0x80,
1366 14, 0x80, 0x80, 0x80, 10, 0x80, 0x80, 0x80,
1367 6, 0x80, 0x80, 0x80, 2, 0x80, 0x80, 0x80);
1368 __m256i ctl_3_0 = _mm256_set_epi8(0x80, 0x80, 0x80, 31, 0x80, 0x80, 0x80, 27,
1369 0x80, 0x80, 0x80, 23, 0x80, 0x80, 0x80, 19,
1370 0x80, 0x80, 0x80, 15, 0x80, 0x80, 0x80, 11,
1371 0x80, 0x80, 0x80, 7, 0x80, 0x80, 0x80, 3);
1372 __m256i ctl_3_1 = _mm256_set_epi8(0x80, 0x80, 31, 0x80, 0x80, 0x80, 27, 0x80,
1373 0x80, 0x80, 23, 0x80, 0x80, 0x80, 19, 0x80,
1374 0x80, 0x80, 15, 0x80, 0x80, 0x80, 11, 0x80,
1375 0x80, 0x80, 7, 0x80, 0x80, 0x80, 3, 0x80);
1376 __m256i ctl_3_2 = _mm256_set_epi8(0x80, 31, 0x80, 0x80, 0x80, 27, 0x80, 0x80,
1377 0x80, 23, 0x80, 0x80, 0x80, 19, 0x80, 0x80,
1378 0x80, 15, 0x80, 0x80, 0x80, 11, 0x80, 0x80,
1379 0x80, 7, 0x80, 0x80, 0x80, 3, 0x80, 0x80);
1380
1381 // Masks for bitwise and operation, treating 256 bits as an array of
1382 // 8-bit elements, and considering them in quadruples of neighboring
1383 // elements. A mask named "keep_M" (M in [0,1,2,3]) is set so that
1384 // bitwise and will copy element with index M from input quadruple
1385 // into element with the same index in output quadruple, while the
1386 // other elements in output quadruple will be set to all 0s.
1387 __m256i keep_0 = _mm256_set1_epi32(0xFF);
1388 __m256i keep_3 = _mm256_set1_epi32(0xFF000000);
1389
1390 // Take each 8-bit element with idx%4==0 from input array to be
1391 // shifted and extend it to 32 bits so that 0s are added to the
1392 // right. Then, perform shifting on this 32-bit number. Upper 8
1393 // bits will be proper result of shifting original 8-bit number, so
1394 // write them to result array, into the same position from which
1395 // corresponding input element is taken. Also, make sure that
1396 // result array elements with idx%4!=0 are set to all 0s.
1397 //
1398 // Note that number of bits to shift for is extended to 32 bits by
1399 // adding 0s to the left. That means this number is not properly
1400 // sign-extended for negative values. However, number of bits to
1401 // shift is treated as an unsigned integer by respective shift
1402 // intrinsics anyway so if negative then either with or without
1403 // proper sign extension, it will be interpreted as a number greater
1404 // than 32, and the shifting result will be the same.
1405 __m256i a0 = _mm256_shuffle_epi8(a, ctl_0_3);
1406 __m256i b0 = _mm256_and_si256(b, keep_0);
1407 __m256i c0;
1408 if (left_shift)
1409 c0 = _mm256_sllv_epi32(a0, b0);
1410 else
1411 if (std::is_same<T, int8_t>::value)
1412 c0 = _mm256_srav_epi32(a0, b0);
1413 else
1414 c0 = _mm256_srlv_epi32(a0, b0);
1415 c0 = _mm256_shuffle_epi8(c0, ctl_3_0);
1416
1417 // Peform shifting the same way for input array elements with
1418 // idx%4==1.
1419 __m256i a1 = _mm256_shuffle_epi8(a, ctl_1_3);
1420 __m256i b1 = _mm256_shuffle_epi8(b, ctl_1_0);
1421 __m256i c1;
1422 if (left_shift)
1423 c1 = _mm256_sllv_epi32(a1, b1);
1424 else
1425 if (std::is_same<T, int8_t>::value)
1426 c1 = _mm256_srav_epi32(a1, b1);
1427 else
1428 c1 = _mm256_srlv_epi32(a1, b1);
1429 c1 = _mm256_shuffle_epi8(c1, ctl_3_1);
1430
1431 // Peform shifting the same way for input array elements with
1432 // idx%4==2.
1433 __m256i a2 = _mm256_shuffle_epi8(a, ctl_2_3);
1434 __m256i b2 = _mm256_shuffle_epi8(b, ctl_2_0);
1435 __m256i c2;
1436 if (left_shift)
1437 c2 = _mm256_sllv_epi32(a2, b2);
1438 else
1439 if (std::is_same<T, int8_t>::value)
1440 c2 = _mm256_srav_epi32(a2, b2);
1441 else
1442 c2 = _mm256_srlv_epi32(a2, b2);
1443 c2 = _mm256_shuffle_epi8(c2, ctl_3_2);
1444
1445 // Peform shifting the same way for input array elements with
1446 // idx%4==3.
1447 __m256i a3 = _mm256_and_si256(a, keep_3);
1448 __m256i b3 = _mm256_shuffle_epi8(b, ctl_3_0);
1449 __m256i c3;
1450 if (left_shift)
1451 c3 = _mm256_sllv_epi32(a3, b3);
1452 else
1453 if (std::is_same<T, int8_t>::value)
1454 c3 = _mm256_srav_epi32(a3, b3);
1455 else
1456 c3 = _mm256_srlv_epi32(a3, b3);
1457 c3 = _mm256_and_si256(c3, keep_3);
1458
1459 // Merge partial results into the final result.
1460 __m256i c01 = _mm256_or_si256(c0, c1);
1461 __m256i c23 = _mm256_or_si256(c2, c3);
1462 __m256i c = _mm256_or_si256(c01, c23);
1463
1464 return c;
1465 }
1466
1467 template <>
1468 Vectorized<int64_t> inline operator<<(const Vectorized<int64_t>& a, const Vectorized<int64_t>& b) {
1469 return _mm256_sllv_epi64(a, b);
1470 }
1471
1472 template <>
1473 Vectorized<int32_t> inline operator<<(const Vectorized<int32_t>& a, const Vectorized<int32_t>& b) {
1474 return _mm256_sllv_epi32(a, b);
1475 }
1476
1477 template <>
1478 Vectorized<int16_t> inline operator<<(const Vectorized<int16_t>& a, const Vectorized<int16_t>& b) {
1479 return shift_256_16<true>(a, b);
1480 }
1481
1482 template <>
1483 Vectorized<int8_t> inline operator<<(const Vectorized<int8_t>& a, const Vectorized<int8_t>& b) {
1484 return shift_256_8<true>(a, b);
1485 }
1486
1487 template <>
1488 Vectorized<uint8_t> inline operator<<(const Vectorized<uint8_t>& a, const Vectorized<uint8_t>& b) {
1489 return shift_256_8<true>(a, b);
1490 }
1491
1492 template <>
1493 Vectorized<int64_t> inline operator>>(const Vectorized<int64_t>& a, const Vectorized<int64_t>& b) {
1494 // No vector instruction for right arithmetic shifting int64_t, so emulating it
1495 // instead.
1496
1497 // Clamp the shift values such that shift values < 0 and > 64 are changed to 64
1498 // which results in -1 for negative input and 0 for non-negative input.
1499 __m256i zero = _mm256_set1_epi64x(0);
1500 __m256i max_shift = _mm256_set1_epi64x(64);
1501 __m256i mask = _mm256_or_si256(_mm256_cmpgt_epi64(zero, b), _mm256_cmpgt_epi64(b, max_shift));
1502 __m256i shift = _mm256_blendv_epi8(b, max_shift, mask);
1503 // Shift the number logically to the right, thus filling the most
1504 // significant bits with 0s. Then, replace these bits with the sign
1505 // bit.
1506 __m256i sign_bits = _mm256_cmpgt_epi64(zero, a);
1507 __m256i sign_shift = _mm256_sub_epi64(max_shift, shift);
1508 __m256i sign_ext = _mm256_sllv_epi64(sign_bits, sign_shift);
1509 __m256i c = _mm256_srlv_epi64(a, shift);
1510 c = _mm256_or_si256(c, sign_ext);
1511
1512 return c;
1513 }
1514
1515 template <>
1516 Vectorized<int32_t> inline operator>>(const Vectorized<int32_t>& a, const Vectorized<int32_t>& b) {
1517 return _mm256_srav_epi32(a, b);
1518 }
1519
1520 template <>
1521 Vectorized<int16_t> inline operator>>(const Vectorized<int16_t>& a, const Vectorized<int16_t>& b) {
1522 return shift_256_16<false>(a, b);
1523 }
1524
1525 template <>
1526 Vectorized<int8_t> inline operator>>(const Vectorized<int8_t>& a, const Vectorized<int8_t>& b) {
1527 return shift_256_8<false>(a, b);
1528 }
1529
1530 template <>
1531 Vectorized<uint8_t> inline operator>>(const Vectorized<uint8_t>& a, const Vectorized<uint8_t>& b) {
1532 return shift_256_8<false>(a, b);
1533 }
1534
1535 #endif
1536
1537 }}}
1538