xref: /aosp_15_r20/external/pytorch/aten/src/ATen/native/transformers/cuda/attention_backward.cu (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #include <string_view>
2 #define TORCH_ASSERT_ONLY_METHOD_OPERATORS
3 #include <cstdint>
4 #include <type_traits>
5 
6 #include <ATen/core/Tensor.h>
7 #include <ATen/TensorOperators.h>
8 
9 #include <ATen/cuda/CUDAContext.h>
10 #include <ATen/cuda/CUDAGraphsUtils.cuh>
11 #include <c10/cuda/CUDAMathCompat.h>
12 #include <c10/util/Exception.h>
13 #include <c10/util/bit_cast.h>
14 
15 #include <c10/core/TensorImpl.h>
16 #include <ATen/native/nested/NestedTensorTransformerFunctions.h>
17 #include <ATen/native/nested/NestedTensorUtils.h>
18 #include <ATen/native/transformers/attention.h>
19 #include <ATen/native/transformers/cuda/sdp_utils.h>
20 #include <ATen/native/transformers/sdp_utils_cpp.h>
21 #include <ATen/cuda/CUDAGeneratorImpl.h>
22 
23 #ifndef AT_PER_OPERATOR_HEADERS
24 #include <ATen/Functions.h>
25 #include <ATen/NativeFunctions.h>
26 #else
27 #include <ATen/ops/_flash_attention_backward.h>
28 #include <ATen/ops/_flash_attention_backward_native.h>
29 #include <ATen/ops/_efficient_attention_backward.h>
30 #include <ATen/ops/_efficient_attention_backward_native.h>
31 #include <ATen/ops/_scaled_dot_product_flash_attention_backward_native.h>
32 #endif
33 
34 #ifdef USE_FLASH_ATTENTION
35 // FlashAttention Specific Imports
36 #include <ATen/native/transformers/cuda/flash_attn/flash_api.h>
37 #endif
38 #ifdef USE_MEM_EFF_ATTENTION
39 #ifndef USE_ROCM
40 // MemoryEfficient Attention Specific Imports for CUDA
41 #include <ATen/native/transformers/cuda/mem_eff_attention/kernel_backward.h>
42 #include <ATen/native/transformers/cuda/mem_eff_attention/kernels/cutlassB.h>
43 #include <ATen/native/transformers/cuda/mem_eff_attention/gemm_kernel_utils.h>
44 #include <ATen/native/transformers/cuda/mem_eff_attention/pytorch_utils.h>
45 #else
46 // MemoryEfficient Attention Specific Imports for ROCM
47 #include <ATen/native/transformers/hip/aotriton_adapter.h>
48 #include <aotriton/flash.h>
49 #include <aotriton/runtime.h>
50 #endif
51 #endif
52 
53 #ifdef __HIP_PLATFORM_AMD__
54 #include <ATen/native/cudnn/hip/MHA.h>
55 #else
56 #include <ATen/native/cudnn/MHA.h>
57 #endif
58 
59 namespace at::native {
60 
_flash_attention_backward(const Tensor & grad_out,const Tensor & query,const Tensor & key,const Tensor & value,const Tensor & out,const Tensor & logsumexp,const Tensor & cumulative_sequence_length_q,const Tensor & cumulative_sequence_length_k,int64_t max_seqlen_batch_q,int64_t max_seqlen_batch_k,double dropout_p,bool is_causal,const Tensor & philox_seed,const Tensor & philox_offset,std::optional<double> scale,std::optional<int64_t> window_size_left,std::optional<int64_t> window_size_right)61 std::tuple<Tensor, Tensor, Tensor> _flash_attention_backward(
62     const Tensor& grad_out,
63     const Tensor& query,
64     const Tensor& key,
65     const Tensor& value,
66     const Tensor& out,
67     const Tensor& logsumexp,
68     const Tensor& cumulative_sequence_length_q,
69     const Tensor& cumulative_sequence_length_k,
70     int64_t max_seqlen_batch_q,
71     int64_t max_seqlen_batch_k,
72     double dropout_p,
73     bool is_causal,
74     const Tensor& philox_seed,
75     const Tensor& philox_offset,
76     std::optional<double> scale,
77     std::optional<int64_t> window_size_left,
78     std::optional<int64_t> window_size_right) {
79 #if defined(USE_FLASH_ATTENTION)
80   const auto softmax_scale = sdp::calculate_scale(query, scale).as_float_unchecked();
81   //  CUDA code assumes that dout is contiguous
82   auto contiguous_grad_out = grad_out.contiguous();
83   auto contiguous_out = out.contiguous();
84 
85   const int non_null_window_left = window_size_left.has_value() ? window_size_left.value() : -1;
86   const int non_null_window_right = window_size_right.has_value() ? window_size_right.value() : -1;
87 
88   std::optional<at::Tensor> dq{std::nullopt};
89   std::optional<at::Tensor> dk{std::nullopt};
90   std::optional<at::Tensor> dv{std::nullopt};
91 
92   //  The kernel computes irregardless we will drop for this functions return
93   Tensor grad_softmax;
94 
95   // Currently unused args:
96   std::optional<at::Tensor> alibi_slopes{std::nullopt};
97 
98   bool determinisitic{false};
99   auto& ctx = at::globalContext();
100   if (ctx.deterministicAlgorithms()) {
101     if (ctx.deterministicAlgorithmsWarnOnly()) {
102       TORCH_WARN_ONCE(
103           "Flash Attention defaults to a non-deterministic algorithm. ",
104           "To explicitly enable determinism call torch.use_deterministic_algorithms(True, warn_only=False).");
105     } else {
106       determinisitic = true;
107     }
108   }
109 
110   // We check the whether the cumulative_sequence_length_q is defined
111   // in order to determine whether we are using varlen or dense forward
112   if (cumulative_sequence_length_q.defined()) {
113     // Varlen forward
114     auto [dQuery, dKey, dValue, dSoftmax] = pytorch_flash::mha_varlen_bwd(
115         contiguous_grad_out,
116         query,
117         key,
118         value,
119         contiguous_out,
120         logsumexp,
121         dq,
122         dk,
123         dv,
124         cumulative_sequence_length_q,
125         cumulative_sequence_length_k,
126         alibi_slopes,
127         max_seqlen_batch_q,
128         max_seqlen_batch_k,
129         dropout_p,
130         softmax_scale,
131         false /*zero_tensors*/,
132         is_causal,
133         non_null_window_left,
134         non_null_window_right,
135         determinisitic,
136         philox_seed,
137         philox_offset);
138     return std::make_tuple(std::move(dQuery), std::move(dKey), std::move(dValue));
139   } else {
140     // Dense forward
141     auto [dQuery, dKey, dValue, dSoftmax] = pytorch_flash::mha_bwd(
142         contiguous_grad_out,
143         query,
144         key,
145         value,
146         contiguous_out,
147         logsumexp,
148         dq,
149         dk,
150         dv,
151         alibi_slopes,
152         dropout_p,
153         softmax_scale,
154         is_causal,
155         non_null_window_left,
156         non_null_window_right,
157         determinisitic,
158         philox_seed,
159         philox_offset);
160     return std::make_tuple(std::move(dQuery), std::move(dKey), std::move(dValue));
161   }
162 #endif
163   TORCH_CHECK(false, "USE_FLASH_ATTENTION was not enabled for build.");
164   return std::make_tuple(Tensor(), Tensor(), Tensor());
165 }
166 
_scaled_dot_product_cudnn_attention_backward_cuda(const Tensor & grad_out,const Tensor & query,const Tensor & key,const Tensor & value,const Tensor & out,const Tensor & logsumexp,const Tensor & philox_seed,const Tensor & philox_offset,const Tensor & attn_bias,const Tensor & cum_seq_q,const Tensor & cum_seq_k,const int64_t max_q,const int64_t max_k,double dropout_p,bool is_causal,std::optional<double> scale)167 std::tuple<Tensor, Tensor, Tensor> _scaled_dot_product_cudnn_attention_backward_cuda(
168     const Tensor& grad_out,
169     const Tensor& query,
170     const Tensor& key,
171     const Tensor& value,
172     const Tensor& out,
173     const Tensor& logsumexp,
174     const Tensor& philox_seed,
175     const Tensor& philox_offset,
176     const Tensor& attn_bias,
177     const Tensor& cum_seq_q,
178     const Tensor& cum_seq_k,
179     const int64_t max_q,
180     const int64_t max_k,
181     double dropout_p,
182     bool is_causal,
183     std::optional<double> scale) {
184 
185     auto& ctx = at::globalContext();
186     if (ctx.deterministicAlgorithms()) {
187       if (ctx.deterministicAlgorithmsWarnOnly()) {
188         TORCH_WARN_ONCE(
189             "cuDNN Attention defaults to a non-deterministic algorithm. ",
190             "To explicitly enable determinism call torch.use_deterministic_algorithms(True, warn_only=False).");
191       }
192     }
193 
194     const int64_t batch_size = query.size(0);
195     const int64_t num_heads = query.size(1);
196     const int64_t head_dim_qk = query.size(3);
197     const int64_t head_dim_v = value.size(3);
198     const int64_t max_seqlen_batch_q = query.size(2);
199     const int64_t max_seqlen_batch_k = key.size(2);
200 
201     // This is needed because SaveVariable automatically converts
202     // std::optional to undefined tensor
203     std::optional<Tensor> attn_bias_;
204     if (attn_bias.defined()) {
205       attn_bias_ = attn_bias;
206     }
207     if (attn_bias_.has_value()) {
208       const auto bias_dim = attn_bias_.value().dim();
209       if (bias_dim == 2) {
210         attn_bias_ = attn_bias_.value().expand({batch_size, 1, max_seqlen_batch_q, max_seqlen_batch_k});
211       } else if (bias_dim == 3) {
212         attn_bias_ = attn_bias_.value().expand({batch_size, 1, max_seqlen_batch_q, max_seqlen_batch_k});
213       } else {
214         attn_bias_ = attn_bias_.value().expand({batch_size, attn_bias_.value().size(1), max_seqlen_batch_q, max_seqlen_batch_k});
215         TORCH_CHECK(bias_dim == 4, "cuDNN SDPA expects either a 2D, 3D, or 4D attn_bias but got ", attn_bias_.value().dim(), "D");
216       }
217     }
218 
219     const auto softmax_scale = sdp::calculate_scale(query, scale).as_float_unchecked();
220     auto dq = at::empty_like(query);
221     auto dk = at::empty_like(key);
222     auto dv = at::empty_like(value);
223     run_cudnn_SDP_bprop(batch_size /*int64_t b*/,
224                         num_heads /*int64_t h*/,
225                         max_q/*int64_t s_q*/,
226                         max_k/*int64_t s_kv*/,
227                         head_dim_qk /*int64_t d_qk*/,
228                         head_dim_v /*int64_t d_v*/,
229                         softmax_scale /*float scaling_factor*/,
230                         is_causal /*bool is_causal*/,
231                         dropout_p /*float dropout_probability*/,
232                         query /*const Tensor& q*/,
233                         key /*const Tensor& k*/,
234                         value /*const Tensor& v*/,
235                         attn_bias_ /*const std::optional<Tensor>& attn_bias*/,
236                         out /*const Tensor& o*/,
237                         grad_out/*const Tensor& dO*/,
238                         logsumexp.unsqueeze(-1)/*const Tensor& softmaxstats*/,
239                         dq/*Tensor& dQ*/,
240                         dk/*Tensor& dK*/,
241                         dv/*Tensor& dV*/,
242                         philox_seed/*Tensor& dropoutseed*/,
243                         philox_offset/*Tensor& dropoutoffset*/);
244     return std::make_tuple(std::move(dq), std::move(dk), std::move(dv));
245 }
246 
247 std::tuple<at::Tensor, at::Tensor, at::Tensor, at::Tensor>
_efficient_attention_backward(const at::Tensor & grad_out_,const at::Tensor & query,const at::Tensor & key,const at::Tensor & value,const std::optional<at::Tensor> & kernel_bias,const at::Tensor & out,const std::optional<at::Tensor> & cu_seqlens_q_dummy,const std::optional<at::Tensor> & cu_seqlens_k_dummy,int64_t max_seqlen_q,int64_t max_seqlen_k,const at::Tensor & logsumexp,double dropout_p,const at::Tensor & philox_seed,const at::Tensor & philox_offset,int64_t custom_mask_type,const bool bias_requires_grad,const std::optional<double> scale,std::optional<int64_t> num_splits_key,const std::optional<int64_t> window_size,const bool shared_storage_dqdkdv)248 _efficient_attention_backward(
249     const at::Tensor& grad_out_,
250     const at::Tensor& query,
251     const at::Tensor& key,
252     const at::Tensor& value,
253     const std::optional<at::Tensor>& kernel_bias, // additive attention bias
254     const at::Tensor& out,
255     // (Mode 1MHK only) [b+1]: cu_seqlens_q[b] contains the
256     // position of the first query token for batch $b
257     const std::optional<at::Tensor>& cu_seqlens_q_dummy,
258     // (Mode 1MHK only) [b+1]: cu_seqlens_k[b] contains the
259     // position of the first key token for batch $b
260     const std::optional<at::Tensor>& cu_seqlens_k_dummy,
261     // (Mode 1MHK only) Maximum sequence length across batches
262     int64_t max_seqlen_q,
263     // (Mode 1MHK only) Maximum sequence length across batches
264     int64_t max_seqlen_k,
265     const at::Tensor& logsumexp,
266     double dropout_p, // dropout probability
267     const at::Tensor& philox_seed, // seed using for generating random numbers for dropout
268     const at::Tensor& philox_offset, // offset into random number sequence
269     int64_t custom_mask_type,
270     const bool bias_requires_grad,
271     const std::optional<double> scale,
272     std::optional <int64_t> num_splits_key,
273     const std::optional<int64_t> window_size,
274     const bool shared_storage_dqdkdv) {
275   #if defined(USE_MEM_EFF_ATTENTION)
276   if (!grad_out_.defined()) {
277     return std::make_tuple(Tensor{}, Tensor{}, Tensor{}, Tensor{});
278   }
279   // This path is used when we directly call _efficient_attention_forward
280   // from python.
281   // This is needed because SaveVariable automatically converts
282   // std::optional to undefined tensor
283   std::optional<Tensor> bias, cu_seqlens_q, cu_seqlens_k;
284   bias = kernel_bias.has_value() && !kernel_bias->defined() ? std::nullopt : kernel_bias;
285   cu_seqlens_q = cu_seqlens_q_dummy.has_value() && !cu_seqlens_q_dummy->defined() ? std::nullopt : cu_seqlens_q_dummy;
286   cu_seqlens_k = cu_seqlens_k_dummy.has_value() && !cu_seqlens_k_dummy->defined() ? std::nullopt : cu_seqlens_k_dummy;
287 
288     // ndim
289   TORCH_CHECK(query.dim() == grad_out_.dim());
290   TORCH_CHECK(query.dim() == key.dim());
291   TORCH_CHECK(query.dim() == value.dim());
292   TORCH_CHECK(query.dim() == 4);
293 
294   // batch size
295   TORCH_CHECK(query.size(0) == grad_out_.size(0));
296   TORCH_CHECK(query.size(0) == key.size(0));
297   TORCH_CHECK(query.size(0) == value.size(0));
298 
299   // seqlen
300   TORCH_CHECK(key.size(1) == value.size(1));
301   TORCH_CHECK(query.size(1) == grad_out_.size(1));
302 
303   // Num heads
304   TORCH_CHECK(query.size(2) == key.size(2));
305   TORCH_CHECK(query.size(2) == value.size(2));
306   TORCH_CHECK(query.size(2) == grad_out_.size(2));
307 
308   // Embedding per head
309   TORCH_CHECK(query.size(3) == key.size(3));
310   TORCH_CHECK(value.size(3) == grad_out_.size(3));
311 
312   // handle potentially non-contiguous grad_out through a copy
313   auto grad_out = grad_out_.contiguous();
314   CHECK_NOSPARSE_CONTIGUOUS_CUDA(grad_out);
315 
316   CHECK_NOSPARSE_LASTCONTIGUOUS_CUDA(query);
317   CHECK_NOSPARSE_LASTCONTIGUOUS_CUDA(key);
318   CHECK_NOSPARSE_LASTCONTIGUOUS_CUDA(value);
319 
320   TORCH_CHECK(cu_seqlens_q.has_value() == cu_seqlens_k.has_value());
321   TORCH_CHECK(
322       !(cu_seqlens_q.has_value() && bias.has_value()),
323       "cu seqlen + bias not supported");
324   if (cu_seqlens_q.has_value()) {
325     TORCH_CHECK(cu_seqlens_q->scalar_type() == at::ScalarType::Int);
326     TORCH_CHECK(cu_seqlens_k->scalar_type() == at::ScalarType::Int);
327     TORCH_CHECK(cu_seqlens_q->dim() == 1 && cu_seqlens_k->dim() == 1);
328     CHECK_NOSPARSE_CONTIGUOUS_CUDA((*cu_seqlens_q));
329     CHECK_NOSPARSE_CONTIGUOUS_CUDA((*cu_seqlens_k));
330     TORCH_CHECK(cu_seqlens_q->size(0) == cu_seqlens_k->size(0));
331     TORCH_CHECK(query.size(0) == 1, "cu_seqlen only supports batch_size=1");
332     TORCH_CHECK(max_seqlen_q > 0, "max_seqlen_q required with `cu_seqlens_q`");
333     TORCH_CHECK(max_seqlen_k > 0, "max_seqlen_k required with `cu_seqlens_k`");
334     TORCH_CHECK(
335         max_seqlen_k <= key.size(1), "Invalid max_seqlen_k:", max_seqlen_k);
336     TORCH_CHECK(
337         max_seqlen_q <= query.size(1), "Invalid max_seqlen_q:", max_seqlen_q);
338   } else {
339     max_seqlen_q = query.size(1);
340     max_seqlen_k = key.size(1);
341   }
342 
343   at::cuda::CUDAGuard device_guard(query.device());
344   cudaStream_t stream = at::cuda::getCurrentCUDAStream();
345 
346   int64_t B = query.size(0);
347   int64_t M = query.size(1);
348   int64_t N = key.size(1);
349   int64_t nH = query.size(2);
350   int64_t K = query.size(3);
351   int64_t Kv = value.size(3);
352 
353   at::Tensor grad_q, grad_k, grad_v, grad_bias;
354   if (shared_storage_dqdkdv) {
355     // Create one big contiguous chunk
356     // This is because q, k and v usually come from a single
357     // output of a linear layer that is chunked.
358     // Creating the gradients with the right layout saves us
359     // a `torch.cat` call in the backward pass
360     TORCH_CHECK(
361       query.size(1) == key.size(1),
362       "`shared_storage_dqdkdv` is only supported when Q/K/V "
363       "have the same sequence length: got ", query.size(1),
364       " query tokens and ", key.size(1), " key/value tokens"
365     );
366     TORCH_CHECK(
367       query.size(3) == key.size(3),
368       "`shared_storage_dqdkdv` is only supported when Q/K/V "
369       "have the same embed dim: got ", query.size(3),
370       " for Q, and ", key.size(3), " for K"
371     );
372     at::Tensor chunk = at::empty({B, M, 3, nH, K}, query.options());
373     grad_q = chunk.select(2, 0);
374     grad_k = chunk.select(2, 1);
375     grad_v = chunk.select(2, 2);
376   } else {
377     grad_q = at::empty(query.sizes(), query.options());
378     grad_k = at::empty(key.sizes(), key.options());
379     grad_v = at::empty(value.sizes(), value.options());
380   }
381 
382   if (bias_requires_grad) {
383     // force alignment for the last dim
384     std::vector<int64_t> sz = bias->sizes().vec();
385     int64_t lastDim = sz[sz.size() - 1];
386     int64_t alignTo = 16;
387     sz[sz.size() - 1] = alignTo * ((lastDim + alignTo - 1) / alignTo);
388     grad_bias = at::empty(sz, bias->options())
389                     .slice(/*dim=*/-1, /*start=*/0, /*end=*/lastDim);
390   }
391 
392   const bool use_dropout = std::fpclassify(dropout_p) != FP_ZERO;
393 
394   // See Note [Seed and Offset Device]
395   at::PhiloxCudaState rng_engine_inputs;
396   if (use_dropout) {
397     if (at::cuda::currentStreamCaptureStatus() ==
398         at::cuda::CaptureStatus::None) {
399       rng_engine_inputs = at::PhiloxCudaState(
400           *philox_seed.data_ptr<int64_t>(),
401           *philox_offset.data_ptr<int64_t>());
402     } else { // dropout + capture
403       rng_engine_inputs = at::PhiloxCudaState(
404           philox_seed.data_ptr<int64_t>(),
405           philox_offset.data_ptr<int64_t>(),
406           0);
407     }
408   }
409 
410 #ifdef USE_ROCM
411   // ROCM Implementation
412   TORCH_CHECK(!num_splits_key.has_value(),
413               "ROCM does not support num_split_keys in _efficient_attention_forward");
414   TORCH_CHECK(!window_size.has_value(),
415               "ROCM does not support window_size in _efficient_attention_forward");
416   auto ret = aotriton::v2::flash::check_gpu(stream);
417   if (hipSuccess != ret) {
418     TORCH_CHECK(false,
419                 "[AOTriton] Accelerated SDPA only supports MI200/MI300X/Navi31 GPUs"
420                 " (gfx90a:sramecc+:xnack-/gfx942:sramecc+:xnack-/gfx1100)")
421   }
422   const auto softmax_scale = sdp::calculate_scale(query, scale).as_float_unchecked();
423   bool is_causal;
424   if (static_cast<int64_t>(sdp::CustomMaskType::CausalFromTopLeft) == custom_mask_type) {
425     is_causal = true;
426   } else if (static_cast<int64_t>(sdp::CustomMaskType::NoCustomMask) == custom_mask_type) {
427     is_causal = false;
428   } else {
429     TORCH_CHECK(false, "[_efficient_attention_backward] Unsupported mask type in AOTriton, for now");
430   }
431   at::Tensor q_t = query.permute({0,2,1,3});
432   at::Tensor k_t = key.permute({0,2,1,3});
433   at::Tensor v_t = value.permute({0,2,1,3});
434   at::Tensor out_t = out.permute({0,2,1,3});
435   at::Tensor dq_t = grad_q.permute({0,2,1,3});
436   at::Tensor dk_t = grad_k.permute({0,2,1,3});
437   at::Tensor dv_t = grad_v.permute({0,2,1,3});
438   at::Tensor dout_t = grad_out.permute({0,2,1,3});
439   at::Tensor softmax_lse = logsumexp.view({B * nH, max_seqlen_q});
440   at::Tensor delta = at::empty_like(softmax_lse).contiguous();
441 
442   hipError_t err;
443   using aotriton::v2::flash::attn_bwd;
444   using sdp::aotriton_adapter::mk_aotensor;
445   using sdp::aotriton_adapter::mk_aoscalartensor;
446   using sdp::aotriton_adapter::cast_dtype;
447   aotriton::TensorView<4> empty_t4(0, {0, 0, 0, 0}, {0, 0, 0, 0}, cast_dtype(query.dtype()));
448   err = attn_bwd(mk_aotensor(q_t, "q"),
449                  mk_aotensor(k_t, "k"),
450                  mk_aotensor(v_t, "v"),
451                  bias.has_value() ? mk_aotensor(bias.value(), "bias") : empty_t4,
452                  softmax_scale,
453                  mk_aotensor(out_t, "out"),
454                  mk_aotensor(dout_t, "dout"),
455                  mk_aotensor(dq_t, "dq"),
456                  mk_aotensor(dk_t, "dk"),
457                  mk_aotensor(dv_t, "dv"),
458                  bias_requires_grad ? mk_aotensor(grad_bias, "db") : empty_t4,
459                  mk_aotensor<2>(softmax_lse, "L"),
460                  mk_aotensor<2>(delta, "delta"),
461                  float(dropout_p),
462                  mk_aoscalartensor(philox_seed),
463                  mk_aoscalartensor(philox_offset),
464                  0,
465                  is_causal,
466                  stream);
467 #else
468   at::Tensor workspace;
469   cudaDeviceProp* p = at::cuda::getDeviceProperties(query.device().index());
470   const int computeCapability = p->major * 10 + p->minor;
471 
472   bool kernel_launched = false;
473   const auto maxK = std::max(query.size(3), value.size(3));
474   const auto maxShmem = p->sharedMemPerBlockOptin;
475 
476   auto launchKernel = [&](auto _k, auto kernel_fn) {
477     using Kernel = decltype(_k);
478     using scalar_t = typename Kernel::scalar_t;
479     (void)_k;
480 
481     if (kernel_launched) {
482       return;
483     }
484     // Check if this kernel is compatible
485     if (Kernel::kMaxK < maxK) {
486       return;
487     }
488     // Dropout must be supported if we need it
489     if (use_dropout && !Kernel::kApplyDropout) {
490       return;
491     }
492     if (Kernel::kKeysQueriesAlignedToBlockSize &&
493         (cu_seqlens_q.has_value() || M % Kernel::kBlockSizeI ||
494          N % Kernel::kBlockSizeJ)) {
495       return;
496     }
497     // Alignment
498     if ((query.stride(2) % Kernel::kMinimumAlignment) ||
499         (key.stride(2) % Kernel::kMinimumAlignment) ||
500         (value.stride(2) % Kernel::kMinimumAlignment)) {
501       return;
502     }
503     // Uses too much shmem
504     size_t smem_bytes = sizeof(typename Kernel::SharedStorage);
505     if (smem_bytes > maxShmem) {
506       return;
507     }
508 
509     kernel_launched = true;
510 
511     // TODO: Fuse this into a kernel?
512     // This is a bottleneck for smaller sequences (M <= 128)
513     auto delta = Kernel::kKernelComputesDelta
514         ? at::empty({B, nH, M}, query.options().dtype(at::ScalarType::Float))
515         : (grad_out.to(at::kFloat) * out.to(at::kFloat))
516               .sum(-1)
517               .transpose(-2, -1)
518               .contiguous();
519     TORCH_INTERNAL_ASSERT(delta.size(0) == B);
520     TORCH_INTERNAL_ASSERT(delta.size(1) == nH);
521     TORCH_INTERNAL_ASSERT(delta.size(2) == M);
522 
523     typename Kernel::Params p;
524     p.query_ptr = (const scalar_t*)query.const_data_ptr();
525     p.key_ptr = (const scalar_t*)key.const_data_ptr();
526     p.value_ptr = (const scalar_t*)value.const_data_ptr();
527     p.logsumexp_ptr = (typename Kernel::lse_scalar_t const *)logsumexp.const_data_ptr();
528     p.output_ptr = (const scalar_t*)out.const_data_ptr();
529     p.grad_output_ptr = (const scalar_t*)grad_out.const_data_ptr();
530     p.grad_query_ptr = (scalar_t*)grad_q.data_ptr();
531     p.grad_key_ptr = (scalar_t*)grad_k.data_ptr();
532     p.grad_value_ptr = (scalar_t*)grad_v.data_ptr();
533     p.delta_ptr = (float*)delta.data_ptr();
534     p.head_dim = query.size(3);
535     p.head_dim_value = value.size(3);
536     p.num_queries = max_seqlen_q;
537     p.num_keys = max_seqlen_k;
538     p.num_batches = cu_seqlens_q.has_value() ? cu_seqlens_q->size(0) - 1 : B;
539     p.num_heads = nH;
540     p.custom_mask_type = custom_mask_type;
541     p.scale = sdp::calculate_scale(query, scale).as_float_unchecked();
542     if (cu_seqlens_q.has_value()) {
543       p.cu_seqlens_q_ptr = (const int32_t*)cu_seqlens_q->const_data_ptr();
544       p.cu_seqlens_k_ptr = (const int32_t*)cu_seqlens_k->const_data_ptr();
545     }
546     if (window_size.has_value()) {
547       p.window_size = *window_size;
548     }
549 
550     ASSIGN_CHECK_OVERFLOW(p.lse_strideB, logsumexp.stride(0));
551     ASSIGN_CHECK_OVERFLOW(p.lse_strideH, logsumexp.stride(1));
552     ASSIGN_CHECK_OVERFLOW(p.gO_strideB, grad_out.stride(0));
553     ASSIGN_CHECK_OVERFLOW(p.gO_strideM, grad_out.stride(1));
554     ASSIGN_CHECK_OVERFLOW(p.gO_strideH, grad_out.stride(2));
555 
556     ASSIGN_CHECK_OVERFLOW(p.o_strideB, out.stride(0));
557     ASSIGN_CHECK_OVERFLOW(p.o_strideH, out.stride(2));
558 
559     ASSIGN_CHECK_OVERFLOW(p.gQ_strideB, grad_q.stride(0));
560     ASSIGN_CHECK_OVERFLOW(p.gK_strideB, grad_k.stride(0));
561     ASSIGN_CHECK_OVERFLOW(p.gV_strideB, grad_v.stride(0));
562     ASSIGN_CHECK_OVERFLOW(p.gQ_strideH, grad_q.stride(2));
563     ASSIGN_CHECK_OVERFLOW(p.gK_strideH, grad_k.stride(2));
564     ASSIGN_CHECK_OVERFLOW(p.gV_strideH, grad_v.stride(2));
565     p.gQKV_strideM_multiplier = shared_storage_dqdkdv ? 3 : 1;
566     TORCH_INTERNAL_ASSERT(p.gQ_strideM() == grad_q.stride(1));
567     TORCH_INTERNAL_ASSERT(p.gK_strideM() == grad_k.stride(1));
568     TORCH_INTERNAL_ASSERT(p.gV_strideM() == grad_v.stride(1));
569 
570     ASSIGN_CHECK_OVERFLOW(p.q_strideB, query.stride(0));
571     ASSIGN_CHECK_OVERFLOW(p.k_strideB, key.stride(0));
572     ASSIGN_CHECK_OVERFLOW(p.v_strideB, value.stride(0));
573     ASSIGN_CHECK_OVERFLOW(p.q_strideM, query.stride(1));
574     ASSIGN_CHECK_OVERFLOW(p.k_strideM, key.stride(1));
575     ASSIGN_CHECK_OVERFLOW(p.v_strideM, value.stride(1));
576     ASSIGN_CHECK_OVERFLOW(p.q_strideH, query.stride(2));
577     ASSIGN_CHECK_OVERFLOW(p.k_strideH, key.stride(2));
578     ASSIGN_CHECK_OVERFLOW(p.v_strideH, value.stride(2));
579     ASSIGN_CHECK_OVERFLOW(p.delta_strideB, delta.stride(0));
580     ASSIGN_CHECK_OVERFLOW(p.delta_strideH, delta.stride(1));
581 
582     if (bias.has_value()) {
583       CHECK_NOSPARSE_LASTCONTIGUOUS_CUDA((*bias));
584       TORCH_CHECK(
585           bias->scalar_type() == CutlassToAtenDtype<scalar_t>::atScalarType(),
586           "invalid dtype for bias - should match query's dtype");
587 
588       p.bias_ptr = (scalar_t*)bias->data_ptr();
589 
590       TORCH_CHECK(bias->dim() == 4, "Bias expected in BMHK format");
591       TORCH_CHECK(
592           bias->size(0) == query.size(0),
593           "attn_bias: wrong shape (batch dimension)");
594       TORCH_CHECK(
595           bias->size(1) == query.size(2),
596           "attn_bias: wrong shape (head dimension)");
597       TORCH_CHECK(
598           bias->size(2) == query.size(1),
599           "attn_bias: wrong shape (seqlenQ dimension)");
600       TORCH_CHECK(
601           bias->size(3) == key.size(1),
602           "attn_bias: wrong shape (seqlenKV dimension)");
603       TORCH_CHECK(
604           bias->stride(3) == 1,
605           "attn_bias: wrong alignment (last dimension must be contiguous)");
606       ASSIGN_CHECK_OVERFLOW(p.bias_strideB, bias->stride(0));
607       ASSIGN_CHECK_OVERFLOW(p.bias_strideH, bias->stride(1));
608       ASSIGN_CHECK_OVERFLOW(p.bias_strideM, bias->stride(2));
609 
610       if (bias_requires_grad) {
611         p.grad_bias_ptr = (scalar_t*)grad_bias.data_ptr();
612 
613         ASSIGN_CHECK_OVERFLOW(p.gB_strideB, grad_bias.stride(0));
614         ASSIGN_CHECK_OVERFLOW(p.gB_strideH, grad_bias.stride(1));
615         ASSIGN_CHECK_OVERFLOW(p.gB_strideM, grad_bias.stride(2));
616       }
617     }
618 
619     if (use_dropout) {
620       p.rng_engine_inputs = rng_engine_inputs;
621       p.dropout_prob = dropout_p;
622     }
623 
624     // Heuristic for finding optimal number of splits
625     auto parallelism_without_split_key =
626         p.getBlocksGrid().x * p.getBlocksGrid().y * p.getBlocksGrid().z;
627     p.num_splits_key = cutlass::ceil_div(p.num_keys, Kernel::kBlockSizeJ);
628     if (num_splits_key.has_value()) {
629       p.num_splits_key =
630           std::min<int64_t>(p.num_splits_key, num_splits_key.value());
631     } else {
632       // Keys splitting heuristic
633 
634       // If we already have enough parallelism, split-keys can help
635       // better use L2 cache.
636       // This is negligible when the seqlen is too small tho
637       if (parallelism_without_split_key >= 256 &&
638           p.num_keys <= 2 * Kernel::kBlockSizeJ) {
639         p.num_splits_key = 1;
640       }
641       // Increasing `split_keys` leads to using more gmem for temporary storage
642       // when we need a staging area for gK/gV. let's avoid that
643       if (Kernel::kNeedsAccumGradK || Kernel::kNeedsAccumGradV) {
644         p.num_splits_key = std::min(
645             int(p.num_splits_key), 200 / (p.num_batches * p.num_heads));
646       }
647     }
648     if (!Kernel::kEnableSplitKeys || p.num_splits_key < 1) {
649       p.num_splits_key = 1;
650     }
651 
652     auto& ctx = at::globalContext();
653     if (ctx.deterministicAlgorithms()) {
654       if (ctx.deterministicAlgorithmsWarnOnly()) {
655         TORCH_WARN_ONCE(
656             "Memory Efficient attention defaults to a non-deterministic algorithm. ",
657             "To explicitly enable determinism call torch.use_deterministic_algorithms(True, warn_only=False).");
658       } else {
659         TORCH_CHECK(
660             num_splits_key.value_or(1) <= 1,
661             "Using `num_splits_key > 1` makes the algorithm non-deterministic, and pytorch's deterministic mode is enabled");
662         p.num_splits_key = 1;
663       }
664     }
665     int64_t size_bytes = p.workspace_size();
666     if (size_bytes) {
667       workspace =
668           at::empty({size_bytes}, query.options().dtype(at::ScalarType::Byte));
669       p.workspace = (float*)workspace.data_ptr();
670       if (p.should_zero_workspace()) {
671         workspace.zero_();
672       }
673     }
674 
675     // Handle the edge-cases where some tensors are empty
676     if (p.num_queries == 0 || p.num_keys == 0 || p.num_batches == 0 ||
677         p.num_heads == 0) {
678       grad_k.zero_();
679       grad_v.zero_();
680       grad_q.zero_();
681       return;
682     }
683     Kernel::check_supported(p);
684 
685     if (smem_bytes > 0xc000) {
686       // https://docs.nvidia.com/cuda/cuda-c-programming-guide/#features-and-technical-specifications-technical-specifications-per-compute-capability
687       auto err = cudaFuncSetAttribute(
688           kernel_fn, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_bytes);
689       TORCH_CHECK(
690           err != cudaErrorInvalidValue,
691           "This GPU does not have enough shared-memory (kernel requires ",
692           smem_bytes / 1024,
693           " kb)");
694       AT_CUDA_CHECK(err);
695     }
696 
697     // second syntax resulted in the error below on windows
698     // error C3495: 'kernel_fn': a simple capture must be a variable
699     // with automatic storage duration declared
700     // in the reaching scope of the lambda
701 #ifdef _WIN32
702     cudaFuncAttributes attr;
703     AT_CUDA_CHECK(cudaFuncGetAttributes(&attr, kernel_fn));
704     TORCH_INTERNAL_ASSERT(
705         attr.binaryVersion >= Kernel::ArchTag::kMinComputeCapability,
706         "Something went wrong in the build process");
707 #else
708     auto checkBinaryArchMatches = [&]() {
709       cudaFuncAttributes attr;
710       AT_CUDA_CHECK(cudaFuncGetAttributes(&attr, kernel_fn));
711       return attr.binaryVersion >= Kernel::ArchTag::kMinComputeCapability;
712     };
713     TORCH_INTERNAL_ASSERT(
714         checkBinaryArchMatches(), "Something went wrong in the build process");
715 #endif
716 
717     kernel_fn<<<p.getBlocksGrid(), p.getThreadsGrid(), smem_bytes, stream>>>(p);
718   };
719 
720   DISPATCH_TYPES(query, ([&]() {
721                    dispatch_cutlassB<scalar_t>(launchKernel, computeCapability);
722                  }));
723   TORCH_CHECK(kernel_launched, "cutlassB: no kernel found to launch!");
724   AT_CUDA_CHECK(cudaGetLastError());
725 #endif // USE_ROCM
726   return std::make_tuple(std::move(grad_q), std::move(grad_k), std::move(grad_v), std::move(grad_bias));
727   #endif // defined(USE_MEM_EFF_ATTENTION)
728   TORCH_CHECK(false, "USE_MEM_EFF_ATTENTION was not enabled for build.")
729   return std::make_tuple(Tensor{}, Tensor{}, Tensor{}, Tensor{});
730 }
731 
_scaled_dot_product_flash_attention_backward_cuda(const at::Tensor & grad_out_,const at::Tensor & query,const at::Tensor & key,const at::Tensor & value,const at::Tensor & out,const at::Tensor & logsumexp,const Tensor & cumulative_sequence_length_q,const Tensor & cumulative_sequence_length_k,const int64_t max_seqlen_batch_q,const int64_t max_seqlen_batch_k,double dropout_p,bool is_causal,const at::Tensor & philox_seed,const at::Tensor & philox_offset,std::optional<double> scale)732 std::tuple<at::Tensor, at::Tensor, at::Tensor> _scaled_dot_product_flash_attention_backward_cuda(
733     const at::Tensor& grad_out_,
734     const at::Tensor& query,
735     const at::Tensor& key,
736     const at::Tensor& value,
737     const at::Tensor& out,
738     const at::Tensor& logsumexp,
739     const Tensor& cumulative_sequence_length_q,
740     const Tensor& cumulative_sequence_length_k,
741     const int64_t max_seqlen_batch_q,
742     const int64_t max_seqlen_batch_k,
743     double dropout_p,
744     bool is_causal,
745     const at::Tensor& philox_seed,
746     const at::Tensor& philox_offset,
747     std::optional<double> scale){
748   if (!grad_out_.defined()) {
749     return std::make_tuple(Tensor{}, Tensor{}, Tensor{});
750   }
751 
752   Tensor q_t = query.transpose(1, 2);
753   Tensor k_t = key.transpose(1, 2);
754   Tensor v_t = value.transpose(1, 2);
755 
756   Tensor grad_out_t = grad_out_.transpose(1,2);
757   Tensor out_t = out.transpose(1,2);
758 
759   auto [grad_q, grad_k, grad_v] = at::_flash_attention_backward(
760     grad_out_t,
761     q_t,
762     k_t,
763     v_t,
764     out_t,
765     logsumexp,
766     cumulative_sequence_length_q,
767     cumulative_sequence_length_k,
768     max_seqlen_batch_q,
769     max_seqlen_batch_k,
770     dropout_p,
771     is_causal,
772     philox_seed,
773     philox_offset,
774     scale);
775 
776   grad_q = grad_q.transpose(1,2);
777   grad_k = grad_k.transpose(1,2);
778   grad_v = grad_v.transpose(1,2);
779 
780   return std::make_tuple(std::move(grad_q), std::move(grad_k), std::move(grad_v));
781 }
782 
783 
_scaled_dot_product_efficient_attention_backward_cuda(const at::Tensor & grad_out_,const at::Tensor & query,const at::Tensor & key,const at::Tensor & value,const at::Tensor & attn_bias,const at::Tensor & out,const at::Tensor & logsumexp,const at::Tensor & philox_seed,const at::Tensor & philox_offset,double dropout_p,std::array<bool,4> grad_input_mask,bool causal,std::optional<double> scale)784 std::tuple<at::Tensor, at::Tensor, at::Tensor, at::Tensor> _scaled_dot_product_efficient_attention_backward_cuda(
785     const at::Tensor& grad_out_,
786     const at::Tensor& query,
787     const at::Tensor& key,
788     const at::Tensor& value,
789     const at::Tensor& attn_bias,
790     const at::Tensor& out,
791     const at::Tensor& logsumexp,
792     const at::Tensor& philox_seed,
793     const at::Tensor& philox_offset,
794     double dropout_p,
795     std::array<bool, 4> grad_input_mask,
796     bool causal,
797     std::optional<double> scale) {
798 
799   if (!grad_out_.defined()) {
800     return std::make_tuple(Tensor{}, Tensor{}, Tensor{}, Tensor{});
801   }
802   auto grad_out = grad_out_.transpose(1, 2);
803   auto out_t = out.transpose(1, 2);
804   auto q_t = query.transpose(1, 2);
805   auto k_t = key.transpose(1, 2);
806   auto v_t = value.transpose(1, 2);
807 
808   // This is needed because SaveVariable automatically converts
809   // std::optional to undefined tensor
810   std::optional<Tensor> kernel_bias;
811   if (attn_bias.defined()) {
812     kernel_bias = attn_bias;
813   }
814   // Will add with signauter changes for dropout and bias
815   // We are only handling Dense inputs, but this should be passed
816   // from forward to backward
817   int64_t max_seqlen_q = q_t.size(1);
818   int64_t max_seqlen_k = k_t.size(1);
819 
820   sdp::CustomMaskType custom_mask_type = causal
821     ? sdp::CustomMaskType::CausalFromTopLeft
822     : sdp::CustomMaskType::NoCustomMask;
823   auto [grad_q, grad_k, grad_v, grad_bias] =
824       at::_efficient_attention_backward(
825           grad_out,
826           q_t,
827           k_t,
828           v_t,
829           kernel_bias,
830           out_t,
831           std::nullopt,
832           std::nullopt,
833           max_seqlen_q,
834           max_seqlen_k,
835           logsumexp,
836           dropout_p,
837           philox_seed,
838           philox_offset,
839           static_cast<int64_t>(custom_mask_type),
840           grad_input_mask[3],
841           scale,
842           std::nullopt);  // num_split_keys
843   return std::make_tuple(
844       grad_q.transpose(1, 2), grad_k.transpose(1, 2), grad_v.transpose(1, 2), grad_bias);
845 }
846 
847 } // namespace at::native
848