xref: /aosp_15_r20/external/pytorch/aten/src/ATen/native/cuda/DistributionTemplates.h (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #pragma once
2 
3 #include <ATen/AccumulateType.h>
4 #include <ATen/Dispatch.h>
5 #include <ATen/Dispatch_v2.h>
6 #include <ATen/ExpandBase.h>
7 #include <ATen/OpMathType.h>
8 #include <ATen/native/TensorIterator.h>
9 #include <ATen/native/cuda/Loops.cuh>
10 #include <c10/util/Half.h>
11 #include <ATen/cuda/CUDAApplyUtils.cuh>
12 #include <ATen/cuda/CUDAContext.h>
13 #include <ATen/cuda/detail/OffsetCalculator.cuh>
14 #include <ATen/cuda/CUDAGraphsUtils.cuh>
15 #include <ATen/detail/FunctionTraits.h>
16 #include <ATen/core/DistributionsHelper.h>
17 
18 #include <curand.h>
19 #include <curand_kernel.h>
20 #include <curand_philox4x32_x.h>
21 #include <cstdint>
22 #include <limits>
23 #include <utility>
24 #include <mutex>
25 #include <tuple>
26 #include <type_traits>
27 
28 namespace at {
29 namespace native {
30 namespace {
31 
32 // launch bounds used for kernels utilizing TensorIterator
33 const uint32_t block_size_bound = 256;
34 const uint32_t grid_size_bound = 4;
35 // At the time of writing, there is no curand_* call that increments the offset by more than 4.
36 // See: https://docs.nvidia.com/cuda/archive/11.8.0/curand/group__DEVICE.html
37 const uint32_t max_generator_offsets_per_curand_call = 4;
38 
39 // utility function that calculates proper philox_offset
40 // for distributions utilizing TensorIterator. For distributions using
41 // TensorIterator, we are using a grid-stride loop with each
42 // thread yielding one element per thread. For the edge of the grid-stride
43 // loop, if the tensor size is large, the unroll loop will kick in and the float4
44 // from curand4 will start getting utilized (for common tensor sizes, we end up
45 // using rand.x from each thread). The philox_offset calculation was changed to
46 // (number of elements per thread * maximum generator increment per "curand_*" call), which makes
47 // sure that philox offset increment is not less than the number of randoms used
48 // in each thread.
calc_execution_policy(const int64_t total_elements,const uint32_t unroll_factor)49 std::tuple<uint64_t, dim3, dim3> calc_execution_policy(const int64_t total_elements, const uint32_t unroll_factor) {
50   const uint64_t numel = static_cast<uint64_t>(total_elements);
51   const uint32_t block_size = block_size_bound;
52   dim3 dim_block(block_size);
53   dim3 grid((numel + block_size - 1) / block_size);
54   uint32_t blocks_per_sm = at::cuda::getCurrentDeviceProperties()->maxThreadsPerMultiProcessor / block_size;
55   grid.x = std::min(
56       static_cast<uint32_t>(at::cuda::getCurrentDeviceProperties()->multiProcessorCount) * blocks_per_sm,
57       grid.x);
58   //number of times random will be generated per thread, to offset philox counter in thc random state
59   uint64_t counter_offset = ((numel - 1) / (block_size * grid.x * unroll_factor) + 1) * max_generator_offsets_per_curand_call;
60   return std::make_tuple(counter_offset, grid, dim_block);
61 }
62 
63 // grid stride loop kernel for distributions
64 template<typename accscalar_t, int unroll_factor, typename dist_t, typename transform_t>
C10_LAUNCH_BOUNDS_2(block_size_bound,grid_size_bound)65 C10_LAUNCH_BOUNDS_2(block_size_bound, grid_size_bound)
66 __global__ void distribution_elementwise_grid_stride_kernel(int numel,
67                                                             PhiloxCudaState philox_args,
68                                                             const dist_t dist_func,
69                                                             const transform_t transform_func) {
70   auto seeds = at::cuda::philox::unpack(philox_args);
71   int idx = blockIdx.x * blockDim.x + threadIdx.x;
72   curandStatePhilox4_32_10_t state;
73   curand_init(std::get<0>(seeds),
74               idx,
75               std::get<1>(seeds),
76               &state);
77 
78   int rounded_size = ((numel - 1)/(blockDim.x * gridDim.x * unroll_factor)+1) *
79       blockDim.x * gridDim.x * unroll_factor;
80   for(int linear_index = idx; linear_index < rounded_size; linear_index += blockDim.x * gridDim.x * unroll_factor) {
81     auto rand = dist_func(&state);
82     #pragma unroll
83     for (int ii = 0; ii < unroll_factor; ii++) {
84       int li = linear_index + blockDim.x * gridDim.x * ii;
85       if (li < numel) {
86         transform_func(li, static_cast<accscalar_t>((&rand.x)[ii]));
87       }
88     }
89     __syncthreads();
90   }
91 }
92 
93 /**
94  * distribution_nullary_kernel is analogous to gpu_kernel in
95  * ATen/native/cuda/Loops.cuh. Like gpu_kernel, it uses
96  * TensorIterator to launch a kernel. However, the differences are
97  *   - it launches a grid-stride loop based kernel. The kernel is not
98  *     generic like elementwise_kernel in Loops.cuh and is specialized
99  *     for the distribution kernels here.
100  *   - For big size tensors, we can launch multiple kernels recursively
101  *     (i.e. if (!iter.can_use_32bit_indexing())) and hence, the philox
102  *     offset calculation is done in this function.
103  *
104  * FIXME: Can we specialize elementwise_kernel and launch_kernel in Loops.cuh
105  * to have grid-stride loop kernel and then use that to launch our distribution
106  * kernels? Note that we need a grid-stride loop kernel because, we found by testing
107  * that it achieves peak effective bandwidth.
108  */
109 template<typename scalar_t,
110          typename accscalar_t,
111          typename dist_func_return_t,
112          typename RNG,
113          typename dist_t,
114          typename transform_t>
distribution_nullary_kernel(at::TensorIteratorBase & iter,RNG gen,const dist_t & dist_func,const transform_t transform_func)115 void distribution_nullary_kernel(at::TensorIteratorBase& iter,
116                                  RNG gen,
117                                  const dist_t& dist_func,
118                                  const transform_t transform_func) {
119   const int unroll_factor = sizeof(dist_func_return_t) / sizeof(accscalar_t);
120   TORCH_CHECK(unroll_factor >= 1, "unroll_factor must be >= 1.");
121   int64_t numel = iter.numel();
122   if (numel == 0) {
123     return;
124   }
125 
126   auto execution_policy = calc_execution_policy(numel, unroll_factor);
127   auto counter_offset = std::get<0>(execution_policy);
128   auto grid = std::get<1>(execution_policy);
129   auto block = std::get<2>(execution_policy);
130   PhiloxCudaState rng_engine_inputs;
131   {
132     // See Note [Acquire lock when using random generators]
133     std::lock_guard<std::mutex> lock(gen->mutex_);
134     rng_engine_inputs = gen->philox_cuda_state(counter_offset);
135   }
136 
137   if (!iter.can_use_32bit_indexing()) {
138     for (auto& sub_iter : iter.with_32bit_indexing()) {
139       distribution_nullary_kernel<scalar_t, accscalar_t, dist_func_return_t>(sub_iter,
140         gen, dist_func, transform_func);
141     }
142     return;
143   }
144 
145   char* out_data = (char*)iter.data_ptr(0);
146 
147   auto stream = at::cuda::getCurrentCUDAStream();
148   if (iter.is_trivial_1d()) {
149     auto strides = iter.get_inner_strides();
150     int stride0 = strides[0];
151     distribution_elementwise_grid_stride_kernel<accscalar_t, unroll_factor><<<grid, block, 0, stream>>>(
152       numel,
153       rng_engine_inputs,
154       dist_func,
155       [=]__device__(int idx, accscalar_t rand) {
156         scalar_t* out = (scalar_t*)&out_data[stride0 * idx];
157         *out = transform_func(rand);
158       }
159     );
160     C10_CUDA_KERNEL_LAUNCH_CHECK();
161   } else {
162     auto offset_calc = make_offset_calculator<1>(iter);
163     distribution_elementwise_grid_stride_kernel<accscalar_t, unroll_factor><<<grid, block, 0, stream>>>(
164       numel,
165       rng_engine_inputs,
166       dist_func,
167       [=]__device__(int idx, accscalar_t rand) {
168         auto offsets = offset_calc.get(idx);
169         scalar_t* out = (scalar_t*)&out_data[offsets[0]];
170         *out = transform_func(rand);
171       }
172     );
173     C10_CUDA_KERNEL_LAUNCH_CHECK();
174   }
175 }
176 
177 // Binary kernel
178 template <typename func_t, typename inp_offset_calc_t, typename out_offset_calc_t>
distribution_binary_elementwise_kernel(int numel,func_t f,PhiloxCudaState philox_args,typename function_traits<func_t>::result_type * output_data,const typename function_traits<func_t>::template arg<1>::type * input_data_1,const typename function_traits<func_t>::template arg<2>::type * input_data_2,inp_offset_calc_t inp_calc,out_offset_calc_t out_calc)179 __global__ void distribution_binary_elementwise_kernel(
180     int numel,
181     func_t f,
182     PhiloxCudaState philox_args,
183     typename function_traits<func_t>::result_type *output_data,
184     const typename function_traits<func_t>::template arg<1>::type *input_data_1,
185     const typename function_traits<func_t>::template arg<2>::type *input_data_2,
186     inp_offset_calc_t inp_calc,
187     out_offset_calc_t out_calc) {
188   auto seeds = at::cuda::philox::unpack(philox_args);
189 
190   using input_t_1 = typename function_traits<func_t>::template arg<1>::type;
191   using input_t_2 = typename function_traits<func_t>::template arg<2>::type;
192 
193   input_t_1 inputs_1[thread_work_size()];
194   input_t_2 inputs_2[thread_work_size()];
195 
196   int base_index = block_work_size() * blockIdx.x;
197   int remaining = std::min<int>(numel - base_index, block_work_size());
198 
199   curandStatePhilox4_32_10_t state;
200   curand_init(std::get<0>(seeds),
201               blockIdx.x * blockDim.x + threadIdx.x,
202               std::get<1>(seeds),
203               &state);
204 
205   // load data into registers
206   int thread_idx = threadIdx.x;
207   #pragma unroll
208   for (int i = 0; i < thread_work_size(); i++) {
209     if (thread_idx >= remaining) {
210       break;
211     }
212     int input_idx = thread_idx + base_index;
213     auto offsets = inp_calc.get(input_idx);
214     inputs_1[i] = input_data_1[offsets[0]];
215     inputs_2[i] = input_data_2[offsets[1]];
216 
217     thread_idx += num_threads();
218   }
219 
220   // compute and store
221   thread_idx = threadIdx.x;
222   #pragma unroll
223   for (int i = 0; i < thread_work_size(); i++) {
224     if (thread_idx >= remaining) {
225       break;
226     }
227     int input_idx = thread_idx + base_index;
228     auto offsets = out_calc.get(input_idx);
229     output_data[offsets[0]] = f(state, inputs_1[i], inputs_2[i]);
230     thread_idx += num_threads();
231   }
232 }
233 
234 template <typename func_t>
distribution_binary_kernel(TensorIteratorBase & iter,PhiloxCudaState philox_args,const func_t & f)235 void distribution_binary_kernel(TensorIteratorBase &iter, PhiloxCudaState philox_args, const func_t &f) {
236   static_assert(std::is_same<typename function_traits<func_t>::template arg<0>::type, curandStatePhilox4_32_10_t&>::value, "the first argument of functor must be curandStatePhilox4_32_10_t");
237   using input_t_1 = typename function_traits<func_t>::template arg<1>::type;
238   using input_t_2 = typename function_traits<func_t>::template arg<2>::type;
239   using output_t = typename function_traits<func_t>::result_type;
240 
241   if (!iter.can_use_32bit_indexing()) {
242     for (auto& sub_iter : iter.with_32bit_indexing()) {
243       distribution_binary_kernel(sub_iter, philox_args, f);
244     }
245     return;
246   }
247 
248   TORCH_INTERNAL_ASSERT_DEBUG_ONLY(iter.can_use_32bit_indexing());
249 
250   int64_t numel = iter.numel();
251   if (numel == 0) {
252     return;
253   }
254 
255   output_t *output_data = static_cast<output_t *>(iter.data_ptr(0));
256   const input_t_1 *input_data_1 = static_cast<const input_t_1 *>(iter.data_ptr(1));
257   const input_t_2 *input_data_2 = static_cast<const input_t_2 *>(iter.data_ptr(2));
258 
259   int64_t grid = (numel + block_work_size() - 1) / block_work_size();
260   auto stream = at::cuda::getCurrentCUDAStream();
261 
262   if (iter.is_contiguous()) {
263     distribution_binary_elementwise_kernel<<<grid,num_threads(), 0, stream>>>(
264         numel, f, philox_args, output_data, input_data_1, input_data_2,
265         TrivialOffsetCalculator<2>(), TrivialOffsetCalculator<1>());
266     C10_CUDA_KERNEL_LAUNCH_CHECK();
267   } else {
268     distribution_binary_elementwise_kernel<<<grid, num_threads(), 0, stream>>>(
269         numel, f, philox_args, output_data, input_data_1, input_data_2,
270         make_input_offset_calculator<2>(iter), make_output_offset_calculator(iter));
271     C10_CUDA_KERNEL_LAUNCH_CHECK();
272   }
273 }
274 
275 } // namespace
276 }} // namespace at::native
277 
278 
279 namespace at {
280 namespace native {
281 namespace templates {
282 namespace cuda {
283 
284 // ==================================================== Random ========================================================
285 
286 template<typename RNG>
random_from_to_kernel(TensorIteratorBase & iter,uint64_t range,int64_t base,RNG gen)287 void random_from_to_kernel(TensorIteratorBase& iter, uint64_t range, int64_t base, RNG gen) {
288   AT_DISPATCH_V2(iter.dtype(), "random_from_to_kernel_cuda", AT_WRAP([&] {
289     if ((
290       std::is_same<scalar_t, int64_t>::value ||
291       std::is_same<scalar_t, double>::value ||
292       std::is_same<scalar_t, float>::value ||
293       std::is_same<scalar_t, at::BFloat16>::value) && range >= 1ULL << 32)
294     {
295       // define lambda to mod with range and add base
296       auto random_func = [range, base] __device__ (uint64_t rand) {
297         return transformation::uniform_int_from_to<scalar_t>(rand, range, base);
298       };
299       distribution_nullary_kernel<scalar_t, uint64_t, ulonglong2>(iter,
300         gen,
301         [] __device__ (curandStatePhilox4_32_10_t* state) -> ulonglong2 {
302           ulonglong2 ret;
303           uint4 rand_val = curand4(state);
304           ret.x = (static_cast<uint64_t>(rand_val.x) << 32) | rand_val.y;
305           ret.y = (static_cast<uint64_t>(rand_val.z) << 32) | rand_val.w;
306           return ret;
307         },
308         random_func);
309     } else {
310       auto random_func = [range, base] __device__ (uint32_t rand) {
311         return transformation::uniform_int_from_to<scalar_t>(rand, range, base);
312       };
313       distribution_nullary_kernel<scalar_t, uint32_t, uint4>(iter,
314         gen,
315         [] __device__ (curandStatePhilox4_32_10_t* state) -> uint4 {
316           return curand4(state);
317         },
318         random_func);
319     }
320    }), AT_EXPAND(AT_ALL_TYPES), kBool, kHalf, kBFloat16, AT_EXPAND(AT_BAREBONES_UNSIGNED_TYPES));
321 }
322 
323 // This is the special kernel to handle single specific case:
324 // from(inclusive) = std::numeric_limits<int64_t>::lowest()
325 // to(exclusive) = None (= std::numeric_limits<int64_t>::max() + 1)
326 template<typename RNG>
random_full_64_bits_range_kernel(TensorIteratorBase & iter,RNG gen)327 void random_full_64_bits_range_kernel(TensorIteratorBase& iter, RNG gen) {
328   AT_DISPATCH_ALL_TYPES_AND(at::ScalarType::BFloat16, iter.dtype(), "random_full_64_bits_range_kernel_cuda", [&] {
329     if (std::is_same<scalar_t, int64_t>::value ||
330         std::is_same<scalar_t, double>::value ||
331         std::is_same<scalar_t, float>::value ||
332         std::is_same<scalar_t, at::BFloat16>::value) {
333       auto random_func = [] __device__ (uint64_t rand) {
334         return transformation::uniform_int_full_range<scalar_t>(rand);
335       };
336       distribution_nullary_kernel<scalar_t, uint64_t, ulonglong2>(iter,
337         gen,
338         [] __device__ (curandStatePhilox4_32_10_t* state) -> ulonglong2 {
339           ulonglong2 ret;
340           uint4 rand_val = curand4(state);
341           ret.x = (static_cast<uint64_t>(rand_val.x) << 32) | rand_val.y;
342           ret.y = (static_cast<uint64_t>(rand_val.z) << 32) | rand_val.w;
343           return ret;
344         },
345         random_func);
346     } else {
347       TORCH_CHECK(false, "random_full_64_bits_range_kernel_cuda handles only int64, double, float and bfloat16");
348     }
349   });
350 }
351 
352 template<typename RNG>
353 struct RandomFromToKernel {
operatorRandomFromToKernel354   void operator()(TensorIteratorBase& iter, uint64_t range, int64_t base, std::optional<Generator> gen) {
355     random_from_to_kernel(iter, range, base, check_generator<RNG>(gen));
356   }
operatorRandomFromToKernel357   void operator()(TensorIteratorBase& iter, std::optional<Generator> gen) {
358     random_full_64_bits_range_kernel(iter, check_generator<RNG>(gen));
359   }
360 };
361 
362 template<typename RNG>
random_kernel(TensorIteratorBase & iter,RNG gen)363 void random_kernel(TensorIteratorBase& iter, RNG gen) {
364   AT_DISPATCH_ALL_TYPES_AND3(at::ScalarType::Half, at::ScalarType::BFloat16, at::ScalarType::Bool, iter.dtype(), "random_kernel_cuda", [&] {
365     if (std::is_same<scalar_t, double>::value || std::is_same<scalar_t, int64_t>::value) {
366       auto random_func = [] __device__ (uint64_t rand) {
367         return transformation::uniform_int<scalar_t>(rand);
368       };
369       distribution_nullary_kernel<scalar_t, uint64_t, ulonglong2>(iter, gen,
370         [] __device__ (curandStatePhilox4_32_10_t* state) -> ulonglong2 {
371           ulonglong2 ret;
372           uint4 rand_val = curand4(state);
373           ret.x = (static_cast<uint64_t>(rand_val.x) << 32) | rand_val.y;
374           ret.y = (static_cast<uint64_t>(rand_val.z) << 32) | rand_val.w;
375           return ret;
376         },
377         random_func);
378     } else {
379       auto random_func = [] __device__ (uint32_t rand) {
380         return transformation::uniform_int<scalar_t>(rand);
381       };
382       distribution_nullary_kernel<scalar_t, uint32_t, uint4>(iter,
383         gen,
384         [] __device__ (curandStatePhilox4_32_10_t* state) -> uint4 {
385           return curand4(state);
386         },
387         random_func);
388     }
389   });
390 }
391 
392 template<typename RNG>
393 struct RandomKernel {
operatorRandomKernel394   void operator()(TensorIteratorBase& iter, RNG gen) {
395     random_kernel(iter, gen);
396   }
397 };
398 
399 // ====================================================================================================================
400 
401 template<typename scalar_t, typename accscalar_t, typename RNG, typename transform_t>
uniform_and_transform(TensorIteratorBase & iter,RNG gen,transform_t transform)402 void uniform_and_transform(TensorIteratorBase& iter, RNG gen, transform_t transform) {
403   if (std::is_same<scalar_t, double>::value) {
404     distribution_nullary_kernel<scalar_t, accscalar_t, double2>(iter,
405       gen,
406       [] __device__ (curandStatePhilox4_32_10_t* state) -> double2 { return curand_uniform2_double(state); },
407       transform);
408   } else {
409     distribution_nullary_kernel<scalar_t, accscalar_t, float4>(iter,
410       gen,
411       [] __device__ (curandStatePhilox4_32_10_t* state) -> float4 { return curand_uniform4(state); },
412       transform);
413   }
414 }
415 
416 template<typename scalar_t, typename accscalar_t, typename RNG, typename transform_t>
normal_and_transform(TensorIteratorBase & iter,RNG gen,transform_t transform)417 void normal_and_transform(TensorIteratorBase& iter, RNG gen, transform_t transform) {
418   if (std::is_same<scalar_t, double>::value) {
419     distribution_nullary_kernel<scalar_t, accscalar_t, double2>(iter,
420       gen,
421       [] __device__ (curandStatePhilox4_32_10_t* state) -> double2 { return curand_normal2_double(state); },
422       transform);
423   } else {
424     distribution_nullary_kernel<scalar_t, accscalar_t, float4>(iter,
425       gen,
426       [] __device__ (curandStatePhilox4_32_10_t* state) -> float4 { return curand_normal4(state); },
427       transform);
428   }
429 }
430 
431 // ==================================================== Normal ========================================================
432 
433 template<typename RNG>
normal_kernel(const TensorBase & self,double mean_,double std_,RNG gen)434 void normal_kernel(const TensorBase &self, double mean_, double std_, RNG gen) {
435   auto iter = TensorIterator::borrowing_nullary_op(self);
436   AT_DISPATCH_FLOATING_TYPES_AND2(at::ScalarType::Half, at::ScalarType::BFloat16, iter.dtype(), "normal_kernel_cuda", [&] {
437     using accscalar_t = at::acc_type<scalar_t, true>;
438     auto mean = static_cast<accscalar_t>(mean_);
439     auto std = static_cast<accscalar_t>(std_);
440     // define lambda to multiply std and add mean
441     auto normal_func = [mean, std] __device__ (accscalar_t rand) {
442       return static_cast<scalar_t>(transformation::normal<accscalar_t>(rand, mean, std));
443     };
444     normal_and_transform<scalar_t, accscalar_t>(iter, gen, normal_func);
445    });
446 }
447 
448 template<typename RNG>
449 struct NormalKernel {
operatorNormalKernel450   void operator()(const TensorBase &self, double mean, double std, std::optional<Generator> gen) {
451     normal_kernel(self, mean, std, check_generator<RNG>(gen));
452   }
453 };
454 
455 // ==================================================== Uniform ========================================================
456 
457 template<typename RNG>
uniform_kernel(TensorIteratorBase & iter,double from_,double to_,RNG gen)458 void uniform_kernel(TensorIteratorBase& iter, double from_, double to_, RNG gen) {
459   AT_DISPATCH_FLOATING_TYPES_AND2(at::ScalarType::Half, at::ScalarType::BFloat16, iter.dtype(), "uniform_kernel_cuda", [&] {
460     auto from = static_cast<scalar_t>(from_);
461     auto to = static_cast<scalar_t>(to_);
462     using opmath_t = at::opmath_type<scalar_t>;
463     auto range = static_cast<opmath_t>(to-from);
464     // define lambda to reverse bounds, multiply 'range' and add 'from_'
465     auto uniform_func = [range, from, to] __device__ (opmath_t rand) {
466       // Compute output value before reversing the bounds
467       // BEFORE TOUCHING THIS CODE READ: https://github.com/pytorch/pytorch/issues/96947
468       auto value = static_cast<scalar_t>(rand * range + from);
469       // reverse the bounds of curand4 from (0, 1] to [0, 1)
470       // Note that this method is from legacy THCTensorRandom and is likely to give
471       // you more 0-s, since, the probability of gettings 1-s is higher than 0-s and
472       // by reversing the bounds, we are flipping the probabilities of 1-s and 0-s.
473       // BEFORE TOUCHING THIS CODE READ: https://github.com/pytorch/pytorch/issues/16706
474       auto reverse_bound_value = value == to ? from : value;
475       return reverse_bound_value;
476     };
477     uniform_and_transform<scalar_t, opmath_t>(iter, gen, uniform_func);
478    });
479 }
480 
481 template<typename RNG>
482 struct UniformKernel {
operatorUniformKernel483   void operator()(TensorIteratorBase& iter, double from, double to, std::optional<Generator> gen) {
484     uniform_kernel(iter, from, to, check_generator<RNG>(gen));
485   }
486 };
487 
488 // ================================================== LogNormal =======================================================
489 
490 template<typename RNG>
log_normal_kernel(TensorIteratorBase & iter,double mean_,double std_,RNG gen)491 void log_normal_kernel(TensorIteratorBase& iter, double mean_, double std_, RNG gen) {
492   AT_DISPATCH_FLOATING_TYPES_AND2(at::ScalarType::Half, at::ScalarType::BFloat16, iter.dtype(), "log_normal_cuda", [&] {
493     using accscalar_t = at::acc_type<scalar_t, true>;
494     auto mean = static_cast<accscalar_t>(mean_);
495     auto std = static_cast<accscalar_t>(std_);
496     // define lambda for log_normal transformation
497     auto log_normal_func = [mean, std] __device__ (accscalar_t rand) {
498       return static_cast<scalar_t>(transformation::log_normal<accscalar_t>(transformation::normal<accscalar_t>(rand, mean, std)));
499     };
500     normal_and_transform<scalar_t, accscalar_t>(iter, gen, log_normal_func);
501    });
502 }
503 
504 template<typename RNG>
505 struct LogNormalKernel {
operatorLogNormalKernel506   void operator()(TensorIteratorBase& iter, double mean, double std, std::optional<Generator> gen) {
507     log_normal_kernel(iter, mean, std, check_generator<RNG>(gen));
508   }
509 };
510 
511 // =================================================== Geometric ======================================================
512 
513 template<typename RNG>
geometric_kernel(TensorIteratorBase & iter,double p,RNG gen)514 void geometric_kernel(TensorIteratorBase& iter, double p, RNG gen) {
515   AT_DISPATCH_ALL_TYPES_AND2(at::ScalarType::Half, at::ScalarType::BFloat16, iter.dtype(), "geometric_cuda", [&] {
516     using accscalar_t = at::DiscreteDistributionType<scalar_t>::type;
517     // define lambda for geometric transformation
518     auto geometric_func = [p] __device__ (accscalar_t rand) {
519       return static_cast<scalar_t>(transformation::geometric<accscalar_t>(rand, p));
520     };
521     uniform_and_transform<scalar_t, accscalar_t>(iter, gen, geometric_func);
522   });
523 }
524 
525 template<typename RNG>
526 struct GeometricKernel {
operatorGeometricKernel527   void operator()(TensorIteratorBase& iter, double p, std::optional<Generator> gen) {
528     geometric_kernel(iter, p, check_generator<RNG>(gen));
529   }
530 };
531 
532 // ================================================== Exponential =====================================================
533 
534 template<typename RNG>
exponential_kernel(TensorIteratorBase & iter,double lambda_,RNG gen)535 void exponential_kernel(TensorIteratorBase& iter, double lambda_, RNG gen) {
536   TORCH_CHECK(isFloatingType(iter.dtype()), "Exponential distribution is a continuous probability distribution. dtype must be a floating point but you specified ", iter.dtype());
537   AT_DISPATCH_FLOATING_TYPES_AND2(at::ScalarType::Half, at::ScalarType::BFloat16, iter.dtype(), "exponential_cuda", [&] {
538     using accscalar_t = at::acc_type<scalar_t, true>;
539     auto lambda = static_cast<accscalar_t>(lambda_);
540     // define lambda for exponential transformation
541     auto exponential_func = [lambda] __device__ (accscalar_t rand) {
542       return static_cast<scalar_t>(transformation::exponential<accscalar_t>(rand, lambda));
543     };
544     uniform_and_transform<scalar_t, accscalar_t>(iter, gen, exponential_func);
545    });
546 }
547 
548 template<typename RNG>
549 struct ExponentialKernel {
operatorExponentialKernel550   void operator()(TensorIteratorBase& iter, double lambda, std::optional<Generator> gen) {
551     exponential_kernel(iter, lambda, check_generator<RNG>(gen));
552   }
553 };
554 
555 // ==================================================== Cauchy ========================================================
556 
557 template<typename RNG>
cauchy_kernel(TensorIteratorBase & iter,double median_,double sigma_,RNG gen)558 void cauchy_kernel(TensorIteratorBase& iter, double median_, double sigma_, RNG gen) {
559   AT_DISPATCH_FLOATING_TYPES_AND2(at::ScalarType::Half, at::ScalarType::BFloat16, iter.dtype(), "cauchy_cuda", [&] {
560     using accscalar_t = at::acc_type<scalar_t, true>;
561     auto median = static_cast<accscalar_t>(median_);
562     auto sigma = static_cast<accscalar_t>(sigma_);
563     // define lambda for cauchy transformation
564     auto cauchy_func = [median, sigma] __device__ (accscalar_t rand) {
565       return static_cast<scalar_t>(transformation::cauchy<accscalar_t>(rand, median, sigma));
566     };
567     uniform_and_transform<scalar_t, accscalar_t>(iter, gen, cauchy_func);
568    });
569 }
570 
571 template<typename RNG>
572 struct CauchyKernel {
operatorCauchyKernel573   void operator()(TensorIteratorBase& iter, double median, double sigma, std::optional<Generator> gen) {
574     cauchy_kernel(iter, median, sigma, check_generator<RNG>(gen));
575   }
576 };
577 
578 // ==================================================== Bernoulli =====================================================
579 
580 template<typename scalar_t, typename prob_t>
bernoulli_tensor_cuda_kernel(const TensorBase & ret,const at::TensorBase & p,PhiloxCudaState philox_args)581 void bernoulli_tensor_cuda_kernel(
582     const TensorBase &ret, const at::TensorBase &p,
583     PhiloxCudaState philox_args) {
584   auto functor = [philox_args] __device__(
585           int n, scalar_t& v1, scalar_t& v2, scalar_t& v3, scalar_t& v4,
586           const prob_t& p1, const prob_t& p2, const prob_t& p3, const prob_t& p4) {
587         auto seeds = at::cuda::philox::unpack(philox_args);
588         curandStatePhilox4_32_10_t state;
589         curand_init(std::get<0>(seeds),
590                     blockIdx.x * blockDim.x + threadIdx.x,
591                     std::get<1>(seeds),
592                     &state);
593 
594         // See Note [Register spilling in curand call for CUDA < 10]
595         float4 rand = curand_uniform4(&state);
596         switch (n) {
597           case 4: {
598             CUDA_KERNEL_ASSERT(0 <= p4 && p4 <= 1);
599             v4 = static_cast<scalar_t>(rand.w <= p4);
600             [[fallthrough]];
601           }
602           case 3: {
603             CUDA_KERNEL_ASSERT(0 <= p3 && p3 <= 1);
604             v3 = static_cast<scalar_t>(rand.z <= p3);
605             [[fallthrough]];
606           }
607           case 2: {
608             CUDA_KERNEL_ASSERT(0 <= p2 && p2 <= 1);
609             v2 = static_cast<scalar_t>(rand.y <= p2);
610             [[fallthrough]];
611           }
612           case 1: {
613             CUDA_KERNEL_ASSERT(0 <= p1 && p1 <= 1);
614             v1 = static_cast<scalar_t>(rand.x <= p1);
615           }
616         }
617       };
618   // The template argument `4` below indicates that we want to operate on four
619   // element at each time. See NOTE [ CUDA_tensor_applyN helpers ] for details.
620   at::cuda::CUDA_tensor_apply2<scalar_t, const prob_t, 4, decltype(functor),
621                                /*max_threads_per_block=*/512,
622                                /*min_blocks_per_sm==*/2>(ret, p, functor);
623 }
624 
625 template<typename RNG>
bernoulli_kernel(const TensorBase & self,const TensorBase & p_,RNG gen)626 void bernoulli_kernel(const TensorBase &self, const TensorBase &p_, RNG gen) {
627   PhiloxCudaState rng_engine_inputs;
628   {
629     // See Note [Acquire lock when using random generators]
630     std::lock_guard<std::mutex> lock(gen->mutex_);
631     rng_engine_inputs = gen->philox_cuda_state(10);
632   }
633   TORCH_CHECK(at::isFloatingType(p_.scalar_type()), "expected probabilities tensor to have floating type, got ", p_.scalar_type());
634   // cast probabilities tensor to double for double `self` tensor, and to `float` for everything else
635   const auto p_type = self.dtype() == at::kDouble ? at::kDouble : at::kFloat;
636   auto p_cuda = p_.to(TensorOptions().device(self.device()).dtype(p_type));
637   auto p = expand_inplace(self, p_cuda);
638   AT_DISPATCH_ALL_TYPES_AND3(
639     at::ScalarType::Half, at::ScalarType::BFloat16, at::ScalarType::Bool, self.scalar_type(), "bernoulli_tensor_cuda_self_", [&] {
640       if (std::is_same<scalar_t, double>::value) {
641         return bernoulli_tensor_cuda_kernel<double, double>(self, *p, rng_engine_inputs);
642       } else {
643         return bernoulli_tensor_cuda_kernel<scalar_t, float>(self, *p, rng_engine_inputs);
644       }
645    });
646 }
647 
648 template<typename RNG>
bernoulli_kernel(TensorIteratorBase & iter,double p,RNG gen)649 void bernoulli_kernel(TensorIteratorBase& iter, double p, RNG gen) {
650   AT_DISPATCH_ALL_TYPES_AND3(
651     at::ScalarType::Half, at::ScalarType::BFloat16, at::ScalarType::Bool, iter.dtype(), "bernoulli_scalar_cuda_", [&] {
652       using accscalar_t = at::DiscreteDistributionType<scalar_t>::type;
653       // define lambda for bernoulli transformation
654       auto bernoulli_func = [p] __device__ (accscalar_t rand) {
655         return static_cast<scalar_t>(transformation::bernoulli<accscalar_t>(rand, p));
656       };
657       uniform_and_transform<scalar_t, accscalar_t>(iter, gen, bernoulli_func);
658    });
659 }
660 
661 template<typename RNG>
662 struct BernoulliKernel {
operatorBernoulliKernel663   void operator()(TensorIteratorBase& iter, double p, std::optional<Generator> gen) {
664     bernoulli_kernel(iter, p, check_generator<RNG>(gen));
665   }
operatorBernoulliKernel666   void operator()(const TensorBase &self, const TensorBase &p_, std::optional<Generator> gen) {
667     bernoulli_kernel(self, p_, check_generator<RNG>(gen));
668   }
669 };
670 
671 }}}}
672