1 #define TORCH_ASSERT_NO_OPERATORS
2 #include <ATen/AccumulateType.h>
3 #include <ATen/Dispatch.h>
4 #include <ATen/native/BinaryOps.h>
5 #include <ATen/native/DispatchStub.h>
6 #include <ATen/native/TensorIterator.h>
7 #include <ATen/native/cuda/Loops.cuh>
8
9 // NOTE: CUDA on Windows requires that the enclosing function
10 // of a __device__ lambda not have internal linkage.
11
12 namespace at::native {
13
maximum_kernel_cuda(TensorIteratorBase & iter)14 void maximum_kernel_cuda(TensorIteratorBase& iter) {
15 if (iter.dtype() == ScalarType::Bool) {
16 opmath_symmetric_gpu_kernel_with_scalars<bool>(
17 iter, []GPU_LAMBDA(bool a, bool b) -> bool {
18 return a || b;
19 });
20 } else if (isIntegralType(iter.dtype(), /*includeBool=*/ false)) {
21 AT_DISPATCH_INTEGRAL_TYPES(iter.dtype(), "max_elementwise_cuda", [&]() {
22 opmath_symmetric_gpu_kernel_with_scalars<scalar_t>(
23 iter, []GPU_LAMBDA(scalar_t a, scalar_t b) -> scalar_t {
24 return ::max(a, b);
25 });
26 });
27 } else {
28 AT_DISPATCH_FLOATING_TYPES_AND2(at::ScalarType::Half, at::ScalarType::BFloat16, iter.dtype(), "max_elementwise_cuda", [&]() {
29 opmath_symmetric_gpu_kernel_with_scalars<scalar_t>(
30 iter, []GPU_LAMBDA(scalar_t a, scalar_t b) -> scalar_t {
31 if (a != a) {
32 return a;
33 } else if (b != b) {
34 return b;
35 } else {
36 return ::max(a, b);
37 }
38 });
39 });
40 }
41 }
42
minimum_kernel_cuda(TensorIteratorBase & iter)43 void minimum_kernel_cuda(TensorIteratorBase& iter) {
44 if (iter.dtype() == ScalarType::Bool) {
45 opmath_symmetric_gpu_kernel_with_scalars<bool>(iter, []GPU_LAMBDA(bool a, bool b) -> bool {
46 return a && b;
47 });
48 } else if (isIntegralType(iter.dtype(), /*includeBool=*/ false)) {
49 AT_DISPATCH_INTEGRAL_TYPES(iter.dtype(), "minimum_cuda", [&]() {
50 opmath_symmetric_gpu_kernel_with_scalars<scalar_t>(iter, []GPU_LAMBDA(scalar_t a, scalar_t b) -> scalar_t {
51 return ::min(a, b);
52 });
53 });
54 } else {
55 AT_DISPATCH_FLOATING_TYPES_AND2(at::ScalarType::Half, at::ScalarType::BFloat16, iter.dtype(), "min_elementwise_cuda", [&]() {
56 opmath_symmetric_gpu_kernel_with_scalars<scalar_t>(iter, []GPU_LAMBDA(scalar_t a, scalar_t b) -> scalar_t {
57 if (a != a) {
58 return a;
59 } else if (b != b) {
60 return b;
61 } else {
62 return ::min(a, b);
63 }
64 });
65 });
66 }
67 }
68
fmax_kernel_cuda(TensorIteratorBase & iter)69 void fmax_kernel_cuda(TensorIteratorBase& iter) {
70 if (isFloatingType(iter.common_dtype())) {
71 AT_DISPATCH_FLOATING_TYPES_AND2(at::ScalarType::Half, at::ScalarType::BFloat16, iter.common_dtype(), "fmax_cuda", [&]() {
72 opmath_symmetric_gpu_kernel_with_scalars<scalar_t>(iter, []GPU_LAMBDA(scalar_t a, scalar_t b) -> scalar_t {
73 return ::fmax(a, b);
74 });
75 });
76 } else {
77 maximum_kernel_cuda(iter);
78 }
79 }
80
fmin_kernel_cuda(TensorIteratorBase & iter)81 void fmin_kernel_cuda(TensorIteratorBase& iter) {
82 if (isFloatingType(iter.common_dtype())) {
83 AT_DISPATCH_FLOATING_TYPES_AND2(at::ScalarType::Half, at::ScalarType::BFloat16, iter.common_dtype(), "fmin_cuda", [&]() {
84 opmath_symmetric_gpu_kernel_with_scalars<scalar_t>(iter, []GPU_LAMBDA(scalar_t a, scalar_t b) -> scalar_t {
85 return ::fmin(a, b);
86 });
87 });
88 } else {
89 minimum_kernel_cuda(iter);
90 }
91 }
92
93 REGISTER_DISPATCH(maximum_stub, &maximum_kernel_cuda);
94 REGISTER_DISPATCH(minimum_stub, &minimum_kernel_cuda);
95 REGISTER_DISPATCH(fmax_stub, &fmax_kernel_cuda);
96 REGISTER_DISPATCH(fmin_stub, &fmin_kernel_cuda);
97
98 } // namespace at::native
99