xref: /aosp_15_r20/external/pytorch/aten/src/ATen/native/cuda/PersistentSoftmax.cuh (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #pragma once
2 
3 #include <cfloat>
4 #include <limits>
5 #include <stdint.h>
6 #include <cuda_fp16.h>
7 #include <c10/macros/Macros.h>
8 
9 #include <ATen/cuda/DeviceUtils.cuh>
10 
11 namespace {
12 
log2_ceil(int value)13 int log2_ceil(int value) {
14     int log2_value = 0;
15     while ((1 << log2_value) < value) ++log2_value;
16     return log2_value;
17 }
18 
19 template<typename T>
20 struct Add {
operator ()__anon714184670111::Add21   __device__ __forceinline__ T operator()(T a, T b) const {
22     return a + b;
23   }
24 };
25 
26 template<typename T>
27 struct Max {
operator ()__anon714184670111::Max28   __device__ __forceinline__ T operator()(T a, T b) const {
29     return a < b ? b : a;
30   }
31 };
32 
33 template <typename acc_t, int WARP_BATCH, int WARP_SIZE, template<typename> class ReduceOp>
warp_reduce(acc_t * sum)34 __device__ __forceinline__ void warp_reduce(acc_t* sum) {
35     ReduceOp<acc_t> r;
36     #pragma unroll
37     for (int offset = WARP_SIZE / 2; offset > 0; offset /= 2) {
38         #pragma unroll
39         for (int i = 0;  i < WARP_BATCH;  ++i) {
40             acc_t b = WARP_SHFL_XOR(sum[i], offset, WARP_SIZE);
41             sum[i] = r(sum[i], b);
42         }
43     }
44 }
45 
46 // The softmax_warp_* methods perform softmax forward and backward propagation on samples spanning the fast dimension.
47 // Each sample contains element_count scalar elements. element_count can be any integer value <= 1024.
48 // The template arguments have the following meaning:
49 // One "WARP" works on one "BATCH". One "BATCH" contains "WARP_BATCH" samples.
50 // WARP_BATCH is equal to 1 when element_count is large, and > 1 when element_count is small.
51 // A "WARP" contains "C10_WARPS_SIZE" threads, these treads are guaranteed to belong to the same warp.
52 // This is important because it means only __shfl_ instructions are required for reductions.
53 // Note that this means WARP_SIZE must be a power of two and <= architecture warp size.
54 // CUDA warp size is 32 for all existing GPU architectures, but there is no guarantee this will not change for future arch.
55 // ROCm warp size is 64 for all currently ROCm-supported GPU architectures, but this may change for future archs.
56 // is_log_softmax is a flag indicating whether SoftMax or LogSoftMax should be computed.
57 // is_masked is a flag indicating whether SoftMax or MaskedSoftMax should be computed.
58 // The template can be instantiated with any floating point type for the type arguments input_t, output_t and acc_t.
59 // This allows SoftMax to be fused with a cast immediately following the SoftMax.
60 // The mask should have the same shape as input, with a boolean indicate if the value is masked.
61 // The head_chunk_size is only used for transformer mask softmax, equals to H * D * D.
62 // For instance:
63 // input_t=half,  acc_t=float, output_t=half  => read half tensor, float accumulators, write half tensor.
64 // input_t=half,  acc_t=float, output_t=float => read half tensor, float accumulators, write float tensor.
65 // input_t_float, acc_t=float, output_t=half  => read float tensor, float accumulators, write half tensor.
66 
67 template <typename input_t, typename output_t, typename acc_t, int log2_elements, bool is_log_softmax, bool is_masked>
softmax_warp_forward(output_t * dst,const input_t * src,int batch_size,int stride,int element_count,const bool * mask=nullptr,const int head_chunk_size=-1,bool is_transformer_mask=false)68 __global__ void softmax_warp_forward(output_t *dst, const input_t *src, int batch_size, int stride, int element_count, const bool *mask = nullptr, const int head_chunk_size = -1, bool is_transformer_mask = false)
69 {
70     // WARP_SIZE and WARP_BATCH must match the return values batches_per_warp and warp_size of method warp_softmax_forward_kernel.
71     constexpr int next_power_of_two = 1 << log2_elements;
72     constexpr int WARP_SIZE = (next_power_of_two < C10_WARP_SIZE) ? next_power_of_two : C10_WARP_SIZE;
73     constexpr int WARP_ITERATIONS = next_power_of_two / WARP_SIZE;
74     constexpr int WARP_BATCH = (next_power_of_two <= 128) ? 2 : 1;
75 
76     int first_batch = (blockDim.y * blockIdx.x + threadIdx.y) * WARP_BATCH;
77 
78     // batch_size might not be a multiple of WARP_BATCH. Check how
79     // many batches have to computed within this WARP.
80     int local_batches = batch_size - first_batch;
81     if (local_batches > WARP_BATCH)
82         local_batches = WARP_BATCH;
83 
84     // there might be multiple batches per warp. compute the index within the batch
85     int local_idx = threadIdx.x;
86     int idx_offset = first_batch * stride + local_idx;
87 
88     src += idx_offset;
89     dst += idx_offset;
90 
91     if (is_transformer_mask) {
92         mask += ((first_batch * stride) / head_chunk_size) * stride + local_idx;
93     } else {
94         mask += idx_offset;
95     }
96     // The nested loops over WARP_BATCH and then WARP_ITERATIONS can be simplified to one loop,
97     // but I think doing so would obfuscate the logic of the algorithm, thus I chose to keep
98     // the nested loops.
99     // This should have no impact on performance because the loops are unrolled anyway.
100 
101     // load data from global memory
102     acc_t elements[WARP_BATCH][WARP_ITERATIONS];
103     for (int i = 0;  i < WARP_BATCH;  ++i) {
104         int batch_element_count = (i >= local_batches) ? 0 : element_count;
105         for (int it = 0;  it < WARP_ITERATIONS;  ++it) {
106             int element_index = local_idx + it * WARP_SIZE;
107             if (element_index < batch_element_count) {
108                 elements[i][it] = src[i*element_count+it*WARP_SIZE];
109             } else {
110                 elements[i][it] = -std::numeric_limits<acc_t>::infinity();
111             }
112         }
113     }
114 
115     // compute max_value
116     acc_t max_value[WARP_BATCH];
117     #pragma unroll
118     for (int i = 0;  i < WARP_BATCH;  ++i) {
119         int batch_element_count = (i >= local_batches) ? 0 : element_count;
120         bool is_meaningful_max = false;
121         max_value[i] = elements[i][0];
122         #pragma unroll
123         for (int it = 0;  it < WARP_ITERATIONS;  ++it) {
124             if (is_masked) {
125                 int idx = it*WARP_SIZE;
126                 if ((idx + local_idx) < batch_element_count) {
127                     if (!is_transformer_mask) {
128                         idx += i*element_count;
129                     }
130                     if (!mask[idx]) {
131                         max_value[i] = (is_meaningful_max && max_value[i] > elements[i][it]) ? max_value[i] : elements[i][it];
132                         is_meaningful_max = true;
133                     }
134                 }
135             } else {
136                 max_value[i] = max_value[i] > elements[i][it] ? max_value[i] : elements[i][it];
137             }
138         }
139         if (is_masked) {
140             if (!is_meaningful_max) {
141                 max_value[i] = -std::numeric_limits<acc_t>::infinity();
142             }
143         }
144     }
145     warp_reduce<acc_t, WARP_BATCH, WARP_SIZE, Max>(max_value);
146 
147     acc_t sum[WARP_BATCH] { 0.0f };
148     #pragma unroll
149     for (int i = 0;  i < WARP_BATCH;  ++i) {
150         int batch_element_count = (i >= local_batches) ? 0 : element_count;
151         #pragma unroll
152         for (int it = 0;  it < WARP_ITERATIONS;  ++it) {
153             if (!is_masked) {
154                 if (is_log_softmax) {
155                     sum[i] += std::exp(elements[i][it] - max_value[i]);
156                 } else {
157                     elements[i][it] = std::exp(elements[i][it] - max_value[i]);
158                     sum[i] += elements[i][it];
159                 }
160             } else {
161                 int idx = it*WARP_SIZE;
162                 bool valid = (idx + local_idx) < batch_element_count;
163                 if (!is_transformer_mask) {
164                     idx += i*element_count;
165                 }
166                 if (valid) {
167                     if (!mask[idx]) {
168                         if (is_log_softmax) {
169                             sum[i] += std::exp(elements[i][it] - max_value[i]);
170                         } else {
171                             elements[i][it] = std::exp(elements[i][it] - max_value[i]);
172                             sum[i] += elements[i][it];
173                         }
174                     } else {
175                         if (!is_log_softmax) {
176                             // Masked values are treated as -infinity, and std::exp(-infinity) is 0.
177                             elements[i][it] = 0;
178                         }
179                     }
180                 } else {
181                     if (!is_log_softmax) {
182                         elements[i][it] = 0.;
183                     }
184                 }
185             }
186         }
187     }
188     warp_reduce<acc_t, WARP_BATCH, WARP_SIZE, Add>(sum);
189 
190     // store result
191     #pragma unroll
192     for (int i = 0;  i < WARP_BATCH;  ++i) {
193         if (i >= local_batches)
194             break;
195         if (is_log_softmax) sum[i] = std::log(sum[i]);
196         #pragma unroll
197         for (int it = 0;  it < WARP_ITERATIONS;  ++it) {
198             int element_index = local_idx + it * WARP_SIZE;
199             if (element_index < element_count) {
200                 if (is_log_softmax) {
201                     dst[i*element_count+it*WARP_SIZE] = elements[i][it] - max_value[i] - sum[i];
202                 } else if (sum[i] == 0) {
203                     dst[i*element_count+it*WARP_SIZE] = std::numeric_limits<acc_t>::quiet_NaN();
204                 } else {
205                     dst[i*element_count+it*WARP_SIZE] = elements[i][it] / sum[i];
206                 }
207             } else {
208                 break;
209             }
210         }
211     }
212 }
213 
214 template <typename input_t, typename output_t, typename acc_t, int log2_elements, bool is_log_softmax, bool is_masked>
softmax_warp_backward(output_t * gradInput,const input_t * grad,const input_t * output,int batch_size,int stride,int element_count,const bool * mask=nullptr)215 __global__ void softmax_warp_backward(output_t *gradInput, const input_t *grad, const input_t *output, int batch_size, int stride, int element_count, const bool *mask = nullptr)
216 {
217     // WARP_SIZE and WARP_BATCH must match the return values batches_per_warp and warp_size of method warp_softmax_backward_kernel.
218     constexpr int next_power_of_two = 1 << log2_elements;
219     constexpr int WARP_SIZE = (next_power_of_two < C10_WARP_SIZE) ? next_power_of_two : C10_WARP_SIZE;
220     constexpr int WARP_ITERATIONS = next_power_of_two / WARP_SIZE;
221     constexpr int WARP_BATCH = (next_power_of_two <= 128) ? 2 : 1;
222 
223     int first_batch = (blockDim.y * blockIdx.x + threadIdx.y) * WARP_BATCH;
224 
225     // batch_size might not be a multiple of WARP_BATCH. Check how
226     // many batches have to computed within this WARP.
227     int local_batches = batch_size - first_batch;
228     if (local_batches > WARP_BATCH)
229         local_batches = WARP_BATCH;
230 
231     // there might be multiple batches per warp. compute the index within the batch
232     int local_idx = threadIdx.x % WARP_SIZE;
233 
234     // the first element to process by the current thread
235     int thread_offset = first_batch * stride + local_idx;
236     grad += thread_offset;
237     output += thread_offset;
238     gradInput += thread_offset;
239     if (is_masked) {
240         mask += thread_offset;
241     }
242 
243     // The nested loops over WARP_BATCH and then WARP_ITERATIONS can be simplified to one loop,
244     // but I think doing so would obfuscate the logic of the algorithm, thus I chose to keep
245     // the nested loops.
246     // This should have no impact on performance because the loops are unrolled anyway.
247 
248     // load data from global memory
249     acc_t grad_reg[WARP_BATCH][WARP_ITERATIONS];
250     acc_t output_reg[WARP_BATCH][WARP_ITERATIONS];
251     for (int i = 0;  i < WARP_BATCH;  ++i) {
252         int batch_element_count = (i >= local_batches) ? 0 : element_count;
253         for (int it = 0;  it < WARP_ITERATIONS;  ++it) {
254             int element_index = local_idx + it * WARP_SIZE;
255             if (element_index < batch_element_count) {
256                 grad_reg[i][it] = grad[i*element_count+it*WARP_SIZE];
257                 output_reg[i][it] = output[i*element_count+it*WARP_SIZE];
258             } else {
259                 grad_reg[i][it] = acc_t(0);
260                 output_reg[i][it] = acc_t(0);
261             }
262         }
263     }
264 
265     acc_t sum[WARP_BATCH] { 0.0f };
266     #pragma unroll
267     for (int i = 0;  i < WARP_BATCH;  ++i) {
268         #pragma unroll
269         for (int it = 0;  it < WARP_ITERATIONS;  ++it) {
270             if (!is_masked || !mask[i*element_count+it*WARP_SIZE]) {
271                 sum[i] += grad_reg[i][it];
272             }
273         }
274     }
275     warp_reduce<acc_t, WARP_BATCH, WARP_SIZE, Add>(sum);
276 
277     // store result
278     #pragma unroll
279     for (int i = 0;  i < WARP_BATCH;  ++i) {
280         if (i >= local_batches)
281             break;
282         #pragma unroll
283         for (int it = 0;  it < WARP_ITERATIONS;  ++it) {
284             int element_index = local_idx + it * WARP_SIZE;
285             if (element_index < element_count) {
286                 if (is_masked && mask[i*element_count+it*WARP_SIZE]) {
287                     gradInput[i*element_count+it*WARP_SIZE] = 0;
288                 }
289                 // compute gradients
290                 else if (is_log_softmax) {
291                     gradInput[i*element_count+it*WARP_SIZE] = (grad_reg[i][it] - std::exp(output_reg[i][it]) * sum[i]);
292                 } else {
293                     gradInput[i*element_count+it*WARP_SIZE] = (grad_reg[i][it] - output_reg[i][it] * sum[i]);
294                 }
295             }
296         }
297     }
298 }
299 
300 } // end of anonymous namespace
301 
302 template<typename input_t, typename output_t, typename acc_t, bool is_log_softmax, bool is_masked>
dispatch_softmax_forward(output_t * dst,const input_t * src,int softmax_elements,int softmax_elements_stride,int batch_count,const bool * mask=nullptr,int chunk_size=-1,bool is_transformer_mask=false)303 void dispatch_softmax_forward(output_t *dst, const input_t *src, int softmax_elements, int softmax_elements_stride, int batch_count, const bool *mask = nullptr, int chunk_size = -1, bool is_transformer_mask = false)
304 {
305     TORCH_INTERNAL_ASSERT( softmax_elements >= 0 && softmax_elements <= 1024 );
306     if (softmax_elements == 0) {
307         return;
308     } else {
309         int log2_elements = log2_ceil(softmax_elements);
310         const int next_power_of_two = 1 << log2_elements;
311 
312         // This value must match the WARP_SIZE constexpr value computed inside softmax_warp_forward.
313         int warp_size = at::cuda::warp_size();
314         warp_size = (next_power_of_two < warp_size) ? next_power_of_two : warp_size;
315 
316         // This value must match the WARP_BATCH constexpr value computed inside softmax_warp_forward.
317         int batches_per_warp = (next_power_of_two <= 128) ? 2 : 1;
318 
319         // use 128 threads per block to maximize gpu utilization
320         constexpr int threads_per_block = 128;
321 
322         int warps_per_block = (threads_per_block / warp_size);
323         int batches_per_block = warps_per_block * batches_per_warp;
324         int blocks = (batch_count + batches_per_block - 1) / batches_per_block;
325         dim3 threads(warp_size, warps_per_block, 1);
326         // Launch code would be more elegant if C++ supported FOR CONSTEXPR
327         switch (log2_elements) {
328             #define LAUNCH_SOFTMAX_WARP_FORWARD(L2E) case L2E:                    \
329             softmax_warp_forward<input_t, output_t, acc_t, L2E, is_log_softmax, is_masked>   \
330                 <<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(dst,   \
331                     src, batch_count, softmax_elements_stride, softmax_elements, mask, chunk_size, is_transformer_mask); \
332             C10_CUDA_KERNEL_LAUNCH_CHECK();                                       \
333             break;
334 
335             LAUNCH_SOFTMAX_WARP_FORWARD(0);  // 1
336             LAUNCH_SOFTMAX_WARP_FORWARD(1);  // 2
337             LAUNCH_SOFTMAX_WARP_FORWARD(2);  // 4
338             LAUNCH_SOFTMAX_WARP_FORWARD(3);  // 8
339             LAUNCH_SOFTMAX_WARP_FORWARD(4);  // 16
340             LAUNCH_SOFTMAX_WARP_FORWARD(5);  // 32
341             LAUNCH_SOFTMAX_WARP_FORWARD(6);  // 64
342             LAUNCH_SOFTMAX_WARP_FORWARD(7);  // 128
343             LAUNCH_SOFTMAX_WARP_FORWARD(8);  // 256
344             LAUNCH_SOFTMAX_WARP_FORWARD(9);  // 512
345             LAUNCH_SOFTMAX_WARP_FORWARD(10); ; // 1024
346             default:
347                 break;
348         }
349     }
350 }
351 
352 template<typename input_t, typename output_t, typename acc_t, bool is_log_softmax, bool is_masked>
dispatch_softmax_backward(output_t * grad_input,const input_t * grad,const input_t * output,int softmax_elements,int softmax_elements_stride,int batch_count,const bool * mask=nullptr)353 void dispatch_softmax_backward(output_t *grad_input, const input_t *grad, const input_t *output, int softmax_elements, int softmax_elements_stride, int batch_count, const bool *mask = nullptr)
354 {
355     TORCH_INTERNAL_ASSERT( softmax_elements >= 0 && softmax_elements <= 1024 );
356     if (softmax_elements == 0) {
357        return;
358     } else {
359         int log2_elements = log2_ceil(softmax_elements);
360         const int next_power_of_two = 1 << log2_elements;
361 
362         // This value must match the WARP_SIZE constexpr value computed inside softmax_warp_backward.
363         int warp_size = at::cuda::warp_size();
364         warp_size = (next_power_of_two < warp_size) ? next_power_of_two : warp_size;
365 
366         // This value must match the WARP_BATCH constexpr value computed inside softmax_warp_backward.
367         int batches_per_warp = (next_power_of_two <= 128) ? 2 : 1;
368 
369         // use 128 threads per block to maximize gpu utilization
370         constexpr int threads_per_block = 128;
371 
372         int warps_per_block = (threads_per_block / warp_size);
373         int batches_per_block = warps_per_block * batches_per_warp;
374         int blocks = (batch_count + batches_per_block - 1) / batches_per_block;
375         dim3 threads(warp_size, warps_per_block, 1);
376         // Launch code would be more elegant if C++ supported FOR CONSTEXPR
377         switch (log2_elements) {
378             #define LAUNCH_SOFTMAX_WARP_BACKWARD(L2E) case L2E:                      \
379             softmax_warp_backward<input_t, output_t, acc_t, L2E, is_log_softmax, is_masked> \
380                 <<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>       \
381                 (grad_input, grad, output, batch_count, softmax_elements_stride, \
382                 softmax_elements, mask);                                              \
383             C10_CUDA_KERNEL_LAUNCH_CHECK();                                      \
384             break;
385 
386             LAUNCH_SOFTMAX_WARP_BACKWARD(0); // 1
387             LAUNCH_SOFTMAX_WARP_BACKWARD(1); // 2
388             LAUNCH_SOFTMAX_WARP_BACKWARD(2); // 4
389             LAUNCH_SOFTMAX_WARP_BACKWARD(3); // 8
390             LAUNCH_SOFTMAX_WARP_BACKWARD(4); // 16
391             LAUNCH_SOFTMAX_WARP_BACKWARD(5); // 32
392             LAUNCH_SOFTMAX_WARP_BACKWARD(6); // 64
393             LAUNCH_SOFTMAX_WARP_BACKWARD(7); // 128
394             LAUNCH_SOFTMAX_WARP_BACKWARD(8); // 256
395             LAUNCH_SOFTMAX_WARP_BACKWARD(9); // 512
396             LAUNCH_SOFTMAX_WARP_BACKWARD(10); // 1024
397             default:
398                 break;
399         }
400     }
401 }
402