xref: /aosp_15_r20/external/pytorch/aten/src/ATen/cpu/vec/functional_base.h (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #pragma once
2 
3 // DO NOT DEFINE STATIC DATA IN THIS HEADER!
4 // See Note [Do not compile initializers with AVX]
5 
6 #include <ATen/cpu/vec/vec.h>
7 #include <c10/util/irange.h>
8 
9 namespace at::vec {
10 
11 // slow path
12 template <typename scalar_t, typename Op>
vec_reduce_all(const Op & vec_fun,vec::Vectorized<scalar_t> acc_vec,int64_t size)13 inline scalar_t vec_reduce_all(
14     const Op& vec_fun,
15     vec::Vectorized<scalar_t> acc_vec,
16     int64_t size) {
17   using Vec = vec::Vectorized<scalar_t>;
18   scalar_t acc_arr[Vec::size()];
19   acc_vec.store(acc_arr);
20   for (const auto i : c10::irange(1, size)) {
21     std::array<scalar_t, Vec::size()> acc_arr_next = {0};
22     acc_arr_next[0] = acc_arr[i];
23     Vec acc_vec_next = Vec::loadu(acc_arr_next.data());
24     acc_vec = vec_fun(acc_vec, acc_vec_next);
25   }
26   acc_vec.store(acc_arr);
27   return acc_arr[0];
28 }
29 
30 template <typename scalar_t, typename Op>
31 struct VecReduceAllSIMD {
applyVecReduceAllSIMD32   static inline scalar_t apply(const Op& vec_fun, const Vectorized<scalar_t>& acc_vec) {
33     return vec_reduce_all(vec_fun, acc_vec, Vectorized<scalar_t>::size());
34   }
35 };
36 
37 #if defined(__GNUC__) && (__GNUC__ > 5) && !defined(_MSC_VER) && !defined(C10_MOBILE)
38 #if defined(CPU_CAPABILITY_AVX2)
39 template <typename Op>
40 struct VecReduceAllSIMD<float, Op> {
41   static inline float apply(const Op& vec_fun, const Vectorized<float>& acc_vec) {
42     using Vec = Vectorized<float>;
43     Vec v = acc_vec;
44     // 128-bit shuffle
45     Vec v1 = _mm256_permute2f128_ps(v, v, 0x1);
46     v = vec_fun(v, v1);
47     // 64-bit shuffle
48     v1 = _mm256_shuffle_ps(v, v, 0x4E);
49     v = vec_fun(v, v1);
50     // 32-bit shuffle
51     v1 = _mm256_shuffle_ps(v, v, 0xB1);
52     v = vec_fun(v, v1);
53     return _mm256_cvtss_f32(v);
54   }
55 };
56 #endif // defined(CPU_CAPABILITY_AVX2)
57 #if defined(CPU_CAPABILITY_AVX512)
58 template <typename Op>
59 struct VecReduceAllSIMD<float, Op> {
60   static inline float apply(const Op& vec_fun, const Vectorized<float>& acc_vec) {
61     using Vec = Vectorized<float>;
62     Vec v = acc_vec;
63     // 256-bit shuffle
64     Vec v1 = _mm512_shuffle_f32x4(v, v, 0x4E);
65     v = vec_fun(v, v1);
66     // 128-bit shuffle
67     v1 = _mm512_shuffle_f32x4(v, v, 0xB1);
68     v = vec_fun(v, v1);
69     // 64-bit shuffle
70     v1 = _mm512_shuffle_ps(v, v, 0x4E);
71     v = vec_fun(v, v1);
72     // 32-bit shuffle
73     v1 = _mm512_shuffle_ps(v, v, 0xB1);
74     v = vec_fun(v, v1);
75     return _mm512_cvtss_f32(v);
76   }
77 };
78 #endif // defined(CPU_CAPABILITY_AVX512)
79 #endif // defined(__GNUC__) && (__GNUC__ > 5) && !defined(_MSC_VER) && !defined(C10_MOBILE)
80 
81 #if defined(__aarch64__) && !defined(C10_MOBILE) && !defined(__CUDACC__)
82 template <typename Op>
83 struct VecReduceAllSIMD<float, Op> {
84   static inline float apply(const Op& vec_fun, const Vectorized<float>& acc_vec) {
85     using Vec = Vectorized<float>;
86     Vec v = acc_vec;
87 
88     // 128-bit shuffle: [a1, a2, a3, a4, a5, a6, a7, a8] -> [a5, a6, a7, a8, a1, a2, a3, a4]
89     Vec v1 = {v.get_high(), v.get_low()};
90     // [a1+a5, a2+a6, a3+a7, a4+a8, -, -, -, -] ('+' stands for the reduction function. Note that the last 4 elements are not required)
91     v = vec_fun(v, v1);
92 
93     // 64-bit shuffle: [a1+a5, a2+a6, a3+a7, a4+a8, -, -, -, -] -> [a3+a7, a4+a8, a1+a5, a2+a6, -, -, -, -]
94     float32x4_t v1_1 = vextq_f32(v.get_low(), v.get_low(), 2);
95     v1 = {v1_1, v1_1};
96     // [a1+a3+a5+a7, a2+a4+a6+a8, a1+a3+a5+a7, a2+a4+a6+a8, -, -, -, -]
97     v = vec_fun(v, v1);
98 
99     // 32-bit shuffle: [a1+a3+a5+a7, a2+a4+a6+a8, a1+a3+a5+a7, a2+a4+a6+a8, -, -, -, -] -> [a2+a4+a6+a8, a1+a3+a5+a7, a2+a4+a6+a8, a1+a3+a5+a7, -, -, -, -]
100     v1_1 = vrev64q_f32(v.get_low());
101     v1 = {v1_1, v1_1};
102     // [a1+a2+a3+a4+a5+a6+a7+a8, a1+a2+a3+a4+a5+a6+a7+a8, a1+a2+a3+a4+a5+a6+a7+a8, a1+a2+a3+a4+a5+a6+a7+a8, -, -, -, -]
103     v = vec_fun(v, v1);
104 
105     return v.get_low()[0];
106   }
107 };
108 #endif // defined(__aarch64__)
109 
110 template <typename scalar_t, typename Op>
111 inline scalar_t vec_reduce_all(const Op& vec_fun, const Vectorized<scalar_t>& acc_vec) {
112   return VecReduceAllSIMD<scalar_t, Op>::apply(vec_fun, acc_vec);
113 }
114 
115 template <typename scalar_t, typename Op,
116           typename std::enable_if_t<!is_reduced_floating_point_v<scalar_t>, int> = 0>
117 inline scalar_t reduce_all(const Op& vec_fun, const scalar_t* data, int64_t size) {
118   using Vec = vec::Vectorized<scalar_t>;
119   if (size < Vec::size())
120     return vec_reduce_all(vec_fun, Vec::loadu(data, size), size);
121   int64_t d = Vec::size();
122   Vec acc_vec = Vec::loadu(data);
123   for (; d < size - (size % Vec::size()); d += Vec::size()) {
124     Vec data_vec = Vec::loadu(data + d);
125     acc_vec = vec_fun(acc_vec, data_vec);
126   }
127   if (size - d > 0) {
128     Vec data_vec = Vec::loadu(data + d, size - d);
129     acc_vec = Vec::set(acc_vec, vec_fun(acc_vec, data_vec), size - d);
130   }
131   return vec_reduce_all(vec_fun, acc_vec);
132 }
133 
134 // similar to reduce_all, but reduces into two outputs
135 template <typename scalar_t, typename Op1, typename Op2,
136           typename std::enable_if_t<!is_reduced_floating_point_v<scalar_t>, int> = 0>
137 inline std::pair<scalar_t, scalar_t> reduce2_all(const Op1& vec_fun1, const Op2& vec_fun2,
138     const scalar_t* data, int64_t size) {
139   using Vec = vec::Vectorized<scalar_t>;
140   if (size < Vec::size()) {
141     auto loaded_data = Vec::loadu(data, size);
142     return std::pair<scalar_t, scalar_t>(
143       vec_reduce_all(vec_fun1, loaded_data, size),
144       vec_reduce_all(vec_fun2, loaded_data, size));
145   }
146   int64_t d = Vec::size();
147   Vec acc_vec1 = Vec::loadu(data);
148   Vec acc_vec2 = Vec::loadu(data);
149   for (; d < size - (size % Vec::size()); d += Vec::size()) {
150     Vec data_vec = Vec::loadu(data + d);
151     acc_vec1 = vec_fun1(acc_vec1, data_vec);
152     acc_vec2 = vec_fun2(acc_vec2, data_vec);
153   }
154   if (size - d > 0) {
155     Vec data_vec = Vec::loadu(data + d, size - d);
156     acc_vec1 = Vec::set(acc_vec1, vec_fun1(acc_vec1, data_vec), size - d);
157     acc_vec2 = Vec::set(acc_vec2, vec_fun2(acc_vec2, data_vec), size - d);
158   }
159   return std::pair<scalar_t, scalar_t>(
160     vec_reduce_all(vec_fun1, acc_vec1),
161     vec_reduce_all(vec_fun2, acc_vec2));
162 }
163 
164 template <typename scalar_t, typename MapOp, typename ReduceOp,
165           typename std::enable_if_t<!is_reduced_floating_point_v<scalar_t>, int> = 0>
166 inline scalar_t map_reduce_all(
167     const MapOp& map_fun,
168     const ReduceOp& red_fun,
169     const scalar_t* data,
170     int64_t size) {
171   using Vec = vec::Vectorized<scalar_t>;
172   if (size < Vec::size())
173     return vec_reduce_all(red_fun, map_fun(Vec::loadu(data, size)), size);
174   int64_t d = Vec::size();
175   Vec acc_vec = map_fun(Vec::loadu(data));
176   for (; d < size - (size % Vec::size()); d += Vec::size()) {
177     Vec data_vec = Vec::loadu(data + d);
178     data_vec = map_fun(data_vec);
179     acc_vec = red_fun(acc_vec, data_vec);
180   }
181   if (size - d > 0) {
182     Vec data_vec = Vec::loadu(data + d, size - d);
183     data_vec = map_fun(data_vec);
184     acc_vec = Vec::set(acc_vec, red_fun(acc_vec, data_vec), size - d);
185   }
186   return vec_reduce_all(red_fun, acc_vec);
187 }
188 
189 template <typename scalar_t, typename MapOp, typename ReduceOp,
190           typename std::enable_if_t<!is_reduced_floating_point_v<scalar_t>, int> = 0>
191 inline scalar_t map2_reduce_all(
192     const MapOp& map_fun,
193     const ReduceOp& red_fun,
194     const scalar_t* data,
195     const scalar_t* data2,
196     int64_t size) {
197   using Vec = vec::Vectorized<scalar_t>;
198   if (size < Vec::size()) {
199     Vec data_vec = Vec::loadu(data, size);
200     Vec data2_vec = Vec::loadu(data2, size);
201     data_vec = map_fun(data_vec, data2_vec);
202     return vec_reduce_all(red_fun, data_vec, size);
203   }
204   int64_t d = Vec::size();
205   Vec acc_vec = map_fun(Vec::loadu(data), Vec::loadu(data2));
206   for (; d < size - (size % Vec::size()); d += Vec::size()) {
207     Vec data_vec = Vec::loadu(data + d);
208     Vec data2_vec = Vec::loadu(data2 + d);
209     data_vec = map_fun(data_vec, data2_vec);
210     acc_vec = red_fun(acc_vec, data_vec);
211   }
212   if (size - d > 0) {
213     Vec data_vec = Vec::loadu(data + d, size - d);
214     Vec data2_vec = Vec::loadu(data2 + d, size - d);
215     data_vec = map_fun(data_vec, data2_vec);
216     acc_vec = Vec::set(acc_vec, red_fun(acc_vec, data_vec), size - d);
217   }
218   return vec_reduce_all(red_fun, acc_vec);
219 }
220 
221 template <typename scalar_t, typename MapOp, typename ReduceOp,
222           typename std::enable_if_t<!is_reduced_floating_point_v<scalar_t>, int> = 0>
223 inline scalar_t map3_reduce_all(
224     const MapOp& map_fun,
225     const ReduceOp& red_fun,
226     const scalar_t* data,
227     const scalar_t* data2,
228     const scalar_t* data3,
229     int64_t size) {
230   using Vec = vec::Vectorized<scalar_t>;
231   if (size < Vec::size()) {
232     Vec data_vec = Vec::loadu(data, size);
233     Vec data2_vec = Vec::loadu(data2, size);
234     Vec data3_vec = Vec::loadu(data3, size);
235     data_vec = map_fun(data_vec, data2_vec, data3_vec);
236     return vec_reduce_all(red_fun, data_vec, size);
237   }
238 
239   int64_t d = Vec::size();
240   Vec acc_vec = map_fun(Vec::loadu(data), Vec::loadu(data2), Vec::loadu(data3));
241   for (; d < size - (size % Vec::size()); d += Vec::size()) {
242     Vec data_vec = Vec::loadu(data + d);
243     Vec data2_vec = Vec::loadu(data2 + d);
244     Vec data3_vec = Vec::loadu(data3 + d);
245     data_vec = map_fun(data_vec, data2_vec, data3_vec);
246     acc_vec = red_fun(acc_vec, data_vec);
247   }
248   if (size - d > 0) {
249     Vec data_vec = Vec::loadu(data + d, size - d);
250     Vec data2_vec = Vec::loadu(data2 + d, size - d);
251     Vec data3_vec = Vec::loadu(data3 + d, size - d);
252     data_vec = map_fun(data_vec, data2_vec, data3_vec);
253     acc_vec = Vec::set(acc_vec, red_fun(acc_vec, data_vec), size - d);
254   }
255   return vec_reduce_all(red_fun, acc_vec);
256 }
257 
258 template <typename scalar_t, typename Op,
259           typename std::enable_if_t<!is_reduced_floating_point_v<scalar_t>, int> = 0>
260 inline void map(
261     const Op& vec_fun,
262     scalar_t* output_data,
263     const scalar_t* input_data,
264     int64_t size) {
265   using Vec = vec::Vectorized<scalar_t>;
266   int64_t d = 0;
267   for (; d < size - (size % Vec::size()); d += Vec::size()) {
268     Vec output_vec = vec_fun(Vec::loadu(input_data + d));
269     output_vec.store(output_data + d);
270   }
271   if (size - d > 0) {
272     Vec output_vec = vec_fun(Vec::loadu(input_data + d, size - d));
273     output_vec.store(output_data + d, size - d);
274   }
275 }
276 
277 template <typename scalar_t, typename Op,
278           typename std::enable_if_t<!is_reduced_floating_point_v<scalar_t>, int> = 0>
279 inline void map2(
280     const Op& vec_fun,
281     scalar_t* output_data,
282     const scalar_t* input_data,
283     const scalar_t* input_data2,
284     int64_t size) {
285   using Vec = vec::Vectorized<scalar_t>;
286   int64_t d = 0;
287   for (; d < size - (size % Vec::size()); d += Vec::size()) {
288     Vec data_vec = Vec::loadu(input_data + d);
289     Vec data_vec2 = Vec::loadu(input_data2 + d);
290     Vec output_vec = vec_fun(data_vec, data_vec2);
291     output_vec.store(output_data + d);
292   }
293   if (size - d > 0) {
294     Vec data_vec = Vec::loadu(input_data + d, size - d);
295     Vec data_vec2 = Vec::loadu(input_data2 + d, size - d);
296     Vec output_vec = vec_fun(data_vec, data_vec2);
297     output_vec.store(output_data + d, size - d);
298   }
299 }
300 
301 template <typename scalar_t, typename Op,
302           typename std::enable_if_t<!is_reduced_floating_point_v<scalar_t>, int> = 0>
303 inline void map3(
304     const Op& vec_fun,
305     scalar_t* output_data,
306     const scalar_t* input_data1,
307     const scalar_t* input_data2,
308     const scalar_t* input_data3,
309     int64_t size) {
310   using Vec = vec::Vectorized<scalar_t>;
311   int64_t d = 0;
312   for (; d < size - (size % Vec::size()); d += Vec::size()) {
313     Vec data_vec1 = Vec::loadu(input_data1 + d);
314     Vec data_vec2 = Vec::loadu(input_data2 + d);
315     Vec data_vec3 = Vec::loadu(input_data3 + d);
316     Vec output_vec = vec_fun(data_vec1, data_vec2, data_vec3);
317     output_vec.store(output_data + d);
318   }
319   if (size - d > 0) {
320     Vec data_vec1 = Vec::loadu(input_data1 + d, size - d);
321     Vec data_vec2 = Vec::loadu(input_data2 + d, size - d);
322     Vec data_vec3 = Vec::loadu(input_data3 + d, size - d);
323     Vec output_vec = vec_fun(data_vec1, data_vec2, data_vec3);
324     output_vec.store(output_data + d, size - d);
325   }
326 }
327 
328 template <typename scalar_t, typename Op,
329           typename std::enable_if_t<!is_reduced_floating_point_v<scalar_t>, int> = 0>
330 inline void map4(
331     const Op& vec_fun,
332     scalar_t* output_data,
333     const scalar_t* input_data1,
334     const scalar_t* input_data2,
335     const scalar_t* input_data3,
336     const scalar_t* input_data4,
337     int64_t size) {
338   using Vec = vec::Vectorized<scalar_t>;
339   int64_t d = 0;
340   for (; d < size - (size % Vec::size()); d += Vec::size()) {
341     Vec data_vec1 = Vec::loadu(input_data1 + d);
342     Vec data_vec2 = Vec::loadu(input_data2 + d);
343     Vec data_vec3 = Vec::loadu(input_data3 + d);
344     Vec data_vec4 = Vec::loadu(input_data4 + d);
345     Vec output_vec = vec_fun(data_vec1, data_vec2, data_vec3, data_vec4);
346     output_vec.store(output_data + d);
347   }
348   if (size - d > 0) {
349     Vec data_vec1 = Vec::loadu(input_data1 + d, size - d);
350     Vec data_vec2 = Vec::loadu(input_data2 + d, size - d);
351     Vec data_vec3 = Vec::loadu(input_data3 + d, size - d);
352     Vec data_vec4 = Vec::loadu(input_data4 + d, size - d);
353     Vec output_vec = vec_fun(data_vec1, data_vec2, data_vec3, data_vec4);
354     output_vec.store(output_data + d, size - d);
355   }
356 }
357 
358 } // namespace at::vec
359