1 #pragma once
2
3 #include <ATen/cpu/vec/vec_base.h>
4 #include <array>
5
6 namespace at::vec {
7 inline namespace CPU_CAPABILITY {
8
9 /**
10 * @brief A class template representing a vectorized type with
11 * `N * Vectorized<T>::size()` elements, aiming to support vectors of
12 * arbitrary size. A specific use case of it is to represent vectors
13 * converted from data types with different sizes but with the same
14 * number of vector elements, e.g., `VectorizedN<float, 2>` can be
15 * a vector converted from two `Vectorized<bfloat16>`, `VectorizedN<int64_t, 2>`
16 * can be a vector converted from two `Vectorized<int32_t>` etc.
17 *
18 * It supports most of the operations of `Vectorized<T>`
19 * and the implementation delegates to `Vectorized<T>` with loops over `N`.
20 *
21 * @tparam T The underlying type of the vectorized elements.
22 * @tparam N The number of underlying `Vectorized<T>`.
23 */
24 template <typename T, int N>
25 class VectorizedN {
26 public:
27 using value_type = T;
28 using size_type = int;
29
30 static constexpr size_type size_T = sizeof(T);
size()31 static constexpr size_type size() {
32 return Vectorized<T>::size() * N;
33 }
34
35 private:
36 std::array<Vectorized<T>, N> values;
37
38 public:
39 // methods not implemented yet:
40 // variadic constructor, operator T*, as_bytes, zero_mask
41
42 #define VECTORIZEDN_DEFINE_UNARY_OP(op) \
43 VectorizedN<T, N> op() const { \
44 return unary_op([](const Vectorized<T>& a) { return a.op(); }); \
45 }
46
47 #define VECTORIZEDN_DEFINE_BINARY_OP(op) \
48 VectorizedN<T, N> op(const VectorizedN<T, N>& other) const { \
49 return binary_op( \
50 other, [](const Vectorized<T>& a, const Vectorized<T>& b) { \
51 return a.op(b); \
52 }); \
53 }
54
55 template <typename Op>
unary_op(Op op)56 inline VectorizedN<T, N> unary_op(Op op) const {
57 VectorizedN<T, N> result;
58 #ifndef _MSC_VER
59 #pragma unroll
60 #endif
61 for (int i = 0; i < N; ++i) {
62 result.values[i] = op(values[i]);
63 }
64 return result;
65 }
66
67 template <typename Op>
binary_op(const VectorizedN<T,N> & other,Op op)68 inline VectorizedN<T, N> binary_op(const VectorizedN<T, N>& other, Op op)
69 const {
70 VectorizedN<T, N> result;
71 #ifndef _MSC_VER
72 #pragma unroll
73 #endif
74 for (int i = 0; i < N; ++i) {
75 result.values[i] = op(values[i], other.values[i]);
76 }
77 return result;
78 }
79
80 VectorizedN() = default;
81
VectorizedN(T val)82 explicit VectorizedN(T val) {
83 for (int i = 0; i < N; ++i) {
84 values[i] = Vectorized<T>(val);
85 }
86 }
87
88 template <int L = N, typename std::enable_if_t<L == 1, int> = 0>
VectorizedN(const Vectorized<T> & val)89 VectorizedN(const Vectorized<T>& val) : values({val}) {}
90
91 template <int L = N, typename std::enable_if_t<L == 2, int> = 0>
VectorizedN(const Vectorized<T> & val_0,const Vectorized<T> & val_1)92 VectorizedN(const Vectorized<T>& val_0, const Vectorized<T>& val_1) : values({val_0, val_1}) {}
93
94 template <int L = N, typename std::enable_if_t<L == 1, int> = 0>
95 inline operator Vectorized<T>() const {
96 return values[0];
97 }
98
99 inline const Vectorized<T>& operator[](int i) const {
100 return values[i];
101 }
102
103 inline Vectorized<T>& operator[](int i) {
104 return values[i];
105 }
106
107 template <int64_t mask>
blend(const VectorizedN<T,N> & a,const VectorizedN<T,N> & b)108 static VectorizedN<T, N> blend(
109 const VectorizedN<T, N>& a,
110 const VectorizedN<T, N>& b) {
111 VectorizedN<T, N> result;
112 for (int i = 0; i < N; ++i) {
113 result.values[i] = Vectorized<T>::template blend<mask>(a.values[i], b.values[i]);
114 }
115 return result;
116 }
117
blendv(const VectorizedN<T,N> & a,const VectorizedN<T,N> & b,const VectorizedN<T,N> & mask)118 static VectorizedN<T, N> blendv(
119 const VectorizedN<T, N>& a,
120 const VectorizedN<T, N>& b,
121 const VectorizedN<T, N>& mask) {
122 VectorizedN<T, N> result;
123 for (int i = 0; i < N; ++i) {
124 result.values[i] =
125 Vectorized<T>::blendv(a.values[i], b.values[i], mask.values[i]);
126 }
127 return result;
128 }
129
130 template <typename step_t>
131 static VectorizedN<T, N> arange(
132 T base = static_cast<T>(0),
133 step_t step = static_cast<step_t>(1)) {
134 VectorizedN<T, N> result;
135 for (int i = 0; i < N; ++i) {
136 result.values[i] = Vectorized<T>::arange(base, step);
137 base += step * Vectorized<T>::size();
138 }
139 return result;
140 }
141
142 static VectorizedN<T, N> set(
143 const VectorizedN<T, N>& a,
144 const VectorizedN<T, N>& b,
145 int64_t count = size()) {
146 VectorizedN<T, N> result;
147 for (int i = 0; i < N; ++i) {
148 if (count > 0) {
149 result.values[i] = Vectorized<T>::set(
150 a.values[i],
151 b.values[i],
152 std::min(count, (int64_t)Vectorized<T>::size()));
153 count -= Vectorized<T>::size();
154 } else {
155 result.values[i] = a.values[i];
156 }
157 }
158 return result;
159 }
160
loadu(const void * ptr)161 static VectorizedN<T, N> loadu(const void* ptr) {
162 VectorizedN<T, N> result;
163 for (int i = 0; i < N; ++i) {
164 result.values[i] = Vectorized<T>::loadu(ptr);
165 ptr = static_cast<const T*>(ptr) + Vectorized<T>::size();
166 }
167 return result;
168 }
169
loadu(const void * ptr,int64_t count)170 static VectorizedN<T, N> loadu(const void* ptr, int64_t count) {
171 VectorizedN<T, N> result;
172 for (int i = 0; i < N; ++i) {
173 result.values[i] = Vectorized<T>::loadu(
174 ptr, std::min(count, (int64_t)Vectorized<T>::size()));
175 ptr = static_cast<const T*>(ptr) + Vectorized<T>::size();
176 count -= Vectorized<T>::size();
177 if (count <= 0) {
178 break;
179 }
180 }
181 return result;
182 }
183
store(void * ptr)184 void store(void* ptr) const {
185 for (int i = 0; i < N; ++i) {
186 values[i].store(ptr);
187 ptr = static_cast<T*>(ptr) + Vectorized<T>::size();
188 }
189 }
190
store(void * ptr,int count)191 void store(void* ptr, int count) const {
192 for (int i = 0; i < N; ++i) {
193 values[i].store(ptr, std::min(count, (int)Vectorized<T>::size()));
194 ptr = static_cast<T*>(ptr) + Vectorized<T>::size();
195 count -= Vectorized<T>::size();
196 if (count <= 0) {
197 break;
198 }
199 }
200 }
201
has_inf_nan()202 bool has_inf_nan() const {
203 for (int i = 0; i < N; ++i) {
204 if (values[i].has_inf_nan()) {
205 return true;
206 }
207 }
208 return false;
209 }
210
map(T (* const f)(T))211 VectorizedN<T, N> map(T (*const f)(T)) const {
212 VectorizedN<T, N> result;
213 for (int i = 0; i < N; ++i) {
214 result.values[i] = values[i].map(f);
215 }
216 return result;
217 }
218
map(T (* const f)(const T &))219 VectorizedN<T, N> map(T (*const f)(const T&)) const {
220 VectorizedN<T, N> result;
221 for (int i = 0; i < N; ++i) {
222 result.values[i] = values[i].map(f);
223 }
224 return result;
225 }
226
227 VECTORIZEDN_DEFINE_UNARY_OP(isnan)
228 VECTORIZEDN_DEFINE_UNARY_OP(abs)
229 VECTORIZEDN_DEFINE_UNARY_OP(sgn)
230 VECTORIZEDN_DEFINE_UNARY_OP(angle)
231 VECTORIZEDN_DEFINE_UNARY_OP(real)
232 VECTORIZEDN_DEFINE_UNARY_OP(imag)
233 VECTORIZEDN_DEFINE_UNARY_OP(conj)
234 VECTORIZEDN_DEFINE_UNARY_OP(acos)
235 VECTORIZEDN_DEFINE_UNARY_OP(acosh)
236 VECTORIZEDN_DEFINE_UNARY_OP(asin)
237 VECTORIZEDN_DEFINE_UNARY_OP(atan)
238 VECTORIZEDN_DEFINE_UNARY_OP(atanh)
239 VECTORIZEDN_DEFINE_BINARY_OP(atan2)
240 VECTORIZEDN_DEFINE_BINARY_OP(copysign)
241 VECTORIZEDN_DEFINE_UNARY_OP(erf)
242 VECTORIZEDN_DEFINE_UNARY_OP(erfc)
243 VECTORIZEDN_DEFINE_UNARY_OP(erfinv)
244 VECTORIZEDN_DEFINE_UNARY_OP(exp)
245 VECTORIZEDN_DEFINE_UNARY_OP(exp2)
246 VECTORIZEDN_DEFINE_UNARY_OP(expm1)
247 VECTORIZEDN_DEFINE_UNARY_OP(exp_u20)
248 VECTORIZEDN_DEFINE_UNARY_OP(frac)
249 VECTORIZEDN_DEFINE_BINARY_OP(fmod)
250 VECTORIZEDN_DEFINE_UNARY_OP(log)
251 VECTORIZEDN_DEFINE_UNARY_OP(log10)
252 VECTORIZEDN_DEFINE_UNARY_OP(log1p)
253 VECTORIZEDN_DEFINE_UNARY_OP(log2)
254 VECTORIZEDN_DEFINE_UNARY_OP(ceil)
255 VECTORIZEDN_DEFINE_UNARY_OP(cos)
256 VECTORIZEDN_DEFINE_UNARY_OP(cosh)
257 VECTORIZEDN_DEFINE_UNARY_OP(floor)
258 VECTORIZEDN_DEFINE_BINARY_OP(hypot)
259 VECTORIZEDN_DEFINE_UNARY_OP(i0)
260 VECTORIZEDN_DEFINE_UNARY_OP(i0e)
261 VECTORIZEDN_DEFINE_UNARY_OP(digamma)
262 VECTORIZEDN_DEFINE_BINARY_OP(igamma)
263 VECTORIZEDN_DEFINE_BINARY_OP(igammac)
264 VECTORIZEDN_DEFINE_UNARY_OP(neg)
265 VECTORIZEDN_DEFINE_BINARY_OP(nextafter)
266 VECTORIZEDN_DEFINE_UNARY_OP(round)
267 VECTORIZEDN_DEFINE_UNARY_OP(sin)
268 VECTORIZEDN_DEFINE_UNARY_OP(sinh)
269 VECTORIZEDN_DEFINE_UNARY_OP(tan)
270 VECTORIZEDN_DEFINE_UNARY_OP(tanh)
271 VECTORIZEDN_DEFINE_UNARY_OP(trunc)
272 VECTORIZEDN_DEFINE_UNARY_OP(lgamma)
273 VECTORIZEDN_DEFINE_UNARY_OP(sqrt)
274 VECTORIZEDN_DEFINE_UNARY_OP(reciprocal)
275 VECTORIZEDN_DEFINE_UNARY_OP(rsqrt)
276 VECTORIZEDN_DEFINE_BINARY_OP(pow)
277 VECTORIZEDN_DEFINE_BINARY_OP(operator==)
278 VECTORIZEDN_DEFINE_BINARY_OP(operator!=)
279 VECTORIZEDN_DEFINE_BINARY_OP(operator>=)
280 VECTORIZEDN_DEFINE_BINARY_OP(operator<=)
281 VECTORIZEDN_DEFINE_BINARY_OP(operator>)
282 VECTORIZEDN_DEFINE_BINARY_OP(operator<)
283 VECTORIZEDN_DEFINE_BINARY_OP(eq)
284 VECTORIZEDN_DEFINE_BINARY_OP(ne)
285 VECTORIZEDN_DEFINE_BINARY_OP(gt)
286 VECTORIZEDN_DEFINE_BINARY_OP(ge)
287 VECTORIZEDN_DEFINE_BINARY_OP(lt)
288 VECTORIZEDN_DEFINE_BINARY_OP(le)
289
290 #undef VECTORIZEDN_DEFINE_UNARY_OP
291 #undef VECTORIZEDN_DEFINE_BINARY_OP
292 };
293
294 #define VECTORIZEDN_DEFINE_UNARY_OP_GLOBAL(op) \
295 template <typename T, int N> \
296 inline VectorizedN<T, N> op(const VectorizedN<T, N>& a) { \
297 return a.unary_op([](const Vectorized<T>& a) { return op(a); }); \
298 }
299
300 #define VECTORIZEDN_DEFINE_BINARY_OP_GLOBAL(op) \
301 template <typename T, int N> \
302 inline VectorizedN<T, N> op( \
303 const VectorizedN<T, N>& a, const VectorizedN<T, N>& b) { \
304 return a.binary_op(b, [](const Vectorized<T>& a, const Vectorized<T>& b) { \
305 return op(a, b); \
306 }); \
307 }
308
309 #define VECTORIZEDN_DEFINE_BINARY_OP_INPLACE_GLOBAL(op) \
310 template <typename T, int N> \
311 inline VectorizedN<T, N>& op( \
312 VectorizedN<T, N>& a, const VectorizedN<T, N>& b) { \
313 a = a.binary_op(b, [](const Vectorized<T>& a, const Vectorized<T>& b) { \
314 return op(a, b); \
315 }); \
316 return a; \
317 }
318
319 VECTORIZEDN_DEFINE_BINARY_OP_GLOBAL(operator+)
320 VECTORIZEDN_DEFINE_BINARY_OP_GLOBAL(operator-)
VECTORIZEDN_DEFINE_BINARY_OP_GLOBAL(operator *)321 VECTORIZEDN_DEFINE_BINARY_OP_GLOBAL(operator*)
322 VECTORIZEDN_DEFINE_BINARY_OP_GLOBAL(operator/)
323 VECTORIZEDN_DEFINE_BINARY_OP_GLOBAL(operator%)
324 VECTORIZEDN_DEFINE_BINARY_OP_GLOBAL(operator||)
325 VECTORIZEDN_DEFINE_BINARY_OP_GLOBAL(operator<<)
326 VECTORIZEDN_DEFINE_BINARY_OP_GLOBAL(operator>>)
327 VECTORIZEDN_DEFINE_BINARY_OP_GLOBAL(maximum)
328 VECTORIZEDN_DEFINE_BINARY_OP_GLOBAL(minimum)
329 VECTORIZEDN_DEFINE_BINARY_OP_GLOBAL(fmadd)
330 VECTORIZEDN_DEFINE_BINARY_OP_GLOBAL(fmsub)
331 VECTORIZEDN_DEFINE_BINARY_OP_GLOBAL(clamp)
332 VECTORIZEDN_DEFINE_BINARY_OP_GLOBAL(clamp_max)
333 VECTORIZEDN_DEFINE_BINARY_OP_GLOBAL(clamp_min)
334 VECTORIZEDN_DEFINE_BINARY_OP_GLOBAL(operator&)
335 VECTORIZEDN_DEFINE_BINARY_OP_GLOBAL(operator|)
336 VECTORIZEDN_DEFINE_BINARY_OP_GLOBAL(operator^)
337 VECTORIZEDN_DEFINE_UNARY_OP_GLOBAL(operator~)
338
339 VECTORIZEDN_DEFINE_BINARY_OP_INPLACE_GLOBAL(operator+=)
340 VECTORIZEDN_DEFINE_BINARY_OP_INPLACE_GLOBAL(operator-=)
341 VECTORIZEDN_DEFINE_BINARY_OP_INPLACE_GLOBAL(operator*=)
342 VECTORIZEDN_DEFINE_BINARY_OP_INPLACE_GLOBAL(operator/=)
343 VECTORIZEDN_DEFINE_BINARY_OP_INPLACE_GLOBAL(operator%=)
344 VECTORIZEDN_DEFINE_BINARY_OP_INPLACE_GLOBAL(operator<<=)
345 VECTORIZEDN_DEFINE_BINARY_OP_INPLACE_GLOBAL(operator>>=)
346
347 #undef VECTORIZEDN_DEFINE_UNARY_OP_GLOBAL
348 #undef VECTORIZEDN_DEFINE_BINARY_OP_GLOBAL
349 #undef VECTORIZEDN_DEFINE_BINARY_OP_INPLACE_GLOBAL
350
351 template <typename T, int N, typename OpVec>
352 inline T vec_reduce_all(const OpVec& vec_fun, VectorizedN<T, N> acc_vec) {
353 Vectorized<T> vec_result = acc_vec[0];
354 for (int i = 1; i < N; i++) {
355 vec_result = vec_fun(vec_result, acc_vec[i]);
356 }
357 return vec_reduce_all(vec_fun, vec_result);
358 }
359
360 } // namespace CPU_CAPABILITY
361 } // namespace at::vec
362