xref: /aosp_15_r20/external/pytorch/aten/src/ATen/cuda/detail/KernelUtils.h (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #pragma once
2 
3 #include <limits>
4 #include <c10/util/Exception.h>
5 
6 namespace at::cuda::detail {
7 
8 // CUDA: grid stride looping
9 //
10 // int64_t _i_n_d_e_x specifically prevents overflow in the loop increment.
11 // If input.numel() < INT_MAX, _i_n_d_e_x < INT_MAX, except after the final
12 // iteration of the loop where _i_n_d_e_x += blockDim.x * gridDim.x can be
13 // greater than INT_MAX.  But in that case _i_n_d_e_x >= n, so there are no
14 // further iterations and the overflowed value in i=_i_n_d_e_x is not used.
15 #define CUDA_KERNEL_LOOP_TYPE(i, n, index_type)                         \
16   int64_t _i_n_d_e_x = blockIdx.x * blockDim.x + threadIdx.x;           \
17   for (index_type i=_i_n_d_e_x; _i_n_d_e_x < (n); _i_n_d_e_x+=blockDim.x * gridDim.x, i=_i_n_d_e_x)
18 
19 #define CUDA_KERNEL_LOOP(i, n) CUDA_KERNEL_LOOP_TYPE(i, n, int)
20 
21 
22 // Use 1024 threads per block, which requires cuda sm_2x or above
23 constexpr int CUDA_NUM_THREADS = 1024;
24 
25 // CUDA: number of blocks for threads.
26 inline int GET_BLOCKS(const int64_t N, const int64_t max_threads_per_block=CUDA_NUM_THREADS) {
27   TORCH_INTERNAL_ASSERT(N > 0, "CUDA kernel launch blocks must be positive, but got N=", N);
28   constexpr int64_t max_int = std::numeric_limits<int>::max();
29 
30   // Round up division for positive number that cannot cause integer overflow
31   auto block_num = (N - 1) / max_threads_per_block + 1;
32   TORCH_INTERNAL_ASSERT(block_num <= max_int, "Can't schedule too many blocks on CUDA device");
33 
34   return static_cast<int>(block_num);
35 }
36 
37 }  // namespace at::cuda::detail
38