1 #pragma once
2
3 #include <thrust/tuple.h>
4
5 #include <ATen/native/SharedReduceOps.h>
6 #include <ATen/cuda/DeviceUtils.cuh>
7
8 namespace at {
9 namespace native {
10 namespace cuda_utils {
11
12 constexpr int kCUDABlockReduceNumThreads = 512;
13 // Algorithmic limitation: BlockReduce does two WarpReduce calls, each
14 // of which reduces C10_WARP_SIZE elements. So, at most
15 // C10_WARP_SIZE**2 elements can be reduced at a time.
16 // NOTE: This is >= the max block size on current hardware anyway (1024).
17 constexpr int kCUDABlockReduceMaxThreads = C10_WARP_SIZE * C10_WARP_SIZE;
18
19 // Sums `val` across all threads in a warp.
20 //
21 // Assumptions:
22 // - The size of each block should be a multiple of `C10_WARP_SIZE`
23 template <typename T>
WarpReduceSum(T val)24 __inline__ __device__ T WarpReduceSum(T val) {
25 #pragma unroll
26 for (int offset = (C10_WARP_SIZE >> 1); offset > 0; offset >>= 1) {
27 val += WARP_SHFL_DOWN(val, offset);
28 }
29 return val;
30 }
31
32 // Picks the maximum `val` across all threads in a warp.
33 //
34 // Assumptions:
35 // - The size of each block should be a multiple of `C10_WARP_SIZE`
36 template <typename T>
WarpReduceMax(T val)37 __inline__ __device__ T WarpReduceMax(T val) {
38 #pragma unroll
39 for (int offset = (C10_WARP_SIZE >> 1); offset > 0; offset >>= 1) {
40 val = max_propagate_nan(val, WARP_SHFL_DOWN(val, offset));
41 }
42 return val;
43 }
44
45 struct Block1D {
Tidat::native::cuda_utils::Block1D46 static __forceinline__ __device__ int Tid() { return threadIdx.x; }
47
Warpsat::native::cuda_utils::Block1D48 static __forceinline__ __device__ int Warps() {
49 return blockDim.x / C10_WARP_SIZE;
50 }
51 };
52
53 struct Block2D {
Tidat::native::cuda_utils::Block2D54 static __forceinline__ __device__ int Tid() {
55 return threadIdx.x + threadIdx.y * blockDim.x;
56 }
57
Warpsat::native::cuda_utils::Block2D58 static __forceinline__ __device__ int Warps() {
59 return blockDim.x * blockDim.y / C10_WARP_SIZE;
60 }
61 };
62
63 // Sums `val` across all threads in a block.
64 //
65 // Warning: the return value is only valid for thread 0.
66 // Assumptions:
67 // - The size of each block should be a multiple of `C10_WARP_SIZE`
68 // - `shared` should be a pointer to shared memory with size of, at least,
69 // `sizeof(T) * number_of_warps`
70 template <typename T, typename B = Block1D>
BlockReduceSum(T val,T * shared)71 __inline__ __device__ T BlockReduceSum(T val, T* shared) {
72 const int tid = B::Tid();
73 const int lid = tid % C10_WARP_SIZE;
74 const int wid = tid / C10_WARP_SIZE;
75 val = WarpReduceSum(val);
76 __syncthreads(); // prevent races when BlockReduces are called in a row.
77 if (lid == 0) {
78 shared[wid] = val;
79 }
80 __syncthreads();
81 val = (tid < B::Warps()) ? shared[lid] : T(0);
82 if (wid == 0) {
83 val = WarpReduceSum(val);
84 }
85 return val;
86 }
87
88 // Picks out the maximum `val` across all threads in a block.
89 //
90 // Warning: the return value is only valid for thread 0.
91 // Assumptions:
92 // - The size of each block should be a multiple of `C10_WARP_SIZE`
93 // - `shared` should be a pointer to shared memory with size of, at least,
94 // `sizeof(T) * number_of_warps`
95 template <typename T, typename B = Block1D>
BlockReduceMax(T val,T * shared)96 __inline__ __device__ T BlockReduceMax(T val, T* shared) {
97 const int tid = B::Tid();
98 const int lid = tid % C10_WARP_SIZE;
99 const int wid = tid / C10_WARP_SIZE;
100 val = WarpReduceMax(val);
101 __syncthreads(); // prevent races when BlockReduces are called in a row.
102 if (lid == 0) {
103 shared[wid] = val;
104 }
105 __syncthreads();
106 val = (tid < B::Warps()) ? shared[lid] : T(std::numeric_limits<T>::lowest());
107 if (wid == 0) {
108 val = WarpReduceMax(val);
109 }
110 return val;
111 }
112
113 template <typename T, class ReduceOp>
WarpReduce(T val,const ReduceOp & op)114 __inline__ __device__ T WarpReduce(T val, const ReduceOp& op) {
115 #pragma unroll
116 for (int offset = (C10_WARP_SIZE >> 1); offset > 0; offset >>= 1) {
117 val = op.combine(val, op.warp_shfl_down(val, offset));
118 }
119 return val;
120 }
121
122 template <typename T, class ReduceOp, typename B = Block1D>
123 __inline__ __device__ T
BlockReduce(T val,const ReduceOp & op,const T & identity_element,T * shared)124 BlockReduce(T val, const ReduceOp& op, const T& identity_element, T* shared) {
125 const int tid = B::Tid();
126 const int lid = tid % C10_WARP_SIZE;
127 const int wid = tid / C10_WARP_SIZE;
128 val = WarpReduce(val, op);
129 __syncthreads(); // prevent races when BlockReduces are called in a row.
130 if (lid == 0) {
131 shared[wid] = val;
132 }
133 __syncthreads();
134 val = (tid < B::Warps()) ? shared[lid] : identity_element;
135 if (wid == 0) {
136 val = WarpReduce(val, op);
137 }
138 return val;
139 }
140
141 } // namespace cuda_utils
142 } // namespace native
143 } // namespace at
144