1 /*
2 * Copyright (c) Meta Platforms, Inc. and affiliates.
3 * All rights reserved.
4 *
5 * This source code is licensed under the BSD-style license found in the
6 * LICENSE file in the root directory of this source tree.
7 */
8
9 #pragma once
10
11 // DO NOT DEFINE STATIC DATA IN THIS HEADER!
12 // See Note [Do not compile initializers with AVX]
13
14 #include <executorch/kernels/optimized/vec/vec.h>
15
16 namespace executorch {
17 namespace vec {
18
19 // slow path
20 template <typename scalar_t, typename Op>
vec_reduce_all(const Op & vec_fun,vec::Vectorized<scalar_t> acc_vec,int64_t size)21 inline scalar_t vec_reduce_all(
22 const Op& vec_fun,
23 vec::Vectorized<scalar_t> acc_vec,
24 int64_t size) {
25 using Vec = vec::Vectorized<scalar_t>;
26 scalar_t acc_arr[Vec::size()];
27 acc_vec.store(acc_arr);
28 for (int64_t i = 1; i < size; ++i) {
29 std::array<scalar_t, Vec::size()> acc_arr_next = {0};
30 acc_arr_next[0] = acc_arr[i];
31 Vec acc_vec_next = Vec::loadu(acc_arr_next.data());
32 acc_vec = vec_fun(acc_vec, acc_vec_next);
33 }
34 acc_vec.store(acc_arr);
35 return acc_arr[0];
36 }
37
38 template <typename scalar_t, typename Op>
39 struct VecReduceAllSIMD {
applyVecReduceAllSIMD40 static inline scalar_t apply(const Op& vec_fun, const Vectorized<scalar_t>& acc_vec) {
41 return vec_reduce_all(vec_fun, acc_vec, Vectorized<scalar_t>::size());
42 }
43 };
44
45 #if defined(__GNUC__) && (__GNUC__ > 5) && !defined(_MSC_VER) && !defined(C10_MOBILE)
46 #if defined(CPU_CAPABILITY_AVX2)
47 template <typename Op>
48 struct VecReduceAllSIMD<float, Op> {
49 static inline float apply(const Op& vec_fun, const Vectorized<float>& acc_vec) {
50 using Vec = Vectorized<float>;
51 Vec v = acc_vec;
52 // 128-bit shuffle
53 Vec v1 = _mm256_permute2f128_ps(v, v, 0x1);
54 v = vec_fun(v, v1);
55 // 64-bit shuffle
56 v1 = _mm256_shuffle_ps(v, v, 0x4E);
57 v = vec_fun(v, v1);
58 // 32-bit shuffle
59 v1 = _mm256_shuffle_ps(v, v, 0xB1);
60 v = vec_fun(v, v1);
61 return _mm256_cvtss_f32(v);
62 }
63 };
64 #endif // defined(CPU_CAPABILITY_AVX2)
65 #if defined(CPU_CAPABILITY_AVX512)
66 template <typename Op>
67 struct VecReduceAllSIMD<float, Op> {
68 static inline float apply(const Op& vec_fun, const Vectorized<float>& acc_vec) {
69 using Vec = Vectorized<float>;
70 Vec v = acc_vec;
71 // 256-bit shuffle
72 Vec v1 = _mm512_shuffle_f32x4(v, v, 0x4E);
73 v = vec_fun(v, v1);
74 // 128-bit shuffle
75 v1 = _mm512_shuffle_f32x4(v, v, 0xB1);
76 v = vec_fun(v, v1);
77 // 64-bit shuffle
78 v1 = _mm512_shuffle_ps(v, v, 0x4E);
79 v = vec_fun(v, v1);
80 // 32-bit shuffle
81 v1 = _mm512_shuffle_ps(v, v, 0xB1);
82 v = vec_fun(v, v1);
83 return _mm512_cvtss_f32(v);
84 }
85 };
86 #endif // defined(CPU_CAPABILITY_AVX512)
87 #endif // defined(__GNUC__) && (__GNUC__ > 5) && !defined(_MSC_VER) && !defined(C10_MOBILE)
88
89 template <typename scalar_t, typename Op>
90 inline scalar_t vec_reduce_all(const Op& vec_fun, const Vectorized<scalar_t>& acc_vec) {
91 return VecReduceAllSIMD<scalar_t, Op>::apply(vec_fun, acc_vec);
92 }
93
94 template <typename scalar_t, typename Op>
95 inline scalar_t reduce_all(const Op& vec_fun, const scalar_t* data, int64_t size) {
96 using Vec = vec::Vectorized<scalar_t>;
97 if (size < Vec::size())
98 return vec_reduce_all(vec_fun, Vec::loadu(data, size), size);
99 int64_t d = Vec::size();
100 Vec acc_vec = Vec::loadu(data);
101 for (; d < size - (size % Vec::size()); d += Vec::size()) {
102 Vec data_vec = Vec::loadu(data + d);
103 acc_vec = vec_fun(acc_vec, data_vec);
104 }
105 if (size - d > 0) {
106 Vec data_vec = Vec::loadu(data + d, size - d);
107 acc_vec = Vec::set(acc_vec, vec_fun(acc_vec, data_vec), size - d);
108 }
109 return vec_reduce_all(vec_fun, acc_vec);
110 }
111
112 // similar to reduce_all, but reduces into two outputs
113 template <typename scalar_t, typename Op1, typename Op2>
114 inline std::pair<scalar_t, scalar_t> reduce2_all(const Op1& vec_fun1, const Op2& vec_fun2,
115 const scalar_t* data, int64_t size) {
116 using Vec = vec::Vectorized<scalar_t>;
117 if (size < Vec::size()) {
118 auto loaded_data = Vec::loadu(data, size);
119 return std::pair<scalar_t, scalar_t>(
120 vec_reduce_all(vec_fun1, loaded_data, size),
121 vec_reduce_all(vec_fun2, loaded_data, size));
122 }
123 int64_t d = Vec::size();
124 Vec acc_vec1 = Vec::loadu(data);
125 Vec acc_vec2 = Vec::loadu(data);
126 for (; d < size - (size % Vec::size()); d += Vec::size()) {
127 Vec data_vec = Vec::loadu(data + d);
128 acc_vec1 = vec_fun1(acc_vec1, data_vec);
129 acc_vec2 = vec_fun2(acc_vec2, data_vec);
130 }
131 if (size - d > 0) {
132 Vec data_vec = Vec::loadu(data + d, size - d);
133 acc_vec1 = Vec::set(acc_vec1, vec_fun1(acc_vec1, data_vec), size - d);
134 acc_vec2 = Vec::set(acc_vec2, vec_fun2(acc_vec2, data_vec), size - d);
135 }
136 return std::pair<scalar_t, scalar_t>(
137 vec_reduce_all(vec_fun1, acc_vec1),
138 vec_reduce_all(vec_fun2, acc_vec2));
139 }
140
141 template <typename scalar_t, typename MapOp, typename ReduceOp>
142 inline scalar_t map_reduce_all(
143 const MapOp& map_fun,
144 const ReduceOp& red_fun,
145 const scalar_t* data,
146 int64_t size) {
147 using Vec = vec::Vectorized<scalar_t>;
148 if (size < Vec::size())
149 return vec_reduce_all(red_fun, map_fun(Vec::loadu(data, size)), size);
150 int64_t d = Vec::size();
151 Vec acc_vec = map_fun(Vec::loadu(data));
152 for (; d < size - (size % Vec::size()); d += Vec::size()) {
153 Vec data_vec = Vec::loadu(data + d);
154 data_vec = map_fun(data_vec);
155 acc_vec = red_fun(acc_vec, data_vec);
156 }
157 if (size - d > 0) {
158 Vec data_vec = Vec::loadu(data + d, size - d);
159 data_vec = map_fun(data_vec);
160 acc_vec = Vec::set(acc_vec, red_fun(acc_vec, data_vec), size - d);
161 }
162 return vec_reduce_all(red_fun, acc_vec);
163 }
164
165 template <typename scalar_t, typename MapOp, typename ReduceOp>
166 inline scalar_t map2_reduce_all(
167 const MapOp& map_fun,
168 const ReduceOp& red_fun,
169 const scalar_t* data,
170 const scalar_t* data2,
171 int64_t size) {
172 using Vec = vec::Vectorized<scalar_t>;
173 if (size < Vec::size()) {
174 Vec data_vec = Vec::loadu(data, size);
175 Vec data2_vec = Vec::loadu(data2, size);
176 data_vec = map_fun(data_vec, data2_vec);
177 return vec_reduce_all(red_fun, data_vec, size);
178 }
179 int64_t d = Vec::size();
180 Vec acc_vec = map_fun(Vec::loadu(data), Vec::loadu(data2));
181 for (; d < size - (size % Vec::size()); d += Vec::size()) {
182 Vec data_vec = Vec::loadu(data + d);
183 Vec data2_vec = Vec::loadu(data2 + d);
184 data_vec = map_fun(data_vec, data2_vec);
185 acc_vec = red_fun(acc_vec, data_vec);
186 }
187 if (size - d > 0) {
188 Vec data_vec = Vec::loadu(data + d, size - d);
189 Vec data2_vec = Vec::loadu(data2 + d, size - d);
190 data_vec = map_fun(data_vec, data2_vec);
191 acc_vec = Vec::set(acc_vec, red_fun(acc_vec, data_vec), size - d);
192 }
193 return vec_reduce_all(red_fun, acc_vec);
194 }
195
196 template <typename scalar_t, typename MapOp, typename ReduceOp>
197 inline scalar_t map3_reduce_all(
198 const MapOp& map_fun,
199 const ReduceOp& red_fun,
200 const scalar_t* data,
201 const scalar_t* data2,
202 const scalar_t* data3,
203 int64_t size) {
204 using Vec = vec::Vectorized<scalar_t>;
205 if (size < Vec::size()) {
206 Vec data_vec = Vec::loadu(data, size);
207 Vec data2_vec = Vec::loadu(data2, size);
208 Vec data3_vec = Vec::loadu(data3, size);
209 data_vec = map_fun(data_vec, data2_vec, data3_vec);
210 return vec_reduce_all(red_fun, data_vec, size);
211 }
212
213 int64_t d = Vec::size();
214 Vec acc_vec = map_fun(Vec::loadu(data), Vec::loadu(data2), Vec::loadu(data3));
215 for (; d < size - (size % Vec::size()); d += Vec::size()) {
216 Vec data_vec = Vec::loadu(data + d);
217 Vec data2_vec = Vec::loadu(data2 + d);
218 Vec data3_vec = Vec::loadu(data3 + d);
219 data_vec = map_fun(data_vec, data2_vec, data3_vec);
220 acc_vec = red_fun(acc_vec, data_vec);
221 }
222 if (size - d > 0) {
223 Vec data_vec = Vec::loadu(data + d, size - d);
224 Vec data2_vec = Vec::loadu(data2 + d, size - d);
225 Vec data3_vec = Vec::loadu(data3 + d, size - d);
226 data_vec = map_fun(data_vec, data2_vec, data3_vec);
227 acc_vec = Vec::set(acc_vec, red_fun(acc_vec, data_vec), size - d);
228 }
229 return vec_reduce_all(red_fun, acc_vec);
230 }
231
232 template <typename scalar_t, typename Op>
233 inline void map(
234 const Op& vec_fun,
235 scalar_t* output_data,
236 const scalar_t* input_data,
237 int64_t size) {
238 using Vec = vec::Vectorized<scalar_t>;
239 int64_t d = 0;
240 for (; d < size - (size % Vec::size()); d += Vec::size()) {
241 Vec output_vec = vec_fun(Vec::loadu(input_data + d));
242 output_vec.store(output_data + d);
243 }
244 if (size - d > 0) {
245 Vec output_vec = vec_fun(Vec::loadu(input_data + d, size - d));
246 output_vec.store(output_data + d, size - d);
247 }
248 }
249
250 template <typename scalar_t, typename Op>
251 inline void map2(
252 const Op& vec_fun,
253 scalar_t* output_data,
254 const scalar_t* input_data,
255 const scalar_t* input_data2,
256 int64_t size) {
257 using Vec = vec::Vectorized<scalar_t>;
258 int64_t d = 0;
259 for (; d < size - (size % Vec::size()); d += Vec::size()) {
260 Vec data_vec = Vec::loadu(input_data + d);
261 Vec data_vec2 = Vec::loadu(input_data2 + d);
262 Vec output_vec = vec_fun(data_vec, data_vec2);
263 output_vec.store(output_data + d);
264 }
265 if (size - d > 0) {
266 Vec data_vec = Vec::loadu(input_data + d, size - d);
267 Vec data_vec2 = Vec::loadu(input_data2 + d, size - d);
268 Vec output_vec = vec_fun(data_vec, data_vec2);
269 output_vec.store(output_data + d, size - d);
270 }
271 }
272
273 template <typename scalar_t, typename Op>
274 inline void map3(
275 const Op& vec_fun,
276 scalar_t* output_data,
277 const scalar_t* input_data1,
278 const scalar_t* input_data2,
279 const scalar_t* input_data3,
280 int64_t size) {
281 using Vec = vec::Vectorized<scalar_t>;
282 int64_t d = 0;
283 for (; d < size - (size % Vec::size()); d += Vec::size()) {
284 Vec data_vec1 = Vec::loadu(input_data1 + d);
285 Vec data_vec2 = Vec::loadu(input_data2 + d);
286 Vec data_vec3 = Vec::loadu(input_data3 + d);
287 Vec output_vec = vec_fun(data_vec1, data_vec2, data_vec3);
288 output_vec.store(output_data + d);
289 }
290 if (size - d > 0) {
291 Vec data_vec1 = Vec::loadu(input_data1 + d, size - d);
292 Vec data_vec2 = Vec::loadu(input_data2 + d, size - d);
293 Vec data_vec3 = Vec::loadu(input_data3 + d, size - d);
294 Vec output_vec = vec_fun(data_vec1, data_vec2, data_vec3);
295 output_vec.store(output_data + d, size - d);
296 }
297 }
298
299 template <typename scalar_t, typename Op>
300 inline void map4(
301 const Op& vec_fun,
302 scalar_t* output_data,
303 const scalar_t* input_data1,
304 const scalar_t* input_data2,
305 const scalar_t* input_data3,
306 const scalar_t* input_data4,
307 int64_t size) {
308 using Vec = vec::Vectorized<scalar_t>;
309 int64_t d = 0;
310 for (; d < size - (size % Vec::size()); d += Vec::size()) {
311 Vec data_vec1 = Vec::loadu(input_data1 + d);
312 Vec data_vec2 = Vec::loadu(input_data2 + d);
313 Vec data_vec3 = Vec::loadu(input_data3 + d);
314 Vec data_vec4 = Vec::loadu(input_data4 + d);
315 Vec output_vec = vec_fun(data_vec1, data_vec2, data_vec3, data_vec4);
316 output_vec.store(output_data + d);
317 }
318 if (size - d > 0) {
319 Vec data_vec1 = Vec::loadu(input_data1 + d, size - d);
320 Vec data_vec2 = Vec::loadu(input_data2 + d, size - d);
321 Vec data_vec3 = Vec::loadu(input_data3 + d, size - d);
322 Vec data_vec4 = Vec::loadu(input_data4 + d, size - d);
323 Vec output_vec = vec_fun(data_vec1, data_vec2, data_vec3, data_vec4);
324 output_vec.store(output_data + d, size - d);
325 }
326 }
327
328
329 // This function implements broadcasting binary operation on two tensors
330 // where lhs tensor is treated to be of shape [outer_size, broadcast_size, inner_size]
331 // and rhs tensor is treated to be of shape [outer_size, 1, inner_size]
332 // And this 1st dimension is considered broadcasting dimension
333 // This formula can map broadcasting on any dim=broadcast_dim
334 // for any two N dimensional tensors, where 0 < braodcast_dim < N-1
335 template <typename scalar_t, typename Op>
336 inline void broadcasting_map_3d_and_unsqueezed_3d(
337 const Op& vec_fun,
338 scalar_t* output_data,
339 const scalar_t* lhs,
340 const scalar_t* rhs,
341 int64_t outer_size,
342 int64_t broadcast_size,
343 int64_t inner_size) {
344 using Vec = vec::Vectorized<scalar_t>;
345 int64_t outer_stride_lhs = inner_size * broadcast_size;
346 int64_t outer_stride_rhs = inner_size;
347 int64_t broadcast_stride_lhs = inner_size;
348 for (int64_t outer_idx = 0; outer_idx < outer_size; ++outer_idx) {
349 const scalar_t* lhs_outer = lhs + outer_idx * outer_stride_lhs;
350 scalar_t* output_data_row = output_data + outer_idx * outer_stride_lhs;
351 const scalar_t* rhs_outer = rhs + outer_idx * outer_stride_rhs;
352 for (int64_t broadcast_idx = 0; broadcast_idx < broadcast_size; ++broadcast_idx) {
353 const scalar_t* lhs_outer_2 = lhs_outer + broadcast_idx * broadcast_stride_lhs;
354 scalar_t* output_data_row_2 = output_data_row + broadcast_idx * broadcast_stride_lhs;
355 int64_t inner_idx = 0;
356 for (; inner_idx < inner_size - (inner_size % Vec::size()); inner_idx += Vec::size()) {
357 Vec data_vec = Vec::loadu(lhs_outer_2 + inner_idx);
358 Vec data_vec2 = Vec::loadu(rhs_outer + inner_idx);
359 Vec output_vec = vec_fun(data_vec, data_vec2);
360 output_vec.store(output_data_row_2 + inner_idx);
361 }
362 if (inner_size - inner_idx > 0) {
363 Vec data_vec = Vec::loadu(lhs_outer_2 + inner_idx, inner_size - inner_idx);
364 Vec data_vec2 = Vec::loadu(rhs_outer + inner_idx, inner_size - inner_idx);
365 Vec output_vec = vec_fun(data_vec, data_vec2);
366 output_vec.store(output_data_row_2 + inner_idx, inner_size - inner_idx);
367 }
368 }
369 }
370 }
371
372 template <typename scalar_t, typename Op>
373 inline void broadcasting_map_2d_by_1d(
374 const Op& vec_fun,
375 scalar_t* output_data,
376 const scalar_t* input_data,
377 const scalar_t* input_data2,
378 int64_t size,
379 int64_t size2) {
380 broadcasting_map_3d_and_unsqueezed_3d(vec_fun, output_data, input_data, input_data2, 1, size, size2);
381 }
382
383 /*
384 Following function is used to implement broadcasting binary operation on two tensors
385 where lhs tensor is treated to be of shape [outer_size, broadcast_size] and
386 rhs tensor is treated to be of shape [outer_size, 1]
387 Any two N dimensional tensors can be mapped to this formula
388 when lhs size = [lhs0, lhs1, ..., lhsN-1] and rhs size = [rhs0, rhs1, ..., 1]
389 by viewing the two tensors as
390 lhs size = [lsh0 * lsh1 * ... * lshN-2, lhsN-1]
391 rhs size = [rsh0 * rsh1 * ... * rshN-2, 1]
392 */
393 template <typename scalar_t, typename Op>
394 inline void broadcasting_map_broadcast_last_dim(
395 const Op& vec_fun,
396 scalar_t* output_data,
397 const scalar_t* lhs,
398 const scalar_t* rhs,
399 int64_t outer_size,
400 int64_t broadcast_size) {
401 using Vec = vec::Vectorized<scalar_t>;
402 int64_t outer_stride_lhs = broadcast_size;
403 for (int64_t outer_idx = 0; outer_idx < outer_size; ++outer_idx) {
404 const scalar_t* lhs_outer = lhs + outer_idx * outer_stride_lhs;
405 scalar_t* output_data_row = output_data + outer_idx * outer_stride_lhs;
406 int64_t inner_idx = 0;
407 Vec data_vec2 = Vec(rhs[outer_idx]);
408 for (; inner_idx < broadcast_size - (broadcast_size % Vec::size()); inner_idx += Vec::size()) {
409 Vec data_vec = Vec::loadu(lhs_outer + inner_idx);
410 Vec output_vec = vec_fun(data_vec, data_vec2);
411 output_vec.store(output_data_row + inner_idx);
412 }
413 if (broadcast_size - inner_idx > 0) {
414 Vec data_vec = Vec::loadu(lhs_outer + inner_idx, broadcast_size - inner_idx);
415 Vec output_vec = vec_fun(data_vec, data_vec2);
416 output_vec.store(output_data_row + inner_idx, broadcast_size - inner_idx);
417 }
418 }
419 }
420
421 } // namespace vec
422 } // namespace executorch
423