1 #pragma once
2
3 #include <ATen/AccumulateType.h>
4 #include <ATen/Dispatch.h>
5 #include <ATen/Dispatch_v2.h>
6 #include <ATen/ExpandBase.h>
7 #include <ATen/OpMathType.h>
8 #include <ATen/native/TensorIterator.h>
9 #include <ATen/native/cuda/Loops.cuh>
10 #include <c10/util/Half.h>
11 #include <ATen/cuda/CUDAApplyUtils.cuh>
12 #include <ATen/cuda/CUDAContext.h>
13 #include <ATen/cuda/detail/OffsetCalculator.cuh>
14 #include <ATen/cuda/CUDAGraphsUtils.cuh>
15 #include <ATen/detail/FunctionTraits.h>
16 #include <ATen/core/DistributionsHelper.h>
17
18 #include <curand.h>
19 #include <curand_kernel.h>
20 #include <curand_philox4x32_x.h>
21 #include <cstdint>
22 #include <limits>
23 #include <utility>
24 #include <mutex>
25 #include <tuple>
26 #include <type_traits>
27
28 namespace at {
29 namespace native {
30 namespace {
31
32 // launch bounds used for kernels utilizing TensorIterator
33 const uint32_t block_size_bound = 256;
34 const uint32_t grid_size_bound = 4;
35 // At the time of writing, there is no curand_* call that increments the offset by more than 4.
36 // See: https://docs.nvidia.com/cuda/archive/11.8.0/curand/group__DEVICE.html
37 const uint32_t max_generator_offsets_per_curand_call = 4;
38
39 // utility function that calculates proper philox_offset
40 // for distributions utilizing TensorIterator. For distributions using
41 // TensorIterator, we are using a grid-stride loop with each
42 // thread yielding one element per thread. For the edge of the grid-stride
43 // loop, if the tensor size is large, the unroll loop will kick in and the float4
44 // from curand4 will start getting utilized (for common tensor sizes, we end up
45 // using rand.x from each thread). The philox_offset calculation was changed to
46 // (number of elements per thread * maximum generator increment per "curand_*" call), which makes
47 // sure that philox offset increment is not less than the number of randoms used
48 // in each thread.
calc_execution_policy(const int64_t total_elements,const uint32_t unroll_factor)49 std::tuple<uint64_t, dim3, dim3> calc_execution_policy(const int64_t total_elements, const uint32_t unroll_factor) {
50 const uint64_t numel = static_cast<uint64_t>(total_elements);
51 const uint32_t block_size = block_size_bound;
52 dim3 dim_block(block_size);
53 dim3 grid((numel + block_size - 1) / block_size);
54 uint32_t blocks_per_sm = at::cuda::getCurrentDeviceProperties()->maxThreadsPerMultiProcessor / block_size;
55 grid.x = std::min(
56 static_cast<uint32_t>(at::cuda::getCurrentDeviceProperties()->multiProcessorCount) * blocks_per_sm,
57 grid.x);
58 //number of times random will be generated per thread, to offset philox counter in thc random state
59 uint64_t counter_offset = ((numel - 1) / (block_size * grid.x * unroll_factor) + 1) * max_generator_offsets_per_curand_call;
60 return std::make_tuple(counter_offset, grid, dim_block);
61 }
62
63 // grid stride loop kernel for distributions
64 template<typename accscalar_t, int unroll_factor, typename dist_t, typename transform_t>
C10_LAUNCH_BOUNDS_2(block_size_bound,grid_size_bound)65 C10_LAUNCH_BOUNDS_2(block_size_bound, grid_size_bound)
66 __global__ void distribution_elementwise_grid_stride_kernel(int numel,
67 PhiloxCudaState philox_args,
68 const dist_t dist_func,
69 const transform_t transform_func) {
70 auto seeds = at::cuda::philox::unpack(philox_args);
71 int idx = blockIdx.x * blockDim.x + threadIdx.x;
72 curandStatePhilox4_32_10_t state;
73 curand_init(std::get<0>(seeds),
74 idx,
75 std::get<1>(seeds),
76 &state);
77
78 int rounded_size = ((numel - 1)/(blockDim.x * gridDim.x * unroll_factor)+1) *
79 blockDim.x * gridDim.x * unroll_factor;
80 for(int linear_index = idx; linear_index < rounded_size; linear_index += blockDim.x * gridDim.x * unroll_factor) {
81 auto rand = dist_func(&state);
82 #pragma unroll
83 for (int ii = 0; ii < unroll_factor; ii++) {
84 int li = linear_index + blockDim.x * gridDim.x * ii;
85 if (li < numel) {
86 transform_func(li, static_cast<accscalar_t>((&rand.x)[ii]));
87 }
88 }
89 __syncthreads();
90 }
91 }
92
93 /**
94 * distribution_nullary_kernel is analogous to gpu_kernel in
95 * ATen/native/cuda/Loops.cuh. Like gpu_kernel, it uses
96 * TensorIterator to launch a kernel. However, the differences are
97 * - it launches a grid-stride loop based kernel. The kernel is not
98 * generic like elementwise_kernel in Loops.cuh and is specialized
99 * for the distribution kernels here.
100 * - For big size tensors, we can launch multiple kernels recursively
101 * (i.e. if (!iter.can_use_32bit_indexing())) and hence, the philox
102 * offset calculation is done in this function.
103 *
104 * FIXME: Can we specialize elementwise_kernel and launch_kernel in Loops.cuh
105 * to have grid-stride loop kernel and then use that to launch our distribution
106 * kernels? Note that we need a grid-stride loop kernel because, we found by testing
107 * that it achieves peak effective bandwidth.
108 */
109 template<typename scalar_t,
110 typename accscalar_t,
111 typename dist_func_return_t,
112 typename RNG,
113 typename dist_t,
114 typename transform_t>
distribution_nullary_kernel(at::TensorIteratorBase & iter,RNG gen,const dist_t & dist_func,const transform_t transform_func)115 void distribution_nullary_kernel(at::TensorIteratorBase& iter,
116 RNG gen,
117 const dist_t& dist_func,
118 const transform_t transform_func) {
119 const int unroll_factor = sizeof(dist_func_return_t) / sizeof(accscalar_t);
120 TORCH_CHECK(unroll_factor >= 1, "unroll_factor must be >= 1.");
121 int64_t numel = iter.numel();
122 if (numel == 0) {
123 return;
124 }
125
126 auto execution_policy = calc_execution_policy(numel, unroll_factor);
127 auto counter_offset = std::get<0>(execution_policy);
128 auto grid = std::get<1>(execution_policy);
129 auto block = std::get<2>(execution_policy);
130 PhiloxCudaState rng_engine_inputs;
131 {
132 // See Note [Acquire lock when using random generators]
133 std::lock_guard<std::mutex> lock(gen->mutex_);
134 rng_engine_inputs = gen->philox_cuda_state(counter_offset);
135 }
136
137 if (!iter.can_use_32bit_indexing()) {
138 for (auto& sub_iter : iter.with_32bit_indexing()) {
139 distribution_nullary_kernel<scalar_t, accscalar_t, dist_func_return_t>(sub_iter,
140 gen, dist_func, transform_func);
141 }
142 return;
143 }
144
145 char* out_data = (char*)iter.data_ptr(0);
146
147 auto stream = at::cuda::getCurrentCUDAStream();
148 if (iter.is_trivial_1d()) {
149 auto strides = iter.get_inner_strides();
150 int stride0 = strides[0];
151 distribution_elementwise_grid_stride_kernel<accscalar_t, unroll_factor><<<grid, block, 0, stream>>>(
152 numel,
153 rng_engine_inputs,
154 dist_func,
155 [=]__device__(int idx, accscalar_t rand) {
156 scalar_t* out = (scalar_t*)&out_data[stride0 * idx];
157 *out = transform_func(rand);
158 }
159 );
160 C10_CUDA_KERNEL_LAUNCH_CHECK();
161 } else {
162 auto offset_calc = make_offset_calculator<1>(iter);
163 distribution_elementwise_grid_stride_kernel<accscalar_t, unroll_factor><<<grid, block, 0, stream>>>(
164 numel,
165 rng_engine_inputs,
166 dist_func,
167 [=]__device__(int idx, accscalar_t rand) {
168 auto offsets = offset_calc.get(idx);
169 scalar_t* out = (scalar_t*)&out_data[offsets[0]];
170 *out = transform_func(rand);
171 }
172 );
173 C10_CUDA_KERNEL_LAUNCH_CHECK();
174 }
175 }
176
177 // Binary kernel
178 template <typename func_t, typename inp_offset_calc_t, typename out_offset_calc_t>
distribution_binary_elementwise_kernel(int numel,func_t f,PhiloxCudaState philox_args,typename function_traits<func_t>::result_type * output_data,const typename function_traits<func_t>::template arg<1>::type * input_data_1,const typename function_traits<func_t>::template arg<2>::type * input_data_2,inp_offset_calc_t inp_calc,out_offset_calc_t out_calc)179 __global__ void distribution_binary_elementwise_kernel(
180 int numel,
181 func_t f,
182 PhiloxCudaState philox_args,
183 typename function_traits<func_t>::result_type *output_data,
184 const typename function_traits<func_t>::template arg<1>::type *input_data_1,
185 const typename function_traits<func_t>::template arg<2>::type *input_data_2,
186 inp_offset_calc_t inp_calc,
187 out_offset_calc_t out_calc) {
188 auto seeds = at::cuda::philox::unpack(philox_args);
189
190 using input_t_1 = typename function_traits<func_t>::template arg<1>::type;
191 using input_t_2 = typename function_traits<func_t>::template arg<2>::type;
192
193 input_t_1 inputs_1[thread_work_size()];
194 input_t_2 inputs_2[thread_work_size()];
195
196 int base_index = block_work_size() * blockIdx.x;
197 int remaining = std::min<int>(numel - base_index, block_work_size());
198
199 curandStatePhilox4_32_10_t state;
200 curand_init(std::get<0>(seeds),
201 blockIdx.x * blockDim.x + threadIdx.x,
202 std::get<1>(seeds),
203 &state);
204
205 // load data into registers
206 int thread_idx = threadIdx.x;
207 #pragma unroll
208 for (int i = 0; i < thread_work_size(); i++) {
209 if (thread_idx >= remaining) {
210 break;
211 }
212 int input_idx = thread_idx + base_index;
213 auto offsets = inp_calc.get(input_idx);
214 inputs_1[i] = input_data_1[offsets[0]];
215 inputs_2[i] = input_data_2[offsets[1]];
216
217 thread_idx += num_threads();
218 }
219
220 // compute and store
221 thread_idx = threadIdx.x;
222 #pragma unroll
223 for (int i = 0; i < thread_work_size(); i++) {
224 if (thread_idx >= remaining) {
225 break;
226 }
227 int input_idx = thread_idx + base_index;
228 auto offsets = out_calc.get(input_idx);
229 output_data[offsets[0]] = f(state, inputs_1[i], inputs_2[i]);
230 thread_idx += num_threads();
231 }
232 }
233
234 template <typename func_t>
distribution_binary_kernel(TensorIteratorBase & iter,PhiloxCudaState philox_args,const func_t & f)235 void distribution_binary_kernel(TensorIteratorBase &iter, PhiloxCudaState philox_args, const func_t &f) {
236 static_assert(std::is_same<typename function_traits<func_t>::template arg<0>::type, curandStatePhilox4_32_10_t&>::value, "the first argument of functor must be curandStatePhilox4_32_10_t");
237 using input_t_1 = typename function_traits<func_t>::template arg<1>::type;
238 using input_t_2 = typename function_traits<func_t>::template arg<2>::type;
239 using output_t = typename function_traits<func_t>::result_type;
240
241 if (!iter.can_use_32bit_indexing()) {
242 for (auto& sub_iter : iter.with_32bit_indexing()) {
243 distribution_binary_kernel(sub_iter, philox_args, f);
244 }
245 return;
246 }
247
248 TORCH_INTERNAL_ASSERT_DEBUG_ONLY(iter.can_use_32bit_indexing());
249
250 int64_t numel = iter.numel();
251 if (numel == 0) {
252 return;
253 }
254
255 output_t *output_data = static_cast<output_t *>(iter.data_ptr(0));
256 const input_t_1 *input_data_1 = static_cast<const input_t_1 *>(iter.data_ptr(1));
257 const input_t_2 *input_data_2 = static_cast<const input_t_2 *>(iter.data_ptr(2));
258
259 int64_t grid = (numel + block_work_size() - 1) / block_work_size();
260 auto stream = at::cuda::getCurrentCUDAStream();
261
262 if (iter.is_contiguous()) {
263 distribution_binary_elementwise_kernel<<<grid,num_threads(), 0, stream>>>(
264 numel, f, philox_args, output_data, input_data_1, input_data_2,
265 TrivialOffsetCalculator<2>(), TrivialOffsetCalculator<1>());
266 C10_CUDA_KERNEL_LAUNCH_CHECK();
267 } else {
268 distribution_binary_elementwise_kernel<<<grid, num_threads(), 0, stream>>>(
269 numel, f, philox_args, output_data, input_data_1, input_data_2,
270 make_input_offset_calculator<2>(iter), make_output_offset_calculator(iter));
271 C10_CUDA_KERNEL_LAUNCH_CHECK();
272 }
273 }
274
275 } // namespace
276 }} // namespace at::native
277
278
279 namespace at {
280 namespace native {
281 namespace templates {
282 namespace cuda {
283
284 // ==================================================== Random ========================================================
285
286 template<typename RNG>
random_from_to_kernel(TensorIteratorBase & iter,uint64_t range,int64_t base,RNG gen)287 void random_from_to_kernel(TensorIteratorBase& iter, uint64_t range, int64_t base, RNG gen) {
288 AT_DISPATCH_V2(iter.dtype(), "random_from_to_kernel_cuda", AT_WRAP([&] {
289 if ((
290 std::is_same<scalar_t, int64_t>::value ||
291 std::is_same<scalar_t, double>::value ||
292 std::is_same<scalar_t, float>::value ||
293 std::is_same<scalar_t, at::BFloat16>::value) && range >= 1ULL << 32)
294 {
295 // define lambda to mod with range and add base
296 auto random_func = [range, base] __device__ (uint64_t rand) {
297 return transformation::uniform_int_from_to<scalar_t>(rand, range, base);
298 };
299 distribution_nullary_kernel<scalar_t, uint64_t, ulonglong2>(iter,
300 gen,
301 [] __device__ (curandStatePhilox4_32_10_t* state) -> ulonglong2 {
302 ulonglong2 ret;
303 uint4 rand_val = curand4(state);
304 ret.x = (static_cast<uint64_t>(rand_val.x) << 32) | rand_val.y;
305 ret.y = (static_cast<uint64_t>(rand_val.z) << 32) | rand_val.w;
306 return ret;
307 },
308 random_func);
309 } else {
310 auto random_func = [range, base] __device__ (uint32_t rand) {
311 return transformation::uniform_int_from_to<scalar_t>(rand, range, base);
312 };
313 distribution_nullary_kernel<scalar_t, uint32_t, uint4>(iter,
314 gen,
315 [] __device__ (curandStatePhilox4_32_10_t* state) -> uint4 {
316 return curand4(state);
317 },
318 random_func);
319 }
320 }), AT_EXPAND(AT_ALL_TYPES), kBool, kHalf, kBFloat16, AT_EXPAND(AT_BAREBONES_UNSIGNED_TYPES));
321 }
322
323 // This is the special kernel to handle single specific case:
324 // from(inclusive) = std::numeric_limits<int64_t>::lowest()
325 // to(exclusive) = None (= std::numeric_limits<int64_t>::max() + 1)
326 template<typename RNG>
random_full_64_bits_range_kernel(TensorIteratorBase & iter,RNG gen)327 void random_full_64_bits_range_kernel(TensorIteratorBase& iter, RNG gen) {
328 AT_DISPATCH_ALL_TYPES_AND(at::ScalarType::BFloat16, iter.dtype(), "random_full_64_bits_range_kernel_cuda", [&] {
329 if (std::is_same<scalar_t, int64_t>::value ||
330 std::is_same<scalar_t, double>::value ||
331 std::is_same<scalar_t, float>::value ||
332 std::is_same<scalar_t, at::BFloat16>::value) {
333 auto random_func = [] __device__ (uint64_t rand) {
334 return transformation::uniform_int_full_range<scalar_t>(rand);
335 };
336 distribution_nullary_kernel<scalar_t, uint64_t, ulonglong2>(iter,
337 gen,
338 [] __device__ (curandStatePhilox4_32_10_t* state) -> ulonglong2 {
339 ulonglong2 ret;
340 uint4 rand_val = curand4(state);
341 ret.x = (static_cast<uint64_t>(rand_val.x) << 32) | rand_val.y;
342 ret.y = (static_cast<uint64_t>(rand_val.z) << 32) | rand_val.w;
343 return ret;
344 },
345 random_func);
346 } else {
347 TORCH_CHECK(false, "random_full_64_bits_range_kernel_cuda handles only int64, double, float and bfloat16");
348 }
349 });
350 }
351
352 template<typename RNG>
353 struct RandomFromToKernel {
operatorRandomFromToKernel354 void operator()(TensorIteratorBase& iter, uint64_t range, int64_t base, std::optional<Generator> gen) {
355 random_from_to_kernel(iter, range, base, check_generator<RNG>(gen));
356 }
operatorRandomFromToKernel357 void operator()(TensorIteratorBase& iter, std::optional<Generator> gen) {
358 random_full_64_bits_range_kernel(iter, check_generator<RNG>(gen));
359 }
360 };
361
362 template<typename RNG>
random_kernel(TensorIteratorBase & iter,RNG gen)363 void random_kernel(TensorIteratorBase& iter, RNG gen) {
364 AT_DISPATCH_ALL_TYPES_AND3(at::ScalarType::Half, at::ScalarType::BFloat16, at::ScalarType::Bool, iter.dtype(), "random_kernel_cuda", [&] {
365 if (std::is_same<scalar_t, double>::value || std::is_same<scalar_t, int64_t>::value) {
366 auto random_func = [] __device__ (uint64_t rand) {
367 return transformation::uniform_int<scalar_t>(rand);
368 };
369 distribution_nullary_kernel<scalar_t, uint64_t, ulonglong2>(iter, gen,
370 [] __device__ (curandStatePhilox4_32_10_t* state) -> ulonglong2 {
371 ulonglong2 ret;
372 uint4 rand_val = curand4(state);
373 ret.x = (static_cast<uint64_t>(rand_val.x) << 32) | rand_val.y;
374 ret.y = (static_cast<uint64_t>(rand_val.z) << 32) | rand_val.w;
375 return ret;
376 },
377 random_func);
378 } else {
379 auto random_func = [] __device__ (uint32_t rand) {
380 return transformation::uniform_int<scalar_t>(rand);
381 };
382 distribution_nullary_kernel<scalar_t, uint32_t, uint4>(iter,
383 gen,
384 [] __device__ (curandStatePhilox4_32_10_t* state) -> uint4 {
385 return curand4(state);
386 },
387 random_func);
388 }
389 });
390 }
391
392 template<typename RNG>
393 struct RandomKernel {
operatorRandomKernel394 void operator()(TensorIteratorBase& iter, RNG gen) {
395 random_kernel(iter, gen);
396 }
397 };
398
399 // ====================================================================================================================
400
401 template<typename scalar_t, typename accscalar_t, typename RNG, typename transform_t>
uniform_and_transform(TensorIteratorBase & iter,RNG gen,transform_t transform)402 void uniform_and_transform(TensorIteratorBase& iter, RNG gen, transform_t transform) {
403 if (std::is_same<scalar_t, double>::value) {
404 distribution_nullary_kernel<scalar_t, accscalar_t, double2>(iter,
405 gen,
406 [] __device__ (curandStatePhilox4_32_10_t* state) -> double2 { return curand_uniform2_double(state); },
407 transform);
408 } else {
409 distribution_nullary_kernel<scalar_t, accscalar_t, float4>(iter,
410 gen,
411 [] __device__ (curandStatePhilox4_32_10_t* state) -> float4 { return curand_uniform4(state); },
412 transform);
413 }
414 }
415
416 template<typename scalar_t, typename accscalar_t, typename RNG, typename transform_t>
normal_and_transform(TensorIteratorBase & iter,RNG gen,transform_t transform)417 void normal_and_transform(TensorIteratorBase& iter, RNG gen, transform_t transform) {
418 if (std::is_same<scalar_t, double>::value) {
419 distribution_nullary_kernel<scalar_t, accscalar_t, double2>(iter,
420 gen,
421 [] __device__ (curandStatePhilox4_32_10_t* state) -> double2 { return curand_normal2_double(state); },
422 transform);
423 } else {
424 distribution_nullary_kernel<scalar_t, accscalar_t, float4>(iter,
425 gen,
426 [] __device__ (curandStatePhilox4_32_10_t* state) -> float4 { return curand_normal4(state); },
427 transform);
428 }
429 }
430
431 // ==================================================== Normal ========================================================
432
433 template<typename RNG>
normal_kernel(const TensorBase & self,double mean_,double std_,RNG gen)434 void normal_kernel(const TensorBase &self, double mean_, double std_, RNG gen) {
435 auto iter = TensorIterator::borrowing_nullary_op(self);
436 AT_DISPATCH_FLOATING_TYPES_AND2(at::ScalarType::Half, at::ScalarType::BFloat16, iter.dtype(), "normal_kernel_cuda", [&] {
437 using accscalar_t = at::acc_type<scalar_t, true>;
438 auto mean = static_cast<accscalar_t>(mean_);
439 auto std = static_cast<accscalar_t>(std_);
440 // define lambda to multiply std and add mean
441 auto normal_func = [mean, std] __device__ (accscalar_t rand) {
442 return static_cast<scalar_t>(transformation::normal<accscalar_t>(rand, mean, std));
443 };
444 normal_and_transform<scalar_t, accscalar_t>(iter, gen, normal_func);
445 });
446 }
447
448 template<typename RNG>
449 struct NormalKernel {
operatorNormalKernel450 void operator()(const TensorBase &self, double mean, double std, std::optional<Generator> gen) {
451 normal_kernel(self, mean, std, check_generator<RNG>(gen));
452 }
453 };
454
455 // ==================================================== Uniform ========================================================
456
457 template<typename RNG>
uniform_kernel(TensorIteratorBase & iter,double from_,double to_,RNG gen)458 void uniform_kernel(TensorIteratorBase& iter, double from_, double to_, RNG gen) {
459 AT_DISPATCH_FLOATING_TYPES_AND2(at::ScalarType::Half, at::ScalarType::BFloat16, iter.dtype(), "uniform_kernel_cuda", [&] {
460 auto from = static_cast<scalar_t>(from_);
461 auto to = static_cast<scalar_t>(to_);
462 using opmath_t = at::opmath_type<scalar_t>;
463 auto range = static_cast<opmath_t>(to-from);
464 // define lambda to reverse bounds, multiply 'range' and add 'from_'
465 auto uniform_func = [range, from, to] __device__ (opmath_t rand) {
466 // Compute output value before reversing the bounds
467 // BEFORE TOUCHING THIS CODE READ: https://github.com/pytorch/pytorch/issues/96947
468 auto value = static_cast<scalar_t>(rand * range + from);
469 // reverse the bounds of curand4 from (0, 1] to [0, 1)
470 // Note that this method is from legacy THCTensorRandom and is likely to give
471 // you more 0-s, since, the probability of gettings 1-s is higher than 0-s and
472 // by reversing the bounds, we are flipping the probabilities of 1-s and 0-s.
473 // BEFORE TOUCHING THIS CODE READ: https://github.com/pytorch/pytorch/issues/16706
474 auto reverse_bound_value = value == to ? from : value;
475 return reverse_bound_value;
476 };
477 uniform_and_transform<scalar_t, opmath_t>(iter, gen, uniform_func);
478 });
479 }
480
481 template<typename RNG>
482 struct UniformKernel {
operatorUniformKernel483 void operator()(TensorIteratorBase& iter, double from, double to, std::optional<Generator> gen) {
484 uniform_kernel(iter, from, to, check_generator<RNG>(gen));
485 }
486 };
487
488 // ================================================== LogNormal =======================================================
489
490 template<typename RNG>
log_normal_kernel(TensorIteratorBase & iter,double mean_,double std_,RNG gen)491 void log_normal_kernel(TensorIteratorBase& iter, double mean_, double std_, RNG gen) {
492 AT_DISPATCH_FLOATING_TYPES_AND2(at::ScalarType::Half, at::ScalarType::BFloat16, iter.dtype(), "log_normal_cuda", [&] {
493 using accscalar_t = at::acc_type<scalar_t, true>;
494 auto mean = static_cast<accscalar_t>(mean_);
495 auto std = static_cast<accscalar_t>(std_);
496 // define lambda for log_normal transformation
497 auto log_normal_func = [mean, std] __device__ (accscalar_t rand) {
498 return static_cast<scalar_t>(transformation::log_normal<accscalar_t>(transformation::normal<accscalar_t>(rand, mean, std)));
499 };
500 normal_and_transform<scalar_t, accscalar_t>(iter, gen, log_normal_func);
501 });
502 }
503
504 template<typename RNG>
505 struct LogNormalKernel {
operatorLogNormalKernel506 void operator()(TensorIteratorBase& iter, double mean, double std, std::optional<Generator> gen) {
507 log_normal_kernel(iter, mean, std, check_generator<RNG>(gen));
508 }
509 };
510
511 // =================================================== Geometric ======================================================
512
513 template<typename RNG>
geometric_kernel(TensorIteratorBase & iter,double p,RNG gen)514 void geometric_kernel(TensorIteratorBase& iter, double p, RNG gen) {
515 AT_DISPATCH_ALL_TYPES_AND2(at::ScalarType::Half, at::ScalarType::BFloat16, iter.dtype(), "geometric_cuda", [&] {
516 using accscalar_t = at::DiscreteDistributionType<scalar_t>::type;
517 // define lambda for geometric transformation
518 auto geometric_func = [p] __device__ (accscalar_t rand) {
519 return static_cast<scalar_t>(transformation::geometric<accscalar_t>(rand, p));
520 };
521 uniform_and_transform<scalar_t, accscalar_t>(iter, gen, geometric_func);
522 });
523 }
524
525 template<typename RNG>
526 struct GeometricKernel {
operatorGeometricKernel527 void operator()(TensorIteratorBase& iter, double p, std::optional<Generator> gen) {
528 geometric_kernel(iter, p, check_generator<RNG>(gen));
529 }
530 };
531
532 // ================================================== Exponential =====================================================
533
534 template<typename RNG>
exponential_kernel(TensorIteratorBase & iter,double lambda_,RNG gen)535 void exponential_kernel(TensorIteratorBase& iter, double lambda_, RNG gen) {
536 TORCH_CHECK(isFloatingType(iter.dtype()), "Exponential distribution is a continuous probability distribution. dtype must be a floating point but you specified ", iter.dtype());
537 AT_DISPATCH_FLOATING_TYPES_AND2(at::ScalarType::Half, at::ScalarType::BFloat16, iter.dtype(), "exponential_cuda", [&] {
538 using accscalar_t = at::acc_type<scalar_t, true>;
539 auto lambda = static_cast<accscalar_t>(lambda_);
540 // define lambda for exponential transformation
541 auto exponential_func = [lambda] __device__ (accscalar_t rand) {
542 return static_cast<scalar_t>(transformation::exponential<accscalar_t>(rand, lambda));
543 };
544 uniform_and_transform<scalar_t, accscalar_t>(iter, gen, exponential_func);
545 });
546 }
547
548 template<typename RNG>
549 struct ExponentialKernel {
operatorExponentialKernel550 void operator()(TensorIteratorBase& iter, double lambda, std::optional<Generator> gen) {
551 exponential_kernel(iter, lambda, check_generator<RNG>(gen));
552 }
553 };
554
555 // ==================================================== Cauchy ========================================================
556
557 template<typename RNG>
cauchy_kernel(TensorIteratorBase & iter,double median_,double sigma_,RNG gen)558 void cauchy_kernel(TensorIteratorBase& iter, double median_, double sigma_, RNG gen) {
559 AT_DISPATCH_FLOATING_TYPES_AND2(at::ScalarType::Half, at::ScalarType::BFloat16, iter.dtype(), "cauchy_cuda", [&] {
560 using accscalar_t = at::acc_type<scalar_t, true>;
561 auto median = static_cast<accscalar_t>(median_);
562 auto sigma = static_cast<accscalar_t>(sigma_);
563 // define lambda for cauchy transformation
564 auto cauchy_func = [median, sigma] __device__ (accscalar_t rand) {
565 return static_cast<scalar_t>(transformation::cauchy<accscalar_t>(rand, median, sigma));
566 };
567 uniform_and_transform<scalar_t, accscalar_t>(iter, gen, cauchy_func);
568 });
569 }
570
571 template<typename RNG>
572 struct CauchyKernel {
operatorCauchyKernel573 void operator()(TensorIteratorBase& iter, double median, double sigma, std::optional<Generator> gen) {
574 cauchy_kernel(iter, median, sigma, check_generator<RNG>(gen));
575 }
576 };
577
578 // ==================================================== Bernoulli =====================================================
579
580 template<typename scalar_t, typename prob_t>
bernoulli_tensor_cuda_kernel(const TensorBase & ret,const at::TensorBase & p,PhiloxCudaState philox_args)581 void bernoulli_tensor_cuda_kernel(
582 const TensorBase &ret, const at::TensorBase &p,
583 PhiloxCudaState philox_args) {
584 auto functor = [philox_args] __device__(
585 int n, scalar_t& v1, scalar_t& v2, scalar_t& v3, scalar_t& v4,
586 const prob_t& p1, const prob_t& p2, const prob_t& p3, const prob_t& p4) {
587 auto seeds = at::cuda::philox::unpack(philox_args);
588 curandStatePhilox4_32_10_t state;
589 curand_init(std::get<0>(seeds),
590 blockIdx.x * blockDim.x + threadIdx.x,
591 std::get<1>(seeds),
592 &state);
593
594 // See Note [Register spilling in curand call for CUDA < 10]
595 float4 rand = curand_uniform4(&state);
596 switch (n) {
597 case 4: {
598 CUDA_KERNEL_ASSERT(0 <= p4 && p4 <= 1);
599 v4 = static_cast<scalar_t>(rand.w <= p4);
600 [[fallthrough]];
601 }
602 case 3: {
603 CUDA_KERNEL_ASSERT(0 <= p3 && p3 <= 1);
604 v3 = static_cast<scalar_t>(rand.z <= p3);
605 [[fallthrough]];
606 }
607 case 2: {
608 CUDA_KERNEL_ASSERT(0 <= p2 && p2 <= 1);
609 v2 = static_cast<scalar_t>(rand.y <= p2);
610 [[fallthrough]];
611 }
612 case 1: {
613 CUDA_KERNEL_ASSERT(0 <= p1 && p1 <= 1);
614 v1 = static_cast<scalar_t>(rand.x <= p1);
615 }
616 }
617 };
618 // The template argument `4` below indicates that we want to operate on four
619 // element at each time. See NOTE [ CUDA_tensor_applyN helpers ] for details.
620 at::cuda::CUDA_tensor_apply2<scalar_t, const prob_t, 4, decltype(functor),
621 /*max_threads_per_block=*/512,
622 /*min_blocks_per_sm==*/2>(ret, p, functor);
623 }
624
625 template<typename RNG>
bernoulli_kernel(const TensorBase & self,const TensorBase & p_,RNG gen)626 void bernoulli_kernel(const TensorBase &self, const TensorBase &p_, RNG gen) {
627 PhiloxCudaState rng_engine_inputs;
628 {
629 // See Note [Acquire lock when using random generators]
630 std::lock_guard<std::mutex> lock(gen->mutex_);
631 rng_engine_inputs = gen->philox_cuda_state(10);
632 }
633 TORCH_CHECK(at::isFloatingType(p_.scalar_type()), "expected probabilities tensor to have floating type, got ", p_.scalar_type());
634 // cast probabilities tensor to double for double `self` tensor, and to `float` for everything else
635 const auto p_type = self.dtype() == at::kDouble ? at::kDouble : at::kFloat;
636 auto p_cuda = p_.to(TensorOptions().device(self.device()).dtype(p_type));
637 auto p = expand_inplace(self, p_cuda);
638 AT_DISPATCH_ALL_TYPES_AND3(
639 at::ScalarType::Half, at::ScalarType::BFloat16, at::ScalarType::Bool, self.scalar_type(), "bernoulli_tensor_cuda_self_", [&] {
640 if (std::is_same<scalar_t, double>::value) {
641 return bernoulli_tensor_cuda_kernel<double, double>(self, *p, rng_engine_inputs);
642 } else {
643 return bernoulli_tensor_cuda_kernel<scalar_t, float>(self, *p, rng_engine_inputs);
644 }
645 });
646 }
647
648 template<typename RNG>
bernoulli_kernel(TensorIteratorBase & iter,double p,RNG gen)649 void bernoulli_kernel(TensorIteratorBase& iter, double p, RNG gen) {
650 AT_DISPATCH_ALL_TYPES_AND3(
651 at::ScalarType::Half, at::ScalarType::BFloat16, at::ScalarType::Bool, iter.dtype(), "bernoulli_scalar_cuda_", [&] {
652 using accscalar_t = at::DiscreteDistributionType<scalar_t>::type;
653 // define lambda for bernoulli transformation
654 auto bernoulli_func = [p] __device__ (accscalar_t rand) {
655 return static_cast<scalar_t>(transformation::bernoulli<accscalar_t>(rand, p));
656 };
657 uniform_and_transform<scalar_t, accscalar_t>(iter, gen, bernoulli_func);
658 });
659 }
660
661 template<typename RNG>
662 struct BernoulliKernel {
operatorBernoulliKernel663 void operator()(TensorIteratorBase& iter, double p, std::optional<Generator> gen) {
664 bernoulli_kernel(iter, p, check_generator<RNG>(gen));
665 }
operatorBernoulliKernel666 void operator()(const TensorBase &self, const TensorBase &p_, std::optional<Generator> gen) {
667 bernoulli_kernel(self, p_, check_generator<RNG>(gen));
668 }
669 };
670
671 }}}}
672