xref: /aosp_15_r20/external/executorch/kernels/optimized/vec/functional_base.h (revision 523fa7a60841cd1ecfb9cc4201f1ca8b03ed023a)
1 /*
2  * Copyright (c) Meta Platforms, Inc. and affiliates.
3  * All rights reserved.
4  *
5  * This source code is licensed under the BSD-style license found in the
6  * LICENSE file in the root directory of this source tree.
7  */
8 
9 #pragma once
10 
11 // DO NOT DEFINE STATIC DATA IN THIS HEADER!
12 // See Note [Do not compile initializers with AVX]
13 
14 #include <executorch/kernels/optimized/vec/vec.h>
15 
16 namespace executorch {
17 namespace vec {
18 
19 // slow path
20 template <typename scalar_t, typename Op>
vec_reduce_all(const Op & vec_fun,vec::Vectorized<scalar_t> acc_vec,int64_t size)21 inline scalar_t vec_reduce_all(
22     const Op& vec_fun,
23     vec::Vectorized<scalar_t> acc_vec,
24     int64_t size) {
25   using Vec = vec::Vectorized<scalar_t>;
26   scalar_t acc_arr[Vec::size()];
27   acc_vec.store(acc_arr);
28   for (int64_t i = 1; i < size; ++i) {
29     std::array<scalar_t, Vec::size()> acc_arr_next = {0};
30     acc_arr_next[0] = acc_arr[i];
31     Vec acc_vec_next = Vec::loadu(acc_arr_next.data());
32     acc_vec = vec_fun(acc_vec, acc_vec_next);
33   }
34   acc_vec.store(acc_arr);
35   return acc_arr[0];
36 }
37 
38 template <typename scalar_t, typename Op>
39 struct VecReduceAllSIMD {
applyVecReduceAllSIMD40   static inline scalar_t apply(const Op& vec_fun, const Vectorized<scalar_t>& acc_vec) {
41     return vec_reduce_all(vec_fun, acc_vec, Vectorized<scalar_t>::size());
42   }
43 };
44 
45 #if defined(__GNUC__) && (__GNUC__ > 5) && !defined(_MSC_VER) && !defined(C10_MOBILE)
46 #if defined(CPU_CAPABILITY_AVX2)
47 template <typename Op>
48 struct VecReduceAllSIMD<float, Op> {
49   static inline float apply(const Op& vec_fun, const Vectorized<float>& acc_vec) {
50     using Vec = Vectorized<float>;
51     Vec v = acc_vec;
52     // 128-bit shuffle
53     Vec v1 = _mm256_permute2f128_ps(v, v, 0x1);
54     v = vec_fun(v, v1);
55     // 64-bit shuffle
56     v1 = _mm256_shuffle_ps(v, v, 0x4E);
57     v = vec_fun(v, v1);
58     // 32-bit shuffle
59     v1 = _mm256_shuffle_ps(v, v, 0xB1);
60     v = vec_fun(v, v1);
61     return _mm256_cvtss_f32(v);
62   }
63 };
64 #endif // defined(CPU_CAPABILITY_AVX2)
65 #if defined(CPU_CAPABILITY_AVX512)
66 template <typename Op>
67 struct VecReduceAllSIMD<float, Op> {
68   static inline float apply(const Op& vec_fun, const Vectorized<float>& acc_vec) {
69     using Vec = Vectorized<float>;
70     Vec v = acc_vec;
71     // 256-bit shuffle
72     Vec v1 = _mm512_shuffle_f32x4(v, v, 0x4E);
73     v = vec_fun(v, v1);
74     // 128-bit shuffle
75     v1 = _mm512_shuffle_f32x4(v, v, 0xB1);
76     v = vec_fun(v, v1);
77     // 64-bit shuffle
78     v1 = _mm512_shuffle_ps(v, v, 0x4E);
79     v = vec_fun(v, v1);
80     // 32-bit shuffle
81     v1 = _mm512_shuffle_ps(v, v, 0xB1);
82     v = vec_fun(v, v1);
83     return _mm512_cvtss_f32(v);
84   }
85 };
86 #endif // defined(CPU_CAPABILITY_AVX512)
87 #endif // defined(__GNUC__) && (__GNUC__ > 5) && !defined(_MSC_VER) && !defined(C10_MOBILE)
88 
89 template <typename scalar_t, typename Op>
90 inline scalar_t vec_reduce_all(const Op& vec_fun, const Vectorized<scalar_t>& acc_vec) {
91   return VecReduceAllSIMD<scalar_t, Op>::apply(vec_fun, acc_vec);
92 }
93 
94 template <typename scalar_t, typename Op>
95 inline scalar_t reduce_all(const Op& vec_fun, const scalar_t* data, int64_t size) {
96   using Vec = vec::Vectorized<scalar_t>;
97   if (size < Vec::size())
98     return vec_reduce_all(vec_fun, Vec::loadu(data, size), size);
99   int64_t d = Vec::size();
100   Vec acc_vec = Vec::loadu(data);
101   for (; d < size - (size % Vec::size()); d += Vec::size()) {
102     Vec data_vec = Vec::loadu(data + d);
103     acc_vec = vec_fun(acc_vec, data_vec);
104   }
105   if (size - d > 0) {
106     Vec data_vec = Vec::loadu(data + d, size - d);
107     acc_vec = Vec::set(acc_vec, vec_fun(acc_vec, data_vec), size - d);
108   }
109   return vec_reduce_all(vec_fun, acc_vec);
110 }
111 
112 // similar to reduce_all, but reduces into two outputs
113 template <typename scalar_t, typename Op1, typename Op2>
114 inline std::pair<scalar_t, scalar_t> reduce2_all(const Op1& vec_fun1, const Op2& vec_fun2,
115     const scalar_t* data, int64_t size) {
116   using Vec = vec::Vectorized<scalar_t>;
117   if (size < Vec::size()) {
118     auto loaded_data = Vec::loadu(data, size);
119     return std::pair<scalar_t, scalar_t>(
120       vec_reduce_all(vec_fun1, loaded_data, size),
121       vec_reduce_all(vec_fun2, loaded_data, size));
122   }
123   int64_t d = Vec::size();
124   Vec acc_vec1 = Vec::loadu(data);
125   Vec acc_vec2 = Vec::loadu(data);
126   for (; d < size - (size % Vec::size()); d += Vec::size()) {
127     Vec data_vec = Vec::loadu(data + d);
128     acc_vec1 = vec_fun1(acc_vec1, data_vec);
129     acc_vec2 = vec_fun2(acc_vec2, data_vec);
130   }
131   if (size - d > 0) {
132     Vec data_vec = Vec::loadu(data + d, size - d);
133     acc_vec1 = Vec::set(acc_vec1, vec_fun1(acc_vec1, data_vec), size - d);
134     acc_vec2 = Vec::set(acc_vec2, vec_fun2(acc_vec2, data_vec), size - d);
135   }
136   return std::pair<scalar_t, scalar_t>(
137     vec_reduce_all(vec_fun1, acc_vec1),
138     vec_reduce_all(vec_fun2, acc_vec2));
139 }
140 
141 template <typename scalar_t, typename MapOp, typename ReduceOp>
142 inline scalar_t map_reduce_all(
143     const MapOp& map_fun,
144     const ReduceOp& red_fun,
145     const scalar_t* data,
146     int64_t size) {
147   using Vec = vec::Vectorized<scalar_t>;
148   if (size < Vec::size())
149     return vec_reduce_all(red_fun, map_fun(Vec::loadu(data, size)), size);
150   int64_t d = Vec::size();
151   Vec acc_vec = map_fun(Vec::loadu(data));
152   for (; d < size - (size % Vec::size()); d += Vec::size()) {
153     Vec data_vec = Vec::loadu(data + d);
154     data_vec = map_fun(data_vec);
155     acc_vec = red_fun(acc_vec, data_vec);
156   }
157   if (size - d > 0) {
158     Vec data_vec = Vec::loadu(data + d, size - d);
159     data_vec = map_fun(data_vec);
160     acc_vec = Vec::set(acc_vec, red_fun(acc_vec, data_vec), size - d);
161   }
162   return vec_reduce_all(red_fun, acc_vec);
163 }
164 
165 template <typename scalar_t, typename MapOp, typename ReduceOp>
166 inline scalar_t map2_reduce_all(
167     const MapOp& map_fun,
168     const ReduceOp& red_fun,
169     const scalar_t* data,
170     const scalar_t* data2,
171     int64_t size) {
172   using Vec = vec::Vectorized<scalar_t>;
173   if (size < Vec::size()) {
174     Vec data_vec = Vec::loadu(data, size);
175     Vec data2_vec = Vec::loadu(data2, size);
176     data_vec = map_fun(data_vec, data2_vec);
177     return vec_reduce_all(red_fun, data_vec, size);
178   }
179   int64_t d = Vec::size();
180   Vec acc_vec = map_fun(Vec::loadu(data), Vec::loadu(data2));
181   for (; d < size - (size % Vec::size()); d += Vec::size()) {
182     Vec data_vec = Vec::loadu(data + d);
183     Vec data2_vec = Vec::loadu(data2 + d);
184     data_vec = map_fun(data_vec, data2_vec);
185     acc_vec = red_fun(acc_vec, data_vec);
186   }
187   if (size - d > 0) {
188     Vec data_vec = Vec::loadu(data + d, size - d);
189     Vec data2_vec = Vec::loadu(data2 + d, size - d);
190     data_vec = map_fun(data_vec, data2_vec);
191     acc_vec = Vec::set(acc_vec, red_fun(acc_vec, data_vec), size - d);
192   }
193   return vec_reduce_all(red_fun, acc_vec);
194 }
195 
196 template <typename scalar_t, typename MapOp, typename ReduceOp>
197 inline scalar_t map3_reduce_all(
198     const MapOp& map_fun,
199     const ReduceOp& red_fun,
200     const scalar_t* data,
201     const scalar_t* data2,
202     const scalar_t* data3,
203     int64_t size) {
204   using Vec = vec::Vectorized<scalar_t>;
205   if (size < Vec::size()) {
206     Vec data_vec = Vec::loadu(data, size);
207     Vec data2_vec = Vec::loadu(data2, size);
208     Vec data3_vec = Vec::loadu(data3, size);
209     data_vec = map_fun(data_vec, data2_vec, data3_vec);
210     return vec_reduce_all(red_fun, data_vec, size);
211   }
212 
213   int64_t d = Vec::size();
214   Vec acc_vec = map_fun(Vec::loadu(data), Vec::loadu(data2), Vec::loadu(data3));
215   for (; d < size - (size % Vec::size()); d += Vec::size()) {
216     Vec data_vec = Vec::loadu(data + d);
217     Vec data2_vec = Vec::loadu(data2 + d);
218     Vec data3_vec = Vec::loadu(data3 + d);
219     data_vec = map_fun(data_vec, data2_vec, data3_vec);
220     acc_vec = red_fun(acc_vec, data_vec);
221   }
222   if (size - d > 0) {
223     Vec data_vec = Vec::loadu(data + d, size - d);
224     Vec data2_vec = Vec::loadu(data2 + d, size - d);
225     Vec data3_vec = Vec::loadu(data3 + d, size - d);
226     data_vec = map_fun(data_vec, data2_vec, data3_vec);
227     acc_vec = Vec::set(acc_vec, red_fun(acc_vec, data_vec), size - d);
228   }
229   return vec_reduce_all(red_fun, acc_vec);
230 }
231 
232 template <typename scalar_t, typename Op>
233 inline void map(
234     const Op& vec_fun,
235     scalar_t* output_data,
236     const scalar_t* input_data,
237     int64_t size) {
238   using Vec = vec::Vectorized<scalar_t>;
239   int64_t d = 0;
240   for (; d < size - (size % Vec::size()); d += Vec::size()) {
241     Vec output_vec = vec_fun(Vec::loadu(input_data + d));
242     output_vec.store(output_data + d);
243   }
244   if (size - d > 0) {
245     Vec output_vec = vec_fun(Vec::loadu(input_data + d, size - d));
246     output_vec.store(output_data + d, size - d);
247   }
248 }
249 
250 template <typename scalar_t, typename Op>
251 inline void map2(
252     const Op& vec_fun,
253     scalar_t* output_data,
254     const scalar_t* input_data,
255     const scalar_t* input_data2,
256     int64_t size) {
257   using Vec = vec::Vectorized<scalar_t>;
258   int64_t d = 0;
259   for (; d < size - (size % Vec::size()); d += Vec::size()) {
260     Vec data_vec = Vec::loadu(input_data + d);
261     Vec data_vec2 = Vec::loadu(input_data2 + d);
262     Vec output_vec = vec_fun(data_vec, data_vec2);
263     output_vec.store(output_data + d);
264   }
265   if (size - d > 0) {
266     Vec data_vec = Vec::loadu(input_data + d, size - d);
267     Vec data_vec2 = Vec::loadu(input_data2 + d, size - d);
268     Vec output_vec = vec_fun(data_vec, data_vec2);
269     output_vec.store(output_data + d, size - d);
270   }
271 }
272 
273 template <typename scalar_t, typename Op>
274 inline void map3(
275     const Op& vec_fun,
276     scalar_t* output_data,
277     const scalar_t* input_data1,
278     const scalar_t* input_data2,
279     const scalar_t* input_data3,
280     int64_t size) {
281   using Vec = vec::Vectorized<scalar_t>;
282   int64_t d = 0;
283   for (; d < size - (size % Vec::size()); d += Vec::size()) {
284     Vec data_vec1 = Vec::loadu(input_data1 + d);
285     Vec data_vec2 = Vec::loadu(input_data2 + d);
286     Vec data_vec3 = Vec::loadu(input_data3 + d);
287     Vec output_vec = vec_fun(data_vec1, data_vec2, data_vec3);
288     output_vec.store(output_data + d);
289   }
290   if (size - d > 0) {
291     Vec data_vec1 = Vec::loadu(input_data1 + d, size - d);
292     Vec data_vec2 = Vec::loadu(input_data2 + d, size - d);
293     Vec data_vec3 = Vec::loadu(input_data3 + d, size - d);
294     Vec output_vec = vec_fun(data_vec1, data_vec2, data_vec3);
295     output_vec.store(output_data + d, size - d);
296   }
297 }
298 
299 template <typename scalar_t, typename Op>
300 inline void map4(
301     const Op& vec_fun,
302     scalar_t* output_data,
303     const scalar_t* input_data1,
304     const scalar_t* input_data2,
305     const scalar_t* input_data3,
306     const scalar_t* input_data4,
307     int64_t size) {
308   using Vec = vec::Vectorized<scalar_t>;
309   int64_t d = 0;
310   for (; d < size - (size % Vec::size()); d += Vec::size()) {
311     Vec data_vec1 = Vec::loadu(input_data1 + d);
312     Vec data_vec2 = Vec::loadu(input_data2 + d);
313     Vec data_vec3 = Vec::loadu(input_data3 + d);
314     Vec data_vec4 = Vec::loadu(input_data4 + d);
315     Vec output_vec = vec_fun(data_vec1, data_vec2, data_vec3, data_vec4);
316     output_vec.store(output_data + d);
317   }
318   if (size - d > 0) {
319     Vec data_vec1 = Vec::loadu(input_data1 + d, size - d);
320     Vec data_vec2 = Vec::loadu(input_data2 + d, size - d);
321     Vec data_vec3 = Vec::loadu(input_data3 + d, size - d);
322     Vec data_vec4 = Vec::loadu(input_data4 + d, size - d);
323     Vec output_vec = vec_fun(data_vec1, data_vec2, data_vec3, data_vec4);
324     output_vec.store(output_data + d, size - d);
325   }
326 }
327 
328 
329 // This function implements broadcasting binary operation on two tensors
330 // where lhs tensor is treated to be of shape [outer_size, broadcast_size, inner_size]
331 // and rhs tensor is treated to be of shape [outer_size, 1, inner_size]
332 // And this 1st dimension is considered broadcasting dimension
333 // This formula can map broadcasting on any dim=broadcast_dim
334 // for any two N dimensional tensors, where 0 < braodcast_dim < N-1
335 template <typename scalar_t, typename Op>
336 inline void broadcasting_map_3d_and_unsqueezed_3d(
337     const Op& vec_fun,
338     scalar_t* output_data,
339     const scalar_t* lhs,
340     const scalar_t* rhs,
341     int64_t outer_size,
342     int64_t broadcast_size,
343     int64_t inner_size) {
344   using Vec = vec::Vectorized<scalar_t>;
345   int64_t outer_stride_lhs = inner_size * broadcast_size;
346   int64_t outer_stride_rhs = inner_size;
347   int64_t broadcast_stride_lhs = inner_size;
348   for (int64_t outer_idx = 0; outer_idx < outer_size; ++outer_idx) {
349     const scalar_t* lhs_outer = lhs + outer_idx * outer_stride_lhs;
350     scalar_t* output_data_row = output_data + outer_idx * outer_stride_lhs;
351     const scalar_t* rhs_outer = rhs + outer_idx * outer_stride_rhs;
352     for (int64_t broadcast_idx = 0; broadcast_idx < broadcast_size; ++broadcast_idx) {
353       const scalar_t* lhs_outer_2 = lhs_outer + broadcast_idx * broadcast_stride_lhs;
354       scalar_t* output_data_row_2 = output_data_row + broadcast_idx * broadcast_stride_lhs;
355       int64_t inner_idx = 0;
356       for (; inner_idx < inner_size - (inner_size % Vec::size()); inner_idx += Vec::size()) {
357         Vec data_vec = Vec::loadu(lhs_outer_2 + inner_idx);
358         Vec data_vec2 = Vec::loadu(rhs_outer + inner_idx);
359         Vec output_vec = vec_fun(data_vec, data_vec2);
360         output_vec.store(output_data_row_2 + inner_idx);
361       }
362       if (inner_size - inner_idx > 0) {
363         Vec data_vec = Vec::loadu(lhs_outer_2 + inner_idx, inner_size - inner_idx);
364         Vec data_vec2 = Vec::loadu(rhs_outer + inner_idx, inner_size - inner_idx);
365         Vec output_vec = vec_fun(data_vec, data_vec2);
366         output_vec.store(output_data_row_2 + inner_idx, inner_size - inner_idx);
367       }
368     }
369   }
370 }
371 
372 template <typename scalar_t, typename Op>
373 inline void broadcasting_map_2d_by_1d(
374     const Op& vec_fun,
375     scalar_t* output_data,
376     const scalar_t* input_data,
377     const scalar_t* input_data2,
378     int64_t size,
379     int64_t size2) {
380   broadcasting_map_3d_and_unsqueezed_3d(vec_fun, output_data, input_data, input_data2, 1, size, size2);
381 }
382 
383 /*
384 Following function is used to implement broadcasting binary operation on two tensors
385 where lhs tensor is treated to be of shape [outer_size, broadcast_size] and
386 rhs tensor is treated to be of shape [outer_size, 1]
387 Any two N dimensional tensors can be mapped to this formula
388 when lhs size = [lhs0, lhs1, ..., lhsN-1] and rhs size = [rhs0, rhs1, ..., 1]
389 by viewing the two tensors as
390 lhs size = [lsh0 * lsh1 * ... * lshN-2, lhsN-1]
391 rhs size = [rsh0 * rsh1 * ... * rshN-2, 1]
392 */
393 template <typename scalar_t, typename Op>
394 inline void broadcasting_map_broadcast_last_dim(
395     const Op& vec_fun,
396     scalar_t* output_data,
397     const scalar_t* lhs,
398     const scalar_t* rhs,
399     int64_t outer_size,
400     int64_t broadcast_size) {
401   using Vec = vec::Vectorized<scalar_t>;
402   int64_t outer_stride_lhs = broadcast_size;
403   for (int64_t outer_idx = 0; outer_idx < outer_size; ++outer_idx) {
404     const scalar_t* lhs_outer = lhs + outer_idx * outer_stride_lhs;
405     scalar_t* output_data_row = output_data + outer_idx * outer_stride_lhs;
406     int64_t inner_idx = 0;
407     Vec data_vec2 = Vec(rhs[outer_idx]);
408     for (; inner_idx < broadcast_size - (broadcast_size % Vec::size()); inner_idx += Vec::size()) {
409       Vec data_vec = Vec::loadu(lhs_outer + inner_idx);
410       Vec output_vec = vec_fun(data_vec, data_vec2);
411       output_vec.store(output_data_row + inner_idx);
412     }
413     if (broadcast_size - inner_idx > 0) {
414       Vec data_vec = Vec::loadu(lhs_outer + inner_idx, broadcast_size - inner_idx);
415       Vec output_vec = vec_fun(data_vec, data_vec2);
416       output_vec.store(output_data_row + inner_idx, broadcast_size - inner_idx);
417     }
418   }
419 }
420 
421 } // namespace vec
422 } // namespace executorch
423