xref: /aosp_15_r20/external/executorch/kernels/optimized/vec/vec256/vec256_int.h (revision 523fa7a60841cd1ecfb9cc4201f1ca8b03ed023a)
1 /*
2  * Copyright (c) Meta Platforms, Inc. and affiliates.
3  * All rights reserved.
4  *
5  * This source code is licensed under the BSD-style license found in the
6  * LICENSE file in the root directory of this source tree.
7  */
8 
9 #pragma once
10 
11 // DO NOT DEFINE STATIC DATA IN THIS HEADER!
12 // See Note [Do not compile initializers with AVX]
13 
14 #include <executorch/kernels/optimized/vec/intrinsics.h>
15 #include <executorch/kernels/optimized/vec/vec_base.h>
16 
17 #include <iostream>
18 
19 namespace executorch {
20 namespace vec {
21 inline namespace CPU_CAPABILITY {
22 
23 #ifdef CPU_CAPABILITY_AVX2
24 
25 struct Vectorizedi {
26 protected:
27   __m256i values;
28 
invertVectorizedi29   static inline __m256i invert(const __m256i& v) {
30     const auto ones = _mm256_set1_epi64x(-1);
31     return _mm256_xor_si256(ones, v);
32   }
33 public:
VectorizediVectorizedi34   Vectorizedi() {}
VectorizediVectorizedi35   Vectorizedi(__m256i v) : values(v) {}
__m256iVectorizedi36   operator __m256i() const {
37     return values;
38   }
39 };
40 
41 #else
42 
43 struct Vectorizedi {};  // dummy definition to make Vectorizedi always defined
44 
45 #endif // CPU_CAPABILITY_AVX2
46 
47 #ifdef CPU_CAPABILITY_AVX2
48 
49 template <>
50 class Vectorized<int64_t> : public Vectorizedi {
51 private:
52   static const Vectorized<int64_t> ones;
53 public:
54   using value_type = int64_t;
55   using size_type = int;
size()56   static constexpr size_type size() {
57     return 4;
58   }
59   using Vectorizedi::Vectorizedi;
Vectorized()60   Vectorized() {}
Vectorized(int64_t v)61   Vectorized(int64_t v) { values = _mm256_set1_epi64x(v); }
Vectorized(int64_t val1,int64_t val2,int64_t val3,int64_t val4)62   Vectorized(int64_t val1, int64_t val2, int64_t val3, int64_t val4) {
63     values = _mm256_setr_epi64x(val1, val2, val3, val4);
64   }
65   template <int64_t mask>
blend(Vectorized<int64_t> a,Vectorized<int64_t> b)66   static Vectorized<int64_t> blend(Vectorized<int64_t> a, Vectorized<int64_t> b) {
67     __at_align__ int64_t tmp_values[size()];
68     a.store(tmp_values);
69     if (mask & 0x01)
70       tmp_values[0] = _mm256_extract_epi64(b.values, 0);
71     if (mask & 0x02)
72       tmp_values[1] = _mm256_extract_epi64(b.values, 1);
73     if (mask & 0x04)
74       tmp_values[2] = _mm256_extract_epi64(b.values, 2);
75     if (mask & 0x08)
76       tmp_values[3] = _mm256_extract_epi64(b.values, 3);
77     return loadu(tmp_values);
78   }
blendv(const Vectorized<int64_t> & a,const Vectorized<int64_t> & b,const Vectorized<int64_t> & mask)79   static Vectorized<int64_t> blendv(const Vectorized<int64_t>& a, const Vectorized<int64_t>& b,
80                                 const Vectorized<int64_t>& mask) {
81     return _mm256_blendv_epi8(a.values, b.values, mask.values);
82   }
83   template <typename step_t>
84   static Vectorized<int64_t> arange(int64_t base = 0, step_t step = static_cast<step_t>(1)) {
85     return Vectorized<int64_t>(base, base + step, base + 2 * step, base + 3 * step);
86   }
87   static Vectorized<int64_t>
88   set(Vectorized<int64_t> a, Vectorized<int64_t> b, int64_t count = size()) {
89     switch (count) {
90       case 0:
91         return a;
92       case 1:
93         return blend<1>(a, b);
94       case 2:
95         return blend<3>(a, b);
96       case 3:
97         return blend<7>(a, b);
98     }
99     return b;
100   }
loadu(const void * ptr)101   static Vectorized<int64_t> loadu(const void* ptr) {
102     return _mm256_loadu_si256(reinterpret_cast<const __m256i*>(ptr));
103   }
loadu(const void * ptr,int64_t count)104   static Vectorized<int64_t> loadu(const void* ptr, int64_t count) {
105     __at_align__ int64_t tmp_values[size()];
106     // Ensure uninitialized memory does not change the output value See https://github.com/pytorch/pytorch/issues/32502
107     // for more details. We do not initialize arrays to zero using "={0}" because gcc would compile it to two
108     // instructions while a loop would be compiled to one instruction.
109     for (size_t i = 0; i < size(); ++i) {
110       tmp_values[i] = 0;
111     }
112     std::memcpy(tmp_values, ptr, count * sizeof(int64_t));
113     return loadu(tmp_values);
114   }
115   void store(void* ptr, int count = size()) const {
116     if (count == size()) {
117       // ptr need not to be aligned here. See
118       // https://software.intel.com/content/www/us/en/develop/documentation/cpp-compiler-developer-guide-and-reference/top/compiler-reference/intrinsics/intrinsics-for-intel-advanced-vector-extensions/intrinsics-for-load-and-store-operations-1/mm256-storeu-si256.html
119       _mm256_storeu_si256(reinterpret_cast<__m256i*>(ptr), values);
120     } else if (count > 0) {
121       __at_align__ int64_t tmp_values[size()];
122       _mm256_storeu_si256(reinterpret_cast<__m256i*>(tmp_values), values);
123       std::memcpy(ptr, tmp_values, count * sizeof(int64_t));
124     }
125   }
126   const int64_t& operator[](int idx) const  = delete;
127   int64_t& operator[](int idx)  = delete;
abs()128   Vectorized<int64_t> abs() const {
129     auto zero = _mm256_set1_epi64x(0);
130     auto is_larger = _mm256_cmpgt_epi64(zero, values);
131     auto inverse = _mm256_xor_si256(values, is_larger);
132     return _mm256_sub_epi64(inverse, is_larger);
133   }
real()134   Vectorized<int64_t> real() const {
135     return *this;
136   }
imag()137   Vectorized<int64_t> imag() const {
138     return _mm256_set1_epi64x(0);
139   }
conj()140   Vectorized<int64_t> conj() const {
141     return *this;
142   }
143   Vectorized<int64_t> neg() const;
144   Vectorized<int64_t> operator==(const Vectorized<int64_t>& other) const {
145     return _mm256_cmpeq_epi64(values, other.values);
146   }
147   Vectorized<int64_t> operator!=(const Vectorized<int64_t>& other) const {
148     return invert(_mm256_cmpeq_epi64(values, other.values));
149   }
150   Vectorized<int64_t> operator<(const Vectorized<int64_t>& other) const {
151     return _mm256_cmpgt_epi64(other.values, values);
152   }
153   Vectorized<int64_t> operator<=(const Vectorized<int64_t>& other) const {
154     return invert(_mm256_cmpgt_epi64(values, other.values));
155   }
156   Vectorized<int64_t> operator>(const Vectorized<int64_t>& other) const {
157     return _mm256_cmpgt_epi64(values, other.values);
158   }
159   Vectorized<int64_t> operator>=(const Vectorized<int64_t>& other) const {
160     return invert(_mm256_cmpgt_epi64(other.values, values));
161   }
162 
163   Vectorized<int64_t> eq(const Vectorized<int64_t>& other) const;
164   Vectorized<int64_t> ne(const Vectorized<int64_t>& other) const;
165   Vectorized<int64_t> gt(const Vectorized<int64_t>& other) const;
166   Vectorized<int64_t> ge(const Vectorized<int64_t>& other) const;
167   Vectorized<int64_t> lt(const Vectorized<int64_t>& other) const;
168   Vectorized<int64_t> le(const Vectorized<int64_t>& other) const;
169 };
170 
171 template <>
172 class Vectorized<int32_t> : public Vectorizedi {
173 private:
174   static const Vectorized<int32_t> ones;
175 public:
176   using value_type = int32_t;
177   using size_type = int;
size()178   static constexpr int size() {
179     return 8;
180   }
181   using Vectorizedi::Vectorizedi;
Vectorized()182   Vectorized() {}
Vectorized(int32_t v)183   Vectorized(int32_t v) { values = _mm256_set1_epi32(v); }
Vectorized(int32_t val1,int32_t val2,int32_t val3,int32_t val4,int32_t val5,int32_t val6,int32_t val7,int32_t val8)184   Vectorized(int32_t val1, int32_t val2, int32_t val3, int32_t val4,
185          int32_t val5, int32_t val6, int32_t val7, int32_t val8) {
186     values = _mm256_setr_epi32(val1, val2, val3, val4, val5, val6, val7, val8);
187   }
188   template <int64_t mask>
blend(Vectorized<int32_t> a,Vectorized<int32_t> b)189   static Vectorized<int32_t> blend(Vectorized<int32_t> a, Vectorized<int32_t> b) {
190     return _mm256_blend_epi32(a, b, mask);
191   }
blendv(const Vectorized<int32_t> & a,const Vectorized<int32_t> & b,const Vectorized<int32_t> & mask)192   static Vectorized<int32_t> blendv(const Vectorized<int32_t>& a, const Vectorized<int32_t>& b,
193                                 const Vectorized<int32_t>& mask) {
194     return _mm256_blendv_epi8(a.values, b.values, mask.values);
195   }
196   template <typename step_t>
197   static Vectorized<int32_t> arange(int32_t base = 0, step_t step = static_cast<step_t>(1)) {
198     return Vectorized<int32_t>(
199       base,            base +     step, base + 2 * step, base + 3 * step,
200       base + 4 * step, base + 5 * step, base + 6 * step, base + 7 * step);
201   }
202   static Vectorized<int32_t>
203   set(Vectorized<int32_t> a, Vectorized<int32_t> b, int32_t count = size()) {
204     switch (count) {
205       case 0:
206         return a;
207       case 1:
208         return blend<1>(a, b);
209       case 2:
210         return blend<3>(a, b);
211       case 3:
212         return blend<7>(a, b);
213       case 4:
214         return blend<15>(a, b);
215       case 5:
216         return blend<31>(a, b);
217       case 6:
218         return blend<63>(a, b);
219       case 7:
220         return blend<127>(a, b);
221     }
222     return b;
223   }
loadu(const void * ptr)224   static Vectorized<int32_t> loadu(const void* ptr) {
225     return _mm256_loadu_si256(reinterpret_cast<const __m256i*>(ptr));
226   }
loadu(const void * ptr,int32_t count)227   static Vectorized<int32_t> loadu(const void* ptr, int32_t count) {
228     __at_align__ int32_t tmp_values[size()];
229     // Ensure uninitialized memory does not change the output value See https://github.com/pytorch/pytorch/issues/32502
230     // for more details. We do not initialize arrays to zero using "={0}" because gcc would compile it to two
231     // instructions while a loop would be compiled to one instruction.
232     for (size_t i = 0; i < size(); ++i) {
233       tmp_values[i] = 0;
234     }
235     std::memcpy(tmp_values, ptr, count * sizeof(int32_t));
236     return loadu(tmp_values);
237   }
238   void store(void* ptr, int count = size()) const {
239     if (count == size()) {
240       // ptr need not to be aligned here. See
241       // https://software.intel.com/content/www/us/en/develop/documentation/cpp-compiler-developer-guide-and-reference/top/compiler-reference/intrinsics/intrinsics-for-intel-advanced-vector-extensions/intrinsics-for-load-and-store-operations-1/mm256-storeu-si256.html
242       _mm256_storeu_si256(reinterpret_cast<__m256i*>(ptr), values);
243     } else if (count > 0) {
244       __at_align__ int32_t tmp_values[size()];
245       _mm256_storeu_si256(reinterpret_cast<__m256i*>(tmp_values), values);
246       std::memcpy(ptr, tmp_values, count * sizeof(int32_t));
247     }
248   }
249   const int32_t& operator[](int idx) const  = delete;
250   int32_t& operator[](int idx)  = delete;
abs()251   Vectorized<int32_t> abs() const {
252     return _mm256_abs_epi32(values);
253   }
real()254   Vectorized<int32_t> real() const {
255     return *this;
256   }
imag()257   Vectorized<int32_t> imag() const {
258     return _mm256_set1_epi32(0);
259   }
conj()260   Vectorized<int32_t> conj() const {
261     return *this;
262   }
263   Vectorized<int32_t> neg() const;
264   Vectorized<int32_t> operator==(const Vectorized<int32_t>& other) const {
265     return _mm256_cmpeq_epi32(values, other.values);
266   }
267   Vectorized<int32_t> operator!=(const Vectorized<int32_t>& other) const {
268     return invert(_mm256_cmpeq_epi32(values, other.values));
269   }
270   Vectorized<int32_t> operator<(const Vectorized<int32_t>& other) const {
271     return _mm256_cmpgt_epi32(other.values, values);
272   }
273   Vectorized<int32_t> operator<=(const Vectorized<int32_t>& other) const {
274     return invert(_mm256_cmpgt_epi32(values, other.values));
275   }
276   Vectorized<int32_t> operator>(const Vectorized<int32_t>& other) const {
277     return _mm256_cmpgt_epi32(values, other.values);
278   }
279   Vectorized<int32_t> operator>=(const Vectorized<int32_t>& other) const {
280     return invert(_mm256_cmpgt_epi32(other.values, values));
281   }
282   Vectorized<int32_t> eq(const Vectorized<int32_t>& other) const;
283   Vectorized<int32_t> ne(const Vectorized<int32_t>& other) const;
284   Vectorized<int32_t> gt(const Vectorized<int32_t>& other) const;
285   Vectorized<int32_t> ge(const Vectorized<int32_t>& other) const;
286   Vectorized<int32_t> lt(const Vectorized<int32_t>& other) const;
287   Vectorized<int32_t> le(const Vectorized<int32_t>& other) const;
288 };
289 
290 template <>
convert(const int32_t * src,float * dst,int64_t n)291 inline void convert(const int32_t *src, float *dst, int64_t n) {
292   int64_t i;
293   // int32_t and float have same size
294 #ifndef _MSC_VER
295 # pragma unroll
296 #endif
297   for (i = 0; i <= (n - Vectorized<int32_t>::size()); i += Vectorized<int32_t>::size()) {
298     auto input_vec = _mm256_loadu_si256(reinterpret_cast<const __m256i*>(src + i));
299     auto output_vec = _mm256_cvtepi32_ps(input_vec);
300     _mm256_storeu_ps(reinterpret_cast<float*>(dst + i), output_vec);
301   }
302 #ifndef _MSC_VER
303 # pragma unroll
304 #endif
305   for (; i < n; i++) {
306     dst[i] = static_cast<float>(src[i]);
307   }
308 }
309 
310 template <>
convert(const int32_t * src,double * dst,int64_t n)311 inline void convert(const int32_t *src, double *dst, int64_t n) {
312   int64_t i;
313   // int32_t has half the size of double
314 #ifndef _MSC_VER
315 # pragma unroll
316 #endif
317   for (i = 0; i <= (n - Vectorized<double>::size()); i += Vectorized<double>::size()) {
318     auto input_128_vec = _mm_loadu_si128(reinterpret_cast<const __m128i*>(src + i));
319     auto output_vec = _mm256_cvtepi32_pd(input_128_vec);
320     _mm256_storeu_pd(reinterpret_cast<double*>(dst + i), output_vec);
321   }
322 #ifndef _MSC_VER
323 # pragma unroll
324 #endif
325   for (; i < n; i++) {
326     dst[i] = static_cast<double>(src[i]);
327   }
328 }
329 
330 template <>
331 class Vectorized<int16_t> : public Vectorizedi {
332 private:
333   static const Vectorized<int16_t> ones;
334 public:
335   using value_type = int16_t;
336   using size_type = int;
size()337   static constexpr int size() {
338     return 16;
339   }
340   using Vectorizedi::Vectorizedi;
Vectorized()341   Vectorized() {}
Vectorized(int16_t v)342   Vectorized(int16_t v) { values = _mm256_set1_epi16(v); }
Vectorized(int16_t val1,int16_t val2,int16_t val3,int16_t val4,int16_t val5,int16_t val6,int16_t val7,int16_t val8,int16_t val9,int16_t val10,int16_t val11,int16_t val12,int16_t val13,int16_t val14,int16_t val15,int16_t val16)343   Vectorized(int16_t val1, int16_t val2, int16_t val3, int16_t val4,
344          int16_t val5, int16_t val6, int16_t val7, int16_t val8,
345          int16_t val9, int16_t val10, int16_t val11, int16_t val12,
346          int16_t val13, int16_t val14, int16_t val15, int16_t val16) {
347     values = _mm256_setr_epi16(val1, val2, val3, val4, val5, val6, val7, val8,
348                                val9, val10, val11, val12, val13, val14, val15, val16);
349   }
350   template <int64_t mask>
blend(Vectorized<int16_t> a,Vectorized<int16_t> b)351   static Vectorized<int16_t> blend(Vectorized<int16_t> a, Vectorized<int16_t> b) {
352     __at_align__ int16_t tmp_values[size()];
353     a.store(tmp_values);
354     if (mask & 0x01)
355       tmp_values[0] = _mm256_extract_epi16(b.values, 0);
356     if (mask & 0x02)
357       tmp_values[1] = _mm256_extract_epi16(b.values, 1);
358     if (mask & 0x04)
359       tmp_values[2] = _mm256_extract_epi16(b.values, 2);
360     if (mask & 0x08)
361       tmp_values[3] = _mm256_extract_epi16(b.values, 3);
362     if (mask & 0x10)
363       tmp_values[4] = _mm256_extract_epi16(b.values, 4);
364     if (mask & 0x20)
365       tmp_values[5] = _mm256_extract_epi16(b.values, 5);
366     if (mask & 0x40)
367       tmp_values[6] = _mm256_extract_epi16(b.values, 6);
368     if (mask & 0x80)
369       tmp_values[7] = _mm256_extract_epi16(b.values, 7);
370     if (mask & 0x100)
371       tmp_values[8] = _mm256_extract_epi16(b.values, 8);
372     if (mask & 0x200)
373       tmp_values[9] = _mm256_extract_epi16(b.values, 9);
374     if (mask & 0x400)
375       tmp_values[10] = _mm256_extract_epi16(b.values, 10);
376     if (mask & 0x800)
377       tmp_values[11] = _mm256_extract_epi16(b.values, 11);
378     if (mask & 0x1000)
379       tmp_values[12] = _mm256_extract_epi16(b.values, 12);
380     if (mask & 0x2000)
381       tmp_values[13] = _mm256_extract_epi16(b.values, 13);
382     if (mask & 0x4000)
383       tmp_values[14] = _mm256_extract_epi16(b.values, 14);
384     if (mask & 0x8000)
385       tmp_values[15] = _mm256_extract_epi16(b.values, 15);
386     return loadu(tmp_values);
387   }
blendv(const Vectorized<int16_t> & a,const Vectorized<int16_t> & b,const Vectorized<int16_t> & mask)388   static Vectorized<int16_t> blendv(const Vectorized<int16_t>& a, const Vectorized<int16_t>& b,
389                                 const Vectorized<int16_t>& mask) {
390     return _mm256_blendv_epi8(a.values, b.values, mask.values);
391   }
392   template <typename step_t>
393   static Vectorized<int16_t> arange(int16_t base = 0, step_t step = static_cast<step_t>(1)) {
394     return Vectorized<int16_t>(
395       base,             base +      step, base +  2 * step, base +  3 * step,
396       base +  4 * step, base +  5 * step, base +  6 * step, base +  7 * step,
397       base +  8 * step, base +  9 * step, base + 10 * step, base + 11 * step,
398       base + 12 * step, base + 13 * step, base + 14 * step, base + 15 * step);
399   }
400   static Vectorized<int16_t>
401   set(Vectorized<int16_t> a, Vectorized<int16_t> b, int16_t count = size()) {
402     switch (count) {
403       case 0:
404         return a;
405       case 1:
406         return blend<1>(a, b);
407       case 2:
408         return blend<3>(a, b);
409       case 3:
410         return blend<7>(a, b);
411       case 4:
412         return blend<15>(a, b);
413       case 5:
414         return blend<31>(a, b);
415       case 6:
416         return blend<63>(a, b);
417       case 7:
418         return blend<127>(a, b);
419       case 8:
420         return blend<255>(a, b);
421       case 9:
422         return blend<511>(a, b);
423       case 10:
424         return blend<1023>(a, b);
425       case 11:
426         return blend<2047>(a, b);
427       case 12:
428         return blend<4095>(a, b);
429       case 13:
430         return blend<8191>(a, b);
431       case 14:
432         return blend<16383>(a, b);
433       case 15:
434         return blend<32767>(a, b);
435     }
436     return b;
437   }
loadu(const void * ptr)438   static Vectorized<int16_t> loadu(const void* ptr) {
439     return _mm256_loadu_si256(reinterpret_cast<const __m256i*>(ptr));
440   }
loadu(const void * ptr,int16_t count)441   static Vectorized<int16_t> loadu(const void* ptr, int16_t count) {
442     __at_align__ int16_t tmp_values[size()];
443     // Ensure uninitialized memory does not change the output value See https://github.com/pytorch/pytorch/issues/32502
444     // for more details. We do not initialize arrays to zero using "={0}" because gcc would compile it to two
445     // instructions while a loop would be compiled to one instruction.
446     for (size_t i = 0; i < size(); ++i) {
447       tmp_values[i] = 0;
448     }
449     std::memcpy(tmp_values, ptr, count * sizeof(int16_t));
450     return loadu(tmp_values);
451   }
452   void store(void* ptr, int count = size()) const {
453     if (count == size()) {
454       // ptr need not to be aligned here. See
455       // https://software.intel.com/content/www/us/en/develop/documentation/cpp-compiler-developer-guide-and-reference/top/compiler-reference/intrinsics/intrinsics-for-intel-advanced-vector-extensions/intrinsics-for-load-and-store-operations-1/mm256-storeu-si256.html
456       _mm256_storeu_si256(reinterpret_cast<__m256i*>(ptr), values);
457     } else if (count > 0) {
458       __at_align__ int16_t tmp_values[size()];
459       _mm256_storeu_si256(reinterpret_cast<__m256i*>(tmp_values), values);
460       std::memcpy(ptr, tmp_values, count * sizeof(int16_t));
461     }
462   }
463   const int16_t& operator[](int idx) const  = delete;
464   int16_t& operator[](int idx)  = delete;
abs()465   Vectorized<int16_t> abs() const {
466     return _mm256_abs_epi16(values);
467   }
real()468   Vectorized<int16_t> real() const {
469     return *this;
470   }
imag()471   Vectorized<int16_t> imag() const {
472     return _mm256_set1_epi16(0);
473   }
conj()474   Vectorized<int16_t> conj() const {
475     return *this;
476   }
477   Vectorized<int16_t> neg() const;
478   Vectorized<int16_t> operator==(const Vectorized<int16_t>& other) const {
479     return _mm256_cmpeq_epi16(values, other.values);
480   }
481   Vectorized<int16_t> operator!=(const Vectorized<int16_t>& other) const {
482     return invert(_mm256_cmpeq_epi16(values, other.values));
483   }
484   Vectorized<int16_t> operator<(const Vectorized<int16_t>& other) const {
485     return _mm256_cmpgt_epi16(other.values, values);
486   }
487   Vectorized<int16_t> operator<=(const Vectorized<int16_t>& other) const {
488     return invert(_mm256_cmpgt_epi16(values, other.values));
489   }
490   Vectorized<int16_t> operator>(const Vectorized<int16_t>& other) const {
491     return _mm256_cmpgt_epi16(values, other.values);
492   }
493   Vectorized<int16_t> operator>=(const Vectorized<int16_t>& other) const {
494     return invert(_mm256_cmpgt_epi16(other.values, values));
495   }
496 
497   Vectorized<int16_t> eq(const Vectorized<int16_t>& other) const;
498   Vectorized<int16_t> ne(const Vectorized<int16_t>& other) const;
499   Vectorized<int16_t> gt(const Vectorized<int16_t>& other) const;
500   Vectorized<int16_t> ge(const Vectorized<int16_t>& other) const;
501   Vectorized<int16_t> lt(const Vectorized<int16_t>& other) const;
502   Vectorized<int16_t> le(const Vectorized<int16_t>& other) const;
503 };
504 
505 template <typename T>
506 class Vectorized8 : public Vectorizedi {
507   static_assert(
508     std::is_same<T, int8_t>::value || std::is_same<T, uint8_t>::value,
509     "Only int8_t/uint8_t are supported");
510 protected:
511   static const Vectorized<T> ones;
512 public:
513   using value_type = T;
514   using size_type = int;
size()515   static constexpr int size() {
516     return 32;
517   }
518   using Vectorizedi::Vectorizedi;
Vectorized8()519   Vectorized8() {}
Vectorized8(T v)520   Vectorized8(T v) { values = _mm256_set1_epi8(v); }
Vectorized8(T val1,T val2,T val3,T val4,T val5,T val6,T val7,T val8,T val9,T val10,T val11,T val12,T val13,T val14,T val15,T val16,T val17,T val18,T val19,T val20,T val21,T val22,T val23,T val24,T val25,T val26,T val27,T val28,T val29,T val30,T val31,T val32)521   Vectorized8(T val1, T val2, T val3, T val4,
522          T val5, T val6, T val7, T val8,
523          T val9, T val10, T val11, T val12,
524          T val13, T val14, T val15, T val16,
525          T val17, T val18, T val19, T val20,
526          T val21, T val22, T val23, T val24,
527          T val25, T val26, T val27, T val28,
528          T val29, T val30, T val31, T val32) {
529     values = _mm256_setr_epi8(val1, val2, val3, val4, val5, val6, val7, val8,
530                               val9, val10, val11, val12, val13, val14, val15, val16,
531                               val17, val18, val19, val20, val21, val22, val23, val24,
532                               val25, val26, val27, val28, val29, val30, val31, val32);
533   }
534   template <int64_t mask>
blend(Vectorized<T> a,Vectorized<T> b)535   static Vectorized<T> blend(Vectorized<T> a, Vectorized<T> b) {
536     __at_align__ T tmp_values[size()];
537     a.store(tmp_values);
538     if (mask & 0x01)
539       tmp_values[0] = _mm256_extract_epi8(b.values, 0);
540     if (mask & 0x02)
541       tmp_values[1] = _mm256_extract_epi8(b.values, 1);
542     if (mask & 0x04)
543       tmp_values[2] = _mm256_extract_epi8(b.values, 2);
544     if (mask & 0x08)
545       tmp_values[3] = _mm256_extract_epi8(b.values, 3);
546     if (mask & 0x10)
547       tmp_values[4] = _mm256_extract_epi8(b.values, 4);
548     if (mask & 0x20)
549       tmp_values[5] = _mm256_extract_epi8(b.values, 5);
550     if (mask & 0x40)
551       tmp_values[6] = _mm256_extract_epi8(b.values, 6);
552     if (mask & 0x80)
553       tmp_values[7] = _mm256_extract_epi8(b.values, 7);
554     if (mask & 0x100)
555       tmp_values[8] = _mm256_extract_epi8(b.values, 8);
556     if (mask & 0x200)
557       tmp_values[9] = _mm256_extract_epi8(b.values, 9);
558     if (mask & 0x400)
559       tmp_values[10] = _mm256_extract_epi8(b.values, 10);
560     if (mask & 0x800)
561       tmp_values[11] = _mm256_extract_epi8(b.values, 11);
562     if (mask & 0x1000)
563       tmp_values[12] = _mm256_extract_epi8(b.values, 12);
564     if (mask & 0x2000)
565       tmp_values[13] = _mm256_extract_epi8(b.values, 13);
566     if (mask & 0x4000)
567       tmp_values[14] = _mm256_extract_epi8(b.values, 14);
568     if (mask & 0x8000)
569       tmp_values[15] = _mm256_extract_epi8(b.values, 15);
570     if (mask & 0x010000)
571       tmp_values[16] = _mm256_extract_epi8(b.values, 16);
572     if (mask & 0x020000)
573       tmp_values[17] = _mm256_extract_epi8(b.values, 17);
574     if (mask & 0x040000)
575       tmp_values[18] = _mm256_extract_epi8(b.values, 18);
576     if (mask & 0x080000)
577       tmp_values[19] = _mm256_extract_epi8(b.values, 19);
578     if (mask & 0x100000)
579       tmp_values[20] = _mm256_extract_epi8(b.values, 20);
580     if (mask & 0x200000)
581       tmp_values[21] = _mm256_extract_epi8(b.values, 21);
582     if (mask & 0x400000)
583       tmp_values[22] = _mm256_extract_epi8(b.values, 22);
584     if (mask & 0x800000)
585       tmp_values[23] = _mm256_extract_epi8(b.values, 23);
586     if (mask & 0x1000000)
587       tmp_values[24] = _mm256_extract_epi8(b.values, 24);
588     if (mask & 0x2000000)
589       tmp_values[25] = _mm256_extract_epi8(b.values, 25);
590     if (mask & 0x4000000)
591       tmp_values[26] = _mm256_extract_epi8(b.values, 26);
592     if (mask & 0x8000000)
593       tmp_values[27] = _mm256_extract_epi8(b.values, 27);
594     if (mask & 0x10000000)
595       tmp_values[28] = _mm256_extract_epi8(b.values, 28);
596     if (mask & 0x20000000)
597       tmp_values[29] = _mm256_extract_epi8(b.values, 29);
598     if (mask & 0x40000000)
599       tmp_values[30] = _mm256_extract_epi8(b.values, 30);
600     if (mask & 0x80000000)
601       tmp_values[31] = _mm256_extract_epi8(b.values, 31);
602     return loadu(tmp_values);
603   }
blendv(const Vectorized<T> & a,const Vectorized<T> & b,const Vectorized<T> & mask)604   static Vectorized<T> blendv(const Vectorized<T>& a, const Vectorized<T>& b,
605                                const Vectorized<T>& mask) {
606     return _mm256_blendv_epi8(a.values, b.values, mask.values);
607   }
608   template <typename step_t>
609   static Vectorized<T> arange(T base = 0, step_t step = static_cast<step_t>(1)) {
610     return Vectorized<T>(
611       base,             base +      step, base +  2 * step, base +  3 * step,
612       base +  4 * step, base +  5 * step, base +  6 * step, base +  7 * step,
613       base +  8 * step, base +  9 * step, base + 10 * step, base + 11 * step,
614       base + 12 * step, base + 13 * step, base + 14 * step, base + 15 * step,
615       base + 16 * step, base + 17 * step, base + 18 * step, base + 19 * step,
616       base + 20 * step, base + 21 * step, base + 22 * step, base + 23 * step,
617       base + 24 * step, base + 25 * step, base + 26 * step, base + 27 * step,
618       base + 28 * step, base + 29 * step, base + 30 * step, base + 31 * step);
619   }
620   static Vectorized<T>
621   set(Vectorized<T> a, Vectorized<T> b, T count = size()) {
622     switch (count) {
623       case 0:
624         return a;
625       case 1:
626         return blend<0x1>(a, b);
627       case 2:
628         return blend<0x3>(a, b);
629       case 3:
630         return blend<0x7>(a, b);
631       case 4:
632         return blend<0xF>(a, b);
633       case 5:
634         return blend<0x1F>(a, b);
635       case 6:
636         return blend<0x3F>(a, b);
637       case 7:
638         return blend<0x7F>(a, b);
639       case 8:
640         return blend<0xFF>(a, b);
641       case 9:
642         return blend<0x1FF>(a, b);
643       case 10:
644         return blend<0x3FF>(a, b);
645       case 11:
646         return blend<0x7FF>(a, b);
647       case 12:
648         return blend<0xFFF>(a, b);
649       case 13:
650         return blend<0x1FFF>(a, b);
651       case 14:
652         return blend<0x3FFF>(a, b);
653       case 15:
654         return blend<0x7FFF>(a, b);
655       case 16:
656         return blend<0xFFFF>(a, b);
657       case 17:
658         return blend<0x1FFFF>(a, b);
659       case 18:
660         return blend<0x3FFFF>(a, b);
661       case 19:
662         return blend<0x7FFFF>(a, b);
663       case 20:
664         return blend<0xFFFFF>(a, b);
665       case 21:
666         return blend<0x1FFFFF>(a, b);
667       case 22:
668         return blend<0x3FFFFF>(a, b);
669       case 23:
670         return blend<0x7FFFFF>(a, b);
671       case 24:
672         return blend<0xFFFFFF>(a, b);
673       case 25:
674         return blend<0x1FFFFFF>(a, b);
675       case 26:
676         return blend<0x3FFFFFF>(a, b);
677       case 27:
678         return blend<0x7FFFFFF>(a, b);
679       case 28:
680         return blend<0xFFFFFFF>(a, b);
681       case 29:
682         return blend<0x1FFFFFFF>(a, b);
683       case 30:
684         return blend<0x3FFFFFFF>(a, b);
685       case 31:
686         return blend<0x7FFFFFFF>(a, b);
687     }
688     return b;
689   }
loadu(const void * ptr)690   static Vectorized<T> loadu(const void* ptr) {
691     return _mm256_loadu_si256(reinterpret_cast<const __m256i*>(ptr));
692   }
loadu(const void * ptr,T count)693   static Vectorized<T> loadu(const void* ptr, T count) {
694     __at_align__ T tmp_values[size()];
695     // Ensure uninitialized memory does not change the output value See https://github.com/pytorch/pytorch/issues/32502
696     // for more details. We do not initialize arrays to zero using "={0}" because gcc would compile it to two
697     // instructions while a loop would be compiled to one instruction.
698     for (size_t i = 0; i < size(); ++i) {
699       tmp_values[i] = 0;
700     }
701     std::memcpy(tmp_values, ptr, count * sizeof(T));
702     return loadu(tmp_values);
703   }
704   void store(void* ptr, int count = size()) const {
705     if (count == size()) {
706       // ptr need not to be aligned here. See
707       // https://software.intel.com/content/www/us/en/develop/documentation/cpp-compiler-developer-guide-and-reference/top/compiler-reference/intrinsics/intrinsics-for-intel-advanced-vector-extensions/intrinsics-for-load-and-store-operations-1/mm256-storeu-si256.html
708       _mm256_storeu_si256(reinterpret_cast<__m256i*>(ptr), values);
709     } else if (count > 0) {
710       __at_align__ T tmp_values[size()];
711       _mm256_storeu_si256(reinterpret_cast<__m256i*>(tmp_values), values);
712       std::memcpy(ptr, tmp_values, count * sizeof(T));
713     }
714   }
715   const T& operator[](int idx) const  = delete;
716   T& operator[](int idx)  = delete;
real()717   Vectorized<T> real() const {
718     return *this;
719   }
imag()720   Vectorized<T> imag() const {
721     return _mm256_set1_epi8(0);
722   }
conj()723   Vectorized<T> conj() const {
724     return *this;
725   }
726 };
727 
728 template<>
729 class Vectorized<int8_t>: public Vectorized8<int8_t> {
730 public:
731   using Vectorized8::Vectorized8;
732 
733   Vectorized<int8_t> neg() const;
734 
abs()735   Vectorized<int8_t> abs() const {
736    return _mm256_abs_epi8(values);
737   }
738 
739   Vectorized<int8_t> operator==(const Vectorized<int8_t>& other) const {
740     return _mm256_cmpeq_epi8(values, other.values);
741   }
742   Vectorized<int8_t> operator!=(const Vectorized<int8_t>& other) const {
743     return invert(_mm256_cmpeq_epi8(values, other.values));
744   }
745   Vectorized<int8_t> operator<(const Vectorized<int8_t>& other) const {
746     return _mm256_cmpgt_epi8(other.values, values);
747   }
748   Vectorized<int8_t> operator<=(const Vectorized<int8_t>& other) const {
749     return invert(_mm256_cmpgt_epi8(values, other.values));
750   }
751   Vectorized<int8_t> operator>(const Vectorized<int8_t>& other) const {
752     return other < *this;
753   }
754   Vectorized<int8_t> operator>=(const Vectorized<int8_t>& other) const {
755     return other <= *this;
756   }
757 
758   Vectorized<int8_t> eq(const Vectorized<int8_t>& other) const;
759   Vectorized<int8_t> ne(const Vectorized<int8_t>& other) const;
760   Vectorized<int8_t> gt(const Vectorized<int8_t>& other) const;
761   Vectorized<int8_t> ge(const Vectorized<int8_t>& other) const;
762   Vectorized<int8_t> lt(const Vectorized<int8_t>& other) const;
763   Vectorized<int8_t> le(const Vectorized<int8_t>& other) const;
764 };
765 
766 template<>
767 class Vectorized<uint8_t>: public Vectorized8<uint8_t> {
768 public:
769   using Vectorized8::Vectorized8;
770 
771   Vectorized<uint8_t> neg() const;
772 
abs()773   Vectorized<uint8_t> abs() const {
774     return *this;
775   }
776 
777   Vectorized<uint8_t> operator==(const Vectorized<uint8_t>& other) const {
778     return _mm256_cmpeq_epi8(values, other.values);
779   }
780   Vectorized<uint8_t> operator!=(const Vectorized<uint8_t>& other) const {
781     return invert(_mm256_cmpeq_epi8(values, other.values));
782   }
783   Vectorized<uint8_t> operator<(const Vectorized<uint8_t>& other) const {
784     __m256i max = _mm256_max_epu8(values, other.values);
785     return invert(_mm256_cmpeq_epi8(max, values));
786   }
787   Vectorized<uint8_t> operator<=(const Vectorized<uint8_t>& other) const {
788     __m256i max = _mm256_max_epu8(values, other.values);
789     return _mm256_cmpeq_epi8(max, other.values);
790   }
791   Vectorized<uint8_t> operator>(const Vectorized<uint8_t>& other) const {
792     return other < *this;
793   }
794   Vectorized<uint8_t> operator>=(const Vectorized<uint8_t>& other) const {
795     return other <= *this;
796   }
797 
798   Vectorized<uint8_t> eq(const Vectorized<uint8_t>& other) const;
799   Vectorized<uint8_t> ne(const Vectorized<uint8_t>& other) const;
800   Vectorized<uint8_t> gt(const Vectorized<uint8_t>& other) const;
801   Vectorized<uint8_t> ge(const Vectorized<uint8_t>& other) const;
802   Vectorized<uint8_t> lt(const Vectorized<uint8_t>& other) const;
803   Vectorized<uint8_t> le(const Vectorized<uint8_t>& other) const;
804 };
805 
806 template <>
807 Vectorized<int64_t> inline operator+(const Vectorized<int64_t>& a, const Vectorized<int64_t>& b) {
808   return _mm256_add_epi64(a, b);
809 }
810 
811 template <>
812 Vectorized<int32_t> inline operator+(const Vectorized<int32_t>& a, const Vectorized<int32_t>& b) {
813   return _mm256_add_epi32(a, b);
814 }
815 
816 template <>
817 Vectorized<int16_t> inline operator+(const Vectorized<int16_t>& a, const Vectorized<int16_t>& b) {
818   return _mm256_add_epi16(a, b);
819 }
820 
821 template <>
822 Vectorized<int8_t> inline operator+(const Vectorized<int8_t>& a, const Vectorized<int8_t>& b) {
823   return _mm256_add_epi8(a, b);
824 }
825 
826 template <>
827 Vectorized<uint8_t> inline operator+(const Vectorized<uint8_t>& a, const Vectorized<uint8_t>& b) {
828   return _mm256_add_epi8(a, b);
829 }
830 
831 template <>
832 Vectorized<int64_t> inline operator-(const Vectorized<int64_t>& a, const Vectorized<int64_t>& b) {
833   return _mm256_sub_epi64(a, b);
834 }
835 
836 template <>
837 Vectorized<int32_t> inline operator-(const Vectorized<int32_t>& a, const Vectorized<int32_t>& b) {
838   return _mm256_sub_epi32(a, b);
839 }
840 
841 template <>
842 Vectorized<int16_t> inline operator-(const Vectorized<int16_t>& a, const Vectorized<int16_t>& b) {
843   return _mm256_sub_epi16(a, b);
844 }
845 
846 template <>
847 Vectorized<int8_t> inline operator-(const Vectorized<int8_t>& a, const Vectorized<int8_t>& b) {
848   return _mm256_sub_epi8(a, b);
849 }
850 
851 template <>
852 Vectorized<uint8_t> inline operator-(const Vectorized<uint8_t>& a, const Vectorized<uint8_t>& b) {
853   return _mm256_sub_epi8(a, b);
854 }
855 
856 // Negation. Defined here so we can utilize operator-
neg()857 inline Vectorized<int64_t> Vectorized<int64_t>::neg() const {
858   return Vectorized<int64_t>(0) - *this;
859 }
860 
neg()861 inline Vectorized<int32_t> Vectorized<int32_t>::neg() const {
862   return Vectorized<int32_t>(0) - *this;
863 }
864 
neg()865 inline Vectorized<int16_t> Vectorized<int16_t>::neg() const {
866   return Vectorized<int16_t>(0) - *this;
867 }
868 
neg()869 inline Vectorized<int8_t> Vectorized<int8_t>::neg() const {
870   return Vectorized<int8_t>(0) - *this;
871 }
872 
neg()873 inline Vectorized<uint8_t> Vectorized<uint8_t>::neg() const {
874   return Vectorized<uint8_t>(0) - *this;
875 }
876 
877 // Emulate operations with no native 64-bit support in avx,
878 // by extracting each element, performing the operation pointwise,
879 // then combining the results into a vector.
880 template <typename op_t>
emulate(const Vectorized<int64_t> & a,const Vectorized<int64_t> & b,const op_t & op)881 Vectorized<int64_t> inline emulate(const Vectorized<int64_t>& a, const Vectorized<int64_t>& b, const op_t& op) {
882   int64_t a0 = _mm256_extract_epi64(a, 0);
883   int64_t a1 = _mm256_extract_epi64(a, 1);
884   int64_t a2 = _mm256_extract_epi64(a, 2);
885   int64_t a3 = _mm256_extract_epi64(a, 3);
886 
887   int64_t b0 = _mm256_extract_epi64(b, 0);
888   int64_t b1 = _mm256_extract_epi64(b, 1);
889   int64_t b2 = _mm256_extract_epi64(b, 2);
890   int64_t b3 = _mm256_extract_epi64(b, 3);
891 
892   int64_t c0 = op(a0, b0);
893   int64_t c1 = op(a1, b1);
894   int64_t c2 = op(a2, b2);
895   int64_t c3 = op(a3, b3);
896 
897   return _mm256_set_epi64x(c3, c2, c1, c0);
898 }
899 
900 template <typename op_t>
emulate(const Vectorized<int64_t> & a,const Vectorized<int64_t> & b,const Vectorized<int64_t> & c,const op_t & op)901 Vectorized<int64_t> inline emulate(const Vectorized<int64_t>& a, const Vectorized<int64_t>& b, const Vectorized<int64_t>& c, const op_t& op) {
902   int64_t a0 = _mm256_extract_epi64(a, 0);
903   int64_t a1 = _mm256_extract_epi64(a, 1);
904   int64_t a2 = _mm256_extract_epi64(a, 2);
905   int64_t a3 = _mm256_extract_epi64(a, 3);
906 
907   int64_t b0 = _mm256_extract_epi64(b, 0);
908   int64_t b1 = _mm256_extract_epi64(b, 1);
909   int64_t b2 = _mm256_extract_epi64(b, 2);
910   int64_t b3 = _mm256_extract_epi64(b, 3);
911 
912   int64_t c0 = _mm256_extract_epi64(c, 0);
913   int64_t c1 = _mm256_extract_epi64(c, 1);
914   int64_t c2 = _mm256_extract_epi64(c, 2);
915   int64_t c3 = _mm256_extract_epi64(c, 3);
916 
917   int64_t d0 = op(a0, b0, c0);
918   int64_t d1 = op(a1, b1, c1);
919   int64_t d2 = op(a2, b2, c2);
920   int64_t d3 = op(a3, b3, c3);
921 
922   return _mm256_set_epi64x(d3, d2, d1, d0);
923 }
924 
925 // AVX2 has no intrinsic for int64_t multiply so it needs to be emulated
926 // This could be implemented more efficiently using epi32 instructions
927 // This is also technically avx compatible, but then we'll need AVX
928 // code for add as well.
929 // Note: intentionally ignores undefined behavior like (-lowest * -1).
930 template <>
931 Vectorized<int64_t> inline operator*(const Vectorized<int64_t>& a, const Vectorized<int64_t>& b) {
932   return emulate(a, b, [](int64_t a_point, int64_t b_point) {return a_point * b_point;});
933 }
934 
935 template <>
936 Vectorized<int32_t> inline operator*(const Vectorized<int32_t>& a, const Vectorized<int32_t>& b) {
937   return _mm256_mullo_epi32(a, b);
938 }
939 
940 template <>
941 Vectorized<int16_t> inline operator*(const Vectorized<int16_t>& a, const Vectorized<int16_t>& b) {
942   return _mm256_mullo_epi16(a, b);
943 }
944 
945 template <typename T, typename Op>
int_elementwise_binary_256(const Vectorized<T> & a,const Vectorized<T> & b,Op op)946 Vectorized<T> inline int_elementwise_binary_256(const Vectorized<T>& a, const Vectorized<T>& b, Op op) {
947   T values_a[Vectorized<T>::size()];
948   T values_b[Vectorized<T>::size()];
949   a.store(values_a);
950   b.store(values_b);
951   for (size_t i = 0; i != Vectorized<T>::size(); i++) {
952     values_a[i] = op(values_a[i], values_b[i]);
953   }
954   return Vectorized<T>::loadu(values_a);
955 }
956 
957 template <>
958 Vectorized<int8_t> inline operator*(const Vectorized<int8_t>& a, const Vectorized<int8_t>& b) {
959   // We don't have an instruction for multiplying int8_t
960   return int_elementwise_binary_256(a, b, std::multiplies<int8_t>());
961 }
962 
963 template <>
964 Vectorized<uint8_t> inline operator*(const Vectorized<uint8_t>& a, const Vectorized<uint8_t>& b) {
965   // We don't have an instruction for multiplying uint8_t
966   return int_elementwise_binary_256(a, b, std::multiplies<uint8_t>());
967 }
968 
969 template <>
minimum(const Vectorized<int64_t> & a,const Vectorized<int64_t> & b)970 Vectorized<int64_t> inline minimum(const Vectorized<int64_t>& a, const Vectorized<int64_t>& b) {
971   return emulate(a, b, [](int64_t a_point, int64_t b_point) {return std::min(a_point, b_point);});
972 }
973 
974 template <>
minimum(const Vectorized<int32_t> & a,const Vectorized<int32_t> & b)975 Vectorized<int32_t> inline minimum(const Vectorized<int32_t>& a, const Vectorized<int32_t>& b) {
976   return _mm256_min_epi32(a, b);
977 }
978 
979 template <>
minimum(const Vectorized<int16_t> & a,const Vectorized<int16_t> & b)980 Vectorized<int16_t> inline minimum(const Vectorized<int16_t>& a, const Vectorized<int16_t>& b) {
981   return _mm256_min_epi16(a, b);
982 }
983 
984 template <>
minimum(const Vectorized<int8_t> & a,const Vectorized<int8_t> & b)985 Vectorized<int8_t> inline minimum(const Vectorized<int8_t>& a, const Vectorized<int8_t>& b) {
986   return _mm256_min_epi8(a, b);
987 }
988 
989 template <>
minimum(const Vectorized<uint8_t> & a,const Vectorized<uint8_t> & b)990 Vectorized<uint8_t> inline minimum(const Vectorized<uint8_t>& a, const Vectorized<uint8_t>& b) {
991   return _mm256_min_epu8(a, b);
992 }
993 
994 template <>
maximum(const Vectorized<int64_t> & a,const Vectorized<int64_t> & b)995 Vectorized<int64_t> inline maximum(const Vectorized<int64_t>& a, const Vectorized<int64_t>& b) {
996   return emulate(a, b, [](int64_t a_point, int64_t b_point) {return std::max(a_point, b_point);});
997 }
998 
999 template <>
maximum(const Vectorized<int32_t> & a,const Vectorized<int32_t> & b)1000 Vectorized<int32_t> inline maximum(const Vectorized<int32_t>& a, const Vectorized<int32_t>& b) {
1001   return _mm256_max_epi32(a, b);
1002 }
1003 
1004 template <>
maximum(const Vectorized<int16_t> & a,const Vectorized<int16_t> & b)1005 Vectorized<int16_t> inline maximum(const Vectorized<int16_t>& a, const Vectorized<int16_t>& b) {
1006   return _mm256_max_epi16(a, b);
1007 }
1008 
1009 template <>
maximum(const Vectorized<int8_t> & a,const Vectorized<int8_t> & b)1010 Vectorized<int8_t> inline maximum(const Vectorized<int8_t>& a, const Vectorized<int8_t>& b) {
1011   return _mm256_max_epi8(a, b);
1012 }
1013 
1014 template <>
maximum(const Vectorized<uint8_t> & a,const Vectorized<uint8_t> & b)1015 Vectorized<uint8_t> inline maximum(const Vectorized<uint8_t>& a, const Vectorized<uint8_t>& b) {
1016   return _mm256_max_epu8(a, b);
1017 }
1018 
1019 template <>
clamp(const Vectorized<int64_t> & a,const Vectorized<int64_t> & min_val,const Vectorized<int64_t> & max_val)1020 Vectorized<int64_t> inline clamp(const Vectorized<int64_t>& a, const Vectorized<int64_t>& min_val, const Vectorized<int64_t>& max_val) {
1021   return emulate(a, min_val, max_val, [](int64_t a_point, int64_t min_point, int64_t max_point) {return std::min(max_point, std::max(a_point, min_point));});
1022 }
1023 
1024 template <>
clamp(const Vectorized<int32_t> & a,const Vectorized<int32_t> & min_val,const Vectorized<int32_t> & max_val)1025 Vectorized<int32_t> inline clamp(const Vectorized<int32_t>& a, const Vectorized<int32_t>& min_val, const Vectorized<int32_t>& max_val) {
1026   return _mm256_min_epi32(max_val, _mm256_max_epi32(a, min_val));
1027 }
1028 
1029 template <>
clamp(const Vectorized<int16_t> & a,const Vectorized<int16_t> & min_val,const Vectorized<int16_t> & max_val)1030 Vectorized<int16_t> inline clamp(const Vectorized<int16_t>& a, const Vectorized<int16_t>& min_val, const Vectorized<int16_t>& max_val) {
1031   return _mm256_min_epi16(max_val, _mm256_max_epi16(a, min_val));
1032 }
1033 
1034 template <>
clamp(const Vectorized<int8_t> & a,const Vectorized<int8_t> & min_val,const Vectorized<int8_t> & max_val)1035 Vectorized<int8_t> inline clamp(const Vectorized<int8_t>& a, const Vectorized<int8_t>& min_val, const Vectorized<int8_t>& max_val) {
1036   return _mm256_min_epi8(max_val, _mm256_max_epi8(a, min_val));
1037 }
1038 
1039 template <>
clamp(const Vectorized<uint8_t> & a,const Vectorized<uint8_t> & min_val,const Vectorized<uint8_t> & max_val)1040 Vectorized<uint8_t> inline clamp(const Vectorized<uint8_t>& a, const Vectorized<uint8_t>& min_val, const Vectorized<uint8_t>& max_val) {
1041   return _mm256_min_epu8(max_val, _mm256_max_epu8(a, min_val));
1042 }
1043 
1044 template <>
clamp_max(const Vectorized<int64_t> & a,const Vectorized<int64_t> & max_val)1045 Vectorized<int64_t> inline clamp_max(const Vectorized<int64_t>& a, const Vectorized<int64_t>& max_val) {
1046   return emulate(a, max_val, [](int64_t a_point, int64_t max_point) {return std::min(max_point, a_point);});
1047 }
1048 
1049 template <>
clamp_max(const Vectorized<int32_t> & a,const Vectorized<int32_t> & max_val)1050 Vectorized<int32_t> inline clamp_max(const Vectorized<int32_t>& a, const Vectorized<int32_t>& max_val) {
1051   return _mm256_min_epi32(max_val, a);
1052 }
1053 
1054 template <>
clamp_max(const Vectorized<int16_t> & a,const Vectorized<int16_t> & max_val)1055 Vectorized<int16_t> inline clamp_max(const Vectorized<int16_t>& a, const Vectorized<int16_t>& max_val) {
1056   return _mm256_min_epi16(max_val, a);
1057 }
1058 
1059 template <>
clamp_max(const Vectorized<int8_t> & a,const Vectorized<int8_t> & max_val)1060 Vectorized<int8_t> inline clamp_max(const Vectorized<int8_t>& a, const Vectorized<int8_t>& max_val) {
1061   return _mm256_min_epi8(max_val, a);
1062 }
1063 
1064 template <>
clamp_max(const Vectorized<uint8_t> & a,const Vectorized<uint8_t> & max_val)1065 Vectorized<uint8_t> inline clamp_max(const Vectorized<uint8_t>& a, const Vectorized<uint8_t>& max_val) {
1066   return _mm256_min_epu8(max_val, a);
1067 }
1068 
1069 template <>
clamp_min(const Vectorized<int64_t> & a,const Vectorized<int64_t> & min_val)1070 Vectorized<int64_t> inline clamp_min(const Vectorized<int64_t>& a, const Vectorized<int64_t>& min_val) {
1071   return emulate(a, min_val, [](int64_t a_point, int64_t min_point) {return std::max(min_point, a_point);});
1072 }
1073 
1074 template <>
clamp_min(const Vectorized<int32_t> & a,const Vectorized<int32_t> & min_val)1075 Vectorized<int32_t> inline clamp_min(const Vectorized<int32_t>& a, const Vectorized<int32_t>& min_val) {
1076   return _mm256_max_epi32(min_val, a);
1077 }
1078 
1079 template <>
clamp_min(const Vectorized<int16_t> & a,const Vectorized<int16_t> & min_val)1080 Vectorized<int16_t> inline clamp_min(const Vectorized<int16_t>& a, const Vectorized<int16_t>& min_val) {
1081   return _mm256_max_epi16(min_val, a);
1082 }
1083 
1084 template <>
clamp_min(const Vectorized<int8_t> & a,const Vectorized<int8_t> & min_val)1085 Vectorized<int8_t> inline clamp_min(const Vectorized<int8_t>& a, const Vectorized<int8_t>& min_val) {
1086   return _mm256_max_epi8(min_val, a);
1087 }
1088 
1089 template <>
clamp_min(const Vectorized<uint8_t> & a,const Vectorized<uint8_t> & min_val)1090 Vectorized<uint8_t> inline clamp_min(const Vectorized<uint8_t>& a, const Vectorized<uint8_t>& min_val) {
1091   return _mm256_max_epu8(min_val, a);
1092 }
1093 
1094 template<typename T>
convert_to_int32(const T * ptr)1095 Vectorized<int32_t> inline convert_to_int32(const T* ptr) {
1096   return Vectorized<int32_t>::loadu(ptr);
1097 }
1098 
1099 template<>
1100 Vectorized<int32_t> inline convert_to_int32<int8_t>(const int8_t* ptr) {
1101   return _mm256_cvtepi8_epi32(_mm_loadl_epi64(reinterpret_cast<const __m128i*>(ptr)));
1102 }
1103 
1104 template<>
1105 Vectorized<int32_t> inline convert_to_int32<uint8_t>(const uint8_t* ptr) {
1106   return _mm256_cvtepu8_epi32(_mm_loadl_epi64(reinterpret_cast<const __m128i*>(ptr)));
1107 }
1108 
1109 template <>
1110 Vectorized<int64_t> inline operator/(const Vectorized<int64_t>& a, const Vectorized<int64_t>& b) {
1111   return int_elementwise_binary_256(a, b, std::divides<int64_t>());
1112 }
1113 template <>
1114 Vectorized<int32_t> inline operator/(const Vectorized<int32_t>& a, const Vectorized<int32_t>& b) {
1115   return int_elementwise_binary_256(a, b, std::divides<int32_t>());
1116 }
1117 template <>
1118 Vectorized<int16_t> inline operator/(const Vectorized<int16_t>& a, const Vectorized<int16_t>& b) {
1119   return int_elementwise_binary_256(a, b, std::divides<int16_t>());
1120 }
1121 template <>
1122 Vectorized<int8_t> inline operator/(const Vectorized<int8_t>& a, const Vectorized<int8_t>& b) {
1123   return int_elementwise_binary_256(a, b, std::divides<int8_t>());
1124 }
1125 template <>
1126 Vectorized<uint8_t> inline operator/(const Vectorized<uint8_t>& a, const Vectorized<uint8_t>& b) {
1127   return int_elementwise_binary_256(a, b, std::divides<uint8_t>());
1128 }
1129 
1130 template<class T, typename std::enable_if_t<std::is_base_of<Vectorizedi, Vectorized<T>>::value, int> = 0>
1131 inline Vectorized<T> operator&(const Vectorized<T>& a, const Vectorized<T>& b) {
1132   return _mm256_and_si256(a, b);
1133 }
1134 template<class T, typename std::enable_if_t<std::is_base_of<Vectorizedi, Vectorized<T>>::value, int> = 0>
1135 inline Vectorized<T> operator|(const Vectorized<T>& a, const Vectorized<T>& b) {
1136   return _mm256_or_si256(a, b);
1137 }
1138 template<class T, typename std::enable_if_t<std::is_base_of<Vectorizedi, Vectorized<T>>::value, int> = 0>
1139 inline Vectorized<T> operator^(const Vectorized<T>& a, const Vectorized<T>& b) {
1140   return _mm256_xor_si256(a, b);
1141 }
1142 template<class T, typename std::enable_if_t<std::is_base_of<Vectorizedi, Vectorized<T>>::value, int> = 0>
1143 inline Vectorized<T> operator~(const Vectorized<T>& a) {
1144   return _mm256_xor_si256(a, _mm256_set1_epi32(-1));
1145 }
1146 
eq(const Vectorized<int64_t> & other)1147 inline Vectorized<int64_t> Vectorized<int64_t>::eq(const Vectorized<int64_t>& other) const {
1148   return (*this == other) & Vectorized<int64_t>(1);
1149 }
1150 
ne(const Vectorized<int64_t> & other)1151 inline Vectorized<int64_t> Vectorized<int64_t>::ne(const Vectorized<int64_t>& other) const {
1152   return (*this != other) & Vectorized<int64_t>(1);
1153 }
1154 
gt(const Vectorized<int64_t> & other)1155 inline Vectorized<int64_t> Vectorized<int64_t>::gt(const Vectorized<int64_t>& other) const {
1156   return (*this > other) & Vectorized<int64_t>(1);
1157 }
1158 
ge(const Vectorized<int64_t> & other)1159 inline Vectorized<int64_t> Vectorized<int64_t>::ge(const Vectorized<int64_t>& other) const {
1160   return (*this >= other) & Vectorized<int64_t>(1);
1161 }
1162 
lt(const Vectorized<int64_t> & other)1163 inline Vectorized<int64_t> Vectorized<int64_t>::lt(const Vectorized<int64_t>& other) const {
1164   return (*this < other) & Vectorized<int64_t>(1);
1165 }
1166 
le(const Vectorized<int64_t> & other)1167 inline Vectorized<int64_t> Vectorized<int64_t>::le(const Vectorized<int64_t>& other) const {
1168   return (*this <= other) & Vectorized<int64_t>(1);
1169 }
1170 
eq(const Vectorized<int32_t> & other)1171 inline Vectorized<int32_t> Vectorized<int32_t>::eq(const Vectorized<int32_t>& other) const {
1172   return (*this == other) & Vectorized<int32_t>(1);
1173 }
1174 
ne(const Vectorized<int32_t> & other)1175 inline Vectorized<int32_t> Vectorized<int32_t>::ne(const Vectorized<int32_t>& other) const {
1176   return (*this != other) & Vectorized<int32_t>(1);
1177 }
1178 
gt(const Vectorized<int32_t> & other)1179 inline Vectorized<int32_t> Vectorized<int32_t>::gt(const Vectorized<int32_t>& other) const {
1180   return (*this > other) & Vectorized<int32_t>(1);
1181 }
1182 
ge(const Vectorized<int32_t> & other)1183 inline Vectorized<int32_t> Vectorized<int32_t>::ge(const Vectorized<int32_t>& other) const {
1184   return (*this >= other) & Vectorized<int32_t>(1);
1185 }
1186 
lt(const Vectorized<int32_t> & other)1187 inline Vectorized<int32_t> Vectorized<int32_t>::lt(const Vectorized<int32_t>& other) const {
1188   return (*this < other) & Vectorized<int32_t>(1);
1189 }
1190 
le(const Vectorized<int32_t> & other)1191 inline Vectorized<int32_t> Vectorized<int32_t>::le(const Vectorized<int32_t>& other) const {
1192   return (*this <= other) & Vectorized<int32_t>(1);
1193 }
1194 
eq(const Vectorized<int16_t> & other)1195 inline Vectorized<int16_t> Vectorized<int16_t>::eq(const Vectorized<int16_t>& other) const {
1196   return (*this == other) & Vectorized<int16_t>(1);
1197 }
1198 
ne(const Vectorized<int16_t> & other)1199 inline Vectorized<int16_t> Vectorized<int16_t>::ne(const Vectorized<int16_t>& other) const {
1200   return (*this != other) & Vectorized<int16_t>(1);
1201 }
1202 
gt(const Vectorized<int16_t> & other)1203 inline Vectorized<int16_t> Vectorized<int16_t>::gt(const Vectorized<int16_t>& other) const {
1204   return (*this > other) & Vectorized<int16_t>(1);
1205 }
1206 
ge(const Vectorized<int16_t> & other)1207 inline Vectorized<int16_t> Vectorized<int16_t>::ge(const Vectorized<int16_t>& other) const {
1208   return (*this >= other) & Vectorized<int16_t>(1);
1209 }
1210 
lt(const Vectorized<int16_t> & other)1211 inline Vectorized<int16_t> Vectorized<int16_t>::lt(const Vectorized<int16_t>& other) const {
1212   return (*this < other) & Vectorized<int16_t>(1);
1213 }
1214 
le(const Vectorized<int16_t> & other)1215 inline Vectorized<int16_t> Vectorized<int16_t>::le(const Vectorized<int16_t>& other) const {
1216   return (*this <= other) & Vectorized<int16_t>(1);
1217 }
1218 
eq(const Vectorized<int8_t> & other)1219 inline Vectorized<int8_t> Vectorized<int8_t>::eq(const Vectorized<int8_t>& other) const {
1220   return (*this == other) & Vectorized<int8_t>(1);
1221 }
1222 
ne(const Vectorized<int8_t> & other)1223 inline Vectorized<int8_t> Vectorized<int8_t>::ne(const Vectorized<int8_t>& other) const {
1224   return (*this != other) & Vectorized<int8_t>(1);
1225 }
1226 
gt(const Vectorized<int8_t> & other)1227 inline Vectorized<int8_t> Vectorized<int8_t>::gt(const Vectorized<int8_t>& other) const {
1228   return (*this > other) & Vectorized<int8_t>(1);
1229 }
1230 
ge(const Vectorized<int8_t> & other)1231 inline Vectorized<int8_t> Vectorized<int8_t>::ge(const Vectorized<int8_t>& other) const {
1232   return (*this >= other) & Vectorized<int8_t>(1);
1233 }
1234 
lt(const Vectorized<int8_t> & other)1235 inline Vectorized<int8_t> Vectorized<int8_t>::lt(const Vectorized<int8_t>& other) const {
1236   return (*this < other) & Vectorized<int8_t>(1);
1237 }
1238 
le(const Vectorized<int8_t> & other)1239 inline Vectorized<int8_t> Vectorized<int8_t>::le(const Vectorized<int8_t>& other) const {
1240   return (*this <= other) & Vectorized<int8_t>(1);
1241 }
1242 
eq(const Vectorized<uint8_t> & other)1243 inline Vectorized<uint8_t> Vectorized<uint8_t>::eq(const Vectorized<uint8_t>& other) const {
1244   return (*this == other) & Vectorized<uint8_t>(1);
1245 }
1246 
ne(const Vectorized<uint8_t> & other)1247 inline Vectorized<uint8_t> Vectorized<uint8_t>::ne(const Vectorized<uint8_t>& other) const {
1248   return (*this != other) & Vectorized<uint8_t>(1);
1249 }
1250 
gt(const Vectorized<uint8_t> & other)1251 inline Vectorized<uint8_t> Vectorized<uint8_t>::gt(const Vectorized<uint8_t>& other) const {
1252   return (*this > other) & Vectorized<uint8_t>(1);
1253 }
1254 
ge(const Vectorized<uint8_t> & other)1255 inline Vectorized<uint8_t> Vectorized<uint8_t>::ge(const Vectorized<uint8_t>& other) const {
1256   return (*this >= other) & Vectorized<uint8_t>(1);
1257 }
1258 
lt(const Vectorized<uint8_t> & other)1259 inline Vectorized<uint8_t> Vectorized<uint8_t>::lt(const Vectorized<uint8_t>& other) const {
1260   return (*this < other) & Vectorized<uint8_t>(1);
1261 }
1262 
le(const Vectorized<uint8_t> & other)1263 inline Vectorized<uint8_t> Vectorized<uint8_t>::le(const Vectorized<uint8_t>& other) const {
1264   return (*this <= other) & Vectorized<uint8_t>(1);
1265 }
1266 
1267 template <bool left_shift>
shift_256_16(const Vectorized<int16_t> & a,const Vectorized<int16_t> & b)1268 Vectorized<int16_t> inline shift_256_16(const Vectorized<int16_t>& a, const Vectorized<int16_t>& b) {
1269   // No vector instruction for shifting int16_t, so emulating it instead.
1270 
1271   // Control masks for shuffle operation, treating 256 bits as an
1272   // array of 16-bit elements, and considering pairs of neighboring
1273   // elements.  Specifially, a mask named "ctl_M_N" (M,N in [0,1], and
1274   // M!=N) is set so that shuffle will move element with index M from
1275   // input pair into element with index N in output pair, and element
1276   // with index M in output pair will be set to all 0s.
1277   __m256i ctl_0_1 = _mm256_set_epi8(29, 28, 0x80, 0x80, 25, 24, 0x80, 0x80,
1278                                     21, 20, 0x80, 0x80, 17, 16, 0x80, 0x80,
1279                                     13, 12, 0x80, 0x80, 9, 8, 0x80, 0x80,
1280                                     5, 4, 0x80, 0x80, 1, 0, 0x80, 0x80);
1281   __m256i ctl_1_0 = _mm256_set_epi8(0x80, 0x80, 31, 30, 0x80, 0x80, 27, 26,
1282                                     0x80, 0x80, 23, 22, 0x80, 0x80, 19, 18,
1283                                     0x80, 0x80, 15, 14, 0x80, 0x80, 11, 10,
1284                                     0x80, 0x80, 7, 6, 0x80, 0x80, 3, 2);
1285 
1286   // Masks for bitwise and operation, treating 256 bits as an array of
1287   // 16-bit elements, and considering them in pairs of neighboring
1288   // elements.  A mask named "keep_M" (M in [0,1]) is set so that
1289   // bitwise and will copy element with index M from input pair into
1290   // element with the same index in output pair, while the other
1291   // element in output pair will be set to all 0s.
1292   __m256i keep_0 = _mm256_set1_epi32(0xFFFF);
1293   __m256i keep_1 = _mm256_set1_epi32(0xFFFF0000);
1294 
1295   // Take each 16-bit element with idx%2==0 from input array to be
1296   // shifted and extend it to 32 bits so that 0s are added to the
1297   // right.  Then, perform shifting on this 32-bit number.  Upper 16
1298   // bits will be proper result of shifting original 16-bit number, so
1299   // write them to result array, into the same position from which
1300   // corresponding input element is taken.  Also, make sure that
1301   // result array elements with idx%2!=0 are set to all 0s.
1302   //
1303   // Note that number of bits to shift for is extended to 32 bits by
1304   // adding 0s to the left.  That means this number is not properly
1305   // sign-extended for negative values.  However, number of bits to
1306   // shift is treated as an unsigned integer by respective shift
1307   // intrinsics anyway so if negative then either with or without
1308   // proper sign extension, it will be interpreted as a number greater
1309   // than 32, and the shifting result will be the same.
1310   __m256i a0 = _mm256_shuffle_epi8(a, ctl_0_1);
1311   __m256i b0 = _mm256_and_si256(b, keep_0);
1312   __m256i c0;
1313   if (left_shift)
1314     c0 = _mm256_sllv_epi32(a0, b0);
1315   else
1316     c0 = _mm256_srav_epi32(a0, b0);
1317   c0 = _mm256_shuffle_epi8(c0, ctl_1_0);
1318 
1319   // Peform shifting the same way for input array elements with
1320   // idx%2==1.
1321   __m256i a1 = _mm256_and_si256(a, keep_1);
1322   __m256i b1 = _mm256_shuffle_epi8(b, ctl_1_0);
1323   __m256i c1;
1324   if (left_shift)
1325     c1 = _mm256_sllv_epi32(a1, b1);
1326   else
1327     c1 = _mm256_srav_epi32(a1, b1);
1328   c1 = _mm256_and_si256(c1, keep_1);
1329 
1330   // Merge partial results into the final result.
1331   __m256i c = _mm256_or_si256(c0, c1);
1332 
1333   return c;
1334 }
1335 
1336 template <bool left_shift, typename T, typename std::enable_if_t<std::is_same<T, int8_t>::value || std::is_same<T, uint8_t>::value, int> = 0>
shift_256_8(const Vectorized<T> & a,const Vectorized<T> & b)1337 Vectorized<T> inline shift_256_8(const Vectorized<T>& a, const Vectorized<T>& b) {
1338   // No vector instruction for shifting int8_t/uint8_t, so emulating
1339   // it instead.
1340 
1341   // Control masks for shuffle operation, treating 256 bits as an
1342   // array of 8-bit elements, and considering quadruples of
1343   // neighboring elements.  Specifially, a mask named "ctl_M_N" (M,N
1344   // in [0,1,2,3], and M!=N) is set so that shuffle will move element
1345   // with index M from input quadruple into element with index N in
1346   // output quadruple, and other elements in output quadruple will be
1347   // set to all 0s.
1348   __m256i ctl_0_3 = _mm256_set_epi8(28, 0x80, 0x80, 0x80, 24, 0x80, 0x80, 0x80,
1349                                     20, 0x80, 0x80, 0x80, 16, 0x80, 0x80, 0x80,
1350                                     12, 0x80, 0x80, 0x80, 8, 0x80, 0x80, 0x80,
1351                                     4, 0x80, 0x80, 0x80, 0, 0x80, 0x80, 0x80);
1352   __m256i ctl_1_0 = _mm256_set_epi8(0x80, 0x80, 0x80, 29, 0x80, 0x80, 0x80, 25,
1353                                     0x80, 0x80, 0x80, 21, 0x80, 0x80, 0x80, 17,
1354                                     0x80, 0x80, 0x80, 13, 0x80, 0x80, 0x80, 9,
1355                                     0x80, 0x80, 0x80, 5, 0x80, 0x80, 0x80, 1);
1356   __m256i ctl_1_3 = _mm256_set_epi8(29, 0x80, 0x80, 0x80, 25, 0x80, 0x80, 0x80,
1357                                     21, 0x80, 0x80, 0x80, 17, 0x80, 0x80, 0x80,
1358                                     13, 0x80, 0x80, 0x80, 9, 0x80, 0x80, 0x80,
1359                                     5, 0x80, 0x80, 0x80, 1, 0x80, 0x80, 0x80);
1360   __m256i ctl_2_0 = _mm256_set_epi8(0x80, 0x80, 0x80, 30, 0x80, 0x80, 0x80, 26,
1361                                     0x80, 0x80, 0x80, 22, 0x80, 0x80, 0x80, 18,
1362                                     0x80, 0x80, 0x80, 14, 0x80, 0x80, 0x80, 10,
1363                                     0x80, 0x80, 0x80, 6, 0x80, 0x80, 0x80, 2);
1364   __m256i ctl_2_3 = _mm256_set_epi8(30, 0x80, 0x80, 0x80, 26, 0x80, 0x80, 0x80,
1365                                     22, 0x80, 0x80, 0x80, 18, 0x80, 0x80, 0x80,
1366                                     14, 0x80, 0x80, 0x80, 10, 0x80, 0x80, 0x80,
1367                                     6, 0x80, 0x80, 0x80, 2, 0x80, 0x80, 0x80);
1368   __m256i ctl_3_0 = _mm256_set_epi8(0x80, 0x80, 0x80, 31, 0x80, 0x80, 0x80, 27,
1369                                     0x80, 0x80, 0x80, 23, 0x80, 0x80, 0x80, 19,
1370                                     0x80, 0x80, 0x80, 15, 0x80, 0x80, 0x80, 11,
1371                                     0x80, 0x80, 0x80, 7, 0x80, 0x80, 0x80, 3);
1372   __m256i ctl_3_1 = _mm256_set_epi8(0x80, 0x80, 31, 0x80, 0x80, 0x80, 27, 0x80,
1373                                     0x80, 0x80, 23, 0x80, 0x80, 0x80, 19, 0x80,
1374                                     0x80, 0x80, 15, 0x80, 0x80, 0x80, 11, 0x80,
1375                                     0x80, 0x80, 7, 0x80, 0x80, 0x80, 3, 0x80);
1376   __m256i ctl_3_2 = _mm256_set_epi8(0x80, 31, 0x80, 0x80, 0x80, 27, 0x80, 0x80,
1377                                     0x80, 23, 0x80, 0x80, 0x80, 19, 0x80, 0x80,
1378                                     0x80, 15, 0x80, 0x80, 0x80, 11, 0x80, 0x80,
1379                                     0x80, 7, 0x80, 0x80, 0x80, 3, 0x80, 0x80);
1380 
1381   // Masks for bitwise and operation, treating 256 bits as an array of
1382   // 8-bit elements, and considering them in quadruples of neighboring
1383   // elements.  A mask named "keep_M" (M in [0,1,2,3]) is set so that
1384   // bitwise and will copy element with index M from input quadruple
1385   // into element with the same index in output quadruple, while the
1386   // other elements in output quadruple will be set to all 0s.
1387   __m256i keep_0 = _mm256_set1_epi32(0xFF);
1388   __m256i keep_3 = _mm256_set1_epi32(0xFF000000);
1389 
1390   // Take each 8-bit element with idx%4==0 from input array to be
1391   // shifted and extend it to 32 bits so that 0s are added to the
1392   // right.  Then, perform shifting on this 32-bit number.  Upper 8
1393   // bits will be proper result of shifting original 8-bit number, so
1394   // write them to result array, into the same position from which
1395   // corresponding input element is taken.  Also, make sure that
1396   // result array elements with idx%4!=0 are set to all 0s.
1397   //
1398   // Note that number of bits to shift for is extended to 32 bits by
1399   // adding 0s to the left.  That means this number is not properly
1400   // sign-extended for negative values.  However, number of bits to
1401   // shift is treated as an unsigned integer by respective shift
1402   // intrinsics anyway so if negative then either with or without
1403   // proper sign extension, it will be interpreted as a number greater
1404   // than 32, and the shifting result will be the same.
1405   __m256i a0 = _mm256_shuffle_epi8(a, ctl_0_3);
1406   __m256i b0 = _mm256_and_si256(b, keep_0);
1407   __m256i c0;
1408   if (left_shift)
1409     c0 = _mm256_sllv_epi32(a0, b0);
1410   else
1411     if (std::is_same<T, int8_t>::value)
1412       c0 = _mm256_srav_epi32(a0, b0);
1413     else
1414       c0 = _mm256_srlv_epi32(a0, b0);
1415   c0 = _mm256_shuffle_epi8(c0, ctl_3_0);
1416 
1417   // Peform shifting the same way for input array elements with
1418   // idx%4==1.
1419   __m256i a1 = _mm256_shuffle_epi8(a, ctl_1_3);
1420   __m256i b1 = _mm256_shuffle_epi8(b, ctl_1_0);
1421   __m256i c1;
1422   if (left_shift)
1423     c1 = _mm256_sllv_epi32(a1, b1);
1424   else
1425     if (std::is_same<T, int8_t>::value)
1426       c1 = _mm256_srav_epi32(a1, b1);
1427     else
1428       c1 = _mm256_srlv_epi32(a1, b1);
1429   c1 = _mm256_shuffle_epi8(c1, ctl_3_1);
1430 
1431   // Peform shifting the same way for input array elements with
1432   // idx%4==2.
1433   __m256i a2 = _mm256_shuffle_epi8(a, ctl_2_3);
1434   __m256i b2 = _mm256_shuffle_epi8(b, ctl_2_0);
1435   __m256i c2;
1436   if (left_shift)
1437     c2 = _mm256_sllv_epi32(a2, b2);
1438   else
1439     if (std::is_same<T, int8_t>::value)
1440       c2 = _mm256_srav_epi32(a2, b2);
1441     else
1442       c2 = _mm256_srlv_epi32(a2, b2);
1443   c2 = _mm256_shuffle_epi8(c2, ctl_3_2);
1444 
1445   // Peform shifting the same way for input array elements with
1446   // idx%4==3.
1447   __m256i a3 =  _mm256_and_si256(a, keep_3);
1448   __m256i b3 = _mm256_shuffle_epi8(b, ctl_3_0);
1449   __m256i c3;
1450   if (left_shift)
1451     c3 = _mm256_sllv_epi32(a3, b3);
1452   else
1453     if (std::is_same<T, int8_t>::value)
1454       c3 = _mm256_srav_epi32(a3, b3);
1455     else
1456       c3 = _mm256_srlv_epi32(a3, b3);
1457   c3 = _mm256_and_si256(c3, keep_3);
1458 
1459   // Merge partial results into the final result.
1460   __m256i c01 = _mm256_or_si256(c0, c1);
1461   __m256i c23 = _mm256_or_si256(c2, c3);
1462   __m256i c = _mm256_or_si256(c01, c23);
1463 
1464   return c;
1465 }
1466 
1467 template <>
1468 Vectorized<int64_t> inline operator<<(const Vectorized<int64_t>& a, const Vectorized<int64_t>& b) {
1469   return _mm256_sllv_epi64(a, b);
1470 }
1471 
1472 template <>
1473 Vectorized<int32_t> inline operator<<(const Vectorized<int32_t>& a, const Vectorized<int32_t>& b) {
1474   return _mm256_sllv_epi32(a, b);
1475 }
1476 
1477 template <>
1478 Vectorized<int16_t> inline operator<<(const Vectorized<int16_t>& a, const Vectorized<int16_t>& b) {
1479   return shift_256_16<true>(a, b);
1480 }
1481 
1482 template <>
1483 Vectorized<int8_t> inline operator<<(const Vectorized<int8_t>& a, const Vectorized<int8_t>& b) {
1484   return shift_256_8<true>(a, b);
1485 }
1486 
1487 template <>
1488 Vectorized<uint8_t> inline operator<<(const Vectorized<uint8_t>& a, const Vectorized<uint8_t>& b) {
1489   return shift_256_8<true>(a, b);
1490 }
1491 
1492 template <>
1493 Vectorized<int64_t> inline operator>>(const Vectorized<int64_t>& a, const Vectorized<int64_t>& b) {
1494   // No vector instruction for right arithmetic shifting int64_t, so emulating it
1495   // instead.
1496 
1497   // Clamp the shift values such that shift values < 0 and > 64 are changed to 64
1498   // which results in -1 for negative input and 0 for non-negative input.
1499   __m256i zero = _mm256_set1_epi64x(0);
1500   __m256i max_shift = _mm256_set1_epi64x(64);
1501   __m256i mask = _mm256_or_si256(_mm256_cmpgt_epi64(zero, b), _mm256_cmpgt_epi64(b, max_shift));
1502   __m256i shift = _mm256_blendv_epi8(b, max_shift, mask);
1503   // Shift the number logically to the right, thus filling the most
1504   // significant bits with 0s.  Then, replace these bits with the sign
1505   // bit.
1506   __m256i sign_bits = _mm256_cmpgt_epi64(zero, a);
1507   __m256i sign_shift = _mm256_sub_epi64(max_shift, shift);
1508   __m256i sign_ext = _mm256_sllv_epi64(sign_bits, sign_shift);
1509   __m256i c = _mm256_srlv_epi64(a, shift);
1510   c = _mm256_or_si256(c, sign_ext);
1511 
1512   return c;
1513 }
1514 
1515 template <>
1516 Vectorized<int32_t> inline operator>>(const Vectorized<int32_t>& a, const Vectorized<int32_t>& b) {
1517   return _mm256_srav_epi32(a, b);
1518 }
1519 
1520 template <>
1521 Vectorized<int16_t> inline operator>>(const Vectorized<int16_t>& a, const Vectorized<int16_t>& b) {
1522   return shift_256_16<false>(a, b);
1523 }
1524 
1525 template <>
1526 Vectorized<int8_t> inline operator>>(const Vectorized<int8_t>& a, const Vectorized<int8_t>& b) {
1527   return shift_256_8<false>(a, b);
1528 }
1529 
1530 template <>
1531 Vectorized<uint8_t> inline operator>>(const Vectorized<uint8_t>& a, const Vectorized<uint8_t>& b) {
1532   return shift_256_8<false>(a, b);
1533 }
1534 
1535 #endif
1536 
1537 }}}
1538