xref: /aosp_15_r20/external/pytorch/test/test_transformers.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1*da0073e9SAndroid Build Coastguard Worker# Owner(s): ["module: nn"]
2*da0073e9SAndroid Build Coastguard Worker
3*da0073e9SAndroid Build Coastguard Workerimport contextlib
4*da0073e9SAndroid Build Coastguard Workerfrom functools import partial
5*da0073e9SAndroid Build Coastguard Workerfrom collections import namedtuple
6*da0073e9SAndroid Build Coastguard Workerimport sys
7*da0073e9SAndroid Build Coastguard Workerimport torch
8*da0073e9SAndroid Build Coastguard Workerimport torch.nn as nn
9*da0073e9SAndroid Build Coastguard Workerimport torch.nn.functional as F
10*da0073e9SAndroid Build Coastguard Workerfrom torch.nn.functional import scaled_dot_product_attention
11*da0073e9SAndroid Build Coastguard Workerfrom torch.nn.attention import sdpa_kernel, SDPBackend
12*da0073e9SAndroid Build Coastguard Workerfrom torch.nn.attention.bias import CausalVariant, causal_lower_right, causal_upper_left
13*da0073e9SAndroid Build Coastguard Workerfrom torch.nn.parameter import Parameter
14*da0073e9SAndroid Build Coastguard Workerimport unittest
15*da0073e9SAndroid Build Coastguard Workerfrom unittest.mock import patch, MagicMock, ANY
16*da0073e9SAndroid Build Coastguard Workerimport math
17*da0073e9SAndroid Build Coastguard Workerimport torch.optim as optim
18*da0073e9SAndroid Build Coastguard Workerfrom torch.testing._internal.common_device_type import instantiate_device_type_tests, onlyCUDA, onlyCPU
19*da0073e9SAndroid Build Coastguard Workerfrom typing import List, Tuple, Optional
20*da0073e9SAndroid Build Coastguard Workerfrom torch.testing._internal.common_nn import NNTestCase
21*da0073e9SAndroid Build Coastguard Workerfrom torch.testing._internal.common_utils import (
22*da0073e9SAndroid Build Coastguard Worker    TEST_WITH_ROCM,
23*da0073e9SAndroid Build Coastguard Worker    skipIfRocm,
24*da0073e9SAndroid Build Coastguard Worker    skipIfTorchDynamo,
25*da0073e9SAndroid Build Coastguard Worker    TEST_FAIRSEQ,
26*da0073e9SAndroid Build Coastguard Worker    run_tests,
27*da0073e9SAndroid Build Coastguard Worker    parametrize,
28*da0073e9SAndroid Build Coastguard Worker    freeze_rng_state,
29*da0073e9SAndroid Build Coastguard Worker    TEST_WITH_CROSSREF,
30*da0073e9SAndroid Build Coastguard Worker    slowTest,
31*da0073e9SAndroid Build Coastguard Worker    set_default_dtype,
32*da0073e9SAndroid Build Coastguard Worker    gradcheck,
33*da0073e9SAndroid Build Coastguard Worker    make_tensor,
34*da0073e9SAndroid Build Coastguard Worker    NOTEST_CPU,
35*da0073e9SAndroid Build Coastguard Worker    IS_WINDOWS,
36*da0073e9SAndroid Build Coastguard Worker    TEST_WITH_TORCHDYNAMO,
37*da0073e9SAndroid Build Coastguard Worker)
38*da0073e9SAndroid Build Coastguard Workerfrom torch._dynamo.testing import CompileCounterWithBackend
39*da0073e9SAndroid Build Coastguard Worker
40*da0073e9SAndroid Build Coastguard Worker
41*da0073e9SAndroid Build Coastguard Workerfrom torch.testing._internal.common_methods_invocations import wrapper_set_seed
42*da0073e9SAndroid Build Coastguard Workerfrom torch.testing._internal.common_cuda import (
43*da0073e9SAndroid Build Coastguard Worker    IS_JETSON, SM80OrLater, PLATFORM_SUPPORTS_FLASH_ATTENTION,
44*da0073e9SAndroid Build Coastguard Worker    PLATFORM_SUPPORTS_MEM_EFF_ATTENTION,
45*da0073e9SAndroid Build Coastguard Worker    PLATFORM_SUPPORTS_FUSED_ATTENTION,
46*da0073e9SAndroid Build Coastguard Worker    PLATFORM_SUPPORTS_CUDNN_ATTENTION
47*da0073e9SAndroid Build Coastguard Worker)
48*da0073e9SAndroid Build Coastguard Worker
49*da0073e9SAndroid Build Coastguard Workerif TEST_FAIRSEQ:
50*da0073e9SAndroid Build Coastguard Worker    import fairseq.models.transformer as fairseq_transformer
51*da0073e9SAndroid Build Coastguard Worker
52*da0073e9SAndroid Build Coastguard WorkerSdpaShape = namedtuple('Sdpa_Shape', ['batch', 'num_heads', 'seq_len', 'head_dim'])
53*da0073e9SAndroid Build Coastguard WorkerTolerances = namedtuple('Tolerances', ['atol', 'rtol'])
54*da0073e9SAndroid Build Coastguard Worker
55*da0073e9SAndroid Build Coastguard Worker@contextlib.contextmanager
56*da0073e9SAndroid Build Coastguard Workerdef use_deterministic_algorithims(mode: bool, warn_only: bool):
57*da0073e9SAndroid Build Coastguard Worker    r"""
58*da0073e9SAndroid Build Coastguard Worker    This context manager can be used to temporarily enable or disable deterministic algorithms.
59*da0073e9SAndroid Build Coastguard Worker    Upon exiting the context manager, the previous state of the flag will be restored.
60*da0073e9SAndroid Build Coastguard Worker    """
61*da0073e9SAndroid Build Coastguard Worker    previous_mode: bool = torch.are_deterministic_algorithms_enabled()
62*da0073e9SAndroid Build Coastguard Worker    previous_warn_only: bool = torch.is_deterministic_algorithms_warn_only_enabled()
63*da0073e9SAndroid Build Coastguard Worker    try:
64*da0073e9SAndroid Build Coastguard Worker        torch.use_deterministic_algorithms(mode, warn_only=warn_only)
65*da0073e9SAndroid Build Coastguard Worker        yield {}
66*da0073e9SAndroid Build Coastguard Worker    finally:
67*da0073e9SAndroid Build Coastguard Worker        torch.use_deterministic_algorithms(previous_mode, warn_only=previous_warn_only)
68*da0073e9SAndroid Build Coastguard Worker
69*da0073e9SAndroid Build Coastguard Worker
70*da0073e9SAndroid Build Coastguard Worker# Found in torch/testing/_comparison.py
71*da0073e9SAndroid Build Coastguard Workerdefault_atol = {torch.float16: 1e-3, torch.bfloat16: 1e-3, torch.float32: 1e-5}
72*da0073e9SAndroid Build Coastguard Workerdefault_rtol = {torch.float16: 1e-3, torch.bfloat16: 1.6e-2, torch.float32: 1.3e-6}
73*da0073e9SAndroid Build Coastguard Worker
74*da0073e9SAndroid Build Coastguard WorkerisSM8XDevice = torch.cuda.is_available() and torch.cuda.get_device_capability() in [(8, 6), (8, 7), (8, 9)]
75*da0073e9SAndroid Build Coastguard WorkerisSM90Device = torch.cuda.is_available() and torch.cuda.get_device_capability() == (9, 0)
76*da0073e9SAndroid Build Coastguard WorkerisSM5xDevice = torch.cuda.is_available() and torch.cuda.get_device_capability()[0] == 5
77*da0073e9SAndroid Build Coastguard WorkerisLessThanSM80Device = torch.cuda.is_available() and torch.cuda.get_device_capability()[0] < 8
78*da0073e9SAndroid Build Coastguard Worker
79*da0073e9SAndroid Build Coastguard Workerdef get_rtol(true_value: torch.Tensor, computed_value: torch.Tensor) -> float:
80*da0073e9SAndroid Build Coastguard Worker    deviation = true_value - computed_value
81*da0073e9SAndroid Build Coastguard Worker    deviation = torch.abs(deviation / true_value)
82*da0073e9SAndroid Build Coastguard Worker    # Fill in the nans with the default rtol
83*da0073e9SAndroid Build Coastguard Worker    torch.nan_to_num_(deviation, nan=default_rtol[computed_value.dtype])
84*da0073e9SAndroid Build Coastguard Worker    return deviation.max().item()
85*da0073e9SAndroid Build Coastguard Worker
86*da0073e9SAndroid Build Coastguard Worker
87*da0073e9SAndroid Build Coastguard Workerdef get_atol(true_value: torch.Tensor, computed_value: torch.Tensor) -> float:
88*da0073e9SAndroid Build Coastguard Worker    deviation = true_value - computed_value
89*da0073e9SAndroid Build Coastguard Worker    atol = torch.abs(deviation).max().item()
90*da0073e9SAndroid Build Coastguard Worker    return atol
91*da0073e9SAndroid Build Coastguard Worker
92*da0073e9SAndroid Build Coastguard Worker
93*da0073e9SAndroid Build Coastguard Workerdef get_tolerances(
94*da0073e9SAndroid Build Coastguard Worker    true_value: torch.Tensor,
95*da0073e9SAndroid Build Coastguard Worker    computed_value: torch.Tensor,
96*da0073e9SAndroid Build Coastguard Worker    fudge_factor: Optional[float] = None,
97*da0073e9SAndroid Build Coastguard Worker) -> Tuple[float, float]:
98*da0073e9SAndroid Build Coastguard Worker    """Returns the absolute and relative tolerances for comparing two tensors."""
99*da0073e9SAndroid Build Coastguard Worker    fudge_factor = fudge_factor if fudge_factor is not None else 1.0
100*da0073e9SAndroid Build Coastguard Worker    atol = get_atol(true_value, computed_value)
101*da0073e9SAndroid Build Coastguard Worker    rtol = get_rtol(true_value, computed_value)
102*da0073e9SAndroid Build Coastguard Worker
103*da0073e9SAndroid Build Coastguard Worker    atol = fudge_factor * max(atol, default_atol[computed_value.dtype])
104*da0073e9SAndroid Build Coastguard Worker    rtol = fudge_factor * max(rtol, default_rtol[computed_value.dtype])
105*da0073e9SAndroid Build Coastguard Worker    # torch.isclose() has weird behavior around see:
106*da0073e9SAndroid Build Coastguard Worker    # https://github.com/pytorch/pytorch/issues/102400
107*da0073e9SAndroid Build Coastguard Worker    if rtol > 1e30:
108*da0073e9SAndroid Build Coastguard Worker        rtol = default_rtol[computed_value.dtype]
109*da0073e9SAndroid Build Coastguard Worker    return atol, rtol
110*da0073e9SAndroid Build Coastguard Worker
111*da0073e9SAndroid Build Coastguard Worker
112*da0073e9SAndroid Build Coastguard Workerdef query_key_value_clones(query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, dtype: torch.dtype = None):
113*da0073e9SAndroid Build Coastguard Worker    """ Clones the query, key, and value tensors and moves them to the specified dtype. """
114*da0073e9SAndroid Build Coastguard Worker    if dtype is None:
115*da0073e9SAndroid Build Coastguard Worker        dtype = query.dtype
116*da0073e9SAndroid Build Coastguard Worker    query_ref = query.clone().detach().to(dtype).requires_grad_(query.requires_grad)
117*da0073e9SAndroid Build Coastguard Worker    key_ref = key.clone().detach().to(dtype).requires_grad_(key.requires_grad)
118*da0073e9SAndroid Build Coastguard Worker    value_ref = value.clone().detach().to(dtype).requires_grad_(value.requires_grad)
119*da0073e9SAndroid Build Coastguard Worker    return query_ref, key_ref, value_ref
120*da0073e9SAndroid Build Coastguard Worker
121*da0073e9SAndroid Build Coastguard Workerdef get_platform_specific_sdpa():
122*da0073e9SAndroid Build Coastguard Worker    ret = []
123*da0073e9SAndroid Build Coastguard Worker    if PLATFORM_SUPPORTS_FLASH_ATTENTION:
124*da0073e9SAndroid Build Coastguard Worker        ret.append(SDPBackend.FLASH_ATTENTION)
125*da0073e9SAndroid Build Coastguard Worker    if PLATFORM_SUPPORTS_MEM_EFF_ATTENTION:
126*da0073e9SAndroid Build Coastguard Worker        ret.append(SDPBackend.EFFICIENT_ATTENTION)
127*da0073e9SAndroid Build Coastguard Worker    if PLATFORM_SUPPORTS_CUDNN_ATTENTION:
128*da0073e9SAndroid Build Coastguard Worker        ret.append(SDPBackend.CUDNN_ATTENTION)
129*da0073e9SAndroid Build Coastguard Worker    if not ret:
130*da0073e9SAndroid Build Coastguard Worker        # Add a placeholder, an empty list causes "An empty arg_values was passed to @parametrize"
131*da0073e9SAndroid Build Coastguard Worker        ret.append(SDPBackend.EFFICIENT_ATTENTION)
132*da0073e9SAndroid Build Coastguard Worker    return ret
133*da0073e9SAndroid Build Coastguard Worker
134*da0073e9SAndroid Build Coastguard WorkerPLATFORM_SPECIFIC_SDPA = get_platform_specific_sdpa()
135*da0073e9SAndroid Build Coastguard Worker# Indicate the Efficient attention backend can support:
136*da0073e9SAndroid Build Coastguard Worker# 1. sequence longher than 512
137*da0073e9SAndroid Build Coastguard Worker# 2. head dimsion larger than 64
138*da0073e9SAndroid Build Coastguard WorkerMEM_EFF_CAPABILITY_MATCHES_SM80 = SM80OrLater or TEST_WITH_ROCM
139*da0073e9SAndroid Build Coastguard Worker
140*da0073e9SAndroid Build Coastguard Workerdef rand_sdpa_tensor(shape: SdpaShape, device: str, dtype: torch.dtype, type: str,
141*da0073e9SAndroid Build Coastguard Worker                     requires_grad: bool = False, packed: bool = False) -> torch.Tensor:
142*da0073e9SAndroid Build Coastguard Worker    """Creates rand dense or nested tensor with given shape and type.
143*da0073e9SAndroid Build Coastguard Worker
144*da0073e9SAndroid Build Coastguard Worker    Args:
145*da0073e9SAndroid Build Coastguard Worker        shape (Tuple[int]): Shape of Tensor to construct
146*da0073e9SAndroid Build Coastguard Worker        device (str): which device to create tensor on
147*da0073e9SAndroid Build Coastguard Worker        dtype (torch.dtype): Tensors' dtype
148*da0073e9SAndroid Build Coastguard Worker        type (str): Nested or Dense
149*da0073e9SAndroid Build Coastguard Worker        requires_grad (bool, optional): Tensors grad status. Defaults to False.
150*da0073e9SAndroid Build Coastguard Worker        packed (bool, optional): Whether to create a single QKV packed or not. Defaults to False.
151*da0073e9SAndroid Build Coastguard Worker
152*da0073e9SAndroid Build Coastguard Worker    Returns:
153*da0073e9SAndroid Build Coastguard Worker        torch.Tensor: A new tensor
154*da0073e9SAndroid Build Coastguard Worker    """
155*da0073e9SAndroid Build Coastguard Worker    batch, num_heads, seq_len, head_dim = shape.batch, shape.num_heads, shape.seq_len, shape.head_dim
156*da0073e9SAndroid Build Coastguard Worker    if type == "nested":
157*da0073e9SAndroid Build Coastguard Worker        if isinstance(seq_len, list):
158*da0073e9SAndroid Build Coastguard Worker            def _size(i):
159*da0073e9SAndroid Build Coastguard Worker                return (seq_len[i], num_heads, head_dim) if not packed else (seq_len[i], 3 * num_heads * head_dim)
160*da0073e9SAndroid Build Coastguard Worker
161*da0073e9SAndroid Build Coastguard Worker            return torch.nested.nested_tensor([
162*da0073e9SAndroid Build Coastguard Worker                torch.randn(_size(i), device=device, dtype=dtype, requires_grad=requires_grad)
163*da0073e9SAndroid Build Coastguard Worker                for i in range(batch)])
164*da0073e9SAndroid Build Coastguard Worker        else:
165*da0073e9SAndroid Build Coastguard Worker            size = (seq_len, num_heads, head_dim) if not packed else (seq_len, 3 * num_heads * head_dim)
166*da0073e9SAndroid Build Coastguard Worker            return torch.nested.nested_tensor([
167*da0073e9SAndroid Build Coastguard Worker                torch.randn(size, device=device, dtype=dtype, requires_grad=requires_grad)
168*da0073e9SAndroid Build Coastguard Worker                for _ in range(batch)])
169*da0073e9SAndroid Build Coastguard Worker    else:
170*da0073e9SAndroid Build Coastguard Worker        assert (isinstance(seq_len, int))
171*da0073e9SAndroid Build Coastguard Worker        size = (batch, seq_len, num_heads, head_dim) if not packed else (batch, seq_len, 3 * num_heads * head_dim)
172*da0073e9SAndroid Build Coastguard Worker        return torch.randn(size, device=device, dtype=dtype, requires_grad=requires_grad)
173*da0073e9SAndroid Build Coastguard Worker
174*da0073e9SAndroid Build Coastguard Workerdef calculate_nt_tolerances(nt_ref_hp, nt_ref_lp, default_dtype, fudge_factor=1):
175*da0073e9SAndroid Build Coastguard Worker    # TODO use NT ops when we have implemented Max for NestedTensor instead of unrolling
176*da0073e9SAndroid Build Coastguard Worker    ref_atol = default_atol[default_dtype]
177*da0073e9SAndroid Build Coastguard Worker    ref_rtol = default_rtol[default_dtype]
178*da0073e9SAndroid Build Coastguard Worker    for tensor_component_ref, tensor_component_ref_lp in zip(nt_ref_hp.unbind(), nt_ref_lp.unbind()):
179*da0073e9SAndroid Build Coastguard Worker        ref_atol = max((fudge_factor * torch.abs(tensor_component_ref - tensor_component_ref_lp)).max().item(), ref_atol)
180*da0073e9SAndroid Build Coastguard Worker        ref_rtol = max(get_rtol(tensor_component_ref, tensor_component_ref_lp), ref_rtol)
181*da0073e9SAndroid Build Coastguard Worker    return ref_atol, ref_rtol
182*da0073e9SAndroid Build Coastguard Worker
183*da0073e9SAndroid Build Coastguard Workerclass TestTransformers(NNTestCase):
184*da0073e9SAndroid Build Coastguard Worker    _do_cuda_memory_leak_check = True
185*da0073e9SAndroid Build Coastguard Worker    _do_cuda_non_default_stream = True
186*da0073e9SAndroid Build Coastguard Worker
187*da0073e9SAndroid Build Coastguard Worker    @onlyCUDA
188*da0073e9SAndroid Build Coastguard Worker    @unittest.skip("4D mask not supported yet - activate when 4D mask supported")
189*da0073e9SAndroid Build Coastguard Worker    def test_self_attn_TxT_attn_mask(self, device):
190*da0073e9SAndroid Build Coastguard Worker        embed_dim = 16
191*da0073e9SAndroid Build Coastguard Worker        num_heads = 4
192*da0073e9SAndroid Build Coastguard Worker        batch_size = 10
193*da0073e9SAndroid Build Coastguard Worker        tgt_len = 16
194*da0073e9SAndroid Build Coastguard Worker
195*da0073e9SAndroid Build Coastguard Worker        query = torch.rand(batch_size, tgt_len, embed_dim, device=device)  # [N, T, D]
196*da0073e9SAndroid Build Coastguard Worker        attn_mask = torch.randint(0, 2, (tgt_len, tgt_len)).cuda().float()  # [T, T]
197*da0073e9SAndroid Build Coastguard Worker        attn_mask = attn_mask.masked_fill(attn_mask == 0, float('-inf')).masked_fill(attn_mask == 1, 0.0)
198*da0073e9SAndroid Build Coastguard Worker
199*da0073e9SAndroid Build Coastguard Worker        attn_mask_4d = attn_mask.expand(batch_size, num_heads, tgt_len, tgt_len)
200*da0073e9SAndroid Build Coastguard Worker
201*da0073e9SAndroid Build Coastguard Worker        mta_model = torch.nn.MultiheadAttention(embed_dim, num_heads, batch_first=True).cuda()
202*da0073e9SAndroid Build Coastguard Worker        mta_model.eval()
203*da0073e9SAndroid Build Coastguard Worker
204*da0073e9SAndroid Build Coastguard Worker        # Generate 3D results
205*da0073e9SAndroid Build Coastguard Worker        with torch.inference_mode():
206*da0073e9SAndroid Build Coastguard Worker            output_mask_4d = mta_model(query, query, query, attn_mask=attn_mask_4d)[0]
207*da0073e9SAndroid Build Coastguard Worker            output_mask_4d = output_mask_4d.transpose(0, 1)  # [N, T, D]
208*da0073e9SAndroid Build Coastguard Worker
209*da0073e9SAndroid Build Coastguard Worker            output_mask_TxT = mta_model(query, query, query, attn_mask=attn_mask)[0]
210*da0073e9SAndroid Build Coastguard Worker            output_mask_TxT = output_mask_TxT.transpose(0, 1)  # [N, T, D]
211*da0073e9SAndroid Build Coastguard Worker
212*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(output_mask_4d, output_mask_TxT)
213*da0073e9SAndroid Build Coastguard Worker
214*da0073e9SAndroid Build Coastguard Worker    @slowTest
215*da0073e9SAndroid Build Coastguard Worker    def test_train_with_pad_and_catch_error(self, device):
216*da0073e9SAndroid Build Coastguard Worker        iters = 100
217*da0073e9SAndroid Build Coastguard Worker        pad_mask = torch.tensor([[1, 1, 0, 0]], dtype=torch.bool).to(device)
218*da0073e9SAndroid Build Coastguard Worker        layer = nn.TransformerEncoderLayer(
219*da0073e9SAndroid Build Coastguard Worker            d_model=2,
220*da0073e9SAndroid Build Coastguard Worker            dim_feedforward=4,
221*da0073e9SAndroid Build Coastguard Worker            nhead=2,
222*da0073e9SAndroid Build Coastguard Worker            batch_first=True,
223*da0073e9SAndroid Build Coastguard Worker            activation="gelu",
224*da0073e9SAndroid Build Coastguard Worker            dropout=0,
225*da0073e9SAndroid Build Coastguard Worker        )
226*da0073e9SAndroid Build Coastguard Worker        criterion = nn.MSELoss()
227*da0073e9SAndroid Build Coastguard Worker        encoder = nn.TransformerEncoder(layer, 2).to(device)
228*da0073e9SAndroid Build Coastguard Worker        optimizer = optim.SGD(encoder.parameters(), lr=0.1, momentum=0.9)
229*da0073e9SAndroid Build Coastguard Worker        encoder.train()
230*da0073e9SAndroid Build Coastguard Worker        for i in range(iters):
231*da0073e9SAndroid Build Coastguard Worker            encoder.train()
232*da0073e9SAndroid Build Coastguard Worker            optimizer.zero_grad()
233*da0073e9SAndroid Build Coastguard Worker            inputs = torch.cat([torch.randn(1, 2, 2), torch.zeros(1, 2, 2)], dim=1).to(device)
234*da0073e9SAndroid Build Coastguard Worker
235*da0073e9SAndroid Build Coastguard Worker            outputs = encoder(inputs, src_key_padding_mask=pad_mask)
236*da0073e9SAndroid Build Coastguard Worker
237*da0073e9SAndroid Build Coastguard Worker            loss = criterion(outputs[:, 0:2, :], inputs[:, 0:2, :])
238*da0073e9SAndroid Build Coastguard Worker            loss.backward()
239*da0073e9SAndroid Build Coastguard Worker            optimizer.step()
240*da0073e9SAndroid Build Coastguard Worker
241*da0073e9SAndroid Build Coastguard Worker            with torch.no_grad():
242*da0073e9SAndroid Build Coastguard Worker                test = torch.cat([torch.randn(1, 2, 2), torch.zeros(1, 2, 2)], dim=1).to(device)
243*da0073e9SAndroid Build Coastguard Worker
244*da0073e9SAndroid Build Coastguard Worker                # Expect uint8 type not supported
245*da0073e9SAndroid Build Coastguard Worker                ex = None
246*da0073e9SAndroid Build Coastguard Worker                try:
247*da0073e9SAndroid Build Coastguard Worker                    test_train_uint8 = encoder(test, src_key_padding_mask=pad_mask.to(torch.uint8))
248*da0073e9SAndroid Build Coastguard Worker                except AssertionError as e:
249*da0073e9SAndroid Build Coastguard Worker                    continue
250*da0073e9SAndroid Build Coastguard Worker                self.assertFalse(e, "Failed to catch unsupported uint8 type exception")  # noqa: F821
251*da0073e9SAndroid Build Coastguard Worker
252*da0073e9SAndroid Build Coastguard Worker                test_train_bool = encoder(test, src_key_padding_mask=pad_mask)
253*da0073e9SAndroid Build Coastguard Worker                encoder.eval()
254*da0073e9SAndroid Build Coastguard Worker
255*da0073e9SAndroid Build Coastguard Worker                # Expect long type not supported
256*da0073e9SAndroid Build Coastguard Worker                ex = None
257*da0073e9SAndroid Build Coastguard Worker                try:
258*da0073e9SAndroid Build Coastguard Worker                    test_eval_uint8 = encoder(test, src_key_padding_mask=pad_mask.to(torch.int64))
259*da0073e9SAndroid Build Coastguard Worker                except AssertionError as e:
260*da0073e9SAndroid Build Coastguard Worker                    continue
261*da0073e9SAndroid Build Coastguard Worker                self.assertFalse(e, "Failed to catch unsupported Long type exception")  # noqa: F821
262*da0073e9SAndroid Build Coastguard Worker
263*da0073e9SAndroid Build Coastguard Worker                test_eval_bool = encoder(test, src_key_padding_mask=pad_mask)
264*da0073e9SAndroid Build Coastguard Worker                l1_bool = nn.L1Loss()(test_train_bool[:, 0:2, :], test_eval_bool[:, 0:2, :]).item()
265*da0073e9SAndroid Build Coastguard Worker                self.assertTrue(l1_bool < 1e-4, "Eval/Train difference in pad_mask BOOL")
266*da0073e9SAndroid Build Coastguard Worker
267*da0073e9SAndroid Build Coastguard Worker    @parametrize("attn_mask_dim", [2, 3, None])
268*da0073e9SAndroid Build Coastguard Worker    @parametrize("key_padding_mask_dim", [2, None])
269*da0073e9SAndroid Build Coastguard Worker    @parametrize("mask_dtype", [torch.bool, torch.float32])
270*da0073e9SAndroid Build Coastguard Worker    def test_multiheadattention_fastpath_attn_mask(self, device, attn_mask_dim, key_padding_mask_dim, mask_dtype):
271*da0073e9SAndroid Build Coastguard Worker        with torch.no_grad():
272*da0073e9SAndroid Build Coastguard Worker            B = 2
273*da0073e9SAndroid Build Coastguard Worker            L = 4
274*da0073e9SAndroid Build Coastguard Worker            D = 8
275*da0073e9SAndroid Build Coastguard Worker            H = 4
276*da0073e9SAndroid Build Coastguard Worker
277*da0073e9SAndroid Build Coastguard Worker            if attn_mask_dim == 2:
278*da0073e9SAndroid Build Coastguard Worker                attn_mask = make_tensor((L, L), dtype=mask_dtype, device=device)
279*da0073e9SAndroid Build Coastguard Worker            elif attn_mask_dim == 3:
280*da0073e9SAndroid Build Coastguard Worker                attn_mask = make_tensor((B * H, L, L), dtype=mask_dtype, device=device)
281*da0073e9SAndroid Build Coastguard Worker            elif attn_mask_dim is None:
282*da0073e9SAndroid Build Coastguard Worker                attn_mask = None
283*da0073e9SAndroid Build Coastguard Worker
284*da0073e9SAndroid Build Coastguard Worker            if key_padding_mask_dim == 2:
285*da0073e9SAndroid Build Coastguard Worker                key_padding_mask = make_tensor((B, L), dtype=mask_dtype, device=device)
286*da0073e9SAndroid Build Coastguard Worker            elif key_padding_mask_dim is None:
287*da0073e9SAndroid Build Coastguard Worker                key_padding_mask = None
288*da0073e9SAndroid Build Coastguard Worker
289*da0073e9SAndroid Build Coastguard Worker            mha = nn.MultiheadAttention(D, H, batch_first=True, device=device)
290*da0073e9SAndroid Build Coastguard Worker            X = torch.randn(B, L, D, device=device)
291*da0073e9SAndroid Build Coastguard Worker
292*da0073e9SAndroid Build Coastguard Worker            mha.train()  # disable fast path
293*da0073e9SAndroid Build Coastguard Worker            out, _ = mha(X, X, X, attn_mask=attn_mask, key_padding_mask=key_padding_mask, need_weights=False)
294*da0073e9SAndroid Build Coastguard Worker            mha.eval()  # enable fast path
295*da0073e9SAndroid Build Coastguard Worker            out_fp, _ = mha(X, X, X, attn_mask=attn_mask, key_padding_mask=key_padding_mask, need_weights=False)
296*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(out, out_fp)
297*da0073e9SAndroid Build Coastguard Worker
298*da0073e9SAndroid Build Coastguard Worker    @parametrize("nhead", [1, 4, 8])
299*da0073e9SAndroid Build Coastguard Worker    def test_transformerencoderlayer_src_mask(self, device, nhead):
300*da0073e9SAndroid Build Coastguard Worker        batch_size = 2
301*da0073e9SAndroid Build Coastguard Worker        seqlen = 4
302*da0073e9SAndroid Build Coastguard Worker        d_model = 8
303*da0073e9SAndroid Build Coastguard Worker        dim_feedforward = 32
304*da0073e9SAndroid Build Coastguard Worker
305*da0073e9SAndroid Build Coastguard Worker        model = torch.nn.TransformerEncoderLayer(
306*da0073e9SAndroid Build Coastguard Worker            d_model=d_model,
307*da0073e9SAndroid Build Coastguard Worker            nhead=nhead,
308*da0073e9SAndroid Build Coastguard Worker            dim_feedforward=dim_feedforward,
309*da0073e9SAndroid Build Coastguard Worker            batch_first=True).to(device)
310*da0073e9SAndroid Build Coastguard Worker        src = torch.rand(batch_size, seqlen, d_model).to(device)  # bs, seqlen, d_model
311*da0073e9SAndroid Build Coastguard Worker        src_mask = torch.zeros(seqlen, seqlen).to(torch.bool).to(device)
312*da0073e9SAndroid Build Coastguard Worker
313*da0073e9SAndroid Build Coastguard Worker        model(src, src_mask=src_mask)
314*da0073e9SAndroid Build Coastguard Worker        model.eval()
315*da0073e9SAndroid Build Coastguard Worker        with torch.no_grad():
316*da0073e9SAndroid Build Coastguard Worker            model(src, src_mask=src_mask)
317*da0073e9SAndroid Build Coastguard Worker
318*da0073e9SAndroid Build Coastguard Worker    @parametrize("use_torchscript", [False])
319*da0073e9SAndroid Build Coastguard Worker    @parametrize("enable_nested_tensor", [True, False])
320*da0073e9SAndroid Build Coastguard Worker    @parametrize("use_autocast", [True, False])
321*da0073e9SAndroid Build Coastguard Worker    @parametrize("d_model", [12, 256])
322*da0073e9SAndroid Build Coastguard Worker    def test_transformerencoder_fastpath(self, device, use_torchscript, enable_nested_tensor, use_autocast, d_model):
323*da0073e9SAndroid Build Coastguard Worker        """
324*da0073e9SAndroid Build Coastguard Worker        Test TransformerEncoder fastpath output matches slowpath output
325*da0073e9SAndroid Build Coastguard Worker        """
326*da0073e9SAndroid Build Coastguard Worker        torch.manual_seed(1234)
327*da0073e9SAndroid Build Coastguard Worker        nhead = 4
328*da0073e9SAndroid Build Coastguard Worker        dim_feedforward = d_model
329*da0073e9SAndroid Build Coastguard Worker        batch_first = True
330*da0073e9SAndroid Build Coastguard Worker
331*da0073e9SAndroid Build Coastguard Worker        model = torch.nn.TransformerEncoder(
332*da0073e9SAndroid Build Coastguard Worker            torch.nn.TransformerEncoderLayer(
333*da0073e9SAndroid Build Coastguard Worker                d_model=d_model,
334*da0073e9SAndroid Build Coastguard Worker                nhead=nhead,
335*da0073e9SAndroid Build Coastguard Worker                dim_feedforward=dim_feedforward,
336*da0073e9SAndroid Build Coastguard Worker                batch_first=batch_first),
337*da0073e9SAndroid Build Coastguard Worker            num_layers=2,
338*da0073e9SAndroid Build Coastguard Worker            enable_nested_tensor=enable_nested_tensor
339*da0073e9SAndroid Build Coastguard Worker        ).to(device).eval()
340*da0073e9SAndroid Build Coastguard Worker
341*da0073e9SAndroid Build Coastguard Worker        if use_torchscript:
342*da0073e9SAndroid Build Coastguard Worker            model = torch.jit.script(model)
343*da0073e9SAndroid Build Coastguard Worker
344*da0073e9SAndroid Build Coastguard Worker        # each input is (input, mask)
345*da0073e9SAndroid Build Coastguard Worker        input_mask_pairs = [
346*da0073e9SAndroid Build Coastguard Worker            (
347*da0073e9SAndroid Build Coastguard Worker                torch.rand(3, 2, d_model),
348*da0073e9SAndroid Build Coastguard Worker                [
349*da0073e9SAndroid Build Coastguard Worker                    [0, 1],
350*da0073e9SAndroid Build Coastguard Worker                    [0, 1],
351*da0073e9SAndroid Build Coastguard Worker                    [1, 1]
352*da0073e9SAndroid Build Coastguard Worker                ]
353*da0073e9SAndroid Build Coastguard Worker            ),
354*da0073e9SAndroid Build Coastguard Worker            (
355*da0073e9SAndroid Build Coastguard Worker                torch.rand(2, 100, d_model),
356*da0073e9SAndroid Build Coastguard Worker                [
357*da0073e9SAndroid Build Coastguard Worker                    [0] * 98 + [1] * 2,
358*da0073e9SAndroid Build Coastguard Worker                    [0] * 90 + [1] * 10
359*da0073e9SAndroid Build Coastguard Worker                ]
360*da0073e9SAndroid Build Coastguard Worker            ),
361*da0073e9SAndroid Build Coastguard Worker            # softmax.cu switches from fast->slowpath at masked seqlen 1024. test 1024.
362*da0073e9SAndroid Build Coastguard Worker            (
363*da0073e9SAndroid Build Coastguard Worker                torch.rand(2, 1024, d_model),
364*da0073e9SAndroid Build Coastguard Worker                [
365*da0073e9SAndroid Build Coastguard Worker                    [0] * 1020 + [1] * 4,
366*da0073e9SAndroid Build Coastguard Worker                    [0] * 1024,
367*da0073e9SAndroid Build Coastguard Worker                ]
368*da0073e9SAndroid Build Coastguard Worker            ),
369*da0073e9SAndroid Build Coastguard Worker            (
370*da0073e9SAndroid Build Coastguard Worker                torch.rand(1, 1026, d_model),
371*da0073e9SAndroid Build Coastguard Worker                [[0] * 1024 + [1] * 2]
372*da0073e9SAndroid Build Coastguard Worker            ),
373*da0073e9SAndroid Build Coastguard Worker            # softmax.cu switches from fast->slowpath at masked seqlen 1024. test range of masks above 1024.
374*da0073e9SAndroid Build Coastguard Worker            (
375*da0073e9SAndroid Build Coastguard Worker                torch.rand(4, 1040, d_model),
376*da0073e9SAndroid Build Coastguard Worker                [
377*da0073e9SAndroid Build Coastguard Worker                    [0] * 1024 + [1] * 16,
378*da0073e9SAndroid Build Coastguard Worker                    [0] * 1025 + [1] * 15,
379*da0073e9SAndroid Build Coastguard Worker                    [0] * 1031 + [1] * 9,
380*da0073e9SAndroid Build Coastguard Worker                    [0] * 1040,
381*da0073e9SAndroid Build Coastguard Worker                ]
382*da0073e9SAndroid Build Coastguard Worker            )
383*da0073e9SAndroid Build Coastguard Worker        ]
384*da0073e9SAndroid Build Coastguard Worker        input_mask_pairs = [
385*da0073e9SAndroid Build Coastguard Worker            (
386*da0073e9SAndroid Build Coastguard Worker                torch.tensor(pair[0], device=device, dtype=torch.get_default_dtype()),  # float input
387*da0073e9SAndroid Build Coastguard Worker                torch.tensor(pair[1], device=device, dtype=torch.bool)  # bool mask
388*da0073e9SAndroid Build Coastguard Worker            ) for pair in input_mask_pairs
389*da0073e9SAndroid Build Coastguard Worker        ]
390*da0073e9SAndroid Build Coastguard Worker
391*da0073e9SAndroid Build Coastguard Worker        maybe_autocast = torch.autocast("cuda", dtype=torch.float16) if use_autocast else contextlib.nullcontext()
392*da0073e9SAndroid Build Coastguard Worker        with maybe_autocast:
393*da0073e9SAndroid Build Coastguard Worker            for input, src_key_padding_mask in input_mask_pairs:
394*da0073e9SAndroid Build Coastguard Worker                with torch.no_grad():
395*da0073e9SAndroid Build Coastguard Worker                    fastpath_output = model(input, src_key_padding_mask=src_key_padding_mask)
396*da0073e9SAndroid Build Coastguard Worker                slowpath_output = model(input, src_key_padding_mask=src_key_padding_mask)  # reference
397*da0073e9SAndroid Build Coastguard Worker                # Make sure fastpath_output is same shape as slowpath_output and mask.
398*da0073e9SAndroid Build Coastguard Worker                # When enable_nested_tensor=true, fastpath_output may be smaller than input tensor.
399*da0073e9SAndroid Build Coastguard Worker                # Eg if input bs=1, seqlen=6, and we mask out 2 tokens, fastpath_output will have bs=1, seqlen=4.
400*da0073e9SAndroid Build Coastguard Worker                # Expand back to old size to match.
401*da0073e9SAndroid Build Coastguard Worker                bs, true_seqlen, embed_dim = fastpath_output.shape
402*da0073e9SAndroid Build Coastguard Worker                expanded_seqlen = src_key_padding_mask.shape[1]
403*da0073e9SAndroid Build Coastguard Worker                fastpath_output_expanded = torch.zeros(bs, expanded_seqlen, embed_dim, device=device)
404*da0073e9SAndroid Build Coastguard Worker                fastpath_output_expanded[:, :true_seqlen, :] = fastpath_output
405*da0073e9SAndroid Build Coastguard Worker                # no garauntees on output corresponding to masked tokens, so they may vary between slow/fast path. set all to 0.
406*da0073e9SAndroid Build Coastguard Worker                fastpath_output_expanded = fastpath_output_expanded.masked_fill(src_key_padding_mask.unsqueeze(-1), 0)
407*da0073e9SAndroid Build Coastguard Worker                slowpath_output = slowpath_output.masked_fill(src_key_padding_mask.unsqueeze(-1), 0)
408*da0073e9SAndroid Build Coastguard Worker                torch.testing.assert_close(fastpath_output_expanded, slowpath_output, rtol=1e-7, atol=1e-5)
409*da0073e9SAndroid Build Coastguard Worker
410*da0073e9SAndroid Build Coastguard Worker    @parametrize("with_no_grad", [True, False])
411*da0073e9SAndroid Build Coastguard Worker    @parametrize("training", [True, False])
412*da0073e9SAndroid Build Coastguard Worker    @parametrize("enable_nested_tensor", [False])
413*da0073e9SAndroid Build Coastguard Worker    def test_transformerencoder_square_input(self, with_no_grad, training, enable_nested_tensor, device):
414*da0073e9SAndroid Build Coastguard Worker        """
415*da0073e9SAndroid Build Coastguard Worker        Test for edge cases when input of shape (batch size, sequence length, embedding dimension) has
416*da0073e9SAndroid Build Coastguard Worker        batch size == sequence length
417*da0073e9SAndroid Build Coastguard Worker        """
418*da0073e9SAndroid Build Coastguard Worker        model = torch.nn.TransformerEncoder(
419*da0073e9SAndroid Build Coastguard Worker            torch.nn.TransformerEncoderLayer(d_model=4, nhead=2, dim_feedforward=16, dropout=0.0, batch_first=True),
420*da0073e9SAndroid Build Coastguard Worker            num_layers=2,
421*da0073e9SAndroid Build Coastguard Worker            enable_nested_tensor=enable_nested_tensor
422*da0073e9SAndroid Build Coastguard Worker        ).to(device)
423*da0073e9SAndroid Build Coastguard Worker
424*da0073e9SAndroid Build Coastguard Worker        with torch.no_grad():
425*da0073e9SAndroid Build Coastguard Worker            # set constant weights of the model
426*da0073e9SAndroid Build Coastguard Worker            for idx, p in enumerate(model.parameters()):
427*da0073e9SAndroid Build Coastguard Worker                x = p.data
428*da0073e9SAndroid Build Coastguard Worker                sz = x.view(-1).size(0)
429*da0073e9SAndroid Build Coastguard Worker                shape = x.shape
430*da0073e9SAndroid Build Coastguard Worker                x = torch.cos(torch.arange(0, sz).float().view(shape))
431*da0073e9SAndroid Build Coastguard Worker                p.data.copy_(x)
432*da0073e9SAndroid Build Coastguard Worker
433*da0073e9SAndroid Build Coastguard Worker        if training:
434*da0073e9SAndroid Build Coastguard Worker            model = model.train()
435*da0073e9SAndroid Build Coastguard Worker        else:
436*da0073e9SAndroid Build Coastguard Worker            model = model.eval()
437*da0073e9SAndroid Build Coastguard Worker        x = torch.arange(0, 16).reshape(2, 2, 4).to(torch.get_default_dtype()).to(device)
438*da0073e9SAndroid Build Coastguard Worker        src_mask = torch.Tensor([[0, 1], [0, 0]]).to(torch.bool).to(device)
439*da0073e9SAndroid Build Coastguard Worker
440*da0073e9SAndroid Build Coastguard Worker        if with_no_grad:
441*da0073e9SAndroid Build Coastguard Worker            cm = torch.no_grad()
442*da0073e9SAndroid Build Coastguard Worker        else:
443*da0073e9SAndroid Build Coastguard Worker            cm = contextlib.nullcontext()
444*da0073e9SAndroid Build Coastguard Worker        with cm:
445*da0073e9SAndroid Build Coastguard Worker            result = model(x, mask=src_mask)
446*da0073e9SAndroid Build Coastguard Worker
447*da0073e9SAndroid Build Coastguard Worker        ref_output = torch.Tensor([[[2.420306205749512, 0.017629241570830, -0.607857942581177, -0.085519507527351],
448*da0073e9SAndroid Build Coastguard Worker                                    [2.420306205749512, 0.017629241570830, -0.607857942581177, -0.085519507527351]],
449*da0073e9SAndroid Build Coastguard Worker                                   [[2.419836044311523, 0.017548924311996, -0.608187675476074, -0.085347734391689],
450*da0073e9SAndroid Build Coastguard Worker                                    [2.419836044311523, 0.017548924311996, -0.608187675476074, -0.085347734391689]]]
451*da0073e9SAndroid Build Coastguard Worker                                  ).to(device)
452*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(tuple(result.shape), tuple(ref_output.shape))
453*da0073e9SAndroid Build Coastguard Worker        torch.testing.assert_close(result, ref_output, rtol=1e-7, atol=1e-5)
454*da0073e9SAndroid Build Coastguard Worker
455*da0073e9SAndroid Build Coastguard Worker    @parametrize("batch_first", [True, False])
456*da0073e9SAndroid Build Coastguard Worker    @parametrize("training", [True, False])
457*da0073e9SAndroid Build Coastguard Worker    @parametrize("enable_nested_tensor", [True, False])
458*da0073e9SAndroid Build Coastguard Worker    def test_transformerencoder(self, batch_first, training, enable_nested_tensor, device):
459*da0073e9SAndroid Build Coastguard Worker        def get_a_test_layer(activation, batch_first=False):
460*da0073e9SAndroid Build Coastguard Worker            d_model = 4
461*da0073e9SAndroid Build Coastguard Worker            nhead = 2
462*da0073e9SAndroid Build Coastguard Worker            dim_feedforward = 16
463*da0073e9SAndroid Build Coastguard Worker            dropout = 0.0
464*da0073e9SAndroid Build Coastguard Worker
465*da0073e9SAndroid Build Coastguard Worker            layer = nn.TransformerEncoderLayer(
466*da0073e9SAndroid Build Coastguard Worker                d_model,
467*da0073e9SAndroid Build Coastguard Worker                nhead,
468*da0073e9SAndroid Build Coastguard Worker                dim_feedforward=dim_feedforward,
469*da0073e9SAndroid Build Coastguard Worker                dropout=dropout,
470*da0073e9SAndroid Build Coastguard Worker                activation=activation,
471*da0073e9SAndroid Build Coastguard Worker                batch_first=batch_first,
472*da0073e9SAndroid Build Coastguard Worker            ).to(device)
473*da0073e9SAndroid Build Coastguard Worker
474*da0073e9SAndroid Build Coastguard Worker            with torch.no_grad():
475*da0073e9SAndroid Build Coastguard Worker                # set constant weights of the model
476*da0073e9SAndroid Build Coastguard Worker                for idx, p in enumerate(layer.parameters()):
477*da0073e9SAndroid Build Coastguard Worker                    x = p.data
478*da0073e9SAndroid Build Coastguard Worker                    sz = x.view(-1).size(0)
479*da0073e9SAndroid Build Coastguard Worker                    shape = x.shape
480*da0073e9SAndroid Build Coastguard Worker                    x = torch.cos(torch.arange(0, sz).float().view(shape))
481*da0073e9SAndroid Build Coastguard Worker                    p.data.copy_(x)
482*da0073e9SAndroid Build Coastguard Worker
483*da0073e9SAndroid Build Coastguard Worker            return layer
484*da0073e9SAndroid Build Coastguard Worker
485*da0073e9SAndroid Build Coastguard Worker        # this is a deterministic test for TransformerEncoder
486*da0073e9SAndroid Build Coastguard Worker        activation = F.relu
487*da0073e9SAndroid Build Coastguard Worker
488*da0073e9SAndroid Build Coastguard Worker        def _test(batch_first, training, enable_nested_tensor):
489*da0073e9SAndroid Build Coastguard Worker            def perm_fn(x):
490*da0073e9SAndroid Build Coastguard Worker                return x.transpose(1, 0) if batch_first else x
491*da0073e9SAndroid Build Coastguard Worker
492*da0073e9SAndroid Build Coastguard Worker            encoder_layer = get_a_test_layer(activation=activation,
493*da0073e9SAndroid Build Coastguard Worker                                             batch_first=batch_first)
494*da0073e9SAndroid Build Coastguard Worker
495*da0073e9SAndroid Build Coastguard Worker            model = nn.TransformerEncoder(
496*da0073e9SAndroid Build Coastguard Worker                encoder_layer, 1, enable_nested_tensor=enable_nested_tensor
497*da0073e9SAndroid Build Coastguard Worker            ).to(device)
498*da0073e9SAndroid Build Coastguard Worker
499*da0073e9SAndroid Build Coastguard Worker            if not training:
500*da0073e9SAndroid Build Coastguard Worker                model = model.eval()
501*da0073e9SAndroid Build Coastguard Worker
502*da0073e9SAndroid Build Coastguard Worker            # deterministic input
503*da0073e9SAndroid Build Coastguard Worker            encoder_input = perm_fn(torch.tensor([[[0.7462, 0.6653, 0.5679, 0.4891],
504*da0073e9SAndroid Build Coastguard Worker                                                   [0.5387, 0.1655, 0.3565, 0.0471]],
505*da0073e9SAndroid Build Coastguard Worker                                                  [[0.8335, 0.2799, 0.5031, 0.2947],
506*da0073e9SAndroid Build Coastguard Worker                                                   [0.1402, 0.0318, 0.7636, 0.1346]],
507*da0073e9SAndroid Build Coastguard Worker                                                  [[0.6333, 0.9344, 0.1376, 0.9938],
508*da0073e9SAndroid Build Coastguard Worker                                                   [0.8924, 0.2872, 0.6692, 0.2944]],
509*da0073e9SAndroid Build Coastguard Worker                                                  [[0.9897, 0.6915, 0.3154, 0.1733],
510*da0073e9SAndroid Build Coastguard Worker                                                   [0.8645, 0.3513, 0.3064, 0.0767]],
511*da0073e9SAndroid Build Coastguard Worker                                                  [[0.8117, 0.2366, 0.4838, 0.7881],
512*da0073e9SAndroid Build Coastguard Worker                                                   [0.3718, 0.4945, 0.9511, 0.0864]]]
513*da0073e9SAndroid Build Coastguard Worker                                                 )).to(device)
514*da0073e9SAndroid Build Coastguard Worker            result = model(encoder_input)
515*da0073e9SAndroid Build Coastguard Worker            ref_output = perm_fn(torch.tensor([[[2.428589, 0.020835, -0.602055, -0.085249],
516*da0073e9SAndroid Build Coastguard Worker                                                [2.427987, 0.021213, -0.602496, -0.084103]],
517*da0073e9SAndroid Build Coastguard Worker                                               [[2.424689, 0.019155, -0.604793, -0.085672],
518*da0073e9SAndroid Build Coastguard Worker                                                [2.413863, 0.022211, -0.612486, -0.072490]],
519*da0073e9SAndroid Build Coastguard Worker                                               [[2.433774, 0.021598, -0.598343, -0.087548],
520*da0073e9SAndroid Build Coastguard Worker                                                [2.425104, 0.019748, -0.604515, -0.084839]],
521*da0073e9SAndroid Build Coastguard Worker                                               [[2.436185, 0.022682, -0.596625, -0.087261],
522*da0073e9SAndroid Build Coastguard Worker                                                [2.433556, 0.021891, -0.598509, -0.086832]],
523*da0073e9SAndroid Build Coastguard Worker                                               [[2.416246, 0.017512, -0.610712, -0.082961],
524*da0073e9SAndroid Build Coastguard Worker                                                [2.422901, 0.024187, -0.606178, -0.074929]]]
525*da0073e9SAndroid Build Coastguard Worker                                              )).to(device)
526*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(tuple(result.shape), tuple(ref_output.shape))
527*da0073e9SAndroid Build Coastguard Worker            torch.testing.assert_close(result, ref_output, rtol=1e-7, atol=1e-5)
528*da0073e9SAndroid Build Coastguard Worker
529*da0073e9SAndroid Build Coastguard Worker            # all 0 src_mask
530*da0073e9SAndroid Build Coastguard Worker            src_mask = torch.zeros([5, 5]).to(device) == 1
531*da0073e9SAndroid Build Coastguard Worker            result = model(encoder_input, mask=src_mask)
532*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(tuple(result.shape), tuple(ref_output.shape))
533*da0073e9SAndroid Build Coastguard Worker            torch.testing.assert_close(result, ref_output, rtol=1e-7, atol=1e-5)
534*da0073e9SAndroid Build Coastguard Worker
535*da0073e9SAndroid Build Coastguard Worker            # all 0
536*da0073e9SAndroid Build Coastguard Worker            mask = torch.zeros([2, 5]).to(device) == 1
537*da0073e9SAndroid Build Coastguard Worker            result = model(encoder_input, src_key_padding_mask=mask)
538*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(tuple(result.shape), tuple(ref_output.shape))
539*da0073e9SAndroid Build Coastguard Worker            torch.testing.assert_close(result, ref_output, rtol=1e-7, atol=1e-5)
540*da0073e9SAndroid Build Coastguard Worker
541*da0073e9SAndroid Build Coastguard Worker            mask[0, 1] = 1
542*da0073e9SAndroid Build Coastguard Worker            mask[1, 3] = 1
543*da0073e9SAndroid Build Coastguard Worker            mask[1, 4] = 1
544*da0073e9SAndroid Build Coastguard Worker            result = model(encoder_input, src_key_padding_mask=mask)
545*da0073e9SAndroid Build Coastguard Worker            ref_output = perm_fn(torch.tensor([[[2.429026, 0.020793, -0.601741, -0.085642],
546*da0073e9SAndroid Build Coastguard Worker                                                [2.428811, 0.021445, -0.601912, -0.084252]],
547*da0073e9SAndroid Build Coastguard Worker                                               [[2.425009, 0.019155, -0.604566, -0.085899],
548*da0073e9SAndroid Build Coastguard Worker                                                [2.415408, 0.02249, -0.611415, -0.073]],
549*da0073e9SAndroid Build Coastguard Worker                                               [[2.434199, 0.021682, -0.598039, -0.087699],
550*da0073e9SAndroid Build Coastguard Worker                                                [2.42598, 0.019941, -0.603896, -0.085091]],
551*da0073e9SAndroid Build Coastguard Worker                                               [[2.436457, 0.022736, -0.59643, -0.08736],
552*da0073e9SAndroid Build Coastguard Worker                                                [2.434021, 0.022093, -0.598179, -0.08679]],
553*da0073e9SAndroid Build Coastguard Worker                                               [[2.416531, 0.017498, -0.610513, -0.083181],
554*da0073e9SAndroid Build Coastguard Worker                                                [2.4242, 0.024653, -0.605266, -0.074959]]]
555*da0073e9SAndroid Build Coastguard Worker                                              )).to(device)
556*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(tuple(result.shape), tuple(ref_output.shape))
557*da0073e9SAndroid Build Coastguard Worker            torch.testing.assert_close(result, ref_output, rtol=1e-7, atol=1e-5)
558*da0073e9SAndroid Build Coastguard Worker
559*da0073e9SAndroid Build Coastguard Worker            # test case 2, multiple layers no norm
560*da0073e9SAndroid Build Coastguard Worker            model = nn.TransformerEncoder(encoder_layer, 2, enable_nested_tensor=enable_nested_tensor).to(device)
561*da0073e9SAndroid Build Coastguard Worker            if not training:
562*da0073e9SAndroid Build Coastguard Worker                model = model.eval()
563*da0073e9SAndroid Build Coastguard Worker            result = model(encoder_input, src_key_padding_mask=mask)
564*da0073e9SAndroid Build Coastguard Worker            ref_output = perm_fn(torch.tensor([[[2.419051, 0.017446, -0.608738, -0.085003],
565*da0073e9SAndroid Build Coastguard Worker                                                [2.419102, 0.017452, -0.608703, -0.085026]],
566*da0073e9SAndroid Build Coastguard Worker                                               [[2.419043, 0.017445, -0.608744, -0.084999],
567*da0073e9SAndroid Build Coastguard Worker                                                [2.419052, 0.017446, -0.608738, -0.085004]],
568*da0073e9SAndroid Build Coastguard Worker                                               [[2.419067, 0.017448, -0.608727, -0.085010],
569*da0073e9SAndroid Build Coastguard Worker                                                [2.419098, 0.017452, -0.608706, -0.085024]],
570*da0073e9SAndroid Build Coastguard Worker                                               [[2.419072, 0.017449, -0.608724, -0.085012],
571*da0073e9SAndroid Build Coastguard Worker                                                [2.419119, 0.017455, -0.608691, -0.085034]],
572*da0073e9SAndroid Build Coastguard Worker                                               [[2.419019, 0.017442, -0.608761, -0.084989],
573*da0073e9SAndroid Build Coastguard Worker                                                [2.419075, 0.017449, -0.608722, -0.085014]]]
574*da0073e9SAndroid Build Coastguard Worker                                              )).to(device)
575*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(tuple(result.shape), tuple(ref_output.shape))
576*da0073e9SAndroid Build Coastguard Worker            torch.testing.assert_close(result, ref_output, rtol=1e-7, atol=1e-5)
577*da0073e9SAndroid Build Coastguard Worker
578*da0073e9SAndroid Build Coastguard Worker            model = nn.TransformerEncoder(encoder_layer, 6, enable_nested_tensor=enable_nested_tensor).to(device)
579*da0073e9SAndroid Build Coastguard Worker            if not training:
580*da0073e9SAndroid Build Coastguard Worker                model = model.eval()
581*da0073e9SAndroid Build Coastguard Worker            result = model(encoder_input, src_key_padding_mask=mask)
582*da0073e9SAndroid Build Coastguard Worker            ref_output = perm_fn(torch.tensor([[[2.419101, 0.017453, -0.608703, -0.085025],
583*da0073e9SAndroid Build Coastguard Worker                                                [2.419101, 0.017453, -0.608704, -0.085025]],
584*da0073e9SAndroid Build Coastguard Worker                                               [[2.419101, 0.017453, -0.608703, -0.085025],
585*da0073e9SAndroid Build Coastguard Worker                                                [2.419101, 0.017453, -0.608704, -0.085025]],
586*da0073e9SAndroid Build Coastguard Worker                                               [[2.419101, 0.017453, -0.608703, -0.085025],
587*da0073e9SAndroid Build Coastguard Worker                                                [2.419101, 0.017453, -0.608704, -0.085025]],
588*da0073e9SAndroid Build Coastguard Worker                                               [[2.419101, 0.017453, -0.608703, -0.085025],
589*da0073e9SAndroid Build Coastguard Worker                                                [2.419101, 0.017453, -0.608704, -0.085025]],
590*da0073e9SAndroid Build Coastguard Worker                                               [[2.419101, 0.017453, -0.608703, -0.085025],
591*da0073e9SAndroid Build Coastguard Worker                                                [2.419101, 0.017453, -0.608704, -0.085025]]]
592*da0073e9SAndroid Build Coastguard Worker                                              )).to(device)
593*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(tuple(result.shape), tuple(ref_output.shape))
594*da0073e9SAndroid Build Coastguard Worker            torch.testing.assert_close(result, ref_output, rtol=1e-7, atol=1e-5)
595*da0073e9SAndroid Build Coastguard Worker
596*da0073e9SAndroid Build Coastguard Worker            # test case 3, multiple layers with norm
597*da0073e9SAndroid Build Coastguard Worker            # d_model = 4
598*da0073e9SAndroid Build Coastguard Worker            norm = nn.LayerNorm(4)
599*da0073e9SAndroid Build Coastguard Worker            model = nn.TransformerEncoder(encoder_layer, 2, norm=norm,
600*da0073e9SAndroid Build Coastguard Worker                                          enable_nested_tensor=enable_nested_tensor).to(device)
601*da0073e9SAndroid Build Coastguard Worker            if not training:
602*da0073e9SAndroid Build Coastguard Worker                model = model.eval()
603*da0073e9SAndroid Build Coastguard Worker            result = model(encoder_input, src_key_padding_mask=mask)
604*da0073e9SAndroid Build Coastguard Worker            ref_output = perm_fn(torch.tensor([[[1.695949, -0.357635, -0.893077, -0.445238],
605*da0073e9SAndroid Build Coastguard Worker                                                [1.695955, -0.357639, -0.893050, -0.445266]],
606*da0073e9SAndroid Build Coastguard Worker                                               [[1.695948, -0.357634, -0.893082, -0.445233],
607*da0073e9SAndroid Build Coastguard Worker                                                [1.695950, -0.357635, -0.893077, -0.445238]],
608*da0073e9SAndroid Build Coastguard Worker                                               [[1.695951, -0.357636, -0.893069, -0.445246],
609*da0073e9SAndroid Build Coastguard Worker                                                [1.695955, -0.357639, -0.893052, -0.445264]],
610*da0073e9SAndroid Build Coastguard Worker                                               [[1.695952, -0.357636, -0.893066, -0.445249],
611*da0073e9SAndroid Build Coastguard Worker                                                [1.695957, -0.357641, -0.893041, -0.445276]],
612*da0073e9SAndroid Build Coastguard Worker                                               [[1.695946, -0.357632, -0.893095, -0.445220],
613*da0073e9SAndroid Build Coastguard Worker                                                [1.695952, -0.357637, -0.893065, -0.445251]]]
614*da0073e9SAndroid Build Coastguard Worker                                              )).to(device)
615*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(tuple(result.shape), tuple(ref_output.shape))
616*da0073e9SAndroid Build Coastguard Worker            torch.testing.assert_close(result, ref_output, rtol=1e-7, atol=1e-5)
617*da0073e9SAndroid Build Coastguard Worker
618*da0073e9SAndroid Build Coastguard Worker            model = nn.TransformerEncoder(encoder_layer, 6, norm=norm,
619*da0073e9SAndroid Build Coastguard Worker                                          enable_nested_tensor=enable_nested_tensor).to(device)
620*da0073e9SAndroid Build Coastguard Worker            if not training:
621*da0073e9SAndroid Build Coastguard Worker                model = model.eval()
622*da0073e9SAndroid Build Coastguard Worker            result = model(encoder_input, src_key_padding_mask=mask)
623*da0073e9SAndroid Build Coastguard Worker            ref_output = perm_fn(torch.tensor([[[1.695955, -0.357639, -0.893051, -0.445265],
624*da0073e9SAndroid Build Coastguard Worker                                                [1.695955, -0.357639, -0.893051, -0.445265]],
625*da0073e9SAndroid Build Coastguard Worker                                               [[1.695955, -0.357639, -0.893051, -0.445265],
626*da0073e9SAndroid Build Coastguard Worker                                                [1.695955, -0.357639, -0.893051, -0.445265]],
627*da0073e9SAndroid Build Coastguard Worker                                               [[1.695955, -0.357639, -0.893051, -0.445265],
628*da0073e9SAndroid Build Coastguard Worker                                                [1.695955, -0.357639, -0.893051, -0.445265]],
629*da0073e9SAndroid Build Coastguard Worker                                               [[1.695955, -0.357639, -0.893051, -0.445265],
630*da0073e9SAndroid Build Coastguard Worker                                                [1.695955, -0.357639, -0.893051, -0.445265]],
631*da0073e9SAndroid Build Coastguard Worker                                               [[1.695955, -0.357639, -0.893051, -0.445265],
632*da0073e9SAndroid Build Coastguard Worker                                                [1.695955, -0.357639, -0.893051, -0.445265]]]
633*da0073e9SAndroid Build Coastguard Worker                                              )).to(device)
634*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(tuple(result.shape), tuple(ref_output.shape))
635*da0073e9SAndroid Build Coastguard Worker            torch.testing.assert_close(result, ref_output, rtol=1e-7, atol=1e-5)
636*da0073e9SAndroid Build Coastguard Worker
637*da0073e9SAndroid Build Coastguard Worker        # TODO: remove set default dtype to double by making ref_output more precise.
638*da0073e9SAndroid Build Coastguard Worker        # Added because this test was copied from test_nn.py, which has default
639*da0073e9SAndroid Build Coastguard Worker        # dtype double. If default dtype is float, tests will say tensors not close because
640*da0073e9SAndroid Build Coastguard Worker        # ref output precision too low
641*da0073e9SAndroid Build Coastguard Worker        with set_default_dtype(torch.double):
642*da0073e9SAndroid Build Coastguard Worker            if training:
643*da0073e9SAndroid Build Coastguard Worker                cm = contextlib.nullcontext()
644*da0073e9SAndroid Build Coastguard Worker            else:
645*da0073e9SAndroid Build Coastguard Worker                cm = torch.no_grad()  # transformer fast path requires no grad
646*da0073e9SAndroid Build Coastguard Worker            with cm:
647*da0073e9SAndroid Build Coastguard Worker                _test(batch_first, training, enable_nested_tensor)
648*da0073e9SAndroid Build Coastguard Worker
649*da0073e9SAndroid Build Coastguard Worker    @unittest.skipIf(sys.version_info < (3, 11), "not supported on pre-3.11 Python")
650*da0073e9SAndroid Build Coastguard Worker    def test_encoder_padding_and_src_mask_bool(self):
651*da0073e9SAndroid Build Coastguard Worker        encoder_layer = nn.TransformerEncoderLayer(
652*da0073e9SAndroid Build Coastguard Worker            d_model=16,
653*da0073e9SAndroid Build Coastguard Worker            nhead=2,
654*da0073e9SAndroid Build Coastguard Worker            dim_feedforward=32,
655*da0073e9SAndroid Build Coastguard Worker            dropout=0.1,
656*da0073e9SAndroid Build Coastguard Worker            activation='relu',
657*da0073e9SAndroid Build Coastguard Worker            batch_first=True,
658*da0073e9SAndroid Build Coastguard Worker        )
659*da0073e9SAndroid Build Coastguard Worker        encoder_norm = nn.LayerNorm(16)
660*da0073e9SAndroid Build Coastguard Worker        encoder = nn.TransformerEncoder(
661*da0073e9SAndroid Build Coastguard Worker            encoder_layer, 2, encoder_norm
662*da0073e9SAndroid Build Coastguard Worker        )
663*da0073e9SAndroid Build Coastguard Worker
664*da0073e9SAndroid Build Coastguard Worker        inputs = torch.randn(2, 3, 16)
665*da0073e9SAndroid Build Coastguard Worker
666*da0073e9SAndroid Build Coastguard Worker        src_mask = torch.ones(3, 3, dtype=torch.bool).triu_(diagonal=1)
667*da0073e9SAndroid Build Coastguard Worker        input_seq_len = torch.tensor([3, 2])
668*da0073e9SAndroid Build Coastguard Worker        padding_mask = (
669*da0073e9SAndroid Build Coastguard Worker            torch.arange(3)[None, :].cpu() >= input_seq_len[:, None]
670*da0073e9SAndroid Build Coastguard Worker        )
671*da0073e9SAndroid Build Coastguard Worker
672*da0073e9SAndroid Build Coastguard Worker        with (self.assertNoLogs(None) if not TEST_WITH_TORCHDYNAMO else contextlib.nullcontext()):
673*da0073e9SAndroid Build Coastguard Worker            encoder(
674*da0073e9SAndroid Build Coastguard Worker                inputs,
675*da0073e9SAndroid Build Coastguard Worker                mask=src_mask,
676*da0073e9SAndroid Build Coastguard Worker                src_key_padding_mask=padding_mask,
677*da0073e9SAndroid Build Coastguard Worker            )
678*da0073e9SAndroid Build Coastguard Worker
679*da0073e9SAndroid Build Coastguard Worker    @unittest.skipIf(sys.version_info < (3, 11), "not supported on pre-3.11 Python")
680*da0073e9SAndroid Build Coastguard Worker    def test_decoder_padding_and_src_mask_bool(self):
681*da0073e9SAndroid Build Coastguard Worker
682*da0073e9SAndroid Build Coastguard Worker        def transformer_decoder(inputs, input_seq_len, memory):
683*da0073e9SAndroid Build Coastguard Worker            decoder_layer = nn.TransformerDecoderLayer(
684*da0073e9SAndroid Build Coastguard Worker                d_model=16,
685*da0073e9SAndroid Build Coastguard Worker                nhead=2,
686*da0073e9SAndroid Build Coastguard Worker                dim_feedforward=32,
687*da0073e9SAndroid Build Coastguard Worker                dropout=0.1,
688*da0073e9SAndroid Build Coastguard Worker                activation='relu',
689*da0073e9SAndroid Build Coastguard Worker                batch_first=True,
690*da0073e9SAndroid Build Coastguard Worker            )
691*da0073e9SAndroid Build Coastguard Worker            decoder_norm = nn.LayerNorm(16)
692*da0073e9SAndroid Build Coastguard Worker            decoder = nn.TransformerDecoder(
693*da0073e9SAndroid Build Coastguard Worker                decoder_layer, 2, decoder_norm
694*da0073e9SAndroid Build Coastguard Worker            )
695*da0073e9SAndroid Build Coastguard Worker
696*da0073e9SAndroid Build Coastguard Worker            src_mask = torch.ones(
697*da0073e9SAndroid Build Coastguard Worker                inputs.shape[1], inputs.shape[1], dtype=torch.bool
698*da0073e9SAndroid Build Coastguard Worker            ).triu_(diagonal=1)
699*da0073e9SAndroid Build Coastguard Worker            padding_mask = (
700*da0073e9SAndroid Build Coastguard Worker                torch.arange(inputs.shape[1])[None, :].cpu()
701*da0073e9SAndroid Build Coastguard Worker                >= input_seq_len[:, None]
702*da0073e9SAndroid Build Coastguard Worker            )
703*da0073e9SAndroid Build Coastguard Worker
704*da0073e9SAndroid Build Coastguard Worker            return decoder(
705*da0073e9SAndroid Build Coastguard Worker                inputs,
706*da0073e9SAndroid Build Coastguard Worker                memory,
707*da0073e9SAndroid Build Coastguard Worker                tgt_mask=src_mask,
708*da0073e9SAndroid Build Coastguard Worker                tgt_key_padding_mask=padding_mask,
709*da0073e9SAndroid Build Coastguard Worker                memory_key_padding_mask=padding_mask,
710*da0073e9SAndroid Build Coastguard Worker            )
711*da0073e9SAndroid Build Coastguard Worker
712*da0073e9SAndroid Build Coastguard Worker        inputs = torch.randn(2, 3, 16)
713*da0073e9SAndroid Build Coastguard Worker        memory = torch.randn(2, 3, 16)
714*da0073e9SAndroid Build Coastguard Worker        input_seq_len = torch.tensor([3, 2])
715*da0073e9SAndroid Build Coastguard Worker
716*da0073e9SAndroid Build Coastguard Worker        with self.assertNoLogs(None):
717*da0073e9SAndroid Build Coastguard Worker            transformer_decoder(inputs, input_seq_len, memory)
718*da0073e9SAndroid Build Coastguard Worker
719*da0073e9SAndroid Build Coastguard Worker    def test_encoder_is_causal(self):
720*da0073e9SAndroid Build Coastguard Worker
721*da0073e9SAndroid Build Coastguard Worker        d_model = 3
722*da0073e9SAndroid Build Coastguard Worker        layer = torch.nn.TransformerEncoderLayer(d_model, 1, 6, batch_first=True)
723*da0073e9SAndroid Build Coastguard Worker        layer.eval()
724*da0073e9SAndroid Build Coastguard Worker        x = torch.randn(1, 5, d_model)
725*da0073e9SAndroid Build Coastguard Worker        unmasked_output = layer(x)
726*da0073e9SAndroid Build Coastguard Worker        mask = torch.nn.Transformer.generate_square_subsequent_mask(x.size(1))
727*da0073e9SAndroid Build Coastguard Worker        is_causal_output = layer(x, src_mask=mask, is_causal=True)
728*da0073e9SAndroid Build Coastguard Worker        masked_output = layer(x, src_mask=mask)
729*da0073e9SAndroid Build Coastguard Worker
730*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(masked_output, is_causal_output)
731*da0073e9SAndroid Build Coastguard Worker
732*da0073e9SAndroid Build Coastguard Worker    @onlyCUDA
733*da0073e9SAndroid Build Coastguard Worker    @parametrize("nb_heads", [1, 8])
734*da0073e9SAndroid Build Coastguard Worker    @parametrize("bias", [True, False])
735*da0073e9SAndroid Build Coastguard Worker    def test_mha_native_args(self, nb_heads, bias):
736*da0073e9SAndroid Build Coastguard Worker
737*da0073e9SAndroid Build Coastguard Worker        B, L, F = 8, 100, 128
738*da0073e9SAndroid Build Coastguard Worker        batch_first = True
739*da0073e9SAndroid Build Coastguard Worker        fast_path = True
740*da0073e9SAndroid Build Coastguard Worker        use_pad_mask = (bias % 2) == 1
741*da0073e9SAndroid Build Coastguard Worker
742*da0073e9SAndroid Build Coastguard Worker        mha = nn.MultiheadAttention(
743*da0073e9SAndroid Build Coastguard Worker            embed_dim=F,
744*da0073e9SAndroid Build Coastguard Worker            num_heads=nb_heads,
745*da0073e9SAndroid Build Coastguard Worker            batch_first=batch_first,
746*da0073e9SAndroid Build Coastguard Worker            bias=bias
747*da0073e9SAndroid Build Coastguard Worker        ).cuda()
748*da0073e9SAndroid Build Coastguard Worker        mha.eval()
749*da0073e9SAndroid Build Coastguard Worker
750*da0073e9SAndroid Build Coastguard Worker        ctx = torch.no_grad if fast_path else contextlib.nullcontext
751*da0073e9SAndroid Build Coastguard Worker        with ctx():
752*da0073e9SAndroid Build Coastguard Worker            x = torch.randn(B, L, F).cuda()
753*da0073e9SAndroid Build Coastguard Worker            if not batch_first:
754*da0073e9SAndroid Build Coastguard Worker                x = x.transpose(0, 1)
755*da0073e9SAndroid Build Coastguard Worker
756*da0073e9SAndroid Build Coastguard Worker            pad_mask = None
757*da0073e9SAndroid Build Coastguard Worker            if use_pad_mask:
758*da0073e9SAndroid Build Coastguard Worker                pad_mask = torch.zeros((B, L), dtype=torch.bool).cuda()
759*da0073e9SAndroid Build Coastguard Worker
760*da0073e9SAndroid Build Coastguard Worker            mha(query=x, key=x, value=x, key_padding_mask=pad_mask)
761*da0073e9SAndroid Build Coastguard Worker
762*da0073e9SAndroid Build Coastguard Worker    def test_kpm_mask_trailing_column_with_nested_tensor(self, device):
763*da0073e9SAndroid Build Coastguard Worker        encoder_layer = nn.TransformerEncoderLayer(
764*da0073e9SAndroid Build Coastguard Worker            d_model=256,
765*da0073e9SAndroid Build Coastguard Worker            nhead=4,
766*da0073e9SAndroid Build Coastguard Worker            dim_feedforward=512,
767*da0073e9SAndroid Build Coastguard Worker            activation='gelu',
768*da0073e9SAndroid Build Coastguard Worker            norm_first=False,
769*da0073e9SAndroid Build Coastguard Worker            batch_first=False,
770*da0073e9SAndroid Build Coastguard Worker        )
771*da0073e9SAndroid Build Coastguard Worker        transformer_encoder = nn.TransformerEncoder(encoder_layer, num_layers=3, enable_nested_tensor=True).to(device)
772*da0073e9SAndroid Build Coastguard Worker
773*da0073e9SAndroid Build Coastguard Worker        x = torch.randn(10, 6, 256).to(device)
774*da0073e9SAndroid Build Coastguard Worker        mask = torch.ones(6, 10)
775*da0073e9SAndroid Build Coastguard Worker        mask[0, :] = 0  # here I masked 5 columns instead of just one
776*da0073e9SAndroid Build Coastguard Worker        mask = mask.bool().to(device)
777*da0073e9SAndroid Build Coastguard Worker        out = transformer_encoder(src=x, src_key_padding_mask=mask)
778*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(out.shape[1], 6)
779*da0073e9SAndroid Build Coastguard Worker
780*da0073e9SAndroid Build Coastguard Worker    # CPU unit test has_torch_functions in test environment,
781*da0073e9SAndroid Build Coastguard Worker    #   preventing successful completion
782*da0073e9SAndroid Build Coastguard Worker    @onlyCUDA
783*da0073e9SAndroid Build Coastguard Worker    def test_with_nested_tensor_input(self, device):
784*da0073e9SAndroid Build Coastguard Worker        encoder_layer = nn.TransformerEncoderLayer(
785*da0073e9SAndroid Build Coastguard Worker            d_model=256,
786*da0073e9SAndroid Build Coastguard Worker            nhead=4,
787*da0073e9SAndroid Build Coastguard Worker            dim_feedforward=512,
788*da0073e9SAndroid Build Coastguard Worker            activation='gelu',
789*da0073e9SAndroid Build Coastguard Worker            norm_first=False,
790*da0073e9SAndroid Build Coastguard Worker            batch_first=True,
791*da0073e9SAndroid Build Coastguard Worker        )
792*da0073e9SAndroid Build Coastguard Worker        transformer_encoder = nn.TransformerEncoder(encoder_layer, num_layers=3, enable_nested_tensor=True).to(device)
793*da0073e9SAndroid Build Coastguard Worker
794*da0073e9SAndroid Build Coastguard Worker        transformer_encoder.eval()
795*da0073e9SAndroid Build Coastguard Worker        with torch.no_grad():
796*da0073e9SAndroid Build Coastguard Worker            x = torch.randn(6, 10, 256).to(device)
797*da0073e9SAndroid Build Coastguard Worker            mask = torch.ones(6, 10)
798*da0073e9SAndroid Build Coastguard Worker            mask[0, 0:] = 0  # here I masked 5 columns instead of just one
799*da0073e9SAndroid Build Coastguard Worker            mask[2, 2:] = 0  # here I masked 5 columns instead of just one
800*da0073e9SAndroid Build Coastguard Worker            mask[4, 4:] = 0  # here I masked 5 columns instead of just one
801*da0073e9SAndroid Build Coastguard Worker            mask[5, 8:] = 0  # here I masked 5 columns instead of just one
802*da0073e9SAndroid Build Coastguard Worker            mask = mask.bool().to(device)
803*da0073e9SAndroid Build Coastguard Worker            x = torch._nested_tensor_from_mask(x, mask.logical_not(), mask_check=False)
804*da0073e9SAndroid Build Coastguard Worker            out = transformer_encoder(src=x, src_key_padding_mask=None)
805*da0073e9SAndroid Build Coastguard Worker
806*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(out.is_nested, True)
807*da0073e9SAndroid Build Coastguard Worker
808*da0073e9SAndroid Build Coastguard Worker
809*da0073e9SAndroid Build Coastguard Worker
810*da0073e9SAndroid Build Coastguard Worker    def test_script_encoder_subclass(self, device):
811*da0073e9SAndroid Build Coastguard Worker        class MyCustomLayer(nn.TransformerEncoderLayer):
812*da0073e9SAndroid Build Coastguard Worker            pass
813*da0073e9SAndroid Build Coastguard Worker
814*da0073e9SAndroid Build Coastguard Worker        encoder = nn.TransformerEncoder(
815*da0073e9SAndroid Build Coastguard Worker            MyCustomLayer(d_model=256, nhead=8), num_layers=6
816*da0073e9SAndroid Build Coastguard Worker        ).to(device=device)
817*da0073e9SAndroid Build Coastguard Worker        torch.jit.script(encoder)
818*da0073e9SAndroid Build Coastguard Worker
819*da0073e9SAndroid Build Coastguard Worker    # brazenly adapted from test_transformerencoderlayer_src_mask to test execution of
820*da0073e9SAndroid Build Coastguard Worker    # torchscripted transformerencoderlayer subclass
821*da0073e9SAndroid Build Coastguard Worker    def test_transformerencoderlayer_subclass(self, device):
822*da0073e9SAndroid Build Coastguard Worker        class MyCustomLayer(nn.TransformerEncoderLayer):
823*da0073e9SAndroid Build Coastguard Worker            pass
824*da0073e9SAndroid Build Coastguard Worker
825*da0073e9SAndroid Build Coastguard Worker        nhead = 4
826*da0073e9SAndroid Build Coastguard Worker        batch_size = 2
827*da0073e9SAndroid Build Coastguard Worker        seqlen = 4
828*da0073e9SAndroid Build Coastguard Worker        d_model = 8
829*da0073e9SAndroid Build Coastguard Worker        dim_feedforward = 32
830*da0073e9SAndroid Build Coastguard Worker
831*da0073e9SAndroid Build Coastguard Worker        model = MyCustomLayer(
832*da0073e9SAndroid Build Coastguard Worker            d_model=d_model,
833*da0073e9SAndroid Build Coastguard Worker            nhead=nhead,
834*da0073e9SAndroid Build Coastguard Worker            dim_feedforward=dim_feedforward,
835*da0073e9SAndroid Build Coastguard Worker            batch_first=True).to(device)
836*da0073e9SAndroid Build Coastguard Worker        script_model = torch.jit.script(model)
837*da0073e9SAndroid Build Coastguard Worker
838*da0073e9SAndroid Build Coastguard Worker        src = torch.rand(batch_size, seqlen, d_model).to(device)  # bs, seqlen, d_model
839*da0073e9SAndroid Build Coastguard Worker        src_mask = torch.zeros(seqlen, seqlen).to(torch.bool).to(device)
840*da0073e9SAndroid Build Coastguard Worker
841*da0073e9SAndroid Build Coastguard Worker        torch.manual_seed(42)
842*da0073e9SAndroid Build Coastguard Worker        result = model(src, src_mask=src_mask)
843*da0073e9SAndroid Build Coastguard Worker        torch.manual_seed(42)
844*da0073e9SAndroid Build Coastguard Worker        scripted_result = script_model(src, src_mask=src_mask)
845*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(result, scripted_result)
846*da0073e9SAndroid Build Coastguard Worker
847*da0073e9SAndroid Build Coastguard Worker        model.eval()
848*da0073e9SAndroid Build Coastguard Worker        script_model = torch.jit.script(model)
849*da0073e9SAndroid Build Coastguard Worker
850*da0073e9SAndroid Build Coastguard Worker        with torch.no_grad():
851*da0073e9SAndroid Build Coastguard Worker            result = model(src, src_mask=src_mask)
852*da0073e9SAndroid Build Coastguard Worker            scripted_result = script_model(src, src_mask=src_mask)
853*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(result, scripted_result)
854*da0073e9SAndroid Build Coastguard Worker
855*da0073e9SAndroid Build Coastguard Worker
856*da0073e9SAndroid Build Coastguard Worker    def test_transformerencoderlayer_subclass_model(self, device):
857*da0073e9SAndroid Build Coastguard Worker        class MyCustomLayer(nn.TransformerEncoderLayer):
858*da0073e9SAndroid Build Coastguard Worker            pass
859*da0073e9SAndroid Build Coastguard Worker
860*da0073e9SAndroid Build Coastguard Worker        nhead = 4
861*da0073e9SAndroid Build Coastguard Worker        batch_size = 2
862*da0073e9SAndroid Build Coastguard Worker        seqlen = 4
863*da0073e9SAndroid Build Coastguard Worker        d_model = 8
864*da0073e9SAndroid Build Coastguard Worker        dim_feedforward = 32
865*da0073e9SAndroid Build Coastguard Worker
866*da0073e9SAndroid Build Coastguard Worker        layer = MyCustomLayer(
867*da0073e9SAndroid Build Coastguard Worker            d_model=d_model,
868*da0073e9SAndroid Build Coastguard Worker            nhead=nhead,
869*da0073e9SAndroid Build Coastguard Worker            dim_feedforward=dim_feedforward,
870*da0073e9SAndroid Build Coastguard Worker            batch_first=True)
871*da0073e9SAndroid Build Coastguard Worker        model = nn.TransformerEncoder(
872*da0073e9SAndroid Build Coastguard Worker            layer, num_layers=6
873*da0073e9SAndroid Build Coastguard Worker        ).to(device=device)
874*da0073e9SAndroid Build Coastguard Worker        script_model = torch.jit.script(model)
875*da0073e9SAndroid Build Coastguard Worker
876*da0073e9SAndroid Build Coastguard Worker        src = torch.rand(batch_size, seqlen, d_model).to(device)  # bs, seqlen, d_model
877*da0073e9SAndroid Build Coastguard Worker        src_mask = torch.zeros(seqlen, seqlen).to(torch.bool).to(device)
878*da0073e9SAndroid Build Coastguard Worker
879*da0073e9SAndroid Build Coastguard Worker        torch.manual_seed(42)
880*da0073e9SAndroid Build Coastguard Worker        result = model(src, mask=src_mask)
881*da0073e9SAndroid Build Coastguard Worker        torch.manual_seed(42)
882*da0073e9SAndroid Build Coastguard Worker        scripted_result = script_model(src, mask=src_mask)
883*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(result, scripted_result)
884*da0073e9SAndroid Build Coastguard Worker
885*da0073e9SAndroid Build Coastguard Worker        model.eval()
886*da0073e9SAndroid Build Coastguard Worker        script_model = torch.jit.script(model)
887*da0073e9SAndroid Build Coastguard Worker
888*da0073e9SAndroid Build Coastguard Worker        with torch.no_grad():
889*da0073e9SAndroid Build Coastguard Worker            result = model(src, mask=src_mask)
890*da0073e9SAndroid Build Coastguard Worker            scripted_result = script_model(src, mask=src_mask)
891*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(result, scripted_result)
892*da0073e9SAndroid Build Coastguard Worker
893*da0073e9SAndroid Build Coastguard Worker
894*da0073e9SAndroid Build Coastguard Worker    @onlyCUDA
895*da0073e9SAndroid Build Coastguard Worker    @unittest.skipIf(not TEST_FAIRSEQ, "Fairseq not found")
896*da0073e9SAndroid Build Coastguard Worker    def test_decoder_only_layer(self):
897*da0073e9SAndroid Build Coastguard Worker        DEFAULT_PADDING_IDX = 0
898*da0073e9SAndroid Build Coastguard Worker
899*da0073e9SAndroid Build Coastguard Worker        class FairseqDecoder(torch.nn.Module):
900*da0073e9SAndroid Build Coastguard Worker            def __init__(
901*da0073e9SAndroid Build Coastguard Worker                self,
902*da0073e9SAndroid Build Coastguard Worker                embed_dim,
903*da0073e9SAndroid Build Coastguard Worker                attention_heads,
904*da0073e9SAndroid Build Coastguard Worker                ffn_embed_dim,
905*da0073e9SAndroid Build Coastguard Worker                num_layers,
906*da0073e9SAndroid Build Coastguard Worker                embedding_layer,  # torch.nn.Embedding. Must have a padding_idx field
907*da0073e9SAndroid Build Coastguard Worker                dropout=0,
908*da0073e9SAndroid Build Coastguard Worker                normalize_before=False,
909*da0073e9SAndroid Build Coastguard Worker                torch_encoder=None,  # torch encoder that you can map weights from
910*da0073e9SAndroid Build Coastguard Worker                activation="relu",
911*da0073e9SAndroid Build Coastguard Worker            ):
912*da0073e9SAndroid Build Coastguard Worker                super().__init__()
913*da0073e9SAndroid Build Coastguard Worker
914*da0073e9SAndroid Build Coastguard Worker                cfg = fairseq_transformer.TransformerConfig()
915*da0073e9SAndroid Build Coastguard Worker                cfg.decoder.embed_dim = embed_dim
916*da0073e9SAndroid Build Coastguard Worker                cfg.decoder.output_dim = embed_dim
917*da0073e9SAndroid Build Coastguard Worker                cfg.decoder.attention_heads = attention_heads
918*da0073e9SAndroid Build Coastguard Worker                cfg.decoder.ffn_embed_dim = ffn_embed_dim
919*da0073e9SAndroid Build Coastguard Worker                cfg.dropout = dropout
920*da0073e9SAndroid Build Coastguard Worker                cfg.decoder.normalize_before = normalize_before
921*da0073e9SAndroid Build Coastguard Worker                cfg.decoder.layers = num_layers
922*da0073e9SAndroid Build Coastguard Worker                # make embedding behavior same as other encoders
923*da0073e9SAndroid Build Coastguard Worker                cfg.no_token_positional_embeddings = True
924*da0073e9SAndroid Build Coastguard Worker                cfg.no_scale_embedding = True
925*da0073e9SAndroid Build Coastguard Worker                cfg.activation_fn = activation
926*da0073e9SAndroid Build Coastguard Worker
927*da0073e9SAndroid Build Coastguard Worker                dictionary = {}  # TODO: verify what this is
928*da0073e9SAndroid Build Coastguard Worker
929*da0073e9SAndroid Build Coastguard Worker                self.decoder = fairseq_transformer.TransformerDecoder(
930*da0073e9SAndroid Build Coastguard Worker                    cfg,
931*da0073e9SAndroid Build Coastguard Worker                    dictionary,
932*da0073e9SAndroid Build Coastguard Worker                    embedding_layer,
933*da0073e9SAndroid Build Coastguard Worker                    no_encoder_attn=True,
934*da0073e9SAndroid Build Coastguard Worker                    output_projection=None,
935*da0073e9SAndroid Build Coastguard Worker                )
936*da0073e9SAndroid Build Coastguard Worker
937*da0073e9SAndroid Build Coastguard Worker                if torch_encoder is not None:
938*da0073e9SAndroid Build Coastguard Worker                    self.decoder = torch_to_fairseq(torch_encoder, self.decoder)  # noqa: F821
939*da0073e9SAndroid Build Coastguard Worker                self.decoder = self.decoder.eval().cuda().half()
940*da0073e9SAndroid Build Coastguard Worker
941*da0073e9SAndroid Build Coastguard Worker            def forward(
942*da0073e9SAndroid Build Coastguard Worker                self,
943*da0073e9SAndroid Build Coastguard Worker                tokens,
944*da0073e9SAndroid Build Coastguard Worker                src_lengths=None,
945*da0073e9SAndroid Build Coastguard Worker                with_triangle_mask=False,
946*da0073e9SAndroid Build Coastguard Worker                incremental_state=None,
947*da0073e9SAndroid Build Coastguard Worker            ):
948*da0073e9SAndroid Build Coastguard Worker                return self.decoder(
949*da0073e9SAndroid Build Coastguard Worker                    prev_output_tokens=tokens,
950*da0073e9SAndroid Build Coastguard Worker                    encoder_out=None,
951*da0073e9SAndroid Build Coastguard Worker                    incremental_state=incremental_state,
952*da0073e9SAndroid Build Coastguard Worker                    features_only=True,
953*da0073e9SAndroid Build Coastguard Worker                    full_context_alignment=not with_triangle_mask,
954*da0073e9SAndroid Build Coastguard Worker                    alignment_layer=None,
955*da0073e9SAndroid Build Coastguard Worker                    alignment_heads=None,
956*da0073e9SAndroid Build Coastguard Worker                    src_lengths=src_lengths,
957*da0073e9SAndroid Build Coastguard Worker                    return_all_hiddens=False,
958*da0073e9SAndroid Build Coastguard Worker                )[0]
959*da0073e9SAndroid Build Coastguard Worker
960*da0073e9SAndroid Build Coastguard Worker    @parametrize("input_dim,attn_mask_dim,is_causal",
961*da0073e9SAndroid Build Coastguard Worker                 [(3, None, False), (3, 2, False), (3, 2, True), (3, 3, False), (3, 3, True),
962*da0073e9SAndroid Build Coastguard Worker                  (4, None, False), (4, 2, False), (4, 2, True), (4, 4, False), (4, 4, True)],
963*da0073e9SAndroid Build Coastguard Worker                 name_fn=lambda input_dim, attn_dim, is_causal: (
964*da0073e9SAndroid Build Coastguard Worker                     f"{input_dim}D_input_dim_" + (
965*da0073e9SAndroid Build Coastguard Worker                         f"{attn_dim}D_{'causal_' if is_causal else ''}attn_mask"
966*da0073e9SAndroid Build Coastguard Worker                         if attn_dim is not None else "no_attn_mask")))
967*da0073e9SAndroid Build Coastguard Worker    @parametrize("dropout_p", [0.0, 0.2, 0.5])
968*da0073e9SAndroid Build Coastguard Worker    @sdpa_kernel(backends=[SDPBackend.MATH])
969*da0073e9SAndroid Build Coastguard Worker    def test_scaled_dot_product_attention(self, device, input_dim, attn_mask_dim, is_causal, dropout_p):
970*da0073e9SAndroid Build Coastguard Worker        def sdp_ref(
971*da0073e9SAndroid Build Coastguard Worker                q,
972*da0073e9SAndroid Build Coastguard Worker                k,
973*da0073e9SAndroid Build Coastguard Worker                v,
974*da0073e9SAndroid Build Coastguard Worker                attn_mask=None,
975*da0073e9SAndroid Build Coastguard Worker                dropout_p=0.0):
976*da0073e9SAndroid Build Coastguard Worker            E = q.size(-1)
977*da0073e9SAndroid Build Coastguard Worker            q = q / math.sqrt(E)
978*da0073e9SAndroid Build Coastguard Worker            # (B, Nt, E) x (B, E, Ns) -> (B, Nt, Ns)
979*da0073e9SAndroid Build Coastguard Worker            if attn_mask is not None:
980*da0073e9SAndroid Build Coastguard Worker                attn = torch.baddbmm(attn_mask, q, k.transpose(-2, -1))
981*da0073e9SAndroid Build Coastguard Worker            else:
982*da0073e9SAndroid Build Coastguard Worker                attn = torch.bmm(q, k.transpose(-2, -1))
983*da0073e9SAndroid Build Coastguard Worker
984*da0073e9SAndroid Build Coastguard Worker            attn = torch.nn.functional.softmax(attn, dim=-1)
985*da0073e9SAndroid Build Coastguard Worker            if dropout_p > 0.0:
986*da0073e9SAndroid Build Coastguard Worker                attn = torch.nn.functional.dropout(attn, p=dropout_p)
987*da0073e9SAndroid Build Coastguard Worker            # (B, Nt, Ns) x (B, Ns, E) -> (B, Nt, E)
988*da0073e9SAndroid Build Coastguard Worker            output = torch.bmm(attn, v)
989*da0073e9SAndroid Build Coastguard Worker            return output
990*da0073e9SAndroid Build Coastguard Worker        # TODO: Support cross-device / dtype testing properly when instantiate_device_type_tests() is used.
991*da0073e9SAndroid Build Coastguard Worker        dtypes = [torch.double, torch.float]
992*da0073e9SAndroid Build Coastguard Worker        for dtype in dtypes:
993*da0073e9SAndroid Build Coastguard Worker
994*da0073e9SAndroid Build Coastguard Worker            def rand_tensor(*shape):
995*da0073e9SAndroid Build Coastguard Worker                return torch.randn(shape, device=device, dtype=dtype)
996*da0073e9SAndroid Build Coastguard Worker
997*da0073e9SAndroid Build Coastguard Worker            # This test compares python and C++ implementations of SDP.
998*da0073e9SAndroid Build Coastguard Worker            N, N_prime, L, S, E = 5, 2, 4, 3, 6
999*da0073e9SAndroid Build Coastguard Worker            if input_dim == 3:
1000*da0073e9SAndroid Build Coastguard Worker                query = rand_tensor(N, L, E)
1001*da0073e9SAndroid Build Coastguard Worker                key = rand_tensor(N, S, E)
1002*da0073e9SAndroid Build Coastguard Worker                value = rand_tensor(N, S, E)
1003*da0073e9SAndroid Build Coastguard Worker            elif input_dim == 4:
1004*da0073e9SAndroid Build Coastguard Worker                query = rand_tensor(N, N_prime, L, E)
1005*da0073e9SAndroid Build Coastguard Worker                key = rand_tensor(N, N_prime, S, E)
1006*da0073e9SAndroid Build Coastguard Worker                value = rand_tensor(N, N_prime, S, E)
1007*da0073e9SAndroid Build Coastguard Worker            else:
1008*da0073e9SAndroid Build Coastguard Worker                self.fail(f'Invalid input_dim {input_dim} encountered in SDP test')
1009*da0073e9SAndroid Build Coastguard Worker
1010*da0073e9SAndroid Build Coastguard Worker            attn_mask = None
1011*da0073e9SAndroid Build Coastguard Worker            if attn_mask_dim is not None:
1012*da0073e9SAndroid Build Coastguard Worker                assert attn_mask_dim in [2, input_dim]
1013*da0073e9SAndroid Build Coastguard Worker                mask_size = (L, S) if attn_mask_dim == 2 else ((N, L, S) if input_dim == 3 else (N, N_prime, L, S))
1014*da0073e9SAndroid Build Coastguard Worker                attn_mask = (torch.ones(mask_size, device=device, dtype=torch.bool).tril() if is_causal
1015*da0073e9SAndroid Build Coastguard Worker                             else torch.randint(0, 2, size=mask_size, device=device, dtype=torch.bool))
1016*da0073e9SAndroid Build Coastguard Worker
1017*da0073e9SAndroid Build Coastguard Worker            with freeze_rng_state():
1018*da0073e9SAndroid Build Coastguard Worker                # Python impl only supports float mask and 3D inputs.
1019*da0073e9SAndroid Build Coastguard Worker                attn_mask_float = attn_mask
1020*da0073e9SAndroid Build Coastguard Worker                if attn_mask_float is not None:
1021*da0073e9SAndroid Build Coastguard Worker                    attn_mask_float = torch.zeros_like(attn_mask, dtype=query.dtype)
1022*da0073e9SAndroid Build Coastguard Worker                    attn_mask_float.masked_fill_(attn_mask.logical_not(), float("-inf"))
1023*da0073e9SAndroid Build Coastguard Worker                q, k, v = query.view(-1, L, E), key.view(-1, S, E), value.view(-1, S, E)
1024*da0073e9SAndroid Build Coastguard Worker                a = attn_mask_float
1025*da0073e9SAndroid Build Coastguard Worker                if a is not None and attn_mask_dim > 3:
1026*da0073e9SAndroid Build Coastguard Worker                    a = a.view(-1, L, S)
1027*da0073e9SAndroid Build Coastguard Worker                expected = sdp_ref(q, k, v, attn_mask=a, dropout_p=dropout_p)
1028*da0073e9SAndroid Build Coastguard Worker                if input_dim > 3:
1029*da0073e9SAndroid Build Coastguard Worker                    expected = expected.view(-1, N_prime, L, E)
1030*da0073e9SAndroid Build Coastguard Worker
1031*da0073e9SAndroid Build Coastguard Worker            with freeze_rng_state():
1032*da0073e9SAndroid Build Coastguard Worker                if is_causal:
1033*da0073e9SAndroid Build Coastguard Worker                    # NB: Don't pass attn_mask here
1034*da0073e9SAndroid Build Coastguard Worker                    actual = torch.nn.functional.scaled_dot_product_attention(
1035*da0073e9SAndroid Build Coastguard Worker                        query, key, value, None, dropout_p, is_causal)
1036*da0073e9SAndroid Build Coastguard Worker
1037*da0073e9SAndroid Build Coastguard Worker                    # Error case: both explicit attn_mask and is_causal are set
1038*da0073e9SAndroid Build Coastguard Worker                    with self.assertRaisesRegex(RuntimeError,
1039*da0073e9SAndroid Build Coastguard Worker                                                "Explicit attn_mask should not be set when is_causal=True"):
1040*da0073e9SAndroid Build Coastguard Worker                        torch.nn.functional.scaled_dot_product_attention(
1041*da0073e9SAndroid Build Coastguard Worker                            query, key, value, attn_mask, dropout_p, is_causal)
1042*da0073e9SAndroid Build Coastguard Worker                else:
1043*da0073e9SAndroid Build Coastguard Worker                    actual = torch.nn.functional.scaled_dot_product_attention(
1044*da0073e9SAndroid Build Coastguard Worker                        query, key, value, attn_mask, dropout_p, is_causal)
1045*da0073e9SAndroid Build Coastguard Worker
1046*da0073e9SAndroid Build Coastguard Worker                self.assertEqual(actual, expected)
1047*da0073e9SAndroid Build Coastguard Worker
1048*da0073e9SAndroid Build Coastguard Worker        if attn_mask_dim is None:
1049*da0073e9SAndroid Build Coastguard Worker            q = q.double().clone()
1050*da0073e9SAndroid Build Coastguard Worker            k = k.double().clone()
1051*da0073e9SAndroid Build Coastguard Worker            v = v.double().clone()
1052*da0073e9SAndroid Build Coastguard Worker            q.requires_grad_()
1053*da0073e9SAndroid Build Coastguard Worker            k.requires_grad_()
1054*da0073e9SAndroid Build Coastguard Worker            v.requires_grad_()
1055*da0073e9SAndroid Build Coastguard Worker
1056*da0073e9SAndroid Build Coastguard Worker            assert gradcheck(lambda *args, **kwargs: wrapper_set_seed(sdp_ref, *args, **kwargs),
1057*da0073e9SAndroid Build Coastguard Worker                             (q, k, v, attn_mask, dropout_p))
1058*da0073e9SAndroid Build Coastguard Worker            assert gradcheck(lambda *args, **kwargs:
1059*da0073e9SAndroid Build Coastguard Worker                             wrapper_set_seed(torch.nn.functional.scaled_dot_product_attention, *args, **kwargs),
1060*da0073e9SAndroid Build Coastguard Worker                             (q, k, v, attn_mask, dropout_p))
1061*da0073e9SAndroid Build Coastguard Worker
1062*da0073e9SAndroid Build Coastguard Worker        def test_incompatible_mask(self, device):
1063*da0073e9SAndroid Build Coastguard Worker            def ones_tensor(*shape):
1064*da0073e9SAndroid Build Coastguard Worker                return torch.ones(shape, dtype=torch.float32)
1065*da0073e9SAndroid Build Coastguard Worker            S, L, E, H = 1, 2, 4, 1
1066*da0073e9SAndroid Build Coastguard Worker            qkv = ones_tensor(S, L, E)
1067*da0073e9SAndroid Build Coastguard Worker
1068*da0073e9SAndroid Build Coastguard Worker            mha = nn.MultiheadAttention(E, H)
1069*da0073e9SAndroid Build Coastguard Worker            mha.in_proj_weight = Parameter(torch.ones((E * 3, E)))
1070*da0073e9SAndroid Build Coastguard Worker            mha.out_proj.weight = Parameter(torch.ones((E, E)))
1071*da0073e9SAndroid Build Coastguard Worker            qkv = qkv.to(float)
1072*da0073e9SAndroid Build Coastguard Worker            kpm = ones_tensor(S, L) * float("-inf")
1073*da0073e9SAndroid Build Coastguard Worker            am = ones_tensor(L, L).to(bool)
1074*da0073e9SAndroid Build Coastguard Worker
1075*da0073e9SAndroid Build Coastguard Worker            def func():
1076*da0073e9SAndroid Build Coastguard Worker                return mha(qkv, qkv, qkv, need_weights=False, key_padding_mask=kpm, attn_mask=am)
1077*da0073e9SAndroid Build Coastguard Worker
1078*da0073e9SAndroid Build Coastguard Worker            self.assertRaises(RuntimeError, func)
1079*da0073e9SAndroid Build Coastguard Worker
1080*da0073e9SAndroid Build Coastguard Worker    @unittest.skipIf(TEST_WITH_CROSSREF, 'Fastpath not available with crossref')
1081*da0073e9SAndroid Build Coastguard Worker    @torch.no_grad()
1082*da0073e9SAndroid Build Coastguard Worker    def test_mask_check_fastpath(self):
1083*da0073e9SAndroid Build Coastguard Worker        """
1084*da0073e9SAndroid Build Coastguard Worker        Test that fastpath is executed independently of the masks that are passed.
1085*da0073e9SAndroid Build Coastguard Worker        If the passed key padding mask is left aligned or mask_check=False, test that nested tensors are used
1086*da0073e9SAndroid Build Coastguard Worker        (sparsity fastpath), otherwise use fastpath with traditional tensors.
1087*da0073e9SAndroid Build Coastguard Worker        Also test that fast path is executed with both key padding mask and attention mask passed at the same time.
1088*da0073e9SAndroid Build Coastguard Worker        """
1089*da0073e9SAndroid Build Coastguard Worker
1090*da0073e9SAndroid Build Coastguard Worker        x = torch.Tensor([[[1, 2], [3, 4], [5, 6]]]).to(torch.float)
1091*da0073e9SAndroid Build Coastguard Worker
1092*da0073e9SAndroid Build Coastguard Worker        def _test_fastpath(model, key_padding_mask, mock_return_value, attn_mask=None, nested_tensors=True):
1093*da0073e9SAndroid Build Coastguard Worker            with patch('torch._transformer_encoder_layer_fwd') as fastpath_mock:
1094*da0073e9SAndroid Build Coastguard Worker                fastpath_mock.return_value = mock_return_value
1095*da0073e9SAndroid Build Coastguard Worker                model(x, src_key_padding_mask=key_padding_mask, mask=attn_mask)
1096*da0073e9SAndroid Build Coastguard Worker
1097*da0073e9SAndroid Build Coastguard Worker                # If mock was called, fastpath was taken
1098*da0073e9SAndroid Build Coastguard Worker                self.assertTrue(fastpath_mock.called)
1099*da0073e9SAndroid Build Coastguard Worker
1100*da0073e9SAndroid Build Coastguard Worker                # If mock was called with nested tensors, sparsity fastpath was taken
1101*da0073e9SAndroid Build Coastguard Worker                for call_args, _ in fastpath_mock.call_args_list:
1102*da0073e9SAndroid Build Coastguard Worker                    self.assertEqual(call_args[0].is_nested, nested_tensors)
1103*da0073e9SAndroid Build Coastguard Worker
1104*da0073e9SAndroid Build Coastguard Worker        encoder_layer = torch.nn.TransformerEncoderLayer(d_model=2, nhead=2, dim_feedforward=8, batch_first=True)
1105*da0073e9SAndroid Build Coastguard Worker
1106*da0073e9SAndroid Build Coastguard Worker        model = torch.nn.TransformerEncoder(encoder_layer, num_layers=2, enable_nested_tensor=True, mask_check=True)
1107*da0073e9SAndroid Build Coastguard Worker        model.eval()
1108*da0073e9SAndroid Build Coastguard Worker
1109*da0073e9SAndroid Build Coastguard Worker        aligned_key_padding_mask = torch.Tensor([[0, 0, 1]]).to(torch.bool)
1110*da0073e9SAndroid Build Coastguard Worker        not_aligned_key_padding_mask = torch.Tensor([[1, 0, 1]]).to(torch.bool)
1111*da0073e9SAndroid Build Coastguard Worker        attn_mask = torch.Tensor([[1, 0, 1], [0, 1, 0], [1, 0, 1]]).to(torch.bool)
1112*da0073e9SAndroid Build Coastguard Worker        nested_tensor_return_value = torch.nested.nested_tensor([torch.ones((2, 2), dtype=torch.float)])
1113*da0073e9SAndroid Build Coastguard Worker        tensor_return_value = torch.ones((1, 3, 2), dtype=torch.float)
1114*da0073e9SAndroid Build Coastguard Worker
1115*da0073e9SAndroid Build Coastguard Worker        # Left aligned mask results in sparsity fastpath
1116*da0073e9SAndroid Build Coastguard Worker        _test_fastpath(model, aligned_key_padding_mask, nested_tensor_return_value, nested_tensors=True)
1117*da0073e9SAndroid Build Coastguard Worker
1118*da0073e9SAndroid Build Coastguard Worker        # Not aligned mask results in fastpath
1119*da0073e9SAndroid Build Coastguard Worker        _test_fastpath(model, not_aligned_key_padding_mask, tensor_return_value, nested_tensors=False)
1120*da0073e9SAndroid Build Coastguard Worker
1121*da0073e9SAndroid Build Coastguard Worker        model = torch.nn.TransformerEncoder(encoder_layer, num_layers=2, enable_nested_tensor=False, mask_check=True)
1122*da0073e9SAndroid Build Coastguard Worker        model.eval()
1123*da0073e9SAndroid Build Coastguard Worker
1124*da0073e9SAndroid Build Coastguard Worker        # If nested tensor disabled, fastpath is always taken
1125*da0073e9SAndroid Build Coastguard Worker        _test_fastpath(model, aligned_key_padding_mask, tensor_return_value, nested_tensors=False)
1126*da0073e9SAndroid Build Coastguard Worker        _test_fastpath(model, not_aligned_key_padding_mask, tensor_return_value, nested_tensors=False)
1127*da0073e9SAndroid Build Coastguard Worker        # Fast path is taken if both attention mask and key padding mask are present
1128*da0073e9SAndroid Build Coastguard Worker        _test_fastpath(model, aligned_key_padding_mask, tensor_return_value, attn_mask=attn_mask, nested_tensors=False)
1129*da0073e9SAndroid Build Coastguard Worker
1130*da0073e9SAndroid Build Coastguard Worker        model = torch.nn.TransformerEncoder(encoder_layer, num_layers=2, enable_nested_tensor=True, mask_check=False)
1131*da0073e9SAndroid Build Coastguard Worker        model.eval()
1132*da0073e9SAndroid Build Coastguard Worker
1133*da0073e9SAndroid Build Coastguard Worker        # Mask check disabled results in sparisty fastpath, independently of the mask
1134*da0073e9SAndroid Build Coastguard Worker        _test_fastpath(model, aligned_key_padding_mask, nested_tensor_return_value, nested_tensors=True)
1135*da0073e9SAndroid Build Coastguard Worker        _test_fastpath(model, not_aligned_key_padding_mask, nested_tensor_return_value, nested_tensors=True)
1136*da0073e9SAndroid Build Coastguard Worker
1137*da0073e9SAndroid Build Coastguard Worker    # Test failing MHA when bias was NoneType
1138*da0073e9SAndroid Build Coastguard Worker    def test_bias_is_none(self):
1139*da0073e9SAndroid Build Coastguard Worker        x = torch.rand((1, 5, 10))
1140*da0073e9SAndroid Build Coastguard Worker        model = torch.nn.modules.activation.MultiheadAttention(10, 1, bias=False, batch_first=True)
1141*da0073e9SAndroid Build Coastguard Worker        model.eval()
1142*da0073e9SAndroid Build Coastguard Worker        model(x, x, x)
1143*da0073e9SAndroid Build Coastguard Worker        # completes without error
1144*da0073e9SAndroid Build Coastguard Worker
1145*da0073e9SAndroid Build Coastguard Worker    def test_transformer_bias_is_none(self, device):
1146*da0073e9SAndroid Build Coastguard Worker        batch_size = 2
1147*da0073e9SAndroid Build Coastguard Worker        seqlen = 3
1148*da0073e9SAndroid Build Coastguard Worker        d_model = 8
1149*da0073e9SAndroid Build Coastguard Worker        nhead = 4
1150*da0073e9SAndroid Build Coastguard Worker
1151*da0073e9SAndroid Build Coastguard Worker        encoder_layer = torch.nn.TransformerEncoderLayer(d_model, nhead, bias=False, batch_first=True, device=device)
1152*da0073e9SAndroid Build Coastguard Worker        encoder_layer.eval()
1153*da0073e9SAndroid Build Coastguard Worker        x = torch.randn(batch_size, seqlen, d_model, device=device)
1154*da0073e9SAndroid Build Coastguard Worker        # runs without error
1155*da0073e9SAndroid Build Coastguard Worker        encoder_layer(x)
1156*da0073e9SAndroid Build Coastguard Worker
1157*da0073e9SAndroid Build Coastguard Worker        with self.assertWarnsRegex(UserWarning, "encoder_layer.self_attn was passed bias=False"):
1158*da0073e9SAndroid Build Coastguard Worker            encoder = torch.nn.TransformerEncoder(encoder_layer, num_layers=1).eval()
1159*da0073e9SAndroid Build Coastguard Worker            encoder(x)
1160*da0073e9SAndroid Build Coastguard Worker
1161*da0073e9SAndroid Build Coastguard Worker        with self.assertWarnsRegex(UserWarning, "self_attn was passed bias=False"):
1162*da0073e9SAndroid Build Coastguard Worker            transformer = torch.nn.Transformer(
1163*da0073e9SAndroid Build Coastguard Worker                d_model=d_model, nhead=nhead, bias=False, batch_first=True, device=device
1164*da0073e9SAndroid Build Coastguard Worker            ).eval()
1165*da0073e9SAndroid Build Coastguard Worker            transformer(x, x)
1166*da0073e9SAndroid Build Coastguard Worker
1167*da0073e9SAndroid Build Coastguard Worker    def test_train_with_is_causal(self, device):
1168*da0073e9SAndroid Build Coastguard Worker        # training with is_causal
1169*da0073e9SAndroid Build Coastguard Worker        S, L, E, H = 1, 2, 2, 1
1170*da0073e9SAndroid Build Coastguard Worker        layer = nn.TransformerEncoderLayer(
1171*da0073e9SAndroid Build Coastguard Worker            d_model=2,
1172*da0073e9SAndroid Build Coastguard Worker            dim_feedforward=4,
1173*da0073e9SAndroid Build Coastguard Worker            nhead=H,
1174*da0073e9SAndroid Build Coastguard Worker            batch_first=True,
1175*da0073e9SAndroid Build Coastguard Worker            activation="gelu",
1176*da0073e9SAndroid Build Coastguard Worker            dropout=0,
1177*da0073e9SAndroid Build Coastguard Worker        )
1178*da0073e9SAndroid Build Coastguard Worker        criterion = nn.MSELoss()
1179*da0073e9SAndroid Build Coastguard Worker        encoder = nn.TransformerEncoder(layer, 2).to(device)
1180*da0073e9SAndroid Build Coastguard Worker        optimizer = optim.SGD(encoder.parameters(), lr=0.1, momentum=0.9)
1181*da0073e9SAndroid Build Coastguard Worker        encoder.train()
1182*da0073e9SAndroid Build Coastguard Worker
1183*da0073e9SAndroid Build Coastguard Worker        encoder.train()
1184*da0073e9SAndroid Build Coastguard Worker        optimizer.zero_grad()
1185*da0073e9SAndroid Build Coastguard Worker        inputs = torch.randn(S, L, E).to(device)
1186*da0073e9SAndroid Build Coastguard Worker        mask = torch.nn.Transformer.generate_square_subsequent_mask(
1187*da0073e9SAndroid Build Coastguard Worker            inputs.size(1), device=device
1188*da0073e9SAndroid Build Coastguard Worker        )
1189*da0073e9SAndroid Build Coastguard Worker
1190*da0073e9SAndroid Build Coastguard Worker        outputs = encoder(inputs, mask=mask, is_causal=True)
1191*da0073e9SAndroid Build Coastguard Worker
1192*da0073e9SAndroid Build Coastguard Worker        loss = criterion(outputs[:, 0:2, :], inputs[:, 0:2, :])
1193*da0073e9SAndroid Build Coastguard Worker        loss.backward()
1194*da0073e9SAndroid Build Coastguard Worker        optimizer.step()
1195*da0073e9SAndroid Build Coastguard Worker
1196*da0073e9SAndroid Build Coastguard Worker        # inference with is_causal
1197*da0073e9SAndroid Build Coastguard Worker        t_qvk = torch.randn((S, L, E), device=device, dtype=torch.float32)
1198*da0073e9SAndroid Build Coastguard Worker        mha = nn.MultiheadAttention(E, H).to(device)
1199*da0073e9SAndroid Build Coastguard Worker        mask = torch.nn.Transformer.generate_square_subsequent_mask(
1200*da0073e9SAndroid Build Coastguard Worker            S, device=device
1201*da0073e9SAndroid Build Coastguard Worker        )
1202*da0073e9SAndroid Build Coastguard Worker
1203*da0073e9SAndroid Build Coastguard Worker        attn_out, _ = mha(t_qvk, t_qvk, t_qvk, attn_mask=mask, is_causal=True)
1204*da0073e9SAndroid Build Coastguard Worker
1205*da0073e9SAndroid Build Coastguard Worker        # Can't give only is_causal
1206*da0073e9SAndroid Build Coastguard Worker        attn_mask = torch.randint(0, 2, size=(L, L), device=device, dtype=torch.bool)
1207*da0073e9SAndroid Build Coastguard Worker        with self.assertRaises(RuntimeError):
1208*da0073e9SAndroid Build Coastguard Worker            _ = mha(t_qvk, t_qvk, t_qvk, is_causal=True)
1209*da0073e9SAndroid Build Coastguard Worker
1210*da0073e9SAndroid Build Coastguard Worker        # # Passing a causal mask sets is_causal to 1
1211*da0073e9SAndroid Build Coastguard Worker        causal_mask = torch.triu(
1212*da0073e9SAndroid Build Coastguard Worker            torch.ones(L, L, device=inputs.device) * float('-inf'), diagonal=1
1213*da0073e9SAndroid Build Coastguard Worker        ).to(torch.bool)
1214*da0073e9SAndroid Build Coastguard Worker
1215*da0073e9SAndroid Build Coastguard Worker        mock_layer = MagicMock(torch.nn.MultiheadAttention(E, H), return_value=inputs)
1216*da0073e9SAndroid Build Coastguard Worker        encoder.layers[1] = mock_layer
1217*da0073e9SAndroid Build Coastguard Worker        outputs = encoder(inputs, mask=causal_mask)
1218*da0073e9SAndroid Build Coastguard Worker        mock_layer.assert_called_with(ANY, src_mask=ANY, is_causal=True, src_key_padding_mask=ANY)
1219*da0073e9SAndroid Build Coastguard Worker
1220*da0073e9SAndroid Build Coastguard Worker        # check expected numerical values with all kernels
1221*da0073e9SAndroid Build Coastguard Worker        self.is_causal_kernels([SDPBackend.MATH], device)
1222*da0073e9SAndroid Build Coastguard Worker
1223*da0073e9SAndroid Build Coastguard Worker    def is_causal_kernels(self, kernels, device):
1224*da0073e9SAndroid Build Coastguard Worker        def ones_tensor(*shape):
1225*da0073e9SAndroid Build Coastguard Worker            return torch.ones(shape, device=device, dtype=torch.float32).to(device)
1226*da0073e9SAndroid Build Coastguard Worker        S, L, E, H = 1, 2, 4, 1
1227*da0073e9SAndroid Build Coastguard Worker        qkv = ones_tensor(S, L, E)
1228*da0073e9SAndroid Build Coastguard Worker
1229*da0073e9SAndroid Build Coastguard Worker        mha = nn.MultiheadAttention(E, H).to(device)
1230*da0073e9SAndroid Build Coastguard Worker        mha.in_proj_weight = Parameter(torch.ones((E * 3, E), device=device))
1231*da0073e9SAndroid Build Coastguard Worker        mha.out_proj.weight = Parameter(torch.ones((E, E), device=device))
1232*da0073e9SAndroid Build Coastguard Worker        expected = torch.ones(size=(S, L, E)).to(device) * 16
1233*da0073e9SAndroid Build Coastguard Worker        mask = torch.nn.Transformer.generate_square_subsequent_mask(
1234*da0073e9SAndroid Build Coastguard Worker            qkv.size(1), device=device
1235*da0073e9SAndroid Build Coastguard Worker        )
1236*da0073e9SAndroid Build Coastguard Worker
1237*da0073e9SAndroid Build Coastguard Worker        for kernel in kernels:
1238*da0073e9SAndroid Build Coastguard Worker            with sdpa_kernel(backends=[kernel]):
1239*da0073e9SAndroid Build Coastguard Worker                actual, _ = mha(qkv, qkv, qkv, attn_mask=mask, need_weights=False, is_causal=True)
1240*da0073e9SAndroid Build Coastguard Worker                self.assertTrue(torch.equal(actual, expected))
1241*da0073e9SAndroid Build Coastguard Worker
1242*da0073e9SAndroid Build Coastguard Worker                if kernel != SDPBackend.MATH:
1243*da0073e9SAndroid Build Coastguard Worker                    # fails with embedding size not multiple of 4
1244*da0073e9SAndroid Build Coastguard Worker                    with self.assertRaisesRegex(RuntimeError, "No available kernel"):
1245*da0073e9SAndroid Build Coastguard Worker                        qkv_f, mha_f = ones_tensor(S, L, 2), nn.MultiheadAttention(2, H).to(device)
1246*da0073e9SAndroid Build Coastguard Worker                        mask = torch.nn.Transformer.generate_square_subsequent_mask(
1247*da0073e9SAndroid Build Coastguard Worker                            qkv_f.size(1), device=device
1248*da0073e9SAndroid Build Coastguard Worker                        )
1249*da0073e9SAndroid Build Coastguard Worker                        _ = mha_f(qkv_f, qkv_f, qkv_f, attn_mask=mask, need_weights=False, is_causal=True)
1250*da0073e9SAndroid Build Coastguard Worker                        torch.cuda.synchronize()
1251*da0073e9SAndroid Build Coastguard Worker
1252*da0073e9SAndroid Build Coastguard Worker    @skipIfRocm  # Missing EFFICIENT_ATTENTION
1253*da0073e9SAndroid Build Coastguard Worker    @unittest.skipIf(
1254*da0073e9SAndroid Build Coastguard Worker        not PLATFORM_SUPPORTS_FLASH_ATTENTION, "Platform does not supposrt fused SDPA or pre-SM80 hardware"
1255*da0073e9SAndroid Build Coastguard Worker    )
1256*da0073e9SAndroid Build Coastguard Worker    def test_is_causal_gpu(self):
1257*da0073e9SAndroid Build Coastguard Worker        device = 'cuda'
1258*da0073e9SAndroid Build Coastguard Worker        self.is_causal_kernels([SDPBackend.MATH, SDPBackend.EFFICIENT_ATTENTION], device)
1259*da0073e9SAndroid Build Coastguard Worker
1260*da0073e9SAndroid Build Coastguard Worker    def test_script_mha_in_proj_weight_none(self):
1261*da0073e9SAndroid Build Coastguard Worker        mha = torch.nn.MultiheadAttention(
1262*da0073e9SAndroid Build Coastguard Worker            embed_dim=128, num_heads=8, kdim=256, vdim=256
1263*da0073e9SAndroid Build Coastguard Worker        ).eval()
1264*da0073e9SAndroid Build Coastguard Worker
1265*da0073e9SAndroid Build Coastguard Worker        torch.jit.script(mha)
1266*da0073e9SAndroid Build Coastguard Worker
1267*da0073e9SAndroid Build Coastguard Worker    @unittest.skipIf(TEST_WITH_CROSSREF, 'Fastpath not available with crossref')
1268*da0073e9SAndroid Build Coastguard Worker    @torch.no_grad()
1269*da0073e9SAndroid Build Coastguard Worker    def test_disable_fastpath(self, device):
1270*da0073e9SAndroid Build Coastguard Worker        def _test_te_fastpath_called(model, args, kwargs=None, return_value=None, is_called=True):
1271*da0073e9SAndroid Build Coastguard Worker            if kwargs is None:
1272*da0073e9SAndroid Build Coastguard Worker                kwargs = {}
1273*da0073e9SAndroid Build Coastguard Worker            with patch('torch._transformer_encoder_layer_fwd') as fastpath_mock:
1274*da0073e9SAndroid Build Coastguard Worker                fastpath_mock.return_value = return_value
1275*da0073e9SAndroid Build Coastguard Worker                output = model(*args, **kwargs)
1276*da0073e9SAndroid Build Coastguard Worker                self.assertTrue(fastpath_mock.called == is_called)
1277*da0073e9SAndroid Build Coastguard Worker
1278*da0073e9SAndroid Build Coastguard Worker        def _test_mha_fastpath_called(model, args, kwargs=None, return_value=None, is_called=True):
1279*da0073e9SAndroid Build Coastguard Worker            if kwargs is None:
1280*da0073e9SAndroid Build Coastguard Worker                kwargs = {}
1281*da0073e9SAndroid Build Coastguard Worker            with patch('torch._native_multi_head_attention') as fastpath_mock:
1282*da0073e9SAndroid Build Coastguard Worker                fastpath_mock.return_value = return_value
1283*da0073e9SAndroid Build Coastguard Worker                output = model(*args, **kwargs)
1284*da0073e9SAndroid Build Coastguard Worker                self.assertTrue(fastpath_mock.called == is_called)
1285*da0073e9SAndroid Build Coastguard Worker
1286*da0073e9SAndroid Build Coastguard Worker        inp = torch.tensor([[[1, 2], [3, 4], [5, 6]]], dtype=torch.float32, device=device)
1287*da0073e9SAndroid Build Coastguard Worker        aligned_key_padding_mask = torch.tensor([[0, 0, 1]], dtype=torch.bool, device=device)
1288*da0073e9SAndroid Build Coastguard Worker        src_key_padding_mask = torch.tensor([[1, 0, 1]], dtype=torch.bool, device=device)
1289*da0073e9SAndroid Build Coastguard Worker        attn_mask = torch.tensor([[1, 0, 1], [0, 1, 0], [1, 0, 1]], dtype=torch.bool, device=device)
1290*da0073e9SAndroid Build Coastguard Worker        te_return_value = torch.ones((1, 3, 2), dtype=torch.float32)
1291*da0073e9SAndroid Build Coastguard Worker
1292*da0073e9SAndroid Build Coastguard Worker        encoder_layer = torch.nn.TransformerEncoderLayer(d_model=2, nhead=2, dim_feedforward=8, batch_first=True)
1293*da0073e9SAndroid Build Coastguard Worker        te = torch.nn.TransformerEncoder(encoder_layer, num_layers=2, enable_nested_tensor=True, mask_check=True)
1294*da0073e9SAndroid Build Coastguard Worker        te = te.to(device).eval()
1295*da0073e9SAndroid Build Coastguard Worker
1296*da0073e9SAndroid Build Coastguard Worker        t = torch.nn.Transformer(d_model=2, nhead=2, batch_first=True, device=device).eval()
1297*da0073e9SAndroid Build Coastguard Worker        src = torch.tensor([[[0, 1], [2, 3], [4, 5]]], dtype=torch.float32, device=device)
1298*da0073e9SAndroid Build Coastguard Worker        tgt = torch.tensor([[[0, 1], [2, 3], [4, 5], [6, 7]]], dtype=torch.float32, device=device)
1299*da0073e9SAndroid Build Coastguard Worker        t_return_value = torch.ones((1, 3, 2), dtype=torch.float32, device=device)
1300*da0073e9SAndroid Build Coastguard Worker
1301*da0073e9SAndroid Build Coastguard Worker        mha = nn.MultiheadAttention(2, 2, batch_first=True, device=device).eval()
1302*da0073e9SAndroid Build Coastguard Worker        q = torch.tensor([[[0, 1], [2, 3]]], dtype=torch.float32, device=device)
1303*da0073e9SAndroid Build Coastguard Worker        mha_return_value = torch.ones((1, 3, 2), dtype=torch.float32, device=device)
1304*da0073e9SAndroid Build Coastguard Worker
1305*da0073e9SAndroid Build Coastguard Worker        _test_te_fastpath_called(
1306*da0073e9SAndroid Build Coastguard Worker            te, (inp,), kwargs={'src_key_padding_mask': src_key_padding_mask},
1307*da0073e9SAndroid Build Coastguard Worker            return_value=te_return_value, is_called=True
1308*da0073e9SAndroid Build Coastguard Worker        )
1309*da0073e9SAndroid Build Coastguard Worker        _test_te_fastpath_called(t, (src, tgt), return_value=t_return_value, is_called=True)
1310*da0073e9SAndroid Build Coastguard Worker        _test_mha_fastpath_called(mha, (q, q, q,), return_value=mha_return_value, is_called=True)
1311*da0073e9SAndroid Build Coastguard Worker
1312*da0073e9SAndroid Build Coastguard Worker        torch.backends.mha.set_fastpath_enabled(False)
1313*da0073e9SAndroid Build Coastguard Worker        _test_te_fastpath_called(
1314*da0073e9SAndroid Build Coastguard Worker            te, (inp,), kwargs={'src_key_padding_mask': src_key_padding_mask},
1315*da0073e9SAndroid Build Coastguard Worker            return_value=te_return_value, is_called=False
1316*da0073e9SAndroid Build Coastguard Worker        )
1317*da0073e9SAndroid Build Coastguard Worker        _test_te_fastpath_called(t, (src, tgt), return_value=t_return_value, is_called=False)
1318*da0073e9SAndroid Build Coastguard Worker        _test_mha_fastpath_called(mha, (q, q, q,), return_value=mha_return_value, is_called=False)
1319*da0073e9SAndroid Build Coastguard Worker
1320*da0073e9SAndroid Build Coastguard Worker        torch.backends.mha.set_fastpath_enabled(True)
1321*da0073e9SAndroid Build Coastguard Worker        _test_te_fastpath_called(
1322*da0073e9SAndroid Build Coastguard Worker            te, (inp,), kwargs={'src_key_padding_mask': src_key_padding_mask},
1323*da0073e9SAndroid Build Coastguard Worker            return_value=te_return_value, is_called=True
1324*da0073e9SAndroid Build Coastguard Worker        )
1325*da0073e9SAndroid Build Coastguard Worker        _test_te_fastpath_called(t, (src, tgt), return_value=t_return_value, is_called=True)
1326*da0073e9SAndroid Build Coastguard Worker        _test_mha_fastpath_called(mha, (q, q, q,), return_value=mha_return_value, is_called=True)
1327*da0073e9SAndroid Build Coastguard Worker
1328*da0073e9SAndroid Build Coastguard Worker
1329*da0073e9SAndroid Build Coastguard Workerclass TestSDPAFailureModes(NNTestCase):
1330*da0073e9SAndroid Build Coastguard Worker    """ Used to test the failure modes of scaled_dot_product_attention
1331*da0073e9SAndroid Build Coastguard Worker    """
1332*da0073e9SAndroid Build Coastguard Worker    _do_cuda_memory_leak_check = True
1333*da0073e9SAndroid Build Coastguard Worker    _do_cuda_non_default_stream = True
1334*da0073e9SAndroid Build Coastguard Worker
1335*da0073e9SAndroid Build Coastguard Worker    @onlyCUDA
1336*da0073e9SAndroid Build Coastguard Worker    @unittest.skipIf(
1337*da0073e9SAndroid Build Coastguard Worker        not PLATFORM_SUPPORTS_FLASH_ATTENTION or not isSM8XDevice,
1338*da0073e9SAndroid Build Coastguard Worker        "Does not support fused SDPA or not SM86+ hardware",
1339*da0073e9SAndroid Build Coastguard Worker    )
1340*da0073e9SAndroid Build Coastguard Worker    @parametrize("head_dim", [193, 204, 256])
1341*da0073e9SAndroid Build Coastguard Worker    @parametrize("dropout_p", [0.0, 0.2])
1342*da0073e9SAndroid Build Coastguard Worker    def test_flash_backward_failure_sm86plus(self, device, head_dim: int, dropout_p: float):
1343*da0073e9SAndroid Build Coastguard Worker        dtype = torch.float16
1344*da0073e9SAndroid Build Coastguard Worker        make_tensor = partial(torch.rand, device=device, dtype=dtype)
1345*da0073e9SAndroid Build Coastguard Worker        # See check_requires_grad_and_head_dim_gt192_constraints_on_sm86_89 in
1346*da0073e9SAndroid Build Coastguard Worker        # pytorch/aten/src/ATen/native/transformers/cuda/sdp_utils.h
1347*da0073e9SAndroid Build Coastguard Worker        size = (2, 2, 4, head_dim)
1348*da0073e9SAndroid Build Coastguard Worker        q, k, v = make_tensor(size), make_tensor(size), make_tensor(size)
1349*da0073e9SAndroid Build Coastguard Worker
1350*da0073e9SAndroid Build Coastguard Worker        with sdpa_kernel(backends=[SDPBackend.MATH]):
1351*da0073e9SAndroid Build Coastguard Worker            math_ref = torch.nn.functional.scaled_dot_product_attention(q, k, v, None, 0.0, False)
1352*da0073e9SAndroid Build Coastguard Worker
1353*da0073e9SAndroid Build Coastguard Worker        with sdpa_kernel(backends=[SDPBackend.FLASH_ATTENTION]):
1354*da0073e9SAndroid Build Coastguard Worker            # Should not fail because inputs don't require grad
1355*da0073e9SAndroid Build Coastguard Worker            flash_ref = torch.nn.functional.scaled_dot_product_attention(q, k, v, None, 0.0, False)
1356*da0073e9SAndroid Build Coastguard Worker
1357*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(math_ref, flash_ref, atol=1e-3, rtol=1e-3)
1358*da0073e9SAndroid Build Coastguard Worker
1359*da0073e9SAndroid Build Coastguard Worker            # Should fail because inputs require grad
1360*da0073e9SAndroid Build Coastguard Worker            q = make_tensor(size, requires_grad=True)
1361*da0073e9SAndroid Build Coastguard Worker            k = make_tensor(size, requires_grad=True)
1362*da0073e9SAndroid Build Coastguard Worker            v = make_tensor(size, requires_grad=True)
1363*da0073e9SAndroid Build Coastguard Worker            if 192 < head_dim <= 224 or (head_dim > 224 and dropout_p != 0.0):
1364*da0073e9SAndroid Build Coastguard Worker                self.assertRaises(
1365*da0073e9SAndroid Build Coastguard Worker                    RuntimeError,
1366*da0073e9SAndroid Build Coastguard Worker                    lambda: torch.nn.functional.scaled_dot_product_attention(
1367*da0073e9SAndroid Build Coastguard Worker                        q, k, v, None, dropout_p, False
1368*da0073e9SAndroid Build Coastguard Worker                    ),
1369*da0073e9SAndroid Build Coastguard Worker                )
1370*da0073e9SAndroid Build Coastguard Worker            else:
1371*da0073e9SAndroid Build Coastguard Worker                flash_ref = torch.nn.functional.scaled_dot_product_attention(q, k, v, None, dropout_p, False)
1372*da0073e9SAndroid Build Coastguard Worker
1373*da0073e9SAndroid Build Coastguard Worker    @onlyCUDA
1374*da0073e9SAndroid Build Coastguard Worker    def test_dispatch_fails_no_backend(self, device):
1375*da0073e9SAndroid Build Coastguard Worker        dtype = torch.float16
1376*da0073e9SAndroid Build Coastguard Worker        with sdpa_kernel(backends=[SDPBackend.ERROR]):
1377*da0073e9SAndroid Build Coastguard Worker            size = (2, 3, 4)
1378*da0073e9SAndroid Build Coastguard Worker            q = torch.randn(size, device=device, dtype=dtype)
1379*da0073e9SAndroid Build Coastguard Worker            k = torch.randn(size, device=device, dtype=dtype)
1380*da0073e9SAndroid Build Coastguard Worker            v = torch.randn(size, device=device, dtype=dtype)
1381*da0073e9SAndroid Build Coastguard Worker            self.assertRaisesRegex(RuntimeError, "No viable backend for scaled_dot_product_attention was found.",
1382*da0073e9SAndroid Build Coastguard Worker                                   lambda: torch._fused_sdp_choice(q, k, v))
1383*da0073e9SAndroid Build Coastguard Worker            self.assertRaisesRegex(RuntimeError, "No viable backend for scaled_dot_product_attention was found.",
1384*da0073e9SAndroid Build Coastguard Worker                                   lambda: torch.nn.functional.scaled_dot_product_attention(q, k, v))
1385*da0073e9SAndroid Build Coastguard Worker
1386*da0073e9SAndroid Build Coastguard Worker    @onlyCUDA
1387*da0073e9SAndroid Build Coastguard Worker    @unittest.skipIf(not PLATFORM_SUPPORTS_FUSED_ATTENTION, "Does not support fused scaled dot product attention")
1388*da0073e9SAndroid Build Coastguard Worker    @parametrize(
1389*da0073e9SAndroid Build Coastguard Worker        "kernel",
1390*da0073e9SAndroid Build Coastguard Worker        PLATFORM_SPECIFIC_SDPA,
1391*da0073e9SAndroid Build Coastguard Worker    )
1392*da0073e9SAndroid Build Coastguard Worker    def test_invalid_fused_inputs_dim_3(self, device, kernel: SDPBackend):
1393*da0073e9SAndroid Build Coastguard Worker        with sdpa_kernel(backends=[kernel]):
1394*da0073e9SAndroid Build Coastguard Worker            # Dim is not 4
1395*da0073e9SAndroid Build Coastguard Worker            size = (2, 3, 8)
1396*da0073e9SAndroid Build Coastguard Worker            dtype = torch.float16
1397*da0073e9SAndroid Build Coastguard Worker            q = torch.randn(size, device=device, dtype=dtype)
1398*da0073e9SAndroid Build Coastguard Worker            k = torch.randn(size, device=device, dtype=dtype)
1399*da0073e9SAndroid Build Coastguard Worker            v = torch.randn(size, device=device, dtype=dtype)
1400*da0073e9SAndroid Build Coastguard Worker            with self.assertWarnsRegex(UserWarning, "Both fused kernels requires query, key and value to be 4 dimensional"):
1401*da0073e9SAndroid Build Coastguard Worker                self.assertRaises(RuntimeError, lambda: torch.nn.functional.scaled_dot_product_attention(
1402*da0073e9SAndroid Build Coastguard Worker                    q, k, v, None, 0.0, False))
1403*da0073e9SAndroid Build Coastguard Worker
1404*da0073e9SAndroid Build Coastguard Worker    @onlyCUDA
1405*da0073e9SAndroid Build Coastguard Worker    @unittest.skipIf(not PLATFORM_SUPPORTS_FUSED_ATTENTION, "Does not support fused scaled dot product attention")
1406*da0073e9SAndroid Build Coastguard Worker    @parametrize(
1407*da0073e9SAndroid Build Coastguard Worker        "kernel",
1408*da0073e9SAndroid Build Coastguard Worker        PLATFORM_SPECIFIC_SDPA,
1409*da0073e9SAndroid Build Coastguard Worker    )
1410*da0073e9SAndroid Build Coastguard Worker    def test_invalid_fused_inputs_broadcast(self, device, kernel: SDPBackend):
1411*da0073e9SAndroid Build Coastguard Worker        with sdpa_kernel(backends=[kernel]):
1412*da0073e9SAndroid Build Coastguard Worker            #  Fused Kernels don't support broadcasting for dense inputs
1413*da0073e9SAndroid Build Coastguard Worker            dtype = torch.float16
1414*da0073e9SAndroid Build Coastguard Worker            size = (2, 4, 3, 8)
1415*da0073e9SAndroid Build Coastguard Worker            size_broadcast = (1, 4, 3, 8)
1416*da0073e9SAndroid Build Coastguard Worker            q = torch.randn(size_broadcast, device=device, dtype=dtype)
1417*da0073e9SAndroid Build Coastguard Worker            k = torch.randn(size, device=device, dtype=dtype)
1418*da0073e9SAndroid Build Coastguard Worker            v = torch.randn(size, device=device, dtype=dtype)
1419*da0073e9SAndroid Build Coastguard Worker            self.assertRaises(RuntimeError, lambda: torch.nn.functional.scaled_dot_product_attention(
1420*da0073e9SAndroid Build Coastguard Worker                q, k, v, None, 0.0, False))
1421*da0073e9SAndroid Build Coastguard Worker
1422*da0073e9SAndroid Build Coastguard Worker    @onlyCUDA
1423*da0073e9SAndroid Build Coastguard Worker    @unittest.skipIf(not PLATFORM_SUPPORTS_FUSED_ATTENTION, "Does not support fused scaled dot product attention")
1424*da0073e9SAndroid Build Coastguard Worker    @parametrize("kernel", PLATFORM_SPECIFIC_SDPA)
1425*da0073e9SAndroid Build Coastguard Worker    def test_invalid_sequence_lengths(self, device, kernel: SDPBackend):
1426*da0073e9SAndroid Build Coastguard Worker        with sdpa_kernel(backends=[kernel]):
1427*da0073e9SAndroid Build Coastguard Worker            # Passing in a q,k,v with 0 length sequences will error
1428*da0073e9SAndroid Build Coastguard Worker            dtype = torch.float16
1429*da0073e9SAndroid Build Coastguard Worker            make_tensor = partial(torch.rand, device=device, dtype=dtype)
1430*da0073e9SAndroid Build Coastguard Worker            size = SdpaShape(2, 2, 0, 8)
1431*da0073e9SAndroid Build Coastguard Worker            q, k, v = make_tensor(size), make_tensor(size), make_tensor(size)
1432*da0073e9SAndroid Build Coastguard Worker            with self.assertWarnsRegex(UserWarning, "Both fused kernels do not support zero seq_len_q or seq_len_kv."):
1433*da0073e9SAndroid Build Coastguard Worker                self.assertRaises(RuntimeError, lambda: torch.nn.functional.scaled_dot_product_attention(
1434*da0073e9SAndroid Build Coastguard Worker                    q, k, v, None, 0.0, False))
1435*da0073e9SAndroid Build Coastguard Worker
1436*da0073e9SAndroid Build Coastguard Worker    @onlyCUDA
1437*da0073e9SAndroid Build Coastguard Worker    @unittest.skipIf(not PLATFORM_SUPPORTS_FUSED_ATTENTION, "Does not support fused scaled dot product attention")
1438*da0073e9SAndroid Build Coastguard Worker    @parametrize("kernel", PLATFORM_SPECIFIC_SDPA)
1439*da0073e9SAndroid Build Coastguard Worker    def test_invalid_last_dim_stride(self, device, kernel: SDPBackend):
1440*da0073e9SAndroid Build Coastguard Worker        with sdpa_kernel(backends=[kernel]):
1441*da0073e9SAndroid Build Coastguard Worker            # Passing in a q,k,v with last dim stride not equal to 1 will error
1442*da0073e9SAndroid Build Coastguard Worker            dtype = torch.float16
1443*da0073e9SAndroid Build Coastguard Worker            make_tensor = partial(torch.rand, device=device, dtype=dtype)
1444*da0073e9SAndroid Build Coastguard Worker            size = SdpaShape(2, 2, 8, 8)
1445*da0073e9SAndroid Build Coastguard Worker            q, k, v = make_tensor(size), make_tensor(size), make_tensor(size)
1446*da0073e9SAndroid Build Coastguard Worker            q.as_strided_(size, [2, 2, 2, 2])
1447*da0073e9SAndroid Build Coastguard Worker            with self.assertWarnsRegex(UserWarning, "Both fused kernels require the last dimension of the input to have stride 1."):
1448*da0073e9SAndroid Build Coastguard Worker                self.assertRaises(RuntimeError, lambda: torch.nn.functional.scaled_dot_product_attention(
1449*da0073e9SAndroid Build Coastguard Worker                    q, k, v, None, 0.0, False))
1450*da0073e9SAndroid Build Coastguard Worker
1451*da0073e9SAndroid Build Coastguard Worker    @onlyCUDA
1452*da0073e9SAndroid Build Coastguard Worker    @unittest.skipIf(not PLATFORM_SUPPORTS_FLASH_ATTENTION, "Does not flash_attention fused scaled dot product attention")
1453*da0073e9SAndroid Build Coastguard Worker    @parametrize("kernel", [SDPBackend.FLASH_ATTENTION, SDPBackend.EFFICIENT_ATTENTION])
1454*da0073e9SAndroid Build Coastguard Worker    def test_invalid_fused_inputs_head_dim(self, device, kernel: SDPBackend):
1455*da0073e9SAndroid Build Coastguard Worker        with sdpa_kernel(backends=[kernel]):
1456*da0073e9SAndroid Build Coastguard Worker            # The embed dim per head is not divisible by 8 for flash attention
1457*da0073e9SAndroid Build Coastguard Worker            dtype = torch.float16
1458*da0073e9SAndroid Build Coastguard Worker            make_tensor = partial(torch.rand, device=device, dtype=dtype)
1459*da0073e9SAndroid Build Coastguard Worker            size = SdpaShape(2, 2, 3, 9) if kernel == SDPBackend.EFFICIENT_ATTENTION else SdpaShape(2, 2, 3, 257)
1460*da0073e9SAndroid Build Coastguard Worker            if TEST_WITH_ROCM:  # On ROCM, FA and EA share the backend GPU kernels
1461*da0073e9SAndroid Build Coastguard Worker                size = SdpaShape(2, 2, 3, 257)
1462*da0073e9SAndroid Build Coastguard Worker            q, k, v = make_tensor(size), make_tensor(size), make_tensor(size)
1463*da0073e9SAndroid Build Coastguard Worker            self.assertRaises(RuntimeError, lambda: torch.nn.functional.scaled_dot_product_attention(
1464*da0073e9SAndroid Build Coastguard Worker                q, k, v, None, 0.0, False))
1465*da0073e9SAndroid Build Coastguard Worker
1466*da0073e9SAndroid Build Coastguard Worker    @onlyCUDA
1467*da0073e9SAndroid Build Coastguard Worker    @unittest.skipIf(not PLATFORM_SUPPORTS_FUSED_ATTENTION, "Does not support fused scaled dot product attention")
1468*da0073e9SAndroid Build Coastguard Worker    @parametrize(
1469*da0073e9SAndroid Build Coastguard Worker        "kernel",
1470*da0073e9SAndroid Build Coastguard Worker        PLATFORM_SPECIFIC_SDPA,
1471*da0073e9SAndroid Build Coastguard Worker    )
1472*da0073e9SAndroid Build Coastguard Worker    def test_invalid_fused_inputs_invalid_dtype(self, device, kernel: SDPBackend):
1473*da0073e9SAndroid Build Coastguard Worker        with sdpa_kernel(backends=[kernel]):
1474*da0073e9SAndroid Build Coastguard Worker            # Invalid dtype for both Flash Attention and Mem Efficient Attention
1475*da0073e9SAndroid Build Coastguard Worker            size = SdpaShape(2, 2, 3, 16)
1476*da0073e9SAndroid Build Coastguard Worker            make_tensor = partial(torch.rand, device=device, dtype=torch.float64)
1477*da0073e9SAndroid Build Coastguard Worker            q, k, v = make_tensor(size), make_tensor(size), make_tensor(size)
1478*da0073e9SAndroid Build Coastguard Worker            self.assertRaises(RuntimeError, lambda: torch.nn.functional.scaled_dot_product_attention(
1479*da0073e9SAndroid Build Coastguard Worker                q, k, v, None, 0.0, False))
1480*da0073e9SAndroid Build Coastguard Worker
1481*da0073e9SAndroid Build Coastguard Worker    @onlyCUDA
1482*da0073e9SAndroid Build Coastguard Worker    @unittest.skipIf(not PLATFORM_SUPPORTS_FLASH_ATTENTION, "Does not support flash attention")
1483*da0073e9SAndroid Build Coastguard Worker    @parametrize("kernel", [SDPBackend.FLASH_ATTENTION])
1484*da0073e9SAndroid Build Coastguard Worker    def test_invalid_fused_inputs_attn_mask_present(self, device, kernel: SDPBackend):
1485*da0073e9SAndroid Build Coastguard Worker        with sdpa_kernel(backends=[kernel]):
1486*da0073e9SAndroid Build Coastguard Worker            # Failures for unsupported SDP args
1487*da0073e9SAndroid Build Coastguard Worker            size = SdpaShape(2, 2, 3, 16)
1488*da0073e9SAndroid Build Coastguard Worker            make_tensor = partial(torch.rand, size, device=device, dtype=torch.float16)
1489*da0073e9SAndroid Build Coastguard Worker            q, k, v = make_tensor(), make_tensor(), make_tensor()
1490*da0073e9SAndroid Build Coastguard Worker            # Non-None attention mask
1491*da0073e9SAndroid Build Coastguard Worker            mask = torch.ones((2, 2, 3, 3), device=device, dtype=q.dtype)
1492*da0073e9SAndroid Build Coastguard Worker            self.assertRaises(RuntimeError, lambda: torch.nn.functional.scaled_dot_product_attention(
1493*da0073e9SAndroid Build Coastguard Worker                q, k, v, mask, 0.0, False))
1494*da0073e9SAndroid Build Coastguard Worker
1495*da0073e9SAndroid Build Coastguard Worker    @onlyCUDA
1496*da0073e9SAndroid Build Coastguard Worker    @unittest.skipIf(not PLATFORM_SUPPORTS_FLASH_ATTENTION, "Does not support fused SDPA or pre-SM80 hardware")
1497*da0073e9SAndroid Build Coastguard Worker    def test_unaligned_tensors(self, device):
1498*da0073e9SAndroid Build Coastguard Worker        # The alignment is depdent on arch so we specifiy SM80OrLater
1499*da0073e9SAndroid Build Coastguard Worker        dtype = torch.float16
1500*da0073e9SAndroid Build Coastguard Worker        size = SdpaShape(2, 2, 8, 5)
1501*da0073e9SAndroid Build Coastguard Worker        make_tensor = partial(torch.rand, size, device=device, dtype=dtype)
1502*da0073e9SAndroid Build Coastguard Worker        q, k, v = make_tensor(), make_tensor(), make_tensor()
1503*da0073e9SAndroid Build Coastguard Worker        with sdpa_kernel(backends=[SDPBackend.EFFICIENT_ATTENTION]):
1504*da0073e9SAndroid Build Coastguard Worker            ctxmgr = self.assertRaises(RuntimeError) if not TEST_WITH_ROCM else contextlib.nullcontext()
1505*da0073e9SAndroid Build Coastguard Worker            with ctxmgr:
1506*da0073e9SAndroid Build Coastguard Worker                torch.nn.functional.scaled_dot_product_attention(q, k, v, None, 0.0, False)
1507*da0073e9SAndroid Build Coastguard Worker
1508*da0073e9SAndroid Build Coastguard Worker    @onlyCUDA
1509*da0073e9SAndroid Build Coastguard Worker    @unittest.skipIf(not PLATFORM_SUPPORTS_FLASH_ATTENTION, "Does not support fused SDPA or pre-SM80 hardware")
1510*da0073e9SAndroid Build Coastguard Worker    def test_flash_fail_fp32(self, device):
1511*da0073e9SAndroid Build Coastguard Worker        dtype = torch.float
1512*da0073e9SAndroid Build Coastguard Worker        size = SdpaShape(16, 16, 32, 32)
1513*da0073e9SAndroid Build Coastguard Worker        make_tensor = partial(torch.rand, size, device=device, dtype=dtype)
1514*da0073e9SAndroid Build Coastguard Worker        q, k, v = make_tensor(), make_tensor(), make_tensor()
1515*da0073e9SAndroid Build Coastguard Worker        with sdpa_kernel(backends=[SDPBackend.FLASH_ATTENTION]):
1516*da0073e9SAndroid Build Coastguard Worker            with self.assertWarnsRegex(UserWarning, "Expected query, key and value to all be of dtype: {Half, BFloat16}"):
1517*da0073e9SAndroid Build Coastguard Worker                self.assertRaises(RuntimeError, lambda: torch.nn.functional.scaled_dot_product_attention(
1518*da0073e9SAndroid Build Coastguard Worker                    q, k, v, None, 0.0, False))
1519*da0073e9SAndroid Build Coastguard Worker
1520*da0073e9SAndroid Build Coastguard Worker    @onlyCUDA
1521*da0073e9SAndroid Build Coastguard Worker    @unittest.skipIf(not PLATFORM_SUPPORTS_FLASH_ATTENTION, "Does not support SDPA or pre-SM80 hardware")
1522*da0073e9SAndroid Build Coastguard Worker    def test_flash_autocast_fp32_float16(self, device):
1523*da0073e9SAndroid Build Coastguard Worker        dtype = torch.float
1524*da0073e9SAndroid Build Coastguard Worker        size = SdpaShape(16, 16, 32, 32)
1525*da0073e9SAndroid Build Coastguard Worker        make_tensor = partial(torch.rand, size, device=device, dtype=dtype)
1526*da0073e9SAndroid Build Coastguard Worker        q, k, v = make_tensor(), make_tensor(), make_tensor()
1527*da0073e9SAndroid Build Coastguard Worker        with torch.autocast(device_type='cuda', dtype=torch.float16):
1528*da0073e9SAndroid Build Coastguard Worker            with sdpa_kernel(backends=[SDPBackend.FLASH_ATTENTION]):
1529*da0073e9SAndroid Build Coastguard Worker                _ = torch.nn.functional.scaled_dot_product_attention(
1530*da0073e9SAndroid Build Coastguard Worker                    q, k, v, None, 0.0, False)
1531*da0073e9SAndroid Build Coastguard Worker
1532*da0073e9SAndroid Build Coastguard Worker    @onlyCUDA
1533*da0073e9SAndroid Build Coastguard Worker    @unittest.skipIf(not PLATFORM_SUPPORTS_FLASH_ATTENTION, "Does not support SDPA or pre-SM80 hardware")
1534*da0073e9SAndroid Build Coastguard Worker    def test_flash_autocast_fp32_bfloat16(self, device):
1535*da0073e9SAndroid Build Coastguard Worker        dtype = torch.float
1536*da0073e9SAndroid Build Coastguard Worker        size = SdpaShape(16, 16, 32, 32)
1537*da0073e9SAndroid Build Coastguard Worker        make_tensor = partial(torch.rand, size, device=device, dtype=dtype)
1538*da0073e9SAndroid Build Coastguard Worker        q, k, v = make_tensor(), make_tensor(), make_tensor()
1539*da0073e9SAndroid Build Coastguard Worker        with torch.autocast(device_type='cuda', dtype=torch.bfloat16):
1540*da0073e9SAndroid Build Coastguard Worker            with sdpa_kernel(backends=[SDPBackend.FLASH_ATTENTION]):
1541*da0073e9SAndroid Build Coastguard Worker                _ = torch.nn.functional.scaled_dot_product_attention(
1542*da0073e9SAndroid Build Coastguard Worker                    q, k, v, None, 0.0, False)
1543*da0073e9SAndroid Build Coastguard Worker
1544*da0073e9SAndroid Build Coastguard Worker    # Note: do not truncate the list according to platforms. These tests should always raise errors.
1545*da0073e9SAndroid Build Coastguard Worker    @parametrize("kernel", [SDPBackend.MATH, SDPBackend.FLASH_ATTENTION, SDPBackend.EFFICIENT_ATTENTION])
1546*da0073e9SAndroid Build Coastguard Worker    def test_invalid_inputs_different_datatypes(self, device, kernel: SDPBackend):
1547*da0073e9SAndroid Build Coastguard Worker        with sdpa_kernel(backends=[kernel]):
1548*da0073e9SAndroid Build Coastguard Worker            # Different datatypes
1549*da0073e9SAndroid Build Coastguard Worker            shape = (1, 4, 8, 16)
1550*da0073e9SAndroid Build Coastguard Worker            query = torch.randn(shape, dtype=torch.float32, device=device)
1551*da0073e9SAndroid Build Coastguard Worker            key = torch.randn(shape, dtype=torch.float16, device=device)
1552*da0073e9SAndroid Build Coastguard Worker            value = torch.randn(shape, dtype=torch.float16, device=device)
1553*da0073e9SAndroid Build Coastguard Worker            self.assertRaises(RuntimeError, lambda: F.scaled_dot_product_attention(query, key, value))
1554*da0073e9SAndroid Build Coastguard Worker
1555*da0073e9SAndroid Build Coastguard Worker    @onlyCUDA
1556*da0073e9SAndroid Build Coastguard Worker    @parametrize("kernel", [SDPBackend.MATH, SDPBackend.FLASH_ATTENTION, SDPBackend.EFFICIENT_ATTENTION])
1557*da0073e9SAndroid Build Coastguard Worker    def test_invalid_inputs_different_devices(self, device, kernel: SDPBackend):
1558*da0073e9SAndroid Build Coastguard Worker        # Different devices
1559*da0073e9SAndroid Build Coastguard Worker        shape = (1, 4, 8, 16)
1560*da0073e9SAndroid Build Coastguard Worker        query = torch.randn(shape, dtype=torch.float32, device=device)
1561*da0073e9SAndroid Build Coastguard Worker        key = torch.randn(shape, dtype=torch.float16, device='cpu')
1562*da0073e9SAndroid Build Coastguard Worker        value = torch.randn(shape, dtype=torch.float16, device='cpu')
1563*da0073e9SAndroid Build Coastguard Worker        self.assertRaises(RuntimeError, lambda: F.scaled_dot_product_attention(query, key, value))
1564*da0073e9SAndroid Build Coastguard Worker
1565*da0073e9SAndroid Build Coastguard Worker    @parametrize("kernel", [SDPBackend.MATH, SDPBackend.FLASH_ATTENTION, SDPBackend.EFFICIENT_ATTENTION])
1566*da0073e9SAndroid Build Coastguard Worker    def test_invalid_inputs_1_dimensional_inputs(self, device, kernel: SDPBackend):
1567*da0073e9SAndroid Build Coastguard Worker        with sdpa_kernel(backends=[kernel]):
1568*da0073e9SAndroid Build Coastguard Worker            # 1 dimensional input
1569*da0073e9SAndroid Build Coastguard Worker            shape = (1, 4)
1570*da0073e9SAndroid Build Coastguard Worker            query = torch.randn(4, dtype=torch.float16, device=device)
1571*da0073e9SAndroid Build Coastguard Worker            key = torch.randn(shape, dtype=torch.float16, device=device)
1572*da0073e9SAndroid Build Coastguard Worker            value = torch.randn(shape, dtype=torch.float16, device=device)
1573*da0073e9SAndroid Build Coastguard Worker            self.assertRaises(RuntimeError, lambda: F.scaled_dot_product_attention(query, key, value))
1574*da0073e9SAndroid Build Coastguard Worker
1575*da0073e9SAndroid Build Coastguard Worker    @onlyCUDA
1576*da0073e9SAndroid Build Coastguard Worker    @skipIfRocm  # Missing EFFICIENT_ATTENTION
1577*da0073e9SAndroid Build Coastguard Worker    @unittest.skipIf(not PLATFORM_SUPPORTS_MEM_EFF_ATTENTION, "Fused SDPA was not built for this system")
1578*da0073e9SAndroid Build Coastguard Worker    def test_fused_kernels_nested_broadcasting_error_cases(self, device):
1579*da0073e9SAndroid Build Coastguard Worker        # one of k,v needs to be broadcasted and other has non consistent seq_len dim
1580*da0073e9SAndroid Build Coastguard Worker        rand_nested_tensor = partial(rand_sdpa_tensor, type="nested", device=device, dtype=torch.float32)
1581*da0073e9SAndroid Build Coastguard Worker        batch, num_heads, head_dim = 32, 8, 64
1582*da0073e9SAndroid Build Coastguard Worker        seq_lens_q = torch.randint(low=1, high=32, size=(batch,)).tolist()
1583*da0073e9SAndroid Build Coastguard Worker        seq_lens_v = torch.randint(low=1, high=32, size=(batch,)).tolist()
1584*da0073e9SAndroid Build Coastguard Worker
1585*da0073e9SAndroid Build Coastguard Worker        q_shape = SdpaShape(batch, num_heads, seq_lens_q, head_dim)
1586*da0073e9SAndroid Build Coastguard Worker        k_shape = SdpaShape(1, num_heads, 1, head_dim)
1587*da0073e9SAndroid Build Coastguard Worker        v_shape = SdpaShape(batch, num_heads, seq_lens_v, head_dim)
1588*da0073e9SAndroid Build Coastguard Worker
1589*da0073e9SAndroid Build Coastguard Worker        query = rand_nested_tensor(q_shape).transpose(1, 2)
1590*da0073e9SAndroid Build Coastguard Worker        key = rand_nested_tensor(k_shape).transpose(1, 2)
1591*da0073e9SAndroid Build Coastguard Worker        value = rand_nested_tensor(v_shape).transpose(1, 2)
1592*da0073e9SAndroid Build Coastguard Worker
1593*da0073e9SAndroid Build Coastguard Worker        with sdpa_kernel(backends=[SDPBackend.EFFICIENT_ATTENTION]):
1594*da0073e9SAndroid Build Coastguard Worker            with self.assertRaisesRegex(RuntimeError, "No available kernel"):
1595*da0073e9SAndroid Build Coastguard Worker                torch.nn.functional.scaled_dot_product_attention(
1596*da0073e9SAndroid Build Coastguard Worker                    query, key, value, attn_mask=None, dropout_p=0.0, is_causal=False)
1597*da0073e9SAndroid Build Coastguard Worker
1598*da0073e9SAndroid Build Coastguard Worker    @onlyCUDA
1599*da0073e9SAndroid Build Coastguard Worker    @unittest.skipIf(not PLATFORM_SUPPORTS_FLASH_ATTENTION, "Fused SDPA was not built for this system")
1600*da0073e9SAndroid Build Coastguard Worker    def test_nested_fails_on_padding_head_dim(self, device):
1601*da0073e9SAndroid Build Coastguard Worker        dtype = torch.bfloat16
1602*da0073e9SAndroid Build Coastguard Worker        seq_len_list = [2, 4, 5, 6, 7]
1603*da0073e9SAndroid Build Coastguard Worker        shape = SdpaShape(5, 8, seq_len_list, 57)
1604*da0073e9SAndroid Build Coastguard Worker        make_tensor = partial(rand_sdpa_tensor, shape=shape, type="nested", device=device, dtype=dtype)
1605*da0073e9SAndroid Build Coastguard Worker        q, k, v = make_tensor(), make_tensor(), make_tensor()
1606*da0073e9SAndroid Build Coastguard Worker        with sdpa_kernel(backends=[SDPBackend.FLASH_ATTENTION]):
1607*da0073e9SAndroid Build Coastguard Worker            with self.assertWarnsRegex(UserWarning, "For NestedTensor inputs, Flash attention requires"):
1608*da0073e9SAndroid Build Coastguard Worker                self.assertRaises(RuntimeError, lambda: torch.nn.functional.scaled_dot_product_attention(
1609*da0073e9SAndroid Build Coastguard Worker                    q, k, v, None, 0.0, False))
1610*da0073e9SAndroid Build Coastguard Worker
1611*da0073e9SAndroid Build Coastguard Worker    @onlyCUDA
1612*da0073e9SAndroid Build Coastguard Worker    @unittest.skipIf(not PLATFORM_SUPPORTS_FUSED_ATTENTION or not isLessThanSM80Device,
1613*da0073e9SAndroid Build Coastguard Worker                     "Current platform does not support fused SDPA or is an SM80+ device.")
1614*da0073e9SAndroid Build Coastguard Worker    def test_mem_efficient_fail_bfloat16_less_than_sm80(self, device):
1615*da0073e9SAndroid Build Coastguard Worker        dtype = torch.bfloat16
1616*da0073e9SAndroid Build Coastguard Worker        size = SdpaShape(16, 16, 32, 32)
1617*da0073e9SAndroid Build Coastguard Worker        make_tensor = partial(torch.rand, size, device=device, dtype=dtype)
1618*da0073e9SAndroid Build Coastguard Worker        q, k, v = make_tensor(), make_tensor(), make_tensor()
1619*da0073e9SAndroid Build Coastguard Worker        with sdpa_kernel(backends=[SDPBackend.EFFICIENT_ATTENTION]):
1620*da0073e9SAndroid Build Coastguard Worker            with self.assertWarnsRegex(UserWarning, "Expected query, key and value to all be of dtype: {Half, Float}"):
1621*da0073e9SAndroid Build Coastguard Worker                self.assertRaises(RuntimeError, lambda: torch.nn.functional.scaled_dot_product_attention(
1622*da0073e9SAndroid Build Coastguard Worker                    q, k, v, None, 0.0, False))
1623*da0073e9SAndroid Build Coastguard Worker
1624*da0073e9SAndroid Build Coastguard Worker    @onlyCUDA
1625*da0073e9SAndroid Build Coastguard Worker    @unittest.skipIf(not PLATFORM_SUPPORTS_FLASH_ATTENTION, "Does not support flash attention")
1626*da0073e9SAndroid Build Coastguard Worker    def test_flash_atteention_large_bf16_nan_values(self, device):
1627*da0073e9SAndroid Build Coastguard Worker        query = torch.full((1, 1, 1, 64), 133120.0, dtype=torch.bfloat16, device="cuda")
1628*da0073e9SAndroid Build Coastguard Worker        key = torch.full((1, 1, 1, 64), 133120.0, dtype=torch.bfloat16, device="cuda")
1629*da0073e9SAndroid Build Coastguard Worker        value = torch.full((1, 1, 1, 64), 133120.0, dtype=torch.bfloat16, device="cuda")
1630*da0073e9SAndroid Build Coastguard Worker
1631*da0073e9SAndroid Build Coastguard Worker        with sdpa_kernel(SDPBackend.FLASH_ATTENTION):
1632*da0073e9SAndroid Build Coastguard Worker            out = torch.nn.functional.scaled_dot_product_attention(query, key, value)
1633*da0073e9SAndroid Build Coastguard Worker
1634*da0073e9SAndroid Build Coastguard Worker        self.assertFalse(torch.isnan(out).any(), "Output should not contain NaNs!")
1635*da0073e9SAndroid Build Coastguard Worker
1636*da0073e9SAndroid Build Coastguard Worker    @onlyCUDA
1637*da0073e9SAndroid Build Coastguard Worker    @unittest.skipIf(not PLATFORM_SUPPORTS_FUSED_ATTENTION, "Fused SDPA was not built for this system")
1638*da0073e9SAndroid Build Coastguard Worker    @parametrize("fused_kernel", [SDPBackend.FLASH_ATTENTION, SDPBackend.EFFICIENT_ATTENTION] if
1639*da0073e9SAndroid Build Coastguard Worker                 PLATFORM_SUPPORTS_FLASH_ATTENTION else [SDPBackend.EFFICIENT_ATTENTION])
1640*da0073e9SAndroid Build Coastguard Worker    def test_fused_kernels_seq_len_0_inputs(self, device, fused_kernel):
1641*da0073e9SAndroid Build Coastguard Worker        rand_nested_tensor = partial(rand_sdpa_tensor, type="nested", device=device, dtype=torch.float16)
1642*da0073e9SAndroid Build Coastguard Worker        batch, num_heads, head_dim = 32, 16, 64
1643*da0073e9SAndroid Build Coastguard Worker        seq_lens = torch.randint(low=1, high=32, size=(batch,))
1644*da0073e9SAndroid Build Coastguard Worker        # make sure some seq_lens are 0
1645*da0073e9SAndroid Build Coastguard Worker        num_zeros = 10
1646*da0073e9SAndroid Build Coastguard Worker        indices = torch.randint(low=0, high=batch, size=(num_zeros,))
1647*da0073e9SAndroid Build Coastguard Worker        seq_lens.scatter_(0, indices, 0)
1648*da0073e9SAndroid Build Coastguard Worker
1649*da0073e9SAndroid Build Coastguard Worker        shape = SdpaShape(batch, num_heads, seq_lens.tolist(), head_dim)
1650*da0073e9SAndroid Build Coastguard Worker        query = rand_nested_tensor(shape)
1651*da0073e9SAndroid Build Coastguard Worker        key = rand_nested_tensor(shape)
1652*da0073e9SAndroid Build Coastguard Worker        value = rand_nested_tensor(shape)
1653*da0073e9SAndroid Build Coastguard Worker
1654*da0073e9SAndroid Build Coastguard Worker        query = query.transpose(1, 2)
1655*da0073e9SAndroid Build Coastguard Worker        key = key.transpose(1, 2)
1656*da0073e9SAndroid Build Coastguard Worker        value = value.transpose(1, 2)
1657*da0073e9SAndroid Build Coastguard Worker
1658*da0073e9SAndroid Build Coastguard Worker        with sdpa_kernel(backends=[fused_kernel]):
1659*da0073e9SAndroid Build Coastguard Worker            with self.assertRaisesRegex(RuntimeError, "No available kernel"):
1660*da0073e9SAndroid Build Coastguard Worker                torch.nn.functional.scaled_dot_product_attention(
1661*da0073e9SAndroid Build Coastguard Worker                    query, key, value, attn_mask=None, dropout_p=0.0, is_causal=False)
1662*da0073e9SAndroid Build Coastguard Worker
1663*da0073e9SAndroid Build Coastguard Worker    @onlyCUDA
1664*da0073e9SAndroid Build Coastguard Worker    @unittest.skipIf(not PLATFORM_SUPPORTS_FLASH_ATTENTION, "Fused SDPA was not built for this system")
1665*da0073e9SAndroid Build Coastguard Worker    def test_fused_kernels_nested_broadcasting_requires_grad_failure(self, device):
1666*da0073e9SAndroid Build Coastguard Worker        rand_nested_tensor = partial(rand_sdpa_tensor, type="nested", device=device, dtype=torch.float16, requires_grad=True)
1667*da0073e9SAndroid Build Coastguard Worker        batch, num_heads, head_dim, head_dim_v = 32, 16, 64, 64
1668*da0073e9SAndroid Build Coastguard Worker        seq_lens = torch.randint(low=1, high=32, size=(batch,)).tolist()
1669*da0073e9SAndroid Build Coastguard Worker        q_shape = SdpaShape(1, num_heads, 1, head_dim)
1670*da0073e9SAndroid Build Coastguard Worker        k_shape = SdpaShape(batch, num_heads, seq_lens, head_dim)
1671*da0073e9SAndroid Build Coastguard Worker        v_shape = SdpaShape(batch, 1, seq_lens, head_dim_v)
1672*da0073e9SAndroid Build Coastguard Worker
1673*da0073e9SAndroid Build Coastguard Worker        # create a dense query
1674*da0073e9SAndroid Build Coastguard Worker        query = torch.randn(q_shape, device=device, dtype=torch.float16, requires_grad=True)
1675*da0073e9SAndroid Build Coastguard Worker        key = rand_nested_tensor(k_shape)
1676*da0073e9SAndroid Build Coastguard Worker        value = rand_nested_tensor(v_shape)
1677*da0073e9SAndroid Build Coastguard Worker
1678*da0073e9SAndroid Build Coastguard Worker        query = query.transpose(1, 2)
1679*da0073e9SAndroid Build Coastguard Worker        key = key.transpose(1, 2)
1680*da0073e9SAndroid Build Coastguard Worker        value = value.transpose(1, 2)
1681*da0073e9SAndroid Build Coastguard Worker
1682*da0073e9SAndroid Build Coastguard Worker        with sdpa_kernel(backends=[SDPBackend.FLASH_ATTENTION]):
1683*da0073e9SAndroid Build Coastguard Worker            with self.assertWarnsRegex(UserWarning, "Both fused kernels do not support training with broadcasted NT inputs"):
1684*da0073e9SAndroid Build Coastguard Worker                with self.assertRaisesRegex(RuntimeError, "No available kernel"):
1685*da0073e9SAndroid Build Coastguard Worker                    out = torch.nn.functional.scaled_dot_product_attention(
1686*da0073e9SAndroid Build Coastguard Worker                        query, key, value, attn_mask=None, dropout_p=0.0, is_causal=False)
1687*da0073e9SAndroid Build Coastguard Worker
1688*da0073e9SAndroid Build Coastguard Worker    @onlyCUDA
1689*da0073e9SAndroid Build Coastguard Worker    @unittest.skipIf(not PLATFORM_SUPPORTS_FLASH_ATTENTION, "Does not support flash attention")
1690*da0073e9SAndroid Build Coastguard Worker    def test_flash_attention_fail_with_non_square_causal_attention(self, device):
1691*da0073e9SAndroid Build Coastguard Worker        dtype = torch.bfloat16
1692*da0073e9SAndroid Build Coastguard Worker        q_shape = SdpaShape(1, 1, 8, 16)
1693*da0073e9SAndroid Build Coastguard Worker        kv_shape = SdpaShape(1, 1, 12, 16)
1694*da0073e9SAndroid Build Coastguard Worker        make_q = partial(torch.rand, q_shape, device=device, dtype=dtype)
1695*da0073e9SAndroid Build Coastguard Worker        make_kv = partial(torch.rand, kv_shape, device=device, dtype=dtype)
1696*da0073e9SAndroid Build Coastguard Worker        q, k, v = make_q(), make_kv(), make_kv()
1697*da0073e9SAndroid Build Coastguard Worker        warning_str = "Flash attention does not support the is_causal flag when seqlen_q != seqlen_k."
1698*da0073e9SAndroid Build Coastguard Worker        with sdpa_kernel(backends=[SDPBackend.FLASH_ATTENTION]):
1699*da0073e9SAndroid Build Coastguard Worker            with self.assertWarnsRegex(UserWarning, warning_str):
1700*da0073e9SAndroid Build Coastguard Worker                self.assertRaises(RuntimeError, lambda: torch.nn.functional.scaled_dot_product_attention(
1701*da0073e9SAndroid Build Coastguard Worker                    q, k, v, None, 0.0, is_causal=True))
1702*da0073e9SAndroid Build Coastguard Worker
1703*da0073e9SAndroid Build Coastguard Workerdef _get_block_size_n(device, head_dim, is_dropout, is_causal):
1704*da0073e9SAndroid Build Coastguard Worker    # This should match the block sizes in the CUDA kernel
1705*da0073e9SAndroid Build Coastguard Worker    assert head_dim <= 256
1706*da0073e9SAndroid Build Coastguard Worker    major, minor = torch.cuda.get_device_capability(device)
1707*da0073e9SAndroid Build Coastguard Worker    is_sm8x = major == 8 and minor > 0  # Only include sm86 and sm89, exclude sm80 (A100)
1708*da0073e9SAndroid Build Coastguard Worker    is_sm80 = major == 8 and minor == 0
1709*da0073e9SAndroid Build Coastguard Worker    is_sm90 = major == 9 and minor == 0
1710*da0073e9SAndroid Build Coastguard Worker    if head_dim <= 32:
1711*da0073e9SAndroid Build Coastguard Worker        return 128
1712*da0073e9SAndroid Build Coastguard Worker    if head_dim <= 64:
1713*da0073e9SAndroid Build Coastguard Worker        return 128 if not is_dropout else 64
1714*da0073e9SAndroid Build Coastguard Worker    elif head_dim <= 96:
1715*da0073e9SAndroid Build Coastguard Worker        return 64
1716*da0073e9SAndroid Build Coastguard Worker    elif head_dim <= 128:
1717*da0073e9SAndroid Build Coastguard Worker        if is_sm8x:
1718*da0073e9SAndroid Build Coastguard Worker            return 64 if (not is_dropout and is_causal) else 32
1719*da0073e9SAndroid Build Coastguard Worker        else:
1720*da0073e9SAndroid Build Coastguard Worker            return 64 if not is_dropout else 32
1721*da0073e9SAndroid Build Coastguard Worker    elif head_dim <= 160:
1722*da0073e9SAndroid Build Coastguard Worker        if is_sm8x:
1723*da0073e9SAndroid Build Coastguard Worker            return 64
1724*da0073e9SAndroid Build Coastguard Worker        else:
1725*da0073e9SAndroid Build Coastguard Worker            return 32
1726*da0073e9SAndroid Build Coastguard Worker    elif head_dim <= 192:
1727*da0073e9SAndroid Build Coastguard Worker        return 64
1728*da0073e9SAndroid Build Coastguard Worker    elif head_dim <= 224:
1729*da0073e9SAndroid Build Coastguard Worker        return 64
1730*da0073e9SAndroid Build Coastguard Worker    elif head_dim <= 256:
1731*da0073e9SAndroid Build Coastguard Worker        return 64
1732*da0073e9SAndroid Build Coastguard Worker
1733*da0073e9SAndroid Build Coastguard Worker
1734*da0073e9SAndroid Build Coastguard Workerdef pad_last_dim(input_tensor, alignment_size, slice: bool = False):
1735*da0073e9SAndroid Build Coastguard Worker    last_dim_size = input_tensor.size(-1)
1736*da0073e9SAndroid Build Coastguard Worker    if (last_dim_size % alignment_size == 0):
1737*da0073e9SAndroid Build Coastguard Worker        return input_tensor, last_dim_size
1738*da0073e9SAndroid Build Coastguard Worker    pad_count = alignment_size - (last_dim_size % alignment_size)
1739*da0073e9SAndroid Build Coastguard Worker    padded_tensor = F.pad(input_tensor, (0, pad_count))
1740*da0073e9SAndroid Build Coastguard Worker    if slice:
1741*da0073e9SAndroid Build Coastguard Worker        return padded_tensor[..., :last_dim_size], last_dim_size
1742*da0073e9SAndroid Build Coastguard Worker    return padded_tensor, last_dim_size
1743*da0073e9SAndroid Build Coastguard Worker
1744*da0073e9SAndroid Build Coastguard Worker
1745*da0073e9SAndroid Build Coastguard Workerclass TestSDPA(NNTestCase):
1746*da0073e9SAndroid Build Coastguard Worker    """ Used to test generic functionality of scaled_dot_product_attention
1747*da0073e9SAndroid Build Coastguard Worker    Summary:
1748*da0073e9SAndroid Build Coastguard Worker        If you are adding a new test to this class, make sure that it runs
1749*da0073e9SAndroid Build Coastguard Worker        for both cpu and cuda. If you're test is only applicable to cuda,
1750*da0073e9SAndroid Build Coastguard Worker        add it to TestSDPACudaOnly.
1751*da0073e9SAndroid Build Coastguard Worker    """
1752*da0073e9SAndroid Build Coastguard Worker    @parametrize("contiguous_inputs", [True, False])
1753*da0073e9SAndroid Build Coastguard Worker    def test_sdp_math_gradcheck(self, device, contiguous_inputs: bool):
1754*da0073e9SAndroid Build Coastguard Worker
1755*da0073e9SAndroid Build Coastguard Worker        batch_size, seq_len, num_heads, head_dim = 4, 4, 2, 16
1756*da0073e9SAndroid Build Coastguard Worker        shape = SdpaShape(batch_size, num_heads, seq_len, head_dim)
1757*da0073e9SAndroid Build Coastguard Worker        make_tensor = partial(rand_sdpa_tensor, type="dense", device=device,
1758*da0073e9SAndroid Build Coastguard Worker                              dtype=torch.float64, requires_grad=True, packed=True)
1759*da0073e9SAndroid Build Coastguard Worker
1760*da0073e9SAndroid Build Coastguard Worker        qkv = make_tensor(shape)
1761*da0073e9SAndroid Build Coastguard Worker        query, key, value = qkv.chunk(3, dim=-1)
1762*da0073e9SAndroid Build Coastguard Worker
1763*da0073e9SAndroid Build Coastguard Worker        query = query.view(batch_size, -1, num_heads, head_dim).transpose(1, 2)
1764*da0073e9SAndroid Build Coastguard Worker        key = key.view(batch_size, -1, num_heads, head_dim).transpose(1, 2)
1765*da0073e9SAndroid Build Coastguard Worker        value = value.view(batch_size, -1, num_heads, head_dim).transpose(1, 2)
1766*da0073e9SAndroid Build Coastguard Worker
1767*da0073e9SAndroid Build Coastguard Worker        if contiguous_inputs:
1768*da0073e9SAndroid Build Coastguard Worker            query = query.contiguous()
1769*da0073e9SAndroid Build Coastguard Worker            key = key.contiguous()
1770*da0073e9SAndroid Build Coastguard Worker            value = value.contiguous()
1771*da0073e9SAndroid Build Coastguard Worker
1772*da0073e9SAndroid Build Coastguard Worker        with sdpa_kernel(backends=[SDPBackend.MATH]):
1773*da0073e9SAndroid Build Coastguard Worker            assert gradcheck(lambda *args, **kwargs:
1774*da0073e9SAndroid Build Coastguard Worker                             wrapper_set_seed(torch.nn.functional.scaled_dot_product_attention, *args, **kwargs),
1775*da0073e9SAndroid Build Coastguard Worker                             (query, key, value, None, 0.0, False)
1776*da0073e9SAndroid Build Coastguard Worker                             )
1777*da0073e9SAndroid Build Coastguard Worker
1778*da0073e9SAndroid Build Coastguard Worker    @onlyCPU
1779*da0073e9SAndroid Build Coastguard Worker    @parametrize("type", ["dense", "nested"])
1780*da0073e9SAndroid Build Coastguard Worker    @parametrize("dropout", [0.0, 0.7])
1781*da0073e9SAndroid Build Coastguard Worker    @parametrize("dtype", [torch.float64, torch.float32, torch.bfloat16, torch.half])
1782*da0073e9SAndroid Build Coastguard Worker    def test_fused_sdp_choice_cpu(self, device, type: str, dropout: float, dtype: torch.dtype):
1783*da0073e9SAndroid Build Coastguard Worker        # Test that cpu and nestedtensor cpu return MATH backend
1784*da0073e9SAndroid Build Coastguard Worker        make_tensor = partial(rand_sdpa_tensor, type=type, device=device, dtype=dtype)
1785*da0073e9SAndroid Build Coastguard Worker        size = SdpaShape(2, 8, 128, 64)
1786*da0073e9SAndroid Build Coastguard Worker        q, k, v = make_tensor(size), make_tensor(size), make_tensor(size)
1787*da0073e9SAndroid Build Coastguard Worker        if type == "nested" \
1788*da0073e9SAndroid Build Coastguard Worker                or dropout > 0.0 \
1789*da0073e9SAndroid Build Coastguard Worker                or dtype not in [torch.float32, torch.float64, torch.bfloat16, torch.float16]:
1790*da0073e9SAndroid Build Coastguard Worker            assert torch._fused_sdp_choice(q, k, v, dropout_p=dropout) == SDPBackend.MATH.value
1791*da0073e9SAndroid Build Coastguard Worker        else:
1792*da0073e9SAndroid Build Coastguard Worker            assert torch._fused_sdp_choice(q, k, v, dropout_p=dropout) == SDPBackend.FLASH_ATTENTION.value
1793*da0073e9SAndroid Build Coastguard Worker
1794*da0073e9SAndroid Build Coastguard Worker    @onlyCPU
1795*da0073e9SAndroid Build Coastguard Worker    @parametrize("fused_kernel", [SDPBackend.FLASH_ATTENTION])
1796*da0073e9SAndroid Build Coastguard Worker    @parametrize("dtype", [torch.float64, torch.float32, torch.bfloat16, torch.float16])
1797*da0073e9SAndroid Build Coastguard Worker    @parametrize("batch_size", [2, 12])
1798*da0073e9SAndroid Build Coastguard Worker    @parametrize("seq_len", [267, 1030])
1799*da0073e9SAndroid Build Coastguard Worker    @parametrize("n_head", [1, 3])
1800*da0073e9SAndroid Build Coastguard Worker    @parametrize("head_dim", [8, 16])
1801*da0073e9SAndroid Build Coastguard Worker    @parametrize("causal", [True, False])
1802*da0073e9SAndroid Build Coastguard Worker    @parametrize("train", [True, False])
1803*da0073e9SAndroid Build Coastguard Worker    def test_scaled_dot_product_fused_attention_vs_math_cpu(
1804*da0073e9SAndroid Build Coastguard Worker        self,
1805*da0073e9SAndroid Build Coastguard Worker        device,
1806*da0073e9SAndroid Build Coastguard Worker        fused_kernel,
1807*da0073e9SAndroid Build Coastguard Worker        dtype,
1808*da0073e9SAndroid Build Coastguard Worker        batch_size,
1809*da0073e9SAndroid Build Coastguard Worker        seq_len,
1810*da0073e9SAndroid Build Coastguard Worker        n_head,
1811*da0073e9SAndroid Build Coastguard Worker        head_dim,
1812*da0073e9SAndroid Build Coastguard Worker        causal,
1813*da0073e9SAndroid Build Coastguard Worker        train,
1814*da0073e9SAndroid Build Coastguard Worker    ):
1815*da0073e9SAndroid Build Coastguard Worker        atol = 1e-5
1816*da0073e9SAndroid Build Coastguard Worker        rtol = 5e-6
1817*da0073e9SAndroid Build Coastguard Worker        if dtype is torch.bfloat16:
1818*da0073e9SAndroid Build Coastguard Worker            atol = 5e-2
1819*da0073e9SAndroid Build Coastguard Worker            rtol = 5e-2
1820*da0073e9SAndroid Build Coastguard Worker        if dtype is torch.float16:
1821*da0073e9SAndroid Build Coastguard Worker            atol = 1e-2
1822*da0073e9SAndroid Build Coastguard Worker            rtol = 1e-2
1823*da0073e9SAndroid Build Coastguard Worker
1824*da0073e9SAndroid Build Coastguard Worker        n_embd = n_head * head_dim
1825*da0073e9SAndroid Build Coastguard Worker        make_tensor = partial(rand_sdpa_tensor, type="dense", device=device, dtype=dtype, packed=True, requires_grad=False)
1826*da0073e9SAndroid Build Coastguard Worker        shape = SdpaShape(batch_size, n_head, seq_len, head_dim)
1827*da0073e9SAndroid Build Coastguard Worker        x = make_tensor(shape)
1828*da0073e9SAndroid Build Coastguard Worker        x2 = x.clone()
1829*da0073e9SAndroid Build Coastguard Worker
1830*da0073e9SAndroid Build Coastguard Worker        if train:
1831*da0073e9SAndroid Build Coastguard Worker            x.requires_grad_(True)
1832*da0073e9SAndroid Build Coastguard Worker            x2.requires_grad_(True)
1833*da0073e9SAndroid Build Coastguard Worker
1834*da0073e9SAndroid Build Coastguard Worker        q, k, v = x.split(n_embd, dim=2)
1835*da0073e9SAndroid Build Coastguard Worker        q2, k2, v2 = x2.split(n_embd, dim=2)
1836*da0073e9SAndroid Build Coastguard Worker
1837*da0073e9SAndroid Build Coastguard Worker        if dtype in [torch.bfloat16, torch.float16]:
1838*da0073e9SAndroid Build Coastguard Worker            q2 = q2.float()
1839*da0073e9SAndroid Build Coastguard Worker            k2 = k2.float()
1840*da0073e9SAndroid Build Coastguard Worker            v2 = v2.float()
1841*da0073e9SAndroid Build Coastguard Worker
1842*da0073e9SAndroid Build Coastguard Worker        # (B, nh, T, hs)
1843*da0073e9SAndroid Build Coastguard Worker        k = k.view(batch_size, seq_len, n_head, head_dim).transpose(1, 2)
1844*da0073e9SAndroid Build Coastguard Worker        q = q.view(batch_size, seq_len, n_head, head_dim).transpose(1, 2)
1845*da0073e9SAndroid Build Coastguard Worker        v = v.view(batch_size, seq_len, n_head, head_dim).transpose(1, 2)
1846*da0073e9SAndroid Build Coastguard Worker        k2 = k2.view(batch_size, seq_len, n_head, head_dim).transpose(1, 2)
1847*da0073e9SAndroid Build Coastguard Worker        q2 = q2.view(batch_size, seq_len, n_head, head_dim).transpose(1, 2)
1848*da0073e9SAndroid Build Coastguard Worker        v2 = v2.view(batch_size, seq_len, n_head, head_dim).transpose(1, 2)
1849*da0073e9SAndroid Build Coastguard Worker
1850*da0073e9SAndroid Build Coastguard Worker        with sdpa_kernel(backends=[fused_kernel]):
1851*da0073e9SAndroid Build Coastguard Worker            actual = torch.nn.functional.scaled_dot_product_attention(
1852*da0073e9SAndroid Build Coastguard Worker                q, k, v, attn_mask=None, dropout_p=0.0, is_causal=causal)
1853*da0073e9SAndroid Build Coastguard Worker        with sdpa_kernel(backends=[SDPBackend.MATH]):
1854*da0073e9SAndroid Build Coastguard Worker            math_ref = torch.nn.functional.scaled_dot_product_attention(
1855*da0073e9SAndroid Build Coastguard Worker                q2, k2, v2, attn_mask=None, dropout_p=0.0, is_causal=causal)
1856*da0073e9SAndroid Build Coastguard Worker
1857*da0073e9SAndroid Build Coastguard Worker        if dtype in [torch.bfloat16, torch.float16]:
1858*da0073e9SAndroid Build Coastguard Worker            math_ref = math_ref.to(dtype)
1859*da0073e9SAndroid Build Coastguard Worker
1860*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(actual, math_ref, atol=atol, rtol=rtol)
1861*da0073e9SAndroid Build Coastguard Worker
1862*da0073e9SAndroid Build Coastguard Worker        if train:
1863*da0073e9SAndroid Build Coastguard Worker            actual.sum().backward()
1864*da0073e9SAndroid Build Coastguard Worker            math_ref.sum().backward()
1865*da0073e9SAndroid Build Coastguard Worker
1866*da0073e9SAndroid Build Coastguard Worker            grad_x, grad_x2 = x.grad, x2.grad
1867*da0073e9SAndroid Build Coastguard Worker            grad_q_actual, grad_k_actual, grad_v_actual = grad_x.split(n_embd, dim=2)
1868*da0073e9SAndroid Build Coastguard Worker            grad_q_ref, grad_k_ref, grad_v_ref = grad_x2.split(n_embd, dim=2)
1869*da0073e9SAndroid Build Coastguard Worker
1870*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(grad_q_actual, grad_q_ref, atol=atol, rtol=rtol)
1871*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(grad_k_actual, grad_k_ref, atol=atol, rtol=rtol)
1872*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(grad_v_actual, grad_v_ref, atol=atol, rtol=rtol)
1873*da0073e9SAndroid Build Coastguard Worker
1874*da0073e9SAndroid Build Coastguard Worker    @onlyCPU
1875*da0073e9SAndroid Build Coastguard Worker    @parametrize("fused_kernel", [SDPBackend.FLASH_ATTENTION])
1876*da0073e9SAndroid Build Coastguard Worker    @parametrize("dtype", [torch.float64, torch.float32, torch.bfloat16, torch.float16])
1877*da0073e9SAndroid Build Coastguard Worker    @parametrize("batch_size", [2, 12])
1878*da0073e9SAndroid Build Coastguard Worker    @parametrize("q_seq_len", [267, 1030])
1879*da0073e9SAndroid Build Coastguard Worker    @parametrize("kv_seq_len", [514, 1179])
1880*da0073e9SAndroid Build Coastguard Worker    @parametrize("n_head", [1, 3])
1881*da0073e9SAndroid Build Coastguard Worker    @parametrize("head_dim", [8, 16])
1882*da0073e9SAndroid Build Coastguard Worker    @parametrize("mask_dim", [2, 4])
1883*da0073e9SAndroid Build Coastguard Worker    @parametrize("bool_mask", [0, 1])
1884*da0073e9SAndroid Build Coastguard Worker    @parametrize("train", [True, False])
1885*da0073e9SAndroid Build Coastguard Worker    def test_scaled_dot_product_fused_attention_mask_vs_math_cpu(
1886*da0073e9SAndroid Build Coastguard Worker        self,
1887*da0073e9SAndroid Build Coastguard Worker        device,
1888*da0073e9SAndroid Build Coastguard Worker        fused_kernel,
1889*da0073e9SAndroid Build Coastguard Worker        dtype,
1890*da0073e9SAndroid Build Coastguard Worker        batch_size,
1891*da0073e9SAndroid Build Coastguard Worker        q_seq_len,
1892*da0073e9SAndroid Build Coastguard Worker        kv_seq_len,
1893*da0073e9SAndroid Build Coastguard Worker        n_head,
1894*da0073e9SAndroid Build Coastguard Worker        head_dim,
1895*da0073e9SAndroid Build Coastguard Worker        mask_dim,
1896*da0073e9SAndroid Build Coastguard Worker        bool_mask,
1897*da0073e9SAndroid Build Coastguard Worker        train,
1898*da0073e9SAndroid Build Coastguard Worker    ):
1899*da0073e9SAndroid Build Coastguard Worker        tol = Tolerances(1e-5, 5e-6)
1900*da0073e9SAndroid Build Coastguard Worker        if dtype is torch.bfloat16:
1901*da0073e9SAndroid Build Coastguard Worker            tol = Tolerances(5e-2, 5e-2)
1902*da0073e9SAndroid Build Coastguard Worker        if dtype is torch.float16:
1903*da0073e9SAndroid Build Coastguard Worker            tol = Tolerances(1e-2, 1e-2)
1904*da0073e9SAndroid Build Coastguard Worker
1905*da0073e9SAndroid Build Coastguard Worker        make_tensor = partial(rand_sdpa_tensor, type="dense", device=device, dtype=dtype, requires_grad=False)
1906*da0073e9SAndroid Build Coastguard Worker        q_shape = SdpaShape(batch_size, n_head, q_seq_len, head_dim)
1907*da0073e9SAndroid Build Coastguard Worker        kv_shape = SdpaShape(batch_size, n_head, kv_seq_len, head_dim)
1908*da0073e9SAndroid Build Coastguard Worker        q = make_tensor(q_shape)
1909*da0073e9SAndroid Build Coastguard Worker        k = make_tensor(kv_shape)
1910*da0073e9SAndroid Build Coastguard Worker        v = make_tensor(kv_shape)
1911*da0073e9SAndroid Build Coastguard Worker        q2, k2, v2 = q.clone(), k.clone(), v.clone()
1912*da0073e9SAndroid Build Coastguard Worker
1913*da0073e9SAndroid Build Coastguard Worker        if train:
1914*da0073e9SAndroid Build Coastguard Worker            q.requires_grad_(True)
1915*da0073e9SAndroid Build Coastguard Worker            k.requires_grad_(True)
1916*da0073e9SAndroid Build Coastguard Worker            v.requires_grad_(True)
1917*da0073e9SAndroid Build Coastguard Worker            q2.requires_grad_(True)
1918*da0073e9SAndroid Build Coastguard Worker            k2.requires_grad_(True)
1919*da0073e9SAndroid Build Coastguard Worker            v2.requires_grad_(True)
1920*da0073e9SAndroid Build Coastguard Worker
1921*da0073e9SAndroid Build Coastguard Worker        if dtype in [torch.bfloat16, torch.float16]:
1922*da0073e9SAndroid Build Coastguard Worker            q2, k2, v2 = q2.float(), k2.float(), v2.float()
1923*da0073e9SAndroid Build Coastguard Worker        # (B, nh, T, hs)
1924*da0073e9SAndroid Build Coastguard Worker        q = q.view(batch_size, q_seq_len, n_head, head_dim).transpose(1, 2)
1925*da0073e9SAndroid Build Coastguard Worker        k = k.view(batch_size, kv_seq_len, n_head, head_dim).transpose(1, 2)
1926*da0073e9SAndroid Build Coastguard Worker        v = v.view(batch_size, kv_seq_len, n_head, head_dim).transpose(1, 2)
1927*da0073e9SAndroid Build Coastguard Worker        if mask_dim == 4:
1928*da0073e9SAndroid Build Coastguard Worker            mask_shape = (batch_size, n_head, q_seq_len, kv_seq_len)
1929*da0073e9SAndroid Build Coastguard Worker        else:
1930*da0073e9SAndroid Build Coastguard Worker            mask_shape = (q_seq_len, kv_seq_len)
1931*da0073e9SAndroid Build Coastguard Worker        if bool_mask:
1932*da0073e9SAndroid Build Coastguard Worker            attn_mask = torch.randint(0, 2, size=mask_shape, dtype=torch.bool, device=device)
1933*da0073e9SAndroid Build Coastguard Worker        else:
1934*da0073e9SAndroid Build Coastguard Worker            attn_mask = torch.randn(mask_shape, dtype=dtype, device=device)
1935*da0073e9SAndroid Build Coastguard Worker        q2 = q2.view(batch_size, q_seq_len, n_head, head_dim).transpose(1, 2)
1936*da0073e9SAndroid Build Coastguard Worker        k2 = k2.view(batch_size, kv_seq_len, n_head, head_dim).transpose(1, 2)
1937*da0073e9SAndroid Build Coastguard Worker        v2 = v2.view(batch_size, kv_seq_len, n_head, head_dim).transpose(1, 2)
1938*da0073e9SAndroid Build Coastguard Worker
1939*da0073e9SAndroid Build Coastguard Worker        with sdpa_kernel(backends=[fused_kernel]):
1940*da0073e9SAndroid Build Coastguard Worker            actual = torch.nn.functional.scaled_dot_product_attention(
1941*da0073e9SAndroid Build Coastguard Worker                q, k, v, attn_mask=attn_mask, dropout_p=0.0, is_causal=False)
1942*da0073e9SAndroid Build Coastguard Worker        with sdpa_kernel(backends=[SDPBackend.MATH]):
1943*da0073e9SAndroid Build Coastguard Worker            if not bool_mask and dtype in [torch.bfloat16, torch.float16]:
1944*da0073e9SAndroid Build Coastguard Worker                attn_mask = attn_mask.float()
1945*da0073e9SAndroid Build Coastguard Worker            math_ref = torch.nn.functional.scaled_dot_product_attention(
1946*da0073e9SAndroid Build Coastguard Worker                q2, k2, v2, attn_mask=attn_mask, dropout_p=0.0, is_causal=False)
1947*da0073e9SAndroid Build Coastguard Worker
1948*da0073e9SAndroid Build Coastguard Worker        if dtype in [torch.bfloat16, torch.float16]:
1949*da0073e9SAndroid Build Coastguard Worker            math_ref = math_ref.to(dtype)
1950*da0073e9SAndroid Build Coastguard Worker
1951*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(actual, math_ref, atol=tol.atol, rtol=tol.rtol)
1952*da0073e9SAndroid Build Coastguard Worker
1953*da0073e9SAndroid Build Coastguard Worker        if train:
1954*da0073e9SAndroid Build Coastguard Worker            actual.sum().backward()
1955*da0073e9SAndroid Build Coastguard Worker            math_ref.sum().backward()
1956*da0073e9SAndroid Build Coastguard Worker
1957*da0073e9SAndroid Build Coastguard Worker            grad_q_actual, grad_k_actual, grad_v_actual = q.grad, k.grad, v.grad
1958*da0073e9SAndroid Build Coastguard Worker            grad_q_ref, grad_k_ref, grad_v_ref = q2.grad, k2.grad, v2.grad
1959*da0073e9SAndroid Build Coastguard Worker
1960*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(grad_q_actual, grad_q_ref, atol=tol.atol, rtol=tol.rtol)
1961*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(grad_k_actual, grad_k_ref, atol=tol.atol, rtol=tol.rtol)
1962*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(grad_v_actual, grad_v_ref, atol=tol.atol, rtol=tol.rtol)
1963*da0073e9SAndroid Build Coastguard Worker
1964*da0073e9SAndroid Build Coastguard Worker    @onlyCPU
1965*da0073e9SAndroid Build Coastguard Worker    def test_scaled_dot_product_fused_attention_with_inf(self, device):
1966*da0073e9SAndroid Build Coastguard Worker        # https://github.com/pytorch/pytorch/issues/127055.
1967*da0073e9SAndroid Build Coastguard Worker        full = torch.full((600, 600), float("-inf"), device=device)
1968*da0073e9SAndroid Build Coastguard Worker        mask = torch.triu(full, diagonal=1) + torch.tril(full, diagonal=-10)
1969*da0073e9SAndroid Build Coastguard Worker        make_tensor = partial(rand_sdpa_tensor, type="dense", device=device, dtype=torch.float32, requires_grad=False)
1970*da0073e9SAndroid Build Coastguard Worker        input_shape = SdpaShape(1, 600, 2, 8)
1971*da0073e9SAndroid Build Coastguard Worker        q = make_tensor(input_shape)
1972*da0073e9SAndroid Build Coastguard Worker        k = make_tensor(input_shape)
1973*da0073e9SAndroid Build Coastguard Worker        v = make_tensor(input_shape)
1974*da0073e9SAndroid Build Coastguard Worker        with sdpa_kernel(backends=[SDPBackend.MATH]):
1975*da0073e9SAndroid Build Coastguard Worker            math_ref = torch.nn.functional.scaled_dot_product_attention(q, k, v, attn_mask=mask)
1976*da0073e9SAndroid Build Coastguard Worker        with sdpa_kernel(backends=[SDPBackend.FLASH_ATTENTION]):
1977*da0073e9SAndroid Build Coastguard Worker            actual = torch.nn.functional.scaled_dot_product_attention(q, k, v, attn_mask=mask)
1978*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(math_ref, actual)
1979*da0073e9SAndroid Build Coastguard Worker
1980*da0073e9SAndroid Build Coastguard Worker    @parametrize("kernel", [SDPBackend.MATH])
1981*da0073e9SAndroid Build Coastguard Worker    def test_scaled_dot_product_attention_math_with_negative_scale(self, device, kernel: SDPBackend):
1982*da0073e9SAndroid Build Coastguard Worker        # https://github.com/pytorch/pytorch/issues/105190.
1983*da0073e9SAndroid Build Coastguard Worker        def ref(x):
1984*da0073e9SAndroid Build Coastguard Worker            v1 = torch.matmul(x, x.transpose(-1, -2))
1985*da0073e9SAndroid Build Coastguard Worker            v2 = v1 / -0.0001
1986*da0073e9SAndroid Build Coastguard Worker            v3 = v2.softmax(dim=-1)
1987*da0073e9SAndroid Build Coastguard Worker            v4 = torch.matmul(v3, x)
1988*da0073e9SAndroid Build Coastguard Worker            return v4
1989*da0073e9SAndroid Build Coastguard Worker
1990*da0073e9SAndroid Build Coastguard Worker        x = torch.randn(1, 3, 64, 64, device=device)
1991*da0073e9SAndroid Build Coastguard Worker        ref_result = ref(x)
1992*da0073e9SAndroid Build Coastguard Worker        with sdpa_kernel(backends=[kernel]):
1993*da0073e9SAndroid Build Coastguard Worker            sdp_math = torch.nn.functional.scaled_dot_product_attention(x, x, x, scale=-1.0 / 0.0001)
1994*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(ref_result, sdp_math)
1995*da0073e9SAndroid Build Coastguard Worker
1996*da0073e9SAndroid Build Coastguard Workerclass TestSDPACudaOnly(NNTestCase):
1997*da0073e9SAndroid Build Coastguard Worker    """ Used to test CUDA only functionality of scaled_dot_product_attention
1998*da0073e9SAndroid Build Coastguard Worker    Quarks:
1999*da0073e9SAndroid Build Coastguard Worker        There is some trickiness with this function. Its runtime behavior
2000*da0073e9SAndroid Build Coastguard Worker        is dependent on the CUDA architecture you are testing it on. See
2001*da0073e9SAndroid Build Coastguard Worker        `PLATFORM_SUPPORTS_FUSED_ATTENTION` at the top of the file.
2002*da0073e9SAndroid Build Coastguard Worker        Summary:
2003*da0073e9SAndroid Build Coastguard Worker            Math: always supported
2004*da0073e9SAndroid Build Coastguard Worker            FlashAttention: Supported on sm80 or newer hardware
2005*da0073e9SAndroid Build Coastguard Worker            MemEfficientAttention: Supported on sm50 or newer hardware
2006*da0073e9SAndroid Build Coastguard Worker    """
2007*da0073e9SAndroid Build Coastguard Worker    _do_cuda_memory_leak_check = True
2008*da0073e9SAndroid Build Coastguard Worker    _do_cuda_non_default_stream = True
2009*da0073e9SAndroid Build Coastguard Worker
2010*da0073e9SAndroid Build Coastguard Worker    # TODO USED FOR TESTING THE SCORES, e.g. testing ALIBI we don't need this now
2011*da0073e9SAndroid Build Coastguard Worker    def normalize_flash_attn_S(
2012*da0073e9SAndroid Build Coastguard Worker        self,
2013*da0073e9SAndroid Build Coastguard Worker        attn_unnorm,
2014*da0073e9SAndroid Build Coastguard Worker        q,
2015*da0073e9SAndroid Build Coastguard Worker        k,
2016*da0073e9SAndroid Build Coastguard Worker        v,
2017*da0073e9SAndroid Build Coastguard Worker        query_padding_mask=None,
2018*da0073e9SAndroid Build Coastguard Worker        key_padding_mask=None,
2019*da0073e9SAndroid Build Coastguard Worker        attn_bias=None,
2020*da0073e9SAndroid Build Coastguard Worker        is_dropout=False,
2021*da0073e9SAndroid Build Coastguard Worker        causal=False,
2022*da0073e9SAndroid Build Coastguard Worker        window_size=(-1, -1),  # -1 means infinite window size
2023*da0073e9SAndroid Build Coastguard Worker        scale=None,
2024*da0073e9SAndroid Build Coastguard Worker    ):
2025*da0073e9SAndroid Build Coastguard Worker        """
2026*da0073e9SAndroid Build Coastguard Worker        Arguments:
2027*da0073e9SAndroid Build Coastguard Worker            q: (batch_size, seqlen_q, nheads, head_dim)
2028*da0073e9SAndroid Build Coastguard Worker            k, v: (batch_size, seqlen_k, nheads, head_dim)
2029*da0073e9SAndroid Build Coastguard Worker            key_padding_mask: (batch_size, seqlen_q)
2030*da0073e9SAndroid Build Coastguard Worker            attn_bias: broadcastable to (batch_size, nheads, seqlen_q, seqlen_k)
2031*da0073e9SAndroid Build Coastguard Worker        Output:
2032*da0073e9SAndroid Build Coastguard Worker            softmax_lse: (batch_size, nheads, seqlen_q)
2033*da0073e9SAndroid Build Coastguard Worker            softmax_max: (batch_size, nheads, seqlen_q)
2034*da0073e9SAndroid Build Coastguard Worker        """
2035*da0073e9SAndroid Build Coastguard Worker        q = q.transpose(1, 2)
2036*da0073e9SAndroid Build Coastguard Worker        k = k.transpose(1, 2)
2037*da0073e9SAndroid Build Coastguard Worker        v = v.transpose(1, 2)
2038*da0073e9SAndroid Build Coastguard Worker        if causal:
2039*da0073e9SAndroid Build Coastguard Worker            window_size = (window_size[0], 0)
2040*da0073e9SAndroid Build Coastguard Worker        q, k, v = q.float(), k.float(), v.float()
2041*da0073e9SAndroid Build Coastguard Worker        _, seqlen_q, _, head_dim = q.shape
2042*da0073e9SAndroid Build Coastguard Worker        seqlen_k = k.shape[1]
2043*da0073e9SAndroid Build Coastguard Worker        b = q.shape[0]
2044*da0073e9SAndroid Build Coastguard Worker        from torch.nn.attention.bias import _calculate_scale
2045*da0073e9SAndroid Build Coastguard Worker        scale = _calculate_scale(head_dim, scale)
2046*da0073e9SAndroid Build Coastguard Worker        scores = torch.matmul(q.transpose(1, 2) * scale, k.permute(0, 2, 3, 1))
2047*da0073e9SAndroid Build Coastguard Worker        if key_padding_mask is not None:
2048*da0073e9SAndroid Build Coastguard Worker            scores.masked_fill_(~key_padding_mask.view(b, 1, 1, -1), float("-inf"))
2049*da0073e9SAndroid Build Coastguard Worker        if window_size[0] >= 0 or window_size[1] >= 0:
2050*da0073e9SAndroid Build Coastguard Worker            local_mask = self.construct_local_mask(
2051*da0073e9SAndroid Build Coastguard Worker                seqlen_q,
2052*da0073e9SAndroid Build Coastguard Worker                seqlen_k,
2053*da0073e9SAndroid Build Coastguard Worker                window_size,
2054*da0073e9SAndroid Build Coastguard Worker                query_padding_mask,
2055*da0073e9SAndroid Build Coastguard Worker                key_padding_mask,
2056*da0073e9SAndroid Build Coastguard Worker                q.device,
2057*da0073e9SAndroid Build Coastguard Worker            )
2058*da0073e9SAndroid Build Coastguard Worker            scores.masked_fill_(local_mask, float("-inf"))
2059*da0073e9SAndroid Build Coastguard Worker        if attn_bias is not None:
2060*da0073e9SAndroid Build Coastguard Worker            scores = scores + attn_bias.to(dtype=scores.dtype)
2061*da0073e9SAndroid Build Coastguard Worker        block_size_n = _get_block_size_n(scores.device, head_dim, is_dropout, causal)
2062*da0073e9SAndroid Build Coastguard Worker        scores_block = scores.split(block_size_n, dim=-1)
2063*da0073e9SAndroid Build Coastguard Worker        lse_block = torch.stack([torch.logsumexp(s, dim=-1) for s in scores_block], dim=-1)
2064*da0073e9SAndroid Build Coastguard Worker        lse = torch.logsumexp(lse_block, dim=-1)
2065*da0073e9SAndroid Build Coastguard Worker        # lse could be -inf (i.e. all values in scores are -inf), and we want to set those to inf
2066*da0073e9SAndroid Build Coastguard Worker        # so that when we do torch.exp(m - lse), we get 0.0 instead of NaN.
2067*da0073e9SAndroid Build Coastguard Worker        lse[lse == float("-inf")] = float("inf")
2068*da0073e9SAndroid Build Coastguard Worker        scores_max_block = torch.stack([torch.amax(s, dim=-1) for s in scores_block], dim=-1)
2069*da0073e9SAndroid Build Coastguard Worker        cummax_block = torch.cummax(scores_max_block.flip(-1), dim=-1).values.flip(-1).unbind(dim=-1)
2070*da0073e9SAndroid Build Coastguard Worker        attn_unnorm_block = attn_unnorm.split(block_size_n, dim=-1)
2071*da0073e9SAndroid Build Coastguard Worker        attn_norm = torch.cat(
2072*da0073e9SAndroid Build Coastguard Worker            [
2073*da0073e9SAndroid Build Coastguard Worker                a * (torch.exp(m - lse)).unsqueeze(-1)
2074*da0073e9SAndroid Build Coastguard Worker                for a, m in zip(attn_unnorm_block, cummax_block)
2075*da0073e9SAndroid Build Coastguard Worker            ],
2076*da0073e9SAndroid Build Coastguard Worker            dim=-1,
2077*da0073e9SAndroid Build Coastguard Worker        )
2078*da0073e9SAndroid Build Coastguard Worker        if query_padding_mask is not None:
2079*da0073e9SAndroid Build Coastguard Worker            attn_norm.masked_fill_(~query_padding_mask.view(b, 1, -1, 1), 0.0)
2080*da0073e9SAndroid Build Coastguard Worker            # attn_norm.masked_fill_(rearrange(~query_padding_mask, "b s -> b 1 s 1"), 0.0)
2081*da0073e9SAndroid Build Coastguard Worker        return attn_norm.to(dtype=attn_unnorm.dtype)
2082*da0073e9SAndroid Build Coastguard Worker
2083*da0073e9SAndroid Build Coastguard Worker    def construct_local_mask(self, seqlen_q, seqlen_k, window_size, query_padding_mask, key_padding_mask, device):
2084*da0073e9SAndroid Build Coastguard Worker        # row_idx = rearrange(torch.arange(seqlen_q, device=device, dtype=torch.long), "s -> s 1")
2085*da0073e9SAndroid Build Coastguard Worker        row_idx = torch.arange(seqlen_q, device=device, dtype=torch.long).view(-1, 1)
2086*da0073e9SAndroid Build Coastguard Worker        col_idx = torch.arange(seqlen_k, device=device, dtype=torch.long)
2087*da0073e9SAndroid Build Coastguard Worker        sk = (
2088*da0073e9SAndroid Build Coastguard Worker            seqlen_k
2089*da0073e9SAndroid Build Coastguard Worker            if key_padding_mask is None
2090*da0073e9SAndroid Build Coastguard Worker            else key_padding_mask.sum(-1).view(-1, 1, 1, 1)
2091*da0073e9SAndroid Build Coastguard Worker            # else rearrange(key_padding_mask.sum(-1), "b -> b 1 1 1")
2092*da0073e9SAndroid Build Coastguard Worker        )
2093*da0073e9SAndroid Build Coastguard Worker        sq = (
2094*da0073e9SAndroid Build Coastguard Worker            seqlen_q
2095*da0073e9SAndroid Build Coastguard Worker            if query_padding_mask is None
2096*da0073e9SAndroid Build Coastguard Worker            else query_padding_mask.sum(-1).view(-1, 1, 1, 1)
2097*da0073e9SAndroid Build Coastguard Worker            # else rearrange(query_padding_mask.sum(-1), "b -> b 1 1 1")
2098*da0073e9SAndroid Build Coastguard Worker        )
2099*da0073e9SAndroid Build Coastguard Worker        if window_size[0] < 0:
2100*da0073e9SAndroid Build Coastguard Worker            return col_idx > row_idx + sk - sq + window_size[1]
2101*da0073e9SAndroid Build Coastguard Worker        else:
2102*da0073e9SAndroid Build Coastguard Worker            sk = torch.full_like(col_idx, seqlen_k) if key_padding_mask is None else sk
2103*da0073e9SAndroid Build Coastguard Worker            return torch.logical_or(
2104*da0073e9SAndroid Build Coastguard Worker                col_idx > torch.minimum(row_idx + sk - sq + window_size[1], sk),
2105*da0073e9SAndroid Build Coastguard Worker                col_idx < row_idx + sk - sq - window_size[0],
2106*da0073e9SAndroid Build Coastguard Worker            )
2107*da0073e9SAndroid Build Coastguard Worker
2108*da0073e9SAndroid Build Coastguard Worker    def convert_flash_attn_S_to_softmax(
2109*da0073e9SAndroid Build Coastguard Worker        self,
2110*da0073e9SAndroid Build Coastguard Worker        S,
2111*da0073e9SAndroid Build Coastguard Worker        seqlen_q,
2112*da0073e9SAndroid Build Coastguard Worker        seqlen_k,
2113*da0073e9SAndroid Build Coastguard Worker        query_padding_mask,
2114*da0073e9SAndroid Build Coastguard Worker        key_padding_mask,
2115*da0073e9SAndroid Build Coastguard Worker        causal=False,
2116*da0073e9SAndroid Build Coastguard Worker        window_size=(-1, -1),  # -1 means infinite window size
2117*da0073e9SAndroid Build Coastguard Worker    ):
2118*da0073e9SAndroid Build Coastguard Worker        """FlashAttention stores the S matrix in a different way.
2119*da0073e9SAndroid Build Coastguard Worker        Arguments:
2120*da0073e9SAndroid Build Coastguard Worker            S: (batch_size, nheads, seqlen_q, seqlen_k)
2121*da0073e9SAndroid Build Coastguard Worker            query_padding_mask: (batch_size, seqlen_q)
2122*da0073e9SAndroid Build Coastguard Worker            key_padding_mask: (batch_size, seqlen_k)
2123*da0073e9SAndroid Build Coastguard Worker        """
2124*da0073e9SAndroid Build Coastguard Worker        if TEST_WITH_ROCM:
2125*da0073e9SAndroid Build Coastguard Worker            return S
2126*da0073e9SAndroid Build Coastguard Worker        b = S.shape[0]
2127*da0073e9SAndroid Build Coastguard Worker
2128*da0073e9SAndroid Build Coastguard Worker        if causal:
2129*da0073e9SAndroid Build Coastguard Worker            window_size = (window_size[0], 0)
2130*da0073e9SAndroid Build Coastguard Worker        seqlen_q_rounded, seqlen_k_rounded = S.shape[-2:]
2131*da0073e9SAndroid Build Coastguard Worker        S_converted = S
2132*da0073e9SAndroid Build Coastguard Worker        if window_size[0] >= 0 or window_size[1] >= 0:
2133*da0073e9SAndroid Build Coastguard Worker            local_mask = self.construct_local_mask(
2134*da0073e9SAndroid Build Coastguard Worker                seqlen_q,
2135*da0073e9SAndroid Build Coastguard Worker                seqlen_k,
2136*da0073e9SAndroid Build Coastguard Worker                window_size,
2137*da0073e9SAndroid Build Coastguard Worker                query_padding_mask,
2138*da0073e9SAndroid Build Coastguard Worker                key_padding_mask,
2139*da0073e9SAndroid Build Coastguard Worker                S.device,
2140*da0073e9SAndroid Build Coastguard Worker            )
2141*da0073e9SAndroid Build Coastguard Worker            local_mask = F.pad(
2142*da0073e9SAndroid Build Coastguard Worker                local_mask,
2143*da0073e9SAndroid Build Coastguard Worker                (0, seqlen_k_rounded - seqlen_k, 0, seqlen_q_rounded - seqlen_q),
2144*da0073e9SAndroid Build Coastguard Worker                value=True,
2145*da0073e9SAndroid Build Coastguard Worker            )
2146*da0073e9SAndroid Build Coastguard Worker            S_converted = S_converted.masked_fill(local_mask, 0.0)
2147*da0073e9SAndroid Build Coastguard Worker
2148*da0073e9SAndroid Build Coastguard Worker        # Need to zero out things not in attention_mask in case S was initialized with random values
2149*da0073e9SAndroid Build Coastguard Worker        # and some of those values aren't overwritten.
2150*da0073e9SAndroid Build Coastguard Worker        seqlen_q_og = (
2151*da0073e9SAndroid Build Coastguard Worker            query_padding_mask.shape[-1] if query_padding_mask is not None else seqlen_q_rounded
2152*da0073e9SAndroid Build Coastguard Worker        )
2153*da0073e9SAndroid Build Coastguard Worker        if query_padding_mask is not None:
2154*da0073e9SAndroid Build Coastguard Worker            query_padding_mask = F.pad(query_padding_mask, (0, seqlen_q_rounded - seqlen_q_og))
2155*da0073e9SAndroid Build Coastguard Worker            # S_converted = S_converted.masked_fill(rearrange(~query_padding_mask, "b s -> b 1 s 1"), 0.0)
2156*da0073e9SAndroid Build Coastguard Worker            S_converted = S_converted.masked_fill(~query_padding_mask.view(b, 1, -1, 1), 0.0)
2157*da0073e9SAndroid Build Coastguard Worker        seqlen_k_og = key_padding_mask.shape[-1] if key_padding_mask is not None else seqlen_k
2158*da0073e9SAndroid Build Coastguard Worker        if key_padding_mask is not None:
2159*da0073e9SAndroid Build Coastguard Worker            key_padding_mask = F.pad(key_padding_mask, (0, seqlen_k_rounded - seqlen_k_og))
2160*da0073e9SAndroid Build Coastguard Worker            S_converted = S_converted.masked_fill(~key_padding_mask.view(b, 1, 1, -1), 0.0)
2161*da0073e9SAndroid Build Coastguard Worker            # S_converted = S_converted.masked_fill(rearrange(~key_padding_mask, "b s -> b 1 1 s"), 0.0)
2162*da0073e9SAndroid Build Coastguard Worker        S_converted = F.pad(S_converted, (0, 0, 0, seqlen_q_og - seqlen_q_rounded))
2163*da0073e9SAndroid Build Coastguard Worker        S_converted = F.pad(S_converted, (0, seqlen_k_og - seqlen_k_rounded))
2164*da0073e9SAndroid Build Coastguard Worker        return S_converted[:, :, :seqlen_q, :seqlen_k]
2165*da0073e9SAndroid Build Coastguard Worker
2166*da0073e9SAndroid Build Coastguard Worker    @skipIfRocm  # No cuDNN Attention
2167*da0073e9SAndroid Build Coastguard Worker    @unittest.skipIf(not PLATFORM_SUPPORTS_CUDNN_ATTENTION, "cuDNN Attention is not supported on this system")
2168*da0073e9SAndroid Build Coastguard Worker    def test_cudnn_attention_different_dk_dv(self, device):
2169*da0073e9SAndroid Build Coastguard Worker        dtype = torch.bfloat16
2170*da0073e9SAndroid Build Coastguard Worker        make_tensor = partial(torch.rand, device=device, dtype=dtype, requires_grad=True)
2171*da0073e9SAndroid Build Coastguard Worker        batch, num_heads, head_dim_k, head_dim_v = 32, 16, 128, 64
2172*da0073e9SAndroid Build Coastguard Worker        seq_len = 640
2173*da0073e9SAndroid Build Coastguard Worker        q_shape = SdpaShape(batch, num_heads, seq_len, head_dim_k)
2174*da0073e9SAndroid Build Coastguard Worker        k_shape = SdpaShape(batch, num_heads, seq_len, head_dim_k)
2175*da0073e9SAndroid Build Coastguard Worker        v_shape = SdpaShape(batch, num_heads, seq_len, head_dim_v)
2176*da0073e9SAndroid Build Coastguard Worker        query, key, value = make_tensor(q_shape), make_tensor(k_shape), make_tensor(v_shape)
2177*da0073e9SAndroid Build Coastguard Worker
2178*da0073e9SAndroid Build Coastguard Worker        with sdpa_kernel(backends=[SDPBackend.CUDNN_ATTENTION]):
2179*da0073e9SAndroid Build Coastguard Worker            actual = torch.nn.functional.scaled_dot_product_attention(
2180*da0073e9SAndroid Build Coastguard Worker                query, key, value, attn_mask=None, dropout_p=0.0, is_causal=False)
2181*da0073e9SAndroid Build Coastguard Worker        with sdpa_kernel(backends=[SDPBackend.MATH]):
2182*da0073e9SAndroid Build Coastguard Worker            math_ref = torch.nn.functional.scaled_dot_product_attention(
2183*da0073e9SAndroid Build Coastguard Worker                query.contiguous().to(torch.float32),
2184*da0073e9SAndroid Build Coastguard Worker                key.contiguous().to(torch.float32),
2185*da0073e9SAndroid Build Coastguard Worker                value.contiguous().to(torch.float32),
2186*da0073e9SAndroid Build Coastguard Worker                attn_mask=None, dropout_p=0.0, is_causal=False)
2187*da0073e9SAndroid Build Coastguard Worker
2188*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(actual.contiguous(), math_ref.contiguous().to(dtype), atol=1e-3, rtol=1e-2)
2189*da0073e9SAndroid Build Coastguard Worker
2190*da0073e9SAndroid Build Coastguard Worker
2191*da0073e9SAndroid Build Coastguard Worker    @unittest.skipIf(not PLATFORM_SUPPORTS_MEM_EFF_ATTENTION, "Fused SDPA was not built for this system")
2192*da0073e9SAndroid Build Coastguard Worker    @parametrize("mask_dim", [1, 2, 3, 4])
2193*da0073e9SAndroid Build Coastguard Worker    def test_mem_efficient_attention_mask_variants(self, device, mask_dim: List[int]):
2194*da0073e9SAndroid Build Coastguard Worker        dtype = torch.float16
2195*da0073e9SAndroid Build Coastguard Worker        make_tensor = partial(torch.rand, device=device, dtype=dtype, requires_grad=True)
2196*da0073e9SAndroid Build Coastguard Worker        batch, num_heads, head_dim = 8, 8, 64
2197*da0073e9SAndroid Build Coastguard Worker        seq_len_q, seq_len_kv = 64, 32
2198*da0073e9SAndroid Build Coastguard Worker        query = make_tensor(SdpaShape(batch, num_heads, seq_len_q, head_dim))
2199*da0073e9SAndroid Build Coastguard Worker        kv_shape = SdpaShape(batch, num_heads, seq_len_kv, head_dim)
2200*da0073e9SAndroid Build Coastguard Worker        key, value = make_tensor(kv_shape), make_tensor(kv_shape)
2201*da0073e9SAndroid Build Coastguard Worker
2202*da0073e9SAndroid Build Coastguard Worker        if mask_dim == 1:
2203*da0073e9SAndroid Build Coastguard Worker            mask = torch.randn((seq_len_kv,), device=device, dtype=dtype)
2204*da0073e9SAndroid Build Coastguard Worker        elif mask_dim == 2:
2205*da0073e9SAndroid Build Coastguard Worker            mask = torch.randn((seq_len_q, seq_len_kv), device=device, dtype=dtype)
2206*da0073e9SAndroid Build Coastguard Worker        elif mask_dim == 3:
2207*da0073e9SAndroid Build Coastguard Worker            mask = torch.randn((num_heads, seq_len_q, seq_len_kv), device=device, dtype=dtype)
2208*da0073e9SAndroid Build Coastguard Worker        elif mask_dim == 4:
2209*da0073e9SAndroid Build Coastguard Worker            mask = torch.randn((batch, num_heads, seq_len_q, seq_len_kv), device=device, dtype=dtype)
2210*da0073e9SAndroid Build Coastguard Worker        with sdpa_kernel(backends=[SDPBackend.EFFICIENT_ATTENTION]):
2211*da0073e9SAndroid Build Coastguard Worker            out = F.scaled_dot_product_attention(query, key, value, mask)
2212*da0073e9SAndroid Build Coastguard Worker        out.sum().backward()
2213*da0073e9SAndroid Build Coastguard Worker
2214*da0073e9SAndroid Build Coastguard Worker    @unittest.skipIf(not PLATFORM_SUPPORTS_MEM_EFF_ATTENTION, "Fused SDPA was not built for this system")
2215*da0073e9SAndroid Build Coastguard Worker    @parametrize("dtype", [torch.float, torch.float16])
2216*da0073e9SAndroid Build Coastguard Worker    def test_mem_eff_attention_pad_mask(self, device, dtype):
2217*da0073e9SAndroid Build Coastguard Worker        make_tensor = partial(torch.rand, device=device, dtype=dtype, requires_grad=True)
2218*da0073e9SAndroid Build Coastguard Worker        batch, num_heads, head_dim = 8, 8, 64
2219*da0073e9SAndroid Build Coastguard Worker        seq_len_q, seq_len_kv = 64, 15
2220*da0073e9SAndroid Build Coastguard Worker        query = make_tensor(SdpaShape(batch, num_heads, seq_len_q, head_dim))
2221*da0073e9SAndroid Build Coastguard Worker        kv_shape = SdpaShape(batch, num_heads, seq_len_kv, head_dim)
2222*da0073e9SAndroid Build Coastguard Worker        key, value = make_tensor(kv_shape), make_tensor(kv_shape)
2223*da0073e9SAndroid Build Coastguard Worker        mask = torch.randn((batch, num_heads, seq_len_q, seq_len_kv), device=device, dtype=dtype)
2224*da0073e9SAndroid Build Coastguard Worker        with sdpa_kernel(backends=[SDPBackend.EFFICIENT_ATTENTION]):
2225*da0073e9SAndroid Build Coastguard Worker            out = F.scaled_dot_product_attention(query, key, value, mask)
2226*da0073e9SAndroid Build Coastguard Worker        out.sum().backward()
2227*da0073e9SAndroid Build Coastguard Worker
2228*da0073e9SAndroid Build Coastguard Worker    @unittest.skipIf(not PLATFORM_SUPPORTS_MEM_EFF_ATTENTION, "Fused SDPA was not built for this system")
2229*da0073e9SAndroid Build Coastguard Worker    @parametrize("dtype", [torch.float, torch.float16])
2230*da0073e9SAndroid Build Coastguard Worker    def test_mem_eff_attention_non_contiguous_mask(self, device, dtype):
2231*da0073e9SAndroid Build Coastguard Worker        make_tensor = partial(torch.rand, device=device, dtype=dtype, requires_grad=True)
2232*da0073e9SAndroid Build Coastguard Worker        batch, num_heads, head_dim = 8, 8, 64
2233*da0073e9SAndroid Build Coastguard Worker        seq_len_q, seq_len_kv = 64, 16
2234*da0073e9SAndroid Build Coastguard Worker        query = make_tensor(SdpaShape(batch, num_heads, seq_len_q, head_dim))
2235*da0073e9SAndroid Build Coastguard Worker        kv_shape = SdpaShape(batch, num_heads, seq_len_kv, head_dim)
2236*da0073e9SAndroid Build Coastguard Worker        key, value = make_tensor(kv_shape), make_tensor(kv_shape)
2237*da0073e9SAndroid Build Coastguard Worker        mask = torch.randn((batch, num_heads, seq_len_q, seq_len_kv), device=device, dtype=dtype)
2238*da0073e9SAndroid Build Coastguard Worker        mask = torch.as_strided(mask, (batch, num_heads, seq_len_q, seq_len_kv), (0, 0, 0, 1))
2239*da0073e9SAndroid Build Coastguard Worker        with sdpa_kernel(backends=[SDPBackend.EFFICIENT_ATTENTION]):
2240*da0073e9SAndroid Build Coastguard Worker            out = F.scaled_dot_product_attention(query, key, value, mask)
2241*da0073e9SAndroid Build Coastguard Worker        out.sum().backward()
2242*da0073e9SAndroid Build Coastguard Worker
2243*da0073e9SAndroid Build Coastguard Worker    @unittest.skipIf(not PLATFORM_SUPPORTS_MEM_EFF_ATTENTION, "Fused SDPA was not built for this system")
2244*da0073e9SAndroid Build Coastguard Worker    @parametrize("dtype", [torch.float, torch.float16])
2245*da0073e9SAndroid Build Coastguard Worker    def test_mem_eff_attention_long_sequence_mask(self, device, dtype):
2246*da0073e9SAndroid Build Coastguard Worker        if torch.cuda.get_device_properties('cuda').total_memory < 80 * 2**30:
2247*da0073e9SAndroid Build Coastguard Worker            unittest.skip("This test requires substatnial GPU memory.")
2248*da0073e9SAndroid Build Coastguard Worker            return
2249*da0073e9SAndroid Build Coastguard Worker        make_tensor = partial(torch.rand, device=device, dtype=dtype, requires_grad=True)
2250*da0073e9SAndroid Build Coastguard Worker        batch, num_heads, head_dim = 1, 32, 64
2251*da0073e9SAndroid Build Coastguard Worker        seq_len_q, seq_len_kv = 8192, 8192
2252*da0073e9SAndroid Build Coastguard Worker        query = make_tensor(SdpaShape(batch, num_heads, seq_len_q, head_dim))
2253*da0073e9SAndroid Build Coastguard Worker        kv_shape = SdpaShape(batch, num_heads, seq_len_kv, head_dim)
2254*da0073e9SAndroid Build Coastguard Worker        key, value = make_tensor(kv_shape), make_tensor(kv_shape)
2255*da0073e9SAndroid Build Coastguard Worker        mask = torch.randn((batch, num_heads, seq_len_q, seq_len_kv), device=device, dtype=dtype)
2256*da0073e9SAndroid Build Coastguard Worker        with sdpa_kernel(backends=[SDPBackend.EFFICIENT_ATTENTION]):
2257*da0073e9SAndroid Build Coastguard Worker            out = F.scaled_dot_product_attention(query, key, value, mask)
2258*da0073e9SAndroid Build Coastguard Worker        out.sum().backward()
2259*da0073e9SAndroid Build Coastguard Worker
2260*da0073e9SAndroid Build Coastguard Worker    @unittest.skipIf(not PLATFORM_SUPPORTS_MEM_EFF_ATTENTION, "Fused SDPA was not built for this system")
2261*da0073e9SAndroid Build Coastguard Worker    def test_mem_eff_attention_non_contig_mask_bug(self, device):
2262*da0073e9SAndroid Build Coastguard Worker        # Without the fix this produces `AssertionError: assert 0.07352933287620544 < 1e-07`
2263*da0073e9SAndroid Build Coastguard Worker        # Shapes taken from repro
2264*da0073e9SAndroid Build Coastguard Worker        query_size = (3, 16, 1, 128)
2265*da0073e9SAndroid Build Coastguard Worker        query_strides = (2304, 128, 2048, 1)
2266*da0073e9SAndroid Build Coastguard Worker        key_size = (3, 16, 14, 128)
2267*da0073e9SAndroid Build Coastguard Worker        key_strides = (3584, 0, 256, 1)
2268*da0073e9SAndroid Build Coastguard Worker        value_size = (3, 16, 14, 128)
2269*da0073e9SAndroid Build Coastguard Worker        value_strides = (3584, 0, 256, 1)
2270*da0073e9SAndroid Build Coastguard Worker        attention_mask_size = (3, 1, 1, 14)
2271*da0073e9SAndroid Build Coastguard Worker        attn_mask_strides = (14, 14, 14, 1)
2272*da0073e9SAndroid Build Coastguard Worker
2273*da0073e9SAndroid Build Coastguard Worker        # Calculate the number of elements needed for each tensor
2274*da0073e9SAndroid Build Coastguard Worker        query_num_elements = max(size * stride for size, stride in zip(query_size, query_strides))
2275*da0073e9SAndroid Build Coastguard Worker        key_num_elements = max(size * stride for size, stride in zip(key_size, key_strides))
2276*da0073e9SAndroid Build Coastguard Worker        value_num_elements = max(size * stride for size, stride in zip(value_size, value_strides))
2277*da0073e9SAndroid Build Coastguard Worker        attention_mask_num_elements = max(size * stride for size, stride in zip(attention_mask_size, attn_mask_strides))
2278*da0073e9SAndroid Build Coastguard Worker
2279*da0073e9SAndroid Build Coastguard Worker        # Create the tensors with the specified sizes and strides
2280*da0073e9SAndroid Build Coastguard Worker        query = torch.randn(query_num_elements, device=device).as_strided(query_size, query_strides)
2281*da0073e9SAndroid Build Coastguard Worker        key = torch.randn(key_num_elements, device=device).as_strided(key_size, key_strides)
2282*da0073e9SAndroid Build Coastguard Worker        value = torch.randn(value_num_elements, device=device).as_strided(value_size, value_strides)
2283*da0073e9SAndroid Build Coastguard Worker        bias = torch.randn(attention_mask_num_elements, device=device).as_strided(attention_mask_size, attn_mask_strides)
2284*da0073e9SAndroid Build Coastguard Worker
2285*da0073e9SAndroid Build Coastguard Worker        with sdpa_kernel(backends=[SDPBackend.EFFICIENT_ATTENTION]):
2286*da0073e9SAndroid Build Coastguard Worker            out = F.scaled_dot_product_attention(query, key, value, bias)
2287*da0073e9SAndroid Build Coastguard Worker            out_contig = F.scaled_dot_product_attention(query, key, value, bias.contiguous())
2288*da0073e9SAndroid Build Coastguard Worker
2289*da0073e9SAndroid Build Coastguard Worker        max_diff = (out - out_contig).abs().mean()
2290*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(max_diff.item() < 1e-7)
2291*da0073e9SAndroid Build Coastguard Worker
2292*da0073e9SAndroid Build Coastguard Worker    @unittest.skipIf(not PLATFORM_SUPPORTS_FLASH_ATTENTION, "Fused SDPA was not built for this system")
2293*da0073e9SAndroid Build Coastguard Worker    def test_singelton_head_dim_stride_ne_1(self, device):
2294*da0073e9SAndroid Build Coastguard Worker        query = torch.tensor([[[[1, 2]]]], dtype=torch.float16, device=device)
2295*da0073e9SAndroid Build Coastguard Worker        query = query.transpose(-1, -2)
2296*da0073e9SAndroid Build Coastguard Worker        key = torch.tensor([[[[1]]]], dtype=torch.float16, device=device)
2297*da0073e9SAndroid Build Coastguard Worker        value = torch.tensor([[[[1]]]], dtype=torch.float16, device=device)
2298*da0073e9SAndroid Build Coastguard Worker
2299*da0073e9SAndroid Build Coastguard Worker        with torch.backends.cuda.sdp_kernel(enable_math=False, enable_flash=True, enable_mem_efficient=False):
2300*da0073e9SAndroid Build Coastguard Worker            scaled_dot_product_attention(query, key, value)
2301*da0073e9SAndroid Build Coastguard Worker
2302*da0073e9SAndroid Build Coastguard Worker    @unittest.skipIf(not PLATFORM_SUPPORTS_MEM_EFF_ATTENTION, "Fused SDPA was not built for this system")
2303*da0073e9SAndroid Build Coastguard Worker    @parametrize("type", ["dense", "nested"])
2304*da0073e9SAndroid Build Coastguard Worker    @parametrize("is_contiguous", [True, False])
2305*da0073e9SAndroid Build Coastguard Worker    def test_scaled_dot_product_attention_fused_kernels_packed(self, device, type: str, is_contiguous: bool):
2306*da0073e9SAndroid Build Coastguard Worker        if TEST_WITH_ROCM and type == 'nested':
2307*da0073e9SAndroid Build Coastguard Worker            self.skipTest("ROCM does not support efficient attention on nested tensors, for now")
2308*da0073e9SAndroid Build Coastguard Worker        make_tensor = partial(rand_sdpa_tensor, type=type, device=device, dtype=torch.float16, packed=True)
2309*da0073e9SAndroid Build Coastguard Worker
2310*da0073e9SAndroid Build Coastguard Worker        batch_size, seq_len, num_heads, head_dim = 32, 64, 16, 64
2311*da0073e9SAndroid Build Coastguard Worker        shape = SdpaShape(batch_size, num_heads, seq_len, head_dim)
2312*da0073e9SAndroid Build Coastguard Worker
2313*da0073e9SAndroid Build Coastguard Worker        # Test Packed
2314*da0073e9SAndroid Build Coastguard Worker        qkv = make_tensor(shape)
2315*da0073e9SAndroid Build Coastguard Worker        query, key, value = qkv.chunk(3, dim=-1)
2316*da0073e9SAndroid Build Coastguard Worker
2317*da0073e9SAndroid Build Coastguard Worker        query = query.view(batch_size, -1, num_heads, head_dim).transpose(1, 2)
2318*da0073e9SAndroid Build Coastguard Worker        value = value.view(batch_size, -1, num_heads, head_dim).transpose(1, 2)
2319*da0073e9SAndroid Build Coastguard Worker        key = key.view(batch_size, -1, num_heads, head_dim).transpose(1, 2)
2320*da0073e9SAndroid Build Coastguard Worker
2321*da0073e9SAndroid Build Coastguard Worker        if is_contiguous:
2322*da0073e9SAndroid Build Coastguard Worker            query = query.contiguous()
2323*da0073e9SAndroid Build Coastguard Worker            key = key.contiguous()
2324*da0073e9SAndroid Build Coastguard Worker            value = value.contiguous()
2325*da0073e9SAndroid Build Coastguard Worker
2326*da0073e9SAndroid Build Coastguard Worker        with sdpa_kernel(backends=[SDPBackend.EFFICIENT_ATTENTION]):
2327*da0073e9SAndroid Build Coastguard Worker            actual = torch.nn.functional.scaled_dot_product_attention(
2328*da0073e9SAndroid Build Coastguard Worker                query, key, value, attn_mask=None, dropout_p=0.0, is_causal=False)
2329*da0073e9SAndroid Build Coastguard Worker        with sdpa_kernel(backends=[SDPBackend.MATH]):
2330*da0073e9SAndroid Build Coastguard Worker            math_ref = torch.nn.functional.scaled_dot_product_attention(
2331*da0073e9SAndroid Build Coastguard Worker                query.contiguous(), key.contiguous(), value.contiguous(),
2332*da0073e9SAndroid Build Coastguard Worker                attn_mask=None, dropout_p=0.0, is_causal=False)
2333*da0073e9SAndroid Build Coastguard Worker
2334*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(actual.contiguous(), math_ref.contiguous(), atol=2e-3, rtol=1e-2)
2335*da0073e9SAndroid Build Coastguard Worker
2336*da0073e9SAndroid Build Coastguard Worker    @skipIfRocm  # Missing nested and EFFICIENT_ATTENTION
2337*da0073e9SAndroid Build Coastguard Worker    @unittest.skipIf(not PLATFORM_SUPPORTS_FUSED_ATTENTION, "Fused SDPA was not built for this system")
2338*da0073e9SAndroid Build Coastguard Worker    @parametrize("type", ["dense", "nested"])
2339*da0073e9SAndroid Build Coastguard Worker    @parametrize("fused_kernel", [SDPBackend.FLASH_ATTENTION, SDPBackend.EFFICIENT_ATTENTION] if
2340*da0073e9SAndroid Build Coastguard Worker                 PLATFORM_SUPPORTS_FLASH_ATTENTION else [SDPBackend.EFFICIENT_ATTENTION])
2341*da0073e9SAndroid Build Coastguard Worker    def test_scaled_dot_product_attention_fused_kernels_packed_accuracy(self, device, type: str, fused_kernel: str):
2342*da0073e9SAndroid Build Coastguard Worker        def rand_nt(shape):
2343*da0073e9SAndroid Build Coastguard Worker            batch, seq_len, num_heads, head_dim = shape
2344*da0073e9SAndroid Build Coastguard Worker            tensors = [6 * torch.rand((seq_len, 3 * num_heads * head_dim), device=device, dtype=torch.float32) - 3
2345*da0073e9SAndroid Build Coastguard Worker                       for _ in range(batch)]
2346*da0073e9SAndroid Build Coastguard Worker            return (torch.nested.nested_tensor(tensors, device=device, dtype=torch.float32),
2347*da0073e9SAndroid Build Coastguard Worker                    torch.nested.nested_tensor(tensors, device=device, dtype=torch.float16))
2348*da0073e9SAndroid Build Coastguard Worker
2349*da0073e9SAndroid Build Coastguard Worker        def rand_tensor(shape):
2350*da0073e9SAndroid Build Coastguard Worker            batch, seq_len, num_heads, head_dim = shape
2351*da0073e9SAndroid Build Coastguard Worker            tensor = 6 * torch.rand((batch, seq_len, 3 * num_heads * head_dim), device=device, dtype=torch.float32) - 3
2352*da0073e9SAndroid Build Coastguard Worker            return tensor, tensor.to(dtype=torch.float16)
2353*da0073e9SAndroid Build Coastguard Worker
2354*da0073e9SAndroid Build Coastguard Worker        batch_size, seq_len, num_heads, head_dim = 16, 8, 4, 64
2355*da0073e9SAndroid Build Coastguard Worker        shape = (batch_size, seq_len, num_heads, head_dim)
2356*da0073e9SAndroid Build Coastguard Worker
2357*da0073e9SAndroid Build Coastguard Worker        # Test Packed
2358*da0073e9SAndroid Build Coastguard Worker        qkv, qkv_low_precision = rand_tensor(shape) if type == "dense" else rand_nt(shape)
2359*da0073e9SAndroid Build Coastguard Worker        query, key, value = qkv.chunk(3, dim=-1)
2360*da0073e9SAndroid Build Coastguard Worker        query_lp, key_lp, value_lp = qkv_low_precision.chunk(3, dim=-1)
2361*da0073e9SAndroid Build Coastguard Worker
2362*da0073e9SAndroid Build Coastguard Worker        query = query.view(batch_size, -1, num_heads, head_dim).transpose(1, 2)
2363*da0073e9SAndroid Build Coastguard Worker        key = key.view(batch_size, -1, num_heads, head_dim).transpose(1, 2)
2364*da0073e9SAndroid Build Coastguard Worker        value = value.view(batch_size, -1, num_heads, head_dim).transpose(1, 2)
2365*da0073e9SAndroid Build Coastguard Worker
2366*da0073e9SAndroid Build Coastguard Worker        query_lp = query_lp.view(batch_size, -1, num_heads, head_dim).transpose(1, 2)
2367*da0073e9SAndroid Build Coastguard Worker        key_lp = key_lp.view(batch_size, -1, num_heads, head_dim).transpose(1, 2)
2368*da0073e9SAndroid Build Coastguard Worker        value_lp = value_lp.view(batch_size, -1, num_heads, head_dim).transpose(1, 2)
2369*da0073e9SAndroid Build Coastguard Worker
2370*da0073e9SAndroid Build Coastguard Worker        with sdpa_kernel(backends=[fused_kernel]):
2371*da0073e9SAndroid Build Coastguard Worker            actual = torch.nn.functional.scaled_dot_product_attention(
2372*da0073e9SAndroid Build Coastguard Worker                query_lp, key_lp, value_lp, attn_mask=None, dropout_p=0.0, is_causal=False)
2373*da0073e9SAndroid Build Coastguard Worker
2374*da0073e9SAndroid Build Coastguard Worker        with sdpa_kernel(backends=[SDPBackend.MATH]):
2375*da0073e9SAndroid Build Coastguard Worker            math_ref_lp = torch.nn.functional.scaled_dot_product_attention(
2376*da0073e9SAndroid Build Coastguard Worker                query_lp.contiguous(), key_lp.contiguous(), value_lp.contiguous(),
2377*da0073e9SAndroid Build Coastguard Worker                attn_mask=None, dropout_p=0.0, is_causal=False)
2378*da0073e9SAndroid Build Coastguard Worker
2379*da0073e9SAndroid Build Coastguard Worker            math_query = query.contiguous()
2380*da0073e9SAndroid Build Coastguard Worker            math_key = key.contiguous()
2381*da0073e9SAndroid Build Coastguard Worker            math_value = value.contiguous()
2382*da0073e9SAndroid Build Coastguard Worker
2383*da0073e9SAndroid Build Coastguard Worker            math_ref = torch.nn.functional.scaled_dot_product_attention(
2384*da0073e9SAndroid Build Coastguard Worker                math_query, math_key, math_value, attn_mask=None, dropout_p=0.0, is_causal=False)
2385*da0073e9SAndroid Build Coastguard Worker
2386*da0073e9SAndroid Build Coastguard Worker        actual_test = actual
2387*da0073e9SAndroid Build Coastguard Worker        math_ref_test = math_ref
2388*da0073e9SAndroid Build Coastguard Worker        math_ref_lp_test = math_ref_lp
2389*da0073e9SAndroid Build Coastguard Worker
2390*da0073e9SAndroid Build Coastguard Worker        if actual_test.is_nested:
2391*da0073e9SAndroid Build Coastguard Worker            actual_test = torch.nested.to_padded_tensor(actual_test.contiguous(), padding=0.0)
2392*da0073e9SAndroid Build Coastguard Worker            math_ref_test = torch.nested.to_padded_tensor(math_ref_test, padding=0.0)
2393*da0073e9SAndroid Build Coastguard Worker            math_ref_lp_test = torch.nested.to_padded_tensor(math_ref_lp_test, padding=0.0)
2394*da0073e9SAndroid Build Coastguard Worker
2395*da0073e9SAndroid Build Coastguard Worker        actual_test = actual_test.to(dtype=torch.float32).contiguous()
2396*da0073e9SAndroid Build Coastguard Worker        math_ref_test = math_ref_test.to(dtype=torch.float32).contiguous()
2397*da0073e9SAndroid Build Coastguard Worker        math_ref_lp_test = math_ref_lp_test.to(dtype=torch.float32).contiguous()
2398*da0073e9SAndroid Build Coastguard Worker
2399*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(math_ref_test, math_ref_lp_test, atol=7e-3, rtol=7e-3)
2400*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(actual_test, math_ref_test, atol=5e-3, rtol=5e-3)
2401*da0073e9SAndroid Build Coastguard Worker
2402*da0073e9SAndroid Build Coastguard Worker    @unittest.skipIf(not PLATFORM_SUPPORTS_MEM_EFF_ATTENTION, "Efficient Attention was not built for this system")
2403*da0073e9SAndroid Build Coastguard Worker    @parametrize("contiguous_inputs", [True, False])
2404*da0073e9SAndroid Build Coastguard Worker    @parametrize("is_causal", [True, False])
2405*da0073e9SAndroid Build Coastguard Worker    def test_sdp_mem_efficient_grad_against_math(self, device, contiguous_inputs: bool, is_causal: bool):
2406*da0073e9SAndroid Build Coastguard Worker        batch_size, seq_len, num_heads, head_dim = 4, 4, 2, 16
2407*da0073e9SAndroid Build Coastguard Worker        make_tensor = partial(rand_sdpa_tensor, type="dense", device=device,
2408*da0073e9SAndroid Build Coastguard Worker                              dtype=torch.float64, requires_grad=True, packed=True)
2409*da0073e9SAndroid Build Coastguard Worker
2410*da0073e9SAndroid Build Coastguard Worker        qkv = make_tensor(SdpaShape(batch_size, num_heads, seq_len, head_dim))
2411*da0073e9SAndroid Build Coastguard Worker        qkv_lp = qkv.detach().clone().to(torch.float32).requires_grad_()
2412*da0073e9SAndroid Build Coastguard Worker
2413*da0073e9SAndroid Build Coastguard Worker        query, key, value = qkv.chunk(3, dim=-1)
2414*da0073e9SAndroid Build Coastguard Worker        query_lp, key_lp, value_lp = qkv_lp.chunk(3, dim=-1)
2415*da0073e9SAndroid Build Coastguard Worker
2416*da0073e9SAndroid Build Coastguard Worker        query = query.view(batch_size, -1, num_heads, head_dim).transpose(1, 2)
2417*da0073e9SAndroid Build Coastguard Worker        key = key.view(batch_size, -1, num_heads, head_dim).transpose(1, 2)
2418*da0073e9SAndroid Build Coastguard Worker        value = value.view(batch_size, -1, num_heads, head_dim).transpose(1, 2)
2419*da0073e9SAndroid Build Coastguard Worker
2420*da0073e9SAndroid Build Coastguard Worker        query_lp = query_lp.view(batch_size, -1, num_heads, head_dim).transpose(1, 2)
2421*da0073e9SAndroid Build Coastguard Worker        key_lp = key_lp.view(batch_size, -1, num_heads, head_dim).transpose(1, 2)
2422*da0073e9SAndroid Build Coastguard Worker        value_lp = value_lp.view(batch_size, -1, num_heads, head_dim).transpose(1, 2)
2423*da0073e9SAndroid Build Coastguard Worker
2424*da0073e9SAndroid Build Coastguard Worker        if contiguous_inputs:
2425*da0073e9SAndroid Build Coastguard Worker            query = query.contiguous()
2426*da0073e9SAndroid Build Coastguard Worker            key = key.contiguous()
2427*da0073e9SAndroid Build Coastguard Worker            value = value.contiguous()
2428*da0073e9SAndroid Build Coastguard Worker
2429*da0073e9SAndroid Build Coastguard Worker            query_lp = query_lp.contiguous()
2430*da0073e9SAndroid Build Coastguard Worker            key_lp = key_lp.contiguous()
2431*da0073e9SAndroid Build Coastguard Worker            value_lp = value_lp.contiguous()
2432*da0073e9SAndroid Build Coastguard Worker
2433*da0073e9SAndroid Build Coastguard Worker        with sdpa_kernel(backends=[SDPBackend.MATH]):
2434*da0073e9SAndroid Build Coastguard Worker            out = torch.nn.functional.scaled_dot_product_attention(query, key, value, None, 0.0, is_causal)
2435*da0073e9SAndroid Build Coastguard Worker
2436*da0073e9SAndroid Build Coastguard Worker        with sdpa_kernel(backends=[SDPBackend.EFFICIENT_ATTENTION]):
2437*da0073e9SAndroid Build Coastguard Worker            out_lp = torch.nn.functional.scaled_dot_product_attention(
2438*da0073e9SAndroid Build Coastguard Worker                query_lp, key_lp, value_lp, None, 0.0, is_causal)
2439*da0073e9SAndroid Build Coastguard Worker
2440*da0073e9SAndroid Build Coastguard Worker        rand_upward = torch.rand_like(out)
2441*da0073e9SAndroid Build Coastguard Worker        rand_upward_lp = rand_upward.to(torch.float32)
2442*da0073e9SAndroid Build Coastguard Worker
2443*da0073e9SAndroid Build Coastguard Worker        out.backward(rand_upward)
2444*da0073e9SAndroid Build Coastguard Worker        out_lp.backward(rand_upward_lp)
2445*da0073e9SAndroid Build Coastguard Worker
2446*da0073e9SAndroid Build Coastguard Worker        # Cast up and compare
2447*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(qkv.grad, qkv_lp.grad.to(torch.float64), atol=1e-5, rtol=1e-5)
2448*da0073e9SAndroid Build Coastguard Worker
2449*da0073e9SAndroid Build Coastguard Worker    @unittest.skipIf(not PLATFORM_SUPPORTS_FLASH_ATTENTION, "Flash Attention was not built for this system")
2450*da0073e9SAndroid Build Coastguard Worker    @parametrize("contiguous_inputs", [True, False])
2451*da0073e9SAndroid Build Coastguard Worker    @parametrize("is_causal", [True, False])
2452*da0073e9SAndroid Build Coastguard Worker    @parametrize("dtype", [torch.float16, torch.bfloat16])
2453*da0073e9SAndroid Build Coastguard Worker    def test_sdp_flash_attention_grad_against_math(self, device, contiguous_inputs: bool, is_causal: bool, dtype: torch.dtype):
2454*da0073e9SAndroid Build Coastguard Worker        batch_size, seq_len, num_heads, head_dim = 4, 4, 2, 16
2455*da0073e9SAndroid Build Coastguard Worker        make_tensor = partial(rand_sdpa_tensor, type="dense", device=device,
2456*da0073e9SAndroid Build Coastguard Worker                              dtype=torch.float64, requires_grad=True, packed=True)
2457*da0073e9SAndroid Build Coastguard Worker
2458*da0073e9SAndroid Build Coastguard Worker        qkv = make_tensor(SdpaShape(batch_size, num_heads, seq_len, head_dim))
2459*da0073e9SAndroid Build Coastguard Worker        qkv_lp = qkv.detach().clone().to(dtype).requires_grad_()
2460*da0073e9SAndroid Build Coastguard Worker
2461*da0073e9SAndroid Build Coastguard Worker        query, key, value = qkv.chunk(3, dim=-1)
2462*da0073e9SAndroid Build Coastguard Worker        query_lp, key_lp, value_lp = qkv_lp.chunk(3, dim=-1)
2463*da0073e9SAndroid Build Coastguard Worker
2464*da0073e9SAndroid Build Coastguard Worker        query = query.view(batch_size, -1, num_heads, head_dim).transpose(1, 2)
2465*da0073e9SAndroid Build Coastguard Worker        key = key.view(batch_size, -1, num_heads, head_dim).transpose(1, 2)
2466*da0073e9SAndroid Build Coastguard Worker        value = value.view(batch_size, -1, num_heads, head_dim).transpose(1, 2)
2467*da0073e9SAndroid Build Coastguard Worker
2468*da0073e9SAndroid Build Coastguard Worker        query_lp = query_lp.view(batch_size, -1, num_heads, head_dim).transpose(1, 2)
2469*da0073e9SAndroid Build Coastguard Worker        key_lp = key_lp.view(batch_size, -1, num_heads, head_dim).transpose(1, 2)
2470*da0073e9SAndroid Build Coastguard Worker        value_lp = value_lp.view(batch_size, -1, num_heads, head_dim).transpose(1, 2)
2471*da0073e9SAndroid Build Coastguard Worker
2472*da0073e9SAndroid Build Coastguard Worker        if contiguous_inputs:
2473*da0073e9SAndroid Build Coastguard Worker            query = query.contiguous()
2474*da0073e9SAndroid Build Coastguard Worker            key = key.contiguous()
2475*da0073e9SAndroid Build Coastguard Worker            value = value.contiguous()
2476*da0073e9SAndroid Build Coastguard Worker
2477*da0073e9SAndroid Build Coastguard Worker            query_lp = query_lp.contiguous()
2478*da0073e9SAndroid Build Coastguard Worker            key_lp = key_lp.contiguous()
2479*da0073e9SAndroid Build Coastguard Worker            value_lp = value_lp.contiguous()
2480*da0073e9SAndroid Build Coastguard Worker
2481*da0073e9SAndroid Build Coastguard Worker        with sdpa_kernel(backends=[SDPBackend.MATH]):
2482*da0073e9SAndroid Build Coastguard Worker            out = torch.nn.functional.scaled_dot_product_attention(query, key, value, None, 0.0, is_causal)
2483*da0073e9SAndroid Build Coastguard Worker
2484*da0073e9SAndroid Build Coastguard Worker        with sdpa_kernel(backends=[SDPBackend.FLASH_ATTENTION]):
2485*da0073e9SAndroid Build Coastguard Worker            out_lp = torch.nn.functional.scaled_dot_product_attention(
2486*da0073e9SAndroid Build Coastguard Worker                query_lp, key_lp, value_lp, None, 0.0, is_causal)
2487*da0073e9SAndroid Build Coastguard Worker
2488*da0073e9SAndroid Build Coastguard Worker        rand_upward = torch.rand_like(out)
2489*da0073e9SAndroid Build Coastguard Worker        rand_upward_lp = rand_upward.to(dtype)
2490*da0073e9SAndroid Build Coastguard Worker
2491*da0073e9SAndroid Build Coastguard Worker        out.backward(rand_upward)
2492*da0073e9SAndroid Build Coastguard Worker        out_lp.backward(rand_upward_lp)
2493*da0073e9SAndroid Build Coastguard Worker
2494*da0073e9SAndroid Build Coastguard Worker        # Cast up and compare
2495*da0073e9SAndroid Build Coastguard Worker        # Since we are doing the compute on fp16 we have to bump the tolerance
2496*da0073e9SAndroid Build Coastguard Worker        # Bump down the tolearnce for blfoat16
2497*da0073e9SAndroid Build Coastguard Worker        atol = 7e-4 if dtype == torch.float16 else 7e-3
2498*da0073e9SAndroid Build Coastguard Worker        rtol = 7e-4 if dtype == torch.float16 else 7e-3
2499*da0073e9SAndroid Build Coastguard Worker        if TEST_WITH_ROCM:
2500*da0073e9SAndroid Build Coastguard Worker            atol = 9e-4 if dtype == torch.float16 else 9e-3
2501*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(qkv.grad, qkv_lp.grad.to(torch.float64), atol=atol, rtol=rtol)
2502*da0073e9SAndroid Build Coastguard Worker
2503*da0073e9SAndroid Build Coastguard Worker    @skipIfRocm  # Missing nested and EFFICIENT_ATTENTION
2504*da0073e9SAndroid Build Coastguard Worker    @unittest.skipIf(not PLATFORM_SUPPORTS_FUSED_ATTENTION, "Platform does not support fused SDPA")
2505*da0073e9SAndroid Build Coastguard Worker    @parametrize("type", ["dense", "nested"])
2506*da0073e9SAndroid Build Coastguard Worker    def test_fused_sdp_choice(self, device, type: str):
2507*da0073e9SAndroid Build Coastguard Worker        batch_size, seq_len, num_heads, head_dim = 2, 128, 8, 64
2508*da0073e9SAndroid Build Coastguard Worker        shape = SdpaShape(batch_size, num_heads, seq_len, head_dim)
2509*da0073e9SAndroid Build Coastguard Worker        make_tensor = partial(rand_sdpa_tensor, device=device, dtype=torch.float16, packed=True, requires_grad=True)
2510*da0073e9SAndroid Build Coastguard Worker
2511*da0073e9SAndroid Build Coastguard Worker        qkv = make_tensor(shape, type=type)
2512*da0073e9SAndroid Build Coastguard Worker        query, key, value = qkv.chunk(3, dim=-1)
2513*da0073e9SAndroid Build Coastguard Worker
2514*da0073e9SAndroid Build Coastguard Worker        query = query.view(batch_size, -1, num_heads, head_dim).transpose(1, 2)
2515*da0073e9SAndroid Build Coastguard Worker        value = value.view(batch_size, -1, num_heads, head_dim).transpose(1, 2)
2516*da0073e9SAndroid Build Coastguard Worker        key = key.view(batch_size, -1, num_heads, head_dim).transpose(1, 2)
2517*da0073e9SAndroid Build Coastguard Worker
2518*da0073e9SAndroid Build Coastguard Worker        if PLATFORM_SUPPORTS_FLASH_ATTENTION:
2519*da0073e9SAndroid Build Coastguard Worker            assert torch._fused_sdp_choice(query, key, value) == SDPBackend.FLASH_ATTENTION.value
2520*da0073e9SAndroid Build Coastguard Worker        else:
2521*da0073e9SAndroid Build Coastguard Worker            assert torch._fused_sdp_choice(query, key, value) == SDPBackend.EFFICIENT_ATTENTION.value
2522*da0073e9SAndroid Build Coastguard Worker
2523*da0073e9SAndroid Build Coastguard Worker        # Change dtype to float32 so that efficient attention should get chosen
2524*da0073e9SAndroid Build Coastguard Worker        make_tensor = partial(rand_sdpa_tensor, device=device, dtype=torch.float32, packed=True)
2525*da0073e9SAndroid Build Coastguard Worker
2526*da0073e9SAndroid Build Coastguard Worker        qkv = make_tensor(shape, type=type)
2527*da0073e9SAndroid Build Coastguard Worker        query, key, value = qkv.chunk(3, dim=-1)
2528*da0073e9SAndroid Build Coastguard Worker
2529*da0073e9SAndroid Build Coastguard Worker        query = query.view(batch_size, -1, num_heads, head_dim).transpose(1, 2)
2530*da0073e9SAndroid Build Coastguard Worker        value = value.view(batch_size, -1, num_heads, head_dim).transpose(1, 2)
2531*da0073e9SAndroid Build Coastguard Worker        key = key.view(batch_size, -1, num_heads, head_dim).transpose(1, 2)
2532*da0073e9SAndroid Build Coastguard Worker
2533*da0073e9SAndroid Build Coastguard Worker        assert torch._fused_sdp_choice(query, key, value) == SDPBackend.EFFICIENT_ATTENTION.value
2534*da0073e9SAndroid Build Coastguard Worker
2535*da0073e9SAndroid Build Coastguard Worker    @skipIfRocm  # Missing triton.float32 ("triton" prefix is to locate skipped UTs), and deterministic algo
2536*da0073e9SAndroid Build Coastguard Worker    @unittest.skipIf(not PLATFORM_SUPPORTS_MEM_EFF_ATTENTION, "Platform does not support fused SDPA")
2537*da0073e9SAndroid Build Coastguard Worker    @parametrize("warn_only", [True, False])
2538*da0073e9SAndroid Build Coastguard Worker    def test_sdp_choice_with_determinism(self, device, warn_only):
2539*da0073e9SAndroid Build Coastguard Worker        batch_size, seq_len, num_heads, head_dim = 1, 64, 8, 64
2540*da0073e9SAndroid Build Coastguard Worker        shape = SdpaShape(batch_size, num_heads, seq_len, head_dim)
2541*da0073e9SAndroid Build Coastguard Worker        make_tensor = partial(rand_sdpa_tensor, type="dense", device=device, dtype=torch.float32, packed=False)
2542*da0073e9SAndroid Build Coastguard Worker        query, key, value = make_tensor(shape), make_tensor(shape), make_tensor(shape)
2543*da0073e9SAndroid Build Coastguard Worker
2544*da0073e9SAndroid Build Coastguard Worker        with use_deterministic_algorithims(True, warn_only=warn_only):
2545*da0073e9SAndroid Build Coastguard Worker            with sdpa_kernel(backends=[SDPBackend.EFFICIENT_ATTENTION, SDPBackend.MATH]):
2546*da0073e9SAndroid Build Coastguard Worker                assert torch._fused_sdp_choice(query, key, value) == SDPBackend.EFFICIENT_ATTENTION.value
2547*da0073e9SAndroid Build Coastguard Worker
2548*da0073e9SAndroid Build Coastguard Worker    @skipIfRocm  # Missing deterministic algo
2549*da0073e9SAndroid Build Coastguard Worker    @unittest.skipIf(not PLATFORM_SUPPORTS_FUSED_ATTENTION, "Fused SDPA was not built for this system")
2550*da0073e9SAndroid Build Coastguard Worker    @parametrize("fused_kernel", PLATFORM_SPECIFIC_SDPA)
2551*da0073e9SAndroid Build Coastguard Worker    @parametrize("warn_only", [True, False])
2552*da0073e9SAndroid Build Coastguard Worker    def test_fused_backwards_throws_determinism_warning(self, device, warn_only, fused_kernel):
2553*da0073e9SAndroid Build Coastguard Worker        batch_size, seq_len, num_heads, head_dim = 1, 64, 8, 64
2554*da0073e9SAndroid Build Coastguard Worker        shape = SdpaShape(batch_size, num_heads, seq_len, head_dim)
2555*da0073e9SAndroid Build Coastguard Worker        make_tensor = partial(rand_sdpa_tensor, type="dense", device=device, dtype=torch.float16, packed=False, requires_grad=True)
2556*da0073e9SAndroid Build Coastguard Worker        query, key, value = make_tensor(shape), make_tensor(shape), make_tensor(shape)
2557*da0073e9SAndroid Build Coastguard Worker
2558*da0073e9SAndroid Build Coastguard Worker        kernel_name = "Memory Efficient attention" if fused_kernel == SDPBackend.EFFICIENT_ATTENTION else "Flash Attention"
2559*da0073e9SAndroid Build Coastguard Worker        warning_context = (
2560*da0073e9SAndroid Build Coastguard Worker            self.assertWarnsRegex(
2561*da0073e9SAndroid Build Coastguard Worker                UserWarning,
2562*da0073e9SAndroid Build Coastguard Worker                f"{kernel_name} defaults to a non-deterministic algorithm.",
2563*da0073e9SAndroid Build Coastguard Worker            )
2564*da0073e9SAndroid Build Coastguard Worker            if warn_only
2565*da0073e9SAndroid Build Coastguard Worker            else contextlib.nullcontext()
2566*da0073e9SAndroid Build Coastguard Worker        )
2567*da0073e9SAndroid Build Coastguard Worker        with use_deterministic_algorithims(True, warn_only=warn_only):
2568*da0073e9SAndroid Build Coastguard Worker            with sdpa_kernel(backends=[fused_kernel]):
2569*da0073e9SAndroid Build Coastguard Worker                with warning_context:
2570*da0073e9SAndroid Build Coastguard Worker                    torch.nn.functional.scaled_dot_product_attention(query, key, value).sum().backward()
2571*da0073e9SAndroid Build Coastguard Worker
2572*da0073e9SAndroid Build Coastguard Worker    @unittest.skip("This test is not behaving deterministaclly non-deterministaclly on CI/CD")
2573*da0073e9SAndroid Build Coastguard Worker    @unittest.skipIf(not PLATFORM_SUPPORTS_FLASH_ATTENTION, "Platform does not support fused SDPA")
2574*da0073e9SAndroid Build Coastguard Worker    def test_mem_eff_backwards_determinism(self, device):
2575*da0073e9SAndroid Build Coastguard Worker        # Need big seq_len to ensure that num_splits > 1
2576*da0073e9SAndroid Build Coastguard Worker        dtype = torch.float32
2577*da0073e9SAndroid Build Coastguard Worker        batch_size, seq_len, n_heads, head_dim = 1, 1024, 8, 64
2578*da0073e9SAndroid Build Coastguard Worker        query = torch.rand(batch_size, n_heads, seq_len, head_dim,
2579*da0073e9SAndroid Build Coastguard Worker                           device=device, dtype=dtype, requires_grad=True)
2580*da0073e9SAndroid Build Coastguard Worker        key = torch.rand(batch_size, n_heads, seq_len, head_dim, device=device,
2581*da0073e9SAndroid Build Coastguard Worker                         dtype=dtype, requires_grad=True)
2582*da0073e9SAndroid Build Coastguard Worker        value = torch.rand(batch_size, n_heads, seq_len, head_dim,
2583*da0073e9SAndroid Build Coastguard Worker                           device=device, dtype=dtype, requires_grad=True)
2584*da0073e9SAndroid Build Coastguard Worker
2585*da0073e9SAndroid Build Coastguard Worker        with sdpa_kernel(backends=[SDPBackend.EFFICIENT_ATTENTION]):
2586*da0073e9SAndroid Build Coastguard Worker            # Run once to establish baseline
2587*da0073e9SAndroid Build Coastguard Worker            out = F.scaled_dot_product_attention(query, key, value)
2588*da0073e9SAndroid Build Coastguard Worker            upward_grad = torch.rand_like(out)
2589*da0073e9SAndroid Build Coastguard Worker            out.backward(upward_grad)
2590*da0073e9SAndroid Build Coastguard Worker            intial_query_grad = query.grad
2591*da0073e9SAndroid Build Coastguard Worker
2592*da0073e9SAndroid Build Coastguard Worker            # Re-run the op with the same upward grad and check that the backward is
2593*da0073e9SAndroid Build Coastguard Worker            # not deterministic
2594*da0073e9SAndroid Build Coastguard Worker            diff_anwser_once = False
2595*da0073e9SAndroid Build Coastguard Worker            for _ in range(100):
2596*da0073e9SAndroid Build Coastguard Worker                query.grad = None
2597*da0073e9SAndroid Build Coastguard Worker                out = F.scaled_dot_product_attention(query, key, value)
2598*da0073e9SAndroid Build Coastguard Worker                out.backward(upward_grad)
2599*da0073e9SAndroid Build Coastguard Worker                if not torch.equal(intial_query_grad, query.grad):
2600*da0073e9SAndroid Build Coastguard Worker                    diff_anwser_once = True
2601*da0073e9SAndroid Build Coastguard Worker                    break
2602*da0073e9SAndroid Build Coastguard Worker            self.assertTrue(diff_anwser_once)
2603*da0073e9SAndroid Build Coastguard Worker
2604*da0073e9SAndroid Build Coastguard Worker        with use_deterministic_algorithims(True, warn_only=False):
2605*da0073e9SAndroid Build Coastguard Worker            query.grad = None
2606*da0073e9SAndroid Build Coastguard Worker            out = F.scaled_dot_product_attention(query, key, value)
2607*da0073e9SAndroid Build Coastguard Worker            upward_grad = torch.rand_like(out)
2608*da0073e9SAndroid Build Coastguard Worker            out.backward(upward_grad)
2609*da0073e9SAndroid Build Coastguard Worker            intial_query_grad = query.grad
2610*da0073e9SAndroid Build Coastguard Worker
2611*da0073e9SAndroid Build Coastguard Worker            # Re-run the op with the same upward grad and check that the backward is
2612*da0073e9SAndroid Build Coastguard Worker            # deterministic now that we have enforced it
2613*da0073e9SAndroid Build Coastguard Worker            diff_anwser_once = False
2614*da0073e9SAndroid Build Coastguard Worker            for _ in range(100):
2615*da0073e9SAndroid Build Coastguard Worker                query.grad = None
2616*da0073e9SAndroid Build Coastguard Worker                out = F.scaled_dot_product_attention(query, key, value)
2617*da0073e9SAndroid Build Coastguard Worker                out.backward(upward_grad)
2618*da0073e9SAndroid Build Coastguard Worker                if not torch.equal(intial_query_grad, query.grad):
2619*da0073e9SAndroid Build Coastguard Worker                    diff_anwser_once = True
2620*da0073e9SAndroid Build Coastguard Worker                    break
2621*da0073e9SAndroid Build Coastguard Worker            self.assertFalse(diff_anwser_once)
2622*da0073e9SAndroid Build Coastguard Worker
2623*da0073e9SAndroid Build Coastguard Worker    # verified passing successfully on H100
2624*da0073e9SAndroid Build Coastguard Worker    @unittest.skipIf(not PLATFORM_SUPPORTS_MEM_EFF_ATTENTION, "Does not support SDPA")
2625*da0073e9SAndroid Build Coastguard Worker    @unittest.skipIf(IS_JETSON, "causing sigkill on Jetson")
2626*da0073e9SAndroid Build Coastguard Worker    @parametrize("batch_size", [1, 8])
2627*da0073e9SAndroid Build Coastguard Worker    @parametrize("seq_len_q", [4, 8, 64, 128, 256, 512, 1024, 2048] if MEM_EFF_CAPABILITY_MATCHES_SM80
2628*da0073e9SAndroid Build Coastguard Worker                 else [4, 8, 64, 128, 256, 512])
2629*da0073e9SAndroid Build Coastguard Worker    @parametrize("seq_len_k", [4, 8, 64, 128, 256, 512, 1024, 2048] if MEM_EFF_CAPABILITY_MATCHES_SM80
2630*da0073e9SAndroid Build Coastguard Worker                 else [4, 8, 64, 128, 256, 512])
2631*da0073e9SAndroid Build Coastguard Worker    @parametrize("head_dim", [8, 16, 32, 64, 72, 96, 128] if MEM_EFF_CAPABILITY_MATCHES_SM80
2632*da0073e9SAndroid Build Coastguard Worker                 else [8, 16, 32, 64])
2633*da0073e9SAndroid Build Coastguard Worker    @parametrize("is_causal", [False, True])
2634*da0073e9SAndroid Build Coastguard Worker    @parametrize("dropout_p", [0.0, 0.22])
2635*da0073e9SAndroid Build Coastguard Worker    @parametrize("dtype", [torch.float16, torch.bfloat16, torch.float32] if MEM_EFF_CAPABILITY_MATCHES_SM80
2636*da0073e9SAndroid Build Coastguard Worker                 else [torch.float16, torch.float32])
2637*da0073e9SAndroid Build Coastguard Worker    @parametrize("scale", [None, "l1"])
2638*da0073e9SAndroid Build Coastguard Worker    def test_mem_efficient_attention_vs_math_ref_grads(self, device, batch_size: int, seq_len_q: int, seq_len_k: int,
2639*da0073e9SAndroid Build Coastguard Worker                                                       head_dim: int, is_causal: bool, dropout_p: float, dtype: torch.dtype,
2640*da0073e9SAndroid Build Coastguard Worker                                                       scale: str):
2641*da0073e9SAndroid Build Coastguard Worker        def _get_mem_eff_drop_mask(batch_size, n_heads, q_len, kv_len, p, seed, offset, device=device):
2642*da0073e9SAndroid Build Coastguard Worker            mask = torch.empty((batch_size, n_heads, q_len, kv_len), device=device, dtype=torch.float32)
2643*da0073e9SAndroid Build Coastguard Worker            rand_uniform = torch._fill_mem_eff_dropout_mask_(mask, p, seed, offset)
2644*da0073e9SAndroid Build Coastguard Worker            mask = (rand_uniform > p).to(torch.float32)
2645*da0073e9SAndroid Build Coastguard Worker            return mask
2646*da0073e9SAndroid Build Coastguard Worker        if max(seq_len_q, seq_len_k) >= 2048 and torch.cuda.get_device_properties('cuda').total_memory < 40 * 2**30:
2647*da0073e9SAndroid Build Coastguard Worker            unittest.skip("Reference implementation OOM")
2648*da0073e9SAndroid Build Coastguard Worker            return
2649*da0073e9SAndroid Build Coastguard Worker        if TEST_WITH_ROCM and seq_len_q * seq_len_k * head_dim * batch_size > 1024 * 1024 * 128:
2650*da0073e9SAndroid Build Coastguard Worker            torch.cuda.empty_cache()  # Prevent memory fragmentation
2651*da0073e9SAndroid Build Coastguard Worker        seed = 42
2652*da0073e9SAndroid Build Coastguard Worker        scale = scale if scale is None else (1 / head_dim)
2653*da0073e9SAndroid Build Coastguard Worker        n_heads = 4
2654*da0073e9SAndroid Build Coastguard Worker        query = torch.rand(batch_size, n_heads, seq_len_q, head_dim,
2655*da0073e9SAndroid Build Coastguard Worker                           device=device, dtype=dtype, requires_grad=True)
2656*da0073e9SAndroid Build Coastguard Worker        key = torch.rand(batch_size, n_heads, seq_len_k, head_dim, device=device,
2657*da0073e9SAndroid Build Coastguard Worker                         dtype=dtype, requires_grad=True)
2658*da0073e9SAndroid Build Coastguard Worker        value = torch.rand(batch_size, n_heads, seq_len_k, head_dim,
2659*da0073e9SAndroid Build Coastguard Worker                           device=device, dtype=dtype, requires_grad=True)
2660*da0073e9SAndroid Build Coastguard Worker
2661*da0073e9SAndroid Build Coastguard Worker        # Run the math kernel on low precision references
2662*da0073e9SAndroid Build Coastguard Worker        query_ref_lp, key_ref_lp, value_ref_lp = query_key_value_clones(query, key, value, dtype=dtype)
2663*da0073e9SAndroid Build Coastguard Worker
2664*da0073e9SAndroid Build Coastguard Worker        higher_precision_dtype = torch.float64 if dtype == torch.float32 else torch.float32
2665*da0073e9SAndroid Build Coastguard Worker        query_ref, key_ref, value_ref = query_key_value_clones(query, key, value, dtype=higher_precision_dtype)
2666*da0073e9SAndroid Build Coastguard Worker
2667*da0073e9SAndroid Build Coastguard Worker        # Create real output
2668*da0073e9SAndroid Build Coastguard Worker        with sdpa_kernel(backends=[SDPBackend.EFFICIENT_ATTENTION]):
2669*da0073e9SAndroid Build Coastguard Worker            # Set the seed and run the kernel
2670*da0073e9SAndroid Build Coastguard Worker            torch.manual_seed(seed)
2671*da0073e9SAndroid Build Coastguard Worker            out = F.scaled_dot_product_attention(query, key, value, dropout_p=dropout_p, is_causal=is_causal, scale=scale)
2672*da0073e9SAndroid Build Coastguard Worker
2673*da0073e9SAndroid Build Coastguard Worker        if dropout_p == 0.0:
2674*da0073e9SAndroid Build Coastguard Worker            with sdpa_kernel(backends=[SDPBackend.MATH]):
2675*da0073e9SAndroid Build Coastguard Worker                # High Precision Math Reference
2676*da0073e9SAndroid Build Coastguard Worker                out_ref = F.scaled_dot_product_attention(query_ref, key_ref, value_ref,
2677*da0073e9SAndroid Build Coastguard Worker                                                         dropout_p=dropout_p, is_causal=is_causal, scale=scale)
2678*da0073e9SAndroid Build Coastguard Worker                # Low Precision Math Reference
2679*da0073e9SAndroid Build Coastguard Worker                out_lp_ref = F.scaled_dot_product_attention(query_ref_lp, key_ref_lp, value_ref_lp,
2680*da0073e9SAndroid Build Coastguard Worker                                                            dropout_p=dropout_p, is_causal=is_causal, scale=scale)
2681*da0073e9SAndroid Build Coastguard Worker        else:
2682*da0073e9SAndroid Build Coastguard Worker            if seq_len_q > 1024:
2683*da0073e9SAndroid Build Coastguard Worker                self.skipTest("Will call _fill_mem_eff_dropout_mask with too many threads!")
2684*da0073e9SAndroid Build Coastguard Worker            # Create the dropout_mask
2685*da0073e9SAndroid Build Coastguard Worker            torch.manual_seed(seed)
2686*da0073e9SAndroid Build Coastguard Worker            dropout_mask = _get_mem_eff_drop_mask(batch_size, n_heads, seq_len_q, seq_len_k, dropout_p, seed, 0, device=device)
2687*da0073e9SAndroid Build Coastguard Worker            # High Precision Math Reference
2688*da0073e9SAndroid Build Coastguard Worker            out_ref = torch.ops.aten._scaled_dot_product_attention_math(
2689*da0073e9SAndroid Build Coastguard Worker                query_ref, key_ref, value_ref, dropout_p=dropout_p, is_causal=is_causal, scale=scale, dropout_mask=dropout_mask)[0]
2690*da0073e9SAndroid Build Coastguard Worker            # Low Precision Math Reference
2691*da0073e9SAndroid Build Coastguard Worker            out_lp_ref = torch.ops.aten._scaled_dot_product_attention_math(
2692*da0073e9SAndroid Build Coastguard Worker                query_ref_lp, key_ref_lp, value_ref_lp, dropout_p=dropout_p, is_causal=is_causal, scale=scale,
2693*da0073e9SAndroid Build Coastguard Worker                dropout_mask=dropout_mask)[0]
2694*da0073e9SAndroid Build Coastguard Worker
2695*da0073e9SAndroid Build Coastguard Worker        upstream_grad = torch.rand_like(out, requires_grad=False)
2696*da0073e9SAndroid Build Coastguard Worker
2697*da0073e9SAndroid Build Coastguard Worker        out.backward(upstream_grad)
2698*da0073e9SAndroid Build Coastguard Worker        out_ref.backward(upstream_grad.to(out_ref.dtype))
2699*da0073e9SAndroid Build Coastguard Worker        out_lp_ref.backward(upstream_grad.to(out_lp_ref.dtype))
2700*da0073e9SAndroid Build Coastguard Worker
2701*da0073e9SAndroid Build Coastguard Worker        # [Note] Fused Tolerances
2702*da0073e9SAndroid Build Coastguard Worker        # Establish the numerical error between the "true" high precision math output
2703*da0073e9SAndroid Build Coastguard Worker        # and the low precision math reference. We use this reference for the atol
2704*da0073e9SAndroid Build Coastguard Worker        # And we use the default rtol for the low precision type.
2705*da0073e9SAndroid Build Coastguard Worker        # We then provide a fudge factor for gradients respectively to account
2706*da0073e9SAndroid Build Coastguard Worker        # for the use of the fused kernel rather than the eager implemntation.
2707*da0073e9SAndroid Build Coastguard Worker        output_ref_atol, output_ref_rtol = get_tolerances(out_ref, out_lp_ref)
2708*da0073e9SAndroid Build Coastguard Worker
2709*da0073e9SAndroid Build Coastguard Worker        # Fudge Factor when dropout is enabled
2710*da0073e9SAndroid Build Coastguard Worker        dropout_fudge_factor = 1.0 if dropout_p == 0.0 else 2.0
2711*da0073e9SAndroid Build Coastguard Worker
2712*da0073e9SAndroid Build Coastguard Worker        query_fudge_factor = dropout_fudge_factor
2713*da0073e9SAndroid Build Coastguard Worker        grad_q_ref_atol, grad_q_ref_rtol = get_tolerances(query_ref.grad, query_ref_lp.grad, query_fudge_factor)
2714*da0073e9SAndroid Build Coastguard Worker
2715*da0073e9SAndroid Build Coastguard Worker        # TODO: Investigate why grad_k needs larger tolerances
2716*da0073e9SAndroid Build Coastguard Worker        key_fudge_factor = 8 * dropout_fudge_factor
2717*da0073e9SAndroid Build Coastguard Worker        grad_k_ref_atol, grad_k_ref_rtol = get_tolerances(key_ref.grad, key_ref_lp.grad, key_fudge_factor)
2718*da0073e9SAndroid Build Coastguard Worker
2719*da0073e9SAndroid Build Coastguard Worker        value_fudge_factor = 7 if not SM80OrLater and dtype == torch.float16 else 1.0
2720*da0073e9SAndroid Build Coastguard Worker        if TEST_WITH_ROCM:
2721*da0073e9SAndroid Build Coastguard Worker            value_fudge_factor = max(2.0, value_fudge_factor)
2722*da0073e9SAndroid Build Coastguard Worker        grad_v_ref_atol, grad_v_ref_rtol = get_tolerances(value_ref.grad, value_ref_lp.grad, value_fudge_factor)
2723*da0073e9SAndroid Build Coastguard Worker
2724*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(out, out_ref.to(out.dtype), atol=output_ref_atol, rtol=output_ref_rtol)
2725*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(query.grad, query_ref.grad.to(query.grad.dtype),
2726*da0073e9SAndroid Build Coastguard Worker                         atol=grad_q_ref_atol, rtol=grad_q_ref_rtol)
2727*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(key.grad, key_ref.grad.to(key.grad.dtype),
2728*da0073e9SAndroid Build Coastguard Worker                         atol=grad_k_ref_atol, rtol=grad_k_ref_rtol)
2729*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(value.grad, value_ref.grad.to(value.grad.dtype),
2730*da0073e9SAndroid Build Coastguard Worker                         atol=grad_v_ref_atol, rtol=grad_v_ref_rtol)
2731*da0073e9SAndroid Build Coastguard Worker
2732*da0073e9SAndroid Build Coastguard Worker
2733*da0073e9SAndroid Build Coastguard Worker    @unittest.skipIf(not PLATFORM_SUPPORTS_MEM_EFF_ATTENTION, "Does not support SDPA")
2734*da0073e9SAndroid Build Coastguard Worker    @unittest.skipIf(IS_JETSON, "causing sigkill on Jetson")
2735*da0073e9SAndroid Build Coastguard Worker    @parametrize("batch_size", [1, 8])
2736*da0073e9SAndroid Build Coastguard Worker    @parametrize("seq_len_q", [4, 8, 64, 128, 256, 312, 512, 1024, 2048] if MEM_EFF_CAPABILITY_MATCHES_SM80
2737*da0073e9SAndroid Build Coastguard Worker                 else [4, 8, 64, 128, 152, 256, 512])
2738*da0073e9SAndroid Build Coastguard Worker    @parametrize("seq_len_k", [4, 8, 64, 65, 128, 256, 408, 512, 1024, 2048] if MEM_EFF_CAPABILITY_MATCHES_SM80
2739*da0073e9SAndroid Build Coastguard Worker                 else [4, 8, 37, 64, 128, 256, 512])
2740*da0073e9SAndroid Build Coastguard Worker    @parametrize("head_dim", [8, 16, 32, 64, 72, 96, 128] if MEM_EFF_CAPABILITY_MATCHES_SM80
2741*da0073e9SAndroid Build Coastguard Worker                 else [8, 16, 32, 64])
2742*da0073e9SAndroid Build Coastguard Worker    @parametrize("is_causal", [False])
2743*da0073e9SAndroid Build Coastguard Worker    @parametrize("dropout_p", [0.0, 0.22])
2744*da0073e9SAndroid Build Coastguard Worker    @parametrize("dtype", [torch.float16, torch.bfloat16, torch.float32] if MEM_EFF_CAPABILITY_MATCHES_SM80
2745*da0073e9SAndroid Build Coastguard Worker                 else [torch.float16, torch.float32])
2746*da0073e9SAndroid Build Coastguard Worker    @parametrize("scale", [None, "l1"])
2747*da0073e9SAndroid Build Coastguard Worker    def test_mem_efficient_attention_attn_mask_vs_math_ref_grads(self, device, batch_size: int, seq_len_q: int,
2748*da0073e9SAndroid Build Coastguard Worker                                                                 seq_len_k: int, head_dim: int, is_causal: bool,
2749*da0073e9SAndroid Build Coastguard Worker                                                                 dropout_p: float, dtype: torch.dtype,
2750*da0073e9SAndroid Build Coastguard Worker                                                                 scale: str):
2751*da0073e9SAndroid Build Coastguard Worker        def _get_mem_eff_drop_mask(batch_size, n_heads, q_len, kv_len, p, seed, offset, device=device):
2752*da0073e9SAndroid Build Coastguard Worker            mask = torch.empty((batch_size, n_heads, q_len, kv_len), device=device, dtype=torch.float32)
2753*da0073e9SAndroid Build Coastguard Worker            rand_uniform = torch._fill_mem_eff_dropout_mask_(mask, p, seed, offset)
2754*da0073e9SAndroid Build Coastguard Worker            mask = (rand_uniform > p).to(torch.float32)
2755*da0073e9SAndroid Build Coastguard Worker            return mask
2756*da0073e9SAndroid Build Coastguard Worker        if max(seq_len_q, seq_len_k) >= 2048 and torch.cuda.get_device_properties('cuda').total_memory < 40 * 2**30:
2757*da0073e9SAndroid Build Coastguard Worker            unittest.skip("Reference implementation OOM")
2758*da0073e9SAndroid Build Coastguard Worker            return
2759*da0073e9SAndroid Build Coastguard Worker        if TEST_WITH_ROCM and dtype == torch.float32:
2760*da0073e9SAndroid Build Coastguard Worker            unittest.skip("Skip fp32 attn_mask gradients on ROCM, for now.")
2761*da0073e9SAndroid Build Coastguard Worker            return
2762*da0073e9SAndroid Build Coastguard Worker        if TEST_WITH_ROCM and seq_len_q * seq_len_k * head_dim * batch_size > 1024 * 1024 * 128:
2763*da0073e9SAndroid Build Coastguard Worker            torch.cuda.empty_cache()  # Prevent memory fragmentation
2764*da0073e9SAndroid Build Coastguard Worker        seed = 42
2765*da0073e9SAndroid Build Coastguard Worker        scale = scale if scale is None else (1 / head_dim)
2766*da0073e9SAndroid Build Coastguard Worker        n_heads = 4
2767*da0073e9SAndroid Build Coastguard Worker        query = torch.rand(batch_size, n_heads, seq_len_q, head_dim,
2768*da0073e9SAndroid Build Coastguard Worker                           device=device, dtype=dtype, requires_grad=True)
2769*da0073e9SAndroid Build Coastguard Worker        key = torch.rand(batch_size, n_heads, seq_len_k, head_dim, device=device,
2770*da0073e9SAndroid Build Coastguard Worker                         dtype=dtype, requires_grad=True)
2771*da0073e9SAndroid Build Coastguard Worker        value = torch.rand(batch_size, n_heads, seq_len_k, head_dim,
2772*da0073e9SAndroid Build Coastguard Worker                           device=device, dtype=dtype, requires_grad=True)
2773*da0073e9SAndroid Build Coastguard Worker
2774*da0073e9SAndroid Build Coastguard Worker        attn_mask = torch.rand(seq_len_q, seq_len_k, device=device, dtype=dtype, requires_grad=True)
2775*da0073e9SAndroid Build Coastguard Worker
2776*da0073e9SAndroid Build Coastguard Worker        # Run the math kernel on low precision references
2777*da0073e9SAndroid Build Coastguard Worker        query_ref_lp, key_ref_lp, value_ref_lp = query_key_value_clones(query, key, value, dtype=dtype)
2778*da0073e9SAndroid Build Coastguard Worker        attn_mask_ref_lp = attn_mask.detach().to(dtype).requires_grad_(True)
2779*da0073e9SAndroid Build Coastguard Worker
2780*da0073e9SAndroid Build Coastguard Worker        higher_precision_dtype = torch.float64 if dtype == torch.float32 else torch.float32
2781*da0073e9SAndroid Build Coastguard Worker        query_ref, key_ref, value_ref = query_key_value_clones(query, key, value, dtype=higher_precision_dtype)
2782*da0073e9SAndroid Build Coastguard Worker        attn_mask_ref = attn_mask.detach().to(higher_precision_dtype).requires_grad_(True)
2783*da0073e9SAndroid Build Coastguard Worker
2784*da0073e9SAndroid Build Coastguard Worker        # Create real output
2785*da0073e9SAndroid Build Coastguard Worker        with sdpa_kernel(backends=[SDPBackend.EFFICIENT_ATTENTION]):
2786*da0073e9SAndroid Build Coastguard Worker            # Set the seed and run the kernel
2787*da0073e9SAndroid Build Coastguard Worker            torch.manual_seed(seed)
2788*da0073e9SAndroid Build Coastguard Worker            out = F.scaled_dot_product_attention(query, key, value, attn_mask, dropout_p=dropout_p,
2789*da0073e9SAndroid Build Coastguard Worker                                                 is_causal=is_causal, scale=scale)
2790*da0073e9SAndroid Build Coastguard Worker
2791*da0073e9SAndroid Build Coastguard Worker        if dropout_p == 0.0:
2792*da0073e9SAndroid Build Coastguard Worker            with sdpa_kernel(backends=[SDPBackend.MATH]):
2793*da0073e9SAndroid Build Coastguard Worker                # High Precision Math Reference
2794*da0073e9SAndroid Build Coastguard Worker                out_ref = F.scaled_dot_product_attention(query_ref, key_ref, value_ref, attn_mask_ref,
2795*da0073e9SAndroid Build Coastguard Worker                                                         dropout_p=dropout_p, is_causal=is_causal, scale=scale)
2796*da0073e9SAndroid Build Coastguard Worker                # Low Precision Math Reference
2797*da0073e9SAndroid Build Coastguard Worker                out_lp_ref = F.scaled_dot_product_attention(query_ref_lp, key_ref_lp, value_ref_lp, attn_mask_ref_lp,
2798*da0073e9SAndroid Build Coastguard Worker                                                            dropout_p=dropout_p, is_causal=is_causal, scale=scale)
2799*da0073e9SAndroid Build Coastguard Worker        else:
2800*da0073e9SAndroid Build Coastguard Worker            if seq_len_q > 1024:
2801*da0073e9SAndroid Build Coastguard Worker                self.skipTest("Will call _fill_mem_eff_dropout_mask with too many threads!")
2802*da0073e9SAndroid Build Coastguard Worker            # Create the dropout_mask
2803*da0073e9SAndroid Build Coastguard Worker            torch.manual_seed(seed)
2804*da0073e9SAndroid Build Coastguard Worker            dropout_mask = _get_mem_eff_drop_mask(batch_size, n_heads, seq_len_q,
2805*da0073e9SAndroid Build Coastguard Worker                                                  seq_len_k, dropout_p, seed, 0, device=device)
2806*da0073e9SAndroid Build Coastguard Worker            # High Precision Math Reference
2807*da0073e9SAndroid Build Coastguard Worker            out_ref = torch.ops.aten._scaled_dot_product_attention_math(
2808*da0073e9SAndroid Build Coastguard Worker                query_ref, key_ref, value_ref, attn_mask_ref, dropout_p=dropout_p, is_causal=is_causal,
2809*da0073e9SAndroid Build Coastguard Worker                scale=scale, dropout_mask=dropout_mask)[0]
2810*da0073e9SAndroid Build Coastguard Worker            # Low Precision Math Reference
2811*da0073e9SAndroid Build Coastguard Worker            out_lp_ref = torch.ops.aten._scaled_dot_product_attention_math(
2812*da0073e9SAndroid Build Coastguard Worker                query_ref_lp, key_ref_lp, value_ref_lp, attn_mask_ref_lp,
2813*da0073e9SAndroid Build Coastguard Worker                dropout_p=dropout_p, is_causal=is_causal, scale=scale,
2814*da0073e9SAndroid Build Coastguard Worker                dropout_mask=dropout_mask)[0]
2815*da0073e9SAndroid Build Coastguard Worker
2816*da0073e9SAndroid Build Coastguard Worker        upstream_grad = torch.rand_like(out, requires_grad=False)
2817*da0073e9SAndroid Build Coastguard Worker
2818*da0073e9SAndroid Build Coastguard Worker        out.backward(upstream_grad)
2819*da0073e9SAndroid Build Coastguard Worker        out_ref.backward(upstream_grad.to(out_ref.dtype))
2820*da0073e9SAndroid Build Coastguard Worker        out_lp_ref.backward(upstream_grad.to(out_lp_ref.dtype))
2821*da0073e9SAndroid Build Coastguard Worker
2822*da0073e9SAndroid Build Coastguard Worker        # [Note] Fused Tolerances
2823*da0073e9SAndroid Build Coastguard Worker        # Establish the numerical error between the "true" high precision math output
2824*da0073e9SAndroid Build Coastguard Worker        # and the low precision math reference. We use this reference for the atol
2825*da0073e9SAndroid Build Coastguard Worker        # And we use the default rtol for the low precision type.
2826*da0073e9SAndroid Build Coastguard Worker        # We then provide a fudge factor for gradients respectively to account
2827*da0073e9SAndroid Build Coastguard Worker        # for the use of the fused kernel rather than the eager implemntation.
2828*da0073e9SAndroid Build Coastguard Worker        output_ref_atol, output_ref_rtol = get_tolerances(out_ref, out_lp_ref)
2829*da0073e9SAndroid Build Coastguard Worker
2830*da0073e9SAndroid Build Coastguard Worker        # Fudge Factor when dropout is enabled
2831*da0073e9SAndroid Build Coastguard Worker        dropout_fudge_factor = 1.0 if dropout_p == 0.0 else 1.75
2832*da0073e9SAndroid Build Coastguard Worker        mask_fudge_factor = 1.0 if attn_mask is None else 1.5
2833*da0073e9SAndroid Build Coastguard Worker
2834*da0073e9SAndroid Build Coastguard Worker        query_fudge_factor = dropout_fudge_factor
2835*da0073e9SAndroid Build Coastguard Worker        grad_q_ref_atol, grad_q_ref_rtol = get_tolerances(query_ref.grad, query_ref_lp.grad, query_fudge_factor)
2836*da0073e9SAndroid Build Coastguard Worker
2837*da0073e9SAndroid Build Coastguard Worker        # TODO: Investigate why grad_k needs larger tolerances
2838*da0073e9SAndroid Build Coastguard Worker        key_fudge_factor = 8 * dropout_fudge_factor * mask_fudge_factor
2839*da0073e9SAndroid Build Coastguard Worker        grad_k_ref_atol, grad_k_ref_rtol = get_tolerances(key_ref.grad, key_ref_lp.grad, key_fudge_factor)
2840*da0073e9SAndroid Build Coastguard Worker
2841*da0073e9SAndroid Build Coastguard Worker        value_fudge_factor = 7 if not SM80OrLater and dtype == torch.float16 else 1.0
2842*da0073e9SAndroid Build Coastguard Worker        if TEST_WITH_ROCM:
2843*da0073e9SAndroid Build Coastguard Worker            value_fudge_factor = max(2.0, value_fudge_factor)
2844*da0073e9SAndroid Build Coastguard Worker        grad_v_ref_atol, grad_v_ref_rtol = get_tolerances(value_ref.grad, value_ref_lp.grad, value_fudge_factor)
2845*da0073e9SAndroid Build Coastguard Worker
2846*da0073e9SAndroid Build Coastguard Worker        mask_fudge_factor = 12 if attn_mask.numel() > 512 else 22
2847*da0073e9SAndroid Build Coastguard Worker        grad_attn_mask_atol, grad_attn_mask_rtol = get_tolerances(
2848*da0073e9SAndroid Build Coastguard Worker            attn_mask_ref.grad, attn_mask_ref_lp.grad, mask_fudge_factor)
2849*da0073e9SAndroid Build Coastguard Worker
2850*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(out, out_ref.to(out.dtype), atol=output_ref_atol, rtol=output_ref_rtol)
2851*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(query.grad, query_ref.grad.to(query.grad.dtype),
2852*da0073e9SAndroid Build Coastguard Worker                         atol=grad_q_ref_atol, rtol=grad_q_ref_rtol)
2853*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(key.grad, key_ref.grad.to(key.grad.dtype),
2854*da0073e9SAndroid Build Coastguard Worker                         atol=grad_k_ref_atol, rtol=grad_k_ref_rtol)
2855*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(value.grad, value_ref.grad.to(value.grad.dtype),
2856*da0073e9SAndroid Build Coastguard Worker                         atol=grad_v_ref_atol, rtol=grad_v_ref_rtol)
2857*da0073e9SAndroid Build Coastguard Worker
2858*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(attn_mask.grad, attn_mask_ref.grad.to(attn_mask.grad.dtype),
2859*da0073e9SAndroid Build Coastguard Worker                         atol=grad_attn_mask_atol, rtol=grad_attn_mask_rtol)
2860*da0073e9SAndroid Build Coastguard Worker
2861*da0073e9SAndroid Build Coastguard Worker    @unittest.skipIf(not PLATFORM_SUPPORTS_FLASH_ATTENTION, "Does not support SDPA or pre-SM80 hardware")
2862*da0073e9SAndroid Build Coastguard Worker    @unittest.skipIf(IS_JETSON, "causing sigkill on Jetson")
2863*da0073e9SAndroid Build Coastguard Worker    @parametrize("batch_size", [1, 8])
2864*da0073e9SAndroid Build Coastguard Worker    @parametrize("seq_len_q", [4, 8, 64, 143, 256, 512, 1024, 2048])
2865*da0073e9SAndroid Build Coastguard Worker    @parametrize("seq_len_k", [4, 8, 64, 128, 256, 587, 1024, 2048])
2866*da0073e9SAndroid Build Coastguard Worker    @parametrize("head_dim", [8, 16, 21, 32, 64, 72, 96, 128, 160, 192, 203, 256])
2867*da0073e9SAndroid Build Coastguard Worker    @parametrize("is_causal", [True, False])
2868*da0073e9SAndroid Build Coastguard Worker    @parametrize("dropout_p", [0.0, 0.22, 0.48])
2869*da0073e9SAndroid Build Coastguard Worker    @parametrize("dtype", [torch.float16, torch.bfloat16])
2870*da0073e9SAndroid Build Coastguard Worker    @parametrize("scale", [None, "l1"])
2871*da0073e9SAndroid Build Coastguard Worker    def test_flash_attention_vs_math_ref_grads(self, device, batch_size: int, seq_len_q: int, seq_len_k: int,
2872*da0073e9SAndroid Build Coastguard Worker                                               head_dim: int, is_causal: bool, dropout_p: float, dtype: torch.dtype,
2873*da0073e9SAndroid Build Coastguard Worker                                               scale: str):
2874*da0073e9SAndroid Build Coastguard Worker        if isSM8XDevice and head_dim in range(193, 256 + 1):
2875*da0073e9SAndroid Build Coastguard Worker            self.skipTest("Flash attention on sm86, sm87, and sm89 for headdim > 192 currently disabled")
2876*da0073e9SAndroid Build Coastguard Worker        if is_causal and seq_len_q != seq_len_k:
2877*da0073e9SAndroid Build Coastguard Worker            self.skipTest("Flash V2 does not accept is_casual when seq_len_q != seq_len_k")
2878*da0073e9SAndroid Build Coastguard Worker        if TEST_WITH_ROCM and seq_len_q >= 1024 and seq_len_k >= 1024 and batch_size > 1:
2879*da0073e9SAndroid Build Coastguard Worker            torch.cuda.empty_cache()  # Prevent memory fragmentation
2880*da0073e9SAndroid Build Coastguard Worker
2881*da0073e9SAndroid Build Coastguard Worker        scale = scale if scale is None else (1 / head_dim)
2882*da0073e9SAndroid Build Coastguard Worker        n_heads = 4
2883*da0073e9SAndroid Build Coastguard Worker        query = torch.rand(batch_size, n_heads, seq_len_q, head_dim,
2884*da0073e9SAndroid Build Coastguard Worker                           device=device, dtype=dtype, requires_grad=True)
2885*da0073e9SAndroid Build Coastguard Worker        key = torch.rand(batch_size, n_heads, seq_len_k, head_dim, device=device,
2886*da0073e9SAndroid Build Coastguard Worker                         dtype=dtype, requires_grad=True)
2887*da0073e9SAndroid Build Coastguard Worker        value = torch.rand(batch_size, n_heads, seq_len_k, head_dim,
2888*da0073e9SAndroid Build Coastguard Worker                           device=device, dtype=dtype, requires_grad=True)
2889*da0073e9SAndroid Build Coastguard Worker
2890*da0073e9SAndroid Build Coastguard Worker        # Run the math kernel on low precision references
2891*da0073e9SAndroid Build Coastguard Worker        query_ref_lp, key_ref_lp, value_ref_lp = query_key_value_clones(query, key, value, dtype=dtype)
2892*da0073e9SAndroid Build Coastguard Worker
2893*da0073e9SAndroid Build Coastguard Worker        higher_precision_dtype = torch.float64 if dtype == torch.float32 else torch.float32
2894*da0073e9SAndroid Build Coastguard Worker        query_ref, key_ref, value_ref = query_key_value_clones(query, key, value, dtype=higher_precision_dtype)
2895*da0073e9SAndroid Build Coastguard Worker
2896*da0073e9SAndroid Build Coastguard Worker        is_dropout = dropout_p > 0.0
2897*da0073e9SAndroid Build Coastguard Worker
2898*da0073e9SAndroid Build Coastguard Worker        if not is_dropout:
2899*da0073e9SAndroid Build Coastguard Worker            with sdpa_kernel(backends=[SDPBackend.FLASH_ATTENTION]):
2900*da0073e9SAndroid Build Coastguard Worker                out = F.scaled_dot_product_attention(query, key, value, dropout_p=dropout_p, is_causal=is_causal, scale=scale)
2901*da0073e9SAndroid Build Coastguard Worker            with sdpa_kernel(backends=[SDPBackend.MATH]):
2902*da0073e9SAndroid Build Coastguard Worker                # High Precision Math Reference
2903*da0073e9SAndroid Build Coastguard Worker                out_ref = F.scaled_dot_product_attention(
2904*da0073e9SAndroid Build Coastguard Worker                    query_ref, key_ref, value_ref, is_causal=is_causal, scale=scale)
2905*da0073e9SAndroid Build Coastguard Worker                # Low Precision Math Reference
2906*da0073e9SAndroid Build Coastguard Worker                out_lp_ref = F.scaled_dot_product_attention(
2907*da0073e9SAndroid Build Coastguard Worker                    query_ref_lp, key_ref_lp, value_ref_lp, is_causal=is_causal, scale=scale)
2908*da0073e9SAndroid Build Coastguard Worker        else:
2909*da0073e9SAndroid Build Coastguard Worker            # Problem: We pad sizes in the composite region of the top level SDPA. But we need the
2910*da0073e9SAndroid Build Coastguard Worker            # Debug mask when have dropout. So I am going to manualy pad up here when testing dropout
2911*da0073e9SAndroid Build Coastguard Worker            q_padded, q_og_size = pad_last_dim(query, 8)
2912*da0073e9SAndroid Build Coastguard Worker            k_padded, k_og_size = pad_last_dim(key, 8)
2913*da0073e9SAndroid Build Coastguard Worker            v_padded, v_og_size = pad_last_dim(value, 8)
2914*da0073e9SAndroid Build Coastguard Worker            # scale needs to be calculated on the og head_size
2915*da0073e9SAndroid Build Coastguard Worker            if scale is None:
2916*da0073e9SAndroid Build Coastguard Worker                scale = 1 / math.sqrt(q_og_size)
2917*da0073e9SAndroid Build Coastguard Worker            output_tuple = torch.ops.aten._scaled_dot_product_flash_attention(
2918*da0073e9SAndroid Build Coastguard Worker                q_padded, k_padded, v_padded, dropout_p=dropout_p, is_causal=is_causal, scale=scale, return_debug_mask=is_dropout)
2919*da0073e9SAndroid Build Coastguard Worker            out = output_tuple[0]
2920*da0073e9SAndroid Build Coastguard Worker            out = out[..., :v_og_size]
2921*da0073e9SAndroid Build Coastguard Worker            # Build dropout_mask
2922*da0073e9SAndroid Build Coastguard Worker            dbug_mask = output_tuple[-1]
2923*da0073e9SAndroid Build Coastguard Worker            query_padding_mask = torch.ones(
2924*da0073e9SAndroid Build Coastguard Worker                batch_size, seq_len_q, device=device, dtype=torch.bool)
2925*da0073e9SAndroid Build Coastguard Worker            key_padding_mask = torch.ones(
2926*da0073e9SAndroid Build Coastguard Worker                batch_size, seq_len_k, device=device, dtype=torch.bool)
2927*da0073e9SAndroid Build Coastguard Worker
2928*da0073e9SAndroid Build Coastguard Worker            softmax_mask = self.convert_flash_attn_S_to_softmax(
2929*da0073e9SAndroid Build Coastguard Worker                dbug_mask, seq_len_q, seq_len_k, query_padding_mask, key_padding_mask,
2930*da0073e9SAndroid Build Coastguard Worker                causal=is_causal)[:, :, :seq_len_q, :seq_len_k]
2931*da0073e9SAndroid Build Coastguard Worker            dropout_mask = softmax_mask >= 0
2932*da0073e9SAndroid Build Coastguard Worker            # attn_unnorm = softmax_mask.abs()
2933*da0073e9SAndroid Build Coastguard Worker            # attn = self.normalize_flash_attn_S(attn_unnorm, q_padded,
2934*da0073e9SAndroid Build Coastguard Worker            #                                    k_padded, v_padded, query_padding_mask,
2935*da0073e9SAndroid Build Coastguard Worker            #                                    key_padding_mask, None, True, is_causal, scale=scale)
2936*da0073e9SAndroid Build Coastguard Worker
2937*da0073e9SAndroid Build Coastguard Worker            # High Precision Math Reference
2938*da0073e9SAndroid Build Coastguard Worker            out_ref = torch.ops.aten._scaled_dot_product_attention_math(
2939*da0073e9SAndroid Build Coastguard Worker                query_ref, key_ref, value_ref, dropout_p=dropout_p, is_causal=is_causal, scale=scale, dropout_mask=dropout_mask)[0]
2940*da0073e9SAndroid Build Coastguard Worker            # Low Precision Math Reference
2941*da0073e9SAndroid Build Coastguard Worker            out_lp_ref = torch.ops.aten._scaled_dot_product_attention_math(
2942*da0073e9SAndroid Build Coastguard Worker                query_ref_lp, key_ref_lp, value_ref_lp, dropout_p=dropout_p, is_causal=is_causal, scale=scale,
2943*da0073e9SAndroid Build Coastguard Worker                dropout_mask=dropout_mask)[0]
2944*da0073e9SAndroid Build Coastguard Worker
2945*da0073e9SAndroid Build Coastguard Worker        upstream_grad = torch.rand_like(out, requires_grad=False)
2946*da0073e9SAndroid Build Coastguard Worker
2947*da0073e9SAndroid Build Coastguard Worker        # backward for flash attention on sm86, sm87, and sm89 for headdim >= 193 currently disabled
2948*da0073e9SAndroid Build Coastguard Worker        if isSM8XDevice and head_dim in range(193, 256):
2949*da0073e9SAndroid Build Coastguard Worker            self.assertRaises(RuntimeError, lambda: out.backward(upstream_grad))
2950*da0073e9SAndroid Build Coastguard Worker            return
2951*da0073e9SAndroid Build Coastguard Worker        out.backward(upstream_grad)
2952*da0073e9SAndroid Build Coastguard Worker        out_ref.backward(upstream_grad.to(out_ref.dtype))
2953*da0073e9SAndroid Build Coastguard Worker        out_lp_ref.backward(upstream_grad.to(out_lp_ref.dtype))
2954*da0073e9SAndroid Build Coastguard Worker
2955*da0073e9SAndroid Build Coastguard Worker        # See [Note] Fused Tolerances above
2956*da0073e9SAndroid Build Coastguard Worker        output_fudge_factor = 3 if head_dim % 8 != 0 or TEST_WITH_ROCM else 1
2957*da0073e9SAndroid Build Coastguard Worker        output_ref_atol, output_ref_rtol = get_tolerances(out_ref, out_lp_ref, output_fudge_factor)
2958*da0073e9SAndroid Build Coastguard Worker
2959*da0073e9SAndroid Build Coastguard Worker        # TODO: Investigate why grad_q needs larger tolerances
2960*da0073e9SAndroid Build Coastguard Worker        query_fudge_factor = 4
2961*da0073e9SAndroid Build Coastguard Worker        grad_q_ref_atol, grad_q_ref_rtol = get_tolerances(query_ref.grad, query_ref_lp.grad, query_fudge_factor)
2962*da0073e9SAndroid Build Coastguard Worker
2963*da0073e9SAndroid Build Coastguard Worker        key_fudge_factor = 2
2964*da0073e9SAndroid Build Coastguard Worker        grad_k_ref_atol, grad_k_ref_rtol = get_tolerances(key_ref.grad, key_ref_lp.grad, key_fudge_factor)
2965*da0073e9SAndroid Build Coastguard Worker
2966*da0073e9SAndroid Build Coastguard Worker        value_fudge_factor = 2
2967*da0073e9SAndroid Build Coastguard Worker        grad_v_ref_atol, grad_v_ref_rtol = get_tolerances(value_ref.grad, value_ref_lp.grad, value_fudge_factor)
2968*da0073e9SAndroid Build Coastguard Worker
2969*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(out, out_ref.to(out.dtype), atol=output_ref_atol, rtol=output_ref_rtol)
2970*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(query.grad, query_ref.grad.to(query.grad.dtype),
2971*da0073e9SAndroid Build Coastguard Worker                         atol=grad_q_ref_atol, rtol=grad_q_ref_rtol)
2972*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(key.grad, key_ref.grad.to(key.grad.dtype),
2973*da0073e9SAndroid Build Coastguard Worker                         atol=grad_k_ref_atol, rtol=grad_k_ref_rtol)
2974*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(value.grad, value_ref.grad.to(value.grad.dtype),
2975*da0073e9SAndroid Build Coastguard Worker                         atol=grad_v_ref_atol, rtol=grad_v_ref_rtol)
2976*da0073e9SAndroid Build Coastguard Worker
2977*da0073e9SAndroid Build Coastguard Worker    @skipIfRocm  # FIXME: "capturing stream has unjoined work"
2978*da0073e9SAndroid Build Coastguard Worker    @unittest.skipIf(not PLATFORM_SUPPORTS_FLASH_ATTENTION, "Does not support SDPA or pre-SM80 hardware")
2979*da0073e9SAndroid Build Coastguard Worker    @parametrize("batch_size", [1, 8])
2980*da0073e9SAndroid Build Coastguard Worker    @parametrize("seq_len_q", [256, 512, 1024])
2981*da0073e9SAndroid Build Coastguard Worker    @parametrize("seq_len_k", [256, 512, 1024])
2982*da0073e9SAndroid Build Coastguard Worker    @parametrize("head_dim", [32, 64])
2983*da0073e9SAndroid Build Coastguard Worker    @parametrize("is_causal", [True, False])
2984*da0073e9SAndroid Build Coastguard Worker    @parametrize("dropout_p", [0.0, 0.22])
2985*da0073e9SAndroid Build Coastguard Worker    @parametrize("dtype", [torch.float16,])
2986*da0073e9SAndroid Build Coastguard Worker    @parametrize("scale", [None, "l1"])
2987*da0073e9SAndroid Build Coastguard Worker    @parametrize("fused_kernel", PLATFORM_SPECIFIC_SDPA)
2988*da0073e9SAndroid Build Coastguard Worker    def test_fused_attention_vs_math_ref_grads_cudagraph(self, device, batch_size: int, seq_len_q: int, seq_len_k: int,
2989*da0073e9SAndroid Build Coastguard Worker                                                         head_dim: int,
2990*da0073e9SAndroid Build Coastguard Worker                                                         is_causal: bool,
2991*da0073e9SAndroid Build Coastguard Worker                                                         dropout_p: float,
2992*da0073e9SAndroid Build Coastguard Worker                                                         dtype: torch.dtype,
2993*da0073e9SAndroid Build Coastguard Worker                                                         scale: str,
2994*da0073e9SAndroid Build Coastguard Worker                                                         fused_kernel: SDPBackend):
2995*da0073e9SAndroid Build Coastguard Worker        def _get_mem_eff_drop_mask(batch_size, n_heads, q_len, kv_len, dropout_p, seed, offset, device=device):
2996*da0073e9SAndroid Build Coastguard Worker            mask = torch.empty((batch_size, n_heads, q_len, kv_len), device=device, dtype=torch.float32)
2997*da0073e9SAndroid Build Coastguard Worker            rand_uniform = torch._fill_mem_eff_dropout_mask_(mask, dropout_p, seed, offset)
2998*da0073e9SAndroid Build Coastguard Worker            mask = (rand_uniform > dropout_p).to(torch.float32)
2999*da0073e9SAndroid Build Coastguard Worker            return mask
3000*da0073e9SAndroid Build Coastguard Worker
3001*da0073e9SAndroid Build Coastguard Worker        def get_dropout_mask(output, fused_kernel, batch_size, n_heads, q_len, kv_len, dropout_p, device=device):
3002*da0073e9SAndroid Build Coastguard Worker            if fused_kernel == SDPBackend.EFFICIENT_ATTENTION:
3003*da0073e9SAndroid Build Coastguard Worker                output_seed, output_offset = output_tuple[2], output_tuple[3]
3004*da0073e9SAndroid Build Coastguard Worker                output_seed = output_seed.item()
3005*da0073e9SAndroid Build Coastguard Worker                output_offset = output_offset.item()
3006*da0073e9SAndroid Build Coastguard Worker                return _get_mem_eff_drop_mask(batch_size, n_heads, q_len, kv_len,
3007*da0073e9SAndroid Build Coastguard Worker                                              dropout_p, output_seed, output_offset, device=device)
3008*da0073e9SAndroid Build Coastguard Worker            else:
3009*da0073e9SAndroid Build Coastguard Worker                # Build dropout_mask
3010*da0073e9SAndroid Build Coastguard Worker                dbug_mask = output_tuple[-1]
3011*da0073e9SAndroid Build Coastguard Worker                query_padding_mask = torch.ones(
3012*da0073e9SAndroid Build Coastguard Worker                    batch_size, seq_len_q, device=device, dtype=torch.bool)
3013*da0073e9SAndroid Build Coastguard Worker                key_padding_mask = torch.ones(
3014*da0073e9SAndroid Build Coastguard Worker                    batch_size, seq_len_k, device=device, dtype=torch.bool)
3015*da0073e9SAndroid Build Coastguard Worker
3016*da0073e9SAndroid Build Coastguard Worker                softmax_mask = self.convert_flash_attn_S_to_softmax(
3017*da0073e9SAndroid Build Coastguard Worker                    dbug_mask, seq_len_q, seq_len_k, query_padding_mask, key_padding_mask,
3018*da0073e9SAndroid Build Coastguard Worker                    causal=is_causal)[:, :, :seq_len_q, :seq_len_k]
3019*da0073e9SAndroid Build Coastguard Worker                dropout_mask = softmax_mask >= 0
3020*da0073e9SAndroid Build Coastguard Worker                return dropout_mask
3021*da0073e9SAndroid Build Coastguard Worker
3022*da0073e9SAndroid Build Coastguard Worker        if fused_kernel == SDPBackend.FLASH_ATTENTION and is_causal and seq_len_q != seq_len_k:
3023*da0073e9SAndroid Build Coastguard Worker            self.skipTest("Flash V2 does not accept is_casual when seq_len_q != seq_len_k")
3024*da0073e9SAndroid Build Coastguard Worker
3025*da0073e9SAndroid Build Coastguard Worker        seed = 42
3026*da0073e9SAndroid Build Coastguard Worker        scale = scale if scale is None else (1 / head_dim)
3027*da0073e9SAndroid Build Coastguard Worker        n_heads = 4
3028*da0073e9SAndroid Build Coastguard Worker        query = torch.rand(batch_size, n_heads, seq_len_q, head_dim,
3029*da0073e9SAndroid Build Coastguard Worker                           device=device, dtype=dtype, requires_grad=True)
3030*da0073e9SAndroid Build Coastguard Worker        key = torch.rand(batch_size, n_heads, seq_len_k, head_dim, device=device,
3031*da0073e9SAndroid Build Coastguard Worker                         dtype=dtype, requires_grad=True)
3032*da0073e9SAndroid Build Coastguard Worker        value = torch.rand(batch_size, n_heads, seq_len_k, head_dim,
3033*da0073e9SAndroid Build Coastguard Worker                           device=device, dtype=dtype, requires_grad=True)
3034*da0073e9SAndroid Build Coastguard Worker
3035*da0073e9SAndroid Build Coastguard Worker        fused_op = (torch.ops.aten._scaled_dot_product_efficient_attention
3036*da0073e9SAndroid Build Coastguard Worker                    if fused_kernel == SDPBackend.EFFICIENT_ATTENTION else torch.ops.aten._scaled_dot_product_flash_attention)
3037*da0073e9SAndroid Build Coastguard Worker        # Run the math kernel on low precision references
3038*da0073e9SAndroid Build Coastguard Worker        query_ref_lp, key_ref_lp, value_ref_lp = query_key_value_clones(query, key, value, dtype=dtype)
3039*da0073e9SAndroid Build Coastguard Worker
3040*da0073e9SAndroid Build Coastguard Worker        higher_precision_dtype = torch.float64 if dtype == torch.float32 else torch.float32
3041*da0073e9SAndroid Build Coastguard Worker        query_ref, key_ref, value_ref = query_key_value_clones(query, key, value, dtype=higher_precision_dtype)
3042*da0073e9SAndroid Build Coastguard Worker
3043*da0073e9SAndroid Build Coastguard Worker        # warmup
3044*da0073e9SAndroid Build Coastguard Worker        s = torch.cuda.Stream()
3045*da0073e9SAndroid Build Coastguard Worker        s.wait_stream(torch.cuda.current_stream())
3046*da0073e9SAndroid Build Coastguard Worker        # Set the global seed before capture
3047*da0073e9SAndroid Build Coastguard Worker        torch.manual_seed(seed)
3048*da0073e9SAndroid Build Coastguard Worker        kwargs = {"dropout_p": dropout_p, "is_causal": is_causal, "scale": scale}
3049*da0073e9SAndroid Build Coastguard Worker        if fused_kernel == SDPBackend.EFFICIENT_ATTENTION:
3050*da0073e9SAndroid Build Coastguard Worker            kwargs["compute_log_sumexp"] = True
3051*da0073e9SAndroid Build Coastguard Worker            kwargs["attn_bias"] = None
3052*da0073e9SAndroid Build Coastguard Worker        if fused_kernel == SDPBackend.FLASH_ATTENTION:
3053*da0073e9SAndroid Build Coastguard Worker            kwargs['return_debug_mask'] = dropout_p > 0.0
3054*da0073e9SAndroid Build Coastguard Worker        with torch.cuda.stream(s):
3055*da0073e9SAndroid Build Coastguard Worker            # Create real output
3056*da0073e9SAndroid Build Coastguard Worker            output_tuple = fused_op(query, key, value, **kwargs)
3057*da0073e9SAndroid Build Coastguard Worker
3058*da0073e9SAndroid Build Coastguard Worker        torch.cuda.current_stream().wait_stream(s)
3059*da0073e9SAndroid Build Coastguard Worker        out = output_tuple[0]
3060*da0073e9SAndroid Build Coastguard Worker        upstream_grad = torch.rand_like(out, requires_grad=False)
3061*da0073e9SAndroid Build Coastguard Worker        s.wait_stream(torch.cuda.current_stream())
3062*da0073e9SAndroid Build Coastguard Worker        with torch.cuda.stream(s):
3063*da0073e9SAndroid Build Coastguard Worker            out.backward(upstream_grad)
3064*da0073e9SAndroid Build Coastguard Worker        for x in (query, key, value):
3065*da0073e9SAndroid Build Coastguard Worker            x.grad = None
3066*da0073e9SAndroid Build Coastguard Worker        g = torch.cuda.CUDAGraph()
3067*da0073e9SAndroid Build Coastguard Worker        # Create real output
3068*da0073e9SAndroid Build Coastguard Worker        with torch.cuda.graph(g):
3069*da0073e9SAndroid Build Coastguard Worker            tmp = torch.rand_like(query, device=query.device)  # test non-zero intragraph offset
3070*da0073e9SAndroid Build Coastguard Worker            # Create real output
3071*da0073e9SAndroid Build Coastguard Worker            output_tuple = fused_op(query, key, value, **kwargs)
3072*da0073e9SAndroid Build Coastguard Worker            assert all(not isinstance(o, torch.Tensor) or o.is_cuda for o in output_tuple)
3073*da0073e9SAndroid Build Coastguard Worker        g.replay()
3074*da0073e9SAndroid Build Coastguard Worker        out_first = output_tuple[0].clone()
3075*da0073e9SAndroid Build Coastguard Worker        g.replay()
3076*da0073e9SAndroid Build Coastguard Worker        out = output_tuple[0]
3077*da0073e9SAndroid Build Coastguard Worker        if dropout_p == 0.0:
3078*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(out_first, out, atol=0, rtol=0)
3079*da0073e9SAndroid Build Coastguard Worker        else:
3080*da0073e9SAndroid Build Coastguard Worker            # replays produce different results
3081*da0073e9SAndroid Build Coastguard Worker            self.assertNotEqual(out_first, out)
3082*da0073e9SAndroid Build Coastguard Worker
3083*da0073e9SAndroid Build Coastguard Worker        with sdpa_kernel(backends=[SDPBackend.MATH]):
3084*da0073e9SAndroid Build Coastguard Worker            if dropout_p == 0.0:
3085*da0073e9SAndroid Build Coastguard Worker                # High Precision Math Reference
3086*da0073e9SAndroid Build Coastguard Worker                out_ref = F.scaled_dot_product_attention(query_ref, key_ref, value_ref,
3087*da0073e9SAndroid Build Coastguard Worker                                                         dropout_p=dropout_p, is_causal=is_causal, scale=scale)
3088*da0073e9SAndroid Build Coastguard Worker                # Low Precision Math Reference
3089*da0073e9SAndroid Build Coastguard Worker                out_lp_ref = F.scaled_dot_product_attention(query_ref_lp, key_ref_lp, value_ref_lp,
3090*da0073e9SAndroid Build Coastguard Worker                                                            dropout_p=dropout_p, is_causal=is_causal, scale=scale)
3091*da0073e9SAndroid Build Coastguard Worker            else:
3092*da0073e9SAndroid Build Coastguard Worker                # Create the dropout_mask
3093*da0073e9SAndroid Build Coastguard Worker                dropout_mask = get_dropout_mask(output_tuple, fused_kernel, batch_size,
3094*da0073e9SAndroid Build Coastguard Worker                                                n_heads, seq_len_q, seq_len_k, dropout_p, device)
3095*da0073e9SAndroid Build Coastguard Worker                # High Precision Math Reference
3096*da0073e9SAndroid Build Coastguard Worker                out_ref = torch.ops.aten._scaled_dot_product_attention_math(
3097*da0073e9SAndroid Build Coastguard Worker                    query_ref, key_ref, value_ref, dropout_p=dropout_p, is_causal=is_causal,
3098*da0073e9SAndroid Build Coastguard Worker                    scale=scale, dropout_mask=dropout_mask)[0]
3099*da0073e9SAndroid Build Coastguard Worker                # Low Precision Math Reference
3100*da0073e9SAndroid Build Coastguard Worker                out_lp_ref = torch.ops.aten._scaled_dot_product_attention_math(
3101*da0073e9SAndroid Build Coastguard Worker                    query_ref_lp, key_ref_lp, value_ref_lp, dropout_p=dropout_p, is_causal=is_causal, scale=scale,
3102*da0073e9SAndroid Build Coastguard Worker                    dropout_mask=dropout_mask)[0]
3103*da0073e9SAndroid Build Coastguard Worker
3104*da0073e9SAndroid Build Coastguard Worker
3105*da0073e9SAndroid Build Coastguard Worker        g1 = torch.cuda.CUDAGraph()
3106*da0073e9SAndroid Build Coastguard Worker        with torch.cuda.graph(g1):
3107*da0073e9SAndroid Build Coastguard Worker            out.backward(upstream_grad)
3108*da0073e9SAndroid Build Coastguard Worker        g1.replay()
3109*da0073e9SAndroid Build Coastguard Worker        out_ref.backward(upstream_grad.to(out_ref.dtype))
3110*da0073e9SAndroid Build Coastguard Worker        out_lp_ref.backward(upstream_grad.to(out_lp_ref.dtype))
3111*da0073e9SAndroid Build Coastguard Worker
3112*da0073e9SAndroid Build Coastguard Worker        # [Note] Fused Tolerances
3113*da0073e9SAndroid Build Coastguard Worker        # Establish the numerical error between the "true" high precision math output
3114*da0073e9SAndroid Build Coastguard Worker        # and the low precision math reference. We use this reference for the atol
3115*da0073e9SAndroid Build Coastguard Worker        # And we use the default rtol for the low precision type.
3116*da0073e9SAndroid Build Coastguard Worker        # We then provide a fudge factor for gradients respectively to account
3117*da0073e9SAndroid Build Coastguard Worker        # for the use of the fused kernel rather than the eager implemntation.
3118*da0073e9SAndroid Build Coastguard Worker        output_ref_atol, output_ref_rtol = get_tolerances(out_ref, out_lp_ref)
3119*da0073e9SAndroid Build Coastguard Worker
3120*da0073e9SAndroid Build Coastguard Worker        # Fudge Factor when dropout is enabled
3121*da0073e9SAndroid Build Coastguard Worker        dropout_fudge_factor = 1.0 if dropout_p == 0.0 else 1.5
3122*da0073e9SAndroid Build Coastguard Worker
3123*da0073e9SAndroid Build Coastguard Worker        query_fudge_factor = dropout_fudge_factor
3124*da0073e9SAndroid Build Coastguard Worker        grad_q_ref_atol, grad_q_ref_rtol = get_tolerances(query_ref.grad, query_ref_lp.grad, query_fudge_factor)
3125*da0073e9SAndroid Build Coastguard Worker
3126*da0073e9SAndroid Build Coastguard Worker        # TODO: Investigate why grad_k needs larger tolerances
3127*da0073e9SAndroid Build Coastguard Worker        key_fudge_factor = 8 * dropout_fudge_factor
3128*da0073e9SAndroid Build Coastguard Worker        grad_k_ref_atol, grad_k_ref_rtol = get_tolerances(key_ref.grad, key_ref_lp.grad, key_fudge_factor)
3129*da0073e9SAndroid Build Coastguard Worker
3130*da0073e9SAndroid Build Coastguard Worker        value_fudge_factor = 7 if not SM80OrLater and dtype == torch.float16 else 1.0
3131*da0073e9SAndroid Build Coastguard Worker        grad_v_ref_atol, grad_v_ref_rtol = get_tolerances(value_ref.grad, value_ref_lp.grad, value_fudge_factor)
3132*da0073e9SAndroid Build Coastguard Worker
3133*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(out, out_ref.to(out.dtype), atol=output_ref_atol, rtol=output_ref_rtol)
3134*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(query.grad, query_ref.grad.to(query.grad.dtype),
3135*da0073e9SAndroid Build Coastguard Worker                         atol=grad_q_ref_atol, rtol=grad_q_ref_rtol)
3136*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(key.grad, key_ref.grad.to(key.grad.dtype),
3137*da0073e9SAndroid Build Coastguard Worker                         atol=grad_k_ref_atol, rtol=grad_k_ref_rtol)
3138*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(value.grad, value_ref.grad.to(value.grad.dtype),
3139*da0073e9SAndroid Build Coastguard Worker                         atol=grad_v_ref_atol, rtol=grad_v_ref_rtol)
3140*da0073e9SAndroid Build Coastguard Worker
3141*da0073e9SAndroid Build Coastguard Worker    @skipIfRocm  # Nested Tensor
3142*da0073e9SAndroid Build Coastguard Worker    @unittest.skipIf(not PLATFORM_SUPPORTS_FUSED_ATTENTION, "Fused SDPA was not built for this system")
3143*da0073e9SAndroid Build Coastguard Worker    @parametrize("fused_kernel", [SDPBackend.FLASH_ATTENTION, SDPBackend.EFFICIENT_ATTENTION] if
3144*da0073e9SAndroid Build Coastguard Worker                 PLATFORM_SUPPORTS_FLASH_ATTENTION else [SDPBackend.EFFICIENT_ATTENTION])
3145*da0073e9SAndroid Build Coastguard Worker    def test_fused_kernels_seq_len_1_inputs(self, device, fused_kernel):
3146*da0073e9SAndroid Build Coastguard Worker        rand_nested_tensor = partial(rand_sdpa_tensor, type="nested", device=device, dtype=torch.float16)
3147*da0073e9SAndroid Build Coastguard Worker        batch, num_heads, head_dim = 32, 16, 64
3148*da0073e9SAndroid Build Coastguard Worker        seq_lens = torch.randint(low=1, high=32, size=(batch,))
3149*da0073e9SAndroid Build Coastguard Worker        # make sure some seq_lens are 1
3150*da0073e9SAndroid Build Coastguard Worker        num_ones = 10
3151*da0073e9SAndroid Build Coastguard Worker        indices = torch.randint(low=0, high=batch, size=(num_ones,))
3152*da0073e9SAndroid Build Coastguard Worker        seq_lens.scatter_(0, indices, 1)
3153*da0073e9SAndroid Build Coastguard Worker
3154*da0073e9SAndroid Build Coastguard Worker        shape = SdpaShape(batch, num_heads, seq_lens.tolist(), head_dim)
3155*da0073e9SAndroid Build Coastguard Worker        query = rand_nested_tensor(shape)
3156*da0073e9SAndroid Build Coastguard Worker        key = rand_nested_tensor(shape)
3157*da0073e9SAndroid Build Coastguard Worker        value = rand_nested_tensor(shape)
3158*da0073e9SAndroid Build Coastguard Worker
3159*da0073e9SAndroid Build Coastguard Worker        query = query.transpose(1, 2)
3160*da0073e9SAndroid Build Coastguard Worker        key = key.transpose(1, 2)
3161*da0073e9SAndroid Build Coastguard Worker        value = value.transpose(1, 2)
3162*da0073e9SAndroid Build Coastguard Worker
3163*da0073e9SAndroid Build Coastguard Worker        with sdpa_kernel(backends=[fused_kernel]):
3164*da0073e9SAndroid Build Coastguard Worker            actual = torch.nn.functional.scaled_dot_product_attention(
3165*da0073e9SAndroid Build Coastguard Worker                query, key, value, attn_mask=None, dropout_p=0.0, is_causal=False)
3166*da0073e9SAndroid Build Coastguard Worker        with sdpa_kernel(backends=[SDPBackend.MATH]):
3167*da0073e9SAndroid Build Coastguard Worker            math_ref = torch.nn.functional.scaled_dot_product_attention(
3168*da0073e9SAndroid Build Coastguard Worker                query.contiguous().to(torch.float32),
3169*da0073e9SAndroid Build Coastguard Worker                key.contiguous().to(torch.float32),
3170*da0073e9SAndroid Build Coastguard Worker                value.contiguous().to(torch.float32),
3171*da0073e9SAndroid Build Coastguard Worker                attn_mask=None, dropout_p=0.0, is_causal=False)
3172*da0073e9SAndroid Build Coastguard Worker
3173*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(actual.contiguous(), math_ref.contiguous().to(torch.float16), atol=1e-3, rtol=1e-2)
3174*da0073e9SAndroid Build Coastguard Worker
3175*da0073e9SAndroid Build Coastguard Worker
3176*da0073e9SAndroid Build Coastguard Worker    @skipIfRocm  # Nested tensor
3177*da0073e9SAndroid Build Coastguard Worker    @unittest.skipIf(not PLATFORM_SUPPORTS_FUSED_ATTENTION, "Fused SDPA was not built for this system")
3178*da0073e9SAndroid Build Coastguard Worker    @parametrize("kernel", [SDPBackend.FLASH_ATTENTION, SDPBackend.EFFICIENT_ATTENTION] if
3179*da0073e9SAndroid Build Coastguard Worker                 PLATFORM_SUPPORTS_FLASH_ATTENTION else [SDPBackend.EFFICIENT_ATTENTION])
3180*da0073e9SAndroid Build Coastguard Worker    @parametrize("expand_q_batch", [True, False])
3181*da0073e9SAndroid Build Coastguard Worker    @parametrize("expand_k_batch", [True, False])
3182*da0073e9SAndroid Build Coastguard Worker    @parametrize("expand_v_batch", [True, False])
3183*da0073e9SAndroid Build Coastguard Worker    @parametrize("expand_q_num_heads", [True, False])
3184*da0073e9SAndroid Build Coastguard Worker    @parametrize("expand_k_num_heads", [True, False])
3185*da0073e9SAndroid Build Coastguard Worker    @parametrize("expand_v_num_heads", [True, False])
3186*da0073e9SAndroid Build Coastguard Worker    def test_fused_kernels_nested_broadcasting(
3187*da0073e9SAndroid Build Coastguard Worker        self,
3188*da0073e9SAndroid Build Coastguard Worker        device,
3189*da0073e9SAndroid Build Coastguard Worker        kernel,
3190*da0073e9SAndroid Build Coastguard Worker        expand_q_batch,
3191*da0073e9SAndroid Build Coastguard Worker        expand_k_batch,
3192*da0073e9SAndroid Build Coastguard Worker        expand_v_batch,
3193*da0073e9SAndroid Build Coastguard Worker        expand_q_num_heads,
3194*da0073e9SAndroid Build Coastguard Worker        expand_k_num_heads,
3195*da0073e9SAndroid Build Coastguard Worker        expand_v_num_heads,
3196*da0073e9SAndroid Build Coastguard Worker    ):
3197*da0073e9SAndroid Build Coastguard Worker        is_efficient = kernel == SDPBackend.EFFICIENT_ATTENTION
3198*da0073e9SAndroid Build Coastguard Worker        dtype = torch.float32 if is_efficient else torch.float16
3199*da0073e9SAndroid Build Coastguard Worker        rand_nested_tensor = partial(rand_sdpa_tensor, type="nested", device=device, dtype=dtype)
3200*da0073e9SAndroid Build Coastguard Worker        batch, num_heads, head_dim = 32, 8, 64
3201*da0073e9SAndroid Build Coastguard Worker        head_dim_v = 32 if is_efficient else head_dim
3202*da0073e9SAndroid Build Coastguard Worker        seq_lens_q = (torch.randint(low=1, high=5, size=(1,)).item()
3203*da0073e9SAndroid Build Coastguard Worker                      if expand_q_batch
3204*da0073e9SAndroid Build Coastguard Worker                      else torch.randint(low=1, high=32, size=(batch,)).tolist())
3205*da0073e9SAndroid Build Coastguard Worker        seq_lens_kv = (torch.randint(low=1, high=5, size=(1,)).item()
3206*da0073e9SAndroid Build Coastguard Worker                       if (expand_k_batch or expand_v_batch)
3207*da0073e9SAndroid Build Coastguard Worker                       else torch.randint(low=1, high=32, size=(batch,)).tolist())
3208*da0073e9SAndroid Build Coastguard Worker
3209*da0073e9SAndroid Build Coastguard Worker        batch_q = 1 if expand_q_batch else batch
3210*da0073e9SAndroid Build Coastguard Worker        batch_k = 1 if expand_k_batch else batch
3211*da0073e9SAndroid Build Coastguard Worker        batch_v = 1 if expand_v_batch else batch
3212*da0073e9SAndroid Build Coastguard Worker
3213*da0073e9SAndroid Build Coastguard Worker        # handle case where all batch_sizes are 1
3214*da0073e9SAndroid Build Coastguard Worker        batch = max(batch_q, batch_k, batch_v)
3215*da0073e9SAndroid Build Coastguard Worker
3216*da0073e9SAndroid Build Coastguard Worker        num_heads_q = 1 if expand_q_num_heads else num_heads
3217*da0073e9SAndroid Build Coastguard Worker        num_heads_k = 1 if expand_k_num_heads else num_heads
3218*da0073e9SAndroid Build Coastguard Worker        num_heads_v = 1 if expand_v_num_heads else num_heads
3219*da0073e9SAndroid Build Coastguard Worker
3220*da0073e9SAndroid Build Coastguard Worker        # handle case where all num_heads are 1
3221*da0073e9SAndroid Build Coastguard Worker        num_heads = max(num_heads_q, num_heads_k, num_heads_v)
3222*da0073e9SAndroid Build Coastguard Worker
3223*da0073e9SAndroid Build Coastguard Worker        q_shape = SdpaShape(batch_q, num_heads_q, seq_lens_q, head_dim)
3224*da0073e9SAndroid Build Coastguard Worker        k_shape = SdpaShape(batch_k, num_heads_k, seq_lens_kv, head_dim)
3225*da0073e9SAndroid Build Coastguard Worker        v_shape = SdpaShape(batch_v, num_heads_v, seq_lens_kv, head_dim_v)
3226*da0073e9SAndroid Build Coastguard Worker
3227*da0073e9SAndroid Build Coastguard Worker        query = rand_nested_tensor(q_shape)
3228*da0073e9SAndroid Build Coastguard Worker        key = rand_nested_tensor(k_shape)
3229*da0073e9SAndroid Build Coastguard Worker        value = rand_nested_tensor(v_shape)
3230*da0073e9SAndroid Build Coastguard Worker
3231*da0073e9SAndroid Build Coastguard Worker        def _broadcast(t, batch_broadcasted, num_heads_broadcasted):
3232*da0073e9SAndroid Build Coastguard Worker            if batch_broadcasted and num_heads_broadcasted:
3233*da0073e9SAndroid Build Coastguard Worker                # (1, seq_len, 1, head_dim) -> (batch, seq_len, num_heads, head_dim)
3234*da0073e9SAndroid Build Coastguard Worker                result = torch.nested.nested_tensor(
3235*da0073e9SAndroid Build Coastguard Worker                    [t[0].expand(-1, num_heads, t.size(-1)) for _ in range(batch)], dtype=torch.float32)
3236*da0073e9SAndroid Build Coastguard Worker            elif batch_broadcasted:
3237*da0073e9SAndroid Build Coastguard Worker                # (1, seq_len, num_heads, head_dim) -> (batch, seq_len, num_heads, head_dim)
3238*da0073e9SAndroid Build Coastguard Worker                result = torch.nested.nested_tensor([t[0] for _ in range(batch)], dtype=torch.float32)
3239*da0073e9SAndroid Build Coastguard Worker            elif num_heads_broadcasted:
3240*da0073e9SAndroid Build Coastguard Worker                # (batch, seq_len, 1, head_dim) -> (batch, seq_len, num_heads, head_dim)
3241*da0073e9SAndroid Build Coastguard Worker                result = torch.nested.nested_tensor([x.expand(-1, num_heads, t.size(-1))
3242*da0073e9SAndroid Build Coastguard Worker                                                    for x in t.unbind()], dtype=torch.float32)
3243*da0073e9SAndroid Build Coastguard Worker            else:
3244*da0073e9SAndroid Build Coastguard Worker                result = t.to(torch.float32)
3245*da0073e9SAndroid Build Coastguard Worker            return result
3246*da0073e9SAndroid Build Coastguard Worker
3247*da0073e9SAndroid Build Coastguard Worker        query_expanded = _broadcast(query, expand_q_batch, expand_q_num_heads).transpose(1, 2)
3248*da0073e9SAndroid Build Coastguard Worker        key_expanded = _broadcast(key, expand_k_batch, expand_k_num_heads).transpose(1, 2)
3249*da0073e9SAndroid Build Coastguard Worker        value_expanded = _broadcast(value, expand_v_batch, expand_v_num_heads).transpose(1, 2)
3250*da0073e9SAndroid Build Coastguard Worker
3251*da0073e9SAndroid Build Coastguard Worker        query = query.transpose(1, 2)
3252*da0073e9SAndroid Build Coastguard Worker        key = key.transpose(1, 2)
3253*da0073e9SAndroid Build Coastguard Worker        value = value.transpose(1, 2)
3254*da0073e9SAndroid Build Coastguard Worker
3255*da0073e9SAndroid Build Coastguard Worker        with sdpa_kernel(backends=[kernel]):
3256*da0073e9SAndroid Build Coastguard Worker            actual = torch.nn.functional.scaled_dot_product_attention(
3257*da0073e9SAndroid Build Coastguard Worker                query, key, value, attn_mask=None, dropout_p=0.0, is_causal=False)
3258*da0073e9SAndroid Build Coastguard Worker        with sdpa_kernel(backends=[SDPBackend.MATH]):
3259*da0073e9SAndroid Build Coastguard Worker            math_ref = torch.nn.functional.scaled_dot_product_attention(
3260*da0073e9SAndroid Build Coastguard Worker                query_expanded.contiguous(), key_expanded.contiguous(), value_expanded.contiguous(),
3261*da0073e9SAndroid Build Coastguard Worker                attn_mask=None, dropout_p=0.0, is_causal=False)
3262*da0073e9SAndroid Build Coastguard Worker
3263*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(actual.contiguous(), math_ref.contiguous().to(dtype), atol=1e-3, rtol=1e-2)
3264*da0073e9SAndroid Build Coastguard Worker
3265*da0073e9SAndroid Build Coastguard Worker    @skipIfRocm  # Nested tensor
3266*da0073e9SAndroid Build Coastguard Worker    @unittest.skipIf(not PLATFORM_SUPPORTS_MEM_EFF_ATTENTION, "Fused SDPA was not built for this system")
3267*da0073e9SAndroid Build Coastguard Worker    def test_fused_kernels_nested_broadcasting_query_dense(self, device):
3268*da0073e9SAndroid Build Coastguard Worker        rand_nested_tensor = partial(rand_sdpa_tensor, type="nested", device=device, dtype=torch.float32)
3269*da0073e9SAndroid Build Coastguard Worker        batch, num_heads, head_dim, head_dim_v = 32, 16, 64, 96
3270*da0073e9SAndroid Build Coastguard Worker        seq_lens = torch.randint(low=1, high=32, size=(batch,)).tolist()
3271*da0073e9SAndroid Build Coastguard Worker        q_shape = (1, 1, num_heads, head_dim)
3272*da0073e9SAndroid Build Coastguard Worker        k_shape = SdpaShape(batch, num_heads, seq_lens, head_dim)
3273*da0073e9SAndroid Build Coastguard Worker        v_shape = SdpaShape(batch, 1, seq_lens, head_dim_v)
3274*da0073e9SAndroid Build Coastguard Worker
3275*da0073e9SAndroid Build Coastguard Worker        # create a dense query
3276*da0073e9SAndroid Build Coastguard Worker        query = torch.randn(q_shape, device=device, dtype=torch.float32)
3277*da0073e9SAndroid Build Coastguard Worker        key = rand_nested_tensor(k_shape)
3278*da0073e9SAndroid Build Coastguard Worker        value = rand_nested_tensor(v_shape)
3279*da0073e9SAndroid Build Coastguard Worker
3280*da0073e9SAndroid Build Coastguard Worker        # (1, 1, num_heads, head_dim) -> (batch, 1, num_heads, head_dim)
3281*da0073e9SAndroid Build Coastguard Worker        query_expanded = torch.nested.nested_tensor([query.squeeze(0) for _ in range(batch)]).transpose(1, 2)
3282*da0073e9SAndroid Build Coastguard Worker        # (batch, seq_lens, 1, head_dim) -> (batch, seq_lens, num_heads, head_dim)
3283*da0073e9SAndroid Build Coastguard Worker        value_expanded = torch.nested.nested_tensor(
3284*da0073e9SAndroid Build Coastguard Worker            [t.expand(-1, num_heads, head_dim_v) for t in value.unbind()]).transpose(1, 2)
3285*da0073e9SAndroid Build Coastguard Worker
3286*da0073e9SAndroid Build Coastguard Worker        query = query.transpose(1, 2)
3287*da0073e9SAndroid Build Coastguard Worker        key = key.transpose(1, 2)
3288*da0073e9SAndroid Build Coastguard Worker        value = value.transpose(1, 2)
3289*da0073e9SAndroid Build Coastguard Worker
3290*da0073e9SAndroid Build Coastguard Worker        with sdpa_kernel(backends=[SDPBackend.EFFICIENT_ATTENTION]):
3291*da0073e9SAndroid Build Coastguard Worker            actual = torch.nn.functional.scaled_dot_product_attention(
3292*da0073e9SAndroid Build Coastguard Worker                query, key, value, attn_mask=None, dropout_p=0.0, is_causal=False)
3293*da0073e9SAndroid Build Coastguard Worker        with sdpa_kernel(backends=[SDPBackend.MATH]):
3294*da0073e9SAndroid Build Coastguard Worker            math_ref = torch.nn.functional.scaled_dot_product_attention(
3295*da0073e9SAndroid Build Coastguard Worker                query_expanded.contiguous(), key.contiguous(), value_expanded.contiguous(),
3296*da0073e9SAndroid Build Coastguard Worker                attn_mask=None, dropout_p=0.0, is_causal=False)
3297*da0073e9SAndroid Build Coastguard Worker
3298*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(actual.contiguous(), math_ref.contiguous(), atol=1e-3, rtol=1e-2)
3299*da0073e9SAndroid Build Coastguard Worker
3300*da0073e9SAndroid Build Coastguard Worker    @onlyCUDA
3301*da0073e9SAndroid Build Coastguard Worker    @skipIfRocm  # Nested tensor
3302*da0073e9SAndroid Build Coastguard Worker    @unittest.skipIf(not PLATFORM_SUPPORTS_FLASH_ATTENTION, "Does not support SDPA or pre-SM80 hardware")
3303*da0073e9SAndroid Build Coastguard Worker    @parametrize("batch_size", [8, 32])
3304*da0073e9SAndroid Build Coastguard Worker    @parametrize("max_seq_len_q", [32, 256])
3305*da0073e9SAndroid Build Coastguard Worker    @parametrize("max_seq_len_kv", [32, 256])
3306*da0073e9SAndroid Build Coastguard Worker    @parametrize("head_dim", [8, 64])
3307*da0073e9SAndroid Build Coastguard Worker    @parametrize("dropout_p", [0.0, 0.1])
3308*da0073e9SAndroid Build Coastguard Worker    @parametrize("dtype", [torch.float16])
3309*da0073e9SAndroid Build Coastguard Worker    @parametrize("scale", [None, "l1"])
3310*da0073e9SAndroid Build Coastguard Worker    @parametrize("is_causal", [True, False])
3311*da0073e9SAndroid Build Coastguard Worker    def test_flash_attention_vs_math_ref_grads_nestedtensor(self, device, batch_size: int, max_seq_len_q: int, max_seq_len_kv: int,
3312*da0073e9SAndroid Build Coastguard Worker                                                            head_dim: int, dropout_p: float, dtype: torch.dtype,
3313*da0073e9SAndroid Build Coastguard Worker                                                            scale: str, is_causal: bool):
3314*da0073e9SAndroid Build Coastguard Worker        if is_causal:
3315*da0073e9SAndroid Build Coastguard Worker            # TODO we should support this
3316*da0073e9SAndroid Build Coastguard Worker            self.assertRaisesRegex(RuntimeError, "Nested tensors for query / key are not supported when is_causal=True")
3317*da0073e9SAndroid Build Coastguard Worker            return
3318*da0073e9SAndroid Build Coastguard Worker        scale = scale if scale is None else (1 / head_dim)
3319*da0073e9SAndroid Build Coastguard Worker        n_heads = 4
3320*da0073e9SAndroid Build Coastguard Worker        seq_lens_q = torch.randint(low=1, high=max_seq_len_q, size=(batch_size,))
3321*da0073e9SAndroid Build Coastguard Worker        # Set one entry to max length
3322*da0073e9SAndroid Build Coastguard Worker        seq_lens_q[torch.randint(0, batch_size, size=(1,))] = max_seq_len_q
3323*da0073e9SAndroid Build Coastguard Worker        seq_lens_kv = torch.randint(low=1, high=max_seq_len_kv, size=(batch_size,))
3324*da0073e9SAndroid Build Coastguard Worker        seq_lens_kv[torch.randint(0, batch_size, size=(1,))] = max_seq_len_kv
3325*da0073e9SAndroid Build Coastguard Worker
3326*da0073e9SAndroid Build Coastguard Worker        def rand_nt(sequence_list, num_heads, head_dim):
3327*da0073e9SAndroid Build Coastguard Worker            tensors = [torch.rand((num_heads, seq_len, head_dim)) for seq_len in sequence_list]
3328*da0073e9SAndroid Build Coastguard Worker            return torch.nested.nested_tensor(tensors, requires_grad=True, device=device, dtype=dtype)
3329*da0073e9SAndroid Build Coastguard Worker
3330*da0073e9SAndroid Build Coastguard Worker        query = rand_nt(seq_lens_q, n_heads, head_dim)
3331*da0073e9SAndroid Build Coastguard Worker        key = rand_nt(seq_lens_kv, n_heads, head_dim)
3332*da0073e9SAndroid Build Coastguard Worker        value = rand_nt(seq_lens_kv, n_heads, head_dim)
3333*da0073e9SAndroid Build Coastguard Worker
3334*da0073e9SAndroid Build Coastguard Worker        # Run the math kernel on low precision references
3335*da0073e9SAndroid Build Coastguard Worker        query_ref_lp = query.clone().detach().requires_grad_(True)
3336*da0073e9SAndroid Build Coastguard Worker        key_ref_lp = key.clone().detach().requires_grad_(True)
3337*da0073e9SAndroid Build Coastguard Worker        value_ref_lp = value.clone().detach().requires_grad_(True)
3338*da0073e9SAndroid Build Coastguard Worker
3339*da0073e9SAndroid Build Coastguard Worker        query_ref = query.clone().detach().to(torch.float32).requires_grad_(True)
3340*da0073e9SAndroid Build Coastguard Worker        key_ref = key.clone().detach().to(torch.float32).requires_grad_(True)
3341*da0073e9SAndroid Build Coastguard Worker        value_ref = value.clone().detach().to(torch.float32).requires_grad_(True)
3342*da0073e9SAndroid Build Coastguard Worker
3343*da0073e9SAndroid Build Coastguard Worker        is_dropout = dropout_p > 0.0
3344*da0073e9SAndroid Build Coastguard Worker
3345*da0073e9SAndroid Build Coastguard Worker        if not is_dropout:
3346*da0073e9SAndroid Build Coastguard Worker            with sdpa_kernel(backends=[SDPBackend.FLASH_ATTENTION]):
3347*da0073e9SAndroid Build Coastguard Worker                out = F.scaled_dot_product_attention(query, key, value, dropout_p=dropout_p, is_causal=is_causal, scale=scale)
3348*da0073e9SAndroid Build Coastguard Worker            with sdpa_kernel(backends=[SDPBackend.MATH]):
3349*da0073e9SAndroid Build Coastguard Worker                # High Precision Math Reference
3350*da0073e9SAndroid Build Coastguard Worker                out_ref = F.scaled_dot_product_attention(
3351*da0073e9SAndroid Build Coastguard Worker                    query_ref, key_ref, value_ref, is_causal=is_causal, scale=scale)
3352*da0073e9SAndroid Build Coastguard Worker                # Low Precision Math Reference
3353*da0073e9SAndroid Build Coastguard Worker                out_lp_ref = F.scaled_dot_product_attention(
3354*da0073e9SAndroid Build Coastguard Worker                    query_ref_lp, key_ref_lp, value_ref_lp, is_causal=is_causal, scale=scale)
3355*da0073e9SAndroid Build Coastguard Worker        else:
3356*da0073e9SAndroid Build Coastguard Worker            # Create real output
3357*da0073e9SAndroid Build Coastguard Worker            output_tuple = torch.ops.aten._scaled_dot_product_flash_attention(
3358*da0073e9SAndroid Build Coastguard Worker                query, key, value, dropout_p=dropout_p, is_causal=is_causal,
3359*da0073e9SAndroid Build Coastguard Worker                scale=scale, return_debug_mask=is_dropout)
3360*da0073e9SAndroid Build Coastguard Worker            out = output_tuple[0]
3361*da0073e9SAndroid Build Coastguard Worker            dbug_mask = output_tuple[-1]
3362*da0073e9SAndroid Build Coastguard Worker
3363*da0073e9SAndroid Build Coastguard Worker            query_padding_mask = torch.arange(max_seq_len_q).unsqueeze(0).expand(
3364*da0073e9SAndroid Build Coastguard Worker                batch_size, max_seq_len_q
3365*da0073e9SAndroid Build Coastguard Worker            ) < seq_lens_q.unsqueeze(-1)
3366*da0073e9SAndroid Build Coastguard Worker            query_padding_mask = query_padding_mask.to("cuda")
3367*da0073e9SAndroid Build Coastguard Worker
3368*da0073e9SAndroid Build Coastguard Worker            key_padding_mask = torch.arange(max_seq_len_kv).unsqueeze(0).expand(
3369*da0073e9SAndroid Build Coastguard Worker                batch_size, max_seq_len_kv
3370*da0073e9SAndroid Build Coastguard Worker            ) < seq_lens_kv.unsqueeze(-1)
3371*da0073e9SAndroid Build Coastguard Worker            key_padding_mask = key_padding_mask.to("cuda")
3372*da0073e9SAndroid Build Coastguard Worker
3373*da0073e9SAndroid Build Coastguard Worker            softmax_mask = self.convert_flash_attn_S_to_softmax(
3374*da0073e9SAndroid Build Coastguard Worker                dbug_mask, max_seq_len_q, max_seq_len_kv, query_padding_mask, key_padding_mask, causal=is_causal)
3375*da0073e9SAndroid Build Coastguard Worker            dropout_mask = softmax_mask >= 0
3376*da0073e9SAndroid Build Coastguard Worker            nt_stack = []
3377*da0073e9SAndroid Build Coastguard Worker            for tensor_component in range(batch_size):
3378*da0073e9SAndroid Build Coastguard Worker                batch_stack = []
3379*da0073e9SAndroid Build Coastguard Worker                for head in range(n_heads):
3380*da0073e9SAndroid Build Coastguard Worker                    batch_stack.append(dropout_mask[tensor_component, head,
3381*da0073e9SAndroid Build Coastguard Worker                                                    0:seq_lens_q[tensor_component],
3382*da0073e9SAndroid Build Coastguard Worker                                                    0:seq_lens_kv[tensor_component]].unsqueeze(0))
3383*da0073e9SAndroid Build Coastguard Worker                nt_stack.append(torch.cat(batch_stack))
3384*da0073e9SAndroid Build Coastguard Worker            nested_dropout_mask = torch.nested.nested_tensor(nt_stack)
3385*da0073e9SAndroid Build Coastguard Worker            # High Precision Math Reference
3386*da0073e9SAndroid Build Coastguard Worker            out_ref = torch.ops.aten._scaled_dot_product_attention_math(
3387*da0073e9SAndroid Build Coastguard Worker                query_ref, key_ref, value_ref, dropout_p=dropout_p,
3388*da0073e9SAndroid Build Coastguard Worker                is_causal=is_causal, scale=scale, dropout_mask=nested_dropout_mask)[0]
3389*da0073e9SAndroid Build Coastguard Worker            # Low Precision Math Reference
3390*da0073e9SAndroid Build Coastguard Worker            out_lp_ref = torch.ops.aten._scaled_dot_product_attention_math(
3391*da0073e9SAndroid Build Coastguard Worker                query_ref_lp, key_ref_lp, value_ref_lp, dropout_p=dropout_p, is_causal=is_causal, scale=scale,
3392*da0073e9SAndroid Build Coastguard Worker                dropout_mask=nested_dropout_mask)[0]
3393*da0073e9SAndroid Build Coastguard Worker
3394*da0073e9SAndroid Build Coastguard Worker        upstream_grad = out.detach().clone().contiguous()
3395*da0073e9SAndroid Build Coastguard Worker
3396*da0073e9SAndroid Build Coastguard Worker        out.backward(upstream_grad)
3397*da0073e9SAndroid Build Coastguard Worker        out_ref.backward(upstream_grad.to(out_ref.dtype))
3398*da0073e9SAndroid Build Coastguard Worker        out_lp_ref.backward(upstream_grad.to(out_lp_ref.dtype))
3399*da0073e9SAndroid Build Coastguard Worker
3400*da0073e9SAndroid Build Coastguard Worker        # See [Note] Fused Tolerances above
3401*da0073e9SAndroid Build Coastguard Worker        output_ref_atol, output_ref_rtol = calculate_nt_tolerances(out_ref, out_lp_ref, out.dtype)
3402*da0073e9SAndroid Build Coastguard Worker        grad_q_ref_atol, grad_q_ref_rtol = calculate_nt_tolerances(query_ref.grad, query_ref_lp.grad,
3403*da0073e9SAndroid Build Coastguard Worker                                                                   query.grad.dtype, fudge_factor=4)
3404*da0073e9SAndroid Build Coastguard Worker        grad_k_ref_atol, grad_k_ref_rtol = calculate_nt_tolerances(key_ref.grad, key_ref_lp.grad, key.grad.dtype)
3405*da0073e9SAndroid Build Coastguard Worker        grad_v_ref_atol, grad_v_ref_rtol = calculate_nt_tolerances(value_ref.grad, value_ref_lp.grad, value.grad.dtype)
3406*da0073e9SAndroid Build Coastguard Worker
3407*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(out, out_ref.to(out.dtype), atol=output_ref_atol, rtol=output_ref_rtol)
3408*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(query.grad, query_ref.grad.to(query.grad.dtype),
3409*da0073e9SAndroid Build Coastguard Worker                         atol=grad_q_ref_atol, rtol=grad_q_ref_rtol)
3410*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(key.grad.contiguous(), key_ref.grad.contiguous().to(key.grad.dtype),
3411*da0073e9SAndroid Build Coastguard Worker                         atol=grad_k_ref_atol, rtol=grad_k_ref_rtol)
3412*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(value.grad, value_ref.grad.to(value.grad.dtype),
3413*da0073e9SAndroid Build Coastguard Worker                         atol=grad_v_ref_atol, rtol=grad_v_ref_rtol)
3414*da0073e9SAndroid Build Coastguard Worker
3415*da0073e9SAndroid Build Coastguard Workerclass TestAttnBias(NNTestCase):
3416*da0073e9SAndroid Build Coastguard Worker
3417*da0073e9SAndroid Build Coastguard Worker    def run_test(
3418*da0073e9SAndroid Build Coastguard Worker        self,
3419*da0073e9SAndroid Build Coastguard Worker        device,
3420*da0073e9SAndroid Build Coastguard Worker        make_q,
3421*da0073e9SAndroid Build Coastguard Worker        make_kv,
3422*da0073e9SAndroid Build Coastguard Worker        attn_bias=None,
3423*da0073e9SAndroid Build Coastguard Worker        forw_tolerances: Optional[Tolerances] = None,
3424*da0073e9SAndroid Build Coastguard Worker        grad_tolerances: Optional[Tolerances] = None,
3425*da0073e9SAndroid Build Coastguard Worker        backend=None,
3426*da0073e9SAndroid Build Coastguard Worker    ):
3427*da0073e9SAndroid Build Coastguard Worker        if backend is not None:
3428*da0073e9SAndroid Build Coastguard Worker            torch._dynamo.reset()
3429*da0073e9SAndroid Build Coastguard Worker
3430*da0073e9SAndroid Build Coastguard Worker        query, key, value = make_q(), make_kv(), make_kv()
3431*da0073e9SAndroid Build Coastguard Worker        query_prototype, key_prototype, value_prototype = query_key_value_clones(query, key, value)
3432*da0073e9SAndroid Build Coastguard Worker
3433*da0073e9SAndroid Build Coastguard Worker        realized = attn_bias._materialize(device) if attn_bias is not None else None
3434*da0073e9SAndroid Build Coastguard Worker        pytorch_output = scaled_dot_product_attention(
3435*da0073e9SAndroid Build Coastguard Worker            query, key, value, attn_mask=realized, dropout_p=0.0, is_causal=False
3436*da0073e9SAndroid Build Coastguard Worker        )
3437*da0073e9SAndroid Build Coastguard Worker
3438*da0073e9SAndroid Build Coastguard Worker        sdpa_op = (
3439*da0073e9SAndroid Build Coastguard Worker            torch.compile(scaled_dot_product_attention, backend=backend)
3440*da0073e9SAndroid Build Coastguard Worker            if backend is not None
3441*da0073e9SAndroid Build Coastguard Worker            else scaled_dot_product_attention
3442*da0073e9SAndroid Build Coastguard Worker        )
3443*da0073e9SAndroid Build Coastguard Worker        sdpa_output = sdpa_op(
3444*da0073e9SAndroid Build Coastguard Worker            query_prototype,
3445*da0073e9SAndroid Build Coastguard Worker            key_prototype,
3446*da0073e9SAndroid Build Coastguard Worker            value_prototype,
3447*da0073e9SAndroid Build Coastguard Worker            attn_mask=attn_bias,
3448*da0073e9SAndroid Build Coastguard Worker            dropout_p=0.0,
3449*da0073e9SAndroid Build Coastguard Worker            is_causal=False,
3450*da0073e9SAndroid Build Coastguard Worker            scale=None,
3451*da0073e9SAndroid Build Coastguard Worker        )
3452*da0073e9SAndroid Build Coastguard Worker
3453*da0073e9SAndroid Build Coastguard Worker        dOut = torch.randn_like(pytorch_output)
3454*da0073e9SAndroid Build Coastguard Worker        pytorch_output.backward(dOut)
3455*da0073e9SAndroid Build Coastguard Worker        sdpa_output.backward(dOut)
3456*da0073e9SAndroid Build Coastguard Worker
3457*da0073e9SAndroid Build Coastguard Worker        # Use default assert_close tolerances for dtypes
3458*da0073e9SAndroid Build Coastguard Worker        if forw_tolerances is None:
3459*da0073e9SAndroid Build Coastguard Worker            forw_tolerances = Tolerances(atol=None, rtol=None)
3460*da0073e9SAndroid Build Coastguard Worker        if grad_tolerances is None:
3461*da0073e9SAndroid Build Coastguard Worker            grad_tolerances = Tolerances(atol=None, rtol=None)
3462*da0073e9SAndroid Build Coastguard Worker
3463*da0073e9SAndroid Build Coastguard Worker        torch.testing.assert_close(pytorch_output, sdpa_output, rtol=forw_tolerances.rtol, atol=forw_tolerances.atol)
3464*da0073e9SAndroid Build Coastguard Worker        torch.testing.assert_close(query.grad, query_prototype.grad, rtol=grad_tolerances.rtol, atol=grad_tolerances.atol)
3465*da0073e9SAndroid Build Coastguard Worker        torch.testing.assert_close(key.grad, key_prototype.grad, rtol=grad_tolerances.rtol, atol=grad_tolerances.atol)
3466*da0073e9SAndroid Build Coastguard Worker        torch.testing.assert_close(value.grad, value_prototype.grad, rtol=grad_tolerances.rtol, atol=grad_tolerances.atol)
3467*da0073e9SAndroid Build Coastguard Worker
3468*da0073e9SAndroid Build Coastguard Worker    @skipIfRocm  # No support for the second variant for now
3469*da0073e9SAndroid Build Coastguard Worker    @parametrize("causal_variant", [CausalVariant.UPPER_LEFT, CausalVariant.LOWER_RIGHT])
3470*da0073e9SAndroid Build Coastguard Worker    @parametrize(
3471*da0073e9SAndroid Build Coastguard Worker        "shape",
3472*da0073e9SAndroid Build Coastguard Worker        [(16, 16, 128, 128, 16), (16, 16, 128, 256, 32), (16, 16, 256, 128, 32), (1, 1, 23, 56, 15)],
3473*da0073e9SAndroid Build Coastguard Worker    )
3474*da0073e9SAndroid Build Coastguard Worker    def test_causal_variants(self, device, causal_variant: CausalVariant, shape: List[Tuple[int]]):
3475*da0073e9SAndroid Build Coastguard Worker        make_tensor = partial(
3476*da0073e9SAndroid Build Coastguard Worker            torch.rand, device=device, dtype=torch.float16, requires_grad=True
3477*da0073e9SAndroid Build Coastguard Worker        )
3478*da0073e9SAndroid Build Coastguard Worker
3479*da0073e9SAndroid Build Coastguard Worker        bsz, num_heads, seq_len_q, seq_len_kv, head_dim = shape
3480*da0073e9SAndroid Build Coastguard Worker        make_q_tensor = partial(make_tensor, SdpaShape(bsz, num_heads, seq_len_q, head_dim))
3481*da0073e9SAndroid Build Coastguard Worker        make_kv_tensor = partial(make_tensor, SdpaShape(bsz, num_heads, seq_len_kv, head_dim))
3482*da0073e9SAndroid Build Coastguard Worker        if causal_variant == CausalVariant.LOWER_RIGHT and seq_len_q > seq_len_kv:
3483*da0073e9SAndroid Build Coastguard Worker            self.skipTest(
3484*da0073e9SAndroid Build Coastguard Worker                "Lower right causal mask will produce NaNs in the output when seq_len_q > seq_len_kv!"
3485*da0073e9SAndroid Build Coastguard Worker            )
3486*da0073e9SAndroid Build Coastguard Worker
3487*da0073e9SAndroid Build Coastguard Worker        forw_tol = Tolerances(1e-3, 1e-3)
3488*da0073e9SAndroid Build Coastguard Worker        grad_tol = Tolerances(5e-3, 5e-3)
3489*da0073e9SAndroid Build Coastguard Worker
3490*da0073e9SAndroid Build Coastguard Worker        if causal_variant == CausalVariant.UPPER_LEFT:
3491*da0073e9SAndroid Build Coastguard Worker            attn_bias = causal_upper_left(seq_len_q, seq_len_kv)
3492*da0073e9SAndroid Build Coastguard Worker        else:
3493*da0073e9SAndroid Build Coastguard Worker            attn_bias = causal_lower_right(seq_len_q, seq_len_kv)
3494*da0073e9SAndroid Build Coastguard Worker
3495*da0073e9SAndroid Build Coastguard Worker        self.run_test(device, make_q_tensor, make_kv_tensor, attn_bias, forw_tol, grad_tol, backend=None)
3496*da0073e9SAndroid Build Coastguard Worker
3497*da0073e9SAndroid Build Coastguard Worker    @skipIfRocm  # CausalVariant
3498*da0073e9SAndroid Build Coastguard Worker    @parametrize("causal_variant", [CausalVariant.UPPER_LEFT, CausalVariant.LOWER_RIGHT])
3499*da0073e9SAndroid Build Coastguard Worker    @parametrize(
3500*da0073e9SAndroid Build Coastguard Worker        "shape",
3501*da0073e9SAndroid Build Coastguard Worker        [(16, 16, 128, 128, 16), (16, 16, 128, 256, 32), (16, 16, 256, 128, 32), (1, 1, 23, 56, 15)],
3502*da0073e9SAndroid Build Coastguard Worker    )
3503*da0073e9SAndroid Build Coastguard Worker    @unittest.skipIf(IS_WINDOWS, "torch.compile is not supported on windows")
3504*da0073e9SAndroid Build Coastguard Worker    @skipIfTorchDynamo("This function already calls torch.compile.")
3505*da0073e9SAndroid Build Coastguard Worker    def test_causal_variants_compile(self, device, causal_variant: CausalVariant, shape: List[Tuple[int]]):
3506*da0073e9SAndroid Build Coastguard Worker        cnts = CompileCounterWithBackend("aot_eager")
3507*da0073e9SAndroid Build Coastguard Worker        make_tensor = partial(
3508*da0073e9SAndroid Build Coastguard Worker            torch.rand, device=device, dtype=torch.float16, requires_grad=True
3509*da0073e9SAndroid Build Coastguard Worker        )
3510*da0073e9SAndroid Build Coastguard Worker
3511*da0073e9SAndroid Build Coastguard Worker        bsz, num_heads, seq_len_q, seq_len_kv, head_dim = shape
3512*da0073e9SAndroid Build Coastguard Worker        make_q_tensor = partial(make_tensor, SdpaShape(bsz, num_heads, seq_len_q, head_dim))
3513*da0073e9SAndroid Build Coastguard Worker        make_kv_tensor = partial(make_tensor, SdpaShape(bsz, num_heads, seq_len_kv, head_dim))
3514*da0073e9SAndroid Build Coastguard Worker        if causal_variant == CausalVariant.LOWER_RIGHT and seq_len_q > seq_len_kv:
3515*da0073e9SAndroid Build Coastguard Worker            self.skipTest(
3516*da0073e9SAndroid Build Coastguard Worker                "Lower right causal mask will produce NaNs in the output when seq_len_q > seq_len_kv!"
3517*da0073e9SAndroid Build Coastguard Worker            )
3518*da0073e9SAndroid Build Coastguard Worker        forw_tol = Tolerances(1e-3, 1e-3)
3519*da0073e9SAndroid Build Coastguard Worker        grad_tol = Tolerances(5e-3, 5e-3)
3520*da0073e9SAndroid Build Coastguard Worker
3521*da0073e9SAndroid Build Coastguard Worker        if causal_variant == CausalVariant.UPPER_LEFT:
3522*da0073e9SAndroid Build Coastguard Worker            attn_bias = causal_upper_left(seq_len_q, seq_len_kv)
3523*da0073e9SAndroid Build Coastguard Worker        else:
3524*da0073e9SAndroid Build Coastguard Worker            attn_bias = causal_lower_right(seq_len_q, seq_len_kv)
3525*da0073e9SAndroid Build Coastguard Worker
3526*da0073e9SAndroid Build Coastguard Worker        self.run_test(device, make_q_tensor, make_kv_tensor, attn_bias, forw_tol, grad_tol, backend=cnts)
3527*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(cnts.frame_count, 1, "Compiled graph should have 1 frame!")
3528*da0073e9SAndroid Build Coastguard Worker
3529*da0073e9SAndroid Build Coastguard Worker    @parametrize("shape", [(16, 16, 128, 128, 16), (16, 16, 128, 256, 32), (16, 16, 256, 128, 32), (1, 1, 23, 56, 15)])
3530*da0073e9SAndroid Build Coastguard Worker    def test_is_causal_equals_upper_left(self, device, shape: List[Tuple[int]]):
3531*da0073e9SAndroid Build Coastguard Worker        make_tensor = partial(
3532*da0073e9SAndroid Build Coastguard Worker            torch.rand, device=device, dtype=torch.float16, requires_grad=True
3533*da0073e9SAndroid Build Coastguard Worker        )
3534*da0073e9SAndroid Build Coastguard Worker
3535*da0073e9SAndroid Build Coastguard Worker        bsz, num_heads, seq_len_q, seq_len_kv, head_dim = shape
3536*da0073e9SAndroid Build Coastguard Worker        make_q_tensor = partial(make_tensor, SdpaShape(bsz, num_heads, seq_len_q, head_dim))
3537*da0073e9SAndroid Build Coastguard Worker        make_kv_tensor = partial(make_tensor, SdpaShape(bsz, num_heads, seq_len_kv, head_dim))
3538*da0073e9SAndroid Build Coastguard Worker
3539*da0073e9SAndroid Build Coastguard Worker        forw_tol = Tolerances(1e-3, 1e-3)
3540*da0073e9SAndroid Build Coastguard Worker        grad_tol = Tolerances(5e-3, 5e-3)
3541*da0073e9SAndroid Build Coastguard Worker
3542*da0073e9SAndroid Build Coastguard Worker        query = make_q_tensor()
3543*da0073e9SAndroid Build Coastguard Worker        key = make_kv_tensor()
3544*da0073e9SAndroid Build Coastguard Worker        value = make_kv_tensor()
3545*da0073e9SAndroid Build Coastguard Worker        attn_bias = causal_upper_left(seq_len_q, seq_len_kv)
3546*da0073e9SAndroid Build Coastguard Worker
3547*da0073e9SAndroid Build Coastguard Worker        out_attn_bias = scaled_dot_product_attention(query, key, value, attn_mask=attn_bias, dropout_p=0.0)
3548*da0073e9SAndroid Build Coastguard Worker        out_is_causal = scaled_dot_product_attention(query, key, value, is_causal=True, dropout_p=0.0)
3549*da0073e9SAndroid Build Coastguard Worker        torch.testing.assert_close(out_attn_bias, out_is_causal, rtol=forw_tol.rtol, atol=forw_tol.atol)
3550*da0073e9SAndroid Build Coastguard Worker
3551*da0073e9SAndroid Build Coastguard Worker    def test_is_causal_and_mask_fails(self, device):
3552*da0073e9SAndroid Build Coastguard Worker        make_tensor = partial(
3553*da0073e9SAndroid Build Coastguard Worker            torch.rand, device=device, dtype=torch.float16, requires_grad=True
3554*da0073e9SAndroid Build Coastguard Worker        )
3555*da0073e9SAndroid Build Coastguard Worker        make_q_tensor = partial(make_tensor, SdpaShape(16, 16, 128, 16))
3556*da0073e9SAndroid Build Coastguard Worker        make_kv_tensor = partial(make_tensor, SdpaShape(16, 16, 128, 16))
3557*da0073e9SAndroid Build Coastguard Worker
3558*da0073e9SAndroid Build Coastguard Worker        query = make_q_tensor()
3559*da0073e9SAndroid Build Coastguard Worker        key = make_kv_tensor()
3560*da0073e9SAndroid Build Coastguard Worker        value = make_kv_tensor()
3561*da0073e9SAndroid Build Coastguard Worker        attn_bias = causal_upper_left(128, 128)
3562*da0073e9SAndroid Build Coastguard Worker
3563*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(ValueError, "CausalBias should not be used with causal=True"):
3564*da0073e9SAndroid Build Coastguard Worker            scaled_dot_product_attention(query, key, value, attn_mask=attn_bias, is_causal=True, dropout_p=0.0)
3565*da0073e9SAndroid Build Coastguard Worker
3566*da0073e9SAndroid Build Coastguard Workerif NOTEST_CPU:
3567*da0073e9SAndroid Build Coastguard Worker    device_types = ("cuda", )
3568*da0073e9SAndroid Build Coastguard Workerelse:
3569*da0073e9SAndroid Build Coastguard Worker    device_types = ("cpu", "cuda")
3570*da0073e9SAndroid Build Coastguard Worker
3571*da0073e9SAndroid Build Coastguard Workerinstantiate_device_type_tests(TestTransformers, globals(), only_for=device_types)
3572*da0073e9SAndroid Build Coastguard Workerinstantiate_device_type_tests(TestSDPAFailureModes, globals(), only_for=device_types)
3573*da0073e9SAndroid Build Coastguard Workerinstantiate_device_type_tests(TestSDPA, globals(), only_for=device_types)
3574*da0073e9SAndroid Build Coastguard Workerinstantiate_device_type_tests(TestSDPACudaOnly, globals(), only_for=("cuda"))
3575*da0073e9SAndroid Build Coastguard Workerinstantiate_device_type_tests(TestAttnBias, globals(), only_for=device_types)
3576*da0073e9SAndroid Build Coastguard Worker
3577*da0073e9SAndroid Build Coastguard Workerif __name__ == '__main__':
3578*da0073e9SAndroid Build Coastguard Worker    run_tests()
3579