1 #define TORCH_ASSERT_NO_OPERATORS
2 #include <ATen/AccumulateType.h>
3 #include <ATen/Dispatch.h>
4 #include <ATen/OpMathType.h>
5 #include <ATen/native/DispatchStub.h>
6 #include <ATen/native/TensorIterator.h>
7 #include <ATen/native/UnaryOps.h>
8 #include <ATen/native/cuda/JitLoops.cuh>
9 #include <ATen/native/cuda/Loops.cuh>
10 #include <ATen/native/cuda/Math.cuh>
11 #include <limits>
12
13 namespace at::native {
14
15 #if AT_USE_JITERATOR()
16 CONSTEXPR_EXCEPT_WIN_CUDA char atan_name[] = "atan_impl";
17 #endif
18
atan_kernel_cuda(TensorIteratorBase & iter)19 void atan_kernel_cuda(TensorIteratorBase& iter) {
20 auto common_dtype = iter.common_dtype();
21 if (at::isComplexType(common_dtype)) {
22 #if AT_USE_JITERATOR()
23 static const auto atan_string = jiterator_stringify(
24 template <typename T>
25 T atan_impl(T a) {
26 return std::atan(a);
27 }
28 );
29 AT_DISPATCH_COMPLEX_TYPES_AND(kComplexHalf, common_dtype, "atan_name", [&]() {
30 jitted_gpu_kernel<
31 /*name=*/ atan_name,
32 /*return_dtype=*/ scalar_t,
33 /*common_dtype=*/ scalar_t,
34 /*arity=*/ 1>(iter, atan_string);
35 });
36 #else
37 AT_DISPATCH_COMPLEX_TYPES_AND(kComplexHalf, common_dtype, "atan_name", [&]() {
38 gpu_kernel(iter, []GPU_LAMBDA(scalar_t a) -> scalar_t {
39 using opmath_t = at::opmath_type<scalar_t>;
40 return ::atan(static_cast<opmath_t>(a));
41 });
42 });
43 #endif
44 } else {
45 AT_DISPATCH_FLOATING_TYPES_AND2(
46 ScalarType::Half, ScalarType::BFloat16,
47 common_dtype, "atan_cuda",
48 [&]() {
49 gpu_kernel(iter, []GPU_LAMBDA(scalar_t a) -> scalar_t {
50 return ::atan(a);
51 });
52 });
53 }
54 }
55
56 REGISTER_DISPATCH(atan_stub, &atan_kernel_cuda);
57
58 } // namespace at::native
59