xref: /aosp_15_r20/external/pytorch/aten/src/ATen/native/cuda/block_reduce.cuh (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #pragma once
2 
3 #include <thrust/tuple.h>
4 
5 #include <ATen/native/SharedReduceOps.h>
6 #include <ATen/cuda/DeviceUtils.cuh>
7 
8 namespace at {
9 namespace native {
10 namespace cuda_utils {
11 
12 constexpr int kCUDABlockReduceNumThreads = 512;
13 // Algorithmic limitation: BlockReduce does two WarpReduce calls, each
14 // of which reduces C10_WARP_SIZE elements. So, at most
15 // C10_WARP_SIZE**2 elements can be reduced at a time.
16 // NOTE: This is >= the max block size on current hardware anyway (1024).
17 constexpr int kCUDABlockReduceMaxThreads = C10_WARP_SIZE * C10_WARP_SIZE;
18 
19 // Sums `val` across all threads in a warp.
20 //
21 // Assumptions:
22 //   - The size of each block should be a multiple of `C10_WARP_SIZE`
23 template <typename T>
WarpReduceSum(T val)24 __inline__ __device__ T WarpReduceSum(T val) {
25 #pragma unroll
26   for (int offset = (C10_WARP_SIZE >> 1); offset > 0; offset >>= 1) {
27     val += WARP_SHFL_DOWN(val, offset);
28   }
29   return val;
30 }
31 
32 // Picks the maximum `val` across all threads in a warp.
33 //
34 // Assumptions:
35 //   - The size of each block should be a multiple of `C10_WARP_SIZE`
36 template <typename T>
WarpReduceMax(T val)37 __inline__ __device__ T WarpReduceMax(T val) {
38 #pragma unroll
39   for (int offset = (C10_WARP_SIZE >> 1); offset > 0; offset >>= 1) {
40     val = max_propagate_nan(val, WARP_SHFL_DOWN(val, offset));
41   }
42   return val;
43 }
44 
45 struct Block1D {
Tidat::native::cuda_utils::Block1D46     static __forceinline__ __device__ int Tid() { return threadIdx.x; }
47 
Warpsat::native::cuda_utils::Block1D48     static __forceinline__ __device__ int Warps() {
49         return blockDim.x / C10_WARP_SIZE;
50     }
51 };
52 
53 struct Block2D {
Tidat::native::cuda_utils::Block2D54     static __forceinline__ __device__ int Tid() {
55         return threadIdx.x + threadIdx.y * blockDim.x;
56     }
57 
Warpsat::native::cuda_utils::Block2D58     static __forceinline__ __device__ int Warps() {
59         return blockDim.x * blockDim.y / C10_WARP_SIZE;
60     }
61 };
62 
63 // Sums `val` across all threads in a block.
64 //
65 // Warning: the return value is only valid for thread 0.
66 // Assumptions:
67 //   - The size of each block should be a multiple of `C10_WARP_SIZE`
68 //   - `shared` should be a pointer to shared memory with size of, at least,
69 //     `sizeof(T) * number_of_warps`
70 template <typename T, typename B = Block1D>
BlockReduceSum(T val,T * shared)71 __inline__ __device__ T BlockReduceSum(T val, T* shared) {
72   const int tid = B::Tid();
73   const int lid = tid % C10_WARP_SIZE;
74   const int wid = tid / C10_WARP_SIZE;
75   val = WarpReduceSum(val);
76   __syncthreads(); // prevent races when BlockReduces are called in a row.
77   if (lid == 0) {
78     shared[wid] = val;
79   }
80   __syncthreads();
81   val = (tid < B::Warps()) ? shared[lid] : T(0);
82   if (wid == 0) {
83     val = WarpReduceSum(val);
84   }
85   return val;
86 }
87 
88 // Picks out the maximum `val` across all threads in a block.
89 //
90 // Warning: the return value is only valid for thread 0.
91 // Assumptions:
92 //   - The size of each block should be a multiple of `C10_WARP_SIZE`
93 //   - `shared` should be a pointer to shared memory with size of, at least,
94 //     `sizeof(T) * number_of_warps`
95 template <typename T, typename B = Block1D>
BlockReduceMax(T val,T * shared)96 __inline__ __device__ T BlockReduceMax(T val, T* shared) {
97   const int tid = B::Tid();
98   const int lid = tid % C10_WARP_SIZE;
99   const int wid = tid / C10_WARP_SIZE;
100   val = WarpReduceMax(val);
101   __syncthreads(); // prevent races when BlockReduces are called in a row.
102   if (lid == 0) {
103     shared[wid] = val;
104   }
105   __syncthreads();
106   val = (tid < B::Warps()) ? shared[lid] : T(std::numeric_limits<T>::lowest());
107   if (wid == 0) {
108     val = WarpReduceMax(val);
109   }
110   return val;
111 }
112 
113 template <typename T, class ReduceOp>
WarpReduce(T val,const ReduceOp & op)114 __inline__ __device__ T WarpReduce(T val, const ReduceOp& op) {
115 #pragma unroll
116   for (int offset = (C10_WARP_SIZE >> 1); offset > 0; offset >>= 1) {
117     val = op.combine(val, op.warp_shfl_down(val, offset));
118   }
119   return val;
120 }
121 
122 template <typename T, class ReduceOp, typename B = Block1D>
123 __inline__ __device__ T
BlockReduce(T val,const ReduceOp & op,const T & identity_element,T * shared)124 BlockReduce(T val, const ReduceOp& op, const T& identity_element, T* shared) {
125   const int tid = B::Tid();
126   const int lid = tid % C10_WARP_SIZE;
127   const int wid = tid / C10_WARP_SIZE;
128   val = WarpReduce(val, op);
129   __syncthreads(); // prevent races when BlockReduces are called in a row.
130   if (lid == 0) {
131     shared[wid] = val;
132   }
133   __syncthreads();
134   val = (tid < B::Warps()) ? shared[lid] : identity_element;
135   if (wid == 0) {
136     val = WarpReduce(val, op);
137   }
138   return val;
139 }
140 
141 } // namespace cuda_utils
142 } // namespace native
143 } // namespace at
144