1/****************************************************************************** 2 * Copyright (c) 2023, Advanced Micro Devices, Inc. 3 * Copyright (c) 2022, Tri Dao. 4 * Copyright (c) 2011-2021, NVIDIA CORPORATION. All rights reserved. 5 * 6 * Redistribution and use in source and binary forms, with or without 7 * modification, are permitted provided that the following conditions are met: 8 * * Redistributions of source code must retain the above copyright 9 * notice, this list of conditions and the following disclaimer. 10 * * Redistributions in binary form must reproduce the above copyright 11 * notice, this list of conditions and the following disclaimer in the 12 * documentation and/or other materials provided with the distribution. 13 * * Neither the name of the NVIDIA CORPORATION nor the 14 * names of its contributors may be used to endorse or promote products 15 * derived from this software without specific prior written permission. 16 * 17 * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND 18 * ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED 19 * WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE 20 * DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY 21 * DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES 22 * (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; 23 * LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND 24 * ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT 25 * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS 26 * SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 27 * 28 ******************************************************************************/ 29#include <c10/core/ScalarType.h> 30#define TORCH_ASSERT_ONLY_METHOD_OPERATORS 31 32#include <cstdint> 33#include <tuple> 34 35#include <ATen/ops/zeros.h> 36 37#ifdef USE_FLASH_ATTENTION 38#include <ATen/core/Tensor.h> 39#include <ATen/hip/HIPContext.h> 40#include <ATen/hip/impl/HIPGuardImplMasqueradingAsCUDA.h> 41#include <ATen/hip/HIPGraphsUtils.cuh> 42 43#ifndef AT_PER_OPERATOR_HEADERS 44#include <ATen/Functions.h> 45#include <ATen/NativeFunctions.h> 46#else 47#include <ATen/ops/empty.h> 48#include <ATen/ops/empty_like.h> 49#include <ATen/ops/reshape.h> 50#include <ATen/ops/scalar_tensor.h> 51#include <ATen/ops/sum.h> 52#include <ATen/ops/slice.h> 53#include <ATen/ops/narrow.h> 54#include <ATen/ops/pad.h> 55#endif 56 57#include <ATen/native/transformers/hip/aotriton_adapter.h> 58#include <ATen/native/transformers/hip/flash_attn/flash_api.h> 59 60#include <c10/util/Exception.h> 61#include <c10/util/CallOnce.h> 62 63// AOTriton headers 64#include <aotriton/flash.h> 65#include <aotriton/runtime.h> 66 67namespace pytorch_flash { 68 69namespace { 70 71void check_gpu_arch(hipStream_t stream) { 72 auto ret = aotriton::v2::flash::check_gpu(stream); 73 if (hipSuccess != ret) { 74 TORCH_CHECK(false, 75 "[AOTriton] Accelerated SDPA only supports MI200/MI300X/Navi31 GPUs" 76 " (gfx90a:sramecc+:xnack-/gfx942:sramecc+:xnack-/gfx1100)") 77 } 78} 79 80} 81 82#define CHECK_DEVICE(x) TORCH_CHECK(x.is_cuda(), #x " must be on CUDA") 83#define CHECK_SHAPE(x, ...) TORCH_CHECK(x.sizes() == at::IntArrayRef({__VA_ARGS__}), #x " must have shape (" #__VA_ARGS__ ")") 84#define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous") 85 86std::tuple<at::Tensor, at::Tensor, at::Tensor, at::Tensor, at::Tensor, at::Tensor, at::Tensor, at::Tensor> 87mha_fwd(const at::Tensor &q, // batch_size x seqlen_q x num_heads x head_size 88 const at::Tensor &k, // batch_size x seqlen_k x num_heads_k x head_size 89 const at::Tensor &v, // batch_size x seqlen_k x num_heads_k x head_size 90 std::optional<at::Tensor> &out_, // batch_size x seqlen_q x num_heads x head_size 91 std::optional<at::Tensor> &alibi_slopes_, // num_heads or batch_size x num_heads 92 const float p_dropout, 93 const float softmax_scale, 94 bool is_causal, 95 int window_size_left, 96 int window_size_right, 97 const bool return_softmax, 98 std::optional<at::Generator> gen_) { 99 auto stream = at::hip::getCurrentHIPStreamMasqueradingAsCUDA().stream(); 100 check_gpu_arch(stream); 101 102 auto q_dtype = q.dtype(); 103 TORCH_CHECK(q_dtype == at::kHalf || q_dtype == at::kBFloat16, 104 "FlashAttention only support fp16 and bf16 data type"); 105 TORCH_CHECK(k.dtype() == q_dtype, "query and key must have the same dtype"); 106 TORCH_CHECK(v.dtype() == q_dtype, "query and value must have the same dtype"); 107 108 CHECK_DEVICE(q); CHECK_DEVICE(k); CHECK_DEVICE(v); 109 110 // FIXME: ROCM probably does not need this 111 TORCH_CHECK(q.stride(-1) == 1, "Input tensor must have contiguous last dimension"); 112 TORCH_CHECK(k.stride(-1) == 1, "Input tensor must have contiguous last dimension"); 113 TORCH_CHECK(v.stride(-1) == 1, "Input tensor must have contiguous last dimension"); 114 115 const auto sizes = q.sizes(); 116 117 const int batch_size = sizes[0]; 118 int seqlen_q = sizes[1]; 119 int num_heads = sizes[2]; 120 const int head_size_og = sizes[3]; 121 const int seqlen_k = k.size(1); 122 const int num_heads_k = k.size(2); 123 TORCH_CHECK(batch_size > 0, "batch size must be positive"); 124 TORCH_CHECK(head_size_og % 8 == 0, "head_size must be a multiple of 8, this is ensured by padding!"); 125 TORCH_CHECK(head_size_og <= 256, "FlashAttention forward only supports head dimension at most 256"); 126 TORCH_CHECK(num_heads % num_heads_k == 0, "Number of heads in key/value must divide number of heads in query"); 127 128 if (seqlen_q == 1) { is_causal = false; } // causal=true is the same as causal=false in this case 129 if (is_causal) { window_size_right = 0; } 130 131 CHECK_SHAPE(q, batch_size, seqlen_q, num_heads, head_size_og); 132 CHECK_SHAPE(k, batch_size, seqlen_k, num_heads_k, head_size_og); 133 CHECK_SHAPE(v, batch_size, seqlen_k, num_heads_k, head_size_og); 134 135 at::Tensor q_padded, k_padded, v_padded; 136 q_padded = q; 137 k_padded = k; 138 v_padded = v; 139 140 at::Tensor out; 141 if (out_.has_value()) { 142 out = out_.value(); 143 TORCH_CHECK(out.dtype() == q_dtype, "Output must have the same dtype as inputs"); 144 CHECK_DEVICE(out); 145 TORCH_CHECK(out.stride(-1) == 1, "Output tensor must have contiguous last dimension"); 146 CHECK_SHAPE(out, batch_size, seqlen_q, num_heads, head_size_og); 147 if (head_size_og % 8 != 0) { out = at::empty_like(q_padded); } 148 } else { 149 out = at::empty_like(q_padded); 150 } 151 152 auto round_multiple = [](int x, int m) { return (x + m - 1) / m * m; }; 153 const int head_size = round_multiple(head_size_og, 8); 154 const int head_size_rounded = round_multiple(head_size, 32); 155 const int seqlen_q_rounded = round_multiple(seqlen_q, 128); 156 const int seqlen_k_rounded = round_multiple(seqlen_k, 128); 157 158 // Otherwise the kernel will be launched from cuda:0 device 159 // Cast to char to avoid compiler warning about narrowing 160 at::hip::HIPGuardMasqueradingAsCUDA device_guard{(char)q.get_device()}; 161 162 // We want to checkpoint and save the RNG state for backward if dropout 163 // We get the default generator and return the seed and offset which will 164 // be used in the backward function 165 auto gen = at::get_generator_or_default<at::CUDAGeneratorImpl>(std::nullopt, at::cuda::detail::getDefaultCUDAGenerator()); 166 at::Tensor seed_t, offset_t; 167 168 at::PhiloxCudaState philox_state; 169 bool use_philox_state = false; 170 if (p_dropout > 0.0) { 171 // number of times random will be generated per thread, to offset philox counter in thc random 172 // state 173 // We use a custom RNG that increases the offset by batch_size * nheads * 32. 174 int64_t counter_offset = batch_size * num_heads * 32; 175 // See Note [Acquire lock when using random generators] 176 std::lock_guard<std::mutex> lock(gen->mutex_); 177 philox_state = gen->philox_cuda_state(counter_offset); 178 if (at::cuda::currentStreamCaptureStatus() == at::cuda::CaptureStatus::None) { 179 auto [seed, offset] = at::cuda::philox::unpack(philox_state); 180 seed_t = at::scalar_tensor(at::Scalar(static_cast<int64_t>(seed)), at::dtype(at::kLong).device(at::kCUDA)); 181 offset_t = at::scalar_tensor(at::Scalar(static_cast<int64_t>(offset)), at::dtype(at::kLong).device(at::kCUDA)); 182 } else { 183 // See Note [CUDA Graph-safe RNG states] about the design 184 use_philox_state = true; 185 seed_t = at::empty({}, at::dtype(at::kLong).device(at::kCUDA)); 186 offset_t = at::empty({}, at::dtype(at::kLong).device(at::kCUDA)); 187 } 188 } else { 189 if (at::cuda::currentStreamCaptureStatus() != at::cuda::CaptureStatus::None) { 190 seed_t = at::empty({}, at::dtype(at::kLong).device(at::kCUDA)); 191 offset_t = at::empty({}, at::dtype(at::kLong).device(at::kCUDA)); 192 } else { 193 seed_t = at::empty({}, at::dtype(at::kLong).device(at::kCUDA)); 194 offset_t = at::empty({}, at::dtype(at::kLong).device(at::kCUDA)); 195 } 196 } 197 198 at::PhiloxCudaState philox_args; 199 if (p_dropout > 0.0) { 200 if (at::cuda::currentStreamCaptureStatus() == 201 at::cuda::CaptureStatus::None) 202 { 203 philox_args = at::PhiloxCudaState(*seed_t.data_ptr<int64_t>(), *offset_t.data_ptr<int64_t>()); 204 } else { // dropout + capture 205 philox_args = at::PhiloxCudaState(seed_t.data_ptr<int64_t>(), offset_t.data_ptr<int64_t>(), 0); 206 } 207 } 208 209 // Transpose tensors to meet AOTriton's Flash API 210 at::Tensor q_t = q_padded.permute({0,2,1,3}); 211 at::Tensor k_t = k_padded.permute({0,2,1,3}); 212 at::Tensor v_t = v_padded.permute({0,2,1,3}); 213 at::Tensor output_t = out.permute({0,2,1,3}); 214 215 at::Tensor M = at::empty({batch_size * num_heads, seqlen_q}, at::dtype(at::kFloat).device(q.device())); // aka softmax_lse 216 217 at::Tensor softmax_fa_t; 218 if (return_softmax) { 219 softmax_fa_t = at::empty({batch_size, num_heads, seqlen_q, seqlen_k}, 220 at::dtype(q.dtype()).device(q.device())); 221 } else { 222 softmax_fa_t = at::empty({ 0, 0, 0, 0 }, at::dtype(q.dtype()).device(q.device())); 223 } 224 225 hipError_t err; // TODO: Error handling 226 using aotriton::v2::flash::attn_fwd; 227 using aotriton::TensorView; 228 using sdp::aotriton_adapter::mk_aotensor; 229 using sdp::aotriton_adapter::mk_aoscalartensor; 230 using sdp::aotriton_adapter::mk_philoxtensor; 231 using sdp::aotriton_adapter::cast_dtype; 232 aotriton::TensorView<4> empty_bias(0, {0,0,0,0}, {0,0,0,0}, cast_dtype(q.dtype())); 233 auto seed = use_philox_state ? mk_philoxtensor(philox_state.seed_.ptr) : mk_aoscalartensor(seed_t); 234 auto offset1 = use_philox_state ? mk_philoxtensor(philox_state.offset_.ptr) : mk_aoscalartensor(offset_t); 235 auto offset2 = use_philox_state ? philox_state.offset_intragraph_ : 0; 236 auto seed_output = use_philox_state ? mk_philoxtensor(seed_t.data_ptr<int64_t>()) : mk_philoxtensor(nullptr); 237 auto offset_output = use_philox_state ? mk_philoxtensor(offset_t.data_ptr<int64_t>()) : mk_philoxtensor(nullptr); 238 err = attn_fwd(mk_aotensor(q_t, "q"), 239 mk_aotensor(k_t, "k"), 240 mk_aotensor(v_t, "v"), 241 empty_bias, 242 softmax_scale, 243 mk_aotensor<2>(M, "M"), 244 mk_aotensor(output_t, "Out"), 245 p_dropout, 246 seed, 247 offset1, 248 offset2, 249 seed_output, 250 offset_output, 251 mk_aotensor(softmax_fa_t, "encoded_softmax"), 252 is_causal, 253 stream); 254 255 return {out, q_padded, k_padded, v_padded, M, seed_t, offset_t, softmax_fa_t}; 256} 257 258std::tuple<at::Tensor, at::Tensor, at::Tensor, at::Tensor, at::Tensor, at::Tensor, at::Tensor, at::Tensor> 259mha_varlen_fwd(const at::Tensor &q, // total_q x num_heads x head_size, total_q := \sum_{i=0}^{b} s_i 260 const at::Tensor &k, // total_k x num_heads_k x head_size, total_k := \sum_{i=0}^{b} s_i 261 const at::Tensor &v, // total_k x num_heads_k x head_size, total_k := \sum_{i=0}^{b} s_i 262 std::optional<at::Tensor> &out_, // total_q x num_heads x head_size, total_k := \sum_{i=0}^{b} s_i 263 const at::Tensor &cu_seqlens_q, // b+1 264 const at::Tensor &cu_seqlens_k, // b+1 265 std::optional<at::Tensor> &seqused_k, // b. If given, only this many elements of each batch element's keys are used. 266 std::optional<at::Tensor> &block_table_, // batch_size x max_num_blocks_per_seq 267 std::optional<at::Tensor> &alibi_slopes_, // num_heads or b x num_heads 268 int max_seqlen_q, 269 const int max_seqlen_k, 270 const float p_dropout, 271 const float softmax_scale, 272 const bool zero_tensors, 273 bool is_causal, 274 int window_size_left, 275 int window_size_right, 276 const bool return_softmax, 277 std::optional<at::Generator> gen_) { 278 279 TORCH_CHECK(false, "mha_varlen_fwd not supported on ROCm"); 280 281 at::Tensor softmax_lse = at::empty({}, at::dtype(at::kFloat)); 282 at::Tensor p = at::empty({}, at::dtype(at::kFloat)); 283 at::Tensor offset_t = at::empty({}, at::dtype(at::kLong)); 284 at::Tensor seed_t = at::empty({}, at::dtype(at::kLong)); 285 at::Tensor out = at::empty({}, at::dtype(at::kFloat)); 286 287 return {out, q, k, v, softmax_lse, seed_t, offset_t, p}; 288} 289 290std::tuple<at::Tensor, at::Tensor, at::Tensor, at::Tensor> 291mha_bwd(const at::Tensor &dout, // batch_size x seqlen_q x num_heads, x head_size_og 292 const at::Tensor &q, // batch_size x seqlen_q x num_heads x head_size 293 const at::Tensor &k, // batch_size x seqlen_k x num_heads_k x head_size 294 const at::Tensor &v, // batch_size x seqlen_k x num_heads_k x head_size 295 const at::Tensor &out, // batch_size x seqlen_q x num_heads x head_size 296 const at::Tensor &softmax_lse, // b x h x seqlen_q 297 std::optional<at::Tensor> &dq_, // batch_size x seqlen_q x num_heads x head_size 298 std::optional<at::Tensor> &dk_, // batch_size x seqlen_k x num_heads_k x head_size 299 std::optional<at::Tensor> &dv_, // batch_size x seqlen_k x num_heads_k x head_size 300 std::optional<at::Tensor> &alibi_slopes_, // num_heads or batch_size x num_heads 301 const float p_dropout, // probability to drop 302 const float softmax_scale, 303 const bool is_causal, 304 int window_size_left, 305 int window_size_right, 306 const bool deterministic, 307 const at::Tensor philox_seed, 308 const at::Tensor philox_offset) { 309 auto stream = at::hip::getCurrentHIPStreamMasqueradingAsCUDA().stream(); 310 check_gpu_arch(stream); 311 312 bool is_dropout = p_dropout > 0.0; 313 314 auto q_dtype = q.dtype(); 315 TORCH_CHECK(q_dtype == at::kHalf || q_dtype == at::kBFloat16, 316 "FlashAttention only support fp16 and bf16 data type"); 317 TORCH_CHECK(k.dtype() == q_dtype, "query and key must have the same dtype"); 318 TORCH_CHECK(v.dtype() == q_dtype, "query and value must have the same dtype"); 319 TORCH_CHECK(out.dtype() == q_dtype, "query and out must have the same dtype"); 320 TORCH_CHECK(dout.dtype() == q_dtype, "query and dout must have the same dtype"); 321 322 CHECK_DEVICE(q); CHECK_DEVICE(k); CHECK_DEVICE(v); 323 CHECK_DEVICE(out); CHECK_DEVICE(dout); CHECK_DEVICE(softmax_lse); 324 325 TORCH_CHECK(q.stride(-1) == 1, "Input tensor must have contiguous last dimension"); 326 TORCH_CHECK(k.stride(-1) == 1, "Input tensor must have contiguous last dimension"); 327 TORCH_CHECK(v.stride(-1) == 1, "Input tensor must have contiguous last dimension"); 328 TORCH_CHECK(out.stride(-1) == 1, "out tensor must have contiguous last dimension"); 329 TORCH_CHECK(dout.stride(-1) == 1, "dout tensor must have contiguous last dimension"); 330 331 const auto sizes = q.sizes(); 332 333 const int batch_size = sizes[0]; 334 const int seqlen_q = sizes[1]; 335 const int num_heads = sizes[2]; 336 const int head_size_og = dout.size(3); 337 const int head_size = sizes[3]; 338 const int seqlen_k = k.size(1); 339 const int num_heads_k = k.size(2); 340 341 if (is_causal){ 342 TORCH_CHECK((seqlen_q == seqlen_k), "For backwards kernel seqlen_q must equal seqlen_k for causal kernels"); 343 } 344 345 TORCH_CHECK(batch_size > 0, "batch size must be positive"); 346 TORCH_CHECK(head_size % 8 == 0, "head_size should be a multiple of 8"); 347 TORCH_CHECK(head_size_og % 8 == 0, "head_size_og should be a multiple of 8, this is ensured by padding!"); 348 TORCH_CHECK(head_size <= 256, "FlashAttention backward only supports head dimension at most 256"); 349 TORCH_CHECK(num_heads % num_heads_k == 0, "Number of heads in key/value must divide number of heads in query"); 350 351 auto round_multiple = [](int x, int m) { return (x + m - 1) / m * m; }; 352 const int head_size_rounded = round_multiple(head_size, 32); 353 const int seqlen_q_rounded = round_multiple(seqlen_q, 128); 354 const int seqlen_k_rounded = round_multiple(seqlen_k, 128); 355 356 TORCH_CHECK(head_size == round_multiple(head_size_og, 8), "head_size must be head_size_og rounded to a multiple of 8"); 357 358 CHECK_SHAPE(q, batch_size, seqlen_q, num_heads, head_size); 359 CHECK_SHAPE(k, batch_size, seqlen_k, num_heads_k, head_size); 360 CHECK_SHAPE(v, batch_size, seqlen_k, num_heads_k, head_size); 361 CHECK_SHAPE(out, batch_size, seqlen_q, num_heads, head_size); 362 CHECK_SHAPE(dout, batch_size, seqlen_q, num_heads, head_size_og); 363 364 at::Tensor dq, dk, dv; 365 if (dq_.has_value()) { 366 dq = dq_.value(); 367 TORCH_CHECK(dq.dtype() == q_dtype, "dq must have the same dtype as q"); 368 CHECK_DEVICE(dq); 369 TORCH_CHECK(dq.stride(-1) == 1, "dq must have contiguous last dimension"); 370 CHECK_SHAPE(dq, batch_size, seqlen_q, num_heads, head_size); 371 } else { 372 dq = at::empty_like(q); 373 } 374 if (dk_.has_value()) { 375 dk = dk_.value(); 376 TORCH_CHECK(dk.dtype() == q_dtype, "dk must have the same dtype as q"); 377 CHECK_DEVICE(dk); 378 TORCH_CHECK(dk.stride(-1) == 1, "dk must have contiguous last dimension"); 379 CHECK_SHAPE(dk, batch_size, seqlen_k, num_heads_k, head_size); 380 } else { 381 dk = at::empty_like(k); 382 } 383 if (dv_.has_value()) { 384 dv = dv_.value(); 385 TORCH_CHECK(dv.dtype() == q_dtype, "dv must have the same dtype as q"); 386 CHECK_DEVICE(dv); 387 TORCH_CHECK(dv.stride(-1) == 1, "dv must have contiguous last dimension"); 388 CHECK_SHAPE(dv, batch_size, seqlen_k, num_heads_k, head_size); 389 } else { 390 dv = at::empty_like(k); 391 } 392 393 // const at::Tensor& dout_padded = dout; 394 395 // Otherwise the kernel will be launched from cuda:0 device 396 // Cast to char to avoid compiler warning about narrowing 397 at::hip::HIPGuardMasqueradingAsCUDA device_guard{(char)q.get_device()}; 398 399 auto opts = q.options(); 400 auto softmax_d = at::empty({batch_size, num_heads, seqlen_q_rounded}, opts.dtype(at::kFloat)); 401 402 at::Tensor dk_expanded, dv_expanded; 403 if (num_heads_k != num_heads) { // MQA / GQA 404 dk_expanded = at::empty({batch_size, seqlen_k, num_heads, head_size}, opts); 405 dv_expanded = at::empty({batch_size, seqlen_k, num_heads, head_size}, opts); 406 } else { 407 dk_expanded = dk; 408 dv_expanded = dv; 409 } 410 411 at::PhiloxCudaState philox_args; 412 if (p_dropout > 0.0) { 413 if (at::cuda::currentStreamCaptureStatus() == 414 at::cuda::CaptureStatus::None) 415 { 416 philox_args = at::PhiloxCudaState(*philox_seed.data_ptr<int64_t>(), *philox_offset.data_ptr<int64_t>()); 417 } else { // dropout + capture 418 philox_args = at::PhiloxCudaState(philox_seed.data_ptr<int64_t>(), philox_offset.data_ptr<int64_t>(), 0); 419 } 420 } 421 422 at::Tensor q_t = q.permute({0,2,1,3}); 423 at::Tensor k_t = k.permute({0,2,1,3}); 424 at::Tensor v_t = v.permute({0,2,1,3}); 425 at::Tensor out_t = out.permute({0,2,1,3}); 426 at::Tensor dq_t = dq.permute({0,2,1,3}); 427 at::Tensor dk_t = dk.permute({0,2,1,3}); 428 at::Tensor dv_t = dv.permute({0,2,1,3}); 429 at::Tensor dout_t = dout.permute({0,2,1,3}); 430 431 at::Tensor softmax_lse_cont = softmax_lse.contiguous(); 432 at::Tensor delta = at::empty_like(softmax_lse).contiguous(); 433 434 int d_head = head_size_og; 435 hipError_t err; // TODO: Error handling 436 { 437 using aotriton::v2::flash::attn_bwd; 438 using sdp::aotriton_adapter::mk_aotensor; 439 using sdp::aotriton_adapter::mk_aoscalartensor; 440 using sdp::aotriton_adapter::cast_dtype; 441 aotriton::TensorView<4> empty_bias(0, {0,0,0,0}, {0,0,0,0}, cast_dtype(q.dtype())); 442 err = attn_bwd(mk_aotensor(q_t, "q"), 443 mk_aotensor(k_t, "k"), 444 mk_aotensor(v_t, "v"), 445 empty_bias, 446 softmax_scale, 447 mk_aotensor(out_t, "out"), 448 mk_aotensor(dout_t, "dout"), 449 mk_aotensor(dq_t, "dq"), 450 mk_aotensor(dk_t, "dk"), 451 mk_aotensor(dv_t, "dv"), 452 empty_bias, 453 mk_aotensor<2>(softmax_lse_cont, "L"), 454 mk_aotensor<2>(delta, "delta"), 455 p_dropout, 456 mk_aoscalartensor(philox_seed), 457 mk_aoscalartensor(philox_offset), 458 0, 459 is_causal, 460 stream); 461 } 462 463 // For MQA/GQA we need to sum dK and dV across the groups 464 if (num_heads_k != num_heads) { 465 at::sum_out(dk, at::reshape(dk_expanded, {batch_size, seqlen_k, num_heads_k, num_heads / num_heads_k, head_size}), {3}); 466 at::sum_out(dv, at::reshape(dv_expanded, {batch_size, seqlen_k, num_heads_k, num_heads / num_heads_k, head_size}), {3}); 467 } 468 return { dq, dk, dv, softmax_d }; 469#undef CALL_BWD_DROPOUT 470#undef CALL_BWD 471} 472 473std::tuple<at::Tensor, at::Tensor, at::Tensor, at::Tensor> 474mha_varlen_bwd(const at::Tensor &dout, // total_q x num_heads, x head_size 475 const at::Tensor &q, // total_q x num_heads x head_size, total_q := \sum_{i=0}^{b} s_i 476 const at::Tensor &k, // total_k x num_heads_k x head_size, total_k := \sum_{i=0}^{b} s_i 477 const at::Tensor &v, // total_k x num_heads_k x head_size, total_k := \sum_{i=0}^{b} s_i 478 const at::Tensor &out, // total_q x num_heads x head_size 479 const at::Tensor &softmax_lse, // b x h x s softmax logsumexp 480 std::optional<at::Tensor> &dq_, // total_q x num_heads x head_size, total_q := \sum_{i=0}^{b} s_i 481 std::optional<at::Tensor> &dk_, // total_k x num_heads_k x head_size, total_k := \sum_{i=0}^{b} s_i 482 std::optional<at::Tensor> &dv_, // total_k x num_heads_k x head_size, total_k := \sum_{i=0}^{b} s_i 483 const at::Tensor &cu_seqlens_q, // b+1 484 const at::Tensor &cu_seqlens_k, // b+1 485 std::optional<at::Tensor> &alibi_slopes_, // num_heads or b x num_heads 486 const int max_seqlen_q, 487 const int max_seqlen_k, // max sequence length to choose the kernel 488 const float p_dropout, // probability to drop 489 const float softmax_scale, 490 const bool zero_tensors, 491 const bool is_causal, 492 int window_size_left, 493 int window_size_right, 494 const bool deterministic, 495 const at::Tensor philox_seed, 496 const at::Tensor philox_offset) { 497 TORCH_CHECK(false, "mha_varlen_bwd not supported on ROCm"); 498 499 at::Tensor softmax_d = at::empty({}, at::dtype(at::kFloat)); 500 501 return { q, k, v, softmax_d }; 502} 503} // namespace pytorch_fmha 504 505#endif 506