xref: /aosp_15_r20/external/pytorch/aten/src/ATen/native/cuda/MaxMinElementwiseKernel.cu (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #define TORCH_ASSERT_NO_OPERATORS
2 #include <ATen/AccumulateType.h>
3 #include <ATen/Dispatch.h>
4 #include <ATen/native/BinaryOps.h>
5 #include <ATen/native/DispatchStub.h>
6 #include <ATen/native/TensorIterator.h>
7 #include <ATen/native/cuda/Loops.cuh>
8 
9 // NOTE: CUDA on Windows requires that the enclosing function
10 // of a __device__ lambda not have internal linkage.
11 
12 namespace at::native {
13 
maximum_kernel_cuda(TensorIteratorBase & iter)14 void maximum_kernel_cuda(TensorIteratorBase& iter) {
15   if (iter.dtype() == ScalarType::Bool) {
16     opmath_symmetric_gpu_kernel_with_scalars<bool>(
17         iter, []GPU_LAMBDA(bool a, bool b) -> bool {
18       return a || b;
19     });
20   } else if (isIntegralType(iter.dtype(), /*includeBool=*/ false)) {
21     AT_DISPATCH_INTEGRAL_TYPES(iter.dtype(), "max_elementwise_cuda", [&]() {
22       opmath_symmetric_gpu_kernel_with_scalars<scalar_t>(
23           iter, []GPU_LAMBDA(scalar_t a, scalar_t b) -> scalar_t {
24         return ::max(a, b);
25       });
26     });
27   } else {
28     AT_DISPATCH_FLOATING_TYPES_AND2(at::ScalarType::Half, at::ScalarType::BFloat16, iter.dtype(), "max_elementwise_cuda", [&]() {
29       opmath_symmetric_gpu_kernel_with_scalars<scalar_t>(
30           iter, []GPU_LAMBDA(scalar_t a, scalar_t b) -> scalar_t {
31         if (a != a) {
32           return a;
33         } else if (b != b) {
34           return b;
35         } else {
36           return ::max(a, b);
37         }
38       });
39     });
40   }
41 }
42 
minimum_kernel_cuda(TensorIteratorBase & iter)43 void minimum_kernel_cuda(TensorIteratorBase& iter) {
44   if (iter.dtype() == ScalarType::Bool) {
45     opmath_symmetric_gpu_kernel_with_scalars<bool>(iter, []GPU_LAMBDA(bool a, bool b) -> bool {
46       return a && b;
47     });
48   } else if (isIntegralType(iter.dtype(), /*includeBool=*/ false)) {
49     AT_DISPATCH_INTEGRAL_TYPES(iter.dtype(), "minimum_cuda", [&]() {
50       opmath_symmetric_gpu_kernel_with_scalars<scalar_t>(iter, []GPU_LAMBDA(scalar_t a, scalar_t b) -> scalar_t {
51         return ::min(a, b);
52       });
53     });
54   } else {
55     AT_DISPATCH_FLOATING_TYPES_AND2(at::ScalarType::Half, at::ScalarType::BFloat16, iter.dtype(), "min_elementwise_cuda", [&]() {
56       opmath_symmetric_gpu_kernel_with_scalars<scalar_t>(iter, []GPU_LAMBDA(scalar_t a, scalar_t b) -> scalar_t {
57         if (a != a) {
58           return a;
59         } else if (b != b) {
60           return b;
61         } else {
62           return ::min(a, b);
63         }
64       });
65     });
66   }
67 }
68 
fmax_kernel_cuda(TensorIteratorBase & iter)69 void fmax_kernel_cuda(TensorIteratorBase& iter) {
70   if (isFloatingType(iter.common_dtype())) {
71     AT_DISPATCH_FLOATING_TYPES_AND2(at::ScalarType::Half, at::ScalarType::BFloat16, iter.common_dtype(), "fmax_cuda", [&]() {
72       opmath_symmetric_gpu_kernel_with_scalars<scalar_t>(iter, []GPU_LAMBDA(scalar_t a, scalar_t b) -> scalar_t {
73         return ::fmax(a, b);
74       });
75     });
76   } else {
77     maximum_kernel_cuda(iter);
78   }
79 }
80 
fmin_kernel_cuda(TensorIteratorBase & iter)81 void fmin_kernel_cuda(TensorIteratorBase& iter) {
82   if (isFloatingType(iter.common_dtype())) {
83     AT_DISPATCH_FLOATING_TYPES_AND2(at::ScalarType::Half, at::ScalarType::BFloat16, iter.common_dtype(), "fmin_cuda", [&]() {
84       opmath_symmetric_gpu_kernel_with_scalars<scalar_t>(iter, []GPU_LAMBDA(scalar_t a, scalar_t b) -> scalar_t {
85         return ::fmin(a, b);
86       });
87     });
88   } else {
89     minimum_kernel_cuda(iter);
90   }
91 }
92 
93 REGISTER_DISPATCH(maximum_stub, &maximum_kernel_cuda);
94 REGISTER_DISPATCH(minimum_stub, &minimum_kernel_cuda);
95 REGISTER_DISPATCH(fmax_stub, &fmax_kernel_cuda);
96 REGISTER_DISPATCH(fmin_stub, &fmin_kernel_cuda);
97 
98 } // namespace at::native
99