1 #define TORCH_ASSERT_ONLY_METHOD_OPERATORS
2 #include <ATen/core/Tensor.h>
3 #include <ATen/AccumulateType.h>
4 #include <ATen/ceil_div.h>
5 #include <ATen/Dispatch.h>
6 #include <ATen/Utils.h>
7 #include <ATen/cuda/CUDAContext.h>
8 #include <ATen/cuda/EmptyTensor.h>
9 #include <ATen/cuda/detail/KernelUtils.h>
10 #include <ATen/native/UnaryOps.h>
11 #include <ATen/native/cuda/LaunchUtils.h>
12 #include <ATen/cuda/CUDAGraphsUtils.cuh>
13 #include <ATen/native/cuda/block_reduce.cuh>
14
15 #ifndef AT_PER_OPERATOR_HEADERS
16 #include <ATen/CUDAFunctions.h>
17 #include <ATen/NativeFunctions.h>
18 #else
19 #include <ATen/ops/empty_native.h>
20 #include <ATen/ops/empty_like_native.h>
21 #include <ATen/ops/cumsum_cuda_dispatch.h>
22 #include <ATen/ops/uniform_native.h>
23 #endif
24
25 #include <curand.h>
26 #include <curand_kernel.h>
27 #include <curand_philox4x32_x.h>
28 #include <type_traits>
29
30 namespace at::native {
31
32 namespace {
33
34 template <
35 typename T,
36 typename = std::enable_if_t<
37 std::is_floating_point_v<T> || std::is_convertible_v<T, float>>>
_isinf(T x)38 inline __device__ bool _isinf(T x) {
39 if constexpr (std::is_floating_point_v<T>) {
40 return ::isinf(x);
41 } else {
42 return ::isinf(static_cast<float>(x));
43 }
44 }
45
46 #define MAX_NUM_BLOCKS 200
47
48 // Normalizes the L1 norm of every row to 1; used by multinomial
49 template <typename scalar_t>
C10_LAUNCH_BOUNDS_1(cuda::detail::CUDA_NUM_THREADS)50 C10_LAUNCH_BOUNDS_1(cuda::detail::CUDA_NUM_THREADS)
51 __global__ void renormRowsL1(scalar_t* dist, long rows, long cols) {
52 extern __shared__ unsigned char my_smem[];
53 scalar_t *smem = reinterpret_cast<scalar_t *>(my_smem);
54 scalar_t zero = static_cast<scalar_t>(0);
55 scalar_t val;
56 for (int64_t row = blockIdx.x; row < rows; row += gridDim.x) {
57 scalar_t sum = static_cast<scalar_t>(0);
58 for (int64_t col = threadIdx.x; col < cols; col += blockDim.x) {
59 val = dist[row * cols + col];
60 CUDA_KERNEL_ASSERT(!(val < zero)); // ! < 0 for NaN handling
61 sum = sum + val;
62 }
63
64 sum = cuda_utils::BlockReduceSum(sum, smem);
65 if (threadIdx.x == 0) {
66 CUDA_KERNEL_ASSERT(!(val < zero)); // ! < 0 for NaN handling
67 smem[0] = sum;
68 }
69 __syncthreads();
70
71 sum = smem[0];
72 if (sum > zero) {
73 for (int64_t col = threadIdx.x; col < cols; col += blockDim.x) {
74 dist[row * cols + col] = dist[row * cols + col] / sum;
75 }
76 }
77 }
78 }
79
renormRows(Tensor & t)80 void renormRows(Tensor& t) {
81 TORCH_CHECK(t.dim() == 2);
82 int64_t rows = t.size(0);
83 int64_t cols = t.size(1);
84
85 auto props = at::cuda::getCurrentDeviceProperties();
86 TORCH_CHECK(props != nullptr);
87 int numSM = props->multiProcessorCount;
88 const int64_t maxThreads = std::min(
89 props->maxThreadsPerBlock, cuda_utils::kCUDABlockReduceMaxThreads);
90
91 int warp_size = at::cuda::warp_size();
92 dim3 grid(rows < numSM * 4 ? rows : numSM * 4);
93 dim3 block(std::min(maxThreads, warp_size * ceil_div(cols, int64_t{warp_size})));
94
95 AT_DISPATCH_FLOATING_TYPES_AND2(kHalf, kBFloat16, t.scalar_type(), "renormRows_cuda", [&] {
96 renormRowsL1<scalar_t>
97 <<<grid, block, (block.x / warp_size) * sizeof(scalar_t),
98 at::cuda::getCurrentCUDAStream()>>>(t.mutable_data_ptr<scalar_t>(),
99 rows, cols);
100 C10_CUDA_KERNEL_LAUNCH_CHECK();
101 });
102 }
103
104 template <typename scalar_t>
binarySearchForMultinomial(const scalar_t * cumdist,const scalar_t * dist,int size,scalar_t val)105 __device__ int binarySearchForMultinomial(const scalar_t* cumdist,
106 const scalar_t* dist,
107 int size,
108 scalar_t val) {
109 int start = 0;
110 int end = size;
111 // cumdist[size - 1] = 0 => all zero prob dist
112 CUDA_KERNEL_ASSERT(cumdist[size - 1] > static_cast<scalar_t>(0));
113
114 while (end - start > 0) {
115 int mid = start + (end - start) / 2;
116
117 scalar_t midVal = cumdist[mid];
118 if (midVal < val) {
119 start = mid + 1;
120 } else {
121 end = mid;
122 }
123 }
124
125 if (start == size) {
126 // No probability mass or precision problems; just return the
127 // first non-zero element by setting start to size-1 here,
128 // the code below will move it to the last non-zero probability
129 // this actually can happen when the random number is 1
130 // (github pytorch issue #4858).
131 start = size - 1;
132 }
133
134 while(start >= 1 && dist[start] == 0) start--;
135
136 return start;
137 }
138
139 template <typename scalar_t>
140 __global__ void
sampleMultinomialWithReplacement(PhiloxCudaState philox_args,int totalSamples,int64_t * dest,int64_t distributions,int categories,const scalar_t * normDistPrefixSum,const scalar_t * normDist)141 sampleMultinomialWithReplacement(PhiloxCudaState philox_args,
142 int totalSamples,
143 int64_t* dest,
144 int64_t distributions,
145 int categories,
146 const scalar_t* normDistPrefixSum,
147 const scalar_t* normDist) {
148 // At the moment, each warp computes one sample value in the binary
149 // search due to divergence. It seems possible to compute multiple
150 // values and limit divergence though later on.
151
152 auto seeds = at::cuda::philox::unpack(philox_args);
153
154 // global index formula for 2D grid of 1D blocks
155 int idx = blockIdx.y * gridDim.x * blockDim.x + blockIdx.x * blockDim.x + threadIdx.x;
156
157 curandStatePhilox4_32_10_t state;
158 curand_init(std::get<0>(seeds),
159 idx,
160 std::get<1>(seeds),
161 &state);
162
163 // The block determines the distribution for which we generate a point
164 for (int64_t curDist = blockIdx.y;
165 curDist < distributions;
166 curDist += gridDim.y) {
167 for (int sample = blockIdx.x*blockDim.x + threadIdx.x;
168 sample < totalSamples; sample += blockDim.x*gridDim.x) {
169
170 //we are losing 3 out of 4 generated numbers but it's ok
171 //this kernel is not very efficient anyway
172 auto rand = curand_uniform4(&state);
173 scalar_t r = static_cast<scalar_t>(rand.x);
174
175 // Find the bucket that a uniform sample lies in
176 int choice = binarySearchForMultinomial<scalar_t>(
177 normDistPrefixSum + curDist * categories,
178 normDist + curDist * categories,
179 categories,
180 r);
181
182 dest[curDist * totalSamples + sample] = choice;
183
184 }
185 }
186 }
187
188 template <typename scalar_t, typename accscalar_t>
C10_LAUNCH_BOUNDS_1(cuda::detail::CUDA_NUM_THREADS)189 C10_LAUNCH_BOUNDS_1(cuda::detail::CUDA_NUM_THREADS)
190 __global__ void sampleMultinomialOnce(
191 int64_t* dest,
192 int64_t distributions,
193 int categories,
194 const scalar_t* sampled,
195 const scalar_t* dist,
196 int stride_dist, // dist->stride(0)
197 int stride_categories // dist->stride(1)
198 ) {
199 extern __shared__ unsigned char my_smem[];
200 __shared__ bool found;
201 __shared__ unsigned foundPos;
202
203 accscalar_t *smem = reinterpret_cast<accscalar_t *>(my_smem);
204
205 accscalar_t accZero = static_cast<accscalar_t>(0);
206 scalar_t zero = static_cast<scalar_t>(0);
207
208 for (int64_t curDist = blockIdx.x;
209 curDist < distributions; curDist += gridDim.x) {
210 // Each block handles one distribution
211 // First pass, find the total sum of the distribution
212 accscalar_t sum = accZero;
213 scalar_t val;
214 for (int cat = threadIdx.x; cat < categories; cat += blockDim.x) {
215 val = dist[curDist * stride_dist + cat * stride_categories];
216 CUDA_KERNEL_ASSERT(!at::_isnan(val));
217 CUDA_KERNEL_ASSERT(!_isinf(val));
218 CUDA_KERNEL_ASSERT(!(val < zero));
219 sum = sum + static_cast<accscalar_t>(val);
220 }
221
222 // threadIdx.x == 0 has the sum value from this
223 sum = cuda_utils::BlockReduceSum(sum, smem);
224
225 // Broadcast sum and sample value
226 if (threadIdx.x == 0) {
227 // Make sure the sum of our distribution didn't overflow
228 CUDA_KERNEL_ASSERT(!_isinf(val));
229 CUDA_KERNEL_ASSERT(sum > accZero);
230
231 foundPos = 0;
232 smem[0] = sum;
233 smem[1] = sampled[curDist];
234 }
235 __syncthreads();
236
237 sum = smem[0];
238 scalar_t sample = static_cast<scalar_t>(smem[1]);
239 __syncthreads();
240
241 if (sum == accZero) {
242 // Choose the first element
243 if (threadIdx.x == 0) {
244 dest[curDist] = 0;
245 }
246
247 continue;
248 }
249
250 int chunks = (categories + (int)blockDim.x - 1) / blockDim.x;
251 accscalar_t prevHighProb = accZero;
252 found = false;
253
254 for (int chunk = 0; chunk < chunks && !found; ++chunk) {
255 // All threads in bounds load a value
256 int cat = chunk * blockDim.x + threadIdx.x;
257
258 accscalar_t dist_val = cat < categories ?
259 static_cast<accscalar_t>(dist[curDist * stride_dist + cat * stride_categories]) / sum :
260 accZero;
261
262 smem[threadIdx.x] = dist_val;
263 __syncthreads();
264
265 // Perform an inclusive prefix sum of the shared memory contents
266 for (int offset = 1; offset < blockDim.x; offset *= 2) {
267 accscalar_t val = accZero;
268
269 if (threadIdx.x >= offset) {
270 val = smem[threadIdx.x - offset] + smem[threadIdx.x];
271 }
272
273 __syncthreads();
274 if (threadIdx.x >= offset) {
275 smem[threadIdx.x] = val;
276 }
277 __syncthreads();
278 }
279
280 // Each thread will check to see if the sample falls in its
281 // bucket
282 scalar_t curBucket =
283 static_cast<scalar_t>(smem[threadIdx.x] + prevHighProb);
284 scalar_t prevBucket = static_cast<scalar_t>(
285 threadIdx.x == 0 ? prevHighProb
286 : smem[threadIdx.x - 1] + prevHighProb);
287 bool inBucket =
288 (cat < categories) &&
289 (!(sample >= curBucket) &&
290 (sample >= prevBucket) &&
291 (dist_val > zero));
292
293 if (inBucket) {
294 // We're done; we have the sample
295 // Torch indices are 1-based
296 atomicMax(&foundPos, cat);
297 found = true;
298 }
299
300 // Store the previous scan's high value for future use
301 prevHighProb = prevHighProb + smem[blockDim.x - 1];
302
303 __syncthreads();
304 }
305
306 if (threadIdx.x == 0) {
307 if (found) {
308 dest[curDist] = foundPos;
309 } else {
310 // This should address a rare bug where we don't select a valid index. This likely occurs when
311 // due to floating point arithmetic rounding errors, our cumulative sum does not add up to 1, but
312 // and our uniform sample is greater than this value. In this case we likely have unitialized memory
313 // in dest[curDist]. So basically we will loop through the distribution and pick the largest index
314 // where the distribution is non-zero. This is obviously terribly inefficient, but due to the
315 // rarity in which this occurs, this should not be an issue.
316 for (int cat = categories - 1; cat >= 0; --cat) {
317 if (dist[curDist * stride_dist + cat * stride_categories] > zero) {
318 dest[curDist] = cat;
319 break;
320 }
321 }
322 }
323 }
324 }
325 }
326
multinomial_with_replacement_kernel_impl(Tensor & result,const Tensor & self,const int64_t n_sample,std::optional<Generator> generator)327 void multinomial_with_replacement_kernel_impl(
328 Tensor& result,
329 const Tensor& self,
330 const int64_t n_sample,
331 std::optional<Generator> generator) {
332 auto gen = get_generator_or_default<CUDAGeneratorImpl>(generator, cuda::detail::getDefaultCUDAGenerator());
333
334 int inputSize = self.dim();
335 int64_t numDist =
336 inputSize == 1 ? 1 : self.size(0);
337 int numCategories =
338 inputSize == 1 ? self.size(0) : self.size(1);
339
340 // Restructure data for 2d
341 auto self_v = inputSize == 1 ? self.view({numDist, numCategories}) : self;
342
343 result.resize_({numDist, n_sample});
344
345 AT_DISPATCH_FLOATING_TYPES_AND2(kHalf, kBFloat16, self_v.scalar_type(), "multinomial_kernel_cuda", [&] {
346 using accscalar_t = at::acc_type<scalar_t, true>;
347 auto props = at::cuda::getCurrentDeviceProperties();
348 TORCH_CHECK(props != nullptr);
349 int numSM = props->multiProcessorCount;
350 int maxThreads = props->maxThreadsPerBlock;
351 int maxShared = props->sharedMemPerBlock;
352
353 int warp_size = at::cuda::warp_size();
354 int requiredWarps = at::ceil_div(numCategories, warp_size);
355 int requiredThreads = std::min(maxThreads, requiredWarps * warp_size);
356 int requiredShared = requiredThreads * sizeof(accscalar_t);
357
358 if (n_sample == 1 && maxShared >= requiredShared) {
359 // Optimized allocation-free implementation
360 // To exploit greater parallelism for the sampling, generate the
361 // Uniform random samples in a separate kernel launch, into
362 // temporarily allocated memory. The device RNG is thread-limited
363 Tensor sampled = at::detail::empty_cuda({numDist, n_sample}, self_v.options());
364 at::native::uniform_(sampled, 0.0, 1.0, generator);
365
366 dim3 block(requiredThreads);
367 dim3 grid(std::min(static_cast<int>(numDist), numSM * 4));
368
369 sampleMultinomialOnce<scalar_t, accscalar_t>
370 <<<grid, block,
371 requiredShared,
372 at::cuda::getCurrentCUDAStream()>>>(
373 result.mutable_data_ptr<int64_t>(),
374 numDist,
375 numCategories,
376 sampled.const_data_ptr<scalar_t>(),
377 self_v.const_data_ptr<scalar_t>(),
378 self_v.stride(0),
379 self_v.stride(1)
380 );
381 C10_CUDA_KERNEL_LAUNCH_CHECK();
382 } else {
383 // Generic, slow implementation with memory allocations
384
385 // For sampling without replacement, we modify the distribution
386 // for subsequent samples in this space
387 Tensor origDist = native::empty_like(
388 self_v,
389 std::nullopt /* dtype */,
390 std::nullopt /* layout */,
391 std::nullopt /* device */,
392 std::nullopt /* pin_memory */,
393 LEGACY_CONTIGUOUS_MEMORY_FORMAT);
394 origDist.copy_(self_v);
395
396 Tensor normDist = native::empty_like(
397 self_v,
398 std::nullopt /* dtype */,
399 std::nullopt /* layout */,
400 std::nullopt /* device */,
401 std::nullopt /* pin_memory */,
402 LEGACY_CONTIGUOUS_MEMORY_FORMAT);
403
404 Tensor prefixSum = native::empty_like(
405 self_v,
406 std::nullopt /* dtype */,
407 std::nullopt /* layout */,
408 std::nullopt /* device */,
409 std::nullopt /* pin_memory */,
410 LEGACY_CONTIGUOUS_MEMORY_FORMAT);
411
412 // Renorm along rows
413 normDist.copy_(origDist);
414 renormRows(normDist);
415
416 // Prefix sum along rows
417 at::cuda::cumsum_out(prefixSum, normDist, 1);
418
419 PhiloxCudaState rng_engine_inputs;
420
421 // Binary search is warp divergent (so effectively we're running
422 // with just a single thread), but for better utilization,
423 // we need each block to have at least 4 warps.
424 dim3 block(128);
425
426 // Each block will generate a sample from one
427 // distribution concurrently.
428 int grid_y=std::min<int>(numDist, at::cuda::getCurrentDeviceProperties()->maxGridSize[1]);
429 dim3 grid((n_sample-1)/block.x+1, grid_y);
430 {
431 // See Note [Acquire lock when using random generators]
432 std::lock_guard<std::mutex> lock(gen->mutex_);
433
434 // each thread generates a single sample for (numdist/numblocks.y) distributions, however, since we have to use
435 // curand_uniform4 (See Note [Register spilling in curand call for CUDA < 10]),
436 // offset is 4 times that.
437 auto offset = ((numDist-1)/grid.y+1)*4;
438 rng_engine_inputs = gen->philox_cuda_state(offset);
439 }
440 // Sample with replacement
441
442 sampleMultinomialWithReplacement
443 <<<grid, block, 0, at::cuda::getCurrentCUDAStream()>>>(
444 rng_engine_inputs,
445 n_sample,
446 result.mutable_data_ptr<int64_t>(),
447 numDist, numCategories,
448 prefixSum.const_data_ptr<scalar_t>(),
449 normDist.const_data_ptr<scalar_t>());
450 C10_CUDA_KERNEL_LAUNCH_CHECK();
451 }
452 });
453
454 if (inputSize == 1) {
455 result.resize_({n_sample});
456 }
457 }
458 }
459
460 REGISTER_DISPATCH(
461 multinomial_with_replacement_stub,
462 &multinomial_with_replacement_kernel_impl);
463 } // namespace at::native
464