xref: /aosp_15_r20/external/pytorch/aten/src/ATen/native/cuda/CumminmaxKernel.cu (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #define TORCH_ASSERT_NO_OPERATORS
2 #include <ATen/core/TensorBase.h>
3 #include <ATen/Dispatch.h>
4 
5 #include <ATen/native/cuda/ScanKernels.h>
6 #include <ATen/native/cuda/ScanUtils.cuh>
7 
8 #include <limits>
9 #include <functional>
10 
11 namespace at::native {
12 
launch_cummax_cuda_kernel(const TensorBase & self,const TensorBase & values,const TensorBase & indices,int64_t dim)13 void launch_cummax_cuda_kernel(const TensorBase& self, const TensorBase& values, const TensorBase& indices, int64_t dim) {
14   AT_DISPATCH_ALL_TYPES_AND3(at::ScalarType::Bool, at::ScalarType::Half, at::ScalarType::BFloat16,
15     self.scalar_type(), "cummax_cuda", [&]() {
16     scalar_t init = self.is_floating_point() ? (-1*std::numeric_limits<scalar_t>::infinity()) : std::numeric_limits<scalar_t>::lowest();
17     scan_dim_with_indices<scalar_t>(self, values, indices, dim, init, std::greater_equal<scalar_t>());
18   });
19 }
20 
launch_cummin_cuda_kernel(const TensorBase & self,const TensorBase & values,const TensorBase & indices,int64_t dim)21 void launch_cummin_cuda_kernel(const TensorBase& self, const TensorBase& values, const TensorBase& indices, int64_t dim) {
22   AT_DISPATCH_ALL_TYPES_AND3(at::ScalarType::Bool, at::ScalarType::Half, at::ScalarType::BFloat16,
23     self.scalar_type(), "cummin_cuda", [&]() {
24     scalar_t init = self.is_floating_point() ? std::numeric_limits<scalar_t>::infinity() : std::numeric_limits<scalar_t>::max();
25     scan_dim_with_indices<scalar_t>(self, values, indices, dim, init, std::less_equal<scalar_t>());
26   });
27 }
28 
29 } // namespace at::native
30