1# Owner(s): ["module: nn"] 2 3import contextlib 4from functools import partial 5from collections import namedtuple 6import sys 7import torch 8import torch.nn as nn 9import torch.nn.functional as F 10from torch.nn.functional import scaled_dot_product_attention 11from torch.nn.attention import sdpa_kernel, SDPBackend 12from torch.nn.attention.bias import CausalVariant, causal_lower_right, causal_upper_left 13from torch.nn.parameter import Parameter 14import unittest 15from unittest.mock import patch, MagicMock, ANY 16import math 17import torch.optim as optim 18from torch.testing._internal.common_device_type import instantiate_device_type_tests, onlyCUDA, onlyCPU 19from typing import List, Tuple, Optional 20from torch.testing._internal.common_nn import NNTestCase 21from torch.testing._internal.common_utils import ( 22 TEST_WITH_ROCM, 23 skipIfRocm, 24 skipIfTorchDynamo, 25 TEST_FAIRSEQ, 26 run_tests, 27 parametrize, 28 freeze_rng_state, 29 TEST_WITH_CROSSREF, 30 slowTest, 31 set_default_dtype, 32 gradcheck, 33 make_tensor, 34 NOTEST_CPU, 35 IS_WINDOWS, 36 TEST_WITH_TORCHDYNAMO, 37) 38from torch._dynamo.testing import CompileCounterWithBackend 39 40 41from torch.testing._internal.common_methods_invocations import wrapper_set_seed 42from torch.testing._internal.common_cuda import ( 43 IS_JETSON, SM80OrLater, PLATFORM_SUPPORTS_FLASH_ATTENTION, 44 PLATFORM_SUPPORTS_MEM_EFF_ATTENTION, 45 PLATFORM_SUPPORTS_FUSED_ATTENTION, 46 PLATFORM_SUPPORTS_CUDNN_ATTENTION 47) 48 49if TEST_FAIRSEQ: 50 import fairseq.models.transformer as fairseq_transformer 51 52SdpaShape = namedtuple('Sdpa_Shape', ['batch', 'num_heads', 'seq_len', 'head_dim']) 53Tolerances = namedtuple('Tolerances', ['atol', 'rtol']) 54 55@contextlib.contextmanager 56def use_deterministic_algorithims(mode: bool, warn_only: bool): 57 r""" 58 This context manager can be used to temporarily enable or disable deterministic algorithms. 59 Upon exiting the context manager, the previous state of the flag will be restored. 60 """ 61 previous_mode: bool = torch.are_deterministic_algorithms_enabled() 62 previous_warn_only: bool = torch.is_deterministic_algorithms_warn_only_enabled() 63 try: 64 torch.use_deterministic_algorithms(mode, warn_only=warn_only) 65 yield {} 66 finally: 67 torch.use_deterministic_algorithms(previous_mode, warn_only=previous_warn_only) 68 69 70# Found in torch/testing/_comparison.py 71default_atol = {torch.float16: 1e-3, torch.bfloat16: 1e-3, torch.float32: 1e-5} 72default_rtol = {torch.float16: 1e-3, torch.bfloat16: 1.6e-2, torch.float32: 1.3e-6} 73 74isSM8XDevice = torch.cuda.is_available() and torch.cuda.get_device_capability() in [(8, 6), (8, 7), (8, 9)] 75isSM90Device = torch.cuda.is_available() and torch.cuda.get_device_capability() == (9, 0) 76isSM5xDevice = torch.cuda.is_available() and torch.cuda.get_device_capability()[0] == 5 77isLessThanSM80Device = torch.cuda.is_available() and torch.cuda.get_device_capability()[0] < 8 78 79def get_rtol(true_value: torch.Tensor, computed_value: torch.Tensor) -> float: 80 deviation = true_value - computed_value 81 deviation = torch.abs(deviation / true_value) 82 # Fill in the nans with the default rtol 83 torch.nan_to_num_(deviation, nan=default_rtol[computed_value.dtype]) 84 return deviation.max().item() 85 86 87def get_atol(true_value: torch.Tensor, computed_value: torch.Tensor) -> float: 88 deviation = true_value - computed_value 89 atol = torch.abs(deviation).max().item() 90 return atol 91 92 93def get_tolerances( 94 true_value: torch.Tensor, 95 computed_value: torch.Tensor, 96 fudge_factor: Optional[float] = None, 97) -> Tuple[float, float]: 98 """Returns the absolute and relative tolerances for comparing two tensors.""" 99 fudge_factor = fudge_factor if fudge_factor is not None else 1.0 100 atol = get_atol(true_value, computed_value) 101 rtol = get_rtol(true_value, computed_value) 102 103 atol = fudge_factor * max(atol, default_atol[computed_value.dtype]) 104 rtol = fudge_factor * max(rtol, default_rtol[computed_value.dtype]) 105 # torch.isclose() has weird behavior around see: 106 # https://github.com/pytorch/pytorch/issues/102400 107 if rtol > 1e30: 108 rtol = default_rtol[computed_value.dtype] 109 return atol, rtol 110 111 112def query_key_value_clones(query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, dtype: torch.dtype = None): 113 """ Clones the query, key, and value tensors and moves them to the specified dtype. """ 114 if dtype is None: 115 dtype = query.dtype 116 query_ref = query.clone().detach().to(dtype).requires_grad_(query.requires_grad) 117 key_ref = key.clone().detach().to(dtype).requires_grad_(key.requires_grad) 118 value_ref = value.clone().detach().to(dtype).requires_grad_(value.requires_grad) 119 return query_ref, key_ref, value_ref 120 121def get_platform_specific_sdpa(): 122 ret = [] 123 if PLATFORM_SUPPORTS_FLASH_ATTENTION: 124 ret.append(SDPBackend.FLASH_ATTENTION) 125 if PLATFORM_SUPPORTS_MEM_EFF_ATTENTION: 126 ret.append(SDPBackend.EFFICIENT_ATTENTION) 127 if PLATFORM_SUPPORTS_CUDNN_ATTENTION: 128 ret.append(SDPBackend.CUDNN_ATTENTION) 129 if not ret: 130 # Add a placeholder, an empty list causes "An empty arg_values was passed to @parametrize" 131 ret.append(SDPBackend.EFFICIENT_ATTENTION) 132 return ret 133 134PLATFORM_SPECIFIC_SDPA = get_platform_specific_sdpa() 135# Indicate the Efficient attention backend can support: 136# 1. sequence longher than 512 137# 2. head dimsion larger than 64 138MEM_EFF_CAPABILITY_MATCHES_SM80 = SM80OrLater or TEST_WITH_ROCM 139 140def rand_sdpa_tensor(shape: SdpaShape, device: str, dtype: torch.dtype, type: str, 141 requires_grad: bool = False, packed: bool = False) -> torch.Tensor: 142 """Creates rand dense or nested tensor with given shape and type. 143 144 Args: 145 shape (Tuple[int]): Shape of Tensor to construct 146 device (str): which device to create tensor on 147 dtype (torch.dtype): Tensors' dtype 148 type (str): Nested or Dense 149 requires_grad (bool, optional): Tensors grad status. Defaults to False. 150 packed (bool, optional): Whether to create a single QKV packed or not. Defaults to False. 151 152 Returns: 153 torch.Tensor: A new tensor 154 """ 155 batch, num_heads, seq_len, head_dim = shape.batch, shape.num_heads, shape.seq_len, shape.head_dim 156 if type == "nested": 157 if isinstance(seq_len, list): 158 def _size(i): 159 return (seq_len[i], num_heads, head_dim) if not packed else (seq_len[i], 3 * num_heads * head_dim) 160 161 return torch.nested.nested_tensor([ 162 torch.randn(_size(i), device=device, dtype=dtype, requires_grad=requires_grad) 163 for i in range(batch)]) 164 else: 165 size = (seq_len, num_heads, head_dim) if not packed else (seq_len, 3 * num_heads * head_dim) 166 return torch.nested.nested_tensor([ 167 torch.randn(size, device=device, dtype=dtype, requires_grad=requires_grad) 168 for _ in range(batch)]) 169 else: 170 assert (isinstance(seq_len, int)) 171 size = (batch, seq_len, num_heads, head_dim) if not packed else (batch, seq_len, 3 * num_heads * head_dim) 172 return torch.randn(size, device=device, dtype=dtype, requires_grad=requires_grad) 173 174def calculate_nt_tolerances(nt_ref_hp, nt_ref_lp, default_dtype, fudge_factor=1): 175 # TODO use NT ops when we have implemented Max for NestedTensor instead of unrolling 176 ref_atol = default_atol[default_dtype] 177 ref_rtol = default_rtol[default_dtype] 178 for tensor_component_ref, tensor_component_ref_lp in zip(nt_ref_hp.unbind(), nt_ref_lp.unbind()): 179 ref_atol = max((fudge_factor * torch.abs(tensor_component_ref - tensor_component_ref_lp)).max().item(), ref_atol) 180 ref_rtol = max(get_rtol(tensor_component_ref, tensor_component_ref_lp), ref_rtol) 181 return ref_atol, ref_rtol 182 183class TestTransformers(NNTestCase): 184 _do_cuda_memory_leak_check = True 185 _do_cuda_non_default_stream = True 186 187 @onlyCUDA 188 @unittest.skip("4D mask not supported yet - activate when 4D mask supported") 189 def test_self_attn_TxT_attn_mask(self, device): 190 embed_dim = 16 191 num_heads = 4 192 batch_size = 10 193 tgt_len = 16 194 195 query = torch.rand(batch_size, tgt_len, embed_dim, device=device) # [N, T, D] 196 attn_mask = torch.randint(0, 2, (tgt_len, tgt_len)).cuda().float() # [T, T] 197 attn_mask = attn_mask.masked_fill(attn_mask == 0, float('-inf')).masked_fill(attn_mask == 1, 0.0) 198 199 attn_mask_4d = attn_mask.expand(batch_size, num_heads, tgt_len, tgt_len) 200 201 mta_model = torch.nn.MultiheadAttention(embed_dim, num_heads, batch_first=True).cuda() 202 mta_model.eval() 203 204 # Generate 3D results 205 with torch.inference_mode(): 206 output_mask_4d = mta_model(query, query, query, attn_mask=attn_mask_4d)[0] 207 output_mask_4d = output_mask_4d.transpose(0, 1) # [N, T, D] 208 209 output_mask_TxT = mta_model(query, query, query, attn_mask=attn_mask)[0] 210 output_mask_TxT = output_mask_TxT.transpose(0, 1) # [N, T, D] 211 212 self.assertEqual(output_mask_4d, output_mask_TxT) 213 214 @slowTest 215 def test_train_with_pad_and_catch_error(self, device): 216 iters = 100 217 pad_mask = torch.tensor([[1, 1, 0, 0]], dtype=torch.bool).to(device) 218 layer = nn.TransformerEncoderLayer( 219 d_model=2, 220 dim_feedforward=4, 221 nhead=2, 222 batch_first=True, 223 activation="gelu", 224 dropout=0, 225 ) 226 criterion = nn.MSELoss() 227 encoder = nn.TransformerEncoder(layer, 2).to(device) 228 optimizer = optim.SGD(encoder.parameters(), lr=0.1, momentum=0.9) 229 encoder.train() 230 for i in range(iters): 231 encoder.train() 232 optimizer.zero_grad() 233 inputs = torch.cat([torch.randn(1, 2, 2), torch.zeros(1, 2, 2)], dim=1).to(device) 234 235 outputs = encoder(inputs, src_key_padding_mask=pad_mask) 236 237 loss = criterion(outputs[:, 0:2, :], inputs[:, 0:2, :]) 238 loss.backward() 239 optimizer.step() 240 241 with torch.no_grad(): 242 test = torch.cat([torch.randn(1, 2, 2), torch.zeros(1, 2, 2)], dim=1).to(device) 243 244 # Expect uint8 type not supported 245 ex = None 246 try: 247 test_train_uint8 = encoder(test, src_key_padding_mask=pad_mask.to(torch.uint8)) 248 except AssertionError as e: 249 continue 250 self.assertFalse(e, "Failed to catch unsupported uint8 type exception") # noqa: F821 251 252 test_train_bool = encoder(test, src_key_padding_mask=pad_mask) 253 encoder.eval() 254 255 # Expect long type not supported 256 ex = None 257 try: 258 test_eval_uint8 = encoder(test, src_key_padding_mask=pad_mask.to(torch.int64)) 259 except AssertionError as e: 260 continue 261 self.assertFalse(e, "Failed to catch unsupported Long type exception") # noqa: F821 262 263 test_eval_bool = encoder(test, src_key_padding_mask=pad_mask) 264 l1_bool = nn.L1Loss()(test_train_bool[:, 0:2, :], test_eval_bool[:, 0:2, :]).item() 265 self.assertTrue(l1_bool < 1e-4, "Eval/Train difference in pad_mask BOOL") 266 267 @parametrize("attn_mask_dim", [2, 3, None]) 268 @parametrize("key_padding_mask_dim", [2, None]) 269 @parametrize("mask_dtype", [torch.bool, torch.float32]) 270 def test_multiheadattention_fastpath_attn_mask(self, device, attn_mask_dim, key_padding_mask_dim, mask_dtype): 271 with torch.no_grad(): 272 B = 2 273 L = 4 274 D = 8 275 H = 4 276 277 if attn_mask_dim == 2: 278 attn_mask = make_tensor((L, L), dtype=mask_dtype, device=device) 279 elif attn_mask_dim == 3: 280 attn_mask = make_tensor((B * H, L, L), dtype=mask_dtype, device=device) 281 elif attn_mask_dim is None: 282 attn_mask = None 283 284 if key_padding_mask_dim == 2: 285 key_padding_mask = make_tensor((B, L), dtype=mask_dtype, device=device) 286 elif key_padding_mask_dim is None: 287 key_padding_mask = None 288 289 mha = nn.MultiheadAttention(D, H, batch_first=True, device=device) 290 X = torch.randn(B, L, D, device=device) 291 292 mha.train() # disable fast path 293 out, _ = mha(X, X, X, attn_mask=attn_mask, key_padding_mask=key_padding_mask, need_weights=False) 294 mha.eval() # enable fast path 295 out_fp, _ = mha(X, X, X, attn_mask=attn_mask, key_padding_mask=key_padding_mask, need_weights=False) 296 self.assertEqual(out, out_fp) 297 298 @parametrize("nhead", [1, 4, 8]) 299 def test_transformerencoderlayer_src_mask(self, device, nhead): 300 batch_size = 2 301 seqlen = 4 302 d_model = 8 303 dim_feedforward = 32 304 305 model = torch.nn.TransformerEncoderLayer( 306 d_model=d_model, 307 nhead=nhead, 308 dim_feedforward=dim_feedforward, 309 batch_first=True).to(device) 310 src = torch.rand(batch_size, seqlen, d_model).to(device) # bs, seqlen, d_model 311 src_mask = torch.zeros(seqlen, seqlen).to(torch.bool).to(device) 312 313 model(src, src_mask=src_mask) 314 model.eval() 315 with torch.no_grad(): 316 model(src, src_mask=src_mask) 317 318 @parametrize("use_torchscript", [False]) 319 @parametrize("enable_nested_tensor", [True, False]) 320 @parametrize("use_autocast", [True, False]) 321 @parametrize("d_model", [12, 256]) 322 def test_transformerencoder_fastpath(self, device, use_torchscript, enable_nested_tensor, use_autocast, d_model): 323 """ 324 Test TransformerEncoder fastpath output matches slowpath output 325 """ 326 torch.manual_seed(1234) 327 nhead = 4 328 dim_feedforward = d_model 329 batch_first = True 330 331 model = torch.nn.TransformerEncoder( 332 torch.nn.TransformerEncoderLayer( 333 d_model=d_model, 334 nhead=nhead, 335 dim_feedforward=dim_feedforward, 336 batch_first=batch_first), 337 num_layers=2, 338 enable_nested_tensor=enable_nested_tensor 339 ).to(device).eval() 340 341 if use_torchscript: 342 model = torch.jit.script(model) 343 344 # each input is (input, mask) 345 input_mask_pairs = [ 346 ( 347 torch.rand(3, 2, d_model), 348 [ 349 [0, 1], 350 [0, 1], 351 [1, 1] 352 ] 353 ), 354 ( 355 torch.rand(2, 100, d_model), 356 [ 357 [0] * 98 + [1] * 2, 358 [0] * 90 + [1] * 10 359 ] 360 ), 361 # softmax.cu switches from fast->slowpath at masked seqlen 1024. test 1024. 362 ( 363 torch.rand(2, 1024, d_model), 364 [ 365 [0] * 1020 + [1] * 4, 366 [0] * 1024, 367 ] 368 ), 369 ( 370 torch.rand(1, 1026, d_model), 371 [[0] * 1024 + [1] * 2] 372 ), 373 # softmax.cu switches from fast->slowpath at masked seqlen 1024. test range of masks above 1024. 374 ( 375 torch.rand(4, 1040, d_model), 376 [ 377 [0] * 1024 + [1] * 16, 378 [0] * 1025 + [1] * 15, 379 [0] * 1031 + [1] * 9, 380 [0] * 1040, 381 ] 382 ) 383 ] 384 input_mask_pairs = [ 385 ( 386 torch.tensor(pair[0], device=device, dtype=torch.get_default_dtype()), # float input 387 torch.tensor(pair[1], device=device, dtype=torch.bool) # bool mask 388 ) for pair in input_mask_pairs 389 ] 390 391 maybe_autocast = torch.autocast("cuda", dtype=torch.float16) if use_autocast else contextlib.nullcontext() 392 with maybe_autocast: 393 for input, src_key_padding_mask in input_mask_pairs: 394 with torch.no_grad(): 395 fastpath_output = model(input, src_key_padding_mask=src_key_padding_mask) 396 slowpath_output = model(input, src_key_padding_mask=src_key_padding_mask) # reference 397 # Make sure fastpath_output is same shape as slowpath_output and mask. 398 # When enable_nested_tensor=true, fastpath_output may be smaller than input tensor. 399 # Eg if input bs=1, seqlen=6, and we mask out 2 tokens, fastpath_output will have bs=1, seqlen=4. 400 # Expand back to old size to match. 401 bs, true_seqlen, embed_dim = fastpath_output.shape 402 expanded_seqlen = src_key_padding_mask.shape[1] 403 fastpath_output_expanded = torch.zeros(bs, expanded_seqlen, embed_dim, device=device) 404 fastpath_output_expanded[:, :true_seqlen, :] = fastpath_output 405 # no garauntees on output corresponding to masked tokens, so they may vary between slow/fast path. set all to 0. 406 fastpath_output_expanded = fastpath_output_expanded.masked_fill(src_key_padding_mask.unsqueeze(-1), 0) 407 slowpath_output = slowpath_output.masked_fill(src_key_padding_mask.unsqueeze(-1), 0) 408 torch.testing.assert_close(fastpath_output_expanded, slowpath_output, rtol=1e-7, atol=1e-5) 409 410 @parametrize("with_no_grad", [True, False]) 411 @parametrize("training", [True, False]) 412 @parametrize("enable_nested_tensor", [False]) 413 def test_transformerencoder_square_input(self, with_no_grad, training, enable_nested_tensor, device): 414 """ 415 Test for edge cases when input of shape (batch size, sequence length, embedding dimension) has 416 batch size == sequence length 417 """ 418 model = torch.nn.TransformerEncoder( 419 torch.nn.TransformerEncoderLayer(d_model=4, nhead=2, dim_feedforward=16, dropout=0.0, batch_first=True), 420 num_layers=2, 421 enable_nested_tensor=enable_nested_tensor 422 ).to(device) 423 424 with torch.no_grad(): 425 # set constant weights of the model 426 for idx, p in enumerate(model.parameters()): 427 x = p.data 428 sz = x.view(-1).size(0) 429 shape = x.shape 430 x = torch.cos(torch.arange(0, sz).float().view(shape)) 431 p.data.copy_(x) 432 433 if training: 434 model = model.train() 435 else: 436 model = model.eval() 437 x = torch.arange(0, 16).reshape(2, 2, 4).to(torch.get_default_dtype()).to(device) 438 src_mask = torch.Tensor([[0, 1], [0, 0]]).to(torch.bool).to(device) 439 440 if with_no_grad: 441 cm = torch.no_grad() 442 else: 443 cm = contextlib.nullcontext() 444 with cm: 445 result = model(x, mask=src_mask) 446 447 ref_output = torch.Tensor([[[2.420306205749512, 0.017629241570830, -0.607857942581177, -0.085519507527351], 448 [2.420306205749512, 0.017629241570830, -0.607857942581177, -0.085519507527351]], 449 [[2.419836044311523, 0.017548924311996, -0.608187675476074, -0.085347734391689], 450 [2.419836044311523, 0.017548924311996, -0.608187675476074, -0.085347734391689]]] 451 ).to(device) 452 self.assertEqual(tuple(result.shape), tuple(ref_output.shape)) 453 torch.testing.assert_close(result, ref_output, rtol=1e-7, atol=1e-5) 454 455 @parametrize("batch_first", [True, False]) 456 @parametrize("training", [True, False]) 457 @parametrize("enable_nested_tensor", [True, False]) 458 def test_transformerencoder(self, batch_first, training, enable_nested_tensor, device): 459 def get_a_test_layer(activation, batch_first=False): 460 d_model = 4 461 nhead = 2 462 dim_feedforward = 16 463 dropout = 0.0 464 465 layer = nn.TransformerEncoderLayer( 466 d_model, 467 nhead, 468 dim_feedforward=dim_feedforward, 469 dropout=dropout, 470 activation=activation, 471 batch_first=batch_first, 472 ).to(device) 473 474 with torch.no_grad(): 475 # set constant weights of the model 476 for idx, p in enumerate(layer.parameters()): 477 x = p.data 478 sz = x.view(-1).size(0) 479 shape = x.shape 480 x = torch.cos(torch.arange(0, sz).float().view(shape)) 481 p.data.copy_(x) 482 483 return layer 484 485 # this is a deterministic test for TransformerEncoder 486 activation = F.relu 487 488 def _test(batch_first, training, enable_nested_tensor): 489 def perm_fn(x): 490 return x.transpose(1, 0) if batch_first else x 491 492 encoder_layer = get_a_test_layer(activation=activation, 493 batch_first=batch_first) 494 495 model = nn.TransformerEncoder( 496 encoder_layer, 1, enable_nested_tensor=enable_nested_tensor 497 ).to(device) 498 499 if not training: 500 model = model.eval() 501 502 # deterministic input 503 encoder_input = perm_fn(torch.tensor([[[0.7462, 0.6653, 0.5679, 0.4891], 504 [0.5387, 0.1655, 0.3565, 0.0471]], 505 [[0.8335, 0.2799, 0.5031, 0.2947], 506 [0.1402, 0.0318, 0.7636, 0.1346]], 507 [[0.6333, 0.9344, 0.1376, 0.9938], 508 [0.8924, 0.2872, 0.6692, 0.2944]], 509 [[0.9897, 0.6915, 0.3154, 0.1733], 510 [0.8645, 0.3513, 0.3064, 0.0767]], 511 [[0.8117, 0.2366, 0.4838, 0.7881], 512 [0.3718, 0.4945, 0.9511, 0.0864]]] 513 )).to(device) 514 result = model(encoder_input) 515 ref_output = perm_fn(torch.tensor([[[2.428589, 0.020835, -0.602055, -0.085249], 516 [2.427987, 0.021213, -0.602496, -0.084103]], 517 [[2.424689, 0.019155, -0.604793, -0.085672], 518 [2.413863, 0.022211, -0.612486, -0.072490]], 519 [[2.433774, 0.021598, -0.598343, -0.087548], 520 [2.425104, 0.019748, -0.604515, -0.084839]], 521 [[2.436185, 0.022682, -0.596625, -0.087261], 522 [2.433556, 0.021891, -0.598509, -0.086832]], 523 [[2.416246, 0.017512, -0.610712, -0.082961], 524 [2.422901, 0.024187, -0.606178, -0.074929]]] 525 )).to(device) 526 self.assertEqual(tuple(result.shape), tuple(ref_output.shape)) 527 torch.testing.assert_close(result, ref_output, rtol=1e-7, atol=1e-5) 528 529 # all 0 src_mask 530 src_mask = torch.zeros([5, 5]).to(device) == 1 531 result = model(encoder_input, mask=src_mask) 532 self.assertEqual(tuple(result.shape), tuple(ref_output.shape)) 533 torch.testing.assert_close(result, ref_output, rtol=1e-7, atol=1e-5) 534 535 # all 0 536 mask = torch.zeros([2, 5]).to(device) == 1 537 result = model(encoder_input, src_key_padding_mask=mask) 538 self.assertEqual(tuple(result.shape), tuple(ref_output.shape)) 539 torch.testing.assert_close(result, ref_output, rtol=1e-7, atol=1e-5) 540 541 mask[0, 1] = 1 542 mask[1, 3] = 1 543 mask[1, 4] = 1 544 result = model(encoder_input, src_key_padding_mask=mask) 545 ref_output = perm_fn(torch.tensor([[[2.429026, 0.020793, -0.601741, -0.085642], 546 [2.428811, 0.021445, -0.601912, -0.084252]], 547 [[2.425009, 0.019155, -0.604566, -0.085899], 548 [2.415408, 0.02249, -0.611415, -0.073]], 549 [[2.434199, 0.021682, -0.598039, -0.087699], 550 [2.42598, 0.019941, -0.603896, -0.085091]], 551 [[2.436457, 0.022736, -0.59643, -0.08736], 552 [2.434021, 0.022093, -0.598179, -0.08679]], 553 [[2.416531, 0.017498, -0.610513, -0.083181], 554 [2.4242, 0.024653, -0.605266, -0.074959]]] 555 )).to(device) 556 self.assertEqual(tuple(result.shape), tuple(ref_output.shape)) 557 torch.testing.assert_close(result, ref_output, rtol=1e-7, atol=1e-5) 558 559 # test case 2, multiple layers no norm 560 model = nn.TransformerEncoder(encoder_layer, 2, enable_nested_tensor=enable_nested_tensor).to(device) 561 if not training: 562 model = model.eval() 563 result = model(encoder_input, src_key_padding_mask=mask) 564 ref_output = perm_fn(torch.tensor([[[2.419051, 0.017446, -0.608738, -0.085003], 565 [2.419102, 0.017452, -0.608703, -0.085026]], 566 [[2.419043, 0.017445, -0.608744, -0.084999], 567 [2.419052, 0.017446, -0.608738, -0.085004]], 568 [[2.419067, 0.017448, -0.608727, -0.085010], 569 [2.419098, 0.017452, -0.608706, -0.085024]], 570 [[2.419072, 0.017449, -0.608724, -0.085012], 571 [2.419119, 0.017455, -0.608691, -0.085034]], 572 [[2.419019, 0.017442, -0.608761, -0.084989], 573 [2.419075, 0.017449, -0.608722, -0.085014]]] 574 )).to(device) 575 self.assertEqual(tuple(result.shape), tuple(ref_output.shape)) 576 torch.testing.assert_close(result, ref_output, rtol=1e-7, atol=1e-5) 577 578 model = nn.TransformerEncoder(encoder_layer, 6, enable_nested_tensor=enable_nested_tensor).to(device) 579 if not training: 580 model = model.eval() 581 result = model(encoder_input, src_key_padding_mask=mask) 582 ref_output = perm_fn(torch.tensor([[[2.419101, 0.017453, -0.608703, -0.085025], 583 [2.419101, 0.017453, -0.608704, -0.085025]], 584 [[2.419101, 0.017453, -0.608703, -0.085025], 585 [2.419101, 0.017453, -0.608704, -0.085025]], 586 [[2.419101, 0.017453, -0.608703, -0.085025], 587 [2.419101, 0.017453, -0.608704, -0.085025]], 588 [[2.419101, 0.017453, -0.608703, -0.085025], 589 [2.419101, 0.017453, -0.608704, -0.085025]], 590 [[2.419101, 0.017453, -0.608703, -0.085025], 591 [2.419101, 0.017453, -0.608704, -0.085025]]] 592 )).to(device) 593 self.assertEqual(tuple(result.shape), tuple(ref_output.shape)) 594 torch.testing.assert_close(result, ref_output, rtol=1e-7, atol=1e-5) 595 596 # test case 3, multiple layers with norm 597 # d_model = 4 598 norm = nn.LayerNorm(4) 599 model = nn.TransformerEncoder(encoder_layer, 2, norm=norm, 600 enable_nested_tensor=enable_nested_tensor).to(device) 601 if not training: 602 model = model.eval() 603 result = model(encoder_input, src_key_padding_mask=mask) 604 ref_output = perm_fn(torch.tensor([[[1.695949, -0.357635, -0.893077, -0.445238], 605 [1.695955, -0.357639, -0.893050, -0.445266]], 606 [[1.695948, -0.357634, -0.893082, -0.445233], 607 [1.695950, -0.357635, -0.893077, -0.445238]], 608 [[1.695951, -0.357636, -0.893069, -0.445246], 609 [1.695955, -0.357639, -0.893052, -0.445264]], 610 [[1.695952, -0.357636, -0.893066, -0.445249], 611 [1.695957, -0.357641, -0.893041, -0.445276]], 612 [[1.695946, -0.357632, -0.893095, -0.445220], 613 [1.695952, -0.357637, -0.893065, -0.445251]]] 614 )).to(device) 615 self.assertEqual(tuple(result.shape), tuple(ref_output.shape)) 616 torch.testing.assert_close(result, ref_output, rtol=1e-7, atol=1e-5) 617 618 model = nn.TransformerEncoder(encoder_layer, 6, norm=norm, 619 enable_nested_tensor=enable_nested_tensor).to(device) 620 if not training: 621 model = model.eval() 622 result = model(encoder_input, src_key_padding_mask=mask) 623 ref_output = perm_fn(torch.tensor([[[1.695955, -0.357639, -0.893051, -0.445265], 624 [1.695955, -0.357639, -0.893051, -0.445265]], 625 [[1.695955, -0.357639, -0.893051, -0.445265], 626 [1.695955, -0.357639, -0.893051, -0.445265]], 627 [[1.695955, -0.357639, -0.893051, -0.445265], 628 [1.695955, -0.357639, -0.893051, -0.445265]], 629 [[1.695955, -0.357639, -0.893051, -0.445265], 630 [1.695955, -0.357639, -0.893051, -0.445265]], 631 [[1.695955, -0.357639, -0.893051, -0.445265], 632 [1.695955, -0.357639, -0.893051, -0.445265]]] 633 )).to(device) 634 self.assertEqual(tuple(result.shape), tuple(ref_output.shape)) 635 torch.testing.assert_close(result, ref_output, rtol=1e-7, atol=1e-5) 636 637 # TODO: remove set default dtype to double by making ref_output more precise. 638 # Added because this test was copied from test_nn.py, which has default 639 # dtype double. If default dtype is float, tests will say tensors not close because 640 # ref output precision too low 641 with set_default_dtype(torch.double): 642 if training: 643 cm = contextlib.nullcontext() 644 else: 645 cm = torch.no_grad() # transformer fast path requires no grad 646 with cm: 647 _test(batch_first, training, enable_nested_tensor) 648 649 @unittest.skipIf(sys.version_info < (3, 11), "not supported on pre-3.11 Python") 650 def test_encoder_padding_and_src_mask_bool(self): 651 encoder_layer = nn.TransformerEncoderLayer( 652 d_model=16, 653 nhead=2, 654 dim_feedforward=32, 655 dropout=0.1, 656 activation='relu', 657 batch_first=True, 658 ) 659 encoder_norm = nn.LayerNorm(16) 660 encoder = nn.TransformerEncoder( 661 encoder_layer, 2, encoder_norm 662 ) 663 664 inputs = torch.randn(2, 3, 16) 665 666 src_mask = torch.ones(3, 3, dtype=torch.bool).triu_(diagonal=1) 667 input_seq_len = torch.tensor([3, 2]) 668 padding_mask = ( 669 torch.arange(3)[None, :].cpu() >= input_seq_len[:, None] 670 ) 671 672 with (self.assertNoLogs(None) if not TEST_WITH_TORCHDYNAMO else contextlib.nullcontext()): 673 encoder( 674 inputs, 675 mask=src_mask, 676 src_key_padding_mask=padding_mask, 677 ) 678 679 @unittest.skipIf(sys.version_info < (3, 11), "not supported on pre-3.11 Python") 680 def test_decoder_padding_and_src_mask_bool(self): 681 682 def transformer_decoder(inputs, input_seq_len, memory): 683 decoder_layer = nn.TransformerDecoderLayer( 684 d_model=16, 685 nhead=2, 686 dim_feedforward=32, 687 dropout=0.1, 688 activation='relu', 689 batch_first=True, 690 ) 691 decoder_norm = nn.LayerNorm(16) 692 decoder = nn.TransformerDecoder( 693 decoder_layer, 2, decoder_norm 694 ) 695 696 src_mask = torch.ones( 697 inputs.shape[1], inputs.shape[1], dtype=torch.bool 698 ).triu_(diagonal=1) 699 padding_mask = ( 700 torch.arange(inputs.shape[1])[None, :].cpu() 701 >= input_seq_len[:, None] 702 ) 703 704 return decoder( 705 inputs, 706 memory, 707 tgt_mask=src_mask, 708 tgt_key_padding_mask=padding_mask, 709 memory_key_padding_mask=padding_mask, 710 ) 711 712 inputs = torch.randn(2, 3, 16) 713 memory = torch.randn(2, 3, 16) 714 input_seq_len = torch.tensor([3, 2]) 715 716 with self.assertNoLogs(None): 717 transformer_decoder(inputs, input_seq_len, memory) 718 719 def test_encoder_is_causal(self): 720 721 d_model = 3 722 layer = torch.nn.TransformerEncoderLayer(d_model, 1, 6, batch_first=True) 723 layer.eval() 724 x = torch.randn(1, 5, d_model) 725 unmasked_output = layer(x) 726 mask = torch.nn.Transformer.generate_square_subsequent_mask(x.size(1)) 727 is_causal_output = layer(x, src_mask=mask, is_causal=True) 728 masked_output = layer(x, src_mask=mask) 729 730 self.assertEqual(masked_output, is_causal_output) 731 732 @onlyCUDA 733 @parametrize("nb_heads", [1, 8]) 734 @parametrize("bias", [True, False]) 735 def test_mha_native_args(self, nb_heads, bias): 736 737 B, L, F = 8, 100, 128 738 batch_first = True 739 fast_path = True 740 use_pad_mask = (bias % 2) == 1 741 742 mha = nn.MultiheadAttention( 743 embed_dim=F, 744 num_heads=nb_heads, 745 batch_first=batch_first, 746 bias=bias 747 ).cuda() 748 mha.eval() 749 750 ctx = torch.no_grad if fast_path else contextlib.nullcontext 751 with ctx(): 752 x = torch.randn(B, L, F).cuda() 753 if not batch_first: 754 x = x.transpose(0, 1) 755 756 pad_mask = None 757 if use_pad_mask: 758 pad_mask = torch.zeros((B, L), dtype=torch.bool).cuda() 759 760 mha(query=x, key=x, value=x, key_padding_mask=pad_mask) 761 762 def test_kpm_mask_trailing_column_with_nested_tensor(self, device): 763 encoder_layer = nn.TransformerEncoderLayer( 764 d_model=256, 765 nhead=4, 766 dim_feedforward=512, 767 activation='gelu', 768 norm_first=False, 769 batch_first=False, 770 ) 771 transformer_encoder = nn.TransformerEncoder(encoder_layer, num_layers=3, enable_nested_tensor=True).to(device) 772 773 x = torch.randn(10, 6, 256).to(device) 774 mask = torch.ones(6, 10) 775 mask[0, :] = 0 # here I masked 5 columns instead of just one 776 mask = mask.bool().to(device) 777 out = transformer_encoder(src=x, src_key_padding_mask=mask) 778 self.assertEqual(out.shape[1], 6) 779 780 # CPU unit test has_torch_functions in test environment, 781 # preventing successful completion 782 @onlyCUDA 783 def test_with_nested_tensor_input(self, device): 784 encoder_layer = nn.TransformerEncoderLayer( 785 d_model=256, 786 nhead=4, 787 dim_feedforward=512, 788 activation='gelu', 789 norm_first=False, 790 batch_first=True, 791 ) 792 transformer_encoder = nn.TransformerEncoder(encoder_layer, num_layers=3, enable_nested_tensor=True).to(device) 793 794 transformer_encoder.eval() 795 with torch.no_grad(): 796 x = torch.randn(6, 10, 256).to(device) 797 mask = torch.ones(6, 10) 798 mask[0, 0:] = 0 # here I masked 5 columns instead of just one 799 mask[2, 2:] = 0 # here I masked 5 columns instead of just one 800 mask[4, 4:] = 0 # here I masked 5 columns instead of just one 801 mask[5, 8:] = 0 # here I masked 5 columns instead of just one 802 mask = mask.bool().to(device) 803 x = torch._nested_tensor_from_mask(x, mask.logical_not(), mask_check=False) 804 out = transformer_encoder(src=x, src_key_padding_mask=None) 805 806 self.assertEqual(out.is_nested, True) 807 808 809 810 def test_script_encoder_subclass(self, device): 811 class MyCustomLayer(nn.TransformerEncoderLayer): 812 pass 813 814 encoder = nn.TransformerEncoder( 815 MyCustomLayer(d_model=256, nhead=8), num_layers=6 816 ).to(device=device) 817 torch.jit.script(encoder) 818 819 # brazenly adapted from test_transformerencoderlayer_src_mask to test execution of 820 # torchscripted transformerencoderlayer subclass 821 def test_transformerencoderlayer_subclass(self, device): 822 class MyCustomLayer(nn.TransformerEncoderLayer): 823 pass 824 825 nhead = 4 826 batch_size = 2 827 seqlen = 4 828 d_model = 8 829 dim_feedforward = 32 830 831 model = MyCustomLayer( 832 d_model=d_model, 833 nhead=nhead, 834 dim_feedforward=dim_feedforward, 835 batch_first=True).to(device) 836 script_model = torch.jit.script(model) 837 838 src = torch.rand(batch_size, seqlen, d_model).to(device) # bs, seqlen, d_model 839 src_mask = torch.zeros(seqlen, seqlen).to(torch.bool).to(device) 840 841 torch.manual_seed(42) 842 result = model(src, src_mask=src_mask) 843 torch.manual_seed(42) 844 scripted_result = script_model(src, src_mask=src_mask) 845 self.assertEqual(result, scripted_result) 846 847 model.eval() 848 script_model = torch.jit.script(model) 849 850 with torch.no_grad(): 851 result = model(src, src_mask=src_mask) 852 scripted_result = script_model(src, src_mask=src_mask) 853 self.assertEqual(result, scripted_result) 854 855 856 def test_transformerencoderlayer_subclass_model(self, device): 857 class MyCustomLayer(nn.TransformerEncoderLayer): 858 pass 859 860 nhead = 4 861 batch_size = 2 862 seqlen = 4 863 d_model = 8 864 dim_feedforward = 32 865 866 layer = MyCustomLayer( 867 d_model=d_model, 868 nhead=nhead, 869 dim_feedforward=dim_feedforward, 870 batch_first=True) 871 model = nn.TransformerEncoder( 872 layer, num_layers=6 873 ).to(device=device) 874 script_model = torch.jit.script(model) 875 876 src = torch.rand(batch_size, seqlen, d_model).to(device) # bs, seqlen, d_model 877 src_mask = torch.zeros(seqlen, seqlen).to(torch.bool).to(device) 878 879 torch.manual_seed(42) 880 result = model(src, mask=src_mask) 881 torch.manual_seed(42) 882 scripted_result = script_model(src, mask=src_mask) 883 self.assertEqual(result, scripted_result) 884 885 model.eval() 886 script_model = torch.jit.script(model) 887 888 with torch.no_grad(): 889 result = model(src, mask=src_mask) 890 scripted_result = script_model(src, mask=src_mask) 891 self.assertEqual(result, scripted_result) 892 893 894 @onlyCUDA 895 @unittest.skipIf(not TEST_FAIRSEQ, "Fairseq not found") 896 def test_decoder_only_layer(self): 897 DEFAULT_PADDING_IDX = 0 898 899 class FairseqDecoder(torch.nn.Module): 900 def __init__( 901 self, 902 embed_dim, 903 attention_heads, 904 ffn_embed_dim, 905 num_layers, 906 embedding_layer, # torch.nn.Embedding. Must have a padding_idx field 907 dropout=0, 908 normalize_before=False, 909 torch_encoder=None, # torch encoder that you can map weights from 910 activation="relu", 911 ): 912 super().__init__() 913 914 cfg = fairseq_transformer.TransformerConfig() 915 cfg.decoder.embed_dim = embed_dim 916 cfg.decoder.output_dim = embed_dim 917 cfg.decoder.attention_heads = attention_heads 918 cfg.decoder.ffn_embed_dim = ffn_embed_dim 919 cfg.dropout = dropout 920 cfg.decoder.normalize_before = normalize_before 921 cfg.decoder.layers = num_layers 922 # make embedding behavior same as other encoders 923 cfg.no_token_positional_embeddings = True 924 cfg.no_scale_embedding = True 925 cfg.activation_fn = activation 926 927 dictionary = {} # TODO: verify what this is 928 929 self.decoder = fairseq_transformer.TransformerDecoder( 930 cfg, 931 dictionary, 932 embedding_layer, 933 no_encoder_attn=True, 934 output_projection=None, 935 ) 936 937 if torch_encoder is not None: 938 self.decoder = torch_to_fairseq(torch_encoder, self.decoder) # noqa: F821 939 self.decoder = self.decoder.eval().cuda().half() 940 941 def forward( 942 self, 943 tokens, 944 src_lengths=None, 945 with_triangle_mask=False, 946 incremental_state=None, 947 ): 948 return self.decoder( 949 prev_output_tokens=tokens, 950 encoder_out=None, 951 incremental_state=incremental_state, 952 features_only=True, 953 full_context_alignment=not with_triangle_mask, 954 alignment_layer=None, 955 alignment_heads=None, 956 src_lengths=src_lengths, 957 return_all_hiddens=False, 958 )[0] 959 960 @parametrize("input_dim,attn_mask_dim,is_causal", 961 [(3, None, False), (3, 2, False), (3, 2, True), (3, 3, False), (3, 3, True), 962 (4, None, False), (4, 2, False), (4, 2, True), (4, 4, False), (4, 4, True)], 963 name_fn=lambda input_dim, attn_dim, is_causal: ( 964 f"{input_dim}D_input_dim_" + ( 965 f"{attn_dim}D_{'causal_' if is_causal else ''}attn_mask" 966 if attn_dim is not None else "no_attn_mask"))) 967 @parametrize("dropout_p", [0.0, 0.2, 0.5]) 968 @sdpa_kernel(backends=[SDPBackend.MATH]) 969 def test_scaled_dot_product_attention(self, device, input_dim, attn_mask_dim, is_causal, dropout_p): 970 def sdp_ref( 971 q, 972 k, 973 v, 974 attn_mask=None, 975 dropout_p=0.0): 976 E = q.size(-1) 977 q = q / math.sqrt(E) 978 # (B, Nt, E) x (B, E, Ns) -> (B, Nt, Ns) 979 if attn_mask is not None: 980 attn = torch.baddbmm(attn_mask, q, k.transpose(-2, -1)) 981 else: 982 attn = torch.bmm(q, k.transpose(-2, -1)) 983 984 attn = torch.nn.functional.softmax(attn, dim=-1) 985 if dropout_p > 0.0: 986 attn = torch.nn.functional.dropout(attn, p=dropout_p) 987 # (B, Nt, Ns) x (B, Ns, E) -> (B, Nt, E) 988 output = torch.bmm(attn, v) 989 return output 990 # TODO: Support cross-device / dtype testing properly when instantiate_device_type_tests() is used. 991 dtypes = [torch.double, torch.float] 992 for dtype in dtypes: 993 994 def rand_tensor(*shape): 995 return torch.randn(shape, device=device, dtype=dtype) 996 997 # This test compares python and C++ implementations of SDP. 998 N, N_prime, L, S, E = 5, 2, 4, 3, 6 999 if input_dim == 3: 1000 query = rand_tensor(N, L, E) 1001 key = rand_tensor(N, S, E) 1002 value = rand_tensor(N, S, E) 1003 elif input_dim == 4: 1004 query = rand_tensor(N, N_prime, L, E) 1005 key = rand_tensor(N, N_prime, S, E) 1006 value = rand_tensor(N, N_prime, S, E) 1007 else: 1008 self.fail(f'Invalid input_dim {input_dim} encountered in SDP test') 1009 1010 attn_mask = None 1011 if attn_mask_dim is not None: 1012 assert attn_mask_dim in [2, input_dim] 1013 mask_size = (L, S) if attn_mask_dim == 2 else ((N, L, S) if input_dim == 3 else (N, N_prime, L, S)) 1014 attn_mask = (torch.ones(mask_size, device=device, dtype=torch.bool).tril() if is_causal 1015 else torch.randint(0, 2, size=mask_size, device=device, dtype=torch.bool)) 1016 1017 with freeze_rng_state(): 1018 # Python impl only supports float mask and 3D inputs. 1019 attn_mask_float = attn_mask 1020 if attn_mask_float is not None: 1021 attn_mask_float = torch.zeros_like(attn_mask, dtype=query.dtype) 1022 attn_mask_float.masked_fill_(attn_mask.logical_not(), float("-inf")) 1023 q, k, v = query.view(-1, L, E), key.view(-1, S, E), value.view(-1, S, E) 1024 a = attn_mask_float 1025 if a is not None and attn_mask_dim > 3: 1026 a = a.view(-1, L, S) 1027 expected = sdp_ref(q, k, v, attn_mask=a, dropout_p=dropout_p) 1028 if input_dim > 3: 1029 expected = expected.view(-1, N_prime, L, E) 1030 1031 with freeze_rng_state(): 1032 if is_causal: 1033 # NB: Don't pass attn_mask here 1034 actual = torch.nn.functional.scaled_dot_product_attention( 1035 query, key, value, None, dropout_p, is_causal) 1036 1037 # Error case: both explicit attn_mask and is_causal are set 1038 with self.assertRaisesRegex(RuntimeError, 1039 "Explicit attn_mask should not be set when is_causal=True"): 1040 torch.nn.functional.scaled_dot_product_attention( 1041 query, key, value, attn_mask, dropout_p, is_causal) 1042 else: 1043 actual = torch.nn.functional.scaled_dot_product_attention( 1044 query, key, value, attn_mask, dropout_p, is_causal) 1045 1046 self.assertEqual(actual, expected) 1047 1048 if attn_mask_dim is None: 1049 q = q.double().clone() 1050 k = k.double().clone() 1051 v = v.double().clone() 1052 q.requires_grad_() 1053 k.requires_grad_() 1054 v.requires_grad_() 1055 1056 assert gradcheck(lambda *args, **kwargs: wrapper_set_seed(sdp_ref, *args, **kwargs), 1057 (q, k, v, attn_mask, dropout_p)) 1058 assert gradcheck(lambda *args, **kwargs: 1059 wrapper_set_seed(torch.nn.functional.scaled_dot_product_attention, *args, **kwargs), 1060 (q, k, v, attn_mask, dropout_p)) 1061 1062 def test_incompatible_mask(self, device): 1063 def ones_tensor(*shape): 1064 return torch.ones(shape, dtype=torch.float32) 1065 S, L, E, H = 1, 2, 4, 1 1066 qkv = ones_tensor(S, L, E) 1067 1068 mha = nn.MultiheadAttention(E, H) 1069 mha.in_proj_weight = Parameter(torch.ones((E * 3, E))) 1070 mha.out_proj.weight = Parameter(torch.ones((E, E))) 1071 qkv = qkv.to(float) 1072 kpm = ones_tensor(S, L) * float("-inf") 1073 am = ones_tensor(L, L).to(bool) 1074 1075 def func(): 1076 return mha(qkv, qkv, qkv, need_weights=False, key_padding_mask=kpm, attn_mask=am) 1077 1078 self.assertRaises(RuntimeError, func) 1079 1080 @unittest.skipIf(TEST_WITH_CROSSREF, 'Fastpath not available with crossref') 1081 @torch.no_grad() 1082 def test_mask_check_fastpath(self): 1083 """ 1084 Test that fastpath is executed independently of the masks that are passed. 1085 If the passed key padding mask is left aligned or mask_check=False, test that nested tensors are used 1086 (sparsity fastpath), otherwise use fastpath with traditional tensors. 1087 Also test that fast path is executed with both key padding mask and attention mask passed at the same time. 1088 """ 1089 1090 x = torch.Tensor([[[1, 2], [3, 4], [5, 6]]]).to(torch.float) 1091 1092 def _test_fastpath(model, key_padding_mask, mock_return_value, attn_mask=None, nested_tensors=True): 1093 with patch('torch._transformer_encoder_layer_fwd') as fastpath_mock: 1094 fastpath_mock.return_value = mock_return_value 1095 model(x, src_key_padding_mask=key_padding_mask, mask=attn_mask) 1096 1097 # If mock was called, fastpath was taken 1098 self.assertTrue(fastpath_mock.called) 1099 1100 # If mock was called with nested tensors, sparsity fastpath was taken 1101 for call_args, _ in fastpath_mock.call_args_list: 1102 self.assertEqual(call_args[0].is_nested, nested_tensors) 1103 1104 encoder_layer = torch.nn.TransformerEncoderLayer(d_model=2, nhead=2, dim_feedforward=8, batch_first=True) 1105 1106 model = torch.nn.TransformerEncoder(encoder_layer, num_layers=2, enable_nested_tensor=True, mask_check=True) 1107 model.eval() 1108 1109 aligned_key_padding_mask = torch.Tensor([[0, 0, 1]]).to(torch.bool) 1110 not_aligned_key_padding_mask = torch.Tensor([[1, 0, 1]]).to(torch.bool) 1111 attn_mask = torch.Tensor([[1, 0, 1], [0, 1, 0], [1, 0, 1]]).to(torch.bool) 1112 nested_tensor_return_value = torch.nested.nested_tensor([torch.ones((2, 2), dtype=torch.float)]) 1113 tensor_return_value = torch.ones((1, 3, 2), dtype=torch.float) 1114 1115 # Left aligned mask results in sparsity fastpath 1116 _test_fastpath(model, aligned_key_padding_mask, nested_tensor_return_value, nested_tensors=True) 1117 1118 # Not aligned mask results in fastpath 1119 _test_fastpath(model, not_aligned_key_padding_mask, tensor_return_value, nested_tensors=False) 1120 1121 model = torch.nn.TransformerEncoder(encoder_layer, num_layers=2, enable_nested_tensor=False, mask_check=True) 1122 model.eval() 1123 1124 # If nested tensor disabled, fastpath is always taken 1125 _test_fastpath(model, aligned_key_padding_mask, tensor_return_value, nested_tensors=False) 1126 _test_fastpath(model, not_aligned_key_padding_mask, tensor_return_value, nested_tensors=False) 1127 # Fast path is taken if both attention mask and key padding mask are present 1128 _test_fastpath(model, aligned_key_padding_mask, tensor_return_value, attn_mask=attn_mask, nested_tensors=False) 1129 1130 model = torch.nn.TransformerEncoder(encoder_layer, num_layers=2, enable_nested_tensor=True, mask_check=False) 1131 model.eval() 1132 1133 # Mask check disabled results in sparisty fastpath, independently of the mask 1134 _test_fastpath(model, aligned_key_padding_mask, nested_tensor_return_value, nested_tensors=True) 1135 _test_fastpath(model, not_aligned_key_padding_mask, nested_tensor_return_value, nested_tensors=True) 1136 1137 # Test failing MHA when bias was NoneType 1138 def test_bias_is_none(self): 1139 x = torch.rand((1, 5, 10)) 1140 model = torch.nn.modules.activation.MultiheadAttention(10, 1, bias=False, batch_first=True) 1141 model.eval() 1142 model(x, x, x) 1143 # completes without error 1144 1145 def test_transformer_bias_is_none(self, device): 1146 batch_size = 2 1147 seqlen = 3 1148 d_model = 8 1149 nhead = 4 1150 1151 encoder_layer = torch.nn.TransformerEncoderLayer(d_model, nhead, bias=False, batch_first=True, device=device) 1152 encoder_layer.eval() 1153 x = torch.randn(batch_size, seqlen, d_model, device=device) 1154 # runs without error 1155 encoder_layer(x) 1156 1157 with self.assertWarnsRegex(UserWarning, "encoder_layer.self_attn was passed bias=False"): 1158 encoder = torch.nn.TransformerEncoder(encoder_layer, num_layers=1).eval() 1159 encoder(x) 1160 1161 with self.assertWarnsRegex(UserWarning, "self_attn was passed bias=False"): 1162 transformer = torch.nn.Transformer( 1163 d_model=d_model, nhead=nhead, bias=False, batch_first=True, device=device 1164 ).eval() 1165 transformer(x, x) 1166 1167 def test_train_with_is_causal(self, device): 1168 # training with is_causal 1169 S, L, E, H = 1, 2, 2, 1 1170 layer = nn.TransformerEncoderLayer( 1171 d_model=2, 1172 dim_feedforward=4, 1173 nhead=H, 1174 batch_first=True, 1175 activation="gelu", 1176 dropout=0, 1177 ) 1178 criterion = nn.MSELoss() 1179 encoder = nn.TransformerEncoder(layer, 2).to(device) 1180 optimizer = optim.SGD(encoder.parameters(), lr=0.1, momentum=0.9) 1181 encoder.train() 1182 1183 encoder.train() 1184 optimizer.zero_grad() 1185 inputs = torch.randn(S, L, E).to(device) 1186 mask = torch.nn.Transformer.generate_square_subsequent_mask( 1187 inputs.size(1), device=device 1188 ) 1189 1190 outputs = encoder(inputs, mask=mask, is_causal=True) 1191 1192 loss = criterion(outputs[:, 0:2, :], inputs[:, 0:2, :]) 1193 loss.backward() 1194 optimizer.step() 1195 1196 # inference with is_causal 1197 t_qvk = torch.randn((S, L, E), device=device, dtype=torch.float32) 1198 mha = nn.MultiheadAttention(E, H).to(device) 1199 mask = torch.nn.Transformer.generate_square_subsequent_mask( 1200 S, device=device 1201 ) 1202 1203 attn_out, _ = mha(t_qvk, t_qvk, t_qvk, attn_mask=mask, is_causal=True) 1204 1205 # Can't give only is_causal 1206 attn_mask = torch.randint(0, 2, size=(L, L), device=device, dtype=torch.bool) 1207 with self.assertRaises(RuntimeError): 1208 _ = mha(t_qvk, t_qvk, t_qvk, is_causal=True) 1209 1210 # # Passing a causal mask sets is_causal to 1 1211 causal_mask = torch.triu( 1212 torch.ones(L, L, device=inputs.device) * float('-inf'), diagonal=1 1213 ).to(torch.bool) 1214 1215 mock_layer = MagicMock(torch.nn.MultiheadAttention(E, H), return_value=inputs) 1216 encoder.layers[1] = mock_layer 1217 outputs = encoder(inputs, mask=causal_mask) 1218 mock_layer.assert_called_with(ANY, src_mask=ANY, is_causal=True, src_key_padding_mask=ANY) 1219 1220 # check expected numerical values with all kernels 1221 self.is_causal_kernels([SDPBackend.MATH], device) 1222 1223 def is_causal_kernels(self, kernels, device): 1224 def ones_tensor(*shape): 1225 return torch.ones(shape, device=device, dtype=torch.float32).to(device) 1226 S, L, E, H = 1, 2, 4, 1 1227 qkv = ones_tensor(S, L, E) 1228 1229 mha = nn.MultiheadAttention(E, H).to(device) 1230 mha.in_proj_weight = Parameter(torch.ones((E * 3, E), device=device)) 1231 mha.out_proj.weight = Parameter(torch.ones((E, E), device=device)) 1232 expected = torch.ones(size=(S, L, E)).to(device) * 16 1233 mask = torch.nn.Transformer.generate_square_subsequent_mask( 1234 qkv.size(1), device=device 1235 ) 1236 1237 for kernel in kernels: 1238 with sdpa_kernel(backends=[kernel]): 1239 actual, _ = mha(qkv, qkv, qkv, attn_mask=mask, need_weights=False, is_causal=True) 1240 self.assertTrue(torch.equal(actual, expected)) 1241 1242 if kernel != SDPBackend.MATH: 1243 # fails with embedding size not multiple of 4 1244 with self.assertRaisesRegex(RuntimeError, "No available kernel"): 1245 qkv_f, mha_f = ones_tensor(S, L, 2), nn.MultiheadAttention(2, H).to(device) 1246 mask = torch.nn.Transformer.generate_square_subsequent_mask( 1247 qkv_f.size(1), device=device 1248 ) 1249 _ = mha_f(qkv_f, qkv_f, qkv_f, attn_mask=mask, need_weights=False, is_causal=True) 1250 torch.cuda.synchronize() 1251 1252 @skipIfRocm # Missing EFFICIENT_ATTENTION 1253 @unittest.skipIf( 1254 not PLATFORM_SUPPORTS_FLASH_ATTENTION, "Platform does not supposrt fused SDPA or pre-SM80 hardware" 1255 ) 1256 def test_is_causal_gpu(self): 1257 device = 'cuda' 1258 self.is_causal_kernels([SDPBackend.MATH, SDPBackend.EFFICIENT_ATTENTION], device) 1259 1260 def test_script_mha_in_proj_weight_none(self): 1261 mha = torch.nn.MultiheadAttention( 1262 embed_dim=128, num_heads=8, kdim=256, vdim=256 1263 ).eval() 1264 1265 torch.jit.script(mha) 1266 1267 @unittest.skipIf(TEST_WITH_CROSSREF, 'Fastpath not available with crossref') 1268 @torch.no_grad() 1269 def test_disable_fastpath(self, device): 1270 def _test_te_fastpath_called(model, args, kwargs=None, return_value=None, is_called=True): 1271 if kwargs is None: 1272 kwargs = {} 1273 with patch('torch._transformer_encoder_layer_fwd') as fastpath_mock: 1274 fastpath_mock.return_value = return_value 1275 output = model(*args, **kwargs) 1276 self.assertTrue(fastpath_mock.called == is_called) 1277 1278 def _test_mha_fastpath_called(model, args, kwargs=None, return_value=None, is_called=True): 1279 if kwargs is None: 1280 kwargs = {} 1281 with patch('torch._native_multi_head_attention') as fastpath_mock: 1282 fastpath_mock.return_value = return_value 1283 output = model(*args, **kwargs) 1284 self.assertTrue(fastpath_mock.called == is_called) 1285 1286 inp = torch.tensor([[[1, 2], [3, 4], [5, 6]]], dtype=torch.float32, device=device) 1287 aligned_key_padding_mask = torch.tensor([[0, 0, 1]], dtype=torch.bool, device=device) 1288 src_key_padding_mask = torch.tensor([[1, 0, 1]], dtype=torch.bool, device=device) 1289 attn_mask = torch.tensor([[1, 0, 1], [0, 1, 0], [1, 0, 1]], dtype=torch.bool, device=device) 1290 te_return_value = torch.ones((1, 3, 2), dtype=torch.float32) 1291 1292 encoder_layer = torch.nn.TransformerEncoderLayer(d_model=2, nhead=2, dim_feedforward=8, batch_first=True) 1293 te = torch.nn.TransformerEncoder(encoder_layer, num_layers=2, enable_nested_tensor=True, mask_check=True) 1294 te = te.to(device).eval() 1295 1296 t = torch.nn.Transformer(d_model=2, nhead=2, batch_first=True, device=device).eval() 1297 src = torch.tensor([[[0, 1], [2, 3], [4, 5]]], dtype=torch.float32, device=device) 1298 tgt = torch.tensor([[[0, 1], [2, 3], [4, 5], [6, 7]]], dtype=torch.float32, device=device) 1299 t_return_value = torch.ones((1, 3, 2), dtype=torch.float32, device=device) 1300 1301 mha = nn.MultiheadAttention(2, 2, batch_first=True, device=device).eval() 1302 q = torch.tensor([[[0, 1], [2, 3]]], dtype=torch.float32, device=device) 1303 mha_return_value = torch.ones((1, 3, 2), dtype=torch.float32, device=device) 1304 1305 _test_te_fastpath_called( 1306 te, (inp,), kwargs={'src_key_padding_mask': src_key_padding_mask}, 1307 return_value=te_return_value, is_called=True 1308 ) 1309 _test_te_fastpath_called(t, (src, tgt), return_value=t_return_value, is_called=True) 1310 _test_mha_fastpath_called(mha, (q, q, q,), return_value=mha_return_value, is_called=True) 1311 1312 torch.backends.mha.set_fastpath_enabled(False) 1313 _test_te_fastpath_called( 1314 te, (inp,), kwargs={'src_key_padding_mask': src_key_padding_mask}, 1315 return_value=te_return_value, is_called=False 1316 ) 1317 _test_te_fastpath_called(t, (src, tgt), return_value=t_return_value, is_called=False) 1318 _test_mha_fastpath_called(mha, (q, q, q,), return_value=mha_return_value, is_called=False) 1319 1320 torch.backends.mha.set_fastpath_enabled(True) 1321 _test_te_fastpath_called( 1322 te, (inp,), kwargs={'src_key_padding_mask': src_key_padding_mask}, 1323 return_value=te_return_value, is_called=True 1324 ) 1325 _test_te_fastpath_called(t, (src, tgt), return_value=t_return_value, is_called=True) 1326 _test_mha_fastpath_called(mha, (q, q, q,), return_value=mha_return_value, is_called=True) 1327 1328 1329class TestSDPAFailureModes(NNTestCase): 1330 """ Used to test the failure modes of scaled_dot_product_attention 1331 """ 1332 _do_cuda_memory_leak_check = True 1333 _do_cuda_non_default_stream = True 1334 1335 @onlyCUDA 1336 @unittest.skipIf( 1337 not PLATFORM_SUPPORTS_FLASH_ATTENTION or not isSM8XDevice, 1338 "Does not support fused SDPA or not SM86+ hardware", 1339 ) 1340 @parametrize("head_dim", [193, 204, 256]) 1341 @parametrize("dropout_p", [0.0, 0.2]) 1342 def test_flash_backward_failure_sm86plus(self, device, head_dim: int, dropout_p: float): 1343 dtype = torch.float16 1344 make_tensor = partial(torch.rand, device=device, dtype=dtype) 1345 # See check_requires_grad_and_head_dim_gt192_constraints_on_sm86_89 in 1346 # pytorch/aten/src/ATen/native/transformers/cuda/sdp_utils.h 1347 size = (2, 2, 4, head_dim) 1348 q, k, v = make_tensor(size), make_tensor(size), make_tensor(size) 1349 1350 with sdpa_kernel(backends=[SDPBackend.MATH]): 1351 math_ref = torch.nn.functional.scaled_dot_product_attention(q, k, v, None, 0.0, False) 1352 1353 with sdpa_kernel(backends=[SDPBackend.FLASH_ATTENTION]): 1354 # Should not fail because inputs don't require grad 1355 flash_ref = torch.nn.functional.scaled_dot_product_attention(q, k, v, None, 0.0, False) 1356 1357 self.assertEqual(math_ref, flash_ref, atol=1e-3, rtol=1e-3) 1358 1359 # Should fail because inputs require grad 1360 q = make_tensor(size, requires_grad=True) 1361 k = make_tensor(size, requires_grad=True) 1362 v = make_tensor(size, requires_grad=True) 1363 if 192 < head_dim <= 224 or (head_dim > 224 and dropout_p != 0.0): 1364 self.assertRaises( 1365 RuntimeError, 1366 lambda: torch.nn.functional.scaled_dot_product_attention( 1367 q, k, v, None, dropout_p, False 1368 ), 1369 ) 1370 else: 1371 flash_ref = torch.nn.functional.scaled_dot_product_attention(q, k, v, None, dropout_p, False) 1372 1373 @onlyCUDA 1374 def test_dispatch_fails_no_backend(self, device): 1375 dtype = torch.float16 1376 with sdpa_kernel(backends=[SDPBackend.ERROR]): 1377 size = (2, 3, 4) 1378 q = torch.randn(size, device=device, dtype=dtype) 1379 k = torch.randn(size, device=device, dtype=dtype) 1380 v = torch.randn(size, device=device, dtype=dtype) 1381 self.assertRaisesRegex(RuntimeError, "No viable backend for scaled_dot_product_attention was found.", 1382 lambda: torch._fused_sdp_choice(q, k, v)) 1383 self.assertRaisesRegex(RuntimeError, "No viable backend for scaled_dot_product_attention was found.", 1384 lambda: torch.nn.functional.scaled_dot_product_attention(q, k, v)) 1385 1386 @onlyCUDA 1387 @unittest.skipIf(not PLATFORM_SUPPORTS_FUSED_ATTENTION, "Does not support fused scaled dot product attention") 1388 @parametrize( 1389 "kernel", 1390 PLATFORM_SPECIFIC_SDPA, 1391 ) 1392 def test_invalid_fused_inputs_dim_3(self, device, kernel: SDPBackend): 1393 with sdpa_kernel(backends=[kernel]): 1394 # Dim is not 4 1395 size = (2, 3, 8) 1396 dtype = torch.float16 1397 q = torch.randn(size, device=device, dtype=dtype) 1398 k = torch.randn(size, device=device, dtype=dtype) 1399 v = torch.randn(size, device=device, dtype=dtype) 1400 with self.assertWarnsRegex(UserWarning, "Both fused kernels requires query, key and value to be 4 dimensional"): 1401 self.assertRaises(RuntimeError, lambda: torch.nn.functional.scaled_dot_product_attention( 1402 q, k, v, None, 0.0, False)) 1403 1404 @onlyCUDA 1405 @unittest.skipIf(not PLATFORM_SUPPORTS_FUSED_ATTENTION, "Does not support fused scaled dot product attention") 1406 @parametrize( 1407 "kernel", 1408 PLATFORM_SPECIFIC_SDPA, 1409 ) 1410 def test_invalid_fused_inputs_broadcast(self, device, kernel: SDPBackend): 1411 with sdpa_kernel(backends=[kernel]): 1412 # Fused Kernels don't support broadcasting for dense inputs 1413 dtype = torch.float16 1414 size = (2, 4, 3, 8) 1415 size_broadcast = (1, 4, 3, 8) 1416 q = torch.randn(size_broadcast, device=device, dtype=dtype) 1417 k = torch.randn(size, device=device, dtype=dtype) 1418 v = torch.randn(size, device=device, dtype=dtype) 1419 self.assertRaises(RuntimeError, lambda: torch.nn.functional.scaled_dot_product_attention( 1420 q, k, v, None, 0.0, False)) 1421 1422 @onlyCUDA 1423 @unittest.skipIf(not PLATFORM_SUPPORTS_FUSED_ATTENTION, "Does not support fused scaled dot product attention") 1424 @parametrize("kernel", PLATFORM_SPECIFIC_SDPA) 1425 def test_invalid_sequence_lengths(self, device, kernel: SDPBackend): 1426 with sdpa_kernel(backends=[kernel]): 1427 # Passing in a q,k,v with 0 length sequences will error 1428 dtype = torch.float16 1429 make_tensor = partial(torch.rand, device=device, dtype=dtype) 1430 size = SdpaShape(2, 2, 0, 8) 1431 q, k, v = make_tensor(size), make_tensor(size), make_tensor(size) 1432 with self.assertWarnsRegex(UserWarning, "Both fused kernels do not support zero seq_len_q or seq_len_kv."): 1433 self.assertRaises(RuntimeError, lambda: torch.nn.functional.scaled_dot_product_attention( 1434 q, k, v, None, 0.0, False)) 1435 1436 @onlyCUDA 1437 @unittest.skipIf(not PLATFORM_SUPPORTS_FUSED_ATTENTION, "Does not support fused scaled dot product attention") 1438 @parametrize("kernel", PLATFORM_SPECIFIC_SDPA) 1439 def test_invalid_last_dim_stride(self, device, kernel: SDPBackend): 1440 with sdpa_kernel(backends=[kernel]): 1441 # Passing in a q,k,v with last dim stride not equal to 1 will error 1442 dtype = torch.float16 1443 make_tensor = partial(torch.rand, device=device, dtype=dtype) 1444 size = SdpaShape(2, 2, 8, 8) 1445 q, k, v = make_tensor(size), make_tensor(size), make_tensor(size) 1446 q.as_strided_(size, [2, 2, 2, 2]) 1447 with self.assertWarnsRegex(UserWarning, "Both fused kernels require the last dimension of the input to have stride 1."): 1448 self.assertRaises(RuntimeError, lambda: torch.nn.functional.scaled_dot_product_attention( 1449 q, k, v, None, 0.0, False)) 1450 1451 @onlyCUDA 1452 @unittest.skipIf(not PLATFORM_SUPPORTS_FLASH_ATTENTION, "Does not flash_attention fused scaled dot product attention") 1453 @parametrize("kernel", [SDPBackend.FLASH_ATTENTION, SDPBackend.EFFICIENT_ATTENTION]) 1454 def test_invalid_fused_inputs_head_dim(self, device, kernel: SDPBackend): 1455 with sdpa_kernel(backends=[kernel]): 1456 # The embed dim per head is not divisible by 8 for flash attention 1457 dtype = torch.float16 1458 make_tensor = partial(torch.rand, device=device, dtype=dtype) 1459 size = SdpaShape(2, 2, 3, 9) if kernel == SDPBackend.EFFICIENT_ATTENTION else SdpaShape(2, 2, 3, 257) 1460 if TEST_WITH_ROCM: # On ROCM, FA and EA share the backend GPU kernels 1461 size = SdpaShape(2, 2, 3, 257) 1462 q, k, v = make_tensor(size), make_tensor(size), make_tensor(size) 1463 self.assertRaises(RuntimeError, lambda: torch.nn.functional.scaled_dot_product_attention( 1464 q, k, v, None, 0.0, False)) 1465 1466 @onlyCUDA 1467 @unittest.skipIf(not PLATFORM_SUPPORTS_FUSED_ATTENTION, "Does not support fused scaled dot product attention") 1468 @parametrize( 1469 "kernel", 1470 PLATFORM_SPECIFIC_SDPA, 1471 ) 1472 def test_invalid_fused_inputs_invalid_dtype(self, device, kernel: SDPBackend): 1473 with sdpa_kernel(backends=[kernel]): 1474 # Invalid dtype for both Flash Attention and Mem Efficient Attention 1475 size = SdpaShape(2, 2, 3, 16) 1476 make_tensor = partial(torch.rand, device=device, dtype=torch.float64) 1477 q, k, v = make_tensor(size), make_tensor(size), make_tensor(size) 1478 self.assertRaises(RuntimeError, lambda: torch.nn.functional.scaled_dot_product_attention( 1479 q, k, v, None, 0.0, False)) 1480 1481 @onlyCUDA 1482 @unittest.skipIf(not PLATFORM_SUPPORTS_FLASH_ATTENTION, "Does not support flash attention") 1483 @parametrize("kernel", [SDPBackend.FLASH_ATTENTION]) 1484 def test_invalid_fused_inputs_attn_mask_present(self, device, kernel: SDPBackend): 1485 with sdpa_kernel(backends=[kernel]): 1486 # Failures for unsupported SDP args 1487 size = SdpaShape(2, 2, 3, 16) 1488 make_tensor = partial(torch.rand, size, device=device, dtype=torch.float16) 1489 q, k, v = make_tensor(), make_tensor(), make_tensor() 1490 # Non-None attention mask 1491 mask = torch.ones((2, 2, 3, 3), device=device, dtype=q.dtype) 1492 self.assertRaises(RuntimeError, lambda: torch.nn.functional.scaled_dot_product_attention( 1493 q, k, v, mask, 0.0, False)) 1494 1495 @onlyCUDA 1496 @unittest.skipIf(not PLATFORM_SUPPORTS_FLASH_ATTENTION, "Does not support fused SDPA or pre-SM80 hardware") 1497 def test_unaligned_tensors(self, device): 1498 # The alignment is depdent on arch so we specifiy SM80OrLater 1499 dtype = torch.float16 1500 size = SdpaShape(2, 2, 8, 5) 1501 make_tensor = partial(torch.rand, size, device=device, dtype=dtype) 1502 q, k, v = make_tensor(), make_tensor(), make_tensor() 1503 with sdpa_kernel(backends=[SDPBackend.EFFICIENT_ATTENTION]): 1504 ctxmgr = self.assertRaises(RuntimeError) if not TEST_WITH_ROCM else contextlib.nullcontext() 1505 with ctxmgr: 1506 torch.nn.functional.scaled_dot_product_attention(q, k, v, None, 0.0, False) 1507 1508 @onlyCUDA 1509 @unittest.skipIf(not PLATFORM_SUPPORTS_FLASH_ATTENTION, "Does not support fused SDPA or pre-SM80 hardware") 1510 def test_flash_fail_fp32(self, device): 1511 dtype = torch.float 1512 size = SdpaShape(16, 16, 32, 32) 1513 make_tensor = partial(torch.rand, size, device=device, dtype=dtype) 1514 q, k, v = make_tensor(), make_tensor(), make_tensor() 1515 with sdpa_kernel(backends=[SDPBackend.FLASH_ATTENTION]): 1516 with self.assertWarnsRegex(UserWarning, "Expected query, key and value to all be of dtype: {Half, BFloat16}"): 1517 self.assertRaises(RuntimeError, lambda: torch.nn.functional.scaled_dot_product_attention( 1518 q, k, v, None, 0.0, False)) 1519 1520 @onlyCUDA 1521 @unittest.skipIf(not PLATFORM_SUPPORTS_FLASH_ATTENTION, "Does not support SDPA or pre-SM80 hardware") 1522 def test_flash_autocast_fp32_float16(self, device): 1523 dtype = torch.float 1524 size = SdpaShape(16, 16, 32, 32) 1525 make_tensor = partial(torch.rand, size, device=device, dtype=dtype) 1526 q, k, v = make_tensor(), make_tensor(), make_tensor() 1527 with torch.autocast(device_type='cuda', dtype=torch.float16): 1528 with sdpa_kernel(backends=[SDPBackend.FLASH_ATTENTION]): 1529 _ = torch.nn.functional.scaled_dot_product_attention( 1530 q, k, v, None, 0.0, False) 1531 1532 @onlyCUDA 1533 @unittest.skipIf(not PLATFORM_SUPPORTS_FLASH_ATTENTION, "Does not support SDPA or pre-SM80 hardware") 1534 def test_flash_autocast_fp32_bfloat16(self, device): 1535 dtype = torch.float 1536 size = SdpaShape(16, 16, 32, 32) 1537 make_tensor = partial(torch.rand, size, device=device, dtype=dtype) 1538 q, k, v = make_tensor(), make_tensor(), make_tensor() 1539 with torch.autocast(device_type='cuda', dtype=torch.bfloat16): 1540 with sdpa_kernel(backends=[SDPBackend.FLASH_ATTENTION]): 1541 _ = torch.nn.functional.scaled_dot_product_attention( 1542 q, k, v, None, 0.0, False) 1543 1544 # Note: do not truncate the list according to platforms. These tests should always raise errors. 1545 @parametrize("kernel", [SDPBackend.MATH, SDPBackend.FLASH_ATTENTION, SDPBackend.EFFICIENT_ATTENTION]) 1546 def test_invalid_inputs_different_datatypes(self, device, kernel: SDPBackend): 1547 with sdpa_kernel(backends=[kernel]): 1548 # Different datatypes 1549 shape = (1, 4, 8, 16) 1550 query = torch.randn(shape, dtype=torch.float32, device=device) 1551 key = torch.randn(shape, dtype=torch.float16, device=device) 1552 value = torch.randn(shape, dtype=torch.float16, device=device) 1553 self.assertRaises(RuntimeError, lambda: F.scaled_dot_product_attention(query, key, value)) 1554 1555 @onlyCUDA 1556 @parametrize("kernel", [SDPBackend.MATH, SDPBackend.FLASH_ATTENTION, SDPBackend.EFFICIENT_ATTENTION]) 1557 def test_invalid_inputs_different_devices(self, device, kernel: SDPBackend): 1558 # Different devices 1559 shape = (1, 4, 8, 16) 1560 query = torch.randn(shape, dtype=torch.float32, device=device) 1561 key = torch.randn(shape, dtype=torch.float16, device='cpu') 1562 value = torch.randn(shape, dtype=torch.float16, device='cpu') 1563 self.assertRaises(RuntimeError, lambda: F.scaled_dot_product_attention(query, key, value)) 1564 1565 @parametrize("kernel", [SDPBackend.MATH, SDPBackend.FLASH_ATTENTION, SDPBackend.EFFICIENT_ATTENTION]) 1566 def test_invalid_inputs_1_dimensional_inputs(self, device, kernel: SDPBackend): 1567 with sdpa_kernel(backends=[kernel]): 1568 # 1 dimensional input 1569 shape = (1, 4) 1570 query = torch.randn(4, dtype=torch.float16, device=device) 1571 key = torch.randn(shape, dtype=torch.float16, device=device) 1572 value = torch.randn(shape, dtype=torch.float16, device=device) 1573 self.assertRaises(RuntimeError, lambda: F.scaled_dot_product_attention(query, key, value)) 1574 1575 @onlyCUDA 1576 @skipIfRocm # Missing EFFICIENT_ATTENTION 1577 @unittest.skipIf(not PLATFORM_SUPPORTS_MEM_EFF_ATTENTION, "Fused SDPA was not built for this system") 1578 def test_fused_kernels_nested_broadcasting_error_cases(self, device): 1579 # one of k,v needs to be broadcasted and other has non consistent seq_len dim 1580 rand_nested_tensor = partial(rand_sdpa_tensor, type="nested", device=device, dtype=torch.float32) 1581 batch, num_heads, head_dim = 32, 8, 64 1582 seq_lens_q = torch.randint(low=1, high=32, size=(batch,)).tolist() 1583 seq_lens_v = torch.randint(low=1, high=32, size=(batch,)).tolist() 1584 1585 q_shape = SdpaShape(batch, num_heads, seq_lens_q, head_dim) 1586 k_shape = SdpaShape(1, num_heads, 1, head_dim) 1587 v_shape = SdpaShape(batch, num_heads, seq_lens_v, head_dim) 1588 1589 query = rand_nested_tensor(q_shape).transpose(1, 2) 1590 key = rand_nested_tensor(k_shape).transpose(1, 2) 1591 value = rand_nested_tensor(v_shape).transpose(1, 2) 1592 1593 with sdpa_kernel(backends=[SDPBackend.EFFICIENT_ATTENTION]): 1594 with self.assertRaisesRegex(RuntimeError, "No available kernel"): 1595 torch.nn.functional.scaled_dot_product_attention( 1596 query, key, value, attn_mask=None, dropout_p=0.0, is_causal=False) 1597 1598 @onlyCUDA 1599 @unittest.skipIf(not PLATFORM_SUPPORTS_FLASH_ATTENTION, "Fused SDPA was not built for this system") 1600 def test_nested_fails_on_padding_head_dim(self, device): 1601 dtype = torch.bfloat16 1602 seq_len_list = [2, 4, 5, 6, 7] 1603 shape = SdpaShape(5, 8, seq_len_list, 57) 1604 make_tensor = partial(rand_sdpa_tensor, shape=shape, type="nested", device=device, dtype=dtype) 1605 q, k, v = make_tensor(), make_tensor(), make_tensor() 1606 with sdpa_kernel(backends=[SDPBackend.FLASH_ATTENTION]): 1607 with self.assertWarnsRegex(UserWarning, "For NestedTensor inputs, Flash attention requires"): 1608 self.assertRaises(RuntimeError, lambda: torch.nn.functional.scaled_dot_product_attention( 1609 q, k, v, None, 0.0, False)) 1610 1611 @onlyCUDA 1612 @unittest.skipIf(not PLATFORM_SUPPORTS_FUSED_ATTENTION or not isLessThanSM80Device, 1613 "Current platform does not support fused SDPA or is an SM80+ device.") 1614 def test_mem_efficient_fail_bfloat16_less_than_sm80(self, device): 1615 dtype = torch.bfloat16 1616 size = SdpaShape(16, 16, 32, 32) 1617 make_tensor = partial(torch.rand, size, device=device, dtype=dtype) 1618 q, k, v = make_tensor(), make_tensor(), make_tensor() 1619 with sdpa_kernel(backends=[SDPBackend.EFFICIENT_ATTENTION]): 1620 with self.assertWarnsRegex(UserWarning, "Expected query, key and value to all be of dtype: {Half, Float}"): 1621 self.assertRaises(RuntimeError, lambda: torch.nn.functional.scaled_dot_product_attention( 1622 q, k, v, None, 0.0, False)) 1623 1624 @onlyCUDA 1625 @unittest.skipIf(not PLATFORM_SUPPORTS_FLASH_ATTENTION, "Does not support flash attention") 1626 def test_flash_atteention_large_bf16_nan_values(self, device): 1627 query = torch.full((1, 1, 1, 64), 133120.0, dtype=torch.bfloat16, device="cuda") 1628 key = torch.full((1, 1, 1, 64), 133120.0, dtype=torch.bfloat16, device="cuda") 1629 value = torch.full((1, 1, 1, 64), 133120.0, dtype=torch.bfloat16, device="cuda") 1630 1631 with sdpa_kernel(SDPBackend.FLASH_ATTENTION): 1632 out = torch.nn.functional.scaled_dot_product_attention(query, key, value) 1633 1634 self.assertFalse(torch.isnan(out).any(), "Output should not contain NaNs!") 1635 1636 @onlyCUDA 1637 @unittest.skipIf(not PLATFORM_SUPPORTS_FUSED_ATTENTION, "Fused SDPA was not built for this system") 1638 @parametrize("fused_kernel", [SDPBackend.FLASH_ATTENTION, SDPBackend.EFFICIENT_ATTENTION] if 1639 PLATFORM_SUPPORTS_FLASH_ATTENTION else [SDPBackend.EFFICIENT_ATTENTION]) 1640 def test_fused_kernels_seq_len_0_inputs(self, device, fused_kernel): 1641 rand_nested_tensor = partial(rand_sdpa_tensor, type="nested", device=device, dtype=torch.float16) 1642 batch, num_heads, head_dim = 32, 16, 64 1643 seq_lens = torch.randint(low=1, high=32, size=(batch,)) 1644 # make sure some seq_lens are 0 1645 num_zeros = 10 1646 indices = torch.randint(low=0, high=batch, size=(num_zeros,)) 1647 seq_lens.scatter_(0, indices, 0) 1648 1649 shape = SdpaShape(batch, num_heads, seq_lens.tolist(), head_dim) 1650 query = rand_nested_tensor(shape) 1651 key = rand_nested_tensor(shape) 1652 value = rand_nested_tensor(shape) 1653 1654 query = query.transpose(1, 2) 1655 key = key.transpose(1, 2) 1656 value = value.transpose(1, 2) 1657 1658 with sdpa_kernel(backends=[fused_kernel]): 1659 with self.assertRaisesRegex(RuntimeError, "No available kernel"): 1660 torch.nn.functional.scaled_dot_product_attention( 1661 query, key, value, attn_mask=None, dropout_p=0.0, is_causal=False) 1662 1663 @onlyCUDA 1664 @unittest.skipIf(not PLATFORM_SUPPORTS_FLASH_ATTENTION, "Fused SDPA was not built for this system") 1665 def test_fused_kernels_nested_broadcasting_requires_grad_failure(self, device): 1666 rand_nested_tensor = partial(rand_sdpa_tensor, type="nested", device=device, dtype=torch.float16, requires_grad=True) 1667 batch, num_heads, head_dim, head_dim_v = 32, 16, 64, 64 1668 seq_lens = torch.randint(low=1, high=32, size=(batch,)).tolist() 1669 q_shape = SdpaShape(1, num_heads, 1, head_dim) 1670 k_shape = SdpaShape(batch, num_heads, seq_lens, head_dim) 1671 v_shape = SdpaShape(batch, 1, seq_lens, head_dim_v) 1672 1673 # create a dense query 1674 query = torch.randn(q_shape, device=device, dtype=torch.float16, requires_grad=True) 1675 key = rand_nested_tensor(k_shape) 1676 value = rand_nested_tensor(v_shape) 1677 1678 query = query.transpose(1, 2) 1679 key = key.transpose(1, 2) 1680 value = value.transpose(1, 2) 1681 1682 with sdpa_kernel(backends=[SDPBackend.FLASH_ATTENTION]): 1683 with self.assertWarnsRegex(UserWarning, "Both fused kernels do not support training with broadcasted NT inputs"): 1684 with self.assertRaisesRegex(RuntimeError, "No available kernel"): 1685 out = torch.nn.functional.scaled_dot_product_attention( 1686 query, key, value, attn_mask=None, dropout_p=0.0, is_causal=False) 1687 1688 @onlyCUDA 1689 @unittest.skipIf(not PLATFORM_SUPPORTS_FLASH_ATTENTION, "Does not support flash attention") 1690 def test_flash_attention_fail_with_non_square_causal_attention(self, device): 1691 dtype = torch.bfloat16 1692 q_shape = SdpaShape(1, 1, 8, 16) 1693 kv_shape = SdpaShape(1, 1, 12, 16) 1694 make_q = partial(torch.rand, q_shape, device=device, dtype=dtype) 1695 make_kv = partial(torch.rand, kv_shape, device=device, dtype=dtype) 1696 q, k, v = make_q(), make_kv(), make_kv() 1697 warning_str = "Flash attention does not support the is_causal flag when seqlen_q != seqlen_k." 1698 with sdpa_kernel(backends=[SDPBackend.FLASH_ATTENTION]): 1699 with self.assertWarnsRegex(UserWarning, warning_str): 1700 self.assertRaises(RuntimeError, lambda: torch.nn.functional.scaled_dot_product_attention( 1701 q, k, v, None, 0.0, is_causal=True)) 1702 1703def _get_block_size_n(device, head_dim, is_dropout, is_causal): 1704 # This should match the block sizes in the CUDA kernel 1705 assert head_dim <= 256 1706 major, minor = torch.cuda.get_device_capability(device) 1707 is_sm8x = major == 8 and minor > 0 # Only include sm86 and sm89, exclude sm80 (A100) 1708 is_sm80 = major == 8 and minor == 0 1709 is_sm90 = major == 9 and minor == 0 1710 if head_dim <= 32: 1711 return 128 1712 if head_dim <= 64: 1713 return 128 if not is_dropout else 64 1714 elif head_dim <= 96: 1715 return 64 1716 elif head_dim <= 128: 1717 if is_sm8x: 1718 return 64 if (not is_dropout and is_causal) else 32 1719 else: 1720 return 64 if not is_dropout else 32 1721 elif head_dim <= 160: 1722 if is_sm8x: 1723 return 64 1724 else: 1725 return 32 1726 elif head_dim <= 192: 1727 return 64 1728 elif head_dim <= 224: 1729 return 64 1730 elif head_dim <= 256: 1731 return 64 1732 1733 1734def pad_last_dim(input_tensor, alignment_size, slice: bool = False): 1735 last_dim_size = input_tensor.size(-1) 1736 if (last_dim_size % alignment_size == 0): 1737 return input_tensor, last_dim_size 1738 pad_count = alignment_size - (last_dim_size % alignment_size) 1739 padded_tensor = F.pad(input_tensor, (0, pad_count)) 1740 if slice: 1741 return padded_tensor[..., :last_dim_size], last_dim_size 1742 return padded_tensor, last_dim_size 1743 1744 1745class TestSDPA(NNTestCase): 1746 """ Used to test generic functionality of scaled_dot_product_attention 1747 Summary: 1748 If you are adding a new test to this class, make sure that it runs 1749 for both cpu and cuda. If you're test is only applicable to cuda, 1750 add it to TestSDPACudaOnly. 1751 """ 1752 @parametrize("contiguous_inputs", [True, False]) 1753 def test_sdp_math_gradcheck(self, device, contiguous_inputs: bool): 1754 1755 batch_size, seq_len, num_heads, head_dim = 4, 4, 2, 16 1756 shape = SdpaShape(batch_size, num_heads, seq_len, head_dim) 1757 make_tensor = partial(rand_sdpa_tensor, type="dense", device=device, 1758 dtype=torch.float64, requires_grad=True, packed=True) 1759 1760 qkv = make_tensor(shape) 1761 query, key, value = qkv.chunk(3, dim=-1) 1762 1763 query = query.view(batch_size, -1, num_heads, head_dim).transpose(1, 2) 1764 key = key.view(batch_size, -1, num_heads, head_dim).transpose(1, 2) 1765 value = value.view(batch_size, -1, num_heads, head_dim).transpose(1, 2) 1766 1767 if contiguous_inputs: 1768 query = query.contiguous() 1769 key = key.contiguous() 1770 value = value.contiguous() 1771 1772 with sdpa_kernel(backends=[SDPBackend.MATH]): 1773 assert gradcheck(lambda *args, **kwargs: 1774 wrapper_set_seed(torch.nn.functional.scaled_dot_product_attention, *args, **kwargs), 1775 (query, key, value, None, 0.0, False) 1776 ) 1777 1778 @onlyCPU 1779 @parametrize("type", ["dense", "nested"]) 1780 @parametrize("dropout", [0.0, 0.7]) 1781 @parametrize("dtype", [torch.float64, torch.float32, torch.bfloat16, torch.half]) 1782 def test_fused_sdp_choice_cpu(self, device, type: str, dropout: float, dtype: torch.dtype): 1783 # Test that cpu and nestedtensor cpu return MATH backend 1784 make_tensor = partial(rand_sdpa_tensor, type=type, device=device, dtype=dtype) 1785 size = SdpaShape(2, 8, 128, 64) 1786 q, k, v = make_tensor(size), make_tensor(size), make_tensor(size) 1787 if type == "nested" \ 1788 or dropout > 0.0 \ 1789 or dtype not in [torch.float32, torch.float64, torch.bfloat16, torch.float16]: 1790 assert torch._fused_sdp_choice(q, k, v, dropout_p=dropout) == SDPBackend.MATH.value 1791 else: 1792 assert torch._fused_sdp_choice(q, k, v, dropout_p=dropout) == SDPBackend.FLASH_ATTENTION.value 1793 1794 @onlyCPU 1795 @parametrize("fused_kernel", [SDPBackend.FLASH_ATTENTION]) 1796 @parametrize("dtype", [torch.float64, torch.float32, torch.bfloat16, torch.float16]) 1797 @parametrize("batch_size", [2, 12]) 1798 @parametrize("seq_len", [267, 1030]) 1799 @parametrize("n_head", [1, 3]) 1800 @parametrize("head_dim", [8, 16]) 1801 @parametrize("causal", [True, False]) 1802 @parametrize("train", [True, False]) 1803 def test_scaled_dot_product_fused_attention_vs_math_cpu( 1804 self, 1805 device, 1806 fused_kernel, 1807 dtype, 1808 batch_size, 1809 seq_len, 1810 n_head, 1811 head_dim, 1812 causal, 1813 train, 1814 ): 1815 atol = 1e-5 1816 rtol = 5e-6 1817 if dtype is torch.bfloat16: 1818 atol = 5e-2 1819 rtol = 5e-2 1820 if dtype is torch.float16: 1821 atol = 1e-2 1822 rtol = 1e-2 1823 1824 n_embd = n_head * head_dim 1825 make_tensor = partial(rand_sdpa_tensor, type="dense", device=device, dtype=dtype, packed=True, requires_grad=False) 1826 shape = SdpaShape(batch_size, n_head, seq_len, head_dim) 1827 x = make_tensor(shape) 1828 x2 = x.clone() 1829 1830 if train: 1831 x.requires_grad_(True) 1832 x2.requires_grad_(True) 1833 1834 q, k, v = x.split(n_embd, dim=2) 1835 q2, k2, v2 = x2.split(n_embd, dim=2) 1836 1837 if dtype in [torch.bfloat16, torch.float16]: 1838 q2 = q2.float() 1839 k2 = k2.float() 1840 v2 = v2.float() 1841 1842 # (B, nh, T, hs) 1843 k = k.view(batch_size, seq_len, n_head, head_dim).transpose(1, 2) 1844 q = q.view(batch_size, seq_len, n_head, head_dim).transpose(1, 2) 1845 v = v.view(batch_size, seq_len, n_head, head_dim).transpose(1, 2) 1846 k2 = k2.view(batch_size, seq_len, n_head, head_dim).transpose(1, 2) 1847 q2 = q2.view(batch_size, seq_len, n_head, head_dim).transpose(1, 2) 1848 v2 = v2.view(batch_size, seq_len, n_head, head_dim).transpose(1, 2) 1849 1850 with sdpa_kernel(backends=[fused_kernel]): 1851 actual = torch.nn.functional.scaled_dot_product_attention( 1852 q, k, v, attn_mask=None, dropout_p=0.0, is_causal=causal) 1853 with sdpa_kernel(backends=[SDPBackend.MATH]): 1854 math_ref = torch.nn.functional.scaled_dot_product_attention( 1855 q2, k2, v2, attn_mask=None, dropout_p=0.0, is_causal=causal) 1856 1857 if dtype in [torch.bfloat16, torch.float16]: 1858 math_ref = math_ref.to(dtype) 1859 1860 self.assertEqual(actual, math_ref, atol=atol, rtol=rtol) 1861 1862 if train: 1863 actual.sum().backward() 1864 math_ref.sum().backward() 1865 1866 grad_x, grad_x2 = x.grad, x2.grad 1867 grad_q_actual, grad_k_actual, grad_v_actual = grad_x.split(n_embd, dim=2) 1868 grad_q_ref, grad_k_ref, grad_v_ref = grad_x2.split(n_embd, dim=2) 1869 1870 self.assertEqual(grad_q_actual, grad_q_ref, atol=atol, rtol=rtol) 1871 self.assertEqual(grad_k_actual, grad_k_ref, atol=atol, rtol=rtol) 1872 self.assertEqual(grad_v_actual, grad_v_ref, atol=atol, rtol=rtol) 1873 1874 @onlyCPU 1875 @parametrize("fused_kernel", [SDPBackend.FLASH_ATTENTION]) 1876 @parametrize("dtype", [torch.float64, torch.float32, torch.bfloat16, torch.float16]) 1877 @parametrize("batch_size", [2, 12]) 1878 @parametrize("q_seq_len", [267, 1030]) 1879 @parametrize("kv_seq_len", [514, 1179]) 1880 @parametrize("n_head", [1, 3]) 1881 @parametrize("head_dim", [8, 16]) 1882 @parametrize("mask_dim", [2, 4]) 1883 @parametrize("bool_mask", [0, 1]) 1884 @parametrize("train", [True, False]) 1885 def test_scaled_dot_product_fused_attention_mask_vs_math_cpu( 1886 self, 1887 device, 1888 fused_kernel, 1889 dtype, 1890 batch_size, 1891 q_seq_len, 1892 kv_seq_len, 1893 n_head, 1894 head_dim, 1895 mask_dim, 1896 bool_mask, 1897 train, 1898 ): 1899 tol = Tolerances(1e-5, 5e-6) 1900 if dtype is torch.bfloat16: 1901 tol = Tolerances(5e-2, 5e-2) 1902 if dtype is torch.float16: 1903 tol = Tolerances(1e-2, 1e-2) 1904 1905 make_tensor = partial(rand_sdpa_tensor, type="dense", device=device, dtype=dtype, requires_grad=False) 1906 q_shape = SdpaShape(batch_size, n_head, q_seq_len, head_dim) 1907 kv_shape = SdpaShape(batch_size, n_head, kv_seq_len, head_dim) 1908 q = make_tensor(q_shape) 1909 k = make_tensor(kv_shape) 1910 v = make_tensor(kv_shape) 1911 q2, k2, v2 = q.clone(), k.clone(), v.clone() 1912 1913 if train: 1914 q.requires_grad_(True) 1915 k.requires_grad_(True) 1916 v.requires_grad_(True) 1917 q2.requires_grad_(True) 1918 k2.requires_grad_(True) 1919 v2.requires_grad_(True) 1920 1921 if dtype in [torch.bfloat16, torch.float16]: 1922 q2, k2, v2 = q2.float(), k2.float(), v2.float() 1923 # (B, nh, T, hs) 1924 q = q.view(batch_size, q_seq_len, n_head, head_dim).transpose(1, 2) 1925 k = k.view(batch_size, kv_seq_len, n_head, head_dim).transpose(1, 2) 1926 v = v.view(batch_size, kv_seq_len, n_head, head_dim).transpose(1, 2) 1927 if mask_dim == 4: 1928 mask_shape = (batch_size, n_head, q_seq_len, kv_seq_len) 1929 else: 1930 mask_shape = (q_seq_len, kv_seq_len) 1931 if bool_mask: 1932 attn_mask = torch.randint(0, 2, size=mask_shape, dtype=torch.bool, device=device) 1933 else: 1934 attn_mask = torch.randn(mask_shape, dtype=dtype, device=device) 1935 q2 = q2.view(batch_size, q_seq_len, n_head, head_dim).transpose(1, 2) 1936 k2 = k2.view(batch_size, kv_seq_len, n_head, head_dim).transpose(1, 2) 1937 v2 = v2.view(batch_size, kv_seq_len, n_head, head_dim).transpose(1, 2) 1938 1939 with sdpa_kernel(backends=[fused_kernel]): 1940 actual = torch.nn.functional.scaled_dot_product_attention( 1941 q, k, v, attn_mask=attn_mask, dropout_p=0.0, is_causal=False) 1942 with sdpa_kernel(backends=[SDPBackend.MATH]): 1943 if not bool_mask and dtype in [torch.bfloat16, torch.float16]: 1944 attn_mask = attn_mask.float() 1945 math_ref = torch.nn.functional.scaled_dot_product_attention( 1946 q2, k2, v2, attn_mask=attn_mask, dropout_p=0.0, is_causal=False) 1947 1948 if dtype in [torch.bfloat16, torch.float16]: 1949 math_ref = math_ref.to(dtype) 1950 1951 self.assertEqual(actual, math_ref, atol=tol.atol, rtol=tol.rtol) 1952 1953 if train: 1954 actual.sum().backward() 1955 math_ref.sum().backward() 1956 1957 grad_q_actual, grad_k_actual, grad_v_actual = q.grad, k.grad, v.grad 1958 grad_q_ref, grad_k_ref, grad_v_ref = q2.grad, k2.grad, v2.grad 1959 1960 self.assertEqual(grad_q_actual, grad_q_ref, atol=tol.atol, rtol=tol.rtol) 1961 self.assertEqual(grad_k_actual, grad_k_ref, atol=tol.atol, rtol=tol.rtol) 1962 self.assertEqual(grad_v_actual, grad_v_ref, atol=tol.atol, rtol=tol.rtol) 1963 1964 @onlyCPU 1965 def test_scaled_dot_product_fused_attention_with_inf(self, device): 1966 # https://github.com/pytorch/pytorch/issues/127055. 1967 full = torch.full((600, 600), float("-inf"), device=device) 1968 mask = torch.triu(full, diagonal=1) + torch.tril(full, diagonal=-10) 1969 make_tensor = partial(rand_sdpa_tensor, type="dense", device=device, dtype=torch.float32, requires_grad=False) 1970 input_shape = SdpaShape(1, 600, 2, 8) 1971 q = make_tensor(input_shape) 1972 k = make_tensor(input_shape) 1973 v = make_tensor(input_shape) 1974 with sdpa_kernel(backends=[SDPBackend.MATH]): 1975 math_ref = torch.nn.functional.scaled_dot_product_attention(q, k, v, attn_mask=mask) 1976 with sdpa_kernel(backends=[SDPBackend.FLASH_ATTENTION]): 1977 actual = torch.nn.functional.scaled_dot_product_attention(q, k, v, attn_mask=mask) 1978 self.assertEqual(math_ref, actual) 1979 1980 @parametrize("kernel", [SDPBackend.MATH]) 1981 def test_scaled_dot_product_attention_math_with_negative_scale(self, device, kernel: SDPBackend): 1982 # https://github.com/pytorch/pytorch/issues/105190. 1983 def ref(x): 1984 v1 = torch.matmul(x, x.transpose(-1, -2)) 1985 v2 = v1 / -0.0001 1986 v3 = v2.softmax(dim=-1) 1987 v4 = torch.matmul(v3, x) 1988 return v4 1989 1990 x = torch.randn(1, 3, 64, 64, device=device) 1991 ref_result = ref(x) 1992 with sdpa_kernel(backends=[kernel]): 1993 sdp_math = torch.nn.functional.scaled_dot_product_attention(x, x, x, scale=-1.0 / 0.0001) 1994 self.assertEqual(ref_result, sdp_math) 1995 1996class TestSDPACudaOnly(NNTestCase): 1997 """ Used to test CUDA only functionality of scaled_dot_product_attention 1998 Quarks: 1999 There is some trickiness with this function. Its runtime behavior 2000 is dependent on the CUDA architecture you are testing it on. See 2001 `PLATFORM_SUPPORTS_FUSED_ATTENTION` at the top of the file. 2002 Summary: 2003 Math: always supported 2004 FlashAttention: Supported on sm80 or newer hardware 2005 MemEfficientAttention: Supported on sm50 or newer hardware 2006 """ 2007 _do_cuda_memory_leak_check = True 2008 _do_cuda_non_default_stream = True 2009 2010 # TODO USED FOR TESTING THE SCORES, e.g. testing ALIBI we don't need this now 2011 def normalize_flash_attn_S( 2012 self, 2013 attn_unnorm, 2014 q, 2015 k, 2016 v, 2017 query_padding_mask=None, 2018 key_padding_mask=None, 2019 attn_bias=None, 2020 is_dropout=False, 2021 causal=False, 2022 window_size=(-1, -1), # -1 means infinite window size 2023 scale=None, 2024 ): 2025 """ 2026 Arguments: 2027 q: (batch_size, seqlen_q, nheads, head_dim) 2028 k, v: (batch_size, seqlen_k, nheads, head_dim) 2029 key_padding_mask: (batch_size, seqlen_q) 2030 attn_bias: broadcastable to (batch_size, nheads, seqlen_q, seqlen_k) 2031 Output: 2032 softmax_lse: (batch_size, nheads, seqlen_q) 2033 softmax_max: (batch_size, nheads, seqlen_q) 2034 """ 2035 q = q.transpose(1, 2) 2036 k = k.transpose(1, 2) 2037 v = v.transpose(1, 2) 2038 if causal: 2039 window_size = (window_size[0], 0) 2040 q, k, v = q.float(), k.float(), v.float() 2041 _, seqlen_q, _, head_dim = q.shape 2042 seqlen_k = k.shape[1] 2043 b = q.shape[0] 2044 from torch.nn.attention.bias import _calculate_scale 2045 scale = _calculate_scale(head_dim, scale) 2046 scores = torch.matmul(q.transpose(1, 2) * scale, k.permute(0, 2, 3, 1)) 2047 if key_padding_mask is not None: 2048 scores.masked_fill_(~key_padding_mask.view(b, 1, 1, -1), float("-inf")) 2049 if window_size[0] >= 0 or window_size[1] >= 0: 2050 local_mask = self.construct_local_mask( 2051 seqlen_q, 2052 seqlen_k, 2053 window_size, 2054 query_padding_mask, 2055 key_padding_mask, 2056 q.device, 2057 ) 2058 scores.masked_fill_(local_mask, float("-inf")) 2059 if attn_bias is not None: 2060 scores = scores + attn_bias.to(dtype=scores.dtype) 2061 block_size_n = _get_block_size_n(scores.device, head_dim, is_dropout, causal) 2062 scores_block = scores.split(block_size_n, dim=-1) 2063 lse_block = torch.stack([torch.logsumexp(s, dim=-1) for s in scores_block], dim=-1) 2064 lse = torch.logsumexp(lse_block, dim=-1) 2065 # lse could be -inf (i.e. all values in scores are -inf), and we want to set those to inf 2066 # so that when we do torch.exp(m - lse), we get 0.0 instead of NaN. 2067 lse[lse == float("-inf")] = float("inf") 2068 scores_max_block = torch.stack([torch.amax(s, dim=-1) for s in scores_block], dim=-1) 2069 cummax_block = torch.cummax(scores_max_block.flip(-1), dim=-1).values.flip(-1).unbind(dim=-1) 2070 attn_unnorm_block = attn_unnorm.split(block_size_n, dim=-1) 2071 attn_norm = torch.cat( 2072 [ 2073 a * (torch.exp(m - lse)).unsqueeze(-1) 2074 for a, m in zip(attn_unnorm_block, cummax_block) 2075 ], 2076 dim=-1, 2077 ) 2078 if query_padding_mask is not None: 2079 attn_norm.masked_fill_(~query_padding_mask.view(b, 1, -1, 1), 0.0) 2080 # attn_norm.masked_fill_(rearrange(~query_padding_mask, "b s -> b 1 s 1"), 0.0) 2081 return attn_norm.to(dtype=attn_unnorm.dtype) 2082 2083 def construct_local_mask(self, seqlen_q, seqlen_k, window_size, query_padding_mask, key_padding_mask, device): 2084 # row_idx = rearrange(torch.arange(seqlen_q, device=device, dtype=torch.long), "s -> s 1") 2085 row_idx = torch.arange(seqlen_q, device=device, dtype=torch.long).view(-1, 1) 2086 col_idx = torch.arange(seqlen_k, device=device, dtype=torch.long) 2087 sk = ( 2088 seqlen_k 2089 if key_padding_mask is None 2090 else key_padding_mask.sum(-1).view(-1, 1, 1, 1) 2091 # else rearrange(key_padding_mask.sum(-1), "b -> b 1 1 1") 2092 ) 2093 sq = ( 2094 seqlen_q 2095 if query_padding_mask is None 2096 else query_padding_mask.sum(-1).view(-1, 1, 1, 1) 2097 # else rearrange(query_padding_mask.sum(-1), "b -> b 1 1 1") 2098 ) 2099 if window_size[0] < 0: 2100 return col_idx > row_idx + sk - sq + window_size[1] 2101 else: 2102 sk = torch.full_like(col_idx, seqlen_k) if key_padding_mask is None else sk 2103 return torch.logical_or( 2104 col_idx > torch.minimum(row_idx + sk - sq + window_size[1], sk), 2105 col_idx < row_idx + sk - sq - window_size[0], 2106 ) 2107 2108 def convert_flash_attn_S_to_softmax( 2109 self, 2110 S, 2111 seqlen_q, 2112 seqlen_k, 2113 query_padding_mask, 2114 key_padding_mask, 2115 causal=False, 2116 window_size=(-1, -1), # -1 means infinite window size 2117 ): 2118 """FlashAttention stores the S matrix in a different way. 2119 Arguments: 2120 S: (batch_size, nheads, seqlen_q, seqlen_k) 2121 query_padding_mask: (batch_size, seqlen_q) 2122 key_padding_mask: (batch_size, seqlen_k) 2123 """ 2124 if TEST_WITH_ROCM: 2125 return S 2126 b = S.shape[0] 2127 2128 if causal: 2129 window_size = (window_size[0], 0) 2130 seqlen_q_rounded, seqlen_k_rounded = S.shape[-2:] 2131 S_converted = S 2132 if window_size[0] >= 0 or window_size[1] >= 0: 2133 local_mask = self.construct_local_mask( 2134 seqlen_q, 2135 seqlen_k, 2136 window_size, 2137 query_padding_mask, 2138 key_padding_mask, 2139 S.device, 2140 ) 2141 local_mask = F.pad( 2142 local_mask, 2143 (0, seqlen_k_rounded - seqlen_k, 0, seqlen_q_rounded - seqlen_q), 2144 value=True, 2145 ) 2146 S_converted = S_converted.masked_fill(local_mask, 0.0) 2147 2148 # Need to zero out things not in attention_mask in case S was initialized with random values 2149 # and some of those values aren't overwritten. 2150 seqlen_q_og = ( 2151 query_padding_mask.shape[-1] if query_padding_mask is not None else seqlen_q_rounded 2152 ) 2153 if query_padding_mask is not None: 2154 query_padding_mask = F.pad(query_padding_mask, (0, seqlen_q_rounded - seqlen_q_og)) 2155 # S_converted = S_converted.masked_fill(rearrange(~query_padding_mask, "b s -> b 1 s 1"), 0.0) 2156 S_converted = S_converted.masked_fill(~query_padding_mask.view(b, 1, -1, 1), 0.0) 2157 seqlen_k_og = key_padding_mask.shape[-1] if key_padding_mask is not None else seqlen_k 2158 if key_padding_mask is not None: 2159 key_padding_mask = F.pad(key_padding_mask, (0, seqlen_k_rounded - seqlen_k_og)) 2160 S_converted = S_converted.masked_fill(~key_padding_mask.view(b, 1, 1, -1), 0.0) 2161 # S_converted = S_converted.masked_fill(rearrange(~key_padding_mask, "b s -> b 1 1 s"), 0.0) 2162 S_converted = F.pad(S_converted, (0, 0, 0, seqlen_q_og - seqlen_q_rounded)) 2163 S_converted = F.pad(S_converted, (0, seqlen_k_og - seqlen_k_rounded)) 2164 return S_converted[:, :, :seqlen_q, :seqlen_k] 2165 2166 @skipIfRocm # No cuDNN Attention 2167 @unittest.skipIf(not PLATFORM_SUPPORTS_CUDNN_ATTENTION, "cuDNN Attention is not supported on this system") 2168 def test_cudnn_attention_different_dk_dv(self, device): 2169 dtype = torch.bfloat16 2170 make_tensor = partial(torch.rand, device=device, dtype=dtype, requires_grad=True) 2171 batch, num_heads, head_dim_k, head_dim_v = 32, 16, 128, 64 2172 seq_len = 640 2173 q_shape = SdpaShape(batch, num_heads, seq_len, head_dim_k) 2174 k_shape = SdpaShape(batch, num_heads, seq_len, head_dim_k) 2175 v_shape = SdpaShape(batch, num_heads, seq_len, head_dim_v) 2176 query, key, value = make_tensor(q_shape), make_tensor(k_shape), make_tensor(v_shape) 2177 2178 with sdpa_kernel(backends=[SDPBackend.CUDNN_ATTENTION]): 2179 actual = torch.nn.functional.scaled_dot_product_attention( 2180 query, key, value, attn_mask=None, dropout_p=0.0, is_causal=False) 2181 with sdpa_kernel(backends=[SDPBackend.MATH]): 2182 math_ref = torch.nn.functional.scaled_dot_product_attention( 2183 query.contiguous().to(torch.float32), 2184 key.contiguous().to(torch.float32), 2185 value.contiguous().to(torch.float32), 2186 attn_mask=None, dropout_p=0.0, is_causal=False) 2187 2188 self.assertEqual(actual.contiguous(), math_ref.contiguous().to(dtype), atol=1e-3, rtol=1e-2) 2189 2190 2191 @unittest.skipIf(not PLATFORM_SUPPORTS_MEM_EFF_ATTENTION, "Fused SDPA was not built for this system") 2192 @parametrize("mask_dim", [1, 2, 3, 4]) 2193 def test_mem_efficient_attention_mask_variants(self, device, mask_dim: List[int]): 2194 dtype = torch.float16 2195 make_tensor = partial(torch.rand, device=device, dtype=dtype, requires_grad=True) 2196 batch, num_heads, head_dim = 8, 8, 64 2197 seq_len_q, seq_len_kv = 64, 32 2198 query = make_tensor(SdpaShape(batch, num_heads, seq_len_q, head_dim)) 2199 kv_shape = SdpaShape(batch, num_heads, seq_len_kv, head_dim) 2200 key, value = make_tensor(kv_shape), make_tensor(kv_shape) 2201 2202 if mask_dim == 1: 2203 mask = torch.randn((seq_len_kv,), device=device, dtype=dtype) 2204 elif mask_dim == 2: 2205 mask = torch.randn((seq_len_q, seq_len_kv), device=device, dtype=dtype) 2206 elif mask_dim == 3: 2207 mask = torch.randn((num_heads, seq_len_q, seq_len_kv), device=device, dtype=dtype) 2208 elif mask_dim == 4: 2209 mask = torch.randn((batch, num_heads, seq_len_q, seq_len_kv), device=device, dtype=dtype) 2210 with sdpa_kernel(backends=[SDPBackend.EFFICIENT_ATTENTION]): 2211 out = F.scaled_dot_product_attention(query, key, value, mask) 2212 out.sum().backward() 2213 2214 @unittest.skipIf(not PLATFORM_SUPPORTS_MEM_EFF_ATTENTION, "Fused SDPA was not built for this system") 2215 @parametrize("dtype", [torch.float, torch.float16]) 2216 def test_mem_eff_attention_pad_mask(self, device, dtype): 2217 make_tensor = partial(torch.rand, device=device, dtype=dtype, requires_grad=True) 2218 batch, num_heads, head_dim = 8, 8, 64 2219 seq_len_q, seq_len_kv = 64, 15 2220 query = make_tensor(SdpaShape(batch, num_heads, seq_len_q, head_dim)) 2221 kv_shape = SdpaShape(batch, num_heads, seq_len_kv, head_dim) 2222 key, value = make_tensor(kv_shape), make_tensor(kv_shape) 2223 mask = torch.randn((batch, num_heads, seq_len_q, seq_len_kv), device=device, dtype=dtype) 2224 with sdpa_kernel(backends=[SDPBackend.EFFICIENT_ATTENTION]): 2225 out = F.scaled_dot_product_attention(query, key, value, mask) 2226 out.sum().backward() 2227 2228 @unittest.skipIf(not PLATFORM_SUPPORTS_MEM_EFF_ATTENTION, "Fused SDPA was not built for this system") 2229 @parametrize("dtype", [torch.float, torch.float16]) 2230 def test_mem_eff_attention_non_contiguous_mask(self, device, dtype): 2231 make_tensor = partial(torch.rand, device=device, dtype=dtype, requires_grad=True) 2232 batch, num_heads, head_dim = 8, 8, 64 2233 seq_len_q, seq_len_kv = 64, 16 2234 query = make_tensor(SdpaShape(batch, num_heads, seq_len_q, head_dim)) 2235 kv_shape = SdpaShape(batch, num_heads, seq_len_kv, head_dim) 2236 key, value = make_tensor(kv_shape), make_tensor(kv_shape) 2237 mask = torch.randn((batch, num_heads, seq_len_q, seq_len_kv), device=device, dtype=dtype) 2238 mask = torch.as_strided(mask, (batch, num_heads, seq_len_q, seq_len_kv), (0, 0, 0, 1)) 2239 with sdpa_kernel(backends=[SDPBackend.EFFICIENT_ATTENTION]): 2240 out = F.scaled_dot_product_attention(query, key, value, mask) 2241 out.sum().backward() 2242 2243 @unittest.skipIf(not PLATFORM_SUPPORTS_MEM_EFF_ATTENTION, "Fused SDPA was not built for this system") 2244 @parametrize("dtype", [torch.float, torch.float16]) 2245 def test_mem_eff_attention_long_sequence_mask(self, device, dtype): 2246 if torch.cuda.get_device_properties('cuda').total_memory < 80 * 2**30: 2247 unittest.skip("This test requires substatnial GPU memory.") 2248 return 2249 make_tensor = partial(torch.rand, device=device, dtype=dtype, requires_grad=True) 2250 batch, num_heads, head_dim = 1, 32, 64 2251 seq_len_q, seq_len_kv = 8192, 8192 2252 query = make_tensor(SdpaShape(batch, num_heads, seq_len_q, head_dim)) 2253 kv_shape = SdpaShape(batch, num_heads, seq_len_kv, head_dim) 2254 key, value = make_tensor(kv_shape), make_tensor(kv_shape) 2255 mask = torch.randn((batch, num_heads, seq_len_q, seq_len_kv), device=device, dtype=dtype) 2256 with sdpa_kernel(backends=[SDPBackend.EFFICIENT_ATTENTION]): 2257 out = F.scaled_dot_product_attention(query, key, value, mask) 2258 out.sum().backward() 2259 2260 @unittest.skipIf(not PLATFORM_SUPPORTS_MEM_EFF_ATTENTION, "Fused SDPA was not built for this system") 2261 def test_mem_eff_attention_non_contig_mask_bug(self, device): 2262 # Without the fix this produces `AssertionError: assert 0.07352933287620544 < 1e-07` 2263 # Shapes taken from repro 2264 query_size = (3, 16, 1, 128) 2265 query_strides = (2304, 128, 2048, 1) 2266 key_size = (3, 16, 14, 128) 2267 key_strides = (3584, 0, 256, 1) 2268 value_size = (3, 16, 14, 128) 2269 value_strides = (3584, 0, 256, 1) 2270 attention_mask_size = (3, 1, 1, 14) 2271 attn_mask_strides = (14, 14, 14, 1) 2272 2273 # Calculate the number of elements needed for each tensor 2274 query_num_elements = max(size * stride for size, stride in zip(query_size, query_strides)) 2275 key_num_elements = max(size * stride for size, stride in zip(key_size, key_strides)) 2276 value_num_elements = max(size * stride for size, stride in zip(value_size, value_strides)) 2277 attention_mask_num_elements = max(size * stride for size, stride in zip(attention_mask_size, attn_mask_strides)) 2278 2279 # Create the tensors with the specified sizes and strides 2280 query = torch.randn(query_num_elements, device=device).as_strided(query_size, query_strides) 2281 key = torch.randn(key_num_elements, device=device).as_strided(key_size, key_strides) 2282 value = torch.randn(value_num_elements, device=device).as_strided(value_size, value_strides) 2283 bias = torch.randn(attention_mask_num_elements, device=device).as_strided(attention_mask_size, attn_mask_strides) 2284 2285 with sdpa_kernel(backends=[SDPBackend.EFFICIENT_ATTENTION]): 2286 out = F.scaled_dot_product_attention(query, key, value, bias) 2287 out_contig = F.scaled_dot_product_attention(query, key, value, bias.contiguous()) 2288 2289 max_diff = (out - out_contig).abs().mean() 2290 self.assertTrue(max_diff.item() < 1e-7) 2291 2292 @unittest.skipIf(not PLATFORM_SUPPORTS_FLASH_ATTENTION, "Fused SDPA was not built for this system") 2293 def test_singelton_head_dim_stride_ne_1(self, device): 2294 query = torch.tensor([[[[1, 2]]]], dtype=torch.float16, device=device) 2295 query = query.transpose(-1, -2) 2296 key = torch.tensor([[[[1]]]], dtype=torch.float16, device=device) 2297 value = torch.tensor([[[[1]]]], dtype=torch.float16, device=device) 2298 2299 with torch.backends.cuda.sdp_kernel(enable_math=False, enable_flash=True, enable_mem_efficient=False): 2300 scaled_dot_product_attention(query, key, value) 2301 2302 @unittest.skipIf(not PLATFORM_SUPPORTS_MEM_EFF_ATTENTION, "Fused SDPA was not built for this system") 2303 @parametrize("type", ["dense", "nested"]) 2304 @parametrize("is_contiguous", [True, False]) 2305 def test_scaled_dot_product_attention_fused_kernels_packed(self, device, type: str, is_contiguous: bool): 2306 if TEST_WITH_ROCM and type == 'nested': 2307 self.skipTest("ROCM does not support efficient attention on nested tensors, for now") 2308 make_tensor = partial(rand_sdpa_tensor, type=type, device=device, dtype=torch.float16, packed=True) 2309 2310 batch_size, seq_len, num_heads, head_dim = 32, 64, 16, 64 2311 shape = SdpaShape(batch_size, num_heads, seq_len, head_dim) 2312 2313 # Test Packed 2314 qkv = make_tensor(shape) 2315 query, key, value = qkv.chunk(3, dim=-1) 2316 2317 query = query.view(batch_size, -1, num_heads, head_dim).transpose(1, 2) 2318 value = value.view(batch_size, -1, num_heads, head_dim).transpose(1, 2) 2319 key = key.view(batch_size, -1, num_heads, head_dim).transpose(1, 2) 2320 2321 if is_contiguous: 2322 query = query.contiguous() 2323 key = key.contiguous() 2324 value = value.contiguous() 2325 2326 with sdpa_kernel(backends=[SDPBackend.EFFICIENT_ATTENTION]): 2327 actual = torch.nn.functional.scaled_dot_product_attention( 2328 query, key, value, attn_mask=None, dropout_p=0.0, is_causal=False) 2329 with sdpa_kernel(backends=[SDPBackend.MATH]): 2330 math_ref = torch.nn.functional.scaled_dot_product_attention( 2331 query.contiguous(), key.contiguous(), value.contiguous(), 2332 attn_mask=None, dropout_p=0.0, is_causal=False) 2333 2334 self.assertEqual(actual.contiguous(), math_ref.contiguous(), atol=2e-3, rtol=1e-2) 2335 2336 @skipIfRocm # Missing nested and EFFICIENT_ATTENTION 2337 @unittest.skipIf(not PLATFORM_SUPPORTS_FUSED_ATTENTION, "Fused SDPA was not built for this system") 2338 @parametrize("type", ["dense", "nested"]) 2339 @parametrize("fused_kernel", [SDPBackend.FLASH_ATTENTION, SDPBackend.EFFICIENT_ATTENTION] if 2340 PLATFORM_SUPPORTS_FLASH_ATTENTION else [SDPBackend.EFFICIENT_ATTENTION]) 2341 def test_scaled_dot_product_attention_fused_kernels_packed_accuracy(self, device, type: str, fused_kernel: str): 2342 def rand_nt(shape): 2343 batch, seq_len, num_heads, head_dim = shape 2344 tensors = [6 * torch.rand((seq_len, 3 * num_heads * head_dim), device=device, dtype=torch.float32) - 3 2345 for _ in range(batch)] 2346 return (torch.nested.nested_tensor(tensors, device=device, dtype=torch.float32), 2347 torch.nested.nested_tensor(tensors, device=device, dtype=torch.float16)) 2348 2349 def rand_tensor(shape): 2350 batch, seq_len, num_heads, head_dim = shape 2351 tensor = 6 * torch.rand((batch, seq_len, 3 * num_heads * head_dim), device=device, dtype=torch.float32) - 3 2352 return tensor, tensor.to(dtype=torch.float16) 2353 2354 batch_size, seq_len, num_heads, head_dim = 16, 8, 4, 64 2355 shape = (batch_size, seq_len, num_heads, head_dim) 2356 2357 # Test Packed 2358 qkv, qkv_low_precision = rand_tensor(shape) if type == "dense" else rand_nt(shape) 2359 query, key, value = qkv.chunk(3, dim=-1) 2360 query_lp, key_lp, value_lp = qkv_low_precision.chunk(3, dim=-1) 2361 2362 query = query.view(batch_size, -1, num_heads, head_dim).transpose(1, 2) 2363 key = key.view(batch_size, -1, num_heads, head_dim).transpose(1, 2) 2364 value = value.view(batch_size, -1, num_heads, head_dim).transpose(1, 2) 2365 2366 query_lp = query_lp.view(batch_size, -1, num_heads, head_dim).transpose(1, 2) 2367 key_lp = key_lp.view(batch_size, -1, num_heads, head_dim).transpose(1, 2) 2368 value_lp = value_lp.view(batch_size, -1, num_heads, head_dim).transpose(1, 2) 2369 2370 with sdpa_kernel(backends=[fused_kernel]): 2371 actual = torch.nn.functional.scaled_dot_product_attention( 2372 query_lp, key_lp, value_lp, attn_mask=None, dropout_p=0.0, is_causal=False) 2373 2374 with sdpa_kernel(backends=[SDPBackend.MATH]): 2375 math_ref_lp = torch.nn.functional.scaled_dot_product_attention( 2376 query_lp.contiguous(), key_lp.contiguous(), value_lp.contiguous(), 2377 attn_mask=None, dropout_p=0.0, is_causal=False) 2378 2379 math_query = query.contiguous() 2380 math_key = key.contiguous() 2381 math_value = value.contiguous() 2382 2383 math_ref = torch.nn.functional.scaled_dot_product_attention( 2384 math_query, math_key, math_value, attn_mask=None, dropout_p=0.0, is_causal=False) 2385 2386 actual_test = actual 2387 math_ref_test = math_ref 2388 math_ref_lp_test = math_ref_lp 2389 2390 if actual_test.is_nested: 2391 actual_test = torch.nested.to_padded_tensor(actual_test.contiguous(), padding=0.0) 2392 math_ref_test = torch.nested.to_padded_tensor(math_ref_test, padding=0.0) 2393 math_ref_lp_test = torch.nested.to_padded_tensor(math_ref_lp_test, padding=0.0) 2394 2395 actual_test = actual_test.to(dtype=torch.float32).contiguous() 2396 math_ref_test = math_ref_test.to(dtype=torch.float32).contiguous() 2397 math_ref_lp_test = math_ref_lp_test.to(dtype=torch.float32).contiguous() 2398 2399 self.assertEqual(math_ref_test, math_ref_lp_test, atol=7e-3, rtol=7e-3) 2400 self.assertEqual(actual_test, math_ref_test, atol=5e-3, rtol=5e-3) 2401 2402 @unittest.skipIf(not PLATFORM_SUPPORTS_MEM_EFF_ATTENTION, "Efficient Attention was not built for this system") 2403 @parametrize("contiguous_inputs", [True, False]) 2404 @parametrize("is_causal", [True, False]) 2405 def test_sdp_mem_efficient_grad_against_math(self, device, contiguous_inputs: bool, is_causal: bool): 2406 batch_size, seq_len, num_heads, head_dim = 4, 4, 2, 16 2407 make_tensor = partial(rand_sdpa_tensor, type="dense", device=device, 2408 dtype=torch.float64, requires_grad=True, packed=True) 2409 2410 qkv = make_tensor(SdpaShape(batch_size, num_heads, seq_len, head_dim)) 2411 qkv_lp = qkv.detach().clone().to(torch.float32).requires_grad_() 2412 2413 query, key, value = qkv.chunk(3, dim=-1) 2414 query_lp, key_lp, value_lp = qkv_lp.chunk(3, dim=-1) 2415 2416 query = query.view(batch_size, -1, num_heads, head_dim).transpose(1, 2) 2417 key = key.view(batch_size, -1, num_heads, head_dim).transpose(1, 2) 2418 value = value.view(batch_size, -1, num_heads, head_dim).transpose(1, 2) 2419 2420 query_lp = query_lp.view(batch_size, -1, num_heads, head_dim).transpose(1, 2) 2421 key_lp = key_lp.view(batch_size, -1, num_heads, head_dim).transpose(1, 2) 2422 value_lp = value_lp.view(batch_size, -1, num_heads, head_dim).transpose(1, 2) 2423 2424 if contiguous_inputs: 2425 query = query.contiguous() 2426 key = key.contiguous() 2427 value = value.contiguous() 2428 2429 query_lp = query_lp.contiguous() 2430 key_lp = key_lp.contiguous() 2431 value_lp = value_lp.contiguous() 2432 2433 with sdpa_kernel(backends=[SDPBackend.MATH]): 2434 out = torch.nn.functional.scaled_dot_product_attention(query, key, value, None, 0.0, is_causal) 2435 2436 with sdpa_kernel(backends=[SDPBackend.EFFICIENT_ATTENTION]): 2437 out_lp = torch.nn.functional.scaled_dot_product_attention( 2438 query_lp, key_lp, value_lp, None, 0.0, is_causal) 2439 2440 rand_upward = torch.rand_like(out) 2441 rand_upward_lp = rand_upward.to(torch.float32) 2442 2443 out.backward(rand_upward) 2444 out_lp.backward(rand_upward_lp) 2445 2446 # Cast up and compare 2447 self.assertEqual(qkv.grad, qkv_lp.grad.to(torch.float64), atol=1e-5, rtol=1e-5) 2448 2449 @unittest.skipIf(not PLATFORM_SUPPORTS_FLASH_ATTENTION, "Flash Attention was not built for this system") 2450 @parametrize("contiguous_inputs", [True, False]) 2451 @parametrize("is_causal", [True, False]) 2452 @parametrize("dtype", [torch.float16, torch.bfloat16]) 2453 def test_sdp_flash_attention_grad_against_math(self, device, contiguous_inputs: bool, is_causal: bool, dtype: torch.dtype): 2454 batch_size, seq_len, num_heads, head_dim = 4, 4, 2, 16 2455 make_tensor = partial(rand_sdpa_tensor, type="dense", device=device, 2456 dtype=torch.float64, requires_grad=True, packed=True) 2457 2458 qkv = make_tensor(SdpaShape(batch_size, num_heads, seq_len, head_dim)) 2459 qkv_lp = qkv.detach().clone().to(dtype).requires_grad_() 2460 2461 query, key, value = qkv.chunk(3, dim=-1) 2462 query_lp, key_lp, value_lp = qkv_lp.chunk(3, dim=-1) 2463 2464 query = query.view(batch_size, -1, num_heads, head_dim).transpose(1, 2) 2465 key = key.view(batch_size, -1, num_heads, head_dim).transpose(1, 2) 2466 value = value.view(batch_size, -1, num_heads, head_dim).transpose(1, 2) 2467 2468 query_lp = query_lp.view(batch_size, -1, num_heads, head_dim).transpose(1, 2) 2469 key_lp = key_lp.view(batch_size, -1, num_heads, head_dim).transpose(1, 2) 2470 value_lp = value_lp.view(batch_size, -1, num_heads, head_dim).transpose(1, 2) 2471 2472 if contiguous_inputs: 2473 query = query.contiguous() 2474 key = key.contiguous() 2475 value = value.contiguous() 2476 2477 query_lp = query_lp.contiguous() 2478 key_lp = key_lp.contiguous() 2479 value_lp = value_lp.contiguous() 2480 2481 with sdpa_kernel(backends=[SDPBackend.MATH]): 2482 out = torch.nn.functional.scaled_dot_product_attention(query, key, value, None, 0.0, is_causal) 2483 2484 with sdpa_kernel(backends=[SDPBackend.FLASH_ATTENTION]): 2485 out_lp = torch.nn.functional.scaled_dot_product_attention( 2486 query_lp, key_lp, value_lp, None, 0.0, is_causal) 2487 2488 rand_upward = torch.rand_like(out) 2489 rand_upward_lp = rand_upward.to(dtype) 2490 2491 out.backward(rand_upward) 2492 out_lp.backward(rand_upward_lp) 2493 2494 # Cast up and compare 2495 # Since we are doing the compute on fp16 we have to bump the tolerance 2496 # Bump down the tolearnce for blfoat16 2497 atol = 7e-4 if dtype == torch.float16 else 7e-3 2498 rtol = 7e-4 if dtype == torch.float16 else 7e-3 2499 if TEST_WITH_ROCM: 2500 atol = 9e-4 if dtype == torch.float16 else 9e-3 2501 self.assertEqual(qkv.grad, qkv_lp.grad.to(torch.float64), atol=atol, rtol=rtol) 2502 2503 @skipIfRocm # Missing nested and EFFICIENT_ATTENTION 2504 @unittest.skipIf(not PLATFORM_SUPPORTS_FUSED_ATTENTION, "Platform does not support fused SDPA") 2505 @parametrize("type", ["dense", "nested"]) 2506 def test_fused_sdp_choice(self, device, type: str): 2507 batch_size, seq_len, num_heads, head_dim = 2, 128, 8, 64 2508 shape = SdpaShape(batch_size, num_heads, seq_len, head_dim) 2509 make_tensor = partial(rand_sdpa_tensor, device=device, dtype=torch.float16, packed=True, requires_grad=True) 2510 2511 qkv = make_tensor(shape, type=type) 2512 query, key, value = qkv.chunk(3, dim=-1) 2513 2514 query = query.view(batch_size, -1, num_heads, head_dim).transpose(1, 2) 2515 value = value.view(batch_size, -1, num_heads, head_dim).transpose(1, 2) 2516 key = key.view(batch_size, -1, num_heads, head_dim).transpose(1, 2) 2517 2518 if PLATFORM_SUPPORTS_FLASH_ATTENTION: 2519 assert torch._fused_sdp_choice(query, key, value) == SDPBackend.FLASH_ATTENTION.value 2520 else: 2521 assert torch._fused_sdp_choice(query, key, value) == SDPBackend.EFFICIENT_ATTENTION.value 2522 2523 # Change dtype to float32 so that efficient attention should get chosen 2524 make_tensor = partial(rand_sdpa_tensor, device=device, dtype=torch.float32, packed=True) 2525 2526 qkv = make_tensor(shape, type=type) 2527 query, key, value = qkv.chunk(3, dim=-1) 2528 2529 query = query.view(batch_size, -1, num_heads, head_dim).transpose(1, 2) 2530 value = value.view(batch_size, -1, num_heads, head_dim).transpose(1, 2) 2531 key = key.view(batch_size, -1, num_heads, head_dim).transpose(1, 2) 2532 2533 assert torch._fused_sdp_choice(query, key, value) == SDPBackend.EFFICIENT_ATTENTION.value 2534 2535 @skipIfRocm # Missing triton.float32 ("triton" prefix is to locate skipped UTs), and deterministic algo 2536 @unittest.skipIf(not PLATFORM_SUPPORTS_MEM_EFF_ATTENTION, "Platform does not support fused SDPA") 2537 @parametrize("warn_only", [True, False]) 2538 def test_sdp_choice_with_determinism(self, device, warn_only): 2539 batch_size, seq_len, num_heads, head_dim = 1, 64, 8, 64 2540 shape = SdpaShape(batch_size, num_heads, seq_len, head_dim) 2541 make_tensor = partial(rand_sdpa_tensor, type="dense", device=device, dtype=torch.float32, packed=False) 2542 query, key, value = make_tensor(shape), make_tensor(shape), make_tensor(shape) 2543 2544 with use_deterministic_algorithims(True, warn_only=warn_only): 2545 with sdpa_kernel(backends=[SDPBackend.EFFICIENT_ATTENTION, SDPBackend.MATH]): 2546 assert torch._fused_sdp_choice(query, key, value) == SDPBackend.EFFICIENT_ATTENTION.value 2547 2548 @skipIfRocm # Missing deterministic algo 2549 @unittest.skipIf(not PLATFORM_SUPPORTS_FUSED_ATTENTION, "Fused SDPA was not built for this system") 2550 @parametrize("fused_kernel", PLATFORM_SPECIFIC_SDPA) 2551 @parametrize("warn_only", [True, False]) 2552 def test_fused_backwards_throws_determinism_warning(self, device, warn_only, fused_kernel): 2553 batch_size, seq_len, num_heads, head_dim = 1, 64, 8, 64 2554 shape = SdpaShape(batch_size, num_heads, seq_len, head_dim) 2555 make_tensor = partial(rand_sdpa_tensor, type="dense", device=device, dtype=torch.float16, packed=False, requires_grad=True) 2556 query, key, value = make_tensor(shape), make_tensor(shape), make_tensor(shape) 2557 2558 kernel_name = "Memory Efficient attention" if fused_kernel == SDPBackend.EFFICIENT_ATTENTION else "Flash Attention" 2559 warning_context = ( 2560 self.assertWarnsRegex( 2561 UserWarning, 2562 f"{kernel_name} defaults to a non-deterministic algorithm.", 2563 ) 2564 if warn_only 2565 else contextlib.nullcontext() 2566 ) 2567 with use_deterministic_algorithims(True, warn_only=warn_only): 2568 with sdpa_kernel(backends=[fused_kernel]): 2569 with warning_context: 2570 torch.nn.functional.scaled_dot_product_attention(query, key, value).sum().backward() 2571 2572 @unittest.skip("This test is not behaving deterministaclly non-deterministaclly on CI/CD") 2573 @unittest.skipIf(not PLATFORM_SUPPORTS_FLASH_ATTENTION, "Platform does not support fused SDPA") 2574 def test_mem_eff_backwards_determinism(self, device): 2575 # Need big seq_len to ensure that num_splits > 1 2576 dtype = torch.float32 2577 batch_size, seq_len, n_heads, head_dim = 1, 1024, 8, 64 2578 query = torch.rand(batch_size, n_heads, seq_len, head_dim, 2579 device=device, dtype=dtype, requires_grad=True) 2580 key = torch.rand(batch_size, n_heads, seq_len, head_dim, device=device, 2581 dtype=dtype, requires_grad=True) 2582 value = torch.rand(batch_size, n_heads, seq_len, head_dim, 2583 device=device, dtype=dtype, requires_grad=True) 2584 2585 with sdpa_kernel(backends=[SDPBackend.EFFICIENT_ATTENTION]): 2586 # Run once to establish baseline 2587 out = F.scaled_dot_product_attention(query, key, value) 2588 upward_grad = torch.rand_like(out) 2589 out.backward(upward_grad) 2590 intial_query_grad = query.grad 2591 2592 # Re-run the op with the same upward grad and check that the backward is 2593 # not deterministic 2594 diff_anwser_once = False 2595 for _ in range(100): 2596 query.grad = None 2597 out = F.scaled_dot_product_attention(query, key, value) 2598 out.backward(upward_grad) 2599 if not torch.equal(intial_query_grad, query.grad): 2600 diff_anwser_once = True 2601 break 2602 self.assertTrue(diff_anwser_once) 2603 2604 with use_deterministic_algorithims(True, warn_only=False): 2605 query.grad = None 2606 out = F.scaled_dot_product_attention(query, key, value) 2607 upward_grad = torch.rand_like(out) 2608 out.backward(upward_grad) 2609 intial_query_grad = query.grad 2610 2611 # Re-run the op with the same upward grad and check that the backward is 2612 # deterministic now that we have enforced it 2613 diff_anwser_once = False 2614 for _ in range(100): 2615 query.grad = None 2616 out = F.scaled_dot_product_attention(query, key, value) 2617 out.backward(upward_grad) 2618 if not torch.equal(intial_query_grad, query.grad): 2619 diff_anwser_once = True 2620 break 2621 self.assertFalse(diff_anwser_once) 2622 2623 # verified passing successfully on H100 2624 @unittest.skipIf(not PLATFORM_SUPPORTS_MEM_EFF_ATTENTION, "Does not support SDPA") 2625 @unittest.skipIf(IS_JETSON, "causing sigkill on Jetson") 2626 @parametrize("batch_size", [1, 8]) 2627 @parametrize("seq_len_q", [4, 8, 64, 128, 256, 512, 1024, 2048] if MEM_EFF_CAPABILITY_MATCHES_SM80 2628 else [4, 8, 64, 128, 256, 512]) 2629 @parametrize("seq_len_k", [4, 8, 64, 128, 256, 512, 1024, 2048] if MEM_EFF_CAPABILITY_MATCHES_SM80 2630 else [4, 8, 64, 128, 256, 512]) 2631 @parametrize("head_dim", [8, 16, 32, 64, 72, 96, 128] if MEM_EFF_CAPABILITY_MATCHES_SM80 2632 else [8, 16, 32, 64]) 2633 @parametrize("is_causal", [False, True]) 2634 @parametrize("dropout_p", [0.0, 0.22]) 2635 @parametrize("dtype", [torch.float16, torch.bfloat16, torch.float32] if MEM_EFF_CAPABILITY_MATCHES_SM80 2636 else [torch.float16, torch.float32]) 2637 @parametrize("scale", [None, "l1"]) 2638 def test_mem_efficient_attention_vs_math_ref_grads(self, device, batch_size: int, seq_len_q: int, seq_len_k: int, 2639 head_dim: int, is_causal: bool, dropout_p: float, dtype: torch.dtype, 2640 scale: str): 2641 def _get_mem_eff_drop_mask(batch_size, n_heads, q_len, kv_len, p, seed, offset, device=device): 2642 mask = torch.empty((batch_size, n_heads, q_len, kv_len), device=device, dtype=torch.float32) 2643 rand_uniform = torch._fill_mem_eff_dropout_mask_(mask, p, seed, offset) 2644 mask = (rand_uniform > p).to(torch.float32) 2645 return mask 2646 if max(seq_len_q, seq_len_k) >= 2048 and torch.cuda.get_device_properties('cuda').total_memory < 40 * 2**30: 2647 unittest.skip("Reference implementation OOM") 2648 return 2649 if TEST_WITH_ROCM and seq_len_q * seq_len_k * head_dim * batch_size > 1024 * 1024 * 128: 2650 torch.cuda.empty_cache() # Prevent memory fragmentation 2651 seed = 42 2652 scale = scale if scale is None else (1 / head_dim) 2653 n_heads = 4 2654 query = torch.rand(batch_size, n_heads, seq_len_q, head_dim, 2655 device=device, dtype=dtype, requires_grad=True) 2656 key = torch.rand(batch_size, n_heads, seq_len_k, head_dim, device=device, 2657 dtype=dtype, requires_grad=True) 2658 value = torch.rand(batch_size, n_heads, seq_len_k, head_dim, 2659 device=device, dtype=dtype, requires_grad=True) 2660 2661 # Run the math kernel on low precision references 2662 query_ref_lp, key_ref_lp, value_ref_lp = query_key_value_clones(query, key, value, dtype=dtype) 2663 2664 higher_precision_dtype = torch.float64 if dtype == torch.float32 else torch.float32 2665 query_ref, key_ref, value_ref = query_key_value_clones(query, key, value, dtype=higher_precision_dtype) 2666 2667 # Create real output 2668 with sdpa_kernel(backends=[SDPBackend.EFFICIENT_ATTENTION]): 2669 # Set the seed and run the kernel 2670 torch.manual_seed(seed) 2671 out = F.scaled_dot_product_attention(query, key, value, dropout_p=dropout_p, is_causal=is_causal, scale=scale) 2672 2673 if dropout_p == 0.0: 2674 with sdpa_kernel(backends=[SDPBackend.MATH]): 2675 # High Precision Math Reference 2676 out_ref = F.scaled_dot_product_attention(query_ref, key_ref, value_ref, 2677 dropout_p=dropout_p, is_causal=is_causal, scale=scale) 2678 # Low Precision Math Reference 2679 out_lp_ref = F.scaled_dot_product_attention(query_ref_lp, key_ref_lp, value_ref_lp, 2680 dropout_p=dropout_p, is_causal=is_causal, scale=scale) 2681 else: 2682 if seq_len_q > 1024: 2683 self.skipTest("Will call _fill_mem_eff_dropout_mask with too many threads!") 2684 # Create the dropout_mask 2685 torch.manual_seed(seed) 2686 dropout_mask = _get_mem_eff_drop_mask(batch_size, n_heads, seq_len_q, seq_len_k, dropout_p, seed, 0, device=device) 2687 # High Precision Math Reference 2688 out_ref = torch.ops.aten._scaled_dot_product_attention_math( 2689 query_ref, key_ref, value_ref, dropout_p=dropout_p, is_causal=is_causal, scale=scale, dropout_mask=dropout_mask)[0] 2690 # Low Precision Math Reference 2691 out_lp_ref = torch.ops.aten._scaled_dot_product_attention_math( 2692 query_ref_lp, key_ref_lp, value_ref_lp, dropout_p=dropout_p, is_causal=is_causal, scale=scale, 2693 dropout_mask=dropout_mask)[0] 2694 2695 upstream_grad = torch.rand_like(out, requires_grad=False) 2696 2697 out.backward(upstream_grad) 2698 out_ref.backward(upstream_grad.to(out_ref.dtype)) 2699 out_lp_ref.backward(upstream_grad.to(out_lp_ref.dtype)) 2700 2701 # [Note] Fused Tolerances 2702 # Establish the numerical error between the "true" high precision math output 2703 # and the low precision math reference. We use this reference for the atol 2704 # And we use the default rtol for the low precision type. 2705 # We then provide a fudge factor for gradients respectively to account 2706 # for the use of the fused kernel rather than the eager implemntation. 2707 output_ref_atol, output_ref_rtol = get_tolerances(out_ref, out_lp_ref) 2708 2709 # Fudge Factor when dropout is enabled 2710 dropout_fudge_factor = 1.0 if dropout_p == 0.0 else 2.0 2711 2712 query_fudge_factor = dropout_fudge_factor 2713 grad_q_ref_atol, grad_q_ref_rtol = get_tolerances(query_ref.grad, query_ref_lp.grad, query_fudge_factor) 2714 2715 # TODO: Investigate why grad_k needs larger tolerances 2716 key_fudge_factor = 8 * dropout_fudge_factor 2717 grad_k_ref_atol, grad_k_ref_rtol = get_tolerances(key_ref.grad, key_ref_lp.grad, key_fudge_factor) 2718 2719 value_fudge_factor = 7 if not SM80OrLater and dtype == torch.float16 else 1.0 2720 if TEST_WITH_ROCM: 2721 value_fudge_factor = max(2.0, value_fudge_factor) 2722 grad_v_ref_atol, grad_v_ref_rtol = get_tolerances(value_ref.grad, value_ref_lp.grad, value_fudge_factor) 2723 2724 self.assertEqual(out, out_ref.to(out.dtype), atol=output_ref_atol, rtol=output_ref_rtol) 2725 self.assertEqual(query.grad, query_ref.grad.to(query.grad.dtype), 2726 atol=grad_q_ref_atol, rtol=grad_q_ref_rtol) 2727 self.assertEqual(key.grad, key_ref.grad.to(key.grad.dtype), 2728 atol=grad_k_ref_atol, rtol=grad_k_ref_rtol) 2729 self.assertEqual(value.grad, value_ref.grad.to(value.grad.dtype), 2730 atol=grad_v_ref_atol, rtol=grad_v_ref_rtol) 2731 2732 2733 @unittest.skipIf(not PLATFORM_SUPPORTS_MEM_EFF_ATTENTION, "Does not support SDPA") 2734 @unittest.skipIf(IS_JETSON, "causing sigkill on Jetson") 2735 @parametrize("batch_size", [1, 8]) 2736 @parametrize("seq_len_q", [4, 8, 64, 128, 256, 312, 512, 1024, 2048] if MEM_EFF_CAPABILITY_MATCHES_SM80 2737 else [4, 8, 64, 128, 152, 256, 512]) 2738 @parametrize("seq_len_k", [4, 8, 64, 65, 128, 256, 408, 512, 1024, 2048] if MEM_EFF_CAPABILITY_MATCHES_SM80 2739 else [4, 8, 37, 64, 128, 256, 512]) 2740 @parametrize("head_dim", [8, 16, 32, 64, 72, 96, 128] if MEM_EFF_CAPABILITY_MATCHES_SM80 2741 else [8, 16, 32, 64]) 2742 @parametrize("is_causal", [False]) 2743 @parametrize("dropout_p", [0.0, 0.22]) 2744 @parametrize("dtype", [torch.float16, torch.bfloat16, torch.float32] if MEM_EFF_CAPABILITY_MATCHES_SM80 2745 else [torch.float16, torch.float32]) 2746 @parametrize("scale", [None, "l1"]) 2747 def test_mem_efficient_attention_attn_mask_vs_math_ref_grads(self, device, batch_size: int, seq_len_q: int, 2748 seq_len_k: int, head_dim: int, is_causal: bool, 2749 dropout_p: float, dtype: torch.dtype, 2750 scale: str): 2751 def _get_mem_eff_drop_mask(batch_size, n_heads, q_len, kv_len, p, seed, offset, device=device): 2752 mask = torch.empty((batch_size, n_heads, q_len, kv_len), device=device, dtype=torch.float32) 2753 rand_uniform = torch._fill_mem_eff_dropout_mask_(mask, p, seed, offset) 2754 mask = (rand_uniform > p).to(torch.float32) 2755 return mask 2756 if max(seq_len_q, seq_len_k) >= 2048 and torch.cuda.get_device_properties('cuda').total_memory < 40 * 2**30: 2757 unittest.skip("Reference implementation OOM") 2758 return 2759 if TEST_WITH_ROCM and dtype == torch.float32: 2760 unittest.skip("Skip fp32 attn_mask gradients on ROCM, for now.") 2761 return 2762 if TEST_WITH_ROCM and seq_len_q * seq_len_k * head_dim * batch_size > 1024 * 1024 * 128: 2763 torch.cuda.empty_cache() # Prevent memory fragmentation 2764 seed = 42 2765 scale = scale if scale is None else (1 / head_dim) 2766 n_heads = 4 2767 query = torch.rand(batch_size, n_heads, seq_len_q, head_dim, 2768 device=device, dtype=dtype, requires_grad=True) 2769 key = torch.rand(batch_size, n_heads, seq_len_k, head_dim, device=device, 2770 dtype=dtype, requires_grad=True) 2771 value = torch.rand(batch_size, n_heads, seq_len_k, head_dim, 2772 device=device, dtype=dtype, requires_grad=True) 2773 2774 attn_mask = torch.rand(seq_len_q, seq_len_k, device=device, dtype=dtype, requires_grad=True) 2775 2776 # Run the math kernel on low precision references 2777 query_ref_lp, key_ref_lp, value_ref_lp = query_key_value_clones(query, key, value, dtype=dtype) 2778 attn_mask_ref_lp = attn_mask.detach().to(dtype).requires_grad_(True) 2779 2780 higher_precision_dtype = torch.float64 if dtype == torch.float32 else torch.float32 2781 query_ref, key_ref, value_ref = query_key_value_clones(query, key, value, dtype=higher_precision_dtype) 2782 attn_mask_ref = attn_mask.detach().to(higher_precision_dtype).requires_grad_(True) 2783 2784 # Create real output 2785 with sdpa_kernel(backends=[SDPBackend.EFFICIENT_ATTENTION]): 2786 # Set the seed and run the kernel 2787 torch.manual_seed(seed) 2788 out = F.scaled_dot_product_attention(query, key, value, attn_mask, dropout_p=dropout_p, 2789 is_causal=is_causal, scale=scale) 2790 2791 if dropout_p == 0.0: 2792 with sdpa_kernel(backends=[SDPBackend.MATH]): 2793 # High Precision Math Reference 2794 out_ref = F.scaled_dot_product_attention(query_ref, key_ref, value_ref, attn_mask_ref, 2795 dropout_p=dropout_p, is_causal=is_causal, scale=scale) 2796 # Low Precision Math Reference 2797 out_lp_ref = F.scaled_dot_product_attention(query_ref_lp, key_ref_lp, value_ref_lp, attn_mask_ref_lp, 2798 dropout_p=dropout_p, is_causal=is_causal, scale=scale) 2799 else: 2800 if seq_len_q > 1024: 2801 self.skipTest("Will call _fill_mem_eff_dropout_mask with too many threads!") 2802 # Create the dropout_mask 2803 torch.manual_seed(seed) 2804 dropout_mask = _get_mem_eff_drop_mask(batch_size, n_heads, seq_len_q, 2805 seq_len_k, dropout_p, seed, 0, device=device) 2806 # High Precision Math Reference 2807 out_ref = torch.ops.aten._scaled_dot_product_attention_math( 2808 query_ref, key_ref, value_ref, attn_mask_ref, dropout_p=dropout_p, is_causal=is_causal, 2809 scale=scale, dropout_mask=dropout_mask)[0] 2810 # Low Precision Math Reference 2811 out_lp_ref = torch.ops.aten._scaled_dot_product_attention_math( 2812 query_ref_lp, key_ref_lp, value_ref_lp, attn_mask_ref_lp, 2813 dropout_p=dropout_p, is_causal=is_causal, scale=scale, 2814 dropout_mask=dropout_mask)[0] 2815 2816 upstream_grad = torch.rand_like(out, requires_grad=False) 2817 2818 out.backward(upstream_grad) 2819 out_ref.backward(upstream_grad.to(out_ref.dtype)) 2820 out_lp_ref.backward(upstream_grad.to(out_lp_ref.dtype)) 2821 2822 # [Note] Fused Tolerances 2823 # Establish the numerical error between the "true" high precision math output 2824 # and the low precision math reference. We use this reference for the atol 2825 # And we use the default rtol for the low precision type. 2826 # We then provide a fudge factor for gradients respectively to account 2827 # for the use of the fused kernel rather than the eager implemntation. 2828 output_ref_atol, output_ref_rtol = get_tolerances(out_ref, out_lp_ref) 2829 2830 # Fudge Factor when dropout is enabled 2831 dropout_fudge_factor = 1.0 if dropout_p == 0.0 else 1.75 2832 mask_fudge_factor = 1.0 if attn_mask is None else 1.5 2833 2834 query_fudge_factor = dropout_fudge_factor 2835 grad_q_ref_atol, grad_q_ref_rtol = get_tolerances(query_ref.grad, query_ref_lp.grad, query_fudge_factor) 2836 2837 # TODO: Investigate why grad_k needs larger tolerances 2838 key_fudge_factor = 8 * dropout_fudge_factor * mask_fudge_factor 2839 grad_k_ref_atol, grad_k_ref_rtol = get_tolerances(key_ref.grad, key_ref_lp.grad, key_fudge_factor) 2840 2841 value_fudge_factor = 7 if not SM80OrLater and dtype == torch.float16 else 1.0 2842 if TEST_WITH_ROCM: 2843 value_fudge_factor = max(2.0, value_fudge_factor) 2844 grad_v_ref_atol, grad_v_ref_rtol = get_tolerances(value_ref.grad, value_ref_lp.grad, value_fudge_factor) 2845 2846 mask_fudge_factor = 12 if attn_mask.numel() > 512 else 22 2847 grad_attn_mask_atol, grad_attn_mask_rtol = get_tolerances( 2848 attn_mask_ref.grad, attn_mask_ref_lp.grad, mask_fudge_factor) 2849 2850 self.assertEqual(out, out_ref.to(out.dtype), atol=output_ref_atol, rtol=output_ref_rtol) 2851 self.assertEqual(query.grad, query_ref.grad.to(query.grad.dtype), 2852 atol=grad_q_ref_atol, rtol=grad_q_ref_rtol) 2853 self.assertEqual(key.grad, key_ref.grad.to(key.grad.dtype), 2854 atol=grad_k_ref_atol, rtol=grad_k_ref_rtol) 2855 self.assertEqual(value.grad, value_ref.grad.to(value.grad.dtype), 2856 atol=grad_v_ref_atol, rtol=grad_v_ref_rtol) 2857 2858 self.assertEqual(attn_mask.grad, attn_mask_ref.grad.to(attn_mask.grad.dtype), 2859 atol=grad_attn_mask_atol, rtol=grad_attn_mask_rtol) 2860 2861 @unittest.skipIf(not PLATFORM_SUPPORTS_FLASH_ATTENTION, "Does not support SDPA or pre-SM80 hardware") 2862 @unittest.skipIf(IS_JETSON, "causing sigkill on Jetson") 2863 @parametrize("batch_size", [1, 8]) 2864 @parametrize("seq_len_q", [4, 8, 64, 143, 256, 512, 1024, 2048]) 2865 @parametrize("seq_len_k", [4, 8, 64, 128, 256, 587, 1024, 2048]) 2866 @parametrize("head_dim", [8, 16, 21, 32, 64, 72, 96, 128, 160, 192, 203, 256]) 2867 @parametrize("is_causal", [True, False]) 2868 @parametrize("dropout_p", [0.0, 0.22, 0.48]) 2869 @parametrize("dtype", [torch.float16, torch.bfloat16]) 2870 @parametrize("scale", [None, "l1"]) 2871 def test_flash_attention_vs_math_ref_grads(self, device, batch_size: int, seq_len_q: int, seq_len_k: int, 2872 head_dim: int, is_causal: bool, dropout_p: float, dtype: torch.dtype, 2873 scale: str): 2874 if isSM8XDevice and head_dim in range(193, 256 + 1): 2875 self.skipTest("Flash attention on sm86, sm87, and sm89 for headdim > 192 currently disabled") 2876 if is_causal and seq_len_q != seq_len_k: 2877 self.skipTest("Flash V2 does not accept is_casual when seq_len_q != seq_len_k") 2878 if TEST_WITH_ROCM and seq_len_q >= 1024 and seq_len_k >= 1024 and batch_size > 1: 2879 torch.cuda.empty_cache() # Prevent memory fragmentation 2880 2881 scale = scale if scale is None else (1 / head_dim) 2882 n_heads = 4 2883 query = torch.rand(batch_size, n_heads, seq_len_q, head_dim, 2884 device=device, dtype=dtype, requires_grad=True) 2885 key = torch.rand(batch_size, n_heads, seq_len_k, head_dim, device=device, 2886 dtype=dtype, requires_grad=True) 2887 value = torch.rand(batch_size, n_heads, seq_len_k, head_dim, 2888 device=device, dtype=dtype, requires_grad=True) 2889 2890 # Run the math kernel on low precision references 2891 query_ref_lp, key_ref_lp, value_ref_lp = query_key_value_clones(query, key, value, dtype=dtype) 2892 2893 higher_precision_dtype = torch.float64 if dtype == torch.float32 else torch.float32 2894 query_ref, key_ref, value_ref = query_key_value_clones(query, key, value, dtype=higher_precision_dtype) 2895 2896 is_dropout = dropout_p > 0.0 2897 2898 if not is_dropout: 2899 with sdpa_kernel(backends=[SDPBackend.FLASH_ATTENTION]): 2900 out = F.scaled_dot_product_attention(query, key, value, dropout_p=dropout_p, is_causal=is_causal, scale=scale) 2901 with sdpa_kernel(backends=[SDPBackend.MATH]): 2902 # High Precision Math Reference 2903 out_ref = F.scaled_dot_product_attention( 2904 query_ref, key_ref, value_ref, is_causal=is_causal, scale=scale) 2905 # Low Precision Math Reference 2906 out_lp_ref = F.scaled_dot_product_attention( 2907 query_ref_lp, key_ref_lp, value_ref_lp, is_causal=is_causal, scale=scale) 2908 else: 2909 # Problem: We pad sizes in the composite region of the top level SDPA. But we need the 2910 # Debug mask when have dropout. So I am going to manualy pad up here when testing dropout 2911 q_padded, q_og_size = pad_last_dim(query, 8) 2912 k_padded, k_og_size = pad_last_dim(key, 8) 2913 v_padded, v_og_size = pad_last_dim(value, 8) 2914 # scale needs to be calculated on the og head_size 2915 if scale is None: 2916 scale = 1 / math.sqrt(q_og_size) 2917 output_tuple = torch.ops.aten._scaled_dot_product_flash_attention( 2918 q_padded, k_padded, v_padded, dropout_p=dropout_p, is_causal=is_causal, scale=scale, return_debug_mask=is_dropout) 2919 out = output_tuple[0] 2920 out = out[..., :v_og_size] 2921 # Build dropout_mask 2922 dbug_mask = output_tuple[-1] 2923 query_padding_mask = torch.ones( 2924 batch_size, seq_len_q, device=device, dtype=torch.bool) 2925 key_padding_mask = torch.ones( 2926 batch_size, seq_len_k, device=device, dtype=torch.bool) 2927 2928 softmax_mask = self.convert_flash_attn_S_to_softmax( 2929 dbug_mask, seq_len_q, seq_len_k, query_padding_mask, key_padding_mask, 2930 causal=is_causal)[:, :, :seq_len_q, :seq_len_k] 2931 dropout_mask = softmax_mask >= 0 2932 # attn_unnorm = softmax_mask.abs() 2933 # attn = self.normalize_flash_attn_S(attn_unnorm, q_padded, 2934 # k_padded, v_padded, query_padding_mask, 2935 # key_padding_mask, None, True, is_causal, scale=scale) 2936 2937 # High Precision Math Reference 2938 out_ref = torch.ops.aten._scaled_dot_product_attention_math( 2939 query_ref, key_ref, value_ref, dropout_p=dropout_p, is_causal=is_causal, scale=scale, dropout_mask=dropout_mask)[0] 2940 # Low Precision Math Reference 2941 out_lp_ref = torch.ops.aten._scaled_dot_product_attention_math( 2942 query_ref_lp, key_ref_lp, value_ref_lp, dropout_p=dropout_p, is_causal=is_causal, scale=scale, 2943 dropout_mask=dropout_mask)[0] 2944 2945 upstream_grad = torch.rand_like(out, requires_grad=False) 2946 2947 # backward for flash attention on sm86, sm87, and sm89 for headdim >= 193 currently disabled 2948 if isSM8XDevice and head_dim in range(193, 256): 2949 self.assertRaises(RuntimeError, lambda: out.backward(upstream_grad)) 2950 return 2951 out.backward(upstream_grad) 2952 out_ref.backward(upstream_grad.to(out_ref.dtype)) 2953 out_lp_ref.backward(upstream_grad.to(out_lp_ref.dtype)) 2954 2955 # See [Note] Fused Tolerances above 2956 output_fudge_factor = 3 if head_dim % 8 != 0 or TEST_WITH_ROCM else 1 2957 output_ref_atol, output_ref_rtol = get_tolerances(out_ref, out_lp_ref, output_fudge_factor) 2958 2959 # TODO: Investigate why grad_q needs larger tolerances 2960 query_fudge_factor = 4 2961 grad_q_ref_atol, grad_q_ref_rtol = get_tolerances(query_ref.grad, query_ref_lp.grad, query_fudge_factor) 2962 2963 key_fudge_factor = 2 2964 grad_k_ref_atol, grad_k_ref_rtol = get_tolerances(key_ref.grad, key_ref_lp.grad, key_fudge_factor) 2965 2966 value_fudge_factor = 2 2967 grad_v_ref_atol, grad_v_ref_rtol = get_tolerances(value_ref.grad, value_ref_lp.grad, value_fudge_factor) 2968 2969 self.assertEqual(out, out_ref.to(out.dtype), atol=output_ref_atol, rtol=output_ref_rtol) 2970 self.assertEqual(query.grad, query_ref.grad.to(query.grad.dtype), 2971 atol=grad_q_ref_atol, rtol=grad_q_ref_rtol) 2972 self.assertEqual(key.grad, key_ref.grad.to(key.grad.dtype), 2973 atol=grad_k_ref_atol, rtol=grad_k_ref_rtol) 2974 self.assertEqual(value.grad, value_ref.grad.to(value.grad.dtype), 2975 atol=grad_v_ref_atol, rtol=grad_v_ref_rtol) 2976 2977 @skipIfRocm # FIXME: "capturing stream has unjoined work" 2978 @unittest.skipIf(not PLATFORM_SUPPORTS_FLASH_ATTENTION, "Does not support SDPA or pre-SM80 hardware") 2979 @parametrize("batch_size", [1, 8]) 2980 @parametrize("seq_len_q", [256, 512, 1024]) 2981 @parametrize("seq_len_k", [256, 512, 1024]) 2982 @parametrize("head_dim", [32, 64]) 2983 @parametrize("is_causal", [True, False]) 2984 @parametrize("dropout_p", [0.0, 0.22]) 2985 @parametrize("dtype", [torch.float16,]) 2986 @parametrize("scale", [None, "l1"]) 2987 @parametrize("fused_kernel", PLATFORM_SPECIFIC_SDPA) 2988 def test_fused_attention_vs_math_ref_grads_cudagraph(self, device, batch_size: int, seq_len_q: int, seq_len_k: int, 2989 head_dim: int, 2990 is_causal: bool, 2991 dropout_p: float, 2992 dtype: torch.dtype, 2993 scale: str, 2994 fused_kernel: SDPBackend): 2995 def _get_mem_eff_drop_mask(batch_size, n_heads, q_len, kv_len, dropout_p, seed, offset, device=device): 2996 mask = torch.empty((batch_size, n_heads, q_len, kv_len), device=device, dtype=torch.float32) 2997 rand_uniform = torch._fill_mem_eff_dropout_mask_(mask, dropout_p, seed, offset) 2998 mask = (rand_uniform > dropout_p).to(torch.float32) 2999 return mask 3000 3001 def get_dropout_mask(output, fused_kernel, batch_size, n_heads, q_len, kv_len, dropout_p, device=device): 3002 if fused_kernel == SDPBackend.EFFICIENT_ATTENTION: 3003 output_seed, output_offset = output_tuple[2], output_tuple[3] 3004 output_seed = output_seed.item() 3005 output_offset = output_offset.item() 3006 return _get_mem_eff_drop_mask(batch_size, n_heads, q_len, kv_len, 3007 dropout_p, output_seed, output_offset, device=device) 3008 else: 3009 # Build dropout_mask 3010 dbug_mask = output_tuple[-1] 3011 query_padding_mask = torch.ones( 3012 batch_size, seq_len_q, device=device, dtype=torch.bool) 3013 key_padding_mask = torch.ones( 3014 batch_size, seq_len_k, device=device, dtype=torch.bool) 3015 3016 softmax_mask = self.convert_flash_attn_S_to_softmax( 3017 dbug_mask, seq_len_q, seq_len_k, query_padding_mask, key_padding_mask, 3018 causal=is_causal)[:, :, :seq_len_q, :seq_len_k] 3019 dropout_mask = softmax_mask >= 0 3020 return dropout_mask 3021 3022 if fused_kernel == SDPBackend.FLASH_ATTENTION and is_causal and seq_len_q != seq_len_k: 3023 self.skipTest("Flash V2 does not accept is_casual when seq_len_q != seq_len_k") 3024 3025 seed = 42 3026 scale = scale if scale is None else (1 / head_dim) 3027 n_heads = 4 3028 query = torch.rand(batch_size, n_heads, seq_len_q, head_dim, 3029 device=device, dtype=dtype, requires_grad=True) 3030 key = torch.rand(batch_size, n_heads, seq_len_k, head_dim, device=device, 3031 dtype=dtype, requires_grad=True) 3032 value = torch.rand(batch_size, n_heads, seq_len_k, head_dim, 3033 device=device, dtype=dtype, requires_grad=True) 3034 3035 fused_op = (torch.ops.aten._scaled_dot_product_efficient_attention 3036 if fused_kernel == SDPBackend.EFFICIENT_ATTENTION else torch.ops.aten._scaled_dot_product_flash_attention) 3037 # Run the math kernel on low precision references 3038 query_ref_lp, key_ref_lp, value_ref_lp = query_key_value_clones(query, key, value, dtype=dtype) 3039 3040 higher_precision_dtype = torch.float64 if dtype == torch.float32 else torch.float32 3041 query_ref, key_ref, value_ref = query_key_value_clones(query, key, value, dtype=higher_precision_dtype) 3042 3043 # warmup 3044 s = torch.cuda.Stream() 3045 s.wait_stream(torch.cuda.current_stream()) 3046 # Set the global seed before capture 3047 torch.manual_seed(seed) 3048 kwargs = {"dropout_p": dropout_p, "is_causal": is_causal, "scale": scale} 3049 if fused_kernel == SDPBackend.EFFICIENT_ATTENTION: 3050 kwargs["compute_log_sumexp"] = True 3051 kwargs["attn_bias"] = None 3052 if fused_kernel == SDPBackend.FLASH_ATTENTION: 3053 kwargs['return_debug_mask'] = dropout_p > 0.0 3054 with torch.cuda.stream(s): 3055 # Create real output 3056 output_tuple = fused_op(query, key, value, **kwargs) 3057 3058 torch.cuda.current_stream().wait_stream(s) 3059 out = output_tuple[0] 3060 upstream_grad = torch.rand_like(out, requires_grad=False) 3061 s.wait_stream(torch.cuda.current_stream()) 3062 with torch.cuda.stream(s): 3063 out.backward(upstream_grad) 3064 for x in (query, key, value): 3065 x.grad = None 3066 g = torch.cuda.CUDAGraph() 3067 # Create real output 3068 with torch.cuda.graph(g): 3069 tmp = torch.rand_like(query, device=query.device) # test non-zero intragraph offset 3070 # Create real output 3071 output_tuple = fused_op(query, key, value, **kwargs) 3072 assert all(not isinstance(o, torch.Tensor) or o.is_cuda for o in output_tuple) 3073 g.replay() 3074 out_first = output_tuple[0].clone() 3075 g.replay() 3076 out = output_tuple[0] 3077 if dropout_p == 0.0: 3078 self.assertEqual(out_first, out, atol=0, rtol=0) 3079 else: 3080 # replays produce different results 3081 self.assertNotEqual(out_first, out) 3082 3083 with sdpa_kernel(backends=[SDPBackend.MATH]): 3084 if dropout_p == 0.0: 3085 # High Precision Math Reference 3086 out_ref = F.scaled_dot_product_attention(query_ref, key_ref, value_ref, 3087 dropout_p=dropout_p, is_causal=is_causal, scale=scale) 3088 # Low Precision Math Reference 3089 out_lp_ref = F.scaled_dot_product_attention(query_ref_lp, key_ref_lp, value_ref_lp, 3090 dropout_p=dropout_p, is_causal=is_causal, scale=scale) 3091 else: 3092 # Create the dropout_mask 3093 dropout_mask = get_dropout_mask(output_tuple, fused_kernel, batch_size, 3094 n_heads, seq_len_q, seq_len_k, dropout_p, device) 3095 # High Precision Math Reference 3096 out_ref = torch.ops.aten._scaled_dot_product_attention_math( 3097 query_ref, key_ref, value_ref, dropout_p=dropout_p, is_causal=is_causal, 3098 scale=scale, dropout_mask=dropout_mask)[0] 3099 # Low Precision Math Reference 3100 out_lp_ref = torch.ops.aten._scaled_dot_product_attention_math( 3101 query_ref_lp, key_ref_lp, value_ref_lp, dropout_p=dropout_p, is_causal=is_causal, scale=scale, 3102 dropout_mask=dropout_mask)[0] 3103 3104 3105 g1 = torch.cuda.CUDAGraph() 3106 with torch.cuda.graph(g1): 3107 out.backward(upstream_grad) 3108 g1.replay() 3109 out_ref.backward(upstream_grad.to(out_ref.dtype)) 3110 out_lp_ref.backward(upstream_grad.to(out_lp_ref.dtype)) 3111 3112 # [Note] Fused Tolerances 3113 # Establish the numerical error between the "true" high precision math output 3114 # and the low precision math reference. We use this reference for the atol 3115 # And we use the default rtol for the low precision type. 3116 # We then provide a fudge factor for gradients respectively to account 3117 # for the use of the fused kernel rather than the eager implemntation. 3118 output_ref_atol, output_ref_rtol = get_tolerances(out_ref, out_lp_ref) 3119 3120 # Fudge Factor when dropout is enabled 3121 dropout_fudge_factor = 1.0 if dropout_p == 0.0 else 1.5 3122 3123 query_fudge_factor = dropout_fudge_factor 3124 grad_q_ref_atol, grad_q_ref_rtol = get_tolerances(query_ref.grad, query_ref_lp.grad, query_fudge_factor) 3125 3126 # TODO: Investigate why grad_k needs larger tolerances 3127 key_fudge_factor = 8 * dropout_fudge_factor 3128 grad_k_ref_atol, grad_k_ref_rtol = get_tolerances(key_ref.grad, key_ref_lp.grad, key_fudge_factor) 3129 3130 value_fudge_factor = 7 if not SM80OrLater and dtype == torch.float16 else 1.0 3131 grad_v_ref_atol, grad_v_ref_rtol = get_tolerances(value_ref.grad, value_ref_lp.grad, value_fudge_factor) 3132 3133 self.assertEqual(out, out_ref.to(out.dtype), atol=output_ref_atol, rtol=output_ref_rtol) 3134 self.assertEqual(query.grad, query_ref.grad.to(query.grad.dtype), 3135 atol=grad_q_ref_atol, rtol=grad_q_ref_rtol) 3136 self.assertEqual(key.grad, key_ref.grad.to(key.grad.dtype), 3137 atol=grad_k_ref_atol, rtol=grad_k_ref_rtol) 3138 self.assertEqual(value.grad, value_ref.grad.to(value.grad.dtype), 3139 atol=grad_v_ref_atol, rtol=grad_v_ref_rtol) 3140 3141 @skipIfRocm # Nested Tensor 3142 @unittest.skipIf(not PLATFORM_SUPPORTS_FUSED_ATTENTION, "Fused SDPA was not built for this system") 3143 @parametrize("fused_kernel", [SDPBackend.FLASH_ATTENTION, SDPBackend.EFFICIENT_ATTENTION] if 3144 PLATFORM_SUPPORTS_FLASH_ATTENTION else [SDPBackend.EFFICIENT_ATTENTION]) 3145 def test_fused_kernels_seq_len_1_inputs(self, device, fused_kernel): 3146 rand_nested_tensor = partial(rand_sdpa_tensor, type="nested", device=device, dtype=torch.float16) 3147 batch, num_heads, head_dim = 32, 16, 64 3148 seq_lens = torch.randint(low=1, high=32, size=(batch,)) 3149 # make sure some seq_lens are 1 3150 num_ones = 10 3151 indices = torch.randint(low=0, high=batch, size=(num_ones,)) 3152 seq_lens.scatter_(0, indices, 1) 3153 3154 shape = SdpaShape(batch, num_heads, seq_lens.tolist(), head_dim) 3155 query = rand_nested_tensor(shape) 3156 key = rand_nested_tensor(shape) 3157 value = rand_nested_tensor(shape) 3158 3159 query = query.transpose(1, 2) 3160 key = key.transpose(1, 2) 3161 value = value.transpose(1, 2) 3162 3163 with sdpa_kernel(backends=[fused_kernel]): 3164 actual = torch.nn.functional.scaled_dot_product_attention( 3165 query, key, value, attn_mask=None, dropout_p=0.0, is_causal=False) 3166 with sdpa_kernel(backends=[SDPBackend.MATH]): 3167 math_ref = torch.nn.functional.scaled_dot_product_attention( 3168 query.contiguous().to(torch.float32), 3169 key.contiguous().to(torch.float32), 3170 value.contiguous().to(torch.float32), 3171 attn_mask=None, dropout_p=0.0, is_causal=False) 3172 3173 self.assertEqual(actual.contiguous(), math_ref.contiguous().to(torch.float16), atol=1e-3, rtol=1e-2) 3174 3175 3176 @skipIfRocm # Nested tensor 3177 @unittest.skipIf(not PLATFORM_SUPPORTS_FUSED_ATTENTION, "Fused SDPA was not built for this system") 3178 @parametrize("kernel", [SDPBackend.FLASH_ATTENTION, SDPBackend.EFFICIENT_ATTENTION] if 3179 PLATFORM_SUPPORTS_FLASH_ATTENTION else [SDPBackend.EFFICIENT_ATTENTION]) 3180 @parametrize("expand_q_batch", [True, False]) 3181 @parametrize("expand_k_batch", [True, False]) 3182 @parametrize("expand_v_batch", [True, False]) 3183 @parametrize("expand_q_num_heads", [True, False]) 3184 @parametrize("expand_k_num_heads", [True, False]) 3185 @parametrize("expand_v_num_heads", [True, False]) 3186 def test_fused_kernels_nested_broadcasting( 3187 self, 3188 device, 3189 kernel, 3190 expand_q_batch, 3191 expand_k_batch, 3192 expand_v_batch, 3193 expand_q_num_heads, 3194 expand_k_num_heads, 3195 expand_v_num_heads, 3196 ): 3197 is_efficient = kernel == SDPBackend.EFFICIENT_ATTENTION 3198 dtype = torch.float32 if is_efficient else torch.float16 3199 rand_nested_tensor = partial(rand_sdpa_tensor, type="nested", device=device, dtype=dtype) 3200 batch, num_heads, head_dim = 32, 8, 64 3201 head_dim_v = 32 if is_efficient else head_dim 3202 seq_lens_q = (torch.randint(low=1, high=5, size=(1,)).item() 3203 if expand_q_batch 3204 else torch.randint(low=1, high=32, size=(batch,)).tolist()) 3205 seq_lens_kv = (torch.randint(low=1, high=5, size=(1,)).item() 3206 if (expand_k_batch or expand_v_batch) 3207 else torch.randint(low=1, high=32, size=(batch,)).tolist()) 3208 3209 batch_q = 1 if expand_q_batch else batch 3210 batch_k = 1 if expand_k_batch else batch 3211 batch_v = 1 if expand_v_batch else batch 3212 3213 # handle case where all batch_sizes are 1 3214 batch = max(batch_q, batch_k, batch_v) 3215 3216 num_heads_q = 1 if expand_q_num_heads else num_heads 3217 num_heads_k = 1 if expand_k_num_heads else num_heads 3218 num_heads_v = 1 if expand_v_num_heads else num_heads 3219 3220 # handle case where all num_heads are 1 3221 num_heads = max(num_heads_q, num_heads_k, num_heads_v) 3222 3223 q_shape = SdpaShape(batch_q, num_heads_q, seq_lens_q, head_dim) 3224 k_shape = SdpaShape(batch_k, num_heads_k, seq_lens_kv, head_dim) 3225 v_shape = SdpaShape(batch_v, num_heads_v, seq_lens_kv, head_dim_v) 3226 3227 query = rand_nested_tensor(q_shape) 3228 key = rand_nested_tensor(k_shape) 3229 value = rand_nested_tensor(v_shape) 3230 3231 def _broadcast(t, batch_broadcasted, num_heads_broadcasted): 3232 if batch_broadcasted and num_heads_broadcasted: 3233 # (1, seq_len, 1, head_dim) -> (batch, seq_len, num_heads, head_dim) 3234 result = torch.nested.nested_tensor( 3235 [t[0].expand(-1, num_heads, t.size(-1)) for _ in range(batch)], dtype=torch.float32) 3236 elif batch_broadcasted: 3237 # (1, seq_len, num_heads, head_dim) -> (batch, seq_len, num_heads, head_dim) 3238 result = torch.nested.nested_tensor([t[0] for _ in range(batch)], dtype=torch.float32) 3239 elif num_heads_broadcasted: 3240 # (batch, seq_len, 1, head_dim) -> (batch, seq_len, num_heads, head_dim) 3241 result = torch.nested.nested_tensor([x.expand(-1, num_heads, t.size(-1)) 3242 for x in t.unbind()], dtype=torch.float32) 3243 else: 3244 result = t.to(torch.float32) 3245 return result 3246 3247 query_expanded = _broadcast(query, expand_q_batch, expand_q_num_heads).transpose(1, 2) 3248 key_expanded = _broadcast(key, expand_k_batch, expand_k_num_heads).transpose(1, 2) 3249 value_expanded = _broadcast(value, expand_v_batch, expand_v_num_heads).transpose(1, 2) 3250 3251 query = query.transpose(1, 2) 3252 key = key.transpose(1, 2) 3253 value = value.transpose(1, 2) 3254 3255 with sdpa_kernel(backends=[kernel]): 3256 actual = torch.nn.functional.scaled_dot_product_attention( 3257 query, key, value, attn_mask=None, dropout_p=0.0, is_causal=False) 3258 with sdpa_kernel(backends=[SDPBackend.MATH]): 3259 math_ref = torch.nn.functional.scaled_dot_product_attention( 3260 query_expanded.contiguous(), key_expanded.contiguous(), value_expanded.contiguous(), 3261 attn_mask=None, dropout_p=0.0, is_causal=False) 3262 3263 self.assertEqual(actual.contiguous(), math_ref.contiguous().to(dtype), atol=1e-3, rtol=1e-2) 3264 3265 @skipIfRocm # Nested tensor 3266 @unittest.skipIf(not PLATFORM_SUPPORTS_MEM_EFF_ATTENTION, "Fused SDPA was not built for this system") 3267 def test_fused_kernels_nested_broadcasting_query_dense(self, device): 3268 rand_nested_tensor = partial(rand_sdpa_tensor, type="nested", device=device, dtype=torch.float32) 3269 batch, num_heads, head_dim, head_dim_v = 32, 16, 64, 96 3270 seq_lens = torch.randint(low=1, high=32, size=(batch,)).tolist() 3271 q_shape = (1, 1, num_heads, head_dim) 3272 k_shape = SdpaShape(batch, num_heads, seq_lens, head_dim) 3273 v_shape = SdpaShape(batch, 1, seq_lens, head_dim_v) 3274 3275 # create a dense query 3276 query = torch.randn(q_shape, device=device, dtype=torch.float32) 3277 key = rand_nested_tensor(k_shape) 3278 value = rand_nested_tensor(v_shape) 3279 3280 # (1, 1, num_heads, head_dim) -> (batch, 1, num_heads, head_dim) 3281 query_expanded = torch.nested.nested_tensor([query.squeeze(0) for _ in range(batch)]).transpose(1, 2) 3282 # (batch, seq_lens, 1, head_dim) -> (batch, seq_lens, num_heads, head_dim) 3283 value_expanded = torch.nested.nested_tensor( 3284 [t.expand(-1, num_heads, head_dim_v) for t in value.unbind()]).transpose(1, 2) 3285 3286 query = query.transpose(1, 2) 3287 key = key.transpose(1, 2) 3288 value = value.transpose(1, 2) 3289 3290 with sdpa_kernel(backends=[SDPBackend.EFFICIENT_ATTENTION]): 3291 actual = torch.nn.functional.scaled_dot_product_attention( 3292 query, key, value, attn_mask=None, dropout_p=0.0, is_causal=False) 3293 with sdpa_kernel(backends=[SDPBackend.MATH]): 3294 math_ref = torch.nn.functional.scaled_dot_product_attention( 3295 query_expanded.contiguous(), key.contiguous(), value_expanded.contiguous(), 3296 attn_mask=None, dropout_p=0.0, is_causal=False) 3297 3298 self.assertEqual(actual.contiguous(), math_ref.contiguous(), atol=1e-3, rtol=1e-2) 3299 3300 @onlyCUDA 3301 @skipIfRocm # Nested tensor 3302 @unittest.skipIf(not PLATFORM_SUPPORTS_FLASH_ATTENTION, "Does not support SDPA or pre-SM80 hardware") 3303 @parametrize("batch_size", [8, 32]) 3304 @parametrize("max_seq_len_q", [32, 256]) 3305 @parametrize("max_seq_len_kv", [32, 256]) 3306 @parametrize("head_dim", [8, 64]) 3307 @parametrize("dropout_p", [0.0, 0.1]) 3308 @parametrize("dtype", [torch.float16]) 3309 @parametrize("scale", [None, "l1"]) 3310 @parametrize("is_causal", [True, False]) 3311 def test_flash_attention_vs_math_ref_grads_nestedtensor(self, device, batch_size: int, max_seq_len_q: int, max_seq_len_kv: int, 3312 head_dim: int, dropout_p: float, dtype: torch.dtype, 3313 scale: str, is_causal: bool): 3314 if is_causal: 3315 # TODO we should support this 3316 self.assertRaisesRegex(RuntimeError, "Nested tensors for query / key are not supported when is_causal=True") 3317 return 3318 scale = scale if scale is None else (1 / head_dim) 3319 n_heads = 4 3320 seq_lens_q = torch.randint(low=1, high=max_seq_len_q, size=(batch_size,)) 3321 # Set one entry to max length 3322 seq_lens_q[torch.randint(0, batch_size, size=(1,))] = max_seq_len_q 3323 seq_lens_kv = torch.randint(low=1, high=max_seq_len_kv, size=(batch_size,)) 3324 seq_lens_kv[torch.randint(0, batch_size, size=(1,))] = max_seq_len_kv 3325 3326 def rand_nt(sequence_list, num_heads, head_dim): 3327 tensors = [torch.rand((num_heads, seq_len, head_dim)) for seq_len in sequence_list] 3328 return torch.nested.nested_tensor(tensors, requires_grad=True, device=device, dtype=dtype) 3329 3330 query = rand_nt(seq_lens_q, n_heads, head_dim) 3331 key = rand_nt(seq_lens_kv, n_heads, head_dim) 3332 value = rand_nt(seq_lens_kv, n_heads, head_dim) 3333 3334 # Run the math kernel on low precision references 3335 query_ref_lp = query.clone().detach().requires_grad_(True) 3336 key_ref_lp = key.clone().detach().requires_grad_(True) 3337 value_ref_lp = value.clone().detach().requires_grad_(True) 3338 3339 query_ref = query.clone().detach().to(torch.float32).requires_grad_(True) 3340 key_ref = key.clone().detach().to(torch.float32).requires_grad_(True) 3341 value_ref = value.clone().detach().to(torch.float32).requires_grad_(True) 3342 3343 is_dropout = dropout_p > 0.0 3344 3345 if not is_dropout: 3346 with sdpa_kernel(backends=[SDPBackend.FLASH_ATTENTION]): 3347 out = F.scaled_dot_product_attention(query, key, value, dropout_p=dropout_p, is_causal=is_causal, scale=scale) 3348 with sdpa_kernel(backends=[SDPBackend.MATH]): 3349 # High Precision Math Reference 3350 out_ref = F.scaled_dot_product_attention( 3351 query_ref, key_ref, value_ref, is_causal=is_causal, scale=scale) 3352 # Low Precision Math Reference 3353 out_lp_ref = F.scaled_dot_product_attention( 3354 query_ref_lp, key_ref_lp, value_ref_lp, is_causal=is_causal, scale=scale) 3355 else: 3356 # Create real output 3357 output_tuple = torch.ops.aten._scaled_dot_product_flash_attention( 3358 query, key, value, dropout_p=dropout_p, is_causal=is_causal, 3359 scale=scale, return_debug_mask=is_dropout) 3360 out = output_tuple[0] 3361 dbug_mask = output_tuple[-1] 3362 3363 query_padding_mask = torch.arange(max_seq_len_q).unsqueeze(0).expand( 3364 batch_size, max_seq_len_q 3365 ) < seq_lens_q.unsqueeze(-1) 3366 query_padding_mask = query_padding_mask.to("cuda") 3367 3368 key_padding_mask = torch.arange(max_seq_len_kv).unsqueeze(0).expand( 3369 batch_size, max_seq_len_kv 3370 ) < seq_lens_kv.unsqueeze(-1) 3371 key_padding_mask = key_padding_mask.to("cuda") 3372 3373 softmax_mask = self.convert_flash_attn_S_to_softmax( 3374 dbug_mask, max_seq_len_q, max_seq_len_kv, query_padding_mask, key_padding_mask, causal=is_causal) 3375 dropout_mask = softmax_mask >= 0 3376 nt_stack = [] 3377 for tensor_component in range(batch_size): 3378 batch_stack = [] 3379 for head in range(n_heads): 3380 batch_stack.append(dropout_mask[tensor_component, head, 3381 0:seq_lens_q[tensor_component], 3382 0:seq_lens_kv[tensor_component]].unsqueeze(0)) 3383 nt_stack.append(torch.cat(batch_stack)) 3384 nested_dropout_mask = torch.nested.nested_tensor(nt_stack) 3385 # High Precision Math Reference 3386 out_ref = torch.ops.aten._scaled_dot_product_attention_math( 3387 query_ref, key_ref, value_ref, dropout_p=dropout_p, 3388 is_causal=is_causal, scale=scale, dropout_mask=nested_dropout_mask)[0] 3389 # Low Precision Math Reference 3390 out_lp_ref = torch.ops.aten._scaled_dot_product_attention_math( 3391 query_ref_lp, key_ref_lp, value_ref_lp, dropout_p=dropout_p, is_causal=is_causal, scale=scale, 3392 dropout_mask=nested_dropout_mask)[0] 3393 3394 upstream_grad = out.detach().clone().contiguous() 3395 3396 out.backward(upstream_grad) 3397 out_ref.backward(upstream_grad.to(out_ref.dtype)) 3398 out_lp_ref.backward(upstream_grad.to(out_lp_ref.dtype)) 3399 3400 # See [Note] Fused Tolerances above 3401 output_ref_atol, output_ref_rtol = calculate_nt_tolerances(out_ref, out_lp_ref, out.dtype) 3402 grad_q_ref_atol, grad_q_ref_rtol = calculate_nt_tolerances(query_ref.grad, query_ref_lp.grad, 3403 query.grad.dtype, fudge_factor=4) 3404 grad_k_ref_atol, grad_k_ref_rtol = calculate_nt_tolerances(key_ref.grad, key_ref_lp.grad, key.grad.dtype) 3405 grad_v_ref_atol, grad_v_ref_rtol = calculate_nt_tolerances(value_ref.grad, value_ref_lp.grad, value.grad.dtype) 3406 3407 self.assertEqual(out, out_ref.to(out.dtype), atol=output_ref_atol, rtol=output_ref_rtol) 3408 self.assertEqual(query.grad, query_ref.grad.to(query.grad.dtype), 3409 atol=grad_q_ref_atol, rtol=grad_q_ref_rtol) 3410 self.assertEqual(key.grad.contiguous(), key_ref.grad.contiguous().to(key.grad.dtype), 3411 atol=grad_k_ref_atol, rtol=grad_k_ref_rtol) 3412 self.assertEqual(value.grad, value_ref.grad.to(value.grad.dtype), 3413 atol=grad_v_ref_atol, rtol=grad_v_ref_rtol) 3414 3415class TestAttnBias(NNTestCase): 3416 3417 def run_test( 3418 self, 3419 device, 3420 make_q, 3421 make_kv, 3422 attn_bias=None, 3423 forw_tolerances: Optional[Tolerances] = None, 3424 grad_tolerances: Optional[Tolerances] = None, 3425 backend=None, 3426 ): 3427 if backend is not None: 3428 torch._dynamo.reset() 3429 3430 query, key, value = make_q(), make_kv(), make_kv() 3431 query_prototype, key_prototype, value_prototype = query_key_value_clones(query, key, value) 3432 3433 realized = attn_bias._materialize(device) if attn_bias is not None else None 3434 pytorch_output = scaled_dot_product_attention( 3435 query, key, value, attn_mask=realized, dropout_p=0.0, is_causal=False 3436 ) 3437 3438 sdpa_op = ( 3439 torch.compile(scaled_dot_product_attention, backend=backend) 3440 if backend is not None 3441 else scaled_dot_product_attention 3442 ) 3443 sdpa_output = sdpa_op( 3444 query_prototype, 3445 key_prototype, 3446 value_prototype, 3447 attn_mask=attn_bias, 3448 dropout_p=0.0, 3449 is_causal=False, 3450 scale=None, 3451 ) 3452 3453 dOut = torch.randn_like(pytorch_output) 3454 pytorch_output.backward(dOut) 3455 sdpa_output.backward(dOut) 3456 3457 # Use default assert_close tolerances for dtypes 3458 if forw_tolerances is None: 3459 forw_tolerances = Tolerances(atol=None, rtol=None) 3460 if grad_tolerances is None: 3461 grad_tolerances = Tolerances(atol=None, rtol=None) 3462 3463 torch.testing.assert_close(pytorch_output, sdpa_output, rtol=forw_tolerances.rtol, atol=forw_tolerances.atol) 3464 torch.testing.assert_close(query.grad, query_prototype.grad, rtol=grad_tolerances.rtol, atol=grad_tolerances.atol) 3465 torch.testing.assert_close(key.grad, key_prototype.grad, rtol=grad_tolerances.rtol, atol=grad_tolerances.atol) 3466 torch.testing.assert_close(value.grad, value_prototype.grad, rtol=grad_tolerances.rtol, atol=grad_tolerances.atol) 3467 3468 @skipIfRocm # No support for the second variant for now 3469 @parametrize("causal_variant", [CausalVariant.UPPER_LEFT, CausalVariant.LOWER_RIGHT]) 3470 @parametrize( 3471 "shape", 3472 [(16, 16, 128, 128, 16), (16, 16, 128, 256, 32), (16, 16, 256, 128, 32), (1, 1, 23, 56, 15)], 3473 ) 3474 def test_causal_variants(self, device, causal_variant: CausalVariant, shape: List[Tuple[int]]): 3475 make_tensor = partial( 3476 torch.rand, device=device, dtype=torch.float16, requires_grad=True 3477 ) 3478 3479 bsz, num_heads, seq_len_q, seq_len_kv, head_dim = shape 3480 make_q_tensor = partial(make_tensor, SdpaShape(bsz, num_heads, seq_len_q, head_dim)) 3481 make_kv_tensor = partial(make_tensor, SdpaShape(bsz, num_heads, seq_len_kv, head_dim)) 3482 if causal_variant == CausalVariant.LOWER_RIGHT and seq_len_q > seq_len_kv: 3483 self.skipTest( 3484 "Lower right causal mask will produce NaNs in the output when seq_len_q > seq_len_kv!" 3485 ) 3486 3487 forw_tol = Tolerances(1e-3, 1e-3) 3488 grad_tol = Tolerances(5e-3, 5e-3) 3489 3490 if causal_variant == CausalVariant.UPPER_LEFT: 3491 attn_bias = causal_upper_left(seq_len_q, seq_len_kv) 3492 else: 3493 attn_bias = causal_lower_right(seq_len_q, seq_len_kv) 3494 3495 self.run_test(device, make_q_tensor, make_kv_tensor, attn_bias, forw_tol, grad_tol, backend=None) 3496 3497 @skipIfRocm # CausalVariant 3498 @parametrize("causal_variant", [CausalVariant.UPPER_LEFT, CausalVariant.LOWER_RIGHT]) 3499 @parametrize( 3500 "shape", 3501 [(16, 16, 128, 128, 16), (16, 16, 128, 256, 32), (16, 16, 256, 128, 32), (1, 1, 23, 56, 15)], 3502 ) 3503 @unittest.skipIf(IS_WINDOWS, "torch.compile is not supported on windows") 3504 @skipIfTorchDynamo("This function already calls torch.compile.") 3505 def test_causal_variants_compile(self, device, causal_variant: CausalVariant, shape: List[Tuple[int]]): 3506 cnts = CompileCounterWithBackend("aot_eager") 3507 make_tensor = partial( 3508 torch.rand, device=device, dtype=torch.float16, requires_grad=True 3509 ) 3510 3511 bsz, num_heads, seq_len_q, seq_len_kv, head_dim = shape 3512 make_q_tensor = partial(make_tensor, SdpaShape(bsz, num_heads, seq_len_q, head_dim)) 3513 make_kv_tensor = partial(make_tensor, SdpaShape(bsz, num_heads, seq_len_kv, head_dim)) 3514 if causal_variant == CausalVariant.LOWER_RIGHT and seq_len_q > seq_len_kv: 3515 self.skipTest( 3516 "Lower right causal mask will produce NaNs in the output when seq_len_q > seq_len_kv!" 3517 ) 3518 forw_tol = Tolerances(1e-3, 1e-3) 3519 grad_tol = Tolerances(5e-3, 5e-3) 3520 3521 if causal_variant == CausalVariant.UPPER_LEFT: 3522 attn_bias = causal_upper_left(seq_len_q, seq_len_kv) 3523 else: 3524 attn_bias = causal_lower_right(seq_len_q, seq_len_kv) 3525 3526 self.run_test(device, make_q_tensor, make_kv_tensor, attn_bias, forw_tol, grad_tol, backend=cnts) 3527 self.assertEqual(cnts.frame_count, 1, "Compiled graph should have 1 frame!") 3528 3529 @parametrize("shape", [(16, 16, 128, 128, 16), (16, 16, 128, 256, 32), (16, 16, 256, 128, 32), (1, 1, 23, 56, 15)]) 3530 def test_is_causal_equals_upper_left(self, device, shape: List[Tuple[int]]): 3531 make_tensor = partial( 3532 torch.rand, device=device, dtype=torch.float16, requires_grad=True 3533 ) 3534 3535 bsz, num_heads, seq_len_q, seq_len_kv, head_dim = shape 3536 make_q_tensor = partial(make_tensor, SdpaShape(bsz, num_heads, seq_len_q, head_dim)) 3537 make_kv_tensor = partial(make_tensor, SdpaShape(bsz, num_heads, seq_len_kv, head_dim)) 3538 3539 forw_tol = Tolerances(1e-3, 1e-3) 3540 grad_tol = Tolerances(5e-3, 5e-3) 3541 3542 query = make_q_tensor() 3543 key = make_kv_tensor() 3544 value = make_kv_tensor() 3545 attn_bias = causal_upper_left(seq_len_q, seq_len_kv) 3546 3547 out_attn_bias = scaled_dot_product_attention(query, key, value, attn_mask=attn_bias, dropout_p=0.0) 3548 out_is_causal = scaled_dot_product_attention(query, key, value, is_causal=True, dropout_p=0.0) 3549 torch.testing.assert_close(out_attn_bias, out_is_causal, rtol=forw_tol.rtol, atol=forw_tol.atol) 3550 3551 def test_is_causal_and_mask_fails(self, device): 3552 make_tensor = partial( 3553 torch.rand, device=device, dtype=torch.float16, requires_grad=True 3554 ) 3555 make_q_tensor = partial(make_tensor, SdpaShape(16, 16, 128, 16)) 3556 make_kv_tensor = partial(make_tensor, SdpaShape(16, 16, 128, 16)) 3557 3558 query = make_q_tensor() 3559 key = make_kv_tensor() 3560 value = make_kv_tensor() 3561 attn_bias = causal_upper_left(128, 128) 3562 3563 with self.assertRaisesRegex(ValueError, "CausalBias should not be used with causal=True"): 3564 scaled_dot_product_attention(query, key, value, attn_mask=attn_bias, is_causal=True, dropout_p=0.0) 3565 3566if NOTEST_CPU: 3567 device_types = ("cuda", ) 3568else: 3569 device_types = ("cpu", "cuda") 3570 3571instantiate_device_type_tests(TestTransformers, globals(), only_for=device_types) 3572instantiate_device_type_tests(TestSDPAFailureModes, globals(), only_for=device_types) 3573instantiate_device_type_tests(TestSDPA, globals(), only_for=device_types) 3574instantiate_device_type_tests(TestSDPACudaOnly, globals(), only_for=("cuda")) 3575instantiate_device_type_tests(TestAttnBias, globals(), only_for=device_types) 3576 3577if __name__ == '__main__': 3578 run_tests() 3579