xref: /aosp_15_r20/external/pytorch/aten/src/ATen/native/cuda/UnaryGeometricAtanKernel.cu (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #define TORCH_ASSERT_NO_OPERATORS
2 #include <ATen/AccumulateType.h>
3 #include <ATen/Dispatch.h>
4 #include <ATen/OpMathType.h>
5 #include <ATen/native/DispatchStub.h>
6 #include <ATen/native/TensorIterator.h>
7 #include <ATen/native/UnaryOps.h>
8 #include <ATen/native/cuda/JitLoops.cuh>
9 #include <ATen/native/cuda/Loops.cuh>
10 #include <ATen/native/cuda/Math.cuh>
11 #include <limits>
12 
13 namespace at::native {
14 
15 #if AT_USE_JITERATOR()
16 CONSTEXPR_EXCEPT_WIN_CUDA char atan_name[] = "atan_impl";
17 #endif
18 
atan_kernel_cuda(TensorIteratorBase & iter)19 void atan_kernel_cuda(TensorIteratorBase& iter) {
20   auto common_dtype = iter.common_dtype();
21   if (at::isComplexType(common_dtype)) {
22 #if AT_USE_JITERATOR()
23   static const auto atan_string = jiterator_stringify(
24     template <typename T>
25     T atan_impl(T a) {
26         return std::atan(a);
27     }
28   );
29   AT_DISPATCH_COMPLEX_TYPES_AND(kComplexHalf, common_dtype, "atan_name", [&]() {
30     jitted_gpu_kernel<
31         /*name=*/ atan_name,
32         /*return_dtype=*/ scalar_t,
33         /*common_dtype=*/ scalar_t,
34         /*arity=*/ 1>(iter, atan_string);
35   });
36 #else
37   AT_DISPATCH_COMPLEX_TYPES_AND(kComplexHalf, common_dtype, "atan_name", [&]() {
38     gpu_kernel(iter, []GPU_LAMBDA(scalar_t a) -> scalar_t {
39         using opmath_t = at::opmath_type<scalar_t>;
40         return ::atan(static_cast<opmath_t>(a));
41     });
42   });
43 #endif
44   } else {
45   AT_DISPATCH_FLOATING_TYPES_AND2(
46       ScalarType::Half, ScalarType::BFloat16,
47       common_dtype, "atan_cuda",
48       [&]() {
49         gpu_kernel(iter, []GPU_LAMBDA(scalar_t a) -> scalar_t {
50           return ::atan(a);
51         });
52       });
53   }
54 }
55 
56 REGISTER_DISPATCH(atan_stub, &atan_kernel_cuda);
57 
58 } // namespace at::native
59