1 #define TORCH_ASSERT_NO_OPERATORS
2 #include <ATen/native/TensorIterator.h>
3 #include <ATen/native/cuda/Reduce.cuh>
4 #include <ATen/native/cuda/ReduceOps.h>
5 #include <ATen/native/DispatchStub.h>
6 #include <ATen/native/SharedReduceOps.h>
7 #include <ATen/Dispatch.h>
8 #include <ATen/cuda/NumericLimits.cuh>
9 #include <ATen/native/ReduceOps.h>
10 #include <ATen/native/ReduceAllOps.h>
11 #include <ATen/native/TensorCompare.h>
12 #include <ATen/NumericUtils.h>
13
14 #include <ATen/Dispatch.h>
15 #include <ATen/NumericUtils.h>
16 #include <ATen/cuda/NumericLimits.cuh>
17
18
19 namespace at::native {
20
21 template <typename acc_t>
22 struct MinNanFunctor {
operator ()at::native::MinNanFunctor23 __device__ __forceinline__ acc_t operator()(acc_t a, acc_t b) const {
24 return (at::_isnan(a) || a < b) ? a : b;
25 }
26 };
27
28 template <typename scalar_t, typename acc_t=scalar_t>
min_values_kernel_cuda_impl(TensorIterator & iter)29 void min_values_kernel_cuda_impl(TensorIterator& iter) {
30 gpu_reduce_kernel<scalar_t, scalar_t>(
31 iter, func_wrapper<acc_t> (MinNanFunctor<acc_t>()),
32 at::numeric_limits<acc_t>::upper_bound());
33 }
34
min_values_kernel_cuda(TensorIterator & iter)35 void min_values_kernel_cuda(TensorIterator& iter) {
36 AT_DISPATCH_ALL_TYPES_AND3(kBFloat16, kHalf, kBool, iter.dtype(), "min_values_cuda", [&]() {
37 min_values_kernel_cuda_impl<scalar_t>(iter);
38 });
39 }
40
min_launch_kernel(TensorIterator & iter)41 void min_launch_kernel(TensorIterator &iter) {
42 AT_DISPATCH_ALL_TYPES_AND3(kBFloat16, kHalf, kBool, iter.input_dtype(), "min_cuda", [&]() {
43 gpu_reduce_kernel<scalar_t, scalar_t>(
44 iter,
45 MinOps<scalar_t>{},
46 thrust::pair<scalar_t, int64_t>(at::numeric_limits<scalar_t>::upper_bound(), 0));
47 });
48 }
49
min_all_launch_kernel(TensorIterator & iter)50 void min_all_launch_kernel(TensorIterator &iter) {
51 AT_DISPATCH_ALL_TYPES_AND3(kBFloat16, kHalf, kBool, iter.input_dtype(), "min_all_cuda", [&] {
52 min_values_kernel_cuda_impl<scalar_t>(iter);
53 });
54 }
55
56 REGISTER_DISPATCH(min_values_stub, &min_values_kernel_cuda);
57
58 } // namespace at::native
59