1 #define TORCH_ASSERT_NO_OPERATORS
2 #include <ATen/Dispatch.h>
3 #include <ATen/NumericUtils.h>
4 #include <ATen/native/DispatchStub.h>
5 #include <ATen/native/ReduceAllOps.h>
6 #include <ATen/native/ReduceOps.h>
7 #include <ATen/native/SharedReduceOps.h>
8 #include <ATen/native/TensorCompare.h>
9 #include <ATen/native/TensorIterator.h>
10 #include <ATen/native/cuda/ReduceOps.h>
11 #include <ATen/cuda/NumericLimits.cuh>
12 #include <ATen/native/cuda/Reduce.cuh>
13
14 #include <ATen/Dispatch.h>
15 #include <ATen/NumericUtils.h>
16 #include <ATen/cuda/NumericLimits.cuh>
17
18 namespace at::native {
19
20 template <typename acc_t>
21 struct MaxNanFunctor {
operator ()at::native::MaxNanFunctor22 __device__ __forceinline__ acc_t operator()(acc_t a, acc_t b) const {
23 return (at::_isnan(a) || a > b) ? a : b;
24 }
25 };
26
27 template <typename scalar_t, typename acc_t = scalar_t>
max_values_kernel_cuda_impl(TensorIterator & iter)28 void max_values_kernel_cuda_impl(TensorIterator& iter) {
29 gpu_reduce_kernel<scalar_t, scalar_t>(
30 iter,
31 func_wrapper<acc_t>(MaxNanFunctor<acc_t>()),
32 at::numeric_limits<acc_t>::lower_bound());
33 }
34
max_values_kernel_cuda(TensorIterator & iter)35 void max_values_kernel_cuda(TensorIterator& iter) {
36 AT_DISPATCH_ALL_TYPES_AND3(
37 kBFloat16, kHalf, kBool, iter.dtype(), "max_values_cuda", [&]() {
38 max_values_kernel_cuda_impl<scalar_t>(iter);
39 });
40 }
41
max_launch_kernel(TensorIterator & iter)42 void max_launch_kernel(TensorIterator& iter) {
43 AT_DISPATCH_ALL_TYPES_AND3(
44 kBFloat16, kHalf, kBool, iter.input_dtype(), "max_cuda", [&]() {
45 gpu_reduce_kernel<scalar_t, scalar_t>(
46 iter,
47 MaxOps<scalar_t>{},
48 thrust::pair<scalar_t, int64_t>(
49 at::numeric_limits<scalar_t>::lower_bound(), 0));
50 });
51 }
52
max_all_launch_kernel(TensorIterator & iter)53 void max_all_launch_kernel(TensorIterator &iter) {
54 AT_DISPATCH_ALL_TYPES_AND3(kBFloat16, kHalf, kBool, iter.input_dtype(), "max_all_cuda", [&] {
55 max_values_kernel_cuda_impl<scalar_t>(iter);
56 });
57 }
58
59 REGISTER_DISPATCH(max_values_stub, &max_values_kernel_cuda);
60
61 } // namespace at::native
62