1 #pragma once
2
3 // DO NOT DEFINE STATIC DATA IN THIS HEADER!
4 // See Note [Do not compile initializers with AVX]
5
6 #include <ATen/cpu/vec/vec.h>
7 #include <c10/util/irange.h>
8
9 namespace at::vec {
10
11 // slow path
12 template <typename scalar_t, typename Op>
vec_reduce_all(const Op & vec_fun,vec::Vectorized<scalar_t> acc_vec,int64_t size)13 inline scalar_t vec_reduce_all(
14 const Op& vec_fun,
15 vec::Vectorized<scalar_t> acc_vec,
16 int64_t size) {
17 using Vec = vec::Vectorized<scalar_t>;
18 scalar_t acc_arr[Vec::size()];
19 acc_vec.store(acc_arr);
20 for (const auto i : c10::irange(1, size)) {
21 std::array<scalar_t, Vec::size()> acc_arr_next = {0};
22 acc_arr_next[0] = acc_arr[i];
23 Vec acc_vec_next = Vec::loadu(acc_arr_next.data());
24 acc_vec = vec_fun(acc_vec, acc_vec_next);
25 }
26 acc_vec.store(acc_arr);
27 return acc_arr[0];
28 }
29
30 template <typename scalar_t, typename Op>
31 struct VecReduceAllSIMD {
applyVecReduceAllSIMD32 static inline scalar_t apply(const Op& vec_fun, const Vectorized<scalar_t>& acc_vec) {
33 return vec_reduce_all(vec_fun, acc_vec, Vectorized<scalar_t>::size());
34 }
35 };
36
37 #if defined(__GNUC__) && (__GNUC__ > 5) && !defined(_MSC_VER) && !defined(C10_MOBILE)
38 #if defined(CPU_CAPABILITY_AVX2)
39 template <typename Op>
40 struct VecReduceAllSIMD<float, Op> {
41 static inline float apply(const Op& vec_fun, const Vectorized<float>& acc_vec) {
42 using Vec = Vectorized<float>;
43 Vec v = acc_vec;
44 // 128-bit shuffle
45 Vec v1 = _mm256_permute2f128_ps(v, v, 0x1);
46 v = vec_fun(v, v1);
47 // 64-bit shuffle
48 v1 = _mm256_shuffle_ps(v, v, 0x4E);
49 v = vec_fun(v, v1);
50 // 32-bit shuffle
51 v1 = _mm256_shuffle_ps(v, v, 0xB1);
52 v = vec_fun(v, v1);
53 return _mm256_cvtss_f32(v);
54 }
55 };
56 #endif // defined(CPU_CAPABILITY_AVX2)
57 #if defined(CPU_CAPABILITY_AVX512)
58 template <typename Op>
59 struct VecReduceAllSIMD<float, Op> {
60 static inline float apply(const Op& vec_fun, const Vectorized<float>& acc_vec) {
61 using Vec = Vectorized<float>;
62 Vec v = acc_vec;
63 // 256-bit shuffle
64 Vec v1 = _mm512_shuffle_f32x4(v, v, 0x4E);
65 v = vec_fun(v, v1);
66 // 128-bit shuffle
67 v1 = _mm512_shuffle_f32x4(v, v, 0xB1);
68 v = vec_fun(v, v1);
69 // 64-bit shuffle
70 v1 = _mm512_shuffle_ps(v, v, 0x4E);
71 v = vec_fun(v, v1);
72 // 32-bit shuffle
73 v1 = _mm512_shuffle_ps(v, v, 0xB1);
74 v = vec_fun(v, v1);
75 return _mm512_cvtss_f32(v);
76 }
77 };
78 #endif // defined(CPU_CAPABILITY_AVX512)
79 #endif // defined(__GNUC__) && (__GNUC__ > 5) && !defined(_MSC_VER) && !defined(C10_MOBILE)
80
81 #if defined(__aarch64__) && !defined(C10_MOBILE) && !defined(__CUDACC__)
82 template <typename Op>
83 struct VecReduceAllSIMD<float, Op> {
84 static inline float apply(const Op& vec_fun, const Vectorized<float>& acc_vec) {
85 using Vec = Vectorized<float>;
86 Vec v = acc_vec;
87
88 // 128-bit shuffle: [a1, a2, a3, a4, a5, a6, a7, a8] -> [a5, a6, a7, a8, a1, a2, a3, a4]
89 Vec v1 = {v.get_high(), v.get_low()};
90 // [a1+a5, a2+a6, a3+a7, a4+a8, -, -, -, -] ('+' stands for the reduction function. Note that the last 4 elements are not required)
91 v = vec_fun(v, v1);
92
93 // 64-bit shuffle: [a1+a5, a2+a6, a3+a7, a4+a8, -, -, -, -] -> [a3+a7, a4+a8, a1+a5, a2+a6, -, -, -, -]
94 float32x4_t v1_1 = vextq_f32(v.get_low(), v.get_low(), 2);
95 v1 = {v1_1, v1_1};
96 // [a1+a3+a5+a7, a2+a4+a6+a8, a1+a3+a5+a7, a2+a4+a6+a8, -, -, -, -]
97 v = vec_fun(v, v1);
98
99 // 32-bit shuffle: [a1+a3+a5+a7, a2+a4+a6+a8, a1+a3+a5+a7, a2+a4+a6+a8, -, -, -, -] -> [a2+a4+a6+a8, a1+a3+a5+a7, a2+a4+a6+a8, a1+a3+a5+a7, -, -, -, -]
100 v1_1 = vrev64q_f32(v.get_low());
101 v1 = {v1_1, v1_1};
102 // [a1+a2+a3+a4+a5+a6+a7+a8, a1+a2+a3+a4+a5+a6+a7+a8, a1+a2+a3+a4+a5+a6+a7+a8, a1+a2+a3+a4+a5+a6+a7+a8, -, -, -, -]
103 v = vec_fun(v, v1);
104
105 return v.get_low()[0];
106 }
107 };
108 #endif // defined(__aarch64__)
109
110 template <typename scalar_t, typename Op>
111 inline scalar_t vec_reduce_all(const Op& vec_fun, const Vectorized<scalar_t>& acc_vec) {
112 return VecReduceAllSIMD<scalar_t, Op>::apply(vec_fun, acc_vec);
113 }
114
115 template <typename scalar_t, typename Op,
116 typename std::enable_if_t<!is_reduced_floating_point_v<scalar_t>, int> = 0>
117 inline scalar_t reduce_all(const Op& vec_fun, const scalar_t* data, int64_t size) {
118 using Vec = vec::Vectorized<scalar_t>;
119 if (size < Vec::size())
120 return vec_reduce_all(vec_fun, Vec::loadu(data, size), size);
121 int64_t d = Vec::size();
122 Vec acc_vec = Vec::loadu(data);
123 for (; d < size - (size % Vec::size()); d += Vec::size()) {
124 Vec data_vec = Vec::loadu(data + d);
125 acc_vec = vec_fun(acc_vec, data_vec);
126 }
127 if (size - d > 0) {
128 Vec data_vec = Vec::loadu(data + d, size - d);
129 acc_vec = Vec::set(acc_vec, vec_fun(acc_vec, data_vec), size - d);
130 }
131 return vec_reduce_all(vec_fun, acc_vec);
132 }
133
134 // similar to reduce_all, but reduces into two outputs
135 template <typename scalar_t, typename Op1, typename Op2,
136 typename std::enable_if_t<!is_reduced_floating_point_v<scalar_t>, int> = 0>
137 inline std::pair<scalar_t, scalar_t> reduce2_all(const Op1& vec_fun1, const Op2& vec_fun2,
138 const scalar_t* data, int64_t size) {
139 using Vec = vec::Vectorized<scalar_t>;
140 if (size < Vec::size()) {
141 auto loaded_data = Vec::loadu(data, size);
142 return std::pair<scalar_t, scalar_t>(
143 vec_reduce_all(vec_fun1, loaded_data, size),
144 vec_reduce_all(vec_fun2, loaded_data, size));
145 }
146 int64_t d = Vec::size();
147 Vec acc_vec1 = Vec::loadu(data);
148 Vec acc_vec2 = Vec::loadu(data);
149 for (; d < size - (size % Vec::size()); d += Vec::size()) {
150 Vec data_vec = Vec::loadu(data + d);
151 acc_vec1 = vec_fun1(acc_vec1, data_vec);
152 acc_vec2 = vec_fun2(acc_vec2, data_vec);
153 }
154 if (size - d > 0) {
155 Vec data_vec = Vec::loadu(data + d, size - d);
156 acc_vec1 = Vec::set(acc_vec1, vec_fun1(acc_vec1, data_vec), size - d);
157 acc_vec2 = Vec::set(acc_vec2, vec_fun2(acc_vec2, data_vec), size - d);
158 }
159 return std::pair<scalar_t, scalar_t>(
160 vec_reduce_all(vec_fun1, acc_vec1),
161 vec_reduce_all(vec_fun2, acc_vec2));
162 }
163
164 template <typename scalar_t, typename MapOp, typename ReduceOp,
165 typename std::enable_if_t<!is_reduced_floating_point_v<scalar_t>, int> = 0>
166 inline scalar_t map_reduce_all(
167 const MapOp& map_fun,
168 const ReduceOp& red_fun,
169 const scalar_t* data,
170 int64_t size) {
171 using Vec = vec::Vectorized<scalar_t>;
172 if (size < Vec::size())
173 return vec_reduce_all(red_fun, map_fun(Vec::loadu(data, size)), size);
174 int64_t d = Vec::size();
175 Vec acc_vec = map_fun(Vec::loadu(data));
176 for (; d < size - (size % Vec::size()); d += Vec::size()) {
177 Vec data_vec = Vec::loadu(data + d);
178 data_vec = map_fun(data_vec);
179 acc_vec = red_fun(acc_vec, data_vec);
180 }
181 if (size - d > 0) {
182 Vec data_vec = Vec::loadu(data + d, size - d);
183 data_vec = map_fun(data_vec);
184 acc_vec = Vec::set(acc_vec, red_fun(acc_vec, data_vec), size - d);
185 }
186 return vec_reduce_all(red_fun, acc_vec);
187 }
188
189 template <typename scalar_t, typename MapOp, typename ReduceOp,
190 typename std::enable_if_t<!is_reduced_floating_point_v<scalar_t>, int> = 0>
191 inline scalar_t map2_reduce_all(
192 const MapOp& map_fun,
193 const ReduceOp& red_fun,
194 const scalar_t* data,
195 const scalar_t* data2,
196 int64_t size) {
197 using Vec = vec::Vectorized<scalar_t>;
198 if (size < Vec::size()) {
199 Vec data_vec = Vec::loadu(data, size);
200 Vec data2_vec = Vec::loadu(data2, size);
201 data_vec = map_fun(data_vec, data2_vec);
202 return vec_reduce_all(red_fun, data_vec, size);
203 }
204 int64_t d = Vec::size();
205 Vec acc_vec = map_fun(Vec::loadu(data), Vec::loadu(data2));
206 for (; d < size - (size % Vec::size()); d += Vec::size()) {
207 Vec data_vec = Vec::loadu(data + d);
208 Vec data2_vec = Vec::loadu(data2 + d);
209 data_vec = map_fun(data_vec, data2_vec);
210 acc_vec = red_fun(acc_vec, data_vec);
211 }
212 if (size - d > 0) {
213 Vec data_vec = Vec::loadu(data + d, size - d);
214 Vec data2_vec = Vec::loadu(data2 + d, size - d);
215 data_vec = map_fun(data_vec, data2_vec);
216 acc_vec = Vec::set(acc_vec, red_fun(acc_vec, data_vec), size - d);
217 }
218 return vec_reduce_all(red_fun, acc_vec);
219 }
220
221 template <typename scalar_t, typename MapOp, typename ReduceOp,
222 typename std::enable_if_t<!is_reduced_floating_point_v<scalar_t>, int> = 0>
223 inline scalar_t map3_reduce_all(
224 const MapOp& map_fun,
225 const ReduceOp& red_fun,
226 const scalar_t* data,
227 const scalar_t* data2,
228 const scalar_t* data3,
229 int64_t size) {
230 using Vec = vec::Vectorized<scalar_t>;
231 if (size < Vec::size()) {
232 Vec data_vec = Vec::loadu(data, size);
233 Vec data2_vec = Vec::loadu(data2, size);
234 Vec data3_vec = Vec::loadu(data3, size);
235 data_vec = map_fun(data_vec, data2_vec, data3_vec);
236 return vec_reduce_all(red_fun, data_vec, size);
237 }
238
239 int64_t d = Vec::size();
240 Vec acc_vec = map_fun(Vec::loadu(data), Vec::loadu(data2), Vec::loadu(data3));
241 for (; d < size - (size % Vec::size()); d += Vec::size()) {
242 Vec data_vec = Vec::loadu(data + d);
243 Vec data2_vec = Vec::loadu(data2 + d);
244 Vec data3_vec = Vec::loadu(data3 + d);
245 data_vec = map_fun(data_vec, data2_vec, data3_vec);
246 acc_vec = red_fun(acc_vec, data_vec);
247 }
248 if (size - d > 0) {
249 Vec data_vec = Vec::loadu(data + d, size - d);
250 Vec data2_vec = Vec::loadu(data2 + d, size - d);
251 Vec data3_vec = Vec::loadu(data3 + d, size - d);
252 data_vec = map_fun(data_vec, data2_vec, data3_vec);
253 acc_vec = Vec::set(acc_vec, red_fun(acc_vec, data_vec), size - d);
254 }
255 return vec_reduce_all(red_fun, acc_vec);
256 }
257
258 template <typename scalar_t, typename Op,
259 typename std::enable_if_t<!is_reduced_floating_point_v<scalar_t>, int> = 0>
260 inline void map(
261 const Op& vec_fun,
262 scalar_t* output_data,
263 const scalar_t* input_data,
264 int64_t size) {
265 using Vec = vec::Vectorized<scalar_t>;
266 int64_t d = 0;
267 for (; d < size - (size % Vec::size()); d += Vec::size()) {
268 Vec output_vec = vec_fun(Vec::loadu(input_data + d));
269 output_vec.store(output_data + d);
270 }
271 if (size - d > 0) {
272 Vec output_vec = vec_fun(Vec::loadu(input_data + d, size - d));
273 output_vec.store(output_data + d, size - d);
274 }
275 }
276
277 template <typename scalar_t, typename Op,
278 typename std::enable_if_t<!is_reduced_floating_point_v<scalar_t>, int> = 0>
279 inline void map2(
280 const Op& vec_fun,
281 scalar_t* output_data,
282 const scalar_t* input_data,
283 const scalar_t* input_data2,
284 int64_t size) {
285 using Vec = vec::Vectorized<scalar_t>;
286 int64_t d = 0;
287 for (; d < size - (size % Vec::size()); d += Vec::size()) {
288 Vec data_vec = Vec::loadu(input_data + d);
289 Vec data_vec2 = Vec::loadu(input_data2 + d);
290 Vec output_vec = vec_fun(data_vec, data_vec2);
291 output_vec.store(output_data + d);
292 }
293 if (size - d > 0) {
294 Vec data_vec = Vec::loadu(input_data + d, size - d);
295 Vec data_vec2 = Vec::loadu(input_data2 + d, size - d);
296 Vec output_vec = vec_fun(data_vec, data_vec2);
297 output_vec.store(output_data + d, size - d);
298 }
299 }
300
301 template <typename scalar_t, typename Op,
302 typename std::enable_if_t<!is_reduced_floating_point_v<scalar_t>, int> = 0>
303 inline void map3(
304 const Op& vec_fun,
305 scalar_t* output_data,
306 const scalar_t* input_data1,
307 const scalar_t* input_data2,
308 const scalar_t* input_data3,
309 int64_t size) {
310 using Vec = vec::Vectorized<scalar_t>;
311 int64_t d = 0;
312 for (; d < size - (size % Vec::size()); d += Vec::size()) {
313 Vec data_vec1 = Vec::loadu(input_data1 + d);
314 Vec data_vec2 = Vec::loadu(input_data2 + d);
315 Vec data_vec3 = Vec::loadu(input_data3 + d);
316 Vec output_vec = vec_fun(data_vec1, data_vec2, data_vec3);
317 output_vec.store(output_data + d);
318 }
319 if (size - d > 0) {
320 Vec data_vec1 = Vec::loadu(input_data1 + d, size - d);
321 Vec data_vec2 = Vec::loadu(input_data2 + d, size - d);
322 Vec data_vec3 = Vec::loadu(input_data3 + d, size - d);
323 Vec output_vec = vec_fun(data_vec1, data_vec2, data_vec3);
324 output_vec.store(output_data + d, size - d);
325 }
326 }
327
328 template <typename scalar_t, typename Op,
329 typename std::enable_if_t<!is_reduced_floating_point_v<scalar_t>, int> = 0>
330 inline void map4(
331 const Op& vec_fun,
332 scalar_t* output_data,
333 const scalar_t* input_data1,
334 const scalar_t* input_data2,
335 const scalar_t* input_data3,
336 const scalar_t* input_data4,
337 int64_t size) {
338 using Vec = vec::Vectorized<scalar_t>;
339 int64_t d = 0;
340 for (; d < size - (size % Vec::size()); d += Vec::size()) {
341 Vec data_vec1 = Vec::loadu(input_data1 + d);
342 Vec data_vec2 = Vec::loadu(input_data2 + d);
343 Vec data_vec3 = Vec::loadu(input_data3 + d);
344 Vec data_vec4 = Vec::loadu(input_data4 + d);
345 Vec output_vec = vec_fun(data_vec1, data_vec2, data_vec3, data_vec4);
346 output_vec.store(output_data + d);
347 }
348 if (size - d > 0) {
349 Vec data_vec1 = Vec::loadu(input_data1 + d, size - d);
350 Vec data_vec2 = Vec::loadu(input_data2 + d, size - d);
351 Vec data_vec3 = Vec::loadu(input_data3 + d, size - d);
352 Vec data_vec4 = Vec::loadu(input_data4 + d, size - d);
353 Vec output_vec = vec_fun(data_vec1, data_vec2, data_vec3, data_vec4);
354 output_vec.store(output_data + d, size - d);
355 }
356 }
357
358 } // namespace at::vec
359