xref: /aosp_15_r20/external/pytorch/aten/src/ATen/native/cuda/ReduceOps.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #define TORCH_ASSERT_ONLY_METHOD_OPERATORS
2 #include <ATen/native/cuda/ReduceOps.h>
3 
4 #include <ATen/native/ReduceOps.h>
5 #include <ATen/native/ReduceAllOps.h>
6 #include <ATen/native/ReduceOpsUtils.h>
7 #include <ATen/native/TensorCompare.h>
8 
9 #include <ATen/Context.h>
10 #include <ATen/TensorUtils.h>
11 #include <ATen/WrapDimUtils.h>
12 #include <ATen/core/NamedTensor.h>
13 #include <ATen/TensorIterator.h>
14 
15 #ifndef AT_PER_OPERATOR_HEADERS
16 #include <ATen/Functions.h>
17 #include <ATen/NativeFunctions.h>
18 #else
19 #include <ATen/ops/full.h>
20 #include <ATen/ops/imag.h>
21 #include <ATen/ops/kthvalue_native.h>
22 #include <ATen/ops/median_native.h>
23 #include <ATen/ops/nanmedian_native.h>
24 #include <ATen/ops/where.h>
25 #endif
26 
27 namespace at::native {
28 namespace {
29 
norm_kernel_cuda(TensorIterator & iter,const Scalar & val)30 void norm_kernel_cuda(TensorIterator& iter, const Scalar& val) {
31   double p;
32   if (val.isIntegral(false)) {
33     p = val.to<int64_t>();
34   } else if (val.isFloatingPoint()) {
35     p = val.to<double>();
36   } else {
37     TORCH_CHECK(false, "norm_kernel_cuda_impl expects norm to be integer or float");
38   }
39   if (iter.numel() == 0) {
40     iter.output().fill_((p < 0) ? INFINITY : 0);
41     return;
42   }
43 
44   norm_launch_kernel(iter, p);
45 
46   if (isComplexType(iter.output().scalar_type())) {
47     at::imag(iter.output()).zero_();
48   }
49 
50 }
51 
min_kernel_impl(const Tensor & result,const Tensor & indice,const Tensor & self,int64_t dim,bool keepdim)52 void min_kernel_impl(const Tensor& result, const Tensor& indice, const Tensor& self, int64_t dim, bool keepdim) {
53   auto iter = meta::make_reduction(self, result, indice, dim, keepdim, self.scalar_type(), kLong);
54   min_launch_kernel(iter);
55 }
56 
max_kernel_impl(const Tensor & result,const Tensor & indice,const Tensor & self,int64_t dim,bool keepdim)57 void max_kernel_impl(const Tensor& result, const Tensor& indice, const Tensor& self, int64_t dim, bool keepdim) {
58   auto iter = meta::make_reduction(self, result, indice, dim, keepdim, self.scalar_type(), kLong);
59   max_launch_kernel(iter);
60 }
61 
aminmax_kernel_impl(const Tensor & self,int64_t dim,bool keepdim,Tensor & min_result,Tensor & max_result)62 void aminmax_kernel_impl(
63     const Tensor& self, int64_t dim, bool keepdim, Tensor& min_result, Tensor& max_result) {
64   at::TensorIterator iter = make_reduction("aminmax_cuda", min_result,
65                                            max_result, self, dim, keepdim, self.scalar_type());
66   if (iter.numel() != 0) {
67     aminmax_launch_kernel(iter);
68   }
69 }
70 
min_all_kernel_impl(Tensor & result,const Tensor & input)71 void min_all_kernel_impl(Tensor& result, const Tensor& input) {
72   auto dtype = input.scalar_type();
73   auto iter = make_reduction("min_all", result, input, IntArrayRef{}, false, dtype);
74   min_all_launch_kernel(iter);
75 }
76 
max_all_kernel_impl(Tensor & result,const Tensor & input)77 void max_all_kernel_impl(Tensor& result, const Tensor& input) {
78   auto dtype = input.scalar_type();
79   auto iter = make_reduction("max_all", result, input, IntArrayRef{}, false, dtype);
80   max_all_launch_kernel(iter);
81 }
82 
aminmax_allreduce_kernel_impl(const Tensor & input,Tensor & min_result,Tensor & max_result)83 void aminmax_allreduce_kernel_impl(const Tensor& input, Tensor& min_result, Tensor& max_result) {
84   auto dtype = input.scalar_type();
85   auto iter = make_reduction("aminmax_cuda", min_result, max_result, input,
86                              IntArrayRef{}, false, dtype);
87   TORCH_CHECK(iter.numel() > 0, "min_max on a tensor with no elements is not defined.");
88   aminmax_allreduce_launch_kernel(iter);
89 }
90 
91 }  // namespace (anonymous)
92 
93 REGISTER_CUDA_DISPATCH(min_stub, &min_kernel_impl);
94 REGISTER_CUDA_DISPATCH(max_stub, &max_kernel_impl);
95 REGISTER_CUDA_DISPATCH(min_all_stub, &min_all_kernel_impl);
96 REGISTER_CUDA_DISPATCH(max_all_stub, &max_all_kernel_impl);
97 REGISTER_CUDA_DISPATCH(aminmax_allreduce_stub, &aminmax_allreduce_kernel_impl);
98 REGISTER_CUDA_DISPATCH(aminmax_stub, &aminmax_kernel_impl);
99 
100 REGISTER_CUDA_DISPATCH(norm_stub, &norm_kernel_cuda);
101 
102 } // namespace at::native
103