xref: /aosp_15_r20/external/pytorch/test/test_transformers.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# Owner(s): ["module: nn"]
2
3import contextlib
4from functools import partial
5from collections import namedtuple
6import sys
7import torch
8import torch.nn as nn
9import torch.nn.functional as F
10from torch.nn.functional import scaled_dot_product_attention
11from torch.nn.attention import sdpa_kernel, SDPBackend
12from torch.nn.attention.bias import CausalVariant, causal_lower_right, causal_upper_left
13from torch.nn.parameter import Parameter
14import unittest
15from unittest.mock import patch, MagicMock, ANY
16import math
17import torch.optim as optim
18from torch.testing._internal.common_device_type import instantiate_device_type_tests, onlyCUDA, onlyCPU
19from typing import List, Tuple, Optional
20from torch.testing._internal.common_nn import NNTestCase
21from torch.testing._internal.common_utils import (
22    TEST_WITH_ROCM,
23    skipIfRocm,
24    skipIfTorchDynamo,
25    TEST_FAIRSEQ,
26    run_tests,
27    parametrize,
28    freeze_rng_state,
29    TEST_WITH_CROSSREF,
30    slowTest,
31    set_default_dtype,
32    gradcheck,
33    make_tensor,
34    NOTEST_CPU,
35    IS_WINDOWS,
36    TEST_WITH_TORCHDYNAMO,
37)
38from torch._dynamo.testing import CompileCounterWithBackend
39
40
41from torch.testing._internal.common_methods_invocations import wrapper_set_seed
42from torch.testing._internal.common_cuda import (
43    IS_JETSON, SM80OrLater, PLATFORM_SUPPORTS_FLASH_ATTENTION,
44    PLATFORM_SUPPORTS_MEM_EFF_ATTENTION,
45    PLATFORM_SUPPORTS_FUSED_ATTENTION,
46    PLATFORM_SUPPORTS_CUDNN_ATTENTION
47)
48
49if TEST_FAIRSEQ:
50    import fairseq.models.transformer as fairseq_transformer
51
52SdpaShape = namedtuple('Sdpa_Shape', ['batch', 'num_heads', 'seq_len', 'head_dim'])
53Tolerances = namedtuple('Tolerances', ['atol', 'rtol'])
54
55@contextlib.contextmanager
56def use_deterministic_algorithims(mode: bool, warn_only: bool):
57    r"""
58    This context manager can be used to temporarily enable or disable deterministic algorithms.
59    Upon exiting the context manager, the previous state of the flag will be restored.
60    """
61    previous_mode: bool = torch.are_deterministic_algorithms_enabled()
62    previous_warn_only: bool = torch.is_deterministic_algorithms_warn_only_enabled()
63    try:
64        torch.use_deterministic_algorithms(mode, warn_only=warn_only)
65        yield {}
66    finally:
67        torch.use_deterministic_algorithms(previous_mode, warn_only=previous_warn_only)
68
69
70# Found in torch/testing/_comparison.py
71default_atol = {torch.float16: 1e-3, torch.bfloat16: 1e-3, torch.float32: 1e-5}
72default_rtol = {torch.float16: 1e-3, torch.bfloat16: 1.6e-2, torch.float32: 1.3e-6}
73
74isSM8XDevice = torch.cuda.is_available() and torch.cuda.get_device_capability() in [(8, 6), (8, 7), (8, 9)]
75isSM90Device = torch.cuda.is_available() and torch.cuda.get_device_capability() == (9, 0)
76isSM5xDevice = torch.cuda.is_available() and torch.cuda.get_device_capability()[0] == 5
77isLessThanSM80Device = torch.cuda.is_available() and torch.cuda.get_device_capability()[0] < 8
78
79def get_rtol(true_value: torch.Tensor, computed_value: torch.Tensor) -> float:
80    deviation = true_value - computed_value
81    deviation = torch.abs(deviation / true_value)
82    # Fill in the nans with the default rtol
83    torch.nan_to_num_(deviation, nan=default_rtol[computed_value.dtype])
84    return deviation.max().item()
85
86
87def get_atol(true_value: torch.Tensor, computed_value: torch.Tensor) -> float:
88    deviation = true_value - computed_value
89    atol = torch.abs(deviation).max().item()
90    return atol
91
92
93def get_tolerances(
94    true_value: torch.Tensor,
95    computed_value: torch.Tensor,
96    fudge_factor: Optional[float] = None,
97) -> Tuple[float, float]:
98    """Returns the absolute and relative tolerances for comparing two tensors."""
99    fudge_factor = fudge_factor if fudge_factor is not None else 1.0
100    atol = get_atol(true_value, computed_value)
101    rtol = get_rtol(true_value, computed_value)
102
103    atol = fudge_factor * max(atol, default_atol[computed_value.dtype])
104    rtol = fudge_factor * max(rtol, default_rtol[computed_value.dtype])
105    # torch.isclose() has weird behavior around see:
106    # https://github.com/pytorch/pytorch/issues/102400
107    if rtol > 1e30:
108        rtol = default_rtol[computed_value.dtype]
109    return atol, rtol
110
111
112def query_key_value_clones(query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, dtype: torch.dtype = None):
113    """ Clones the query, key, and value tensors and moves them to the specified dtype. """
114    if dtype is None:
115        dtype = query.dtype
116    query_ref = query.clone().detach().to(dtype).requires_grad_(query.requires_grad)
117    key_ref = key.clone().detach().to(dtype).requires_grad_(key.requires_grad)
118    value_ref = value.clone().detach().to(dtype).requires_grad_(value.requires_grad)
119    return query_ref, key_ref, value_ref
120
121def get_platform_specific_sdpa():
122    ret = []
123    if PLATFORM_SUPPORTS_FLASH_ATTENTION:
124        ret.append(SDPBackend.FLASH_ATTENTION)
125    if PLATFORM_SUPPORTS_MEM_EFF_ATTENTION:
126        ret.append(SDPBackend.EFFICIENT_ATTENTION)
127    if PLATFORM_SUPPORTS_CUDNN_ATTENTION:
128        ret.append(SDPBackend.CUDNN_ATTENTION)
129    if not ret:
130        # Add a placeholder, an empty list causes "An empty arg_values was passed to @parametrize"
131        ret.append(SDPBackend.EFFICIENT_ATTENTION)
132    return ret
133
134PLATFORM_SPECIFIC_SDPA = get_platform_specific_sdpa()
135# Indicate the Efficient attention backend can support:
136# 1. sequence longher than 512
137# 2. head dimsion larger than 64
138MEM_EFF_CAPABILITY_MATCHES_SM80 = SM80OrLater or TEST_WITH_ROCM
139
140def rand_sdpa_tensor(shape: SdpaShape, device: str, dtype: torch.dtype, type: str,
141                     requires_grad: bool = False, packed: bool = False) -> torch.Tensor:
142    """Creates rand dense or nested tensor with given shape and type.
143
144    Args:
145        shape (Tuple[int]): Shape of Tensor to construct
146        device (str): which device to create tensor on
147        dtype (torch.dtype): Tensors' dtype
148        type (str): Nested or Dense
149        requires_grad (bool, optional): Tensors grad status. Defaults to False.
150        packed (bool, optional): Whether to create a single QKV packed or not. Defaults to False.
151
152    Returns:
153        torch.Tensor: A new tensor
154    """
155    batch, num_heads, seq_len, head_dim = shape.batch, shape.num_heads, shape.seq_len, shape.head_dim
156    if type == "nested":
157        if isinstance(seq_len, list):
158            def _size(i):
159                return (seq_len[i], num_heads, head_dim) if not packed else (seq_len[i], 3 * num_heads * head_dim)
160
161            return torch.nested.nested_tensor([
162                torch.randn(_size(i), device=device, dtype=dtype, requires_grad=requires_grad)
163                for i in range(batch)])
164        else:
165            size = (seq_len, num_heads, head_dim) if not packed else (seq_len, 3 * num_heads * head_dim)
166            return torch.nested.nested_tensor([
167                torch.randn(size, device=device, dtype=dtype, requires_grad=requires_grad)
168                for _ in range(batch)])
169    else:
170        assert (isinstance(seq_len, int))
171        size = (batch, seq_len, num_heads, head_dim) if not packed else (batch, seq_len, 3 * num_heads * head_dim)
172        return torch.randn(size, device=device, dtype=dtype, requires_grad=requires_grad)
173
174def calculate_nt_tolerances(nt_ref_hp, nt_ref_lp, default_dtype, fudge_factor=1):
175    # TODO use NT ops when we have implemented Max for NestedTensor instead of unrolling
176    ref_atol = default_atol[default_dtype]
177    ref_rtol = default_rtol[default_dtype]
178    for tensor_component_ref, tensor_component_ref_lp in zip(nt_ref_hp.unbind(), nt_ref_lp.unbind()):
179        ref_atol = max((fudge_factor * torch.abs(tensor_component_ref - tensor_component_ref_lp)).max().item(), ref_atol)
180        ref_rtol = max(get_rtol(tensor_component_ref, tensor_component_ref_lp), ref_rtol)
181    return ref_atol, ref_rtol
182
183class TestTransformers(NNTestCase):
184    _do_cuda_memory_leak_check = True
185    _do_cuda_non_default_stream = True
186
187    @onlyCUDA
188    @unittest.skip("4D mask not supported yet - activate when 4D mask supported")
189    def test_self_attn_TxT_attn_mask(self, device):
190        embed_dim = 16
191        num_heads = 4
192        batch_size = 10
193        tgt_len = 16
194
195        query = torch.rand(batch_size, tgt_len, embed_dim, device=device)  # [N, T, D]
196        attn_mask = torch.randint(0, 2, (tgt_len, tgt_len)).cuda().float()  # [T, T]
197        attn_mask = attn_mask.masked_fill(attn_mask == 0, float('-inf')).masked_fill(attn_mask == 1, 0.0)
198
199        attn_mask_4d = attn_mask.expand(batch_size, num_heads, tgt_len, tgt_len)
200
201        mta_model = torch.nn.MultiheadAttention(embed_dim, num_heads, batch_first=True).cuda()
202        mta_model.eval()
203
204        # Generate 3D results
205        with torch.inference_mode():
206            output_mask_4d = mta_model(query, query, query, attn_mask=attn_mask_4d)[0]
207            output_mask_4d = output_mask_4d.transpose(0, 1)  # [N, T, D]
208
209            output_mask_TxT = mta_model(query, query, query, attn_mask=attn_mask)[0]
210            output_mask_TxT = output_mask_TxT.transpose(0, 1)  # [N, T, D]
211
212            self.assertEqual(output_mask_4d, output_mask_TxT)
213
214    @slowTest
215    def test_train_with_pad_and_catch_error(self, device):
216        iters = 100
217        pad_mask = torch.tensor([[1, 1, 0, 0]], dtype=torch.bool).to(device)
218        layer = nn.TransformerEncoderLayer(
219            d_model=2,
220            dim_feedforward=4,
221            nhead=2,
222            batch_first=True,
223            activation="gelu",
224            dropout=0,
225        )
226        criterion = nn.MSELoss()
227        encoder = nn.TransformerEncoder(layer, 2).to(device)
228        optimizer = optim.SGD(encoder.parameters(), lr=0.1, momentum=0.9)
229        encoder.train()
230        for i in range(iters):
231            encoder.train()
232            optimizer.zero_grad()
233            inputs = torch.cat([torch.randn(1, 2, 2), torch.zeros(1, 2, 2)], dim=1).to(device)
234
235            outputs = encoder(inputs, src_key_padding_mask=pad_mask)
236
237            loss = criterion(outputs[:, 0:2, :], inputs[:, 0:2, :])
238            loss.backward()
239            optimizer.step()
240
241            with torch.no_grad():
242                test = torch.cat([torch.randn(1, 2, 2), torch.zeros(1, 2, 2)], dim=1).to(device)
243
244                # Expect uint8 type not supported
245                ex = None
246                try:
247                    test_train_uint8 = encoder(test, src_key_padding_mask=pad_mask.to(torch.uint8))
248                except AssertionError as e:
249                    continue
250                self.assertFalse(e, "Failed to catch unsupported uint8 type exception")  # noqa: F821
251
252                test_train_bool = encoder(test, src_key_padding_mask=pad_mask)
253                encoder.eval()
254
255                # Expect long type not supported
256                ex = None
257                try:
258                    test_eval_uint8 = encoder(test, src_key_padding_mask=pad_mask.to(torch.int64))
259                except AssertionError as e:
260                    continue
261                self.assertFalse(e, "Failed to catch unsupported Long type exception")  # noqa: F821
262
263                test_eval_bool = encoder(test, src_key_padding_mask=pad_mask)
264                l1_bool = nn.L1Loss()(test_train_bool[:, 0:2, :], test_eval_bool[:, 0:2, :]).item()
265                self.assertTrue(l1_bool < 1e-4, "Eval/Train difference in pad_mask BOOL")
266
267    @parametrize("attn_mask_dim", [2, 3, None])
268    @parametrize("key_padding_mask_dim", [2, None])
269    @parametrize("mask_dtype", [torch.bool, torch.float32])
270    def test_multiheadattention_fastpath_attn_mask(self, device, attn_mask_dim, key_padding_mask_dim, mask_dtype):
271        with torch.no_grad():
272            B = 2
273            L = 4
274            D = 8
275            H = 4
276
277            if attn_mask_dim == 2:
278                attn_mask = make_tensor((L, L), dtype=mask_dtype, device=device)
279            elif attn_mask_dim == 3:
280                attn_mask = make_tensor((B * H, L, L), dtype=mask_dtype, device=device)
281            elif attn_mask_dim is None:
282                attn_mask = None
283
284            if key_padding_mask_dim == 2:
285                key_padding_mask = make_tensor((B, L), dtype=mask_dtype, device=device)
286            elif key_padding_mask_dim is None:
287                key_padding_mask = None
288
289            mha = nn.MultiheadAttention(D, H, batch_first=True, device=device)
290            X = torch.randn(B, L, D, device=device)
291
292            mha.train()  # disable fast path
293            out, _ = mha(X, X, X, attn_mask=attn_mask, key_padding_mask=key_padding_mask, need_weights=False)
294            mha.eval()  # enable fast path
295            out_fp, _ = mha(X, X, X, attn_mask=attn_mask, key_padding_mask=key_padding_mask, need_weights=False)
296            self.assertEqual(out, out_fp)
297
298    @parametrize("nhead", [1, 4, 8])
299    def test_transformerencoderlayer_src_mask(self, device, nhead):
300        batch_size = 2
301        seqlen = 4
302        d_model = 8
303        dim_feedforward = 32
304
305        model = torch.nn.TransformerEncoderLayer(
306            d_model=d_model,
307            nhead=nhead,
308            dim_feedforward=dim_feedforward,
309            batch_first=True).to(device)
310        src = torch.rand(batch_size, seqlen, d_model).to(device)  # bs, seqlen, d_model
311        src_mask = torch.zeros(seqlen, seqlen).to(torch.bool).to(device)
312
313        model(src, src_mask=src_mask)
314        model.eval()
315        with torch.no_grad():
316            model(src, src_mask=src_mask)
317
318    @parametrize("use_torchscript", [False])
319    @parametrize("enable_nested_tensor", [True, False])
320    @parametrize("use_autocast", [True, False])
321    @parametrize("d_model", [12, 256])
322    def test_transformerencoder_fastpath(self, device, use_torchscript, enable_nested_tensor, use_autocast, d_model):
323        """
324        Test TransformerEncoder fastpath output matches slowpath output
325        """
326        torch.manual_seed(1234)
327        nhead = 4
328        dim_feedforward = d_model
329        batch_first = True
330
331        model = torch.nn.TransformerEncoder(
332            torch.nn.TransformerEncoderLayer(
333                d_model=d_model,
334                nhead=nhead,
335                dim_feedforward=dim_feedforward,
336                batch_first=batch_first),
337            num_layers=2,
338            enable_nested_tensor=enable_nested_tensor
339        ).to(device).eval()
340
341        if use_torchscript:
342            model = torch.jit.script(model)
343
344        # each input is (input, mask)
345        input_mask_pairs = [
346            (
347                torch.rand(3, 2, d_model),
348                [
349                    [0, 1],
350                    [0, 1],
351                    [1, 1]
352                ]
353            ),
354            (
355                torch.rand(2, 100, d_model),
356                [
357                    [0] * 98 + [1] * 2,
358                    [0] * 90 + [1] * 10
359                ]
360            ),
361            # softmax.cu switches from fast->slowpath at masked seqlen 1024. test 1024.
362            (
363                torch.rand(2, 1024, d_model),
364                [
365                    [0] * 1020 + [1] * 4,
366                    [0] * 1024,
367                ]
368            ),
369            (
370                torch.rand(1, 1026, d_model),
371                [[0] * 1024 + [1] * 2]
372            ),
373            # softmax.cu switches from fast->slowpath at masked seqlen 1024. test range of masks above 1024.
374            (
375                torch.rand(4, 1040, d_model),
376                [
377                    [0] * 1024 + [1] * 16,
378                    [0] * 1025 + [1] * 15,
379                    [0] * 1031 + [1] * 9,
380                    [0] * 1040,
381                ]
382            )
383        ]
384        input_mask_pairs = [
385            (
386                torch.tensor(pair[0], device=device, dtype=torch.get_default_dtype()),  # float input
387                torch.tensor(pair[1], device=device, dtype=torch.bool)  # bool mask
388            ) for pair in input_mask_pairs
389        ]
390
391        maybe_autocast = torch.autocast("cuda", dtype=torch.float16) if use_autocast else contextlib.nullcontext()
392        with maybe_autocast:
393            for input, src_key_padding_mask in input_mask_pairs:
394                with torch.no_grad():
395                    fastpath_output = model(input, src_key_padding_mask=src_key_padding_mask)
396                slowpath_output = model(input, src_key_padding_mask=src_key_padding_mask)  # reference
397                # Make sure fastpath_output is same shape as slowpath_output and mask.
398                # When enable_nested_tensor=true, fastpath_output may be smaller than input tensor.
399                # Eg if input bs=1, seqlen=6, and we mask out 2 tokens, fastpath_output will have bs=1, seqlen=4.
400                # Expand back to old size to match.
401                bs, true_seqlen, embed_dim = fastpath_output.shape
402                expanded_seqlen = src_key_padding_mask.shape[1]
403                fastpath_output_expanded = torch.zeros(bs, expanded_seqlen, embed_dim, device=device)
404                fastpath_output_expanded[:, :true_seqlen, :] = fastpath_output
405                # no garauntees on output corresponding to masked tokens, so they may vary between slow/fast path. set all to 0.
406                fastpath_output_expanded = fastpath_output_expanded.masked_fill(src_key_padding_mask.unsqueeze(-1), 0)
407                slowpath_output = slowpath_output.masked_fill(src_key_padding_mask.unsqueeze(-1), 0)
408                torch.testing.assert_close(fastpath_output_expanded, slowpath_output, rtol=1e-7, atol=1e-5)
409
410    @parametrize("with_no_grad", [True, False])
411    @parametrize("training", [True, False])
412    @parametrize("enable_nested_tensor", [False])
413    def test_transformerencoder_square_input(self, with_no_grad, training, enable_nested_tensor, device):
414        """
415        Test for edge cases when input of shape (batch size, sequence length, embedding dimension) has
416        batch size == sequence length
417        """
418        model = torch.nn.TransformerEncoder(
419            torch.nn.TransformerEncoderLayer(d_model=4, nhead=2, dim_feedforward=16, dropout=0.0, batch_first=True),
420            num_layers=2,
421            enable_nested_tensor=enable_nested_tensor
422        ).to(device)
423
424        with torch.no_grad():
425            # set constant weights of the model
426            for idx, p in enumerate(model.parameters()):
427                x = p.data
428                sz = x.view(-1).size(0)
429                shape = x.shape
430                x = torch.cos(torch.arange(0, sz).float().view(shape))
431                p.data.copy_(x)
432
433        if training:
434            model = model.train()
435        else:
436            model = model.eval()
437        x = torch.arange(0, 16).reshape(2, 2, 4).to(torch.get_default_dtype()).to(device)
438        src_mask = torch.Tensor([[0, 1], [0, 0]]).to(torch.bool).to(device)
439
440        if with_no_grad:
441            cm = torch.no_grad()
442        else:
443            cm = contextlib.nullcontext()
444        with cm:
445            result = model(x, mask=src_mask)
446
447        ref_output = torch.Tensor([[[2.420306205749512, 0.017629241570830, -0.607857942581177, -0.085519507527351],
448                                    [2.420306205749512, 0.017629241570830, -0.607857942581177, -0.085519507527351]],
449                                   [[2.419836044311523, 0.017548924311996, -0.608187675476074, -0.085347734391689],
450                                    [2.419836044311523, 0.017548924311996, -0.608187675476074, -0.085347734391689]]]
451                                  ).to(device)
452        self.assertEqual(tuple(result.shape), tuple(ref_output.shape))
453        torch.testing.assert_close(result, ref_output, rtol=1e-7, atol=1e-5)
454
455    @parametrize("batch_first", [True, False])
456    @parametrize("training", [True, False])
457    @parametrize("enable_nested_tensor", [True, False])
458    def test_transformerencoder(self, batch_first, training, enable_nested_tensor, device):
459        def get_a_test_layer(activation, batch_first=False):
460            d_model = 4
461            nhead = 2
462            dim_feedforward = 16
463            dropout = 0.0
464
465            layer = nn.TransformerEncoderLayer(
466                d_model,
467                nhead,
468                dim_feedforward=dim_feedforward,
469                dropout=dropout,
470                activation=activation,
471                batch_first=batch_first,
472            ).to(device)
473
474            with torch.no_grad():
475                # set constant weights of the model
476                for idx, p in enumerate(layer.parameters()):
477                    x = p.data
478                    sz = x.view(-1).size(0)
479                    shape = x.shape
480                    x = torch.cos(torch.arange(0, sz).float().view(shape))
481                    p.data.copy_(x)
482
483            return layer
484
485        # this is a deterministic test for TransformerEncoder
486        activation = F.relu
487
488        def _test(batch_first, training, enable_nested_tensor):
489            def perm_fn(x):
490                return x.transpose(1, 0) if batch_first else x
491
492            encoder_layer = get_a_test_layer(activation=activation,
493                                             batch_first=batch_first)
494
495            model = nn.TransformerEncoder(
496                encoder_layer, 1, enable_nested_tensor=enable_nested_tensor
497            ).to(device)
498
499            if not training:
500                model = model.eval()
501
502            # deterministic input
503            encoder_input = perm_fn(torch.tensor([[[0.7462, 0.6653, 0.5679, 0.4891],
504                                                   [0.5387, 0.1655, 0.3565, 0.0471]],
505                                                  [[0.8335, 0.2799, 0.5031, 0.2947],
506                                                   [0.1402, 0.0318, 0.7636, 0.1346]],
507                                                  [[0.6333, 0.9344, 0.1376, 0.9938],
508                                                   [0.8924, 0.2872, 0.6692, 0.2944]],
509                                                  [[0.9897, 0.6915, 0.3154, 0.1733],
510                                                   [0.8645, 0.3513, 0.3064, 0.0767]],
511                                                  [[0.8117, 0.2366, 0.4838, 0.7881],
512                                                   [0.3718, 0.4945, 0.9511, 0.0864]]]
513                                                 )).to(device)
514            result = model(encoder_input)
515            ref_output = perm_fn(torch.tensor([[[2.428589, 0.020835, -0.602055, -0.085249],
516                                                [2.427987, 0.021213, -0.602496, -0.084103]],
517                                               [[2.424689, 0.019155, -0.604793, -0.085672],
518                                                [2.413863, 0.022211, -0.612486, -0.072490]],
519                                               [[2.433774, 0.021598, -0.598343, -0.087548],
520                                                [2.425104, 0.019748, -0.604515, -0.084839]],
521                                               [[2.436185, 0.022682, -0.596625, -0.087261],
522                                                [2.433556, 0.021891, -0.598509, -0.086832]],
523                                               [[2.416246, 0.017512, -0.610712, -0.082961],
524                                                [2.422901, 0.024187, -0.606178, -0.074929]]]
525                                              )).to(device)
526            self.assertEqual(tuple(result.shape), tuple(ref_output.shape))
527            torch.testing.assert_close(result, ref_output, rtol=1e-7, atol=1e-5)
528
529            # all 0 src_mask
530            src_mask = torch.zeros([5, 5]).to(device) == 1
531            result = model(encoder_input, mask=src_mask)
532            self.assertEqual(tuple(result.shape), tuple(ref_output.shape))
533            torch.testing.assert_close(result, ref_output, rtol=1e-7, atol=1e-5)
534
535            # all 0
536            mask = torch.zeros([2, 5]).to(device) == 1
537            result = model(encoder_input, src_key_padding_mask=mask)
538            self.assertEqual(tuple(result.shape), tuple(ref_output.shape))
539            torch.testing.assert_close(result, ref_output, rtol=1e-7, atol=1e-5)
540
541            mask[0, 1] = 1
542            mask[1, 3] = 1
543            mask[1, 4] = 1
544            result = model(encoder_input, src_key_padding_mask=mask)
545            ref_output = perm_fn(torch.tensor([[[2.429026, 0.020793, -0.601741, -0.085642],
546                                                [2.428811, 0.021445, -0.601912, -0.084252]],
547                                               [[2.425009, 0.019155, -0.604566, -0.085899],
548                                                [2.415408, 0.02249, -0.611415, -0.073]],
549                                               [[2.434199, 0.021682, -0.598039, -0.087699],
550                                                [2.42598, 0.019941, -0.603896, -0.085091]],
551                                               [[2.436457, 0.022736, -0.59643, -0.08736],
552                                                [2.434021, 0.022093, -0.598179, -0.08679]],
553                                               [[2.416531, 0.017498, -0.610513, -0.083181],
554                                                [2.4242, 0.024653, -0.605266, -0.074959]]]
555                                              )).to(device)
556            self.assertEqual(tuple(result.shape), tuple(ref_output.shape))
557            torch.testing.assert_close(result, ref_output, rtol=1e-7, atol=1e-5)
558
559            # test case 2, multiple layers no norm
560            model = nn.TransformerEncoder(encoder_layer, 2, enable_nested_tensor=enable_nested_tensor).to(device)
561            if not training:
562                model = model.eval()
563            result = model(encoder_input, src_key_padding_mask=mask)
564            ref_output = perm_fn(torch.tensor([[[2.419051, 0.017446, -0.608738, -0.085003],
565                                                [2.419102, 0.017452, -0.608703, -0.085026]],
566                                               [[2.419043, 0.017445, -0.608744, -0.084999],
567                                                [2.419052, 0.017446, -0.608738, -0.085004]],
568                                               [[2.419067, 0.017448, -0.608727, -0.085010],
569                                                [2.419098, 0.017452, -0.608706, -0.085024]],
570                                               [[2.419072, 0.017449, -0.608724, -0.085012],
571                                                [2.419119, 0.017455, -0.608691, -0.085034]],
572                                               [[2.419019, 0.017442, -0.608761, -0.084989],
573                                                [2.419075, 0.017449, -0.608722, -0.085014]]]
574                                              )).to(device)
575            self.assertEqual(tuple(result.shape), tuple(ref_output.shape))
576            torch.testing.assert_close(result, ref_output, rtol=1e-7, atol=1e-5)
577
578            model = nn.TransformerEncoder(encoder_layer, 6, enable_nested_tensor=enable_nested_tensor).to(device)
579            if not training:
580                model = model.eval()
581            result = model(encoder_input, src_key_padding_mask=mask)
582            ref_output = perm_fn(torch.tensor([[[2.419101, 0.017453, -0.608703, -0.085025],
583                                                [2.419101, 0.017453, -0.608704, -0.085025]],
584                                               [[2.419101, 0.017453, -0.608703, -0.085025],
585                                                [2.419101, 0.017453, -0.608704, -0.085025]],
586                                               [[2.419101, 0.017453, -0.608703, -0.085025],
587                                                [2.419101, 0.017453, -0.608704, -0.085025]],
588                                               [[2.419101, 0.017453, -0.608703, -0.085025],
589                                                [2.419101, 0.017453, -0.608704, -0.085025]],
590                                               [[2.419101, 0.017453, -0.608703, -0.085025],
591                                                [2.419101, 0.017453, -0.608704, -0.085025]]]
592                                              )).to(device)
593            self.assertEqual(tuple(result.shape), tuple(ref_output.shape))
594            torch.testing.assert_close(result, ref_output, rtol=1e-7, atol=1e-5)
595
596            # test case 3, multiple layers with norm
597            # d_model = 4
598            norm = nn.LayerNorm(4)
599            model = nn.TransformerEncoder(encoder_layer, 2, norm=norm,
600                                          enable_nested_tensor=enable_nested_tensor).to(device)
601            if not training:
602                model = model.eval()
603            result = model(encoder_input, src_key_padding_mask=mask)
604            ref_output = perm_fn(torch.tensor([[[1.695949, -0.357635, -0.893077, -0.445238],
605                                                [1.695955, -0.357639, -0.893050, -0.445266]],
606                                               [[1.695948, -0.357634, -0.893082, -0.445233],
607                                                [1.695950, -0.357635, -0.893077, -0.445238]],
608                                               [[1.695951, -0.357636, -0.893069, -0.445246],
609                                                [1.695955, -0.357639, -0.893052, -0.445264]],
610                                               [[1.695952, -0.357636, -0.893066, -0.445249],
611                                                [1.695957, -0.357641, -0.893041, -0.445276]],
612                                               [[1.695946, -0.357632, -0.893095, -0.445220],
613                                                [1.695952, -0.357637, -0.893065, -0.445251]]]
614                                              )).to(device)
615            self.assertEqual(tuple(result.shape), tuple(ref_output.shape))
616            torch.testing.assert_close(result, ref_output, rtol=1e-7, atol=1e-5)
617
618            model = nn.TransformerEncoder(encoder_layer, 6, norm=norm,
619                                          enable_nested_tensor=enable_nested_tensor).to(device)
620            if not training:
621                model = model.eval()
622            result = model(encoder_input, src_key_padding_mask=mask)
623            ref_output = perm_fn(torch.tensor([[[1.695955, -0.357639, -0.893051, -0.445265],
624                                                [1.695955, -0.357639, -0.893051, -0.445265]],
625                                               [[1.695955, -0.357639, -0.893051, -0.445265],
626                                                [1.695955, -0.357639, -0.893051, -0.445265]],
627                                               [[1.695955, -0.357639, -0.893051, -0.445265],
628                                                [1.695955, -0.357639, -0.893051, -0.445265]],
629                                               [[1.695955, -0.357639, -0.893051, -0.445265],
630                                                [1.695955, -0.357639, -0.893051, -0.445265]],
631                                               [[1.695955, -0.357639, -0.893051, -0.445265],
632                                                [1.695955, -0.357639, -0.893051, -0.445265]]]
633                                              )).to(device)
634            self.assertEqual(tuple(result.shape), tuple(ref_output.shape))
635            torch.testing.assert_close(result, ref_output, rtol=1e-7, atol=1e-5)
636
637        # TODO: remove set default dtype to double by making ref_output more precise.
638        # Added because this test was copied from test_nn.py, which has default
639        # dtype double. If default dtype is float, tests will say tensors not close because
640        # ref output precision too low
641        with set_default_dtype(torch.double):
642            if training:
643                cm = contextlib.nullcontext()
644            else:
645                cm = torch.no_grad()  # transformer fast path requires no grad
646            with cm:
647                _test(batch_first, training, enable_nested_tensor)
648
649    @unittest.skipIf(sys.version_info < (3, 11), "not supported on pre-3.11 Python")
650    def test_encoder_padding_and_src_mask_bool(self):
651        encoder_layer = nn.TransformerEncoderLayer(
652            d_model=16,
653            nhead=2,
654            dim_feedforward=32,
655            dropout=0.1,
656            activation='relu',
657            batch_first=True,
658        )
659        encoder_norm = nn.LayerNorm(16)
660        encoder = nn.TransformerEncoder(
661            encoder_layer, 2, encoder_norm
662        )
663
664        inputs = torch.randn(2, 3, 16)
665
666        src_mask = torch.ones(3, 3, dtype=torch.bool).triu_(diagonal=1)
667        input_seq_len = torch.tensor([3, 2])
668        padding_mask = (
669            torch.arange(3)[None, :].cpu() >= input_seq_len[:, None]
670        )
671
672        with (self.assertNoLogs(None) if not TEST_WITH_TORCHDYNAMO else contextlib.nullcontext()):
673            encoder(
674                inputs,
675                mask=src_mask,
676                src_key_padding_mask=padding_mask,
677            )
678
679    @unittest.skipIf(sys.version_info < (3, 11), "not supported on pre-3.11 Python")
680    def test_decoder_padding_and_src_mask_bool(self):
681
682        def transformer_decoder(inputs, input_seq_len, memory):
683            decoder_layer = nn.TransformerDecoderLayer(
684                d_model=16,
685                nhead=2,
686                dim_feedforward=32,
687                dropout=0.1,
688                activation='relu',
689                batch_first=True,
690            )
691            decoder_norm = nn.LayerNorm(16)
692            decoder = nn.TransformerDecoder(
693                decoder_layer, 2, decoder_norm
694            )
695
696            src_mask = torch.ones(
697                inputs.shape[1], inputs.shape[1], dtype=torch.bool
698            ).triu_(diagonal=1)
699            padding_mask = (
700                torch.arange(inputs.shape[1])[None, :].cpu()
701                >= input_seq_len[:, None]
702            )
703
704            return decoder(
705                inputs,
706                memory,
707                tgt_mask=src_mask,
708                tgt_key_padding_mask=padding_mask,
709                memory_key_padding_mask=padding_mask,
710            )
711
712        inputs = torch.randn(2, 3, 16)
713        memory = torch.randn(2, 3, 16)
714        input_seq_len = torch.tensor([3, 2])
715
716        with self.assertNoLogs(None):
717            transformer_decoder(inputs, input_seq_len, memory)
718
719    def test_encoder_is_causal(self):
720
721        d_model = 3
722        layer = torch.nn.TransformerEncoderLayer(d_model, 1, 6, batch_first=True)
723        layer.eval()
724        x = torch.randn(1, 5, d_model)
725        unmasked_output = layer(x)
726        mask = torch.nn.Transformer.generate_square_subsequent_mask(x.size(1))
727        is_causal_output = layer(x, src_mask=mask, is_causal=True)
728        masked_output = layer(x, src_mask=mask)
729
730        self.assertEqual(masked_output, is_causal_output)
731
732    @onlyCUDA
733    @parametrize("nb_heads", [1, 8])
734    @parametrize("bias", [True, False])
735    def test_mha_native_args(self, nb_heads, bias):
736
737        B, L, F = 8, 100, 128
738        batch_first = True
739        fast_path = True
740        use_pad_mask = (bias % 2) == 1
741
742        mha = nn.MultiheadAttention(
743            embed_dim=F,
744            num_heads=nb_heads,
745            batch_first=batch_first,
746            bias=bias
747        ).cuda()
748        mha.eval()
749
750        ctx = torch.no_grad if fast_path else contextlib.nullcontext
751        with ctx():
752            x = torch.randn(B, L, F).cuda()
753            if not batch_first:
754                x = x.transpose(0, 1)
755
756            pad_mask = None
757            if use_pad_mask:
758                pad_mask = torch.zeros((B, L), dtype=torch.bool).cuda()
759
760            mha(query=x, key=x, value=x, key_padding_mask=pad_mask)
761
762    def test_kpm_mask_trailing_column_with_nested_tensor(self, device):
763        encoder_layer = nn.TransformerEncoderLayer(
764            d_model=256,
765            nhead=4,
766            dim_feedforward=512,
767            activation='gelu',
768            norm_first=False,
769            batch_first=False,
770        )
771        transformer_encoder = nn.TransformerEncoder(encoder_layer, num_layers=3, enable_nested_tensor=True).to(device)
772
773        x = torch.randn(10, 6, 256).to(device)
774        mask = torch.ones(6, 10)
775        mask[0, :] = 0  # here I masked 5 columns instead of just one
776        mask = mask.bool().to(device)
777        out = transformer_encoder(src=x, src_key_padding_mask=mask)
778        self.assertEqual(out.shape[1], 6)
779
780    # CPU unit test has_torch_functions in test environment,
781    #   preventing successful completion
782    @onlyCUDA
783    def test_with_nested_tensor_input(self, device):
784        encoder_layer = nn.TransformerEncoderLayer(
785            d_model=256,
786            nhead=4,
787            dim_feedforward=512,
788            activation='gelu',
789            norm_first=False,
790            batch_first=True,
791        )
792        transformer_encoder = nn.TransformerEncoder(encoder_layer, num_layers=3, enable_nested_tensor=True).to(device)
793
794        transformer_encoder.eval()
795        with torch.no_grad():
796            x = torch.randn(6, 10, 256).to(device)
797            mask = torch.ones(6, 10)
798            mask[0, 0:] = 0  # here I masked 5 columns instead of just one
799            mask[2, 2:] = 0  # here I masked 5 columns instead of just one
800            mask[4, 4:] = 0  # here I masked 5 columns instead of just one
801            mask[5, 8:] = 0  # here I masked 5 columns instead of just one
802            mask = mask.bool().to(device)
803            x = torch._nested_tensor_from_mask(x, mask.logical_not(), mask_check=False)
804            out = transformer_encoder(src=x, src_key_padding_mask=None)
805
806        self.assertEqual(out.is_nested, True)
807
808
809
810    def test_script_encoder_subclass(self, device):
811        class MyCustomLayer(nn.TransformerEncoderLayer):
812            pass
813
814        encoder = nn.TransformerEncoder(
815            MyCustomLayer(d_model=256, nhead=8), num_layers=6
816        ).to(device=device)
817        torch.jit.script(encoder)
818
819    # brazenly adapted from test_transformerencoderlayer_src_mask to test execution of
820    # torchscripted transformerencoderlayer subclass
821    def test_transformerencoderlayer_subclass(self, device):
822        class MyCustomLayer(nn.TransformerEncoderLayer):
823            pass
824
825        nhead = 4
826        batch_size = 2
827        seqlen = 4
828        d_model = 8
829        dim_feedforward = 32
830
831        model = MyCustomLayer(
832            d_model=d_model,
833            nhead=nhead,
834            dim_feedforward=dim_feedforward,
835            batch_first=True).to(device)
836        script_model = torch.jit.script(model)
837
838        src = torch.rand(batch_size, seqlen, d_model).to(device)  # bs, seqlen, d_model
839        src_mask = torch.zeros(seqlen, seqlen).to(torch.bool).to(device)
840
841        torch.manual_seed(42)
842        result = model(src, src_mask=src_mask)
843        torch.manual_seed(42)
844        scripted_result = script_model(src, src_mask=src_mask)
845        self.assertEqual(result, scripted_result)
846
847        model.eval()
848        script_model = torch.jit.script(model)
849
850        with torch.no_grad():
851            result = model(src, src_mask=src_mask)
852            scripted_result = script_model(src, src_mask=src_mask)
853            self.assertEqual(result, scripted_result)
854
855
856    def test_transformerencoderlayer_subclass_model(self, device):
857        class MyCustomLayer(nn.TransformerEncoderLayer):
858            pass
859
860        nhead = 4
861        batch_size = 2
862        seqlen = 4
863        d_model = 8
864        dim_feedforward = 32
865
866        layer = MyCustomLayer(
867            d_model=d_model,
868            nhead=nhead,
869            dim_feedforward=dim_feedforward,
870            batch_first=True)
871        model = nn.TransformerEncoder(
872            layer, num_layers=6
873        ).to(device=device)
874        script_model = torch.jit.script(model)
875
876        src = torch.rand(batch_size, seqlen, d_model).to(device)  # bs, seqlen, d_model
877        src_mask = torch.zeros(seqlen, seqlen).to(torch.bool).to(device)
878
879        torch.manual_seed(42)
880        result = model(src, mask=src_mask)
881        torch.manual_seed(42)
882        scripted_result = script_model(src, mask=src_mask)
883        self.assertEqual(result, scripted_result)
884
885        model.eval()
886        script_model = torch.jit.script(model)
887
888        with torch.no_grad():
889            result = model(src, mask=src_mask)
890            scripted_result = script_model(src, mask=src_mask)
891            self.assertEqual(result, scripted_result)
892
893
894    @onlyCUDA
895    @unittest.skipIf(not TEST_FAIRSEQ, "Fairseq not found")
896    def test_decoder_only_layer(self):
897        DEFAULT_PADDING_IDX = 0
898
899        class FairseqDecoder(torch.nn.Module):
900            def __init__(
901                self,
902                embed_dim,
903                attention_heads,
904                ffn_embed_dim,
905                num_layers,
906                embedding_layer,  # torch.nn.Embedding. Must have a padding_idx field
907                dropout=0,
908                normalize_before=False,
909                torch_encoder=None,  # torch encoder that you can map weights from
910                activation="relu",
911            ):
912                super().__init__()
913
914                cfg = fairseq_transformer.TransformerConfig()
915                cfg.decoder.embed_dim = embed_dim
916                cfg.decoder.output_dim = embed_dim
917                cfg.decoder.attention_heads = attention_heads
918                cfg.decoder.ffn_embed_dim = ffn_embed_dim
919                cfg.dropout = dropout
920                cfg.decoder.normalize_before = normalize_before
921                cfg.decoder.layers = num_layers
922                # make embedding behavior same as other encoders
923                cfg.no_token_positional_embeddings = True
924                cfg.no_scale_embedding = True
925                cfg.activation_fn = activation
926
927                dictionary = {}  # TODO: verify what this is
928
929                self.decoder = fairseq_transformer.TransformerDecoder(
930                    cfg,
931                    dictionary,
932                    embedding_layer,
933                    no_encoder_attn=True,
934                    output_projection=None,
935                )
936
937                if torch_encoder is not None:
938                    self.decoder = torch_to_fairseq(torch_encoder, self.decoder)  # noqa: F821
939                self.decoder = self.decoder.eval().cuda().half()
940
941            def forward(
942                self,
943                tokens,
944                src_lengths=None,
945                with_triangle_mask=False,
946                incremental_state=None,
947            ):
948                return self.decoder(
949                    prev_output_tokens=tokens,
950                    encoder_out=None,
951                    incremental_state=incremental_state,
952                    features_only=True,
953                    full_context_alignment=not with_triangle_mask,
954                    alignment_layer=None,
955                    alignment_heads=None,
956                    src_lengths=src_lengths,
957                    return_all_hiddens=False,
958                )[0]
959
960    @parametrize("input_dim,attn_mask_dim,is_causal",
961                 [(3, None, False), (3, 2, False), (3, 2, True), (3, 3, False), (3, 3, True),
962                  (4, None, False), (4, 2, False), (4, 2, True), (4, 4, False), (4, 4, True)],
963                 name_fn=lambda input_dim, attn_dim, is_causal: (
964                     f"{input_dim}D_input_dim_" + (
965                         f"{attn_dim}D_{'causal_' if is_causal else ''}attn_mask"
966                         if attn_dim is not None else "no_attn_mask")))
967    @parametrize("dropout_p", [0.0, 0.2, 0.5])
968    @sdpa_kernel(backends=[SDPBackend.MATH])
969    def test_scaled_dot_product_attention(self, device, input_dim, attn_mask_dim, is_causal, dropout_p):
970        def sdp_ref(
971                q,
972                k,
973                v,
974                attn_mask=None,
975                dropout_p=0.0):
976            E = q.size(-1)
977            q = q / math.sqrt(E)
978            # (B, Nt, E) x (B, E, Ns) -> (B, Nt, Ns)
979            if attn_mask is not None:
980                attn = torch.baddbmm(attn_mask, q, k.transpose(-2, -1))
981            else:
982                attn = torch.bmm(q, k.transpose(-2, -1))
983
984            attn = torch.nn.functional.softmax(attn, dim=-1)
985            if dropout_p > 0.0:
986                attn = torch.nn.functional.dropout(attn, p=dropout_p)
987            # (B, Nt, Ns) x (B, Ns, E) -> (B, Nt, E)
988            output = torch.bmm(attn, v)
989            return output
990        # TODO: Support cross-device / dtype testing properly when instantiate_device_type_tests() is used.
991        dtypes = [torch.double, torch.float]
992        for dtype in dtypes:
993
994            def rand_tensor(*shape):
995                return torch.randn(shape, device=device, dtype=dtype)
996
997            # This test compares python and C++ implementations of SDP.
998            N, N_prime, L, S, E = 5, 2, 4, 3, 6
999            if input_dim == 3:
1000                query = rand_tensor(N, L, E)
1001                key = rand_tensor(N, S, E)
1002                value = rand_tensor(N, S, E)
1003            elif input_dim == 4:
1004                query = rand_tensor(N, N_prime, L, E)
1005                key = rand_tensor(N, N_prime, S, E)
1006                value = rand_tensor(N, N_prime, S, E)
1007            else:
1008                self.fail(f'Invalid input_dim {input_dim} encountered in SDP test')
1009
1010            attn_mask = None
1011            if attn_mask_dim is not None:
1012                assert attn_mask_dim in [2, input_dim]
1013                mask_size = (L, S) if attn_mask_dim == 2 else ((N, L, S) if input_dim == 3 else (N, N_prime, L, S))
1014                attn_mask = (torch.ones(mask_size, device=device, dtype=torch.bool).tril() if is_causal
1015                             else torch.randint(0, 2, size=mask_size, device=device, dtype=torch.bool))
1016
1017            with freeze_rng_state():
1018                # Python impl only supports float mask and 3D inputs.
1019                attn_mask_float = attn_mask
1020                if attn_mask_float is not None:
1021                    attn_mask_float = torch.zeros_like(attn_mask, dtype=query.dtype)
1022                    attn_mask_float.masked_fill_(attn_mask.logical_not(), float("-inf"))
1023                q, k, v = query.view(-1, L, E), key.view(-1, S, E), value.view(-1, S, E)
1024                a = attn_mask_float
1025                if a is not None and attn_mask_dim > 3:
1026                    a = a.view(-1, L, S)
1027                expected = sdp_ref(q, k, v, attn_mask=a, dropout_p=dropout_p)
1028                if input_dim > 3:
1029                    expected = expected.view(-1, N_prime, L, E)
1030
1031            with freeze_rng_state():
1032                if is_causal:
1033                    # NB: Don't pass attn_mask here
1034                    actual = torch.nn.functional.scaled_dot_product_attention(
1035                        query, key, value, None, dropout_p, is_causal)
1036
1037                    # Error case: both explicit attn_mask and is_causal are set
1038                    with self.assertRaisesRegex(RuntimeError,
1039                                                "Explicit attn_mask should not be set when is_causal=True"):
1040                        torch.nn.functional.scaled_dot_product_attention(
1041                            query, key, value, attn_mask, dropout_p, is_causal)
1042                else:
1043                    actual = torch.nn.functional.scaled_dot_product_attention(
1044                        query, key, value, attn_mask, dropout_p, is_causal)
1045
1046                self.assertEqual(actual, expected)
1047
1048        if attn_mask_dim is None:
1049            q = q.double().clone()
1050            k = k.double().clone()
1051            v = v.double().clone()
1052            q.requires_grad_()
1053            k.requires_grad_()
1054            v.requires_grad_()
1055
1056            assert gradcheck(lambda *args, **kwargs: wrapper_set_seed(sdp_ref, *args, **kwargs),
1057                             (q, k, v, attn_mask, dropout_p))
1058            assert gradcheck(lambda *args, **kwargs:
1059                             wrapper_set_seed(torch.nn.functional.scaled_dot_product_attention, *args, **kwargs),
1060                             (q, k, v, attn_mask, dropout_p))
1061
1062        def test_incompatible_mask(self, device):
1063            def ones_tensor(*shape):
1064                return torch.ones(shape, dtype=torch.float32)
1065            S, L, E, H = 1, 2, 4, 1
1066            qkv = ones_tensor(S, L, E)
1067
1068            mha = nn.MultiheadAttention(E, H)
1069            mha.in_proj_weight = Parameter(torch.ones((E * 3, E)))
1070            mha.out_proj.weight = Parameter(torch.ones((E, E)))
1071            qkv = qkv.to(float)
1072            kpm = ones_tensor(S, L) * float("-inf")
1073            am = ones_tensor(L, L).to(bool)
1074
1075            def func():
1076                return mha(qkv, qkv, qkv, need_weights=False, key_padding_mask=kpm, attn_mask=am)
1077
1078            self.assertRaises(RuntimeError, func)
1079
1080    @unittest.skipIf(TEST_WITH_CROSSREF, 'Fastpath not available with crossref')
1081    @torch.no_grad()
1082    def test_mask_check_fastpath(self):
1083        """
1084        Test that fastpath is executed independently of the masks that are passed.
1085        If the passed key padding mask is left aligned or mask_check=False, test that nested tensors are used
1086        (sparsity fastpath), otherwise use fastpath with traditional tensors.
1087        Also test that fast path is executed with both key padding mask and attention mask passed at the same time.
1088        """
1089
1090        x = torch.Tensor([[[1, 2], [3, 4], [5, 6]]]).to(torch.float)
1091
1092        def _test_fastpath(model, key_padding_mask, mock_return_value, attn_mask=None, nested_tensors=True):
1093            with patch('torch._transformer_encoder_layer_fwd') as fastpath_mock:
1094                fastpath_mock.return_value = mock_return_value
1095                model(x, src_key_padding_mask=key_padding_mask, mask=attn_mask)
1096
1097                # If mock was called, fastpath was taken
1098                self.assertTrue(fastpath_mock.called)
1099
1100                # If mock was called with nested tensors, sparsity fastpath was taken
1101                for call_args, _ in fastpath_mock.call_args_list:
1102                    self.assertEqual(call_args[0].is_nested, nested_tensors)
1103
1104        encoder_layer = torch.nn.TransformerEncoderLayer(d_model=2, nhead=2, dim_feedforward=8, batch_first=True)
1105
1106        model = torch.nn.TransformerEncoder(encoder_layer, num_layers=2, enable_nested_tensor=True, mask_check=True)
1107        model.eval()
1108
1109        aligned_key_padding_mask = torch.Tensor([[0, 0, 1]]).to(torch.bool)
1110        not_aligned_key_padding_mask = torch.Tensor([[1, 0, 1]]).to(torch.bool)
1111        attn_mask = torch.Tensor([[1, 0, 1], [0, 1, 0], [1, 0, 1]]).to(torch.bool)
1112        nested_tensor_return_value = torch.nested.nested_tensor([torch.ones((2, 2), dtype=torch.float)])
1113        tensor_return_value = torch.ones((1, 3, 2), dtype=torch.float)
1114
1115        # Left aligned mask results in sparsity fastpath
1116        _test_fastpath(model, aligned_key_padding_mask, nested_tensor_return_value, nested_tensors=True)
1117
1118        # Not aligned mask results in fastpath
1119        _test_fastpath(model, not_aligned_key_padding_mask, tensor_return_value, nested_tensors=False)
1120
1121        model = torch.nn.TransformerEncoder(encoder_layer, num_layers=2, enable_nested_tensor=False, mask_check=True)
1122        model.eval()
1123
1124        # If nested tensor disabled, fastpath is always taken
1125        _test_fastpath(model, aligned_key_padding_mask, tensor_return_value, nested_tensors=False)
1126        _test_fastpath(model, not_aligned_key_padding_mask, tensor_return_value, nested_tensors=False)
1127        # Fast path is taken if both attention mask and key padding mask are present
1128        _test_fastpath(model, aligned_key_padding_mask, tensor_return_value, attn_mask=attn_mask, nested_tensors=False)
1129
1130        model = torch.nn.TransformerEncoder(encoder_layer, num_layers=2, enable_nested_tensor=True, mask_check=False)
1131        model.eval()
1132
1133        # Mask check disabled results in sparisty fastpath, independently of the mask
1134        _test_fastpath(model, aligned_key_padding_mask, nested_tensor_return_value, nested_tensors=True)
1135        _test_fastpath(model, not_aligned_key_padding_mask, nested_tensor_return_value, nested_tensors=True)
1136
1137    # Test failing MHA when bias was NoneType
1138    def test_bias_is_none(self):
1139        x = torch.rand((1, 5, 10))
1140        model = torch.nn.modules.activation.MultiheadAttention(10, 1, bias=False, batch_first=True)
1141        model.eval()
1142        model(x, x, x)
1143        # completes without error
1144
1145    def test_transformer_bias_is_none(self, device):
1146        batch_size = 2
1147        seqlen = 3
1148        d_model = 8
1149        nhead = 4
1150
1151        encoder_layer = torch.nn.TransformerEncoderLayer(d_model, nhead, bias=False, batch_first=True, device=device)
1152        encoder_layer.eval()
1153        x = torch.randn(batch_size, seqlen, d_model, device=device)
1154        # runs without error
1155        encoder_layer(x)
1156
1157        with self.assertWarnsRegex(UserWarning, "encoder_layer.self_attn was passed bias=False"):
1158            encoder = torch.nn.TransformerEncoder(encoder_layer, num_layers=1).eval()
1159            encoder(x)
1160
1161        with self.assertWarnsRegex(UserWarning, "self_attn was passed bias=False"):
1162            transformer = torch.nn.Transformer(
1163                d_model=d_model, nhead=nhead, bias=False, batch_first=True, device=device
1164            ).eval()
1165            transformer(x, x)
1166
1167    def test_train_with_is_causal(self, device):
1168        # training with is_causal
1169        S, L, E, H = 1, 2, 2, 1
1170        layer = nn.TransformerEncoderLayer(
1171            d_model=2,
1172            dim_feedforward=4,
1173            nhead=H,
1174            batch_first=True,
1175            activation="gelu",
1176            dropout=0,
1177        )
1178        criterion = nn.MSELoss()
1179        encoder = nn.TransformerEncoder(layer, 2).to(device)
1180        optimizer = optim.SGD(encoder.parameters(), lr=0.1, momentum=0.9)
1181        encoder.train()
1182
1183        encoder.train()
1184        optimizer.zero_grad()
1185        inputs = torch.randn(S, L, E).to(device)
1186        mask = torch.nn.Transformer.generate_square_subsequent_mask(
1187            inputs.size(1), device=device
1188        )
1189
1190        outputs = encoder(inputs, mask=mask, is_causal=True)
1191
1192        loss = criterion(outputs[:, 0:2, :], inputs[:, 0:2, :])
1193        loss.backward()
1194        optimizer.step()
1195
1196        # inference with is_causal
1197        t_qvk = torch.randn((S, L, E), device=device, dtype=torch.float32)
1198        mha = nn.MultiheadAttention(E, H).to(device)
1199        mask = torch.nn.Transformer.generate_square_subsequent_mask(
1200            S, device=device
1201        )
1202
1203        attn_out, _ = mha(t_qvk, t_qvk, t_qvk, attn_mask=mask, is_causal=True)
1204
1205        # Can't give only is_causal
1206        attn_mask = torch.randint(0, 2, size=(L, L), device=device, dtype=torch.bool)
1207        with self.assertRaises(RuntimeError):
1208            _ = mha(t_qvk, t_qvk, t_qvk, is_causal=True)
1209
1210        # # Passing a causal mask sets is_causal to 1
1211        causal_mask = torch.triu(
1212            torch.ones(L, L, device=inputs.device) * float('-inf'), diagonal=1
1213        ).to(torch.bool)
1214
1215        mock_layer = MagicMock(torch.nn.MultiheadAttention(E, H), return_value=inputs)
1216        encoder.layers[1] = mock_layer
1217        outputs = encoder(inputs, mask=causal_mask)
1218        mock_layer.assert_called_with(ANY, src_mask=ANY, is_causal=True, src_key_padding_mask=ANY)
1219
1220        # check expected numerical values with all kernels
1221        self.is_causal_kernels([SDPBackend.MATH], device)
1222
1223    def is_causal_kernels(self, kernels, device):
1224        def ones_tensor(*shape):
1225            return torch.ones(shape, device=device, dtype=torch.float32).to(device)
1226        S, L, E, H = 1, 2, 4, 1
1227        qkv = ones_tensor(S, L, E)
1228
1229        mha = nn.MultiheadAttention(E, H).to(device)
1230        mha.in_proj_weight = Parameter(torch.ones((E * 3, E), device=device))
1231        mha.out_proj.weight = Parameter(torch.ones((E, E), device=device))
1232        expected = torch.ones(size=(S, L, E)).to(device) * 16
1233        mask = torch.nn.Transformer.generate_square_subsequent_mask(
1234            qkv.size(1), device=device
1235        )
1236
1237        for kernel in kernels:
1238            with sdpa_kernel(backends=[kernel]):
1239                actual, _ = mha(qkv, qkv, qkv, attn_mask=mask, need_weights=False, is_causal=True)
1240                self.assertTrue(torch.equal(actual, expected))
1241
1242                if kernel != SDPBackend.MATH:
1243                    # fails with embedding size not multiple of 4
1244                    with self.assertRaisesRegex(RuntimeError, "No available kernel"):
1245                        qkv_f, mha_f = ones_tensor(S, L, 2), nn.MultiheadAttention(2, H).to(device)
1246                        mask = torch.nn.Transformer.generate_square_subsequent_mask(
1247                            qkv_f.size(1), device=device
1248                        )
1249                        _ = mha_f(qkv_f, qkv_f, qkv_f, attn_mask=mask, need_weights=False, is_causal=True)
1250                        torch.cuda.synchronize()
1251
1252    @skipIfRocm  # Missing EFFICIENT_ATTENTION
1253    @unittest.skipIf(
1254        not PLATFORM_SUPPORTS_FLASH_ATTENTION, "Platform does not supposrt fused SDPA or pre-SM80 hardware"
1255    )
1256    def test_is_causal_gpu(self):
1257        device = 'cuda'
1258        self.is_causal_kernels([SDPBackend.MATH, SDPBackend.EFFICIENT_ATTENTION], device)
1259
1260    def test_script_mha_in_proj_weight_none(self):
1261        mha = torch.nn.MultiheadAttention(
1262            embed_dim=128, num_heads=8, kdim=256, vdim=256
1263        ).eval()
1264
1265        torch.jit.script(mha)
1266
1267    @unittest.skipIf(TEST_WITH_CROSSREF, 'Fastpath not available with crossref')
1268    @torch.no_grad()
1269    def test_disable_fastpath(self, device):
1270        def _test_te_fastpath_called(model, args, kwargs=None, return_value=None, is_called=True):
1271            if kwargs is None:
1272                kwargs = {}
1273            with patch('torch._transformer_encoder_layer_fwd') as fastpath_mock:
1274                fastpath_mock.return_value = return_value
1275                output = model(*args, **kwargs)
1276                self.assertTrue(fastpath_mock.called == is_called)
1277
1278        def _test_mha_fastpath_called(model, args, kwargs=None, return_value=None, is_called=True):
1279            if kwargs is None:
1280                kwargs = {}
1281            with patch('torch._native_multi_head_attention') as fastpath_mock:
1282                fastpath_mock.return_value = return_value
1283                output = model(*args, **kwargs)
1284                self.assertTrue(fastpath_mock.called == is_called)
1285
1286        inp = torch.tensor([[[1, 2], [3, 4], [5, 6]]], dtype=torch.float32, device=device)
1287        aligned_key_padding_mask = torch.tensor([[0, 0, 1]], dtype=torch.bool, device=device)
1288        src_key_padding_mask = torch.tensor([[1, 0, 1]], dtype=torch.bool, device=device)
1289        attn_mask = torch.tensor([[1, 0, 1], [0, 1, 0], [1, 0, 1]], dtype=torch.bool, device=device)
1290        te_return_value = torch.ones((1, 3, 2), dtype=torch.float32)
1291
1292        encoder_layer = torch.nn.TransformerEncoderLayer(d_model=2, nhead=2, dim_feedforward=8, batch_first=True)
1293        te = torch.nn.TransformerEncoder(encoder_layer, num_layers=2, enable_nested_tensor=True, mask_check=True)
1294        te = te.to(device).eval()
1295
1296        t = torch.nn.Transformer(d_model=2, nhead=2, batch_first=True, device=device).eval()
1297        src = torch.tensor([[[0, 1], [2, 3], [4, 5]]], dtype=torch.float32, device=device)
1298        tgt = torch.tensor([[[0, 1], [2, 3], [4, 5], [6, 7]]], dtype=torch.float32, device=device)
1299        t_return_value = torch.ones((1, 3, 2), dtype=torch.float32, device=device)
1300
1301        mha = nn.MultiheadAttention(2, 2, batch_first=True, device=device).eval()
1302        q = torch.tensor([[[0, 1], [2, 3]]], dtype=torch.float32, device=device)
1303        mha_return_value = torch.ones((1, 3, 2), dtype=torch.float32, device=device)
1304
1305        _test_te_fastpath_called(
1306            te, (inp,), kwargs={'src_key_padding_mask': src_key_padding_mask},
1307            return_value=te_return_value, is_called=True
1308        )
1309        _test_te_fastpath_called(t, (src, tgt), return_value=t_return_value, is_called=True)
1310        _test_mha_fastpath_called(mha, (q, q, q,), return_value=mha_return_value, is_called=True)
1311
1312        torch.backends.mha.set_fastpath_enabled(False)
1313        _test_te_fastpath_called(
1314            te, (inp,), kwargs={'src_key_padding_mask': src_key_padding_mask},
1315            return_value=te_return_value, is_called=False
1316        )
1317        _test_te_fastpath_called(t, (src, tgt), return_value=t_return_value, is_called=False)
1318        _test_mha_fastpath_called(mha, (q, q, q,), return_value=mha_return_value, is_called=False)
1319
1320        torch.backends.mha.set_fastpath_enabled(True)
1321        _test_te_fastpath_called(
1322            te, (inp,), kwargs={'src_key_padding_mask': src_key_padding_mask},
1323            return_value=te_return_value, is_called=True
1324        )
1325        _test_te_fastpath_called(t, (src, tgt), return_value=t_return_value, is_called=True)
1326        _test_mha_fastpath_called(mha, (q, q, q,), return_value=mha_return_value, is_called=True)
1327
1328
1329class TestSDPAFailureModes(NNTestCase):
1330    """ Used to test the failure modes of scaled_dot_product_attention
1331    """
1332    _do_cuda_memory_leak_check = True
1333    _do_cuda_non_default_stream = True
1334
1335    @onlyCUDA
1336    @unittest.skipIf(
1337        not PLATFORM_SUPPORTS_FLASH_ATTENTION or not isSM8XDevice,
1338        "Does not support fused SDPA or not SM86+ hardware",
1339    )
1340    @parametrize("head_dim", [193, 204, 256])
1341    @parametrize("dropout_p", [0.0, 0.2])
1342    def test_flash_backward_failure_sm86plus(self, device, head_dim: int, dropout_p: float):
1343        dtype = torch.float16
1344        make_tensor = partial(torch.rand, device=device, dtype=dtype)
1345        # See check_requires_grad_and_head_dim_gt192_constraints_on_sm86_89 in
1346        # pytorch/aten/src/ATen/native/transformers/cuda/sdp_utils.h
1347        size = (2, 2, 4, head_dim)
1348        q, k, v = make_tensor(size), make_tensor(size), make_tensor(size)
1349
1350        with sdpa_kernel(backends=[SDPBackend.MATH]):
1351            math_ref = torch.nn.functional.scaled_dot_product_attention(q, k, v, None, 0.0, False)
1352
1353        with sdpa_kernel(backends=[SDPBackend.FLASH_ATTENTION]):
1354            # Should not fail because inputs don't require grad
1355            flash_ref = torch.nn.functional.scaled_dot_product_attention(q, k, v, None, 0.0, False)
1356
1357            self.assertEqual(math_ref, flash_ref, atol=1e-3, rtol=1e-3)
1358
1359            # Should fail because inputs require grad
1360            q = make_tensor(size, requires_grad=True)
1361            k = make_tensor(size, requires_grad=True)
1362            v = make_tensor(size, requires_grad=True)
1363            if 192 < head_dim <= 224 or (head_dim > 224 and dropout_p != 0.0):
1364                self.assertRaises(
1365                    RuntimeError,
1366                    lambda: torch.nn.functional.scaled_dot_product_attention(
1367                        q, k, v, None, dropout_p, False
1368                    ),
1369                )
1370            else:
1371                flash_ref = torch.nn.functional.scaled_dot_product_attention(q, k, v, None, dropout_p, False)
1372
1373    @onlyCUDA
1374    def test_dispatch_fails_no_backend(self, device):
1375        dtype = torch.float16
1376        with sdpa_kernel(backends=[SDPBackend.ERROR]):
1377            size = (2, 3, 4)
1378            q = torch.randn(size, device=device, dtype=dtype)
1379            k = torch.randn(size, device=device, dtype=dtype)
1380            v = torch.randn(size, device=device, dtype=dtype)
1381            self.assertRaisesRegex(RuntimeError, "No viable backend for scaled_dot_product_attention was found.",
1382                                   lambda: torch._fused_sdp_choice(q, k, v))
1383            self.assertRaisesRegex(RuntimeError, "No viable backend for scaled_dot_product_attention was found.",
1384                                   lambda: torch.nn.functional.scaled_dot_product_attention(q, k, v))
1385
1386    @onlyCUDA
1387    @unittest.skipIf(not PLATFORM_SUPPORTS_FUSED_ATTENTION, "Does not support fused scaled dot product attention")
1388    @parametrize(
1389        "kernel",
1390        PLATFORM_SPECIFIC_SDPA,
1391    )
1392    def test_invalid_fused_inputs_dim_3(self, device, kernel: SDPBackend):
1393        with sdpa_kernel(backends=[kernel]):
1394            # Dim is not 4
1395            size = (2, 3, 8)
1396            dtype = torch.float16
1397            q = torch.randn(size, device=device, dtype=dtype)
1398            k = torch.randn(size, device=device, dtype=dtype)
1399            v = torch.randn(size, device=device, dtype=dtype)
1400            with self.assertWarnsRegex(UserWarning, "Both fused kernels requires query, key and value to be 4 dimensional"):
1401                self.assertRaises(RuntimeError, lambda: torch.nn.functional.scaled_dot_product_attention(
1402                    q, k, v, None, 0.0, False))
1403
1404    @onlyCUDA
1405    @unittest.skipIf(not PLATFORM_SUPPORTS_FUSED_ATTENTION, "Does not support fused scaled dot product attention")
1406    @parametrize(
1407        "kernel",
1408        PLATFORM_SPECIFIC_SDPA,
1409    )
1410    def test_invalid_fused_inputs_broadcast(self, device, kernel: SDPBackend):
1411        with sdpa_kernel(backends=[kernel]):
1412            #  Fused Kernels don't support broadcasting for dense inputs
1413            dtype = torch.float16
1414            size = (2, 4, 3, 8)
1415            size_broadcast = (1, 4, 3, 8)
1416            q = torch.randn(size_broadcast, device=device, dtype=dtype)
1417            k = torch.randn(size, device=device, dtype=dtype)
1418            v = torch.randn(size, device=device, dtype=dtype)
1419            self.assertRaises(RuntimeError, lambda: torch.nn.functional.scaled_dot_product_attention(
1420                q, k, v, None, 0.0, False))
1421
1422    @onlyCUDA
1423    @unittest.skipIf(not PLATFORM_SUPPORTS_FUSED_ATTENTION, "Does not support fused scaled dot product attention")
1424    @parametrize("kernel", PLATFORM_SPECIFIC_SDPA)
1425    def test_invalid_sequence_lengths(self, device, kernel: SDPBackend):
1426        with sdpa_kernel(backends=[kernel]):
1427            # Passing in a q,k,v with 0 length sequences will error
1428            dtype = torch.float16
1429            make_tensor = partial(torch.rand, device=device, dtype=dtype)
1430            size = SdpaShape(2, 2, 0, 8)
1431            q, k, v = make_tensor(size), make_tensor(size), make_tensor(size)
1432            with self.assertWarnsRegex(UserWarning, "Both fused kernels do not support zero seq_len_q or seq_len_kv."):
1433                self.assertRaises(RuntimeError, lambda: torch.nn.functional.scaled_dot_product_attention(
1434                    q, k, v, None, 0.0, False))
1435
1436    @onlyCUDA
1437    @unittest.skipIf(not PLATFORM_SUPPORTS_FUSED_ATTENTION, "Does not support fused scaled dot product attention")
1438    @parametrize("kernel", PLATFORM_SPECIFIC_SDPA)
1439    def test_invalid_last_dim_stride(self, device, kernel: SDPBackend):
1440        with sdpa_kernel(backends=[kernel]):
1441            # Passing in a q,k,v with last dim stride not equal to 1 will error
1442            dtype = torch.float16
1443            make_tensor = partial(torch.rand, device=device, dtype=dtype)
1444            size = SdpaShape(2, 2, 8, 8)
1445            q, k, v = make_tensor(size), make_tensor(size), make_tensor(size)
1446            q.as_strided_(size, [2, 2, 2, 2])
1447            with self.assertWarnsRegex(UserWarning, "Both fused kernels require the last dimension of the input to have stride 1."):
1448                self.assertRaises(RuntimeError, lambda: torch.nn.functional.scaled_dot_product_attention(
1449                    q, k, v, None, 0.0, False))
1450
1451    @onlyCUDA
1452    @unittest.skipIf(not PLATFORM_SUPPORTS_FLASH_ATTENTION, "Does not flash_attention fused scaled dot product attention")
1453    @parametrize("kernel", [SDPBackend.FLASH_ATTENTION, SDPBackend.EFFICIENT_ATTENTION])
1454    def test_invalid_fused_inputs_head_dim(self, device, kernel: SDPBackend):
1455        with sdpa_kernel(backends=[kernel]):
1456            # The embed dim per head is not divisible by 8 for flash attention
1457            dtype = torch.float16
1458            make_tensor = partial(torch.rand, device=device, dtype=dtype)
1459            size = SdpaShape(2, 2, 3, 9) if kernel == SDPBackend.EFFICIENT_ATTENTION else SdpaShape(2, 2, 3, 257)
1460            if TEST_WITH_ROCM:  # On ROCM, FA and EA share the backend GPU kernels
1461                size = SdpaShape(2, 2, 3, 257)
1462            q, k, v = make_tensor(size), make_tensor(size), make_tensor(size)
1463            self.assertRaises(RuntimeError, lambda: torch.nn.functional.scaled_dot_product_attention(
1464                q, k, v, None, 0.0, False))
1465
1466    @onlyCUDA
1467    @unittest.skipIf(not PLATFORM_SUPPORTS_FUSED_ATTENTION, "Does not support fused scaled dot product attention")
1468    @parametrize(
1469        "kernel",
1470        PLATFORM_SPECIFIC_SDPA,
1471    )
1472    def test_invalid_fused_inputs_invalid_dtype(self, device, kernel: SDPBackend):
1473        with sdpa_kernel(backends=[kernel]):
1474            # Invalid dtype for both Flash Attention and Mem Efficient Attention
1475            size = SdpaShape(2, 2, 3, 16)
1476            make_tensor = partial(torch.rand, device=device, dtype=torch.float64)
1477            q, k, v = make_tensor(size), make_tensor(size), make_tensor(size)
1478            self.assertRaises(RuntimeError, lambda: torch.nn.functional.scaled_dot_product_attention(
1479                q, k, v, None, 0.0, False))
1480
1481    @onlyCUDA
1482    @unittest.skipIf(not PLATFORM_SUPPORTS_FLASH_ATTENTION, "Does not support flash attention")
1483    @parametrize("kernel", [SDPBackend.FLASH_ATTENTION])
1484    def test_invalid_fused_inputs_attn_mask_present(self, device, kernel: SDPBackend):
1485        with sdpa_kernel(backends=[kernel]):
1486            # Failures for unsupported SDP args
1487            size = SdpaShape(2, 2, 3, 16)
1488            make_tensor = partial(torch.rand, size, device=device, dtype=torch.float16)
1489            q, k, v = make_tensor(), make_tensor(), make_tensor()
1490            # Non-None attention mask
1491            mask = torch.ones((2, 2, 3, 3), device=device, dtype=q.dtype)
1492            self.assertRaises(RuntimeError, lambda: torch.nn.functional.scaled_dot_product_attention(
1493                q, k, v, mask, 0.0, False))
1494
1495    @onlyCUDA
1496    @unittest.skipIf(not PLATFORM_SUPPORTS_FLASH_ATTENTION, "Does not support fused SDPA or pre-SM80 hardware")
1497    def test_unaligned_tensors(self, device):
1498        # The alignment is depdent on arch so we specifiy SM80OrLater
1499        dtype = torch.float16
1500        size = SdpaShape(2, 2, 8, 5)
1501        make_tensor = partial(torch.rand, size, device=device, dtype=dtype)
1502        q, k, v = make_tensor(), make_tensor(), make_tensor()
1503        with sdpa_kernel(backends=[SDPBackend.EFFICIENT_ATTENTION]):
1504            ctxmgr = self.assertRaises(RuntimeError) if not TEST_WITH_ROCM else contextlib.nullcontext()
1505            with ctxmgr:
1506                torch.nn.functional.scaled_dot_product_attention(q, k, v, None, 0.0, False)
1507
1508    @onlyCUDA
1509    @unittest.skipIf(not PLATFORM_SUPPORTS_FLASH_ATTENTION, "Does not support fused SDPA or pre-SM80 hardware")
1510    def test_flash_fail_fp32(self, device):
1511        dtype = torch.float
1512        size = SdpaShape(16, 16, 32, 32)
1513        make_tensor = partial(torch.rand, size, device=device, dtype=dtype)
1514        q, k, v = make_tensor(), make_tensor(), make_tensor()
1515        with sdpa_kernel(backends=[SDPBackend.FLASH_ATTENTION]):
1516            with self.assertWarnsRegex(UserWarning, "Expected query, key and value to all be of dtype: {Half, BFloat16}"):
1517                self.assertRaises(RuntimeError, lambda: torch.nn.functional.scaled_dot_product_attention(
1518                    q, k, v, None, 0.0, False))
1519
1520    @onlyCUDA
1521    @unittest.skipIf(not PLATFORM_SUPPORTS_FLASH_ATTENTION, "Does not support SDPA or pre-SM80 hardware")
1522    def test_flash_autocast_fp32_float16(self, device):
1523        dtype = torch.float
1524        size = SdpaShape(16, 16, 32, 32)
1525        make_tensor = partial(torch.rand, size, device=device, dtype=dtype)
1526        q, k, v = make_tensor(), make_tensor(), make_tensor()
1527        with torch.autocast(device_type='cuda', dtype=torch.float16):
1528            with sdpa_kernel(backends=[SDPBackend.FLASH_ATTENTION]):
1529                _ = torch.nn.functional.scaled_dot_product_attention(
1530                    q, k, v, None, 0.0, False)
1531
1532    @onlyCUDA
1533    @unittest.skipIf(not PLATFORM_SUPPORTS_FLASH_ATTENTION, "Does not support SDPA or pre-SM80 hardware")
1534    def test_flash_autocast_fp32_bfloat16(self, device):
1535        dtype = torch.float
1536        size = SdpaShape(16, 16, 32, 32)
1537        make_tensor = partial(torch.rand, size, device=device, dtype=dtype)
1538        q, k, v = make_tensor(), make_tensor(), make_tensor()
1539        with torch.autocast(device_type='cuda', dtype=torch.bfloat16):
1540            with sdpa_kernel(backends=[SDPBackend.FLASH_ATTENTION]):
1541                _ = torch.nn.functional.scaled_dot_product_attention(
1542                    q, k, v, None, 0.0, False)
1543
1544    # Note: do not truncate the list according to platforms. These tests should always raise errors.
1545    @parametrize("kernel", [SDPBackend.MATH, SDPBackend.FLASH_ATTENTION, SDPBackend.EFFICIENT_ATTENTION])
1546    def test_invalid_inputs_different_datatypes(self, device, kernel: SDPBackend):
1547        with sdpa_kernel(backends=[kernel]):
1548            # Different datatypes
1549            shape = (1, 4, 8, 16)
1550            query = torch.randn(shape, dtype=torch.float32, device=device)
1551            key = torch.randn(shape, dtype=torch.float16, device=device)
1552            value = torch.randn(shape, dtype=torch.float16, device=device)
1553            self.assertRaises(RuntimeError, lambda: F.scaled_dot_product_attention(query, key, value))
1554
1555    @onlyCUDA
1556    @parametrize("kernel", [SDPBackend.MATH, SDPBackend.FLASH_ATTENTION, SDPBackend.EFFICIENT_ATTENTION])
1557    def test_invalid_inputs_different_devices(self, device, kernel: SDPBackend):
1558        # Different devices
1559        shape = (1, 4, 8, 16)
1560        query = torch.randn(shape, dtype=torch.float32, device=device)
1561        key = torch.randn(shape, dtype=torch.float16, device='cpu')
1562        value = torch.randn(shape, dtype=torch.float16, device='cpu')
1563        self.assertRaises(RuntimeError, lambda: F.scaled_dot_product_attention(query, key, value))
1564
1565    @parametrize("kernel", [SDPBackend.MATH, SDPBackend.FLASH_ATTENTION, SDPBackend.EFFICIENT_ATTENTION])
1566    def test_invalid_inputs_1_dimensional_inputs(self, device, kernel: SDPBackend):
1567        with sdpa_kernel(backends=[kernel]):
1568            # 1 dimensional input
1569            shape = (1, 4)
1570            query = torch.randn(4, dtype=torch.float16, device=device)
1571            key = torch.randn(shape, dtype=torch.float16, device=device)
1572            value = torch.randn(shape, dtype=torch.float16, device=device)
1573            self.assertRaises(RuntimeError, lambda: F.scaled_dot_product_attention(query, key, value))
1574
1575    @onlyCUDA
1576    @skipIfRocm  # Missing EFFICIENT_ATTENTION
1577    @unittest.skipIf(not PLATFORM_SUPPORTS_MEM_EFF_ATTENTION, "Fused SDPA was not built for this system")
1578    def test_fused_kernels_nested_broadcasting_error_cases(self, device):
1579        # one of k,v needs to be broadcasted and other has non consistent seq_len dim
1580        rand_nested_tensor = partial(rand_sdpa_tensor, type="nested", device=device, dtype=torch.float32)
1581        batch, num_heads, head_dim = 32, 8, 64
1582        seq_lens_q = torch.randint(low=1, high=32, size=(batch,)).tolist()
1583        seq_lens_v = torch.randint(low=1, high=32, size=(batch,)).tolist()
1584
1585        q_shape = SdpaShape(batch, num_heads, seq_lens_q, head_dim)
1586        k_shape = SdpaShape(1, num_heads, 1, head_dim)
1587        v_shape = SdpaShape(batch, num_heads, seq_lens_v, head_dim)
1588
1589        query = rand_nested_tensor(q_shape).transpose(1, 2)
1590        key = rand_nested_tensor(k_shape).transpose(1, 2)
1591        value = rand_nested_tensor(v_shape).transpose(1, 2)
1592
1593        with sdpa_kernel(backends=[SDPBackend.EFFICIENT_ATTENTION]):
1594            with self.assertRaisesRegex(RuntimeError, "No available kernel"):
1595                torch.nn.functional.scaled_dot_product_attention(
1596                    query, key, value, attn_mask=None, dropout_p=0.0, is_causal=False)
1597
1598    @onlyCUDA
1599    @unittest.skipIf(not PLATFORM_SUPPORTS_FLASH_ATTENTION, "Fused SDPA was not built for this system")
1600    def test_nested_fails_on_padding_head_dim(self, device):
1601        dtype = torch.bfloat16
1602        seq_len_list = [2, 4, 5, 6, 7]
1603        shape = SdpaShape(5, 8, seq_len_list, 57)
1604        make_tensor = partial(rand_sdpa_tensor, shape=shape, type="nested", device=device, dtype=dtype)
1605        q, k, v = make_tensor(), make_tensor(), make_tensor()
1606        with sdpa_kernel(backends=[SDPBackend.FLASH_ATTENTION]):
1607            with self.assertWarnsRegex(UserWarning, "For NestedTensor inputs, Flash attention requires"):
1608                self.assertRaises(RuntimeError, lambda: torch.nn.functional.scaled_dot_product_attention(
1609                    q, k, v, None, 0.0, False))
1610
1611    @onlyCUDA
1612    @unittest.skipIf(not PLATFORM_SUPPORTS_FUSED_ATTENTION or not isLessThanSM80Device,
1613                     "Current platform does not support fused SDPA or is an SM80+ device.")
1614    def test_mem_efficient_fail_bfloat16_less_than_sm80(self, device):
1615        dtype = torch.bfloat16
1616        size = SdpaShape(16, 16, 32, 32)
1617        make_tensor = partial(torch.rand, size, device=device, dtype=dtype)
1618        q, k, v = make_tensor(), make_tensor(), make_tensor()
1619        with sdpa_kernel(backends=[SDPBackend.EFFICIENT_ATTENTION]):
1620            with self.assertWarnsRegex(UserWarning, "Expected query, key and value to all be of dtype: {Half, Float}"):
1621                self.assertRaises(RuntimeError, lambda: torch.nn.functional.scaled_dot_product_attention(
1622                    q, k, v, None, 0.0, False))
1623
1624    @onlyCUDA
1625    @unittest.skipIf(not PLATFORM_SUPPORTS_FLASH_ATTENTION, "Does not support flash attention")
1626    def test_flash_atteention_large_bf16_nan_values(self, device):
1627        query = torch.full((1, 1, 1, 64), 133120.0, dtype=torch.bfloat16, device="cuda")
1628        key = torch.full((1, 1, 1, 64), 133120.0, dtype=torch.bfloat16, device="cuda")
1629        value = torch.full((1, 1, 1, 64), 133120.0, dtype=torch.bfloat16, device="cuda")
1630
1631        with sdpa_kernel(SDPBackend.FLASH_ATTENTION):
1632            out = torch.nn.functional.scaled_dot_product_attention(query, key, value)
1633
1634        self.assertFalse(torch.isnan(out).any(), "Output should not contain NaNs!")
1635
1636    @onlyCUDA
1637    @unittest.skipIf(not PLATFORM_SUPPORTS_FUSED_ATTENTION, "Fused SDPA was not built for this system")
1638    @parametrize("fused_kernel", [SDPBackend.FLASH_ATTENTION, SDPBackend.EFFICIENT_ATTENTION] if
1639                 PLATFORM_SUPPORTS_FLASH_ATTENTION else [SDPBackend.EFFICIENT_ATTENTION])
1640    def test_fused_kernels_seq_len_0_inputs(self, device, fused_kernel):
1641        rand_nested_tensor = partial(rand_sdpa_tensor, type="nested", device=device, dtype=torch.float16)
1642        batch, num_heads, head_dim = 32, 16, 64
1643        seq_lens = torch.randint(low=1, high=32, size=(batch,))
1644        # make sure some seq_lens are 0
1645        num_zeros = 10
1646        indices = torch.randint(low=0, high=batch, size=(num_zeros,))
1647        seq_lens.scatter_(0, indices, 0)
1648
1649        shape = SdpaShape(batch, num_heads, seq_lens.tolist(), head_dim)
1650        query = rand_nested_tensor(shape)
1651        key = rand_nested_tensor(shape)
1652        value = rand_nested_tensor(shape)
1653
1654        query = query.transpose(1, 2)
1655        key = key.transpose(1, 2)
1656        value = value.transpose(1, 2)
1657
1658        with sdpa_kernel(backends=[fused_kernel]):
1659            with self.assertRaisesRegex(RuntimeError, "No available kernel"):
1660                torch.nn.functional.scaled_dot_product_attention(
1661                    query, key, value, attn_mask=None, dropout_p=0.0, is_causal=False)
1662
1663    @onlyCUDA
1664    @unittest.skipIf(not PLATFORM_SUPPORTS_FLASH_ATTENTION, "Fused SDPA was not built for this system")
1665    def test_fused_kernels_nested_broadcasting_requires_grad_failure(self, device):
1666        rand_nested_tensor = partial(rand_sdpa_tensor, type="nested", device=device, dtype=torch.float16, requires_grad=True)
1667        batch, num_heads, head_dim, head_dim_v = 32, 16, 64, 64
1668        seq_lens = torch.randint(low=1, high=32, size=(batch,)).tolist()
1669        q_shape = SdpaShape(1, num_heads, 1, head_dim)
1670        k_shape = SdpaShape(batch, num_heads, seq_lens, head_dim)
1671        v_shape = SdpaShape(batch, 1, seq_lens, head_dim_v)
1672
1673        # create a dense query
1674        query = torch.randn(q_shape, device=device, dtype=torch.float16, requires_grad=True)
1675        key = rand_nested_tensor(k_shape)
1676        value = rand_nested_tensor(v_shape)
1677
1678        query = query.transpose(1, 2)
1679        key = key.transpose(1, 2)
1680        value = value.transpose(1, 2)
1681
1682        with sdpa_kernel(backends=[SDPBackend.FLASH_ATTENTION]):
1683            with self.assertWarnsRegex(UserWarning, "Both fused kernels do not support training with broadcasted NT inputs"):
1684                with self.assertRaisesRegex(RuntimeError, "No available kernel"):
1685                    out = torch.nn.functional.scaled_dot_product_attention(
1686                        query, key, value, attn_mask=None, dropout_p=0.0, is_causal=False)
1687
1688    @onlyCUDA
1689    @unittest.skipIf(not PLATFORM_SUPPORTS_FLASH_ATTENTION, "Does not support flash attention")
1690    def test_flash_attention_fail_with_non_square_causal_attention(self, device):
1691        dtype = torch.bfloat16
1692        q_shape = SdpaShape(1, 1, 8, 16)
1693        kv_shape = SdpaShape(1, 1, 12, 16)
1694        make_q = partial(torch.rand, q_shape, device=device, dtype=dtype)
1695        make_kv = partial(torch.rand, kv_shape, device=device, dtype=dtype)
1696        q, k, v = make_q(), make_kv(), make_kv()
1697        warning_str = "Flash attention does not support the is_causal flag when seqlen_q != seqlen_k."
1698        with sdpa_kernel(backends=[SDPBackend.FLASH_ATTENTION]):
1699            with self.assertWarnsRegex(UserWarning, warning_str):
1700                self.assertRaises(RuntimeError, lambda: torch.nn.functional.scaled_dot_product_attention(
1701                    q, k, v, None, 0.0, is_causal=True))
1702
1703def _get_block_size_n(device, head_dim, is_dropout, is_causal):
1704    # This should match the block sizes in the CUDA kernel
1705    assert head_dim <= 256
1706    major, minor = torch.cuda.get_device_capability(device)
1707    is_sm8x = major == 8 and minor > 0  # Only include sm86 and sm89, exclude sm80 (A100)
1708    is_sm80 = major == 8 and minor == 0
1709    is_sm90 = major == 9 and minor == 0
1710    if head_dim <= 32:
1711        return 128
1712    if head_dim <= 64:
1713        return 128 if not is_dropout else 64
1714    elif head_dim <= 96:
1715        return 64
1716    elif head_dim <= 128:
1717        if is_sm8x:
1718            return 64 if (not is_dropout and is_causal) else 32
1719        else:
1720            return 64 if not is_dropout else 32
1721    elif head_dim <= 160:
1722        if is_sm8x:
1723            return 64
1724        else:
1725            return 32
1726    elif head_dim <= 192:
1727        return 64
1728    elif head_dim <= 224:
1729        return 64
1730    elif head_dim <= 256:
1731        return 64
1732
1733
1734def pad_last_dim(input_tensor, alignment_size, slice: bool = False):
1735    last_dim_size = input_tensor.size(-1)
1736    if (last_dim_size % alignment_size == 0):
1737        return input_tensor, last_dim_size
1738    pad_count = alignment_size - (last_dim_size % alignment_size)
1739    padded_tensor = F.pad(input_tensor, (0, pad_count))
1740    if slice:
1741        return padded_tensor[..., :last_dim_size], last_dim_size
1742    return padded_tensor, last_dim_size
1743
1744
1745class TestSDPA(NNTestCase):
1746    """ Used to test generic functionality of scaled_dot_product_attention
1747    Summary:
1748        If you are adding a new test to this class, make sure that it runs
1749        for both cpu and cuda. If you're test is only applicable to cuda,
1750        add it to TestSDPACudaOnly.
1751    """
1752    @parametrize("contiguous_inputs", [True, False])
1753    def test_sdp_math_gradcheck(self, device, contiguous_inputs: bool):
1754
1755        batch_size, seq_len, num_heads, head_dim = 4, 4, 2, 16
1756        shape = SdpaShape(batch_size, num_heads, seq_len, head_dim)
1757        make_tensor = partial(rand_sdpa_tensor, type="dense", device=device,
1758                              dtype=torch.float64, requires_grad=True, packed=True)
1759
1760        qkv = make_tensor(shape)
1761        query, key, value = qkv.chunk(3, dim=-1)
1762
1763        query = query.view(batch_size, -1, num_heads, head_dim).transpose(1, 2)
1764        key = key.view(batch_size, -1, num_heads, head_dim).transpose(1, 2)
1765        value = value.view(batch_size, -1, num_heads, head_dim).transpose(1, 2)
1766
1767        if contiguous_inputs:
1768            query = query.contiguous()
1769            key = key.contiguous()
1770            value = value.contiguous()
1771
1772        with sdpa_kernel(backends=[SDPBackend.MATH]):
1773            assert gradcheck(lambda *args, **kwargs:
1774                             wrapper_set_seed(torch.nn.functional.scaled_dot_product_attention, *args, **kwargs),
1775                             (query, key, value, None, 0.0, False)
1776                             )
1777
1778    @onlyCPU
1779    @parametrize("type", ["dense", "nested"])
1780    @parametrize("dropout", [0.0, 0.7])
1781    @parametrize("dtype", [torch.float64, torch.float32, torch.bfloat16, torch.half])
1782    def test_fused_sdp_choice_cpu(self, device, type: str, dropout: float, dtype: torch.dtype):
1783        # Test that cpu and nestedtensor cpu return MATH backend
1784        make_tensor = partial(rand_sdpa_tensor, type=type, device=device, dtype=dtype)
1785        size = SdpaShape(2, 8, 128, 64)
1786        q, k, v = make_tensor(size), make_tensor(size), make_tensor(size)
1787        if type == "nested" \
1788                or dropout > 0.0 \
1789                or dtype not in [torch.float32, torch.float64, torch.bfloat16, torch.float16]:
1790            assert torch._fused_sdp_choice(q, k, v, dropout_p=dropout) == SDPBackend.MATH.value
1791        else:
1792            assert torch._fused_sdp_choice(q, k, v, dropout_p=dropout) == SDPBackend.FLASH_ATTENTION.value
1793
1794    @onlyCPU
1795    @parametrize("fused_kernel", [SDPBackend.FLASH_ATTENTION])
1796    @parametrize("dtype", [torch.float64, torch.float32, torch.bfloat16, torch.float16])
1797    @parametrize("batch_size", [2, 12])
1798    @parametrize("seq_len", [267, 1030])
1799    @parametrize("n_head", [1, 3])
1800    @parametrize("head_dim", [8, 16])
1801    @parametrize("causal", [True, False])
1802    @parametrize("train", [True, False])
1803    def test_scaled_dot_product_fused_attention_vs_math_cpu(
1804        self,
1805        device,
1806        fused_kernel,
1807        dtype,
1808        batch_size,
1809        seq_len,
1810        n_head,
1811        head_dim,
1812        causal,
1813        train,
1814    ):
1815        atol = 1e-5
1816        rtol = 5e-6
1817        if dtype is torch.bfloat16:
1818            atol = 5e-2
1819            rtol = 5e-2
1820        if dtype is torch.float16:
1821            atol = 1e-2
1822            rtol = 1e-2
1823
1824        n_embd = n_head * head_dim
1825        make_tensor = partial(rand_sdpa_tensor, type="dense", device=device, dtype=dtype, packed=True, requires_grad=False)
1826        shape = SdpaShape(batch_size, n_head, seq_len, head_dim)
1827        x = make_tensor(shape)
1828        x2 = x.clone()
1829
1830        if train:
1831            x.requires_grad_(True)
1832            x2.requires_grad_(True)
1833
1834        q, k, v = x.split(n_embd, dim=2)
1835        q2, k2, v2 = x2.split(n_embd, dim=2)
1836
1837        if dtype in [torch.bfloat16, torch.float16]:
1838            q2 = q2.float()
1839            k2 = k2.float()
1840            v2 = v2.float()
1841
1842        # (B, nh, T, hs)
1843        k = k.view(batch_size, seq_len, n_head, head_dim).transpose(1, 2)
1844        q = q.view(batch_size, seq_len, n_head, head_dim).transpose(1, 2)
1845        v = v.view(batch_size, seq_len, n_head, head_dim).transpose(1, 2)
1846        k2 = k2.view(batch_size, seq_len, n_head, head_dim).transpose(1, 2)
1847        q2 = q2.view(batch_size, seq_len, n_head, head_dim).transpose(1, 2)
1848        v2 = v2.view(batch_size, seq_len, n_head, head_dim).transpose(1, 2)
1849
1850        with sdpa_kernel(backends=[fused_kernel]):
1851            actual = torch.nn.functional.scaled_dot_product_attention(
1852                q, k, v, attn_mask=None, dropout_p=0.0, is_causal=causal)
1853        with sdpa_kernel(backends=[SDPBackend.MATH]):
1854            math_ref = torch.nn.functional.scaled_dot_product_attention(
1855                q2, k2, v2, attn_mask=None, dropout_p=0.0, is_causal=causal)
1856
1857        if dtype in [torch.bfloat16, torch.float16]:
1858            math_ref = math_ref.to(dtype)
1859
1860        self.assertEqual(actual, math_ref, atol=atol, rtol=rtol)
1861
1862        if train:
1863            actual.sum().backward()
1864            math_ref.sum().backward()
1865
1866            grad_x, grad_x2 = x.grad, x2.grad
1867            grad_q_actual, grad_k_actual, grad_v_actual = grad_x.split(n_embd, dim=2)
1868            grad_q_ref, grad_k_ref, grad_v_ref = grad_x2.split(n_embd, dim=2)
1869
1870            self.assertEqual(grad_q_actual, grad_q_ref, atol=atol, rtol=rtol)
1871            self.assertEqual(grad_k_actual, grad_k_ref, atol=atol, rtol=rtol)
1872            self.assertEqual(grad_v_actual, grad_v_ref, atol=atol, rtol=rtol)
1873
1874    @onlyCPU
1875    @parametrize("fused_kernel", [SDPBackend.FLASH_ATTENTION])
1876    @parametrize("dtype", [torch.float64, torch.float32, torch.bfloat16, torch.float16])
1877    @parametrize("batch_size", [2, 12])
1878    @parametrize("q_seq_len", [267, 1030])
1879    @parametrize("kv_seq_len", [514, 1179])
1880    @parametrize("n_head", [1, 3])
1881    @parametrize("head_dim", [8, 16])
1882    @parametrize("mask_dim", [2, 4])
1883    @parametrize("bool_mask", [0, 1])
1884    @parametrize("train", [True, False])
1885    def test_scaled_dot_product_fused_attention_mask_vs_math_cpu(
1886        self,
1887        device,
1888        fused_kernel,
1889        dtype,
1890        batch_size,
1891        q_seq_len,
1892        kv_seq_len,
1893        n_head,
1894        head_dim,
1895        mask_dim,
1896        bool_mask,
1897        train,
1898    ):
1899        tol = Tolerances(1e-5, 5e-6)
1900        if dtype is torch.bfloat16:
1901            tol = Tolerances(5e-2, 5e-2)
1902        if dtype is torch.float16:
1903            tol = Tolerances(1e-2, 1e-2)
1904
1905        make_tensor = partial(rand_sdpa_tensor, type="dense", device=device, dtype=dtype, requires_grad=False)
1906        q_shape = SdpaShape(batch_size, n_head, q_seq_len, head_dim)
1907        kv_shape = SdpaShape(batch_size, n_head, kv_seq_len, head_dim)
1908        q = make_tensor(q_shape)
1909        k = make_tensor(kv_shape)
1910        v = make_tensor(kv_shape)
1911        q2, k2, v2 = q.clone(), k.clone(), v.clone()
1912
1913        if train:
1914            q.requires_grad_(True)
1915            k.requires_grad_(True)
1916            v.requires_grad_(True)
1917            q2.requires_grad_(True)
1918            k2.requires_grad_(True)
1919            v2.requires_grad_(True)
1920
1921        if dtype in [torch.bfloat16, torch.float16]:
1922            q2, k2, v2 = q2.float(), k2.float(), v2.float()
1923        # (B, nh, T, hs)
1924        q = q.view(batch_size, q_seq_len, n_head, head_dim).transpose(1, 2)
1925        k = k.view(batch_size, kv_seq_len, n_head, head_dim).transpose(1, 2)
1926        v = v.view(batch_size, kv_seq_len, n_head, head_dim).transpose(1, 2)
1927        if mask_dim == 4:
1928            mask_shape = (batch_size, n_head, q_seq_len, kv_seq_len)
1929        else:
1930            mask_shape = (q_seq_len, kv_seq_len)
1931        if bool_mask:
1932            attn_mask = torch.randint(0, 2, size=mask_shape, dtype=torch.bool, device=device)
1933        else:
1934            attn_mask = torch.randn(mask_shape, dtype=dtype, device=device)
1935        q2 = q2.view(batch_size, q_seq_len, n_head, head_dim).transpose(1, 2)
1936        k2 = k2.view(batch_size, kv_seq_len, n_head, head_dim).transpose(1, 2)
1937        v2 = v2.view(batch_size, kv_seq_len, n_head, head_dim).transpose(1, 2)
1938
1939        with sdpa_kernel(backends=[fused_kernel]):
1940            actual = torch.nn.functional.scaled_dot_product_attention(
1941                q, k, v, attn_mask=attn_mask, dropout_p=0.0, is_causal=False)
1942        with sdpa_kernel(backends=[SDPBackend.MATH]):
1943            if not bool_mask and dtype in [torch.bfloat16, torch.float16]:
1944                attn_mask = attn_mask.float()
1945            math_ref = torch.nn.functional.scaled_dot_product_attention(
1946                q2, k2, v2, attn_mask=attn_mask, dropout_p=0.0, is_causal=False)
1947
1948        if dtype in [torch.bfloat16, torch.float16]:
1949            math_ref = math_ref.to(dtype)
1950
1951        self.assertEqual(actual, math_ref, atol=tol.atol, rtol=tol.rtol)
1952
1953        if train:
1954            actual.sum().backward()
1955            math_ref.sum().backward()
1956
1957            grad_q_actual, grad_k_actual, grad_v_actual = q.grad, k.grad, v.grad
1958            grad_q_ref, grad_k_ref, grad_v_ref = q2.grad, k2.grad, v2.grad
1959
1960            self.assertEqual(grad_q_actual, grad_q_ref, atol=tol.atol, rtol=tol.rtol)
1961            self.assertEqual(grad_k_actual, grad_k_ref, atol=tol.atol, rtol=tol.rtol)
1962            self.assertEqual(grad_v_actual, grad_v_ref, atol=tol.atol, rtol=tol.rtol)
1963
1964    @onlyCPU
1965    def test_scaled_dot_product_fused_attention_with_inf(self, device):
1966        # https://github.com/pytorch/pytorch/issues/127055.
1967        full = torch.full((600, 600), float("-inf"), device=device)
1968        mask = torch.triu(full, diagonal=1) + torch.tril(full, diagonal=-10)
1969        make_tensor = partial(rand_sdpa_tensor, type="dense", device=device, dtype=torch.float32, requires_grad=False)
1970        input_shape = SdpaShape(1, 600, 2, 8)
1971        q = make_tensor(input_shape)
1972        k = make_tensor(input_shape)
1973        v = make_tensor(input_shape)
1974        with sdpa_kernel(backends=[SDPBackend.MATH]):
1975            math_ref = torch.nn.functional.scaled_dot_product_attention(q, k, v, attn_mask=mask)
1976        with sdpa_kernel(backends=[SDPBackend.FLASH_ATTENTION]):
1977            actual = torch.nn.functional.scaled_dot_product_attention(q, k, v, attn_mask=mask)
1978        self.assertEqual(math_ref, actual)
1979
1980    @parametrize("kernel", [SDPBackend.MATH])
1981    def test_scaled_dot_product_attention_math_with_negative_scale(self, device, kernel: SDPBackend):
1982        # https://github.com/pytorch/pytorch/issues/105190.
1983        def ref(x):
1984            v1 = torch.matmul(x, x.transpose(-1, -2))
1985            v2 = v1 / -0.0001
1986            v3 = v2.softmax(dim=-1)
1987            v4 = torch.matmul(v3, x)
1988            return v4
1989
1990        x = torch.randn(1, 3, 64, 64, device=device)
1991        ref_result = ref(x)
1992        with sdpa_kernel(backends=[kernel]):
1993            sdp_math = torch.nn.functional.scaled_dot_product_attention(x, x, x, scale=-1.0 / 0.0001)
1994        self.assertEqual(ref_result, sdp_math)
1995
1996class TestSDPACudaOnly(NNTestCase):
1997    """ Used to test CUDA only functionality of scaled_dot_product_attention
1998    Quarks:
1999        There is some trickiness with this function. Its runtime behavior
2000        is dependent on the CUDA architecture you are testing it on. See
2001        `PLATFORM_SUPPORTS_FUSED_ATTENTION` at the top of the file.
2002        Summary:
2003            Math: always supported
2004            FlashAttention: Supported on sm80 or newer hardware
2005            MemEfficientAttention: Supported on sm50 or newer hardware
2006    """
2007    _do_cuda_memory_leak_check = True
2008    _do_cuda_non_default_stream = True
2009
2010    # TODO USED FOR TESTING THE SCORES, e.g. testing ALIBI we don't need this now
2011    def normalize_flash_attn_S(
2012        self,
2013        attn_unnorm,
2014        q,
2015        k,
2016        v,
2017        query_padding_mask=None,
2018        key_padding_mask=None,
2019        attn_bias=None,
2020        is_dropout=False,
2021        causal=False,
2022        window_size=(-1, -1),  # -1 means infinite window size
2023        scale=None,
2024    ):
2025        """
2026        Arguments:
2027            q: (batch_size, seqlen_q, nheads, head_dim)
2028            k, v: (batch_size, seqlen_k, nheads, head_dim)
2029            key_padding_mask: (batch_size, seqlen_q)
2030            attn_bias: broadcastable to (batch_size, nheads, seqlen_q, seqlen_k)
2031        Output:
2032            softmax_lse: (batch_size, nheads, seqlen_q)
2033            softmax_max: (batch_size, nheads, seqlen_q)
2034        """
2035        q = q.transpose(1, 2)
2036        k = k.transpose(1, 2)
2037        v = v.transpose(1, 2)
2038        if causal:
2039            window_size = (window_size[0], 0)
2040        q, k, v = q.float(), k.float(), v.float()
2041        _, seqlen_q, _, head_dim = q.shape
2042        seqlen_k = k.shape[1]
2043        b = q.shape[0]
2044        from torch.nn.attention.bias import _calculate_scale
2045        scale = _calculate_scale(head_dim, scale)
2046        scores = torch.matmul(q.transpose(1, 2) * scale, k.permute(0, 2, 3, 1))
2047        if key_padding_mask is not None:
2048            scores.masked_fill_(~key_padding_mask.view(b, 1, 1, -1), float("-inf"))
2049        if window_size[0] >= 0 or window_size[1] >= 0:
2050            local_mask = self.construct_local_mask(
2051                seqlen_q,
2052                seqlen_k,
2053                window_size,
2054                query_padding_mask,
2055                key_padding_mask,
2056                q.device,
2057            )
2058            scores.masked_fill_(local_mask, float("-inf"))
2059        if attn_bias is not None:
2060            scores = scores + attn_bias.to(dtype=scores.dtype)
2061        block_size_n = _get_block_size_n(scores.device, head_dim, is_dropout, causal)
2062        scores_block = scores.split(block_size_n, dim=-1)
2063        lse_block = torch.stack([torch.logsumexp(s, dim=-1) for s in scores_block], dim=-1)
2064        lse = torch.logsumexp(lse_block, dim=-1)
2065        # lse could be -inf (i.e. all values in scores are -inf), and we want to set those to inf
2066        # so that when we do torch.exp(m - lse), we get 0.0 instead of NaN.
2067        lse[lse == float("-inf")] = float("inf")
2068        scores_max_block = torch.stack([torch.amax(s, dim=-1) for s in scores_block], dim=-1)
2069        cummax_block = torch.cummax(scores_max_block.flip(-1), dim=-1).values.flip(-1).unbind(dim=-1)
2070        attn_unnorm_block = attn_unnorm.split(block_size_n, dim=-1)
2071        attn_norm = torch.cat(
2072            [
2073                a * (torch.exp(m - lse)).unsqueeze(-1)
2074                for a, m in zip(attn_unnorm_block, cummax_block)
2075            ],
2076            dim=-1,
2077        )
2078        if query_padding_mask is not None:
2079            attn_norm.masked_fill_(~query_padding_mask.view(b, 1, -1, 1), 0.0)
2080            # attn_norm.masked_fill_(rearrange(~query_padding_mask, "b s -> b 1 s 1"), 0.0)
2081        return attn_norm.to(dtype=attn_unnorm.dtype)
2082
2083    def construct_local_mask(self, seqlen_q, seqlen_k, window_size, query_padding_mask, key_padding_mask, device):
2084        # row_idx = rearrange(torch.arange(seqlen_q, device=device, dtype=torch.long), "s -> s 1")
2085        row_idx = torch.arange(seqlen_q, device=device, dtype=torch.long).view(-1, 1)
2086        col_idx = torch.arange(seqlen_k, device=device, dtype=torch.long)
2087        sk = (
2088            seqlen_k
2089            if key_padding_mask is None
2090            else key_padding_mask.sum(-1).view(-1, 1, 1, 1)
2091            # else rearrange(key_padding_mask.sum(-1), "b -> b 1 1 1")
2092        )
2093        sq = (
2094            seqlen_q
2095            if query_padding_mask is None
2096            else query_padding_mask.sum(-1).view(-1, 1, 1, 1)
2097            # else rearrange(query_padding_mask.sum(-1), "b -> b 1 1 1")
2098        )
2099        if window_size[0] < 0:
2100            return col_idx > row_idx + sk - sq + window_size[1]
2101        else:
2102            sk = torch.full_like(col_idx, seqlen_k) if key_padding_mask is None else sk
2103            return torch.logical_or(
2104                col_idx > torch.minimum(row_idx + sk - sq + window_size[1], sk),
2105                col_idx < row_idx + sk - sq - window_size[0],
2106            )
2107
2108    def convert_flash_attn_S_to_softmax(
2109        self,
2110        S,
2111        seqlen_q,
2112        seqlen_k,
2113        query_padding_mask,
2114        key_padding_mask,
2115        causal=False,
2116        window_size=(-1, -1),  # -1 means infinite window size
2117    ):
2118        """FlashAttention stores the S matrix in a different way.
2119        Arguments:
2120            S: (batch_size, nheads, seqlen_q, seqlen_k)
2121            query_padding_mask: (batch_size, seqlen_q)
2122            key_padding_mask: (batch_size, seqlen_k)
2123        """
2124        if TEST_WITH_ROCM:
2125            return S
2126        b = S.shape[0]
2127
2128        if causal:
2129            window_size = (window_size[0], 0)
2130        seqlen_q_rounded, seqlen_k_rounded = S.shape[-2:]
2131        S_converted = S
2132        if window_size[0] >= 0 or window_size[1] >= 0:
2133            local_mask = self.construct_local_mask(
2134                seqlen_q,
2135                seqlen_k,
2136                window_size,
2137                query_padding_mask,
2138                key_padding_mask,
2139                S.device,
2140            )
2141            local_mask = F.pad(
2142                local_mask,
2143                (0, seqlen_k_rounded - seqlen_k, 0, seqlen_q_rounded - seqlen_q),
2144                value=True,
2145            )
2146            S_converted = S_converted.masked_fill(local_mask, 0.0)
2147
2148        # Need to zero out things not in attention_mask in case S was initialized with random values
2149        # and some of those values aren't overwritten.
2150        seqlen_q_og = (
2151            query_padding_mask.shape[-1] if query_padding_mask is not None else seqlen_q_rounded
2152        )
2153        if query_padding_mask is not None:
2154            query_padding_mask = F.pad(query_padding_mask, (0, seqlen_q_rounded - seqlen_q_og))
2155            # S_converted = S_converted.masked_fill(rearrange(~query_padding_mask, "b s -> b 1 s 1"), 0.0)
2156            S_converted = S_converted.masked_fill(~query_padding_mask.view(b, 1, -1, 1), 0.0)
2157        seqlen_k_og = key_padding_mask.shape[-1] if key_padding_mask is not None else seqlen_k
2158        if key_padding_mask is not None:
2159            key_padding_mask = F.pad(key_padding_mask, (0, seqlen_k_rounded - seqlen_k_og))
2160            S_converted = S_converted.masked_fill(~key_padding_mask.view(b, 1, 1, -1), 0.0)
2161            # S_converted = S_converted.masked_fill(rearrange(~key_padding_mask, "b s -> b 1 1 s"), 0.0)
2162        S_converted = F.pad(S_converted, (0, 0, 0, seqlen_q_og - seqlen_q_rounded))
2163        S_converted = F.pad(S_converted, (0, seqlen_k_og - seqlen_k_rounded))
2164        return S_converted[:, :, :seqlen_q, :seqlen_k]
2165
2166    @skipIfRocm  # No cuDNN Attention
2167    @unittest.skipIf(not PLATFORM_SUPPORTS_CUDNN_ATTENTION, "cuDNN Attention is not supported on this system")
2168    def test_cudnn_attention_different_dk_dv(self, device):
2169        dtype = torch.bfloat16
2170        make_tensor = partial(torch.rand, device=device, dtype=dtype, requires_grad=True)
2171        batch, num_heads, head_dim_k, head_dim_v = 32, 16, 128, 64
2172        seq_len = 640
2173        q_shape = SdpaShape(batch, num_heads, seq_len, head_dim_k)
2174        k_shape = SdpaShape(batch, num_heads, seq_len, head_dim_k)
2175        v_shape = SdpaShape(batch, num_heads, seq_len, head_dim_v)
2176        query, key, value = make_tensor(q_shape), make_tensor(k_shape), make_tensor(v_shape)
2177
2178        with sdpa_kernel(backends=[SDPBackend.CUDNN_ATTENTION]):
2179            actual = torch.nn.functional.scaled_dot_product_attention(
2180                query, key, value, attn_mask=None, dropout_p=0.0, is_causal=False)
2181        with sdpa_kernel(backends=[SDPBackend.MATH]):
2182            math_ref = torch.nn.functional.scaled_dot_product_attention(
2183                query.contiguous().to(torch.float32),
2184                key.contiguous().to(torch.float32),
2185                value.contiguous().to(torch.float32),
2186                attn_mask=None, dropout_p=0.0, is_causal=False)
2187
2188        self.assertEqual(actual.contiguous(), math_ref.contiguous().to(dtype), atol=1e-3, rtol=1e-2)
2189
2190
2191    @unittest.skipIf(not PLATFORM_SUPPORTS_MEM_EFF_ATTENTION, "Fused SDPA was not built for this system")
2192    @parametrize("mask_dim", [1, 2, 3, 4])
2193    def test_mem_efficient_attention_mask_variants(self, device, mask_dim: List[int]):
2194        dtype = torch.float16
2195        make_tensor = partial(torch.rand, device=device, dtype=dtype, requires_grad=True)
2196        batch, num_heads, head_dim = 8, 8, 64
2197        seq_len_q, seq_len_kv = 64, 32
2198        query = make_tensor(SdpaShape(batch, num_heads, seq_len_q, head_dim))
2199        kv_shape = SdpaShape(batch, num_heads, seq_len_kv, head_dim)
2200        key, value = make_tensor(kv_shape), make_tensor(kv_shape)
2201
2202        if mask_dim == 1:
2203            mask = torch.randn((seq_len_kv,), device=device, dtype=dtype)
2204        elif mask_dim == 2:
2205            mask = torch.randn((seq_len_q, seq_len_kv), device=device, dtype=dtype)
2206        elif mask_dim == 3:
2207            mask = torch.randn((num_heads, seq_len_q, seq_len_kv), device=device, dtype=dtype)
2208        elif mask_dim == 4:
2209            mask = torch.randn((batch, num_heads, seq_len_q, seq_len_kv), device=device, dtype=dtype)
2210        with sdpa_kernel(backends=[SDPBackend.EFFICIENT_ATTENTION]):
2211            out = F.scaled_dot_product_attention(query, key, value, mask)
2212        out.sum().backward()
2213
2214    @unittest.skipIf(not PLATFORM_SUPPORTS_MEM_EFF_ATTENTION, "Fused SDPA was not built for this system")
2215    @parametrize("dtype", [torch.float, torch.float16])
2216    def test_mem_eff_attention_pad_mask(self, device, dtype):
2217        make_tensor = partial(torch.rand, device=device, dtype=dtype, requires_grad=True)
2218        batch, num_heads, head_dim = 8, 8, 64
2219        seq_len_q, seq_len_kv = 64, 15
2220        query = make_tensor(SdpaShape(batch, num_heads, seq_len_q, head_dim))
2221        kv_shape = SdpaShape(batch, num_heads, seq_len_kv, head_dim)
2222        key, value = make_tensor(kv_shape), make_tensor(kv_shape)
2223        mask = torch.randn((batch, num_heads, seq_len_q, seq_len_kv), device=device, dtype=dtype)
2224        with sdpa_kernel(backends=[SDPBackend.EFFICIENT_ATTENTION]):
2225            out = F.scaled_dot_product_attention(query, key, value, mask)
2226        out.sum().backward()
2227
2228    @unittest.skipIf(not PLATFORM_SUPPORTS_MEM_EFF_ATTENTION, "Fused SDPA was not built for this system")
2229    @parametrize("dtype", [torch.float, torch.float16])
2230    def test_mem_eff_attention_non_contiguous_mask(self, device, dtype):
2231        make_tensor = partial(torch.rand, device=device, dtype=dtype, requires_grad=True)
2232        batch, num_heads, head_dim = 8, 8, 64
2233        seq_len_q, seq_len_kv = 64, 16
2234        query = make_tensor(SdpaShape(batch, num_heads, seq_len_q, head_dim))
2235        kv_shape = SdpaShape(batch, num_heads, seq_len_kv, head_dim)
2236        key, value = make_tensor(kv_shape), make_tensor(kv_shape)
2237        mask = torch.randn((batch, num_heads, seq_len_q, seq_len_kv), device=device, dtype=dtype)
2238        mask = torch.as_strided(mask, (batch, num_heads, seq_len_q, seq_len_kv), (0, 0, 0, 1))
2239        with sdpa_kernel(backends=[SDPBackend.EFFICIENT_ATTENTION]):
2240            out = F.scaled_dot_product_attention(query, key, value, mask)
2241        out.sum().backward()
2242
2243    @unittest.skipIf(not PLATFORM_SUPPORTS_MEM_EFF_ATTENTION, "Fused SDPA was not built for this system")
2244    @parametrize("dtype", [torch.float, torch.float16])
2245    def test_mem_eff_attention_long_sequence_mask(self, device, dtype):
2246        if torch.cuda.get_device_properties('cuda').total_memory < 80 * 2**30:
2247            unittest.skip("This test requires substatnial GPU memory.")
2248            return
2249        make_tensor = partial(torch.rand, device=device, dtype=dtype, requires_grad=True)
2250        batch, num_heads, head_dim = 1, 32, 64
2251        seq_len_q, seq_len_kv = 8192, 8192
2252        query = make_tensor(SdpaShape(batch, num_heads, seq_len_q, head_dim))
2253        kv_shape = SdpaShape(batch, num_heads, seq_len_kv, head_dim)
2254        key, value = make_tensor(kv_shape), make_tensor(kv_shape)
2255        mask = torch.randn((batch, num_heads, seq_len_q, seq_len_kv), device=device, dtype=dtype)
2256        with sdpa_kernel(backends=[SDPBackend.EFFICIENT_ATTENTION]):
2257            out = F.scaled_dot_product_attention(query, key, value, mask)
2258        out.sum().backward()
2259
2260    @unittest.skipIf(not PLATFORM_SUPPORTS_MEM_EFF_ATTENTION, "Fused SDPA was not built for this system")
2261    def test_mem_eff_attention_non_contig_mask_bug(self, device):
2262        # Without the fix this produces `AssertionError: assert 0.07352933287620544 < 1e-07`
2263        # Shapes taken from repro
2264        query_size = (3, 16, 1, 128)
2265        query_strides = (2304, 128, 2048, 1)
2266        key_size = (3, 16, 14, 128)
2267        key_strides = (3584, 0, 256, 1)
2268        value_size = (3, 16, 14, 128)
2269        value_strides = (3584, 0, 256, 1)
2270        attention_mask_size = (3, 1, 1, 14)
2271        attn_mask_strides = (14, 14, 14, 1)
2272
2273        # Calculate the number of elements needed for each tensor
2274        query_num_elements = max(size * stride for size, stride in zip(query_size, query_strides))
2275        key_num_elements = max(size * stride for size, stride in zip(key_size, key_strides))
2276        value_num_elements = max(size * stride for size, stride in zip(value_size, value_strides))
2277        attention_mask_num_elements = max(size * stride for size, stride in zip(attention_mask_size, attn_mask_strides))
2278
2279        # Create the tensors with the specified sizes and strides
2280        query = torch.randn(query_num_elements, device=device).as_strided(query_size, query_strides)
2281        key = torch.randn(key_num_elements, device=device).as_strided(key_size, key_strides)
2282        value = torch.randn(value_num_elements, device=device).as_strided(value_size, value_strides)
2283        bias = torch.randn(attention_mask_num_elements, device=device).as_strided(attention_mask_size, attn_mask_strides)
2284
2285        with sdpa_kernel(backends=[SDPBackend.EFFICIENT_ATTENTION]):
2286            out = F.scaled_dot_product_attention(query, key, value, bias)
2287            out_contig = F.scaled_dot_product_attention(query, key, value, bias.contiguous())
2288
2289        max_diff = (out - out_contig).abs().mean()
2290        self.assertTrue(max_diff.item() < 1e-7)
2291
2292    @unittest.skipIf(not PLATFORM_SUPPORTS_FLASH_ATTENTION, "Fused SDPA was not built for this system")
2293    def test_singelton_head_dim_stride_ne_1(self, device):
2294        query = torch.tensor([[[[1, 2]]]], dtype=torch.float16, device=device)
2295        query = query.transpose(-1, -2)
2296        key = torch.tensor([[[[1]]]], dtype=torch.float16, device=device)
2297        value = torch.tensor([[[[1]]]], dtype=torch.float16, device=device)
2298
2299        with torch.backends.cuda.sdp_kernel(enable_math=False, enable_flash=True, enable_mem_efficient=False):
2300            scaled_dot_product_attention(query, key, value)
2301
2302    @unittest.skipIf(not PLATFORM_SUPPORTS_MEM_EFF_ATTENTION, "Fused SDPA was not built for this system")
2303    @parametrize("type", ["dense", "nested"])
2304    @parametrize("is_contiguous", [True, False])
2305    def test_scaled_dot_product_attention_fused_kernels_packed(self, device, type: str, is_contiguous: bool):
2306        if TEST_WITH_ROCM and type == 'nested':
2307            self.skipTest("ROCM does not support efficient attention on nested tensors, for now")
2308        make_tensor = partial(rand_sdpa_tensor, type=type, device=device, dtype=torch.float16, packed=True)
2309
2310        batch_size, seq_len, num_heads, head_dim = 32, 64, 16, 64
2311        shape = SdpaShape(batch_size, num_heads, seq_len, head_dim)
2312
2313        # Test Packed
2314        qkv = make_tensor(shape)
2315        query, key, value = qkv.chunk(3, dim=-1)
2316
2317        query = query.view(batch_size, -1, num_heads, head_dim).transpose(1, 2)
2318        value = value.view(batch_size, -1, num_heads, head_dim).transpose(1, 2)
2319        key = key.view(batch_size, -1, num_heads, head_dim).transpose(1, 2)
2320
2321        if is_contiguous:
2322            query = query.contiguous()
2323            key = key.contiguous()
2324            value = value.contiguous()
2325
2326        with sdpa_kernel(backends=[SDPBackend.EFFICIENT_ATTENTION]):
2327            actual = torch.nn.functional.scaled_dot_product_attention(
2328                query, key, value, attn_mask=None, dropout_p=0.0, is_causal=False)
2329        with sdpa_kernel(backends=[SDPBackend.MATH]):
2330            math_ref = torch.nn.functional.scaled_dot_product_attention(
2331                query.contiguous(), key.contiguous(), value.contiguous(),
2332                attn_mask=None, dropout_p=0.0, is_causal=False)
2333
2334        self.assertEqual(actual.contiguous(), math_ref.contiguous(), atol=2e-3, rtol=1e-2)
2335
2336    @skipIfRocm  # Missing nested and EFFICIENT_ATTENTION
2337    @unittest.skipIf(not PLATFORM_SUPPORTS_FUSED_ATTENTION, "Fused SDPA was not built for this system")
2338    @parametrize("type", ["dense", "nested"])
2339    @parametrize("fused_kernel", [SDPBackend.FLASH_ATTENTION, SDPBackend.EFFICIENT_ATTENTION] if
2340                 PLATFORM_SUPPORTS_FLASH_ATTENTION else [SDPBackend.EFFICIENT_ATTENTION])
2341    def test_scaled_dot_product_attention_fused_kernels_packed_accuracy(self, device, type: str, fused_kernel: str):
2342        def rand_nt(shape):
2343            batch, seq_len, num_heads, head_dim = shape
2344            tensors = [6 * torch.rand((seq_len, 3 * num_heads * head_dim), device=device, dtype=torch.float32) - 3
2345                       for _ in range(batch)]
2346            return (torch.nested.nested_tensor(tensors, device=device, dtype=torch.float32),
2347                    torch.nested.nested_tensor(tensors, device=device, dtype=torch.float16))
2348
2349        def rand_tensor(shape):
2350            batch, seq_len, num_heads, head_dim = shape
2351            tensor = 6 * torch.rand((batch, seq_len, 3 * num_heads * head_dim), device=device, dtype=torch.float32) - 3
2352            return tensor, tensor.to(dtype=torch.float16)
2353
2354        batch_size, seq_len, num_heads, head_dim = 16, 8, 4, 64
2355        shape = (batch_size, seq_len, num_heads, head_dim)
2356
2357        # Test Packed
2358        qkv, qkv_low_precision = rand_tensor(shape) if type == "dense" else rand_nt(shape)
2359        query, key, value = qkv.chunk(3, dim=-1)
2360        query_lp, key_lp, value_lp = qkv_low_precision.chunk(3, dim=-1)
2361
2362        query = query.view(batch_size, -1, num_heads, head_dim).transpose(1, 2)
2363        key = key.view(batch_size, -1, num_heads, head_dim).transpose(1, 2)
2364        value = value.view(batch_size, -1, num_heads, head_dim).transpose(1, 2)
2365
2366        query_lp = query_lp.view(batch_size, -1, num_heads, head_dim).transpose(1, 2)
2367        key_lp = key_lp.view(batch_size, -1, num_heads, head_dim).transpose(1, 2)
2368        value_lp = value_lp.view(batch_size, -1, num_heads, head_dim).transpose(1, 2)
2369
2370        with sdpa_kernel(backends=[fused_kernel]):
2371            actual = torch.nn.functional.scaled_dot_product_attention(
2372                query_lp, key_lp, value_lp, attn_mask=None, dropout_p=0.0, is_causal=False)
2373
2374        with sdpa_kernel(backends=[SDPBackend.MATH]):
2375            math_ref_lp = torch.nn.functional.scaled_dot_product_attention(
2376                query_lp.contiguous(), key_lp.contiguous(), value_lp.contiguous(),
2377                attn_mask=None, dropout_p=0.0, is_causal=False)
2378
2379            math_query = query.contiguous()
2380            math_key = key.contiguous()
2381            math_value = value.contiguous()
2382
2383            math_ref = torch.nn.functional.scaled_dot_product_attention(
2384                math_query, math_key, math_value, attn_mask=None, dropout_p=0.0, is_causal=False)
2385
2386        actual_test = actual
2387        math_ref_test = math_ref
2388        math_ref_lp_test = math_ref_lp
2389
2390        if actual_test.is_nested:
2391            actual_test = torch.nested.to_padded_tensor(actual_test.contiguous(), padding=0.0)
2392            math_ref_test = torch.nested.to_padded_tensor(math_ref_test, padding=0.0)
2393            math_ref_lp_test = torch.nested.to_padded_tensor(math_ref_lp_test, padding=0.0)
2394
2395        actual_test = actual_test.to(dtype=torch.float32).contiguous()
2396        math_ref_test = math_ref_test.to(dtype=torch.float32).contiguous()
2397        math_ref_lp_test = math_ref_lp_test.to(dtype=torch.float32).contiguous()
2398
2399        self.assertEqual(math_ref_test, math_ref_lp_test, atol=7e-3, rtol=7e-3)
2400        self.assertEqual(actual_test, math_ref_test, atol=5e-3, rtol=5e-3)
2401
2402    @unittest.skipIf(not PLATFORM_SUPPORTS_MEM_EFF_ATTENTION, "Efficient Attention was not built for this system")
2403    @parametrize("contiguous_inputs", [True, False])
2404    @parametrize("is_causal", [True, False])
2405    def test_sdp_mem_efficient_grad_against_math(self, device, contiguous_inputs: bool, is_causal: bool):
2406        batch_size, seq_len, num_heads, head_dim = 4, 4, 2, 16
2407        make_tensor = partial(rand_sdpa_tensor, type="dense", device=device,
2408                              dtype=torch.float64, requires_grad=True, packed=True)
2409
2410        qkv = make_tensor(SdpaShape(batch_size, num_heads, seq_len, head_dim))
2411        qkv_lp = qkv.detach().clone().to(torch.float32).requires_grad_()
2412
2413        query, key, value = qkv.chunk(3, dim=-1)
2414        query_lp, key_lp, value_lp = qkv_lp.chunk(3, dim=-1)
2415
2416        query = query.view(batch_size, -1, num_heads, head_dim).transpose(1, 2)
2417        key = key.view(batch_size, -1, num_heads, head_dim).transpose(1, 2)
2418        value = value.view(batch_size, -1, num_heads, head_dim).transpose(1, 2)
2419
2420        query_lp = query_lp.view(batch_size, -1, num_heads, head_dim).transpose(1, 2)
2421        key_lp = key_lp.view(batch_size, -1, num_heads, head_dim).transpose(1, 2)
2422        value_lp = value_lp.view(batch_size, -1, num_heads, head_dim).transpose(1, 2)
2423
2424        if contiguous_inputs:
2425            query = query.contiguous()
2426            key = key.contiguous()
2427            value = value.contiguous()
2428
2429            query_lp = query_lp.contiguous()
2430            key_lp = key_lp.contiguous()
2431            value_lp = value_lp.contiguous()
2432
2433        with sdpa_kernel(backends=[SDPBackend.MATH]):
2434            out = torch.nn.functional.scaled_dot_product_attention(query, key, value, None, 0.0, is_causal)
2435
2436        with sdpa_kernel(backends=[SDPBackend.EFFICIENT_ATTENTION]):
2437            out_lp = torch.nn.functional.scaled_dot_product_attention(
2438                query_lp, key_lp, value_lp, None, 0.0, is_causal)
2439
2440        rand_upward = torch.rand_like(out)
2441        rand_upward_lp = rand_upward.to(torch.float32)
2442
2443        out.backward(rand_upward)
2444        out_lp.backward(rand_upward_lp)
2445
2446        # Cast up and compare
2447        self.assertEqual(qkv.grad, qkv_lp.grad.to(torch.float64), atol=1e-5, rtol=1e-5)
2448
2449    @unittest.skipIf(not PLATFORM_SUPPORTS_FLASH_ATTENTION, "Flash Attention was not built for this system")
2450    @parametrize("contiguous_inputs", [True, False])
2451    @parametrize("is_causal", [True, False])
2452    @parametrize("dtype", [torch.float16, torch.bfloat16])
2453    def test_sdp_flash_attention_grad_against_math(self, device, contiguous_inputs: bool, is_causal: bool, dtype: torch.dtype):
2454        batch_size, seq_len, num_heads, head_dim = 4, 4, 2, 16
2455        make_tensor = partial(rand_sdpa_tensor, type="dense", device=device,
2456                              dtype=torch.float64, requires_grad=True, packed=True)
2457
2458        qkv = make_tensor(SdpaShape(batch_size, num_heads, seq_len, head_dim))
2459        qkv_lp = qkv.detach().clone().to(dtype).requires_grad_()
2460
2461        query, key, value = qkv.chunk(3, dim=-1)
2462        query_lp, key_lp, value_lp = qkv_lp.chunk(3, dim=-1)
2463
2464        query = query.view(batch_size, -1, num_heads, head_dim).transpose(1, 2)
2465        key = key.view(batch_size, -1, num_heads, head_dim).transpose(1, 2)
2466        value = value.view(batch_size, -1, num_heads, head_dim).transpose(1, 2)
2467
2468        query_lp = query_lp.view(batch_size, -1, num_heads, head_dim).transpose(1, 2)
2469        key_lp = key_lp.view(batch_size, -1, num_heads, head_dim).transpose(1, 2)
2470        value_lp = value_lp.view(batch_size, -1, num_heads, head_dim).transpose(1, 2)
2471
2472        if contiguous_inputs:
2473            query = query.contiguous()
2474            key = key.contiguous()
2475            value = value.contiguous()
2476
2477            query_lp = query_lp.contiguous()
2478            key_lp = key_lp.contiguous()
2479            value_lp = value_lp.contiguous()
2480
2481        with sdpa_kernel(backends=[SDPBackend.MATH]):
2482            out = torch.nn.functional.scaled_dot_product_attention(query, key, value, None, 0.0, is_causal)
2483
2484        with sdpa_kernel(backends=[SDPBackend.FLASH_ATTENTION]):
2485            out_lp = torch.nn.functional.scaled_dot_product_attention(
2486                query_lp, key_lp, value_lp, None, 0.0, is_causal)
2487
2488        rand_upward = torch.rand_like(out)
2489        rand_upward_lp = rand_upward.to(dtype)
2490
2491        out.backward(rand_upward)
2492        out_lp.backward(rand_upward_lp)
2493
2494        # Cast up and compare
2495        # Since we are doing the compute on fp16 we have to bump the tolerance
2496        # Bump down the tolearnce for blfoat16
2497        atol = 7e-4 if dtype == torch.float16 else 7e-3
2498        rtol = 7e-4 if dtype == torch.float16 else 7e-3
2499        if TEST_WITH_ROCM:
2500            atol = 9e-4 if dtype == torch.float16 else 9e-3
2501        self.assertEqual(qkv.grad, qkv_lp.grad.to(torch.float64), atol=atol, rtol=rtol)
2502
2503    @skipIfRocm  # Missing nested and EFFICIENT_ATTENTION
2504    @unittest.skipIf(not PLATFORM_SUPPORTS_FUSED_ATTENTION, "Platform does not support fused SDPA")
2505    @parametrize("type", ["dense", "nested"])
2506    def test_fused_sdp_choice(self, device, type: str):
2507        batch_size, seq_len, num_heads, head_dim = 2, 128, 8, 64
2508        shape = SdpaShape(batch_size, num_heads, seq_len, head_dim)
2509        make_tensor = partial(rand_sdpa_tensor, device=device, dtype=torch.float16, packed=True, requires_grad=True)
2510
2511        qkv = make_tensor(shape, type=type)
2512        query, key, value = qkv.chunk(3, dim=-1)
2513
2514        query = query.view(batch_size, -1, num_heads, head_dim).transpose(1, 2)
2515        value = value.view(batch_size, -1, num_heads, head_dim).transpose(1, 2)
2516        key = key.view(batch_size, -1, num_heads, head_dim).transpose(1, 2)
2517
2518        if PLATFORM_SUPPORTS_FLASH_ATTENTION:
2519            assert torch._fused_sdp_choice(query, key, value) == SDPBackend.FLASH_ATTENTION.value
2520        else:
2521            assert torch._fused_sdp_choice(query, key, value) == SDPBackend.EFFICIENT_ATTENTION.value
2522
2523        # Change dtype to float32 so that efficient attention should get chosen
2524        make_tensor = partial(rand_sdpa_tensor, device=device, dtype=torch.float32, packed=True)
2525
2526        qkv = make_tensor(shape, type=type)
2527        query, key, value = qkv.chunk(3, dim=-1)
2528
2529        query = query.view(batch_size, -1, num_heads, head_dim).transpose(1, 2)
2530        value = value.view(batch_size, -1, num_heads, head_dim).transpose(1, 2)
2531        key = key.view(batch_size, -1, num_heads, head_dim).transpose(1, 2)
2532
2533        assert torch._fused_sdp_choice(query, key, value) == SDPBackend.EFFICIENT_ATTENTION.value
2534
2535    @skipIfRocm  # Missing triton.float32 ("triton" prefix is to locate skipped UTs), and deterministic algo
2536    @unittest.skipIf(not PLATFORM_SUPPORTS_MEM_EFF_ATTENTION, "Platform does not support fused SDPA")
2537    @parametrize("warn_only", [True, False])
2538    def test_sdp_choice_with_determinism(self, device, warn_only):
2539        batch_size, seq_len, num_heads, head_dim = 1, 64, 8, 64
2540        shape = SdpaShape(batch_size, num_heads, seq_len, head_dim)
2541        make_tensor = partial(rand_sdpa_tensor, type="dense", device=device, dtype=torch.float32, packed=False)
2542        query, key, value = make_tensor(shape), make_tensor(shape), make_tensor(shape)
2543
2544        with use_deterministic_algorithims(True, warn_only=warn_only):
2545            with sdpa_kernel(backends=[SDPBackend.EFFICIENT_ATTENTION, SDPBackend.MATH]):
2546                assert torch._fused_sdp_choice(query, key, value) == SDPBackend.EFFICIENT_ATTENTION.value
2547
2548    @skipIfRocm  # Missing deterministic algo
2549    @unittest.skipIf(not PLATFORM_SUPPORTS_FUSED_ATTENTION, "Fused SDPA was not built for this system")
2550    @parametrize("fused_kernel", PLATFORM_SPECIFIC_SDPA)
2551    @parametrize("warn_only", [True, False])
2552    def test_fused_backwards_throws_determinism_warning(self, device, warn_only, fused_kernel):
2553        batch_size, seq_len, num_heads, head_dim = 1, 64, 8, 64
2554        shape = SdpaShape(batch_size, num_heads, seq_len, head_dim)
2555        make_tensor = partial(rand_sdpa_tensor, type="dense", device=device, dtype=torch.float16, packed=False, requires_grad=True)
2556        query, key, value = make_tensor(shape), make_tensor(shape), make_tensor(shape)
2557
2558        kernel_name = "Memory Efficient attention" if fused_kernel == SDPBackend.EFFICIENT_ATTENTION else "Flash Attention"
2559        warning_context = (
2560            self.assertWarnsRegex(
2561                UserWarning,
2562                f"{kernel_name} defaults to a non-deterministic algorithm.",
2563            )
2564            if warn_only
2565            else contextlib.nullcontext()
2566        )
2567        with use_deterministic_algorithims(True, warn_only=warn_only):
2568            with sdpa_kernel(backends=[fused_kernel]):
2569                with warning_context:
2570                    torch.nn.functional.scaled_dot_product_attention(query, key, value).sum().backward()
2571
2572    @unittest.skip("This test is not behaving deterministaclly non-deterministaclly on CI/CD")
2573    @unittest.skipIf(not PLATFORM_SUPPORTS_FLASH_ATTENTION, "Platform does not support fused SDPA")
2574    def test_mem_eff_backwards_determinism(self, device):
2575        # Need big seq_len to ensure that num_splits > 1
2576        dtype = torch.float32
2577        batch_size, seq_len, n_heads, head_dim = 1, 1024, 8, 64
2578        query = torch.rand(batch_size, n_heads, seq_len, head_dim,
2579                           device=device, dtype=dtype, requires_grad=True)
2580        key = torch.rand(batch_size, n_heads, seq_len, head_dim, device=device,
2581                         dtype=dtype, requires_grad=True)
2582        value = torch.rand(batch_size, n_heads, seq_len, head_dim,
2583                           device=device, dtype=dtype, requires_grad=True)
2584
2585        with sdpa_kernel(backends=[SDPBackend.EFFICIENT_ATTENTION]):
2586            # Run once to establish baseline
2587            out = F.scaled_dot_product_attention(query, key, value)
2588            upward_grad = torch.rand_like(out)
2589            out.backward(upward_grad)
2590            intial_query_grad = query.grad
2591
2592            # Re-run the op with the same upward grad and check that the backward is
2593            # not deterministic
2594            diff_anwser_once = False
2595            for _ in range(100):
2596                query.grad = None
2597                out = F.scaled_dot_product_attention(query, key, value)
2598                out.backward(upward_grad)
2599                if not torch.equal(intial_query_grad, query.grad):
2600                    diff_anwser_once = True
2601                    break
2602            self.assertTrue(diff_anwser_once)
2603
2604        with use_deterministic_algorithims(True, warn_only=False):
2605            query.grad = None
2606            out = F.scaled_dot_product_attention(query, key, value)
2607            upward_grad = torch.rand_like(out)
2608            out.backward(upward_grad)
2609            intial_query_grad = query.grad
2610
2611            # Re-run the op with the same upward grad and check that the backward is
2612            # deterministic now that we have enforced it
2613            diff_anwser_once = False
2614            for _ in range(100):
2615                query.grad = None
2616                out = F.scaled_dot_product_attention(query, key, value)
2617                out.backward(upward_grad)
2618                if not torch.equal(intial_query_grad, query.grad):
2619                    diff_anwser_once = True
2620                    break
2621            self.assertFalse(diff_anwser_once)
2622
2623    # verified passing successfully on H100
2624    @unittest.skipIf(not PLATFORM_SUPPORTS_MEM_EFF_ATTENTION, "Does not support SDPA")
2625    @unittest.skipIf(IS_JETSON, "causing sigkill on Jetson")
2626    @parametrize("batch_size", [1, 8])
2627    @parametrize("seq_len_q", [4, 8, 64, 128, 256, 512, 1024, 2048] if MEM_EFF_CAPABILITY_MATCHES_SM80
2628                 else [4, 8, 64, 128, 256, 512])
2629    @parametrize("seq_len_k", [4, 8, 64, 128, 256, 512, 1024, 2048] if MEM_EFF_CAPABILITY_MATCHES_SM80
2630                 else [4, 8, 64, 128, 256, 512])
2631    @parametrize("head_dim", [8, 16, 32, 64, 72, 96, 128] if MEM_EFF_CAPABILITY_MATCHES_SM80
2632                 else [8, 16, 32, 64])
2633    @parametrize("is_causal", [False, True])
2634    @parametrize("dropout_p", [0.0, 0.22])
2635    @parametrize("dtype", [torch.float16, torch.bfloat16, torch.float32] if MEM_EFF_CAPABILITY_MATCHES_SM80
2636                 else [torch.float16, torch.float32])
2637    @parametrize("scale", [None, "l1"])
2638    def test_mem_efficient_attention_vs_math_ref_grads(self, device, batch_size: int, seq_len_q: int, seq_len_k: int,
2639                                                       head_dim: int, is_causal: bool, dropout_p: float, dtype: torch.dtype,
2640                                                       scale: str):
2641        def _get_mem_eff_drop_mask(batch_size, n_heads, q_len, kv_len, p, seed, offset, device=device):
2642            mask = torch.empty((batch_size, n_heads, q_len, kv_len), device=device, dtype=torch.float32)
2643            rand_uniform = torch._fill_mem_eff_dropout_mask_(mask, p, seed, offset)
2644            mask = (rand_uniform > p).to(torch.float32)
2645            return mask
2646        if max(seq_len_q, seq_len_k) >= 2048 and torch.cuda.get_device_properties('cuda').total_memory < 40 * 2**30:
2647            unittest.skip("Reference implementation OOM")
2648            return
2649        if TEST_WITH_ROCM and seq_len_q * seq_len_k * head_dim * batch_size > 1024 * 1024 * 128:
2650            torch.cuda.empty_cache()  # Prevent memory fragmentation
2651        seed = 42
2652        scale = scale if scale is None else (1 / head_dim)
2653        n_heads = 4
2654        query = torch.rand(batch_size, n_heads, seq_len_q, head_dim,
2655                           device=device, dtype=dtype, requires_grad=True)
2656        key = torch.rand(batch_size, n_heads, seq_len_k, head_dim, device=device,
2657                         dtype=dtype, requires_grad=True)
2658        value = torch.rand(batch_size, n_heads, seq_len_k, head_dim,
2659                           device=device, dtype=dtype, requires_grad=True)
2660
2661        # Run the math kernel on low precision references
2662        query_ref_lp, key_ref_lp, value_ref_lp = query_key_value_clones(query, key, value, dtype=dtype)
2663
2664        higher_precision_dtype = torch.float64 if dtype == torch.float32 else torch.float32
2665        query_ref, key_ref, value_ref = query_key_value_clones(query, key, value, dtype=higher_precision_dtype)
2666
2667        # Create real output
2668        with sdpa_kernel(backends=[SDPBackend.EFFICIENT_ATTENTION]):
2669            # Set the seed and run the kernel
2670            torch.manual_seed(seed)
2671            out = F.scaled_dot_product_attention(query, key, value, dropout_p=dropout_p, is_causal=is_causal, scale=scale)
2672
2673        if dropout_p == 0.0:
2674            with sdpa_kernel(backends=[SDPBackend.MATH]):
2675                # High Precision Math Reference
2676                out_ref = F.scaled_dot_product_attention(query_ref, key_ref, value_ref,
2677                                                         dropout_p=dropout_p, is_causal=is_causal, scale=scale)
2678                # Low Precision Math Reference
2679                out_lp_ref = F.scaled_dot_product_attention(query_ref_lp, key_ref_lp, value_ref_lp,
2680                                                            dropout_p=dropout_p, is_causal=is_causal, scale=scale)
2681        else:
2682            if seq_len_q > 1024:
2683                self.skipTest("Will call _fill_mem_eff_dropout_mask with too many threads!")
2684            # Create the dropout_mask
2685            torch.manual_seed(seed)
2686            dropout_mask = _get_mem_eff_drop_mask(batch_size, n_heads, seq_len_q, seq_len_k, dropout_p, seed, 0, device=device)
2687            # High Precision Math Reference
2688            out_ref = torch.ops.aten._scaled_dot_product_attention_math(
2689                query_ref, key_ref, value_ref, dropout_p=dropout_p, is_causal=is_causal, scale=scale, dropout_mask=dropout_mask)[0]
2690            # Low Precision Math Reference
2691            out_lp_ref = torch.ops.aten._scaled_dot_product_attention_math(
2692                query_ref_lp, key_ref_lp, value_ref_lp, dropout_p=dropout_p, is_causal=is_causal, scale=scale,
2693                dropout_mask=dropout_mask)[0]
2694
2695        upstream_grad = torch.rand_like(out, requires_grad=False)
2696
2697        out.backward(upstream_grad)
2698        out_ref.backward(upstream_grad.to(out_ref.dtype))
2699        out_lp_ref.backward(upstream_grad.to(out_lp_ref.dtype))
2700
2701        # [Note] Fused Tolerances
2702        # Establish the numerical error between the "true" high precision math output
2703        # and the low precision math reference. We use this reference for the atol
2704        # And we use the default rtol for the low precision type.
2705        # We then provide a fudge factor for gradients respectively to account
2706        # for the use of the fused kernel rather than the eager implemntation.
2707        output_ref_atol, output_ref_rtol = get_tolerances(out_ref, out_lp_ref)
2708
2709        # Fudge Factor when dropout is enabled
2710        dropout_fudge_factor = 1.0 if dropout_p == 0.0 else 2.0
2711
2712        query_fudge_factor = dropout_fudge_factor
2713        grad_q_ref_atol, grad_q_ref_rtol = get_tolerances(query_ref.grad, query_ref_lp.grad, query_fudge_factor)
2714
2715        # TODO: Investigate why grad_k needs larger tolerances
2716        key_fudge_factor = 8 * dropout_fudge_factor
2717        grad_k_ref_atol, grad_k_ref_rtol = get_tolerances(key_ref.grad, key_ref_lp.grad, key_fudge_factor)
2718
2719        value_fudge_factor = 7 if not SM80OrLater and dtype == torch.float16 else 1.0
2720        if TEST_WITH_ROCM:
2721            value_fudge_factor = max(2.0, value_fudge_factor)
2722        grad_v_ref_atol, grad_v_ref_rtol = get_tolerances(value_ref.grad, value_ref_lp.grad, value_fudge_factor)
2723
2724        self.assertEqual(out, out_ref.to(out.dtype), atol=output_ref_atol, rtol=output_ref_rtol)
2725        self.assertEqual(query.grad, query_ref.grad.to(query.grad.dtype),
2726                         atol=grad_q_ref_atol, rtol=grad_q_ref_rtol)
2727        self.assertEqual(key.grad, key_ref.grad.to(key.grad.dtype),
2728                         atol=grad_k_ref_atol, rtol=grad_k_ref_rtol)
2729        self.assertEqual(value.grad, value_ref.grad.to(value.grad.dtype),
2730                         atol=grad_v_ref_atol, rtol=grad_v_ref_rtol)
2731
2732
2733    @unittest.skipIf(not PLATFORM_SUPPORTS_MEM_EFF_ATTENTION, "Does not support SDPA")
2734    @unittest.skipIf(IS_JETSON, "causing sigkill on Jetson")
2735    @parametrize("batch_size", [1, 8])
2736    @parametrize("seq_len_q", [4, 8, 64, 128, 256, 312, 512, 1024, 2048] if MEM_EFF_CAPABILITY_MATCHES_SM80
2737                 else [4, 8, 64, 128, 152, 256, 512])
2738    @parametrize("seq_len_k", [4, 8, 64, 65, 128, 256, 408, 512, 1024, 2048] if MEM_EFF_CAPABILITY_MATCHES_SM80
2739                 else [4, 8, 37, 64, 128, 256, 512])
2740    @parametrize("head_dim", [8, 16, 32, 64, 72, 96, 128] if MEM_EFF_CAPABILITY_MATCHES_SM80
2741                 else [8, 16, 32, 64])
2742    @parametrize("is_causal", [False])
2743    @parametrize("dropout_p", [0.0, 0.22])
2744    @parametrize("dtype", [torch.float16, torch.bfloat16, torch.float32] if MEM_EFF_CAPABILITY_MATCHES_SM80
2745                 else [torch.float16, torch.float32])
2746    @parametrize("scale", [None, "l1"])
2747    def test_mem_efficient_attention_attn_mask_vs_math_ref_grads(self, device, batch_size: int, seq_len_q: int,
2748                                                                 seq_len_k: int, head_dim: int, is_causal: bool,
2749                                                                 dropout_p: float, dtype: torch.dtype,
2750                                                                 scale: str):
2751        def _get_mem_eff_drop_mask(batch_size, n_heads, q_len, kv_len, p, seed, offset, device=device):
2752            mask = torch.empty((batch_size, n_heads, q_len, kv_len), device=device, dtype=torch.float32)
2753            rand_uniform = torch._fill_mem_eff_dropout_mask_(mask, p, seed, offset)
2754            mask = (rand_uniform > p).to(torch.float32)
2755            return mask
2756        if max(seq_len_q, seq_len_k) >= 2048 and torch.cuda.get_device_properties('cuda').total_memory < 40 * 2**30:
2757            unittest.skip("Reference implementation OOM")
2758            return
2759        if TEST_WITH_ROCM and dtype == torch.float32:
2760            unittest.skip("Skip fp32 attn_mask gradients on ROCM, for now.")
2761            return
2762        if TEST_WITH_ROCM and seq_len_q * seq_len_k * head_dim * batch_size > 1024 * 1024 * 128:
2763            torch.cuda.empty_cache()  # Prevent memory fragmentation
2764        seed = 42
2765        scale = scale if scale is None else (1 / head_dim)
2766        n_heads = 4
2767        query = torch.rand(batch_size, n_heads, seq_len_q, head_dim,
2768                           device=device, dtype=dtype, requires_grad=True)
2769        key = torch.rand(batch_size, n_heads, seq_len_k, head_dim, device=device,
2770                         dtype=dtype, requires_grad=True)
2771        value = torch.rand(batch_size, n_heads, seq_len_k, head_dim,
2772                           device=device, dtype=dtype, requires_grad=True)
2773
2774        attn_mask = torch.rand(seq_len_q, seq_len_k, device=device, dtype=dtype, requires_grad=True)
2775
2776        # Run the math kernel on low precision references
2777        query_ref_lp, key_ref_lp, value_ref_lp = query_key_value_clones(query, key, value, dtype=dtype)
2778        attn_mask_ref_lp = attn_mask.detach().to(dtype).requires_grad_(True)
2779
2780        higher_precision_dtype = torch.float64 if dtype == torch.float32 else torch.float32
2781        query_ref, key_ref, value_ref = query_key_value_clones(query, key, value, dtype=higher_precision_dtype)
2782        attn_mask_ref = attn_mask.detach().to(higher_precision_dtype).requires_grad_(True)
2783
2784        # Create real output
2785        with sdpa_kernel(backends=[SDPBackend.EFFICIENT_ATTENTION]):
2786            # Set the seed and run the kernel
2787            torch.manual_seed(seed)
2788            out = F.scaled_dot_product_attention(query, key, value, attn_mask, dropout_p=dropout_p,
2789                                                 is_causal=is_causal, scale=scale)
2790
2791        if dropout_p == 0.0:
2792            with sdpa_kernel(backends=[SDPBackend.MATH]):
2793                # High Precision Math Reference
2794                out_ref = F.scaled_dot_product_attention(query_ref, key_ref, value_ref, attn_mask_ref,
2795                                                         dropout_p=dropout_p, is_causal=is_causal, scale=scale)
2796                # Low Precision Math Reference
2797                out_lp_ref = F.scaled_dot_product_attention(query_ref_lp, key_ref_lp, value_ref_lp, attn_mask_ref_lp,
2798                                                            dropout_p=dropout_p, is_causal=is_causal, scale=scale)
2799        else:
2800            if seq_len_q > 1024:
2801                self.skipTest("Will call _fill_mem_eff_dropout_mask with too many threads!")
2802            # Create the dropout_mask
2803            torch.manual_seed(seed)
2804            dropout_mask = _get_mem_eff_drop_mask(batch_size, n_heads, seq_len_q,
2805                                                  seq_len_k, dropout_p, seed, 0, device=device)
2806            # High Precision Math Reference
2807            out_ref = torch.ops.aten._scaled_dot_product_attention_math(
2808                query_ref, key_ref, value_ref, attn_mask_ref, dropout_p=dropout_p, is_causal=is_causal,
2809                scale=scale, dropout_mask=dropout_mask)[0]
2810            # Low Precision Math Reference
2811            out_lp_ref = torch.ops.aten._scaled_dot_product_attention_math(
2812                query_ref_lp, key_ref_lp, value_ref_lp, attn_mask_ref_lp,
2813                dropout_p=dropout_p, is_causal=is_causal, scale=scale,
2814                dropout_mask=dropout_mask)[0]
2815
2816        upstream_grad = torch.rand_like(out, requires_grad=False)
2817
2818        out.backward(upstream_grad)
2819        out_ref.backward(upstream_grad.to(out_ref.dtype))
2820        out_lp_ref.backward(upstream_grad.to(out_lp_ref.dtype))
2821
2822        # [Note] Fused Tolerances
2823        # Establish the numerical error between the "true" high precision math output
2824        # and the low precision math reference. We use this reference for the atol
2825        # And we use the default rtol for the low precision type.
2826        # We then provide a fudge factor for gradients respectively to account
2827        # for the use of the fused kernel rather than the eager implemntation.
2828        output_ref_atol, output_ref_rtol = get_tolerances(out_ref, out_lp_ref)
2829
2830        # Fudge Factor when dropout is enabled
2831        dropout_fudge_factor = 1.0 if dropout_p == 0.0 else 1.75
2832        mask_fudge_factor = 1.0 if attn_mask is None else 1.5
2833
2834        query_fudge_factor = dropout_fudge_factor
2835        grad_q_ref_atol, grad_q_ref_rtol = get_tolerances(query_ref.grad, query_ref_lp.grad, query_fudge_factor)
2836
2837        # TODO: Investigate why grad_k needs larger tolerances
2838        key_fudge_factor = 8 * dropout_fudge_factor * mask_fudge_factor
2839        grad_k_ref_atol, grad_k_ref_rtol = get_tolerances(key_ref.grad, key_ref_lp.grad, key_fudge_factor)
2840
2841        value_fudge_factor = 7 if not SM80OrLater and dtype == torch.float16 else 1.0
2842        if TEST_WITH_ROCM:
2843            value_fudge_factor = max(2.0, value_fudge_factor)
2844        grad_v_ref_atol, grad_v_ref_rtol = get_tolerances(value_ref.grad, value_ref_lp.grad, value_fudge_factor)
2845
2846        mask_fudge_factor = 12 if attn_mask.numel() > 512 else 22
2847        grad_attn_mask_atol, grad_attn_mask_rtol = get_tolerances(
2848            attn_mask_ref.grad, attn_mask_ref_lp.grad, mask_fudge_factor)
2849
2850        self.assertEqual(out, out_ref.to(out.dtype), atol=output_ref_atol, rtol=output_ref_rtol)
2851        self.assertEqual(query.grad, query_ref.grad.to(query.grad.dtype),
2852                         atol=grad_q_ref_atol, rtol=grad_q_ref_rtol)
2853        self.assertEqual(key.grad, key_ref.grad.to(key.grad.dtype),
2854                         atol=grad_k_ref_atol, rtol=grad_k_ref_rtol)
2855        self.assertEqual(value.grad, value_ref.grad.to(value.grad.dtype),
2856                         atol=grad_v_ref_atol, rtol=grad_v_ref_rtol)
2857
2858        self.assertEqual(attn_mask.grad, attn_mask_ref.grad.to(attn_mask.grad.dtype),
2859                         atol=grad_attn_mask_atol, rtol=grad_attn_mask_rtol)
2860
2861    @unittest.skipIf(not PLATFORM_SUPPORTS_FLASH_ATTENTION, "Does not support SDPA or pre-SM80 hardware")
2862    @unittest.skipIf(IS_JETSON, "causing sigkill on Jetson")
2863    @parametrize("batch_size", [1, 8])
2864    @parametrize("seq_len_q", [4, 8, 64, 143, 256, 512, 1024, 2048])
2865    @parametrize("seq_len_k", [4, 8, 64, 128, 256, 587, 1024, 2048])
2866    @parametrize("head_dim", [8, 16, 21, 32, 64, 72, 96, 128, 160, 192, 203, 256])
2867    @parametrize("is_causal", [True, False])
2868    @parametrize("dropout_p", [0.0, 0.22, 0.48])
2869    @parametrize("dtype", [torch.float16, torch.bfloat16])
2870    @parametrize("scale", [None, "l1"])
2871    def test_flash_attention_vs_math_ref_grads(self, device, batch_size: int, seq_len_q: int, seq_len_k: int,
2872                                               head_dim: int, is_causal: bool, dropout_p: float, dtype: torch.dtype,
2873                                               scale: str):
2874        if isSM8XDevice and head_dim in range(193, 256 + 1):
2875            self.skipTest("Flash attention on sm86, sm87, and sm89 for headdim > 192 currently disabled")
2876        if is_causal and seq_len_q != seq_len_k:
2877            self.skipTest("Flash V2 does not accept is_casual when seq_len_q != seq_len_k")
2878        if TEST_WITH_ROCM and seq_len_q >= 1024 and seq_len_k >= 1024 and batch_size > 1:
2879            torch.cuda.empty_cache()  # Prevent memory fragmentation
2880
2881        scale = scale if scale is None else (1 / head_dim)
2882        n_heads = 4
2883        query = torch.rand(batch_size, n_heads, seq_len_q, head_dim,
2884                           device=device, dtype=dtype, requires_grad=True)
2885        key = torch.rand(batch_size, n_heads, seq_len_k, head_dim, device=device,
2886                         dtype=dtype, requires_grad=True)
2887        value = torch.rand(batch_size, n_heads, seq_len_k, head_dim,
2888                           device=device, dtype=dtype, requires_grad=True)
2889
2890        # Run the math kernel on low precision references
2891        query_ref_lp, key_ref_lp, value_ref_lp = query_key_value_clones(query, key, value, dtype=dtype)
2892
2893        higher_precision_dtype = torch.float64 if dtype == torch.float32 else torch.float32
2894        query_ref, key_ref, value_ref = query_key_value_clones(query, key, value, dtype=higher_precision_dtype)
2895
2896        is_dropout = dropout_p > 0.0
2897
2898        if not is_dropout:
2899            with sdpa_kernel(backends=[SDPBackend.FLASH_ATTENTION]):
2900                out = F.scaled_dot_product_attention(query, key, value, dropout_p=dropout_p, is_causal=is_causal, scale=scale)
2901            with sdpa_kernel(backends=[SDPBackend.MATH]):
2902                # High Precision Math Reference
2903                out_ref = F.scaled_dot_product_attention(
2904                    query_ref, key_ref, value_ref, is_causal=is_causal, scale=scale)
2905                # Low Precision Math Reference
2906                out_lp_ref = F.scaled_dot_product_attention(
2907                    query_ref_lp, key_ref_lp, value_ref_lp, is_causal=is_causal, scale=scale)
2908        else:
2909            # Problem: We pad sizes in the composite region of the top level SDPA. But we need the
2910            # Debug mask when have dropout. So I am going to manualy pad up here when testing dropout
2911            q_padded, q_og_size = pad_last_dim(query, 8)
2912            k_padded, k_og_size = pad_last_dim(key, 8)
2913            v_padded, v_og_size = pad_last_dim(value, 8)
2914            # scale needs to be calculated on the og head_size
2915            if scale is None:
2916                scale = 1 / math.sqrt(q_og_size)
2917            output_tuple = torch.ops.aten._scaled_dot_product_flash_attention(
2918                q_padded, k_padded, v_padded, dropout_p=dropout_p, is_causal=is_causal, scale=scale, return_debug_mask=is_dropout)
2919            out = output_tuple[0]
2920            out = out[..., :v_og_size]
2921            # Build dropout_mask
2922            dbug_mask = output_tuple[-1]
2923            query_padding_mask = torch.ones(
2924                batch_size, seq_len_q, device=device, dtype=torch.bool)
2925            key_padding_mask = torch.ones(
2926                batch_size, seq_len_k, device=device, dtype=torch.bool)
2927
2928            softmax_mask = self.convert_flash_attn_S_to_softmax(
2929                dbug_mask, seq_len_q, seq_len_k, query_padding_mask, key_padding_mask,
2930                causal=is_causal)[:, :, :seq_len_q, :seq_len_k]
2931            dropout_mask = softmax_mask >= 0
2932            # attn_unnorm = softmax_mask.abs()
2933            # attn = self.normalize_flash_attn_S(attn_unnorm, q_padded,
2934            #                                    k_padded, v_padded, query_padding_mask,
2935            #                                    key_padding_mask, None, True, is_causal, scale=scale)
2936
2937            # High Precision Math Reference
2938            out_ref = torch.ops.aten._scaled_dot_product_attention_math(
2939                query_ref, key_ref, value_ref, dropout_p=dropout_p, is_causal=is_causal, scale=scale, dropout_mask=dropout_mask)[0]
2940            # Low Precision Math Reference
2941            out_lp_ref = torch.ops.aten._scaled_dot_product_attention_math(
2942                query_ref_lp, key_ref_lp, value_ref_lp, dropout_p=dropout_p, is_causal=is_causal, scale=scale,
2943                dropout_mask=dropout_mask)[0]
2944
2945        upstream_grad = torch.rand_like(out, requires_grad=False)
2946
2947        # backward for flash attention on sm86, sm87, and sm89 for headdim >= 193 currently disabled
2948        if isSM8XDevice and head_dim in range(193, 256):
2949            self.assertRaises(RuntimeError, lambda: out.backward(upstream_grad))
2950            return
2951        out.backward(upstream_grad)
2952        out_ref.backward(upstream_grad.to(out_ref.dtype))
2953        out_lp_ref.backward(upstream_grad.to(out_lp_ref.dtype))
2954
2955        # See [Note] Fused Tolerances above
2956        output_fudge_factor = 3 if head_dim % 8 != 0 or TEST_WITH_ROCM else 1
2957        output_ref_atol, output_ref_rtol = get_tolerances(out_ref, out_lp_ref, output_fudge_factor)
2958
2959        # TODO: Investigate why grad_q needs larger tolerances
2960        query_fudge_factor = 4
2961        grad_q_ref_atol, grad_q_ref_rtol = get_tolerances(query_ref.grad, query_ref_lp.grad, query_fudge_factor)
2962
2963        key_fudge_factor = 2
2964        grad_k_ref_atol, grad_k_ref_rtol = get_tolerances(key_ref.grad, key_ref_lp.grad, key_fudge_factor)
2965
2966        value_fudge_factor = 2
2967        grad_v_ref_atol, grad_v_ref_rtol = get_tolerances(value_ref.grad, value_ref_lp.grad, value_fudge_factor)
2968
2969        self.assertEqual(out, out_ref.to(out.dtype), atol=output_ref_atol, rtol=output_ref_rtol)
2970        self.assertEqual(query.grad, query_ref.grad.to(query.grad.dtype),
2971                         atol=grad_q_ref_atol, rtol=grad_q_ref_rtol)
2972        self.assertEqual(key.grad, key_ref.grad.to(key.grad.dtype),
2973                         atol=grad_k_ref_atol, rtol=grad_k_ref_rtol)
2974        self.assertEqual(value.grad, value_ref.grad.to(value.grad.dtype),
2975                         atol=grad_v_ref_atol, rtol=grad_v_ref_rtol)
2976
2977    @skipIfRocm  # FIXME: "capturing stream has unjoined work"
2978    @unittest.skipIf(not PLATFORM_SUPPORTS_FLASH_ATTENTION, "Does not support SDPA or pre-SM80 hardware")
2979    @parametrize("batch_size", [1, 8])
2980    @parametrize("seq_len_q", [256, 512, 1024])
2981    @parametrize("seq_len_k", [256, 512, 1024])
2982    @parametrize("head_dim", [32, 64])
2983    @parametrize("is_causal", [True, False])
2984    @parametrize("dropout_p", [0.0, 0.22])
2985    @parametrize("dtype", [torch.float16,])
2986    @parametrize("scale", [None, "l1"])
2987    @parametrize("fused_kernel", PLATFORM_SPECIFIC_SDPA)
2988    def test_fused_attention_vs_math_ref_grads_cudagraph(self, device, batch_size: int, seq_len_q: int, seq_len_k: int,
2989                                                         head_dim: int,
2990                                                         is_causal: bool,
2991                                                         dropout_p: float,
2992                                                         dtype: torch.dtype,
2993                                                         scale: str,
2994                                                         fused_kernel: SDPBackend):
2995        def _get_mem_eff_drop_mask(batch_size, n_heads, q_len, kv_len, dropout_p, seed, offset, device=device):
2996            mask = torch.empty((batch_size, n_heads, q_len, kv_len), device=device, dtype=torch.float32)
2997            rand_uniform = torch._fill_mem_eff_dropout_mask_(mask, dropout_p, seed, offset)
2998            mask = (rand_uniform > dropout_p).to(torch.float32)
2999            return mask
3000
3001        def get_dropout_mask(output, fused_kernel, batch_size, n_heads, q_len, kv_len, dropout_p, device=device):
3002            if fused_kernel == SDPBackend.EFFICIENT_ATTENTION:
3003                output_seed, output_offset = output_tuple[2], output_tuple[3]
3004                output_seed = output_seed.item()
3005                output_offset = output_offset.item()
3006                return _get_mem_eff_drop_mask(batch_size, n_heads, q_len, kv_len,
3007                                              dropout_p, output_seed, output_offset, device=device)
3008            else:
3009                # Build dropout_mask
3010                dbug_mask = output_tuple[-1]
3011                query_padding_mask = torch.ones(
3012                    batch_size, seq_len_q, device=device, dtype=torch.bool)
3013                key_padding_mask = torch.ones(
3014                    batch_size, seq_len_k, device=device, dtype=torch.bool)
3015
3016                softmax_mask = self.convert_flash_attn_S_to_softmax(
3017                    dbug_mask, seq_len_q, seq_len_k, query_padding_mask, key_padding_mask,
3018                    causal=is_causal)[:, :, :seq_len_q, :seq_len_k]
3019                dropout_mask = softmax_mask >= 0
3020                return dropout_mask
3021
3022        if fused_kernel == SDPBackend.FLASH_ATTENTION and is_causal and seq_len_q != seq_len_k:
3023            self.skipTest("Flash V2 does not accept is_casual when seq_len_q != seq_len_k")
3024
3025        seed = 42
3026        scale = scale if scale is None else (1 / head_dim)
3027        n_heads = 4
3028        query = torch.rand(batch_size, n_heads, seq_len_q, head_dim,
3029                           device=device, dtype=dtype, requires_grad=True)
3030        key = torch.rand(batch_size, n_heads, seq_len_k, head_dim, device=device,
3031                         dtype=dtype, requires_grad=True)
3032        value = torch.rand(batch_size, n_heads, seq_len_k, head_dim,
3033                           device=device, dtype=dtype, requires_grad=True)
3034
3035        fused_op = (torch.ops.aten._scaled_dot_product_efficient_attention
3036                    if fused_kernel == SDPBackend.EFFICIENT_ATTENTION else torch.ops.aten._scaled_dot_product_flash_attention)
3037        # Run the math kernel on low precision references
3038        query_ref_lp, key_ref_lp, value_ref_lp = query_key_value_clones(query, key, value, dtype=dtype)
3039
3040        higher_precision_dtype = torch.float64 if dtype == torch.float32 else torch.float32
3041        query_ref, key_ref, value_ref = query_key_value_clones(query, key, value, dtype=higher_precision_dtype)
3042
3043        # warmup
3044        s = torch.cuda.Stream()
3045        s.wait_stream(torch.cuda.current_stream())
3046        # Set the global seed before capture
3047        torch.manual_seed(seed)
3048        kwargs = {"dropout_p": dropout_p, "is_causal": is_causal, "scale": scale}
3049        if fused_kernel == SDPBackend.EFFICIENT_ATTENTION:
3050            kwargs["compute_log_sumexp"] = True
3051            kwargs["attn_bias"] = None
3052        if fused_kernel == SDPBackend.FLASH_ATTENTION:
3053            kwargs['return_debug_mask'] = dropout_p > 0.0
3054        with torch.cuda.stream(s):
3055            # Create real output
3056            output_tuple = fused_op(query, key, value, **kwargs)
3057
3058        torch.cuda.current_stream().wait_stream(s)
3059        out = output_tuple[0]
3060        upstream_grad = torch.rand_like(out, requires_grad=False)
3061        s.wait_stream(torch.cuda.current_stream())
3062        with torch.cuda.stream(s):
3063            out.backward(upstream_grad)
3064        for x in (query, key, value):
3065            x.grad = None
3066        g = torch.cuda.CUDAGraph()
3067        # Create real output
3068        with torch.cuda.graph(g):
3069            tmp = torch.rand_like(query, device=query.device)  # test non-zero intragraph offset
3070            # Create real output
3071            output_tuple = fused_op(query, key, value, **kwargs)
3072            assert all(not isinstance(o, torch.Tensor) or o.is_cuda for o in output_tuple)
3073        g.replay()
3074        out_first = output_tuple[0].clone()
3075        g.replay()
3076        out = output_tuple[0]
3077        if dropout_p == 0.0:
3078            self.assertEqual(out_first, out, atol=0, rtol=0)
3079        else:
3080            # replays produce different results
3081            self.assertNotEqual(out_first, out)
3082
3083        with sdpa_kernel(backends=[SDPBackend.MATH]):
3084            if dropout_p == 0.0:
3085                # High Precision Math Reference
3086                out_ref = F.scaled_dot_product_attention(query_ref, key_ref, value_ref,
3087                                                         dropout_p=dropout_p, is_causal=is_causal, scale=scale)
3088                # Low Precision Math Reference
3089                out_lp_ref = F.scaled_dot_product_attention(query_ref_lp, key_ref_lp, value_ref_lp,
3090                                                            dropout_p=dropout_p, is_causal=is_causal, scale=scale)
3091            else:
3092                # Create the dropout_mask
3093                dropout_mask = get_dropout_mask(output_tuple, fused_kernel, batch_size,
3094                                                n_heads, seq_len_q, seq_len_k, dropout_p, device)
3095                # High Precision Math Reference
3096                out_ref = torch.ops.aten._scaled_dot_product_attention_math(
3097                    query_ref, key_ref, value_ref, dropout_p=dropout_p, is_causal=is_causal,
3098                    scale=scale, dropout_mask=dropout_mask)[0]
3099                # Low Precision Math Reference
3100                out_lp_ref = torch.ops.aten._scaled_dot_product_attention_math(
3101                    query_ref_lp, key_ref_lp, value_ref_lp, dropout_p=dropout_p, is_causal=is_causal, scale=scale,
3102                    dropout_mask=dropout_mask)[0]
3103
3104
3105        g1 = torch.cuda.CUDAGraph()
3106        with torch.cuda.graph(g1):
3107            out.backward(upstream_grad)
3108        g1.replay()
3109        out_ref.backward(upstream_grad.to(out_ref.dtype))
3110        out_lp_ref.backward(upstream_grad.to(out_lp_ref.dtype))
3111
3112        # [Note] Fused Tolerances
3113        # Establish the numerical error between the "true" high precision math output
3114        # and the low precision math reference. We use this reference for the atol
3115        # And we use the default rtol for the low precision type.
3116        # We then provide a fudge factor for gradients respectively to account
3117        # for the use of the fused kernel rather than the eager implemntation.
3118        output_ref_atol, output_ref_rtol = get_tolerances(out_ref, out_lp_ref)
3119
3120        # Fudge Factor when dropout is enabled
3121        dropout_fudge_factor = 1.0 if dropout_p == 0.0 else 1.5
3122
3123        query_fudge_factor = dropout_fudge_factor
3124        grad_q_ref_atol, grad_q_ref_rtol = get_tolerances(query_ref.grad, query_ref_lp.grad, query_fudge_factor)
3125
3126        # TODO: Investigate why grad_k needs larger tolerances
3127        key_fudge_factor = 8 * dropout_fudge_factor
3128        grad_k_ref_atol, grad_k_ref_rtol = get_tolerances(key_ref.grad, key_ref_lp.grad, key_fudge_factor)
3129
3130        value_fudge_factor = 7 if not SM80OrLater and dtype == torch.float16 else 1.0
3131        grad_v_ref_atol, grad_v_ref_rtol = get_tolerances(value_ref.grad, value_ref_lp.grad, value_fudge_factor)
3132
3133        self.assertEqual(out, out_ref.to(out.dtype), atol=output_ref_atol, rtol=output_ref_rtol)
3134        self.assertEqual(query.grad, query_ref.grad.to(query.grad.dtype),
3135                         atol=grad_q_ref_atol, rtol=grad_q_ref_rtol)
3136        self.assertEqual(key.grad, key_ref.grad.to(key.grad.dtype),
3137                         atol=grad_k_ref_atol, rtol=grad_k_ref_rtol)
3138        self.assertEqual(value.grad, value_ref.grad.to(value.grad.dtype),
3139                         atol=grad_v_ref_atol, rtol=grad_v_ref_rtol)
3140
3141    @skipIfRocm  # Nested Tensor
3142    @unittest.skipIf(not PLATFORM_SUPPORTS_FUSED_ATTENTION, "Fused SDPA was not built for this system")
3143    @parametrize("fused_kernel", [SDPBackend.FLASH_ATTENTION, SDPBackend.EFFICIENT_ATTENTION] if
3144                 PLATFORM_SUPPORTS_FLASH_ATTENTION else [SDPBackend.EFFICIENT_ATTENTION])
3145    def test_fused_kernels_seq_len_1_inputs(self, device, fused_kernel):
3146        rand_nested_tensor = partial(rand_sdpa_tensor, type="nested", device=device, dtype=torch.float16)
3147        batch, num_heads, head_dim = 32, 16, 64
3148        seq_lens = torch.randint(low=1, high=32, size=(batch,))
3149        # make sure some seq_lens are 1
3150        num_ones = 10
3151        indices = torch.randint(low=0, high=batch, size=(num_ones,))
3152        seq_lens.scatter_(0, indices, 1)
3153
3154        shape = SdpaShape(batch, num_heads, seq_lens.tolist(), head_dim)
3155        query = rand_nested_tensor(shape)
3156        key = rand_nested_tensor(shape)
3157        value = rand_nested_tensor(shape)
3158
3159        query = query.transpose(1, 2)
3160        key = key.transpose(1, 2)
3161        value = value.transpose(1, 2)
3162
3163        with sdpa_kernel(backends=[fused_kernel]):
3164            actual = torch.nn.functional.scaled_dot_product_attention(
3165                query, key, value, attn_mask=None, dropout_p=0.0, is_causal=False)
3166        with sdpa_kernel(backends=[SDPBackend.MATH]):
3167            math_ref = torch.nn.functional.scaled_dot_product_attention(
3168                query.contiguous().to(torch.float32),
3169                key.contiguous().to(torch.float32),
3170                value.contiguous().to(torch.float32),
3171                attn_mask=None, dropout_p=0.0, is_causal=False)
3172
3173        self.assertEqual(actual.contiguous(), math_ref.contiguous().to(torch.float16), atol=1e-3, rtol=1e-2)
3174
3175
3176    @skipIfRocm  # Nested tensor
3177    @unittest.skipIf(not PLATFORM_SUPPORTS_FUSED_ATTENTION, "Fused SDPA was not built for this system")
3178    @parametrize("kernel", [SDPBackend.FLASH_ATTENTION, SDPBackend.EFFICIENT_ATTENTION] if
3179                 PLATFORM_SUPPORTS_FLASH_ATTENTION else [SDPBackend.EFFICIENT_ATTENTION])
3180    @parametrize("expand_q_batch", [True, False])
3181    @parametrize("expand_k_batch", [True, False])
3182    @parametrize("expand_v_batch", [True, False])
3183    @parametrize("expand_q_num_heads", [True, False])
3184    @parametrize("expand_k_num_heads", [True, False])
3185    @parametrize("expand_v_num_heads", [True, False])
3186    def test_fused_kernels_nested_broadcasting(
3187        self,
3188        device,
3189        kernel,
3190        expand_q_batch,
3191        expand_k_batch,
3192        expand_v_batch,
3193        expand_q_num_heads,
3194        expand_k_num_heads,
3195        expand_v_num_heads,
3196    ):
3197        is_efficient = kernel == SDPBackend.EFFICIENT_ATTENTION
3198        dtype = torch.float32 if is_efficient else torch.float16
3199        rand_nested_tensor = partial(rand_sdpa_tensor, type="nested", device=device, dtype=dtype)
3200        batch, num_heads, head_dim = 32, 8, 64
3201        head_dim_v = 32 if is_efficient else head_dim
3202        seq_lens_q = (torch.randint(low=1, high=5, size=(1,)).item()
3203                      if expand_q_batch
3204                      else torch.randint(low=1, high=32, size=(batch,)).tolist())
3205        seq_lens_kv = (torch.randint(low=1, high=5, size=(1,)).item()
3206                       if (expand_k_batch or expand_v_batch)
3207                       else torch.randint(low=1, high=32, size=(batch,)).tolist())
3208
3209        batch_q = 1 if expand_q_batch else batch
3210        batch_k = 1 if expand_k_batch else batch
3211        batch_v = 1 if expand_v_batch else batch
3212
3213        # handle case where all batch_sizes are 1
3214        batch = max(batch_q, batch_k, batch_v)
3215
3216        num_heads_q = 1 if expand_q_num_heads else num_heads
3217        num_heads_k = 1 if expand_k_num_heads else num_heads
3218        num_heads_v = 1 if expand_v_num_heads else num_heads
3219
3220        # handle case where all num_heads are 1
3221        num_heads = max(num_heads_q, num_heads_k, num_heads_v)
3222
3223        q_shape = SdpaShape(batch_q, num_heads_q, seq_lens_q, head_dim)
3224        k_shape = SdpaShape(batch_k, num_heads_k, seq_lens_kv, head_dim)
3225        v_shape = SdpaShape(batch_v, num_heads_v, seq_lens_kv, head_dim_v)
3226
3227        query = rand_nested_tensor(q_shape)
3228        key = rand_nested_tensor(k_shape)
3229        value = rand_nested_tensor(v_shape)
3230
3231        def _broadcast(t, batch_broadcasted, num_heads_broadcasted):
3232            if batch_broadcasted and num_heads_broadcasted:
3233                # (1, seq_len, 1, head_dim) -> (batch, seq_len, num_heads, head_dim)
3234                result = torch.nested.nested_tensor(
3235                    [t[0].expand(-1, num_heads, t.size(-1)) for _ in range(batch)], dtype=torch.float32)
3236            elif batch_broadcasted:
3237                # (1, seq_len, num_heads, head_dim) -> (batch, seq_len, num_heads, head_dim)
3238                result = torch.nested.nested_tensor([t[0] for _ in range(batch)], dtype=torch.float32)
3239            elif num_heads_broadcasted:
3240                # (batch, seq_len, 1, head_dim) -> (batch, seq_len, num_heads, head_dim)
3241                result = torch.nested.nested_tensor([x.expand(-1, num_heads, t.size(-1))
3242                                                    for x in t.unbind()], dtype=torch.float32)
3243            else:
3244                result = t.to(torch.float32)
3245            return result
3246
3247        query_expanded = _broadcast(query, expand_q_batch, expand_q_num_heads).transpose(1, 2)
3248        key_expanded = _broadcast(key, expand_k_batch, expand_k_num_heads).transpose(1, 2)
3249        value_expanded = _broadcast(value, expand_v_batch, expand_v_num_heads).transpose(1, 2)
3250
3251        query = query.transpose(1, 2)
3252        key = key.transpose(1, 2)
3253        value = value.transpose(1, 2)
3254
3255        with sdpa_kernel(backends=[kernel]):
3256            actual = torch.nn.functional.scaled_dot_product_attention(
3257                query, key, value, attn_mask=None, dropout_p=0.0, is_causal=False)
3258        with sdpa_kernel(backends=[SDPBackend.MATH]):
3259            math_ref = torch.nn.functional.scaled_dot_product_attention(
3260                query_expanded.contiguous(), key_expanded.contiguous(), value_expanded.contiguous(),
3261                attn_mask=None, dropout_p=0.0, is_causal=False)
3262
3263        self.assertEqual(actual.contiguous(), math_ref.contiguous().to(dtype), atol=1e-3, rtol=1e-2)
3264
3265    @skipIfRocm  # Nested tensor
3266    @unittest.skipIf(not PLATFORM_SUPPORTS_MEM_EFF_ATTENTION, "Fused SDPA was not built for this system")
3267    def test_fused_kernels_nested_broadcasting_query_dense(self, device):
3268        rand_nested_tensor = partial(rand_sdpa_tensor, type="nested", device=device, dtype=torch.float32)
3269        batch, num_heads, head_dim, head_dim_v = 32, 16, 64, 96
3270        seq_lens = torch.randint(low=1, high=32, size=(batch,)).tolist()
3271        q_shape = (1, 1, num_heads, head_dim)
3272        k_shape = SdpaShape(batch, num_heads, seq_lens, head_dim)
3273        v_shape = SdpaShape(batch, 1, seq_lens, head_dim_v)
3274
3275        # create a dense query
3276        query = torch.randn(q_shape, device=device, dtype=torch.float32)
3277        key = rand_nested_tensor(k_shape)
3278        value = rand_nested_tensor(v_shape)
3279
3280        # (1, 1, num_heads, head_dim) -> (batch, 1, num_heads, head_dim)
3281        query_expanded = torch.nested.nested_tensor([query.squeeze(0) for _ in range(batch)]).transpose(1, 2)
3282        # (batch, seq_lens, 1, head_dim) -> (batch, seq_lens, num_heads, head_dim)
3283        value_expanded = torch.nested.nested_tensor(
3284            [t.expand(-1, num_heads, head_dim_v) for t in value.unbind()]).transpose(1, 2)
3285
3286        query = query.transpose(1, 2)
3287        key = key.transpose(1, 2)
3288        value = value.transpose(1, 2)
3289
3290        with sdpa_kernel(backends=[SDPBackend.EFFICIENT_ATTENTION]):
3291            actual = torch.nn.functional.scaled_dot_product_attention(
3292                query, key, value, attn_mask=None, dropout_p=0.0, is_causal=False)
3293        with sdpa_kernel(backends=[SDPBackend.MATH]):
3294            math_ref = torch.nn.functional.scaled_dot_product_attention(
3295                query_expanded.contiguous(), key.contiguous(), value_expanded.contiguous(),
3296                attn_mask=None, dropout_p=0.0, is_causal=False)
3297
3298        self.assertEqual(actual.contiguous(), math_ref.contiguous(), atol=1e-3, rtol=1e-2)
3299
3300    @onlyCUDA
3301    @skipIfRocm  # Nested tensor
3302    @unittest.skipIf(not PLATFORM_SUPPORTS_FLASH_ATTENTION, "Does not support SDPA or pre-SM80 hardware")
3303    @parametrize("batch_size", [8, 32])
3304    @parametrize("max_seq_len_q", [32, 256])
3305    @parametrize("max_seq_len_kv", [32, 256])
3306    @parametrize("head_dim", [8, 64])
3307    @parametrize("dropout_p", [0.0, 0.1])
3308    @parametrize("dtype", [torch.float16])
3309    @parametrize("scale", [None, "l1"])
3310    @parametrize("is_causal", [True, False])
3311    def test_flash_attention_vs_math_ref_grads_nestedtensor(self, device, batch_size: int, max_seq_len_q: int, max_seq_len_kv: int,
3312                                                            head_dim: int, dropout_p: float, dtype: torch.dtype,
3313                                                            scale: str, is_causal: bool):
3314        if is_causal:
3315            # TODO we should support this
3316            self.assertRaisesRegex(RuntimeError, "Nested tensors for query / key are not supported when is_causal=True")
3317            return
3318        scale = scale if scale is None else (1 / head_dim)
3319        n_heads = 4
3320        seq_lens_q = torch.randint(low=1, high=max_seq_len_q, size=(batch_size,))
3321        # Set one entry to max length
3322        seq_lens_q[torch.randint(0, batch_size, size=(1,))] = max_seq_len_q
3323        seq_lens_kv = torch.randint(low=1, high=max_seq_len_kv, size=(batch_size,))
3324        seq_lens_kv[torch.randint(0, batch_size, size=(1,))] = max_seq_len_kv
3325
3326        def rand_nt(sequence_list, num_heads, head_dim):
3327            tensors = [torch.rand((num_heads, seq_len, head_dim)) for seq_len in sequence_list]
3328            return torch.nested.nested_tensor(tensors, requires_grad=True, device=device, dtype=dtype)
3329
3330        query = rand_nt(seq_lens_q, n_heads, head_dim)
3331        key = rand_nt(seq_lens_kv, n_heads, head_dim)
3332        value = rand_nt(seq_lens_kv, n_heads, head_dim)
3333
3334        # Run the math kernel on low precision references
3335        query_ref_lp = query.clone().detach().requires_grad_(True)
3336        key_ref_lp = key.clone().detach().requires_grad_(True)
3337        value_ref_lp = value.clone().detach().requires_grad_(True)
3338
3339        query_ref = query.clone().detach().to(torch.float32).requires_grad_(True)
3340        key_ref = key.clone().detach().to(torch.float32).requires_grad_(True)
3341        value_ref = value.clone().detach().to(torch.float32).requires_grad_(True)
3342
3343        is_dropout = dropout_p > 0.0
3344
3345        if not is_dropout:
3346            with sdpa_kernel(backends=[SDPBackend.FLASH_ATTENTION]):
3347                out = F.scaled_dot_product_attention(query, key, value, dropout_p=dropout_p, is_causal=is_causal, scale=scale)
3348            with sdpa_kernel(backends=[SDPBackend.MATH]):
3349                # High Precision Math Reference
3350                out_ref = F.scaled_dot_product_attention(
3351                    query_ref, key_ref, value_ref, is_causal=is_causal, scale=scale)
3352                # Low Precision Math Reference
3353                out_lp_ref = F.scaled_dot_product_attention(
3354                    query_ref_lp, key_ref_lp, value_ref_lp, is_causal=is_causal, scale=scale)
3355        else:
3356            # Create real output
3357            output_tuple = torch.ops.aten._scaled_dot_product_flash_attention(
3358                query, key, value, dropout_p=dropout_p, is_causal=is_causal,
3359                scale=scale, return_debug_mask=is_dropout)
3360            out = output_tuple[0]
3361            dbug_mask = output_tuple[-1]
3362
3363            query_padding_mask = torch.arange(max_seq_len_q).unsqueeze(0).expand(
3364                batch_size, max_seq_len_q
3365            ) < seq_lens_q.unsqueeze(-1)
3366            query_padding_mask = query_padding_mask.to("cuda")
3367
3368            key_padding_mask = torch.arange(max_seq_len_kv).unsqueeze(0).expand(
3369                batch_size, max_seq_len_kv
3370            ) < seq_lens_kv.unsqueeze(-1)
3371            key_padding_mask = key_padding_mask.to("cuda")
3372
3373            softmax_mask = self.convert_flash_attn_S_to_softmax(
3374                dbug_mask, max_seq_len_q, max_seq_len_kv, query_padding_mask, key_padding_mask, causal=is_causal)
3375            dropout_mask = softmax_mask >= 0
3376            nt_stack = []
3377            for tensor_component in range(batch_size):
3378                batch_stack = []
3379                for head in range(n_heads):
3380                    batch_stack.append(dropout_mask[tensor_component, head,
3381                                                    0:seq_lens_q[tensor_component],
3382                                                    0:seq_lens_kv[tensor_component]].unsqueeze(0))
3383                nt_stack.append(torch.cat(batch_stack))
3384            nested_dropout_mask = torch.nested.nested_tensor(nt_stack)
3385            # High Precision Math Reference
3386            out_ref = torch.ops.aten._scaled_dot_product_attention_math(
3387                query_ref, key_ref, value_ref, dropout_p=dropout_p,
3388                is_causal=is_causal, scale=scale, dropout_mask=nested_dropout_mask)[0]
3389            # Low Precision Math Reference
3390            out_lp_ref = torch.ops.aten._scaled_dot_product_attention_math(
3391                query_ref_lp, key_ref_lp, value_ref_lp, dropout_p=dropout_p, is_causal=is_causal, scale=scale,
3392                dropout_mask=nested_dropout_mask)[0]
3393
3394        upstream_grad = out.detach().clone().contiguous()
3395
3396        out.backward(upstream_grad)
3397        out_ref.backward(upstream_grad.to(out_ref.dtype))
3398        out_lp_ref.backward(upstream_grad.to(out_lp_ref.dtype))
3399
3400        # See [Note] Fused Tolerances above
3401        output_ref_atol, output_ref_rtol = calculate_nt_tolerances(out_ref, out_lp_ref, out.dtype)
3402        grad_q_ref_atol, grad_q_ref_rtol = calculate_nt_tolerances(query_ref.grad, query_ref_lp.grad,
3403                                                                   query.grad.dtype, fudge_factor=4)
3404        grad_k_ref_atol, grad_k_ref_rtol = calculate_nt_tolerances(key_ref.grad, key_ref_lp.grad, key.grad.dtype)
3405        grad_v_ref_atol, grad_v_ref_rtol = calculate_nt_tolerances(value_ref.grad, value_ref_lp.grad, value.grad.dtype)
3406
3407        self.assertEqual(out, out_ref.to(out.dtype), atol=output_ref_atol, rtol=output_ref_rtol)
3408        self.assertEqual(query.grad, query_ref.grad.to(query.grad.dtype),
3409                         atol=grad_q_ref_atol, rtol=grad_q_ref_rtol)
3410        self.assertEqual(key.grad.contiguous(), key_ref.grad.contiguous().to(key.grad.dtype),
3411                         atol=grad_k_ref_atol, rtol=grad_k_ref_rtol)
3412        self.assertEqual(value.grad, value_ref.grad.to(value.grad.dtype),
3413                         atol=grad_v_ref_atol, rtol=grad_v_ref_rtol)
3414
3415class TestAttnBias(NNTestCase):
3416
3417    def run_test(
3418        self,
3419        device,
3420        make_q,
3421        make_kv,
3422        attn_bias=None,
3423        forw_tolerances: Optional[Tolerances] = None,
3424        grad_tolerances: Optional[Tolerances] = None,
3425        backend=None,
3426    ):
3427        if backend is not None:
3428            torch._dynamo.reset()
3429
3430        query, key, value = make_q(), make_kv(), make_kv()
3431        query_prototype, key_prototype, value_prototype = query_key_value_clones(query, key, value)
3432
3433        realized = attn_bias._materialize(device) if attn_bias is not None else None
3434        pytorch_output = scaled_dot_product_attention(
3435            query, key, value, attn_mask=realized, dropout_p=0.0, is_causal=False
3436        )
3437
3438        sdpa_op = (
3439            torch.compile(scaled_dot_product_attention, backend=backend)
3440            if backend is not None
3441            else scaled_dot_product_attention
3442        )
3443        sdpa_output = sdpa_op(
3444            query_prototype,
3445            key_prototype,
3446            value_prototype,
3447            attn_mask=attn_bias,
3448            dropout_p=0.0,
3449            is_causal=False,
3450            scale=None,
3451        )
3452
3453        dOut = torch.randn_like(pytorch_output)
3454        pytorch_output.backward(dOut)
3455        sdpa_output.backward(dOut)
3456
3457        # Use default assert_close tolerances for dtypes
3458        if forw_tolerances is None:
3459            forw_tolerances = Tolerances(atol=None, rtol=None)
3460        if grad_tolerances is None:
3461            grad_tolerances = Tolerances(atol=None, rtol=None)
3462
3463        torch.testing.assert_close(pytorch_output, sdpa_output, rtol=forw_tolerances.rtol, atol=forw_tolerances.atol)
3464        torch.testing.assert_close(query.grad, query_prototype.grad, rtol=grad_tolerances.rtol, atol=grad_tolerances.atol)
3465        torch.testing.assert_close(key.grad, key_prototype.grad, rtol=grad_tolerances.rtol, atol=grad_tolerances.atol)
3466        torch.testing.assert_close(value.grad, value_prototype.grad, rtol=grad_tolerances.rtol, atol=grad_tolerances.atol)
3467
3468    @skipIfRocm  # No support for the second variant for now
3469    @parametrize("causal_variant", [CausalVariant.UPPER_LEFT, CausalVariant.LOWER_RIGHT])
3470    @parametrize(
3471        "shape",
3472        [(16, 16, 128, 128, 16), (16, 16, 128, 256, 32), (16, 16, 256, 128, 32), (1, 1, 23, 56, 15)],
3473    )
3474    def test_causal_variants(self, device, causal_variant: CausalVariant, shape: List[Tuple[int]]):
3475        make_tensor = partial(
3476            torch.rand, device=device, dtype=torch.float16, requires_grad=True
3477        )
3478
3479        bsz, num_heads, seq_len_q, seq_len_kv, head_dim = shape
3480        make_q_tensor = partial(make_tensor, SdpaShape(bsz, num_heads, seq_len_q, head_dim))
3481        make_kv_tensor = partial(make_tensor, SdpaShape(bsz, num_heads, seq_len_kv, head_dim))
3482        if causal_variant == CausalVariant.LOWER_RIGHT and seq_len_q > seq_len_kv:
3483            self.skipTest(
3484                "Lower right causal mask will produce NaNs in the output when seq_len_q > seq_len_kv!"
3485            )
3486
3487        forw_tol = Tolerances(1e-3, 1e-3)
3488        grad_tol = Tolerances(5e-3, 5e-3)
3489
3490        if causal_variant == CausalVariant.UPPER_LEFT:
3491            attn_bias = causal_upper_left(seq_len_q, seq_len_kv)
3492        else:
3493            attn_bias = causal_lower_right(seq_len_q, seq_len_kv)
3494
3495        self.run_test(device, make_q_tensor, make_kv_tensor, attn_bias, forw_tol, grad_tol, backend=None)
3496
3497    @skipIfRocm  # CausalVariant
3498    @parametrize("causal_variant", [CausalVariant.UPPER_LEFT, CausalVariant.LOWER_RIGHT])
3499    @parametrize(
3500        "shape",
3501        [(16, 16, 128, 128, 16), (16, 16, 128, 256, 32), (16, 16, 256, 128, 32), (1, 1, 23, 56, 15)],
3502    )
3503    @unittest.skipIf(IS_WINDOWS, "torch.compile is not supported on windows")
3504    @skipIfTorchDynamo("This function already calls torch.compile.")
3505    def test_causal_variants_compile(self, device, causal_variant: CausalVariant, shape: List[Tuple[int]]):
3506        cnts = CompileCounterWithBackend("aot_eager")
3507        make_tensor = partial(
3508            torch.rand, device=device, dtype=torch.float16, requires_grad=True
3509        )
3510
3511        bsz, num_heads, seq_len_q, seq_len_kv, head_dim = shape
3512        make_q_tensor = partial(make_tensor, SdpaShape(bsz, num_heads, seq_len_q, head_dim))
3513        make_kv_tensor = partial(make_tensor, SdpaShape(bsz, num_heads, seq_len_kv, head_dim))
3514        if causal_variant == CausalVariant.LOWER_RIGHT and seq_len_q > seq_len_kv:
3515            self.skipTest(
3516                "Lower right causal mask will produce NaNs in the output when seq_len_q > seq_len_kv!"
3517            )
3518        forw_tol = Tolerances(1e-3, 1e-3)
3519        grad_tol = Tolerances(5e-3, 5e-3)
3520
3521        if causal_variant == CausalVariant.UPPER_LEFT:
3522            attn_bias = causal_upper_left(seq_len_q, seq_len_kv)
3523        else:
3524            attn_bias = causal_lower_right(seq_len_q, seq_len_kv)
3525
3526        self.run_test(device, make_q_tensor, make_kv_tensor, attn_bias, forw_tol, grad_tol, backend=cnts)
3527        self.assertEqual(cnts.frame_count, 1, "Compiled graph should have 1 frame!")
3528
3529    @parametrize("shape", [(16, 16, 128, 128, 16), (16, 16, 128, 256, 32), (16, 16, 256, 128, 32), (1, 1, 23, 56, 15)])
3530    def test_is_causal_equals_upper_left(self, device, shape: List[Tuple[int]]):
3531        make_tensor = partial(
3532            torch.rand, device=device, dtype=torch.float16, requires_grad=True
3533        )
3534
3535        bsz, num_heads, seq_len_q, seq_len_kv, head_dim = shape
3536        make_q_tensor = partial(make_tensor, SdpaShape(bsz, num_heads, seq_len_q, head_dim))
3537        make_kv_tensor = partial(make_tensor, SdpaShape(bsz, num_heads, seq_len_kv, head_dim))
3538
3539        forw_tol = Tolerances(1e-3, 1e-3)
3540        grad_tol = Tolerances(5e-3, 5e-3)
3541
3542        query = make_q_tensor()
3543        key = make_kv_tensor()
3544        value = make_kv_tensor()
3545        attn_bias = causal_upper_left(seq_len_q, seq_len_kv)
3546
3547        out_attn_bias = scaled_dot_product_attention(query, key, value, attn_mask=attn_bias, dropout_p=0.0)
3548        out_is_causal = scaled_dot_product_attention(query, key, value, is_causal=True, dropout_p=0.0)
3549        torch.testing.assert_close(out_attn_bias, out_is_causal, rtol=forw_tol.rtol, atol=forw_tol.atol)
3550
3551    def test_is_causal_and_mask_fails(self, device):
3552        make_tensor = partial(
3553            torch.rand, device=device, dtype=torch.float16, requires_grad=True
3554        )
3555        make_q_tensor = partial(make_tensor, SdpaShape(16, 16, 128, 16))
3556        make_kv_tensor = partial(make_tensor, SdpaShape(16, 16, 128, 16))
3557
3558        query = make_q_tensor()
3559        key = make_kv_tensor()
3560        value = make_kv_tensor()
3561        attn_bias = causal_upper_left(128, 128)
3562
3563        with self.assertRaisesRegex(ValueError, "CausalBias should not be used with causal=True"):
3564            scaled_dot_product_attention(query, key, value, attn_mask=attn_bias, is_causal=True, dropout_p=0.0)
3565
3566if NOTEST_CPU:
3567    device_types = ("cuda", )
3568else:
3569    device_types = ("cpu", "cuda")
3570
3571instantiate_device_type_tests(TestTransformers, globals(), only_for=device_types)
3572instantiate_device_type_tests(TestSDPAFailureModes, globals(), only_for=device_types)
3573instantiate_device_type_tests(TestSDPA, globals(), only_for=device_types)
3574instantiate_device_type_tests(TestSDPACudaOnly, globals(), only_for=("cuda"))
3575instantiate_device_type_tests(TestAttnBias, globals(), only_for=device_types)
3576
3577if __name__ == '__main__':
3578    run_tests()
3579