xref: /aosp_15_r20/external/pytorch/aten/src/ATen/native/cpu/Loops.h (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #pragma once
2 
3 // This file provides two functions to help write elementwise kernels:
4 //
5 //   cpu_kernel(TensorIterator iter, <lambda>)
6 //   cpu_kernel_vec(TensorIterator iter, <lambda>, <vec_lambda>)
7 //
8 // Both functions may generate vectorized code. The cpu_kernel implementation
9 // relies on the compiler's auto-vectorization. The cpu_kernel_vec
10 // implementation uses x86 SIMD intrinsics when available. These functions
11 // are only intended to be used in the ATen/native/cpu subdirectory, since files
12 // in other directories are not compiled with AVX/AVX2 enabled. See README.md
13 // for more details.
14 //
15 // For example, to write a multiplication kernel for float:
16 //
17 //   cpu_kernel(iter, [](float a, float b) { return a * b; });
18 //
19 // Or you may write:
20 //
21 //   cpu_kernel_vec(iter,
22 //     [](float a, float b) { return a * b; },
23 //     [](Vectorized<float> a, Vectorized<float> b) { return a * b; });
24 //
25 // See BinaryOpsKernel.cpp for the complete implementation
26 //
27 //
28 
29 #include <cstdint>
30 #include <c10/util/C++17.h>
31 #include <c10/util/Load.h>
32 #include <c10/util/irange.h>
33 #include <ATen/detail/FunctionTraits.h>
34 #include <ATen/native/cpu/IsContiguous.h>
35 #include <ATen/native/TensorIterator.h>
36 #include <ATen/native/TensorIteratorDynamicCasting.h>
37 #include <ATen/cpu/vec/vec.h>
38 
39 #include <utility>
40 
41 namespace at::native { inline namespace CPU_CAPABILITY {
42 
43 using namespace vec;
44 
45 template <typename traits, std::size_t... INDEX>
46 typename traits::ArgsTuple
dereference_impl(char * C10_RESTRICT data[],const int64_t * strides,int64_t i,std::index_sequence<INDEX...>)47 dereference_impl(char* C10_RESTRICT data[], const int64_t* strides, int64_t i,
48                  std::index_sequence<INDEX...>) {
49   return std::make_tuple(
50       c10::load<typename traits::template arg<INDEX>::type>(
51           data[INDEX] + i * strides[INDEX])...);
52 }
53 
54 template <typename traits>
55 typename traits::ArgsTuple
dereference(char * C10_RESTRICT data[],const int64_t * strides,int64_t i)56 dereference(char* C10_RESTRICT data[], const int64_t* strides, int64_t i) {
57   using Indices = std::make_index_sequence<traits::arity>;
58   return dereference_impl<traits>(data, strides, i, Indices{});
59 }
60 
61 template <typename traits, std::size_t... INDEX>
62 typename traits::ArgsTuple
dereference_vec_impl(char * C10_RESTRICT data[],const typename traits::result_type & opt_scalar,size_t S,int64_t i,std::index_sequence<INDEX...>)63 dereference_vec_impl(char* C10_RESTRICT data[],
64                      const typename traits::result_type& opt_scalar,
65                      size_t S,
66                      int64_t i,
67                      std::index_sequence<INDEX...>) {
68   using Vec = typename traits::result_type;
69   using scalar_t = typename Vec::value_type;
70   return std::make_tuple(
71       S == INDEX + 1 ?
72       opt_scalar :
73       Vec::loadu(data[INDEX] + i * sizeof(scalar_t))...);
74 }
75 
76 template <typename traits>
77 typename traits::ArgsTuple
dereference_vec(char * C10_RESTRICT data[],const typename traits::result_type & opt_scalar,size_t S,int64_t i)78 dereference_vec(char* C10_RESTRICT data[], const typename traits::result_type& opt_scalar, size_t S, int64_t i) {
79   using Indices = std::make_index_sequence<traits::arity>;
80   return dereference_vec_impl<traits>(data, opt_scalar, S, i, Indices{});
81 }
82 
83 template <typename func_t,
84     std::enable_if_t<!std::is_void_v<typename function_traits<func_t>::result_type>>* = nullptr>
85 inline void
execute_op(char * C10_RESTRICT data[],const int64_t * strides,int64_t i,int64_t n,func_t && op)86 execute_op(char* C10_RESTRICT data[], const int64_t* strides, int64_t i, int64_t n, func_t&& op) {
87   using traits = function_traits<func_t>;
88   using result_type = typename traits::result_type;
89   for (; i < n; i++) {
90     result_type* out_ptr = (result_type*)(data[0] + i * strides[0]);
91     *out_ptr = c10::guts::apply(op, dereference<traits>(
92         &data[1],
93         &strides[1],
94         i));
95   }
96 }
97 
98 template <typename func_t,
99     std::enable_if_t<std::is_void_v<typename function_traits<func_t>::result_type>>* = nullptr>
100 inline void
execute_op(char * C10_RESTRICT data[],const int64_t * strides,int64_t i,int64_t n,func_t && op)101 execute_op(char* C10_RESTRICT data[], const int64_t* strides, int64_t i, int64_t n, func_t&& op) {
102   using traits = function_traits<func_t>;
103   for (; i < n; i++) {
104     c10::guts::apply(op, dereference<traits>(
105         &data[0],
106         &strides[0],
107         i));
108   }
109 }
110 
111 // Basic loop operation (one output, N inputs). May be auto-vectorized
112 // by the compiler. Supports inputs and outputs of different types.
113 template <typename func_t>
114 inline void
basic_loop(char * C10_RESTRICT data[],const int64_t * strides_,int64_t i,int64_t n,func_t && op)115 basic_loop(char* C10_RESTRICT data[], const int64_t* strides_, int64_t i, int64_t n, func_t&& op) {
116   using traits = function_traits<func_t>;
117   constexpr int ntensors = traits::arity + 1;
118 
119   // Copying strides to temporary array helps auto vectorization in older GCC
120   // versions.
121   int64_t strides[ntensors];
122   for (const auto arg : c10::irange(ntensors)) {
123     strides[arg] = strides_[arg];
124   }
125 
126   execute_op(data, strides, i, n, std::forward<func_t>(op));
127 }
128 
129 // the recursive variadic template for iterating over the returned tuple
130 template<class T, size_t N>
131 struct TupleOutput {
handleTupleOutput132   static void handle(char *C10_RESTRICT data[], const int64_t *strides, int64_t i,
133                      const T &tuple) {
134     TupleOutput<T, N - 1>::handle(data, strides, i, tuple);
135 
136     auto output = std::get<N - 1>(tuple);
137     using output_type = decltype(output);
138     output_type * out_ptr = (output_type *)(data[N - 1] + i * strides[N - 1]);
139     *out_ptr = output;
140   }
141 };
142 
143 // Base case for the above recursive template
144 template<class T>
145 struct TupleOutput<T, 1> {
146   static void handle(char *C10_RESTRICT data[], const int64_t *strides, int64_t i,
147                      const T &tuple) {
148     auto output = std::get<0>(tuple);
149     using output_type = decltype(output);
150     output_type* out_ptr = (output_type *)(data[0] + i * strides[0]);
151     *out_ptr = output;
152   }
153 };
154 
155 template<class... Args>
156 void handle_tuple_outputs(char* C10_RESTRICT data[],
157                           const int64_t* strides,
158                           int64_t i,
159                           const std::tuple<Args...> &tuple) {
160   TupleOutput<decltype(tuple), sizeof...(Args)>::handle(data, strides, i, tuple);
161 }
162 
163 // Loop operation for `cpu_kernel_multiple_outputs`.
164 // 1. Use `c10::guts::apply` to make dynamic method invocation
165 //    for the lambda passed in `cpu_kernel_multiple_outputs`.
166 // 2. Iterate over the members of the returned tuple, set the corresponding
167 //    output tensor by the tuple member in `handle_tuple_outputs` function.
168 template <typename func_t>
169 inline void
170 multiple_outputs_loop(char* C10_RESTRICT data[], const int64_t* strides_, int64_t i, int64_t n, func_t&& op) {
171   using traits = function_traits<func_t>;
172 
173   using result_type = typename traits::result_type;
174   constexpr int num_outputs = std::tuple_size<result_type>::value;
175   constexpr int ntensors = traits::arity + num_outputs;
176 
177   // Copying strides to temporary array helps auto vectorization in older GCC
178   // versions.
179   int64_t strides[ntensors];
180   for (const auto arg : c10::irange(ntensors)) {
181     strides[arg] = strides_[arg];
182   }
183 
184   for (; i < n; i++) {
185     auto output = c10::guts::apply(op, dereference<traits>(
186       &data[num_outputs],
187       &strides[num_outputs],
188       i));
189     handle_tuple_outputs(data, strides, i, output);
190   }
191 }
192 
193 // Explicitly vectorized loop implementation. All inputs and outputs must be
194 // the same type and contiguous with one exception: a single input may be
195 // a scalar (stride 0). It's position is indicated by the argument `S`. If `S`
196 // is 0, then there are no scalar inputs.
197 template <typename func_t, typename vec_func_t>
198 inline void
199 vectorized_loop(char** C10_RESTRICT data_, int64_t n, int64_t S, func_t&& op, vec_func_t&& vop) {
200   using traits = function_traits<vec_func_t>;
201   using scalar_t = typename function_traits<func_t>::result_type;
202   using Vec = Vectorized<scalar_t>;
203   constexpr int ntensors = traits::arity + 1;
204 
205   char* C10_RESTRICT data[ntensors];
206   for (const auto arg : c10::irange(ntensors)) {
207     data[arg] = data_[arg];
208   }
209 
210   Vec opt_scalar = Vec(S > 0 ? *(scalar_t*)data[S] : scalar_t(0));
211   int64_t i = 0;
212   for (; i <= n - 2 * Vec::size(); i += 2 * Vec::size()) {
213     auto args1 = dereference_vec<traits>(&data[1], opt_scalar, S, i);
214     auto args2 = dereference_vec<traits>(&data[1], opt_scalar, S, i + Vec::size());
215     auto out1 = c10::guts::apply(vop, std::move(args1));
216     auto out2 = c10::guts::apply(vop, std::move(args2));
217     out1.store(data[0] + i * sizeof(scalar_t));
218     out2.store(data[0] + (i + Vec::size()) * sizeof(scalar_t));
219   }
220   if (i < n) {
221     int64_t strides[ntensors];
222     for (const auto arg : c10::irange(ntensors)) {
223       strides[arg] = (S > 0 && arg == S) ? 0 : sizeof(scalar_t);
224     }
225     basic_loop(data, strides, i, n, std::forward<func_t>(op));
226   }
227 }
228 
229 
230 template <typename traits, typename cb_t>
231 inline void unroll_contiguous_scalar_checks(
232     const int64_t* /*strides*/,
233     std::index_sequence<>,
234     cb_t&& cb) {
235   cb(0);
236 }
237 
238 template <typename traits, typename cb_t, size_t INDEX0, size_t ...INDEX>
239 inline void unroll_contiguous_scalar_checks(
240     const int64_t* strides,
241     std::index_sequence<INDEX0, INDEX...>,
242     cb_t&& cb) {
243   if (is_contiguous_scalar<traits, INDEX0 + 1>(strides)) {
244     cb(INDEX0 + 1);
245   } else {
246     unroll_contiguous_scalar_checks<traits>(strides, std::index_sequence<INDEX...>{}, std::forward<cb_t>(cb));
247   }
248 }
249 
250 template <typename op_t, typename vop_t>
251 struct VectorizedLoop2d {
252   op_t op;
253   vop_t vop;
254 
255   using traits = function_traits<op_t>;
256   static constexpr int ntensors = traits::arity + 1;
257   using data_t = std::array<char*, ntensors>;
258 
259   VectorizedLoop2d(op_t op, vop_t vop):
260     op(std::move(op)), vop(std::move(vop)) {}
261 
262   static void advance(data_t &data, const int64_t *outer_strides) {
263     for (const auto arg : c10::irange(data.size())) {
264       data[arg] += outer_strides[arg];
265     }
266   }
267 
268   void operator()(char** base, const int64_t *strides, int64_t size0, int64_t size1) {
269     data_t data;
270     std::copy_n(base, ntensors, data.data());
271     const int64_t *outer_strides = &strides[ntensors];
272 
273     if (is_contiguous<traits>(strides)) {
274       for (const auto i C10_UNUSED : c10::irange(size1)) {
275         vectorized_loop(data.data(), size0, 0, op, vop);
276         advance(data, outer_strides);
277       }
278     } else {
279       using Indices = std::make_index_sequence<traits::arity>;
280       unroll_contiguous_scalar_checks<traits>(strides, Indices{}, [&](size_t idx) {
281         if (idx) {
282           for (const auto i C10_UNUSED : c10::irange(size1)) {
283             vectorized_loop(data.data(), size0, idx, op, vop);
284             advance(data, outer_strides);
285           }
286         } else {
287           for (const auto i C10_UNUSED : c10::irange(size1)) {
288             basic_loop(data.data(), strides, 0, size0, op);
289             advance(data, outer_strides);
290           }
291         }
292       });
293     }
294   }
295 };
296 
297 template <typename op_t, typename vop_t>
298 VectorizedLoop2d<op_t, vop_t> make_vectorized_loop2d(
299     op_t &&op, vop_t &&vop) {
300   return VectorizedLoop2d<op_t, vop_t>(std::forward<op_t>(op), std::forward<vop_t>(vop));
301 }
302 
303 template <typename func_t>
304 void cpu_kernel(TensorIteratorBase& iter, func_t&& op, int64_t grain_size = at::internal::GRAIN_SIZE) {
305   using traits = function_traits<func_t>;
306   // this could be extended to work with void return types
307   TORCH_INTERNAL_ASSERT(iter.ninputs() == traits::arity);
308   TORCH_INTERNAL_ASSERT(iter.noutputs() == 1);
309   // dynamic casting not currently supported on CPU
310   TORCH_INTERNAL_ASSERT(!needs_dynamic_casting<func_t>::check(iter));
311 
312   iter.for_each([&](char** data, const int64_t* strides, int64_t n) {
313     // basic loop can handle 1d slices with arbitrary strides, and 1d slices is all that
314     // iter.for_each is ever sending to the loop lambda
315       basic_loop(data, strides, 0, n, op);
316   }, grain_size);
317   iter.cast_outputs();
318 }
319 
320 // This function helps write elementwise kernels that requires multiple outputs.
321 // It follows the similar structure of cpu_kernel.
322 // Instead of `basic_loop` function, a new `multiple_outputs_loop` function is
323 // manipulated to handle multiple return values.
324 // For now `needs_dynamic_casting` check is not added as the passed lambda (`func_t`)
325 // of `multiple_outputs_loop` returns `std::tuple` instead of `scalar_t`.
326 // The `gpu_kernel_multiple_outputs` is also implemented without this check,
327 // We could extend `needs_dynamic_casting` to support both `std::tuple` and
328 // `thrust::tuple` in the future.
329 template <typename func_t>
330 void cpu_kernel_multiple_outputs(TensorIteratorBase& iter, func_t&& op, int64_t grain_size = at::internal::GRAIN_SIZE) {
331   using traits = function_traits<func_t>;
332   TORCH_INTERNAL_ASSERT(iter.ninputs() == traits::arity);
333 
334   iter.for_each([&](char** data, const int64_t* strides, int64_t n) {
335     multiple_outputs_loop(data, strides, 0, n, op);
336   }, grain_size);
337   iter.cast_outputs();
338 }
339 
340 template <bool check_dynamic_cast=true, typename func_t, typename vec_func_t>
341 void cpu_kernel_vec(TensorIteratorBase& iter, func_t&& op, vec_func_t&& vop, int64_t grain_size = at::internal::GRAIN_SIZE) {
342   using traits = function_traits<func_t>;
343   // this could be extended to work with void return types
344   TORCH_INTERNAL_ASSERT(iter.ninputs() == traits::arity);
345   TORCH_INTERNAL_ASSERT(iter.noutputs() == 1);
346   // dynamic casting not currently supported on CPU, but some kernels (like Fill)
347   // explicitly dynamic_cast, so we give the opt-out of checking.
348   if constexpr (check_dynamic_cast) {
349     TORCH_INTERNAL_ASSERT(!needs_dynamic_casting<func_t>::check(iter));
350   }
351 
352   iter.for_each(make_vectorized_loop2d(std::forward<func_t>(op), std::forward<vec_func_t>(vop)), grain_size);
353   iter.cast_outputs();
354 }
355 
356 template <typename func_t>
357 void cpu_serial_kernel(TensorIteratorBase& iter, func_t&& op, const Range& range) {
358   using traits = function_traits<func_t>;
359   constexpr bool result_void = std::is_void_v<typename traits::result_type>;
360   TORCH_INTERNAL_ASSERT(iter.ninputs() == traits::arity &&
361                         ((result_void && iter.noutputs() == 0) || (!result_void && iter.noutputs() == 1)));
362   // dynamic casting not currently supported on CPU
363   TORCH_INTERNAL_ASSERT(!needs_dynamic_casting<func_t>::check(iter));
364 
365   iter.serial_for_each([&](char** data, const int64_t* strides, int64_t n) {
366     basic_loop(data, strides, 0, n, op);
367   }, range);
368   iter.cast_outputs();
369 }
370 
371 template <typename func_t>
372 void cpu_serial_kernel(TensorIteratorBase& iter, func_t&& op) {
373   cpu_serial_kernel(iter, std::forward<func_t>(op), {0, iter.numel()});
374 }
375 
376 template <typename func_t, typename vec_func_t>
377 void cpu_serial_kernel_vec(TensorIteratorBase& iter, func_t&& op, vec_func_t&& vop, const Range& range) {
378   using traits = function_traits<func_t>;
379   // this could be extended to work with void return types
380   TORCH_INTERNAL_ASSERT(iter.ninputs() == traits::arity);
381   TORCH_INTERNAL_ASSERT(iter.noutputs() == 1);
382   // dynamic casting not currently supported on CPU
383   TORCH_INTERNAL_ASSERT(!needs_dynamic_casting<func_t>::check(iter));
384 
385   iter.serial_for_each(make_vectorized_loop2d(std::forward<func_t>(op), std::forward<vec_func_t>(vop)), range);
386   iter.cast_outputs();
387 }
388 
389 template <typename func_t, typename vec_func_t>
390 void cpu_serial_kernel_vec(TensorIteratorBase& iter, func_t&& op, vec_func_t&& vop) {
391   cpu_serial_kernel_vec(iter, std::forward<func_t>(op), std::forward<vec_func_t>(vop), {0, iter.numel()});
392 }
393 
394 }} // namespace at::native::<anonymous>
395