1 #define TORCH_ASSERT_NO_OPERATORS
2 #include <ATen/native/cuda/Distributions.h>
3 #include <ATen/Dispatch.h>
4 #include <ATen/cuda/CUDAApplyUtils.cuh>
5 #include <ATen/AccumulateType.h>
6 #include <ATen/cuda/CUDAGeneratorImpl.h>
7 #include <ATen/native/UnaryOps.h>
8 #include <ATen/native/cuda/DistributionTemplates.h>
9
10 #include <curand.h>
11 #include <curand_kernel.h>
12 #include <curand_philox4x32_x.h>
13 #include <utility>
14 #include <functional>
15
16 #include <ATen/native/Distributions.h>
17 #include <ATen/native/cuda/Loops.cuh>
18 #include <ATen/native/TensorIterator.h>
19
20 #include <cstdint>
21 #include <limits>
22 #include <utility>
23 #include <type_traits>
24
25 /**
26 * Note [Register spilling in curand call for CUDA < 10]
27 * ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
28 * For CUDA < 10, curandStatePhilox4_32_10_t engine achieves poor performance (60% SOL bandwidth)
29 * when called to generate one random number at a time. This is because the line
30 * unsigned ret = (&state->output.x)[state->STATE++];
31 * in
32 * QUALIFIERS unsigned int curand(curandStatePhilox4_32_10_t *state)
33 * in curand_kernel.h dynamically indexes into state.output, preventing the compiler from ever
34 * storing state.output in registers.
35 *
36 * CUDA 10 fixed this problem. However, for backwards compatibility, in the following kernels
37 * we are using curand distributions that utilize curand4 call. curand4 call doesn't have the
38 * register spilling problem.
39 */
40
41 namespace {
42
43 template <typename scalar_t>
poisson_cuda_kernel(const at::TensorBase & ret,const at::TensorBase & lambda,at::PhiloxCudaState philox_args)44 void poisson_cuda_kernel(
45 const at::TensorBase &ret,
46 const at::TensorBase &lambda,
47 at::PhiloxCudaState philox_args) {
48 auto functor = [philox_args] __device__(
49 scalar_t & ret_val, const scalar_t& lambda) {
50 CUDA_KERNEL_ASSERT(lambda >= 0 && "invalid Poisson rate, expected rate to be non-negative");
51 auto seeds = at::cuda::philox::unpack(philox_args);
52 curandStatePhilox4_32_10_t state;
53 curand_init(std::get<0>(seeds),
54 blockIdx.x * blockDim.x + threadIdx.x,
55 std::get<1>(seeds),
56 &state);
57 ret_val = static_cast<scalar_t>(curand_poisson(&state, lambda));
58 };
59 at::cuda::CUDA_tensor_apply2<scalar_t, scalar_t, decltype(functor),
60 /*max_threads_per_block=*/512,
61 /*min_blocks_per_sm==*/2>(ret, lambda, functor);
62 }
63
64 struct curand_uniform_wrapper {
65 curandStatePhilox4_32_10_t &state;
curand_uniform_wrapper__anon92901c7f0111::curand_uniform_wrapper66 __device__ curand_uniform_wrapper(curandStatePhilox4_32_10_t &state): state(state) {}
operator ()__anon92901c7f0111::curand_uniform_wrapper67 __device__ float operator()() {
68
69 uint32_t val = curand(&state); //need just bits
70 constexpr auto MASK = static_cast<uint32_t>((static_cast<uint64_t>(1) << std::numeric_limits<float>::digits) - 1);
71 constexpr auto DIVISOR = static_cast<float>(1) / (static_cast<uint32_t>(1) << std::numeric_limits<float>::digits);
72 return (val & MASK) * DIVISOR;
73 }
74 };
75
76 template <typename scalar_t>
binomial_cuda_kernel(at::TensorIteratorBase & iter,at::PhiloxCudaState philox_args)77 void binomial_cuda_kernel(
78 at::TensorIteratorBase &iter,
79 at::PhiloxCudaState philox_args) {
80 using accscalar_t = at::acc_type<scalar_t, true>;
81
82 at::native::distribution_binary_kernel(iter, philox_args,
83 [] GPU_LAMBDA (curandStatePhilox4_32_10_t& state, scalar_t count, scalar_t prob) {
84 #if defined(__CUDA_ARCH__) || defined(USE_ROCM)
85 auto uniform_lambda = curand_uniform_wrapper(state);
86 BaseSampler<accscalar_t, decltype(uniform_lambda)> standard_uniform(uniform_lambda);
87 auto sample = sample_binomial<scalar_t, accscalar_t, decltype(uniform_lambda)>(count, prob, standard_uniform);
88 return static_cast<scalar_t>(sample);
89 #else
90 return count; // useless.
91 #endif
92 }
93 );
94 }
95
96 template <typename scalar_t>
gamma_cuda_kernel(const at::TensorBase & ret,const at::TensorBase & alpha,at::PhiloxCudaState philox_args)97 void gamma_cuda_kernel(
98 const at::TensorBase &ret,
99 const at::TensorBase &alpha,
100 at::PhiloxCudaState philox_args) {
101 using accscalar_t = at::acc_type<scalar_t, true>;
102 auto functor = [philox_args] __device__(
103 scalar_t & ret_val, const scalar_t& alpha) {
104 auto seeds = at::cuda::philox::unpack(philox_args);
105 curandStatePhilox4_32_10_t state;
106 curand_init(std::get<0>(seeds),
107 blockIdx.x * blockDim.x + threadIdx.x,
108 std::get<1>(seeds),
109 &state);
110
111 auto uniform_lambda = [&state] __device__ () {
112 return curand_uniform(&state);
113 };
114 BaseSampler<accscalar_t, decltype(uniform_lambda)> standard_uniform(uniform_lambda);
115
116 auto normal_lambda = [&state] __device__ () {
117 return curand_normal(&state);
118 };
119 BaseSampler<accscalar_t, decltype(normal_lambda)> standard_normal(normal_lambda);
120 auto sample = sample_gamma<scalar_t, accscalar_t, decltype(uniform_lambda), decltype(normal_lambda)>(alpha, standard_uniform, standard_normal);
121 auto min_value = std::numeric_limits<scalar_t>::min();
122 ret_val = (min_value > sample) ? min_value : sample;
123 };
124 at::cuda::CUDA_tensor_apply2<scalar_t, scalar_t, decltype(functor),
125 /*max_threads_per_block=*/256,
126 /*min_blocks_per_sm==*/2>(ret, alpha, functor);
127 }
128
129 } // namespace
130
131 namespace at::native {
132
launch_dirichlet_kernel(at::TensorIteratorBase & iter)133 void launch_dirichlet_kernel(at::TensorIteratorBase &iter) {
134 AT_DISPATCH_FLOATING_TYPES_AND2(at::ScalarType::Half, at::ScalarType::BFloat16,
135 iter.input_dtype(), "dirichlet_cuda", [&] {
136 at::native::gpu_kernel(
137 iter,
138 [] GPU_LAMBDA (scalar_t gamma, scalar_t gamma_sum) {
139 auto ret_val = gamma / gamma_sum;
140 auto min_value = std::numeric_limits<scalar_t>::min();
141 auto max_value = 1 - std::numeric_limits<scalar_t>::epsilon();
142 ret_val = (min_value > ret_val) ? min_value : ret_val;
143 ret_val = (max_value < ret_val) ? max_value : ret_val;
144 return ret_val;
145 });
146 });
147 }
148
launch_poisson_cuda_kernel(const TensorBase & ret,const TensorBase & lambda,CUDAGeneratorImpl * gen)149 void launch_poisson_cuda_kernel(
150 const TensorBase &ret, const TensorBase &lambda, CUDAGeneratorImpl *gen) {
151 PhiloxCudaState rng_engine_inputs;
152 {
153 // See Note [Acquire lock when using random generators]
154 std::lock_guard<std::mutex> lock(gen->mutex_);
155 rng_engine_inputs = gen->philox_cuda_state(20);
156 }
157 AT_DISPATCH_FLOATING_TYPES_AND2(at::ScalarType::Half, at::ScalarType::BFloat16, ret.scalar_type(), "poisson_cuda", [&] {
158 poisson_cuda_kernel<scalar_t>(ret, lambda, rng_engine_inputs);
159 });
160 }
161
launch_binomial_cuda_kernel(TensorIteratorBase & iter,CUDAGeneratorImpl * gen)162 void launch_binomial_cuda_kernel(
163 TensorIteratorBase &iter, CUDAGeneratorImpl *gen) {
164 PhiloxCudaState rng_engine_inputs;
165 {
166 // See Note [Acquire lock when using random generators]
167 std::lock_guard<std::mutex> lock(gen->mutex_);
168 rng_engine_inputs = gen->philox_cuda_state(42);
169 }
170 AT_DISPATCH_FLOATING_TYPES_AND2(at::ScalarType::Half, at::ScalarType::BFloat16, iter.input_dtype(), "binomial_cuda", [&] {
171 binomial_cuda_kernel<scalar_t>(iter, rng_engine_inputs);
172 });
173 }
174
launch_gamma_kernel(const TensorBase & ret,const TensorBase & alpha,CUDAGeneratorImpl * gen)175 void launch_gamma_kernel(
176 const TensorBase &ret, const TensorBase &alpha, CUDAGeneratorImpl *gen) {
177 PhiloxCudaState rng_engine_inputs;
178 {
179 // See Note [Acquire lock when using random generators]
180 std::lock_guard<std::mutex> lock(gen->mutex_);
181 rng_engine_inputs = gen->philox_cuda_state(10);
182 }
183 AT_DISPATCH_FLOATING_TYPES_AND2(at::ScalarType::Half, at::ScalarType::BFloat16, ret.scalar_type(), "gamma_cuda", [&] {
184 gamma_cuda_kernel<scalar_t>(ret, alpha, rng_engine_inputs);
185 });
186 }
187
launch_standard_gamma_grad_kernel(TensorIteratorBase & iter)188 void launch_standard_gamma_grad_kernel(TensorIteratorBase &iter) {
189 AT_DISPATCH_FLOATING_TYPES_AND2(at::ScalarType::Half, at::ScalarType::BFloat16, iter.input_dtype(), "_standard_gamma_grad_cuda", [&] {
190 using accscalar_t = at::acc_type<scalar_t, true>;
191 gpu_kernel(iter,
192 [] GPU_LAMBDA (scalar_t self_val, scalar_t output_val) {
193 return standard_gamma_grad_one<scalar_t, accscalar_t>(self_val, output_val);
194 });
195 });
196 }
197
launch_dirichlet_grad_kernel(TensorIteratorBase & iter)198 void launch_dirichlet_grad_kernel(TensorIteratorBase &iter) {
199 AT_DISPATCH_FLOATING_TYPES(iter.input_dtype(), "_dirichlet_grad_cuda", [&] {
200 using accscalar_t = at::acc_type<scalar_t, true>;
201 at::native::gpu_kernel(iter,
202 [] GPU_LAMBDA (scalar_t x_val, scalar_t alpha_val, scalar_t total_val) -> scalar_t {
203 return dirichlet_grad_one<scalar_t, accscalar_t>(x_val, alpha_val, total_val);
204 });
205 });
206 }
207
208 } // namespace at::native
209