xref: /aosp_15_r20/external/pytorch/aten/src/ATen/native/cuda/BinaryGeometricKernels.cu (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #define TORCH_ASSERT_NO_OPERATORS
2 #include <ATen/Dispatch.h>
3 #include <ATen/native/DispatchStub.h>
4 #include <ATen/native/cuda/Loops.cuh>
5 #include <ATen/native/TensorIterator.h>
6 #include <ATen/native/BinaryOps.h>
7 
8 // NOTE: CUDA on Windows requires that the enclosing function
9 // of a __device__ lambda not have internal linkage.
10 
11 namespace at::native {
12 
atan2_kernel_cuda(TensorIteratorBase & iter)13 void atan2_kernel_cuda(TensorIteratorBase& iter) {
14   AT_DISPATCH_FLOATING_TYPES_AND2(
15       at::ScalarType::Half, at::ScalarType::BFloat16,
16       iter.common_dtype(), "atan2_cuda",
17       [&]() {
18         gpu_kernel_with_scalars(iter, []GPU_LAMBDA(scalar_t a, scalar_t b) -> scalar_t {
19           return ::atan2(a, b);
20         });
21       });
22 }
23 
hypot_kernel_cuda(TensorIteratorBase & iter)24 void hypot_kernel_cuda(TensorIteratorBase& iter) {
25   AT_DISPATCH_FLOATING_TYPES_AND2(
26       at::ScalarType::Half, at::ScalarType::BFloat16,
27       iter.common_dtype(), "hypot_cuda",
28       [&]() {
29         opmath_symmetric_gpu_kernel_with_scalars<scalar_t>(
30             iter, []GPU_LAMBDA(scalar_t a, scalar_t b) -> scalar_t {
31           return ::hypot(a, b);
32         });
33       });
34 }
35 
36 REGISTER_DISPATCH(atan2_stub, &atan2_kernel_cuda);
37 REGISTER_DISPATCH(hypot_stub, &hypot_kernel_cuda);
38 
39 } // namespace at::native
40