xref: /aosp_15_r20/external/pytorch/aten/src/ATen/native/cuda/Distributions.cu (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #define TORCH_ASSERT_NO_OPERATORS
2 #include <ATen/native/cuda/Distributions.h>
3 #include <ATen/Dispatch.h>
4 #include <ATen/cuda/CUDAApplyUtils.cuh>
5 #include <ATen/AccumulateType.h>
6 #include <ATen/cuda/CUDAGeneratorImpl.h>
7 #include <ATen/native/UnaryOps.h>
8 #include <ATen/native/cuda/DistributionTemplates.h>
9 
10 #include <curand.h>
11 #include <curand_kernel.h>
12 #include <curand_philox4x32_x.h>
13 #include <utility>
14 #include <functional>
15 
16 #include <ATen/native/Distributions.h>
17 #include <ATen/native/cuda/Loops.cuh>
18 #include <ATen/native/TensorIterator.h>
19 
20 #include <cstdint>
21 #include <limits>
22 #include <utility>
23 #include <type_traits>
24 
25 /**
26  * Note [Register spilling in curand call for CUDA < 10]
27  * ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
28  * For CUDA < 10, curandStatePhilox4_32_10_t engine achieves poor performance (60% SOL bandwidth)
29  * when called to generate one random number at a time. This is because the line
30  *            unsigned ret = (&state->output.x)[state->STATE++];
31  * in
32  *            QUALIFIERS unsigned int curand(curandStatePhilox4_32_10_t *state)
33  * in curand_kernel.h dynamically indexes into state.output, preventing the compiler from ever
34  * storing state.output in registers.
35  *
36  * CUDA 10 fixed this problem. However, for backwards compatibility, in the following kernels
37  * we are using curand distributions that utilize curand4 call. curand4 call doesn't have the
38  * register spilling problem.
39  */
40 
41 namespace {
42 
43 template <typename scalar_t>
poisson_cuda_kernel(const at::TensorBase & ret,const at::TensorBase & lambda,at::PhiloxCudaState philox_args)44 void poisson_cuda_kernel(
45     const at::TensorBase &ret,
46     const at::TensorBase &lambda,
47     at::PhiloxCudaState philox_args) {
48   auto functor = [philox_args] __device__(
49           scalar_t & ret_val, const scalar_t& lambda) {
50         CUDA_KERNEL_ASSERT(lambda >= 0 && "invalid Poisson rate, expected rate to be non-negative");
51         auto seeds = at::cuda::philox::unpack(philox_args);
52         curandStatePhilox4_32_10_t state;
53         curand_init(std::get<0>(seeds),
54                     blockIdx.x * blockDim.x + threadIdx.x,
55                     std::get<1>(seeds),
56                     &state);
57         ret_val = static_cast<scalar_t>(curand_poisson(&state, lambda));
58       };
59   at::cuda::CUDA_tensor_apply2<scalar_t, scalar_t, decltype(functor),
60                                /*max_threads_per_block=*/512,
61                                /*min_blocks_per_sm==*/2>(ret, lambda, functor);
62 }
63 
64 struct curand_uniform_wrapper {
65   curandStatePhilox4_32_10_t &state;
curand_uniform_wrapper__anon92901c7f0111::curand_uniform_wrapper66   __device__ curand_uniform_wrapper(curandStatePhilox4_32_10_t &state): state(state) {}
operator ()__anon92901c7f0111::curand_uniform_wrapper67   __device__ float operator()() {
68 
69   uint32_t val = curand(&state); //need just bits
70   constexpr auto MASK = static_cast<uint32_t>((static_cast<uint64_t>(1) << std::numeric_limits<float>::digits) - 1);
71   constexpr auto DIVISOR = static_cast<float>(1) / (static_cast<uint32_t>(1) << std::numeric_limits<float>::digits);
72     return (val & MASK) * DIVISOR;
73   }
74 };
75 
76 template <typename scalar_t>
binomial_cuda_kernel(at::TensorIteratorBase & iter,at::PhiloxCudaState philox_args)77 void binomial_cuda_kernel(
78     at::TensorIteratorBase &iter,
79     at::PhiloxCudaState philox_args) {
80   using accscalar_t = at::acc_type<scalar_t, true>;
81 
82   at::native::distribution_binary_kernel(iter, philox_args,
83       [] GPU_LAMBDA (curandStatePhilox4_32_10_t& state, scalar_t count, scalar_t prob) {
84         #if defined(__CUDA_ARCH__) || defined(USE_ROCM)
85         auto uniform_lambda = curand_uniform_wrapper(state);
86         BaseSampler<accscalar_t, decltype(uniform_lambda)> standard_uniform(uniform_lambda);
87         auto sample = sample_binomial<scalar_t, accscalar_t, decltype(uniform_lambda)>(count, prob, standard_uniform);
88         return static_cast<scalar_t>(sample);
89         #else
90         return count; // useless.
91         #endif
92       }
93   );
94 }
95 
96 template <typename scalar_t>
gamma_cuda_kernel(const at::TensorBase & ret,const at::TensorBase & alpha,at::PhiloxCudaState philox_args)97 void gamma_cuda_kernel(
98     const at::TensorBase &ret,
99     const at::TensorBase &alpha,
100     at::PhiloxCudaState philox_args) {
101   using accscalar_t = at::acc_type<scalar_t, true>;
102   auto functor = [philox_args] __device__(
103           scalar_t & ret_val, const scalar_t& alpha) {
104         auto seeds = at::cuda::philox::unpack(philox_args);
105         curandStatePhilox4_32_10_t state;
106         curand_init(std::get<0>(seeds),
107                     blockIdx.x * blockDim.x + threadIdx.x,
108                     std::get<1>(seeds),
109                     &state);
110 
111         auto uniform_lambda = [&state] __device__ () {
112           return curand_uniform(&state);
113         };
114         BaseSampler<accscalar_t, decltype(uniform_lambda)> standard_uniform(uniform_lambda);
115 
116         auto normal_lambda = [&state] __device__ () {
117           return curand_normal(&state);
118         };
119         BaseSampler<accscalar_t, decltype(normal_lambda)> standard_normal(normal_lambda);
120         auto sample = sample_gamma<scalar_t, accscalar_t, decltype(uniform_lambda), decltype(normal_lambda)>(alpha, standard_uniform, standard_normal);
121         auto min_value = std::numeric_limits<scalar_t>::min();
122         ret_val = (min_value > sample) ? min_value : sample;
123       };
124   at::cuda::CUDA_tensor_apply2<scalar_t, scalar_t, decltype(functor),
125                                /*max_threads_per_block=*/256,
126                                /*min_blocks_per_sm==*/2>(ret, alpha, functor);
127 }
128 
129 } // namespace
130 
131 namespace at::native {
132 
launch_dirichlet_kernel(at::TensorIteratorBase & iter)133 void launch_dirichlet_kernel(at::TensorIteratorBase &iter) {
134   AT_DISPATCH_FLOATING_TYPES_AND2(at::ScalarType::Half, at::ScalarType::BFloat16,
135                                   iter.input_dtype(), "dirichlet_cuda", [&] {
136     at::native::gpu_kernel(
137         iter,
138         [] GPU_LAMBDA (scalar_t gamma, scalar_t gamma_sum) {
139       auto ret_val = gamma / gamma_sum;
140       auto min_value = std::numeric_limits<scalar_t>::min();
141       auto max_value = 1 - std::numeric_limits<scalar_t>::epsilon();
142       ret_val = (min_value > ret_val) ? min_value : ret_val;
143       ret_val = (max_value < ret_val) ? max_value : ret_val;
144       return ret_val;
145     });
146   });
147 }
148 
launch_poisson_cuda_kernel(const TensorBase & ret,const TensorBase & lambda,CUDAGeneratorImpl * gen)149 void launch_poisson_cuda_kernel(
150     const TensorBase &ret, const TensorBase &lambda, CUDAGeneratorImpl *gen) {
151   PhiloxCudaState rng_engine_inputs;
152   {
153     // See Note [Acquire lock when using random generators]
154     std::lock_guard<std::mutex> lock(gen->mutex_);
155     rng_engine_inputs = gen->philox_cuda_state(20);
156   }
157   AT_DISPATCH_FLOATING_TYPES_AND2(at::ScalarType::Half, at::ScalarType::BFloat16, ret.scalar_type(), "poisson_cuda", [&] {
158     poisson_cuda_kernel<scalar_t>(ret, lambda, rng_engine_inputs);
159   });
160 }
161 
launch_binomial_cuda_kernel(TensorIteratorBase & iter,CUDAGeneratorImpl * gen)162 void launch_binomial_cuda_kernel(
163     TensorIteratorBase &iter, CUDAGeneratorImpl *gen) {
164   PhiloxCudaState rng_engine_inputs;
165   {
166     // See Note [Acquire lock when using random generators]
167     std::lock_guard<std::mutex> lock(gen->mutex_);
168     rng_engine_inputs = gen->philox_cuda_state(42);
169   }
170   AT_DISPATCH_FLOATING_TYPES_AND2(at::ScalarType::Half, at::ScalarType::BFloat16, iter.input_dtype(), "binomial_cuda", [&] {
171     binomial_cuda_kernel<scalar_t>(iter, rng_engine_inputs);
172   });
173 }
174 
launch_gamma_kernel(const TensorBase & ret,const TensorBase & alpha,CUDAGeneratorImpl * gen)175 void launch_gamma_kernel(
176     const TensorBase &ret, const TensorBase &alpha, CUDAGeneratorImpl *gen) {
177   PhiloxCudaState rng_engine_inputs;
178   {
179     // See Note [Acquire lock when using random generators]
180     std::lock_guard<std::mutex> lock(gen->mutex_);
181     rng_engine_inputs = gen->philox_cuda_state(10);
182   }
183   AT_DISPATCH_FLOATING_TYPES_AND2(at::ScalarType::Half, at::ScalarType::BFloat16, ret.scalar_type(), "gamma_cuda", [&] {
184      gamma_cuda_kernel<scalar_t>(ret, alpha, rng_engine_inputs);
185    });
186 }
187 
launch_standard_gamma_grad_kernel(TensorIteratorBase & iter)188 void launch_standard_gamma_grad_kernel(TensorIteratorBase &iter) {
189   AT_DISPATCH_FLOATING_TYPES_AND2(at::ScalarType::Half, at::ScalarType::BFloat16, iter.input_dtype(), "_standard_gamma_grad_cuda", [&] {
190     using accscalar_t = at::acc_type<scalar_t, true>;
191     gpu_kernel(iter,
192       [] GPU_LAMBDA (scalar_t self_val, scalar_t output_val) {
193         return standard_gamma_grad_one<scalar_t, accscalar_t>(self_val, output_val);
194       });
195   });
196 }
197 
launch_dirichlet_grad_kernel(TensorIteratorBase & iter)198 void launch_dirichlet_grad_kernel(TensorIteratorBase &iter) {
199   AT_DISPATCH_FLOATING_TYPES(iter.input_dtype(), "_dirichlet_grad_cuda", [&] {
200     using accscalar_t = at::acc_type<scalar_t, true>;
201     at::native::gpu_kernel(iter,
202       [] GPU_LAMBDA (scalar_t x_val, scalar_t alpha_val, scalar_t total_val) -> scalar_t {
203         return dirichlet_grad_one<scalar_t, accscalar_t>(x_val, alpha_val, total_val);
204       });
205   });
206 }
207 
208 } // namespace at::native
209