xref: /aosp_15_r20/external/pytorch/aten/src/ATen/native/ReduceOps.h (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #pragma once
2 
3 #include <ATen/native/DispatchStub.h>
4 #include <c10/util/ArrayRef.h>
5 #include <optional>
6 
7 namespace c10 {
8 class Scalar;
9 }
10 
11 namespace at {
12 struct TensorIterator;
13 class Tensor;
14 }
15 
16 namespace at::native {
17 
18 using reduce_fn = void(*)(TensorIterator &);
19 
20 DECLARE_DISPATCH(reduce_fn, sum_stub);
21 DECLARE_DISPATCH(reduce_fn, nansum_stub);
22 DECLARE_DISPATCH(reduce_fn, prod_stub);
23 DECLARE_DISPATCH(reduce_fn, mean_stub);
24 DECLARE_DISPATCH(reduce_fn, and_stub);
25 DECLARE_DISPATCH(reduce_fn, or_stub);
26 DECLARE_DISPATCH(reduce_fn, min_values_stub);
27 DECLARE_DISPATCH(reduce_fn, max_values_stub);
28 DECLARE_DISPATCH(reduce_fn, argmax_stub);
29 DECLARE_DISPATCH(reduce_fn, argmin_stub);
30 
31 using reduce_std_var_function =
32     void (*)(TensorIterator&, double correction, bool take_sqrt);
33 DECLARE_DISPATCH(reduce_std_var_function, std_var_stub);
34 
35 using reduce_norm_fn =
36     void (*)(Tensor&, const Tensor&, const c10::Scalar&, std::optional<int64_t>);
37 DECLARE_DISPATCH(reduce_norm_fn, norm_kernel);
38 
39 using reduce_fn_flag = void(*)(TensorIterator &, const c10::Scalar&);
40 DECLARE_DISPATCH(reduce_fn_flag, norm_stub);
41 
42 using structured_cum_fn = void (*)(const Tensor&, const Tensor&, int64_t);
43 using cum_fn = void (*)(Tensor&, const Tensor&, int64_t);
44 DECLARE_DISPATCH(structured_cum_fn, cumsum_stub);
45 DECLARE_DISPATCH(structured_cum_fn, cumprod_stub);
46 DECLARE_DISPATCH(cum_fn, logcumsumexp_stub);
47 
48 DECLARE_DISPATCH(void (*)(const Tensor&, int64_t, bool, Tensor&, Tensor&), aminmax_stub);
49 DECLARE_DISPATCH(void (*)(const Tensor&, Tensor&, Tensor&), aminmax_allreduce_stub);
50 
51 // Used in cuda/Normalization.cu
52 TORCH_API std::tuple<Tensor&,Tensor&> var_mean_out(
53     Tensor &result1, Tensor &result2, const Tensor &self, IntArrayRef dim,
54     int64_t correction, bool keepdim);
55 
56 } // namespace at::native
57