1 #pragma once 2 3 #include <ATen/native/DispatchStub.h> 4 #include <c10/util/ArrayRef.h> 5 #include <optional> 6 7 namespace c10 { 8 class Scalar; 9 } 10 11 namespace at { 12 struct TensorIterator; 13 class Tensor; 14 } 15 16 namespace at::native { 17 18 using reduce_fn = void(*)(TensorIterator &); 19 20 DECLARE_DISPATCH(reduce_fn, sum_stub); 21 DECLARE_DISPATCH(reduce_fn, nansum_stub); 22 DECLARE_DISPATCH(reduce_fn, prod_stub); 23 DECLARE_DISPATCH(reduce_fn, mean_stub); 24 DECLARE_DISPATCH(reduce_fn, and_stub); 25 DECLARE_DISPATCH(reduce_fn, or_stub); 26 DECLARE_DISPATCH(reduce_fn, min_values_stub); 27 DECLARE_DISPATCH(reduce_fn, max_values_stub); 28 DECLARE_DISPATCH(reduce_fn, argmax_stub); 29 DECLARE_DISPATCH(reduce_fn, argmin_stub); 30 31 using reduce_std_var_function = 32 void (*)(TensorIterator&, double correction, bool take_sqrt); 33 DECLARE_DISPATCH(reduce_std_var_function, std_var_stub); 34 35 using reduce_norm_fn = 36 void (*)(Tensor&, const Tensor&, const c10::Scalar&, std::optional<int64_t>); 37 DECLARE_DISPATCH(reduce_norm_fn, norm_kernel); 38 39 using reduce_fn_flag = void(*)(TensorIterator &, const c10::Scalar&); 40 DECLARE_DISPATCH(reduce_fn_flag, norm_stub); 41 42 using structured_cum_fn = void (*)(const Tensor&, const Tensor&, int64_t); 43 using cum_fn = void (*)(Tensor&, const Tensor&, int64_t); 44 DECLARE_DISPATCH(structured_cum_fn, cumsum_stub); 45 DECLARE_DISPATCH(structured_cum_fn, cumprod_stub); 46 DECLARE_DISPATCH(cum_fn, logcumsumexp_stub); 47 48 DECLARE_DISPATCH(void (*)(const Tensor&, int64_t, bool, Tensor&, Tensor&), aminmax_stub); 49 DECLARE_DISPATCH(void (*)(const Tensor&, Tensor&, Tensor&), aminmax_allreduce_stub); 50 51 // Used in cuda/Normalization.cu 52 TORCH_API std::tuple<Tensor&,Tensor&> var_mean_out( 53 Tensor &result1, Tensor &result2, const Tensor &self, IntArrayRef dim, 54 int64_t correction, bool keepdim); 55 56 } // namespace at::native 57