1 #define TORCH_ASSERT_ONLY_METHOD_OPERATORS
2 #include <ATen/native/cuda/Distributions.h>
3 #include <ATen/TensorIterator.h>
4 #include <ATen/cuda/CUDAGeneratorImpl.h>
5
6 #ifndef AT_PER_OPERATOR_HEADERS
7 #include <ATen/Functions.h>
8 #include <ATen/NativeFunctions.h>
9 #else
10 #include <ATen/ops/_dirichlet_grad_native.h>
11 #include <ATen/ops/_sample_dirichlet_native.h>
12 #include <ATen/ops/_standard_gamma_grad_native.h>
13 #include <ATen/ops/_standard_gamma_native.h>
14 #include <ATen/ops/binomial_native.h>
15 #include <ATen/ops/empty.h>
16 #include <ATen/ops/poisson_native.h>
17 #endif
18
19 namespace at::native {
20
_s_poisson_cuda(const Tensor & lambda,std::optional<Generator> gen_)21 Tensor _s_poisson_cuda(const Tensor& lambda, std::optional<Generator> gen_) {
22 auto gen = get_generator_or_default<CUDAGeneratorImpl>(gen_, cuda::detail::getDefaultCUDAGenerator());
23 Tensor ret = at::empty(lambda.sizes(), lambda.options());
24 launch_poisson_cuda_kernel(ret, lambda, gen);
25 return ret;
26 }
27
_s_binomial_cuda(const Tensor & count,const Tensor & prob,std::optional<Generator> gen_)28 Tensor _s_binomial_cuda(const Tensor& count, const Tensor& prob, std::optional<Generator> gen_) {
29 auto gen = get_generator_or_default<CUDAGeneratorImpl>(gen_, cuda::detail::getDefaultCUDAGenerator());
30 Tensor ret = at::empty(count.sizes(), count.options());
31 at::TensorIterator iter = at::TensorIteratorConfig()
32 .add_output(ret)
33 .add_input(count)
34 .add_input(prob)
35 .build();
36 launch_binomial_cuda_kernel(iter, gen);
37 return ret;
38 }
39
_s_gamma_cuda(const Tensor & alpha,std::optional<Generator> gen_)40 Tensor _s_gamma_cuda(const Tensor& alpha, std::optional<Generator> gen_) {
41 auto gen = get_generator_or_default<CUDAGeneratorImpl>(gen_, cuda::detail::getDefaultCUDAGenerator());
42 Tensor ret = at::empty(alpha.sizes(), alpha.options());
43 launch_gamma_kernel(ret, alpha, gen);
44 return ret;
45 }
46
_s_dirichlet_cuda(const Tensor & alpha,std::optional<Generator> gen_)47 Tensor _s_dirichlet_cuda(const Tensor& alpha, std::optional<Generator> gen_) {
48 auto gen = get_generator_or_default<CUDAGeneratorImpl>(gen_, cuda::detail::getDefaultCUDAGenerator());
49 Tensor ret = at::empty(alpha.sizes(), alpha.options());
50 launch_gamma_kernel(ret, alpha, gen);
51 auto gamma_sum = ret.sum(/*dim=*/-1, /*keepdim=*/true);
52 at::TensorIterator iter = at::TensorIteratorConfig()
53 .add_output(ret)
54 .add_input(ret)
55 .add_input(gamma_sum)
56 .build();
57 launch_dirichlet_kernel(iter);
58 return ret;
59 }
60
_standard_gamma_grad_cuda(const Tensor & self,const Tensor & output)61 Tensor _standard_gamma_grad_cuda(const Tensor& self, const Tensor& output) {
62 Tensor ret = at::empty(self.sizes(), self.options());
63 TensorIterator iter = at::TensorIteratorConfig()
64 .add_output(ret)
65 .add_input(self)
66 .add_input(output)
67 .build();
68 launch_standard_gamma_grad_kernel(iter);
69 return ret;
70 }
71
_dirichlet_grad_cuda(const Tensor & x,const Tensor & alpha,const Tensor & total)72 Tensor _dirichlet_grad_cuda(const Tensor& x, const Tensor& alpha, const Tensor& total) {
73 Tensor ret = at::empty(x.sizes(), x.options());
74 TensorIterator iter = at::TensorIteratorConfig()
75 .add_output(ret)
76 .add_input(x)
77 .add_input(alpha)
78 .add_input(total)
79 .build();
80 launch_dirichlet_grad_kernel(iter);
81 return ret;
82 }
83
84 } // namespace at::native
85