xref: /aosp_15_r20/external/pytorch/aten/src/ATen/native/cuda/Distributions.h (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #pragma once
2 
3 namespace at {
4 struct CUDAGeneratorImpl;
5 struct TensorIteratorBase;
6 class TensorBase;
7 
8 namespace native {
9 
10 void launch_poisson_cuda_kernel(
11     const TensorBase &ret, const TensorBase &lambda, CUDAGeneratorImpl *gen);
12 
13 void launch_gamma_kernel(
14     const TensorBase &ret, const TensorBase &alpha, CUDAGeneratorImpl *gen);
15 
16 void launch_binomial_cuda_kernel(
17     TensorIteratorBase &iter, CUDAGeneratorImpl *gen);
18 
19 void launch_dirichlet_kernel(TensorIteratorBase &iter);
20 
21 void launch_standard_gamma_grad_kernel(TensorIteratorBase &iter);
22 
23 void launch_dirichlet_grad_kernel(TensorIteratorBase &iter);
24 
25 }}  // namespace at::native
26