xref: /aosp_15_r20/external/pytorch/aten/src/ATen/native/cpu/Reduce.h (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #pragma once
2 
3 #include <ATen/native/cpu/Loops.h>
4 #include <ATen/Parallel.h>
5 #include <c10/util/TypeList.h>
6 #include <c10/core/Scalar.h>
7 #include <c10/util/irange.h>
8 
9 #include <sstream>
10 #include <type_traits>
11 
12 namespace at { namespace native { inline namespace CPU_CAPABILITY {
13 
14 using namespace vec;
15 
16 #define VEC_LOOP_HEADER(func_t, data) \
17   using scalar_t = typename function_traits<func_t>::result_type; \
18   using Vec = Vectorized<scalar_t>; \
19   char* out_ptr = data[0]; \
20   (void) out_ptr;
21 
22 // reduction that is contiguous over the input in dim 0
23 template <typename traits>
is_contiguous_reduction(const int64_t * strides)24 inline bool is_contiguous_reduction(const int64_t* strides) {
25   return strides[0] == 0 &&
26          strides[1] == sizeof(typename traits::arg2_t);
27 }
28 
29 // reduction that is contiguous over the input in dim 1
30 template <typename traits>
is_outer_reduction(const int64_t * strides)31 inline bool is_outer_reduction(const int64_t* strides) {
32   return strides[0] == 0 &&
33          strides[2] == sizeof(typename traits::result_type) &&
34          strides[3] == sizeof(typename traits::arg2_t);
35 }
36 
37 template <typename func_t, typename vec_func_t>
vectorized_reduction(char ** data,int64_t n,int64_t stride,func_t op,vec_func_t vop,bool reduce)38 inline void vectorized_reduction(char** data, int64_t n, int64_t stride,
39                                         func_t op, vec_func_t vop, bool reduce) {
40   VEC_LOOP_HEADER(func_t, data)
41   const char* in1_ptr = data[1];
42   Vec acc[4];
43   for (const auto j : c10::irange(4)) {
44     acc[j] = Vec::loadu(in1_ptr + j * Vec::size() * sizeof(scalar_t));
45   }
46   for (const auto i : c10::irange(1, n)) {
47     const char* ptr = in1_ptr + stride * i;
48     acc[0] = vop(acc[0], Vec::loadu(ptr + (0 * Vec::size() * sizeof(scalar_t))));
49     acc[1] = vop(acc[1], Vec::loadu(ptr + (1 * Vec::size() * sizeof(scalar_t))));
50     acc[2] = vop(acc[2], Vec::loadu(ptr + (2 * Vec::size() * sizeof(scalar_t))));
51     acc[3] = vop(acc[3], Vec::loadu(ptr + (3 * Vec::size() * sizeof(scalar_t))));
52   }
53   if (reduce) {
54     scalar_t buffer[Vec::size()];
55     acc[0] = vop(vop(acc[0], acc[1]), vop(acc[2], acc[3]));
56     acc[0].store(buffer);
57     for (const auto j : c10::irange(1, Vec::size())) {
58       buffer[0] = op(buffer[0], buffer[j]);
59     }
60     auto dst = (scalar_t*)out_ptr;
61     *dst = op(*dst, buffer[0]);
62   } else {
63     for (const auto j : c10::irange(4)) {
64       auto dst = out_ptr + j * Vec::size() * sizeof(scalar_t);
65       acc[j] = vop(acc[j], Vec::loadu(dst));
66       acc[j].store(dst);
67     }
68   }
69 }
70 
71 template <typename F>
UNARY_OUTER_LOOP(char * data[2],const int64_t strides[2],int64_t n,F f)72 inline void UNARY_OUTER_LOOP(char* data[2], const int64_t strides[2], int64_t n, F f) {
73   for (const auto j C10_UNUSED : c10::irange(n)) {
74     f();
75     data[0] += strides[0];
76     data[1] += strides[1];
77   }
78 }
79 
80 // computes the reduction out = op(out, in)
81 template <typename func_t, typename vec_func_t>
vectorized_inner_reduction(char ** data,int64_t n,func_t op,vec_func_t vop)82 inline void vectorized_inner_reduction(char** data, int64_t n, func_t op, vec_func_t vop) {
83   VEC_LOOP_HEADER(func_t, data)
84   int64_t vector_stride = 4 * Vec::size() * sizeof(scalar_t);
85   int64_t count = n / (4 * Vec::size());
86   if (count > 0) {
87     vectorized_reduction(data, count, vector_stride, op, vop, /*reduce=*/true);
88   }
89   char* ptrs[3] = { data[0], data[0], data[1] };
90   int64_t strides[] = { 0, 0, sizeof(scalar_t) };
91   basic_loop(ptrs, strides, count * 4 * Vec::size(), n, op);
92 }
93 
94 // computes the reduction out = op(out, in)
95 template <typename func_t, typename vec_func_t>
vectorized_outer_reduction(char ** data,int64_t inner_stride,int64_t size0,int64_t size1,func_t op,vec_func_t vop)96 inline void vectorized_outer_reduction(char** data, int64_t inner_stride, int64_t size0, int64_t size1, func_t op, vec_func_t vop) {
97   VEC_LOOP_HEADER(func_t, data)
98 
99   // reduce down each column of 4 * Vec::size() elements (128 or 256 bytes)
100 #if defined(CPU_CAPABILITY_AVX512)
101   int64_t outer_stride[2] = { 256, 256 };
102 #else
103   int64_t outer_stride[2] = { 128, 128 };
104 #endif
105   UNARY_OUTER_LOOP(data, outer_stride, size1 / (4 * Vec::size()), [&] {
106     vectorized_reduction(data, size0, inner_stride, op, vop, /*reduce=*/false);
107   });
108 
109   // reduce down the remaining columns
110   int64_t step[] = { sizeof(scalar_t), sizeof(scalar_t) };
111   int64_t remaining = size1 % (4 * Vec::size());
112   UNARY_OUTER_LOOP(data, step, remaining, [&] {
113     char* ptrs[3] = { data[0], data[0], data[1] };
114     int64_t strides[] = { 0, 0, inner_stride };
115     basic_loop(ptrs, strides, 0, size0, op);
116   });
117 }
118 
119 template<typename traits, typename res_t>
set_result(const int index,const res_t result,const TensorIteratorBase & iter,const int num_outputs)120 static void set_result(const int index, const res_t result, const TensorIteratorBase &iter, const int num_outputs) {
121   // static_assert(std::is_same<res_t, typename traits::arg2_t>::value, "data types must match");
122   if (index < num_outputs) {
123     char *out = (char *) iter.data_ptr(index);
124     *(res_t *) out = result;
125   }
126 }
127 
128 template<typename traits, typename res_t>
set_results(const res_t result,const TensorIteratorBase & iter,const int num_outputs)129 static void set_results(const res_t result, const TensorIteratorBase &iter, const int num_outputs) {
130   AT_ASSERT(num_outputs == 1);
131   set_result<traits>(0, result, iter, num_outputs);
132 }
133 
134 template<typename traits, std::size_t i = 0, typename... tuple_t>
135 inline typename std::enable_if<i == sizeof...(tuple_t), std::size_t>::type
for_each_in_tuple(const std::tuple<tuple_t...> &,const TensorIteratorBase &,const int)136 for_each_in_tuple(const std::tuple<tuple_t...>& /*t*/, const TensorIteratorBase& /*iter*/, const int /*num_outputs*/) {
137   return i;
138 }
139 
140 template<typename traits, std::size_t i = 0, typename... tuple_t>
141 inline typename std::enable_if<i < sizeof...(tuple_t), std::size_t>::type
142 for_each_in_tuple(const std::tuple<tuple_t...>& t, const TensorIteratorBase &iter, const int num_outputs) {
143   if (i < (size_t)num_outputs) {
144     set_result<traits>(i, std::get<i>(t), iter, num_outputs);
145     return for_each_in_tuple<traits, i + 1, tuple_t...>(t, iter, num_outputs);
146   }
147   return i;
148 }
149 
150 template<typename traits, typename... res_t>
set_results(const std::tuple<res_t...> & result,const TensorIteratorBase & iter,const int num_outputs)151 static void set_results(const std::tuple<res_t...>& result, const TensorIteratorBase &iter, const int num_outputs) {
152   AT_ASSERT(num_outputs >= 1);
153   std::size_t result_size = for_each_in_tuple<traits>(result, iter, num_outputs);
154   AT_ASSERT((size_t)num_outputs == result_size);
155 }
156 
157 template <typename T, typename... Args>
158 struct all_same : std::conjunction<
159   std::is_same<T, Args>...
160 > {};
161 
162 // data_t is the input/output data type.
163 // acc_t is a type that contains all the necessary data
164 // to continue reducing.
165 // index_t is a one-dimensional index
166 //
167 // ops_t is such that &ops_t::reduce, &ops_t::combine, and &ops_t::project exist and satisfy
168 // the following.
169 // reduce: (acc_t, data_t, index_t) -> acc_t adds one data point to the accumulated value.
170 // combine: (acc_t, acc_t) -> acc_t combines two accumulated values into one.
171 // project: acc_t -> out_t finishes the reduction, getting the required output.
172 //
173 // Additionally, acc_t must be default-constructible:
174 // acc_t {} is an identity for combine,
175 // and project(acc_t {}) is the value of the operation on zero elements.
176 //
177 // The point of `combine` is to support parallelization -
178 // the idea is to one sequence of `reduce` calls per thread of execution,
179 // and then to combine them at the end with `combine`.
180 //
181 // If there is more than one output element,
182 // our parallelization strategy is to use one thread for each of them,
183 // which means that `combine` will never be called.
184 //
185 // If, on the other hand, there is only one, then we split the input into
186 // into several pieces, reduce each separately, and then combine them.
187 
188 template <typename ops_t, typename init_t>
binary_kernel_reduce(TensorIteratorBase & iter,ops_t ops,init_t init)189 void binary_kernel_reduce(TensorIteratorBase& iter, ops_t ops, init_t init) {
190   using rf_t = decltype(&ops_t::reduce);
191   using cf_t = decltype(&ops_t::combine);
192   using pf_t = decltype(&ops_t::project);
193   using r_traits = binary_function_traits<rf_t>;
194   using c_traits = binary_function_traits<cf_t>;
195   using p_traits = unary_function_traits<pf_t>;
196   using acc_t = typename p_traits::arg1_t;
197   using data_t = typename r_traits::arg2_t;
198   static_assert(
199     all_same<
200       acc_t,
201       init_t,
202       typename r_traits::arg1_t,
203       typename r_traits::result_type,
204       typename c_traits::arg1_t,
205       typename c_traits::arg2_t,
206       typename c_traits::result_type>::value,
207     "all accumulate types must match");
208   static_assert(
209     std::is_default_constructible<acc_t>::value,
210     "the accumulate type must be default-constructible"
211   );
212   const int num_outputs = iter.noutputs();
213   iter.foreach_reduced_elt([&ops, &init, num_outputs](TensorIteratorBase &sub_iter) {
214     auto reduction_body = [&ops, &sub_iter, num_outputs](acc_t acc, int64_t begin, int64_t end) -> acc_t {
215       int ntensors = sub_iter.ntensors();
216       sub_iter.serial_for_each([&acc, &ops, num_outputs, ntensors, begin](char** data, const int64_t* strides, int64_t size) {
217         AT_ASSERT(ntensors - num_outputs == 1);
218         char *in = data[ntensors - 1];
219         int64_t stride = strides[ntensors - 1];
220         for (const auto i : c10::irange(size)) {
221           acc = ops.reduce(acc, c10::load<data_t>(in), begin + i);
222           in += stride;
223         }
224       }, {begin, end});
225       return ops.translate_idx(acc, sub_iter.view_offsets()[0]);
226     };
227     acc_t total_acc = init;
228     auto numel = sub_iter.numel();
229     if (numel < at::internal::GRAIN_SIZE || at::get_num_threads() == 1 ||
230         at::in_parallel_region()) {
231       total_acc = reduction_body(total_acc, 0, numel);
232     } else {
233       int max_threads = at::get_num_threads();
234       AT_ASSERT(max_threads > 0);
235       static_assert(
236         !std::is_same<acc_t, bool>::value,
237         "Concurrently modifying different references into std::vector<bool> is UB."
238       );
239       std::vector<acc_t> buffer((unsigned)max_threads, init);
240       at::parallel_for(0, numel, internal::GRAIN_SIZE,
241         [&](int64_t begin, int64_t end) {
242           auto& acc = buffer[at::get_thread_num()];
243           acc = reduction_body(acc, begin, end);
244         }
245       );
246       for (const auto i : c10::irange(max_threads)) {
247         total_acc = ops.combine(total_acc, buffer[i]);
248       }
249     }
250     set_results<r_traits>(ops.project(total_acc), sub_iter, num_outputs);
251   });
252 }
253 
254 template <typename func_t, typename vec_func_t>
255 void binary_kernel_reduce_vec(TensorIteratorBase& iter, func_t op, vec_func_t vop, double ident = 0) {
256   using traits = binary_function_traits<func_t>;
257   static_assert(
258     all_same<
259       typename traits::result_type,
260       typename traits::arg1_t,
261       typename traits::arg2_t>::value,
262     "all types must match");
263 
264   iter.output_base().fill_(ident);
265   iter.parallel_reduce([&](char** data, const int64_t* strides, int64_t size0, int64_t size1) {
266     int64_t outer_strides[] = { strides[2], strides[3] };
267     if (is_contiguous_reduction<traits>(strides)) {
268       // input is contiguous in dim 0, output is reduced in dim 0
269       UNARY_OUTER_LOOP(data, outer_strides, size1, [&] {
270         vectorized_inner_reduction(data, size0, op, vop);
271       });
272     } else if (is_outer_reduction<traits>(strides)) {
273       // input and output are contiguous in dim 1
274       int64_t inner_stride = strides[1]; // stride of input in dim 0
275       vectorized_outer_reduction(data, inner_stride, size0, size1, op, vop);
276     } else {
277       UNARY_OUTER_LOOP(data, outer_strides, size1, [&] {
278         char* ptrs[3] = { data[0], data[0], data[1] };
279         int64_t inner_strides[3] = { strides[0], strides[0], strides[1] };
280         basic_loop(ptrs, inner_strides, 0, size0, op);
281       });
282     }
283   });
284 }
285 
286 // when reduction is on most inner dimension (dim 0 in TensorIterator)
287 // and input has contiguous most inner dimension, `binary_kernel_reduce_lastdim`
288 // can be used.
is_reduce_lastdim(TensorIteratorBase & iter)289 inline bool is_reduce_lastdim(TensorIteratorBase& iter) {
290   return iter.num_reduce_dims() == 1 && iter.is_dim_reduced(0)
291       && iter.ninputs() == 1 && iter.strides(1)[0] == iter.element_size(1);
292 }
293 
294 template <typename reduce_func_t>
binary_kernel_reduce_lastdim(TensorIteratorBase & iter,reduce_func_t reduce_op)295 void binary_kernel_reduce_lastdim(TensorIteratorBase& iter, reduce_func_t reduce_op) {
296   auto shape = iter.shape();
297   int64_t dim_size = shape[0];
298   int64_t grain_size = std::max((int64_t) 1, at::internal::GRAIN_SIZE / dim_size);
299   TensorIterator sub_iter(iter);
300   // create sub iterator to parallel on all non-reduce-dims
301   sub_iter.narrow(0, 0, 1);
302   auto loop = [&](char** data, const int64_t* strides, int64_t size) {
303     char* out = data[0];
304     char* in = data[1];
305     for (int64_t i = 0; i < size; ++i) {
306       reduce_op(out, in, dim_size);
307       out += strides[0];
308       in += strides[1];
309     }
310   };
311   sub_iter.for_each(loop, grain_size);
312 }
313 
314 }}}  // namespace at::native::<anonymous>
315