xref: /aosp_15_r20/external/pytorch/aten/src/ATen/native/cuda/Distributions.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #define TORCH_ASSERT_ONLY_METHOD_OPERATORS
2 #include <ATen/native/cuda/Distributions.h>
3 #include <ATen/TensorIterator.h>
4 #include <ATen/cuda/CUDAGeneratorImpl.h>
5 
6 #ifndef AT_PER_OPERATOR_HEADERS
7 #include <ATen/Functions.h>
8 #include <ATen/NativeFunctions.h>
9 #else
10 #include <ATen/ops/_dirichlet_grad_native.h>
11 #include <ATen/ops/_sample_dirichlet_native.h>
12 #include <ATen/ops/_standard_gamma_grad_native.h>
13 #include <ATen/ops/_standard_gamma_native.h>
14 #include <ATen/ops/binomial_native.h>
15 #include <ATen/ops/empty.h>
16 #include <ATen/ops/poisson_native.h>
17 #endif
18 
19 namespace at::native {
20 
_s_poisson_cuda(const Tensor & lambda,std::optional<Generator> gen_)21 Tensor _s_poisson_cuda(const Tensor& lambda, std::optional<Generator> gen_) {
22   auto gen = get_generator_or_default<CUDAGeneratorImpl>(gen_, cuda::detail::getDefaultCUDAGenerator());
23   Tensor ret = at::empty(lambda.sizes(), lambda.options());
24   launch_poisson_cuda_kernel(ret, lambda, gen);
25   return ret;
26 }
27 
_s_binomial_cuda(const Tensor & count,const Tensor & prob,std::optional<Generator> gen_)28 Tensor _s_binomial_cuda(const Tensor& count, const Tensor& prob, std::optional<Generator> gen_) {
29   auto gen = get_generator_or_default<CUDAGeneratorImpl>(gen_, cuda::detail::getDefaultCUDAGenerator());
30   Tensor ret = at::empty(count.sizes(), count.options());
31   at::TensorIterator iter = at::TensorIteratorConfig()
32       .add_output(ret)
33       .add_input(count)
34       .add_input(prob)
35       .build();
36   launch_binomial_cuda_kernel(iter, gen);
37   return ret;
38 }
39 
_s_gamma_cuda(const Tensor & alpha,std::optional<Generator> gen_)40 Tensor _s_gamma_cuda(const Tensor& alpha, std::optional<Generator> gen_) {
41   auto gen = get_generator_or_default<CUDAGeneratorImpl>(gen_, cuda::detail::getDefaultCUDAGenerator());
42   Tensor ret = at::empty(alpha.sizes(), alpha.options());
43   launch_gamma_kernel(ret, alpha, gen);
44   return ret;
45 }
46 
_s_dirichlet_cuda(const Tensor & alpha,std::optional<Generator> gen_)47 Tensor _s_dirichlet_cuda(const Tensor& alpha, std::optional<Generator> gen_) {
48   auto gen = get_generator_or_default<CUDAGeneratorImpl>(gen_, cuda::detail::getDefaultCUDAGenerator());
49   Tensor ret = at::empty(alpha.sizes(), alpha.options());
50   launch_gamma_kernel(ret, alpha, gen);
51   auto gamma_sum = ret.sum(/*dim=*/-1, /*keepdim=*/true);
52   at::TensorIterator iter = at::TensorIteratorConfig()
53       .add_output(ret)
54       .add_input(ret)
55       .add_input(gamma_sum)
56       .build();
57   launch_dirichlet_kernel(iter);
58   return ret;
59 }
60 
_standard_gamma_grad_cuda(const Tensor & self,const Tensor & output)61 Tensor _standard_gamma_grad_cuda(const Tensor& self, const Tensor& output) {
62   Tensor ret = at::empty(self.sizes(), self.options());
63   TensorIterator iter = at::TensorIteratorConfig()
64       .add_output(ret)
65       .add_input(self)
66       .add_input(output)
67       .build();
68   launch_standard_gamma_grad_kernel(iter);
69   return ret;
70 }
71 
_dirichlet_grad_cuda(const Tensor & x,const Tensor & alpha,const Tensor & total)72 Tensor _dirichlet_grad_cuda(const Tensor& x, const Tensor& alpha, const Tensor& total) {
73   Tensor ret = at::empty(x.sizes(), x.options());
74   TensorIterator iter = at::TensorIteratorConfig()
75       .add_output(ret)
76       .add_input(x)
77       .add_input(alpha)
78       .add_input(total)
79       .build();
80   launch_dirichlet_grad_kernel(iter);
81   return ret;
82 }
83 
84 } // namespace at::native
85