1 #define TORCH_ASSERT_NO_OPERATORS
2 #include <ATen/core/TensorBase.h>
3 #include <ATen/Dispatch.h>
4
5 #include <ATen/native/cuda/ScanKernels.h>
6 #include <ATen/native/cuda/ScanUtils.cuh>
7
8 #include <limits>
9 #include <functional>
10
11 namespace at::native {
12
launch_cummax_cuda_kernel(const TensorBase & self,const TensorBase & values,const TensorBase & indices,int64_t dim)13 void launch_cummax_cuda_kernel(const TensorBase& self, const TensorBase& values, const TensorBase& indices, int64_t dim) {
14 AT_DISPATCH_ALL_TYPES_AND3(at::ScalarType::Bool, at::ScalarType::Half, at::ScalarType::BFloat16,
15 self.scalar_type(), "cummax_cuda", [&]() {
16 scalar_t init = self.is_floating_point() ? (-1*std::numeric_limits<scalar_t>::infinity()) : std::numeric_limits<scalar_t>::lowest();
17 scan_dim_with_indices<scalar_t>(self, values, indices, dim, init, std::greater_equal<scalar_t>());
18 });
19 }
20
launch_cummin_cuda_kernel(const TensorBase & self,const TensorBase & values,const TensorBase & indices,int64_t dim)21 void launch_cummin_cuda_kernel(const TensorBase& self, const TensorBase& values, const TensorBase& indices, int64_t dim) {
22 AT_DISPATCH_ALL_TYPES_AND3(at::ScalarType::Bool, at::ScalarType::Half, at::ScalarType::BFloat16,
23 self.scalar_type(), "cummin_cuda", [&]() {
24 scalar_t init = self.is_floating_point() ? std::numeric_limits<scalar_t>::infinity() : std::numeric_limits<scalar_t>::max();
25 scan_dim_with_indices<scalar_t>(self, values, indices, dim, init, std::less_equal<scalar_t>());
26 });
27 }
28
29 } // namespace at::native
30