xref: /aosp_15_r20/external/pytorch/aten/src/ATen/cpu/vec/vec_n.h (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #pragma once
2 
3 #include <ATen/cpu/vec/vec_base.h>
4 #include <array>
5 
6 namespace at::vec {
7 inline namespace CPU_CAPABILITY {
8 
9 /**
10  * @brief A class template representing a vectorized type with
11  * `N * Vectorized<T>::size()` elements, aiming to support vectors of
12  * arbitrary size. A specific use case of it is to represent vectors
13  * converted from data types with different sizes but with the same
14  * number of vector elements, e.g., `VectorizedN<float, 2>` can be
15  * a vector converted from two `Vectorized<bfloat16>`, `VectorizedN<int64_t, 2>`
16  * can be a vector converted from two `Vectorized<int32_t>` etc.
17  *
18  * It supports most of the operations of `Vectorized<T>`
19  * and the implementation delegates to `Vectorized<T>` with loops over `N`.
20  *
21  * @tparam T The underlying type of the vectorized elements.
22  * @tparam N The number of underlying `Vectorized<T>`.
23  */
24 template <typename T, int N>
25 class VectorizedN {
26  public:
27   using value_type = T;
28   using size_type = int;
29 
30   static constexpr size_type size_T = sizeof(T);
size()31   static constexpr size_type size() {
32     return Vectorized<T>::size() * N;
33   }
34 
35  private:
36   std::array<Vectorized<T>, N> values;
37 
38  public:
39   // methods not implemented yet:
40   // variadic constructor, operator T*, as_bytes, zero_mask
41 
42 #define VECTORIZEDN_DEFINE_UNARY_OP(op)                             \
43   VectorizedN<T, N> op() const {                                    \
44     return unary_op([](const Vectorized<T>& a) { return a.op(); }); \
45   }
46 
47 #define VECTORIZEDN_DEFINE_BINARY_OP(op)                            \
48   VectorizedN<T, N> op(const VectorizedN<T, N>& other) const {      \
49     return binary_op(                                               \
50         other, [](const Vectorized<T>& a, const Vectorized<T>& b) { \
51           return a.op(b);                                           \
52         });                                                         \
53   }
54 
55   template <typename Op>
unary_op(Op op)56   inline VectorizedN<T, N> unary_op(Op op) const {
57     VectorizedN<T, N> result;
58 #ifndef _MSC_VER
59 #pragma unroll
60 #endif
61     for (int i = 0; i < N; ++i) {
62       result.values[i] = op(values[i]);
63     }
64     return result;
65   }
66 
67   template <typename Op>
binary_op(const VectorizedN<T,N> & other,Op op)68   inline VectorizedN<T, N> binary_op(const VectorizedN<T, N>& other, Op op)
69       const {
70     VectorizedN<T, N> result;
71 #ifndef _MSC_VER
72 #pragma unroll
73 #endif
74     for (int i = 0; i < N; ++i) {
75       result.values[i] = op(values[i], other.values[i]);
76     }
77     return result;
78   }
79 
80   VectorizedN() = default;
81 
VectorizedN(T val)82   explicit VectorizedN(T val) {
83     for (int i = 0; i < N; ++i) {
84       values[i] = Vectorized<T>(val);
85     }
86   }
87 
88   template <int L = N, typename std::enable_if_t<L == 1, int> = 0>
VectorizedN(const Vectorized<T> & val)89   VectorizedN(const Vectorized<T>& val) : values({val}) {}
90 
91   template <int L = N, typename std::enable_if_t<L == 2, int> = 0>
VectorizedN(const Vectorized<T> & val_0,const Vectorized<T> & val_1)92   VectorizedN(const Vectorized<T>& val_0, const Vectorized<T>& val_1) : values({val_0, val_1}) {}
93 
94   template <int L = N, typename std::enable_if_t<L == 1, int> = 0>
95   inline operator Vectorized<T>() const {
96     return values[0];
97   }
98 
99   inline const Vectorized<T>& operator[](int i) const {
100     return values[i];
101   }
102 
103   inline Vectorized<T>& operator[](int i) {
104     return values[i];
105   }
106 
107   template <int64_t mask>
blend(const VectorizedN<T,N> & a,const VectorizedN<T,N> & b)108   static VectorizedN<T, N> blend(
109       const VectorizedN<T, N>& a,
110       const VectorizedN<T, N>& b) {
111     VectorizedN<T, N> result;
112     for (int i = 0; i < N; ++i) {
113       result.values[i] = Vectorized<T>::template blend<mask>(a.values[i], b.values[i]);
114     }
115     return result;
116   }
117 
blendv(const VectorizedN<T,N> & a,const VectorizedN<T,N> & b,const VectorizedN<T,N> & mask)118   static VectorizedN<T, N> blendv(
119       const VectorizedN<T, N>& a,
120       const VectorizedN<T, N>& b,
121       const VectorizedN<T, N>& mask) {
122     VectorizedN<T, N> result;
123     for (int i = 0; i < N; ++i) {
124       result.values[i] =
125           Vectorized<T>::blendv(a.values[i], b.values[i], mask.values[i]);
126     }
127     return result;
128   }
129 
130   template <typename step_t>
131   static VectorizedN<T, N> arange(
132       T base = static_cast<T>(0),
133       step_t step = static_cast<step_t>(1)) {
134     VectorizedN<T, N> result;
135     for (int i = 0; i < N; ++i) {
136       result.values[i] = Vectorized<T>::arange(base, step);
137       base += step * Vectorized<T>::size();
138     }
139     return result;
140   }
141 
142   static VectorizedN<T, N> set(
143       const VectorizedN<T, N>& a,
144       const VectorizedN<T, N>& b,
145       int64_t count = size()) {
146     VectorizedN<T, N> result;
147     for (int i = 0; i < N; ++i) {
148       if (count > 0) {
149         result.values[i] = Vectorized<T>::set(
150             a.values[i],
151             b.values[i],
152             std::min(count, (int64_t)Vectorized<T>::size()));
153         count -= Vectorized<T>::size();
154       } else {
155         result.values[i] = a.values[i];
156       }
157     }
158     return result;
159   }
160 
loadu(const void * ptr)161   static VectorizedN<T, N> loadu(const void* ptr) {
162     VectorizedN<T, N> result;
163     for (int i = 0; i < N; ++i) {
164       result.values[i] = Vectorized<T>::loadu(ptr);
165       ptr = static_cast<const T*>(ptr) + Vectorized<T>::size();
166     }
167     return result;
168   }
169 
loadu(const void * ptr,int64_t count)170   static VectorizedN<T, N> loadu(const void* ptr, int64_t count) {
171     VectorizedN<T, N> result;
172     for (int i = 0; i < N; ++i) {
173       result.values[i] = Vectorized<T>::loadu(
174           ptr, std::min(count, (int64_t)Vectorized<T>::size()));
175       ptr = static_cast<const T*>(ptr) + Vectorized<T>::size();
176       count -= Vectorized<T>::size();
177       if (count <= 0) {
178         break;
179       }
180     }
181     return result;
182   }
183 
store(void * ptr)184   void store(void* ptr) const {
185     for (int i = 0; i < N; ++i) {
186       values[i].store(ptr);
187       ptr = static_cast<T*>(ptr) + Vectorized<T>::size();
188     }
189   }
190 
store(void * ptr,int count)191   void store(void* ptr, int count) const {
192     for (int i = 0; i < N; ++i) {
193       values[i].store(ptr, std::min(count, (int)Vectorized<T>::size()));
194       ptr = static_cast<T*>(ptr) + Vectorized<T>::size();
195       count -= Vectorized<T>::size();
196       if (count <= 0) {
197         break;
198       }
199     }
200   }
201 
has_inf_nan()202   bool has_inf_nan() const {
203     for (int i = 0; i < N; ++i) {
204       if (values[i].has_inf_nan()) {
205         return true;
206       }
207     }
208     return false;
209   }
210 
map(T (* const f)(T))211   VectorizedN<T, N> map(T (*const f)(T)) const {
212     VectorizedN<T, N> result;
213     for (int i = 0; i < N; ++i) {
214       result.values[i] = values[i].map(f);
215     }
216     return result;
217   }
218 
map(T (* const f)(const T &))219   VectorizedN<T, N> map(T (*const f)(const T&)) const {
220     VectorizedN<T, N> result;
221     for (int i = 0; i < N; ++i) {
222       result.values[i] = values[i].map(f);
223     }
224     return result;
225   }
226 
227   VECTORIZEDN_DEFINE_UNARY_OP(isnan)
228   VECTORIZEDN_DEFINE_UNARY_OP(abs)
229   VECTORIZEDN_DEFINE_UNARY_OP(sgn)
230   VECTORIZEDN_DEFINE_UNARY_OP(angle)
231   VECTORIZEDN_DEFINE_UNARY_OP(real)
232   VECTORIZEDN_DEFINE_UNARY_OP(imag)
233   VECTORIZEDN_DEFINE_UNARY_OP(conj)
234   VECTORIZEDN_DEFINE_UNARY_OP(acos)
235   VECTORIZEDN_DEFINE_UNARY_OP(acosh)
236   VECTORIZEDN_DEFINE_UNARY_OP(asin)
237   VECTORIZEDN_DEFINE_UNARY_OP(atan)
238   VECTORIZEDN_DEFINE_UNARY_OP(atanh)
239   VECTORIZEDN_DEFINE_BINARY_OP(atan2)
240   VECTORIZEDN_DEFINE_BINARY_OP(copysign)
241   VECTORIZEDN_DEFINE_UNARY_OP(erf)
242   VECTORIZEDN_DEFINE_UNARY_OP(erfc)
243   VECTORIZEDN_DEFINE_UNARY_OP(erfinv)
244   VECTORIZEDN_DEFINE_UNARY_OP(exp)
245   VECTORIZEDN_DEFINE_UNARY_OP(exp2)
246   VECTORIZEDN_DEFINE_UNARY_OP(expm1)
247   VECTORIZEDN_DEFINE_UNARY_OP(exp_u20)
248   VECTORIZEDN_DEFINE_UNARY_OP(frac)
249   VECTORIZEDN_DEFINE_BINARY_OP(fmod)
250   VECTORIZEDN_DEFINE_UNARY_OP(log)
251   VECTORIZEDN_DEFINE_UNARY_OP(log10)
252   VECTORIZEDN_DEFINE_UNARY_OP(log1p)
253   VECTORIZEDN_DEFINE_UNARY_OP(log2)
254   VECTORIZEDN_DEFINE_UNARY_OP(ceil)
255   VECTORIZEDN_DEFINE_UNARY_OP(cos)
256   VECTORIZEDN_DEFINE_UNARY_OP(cosh)
257   VECTORIZEDN_DEFINE_UNARY_OP(floor)
258   VECTORIZEDN_DEFINE_BINARY_OP(hypot)
259   VECTORIZEDN_DEFINE_UNARY_OP(i0)
260   VECTORIZEDN_DEFINE_UNARY_OP(i0e)
261   VECTORIZEDN_DEFINE_UNARY_OP(digamma)
262   VECTORIZEDN_DEFINE_BINARY_OP(igamma)
263   VECTORIZEDN_DEFINE_BINARY_OP(igammac)
264   VECTORIZEDN_DEFINE_UNARY_OP(neg)
265   VECTORIZEDN_DEFINE_BINARY_OP(nextafter)
266   VECTORIZEDN_DEFINE_UNARY_OP(round)
267   VECTORIZEDN_DEFINE_UNARY_OP(sin)
268   VECTORIZEDN_DEFINE_UNARY_OP(sinh)
269   VECTORIZEDN_DEFINE_UNARY_OP(tan)
270   VECTORIZEDN_DEFINE_UNARY_OP(tanh)
271   VECTORIZEDN_DEFINE_UNARY_OP(trunc)
272   VECTORIZEDN_DEFINE_UNARY_OP(lgamma)
273   VECTORIZEDN_DEFINE_UNARY_OP(sqrt)
274   VECTORIZEDN_DEFINE_UNARY_OP(reciprocal)
275   VECTORIZEDN_DEFINE_UNARY_OP(rsqrt)
276   VECTORIZEDN_DEFINE_BINARY_OP(pow)
277   VECTORIZEDN_DEFINE_BINARY_OP(operator==)
278   VECTORIZEDN_DEFINE_BINARY_OP(operator!=)
279   VECTORIZEDN_DEFINE_BINARY_OP(operator>=)
280   VECTORIZEDN_DEFINE_BINARY_OP(operator<=)
281   VECTORIZEDN_DEFINE_BINARY_OP(operator>)
282   VECTORIZEDN_DEFINE_BINARY_OP(operator<)
283   VECTORIZEDN_DEFINE_BINARY_OP(eq)
284   VECTORIZEDN_DEFINE_BINARY_OP(ne)
285   VECTORIZEDN_DEFINE_BINARY_OP(gt)
286   VECTORIZEDN_DEFINE_BINARY_OP(ge)
287   VECTORIZEDN_DEFINE_BINARY_OP(lt)
288   VECTORIZEDN_DEFINE_BINARY_OP(le)
289 
290 #undef VECTORIZEDN_DEFINE_UNARY_OP
291 #undef VECTORIZEDN_DEFINE_BINARY_OP
292 };
293 
294 #define VECTORIZEDN_DEFINE_UNARY_OP_GLOBAL(op)                       \
295   template <typename T, int N>                                       \
296   inline VectorizedN<T, N> op(const VectorizedN<T, N>& a) {          \
297     return a.unary_op([](const Vectorized<T>& a) { return op(a); }); \
298   }
299 
300 #define VECTORIZEDN_DEFINE_BINARY_OP_GLOBAL(op)                                \
301   template <typename T, int N>                                                 \
302   inline VectorizedN<T, N> op(                                                 \
303       const VectorizedN<T, N>& a, const VectorizedN<T, N>& b) {                \
304     return a.binary_op(b, [](const Vectorized<T>& a, const Vectorized<T>& b) { \
305       return op(a, b);                                                         \
306     });                                                                        \
307   }
308 
309 #define VECTORIZEDN_DEFINE_BINARY_OP_INPLACE_GLOBAL(op)                     \
310   template <typename T, int N>                                              \
311   inline VectorizedN<T, N>& op(                                             \
312       VectorizedN<T, N>& a, const VectorizedN<T, N>& b) {                   \
313     a = a.binary_op(b, [](const Vectorized<T>& a, const Vectorized<T>& b) { \
314       return op(a, b);                                                      \
315     });                                                                     \
316     return a;                                                               \
317   }
318 
319 VECTORIZEDN_DEFINE_BINARY_OP_GLOBAL(operator+)
320 VECTORIZEDN_DEFINE_BINARY_OP_GLOBAL(operator-)
VECTORIZEDN_DEFINE_BINARY_OP_GLOBAL(operator *)321 VECTORIZEDN_DEFINE_BINARY_OP_GLOBAL(operator*)
322 VECTORIZEDN_DEFINE_BINARY_OP_GLOBAL(operator/)
323 VECTORIZEDN_DEFINE_BINARY_OP_GLOBAL(operator%)
324 VECTORIZEDN_DEFINE_BINARY_OP_GLOBAL(operator||)
325 VECTORIZEDN_DEFINE_BINARY_OP_GLOBAL(operator<<)
326 VECTORIZEDN_DEFINE_BINARY_OP_GLOBAL(operator>>)
327 VECTORIZEDN_DEFINE_BINARY_OP_GLOBAL(maximum)
328 VECTORIZEDN_DEFINE_BINARY_OP_GLOBAL(minimum)
329 VECTORIZEDN_DEFINE_BINARY_OP_GLOBAL(fmadd)
330 VECTORIZEDN_DEFINE_BINARY_OP_GLOBAL(fmsub)
331 VECTORIZEDN_DEFINE_BINARY_OP_GLOBAL(clamp)
332 VECTORIZEDN_DEFINE_BINARY_OP_GLOBAL(clamp_max)
333 VECTORIZEDN_DEFINE_BINARY_OP_GLOBAL(clamp_min)
334 VECTORIZEDN_DEFINE_BINARY_OP_GLOBAL(operator&)
335 VECTORIZEDN_DEFINE_BINARY_OP_GLOBAL(operator|)
336 VECTORIZEDN_DEFINE_BINARY_OP_GLOBAL(operator^)
337 VECTORIZEDN_DEFINE_UNARY_OP_GLOBAL(operator~)
338 
339 VECTORIZEDN_DEFINE_BINARY_OP_INPLACE_GLOBAL(operator+=)
340 VECTORIZEDN_DEFINE_BINARY_OP_INPLACE_GLOBAL(operator-=)
341 VECTORIZEDN_DEFINE_BINARY_OP_INPLACE_GLOBAL(operator*=)
342 VECTORIZEDN_DEFINE_BINARY_OP_INPLACE_GLOBAL(operator/=)
343 VECTORIZEDN_DEFINE_BINARY_OP_INPLACE_GLOBAL(operator%=)
344 VECTORIZEDN_DEFINE_BINARY_OP_INPLACE_GLOBAL(operator<<=)
345 VECTORIZEDN_DEFINE_BINARY_OP_INPLACE_GLOBAL(operator>>=)
346 
347 #undef VECTORIZEDN_DEFINE_UNARY_OP_GLOBAL
348 #undef VECTORIZEDN_DEFINE_BINARY_OP_GLOBAL
349 #undef VECTORIZEDN_DEFINE_BINARY_OP_INPLACE_GLOBAL
350 
351 template <typename T, int N, typename OpVec>
352 inline T vec_reduce_all(const OpVec& vec_fun, VectorizedN<T, N> acc_vec) {
353   Vectorized<T> vec_result = acc_vec[0];
354   for (int i = 1; i < N; i++) {
355     vec_result = vec_fun(vec_result, acc_vec[i]);
356   }
357   return vec_reduce_all(vec_fun, vec_result);
358 }
359 
360 } // namespace CPU_CAPABILITY
361 } // namespace at::vec
362