xref: /aosp_15_r20/external/pytorch/aten/src/ATen/native/cpu/ReduceOpsKernel.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #define TORCH_ASSERT_ONLY_METHOD_OPERATORS
2 #include <algorithm>
3 
4 #include <ATen/core/Tensor.h>
5 #include <ATen/Dispatch.h>
6 #include <ATen/OpMathType.h>
7 #include <ATen/cpu/vec/vec.h>
8 #include <ATen/cpu/vec/functional.h>
9 #include <ATen/native/ReduceOps.h>
10 #include <ATen/native/Resize.h>
11 #include <ATen/native/TensorIterator.h>
12 #include <ATen/native/SharedReduceOps.h>
13 #include <ATen/native/ReduceOpsUtils.h>
14 #include <ATen/native/cpu/Reduce.h>
15 #include <ATen/native/cpu/LogAddExp.h>
16 
17 #ifndef AT_PER_OPERATOR_HEADERS
18 #include <ATen/Functions.h>
19 #else
20 #include <ATen/ops/imag.h>
21 #endif
22 
23 #include <c10/util/irange.h>
24 #include <ATen/AccumulateType.h>
25 
26 namespace at::native { namespace {
27 
28 using namespace vec;
29 
30 template <typename scalar_t, typename func_t>
cpu_cum_base_kernel(const Tensor & result,const Tensor & self,int64_t dim,const func_t & f,scalar_t init_val)31 static inline void cpu_cum_base_kernel(const Tensor& result,
32     const Tensor& self,
33     int64_t dim,
34     const func_t& f,
35     scalar_t init_val) {
36   if (result.sizes() != self.sizes()) {
37     at::native::resize_output(result, self.sizes());
38   }
39   if (self.numel() == 0) {
40     return;
41   }
42   const auto input_ndim = self.dim();
43   if (input_ndim == 0) {
44     result.fill_(self);
45     return;
46   }
47 
48   // TODO This probably should be using at::native::make_reduction
49   auto iter = TensorIteratorConfig()
50     .check_all_same_dtype(false)
51     .resize_outputs(false)
52     // NOLINTNEXTLINE(bugprone-argument-comment)
53     .declare_static_shape(self.sizes(), /*squash_dim=*/dim)
54     .add_output(result)
55     .add_const_input(self)
56     .build();
57 
58   auto result_dim_stride = ensure_nonempty_stride(result, dim);
59   auto self_dim_stride = ensure_nonempty_stride(self, dim);
60 
61   auto loop = [&](char** data, const int64_t* strides, int64_t n) {
62     auto* result_data_bytes = data[0];
63     const auto* self_data_bytes = data[1];
64 
65     for (const auto i C10_UNUSED : c10::irange(n)) {
66       f(
67         (scalar_t*)result_data_bytes, result_dim_stride,
68         (scalar_t*)self_data_bytes, self_dim_stride, init_val
69       );
70       result_data_bytes += strides[0];
71       self_data_bytes += strides[1];
72     }
73   };
74 
75   int64_t grain_size = internal::GRAIN_SIZE / std::max(int64_t{1}, self.size(dim));
76   iter.for_each(loop, grain_size);
77 }
78 
cumsum_cpu_kernel(const Tensor & result,const Tensor & self,int64_t dim)79 static void cumsum_cpu_kernel(const Tensor& result, const Tensor& self, int64_t dim) {
80   auto wrap_dim = maybe_wrap_dim(dim, self.dim());
81   int64_t self_dim_size = ensure_nonempty_size(self, wrap_dim);
82 
83   AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND2(kBFloat16, kHalf, self.scalar_type(), "cumsum_out_cpu", [&] {
84     cpu_cum_base_kernel<scalar_t>(result, self, wrap_dim, [&] (
85       scalar_t* result_data, auto result_dim_stride,
86       const scalar_t* self_data, auto self_dim_stride, scalar_t init_val) {
87         // NOLINTNEXTLINE(bugprone-signed-char-misuse)
88         auto cum_number = (at::acc_type<scalar_t, false>)init_val;
89         for (const auto i : c10::irange(self_dim_size)) {
90           cum_number += self_data[i * self_dim_stride];
91           result_data[i * result_dim_stride] = (scalar_t)cum_number;
92         }
93       }, /*init_val=*/ 0
94     );
95   });
96 }
97 
cumprod_cpu_kernel(const Tensor & result,const Tensor & self,int64_t dim)98 static void cumprod_cpu_kernel(const Tensor& result, const Tensor& self, int64_t dim) {
99   auto wrap_dim = maybe_wrap_dim(dim, self.dim());
100   int64_t self_dim_size = ensure_nonempty_size(self, wrap_dim);
101 
102   AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND2(kBFloat16, kHalf, self.scalar_type(), "cumprod_out_cpu", [&] {
103     cpu_cum_base_kernel<scalar_t>(result, self, wrap_dim, [&] (
104       scalar_t* result_data, auto result_dim_stride,
105       const scalar_t* self_data, auto self_dim_stride, scalar_t init_val) {
106         // NOLINTNEXTLINE(bugprone-signed-char-misuse)
107         auto cum_number = (at::acc_type<scalar_t, false>)init_val;
108         for (const auto i : c10::irange(self_dim_size)) {
109           cum_number *= self_data[i * self_dim_stride];
110           result_data[i * result_dim_stride] = (scalar_t)cum_number;
111         }
112       }, /*init_val=*/ 1
113     );
114   });
115 }
116 
logcumsumexp_cpu_kernel(Tensor & result,const Tensor & self,int64_t dim)117 static void logcumsumexp_cpu_kernel(Tensor& result, const Tensor& self, int64_t dim) {
118   auto wrap_dim = maybe_wrap_dim(dim, self.dim());
119   int64_t self_dim_size = ensure_nonempty_size(self, wrap_dim);
120 
121   AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES_AND2(kBFloat16, kHalf, self.scalar_type(), "logcumsumexp_out_cpu", [&] {
122     cpu_cum_base_kernel<scalar_t>(result, self, wrap_dim, [&] (
123       scalar_t* result_data, auto result_dim_stride,
124       const scalar_t* self_data, auto self_dim_stride, scalar_t init_val) {
125         using accscalar_t = at::acc_type<scalar_t, false>;
126         auto cum_number = (accscalar_t)init_val;
127         for (const auto i : c10::irange(self_dim_size)) {
128           accscalar_t x = self_data[i * self_dim_stride];
129 
130           cum_number = _log_add_exp_helper(x, cum_number);
131           result_data[i * result_dim_stride] = static_cast<scalar_t>(cum_number);
132         }
133       }, /*init_val=*/ -std::numeric_limits<scalar_t>::infinity()
134     );
135   });
136 }
137 
std_var_kernel_impl(TensorIterator & iter,double correction,bool take_sqrt)138 static void std_var_kernel_impl(TensorIterator& iter, double correction, bool take_sqrt) {
139   AT_DISPATCH_FLOATING_TYPES_AND2(kHalf, kBFloat16, iter.dtype(), "std_cpu", [&] {
140     binary_kernel_reduce(
141         iter,
142         WelfordOps<
143             scalar_t,
144             double,
145             int64_t,
146             std::tuple<scalar_t, scalar_t>>{correction, take_sqrt},
147         WelfordData<double, int64_t>());
148   });
149 }
150 
prod_kernel_impl(TensorIterator & iter)151 static void prod_kernel_impl(TensorIterator& iter) {
152   // Workaround for the error: '*' in boolean context, suggest '&&' instead
153   // [-Werror=int-in-bool-context]
154   if (iter.dtype() == ScalarType::Bool) {
155     using scalar_t = bool;
156     binary_kernel_reduce_vec(
157         iter,
158         [=](scalar_t a, scalar_t b)
159             __ubsan_ignore_undefined__ -> scalar_t { return a && b; },
160         [=](Vectorized<scalar_t> a, Vectorized<scalar_t> b)
161             __ubsan_ignore_undefined__ { return a && b; },
162         // NOLINTNEXTLINE(bugprone-argument-comment)
163         /*identity=*/1);
164   } else {
165     AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND2(kBFloat16, kHalf, iter.dtype(), "prod_out_cpu", [&] {
166       binary_kernel_reduce_vec(
167           iter,
168           [=](scalar_t a, scalar_t b)
169               __ubsan_ignore_undefined__ -> scalar_t { return a * b; },
170           [=](Vectorized<scalar_t> a, Vectorized<scalar_t> b)
171               __ubsan_ignore_undefined__ { return a * b; },
172           // NOLINTNEXTLINE(bugprone-argument-comment)
173           /*identity=*/1);
174     });
175   }
176 }
177 
178 template <typename scalar_t, typename acc_t>
norm_two_reduce_step(Vectorized<acc_t> & acc_vec,Vectorized<scalar_t> & data_vec)179 inline void norm_two_reduce_step(Vectorized<acc_t>& acc_vec, Vectorized<scalar_t>& data_vec) {
180   acc_vec += data_vec * data_vec;
181 }
182 
183 template <>
norm_two_reduce_step(Vectorized<float> & acc_fvec,Vectorized<BFloat16> & data_bvec)184 inline void norm_two_reduce_step(Vectorized<float>& acc_fvec, Vectorized<BFloat16>& data_bvec) {
185   auto [data_fvec0, data_fvec1] = convert_bfloat16_float(data_bvec);
186   acc_fvec += data_fvec0 * data_fvec0;
187   acc_fvec += data_fvec1 * data_fvec1;
188 }
189 
190 // This reduction accumulates results as the type `acc_t`. By default, when
191 // `scalar_t` is complex, `acc_t` is the downgraded real number type.
192 // Otherwise, `acc_t` and `scalar_t` are the same type.
193 template <typename scalar_t, typename acc_t=typename scalar_value_type<scalar_t>::type, typename out_t=typename scalar_value_type<scalar_t>::type>
norm_kernel_cpu_impl(TensorIterator & iter,const double & val)194 void norm_kernel_cpu_impl(TensorIterator& iter, const double& val) {
195   if (val == 0.0) {
196     binary_kernel_reduce(iter, NormZeroOps<scalar_t, acc_t, out_t>(), acc_t(0));
197   } else if (val == 1.0) {
198     binary_kernel_reduce(iter, NormOneOps<scalar_t, acc_t, out_t>(), acc_t(0));
199   } else if (val == 2.0) {
200     binary_kernel_reduce(iter, NormTwoOps<scalar_t, acc_t, out_t>(), acc_t(0));
201   } else if (val == INFINITY) {
202     binary_kernel_reduce(iter, AbsMaxOps<scalar_t, acc_t, out_t>(), acc_t(0));
203   } else if (val == -INFINITY) {
204     binary_kernel_reduce(iter, AbsMinOps<scalar_t, acc_t, out_t>(), std::numeric_limits<acc_t>::infinity());
205   } else {
206     binary_kernel_reduce(iter, NormOps<scalar_t, acc_t, out_t>{acc_t(val)}, acc_t(0));
207   }
208 }
209 
norm_kernel_tensor_iterator_impl(TensorIterator & iter,const Scalar & p)210 static void norm_kernel_tensor_iterator_impl(
211     TensorIterator& iter,
212     const Scalar& p) {
213   double val = 0;
214   if (p.isIntegral(false)) {
215     val = p.to<int64_t>();
216   } else if (p.isFloatingPoint()) {
217     val = p.to<double>();
218   } else {
219     TORCH_CHECK(false, "norm_kernel_cpu expects norm to be integer or float");
220   }
221   if (iter.numel() == 0) {
222     iter.output().fill_((val < 0) ? INFINITY : 0);
223     return;
224   }
225 
226   if (val == 2.0 && is_reduce_lastdim(iter) &&
227       iter.dtype(0) == iter.input_dtype() &&
228       (iter.input_dtype() == kFloat || iter.input_dtype() == kDouble ||
229        iter.input_dtype() == kBFloat16)) {
230     // If we can vectorize over the last dimension and the dtype
231     // of the output is the same as that of the input,
232     // then we go through the vectorised path.
233     AT_DISPATCH_FLOATING_TYPES_AND(kBFloat16, iter.input_dtype(), "norm_cpu", [&] {
234         // use float as accumulate type for BFloat16
235         using acc_t = at::opmath_type<scalar_t>;
236         binary_kernel_reduce_lastdim(iter, [](char* result_data_bytes, char* self_data_bytes, int64_t size) {
237           scalar_t* result_data = (scalar_t*)result_data_bytes;
238           scalar_t* self_data = (scalar_t*)self_data_bytes;
239 
240           using Vec = Vectorized<scalar_t>;
241           using fVec = Vectorized<acc_t>;
242           fVec acc_vec{acc_t(0)};
243           acc_t buffer[fVec::size()];
244           int64_t d = 0;
245           for (; d < size - (size % Vec::size()); d += Vec::size()) {
246             Vec data_vec = Vec::loadu(self_data + d);
247             norm_two_reduce_step(acc_vec, data_vec);
248           }
249           acc_vec.store(buffer);
250           for (int j = 1; j < fVec::size(); j++) {
251             buffer[0] = buffer[0] + buffer[j];
252           }
253           for (; d < size; d++) {
254             acc_t data_val = acc_t(self_data[d]);
255             buffer[0] += data_val * data_val;
256           }
257           result_data[0] = scalar_t(std::sqrt(buffer[0]));
258         });
259       });
260   } else {
261     if (iter.dtype(0) == kHalf) {
262       return norm_kernel_cpu_impl<at::Half, float>(iter, val);
263     } else if (iter.input_dtype() == kHalf && iter.dtype(0) == kFloat) {
264       // type promotion that does cast and reduction in a single kernel
265       return norm_kernel_cpu_impl<at::Half, float, float>(iter, val);
266     } else if(iter.dtype(0) == kBFloat16) {
267       return norm_kernel_cpu_impl<at::BFloat16, float>(iter, val);
268     } else if (iter.input_dtype() == kBFloat16 && iter.dtype(0) == kFloat) {
269       // type promotion that does cast and reduction in a single kernel
270       return norm_kernel_cpu_impl<at::BFloat16, float, float>(iter, val);
271     }
272 
273     AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES(iter.input_dtype(), "norm_cpu", [&] {
274       norm_kernel_cpu_impl<scalar_t>(iter, val);
275     });
276 
277     // For complex outputs, the above kernels do not touch the imaginary values,
278     // so we must zero them out
279     if (isComplexType(iter.output().scalar_type())) {
280       at::imag(iter.output()).zero_();
281     }
282   }
283 }
284 
and_kernel_impl(TensorIterator & iter)285 static void and_kernel_impl(TensorIterator& iter) {
286   if (iter.dtype() == ScalarType::Byte) {
287     // Refer [all, any : uint8 compatibility]
288     binary_kernel_reduce_vec(
289         iter,
290         [=](uint8_t a, uint8_t b) -> uint8_t { return (a && b) ? 1 : 0; },
291         [=](Vectorized<uint8_t> a, Vectorized<uint8_t> b) {
292           // NB: != returns 0xFF rather than 0x01, so we must negate to get
293           // the desired result
294           return (a != Vectorized<uint8_t>(0)).neg() & (b != Vectorized<uint8_t>(0)).neg();
295         },
296         /*ident=*/true);
297   } else {
298     binary_kernel_reduce_vec(
299         iter,
300         [=](bool a, bool b) -> bool { return a && b; },
301         [=](Vectorized<bool> a, Vectorized<bool> b) {
302           // Adding the implementation here instead of in vec256_base to avoid
303           // return value inconsistency. Other comparison operators in
304           // vec256_base return -1/0 (all bit 1 / all bit 0) as true/false to
305           // follow the AVX2 convention. This would be convenient when combined
306           // with other vectorized operations. For example, one can use the
307           // logical operation results as a mask for a bit operation to
308           // retrieve/reset multiple elements in a vector.
309           //
310           // In this method, users would expect, e.g., all(), to return 1/0 as
311           // true/false.
312           Vectorized<bool> c = Vectorized<bool>();
313 
314           for (decltype(c.size()) i = 0; i != Vectorized<bool>::size(); i++) {
315             c[i] = a[i] && b[i];
316           }
317           return c;
318         },
319         /*ident=*/true);
320   }
321 }
322 
or_kernel_impl(TensorIterator & iter)323 static void or_kernel_impl(TensorIterator& iter) {
324   if (iter.dtype() == ScalarType::Byte) {
325     // Refer [all, any : uint8 compatibility]
326     binary_kernel_reduce_vec(
327         iter,
328         [=](uint8_t a, uint8_t b) -> uint8_t { return (a || b) ? 1 : 0; },
329         [=](Vectorized<uint8_t> a, Vectorized<uint8_t> b) {
330           return (a != Vectorized<uint8_t>(0)).neg() | (b != Vectorized<uint8_t>(0)).neg();
331         },
332         /*ident=*/false);
333   } else {
334     binary_kernel_reduce_vec(
335         iter,
336         [=](bool a, bool b) -> bool { return a || b; },
337         [=](Vectorized<bool> a, Vectorized<bool> b) {
338           Vectorized<bool> c = Vectorized<bool>();
339 
340           for (decltype(c.size()) i = 0; i != Vectorized<bool>::size(); i++) {
341             c[i] = a[i] || b[i];
342           }
343           return c;
344         },
345         /*ident=*/false);
346   }
347 }
348 
349 template<typename scalar_t>
350 struct MinValuesOps: public at::native::MinOps<scalar_t> {
351   using arg_t = typename MinOps<scalar_t>::arg_t;
projectat::native::__anonbe6d014d0111::MinValuesOps352   static scalar_t project(arg_t arg) {
353     return arg.first;
354   }
355 };
356 
min_values_kernel_impl(TensorIterator & iter)357 static void min_values_kernel_impl(TensorIterator& iter) {
358   if (iter.dtype() == kLong) {
359     // This case is special because of Vectorized<int64_t> does not
360     // handle upper_bound<int64_t>().
361     // See: https://github.com/pytorch/pytorch/issues/43254
362     using scalar_t = int64_t;
363     binary_kernel_reduce(
364       iter,
365       MinValuesOps<scalar_t>{},
366       std::pair<scalar_t, int64_t>(upper_bound<scalar_t>(), -1));
367     return;
368   }
369   AT_DISPATCH_ALL_TYPES_AND3(kBFloat16, kHalf, kBool, iter.dtype(), "min_values_cpu", [&iter] {
370     binary_kernel_reduce_vec(
371       iter,
372       [](scalar_t a, scalar_t b) -> scalar_t { return min_impl(a, b); },
373       [](Vectorized<scalar_t> a, Vectorized<scalar_t> b) { return minimum(a, b); },
374       static_cast<double>(upper_bound<scalar_t>()));
375   });
376 }
377 
max_values_kernel_impl(TensorIterator & iter)378 static void max_values_kernel_impl(TensorIterator& iter) {
379   AT_DISPATCH_ALL_TYPES_AND3(kBFloat16, kHalf, kBool, iter.dtype(), "max_values_cpu", [&iter] {
380     binary_kernel_reduce_vec(
381       iter,
382       [](scalar_t a, scalar_t b) -> scalar_t { return max_impl(a, b); },
383       [](Vectorized<scalar_t> a, Vectorized<scalar_t> b) { return maximum(a, b); },
384       lower_bound<scalar_t>());
385   });
386 }
387 
argmax_kernel_impl(TensorIterator & iter)388 static void argmax_kernel_impl(TensorIterator &iter) {
389   AT_DISPATCH_ALL_TYPES_AND2(kHalf, kBFloat16, iter.dtype(1), "argmax_cpu", [&] {
390     if (is_reduce_lastdim(iter)) {
391       using arg_t = std::pair<scalar_t, int64_t>;
392       auto op = ArgMaxOps<scalar_t>{};
393       binary_kernel_reduce_lastdim(iter, [&](char* result_data_bytes, char* self_data_bytes, int64_t size) {
394         int64_t* result_data = (int64_t*)result_data_bytes;
395         scalar_t* self_data = (scalar_t*)self_data_bytes;
396 
397         arg_t acc = arg_t(lower_bound<scalar_t>(), 0);
398         for (int64_t i = 0; i < size; i++) {
399           acc = op.reduce(acc, self_data[i], i);
400         }
401         result_data[0] = acc.second;
402       });
403       return;
404     }
405     binary_kernel_reduce(
406       iter,
407       ArgMaxOps<scalar_t>{},
408       std::pair<scalar_t, int64_t>(lower_bound<scalar_t>(), 0));
409   });
410 }
411 
argmin_kernel_impl(TensorIterator & iter)412 static void argmin_kernel_impl(TensorIterator &iter) {
413   AT_DISPATCH_ALL_TYPES_AND2(kHalf, kBFloat16, iter.dtype(1), "argmin_cpu", [&] {
414     if (is_reduce_lastdim(iter)) {
415       using arg_t = std::pair<scalar_t, int64_t>;
416       auto op = ArgMinOps<scalar_t>{};
417       binary_kernel_reduce_lastdim(iter, [&](char* result_data_bytes, char* self_data_bytes, int64_t size) {
418         int64_t* result_data = (int64_t*)result_data_bytes;
419         scalar_t* self_data = (scalar_t*)self_data_bytes;
420 
421         arg_t acc = arg_t(upper_bound<scalar_t>(), 0);
422         for (int64_t i = 0; i < size; i++) {
423           acc = op.reduce(acc, self_data[i], i);
424         }
425         result_data[0] = acc.second;
426       });
427       return;
428     }
429     binary_kernel_reduce(
430       iter,
431       ArgMinOps<scalar_t>{},
432       std::pair<scalar_t, int64_t>(upper_bound<scalar_t>(), 0));
433   });
434 }
435 
436 }  // anonymous namespace
437 
438 REGISTER_DISPATCH(std_var_stub, &std_var_kernel_impl);
439 REGISTER_DISPATCH(prod_stub, &prod_kernel_impl);
440 // mean implementation for CPU is in aten/src/ATen/native/ReduceOps.cpp
441 // but mean_stub must be defined for CPU as well
442 REGISTER_DISPATCH(mean_stub, nullptr);
443 REGISTER_DISPATCH(norm_stub, &norm_kernel_tensor_iterator_impl);
444 REGISTER_DISPATCH(and_stub, &and_kernel_impl);
445 REGISTER_DISPATCH(or_stub, &or_kernel_impl);
446 REGISTER_DISPATCH(min_values_stub, &min_values_kernel_impl);
447 REGISTER_DISPATCH(max_values_stub, &max_values_kernel_impl);
448 REGISTER_DISPATCH(argmax_stub, &argmax_kernel_impl);
449 REGISTER_DISPATCH(argmin_stub, &argmin_kernel_impl);
450 
451 REGISTER_DISPATCH(cumprod_stub, &cumprod_cpu_kernel);
452 REGISTER_DISPATCH(cumsum_stub, &cumsum_cpu_kernel);
453 REGISTER_DISPATCH(logcumsumexp_stub, &logcumsumexp_cpu_kernel);
454 
455 }  // namespace at::native
456