1 #define TORCH_ASSERT_ONLY_METHOD_OPERATORS
2 #include <ATen/native/cuda/ReduceOps.h>
3
4 #include <ATen/native/ReduceOps.h>
5 #include <ATen/native/ReduceAllOps.h>
6 #include <ATen/native/ReduceOpsUtils.h>
7 #include <ATen/native/TensorCompare.h>
8
9 #include <ATen/Context.h>
10 #include <ATen/TensorUtils.h>
11 #include <ATen/WrapDimUtils.h>
12 #include <ATen/core/NamedTensor.h>
13 #include <ATen/TensorIterator.h>
14
15 #ifndef AT_PER_OPERATOR_HEADERS
16 #include <ATen/Functions.h>
17 #include <ATen/NativeFunctions.h>
18 #else
19 #include <ATen/ops/full.h>
20 #include <ATen/ops/imag.h>
21 #include <ATen/ops/kthvalue_native.h>
22 #include <ATen/ops/median_native.h>
23 #include <ATen/ops/nanmedian_native.h>
24 #include <ATen/ops/where.h>
25 #endif
26
27 namespace at::native {
28 namespace {
29
norm_kernel_cuda(TensorIterator & iter,const Scalar & val)30 void norm_kernel_cuda(TensorIterator& iter, const Scalar& val) {
31 double p;
32 if (val.isIntegral(false)) {
33 p = val.to<int64_t>();
34 } else if (val.isFloatingPoint()) {
35 p = val.to<double>();
36 } else {
37 TORCH_CHECK(false, "norm_kernel_cuda_impl expects norm to be integer or float");
38 }
39 if (iter.numel() == 0) {
40 iter.output().fill_((p < 0) ? INFINITY : 0);
41 return;
42 }
43
44 norm_launch_kernel(iter, p);
45
46 if (isComplexType(iter.output().scalar_type())) {
47 at::imag(iter.output()).zero_();
48 }
49
50 }
51
min_kernel_impl(const Tensor & result,const Tensor & indice,const Tensor & self,int64_t dim,bool keepdim)52 void min_kernel_impl(const Tensor& result, const Tensor& indice, const Tensor& self, int64_t dim, bool keepdim) {
53 auto iter = meta::make_reduction(self, result, indice, dim, keepdim, self.scalar_type(), kLong);
54 min_launch_kernel(iter);
55 }
56
max_kernel_impl(const Tensor & result,const Tensor & indice,const Tensor & self,int64_t dim,bool keepdim)57 void max_kernel_impl(const Tensor& result, const Tensor& indice, const Tensor& self, int64_t dim, bool keepdim) {
58 auto iter = meta::make_reduction(self, result, indice, dim, keepdim, self.scalar_type(), kLong);
59 max_launch_kernel(iter);
60 }
61
aminmax_kernel_impl(const Tensor & self,int64_t dim,bool keepdim,Tensor & min_result,Tensor & max_result)62 void aminmax_kernel_impl(
63 const Tensor& self, int64_t dim, bool keepdim, Tensor& min_result, Tensor& max_result) {
64 at::TensorIterator iter = make_reduction("aminmax_cuda", min_result,
65 max_result, self, dim, keepdim, self.scalar_type());
66 if (iter.numel() != 0) {
67 aminmax_launch_kernel(iter);
68 }
69 }
70
min_all_kernel_impl(Tensor & result,const Tensor & input)71 void min_all_kernel_impl(Tensor& result, const Tensor& input) {
72 auto dtype = input.scalar_type();
73 auto iter = make_reduction("min_all", result, input, IntArrayRef{}, false, dtype);
74 min_all_launch_kernel(iter);
75 }
76
max_all_kernel_impl(Tensor & result,const Tensor & input)77 void max_all_kernel_impl(Tensor& result, const Tensor& input) {
78 auto dtype = input.scalar_type();
79 auto iter = make_reduction("max_all", result, input, IntArrayRef{}, false, dtype);
80 max_all_launch_kernel(iter);
81 }
82
aminmax_allreduce_kernel_impl(const Tensor & input,Tensor & min_result,Tensor & max_result)83 void aminmax_allreduce_kernel_impl(const Tensor& input, Tensor& min_result, Tensor& max_result) {
84 auto dtype = input.scalar_type();
85 auto iter = make_reduction("aminmax_cuda", min_result, max_result, input,
86 IntArrayRef{}, false, dtype);
87 TORCH_CHECK(iter.numel() > 0, "min_max on a tensor with no elements is not defined.");
88 aminmax_allreduce_launch_kernel(iter);
89 }
90
91 } // namespace (anonymous)
92
93 REGISTER_CUDA_DISPATCH(min_stub, &min_kernel_impl);
94 REGISTER_CUDA_DISPATCH(max_stub, &max_kernel_impl);
95 REGISTER_CUDA_DISPATCH(min_all_stub, &min_all_kernel_impl);
96 REGISTER_CUDA_DISPATCH(max_all_stub, &max_all_kernel_impl);
97 REGISTER_CUDA_DISPATCH(aminmax_allreduce_stub, &aminmax_allreduce_kernel_impl);
98 REGISTER_CUDA_DISPATCH(aminmax_stub, &aminmax_kernel_impl);
99
100 REGISTER_CUDA_DISPATCH(norm_stub, &norm_kernel_cuda);
101
102 } // namespace at::native
103