1 /******************************************************************************
2 * Copyright (c) 2024, Tri Dao.
3 ******************************************************************************/
4 #include <c10/core/ScalarType.h>
5 #define TORCH_ASSERT_ONLY_METHOD_OPERATORS
6
7 #include <cstdint>
8 #include <tuple>
9
10
11 #ifdef USE_FLASH_ATTENTION
12 #include <ATen/core/Tensor.h>
13 #include <ATen/cuda/CUDAContext.h>
14 #include <c10/cuda/CUDAGuard.h>
15 #include <ATen/cuda/CUDAGraphsUtils.cuh>
16
17 #ifndef AT_PER_OPERATOR_HEADERS
18 #include <ATen/Functions.h>
19 #include <ATen/NativeFunctions.h>
20 #else
21 #include <ATen/ops/empty.h>
22 #include <ATen/ops/empty_like.h>
23 #include <ATen/ops/reshape.h>
24 #include <ATen/ops/scalar_tensor.h>
25 #include <ATen/ops/sum.h>
26 #include <ATen/ops/slice.h>
27 #include <ATen/ops/narrow.h>
28 #include <ATen/ops/pad.h>
29 #include <ATen/ops/zeros.h>
30 #endif
31
32
33 #include <cutlass/numeric_types.h>
34
35 #include <ATen/native/transformers/cuda/flash_attn/flash.h>
36 #include <ATen/native/transformers/cuda/flash_attn/flash_api.h>
37 #include <ATen/native/transformers/cuda/flash_attn/static_switch.h>
38
39 #include <c10/util/Exception.h>
40
41 namespace pytorch_flash {
42
43 #define CHECK_DEVICE(x) TORCH_CHECK(x.is_cuda(), #x " must be on CUDA")
44 #define CHECK_SHAPE(x, ...) TORCH_CHECK(x.sizes() == at::IntArrayRef({__VA_ARGS__}), #x " must have shape (" #__VA_ARGS__ ")")
45 #define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous")
46
47
set_params_fprop(Flash_fwd_params & params,const size_t b,const size_t seqlen_q,const size_t seqlen_k,const size_t seqlen_q_rounded,const size_t seqlen_k_rounded,const size_t h,const size_t h_k,const size_t d,const size_t d_rounded,const at::Tensor q,const at::Tensor k,const at::Tensor v,at::Tensor out,void * cu_seqlens_q_d,void * cu_seqlens_k_d,void * seqused_k,void * p_d,void * softmax_lse_d,float p_dropout,float softmax_scale,int window_size_left,int window_size_right,bool seqlenq_ngroups_swapped=false)48 void set_params_fprop(Flash_fwd_params ¶ms,
49 // sizes
50 const size_t b,
51 const size_t seqlen_q,
52 const size_t seqlen_k,
53 const size_t seqlen_q_rounded,
54 const size_t seqlen_k_rounded,
55 const size_t h,
56 const size_t h_k,
57 const size_t d,
58 const size_t d_rounded,
59 // device pointers
60 const at::Tensor q,
61 const at::Tensor k,
62 const at::Tensor v,
63 at::Tensor out,
64 void *cu_seqlens_q_d,
65 void *cu_seqlens_k_d,
66 void *seqused_k,
67 void *p_d,
68 void *softmax_lse_d,
69 float p_dropout,
70 float softmax_scale,
71 int window_size_left,
72 int window_size_right,
73 bool seqlenq_ngroups_swapped=false) {
74
75 // Reset the parameters
76 params = {};
77
78 params.is_bf16 = q.dtype() == at::kBFloat16;
79
80 // Set the pointers and strides.
81 params.q_ptr = q.data_ptr();
82 params.k_ptr = k.data_ptr();
83 params.v_ptr = v.data_ptr();
84 // All stride are in elements, not bytes.
85 params.q_row_stride = q.stride(-3);
86 params.k_row_stride = k.stride(-3);
87 params.v_row_stride = v.stride(-3);
88 params.q_head_stride = q.stride(-2);
89 params.k_head_stride = k.stride(-2);
90 params.v_head_stride = v.stride(-2);
91 params.o_ptr = out.data_ptr();
92 params.o_row_stride = out.stride(-3);
93 params.o_head_stride = out.stride(-2);
94
95 if (cu_seqlens_q_d == nullptr) {
96 params.q_batch_stride = q.stride(0);
97 params.k_batch_stride = k.stride(0);
98 params.v_batch_stride = v.stride(0);
99 params.o_batch_stride = out.stride(0);
100 if (seqlenq_ngroups_swapped) {
101 params.q_batch_stride *= seqlen_q;
102 params.o_batch_stride *= seqlen_q;
103 }
104 }
105
106 params.cu_seqlens_q = static_cast<int *>(cu_seqlens_q_d);
107 params.cu_seqlens_k = static_cast<int *>(cu_seqlens_k_d);
108 params.seqused_k = static_cast<int *>(seqused_k);
109
110 // P = softmax(QK^T)
111 params.p_ptr = p_d;
112
113 // Softmax sum
114 params.softmax_lse_ptr = softmax_lse_d;
115
116 // Set the dimensions.
117 params.b = b;
118 params.h = h;
119 params.h_k = h_k;
120 params.h_h_k_ratio = h / h_k;
121 params.seqlen_q = seqlen_q;
122 params.seqlen_k = seqlen_k;
123 params.seqlen_q_rounded = seqlen_q_rounded;
124 params.seqlen_k_rounded = seqlen_k_rounded;
125 params.d = d;
126 params.d_rounded = d_rounded;
127
128 // Set the different scale values.
129 params.scale_softmax = softmax_scale;
130 params.scale_softmax_log2 = softmax_scale * M_LOG2E;
131
132 // Set this to probability of keeping an element to simplify things.
133 params.p_dropout = 1.f - p_dropout;
134 // Convert p from float to int so we don't have to convert the random uint to float to compare.
135 // [Minor] We want to round down since when we do the comparison we use <= instead of <
136 // params.p_dropout_in_uint = uint32_t(std::floor(params.p_dropout * 4294967295.0));
137 // params.p_dropout_in_uint16_t = uint16_t(std::floor(params.p_dropout * 65535.0));
138 params.p_dropout_in_uint8_t = uint8_t(std::floor(params.p_dropout * 255.0));
139 params.rp_dropout = 1.f / params.p_dropout;
140 params.scale_softmax_rp_dropout = params.rp_dropout * params.scale_softmax;
141 TORCH_CHECK(p_dropout < 1.f);
142 #ifdef FLASHATTENTION_DISABLE_DROPOUT
143 TORCH_CHECK(p_dropout == 0.0f, "This flash attention build does not support dropout.");
144 #endif
145
146 // Causal is the special case where window_size_right == 0 and window_size_left < 0.
147 // Local is the more general case where window_size_right >= 0 or window_size_left >= 0.
148 params.is_causal = window_size_left < 0 && window_size_right == 0;
149
150 if (window_size_left < 0 && window_size_right >= 0) { window_size_left = seqlen_k; }
151 if (window_size_left >= 0 && window_size_right < 0) { window_size_right = seqlen_k; }
152 params.window_size_left = window_size_left;
153 params.window_size_right = window_size_right;
154
155 #ifdef FLASHATTENTION_DISABLE_LOCAL
156 TORCH_CHECK(params.is_causal || (window_size_left < 0 && window_size_right < 0),
157 "This flash attention build does not support local attention.");
158 #endif
159
160 params.is_seqlens_k_cumulative = true;
161
162 #ifdef FLASHATTENTION_DISABLE_UNEVEN_K
163 TORCH_CHECK(d == d_rounded, "This flash attention build does not support headdim not being a multiple of 32.");
164 #endif
165 }
166
set_params_dgrad(Flash_bwd_params & params,const size_t b,const size_t seqlen_q,const size_t seqlen_k,const size_t seqlen_q_rounded,const size_t seqlen_k_rounded,const size_t h,const size_t h_k,const size_t d,const size_t d_rounded,const at::Tensor q,const at::Tensor k,const at::Tensor v,const at::Tensor out,const at::Tensor dout,at::Tensor dq,at::Tensor dk,at::Tensor dv,void * cu_seqlens_q_d,void * cu_seqlens_k_d,void * dq_accum_d,void * dk_accum_d,void * dv_accum_d,void * softmax_lse_d,void * dsoftmax_sum_d,float p_dropout,float softmax_scale,int window_size_left,int window_size_right,bool deterministic)167 void set_params_dgrad(Flash_bwd_params ¶ms,
168 // sizes
169 const size_t b,
170 const size_t seqlen_q,
171 const size_t seqlen_k,
172 const size_t seqlen_q_rounded,
173 const size_t seqlen_k_rounded,
174 const size_t h,
175 const size_t h_k,
176 const size_t d,
177 const size_t d_rounded,
178 // device pointers
179 const at::Tensor q,
180 const at::Tensor k,
181 const at::Tensor v,
182 const at::Tensor out,
183 const at::Tensor dout,
184 at::Tensor dq,
185 at::Tensor dk,
186 at::Tensor dv,
187 void *cu_seqlens_q_d,
188 void *cu_seqlens_k_d,
189 void *dq_accum_d,
190 void *dk_accum_d,
191 void *dv_accum_d,
192 void *softmax_lse_d,
193 void *dsoftmax_sum_d,
194 float p_dropout,
195 float softmax_scale,
196 int window_size_left,
197 int window_size_right,
198 bool deterministic) {
199
200 set_params_fprop(params,
201 b, seqlen_q, seqlen_k, seqlen_q_rounded, seqlen_k_rounded, h, h_k, d, d_rounded,
202 q, k, v, out,
203 cu_seqlens_q_d,
204 cu_seqlens_k_d,
205 nullptr,
206 nullptr,
207 softmax_lse_d,
208 p_dropout,
209 softmax_scale,
210 window_size_left,
211 window_size_right);
212
213 // Set the pointers and strides.
214 params.do_ptr = dout.data_ptr();
215 params.do_row_stride = dout.stride(-3);
216 params.do_head_stride = dout.stride(-2);
217 params.dq_ptr = dq.data_ptr();
218 params.dk_ptr = dk.data_ptr();
219 params.dv_ptr = dv.data_ptr();
220 params.dq_row_stride = dq.stride(-3);
221 params.dk_row_stride = dk.stride(-3);
222 params.dv_row_stride = dv.stride(-3);
223 params.dq_head_stride = dq.stride(-2);
224 params.dk_head_stride = dk.stride(-2);
225 params.dv_head_stride = dv.stride(-2);
226
227 if (cu_seqlens_q_d == nullptr) {
228 params.do_batch_stride = dout.stride(0);
229 params.dq_batch_stride = dq.stride(0);
230 params.dk_batch_stride = dk.stride(0);
231 params.dv_batch_stride = dv.stride(0);
232 }
233
234 params.dq_accum_ptr = dq_accum_d;
235 params.dk_accum_ptr = dk_accum_d;
236 params.dv_accum_ptr = dv_accum_d;
237
238 // Softmax sum
239 params.dsoftmax_sum = dsoftmax_sum_d;
240
241 params.deterministic = deterministic;
242 }
243
run_mha_fwd(Flash_fwd_params & params,cudaStream_t stream,bool force_split_kernel=false)244 void run_mha_fwd(Flash_fwd_params ¶ms, cudaStream_t stream, bool force_split_kernel=false) {
245 FP16_SWITCH(!params.is_bf16, [&] {
246 HEADDIM_SWITCH(params.d, [&] {
247 if (params.num_splits <= 1 && !force_split_kernel) { // If we don't set it num_splits == 0
248 run_mha_fwd_<elem_type, kHeadDim>(params, stream);
249 } else {
250 run_mha_fwd_splitkv_dispatch<elem_type, kHeadDim>(params, stream);
251 }
252 });
253 });
254 }
255
256 // Find the number of splits that maximizes the occupancy. For example, if we have
257 // batch * n_heads = 48 and we have 108 SMs, having 2 splits (efficiency = 0.89) is
258 // better than having 3 splits (efficiency = 0.67). However, we also don't want too many
259 // splits as that would incur more HBM reads/writes.
260 // So we find the best efficiency, then find the smallest number of splits that gets 85%
261 // of the best efficiency.
num_splits_heuristic(int batch_nheads_mblocks,int num_SMs,int num_n_blocks,int max_splits)262 inline int num_splits_heuristic(int batch_nheads_mblocks, int num_SMs, int num_n_blocks, int max_splits) {
263 // If we have enough to almost fill the SMs, then just use 1 split
264 if (batch_nheads_mblocks >= 0.8f * num_SMs) { return 1; }
265 max_splits = std::min({max_splits, num_SMs, num_n_blocks});
266 float max_efficiency = 0.f;
267 std::vector<float> efficiency;
268 efficiency.reserve(max_splits);
269 auto ceildiv = [](int a, int b) { return (a + b - 1) / b; };
270 // Some splits are not eligible. For example, if we have 64 blocks and choose 11 splits,
271 // we'll have 6 * 10 + 4 blocks. If we choose 12 splits, we'll have 6 * 11 + (-2) blocks
272 // (i.e. it's 11 splits anyway).
273 // So we check if the number of blocks per split is the same as the previous num_splits.
274 auto is_split_eligible = [&ceildiv, &num_n_blocks](int num_splits) {
275 return num_splits == 1 || ceildiv(num_n_blocks, num_splits) != ceildiv(num_n_blocks, num_splits - 1);
276 };
277 for (int num_splits = 1; num_splits <= max_splits; num_splits++) {
278 if (!is_split_eligible(num_splits)) {
279 efficiency.push_back(0.f);
280 } else {
281 float n_waves = float(batch_nheads_mblocks * num_splits) / num_SMs;
282 float eff = n_waves / ceil(n_waves);
283 // printf("num_splits = %d, eff = %f\n", num_splits, eff);
284 if (eff > max_efficiency) { max_efficiency = eff; }
285 efficiency.push_back(eff);
286 }
287 }
288 for (int num_splits = 1; num_splits <= max_splits; num_splits++) {
289 if (!is_split_eligible(num_splits)) { continue; }
290 if (efficiency[num_splits - 1] >= 0.85 * max_efficiency) {
291 // printf("num_splits chosen = %d\n", num_splits);
292 return num_splits;
293 }
294 }
295 return 1;
296 }
set_params_splitkv(Flash_fwd_params & params,const int batch_size,const int num_heads,const int head_size,const int max_seqlen_k,const int max_seqlen_q,const int head_size_rounded,const float p_dropout,const int num_splits,cudaDeviceProp * dprops,struct c10::TensorOptions opts)297 std::tuple<at::Tensor, at::Tensor> set_params_splitkv(Flash_fwd_params ¶ms, const int batch_size,
298 const int num_heads, const int head_size, const int max_seqlen_k, const int max_seqlen_q,
299 const int head_size_rounded, const float p_dropout,
300 const int num_splits, cudaDeviceProp *dprops, struct c10::TensorOptions opts) {
301
302 // This needs to match with run_mha_fwd_splitkv_dispatch
303 const int block_n = head_size <= 64 ? 256 : (head_size <= 128 ? 128 : 64);
304 const int num_n_blocks = (max_seqlen_k + block_n - 1) / block_n;
305 // Technically kBlockM = 64 only for the splitKV kernels, not the standard kernel.
306 // In any case we don't expect seqlen_q to be larger than 64 for inference.
307 const int num_m_blocks = (max_seqlen_q + 64 - 1) / 64;
308 params.num_splits = num_splits;
309 at::Tensor softmax_lse_accum;
310 at::Tensor out_accum;
311
312 if (p_dropout == 0.0f) { // SplitKV is not implemented for dropout
313 if (num_splits < 1) {
314 // We multiply number of SMs by 2 to hard-code the fact that we're using 128 threads per block.
315 params.num_splits = num_splits_heuristic(batch_size * num_heads * num_m_blocks, dprops->multiProcessorCount * 2, num_n_blocks, 128);
316 }
317 if (params.num_splits > 1) {
318 softmax_lse_accum = at::empty({params.num_splits, batch_size, num_heads, max_seqlen_q}, opts.dtype(at::kFloat));
319 out_accum = at::empty({params.num_splits, batch_size, num_heads, max_seqlen_q, head_size_rounded}, opts.dtype(at::kFloat));
320 params.softmax_lseaccum_ptr = softmax_lse_accum.data_ptr();
321 params.oaccum_ptr = out_accum.data_ptr();
322 }
323 TORCH_CHECK(params.num_splits <= 128, "num_splits > 128 not supported");
324 }
325
326 return std::make_tuple(softmax_lse_accum, out_accum);
327 }
328
set_params_alibi(Flash_fwd_params & params,std::optional<at::Tensor> & alibi_slopes_,int batch_size,int num_heads)329 void set_params_alibi(Flash_fwd_params ¶ms, std::optional<at::Tensor> &alibi_slopes_, int batch_size, int num_heads){
330 #ifdef FLASHATTENTION_DISABLE_ALIBI
331 TORCH_CHECK(!alibi_slopes_.has_value(), "This flash attention build does not support alibi.");
332 params.alibi_slopes_ptr = nullptr;
333 #else
334 if (alibi_slopes_.has_value()) {
335 auto alibi_slopes = alibi_slopes_.value();
336 TORCH_CHECK(alibi_slopes.dtype() == at::kFloat, "ALiBi slopes must have dtype fp32");
337 CHECK_DEVICE(alibi_slopes);
338 TORCH_CHECK(alibi_slopes.stride(-1) == 1, "ALiBi slopes tensor must have contiguous last dimension");
339 TORCH_CHECK(alibi_slopes.sizes() == at::IntArrayRef({num_heads}) || alibi_slopes.sizes() == at::IntArrayRef({batch_size, num_heads}));
340 params.alibi_slopes_ptr = alibi_slopes.data_ptr();
341 params.alibi_slopes_batch_stride = alibi_slopes.dim() == 2 ? alibi_slopes.stride(0) : 0;
342 } else {
343 params.alibi_slopes_ptr = nullptr;
344 }
345 #endif
346 }
347
348 // return {out, q_padded, k_padded, v_padded, out_padded, softmax_lse, p};
349 std::tuple<at::Tensor, at::Tensor, at::Tensor, at::Tensor, at::Tensor, at::Tensor, at::Tensor, at::Tensor>
mha_fwd(const at::Tensor & q,const at::Tensor & k,const at::Tensor & v,std::optional<at::Tensor> & out_,std::optional<at::Tensor> & alibi_slopes_,const float p_dropout,const float softmax_scale,bool is_causal,int window_size_left,int window_size_right,const bool return_softmax,std::optional<at::Generator> gen_)350 mha_fwd(const at::Tensor &q, // batch_size x seqlen_q x num_heads x head_size
351 const at::Tensor &k, // batch_size x seqlen_k x num_heads_k x head_size
352 const at::Tensor &v, // batch_size x seqlen_k x num_heads_k x head_size
353 std::optional<at::Tensor> &out_, // batch_size x seqlen_q x num_heads x head_size
354 std::optional<at::Tensor> &alibi_slopes_, // num_heads or batch_size x num_heads
355 const float p_dropout,
356 const float softmax_scale,
357 bool is_causal,
358 int window_size_left,
359 int window_size_right,
360 const bool return_softmax,
361 std::optional<at::Generator> gen_) {
362
363 auto dprops = at::cuda::getCurrentDeviceProperties();
364 // bool is_sm75 = dprops->major == 7 && dprops->minor == 5;
365 bool is_sm8x = dprops->major == 8 && dprops->minor >= 0;
366 bool is_sm90 = dprops->major == 9 && dprops->minor == 0;
367 TORCH_CHECK(is_sm90 || is_sm8x, "FlashAttention only supports Ampere GPUs or newer.");
368 // We will support Turing in the near future
369 // TORCH_CHECK(is_sm90 || is_sm8x || is_sm75, "FlashAttention only supports Turing GPUs or newer.");
370
371 auto q_dtype = q.dtype();
372 TORCH_CHECK(q_dtype == at::kHalf || q_dtype == at::kBFloat16,
373 "FlashAttention only support fp16 and bf16 data type");
374 if (q_dtype == at::kBFloat16) {
375 TORCH_CHECK(is_sm90 || is_sm8x, "bfloat16 is only supported on Ampere GPUs or newer");
376 }
377 TORCH_CHECK(k.dtype() == q_dtype, "query and key must have the same dtype");
378 TORCH_CHECK(v.dtype() == q_dtype, "query and value must have the same dtype");
379
380 CHECK_DEVICE(q); CHECK_DEVICE(k); CHECK_DEVICE(v);
381
382 TORCH_CHECK(q.stride(-1) == 1, "Input tensor must have contiguous last dimension");
383 TORCH_CHECK(k.stride(-1) == 1, "Input tensor must have contiguous last dimension");
384 TORCH_CHECK(v.stride(-1) == 1, "Input tensor must have contiguous last dimension");
385
386 const auto sizes = q.sizes();
387
388 const int batch_size = sizes[0];
389 int seqlen_q = sizes[1];
390 int num_heads = sizes[2];
391 const int head_size_og = sizes[3];
392 const int seqlen_k = k.size(1);
393 const int num_heads_k = k.size(2);
394 TORCH_CHECK(batch_size > 0, "batch size must be positive");
395 TORCH_CHECK(head_size_og % 8 == 0, "head_size must be a multiple of 8, this is ensured by padding!");
396 TORCH_CHECK(head_size_og <= 256, "FlashAttention forward only supports head dimension at most 256");
397 TORCH_CHECK(num_heads % num_heads_k == 0, "Number of heads in key/value must divide number of heads in query");
398
399 if (window_size_left >= seqlen_k) { window_size_left = -1; }
400 if (window_size_right >= seqlen_k) { window_size_right = -1; }
401
402 // causal=true is the same as causal=false in this case
403 if (seqlen_q == 1 && !alibi_slopes_.has_value()) { is_causal = false; }
404 if (is_causal) { window_size_right = 0; }
405
406 // Faster to transpose q from (b, 1, (nheads_kv ngroups), d) to (b, ngroups, nheads_kv, d) in this case
407 // H/t Daniel Haziza
408 const int seqlenq_ngroups_swapped = seqlen_q == 1 && num_heads > num_heads_k && window_size_left < 0 && window_size_right < 0 && p_dropout == 0.f && head_size_og % 8 == 0 && !alibi_slopes_.has_value();
409 const int ngroups = num_heads / num_heads_k;
410 at::Tensor temp_q = q;
411 if (seqlenq_ngroups_swapped) {
412 temp_q = q.reshape({batch_size, num_heads_k, ngroups, head_size_og}).transpose(1, 2);
413 seqlen_q = ngroups;
414 num_heads = num_heads_k;
415 }
416
417 CHECK_SHAPE(temp_q, batch_size, seqlen_q, num_heads, head_size_og);
418 CHECK_SHAPE(k, batch_size, seqlen_k, num_heads_k, head_size_og);
419 CHECK_SHAPE(v, batch_size, seqlen_k, num_heads_k, head_size_og);
420
421 at::Tensor q_padded, k_padded, v_padded;
422 q_padded = temp_q;
423 k_padded = k;
424 v_padded = v;
425
426 at::Tensor out;
427 if (out_.has_value()) {
428 out = out_.value();
429 TORCH_CHECK(out.dtype() == q_dtype, "Output must have the same dtype as inputs");
430 CHECK_DEVICE(out);
431 TORCH_CHECK(out.stride(-1) == 1, "Output tensor must have contiguous last dimension");
432 CHECK_SHAPE(out, batch_size, sizes[1], sizes[2], head_size_og);
433 if (seqlenq_ngroups_swapped) {
434 out = out.reshape({batch_size, num_heads_k, ngroups, head_size_og}).transpose(1, 2);
435 }
436 CHECK_SHAPE(out, batch_size, seqlen_q, num_heads, head_size_og);
437 if (head_size_og % 8 != 0) { out = at::empty_like(q_padded); }
438 } else {
439 out = at::empty_like(q_padded);
440 }
441
442 auto round_multiple = [](int x, int m) { return (x + m - 1) / m * m; };
443 const int head_size = round_multiple(head_size_og, 8);
444 const int head_size_rounded = round_multiple(head_size, 32);
445 const int seqlen_q_rounded = round_multiple(seqlen_q, 128);
446 const int seqlen_k_rounded = round_multiple(seqlen_k, 128);
447
448 // Otherwise the kernel will be launched from cuda:0 device
449 // Cast to char to avoid compiler warning about narrowing
450 at::cuda::CUDAGuard device_guard{(char)q.get_device()};
451
452 auto opts = q.options();
453
454 auto softmax_lse = at::empty({batch_size, num_heads, seqlen_q }, opts.dtype(at::kFloat));
455
456 at::Tensor p;
457 // Only return softmax if there's dropout to reduce compilation time
458 if (return_softmax) {
459 TORCH_CHECK(p_dropout > 0.0f, "return_softmax is only supported when p_dropout > 0.0");
460 p = at::empty({ batch_size, num_heads, seqlen_q_rounded, seqlen_k_rounded }, opts);
461 }
462
463 Flash_fwd_params params;
464 set_params_fprop(params,
465 batch_size,
466 seqlen_q, seqlen_k,
467 seqlen_q_rounded, seqlen_k_rounded,
468 num_heads, num_heads_k,
469 head_size, head_size_rounded,
470 q_padded, k_padded, v_padded, out,
471 /*cu_seqlens_q_d=*/nullptr,
472 /*cu_seqlens_k_d=*/nullptr,
473 /*seqused_k=*/nullptr,
474 return_softmax ? p.data_ptr() : nullptr,
475 softmax_lse.data_ptr(),
476 p_dropout,
477 softmax_scale,
478 window_size_left,
479 window_size_right);
480
481
482 // Keep references to these tensors to extend their lifetime
483 at::Tensor softmax_lse_accum, out_accum;
484 std::tie(softmax_lse_accum, out_accum) = set_params_splitkv(params, batch_size, num_heads,
485 head_size, seqlen_k, seqlen_q,
486 head_size_rounded, p_dropout, /*num_splits*/0, dprops, opts);
487
488 // We want to checkpoint and save the RNG state for backward if dropout
489 // We get the default generator and return the seed and offset which will
490 // be used in the backward function
491 at::Tensor seed_t, offset_t;
492 if (p_dropout > 0.0) {
493 auto gen = at::get_generator_or_default<at::CUDAGeneratorImpl>(std::nullopt, at::cuda::detail::getDefaultCUDAGenerator());
494 // number of times random will be generated per thread, to offset philox counter in thc random
495 // state
496 // We use a custom RNG that increases the offset by batch_size * nheads * 32.
497 int64_t counter_offset = params.b * params.h * 32;
498 // See Note [Acquire lock when using random generators]
499 std::lock_guard<std::mutex> lock(gen->mutex_);
500 at::PhiloxCudaState philox_state = gen->philox_cuda_state(counter_offset);
501 if (at::cuda::currentStreamCaptureStatus() == at::cuda::CaptureStatus::None) {
502 auto [seed, offset] = at::cuda::philox::unpack(philox_state);
503 seed_t = at::scalar_tensor(at::Scalar(static_cast<int64_t>(seed)), at::dtype(at::kLong));
504 offset_t = at::scalar_tensor(at::Scalar(static_cast<int64_t>(offset)), at::dtype(at::kLong));
505 } else {
506 seed_t = at::empty({}, at::dtype(at::kLong).device(at::kCUDA));
507 offset_t = at::empty({}, at::dtype(at::kLong).device(at::kCUDA));
508 params.seed = seed_t.data_ptr<int64_t>();
509 params.extragraph_offset = offset_t.data_ptr<int64_t>();
510 }
511 params.philox_args = philox_state;
512 } else {
513 if (at::cuda::currentStreamCaptureStatus() != at::cuda::CaptureStatus::None) {
514 seed_t = at::empty({}, at::dtype(at::kLong).device(at::kCUDA));
515 offset_t = at::empty({}, at::dtype(at::kLong).device(at::kCUDA));
516 } else {
517 seed_t = at::empty({}, at::dtype(at::kLong));
518 offset_t = at::empty({}, at::dtype(at::kLong));
519 }
520
521 }
522
523 set_params_alibi(params, alibi_slopes_, batch_size, num_heads);
524
525 if (seqlen_k > 0) {
526 auto stream = at::cuda::getCurrentCUDAStream().stream();
527 run_mha_fwd(params, stream);
528 } else {
529 // If seqlen_k == 0, then we have an empty tensor. We need to set the output to 0.
530 out.zero_();
531 softmax_lse.fill_(std::numeric_limits<float>::infinity());
532 }
533
534 if (seqlenq_ngroups_swapped) {
535 out = out.transpose(1, 2).reshape({batch_size, 1, num_heads_k * seqlen_q, head_size_og});
536 q_padded = q_padded.transpose(1, 2).reshape({batch_size, 1, num_heads_k * seqlen_q, head_size_og});
537 softmax_lse = softmax_lse.reshape({batch_size, num_heads_k * seqlen_q, 1});
538 }
539 return {out, q_padded, k_padded, v_padded, softmax_lse, seed_t, offset_t, p};
540 }
541
542 std::tuple<at::Tensor, at::Tensor, at::Tensor, at::Tensor, at::Tensor, at::Tensor, at::Tensor, at::Tensor>
mha_varlen_fwd(const at::Tensor & q,const at::Tensor & k,const at::Tensor & v,std::optional<at::Tensor> & out_,const at::Tensor & cu_seqlens_q,const at::Tensor & cu_seqlens_k,std::optional<at::Tensor> & seqused_k,std::optional<at::Tensor> & block_table_,std::optional<at::Tensor> & alibi_slopes_,int max_seqlen_q,const int max_seqlen_k,const float p_dropout,const float softmax_scale,const bool zero_tensors,bool is_causal,int window_size_left,int window_size_right,const bool return_softmax,std::optional<at::Generator> gen_)543 mha_varlen_fwd(const at::Tensor &q, // total_q x num_heads x head_size, total_q := \sum_{i=0}^{b} s_i
544 const at::Tensor &k, // total_k x num_heads_k x head_size, total_k := \sum_{i=0}^{b} s_i or num_blocks x page_block_size x num_heads_k x head_size if there's a block_table.
545 const at::Tensor &v, // total_k x num_heads_k x head_size, total_k := \sum_{i=0}^{b} s_i or num_blocks x page_block_size x num_heads_k x head_size if there's a block_table.
546 std::optional<at::Tensor> &out_, // total_q x num_heads x head_size, total_k := \sum_{i=0}^{b} s_i
547 const at::Tensor &cu_seqlens_q, // b+1
548 const at::Tensor &cu_seqlens_k, // b+1
549 std::optional<at::Tensor> &seqused_k, // b. If given, only this many elements of each batch element's keys are used.
550 std::optional<at::Tensor> &block_table_, // batch_size x max_num_blocks_per_seq
551 std::optional<at::Tensor> &alibi_slopes_, // num_heads or b x num_heads
552 int max_seqlen_q,
553 const int max_seqlen_k,
554 const float p_dropout,
555 const float softmax_scale,
556 const bool zero_tensors,
557 bool is_causal,
558 int window_size_left,
559 int window_size_right,
560 const bool return_softmax,
561 std::optional<at::Generator> gen_) {
562
563 auto dprops = at::cuda::getCurrentDeviceProperties();
564 // bool is_sm75 = dprops->major == 7 && dprops->minor == 5;
565 bool is_sm8x = dprops->major == 8 && dprops->minor >= 0;
566 bool is_sm90 = dprops->major == 9 && dprops->minor == 0;
567 TORCH_CHECK(is_sm90 || is_sm8x, "FlashAttention only supports Ampere GPUs or newer.");
568 // We will support Turing in the near future
569 // TORCH_CHECK(is_sm90 || is_sm8x || is_sm75, "FlashAttention only supports Turing GPUs or newer.");
570
571 auto q_dtype = q.dtype();
572 TORCH_CHECK(q_dtype == at::kHalf || q_dtype == at::kBFloat16,
573 "FlashAttention only support fp16 and bf16 data type");
574 if (q_dtype == at::kBFloat16) {
575 TORCH_CHECK(is_sm90 || is_sm8x, "bfloat16 is only supported on Ampere GPUs or newer");
576 }
577 TORCH_CHECK(k.dtype() == q_dtype, "query and key must have the same dtype");
578 TORCH_CHECK(v.dtype() == q_dtype, "query and value must have the same dtype");
579 TORCH_CHECK(cu_seqlens_q.dtype() == at::kInt, "cu_seqlens_q must have dtype int32");
580 TORCH_CHECK(cu_seqlens_k.dtype() == at::kInt, "cu_seqlens_k must have dtype int32");
581
582 CHECK_DEVICE(q); CHECK_DEVICE(k); CHECK_DEVICE(v);
583 CHECK_DEVICE(cu_seqlens_q);
584 CHECK_DEVICE(cu_seqlens_k);
585
586 at::Tensor block_table;
587 const bool paged_KV = block_table_.has_value();
588 if (paged_KV) {
589 block_table = block_table_.value();
590 CHECK_DEVICE(block_table);
591 TORCH_CHECK(block_table.dtype() == at::kInt, "block_table must have dtype torch.int32");
592 TORCH_CHECK(block_table.stride(-1) == 1, "block_table must have contiguous last dimension");
593 }
594
595 TORCH_CHECK(q.stride(-1) == 1, "Input tensor must have contiguous last dimension");
596 TORCH_CHECK(k.stride(-1) == 1, "Input tensor must have contiguous last dimension");
597 TORCH_CHECK(v.stride(-1) == 1, "Input tensor must have contiguous last dimension");
598 CHECK_CONTIGUOUS(cu_seqlens_q);
599 CHECK_CONTIGUOUS(cu_seqlens_k);
600
601 const auto sizes = q.sizes();
602
603 const int batch_size = cu_seqlens_q.numel() - 1;
604 int num_heads = sizes[1];
605 const int head_size_og = sizes[2];
606 const int num_heads_k = paged_KV ? k.size(2) : k.size(1);
607
608 const int max_num_blocks_per_seq = !paged_KV ? 0 : block_table.size(1);
609 const int num_blocks = !paged_KV ? 0 : k.size(0);
610 const int page_block_size = !paged_KV ? 1 : k.size(1);
611 TORCH_CHECK(!paged_KV || page_block_size % 256 == 0, "Paged KV cache block size must be divisible by 256");
612
613 if (max_seqlen_q == 1 && !alibi_slopes_.has_value()) { is_causal = false; } // causal=true is the same as causal=false in this case
614 if (is_causal) { window_size_right = 0; }
615
616 void *cu_seqlens_q_d = cu_seqlens_q.data_ptr();
617
618 // Faster to transpose q from (b, 1, (nheads_kv ngroups), d) to (b, ngroups, nheads_kv, d) in this case
619 // H/t Daniel Haziza
620 const int seqlenq_ngroups_swapped = max_seqlen_q == 1 && num_heads > num_heads_k && window_size_left < 0 && window_size_right < 0 && p_dropout == 0.f && head_size_og % 8 == 0 && !alibi_slopes_.has_value();
621 at::Tensor temp_q = q;
622 const int ngroups = num_heads / num_heads_k;
623 if (seqlenq_ngroups_swapped) {
624 temp_q = q.reshape({batch_size, num_heads_k, ngroups, head_size_og}).transpose(1, 2).reshape({batch_size * ngroups, num_heads_k, head_size_og});
625 max_seqlen_q = ngroups;
626 num_heads = num_heads_k;
627 cu_seqlens_q_d = nullptr;
628 }
629
630 const int total_q = temp_q.sizes()[0];
631
632 TORCH_CHECK(batch_size > 0, "batch size must be positive");
633 TORCH_CHECK(head_size_og <= 256, "FlashAttention forward only supports head dimension at most 256");
634 TORCH_CHECK(num_heads % num_heads_k == 0, "Number of heads in key/value must divide number of heads in query");
635 TORCH_CHECK(head_size_og % 8 == 0, "head_size must be a multiple of 8, this is ensured by padding!")
636
637 if (window_size_left >= max_seqlen_k) { window_size_left = -1; }
638 if (window_size_right >= max_seqlen_k) { window_size_right = -1; }
639
640 CHECK_SHAPE(temp_q, total_q, num_heads, head_size_og);
641 if (!paged_KV) {
642 const int total_k = k.size(0);
643 CHECK_SHAPE(k, total_k, num_heads_k, head_size_og);
644 CHECK_SHAPE(v, total_k, num_heads_k, head_size_og);
645 } else {
646 CHECK_SHAPE(k, num_blocks, page_block_size, num_heads_k, head_size_og);
647 CHECK_SHAPE(v, num_blocks, page_block_size, num_heads_k, head_size_og);
648 CHECK_SHAPE(block_table, batch_size, max_num_blocks_per_seq);
649 }
650 CHECK_SHAPE(cu_seqlens_q, batch_size + 1);
651 CHECK_SHAPE(cu_seqlens_k, batch_size + 1);
652 if (seqused_k.has_value()){
653 auto seqused_k_ = seqused_k.value();
654 TORCH_CHECK(seqused_k_.dtype() == at::kInt, "seqused_k must have dtype int32");
655 TORCH_CHECK(seqused_k_.is_cuda(), "seqused_k must be on CUDA device");
656 TORCH_CHECK(seqused_k_.is_contiguous(), "seqused_k must be contiguous");
657 CHECK_SHAPE(seqused_k_, batch_size);
658 }
659
660 at::Tensor q_padded, k_padded, v_padded;
661 q_padded = temp_q;
662 k_padded = k;
663 v_padded = v;
664
665 at::Tensor out;
666 if (out_.has_value()) {
667 out = out_.value();
668 TORCH_CHECK(out.dtype() == q_dtype, "Output must have the same dtype as inputs");
669 CHECK_DEVICE(out);
670 TORCH_CHECK(out.stride(-1) == 1, "Output tensor must have contiguous last dimension");
671 CHECK_SHAPE(out, total_q, num_heads, head_size_og);
672 CHECK_SHAPE(out, sizes[0], sizes[1], head_size_og);
673 if (seqlenq_ngroups_swapped) {
674 out = out.reshape({batch_size, num_heads_k, ngroups, head_size_og}).transpose(1, 2).reshape({batch_size * ngroups, num_heads_k, head_size_og});
675 }
676 if (head_size_og % 8 != 0) { out = at::empty_like(q_padded); }
677 } else {
678 out = at::empty_like(q_padded);
679 }
680
681 auto round_multiple = [](int x, int m) { return (x + m - 1) / m * m; };
682 const int head_size = round_multiple(head_size_og, 8);
683 const int head_size_rounded = round_multiple(head_size, 32);
684 const int seqlen_q_rounded = round_multiple(max_seqlen_q, 128);
685 const int seqlen_k_rounded = round_multiple(max_seqlen_k, 128);
686
687 // Otherwise the kernel will be launched from cuda:0 device
688 // Cast to char to avoid compiler warning about narrowing
689 at::cuda::CUDAGuard device_guard{(char)q.get_device()};
690
691 auto opts = q.options();
692
693 auto softmax_lse = at::empty({batch_size, num_heads, max_seqlen_q}, opts.dtype(at::kFloat));
694 at::Tensor p;
695 // Only return softmax if there's dropout to reduce compilation time
696 if (return_softmax) {
697 TORCH_CHECK(p_dropout > 0.0f, "return_softmax is only supported when p_dropout > 0.0");
698 p = at::empty({ batch_size, num_heads, seqlen_q_rounded, seqlen_k_rounded }, opts);
699 }
700
701 if (zero_tensors) {
702 out.zero_();
703 softmax_lse.fill_(-std::numeric_limits<float>::infinity());
704 if (return_softmax) {p.zero_();}
705 }
706
707 Flash_fwd_params params;
708 set_params_fprop(params,
709 batch_size,
710 max_seqlen_q, max_seqlen_k,
711 seqlen_q_rounded, seqlen_k_rounded,
712 num_heads, num_heads_k,
713 head_size, head_size_rounded,
714 q_padded, k_padded, v_padded, out,
715 cu_seqlens_q_d,
716 cu_seqlens_k.data_ptr(),
717 seqused_k.has_value() ? seqused_k.value().data_ptr() : nullptr,
718 return_softmax ? p.data_ptr() : nullptr,
719 softmax_lse.data_ptr(),
720 p_dropout,
721 softmax_scale,
722 window_size_left,
723 window_size_right,
724 seqlenq_ngroups_swapped);
725 if (paged_KV) {
726 params.block_table = block_table.data_ptr<int>();
727 params.block_table_batch_stride = block_table.stride(0);
728 params.k_batch_stride = k_padded.stride(0);
729 params.v_batch_stride = v_padded.stride(0);
730 }
731 params.page_block_size = page_block_size;
732 // Keep references to these tensors to extend their lifetime
733 at::Tensor softmax_lse_accum, out_accum;
734 if (seqlenq_ngroups_swapped) {
735 // Only apply split-k for decoding
736 std::tie(softmax_lse_accum, out_accum) = set_params_splitkv(params, batch_size, num_heads,
737 head_size, max_seqlen_k, max_seqlen_q,
738 head_size_rounded, p_dropout, /*num_splits*/0, dprops, opts);
739 }
740
741 // We want to checkpoint and save the RNG state for backward if dropout
742 // We get the default generator and return the seed and offset which will
743 // be used in the backward function
744 auto gen = at::get_generator_or_default<at::CUDAGeneratorImpl>(std::nullopt, at::cuda::detail::getDefaultCUDAGenerator());
745 at::Tensor seed_t, offset_t;
746 if (p_dropout > 0.0) {
747 // number of times random will be generated per thread, to offset philox counter in thc random
748 // state
749 // We use a custom RNG that increases the offset by batch_size * nheads * 32.
750 int64_t counter_offset = params.b * params.h * 32;
751 // See Note [Acquire lock when using random generators]
752 std::lock_guard<std::mutex> lock(gen->mutex_);
753 at::PhiloxCudaState philox_state = gen->philox_cuda_state(counter_offset);
754 if (at::cuda::currentStreamCaptureStatus() == at::cuda::CaptureStatus::None) {
755 auto [seed, offset] = at::cuda::philox::unpack(philox_state);
756 seed_t = at::scalar_tensor(at::Scalar(static_cast<int64_t>(seed)), at::dtype(at::kLong));
757 offset_t = at::scalar_tensor(at::Scalar(static_cast<int64_t>(offset)), at::dtype(at::kLong));
758 } else {
759 seed_t = at::empty({}, at::dtype(at::kLong).device(at::kCUDA));
760 offset_t = at::empty({}, at::dtype(at::kLong).device(at::kCUDA));
761 params.seed = seed_t.data_ptr<int64_t>();
762 params.extragraph_offset = offset_t.data_ptr<int64_t>();
763 }
764 params.philox_args = philox_state;
765 } else {
766 if (at::cuda::currentStreamCaptureStatus() != at::cuda::CaptureStatus::None) {
767 seed_t = at::empty({}, at::dtype(at::kLong).device(at::kCUDA));
768 offset_t = at::empty({}, at::dtype(at::kLong).device(at::kCUDA));
769 } else {
770 seed_t = at::empty({}, at::dtype(at::kLong));
771 offset_t = at::empty({}, at::dtype(at::kLong));
772 }
773
774 }
775
776 set_params_alibi(params, alibi_slopes_, batch_size, num_heads);
777
778 if (max_seqlen_k > 0) {
779 auto stream = at::cuda::getCurrentCUDAStream().stream();
780 run_mha_fwd(params, stream, paged_KV);
781 } else {
782 // If seqlen_k == 0, then we have an empty tensor. We need to set the output to 0.
783 out.zero_();
784 softmax_lse.fill_(std::numeric_limits<float>::infinity());
785 }
786
787 if (seqlenq_ngroups_swapped) {
788 std::array<int64_t, 4> size_before = {batch_size, max_seqlen_q, num_heads_k, head_size_og};
789 std::array<int64_t, 3> size_after = {batch_size, num_heads_k * max_seqlen_q, head_size_og};
790 out = out.reshape(size_before).transpose(1, 2).reshape(size_after);
791 q_padded = q_padded.reshape(size_before).transpose(1, 2).reshape(size_after);
792 softmax_lse = softmax_lse.reshape({batch_size, num_heads_k * max_seqlen_q, 1});
793 }
794
795 return {out, q_padded, k_padded, v_padded, softmax_lse, seed_t, offset_t, p};
796 }
797
run_mha_bwd(Flash_bwd_params & params,cudaStream_t stream)798 void run_mha_bwd(Flash_bwd_params ¶ms, cudaStream_t stream) {
799 FP16_SWITCH(!params.is_bf16, [&] {
800 HEADDIM_SWITCH(params.d, [&] {
801 run_mha_bwd_<elem_type, kHeadDim>(params, stream);
802 });
803 });
804 }
805
806 std::tuple<at::Tensor, at::Tensor, at::Tensor, at::Tensor>
mha_bwd(const at::Tensor & dout,const at::Tensor & q,const at::Tensor & k,const at::Tensor & v,const at::Tensor & out,const at::Tensor & softmax_lse,std::optional<at::Tensor> & dq_,std::optional<at::Tensor> & dk_,std::optional<at::Tensor> & dv_,std::optional<at::Tensor> & alibi_slopes_,const float p_dropout,const float softmax_scale,const bool is_causal,int window_size_left,int window_size_right,const bool deterministic,const at::Tensor philox_seed,const at::Tensor philox_offset)807 mha_bwd(const at::Tensor &dout, // batch_size x seqlen_q x num_heads, x head_size_og
808 const at::Tensor &q, // batch_size x seqlen_q x num_heads x head_size
809 const at::Tensor &k, // batch_size x seqlen_k x num_heads_k x head_size
810 const at::Tensor &v, // batch_size x seqlen_k x num_heads_k x head_size
811 const at::Tensor &out, // batch_size x seqlen_q x num_heads x head_size
812 const at::Tensor &softmax_lse, // b x h x seqlen_q
813 std::optional<at::Tensor> &dq_, // batch_size x seqlen_q x num_heads x head_size
814 std::optional<at::Tensor> &dk_, // batch_size x seqlen_k x num_heads_k x head_size
815 std::optional<at::Tensor> &dv_, // batch_size x seqlen_k x num_heads_k x head_size
816 std::optional<at::Tensor> &alibi_slopes_, // num_heads or batch_size x num_heads
817 const float p_dropout, // probability to drop
818 const float softmax_scale,
819 const bool is_causal,
820 int window_size_left,
821 int window_size_right,
822 const bool deterministic,
823 const at::Tensor philox_seed,
824 const at::Tensor philox_offset) {
825
826 #ifdef FLASHATTENTION_DISABLE_BACKWARD
827 TORCH_CHECK(false, "This flash attention build does not support backward.");
828 #endif
829 if (is_causal) { window_size_right = 0; }
830 auto dprops = at::cuda::getCurrentDeviceProperties();
831 // bool is_sm75 = dprops->major == 7 && dprops->minor == 5;
832 bool is_sm8x = dprops->major == 8 && dprops->minor >= 0;
833 bool is_sm80 = dprops->major == 8 && dprops->minor == 0;
834 bool is_sm90 = dprops->major == 9 && dprops->minor == 0;
835 TORCH_CHECK(is_sm90 || is_sm8x, "FlashAttention only supports Ampere GPUs or newer.");
836 // We will support Turing in the near future
837 // TORCH_CHECK(is_sm90 || is_sm8x || is_sm75, "FlashAttention only supports Turing GPUs or newer.");
838
839 bool is_dropout = p_dropout > 0.0;
840 auto stream = at::cuda::getCurrentCUDAStream().stream();
841
842 auto q_dtype = q.dtype();
843 TORCH_CHECK(q_dtype == at::kHalf || q_dtype == at::kBFloat16,
844 "FlashAttention only support fp16 and bf16 data type");
845 if (q_dtype == at::kBFloat16) {
846 TORCH_CHECK(is_sm90 || is_sm8x, "bfloat16 is only supported on Ampere GPUs or newer");
847 }
848 TORCH_CHECK(k.dtype() == q_dtype, "query and key must have the same dtype");
849 TORCH_CHECK(v.dtype() == q_dtype, "query and value must have the same dtype");
850 TORCH_CHECK(out.dtype() == q_dtype, "query and out must have the same dtype");
851 TORCH_CHECK(dout.dtype() == q_dtype, "query and dout must have the same dtype");
852
853 CHECK_DEVICE(q); CHECK_DEVICE(k); CHECK_DEVICE(v);
854 CHECK_DEVICE(out); CHECK_DEVICE(dout); CHECK_DEVICE(softmax_lse);
855
856 TORCH_CHECK(q.stride(-1) == 1, "Input tensor must have contiguous last dimension");
857 TORCH_CHECK(k.stride(-1) == 1, "Input tensor must have contiguous last dimension");
858 TORCH_CHECK(v.stride(-1) == 1, "Input tensor must have contiguous last dimension");
859 TORCH_CHECK(out.stride(-1) == 1, "out tensor must have contiguous last dimension");
860 TORCH_CHECK(dout.stride(-1) == 1, "dout tensor must have contiguous last dimension");
861
862 const auto sizes = q.sizes();
863
864 const int batch_size = sizes[0];
865 const int seqlen_q = sizes[1];
866 const int num_heads = sizes[2];
867 const int head_size_og = dout.size(3);
868 const int head_size = sizes[3];
869 const int seqlen_k = k.size(1);
870 const int num_heads_k = k.size(2);
871 TORCH_CHECK(batch_size > 0, "batch size must be positive");
872 TORCH_CHECK(head_size % 8 == 0, "head_size should be a multiple of 8");
873 TORCH_CHECK(head_size_og % 8 == 0, "head_size_og should be a multiple of 8, this is ensured by padding!");
874 TORCH_CHECK(head_size <= 256, "FlashAttention backward only supports head dimension at most 256");
875 if (head_size > 192 && (head_size <= 224 || is_dropout)) {
876 TORCH_CHECK(is_sm80 || is_sm90, "FlashAttention backward for head dim 256 with dropout, or head dim 224 with/without dropout requires A100/A800 or H100/H800");
877 }
878 TORCH_CHECK(num_heads % num_heads_k == 0, "Number of heads in key/value must divide number of heads in query");
879
880 auto round_multiple = [](int x, int m) { return (x + m - 1) / m * m; };
881 const int head_size_rounded = round_multiple(head_size, 32);
882 const int seqlen_q_rounded = round_multiple(seqlen_q, 128);
883 const int seqlen_k_rounded = round_multiple(seqlen_k, 128);
884
885 TORCH_CHECK(head_size == round_multiple(head_size_og, 8), "head_size must be head_size_og rounded to a multiple of 8");
886
887 if (window_size_left >= seqlen_k) { window_size_left = -1; }
888 if (window_size_right >= seqlen_k) { window_size_right = -1; }
889
890 CHECK_SHAPE(q, batch_size, seqlen_q, num_heads, head_size);
891 CHECK_SHAPE(k, batch_size, seqlen_k, num_heads_k, head_size);
892 CHECK_SHAPE(v, batch_size, seqlen_k, num_heads_k, head_size);
893 CHECK_SHAPE(out, batch_size, seqlen_q, num_heads, head_size);
894 CHECK_SHAPE(dout, batch_size, seqlen_q, num_heads, head_size_og);
895
896 at::Tensor dq, dk, dv;
897 if (dq_.has_value()) {
898 dq = dq_.value();
899 TORCH_CHECK(dq.dtype() == q_dtype, "dq must have the same dtype as q");
900 CHECK_DEVICE(dq);
901 TORCH_CHECK(dq.stride(-1) == 1, "dq must have contiguous last dimension");
902 CHECK_SHAPE(dq, batch_size, seqlen_q, num_heads, head_size);
903 } else {
904 dq = at::empty_like(q);
905 }
906 if (dk_.has_value()) {
907 dk = dk_.value();
908 TORCH_CHECK(dk.dtype() == q_dtype, "dk must have the same dtype as q");
909 CHECK_DEVICE(dk);
910 TORCH_CHECK(dk.stride(-1) == 1, "dk must have contiguous last dimension");
911 CHECK_SHAPE(dk, batch_size, seqlen_k, num_heads_k, head_size);
912 } else {
913 dk = at::empty_like(k);
914 }
915 if (dv_.has_value()) {
916 dv = dv_.value();
917 TORCH_CHECK(dv.dtype() == q_dtype, "dv must have the same dtype as q");
918 CHECK_DEVICE(dv);
919 TORCH_CHECK(dv.stride(-1) == 1, "dv must have contiguous last dimension");
920 CHECK_SHAPE(dv, batch_size, seqlen_k, num_heads_k, head_size);
921 } else {
922 dv = at::empty_like(v);
923 }
924
925 // bool loop = seqlen_k > blocksize_c;
926 // TODO: change later, for now set to true for simplicity
927 bool loop = true;
928
929 // Otherwise the kernel will be launched from cuda:0 device
930 // Cast to char to avoid compiler warning about narrowing
931 at::cuda::CUDAGuard device_guard{(char)q.get_device()};
932
933 auto opts = q.options();
934 auto softmax_d = at::empty({batch_size, num_heads, seqlen_q_rounded}, opts.dtype(at::kFloat));
935 at::Tensor dq_accum;
936 at::Tensor dk_accum, dv_accum;
937 if (loop) {
938 if (!deterministic) {
939 dq_accum = at::empty({batch_size, seqlen_q_rounded, num_heads, head_size_rounded}, opts.dtype(at::kFloat));
940 } else {
941 const int nsplits = (dprops->multiProcessorCount + batch_size * num_heads - 1) / (batch_size * num_heads);
942 dq_accum = at::zeros({nsplits, batch_size, seqlen_q_rounded, num_heads, head_size_rounded}, opts.dtype(at::kFloat));
943 }
944 // dk_accum = torch::empty({batch_size, num_heads_k, seqlen_k_rounded, head_size_rounded}, opts.dtype(at::kFloat));
945 // dv_accum = torch::empty({batch_size, num_heads_k, seqlen_k_rounded, head_size_rounded}, opts.dtype(at::kFloat));
946 }
947
948 at::Tensor dk_expanded, dv_expanded;
949 if (num_heads_k != num_heads) { // MQA / GQA
950 dk_expanded = at::empty({batch_size, seqlen_k, num_heads, head_size}, opts);
951 dv_expanded = at::empty({batch_size, seqlen_k, num_heads, head_size}, opts);
952 } else {
953 dk_expanded = dk;
954 dv_expanded = dv;
955 }
956
957 Flash_bwd_params params;
958
959 set_params_dgrad(params,
960 batch_size,
961 seqlen_q, seqlen_k,
962 seqlen_q_rounded, seqlen_k_rounded,
963 num_heads, num_heads_k,
964 head_size, head_size_rounded,
965 q, k, v, out,
966 dout, dq, dk_expanded, dv_expanded,
967 nullptr,
968 nullptr,
969 loop ? dq_accum.data_ptr() : nullptr,
970 // loop ? dk_accum.data_ptr() : nullptr,
971 // loop ? dv_accum.data_ptr() : nullptr,
972 nullptr,
973 nullptr,
974 softmax_lse.data_ptr(),
975 softmax_d.data_ptr(),
976 p_dropout,
977 softmax_scale,
978 window_size_left,
979 window_size_right,
980 deterministic);
981 params.dq_accum_split_stride = !deterministic ? 0 : dq_accum.stride(0);
982
983 auto launch = &run_mha_bwd;
984
985 at::PhiloxCudaState philox_args;
986 if (is_dropout) {
987 if (at::cuda::currentStreamCaptureStatus() ==
988 at::cuda::CaptureStatus::None)
989 {
990 philox_args = at::PhiloxCudaState(*philox_seed.data_ptr<int64_t>(), *philox_offset.data_ptr<int64_t>());
991 } else { // dropout + capture
992 philox_args = at::PhiloxCudaState(
993 philox_seed.data_ptr<int64_t>(), philox_offset.data_ptr<int64_t>(), 0);
994 }
995 }
996 params.philox_args = philox_args;
997
998 set_params_alibi(params, alibi_slopes_, batch_size, num_heads);
999
1000 if (seqlen_q > 0) {
1001 launch(params, stream);
1002 } else {
1003 // If seqlen_q == 0, then we have an empty tensor. We need to set the output to 0.
1004 dk_expanded.zero_();
1005 dv_expanded.zero_();
1006 softmax_d.zero_();
1007 }
1008
1009 // For MQA/GQA we need to sum dK and dV across the groups
1010 if (num_heads_k != num_heads) {
1011 at::sum_out(dk, at::reshape(dk_expanded, {batch_size, seqlen_k, num_heads_k, num_heads / num_heads_k, head_size}), {3});
1012 at::sum_out(dv, at::reshape(dv_expanded, {batch_size, seqlen_k, num_heads_k, num_heads / num_heads_k, head_size}), {3});
1013 }
1014 return { dq, dk, dv, softmax_d };
1015 }
1016
1017 std::tuple<at::Tensor, at::Tensor, at::Tensor, at::Tensor>
mha_varlen_bwd(const at::Tensor & dout,const at::Tensor & q,const at::Tensor & k,const at::Tensor & v,const at::Tensor & out,const at::Tensor & softmax_lse,std::optional<at::Tensor> & dq_,std::optional<at::Tensor> & dk_,std::optional<at::Tensor> & dv_,const at::Tensor & cu_seqlens_q,const at::Tensor & cu_seqlens_k,std::optional<at::Tensor> & alibi_slopes_,const int max_seqlen_q,const int max_seqlen_k,const float p_dropout,const float softmax_scale,const bool zero_tensors,const bool is_causal,int window_size_left,int window_size_right,const bool deterministic,const at::Tensor philox_seed,const at::Tensor philox_offset)1018 mha_varlen_bwd(const at::Tensor &dout, // total_q x num_heads, x head_size
1019 const at::Tensor &q, // total_q x num_heads x head_size, total_q := \sum_{i=0}^{b} s_i
1020 const at::Tensor &k, // total_k x num_heads_k x head_size, total_k := \sum_{i=0}^{b} s_i
1021 const at::Tensor &v, // total_k x num_heads_k x head_size, total_k := \sum_{i=0}^{b} s_i
1022 const at::Tensor &out, // total_q x num_heads x head_size
1023 const at::Tensor &softmax_lse, // b x h x s softmax logsumexp
1024 std::optional<at::Tensor> &dq_, // total_q x num_heads x head_size, total_q := \sum_{i=0}^{b} s_i
1025 std::optional<at::Tensor> &dk_, // total_k x num_heads_k x head_size, total_k := \sum_{i=0}^{b} s_i
1026 std::optional<at::Tensor> &dv_, // total_k x num_heads_k x head_size, total_k := \sum_{i=0}^{b} s_i
1027 const at::Tensor &cu_seqlens_q, // b+1
1028 const at::Tensor &cu_seqlens_k, // b+1
1029 std::optional<at::Tensor> &alibi_slopes_, // num_heads or b x num_heads
1030 const int max_seqlen_q,
1031 const int max_seqlen_k, // max sequence length to choose the kernel
1032 const float p_dropout, // probability to drop
1033 const float softmax_scale,
1034 const bool zero_tensors,
1035 const bool is_causal,
1036 int window_size_left,
1037 int window_size_right,
1038 const bool deterministic,
1039 const at::Tensor philox_seed,
1040 const at::Tensor philox_offset)
1041 {
1042
1043 #ifdef FLASHATTENTION_DISABLE_BACKWARD
1044 TORCH_CHECK(false, "This flash attention build does not support backward.");
1045 #endif
1046
1047 if (is_causal) { window_size_right = 0; }
1048 auto dprops = at::cuda::getCurrentDeviceProperties();
1049 // bool is_sm75 = dprops->major == 7 && dprops->minor == 5;
1050 bool is_sm8x = dprops->major == 8 && dprops->minor >= 0;
1051 bool is_sm80 = dprops->major == 8 && dprops->minor == 0;
1052 bool is_sm90 = dprops->major == 9 && dprops->minor == 0;
1053 TORCH_CHECK(is_sm90 || is_sm8x, "FlashAttention only supports Ampere GPUs or newer.");
1054 // We will support Turing in the near future
1055 // TORCH_CHECK(is_sm90 || is_sm8x || is_sm75, "FlashAttention only supports Turing GPUs or newer.");
1056 bool is_dropout = p_dropout > 0.0;
1057 auto stream = at::cuda::getCurrentCUDAStream().stream();
1058
1059 auto q_dtype = q.dtype();
1060 TORCH_CHECK(q_dtype == at::kHalf || q_dtype == at::kBFloat16,
1061 "FlashAttention only support fp16 and bf16 data type");
1062 if (q_dtype == at::kBFloat16) {
1063 TORCH_CHECK(is_sm90 || is_sm8x, "bfloat16 is only supported on Ampere GPUs or newer");
1064 }
1065 TORCH_CHECK(k.dtype() == q_dtype, "query and key must have the same dtype");
1066 TORCH_CHECK(v.dtype() == q_dtype, "query and value must have the same dtype");
1067 TORCH_CHECK(out.dtype() == q_dtype, "query and out must have the same dtype");
1068 TORCH_CHECK(dout.dtype() == q_dtype, "query and dout must have the same dtype");
1069 TORCH_CHECK(cu_seqlens_q.dtype() == at::kInt, "cu_seqlens_q must have dtype int32");
1070 TORCH_CHECK(cu_seqlens_k.dtype() == at::kInt, "cu_seqlens_k must have dtype int32");
1071
1072 CHECK_DEVICE(q); CHECK_DEVICE(k); CHECK_DEVICE(v);
1073 CHECK_DEVICE(out); CHECK_DEVICE(dout); CHECK_DEVICE(softmax_lse);
1074 CHECK_DEVICE(cu_seqlens_q); CHECK_DEVICE(cu_seqlens_k);
1075
1076 TORCH_CHECK(q.stride(-1) == 1, "Input tensor must have contiguous last dimension");
1077 TORCH_CHECK(k.stride(-1) == 1, "Input tensor must have contiguous last dimension");
1078 TORCH_CHECK(v.stride(-1) == 1, "Input tensor must have contiguous last dimension");
1079 TORCH_CHECK(out.stride(-1) == 1, "out tensor must have contiguous last dimension");
1080 TORCH_CHECK(dout.stride(-1) == 1, "dout tensor must have contiguous last dimension");
1081 CHECK_CONTIGUOUS(cu_seqlens_q);
1082 CHECK_CONTIGUOUS(cu_seqlens_k);
1083
1084 const auto sizes = q.sizes();
1085
1086 const int total_q = sizes[0];
1087 const int batch_size = cu_seqlens_q.numel() - 1;
1088 const int num_heads = sizes[1];
1089 const int head_size_og = dout.size(2);
1090 const int head_size = sizes[2];
1091 const int total_k = k.size(0);
1092 const int num_heads_k = k.size(1);
1093 TORCH_CHECK(batch_size > 0, "batch size must be positive");
1094 TORCH_CHECK(head_size % 8 == 0, "head_size should be a multiple of 8");
1095 TORCH_CHECK(head_size_og % 8 == 0, "head_size_og should be a multiple of 8, this is ensured by padding!");
1096 TORCH_CHECK(head_size <= 256, "FlashAttention backward only supports head dimension at most 256");
1097 if (head_size > 192 && (head_size <= 224 || is_dropout)) {
1098 TORCH_CHECK(is_sm80 || is_sm90, "FlashAttention backward for head dim 256 with dropout, or head dim 224 with/without dropout requires A100/A800 or H100/H800");
1099 }
1100 TORCH_CHECK(num_heads % num_heads_k == 0, "Number of heads in key/value must divide number of heads in query");
1101
1102 auto round_multiple = [](int x, int m) { return (x + m - 1) / m * m; };
1103 const int head_size_rounded = round_multiple(head_size, 32);
1104 const int seqlen_q_rounded = round_multiple(max_seqlen_q, 128);
1105 const int seqlen_k_rounded = round_multiple(max_seqlen_k, 128);
1106
1107 TORCH_CHECK(head_size == round_multiple(head_size_og, 8), "head_size must be head_size_og rounded to a multiple of 8");
1108
1109 if (window_size_left >= max_seqlen_k) { window_size_left = -1; }
1110 if (window_size_right >= max_seqlen_k) { window_size_right = -1; }
1111
1112 CHECK_SHAPE(q, total_q, num_heads, head_size);
1113 CHECK_SHAPE(k, total_k, num_heads_k, head_size);
1114 CHECK_SHAPE(v, total_k, num_heads_k, head_size);
1115 CHECK_SHAPE(out, total_q, num_heads, head_size);
1116 CHECK_SHAPE(dout, total_q, num_heads, head_size_og);
1117 CHECK_SHAPE(cu_seqlens_q, batch_size + 1);
1118 CHECK_SHAPE(cu_seqlens_k, batch_size + 1);
1119
1120 at::Tensor dq, dk, dv;
1121 if (dq_.has_value()) {
1122 dq = dq_.value();
1123 TORCH_CHECK(dq.dtype() == q_dtype, "dq must have the same dtype as q");
1124 CHECK_DEVICE(dq);
1125 TORCH_CHECK(dq.stride(-1) == 1, "dq must have contiguous last dimension");
1126 CHECK_SHAPE(dq, total_q, num_heads, head_size);
1127 } else {
1128 dq = at::empty_like(q);
1129 }
1130 if (dk_.has_value()) {
1131 dk = dk_.value();
1132 TORCH_CHECK(dk.dtype() == q_dtype, "dk must have the same dtype as q");
1133 CHECK_DEVICE(dk);
1134 TORCH_CHECK(dk.stride(-1) == 1, "dk must have contiguous last dimension");
1135 CHECK_SHAPE(dk, total_k, num_heads_k, head_size);
1136 } else {
1137 dk = at::empty_like(k);
1138 }
1139 if (dv_.has_value()) {
1140 dv = dv_.value();
1141 TORCH_CHECK(dv.dtype() == q_dtype, "dv must have the same dtype as q");
1142 CHECK_DEVICE(dv);
1143 TORCH_CHECK(dv.stride(-1) == 1, "dv must have contiguous last dimension");
1144 CHECK_SHAPE(dv, total_k, num_heads_k, head_size);
1145 } else {
1146 dv = at::empty_like(v);
1147 }
1148
1149 // bool loop = max_seqlen_k > blocksize_c;
1150 // TODO: change later, for now set to true for simplicity
1151 bool loop = true;
1152
1153 // Otherwise the kernel will be launched from cuda:0 device
1154 // Cast to char to avoid compiler warning about narrowing
1155 at::cuda::CUDAGuard device_guard{(char)q.get_device()};
1156
1157 auto opts = q.options();
1158 auto softmax_d = at::empty({batch_size, num_heads, seqlen_q_rounded}, opts.dtype(at::kFloat));
1159 at::Tensor dq_accum;
1160 if (loop) {
1161 // We don't want to allocate dq_accum of size (batch, seqlen_q_rounded, num_heads, head_size_rounded)
1162 // because that would be too large if there is a very long sequence and the rest of the sequences are short.
1163 // Instead, we allocate dq_accum of size (total_q + 128 * batch, num_heads, head_size_rounded).
1164 // Note that 128 is the max block size on the seqlen_q dimension.
1165 // For dQ, the i-th sequence is stored in indices from cu_seqlens[i] + 128 * i to
1166 // cu_seqlens[i + 1] * 128 * i - 1. This ensures that the i-th sequence and (i + 1)-th sequence will
1167 // be at least 128 apart. It's ok for us to do atomicAdds up to 128 rows beyond what we're normally
1168 // allowed to do. So we won't have to do any bound checking, and performance should stay the same.
1169 if (!deterministic) {
1170 dq_accum = at::empty({total_q + 128 * batch_size, num_heads, head_size_rounded}, opts.dtype(at::kFloat));
1171 } else {
1172 const int nsplits = (dprops->multiProcessorCount + batch_size * num_heads - 1) / (batch_size * num_heads);
1173 dq_accum = at::zeros({nsplits, total_q + 128 * batch_size, num_heads, head_size_rounded}, opts.dtype(at::kFloat));
1174 }
1175 }
1176
1177 at::Tensor dk_expanded, dv_expanded;
1178 if (num_heads_k != num_heads) { // MQA / GQA
1179 dk_expanded = at::empty({total_k, num_heads, head_size}, opts);
1180 dv_expanded = at::empty({total_k, num_heads, head_size}, opts);
1181 } else {
1182 dk_expanded = dk;
1183 dv_expanded = dv;
1184 }
1185
1186 if( zero_tensors ) {
1187 dq.zero_();
1188 dk_expanded.zero_();
1189 dv_expanded.zero_();
1190 softmax_d.zero_();
1191 }
1192
1193 Flash_bwd_params params;
1194
1195 set_params_dgrad(params,
1196 batch_size,
1197 max_seqlen_q, max_seqlen_k,
1198 seqlen_q_rounded, seqlen_k_rounded,
1199 num_heads, num_heads_k,
1200 head_size, head_size_rounded,
1201 q, k, v, out,
1202 dout, dq, dk_expanded, dv_expanded,
1203 cu_seqlens_q.data_ptr(),
1204 cu_seqlens_k.data_ptr(),
1205 loop ? dq_accum.data_ptr() : nullptr,
1206 nullptr,
1207 nullptr,
1208 softmax_lse.data_ptr(),
1209 softmax_d.data_ptr(),
1210 p_dropout,
1211 softmax_scale,
1212 window_size_left,
1213 window_size_right,
1214 deterministic);
1215 params.dq_accum_split_stride = !deterministic ? 0 : dq_accum.stride(0);
1216
1217 auto launch = &run_mha_bwd;
1218
1219 at::PhiloxCudaState philox_args;
1220 if (is_dropout) {
1221 if (at::cuda::currentStreamCaptureStatus() ==
1222 at::cuda::CaptureStatus::None)
1223 {
1224 philox_args = at::PhiloxCudaState(*philox_seed.data_ptr<int64_t>(), *philox_offset.data_ptr<int64_t>());
1225 } else { // dropout + capture
1226 philox_args = at::PhiloxCudaState(
1227 philox_seed.data_ptr<int64_t>(), philox_offset.data_ptr<int64_t>(), 0);
1228 }
1229 }
1230 params.philox_args = philox_args;
1231
1232 set_params_alibi(params, alibi_slopes_, batch_size, num_heads);
1233
1234 if (max_seqlen_q > 0) {
1235 launch(params, stream);
1236 } else {
1237 // If seqlen_q == 0, then we have an empty tensor. We need to set the output to 0.
1238 dk_expanded.zero_();
1239 dv_expanded.zero_();
1240 softmax_d.zero_();
1241 }
1242
1243 // For MQA/GQA we need to sum dK and dV across the groups
1244 if (num_heads_k != num_heads) {
1245 at::sum_out(dk, at::reshape(dk_expanded, {total_k, num_heads_k, num_heads / num_heads_k, head_size}), {2});
1246 at::sum_out(dv, at::reshape(dv_expanded, {total_k, num_heads_k, num_heads / num_heads_k, head_size}), {2});
1247 }
1248
1249 return { dq, dk, dv, softmax_d };
1250 }
1251
1252 std::tuple<at::Tensor, at::Tensor>
mha_fwd_kvcache(at::Tensor & q,const at::Tensor & kcache,const at::Tensor & vcache,std::optional<const at::Tensor> & k_,std::optional<const at::Tensor> & v_,std::optional<const at::Tensor> & seqlens_k_,std::optional<const at::Tensor> & rotary_cos_,std::optional<const at::Tensor> & rotary_sin_,std::optional<const at::Tensor> & cache_batch_idx_,std::optional<at::Tensor> & block_table_,std::optional<at::Tensor> & alibi_slopes_,std::optional<at::Tensor> & out_,const float softmax_scale,bool is_causal,int window_size_left,int window_size_right,bool is_rotary_interleaved,int num_splits)1253 mha_fwd_kvcache(at::Tensor &q, // batch_size x seqlen_q x num_heads x head_size
1254 const at::Tensor &kcache, // batch_size_c x seqlen_k x num_heads_k x head_size or num_blocks x page_block_size x num_heads_k x head_size if there's a block_table.
1255 const at::Tensor &vcache, // batch_size_c x seqlen_k x num_heads_k x head_size or num_blocks x page_block_size x num_heads_k x head_size if there's a block_table.
1256 std::optional<const at::Tensor> &k_, // batch_size x seqlen_knew x num_heads_k x head_size
1257 std::optional<const at::Tensor> &v_, // batch_size x seqlen_knew x num_heads_k x head_size
1258 std::optional<const at::Tensor> &seqlens_k_, // batch_size
1259 std::optional<const at::Tensor> &rotary_cos_, // seqlen_ro x (rotary_dim / 2)
1260 std::optional<const at::Tensor> &rotary_sin_, // seqlen_ro x (rotary_dim / 2)
1261 std::optional<const at::Tensor> &cache_batch_idx_, // indices to index into the KV cache
1262 std::optional<at::Tensor> &block_table_, // batch_size x max_num_blocks_per_seq
1263 std::optional<at::Tensor> &alibi_slopes_, // num_heads or batch_size x num_heads
1264 std::optional<at::Tensor> &out_, // batch_size x seqlen_q x num_heads x head_size
1265 const float softmax_scale,
1266 bool is_causal,
1267 int window_size_left,
1268 int window_size_right,
1269 bool is_rotary_interleaved, // if true, rotary combines indices 0 & 1, else indices 0 & rotary_dim / 2
1270 int num_splits
1271 ) {
1272
1273 auto dprops = at::cuda::getCurrentDeviceProperties();
1274 // bool is_sm75 = dprops->major == 7 && dprops->minor == 5;
1275 bool is_sm8x = dprops->major == 8 && dprops->minor >= 0;
1276 bool is_sm90 = dprops->major == 9 && dprops->minor == 0;
1277 TORCH_CHECK(is_sm90 || is_sm8x, "FlashAttention only supports Ampere GPUs or newer.");
1278 // We will support Turing in the near future
1279 // TORCH_CHECK(is_sm90 || is_sm8x || is_sm75, "FlashAttention only supports Turing GPUs or newer.");
1280
1281 auto q_dtype = q.dtype();
1282 TORCH_CHECK(q_dtype == at::kHalf || q_dtype == at::kBFloat16,
1283 "FlashAttention only support fp16 and bf16 data type");
1284 if (q_dtype == at::kBFloat16) {
1285 TORCH_CHECK(is_sm90 || is_sm8x, "bfloat16 is only supported on Ampere GPUs or newer");
1286 }
1287 TORCH_CHECK(kcache.dtype() == q_dtype, "query and key must have the same dtype");
1288 TORCH_CHECK(vcache.dtype() == q_dtype, "query and value must have the same dtype");
1289
1290 CHECK_DEVICE(q); CHECK_DEVICE(kcache); CHECK_DEVICE(vcache);
1291
1292 TORCH_CHECK(q.stride(-1) == 1, "Input tensor must have contiguous last dimension");
1293 TORCH_CHECK(kcache.stride(-1) == 1, "Input tensor must have contiguous last dimension");
1294 TORCH_CHECK(vcache.stride(-1) == 1, "Input tensor must have contiguous last dimension");
1295
1296 at::Tensor block_table;
1297 const bool paged_KV = block_table_.has_value();
1298 if (paged_KV) {
1299 TORCH_CHECK(!cache_batch_idx_.has_value(), "Paged KVcache does not support cache_batch_idx");
1300 block_table = block_table_.value();
1301 CHECK_DEVICE(block_table);
1302 TORCH_CHECK(block_table.dtype() == at::kInt, "block_table must have dtype torch.int32");
1303 TORCH_CHECK(block_table.stride(-1) == 1, "block_table must have contiguous last dimension");
1304 }
1305
1306 const auto sizes = q.sizes();
1307
1308 const int batch_size = sizes[0];
1309 int seqlen_q = sizes[1];
1310 int num_heads = sizes[2];
1311 const int head_size_og = sizes[3];
1312
1313 const int max_num_blocks_per_seq = !paged_KV ? 0 : block_table.size(1);
1314 const int num_blocks = !paged_KV ? 0 : kcache.size(0);
1315 const int page_block_size = !paged_KV ? 1 : kcache.size(1);
1316 TORCH_CHECK(!paged_KV || page_block_size % 256 == 0, "Paged KV cache block size must be divisible by 256");
1317 const int seqlen_k = !paged_KV ? kcache.size(1) : max_num_blocks_per_seq * page_block_size;
1318 const int num_heads_k = kcache.size(2);
1319 const int batch_size_c = !paged_KV ? kcache.size(0) : batch_size;
1320 TORCH_CHECK(batch_size > 0, "batch size must be postive");
1321 TORCH_CHECK(head_size_og <= 256, "FlashAttention forward only supports head dimension at most 256");
1322 TORCH_CHECK(num_heads % num_heads_k == 0, "Number of heads in key/value must divide number of heads in query");
1323
1324 // causal=true is the same as causal=false in this case
1325 if (seqlen_q == 1 && !alibi_slopes_.has_value()) { is_causal = false; }
1326 if (is_causal) { window_size_right = 0; }
1327
1328 // Faster to transpose q from (b, 1, (nheads_kv ngroups), d) to (b, ngroups, nheads_kv, d) in this case
1329 // H/t Daniel Haziza
1330 const int seqlenq_ngroups_swapped = seqlen_q == 1 && num_heads > num_heads_k && window_size_left < 0 && window_size_right < 0 && head_size_og % 8 == 0 && !alibi_slopes_.has_value();
1331 if (seqlenq_ngroups_swapped) {
1332 const int ngroups = num_heads / num_heads_k;
1333 q = q.reshape({batch_size, num_heads_k, ngroups, head_size_og}).transpose(1, 2);
1334 seqlen_q = ngroups;
1335 num_heads = num_heads_k;
1336 }
1337
1338 if (window_size_left >= seqlen_k) { window_size_left = -1; }
1339 if (window_size_right >= seqlen_k) { window_size_right = -1; }
1340
1341 CHECK_SHAPE(q, batch_size, seqlen_q, num_heads, head_size_og);
1342 if (!paged_KV) {
1343 CHECK_SHAPE(kcache, batch_size_c, seqlen_k, num_heads_k, head_size_og);
1344 CHECK_SHAPE(vcache, batch_size_c, seqlen_k, num_heads_k, head_size_og);
1345 } else {
1346 CHECK_SHAPE(kcache, num_blocks, page_block_size, num_heads_k, head_size_og);
1347 CHECK_SHAPE(vcache, num_blocks, page_block_size, num_heads_k, head_size_og);
1348 CHECK_SHAPE(block_table, batch_size, max_num_blocks_per_seq);
1349 }
1350
1351 at::Tensor q_padded, kcache_padded, vcache_padded;
1352 if (head_size_og % 8 != 0) {
1353 q_padded = at::pad(q, {0, 8 - head_size_og % 8});
1354 kcache_padded = at::pad(kcache, {0, 8 - head_size_og % 8});
1355 vcache_padded = at::pad(vcache, {0, 8 - head_size_og % 8});
1356 // q_padded = at::nn::functional::pad(q, at::nn::functional::PadFuncOptions({0, 8 - head_size_og % 8}));
1357 // kcache_padded = at::nn::functional::pad(kcache, at::nn::functional::PadFuncOptions({0, 8 - head_size_og % 8}));
1358 // vcache_padded = at::nn::functional::pad(vcache, at::nn::functional::PadFuncOptions({0, 8 - head_size_og % 8}));
1359 } else {
1360 q_padded = q;
1361 kcache_padded = kcache;
1362 vcache_padded = vcache;
1363 }
1364
1365 at::Tensor out;
1366 if (out_.has_value()) {
1367 out = out_.value();
1368 TORCH_CHECK(out.dtype() == q_dtype, "Output must have the same dtype as inputs");
1369 CHECK_DEVICE(out);
1370 TORCH_CHECK(out.stride(-1) == 1, "Output tensor must have contiguous last dimension");
1371 CHECK_SHAPE(out, batch_size, seqlen_q, num_heads, head_size_og);
1372 if (head_size_og % 8 != 0) { out = at::empty_like(q_padded); }
1373 } else {
1374 out = at::empty_like(q_padded);
1375 }
1376
1377 auto round_multiple = [](int x, int m) { return (x + m - 1) / m * m; };
1378 const int head_size = round_multiple(head_size_og, 8);
1379 const int head_size_rounded = round_multiple(head_size, 32);
1380 const int seqlen_q_rounded = round_multiple(seqlen_q, 128);
1381 const int seqlen_k_rounded = round_multiple(seqlen_k, 128);
1382
1383 // Otherwise the kernel will be launched from cuda:0 device
1384 // Cast to char to avoid compiler warning about narrowing
1385 at::cuda::CUDAGuard device_guard{(char)q.get_device()};
1386
1387 auto opts = q.options();
1388
1389 auto softmax_lse = at::empty({batch_size, num_heads, seqlen_q}, opts.dtype(at::kFloat));
1390
1391 Flash_fwd_params params;
1392 set_params_fprop(params,
1393 batch_size,
1394 seqlen_q, seqlen_k,
1395 seqlen_q_rounded, seqlen_k_rounded,
1396 num_heads, num_heads_k,
1397 head_size, head_size_rounded,
1398 q_padded, kcache_padded, vcache_padded, out,
1399 /*cu_seqlens_q_d=*/nullptr,
1400 /*cu_seqlens_k_d=*/nullptr,
1401 /*seqused_k=*/nullptr,
1402 /*p_ptr=*/nullptr,
1403 softmax_lse.data_ptr(),
1404 /*p_dropout=*/0.f,
1405 softmax_scale,
1406 window_size_left,
1407 window_size_right);
1408
1409 at::Tensor k, v, k_padded, v_padded;
1410 if (k_.has_value()) {
1411 TORCH_CHECK(v_.has_value(), "If key is supplied, value must also be passed in");
1412 TORCH_CHECK(seqlens_k_.has_value(), "If key is supplied, seqlens_k must also be passed in");
1413 TORCH_CHECK(seqlen_q <= seqlen_k, "If key is supplied, it must have seqlen <= the seqlen of the KV cache");
1414 k = k_.value();
1415 v = v_.value();
1416 TORCH_CHECK(k.dtype() == q_dtype, "Key must have the same dtype as query");
1417 TORCH_CHECK(v.dtype() == q_dtype, "Value must have the same dtype as query");
1418 CHECK_DEVICE(k); CHECK_DEVICE(v);
1419 TORCH_CHECK(k.stride(-1) == 1, "Key tensor must have contiguous last dimension");
1420 TORCH_CHECK(v.stride(-1) == 1, "Value tensor must have contiguous last dimension");
1421 int seqlen_knew = k.size(1);
1422 CHECK_SHAPE(k, batch_size, seqlen_knew, num_heads_k, head_size_og);
1423 CHECK_SHAPE(v, batch_size, seqlen_knew, num_heads_k, head_size_og);
1424 if (head_size_og % 8 != 0) {
1425 k_padded = at::pad(k, {0, 8 - head_size_og % 8});
1426 v_padded = at::pad(v, {0, 8 - head_size_og % 8});
1427 // k_padded = at::nn::functional::pad(k, at::nn::functional::PadFuncOptions({0, 8 - head_size_og % 8}));
1428 // v_padded = at::nn::functional::pad(v, at::nn::functional::PadFuncOptions({0, 8 - head_size_og % 8}));
1429 } else {
1430 k_padded = k;
1431 v_padded = v;
1432 }
1433 params.seqlen_knew = seqlen_knew;
1434 params.knew_ptr = k_padded.data_ptr();
1435 params.vnew_ptr = v_padded.data_ptr();
1436 // All stride are in elements, not bytes.
1437 params.knew_batch_stride = k_padded.stride(0);
1438 params.vnew_batch_stride = v_padded.stride(0);
1439 params.knew_row_stride = k_padded.stride(-3);
1440 params.vnew_row_stride = v_padded.stride(-3);
1441 params.knew_head_stride = k_padded.stride(-2);
1442 params.vnew_head_stride = v_padded.stride(-2);
1443 }
1444
1445 if (seqlens_k_.has_value()) {
1446 auto seqlens_k = seqlens_k_.value();
1447 TORCH_CHECK(seqlens_k.dtype() == at::kInt, "seqlens_k must have dtype int32");
1448 CHECK_DEVICE(seqlens_k);
1449 CHECK_CONTIGUOUS(seqlens_k);
1450 CHECK_SHAPE(seqlens_k, batch_size);
1451 params.cu_seqlens_k = static_cast<int *>(seqlens_k.data_ptr());
1452 }
1453 params.is_seqlens_k_cumulative = !(seqlens_k_.has_value());
1454
1455 if (rotary_cos_.has_value()) {
1456 TORCH_CHECK(k_.has_value(), "If rotary cos/sin are provided, new key / value to be appended to KV cache must also be provided");
1457 auto rotary_cos = rotary_cos_.value();
1458 CHECK_DEVICE(rotary_cos);
1459 params.rotary_dim = rotary_cos.size(1) * 2;
1460 TORCH_CHECK(params.rotary_dim <= head_size, "rotary_dim must be <= headdim");
1461 TORCH_CHECK(params.rotary_dim % 16 == 0, "Only rotary dimensions divisible by 16 are currently supported");
1462 const int seqlen_ro = rotary_cos.size(0);
1463 TORCH_CHECK(seqlen_ro >= seqlen_k, "cos/sin seqlen must be at least the seqlen of KV cache");
1464 CHECK_SHAPE(rotary_cos, seqlen_ro, params.rotary_dim / 2);
1465 CHECK_CONTIGUOUS(rotary_cos);
1466 TORCH_CHECK(rotary_cos.scalar_type() == q_dtype, "rotary_cos must have the same dtype as query");
1467
1468 TORCH_CHECK(rotary_sin_.has_value(), "If rotary cos is provided, rotary sin must also be provided");
1469 auto rotary_sin = rotary_sin_.value();
1470 CHECK_DEVICE(rotary_sin);
1471 CHECK_SHAPE(rotary_sin, seqlen_ro, params.rotary_dim / 2);
1472 CHECK_CONTIGUOUS(rotary_sin);
1473 TORCH_CHECK(rotary_sin.scalar_type() == q_dtype, "rotary_cos must have the same dtype as query");
1474 params.rotary_cos_ptr = rotary_cos.data_ptr();
1475 params.rotary_sin_ptr = rotary_sin.data_ptr();
1476 params.is_rotary_interleaved = is_rotary_interleaved;
1477 } else {
1478 params.rotary_dim = 0;
1479 }
1480
1481 if (cache_batch_idx_.has_value()) {
1482 auto cache_batch_idx = cache_batch_idx_.value();
1483 CHECK_DEVICE(cache_batch_idx);
1484 CHECK_CONTIGUOUS(cache_batch_idx);
1485 TORCH_CHECK(cache_batch_idx.scalar_type() == at::kInt, "cache_batch_idx must have dtype int32");
1486 params.cache_batch_idx = reinterpret_cast<int *>(cache_batch_idx.data_ptr());
1487 }
1488
1489 // Keep references to these tensors to extend their lifetime
1490 at::Tensor softmax_lse_accum, out_accum;
1491 std::tie(softmax_lse_accum, out_accum) = set_params_splitkv(params, batch_size, num_heads,
1492 head_size, seqlen_k, seqlen_q,
1493 head_size_rounded, /*dropout*/0.f, num_splits, dprops, opts);
1494
1495 if (paged_KV) {
1496 params.block_table = block_table.data_ptr<int>();
1497 params.block_table_batch_stride = block_table.stride(0);
1498 }
1499 params.page_block_size = page_block_size;
1500
1501
1502 set_params_alibi(params, alibi_slopes_, batch_size, num_heads);
1503
1504 auto stream = at::cuda::getCurrentCUDAStream().stream();
1505 // Only split kernel supports appending to KV cache, or indexing to the cache with cache_batch_idx,
1506 // or paged KV cache
1507 run_mha_fwd(params, stream, /*force_split_kernel=*/k_.has_value() || cache_batch_idx_.has_value() || paged_KV);
1508
1509 if (head_size_og % 8 != 0) {
1510 // out = out.index({"...", at::indexing::Slice(at::indexing::None, head_size_og)});
1511 out = out.narrow(-1, 0, head_size_og);
1512 if (out_.has_value()) { out_.value().copy_(out); }
1513 if (k_.has_value()) {
1514 // It's expensive to copy the KV cache here for the case where head size not divisible by 8,
1515 // but we don't expect to get this case in practice. This is just so that the code works for that case.
1516 kcache.copy_(kcache_padded.narrow(-1, 0, head_size_og));
1517 vcache.copy_(vcache_padded.narrow(-1, 0, head_size_og));
1518 // kcache.copy_(kcache_padded.index({"...", at::indexing::Slice(at::indexing::None, head_size_og)}));
1519 // vcache.copy_(vcache_padded.index({"...", at::indexing::Slice(at::indexing::None, head_size_og)}));
1520 }
1521 }
1522
1523 if (seqlenq_ngroups_swapped) {
1524 out = out.transpose(1, 2).reshape({batch_size, 1, num_heads_k * seqlen_q, head_size_og});
1525 softmax_lse = softmax_lse.reshape({batch_size, num_heads_k * seqlen_q, 1});
1526 }
1527 return {out, softmax_lse};
1528 }
1529
1530 } // namespace pytorch_fmha
1531
1532 #endif
1533