xref: /aosp_15_r20/external/pytorch/aten/src/ATen/native/transformers/hip/flash_attn/flash_api.hip (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1/******************************************************************************
2 * Copyright (c) 2023, Advanced Micro Devices, Inc.
3 * Copyright (c) 2022, Tri Dao.
4 * Copyright (c) 2011-2021, NVIDIA CORPORATION.  All rights reserved.
5 *
6 * Redistribution and use in source and binary forms, with or without
7 * modification, are permitted provided that the following conditions are met:
8 *     * Redistributions of source code must retain the above copyright
9 *       notice, this list of conditions and the following disclaimer.
10 *     * Redistributions in binary form must reproduce the above copyright
11 *       notice, this list of conditions and the following disclaimer in the
12 *       documentation and/or other materials provided with the distribution.
13 *     * Neither the name of the NVIDIA CORPORATION nor the
14 *       names of its contributors may be used to endorse or promote products
15 *       derived from this software without specific prior written permission.
16 *
17 * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
18 * ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
19 * WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
20 * DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY
21 * DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
22 * (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
23 * LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND
24 * ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
25 * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
26 * SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
27 *
28 ******************************************************************************/
29#include <c10/core/ScalarType.h>
30#define TORCH_ASSERT_ONLY_METHOD_OPERATORS
31
32#include <cstdint>
33#include <tuple>
34
35#include <ATen/ops/zeros.h>
36
37#ifdef USE_FLASH_ATTENTION
38#include <ATen/core/Tensor.h>
39#include <ATen/hip/HIPContext.h>
40#include <ATen/hip/impl/HIPGuardImplMasqueradingAsCUDA.h>
41#include <ATen/hip/HIPGraphsUtils.cuh>
42
43#ifndef AT_PER_OPERATOR_HEADERS
44#include <ATen/Functions.h>
45#include <ATen/NativeFunctions.h>
46#else
47#include <ATen/ops/empty.h>
48#include <ATen/ops/empty_like.h>
49#include <ATen/ops/reshape.h>
50#include <ATen/ops/scalar_tensor.h>
51#include <ATen/ops/sum.h>
52#include <ATen/ops/slice.h>
53#include <ATen/ops/narrow.h>
54#include <ATen/ops/pad.h>
55#endif
56
57#include <ATen/native/transformers/hip/aotriton_adapter.h>
58#include <ATen/native/transformers/hip/flash_attn/flash_api.h>
59
60#include <c10/util/Exception.h>
61#include <c10/util/CallOnce.h>
62
63// AOTriton headers
64#include <aotriton/flash.h>
65#include <aotriton/runtime.h>
66
67namespace pytorch_flash {
68
69namespace {
70
71void check_gpu_arch(hipStream_t stream) {
72  auto ret = aotriton::v2::flash::check_gpu(stream);
73  if (hipSuccess != ret) {
74      TORCH_CHECK(false,
75                  "[AOTriton] Accelerated SDPA only supports MI200/MI300X/Navi31 GPUs"
76                  " (gfx90a:sramecc+:xnack-/gfx942:sramecc+:xnack-/gfx1100)")
77  }
78}
79
80}
81
82#define CHECK_DEVICE(x) TORCH_CHECK(x.is_cuda(), #x " must be on CUDA")
83#define CHECK_SHAPE(x, ...) TORCH_CHECK(x.sizes() == at::IntArrayRef({__VA_ARGS__}), #x " must have shape (" #__VA_ARGS__ ")")
84#define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous")
85
86std::tuple<at::Tensor, at::Tensor, at::Tensor, at::Tensor, at::Tensor, at::Tensor, at::Tensor, at::Tensor>
87mha_fwd(const at::Tensor &q,         // batch_size x seqlen_q x num_heads x head_size
88        const at::Tensor &k,         // batch_size x seqlen_k x num_heads_k x head_size
89        const at::Tensor &v,         // batch_size x seqlen_k x num_heads_k x head_size
90        std::optional<at::Tensor> &out_,             // batch_size x seqlen_q x num_heads x head_size
91        std::optional<at::Tensor> &alibi_slopes_, // num_heads or batch_size x num_heads
92        const float p_dropout,
93        const float softmax_scale,
94        bool is_causal,
95        int window_size_left,
96        int window_size_right,
97        const bool return_softmax,
98        std::optional<at::Generator> gen_) {
99  auto stream = at::hip::getCurrentHIPStreamMasqueradingAsCUDA().stream();
100  check_gpu_arch(stream);
101
102  auto q_dtype = q.dtype();
103  TORCH_CHECK(q_dtype == at::kHalf || q_dtype == at::kBFloat16,
104              "FlashAttention only support fp16 and bf16 data type");
105  TORCH_CHECK(k.dtype() == q_dtype, "query and key must have the same dtype");
106  TORCH_CHECK(v.dtype() == q_dtype, "query and value must have the same dtype");
107
108  CHECK_DEVICE(q); CHECK_DEVICE(k); CHECK_DEVICE(v);
109
110  // FIXME: ROCM probably does not need this
111  TORCH_CHECK(q.stride(-1) == 1, "Input tensor must have contiguous last dimension");
112  TORCH_CHECK(k.stride(-1) == 1, "Input tensor must have contiguous last dimension");
113  TORCH_CHECK(v.stride(-1) == 1, "Input tensor must have contiguous last dimension");
114
115  const auto sizes = q.sizes();
116
117  const int batch_size = sizes[0];
118  int seqlen_q = sizes[1];
119  int num_heads = sizes[2];
120  const int head_size_og = sizes[3];
121  const int seqlen_k = k.size(1);
122  const int num_heads_k = k.size(2);
123  TORCH_CHECK(batch_size > 0, "batch size must be positive");
124  TORCH_CHECK(head_size_og % 8 == 0, "head_size must be a multiple of 8, this is ensured by padding!");
125  TORCH_CHECK(head_size_og <= 256, "FlashAttention forward only supports head dimension at most 256");
126  TORCH_CHECK(num_heads % num_heads_k == 0, "Number of heads in key/value must divide number of heads in query");
127
128  if (seqlen_q == 1) { is_causal = false; }  // causal=true is the same as causal=false in this case
129  if (is_causal) { window_size_right = 0; }
130
131  CHECK_SHAPE(q, batch_size, seqlen_q, num_heads, head_size_og);
132  CHECK_SHAPE(k, batch_size, seqlen_k, num_heads_k, head_size_og);
133  CHECK_SHAPE(v, batch_size, seqlen_k, num_heads_k, head_size_og);
134
135  at::Tensor q_padded, k_padded, v_padded;
136  q_padded = q;
137  k_padded = k;
138  v_padded = v;
139
140  at::Tensor out;
141  if (out_.has_value()) {
142    out = out_.value();
143    TORCH_CHECK(out.dtype() == q_dtype, "Output must have the same dtype as inputs");
144    CHECK_DEVICE(out);
145    TORCH_CHECK(out.stride(-1) == 1, "Output tensor must have contiguous last dimension");
146    CHECK_SHAPE(out, batch_size, seqlen_q, num_heads, head_size_og);
147    if (head_size_og % 8 != 0) { out = at::empty_like(q_padded); }
148  } else {
149    out = at::empty_like(q_padded);
150  }
151
152  auto round_multiple = [](int x, int m) { return (x + m - 1) / m * m; };
153  const int head_size = round_multiple(head_size_og, 8);
154  const int head_size_rounded = round_multiple(head_size, 32);
155  const int seqlen_q_rounded = round_multiple(seqlen_q, 128);
156  const int seqlen_k_rounded = round_multiple(seqlen_k, 128);
157
158  // Otherwise the kernel will be launched from cuda:0 device
159  // Cast to char to avoid compiler warning about narrowing
160  at::hip::HIPGuardMasqueradingAsCUDA device_guard{(char)q.get_device()};
161
162  // We want to checkpoint and save the RNG state for backward if dropout
163  // We get the default generator and return the seed and offset which will
164  // be used in the backward function
165  auto gen = at::get_generator_or_default<at::CUDAGeneratorImpl>(std::nullopt, at::cuda::detail::getDefaultCUDAGenerator());
166  at::Tensor seed_t, offset_t;
167
168  at::PhiloxCudaState philox_state;
169  bool use_philox_state = false;
170  if (p_dropout > 0.0)  {
171    // number of times random will be generated per thread, to offset philox counter in thc random
172    // state
173    // We use a custom RNG that increases the offset by batch_size * nheads * 32.
174    int64_t counter_offset = batch_size * num_heads * 32;
175    // See Note [Acquire lock when using random generators]
176    std::lock_guard<std::mutex> lock(gen->mutex_);
177    philox_state = gen->philox_cuda_state(counter_offset);
178    if (at::cuda::currentStreamCaptureStatus() == at::cuda::CaptureStatus::None) {
179      auto [seed, offset] = at::cuda::philox::unpack(philox_state);
180      seed_t = at::scalar_tensor(at::Scalar(static_cast<int64_t>(seed)), at::dtype(at::kLong).device(at::kCUDA));
181      offset_t = at::scalar_tensor(at::Scalar(static_cast<int64_t>(offset)), at::dtype(at::kLong).device(at::kCUDA));
182    } else {
183      // See Note [CUDA Graph-safe RNG states] about the design
184      use_philox_state = true;
185      seed_t = at::empty({}, at::dtype(at::kLong).device(at::kCUDA));
186      offset_t = at::empty({}, at::dtype(at::kLong).device(at::kCUDA));
187    }
188  } else {
189    if (at::cuda::currentStreamCaptureStatus() != at::cuda::CaptureStatus::None) {
190      seed_t = at::empty({}, at::dtype(at::kLong).device(at::kCUDA));
191      offset_t = at::empty({}, at::dtype(at::kLong).device(at::kCUDA));
192    } else {
193      seed_t = at::empty({}, at::dtype(at::kLong).device(at::kCUDA));
194      offset_t = at::empty({}, at::dtype(at::kLong).device(at::kCUDA));
195    }
196  }
197
198  at::PhiloxCudaState philox_args;
199  if (p_dropout > 0.0) {
200    if (at::cuda::currentStreamCaptureStatus() ==
201        at::cuda::CaptureStatus::None)
202    {
203      philox_args = at::PhiloxCudaState(*seed_t.data_ptr<int64_t>(), *offset_t.data_ptr<int64_t>());
204    } else { // dropout + capture
205      philox_args = at::PhiloxCudaState(seed_t.data_ptr<int64_t>(), offset_t.data_ptr<int64_t>(), 0);
206    }
207  }
208
209  // Transpose tensors to meet AOTriton's Flash API
210  at::Tensor q_t = q_padded.permute({0,2,1,3});
211  at::Tensor k_t = k_padded.permute({0,2,1,3});
212  at::Tensor v_t = v_padded.permute({0,2,1,3});
213  at::Tensor output_t = out.permute({0,2,1,3});
214
215  at::Tensor M = at::empty({batch_size * num_heads, seqlen_q}, at::dtype(at::kFloat).device(q.device())); // aka softmax_lse
216
217  at::Tensor softmax_fa_t;
218  if (return_softmax) {
219    softmax_fa_t = at::empty({batch_size, num_heads, seqlen_q, seqlen_k},
220                             at::dtype(q.dtype()).device(q.device()));
221  } else {
222    softmax_fa_t = at::empty({ 0, 0, 0, 0 }, at::dtype(q.dtype()).device(q.device()));
223  }
224
225  hipError_t err; // TODO: Error handling
226  using aotriton::v2::flash::attn_fwd;
227  using aotriton::TensorView;
228  using sdp::aotriton_adapter::mk_aotensor;
229  using sdp::aotriton_adapter::mk_aoscalartensor;
230  using sdp::aotriton_adapter::mk_philoxtensor;
231  using sdp::aotriton_adapter::cast_dtype;
232  aotriton::TensorView<4> empty_bias(0, {0,0,0,0}, {0,0,0,0}, cast_dtype(q.dtype()));
233  auto seed = use_philox_state ? mk_philoxtensor(philox_state.seed_.ptr) : mk_aoscalartensor(seed_t);
234  auto offset1 = use_philox_state ? mk_philoxtensor(philox_state.offset_.ptr) : mk_aoscalartensor(offset_t);
235  auto offset2 = use_philox_state ? philox_state.offset_intragraph_ : 0;
236  auto seed_output = use_philox_state ? mk_philoxtensor(seed_t.data_ptr<int64_t>()) : mk_philoxtensor(nullptr);
237  auto offset_output = use_philox_state ? mk_philoxtensor(offset_t.data_ptr<int64_t>()) : mk_philoxtensor(nullptr);
238  err = attn_fwd(mk_aotensor(q_t, "q"),
239                 mk_aotensor(k_t, "k"),
240                 mk_aotensor(v_t, "v"),
241                 empty_bias,
242                 softmax_scale,
243                 mk_aotensor<2>(M, "M"),
244                 mk_aotensor(output_t, "Out"),
245                 p_dropout,
246                 seed,
247                 offset1,
248                 offset2,
249                 seed_output,
250                 offset_output,
251                 mk_aotensor(softmax_fa_t, "encoded_softmax"),
252                 is_causal,
253                 stream);
254
255  return {out, q_padded, k_padded, v_padded, M, seed_t, offset_t, softmax_fa_t};
256}
257
258std::tuple<at::Tensor, at::Tensor, at::Tensor, at::Tensor, at::Tensor, at::Tensor, at::Tensor, at::Tensor>
259mha_varlen_fwd(const at::Tensor &q,  // total_q x num_heads x head_size, total_q := \sum_{i=0}^{b} s_i
260               const at::Tensor &k,  // total_k x num_heads_k x head_size, total_k := \sum_{i=0}^{b} s_i
261               const at::Tensor &v,  // total_k x num_heads_k x head_size, total_k := \sum_{i=0}^{b} s_i
262               std::optional<at::Tensor> &out_, // total_q x num_heads x head_size, total_k := \sum_{i=0}^{b} s_i
263               const at::Tensor &cu_seqlens_q,  // b+1
264               const at::Tensor &cu_seqlens_k,  // b+1
265               std::optional<at::Tensor> &seqused_k, // b. If given, only this many elements of each batch element's keys are used.
266               std::optional<at::Tensor> &block_table_, // batch_size x max_num_blocks_per_seq
267               std::optional<at::Tensor> &alibi_slopes_, // num_heads or b x num_heads
268               int max_seqlen_q,
269               const int max_seqlen_k,
270               const float p_dropout,
271               const float softmax_scale,
272               const bool zero_tensors,
273               bool is_causal,
274               int window_size_left,
275               int window_size_right,
276               const bool return_softmax,
277               std::optional<at::Generator> gen_) {
278
279  TORCH_CHECK(false, "mha_varlen_fwd not supported on ROCm");
280
281  at::Tensor softmax_lse = at::empty({}, at::dtype(at::kFloat));
282  at::Tensor p = at::empty({}, at::dtype(at::kFloat));
283  at::Tensor offset_t = at::empty({}, at::dtype(at::kLong));
284  at::Tensor seed_t = at::empty({}, at::dtype(at::kLong));
285  at::Tensor out = at::empty({}, at::dtype(at::kFloat));
286
287  return {out, q, k, v, softmax_lse, seed_t, offset_t, p};
288}
289
290std::tuple<at::Tensor, at::Tensor, at::Tensor, at::Tensor>
291mha_bwd(const at::Tensor &dout,  // batch_size x seqlen_q x num_heads, x head_size_og
292        const at::Tensor &q,   // batch_size x seqlen_q x num_heads x head_size
293        const at::Tensor &k,   // batch_size x seqlen_k x num_heads_k x head_size
294        const at::Tensor &v,   // batch_size x seqlen_k x num_heads_k x head_size
295        const at::Tensor &out,   // batch_size x seqlen_q x num_heads x head_size
296        const at::Tensor &softmax_lse,     // b x h x seqlen_q
297        std::optional<at::Tensor> &dq_,   // batch_size x seqlen_q x num_heads x head_size
298        std::optional<at::Tensor> &dk_,   // batch_size x seqlen_k x num_heads_k x head_size
299        std::optional<at::Tensor> &dv_,   // batch_size x seqlen_k x num_heads_k x head_size
300        std::optional<at::Tensor> &alibi_slopes_, // num_heads or batch_size x num_heads
301        const float p_dropout,         // probability to drop
302        const float softmax_scale,
303        const bool is_causal,
304        int window_size_left,
305        int window_size_right,
306        const bool deterministic,
307        const at::Tensor philox_seed,
308        const at::Tensor philox_offset) {
309  auto stream = at::hip::getCurrentHIPStreamMasqueradingAsCUDA().stream();
310  check_gpu_arch(stream);
311
312  bool is_dropout = p_dropout > 0.0;
313
314  auto q_dtype = q.dtype();
315  TORCH_CHECK(q_dtype == at::kHalf || q_dtype == at::kBFloat16,
316              "FlashAttention only support fp16 and bf16 data type");
317  TORCH_CHECK(k.dtype() == q_dtype, "query and key must have the same dtype");
318  TORCH_CHECK(v.dtype() == q_dtype, "query and value must have the same dtype");
319  TORCH_CHECK(out.dtype() == q_dtype, "query and out must have the same dtype");
320  TORCH_CHECK(dout.dtype() == q_dtype, "query and dout must have the same dtype");
321
322  CHECK_DEVICE(q); CHECK_DEVICE(k); CHECK_DEVICE(v);
323  CHECK_DEVICE(out); CHECK_DEVICE(dout); CHECK_DEVICE(softmax_lse);
324
325  TORCH_CHECK(q.stride(-1) == 1, "Input tensor must have contiguous last dimension");
326  TORCH_CHECK(k.stride(-1) == 1, "Input tensor must have contiguous last dimension");
327  TORCH_CHECK(v.stride(-1) == 1, "Input tensor must have contiguous last dimension");
328  TORCH_CHECK(out.stride(-1) == 1, "out tensor must have contiguous last dimension");
329  TORCH_CHECK(dout.stride(-1) == 1, "dout tensor must have contiguous last dimension");
330
331  const auto sizes = q.sizes();
332
333  const int batch_size = sizes[0];
334  const int seqlen_q = sizes[1];
335  const int num_heads = sizes[2];
336  const int head_size_og = dout.size(3);
337  const int head_size = sizes[3];
338  const int seqlen_k = k.size(1);
339  const int num_heads_k = k.size(2);
340
341  if (is_causal){
342    TORCH_CHECK((seqlen_q == seqlen_k), "For backwards kernel seqlen_q must equal seqlen_k for causal kernels");
343  }
344
345  TORCH_CHECK(batch_size > 0, "batch size must be positive");
346  TORCH_CHECK(head_size % 8 == 0, "head_size should be a multiple of 8");
347  TORCH_CHECK(head_size_og % 8 == 0, "head_size_og should be a multiple of 8, this is ensured by padding!");
348  TORCH_CHECK(head_size <= 256, "FlashAttention backward only supports head dimension at most 256");
349  TORCH_CHECK(num_heads % num_heads_k == 0, "Number of heads in key/value must divide number of heads in query");
350
351  auto round_multiple = [](int x, int m) { return (x + m - 1) / m * m; };
352  const int head_size_rounded = round_multiple(head_size, 32);
353  const int seqlen_q_rounded = round_multiple(seqlen_q, 128);
354  const int seqlen_k_rounded = round_multiple(seqlen_k, 128);
355
356  TORCH_CHECK(head_size == round_multiple(head_size_og, 8), "head_size must be head_size_og rounded to a multiple of 8");
357
358  CHECK_SHAPE(q, batch_size, seqlen_q, num_heads, head_size);
359  CHECK_SHAPE(k, batch_size, seqlen_k, num_heads_k, head_size);
360  CHECK_SHAPE(v, batch_size, seqlen_k, num_heads_k, head_size);
361  CHECK_SHAPE(out, batch_size, seqlen_q, num_heads, head_size);
362  CHECK_SHAPE(dout, batch_size, seqlen_q, num_heads, head_size_og);
363
364  at::Tensor dq, dk, dv;
365  if (dq_.has_value()) {
366    dq = dq_.value();
367    TORCH_CHECK(dq.dtype() == q_dtype, "dq must have the same dtype as q");
368    CHECK_DEVICE(dq);
369    TORCH_CHECK(dq.stride(-1) == 1, "dq must have contiguous last dimension");
370    CHECK_SHAPE(dq, batch_size, seqlen_q, num_heads, head_size);
371  } else {
372    dq = at::empty_like(q);
373  }
374  if (dk_.has_value()) {
375    dk = dk_.value();
376    TORCH_CHECK(dk.dtype() == q_dtype, "dk must have the same dtype as q");
377    CHECK_DEVICE(dk);
378    TORCH_CHECK(dk.stride(-1) == 1, "dk must have contiguous last dimension");
379    CHECK_SHAPE(dk, batch_size, seqlen_k, num_heads_k, head_size);
380  } else {
381    dk = at::empty_like(k);
382  }
383  if (dv_.has_value()) {
384    dv = dv_.value();
385    TORCH_CHECK(dv.dtype() == q_dtype, "dv must have the same dtype as q");
386    CHECK_DEVICE(dv);
387    TORCH_CHECK(dv.stride(-1) == 1, "dv must have contiguous last dimension");
388    CHECK_SHAPE(dv, batch_size, seqlen_k, num_heads_k, head_size);
389  } else {
390    dv = at::empty_like(k);
391  }
392
393  // const at::Tensor& dout_padded = dout;
394
395  // Otherwise the kernel will be launched from cuda:0 device
396  // Cast to char to avoid compiler warning about narrowing
397  at::hip::HIPGuardMasqueradingAsCUDA device_guard{(char)q.get_device()};
398
399  auto opts = q.options();
400  auto softmax_d = at::empty({batch_size, num_heads, seqlen_q_rounded}, opts.dtype(at::kFloat));
401
402  at::Tensor dk_expanded, dv_expanded;
403  if (num_heads_k != num_heads) {  // MQA / GQA
404    dk_expanded = at::empty({batch_size, seqlen_k, num_heads, head_size}, opts);
405    dv_expanded = at::empty({batch_size, seqlen_k, num_heads, head_size}, opts);
406  } else {
407    dk_expanded = dk;
408    dv_expanded = dv;
409  }
410
411  at::PhiloxCudaState philox_args;
412  if (p_dropout > 0.0) {
413    if (at::cuda::currentStreamCaptureStatus() ==
414        at::cuda::CaptureStatus::None)
415    {
416      philox_args = at::PhiloxCudaState(*philox_seed.data_ptr<int64_t>(), *philox_offset.data_ptr<int64_t>());
417    } else { // dropout + capture
418      philox_args = at::PhiloxCudaState(philox_seed.data_ptr<int64_t>(), philox_offset.data_ptr<int64_t>(), 0);
419    }
420  }
421
422  at::Tensor q_t = q.permute({0,2,1,3});
423  at::Tensor k_t = k.permute({0,2,1,3});
424  at::Tensor v_t = v.permute({0,2,1,3});
425  at::Tensor out_t = out.permute({0,2,1,3});
426  at::Tensor dq_t = dq.permute({0,2,1,3});
427  at::Tensor dk_t = dk.permute({0,2,1,3});
428  at::Tensor dv_t = dv.permute({0,2,1,3});
429  at::Tensor dout_t = dout.permute({0,2,1,3});
430
431  at::Tensor softmax_lse_cont = softmax_lse.contiguous();
432  at::Tensor delta = at::empty_like(softmax_lse).contiguous();
433
434  int d_head = head_size_og;
435  hipError_t err; // TODO: Error handling
436  {
437    using aotriton::v2::flash::attn_bwd;
438    using sdp::aotriton_adapter::mk_aotensor;
439    using sdp::aotriton_adapter::mk_aoscalartensor;
440    using sdp::aotriton_adapter::cast_dtype;
441    aotriton::TensorView<4> empty_bias(0, {0,0,0,0}, {0,0,0,0}, cast_dtype(q.dtype()));
442    err = attn_bwd(mk_aotensor(q_t, "q"),
443                   mk_aotensor(k_t, "k"),
444                   mk_aotensor(v_t, "v"),
445                   empty_bias,
446                   softmax_scale,
447                   mk_aotensor(out_t, "out"),
448                   mk_aotensor(dout_t, "dout"),
449                   mk_aotensor(dq_t, "dq"),
450                   mk_aotensor(dk_t, "dk"),
451                   mk_aotensor(dv_t, "dv"),
452                   empty_bias,
453                   mk_aotensor<2>(softmax_lse_cont, "L"),
454                   mk_aotensor<2>(delta, "delta"),
455                   p_dropout,
456                   mk_aoscalartensor(philox_seed),
457                   mk_aoscalartensor(philox_offset),
458                   0,
459                   is_causal,
460                   stream);
461  }
462
463  // For MQA/GQA we need to sum dK and dV across the groups
464  if (num_heads_k != num_heads) {
465    at::sum_out(dk, at::reshape(dk_expanded, {batch_size, seqlen_k, num_heads_k, num_heads / num_heads_k, head_size}), {3});
466    at::sum_out(dv, at::reshape(dv_expanded, {batch_size, seqlen_k, num_heads_k, num_heads / num_heads_k, head_size}), {3});
467  }
468  return { dq, dk, dv, softmax_d };
469#undef CALL_BWD_DROPOUT
470#undef CALL_BWD
471}
472
473std::tuple<at::Tensor, at::Tensor, at::Tensor, at::Tensor>
474mha_varlen_bwd(const at::Tensor &dout,  // total_q x num_heads, x head_size
475               const at::Tensor &q,   // total_q x num_heads x head_size, total_q := \sum_{i=0}^{b} s_i
476               const at::Tensor &k,   // total_k x num_heads_k x head_size, total_k := \sum_{i=0}^{b} s_i
477               const at::Tensor &v,   // total_k x num_heads_k x head_size, total_k := \sum_{i=0}^{b} s_i
478               const at::Tensor &out,   // total_q x num_heads x head_size
479               const at::Tensor &softmax_lse,     // b x h x s   softmax logsumexp
480               std::optional<at::Tensor> &dq_,   // total_q x num_heads x head_size, total_q := \sum_{i=0}^{b} s_i
481               std::optional<at::Tensor> &dk_,   // total_k x num_heads_k x head_size, total_k := \sum_{i=0}^{b} s_i
482               std::optional<at::Tensor> &dv_,   // total_k x num_heads_k x head_size, total_k := \sum_{i=0}^{b} s_i
483               const at::Tensor &cu_seqlens_q,  // b+1
484               const at::Tensor &cu_seqlens_k,  // b+1
485               std::optional<at::Tensor> &alibi_slopes_, // num_heads or b x num_heads
486               const int max_seqlen_q,
487               const int max_seqlen_k,          // max sequence length to choose the kernel
488               const float p_dropout,         // probability to drop
489               const float softmax_scale,
490               const bool zero_tensors,
491               const bool is_causal,
492               int window_size_left,
493               int window_size_right,
494               const bool deterministic,
495               const at::Tensor philox_seed,
496               const at::Tensor philox_offset) {
497  TORCH_CHECK(false, "mha_varlen_bwd not supported on ROCm");
498
499  at::Tensor softmax_d = at::empty({}, at::dtype(at::kFloat));
500
501  return { q, k, v, softmax_d };
502}
503} // namespace pytorch_fmha
504
505#endif
506