1*da0073e9SAndroid Build Coastguard Worker# Owner(s): ["module: nn"] 2*da0073e9SAndroid Build Coastguard Worker 3*da0073e9SAndroid Build Coastguard Workerimport contextlib 4*da0073e9SAndroid Build Coastguard Workerfrom functools import partial 5*da0073e9SAndroid Build Coastguard Workerfrom collections import namedtuple 6*da0073e9SAndroid Build Coastguard Workerimport sys 7*da0073e9SAndroid Build Coastguard Workerimport torch 8*da0073e9SAndroid Build Coastguard Workerimport torch.nn as nn 9*da0073e9SAndroid Build Coastguard Workerimport torch.nn.functional as F 10*da0073e9SAndroid Build Coastguard Workerfrom torch.nn.functional import scaled_dot_product_attention 11*da0073e9SAndroid Build Coastguard Workerfrom torch.nn.attention import sdpa_kernel, SDPBackend 12*da0073e9SAndroid Build Coastguard Workerfrom torch.nn.attention.bias import CausalVariant, causal_lower_right, causal_upper_left 13*da0073e9SAndroid Build Coastguard Workerfrom torch.nn.parameter import Parameter 14*da0073e9SAndroid Build Coastguard Workerimport unittest 15*da0073e9SAndroid Build Coastguard Workerfrom unittest.mock import patch, MagicMock, ANY 16*da0073e9SAndroid Build Coastguard Workerimport math 17*da0073e9SAndroid Build Coastguard Workerimport torch.optim as optim 18*da0073e9SAndroid Build Coastguard Workerfrom torch.testing._internal.common_device_type import instantiate_device_type_tests, onlyCUDA, onlyCPU 19*da0073e9SAndroid Build Coastguard Workerfrom typing import List, Tuple, Optional 20*da0073e9SAndroid Build Coastguard Workerfrom torch.testing._internal.common_nn import NNTestCase 21*da0073e9SAndroid Build Coastguard Workerfrom torch.testing._internal.common_utils import ( 22*da0073e9SAndroid Build Coastguard Worker TEST_WITH_ROCM, 23*da0073e9SAndroid Build Coastguard Worker skipIfRocm, 24*da0073e9SAndroid Build Coastguard Worker skipIfTorchDynamo, 25*da0073e9SAndroid Build Coastguard Worker TEST_FAIRSEQ, 26*da0073e9SAndroid Build Coastguard Worker run_tests, 27*da0073e9SAndroid Build Coastguard Worker parametrize, 28*da0073e9SAndroid Build Coastguard Worker freeze_rng_state, 29*da0073e9SAndroid Build Coastguard Worker TEST_WITH_CROSSREF, 30*da0073e9SAndroid Build Coastguard Worker slowTest, 31*da0073e9SAndroid Build Coastguard Worker set_default_dtype, 32*da0073e9SAndroid Build Coastguard Worker gradcheck, 33*da0073e9SAndroid Build Coastguard Worker make_tensor, 34*da0073e9SAndroid Build Coastguard Worker NOTEST_CPU, 35*da0073e9SAndroid Build Coastguard Worker IS_WINDOWS, 36*da0073e9SAndroid Build Coastguard Worker TEST_WITH_TORCHDYNAMO, 37*da0073e9SAndroid Build Coastguard Worker) 38*da0073e9SAndroid Build Coastguard Workerfrom torch._dynamo.testing import CompileCounterWithBackend 39*da0073e9SAndroid Build Coastguard Worker 40*da0073e9SAndroid Build Coastguard Worker 41*da0073e9SAndroid Build Coastguard Workerfrom torch.testing._internal.common_methods_invocations import wrapper_set_seed 42*da0073e9SAndroid Build Coastguard Workerfrom torch.testing._internal.common_cuda import ( 43*da0073e9SAndroid Build Coastguard Worker IS_JETSON, SM80OrLater, PLATFORM_SUPPORTS_FLASH_ATTENTION, 44*da0073e9SAndroid Build Coastguard Worker PLATFORM_SUPPORTS_MEM_EFF_ATTENTION, 45*da0073e9SAndroid Build Coastguard Worker PLATFORM_SUPPORTS_FUSED_ATTENTION, 46*da0073e9SAndroid Build Coastguard Worker PLATFORM_SUPPORTS_CUDNN_ATTENTION 47*da0073e9SAndroid Build Coastguard Worker) 48*da0073e9SAndroid Build Coastguard Worker 49*da0073e9SAndroid Build Coastguard Workerif TEST_FAIRSEQ: 50*da0073e9SAndroid Build Coastguard Worker import fairseq.models.transformer as fairseq_transformer 51*da0073e9SAndroid Build Coastguard Worker 52*da0073e9SAndroid Build Coastguard WorkerSdpaShape = namedtuple('Sdpa_Shape', ['batch', 'num_heads', 'seq_len', 'head_dim']) 53*da0073e9SAndroid Build Coastguard WorkerTolerances = namedtuple('Tolerances', ['atol', 'rtol']) 54*da0073e9SAndroid Build Coastguard Worker 55*da0073e9SAndroid Build Coastguard Worker@contextlib.contextmanager 56*da0073e9SAndroid Build Coastguard Workerdef use_deterministic_algorithims(mode: bool, warn_only: bool): 57*da0073e9SAndroid Build Coastguard Worker r""" 58*da0073e9SAndroid Build Coastguard Worker This context manager can be used to temporarily enable or disable deterministic algorithms. 59*da0073e9SAndroid Build Coastguard Worker Upon exiting the context manager, the previous state of the flag will be restored. 60*da0073e9SAndroid Build Coastguard Worker """ 61*da0073e9SAndroid Build Coastguard Worker previous_mode: bool = torch.are_deterministic_algorithms_enabled() 62*da0073e9SAndroid Build Coastguard Worker previous_warn_only: bool = torch.is_deterministic_algorithms_warn_only_enabled() 63*da0073e9SAndroid Build Coastguard Worker try: 64*da0073e9SAndroid Build Coastguard Worker torch.use_deterministic_algorithms(mode, warn_only=warn_only) 65*da0073e9SAndroid Build Coastguard Worker yield {} 66*da0073e9SAndroid Build Coastguard Worker finally: 67*da0073e9SAndroid Build Coastguard Worker torch.use_deterministic_algorithms(previous_mode, warn_only=previous_warn_only) 68*da0073e9SAndroid Build Coastguard Worker 69*da0073e9SAndroid Build Coastguard Worker 70*da0073e9SAndroid Build Coastguard Worker# Found in torch/testing/_comparison.py 71*da0073e9SAndroid Build Coastguard Workerdefault_atol = {torch.float16: 1e-3, torch.bfloat16: 1e-3, torch.float32: 1e-5} 72*da0073e9SAndroid Build Coastguard Workerdefault_rtol = {torch.float16: 1e-3, torch.bfloat16: 1.6e-2, torch.float32: 1.3e-6} 73*da0073e9SAndroid Build Coastguard Worker 74*da0073e9SAndroid Build Coastguard WorkerisSM8XDevice = torch.cuda.is_available() and torch.cuda.get_device_capability() in [(8, 6), (8, 7), (8, 9)] 75*da0073e9SAndroid Build Coastguard WorkerisSM90Device = torch.cuda.is_available() and torch.cuda.get_device_capability() == (9, 0) 76*da0073e9SAndroid Build Coastguard WorkerisSM5xDevice = torch.cuda.is_available() and torch.cuda.get_device_capability()[0] == 5 77*da0073e9SAndroid Build Coastguard WorkerisLessThanSM80Device = torch.cuda.is_available() and torch.cuda.get_device_capability()[0] < 8 78*da0073e9SAndroid Build Coastguard Worker 79*da0073e9SAndroid Build Coastguard Workerdef get_rtol(true_value: torch.Tensor, computed_value: torch.Tensor) -> float: 80*da0073e9SAndroid Build Coastguard Worker deviation = true_value - computed_value 81*da0073e9SAndroid Build Coastguard Worker deviation = torch.abs(deviation / true_value) 82*da0073e9SAndroid Build Coastguard Worker # Fill in the nans with the default rtol 83*da0073e9SAndroid Build Coastguard Worker torch.nan_to_num_(deviation, nan=default_rtol[computed_value.dtype]) 84*da0073e9SAndroid Build Coastguard Worker return deviation.max().item() 85*da0073e9SAndroid Build Coastguard Worker 86*da0073e9SAndroid Build Coastguard Worker 87*da0073e9SAndroid Build Coastguard Workerdef get_atol(true_value: torch.Tensor, computed_value: torch.Tensor) -> float: 88*da0073e9SAndroid Build Coastguard Worker deviation = true_value - computed_value 89*da0073e9SAndroid Build Coastguard Worker atol = torch.abs(deviation).max().item() 90*da0073e9SAndroid Build Coastguard Worker return atol 91*da0073e9SAndroid Build Coastguard Worker 92*da0073e9SAndroid Build Coastguard Worker 93*da0073e9SAndroid Build Coastguard Workerdef get_tolerances( 94*da0073e9SAndroid Build Coastguard Worker true_value: torch.Tensor, 95*da0073e9SAndroid Build Coastguard Worker computed_value: torch.Tensor, 96*da0073e9SAndroid Build Coastguard Worker fudge_factor: Optional[float] = None, 97*da0073e9SAndroid Build Coastguard Worker) -> Tuple[float, float]: 98*da0073e9SAndroid Build Coastguard Worker """Returns the absolute and relative tolerances for comparing two tensors.""" 99*da0073e9SAndroid Build Coastguard Worker fudge_factor = fudge_factor if fudge_factor is not None else 1.0 100*da0073e9SAndroid Build Coastguard Worker atol = get_atol(true_value, computed_value) 101*da0073e9SAndroid Build Coastguard Worker rtol = get_rtol(true_value, computed_value) 102*da0073e9SAndroid Build Coastguard Worker 103*da0073e9SAndroid Build Coastguard Worker atol = fudge_factor * max(atol, default_atol[computed_value.dtype]) 104*da0073e9SAndroid Build Coastguard Worker rtol = fudge_factor * max(rtol, default_rtol[computed_value.dtype]) 105*da0073e9SAndroid Build Coastguard Worker # torch.isclose() has weird behavior around see: 106*da0073e9SAndroid Build Coastguard Worker # https://github.com/pytorch/pytorch/issues/102400 107*da0073e9SAndroid Build Coastguard Worker if rtol > 1e30: 108*da0073e9SAndroid Build Coastguard Worker rtol = default_rtol[computed_value.dtype] 109*da0073e9SAndroid Build Coastguard Worker return atol, rtol 110*da0073e9SAndroid Build Coastguard Worker 111*da0073e9SAndroid Build Coastguard Worker 112*da0073e9SAndroid Build Coastguard Workerdef query_key_value_clones(query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, dtype: torch.dtype = None): 113*da0073e9SAndroid Build Coastguard Worker """ Clones the query, key, and value tensors and moves them to the specified dtype. """ 114*da0073e9SAndroid Build Coastguard Worker if dtype is None: 115*da0073e9SAndroid Build Coastguard Worker dtype = query.dtype 116*da0073e9SAndroid Build Coastguard Worker query_ref = query.clone().detach().to(dtype).requires_grad_(query.requires_grad) 117*da0073e9SAndroid Build Coastguard Worker key_ref = key.clone().detach().to(dtype).requires_grad_(key.requires_grad) 118*da0073e9SAndroid Build Coastguard Worker value_ref = value.clone().detach().to(dtype).requires_grad_(value.requires_grad) 119*da0073e9SAndroid Build Coastguard Worker return query_ref, key_ref, value_ref 120*da0073e9SAndroid Build Coastguard Worker 121*da0073e9SAndroid Build Coastguard Workerdef get_platform_specific_sdpa(): 122*da0073e9SAndroid Build Coastguard Worker ret = [] 123*da0073e9SAndroid Build Coastguard Worker if PLATFORM_SUPPORTS_FLASH_ATTENTION: 124*da0073e9SAndroid Build Coastguard Worker ret.append(SDPBackend.FLASH_ATTENTION) 125*da0073e9SAndroid Build Coastguard Worker if PLATFORM_SUPPORTS_MEM_EFF_ATTENTION: 126*da0073e9SAndroid Build Coastguard Worker ret.append(SDPBackend.EFFICIENT_ATTENTION) 127*da0073e9SAndroid Build Coastguard Worker if PLATFORM_SUPPORTS_CUDNN_ATTENTION: 128*da0073e9SAndroid Build Coastguard Worker ret.append(SDPBackend.CUDNN_ATTENTION) 129*da0073e9SAndroid Build Coastguard Worker if not ret: 130*da0073e9SAndroid Build Coastguard Worker # Add a placeholder, an empty list causes "An empty arg_values was passed to @parametrize" 131*da0073e9SAndroid Build Coastguard Worker ret.append(SDPBackend.EFFICIENT_ATTENTION) 132*da0073e9SAndroid Build Coastguard Worker return ret 133*da0073e9SAndroid Build Coastguard Worker 134*da0073e9SAndroid Build Coastguard WorkerPLATFORM_SPECIFIC_SDPA = get_platform_specific_sdpa() 135*da0073e9SAndroid Build Coastguard Worker# Indicate the Efficient attention backend can support: 136*da0073e9SAndroid Build Coastguard Worker# 1. sequence longher than 512 137*da0073e9SAndroid Build Coastguard Worker# 2. head dimsion larger than 64 138*da0073e9SAndroid Build Coastguard WorkerMEM_EFF_CAPABILITY_MATCHES_SM80 = SM80OrLater or TEST_WITH_ROCM 139*da0073e9SAndroid Build Coastguard Worker 140*da0073e9SAndroid Build Coastguard Workerdef rand_sdpa_tensor(shape: SdpaShape, device: str, dtype: torch.dtype, type: str, 141*da0073e9SAndroid Build Coastguard Worker requires_grad: bool = False, packed: bool = False) -> torch.Tensor: 142*da0073e9SAndroid Build Coastguard Worker """Creates rand dense or nested tensor with given shape and type. 143*da0073e9SAndroid Build Coastguard Worker 144*da0073e9SAndroid Build Coastguard Worker Args: 145*da0073e9SAndroid Build Coastguard Worker shape (Tuple[int]): Shape of Tensor to construct 146*da0073e9SAndroid Build Coastguard Worker device (str): which device to create tensor on 147*da0073e9SAndroid Build Coastguard Worker dtype (torch.dtype): Tensors' dtype 148*da0073e9SAndroid Build Coastguard Worker type (str): Nested or Dense 149*da0073e9SAndroid Build Coastguard Worker requires_grad (bool, optional): Tensors grad status. Defaults to False. 150*da0073e9SAndroid Build Coastguard Worker packed (bool, optional): Whether to create a single QKV packed or not. Defaults to False. 151*da0073e9SAndroid Build Coastguard Worker 152*da0073e9SAndroid Build Coastguard Worker Returns: 153*da0073e9SAndroid Build Coastguard Worker torch.Tensor: A new tensor 154*da0073e9SAndroid Build Coastguard Worker """ 155*da0073e9SAndroid Build Coastguard Worker batch, num_heads, seq_len, head_dim = shape.batch, shape.num_heads, shape.seq_len, shape.head_dim 156*da0073e9SAndroid Build Coastguard Worker if type == "nested": 157*da0073e9SAndroid Build Coastguard Worker if isinstance(seq_len, list): 158*da0073e9SAndroid Build Coastguard Worker def _size(i): 159*da0073e9SAndroid Build Coastguard Worker return (seq_len[i], num_heads, head_dim) if not packed else (seq_len[i], 3 * num_heads * head_dim) 160*da0073e9SAndroid Build Coastguard Worker 161*da0073e9SAndroid Build Coastguard Worker return torch.nested.nested_tensor([ 162*da0073e9SAndroid Build Coastguard Worker torch.randn(_size(i), device=device, dtype=dtype, requires_grad=requires_grad) 163*da0073e9SAndroid Build Coastguard Worker for i in range(batch)]) 164*da0073e9SAndroid Build Coastguard Worker else: 165*da0073e9SAndroid Build Coastguard Worker size = (seq_len, num_heads, head_dim) if not packed else (seq_len, 3 * num_heads * head_dim) 166*da0073e9SAndroid Build Coastguard Worker return torch.nested.nested_tensor([ 167*da0073e9SAndroid Build Coastguard Worker torch.randn(size, device=device, dtype=dtype, requires_grad=requires_grad) 168*da0073e9SAndroid Build Coastguard Worker for _ in range(batch)]) 169*da0073e9SAndroid Build Coastguard Worker else: 170*da0073e9SAndroid Build Coastguard Worker assert (isinstance(seq_len, int)) 171*da0073e9SAndroid Build Coastguard Worker size = (batch, seq_len, num_heads, head_dim) if not packed else (batch, seq_len, 3 * num_heads * head_dim) 172*da0073e9SAndroid Build Coastguard Worker return torch.randn(size, device=device, dtype=dtype, requires_grad=requires_grad) 173*da0073e9SAndroid Build Coastguard Worker 174*da0073e9SAndroid Build Coastguard Workerdef calculate_nt_tolerances(nt_ref_hp, nt_ref_lp, default_dtype, fudge_factor=1): 175*da0073e9SAndroid Build Coastguard Worker # TODO use NT ops when we have implemented Max for NestedTensor instead of unrolling 176*da0073e9SAndroid Build Coastguard Worker ref_atol = default_atol[default_dtype] 177*da0073e9SAndroid Build Coastguard Worker ref_rtol = default_rtol[default_dtype] 178*da0073e9SAndroid Build Coastguard Worker for tensor_component_ref, tensor_component_ref_lp in zip(nt_ref_hp.unbind(), nt_ref_lp.unbind()): 179*da0073e9SAndroid Build Coastguard Worker ref_atol = max((fudge_factor * torch.abs(tensor_component_ref - tensor_component_ref_lp)).max().item(), ref_atol) 180*da0073e9SAndroid Build Coastguard Worker ref_rtol = max(get_rtol(tensor_component_ref, tensor_component_ref_lp), ref_rtol) 181*da0073e9SAndroid Build Coastguard Worker return ref_atol, ref_rtol 182*da0073e9SAndroid Build Coastguard Worker 183*da0073e9SAndroid Build Coastguard Workerclass TestTransformers(NNTestCase): 184*da0073e9SAndroid Build Coastguard Worker _do_cuda_memory_leak_check = True 185*da0073e9SAndroid Build Coastguard Worker _do_cuda_non_default_stream = True 186*da0073e9SAndroid Build Coastguard Worker 187*da0073e9SAndroid Build Coastguard Worker @onlyCUDA 188*da0073e9SAndroid Build Coastguard Worker @unittest.skip("4D mask not supported yet - activate when 4D mask supported") 189*da0073e9SAndroid Build Coastguard Worker def test_self_attn_TxT_attn_mask(self, device): 190*da0073e9SAndroid Build Coastguard Worker embed_dim = 16 191*da0073e9SAndroid Build Coastguard Worker num_heads = 4 192*da0073e9SAndroid Build Coastguard Worker batch_size = 10 193*da0073e9SAndroid Build Coastguard Worker tgt_len = 16 194*da0073e9SAndroid Build Coastguard Worker 195*da0073e9SAndroid Build Coastguard Worker query = torch.rand(batch_size, tgt_len, embed_dim, device=device) # [N, T, D] 196*da0073e9SAndroid Build Coastguard Worker attn_mask = torch.randint(0, 2, (tgt_len, tgt_len)).cuda().float() # [T, T] 197*da0073e9SAndroid Build Coastguard Worker attn_mask = attn_mask.masked_fill(attn_mask == 0, float('-inf')).masked_fill(attn_mask == 1, 0.0) 198*da0073e9SAndroid Build Coastguard Worker 199*da0073e9SAndroid Build Coastguard Worker attn_mask_4d = attn_mask.expand(batch_size, num_heads, tgt_len, tgt_len) 200*da0073e9SAndroid Build Coastguard Worker 201*da0073e9SAndroid Build Coastguard Worker mta_model = torch.nn.MultiheadAttention(embed_dim, num_heads, batch_first=True).cuda() 202*da0073e9SAndroid Build Coastguard Worker mta_model.eval() 203*da0073e9SAndroid Build Coastguard Worker 204*da0073e9SAndroid Build Coastguard Worker # Generate 3D results 205*da0073e9SAndroid Build Coastguard Worker with torch.inference_mode(): 206*da0073e9SAndroid Build Coastguard Worker output_mask_4d = mta_model(query, query, query, attn_mask=attn_mask_4d)[0] 207*da0073e9SAndroid Build Coastguard Worker output_mask_4d = output_mask_4d.transpose(0, 1) # [N, T, D] 208*da0073e9SAndroid Build Coastguard Worker 209*da0073e9SAndroid Build Coastguard Worker output_mask_TxT = mta_model(query, query, query, attn_mask=attn_mask)[0] 210*da0073e9SAndroid Build Coastguard Worker output_mask_TxT = output_mask_TxT.transpose(0, 1) # [N, T, D] 211*da0073e9SAndroid Build Coastguard Worker 212*da0073e9SAndroid Build Coastguard Worker self.assertEqual(output_mask_4d, output_mask_TxT) 213*da0073e9SAndroid Build Coastguard Worker 214*da0073e9SAndroid Build Coastguard Worker @slowTest 215*da0073e9SAndroid Build Coastguard Worker def test_train_with_pad_and_catch_error(self, device): 216*da0073e9SAndroid Build Coastguard Worker iters = 100 217*da0073e9SAndroid Build Coastguard Worker pad_mask = torch.tensor([[1, 1, 0, 0]], dtype=torch.bool).to(device) 218*da0073e9SAndroid Build Coastguard Worker layer = nn.TransformerEncoderLayer( 219*da0073e9SAndroid Build Coastguard Worker d_model=2, 220*da0073e9SAndroid Build Coastguard Worker dim_feedforward=4, 221*da0073e9SAndroid Build Coastguard Worker nhead=2, 222*da0073e9SAndroid Build Coastguard Worker batch_first=True, 223*da0073e9SAndroid Build Coastguard Worker activation="gelu", 224*da0073e9SAndroid Build Coastguard Worker dropout=0, 225*da0073e9SAndroid Build Coastguard Worker ) 226*da0073e9SAndroid Build Coastguard Worker criterion = nn.MSELoss() 227*da0073e9SAndroid Build Coastguard Worker encoder = nn.TransformerEncoder(layer, 2).to(device) 228*da0073e9SAndroid Build Coastguard Worker optimizer = optim.SGD(encoder.parameters(), lr=0.1, momentum=0.9) 229*da0073e9SAndroid Build Coastguard Worker encoder.train() 230*da0073e9SAndroid Build Coastguard Worker for i in range(iters): 231*da0073e9SAndroid Build Coastguard Worker encoder.train() 232*da0073e9SAndroid Build Coastguard Worker optimizer.zero_grad() 233*da0073e9SAndroid Build Coastguard Worker inputs = torch.cat([torch.randn(1, 2, 2), torch.zeros(1, 2, 2)], dim=1).to(device) 234*da0073e9SAndroid Build Coastguard Worker 235*da0073e9SAndroid Build Coastguard Worker outputs = encoder(inputs, src_key_padding_mask=pad_mask) 236*da0073e9SAndroid Build Coastguard Worker 237*da0073e9SAndroid Build Coastguard Worker loss = criterion(outputs[:, 0:2, :], inputs[:, 0:2, :]) 238*da0073e9SAndroid Build Coastguard Worker loss.backward() 239*da0073e9SAndroid Build Coastguard Worker optimizer.step() 240*da0073e9SAndroid Build Coastguard Worker 241*da0073e9SAndroid Build Coastguard Worker with torch.no_grad(): 242*da0073e9SAndroid Build Coastguard Worker test = torch.cat([torch.randn(1, 2, 2), torch.zeros(1, 2, 2)], dim=1).to(device) 243*da0073e9SAndroid Build Coastguard Worker 244*da0073e9SAndroid Build Coastguard Worker # Expect uint8 type not supported 245*da0073e9SAndroid Build Coastguard Worker ex = None 246*da0073e9SAndroid Build Coastguard Worker try: 247*da0073e9SAndroid Build Coastguard Worker test_train_uint8 = encoder(test, src_key_padding_mask=pad_mask.to(torch.uint8)) 248*da0073e9SAndroid Build Coastguard Worker except AssertionError as e: 249*da0073e9SAndroid Build Coastguard Worker continue 250*da0073e9SAndroid Build Coastguard Worker self.assertFalse(e, "Failed to catch unsupported uint8 type exception") # noqa: F821 251*da0073e9SAndroid Build Coastguard Worker 252*da0073e9SAndroid Build Coastguard Worker test_train_bool = encoder(test, src_key_padding_mask=pad_mask) 253*da0073e9SAndroid Build Coastguard Worker encoder.eval() 254*da0073e9SAndroid Build Coastguard Worker 255*da0073e9SAndroid Build Coastguard Worker # Expect long type not supported 256*da0073e9SAndroid Build Coastguard Worker ex = None 257*da0073e9SAndroid Build Coastguard Worker try: 258*da0073e9SAndroid Build Coastguard Worker test_eval_uint8 = encoder(test, src_key_padding_mask=pad_mask.to(torch.int64)) 259*da0073e9SAndroid Build Coastguard Worker except AssertionError as e: 260*da0073e9SAndroid Build Coastguard Worker continue 261*da0073e9SAndroid Build Coastguard Worker self.assertFalse(e, "Failed to catch unsupported Long type exception") # noqa: F821 262*da0073e9SAndroid Build Coastguard Worker 263*da0073e9SAndroid Build Coastguard Worker test_eval_bool = encoder(test, src_key_padding_mask=pad_mask) 264*da0073e9SAndroid Build Coastguard Worker l1_bool = nn.L1Loss()(test_train_bool[:, 0:2, :], test_eval_bool[:, 0:2, :]).item() 265*da0073e9SAndroid Build Coastguard Worker self.assertTrue(l1_bool < 1e-4, "Eval/Train difference in pad_mask BOOL") 266*da0073e9SAndroid Build Coastguard Worker 267*da0073e9SAndroid Build Coastguard Worker @parametrize("attn_mask_dim", [2, 3, None]) 268*da0073e9SAndroid Build Coastguard Worker @parametrize("key_padding_mask_dim", [2, None]) 269*da0073e9SAndroid Build Coastguard Worker @parametrize("mask_dtype", [torch.bool, torch.float32]) 270*da0073e9SAndroid Build Coastguard Worker def test_multiheadattention_fastpath_attn_mask(self, device, attn_mask_dim, key_padding_mask_dim, mask_dtype): 271*da0073e9SAndroid Build Coastguard Worker with torch.no_grad(): 272*da0073e9SAndroid Build Coastguard Worker B = 2 273*da0073e9SAndroid Build Coastguard Worker L = 4 274*da0073e9SAndroid Build Coastguard Worker D = 8 275*da0073e9SAndroid Build Coastguard Worker H = 4 276*da0073e9SAndroid Build Coastguard Worker 277*da0073e9SAndroid Build Coastguard Worker if attn_mask_dim == 2: 278*da0073e9SAndroid Build Coastguard Worker attn_mask = make_tensor((L, L), dtype=mask_dtype, device=device) 279*da0073e9SAndroid Build Coastguard Worker elif attn_mask_dim == 3: 280*da0073e9SAndroid Build Coastguard Worker attn_mask = make_tensor((B * H, L, L), dtype=mask_dtype, device=device) 281*da0073e9SAndroid Build Coastguard Worker elif attn_mask_dim is None: 282*da0073e9SAndroid Build Coastguard Worker attn_mask = None 283*da0073e9SAndroid Build Coastguard Worker 284*da0073e9SAndroid Build Coastguard Worker if key_padding_mask_dim == 2: 285*da0073e9SAndroid Build Coastguard Worker key_padding_mask = make_tensor((B, L), dtype=mask_dtype, device=device) 286*da0073e9SAndroid Build Coastguard Worker elif key_padding_mask_dim is None: 287*da0073e9SAndroid Build Coastguard Worker key_padding_mask = None 288*da0073e9SAndroid Build Coastguard Worker 289*da0073e9SAndroid Build Coastguard Worker mha = nn.MultiheadAttention(D, H, batch_first=True, device=device) 290*da0073e9SAndroid Build Coastguard Worker X = torch.randn(B, L, D, device=device) 291*da0073e9SAndroid Build Coastguard Worker 292*da0073e9SAndroid Build Coastguard Worker mha.train() # disable fast path 293*da0073e9SAndroid Build Coastguard Worker out, _ = mha(X, X, X, attn_mask=attn_mask, key_padding_mask=key_padding_mask, need_weights=False) 294*da0073e9SAndroid Build Coastguard Worker mha.eval() # enable fast path 295*da0073e9SAndroid Build Coastguard Worker out_fp, _ = mha(X, X, X, attn_mask=attn_mask, key_padding_mask=key_padding_mask, need_weights=False) 296*da0073e9SAndroid Build Coastguard Worker self.assertEqual(out, out_fp) 297*da0073e9SAndroid Build Coastguard Worker 298*da0073e9SAndroid Build Coastguard Worker @parametrize("nhead", [1, 4, 8]) 299*da0073e9SAndroid Build Coastguard Worker def test_transformerencoderlayer_src_mask(self, device, nhead): 300*da0073e9SAndroid Build Coastguard Worker batch_size = 2 301*da0073e9SAndroid Build Coastguard Worker seqlen = 4 302*da0073e9SAndroid Build Coastguard Worker d_model = 8 303*da0073e9SAndroid Build Coastguard Worker dim_feedforward = 32 304*da0073e9SAndroid Build Coastguard Worker 305*da0073e9SAndroid Build Coastguard Worker model = torch.nn.TransformerEncoderLayer( 306*da0073e9SAndroid Build Coastguard Worker d_model=d_model, 307*da0073e9SAndroid Build Coastguard Worker nhead=nhead, 308*da0073e9SAndroid Build Coastguard Worker dim_feedforward=dim_feedforward, 309*da0073e9SAndroid Build Coastguard Worker batch_first=True).to(device) 310*da0073e9SAndroid Build Coastguard Worker src = torch.rand(batch_size, seqlen, d_model).to(device) # bs, seqlen, d_model 311*da0073e9SAndroid Build Coastguard Worker src_mask = torch.zeros(seqlen, seqlen).to(torch.bool).to(device) 312*da0073e9SAndroid Build Coastguard Worker 313*da0073e9SAndroid Build Coastguard Worker model(src, src_mask=src_mask) 314*da0073e9SAndroid Build Coastguard Worker model.eval() 315*da0073e9SAndroid Build Coastguard Worker with torch.no_grad(): 316*da0073e9SAndroid Build Coastguard Worker model(src, src_mask=src_mask) 317*da0073e9SAndroid Build Coastguard Worker 318*da0073e9SAndroid Build Coastguard Worker @parametrize("use_torchscript", [False]) 319*da0073e9SAndroid Build Coastguard Worker @parametrize("enable_nested_tensor", [True, False]) 320*da0073e9SAndroid Build Coastguard Worker @parametrize("use_autocast", [True, False]) 321*da0073e9SAndroid Build Coastguard Worker @parametrize("d_model", [12, 256]) 322*da0073e9SAndroid Build Coastguard Worker def test_transformerencoder_fastpath(self, device, use_torchscript, enable_nested_tensor, use_autocast, d_model): 323*da0073e9SAndroid Build Coastguard Worker """ 324*da0073e9SAndroid Build Coastguard Worker Test TransformerEncoder fastpath output matches slowpath output 325*da0073e9SAndroid Build Coastguard Worker """ 326*da0073e9SAndroid Build Coastguard Worker torch.manual_seed(1234) 327*da0073e9SAndroid Build Coastguard Worker nhead = 4 328*da0073e9SAndroid Build Coastguard Worker dim_feedforward = d_model 329*da0073e9SAndroid Build Coastguard Worker batch_first = True 330*da0073e9SAndroid Build Coastguard Worker 331*da0073e9SAndroid Build Coastguard Worker model = torch.nn.TransformerEncoder( 332*da0073e9SAndroid Build Coastguard Worker torch.nn.TransformerEncoderLayer( 333*da0073e9SAndroid Build Coastguard Worker d_model=d_model, 334*da0073e9SAndroid Build Coastguard Worker nhead=nhead, 335*da0073e9SAndroid Build Coastguard Worker dim_feedforward=dim_feedforward, 336*da0073e9SAndroid Build Coastguard Worker batch_first=batch_first), 337*da0073e9SAndroid Build Coastguard Worker num_layers=2, 338*da0073e9SAndroid Build Coastguard Worker enable_nested_tensor=enable_nested_tensor 339*da0073e9SAndroid Build Coastguard Worker ).to(device).eval() 340*da0073e9SAndroid Build Coastguard Worker 341*da0073e9SAndroid Build Coastguard Worker if use_torchscript: 342*da0073e9SAndroid Build Coastguard Worker model = torch.jit.script(model) 343*da0073e9SAndroid Build Coastguard Worker 344*da0073e9SAndroid Build Coastguard Worker # each input is (input, mask) 345*da0073e9SAndroid Build Coastguard Worker input_mask_pairs = [ 346*da0073e9SAndroid Build Coastguard Worker ( 347*da0073e9SAndroid Build Coastguard Worker torch.rand(3, 2, d_model), 348*da0073e9SAndroid Build Coastguard Worker [ 349*da0073e9SAndroid Build Coastguard Worker [0, 1], 350*da0073e9SAndroid Build Coastguard Worker [0, 1], 351*da0073e9SAndroid Build Coastguard Worker [1, 1] 352*da0073e9SAndroid Build Coastguard Worker ] 353*da0073e9SAndroid Build Coastguard Worker ), 354*da0073e9SAndroid Build Coastguard Worker ( 355*da0073e9SAndroid Build Coastguard Worker torch.rand(2, 100, d_model), 356*da0073e9SAndroid Build Coastguard Worker [ 357*da0073e9SAndroid Build Coastguard Worker [0] * 98 + [1] * 2, 358*da0073e9SAndroid Build Coastguard Worker [0] * 90 + [1] * 10 359*da0073e9SAndroid Build Coastguard Worker ] 360*da0073e9SAndroid Build Coastguard Worker ), 361*da0073e9SAndroid Build Coastguard Worker # softmax.cu switches from fast->slowpath at masked seqlen 1024. test 1024. 362*da0073e9SAndroid Build Coastguard Worker ( 363*da0073e9SAndroid Build Coastguard Worker torch.rand(2, 1024, d_model), 364*da0073e9SAndroid Build Coastguard Worker [ 365*da0073e9SAndroid Build Coastguard Worker [0] * 1020 + [1] * 4, 366*da0073e9SAndroid Build Coastguard Worker [0] * 1024, 367*da0073e9SAndroid Build Coastguard Worker ] 368*da0073e9SAndroid Build Coastguard Worker ), 369*da0073e9SAndroid Build Coastguard Worker ( 370*da0073e9SAndroid Build Coastguard Worker torch.rand(1, 1026, d_model), 371*da0073e9SAndroid Build Coastguard Worker [[0] * 1024 + [1] * 2] 372*da0073e9SAndroid Build Coastguard Worker ), 373*da0073e9SAndroid Build Coastguard Worker # softmax.cu switches from fast->slowpath at masked seqlen 1024. test range of masks above 1024. 374*da0073e9SAndroid Build Coastguard Worker ( 375*da0073e9SAndroid Build Coastguard Worker torch.rand(4, 1040, d_model), 376*da0073e9SAndroid Build Coastguard Worker [ 377*da0073e9SAndroid Build Coastguard Worker [0] * 1024 + [1] * 16, 378*da0073e9SAndroid Build Coastguard Worker [0] * 1025 + [1] * 15, 379*da0073e9SAndroid Build Coastguard Worker [0] * 1031 + [1] * 9, 380*da0073e9SAndroid Build Coastguard Worker [0] * 1040, 381*da0073e9SAndroid Build Coastguard Worker ] 382*da0073e9SAndroid Build Coastguard Worker ) 383*da0073e9SAndroid Build Coastguard Worker ] 384*da0073e9SAndroid Build Coastguard Worker input_mask_pairs = [ 385*da0073e9SAndroid Build Coastguard Worker ( 386*da0073e9SAndroid Build Coastguard Worker torch.tensor(pair[0], device=device, dtype=torch.get_default_dtype()), # float input 387*da0073e9SAndroid Build Coastguard Worker torch.tensor(pair[1], device=device, dtype=torch.bool) # bool mask 388*da0073e9SAndroid Build Coastguard Worker ) for pair in input_mask_pairs 389*da0073e9SAndroid Build Coastguard Worker ] 390*da0073e9SAndroid Build Coastguard Worker 391*da0073e9SAndroid Build Coastguard Worker maybe_autocast = torch.autocast("cuda", dtype=torch.float16) if use_autocast else contextlib.nullcontext() 392*da0073e9SAndroid Build Coastguard Worker with maybe_autocast: 393*da0073e9SAndroid Build Coastguard Worker for input, src_key_padding_mask in input_mask_pairs: 394*da0073e9SAndroid Build Coastguard Worker with torch.no_grad(): 395*da0073e9SAndroid Build Coastguard Worker fastpath_output = model(input, src_key_padding_mask=src_key_padding_mask) 396*da0073e9SAndroid Build Coastguard Worker slowpath_output = model(input, src_key_padding_mask=src_key_padding_mask) # reference 397*da0073e9SAndroid Build Coastguard Worker # Make sure fastpath_output is same shape as slowpath_output and mask. 398*da0073e9SAndroid Build Coastguard Worker # When enable_nested_tensor=true, fastpath_output may be smaller than input tensor. 399*da0073e9SAndroid Build Coastguard Worker # Eg if input bs=1, seqlen=6, and we mask out 2 tokens, fastpath_output will have bs=1, seqlen=4. 400*da0073e9SAndroid Build Coastguard Worker # Expand back to old size to match. 401*da0073e9SAndroid Build Coastguard Worker bs, true_seqlen, embed_dim = fastpath_output.shape 402*da0073e9SAndroid Build Coastguard Worker expanded_seqlen = src_key_padding_mask.shape[1] 403*da0073e9SAndroid Build Coastguard Worker fastpath_output_expanded = torch.zeros(bs, expanded_seqlen, embed_dim, device=device) 404*da0073e9SAndroid Build Coastguard Worker fastpath_output_expanded[:, :true_seqlen, :] = fastpath_output 405*da0073e9SAndroid Build Coastguard Worker # no garauntees on output corresponding to masked tokens, so they may vary between slow/fast path. set all to 0. 406*da0073e9SAndroid Build Coastguard Worker fastpath_output_expanded = fastpath_output_expanded.masked_fill(src_key_padding_mask.unsqueeze(-1), 0) 407*da0073e9SAndroid Build Coastguard Worker slowpath_output = slowpath_output.masked_fill(src_key_padding_mask.unsqueeze(-1), 0) 408*da0073e9SAndroid Build Coastguard Worker torch.testing.assert_close(fastpath_output_expanded, slowpath_output, rtol=1e-7, atol=1e-5) 409*da0073e9SAndroid Build Coastguard Worker 410*da0073e9SAndroid Build Coastguard Worker @parametrize("with_no_grad", [True, False]) 411*da0073e9SAndroid Build Coastguard Worker @parametrize("training", [True, False]) 412*da0073e9SAndroid Build Coastguard Worker @parametrize("enable_nested_tensor", [False]) 413*da0073e9SAndroid Build Coastguard Worker def test_transformerencoder_square_input(self, with_no_grad, training, enable_nested_tensor, device): 414*da0073e9SAndroid Build Coastguard Worker """ 415*da0073e9SAndroid Build Coastguard Worker Test for edge cases when input of shape (batch size, sequence length, embedding dimension) has 416*da0073e9SAndroid Build Coastguard Worker batch size == sequence length 417*da0073e9SAndroid Build Coastguard Worker """ 418*da0073e9SAndroid Build Coastguard Worker model = torch.nn.TransformerEncoder( 419*da0073e9SAndroid Build Coastguard Worker torch.nn.TransformerEncoderLayer(d_model=4, nhead=2, dim_feedforward=16, dropout=0.0, batch_first=True), 420*da0073e9SAndroid Build Coastguard Worker num_layers=2, 421*da0073e9SAndroid Build Coastguard Worker enable_nested_tensor=enable_nested_tensor 422*da0073e9SAndroid Build Coastguard Worker ).to(device) 423*da0073e9SAndroid Build Coastguard Worker 424*da0073e9SAndroid Build Coastguard Worker with torch.no_grad(): 425*da0073e9SAndroid Build Coastguard Worker # set constant weights of the model 426*da0073e9SAndroid Build Coastguard Worker for idx, p in enumerate(model.parameters()): 427*da0073e9SAndroid Build Coastguard Worker x = p.data 428*da0073e9SAndroid Build Coastguard Worker sz = x.view(-1).size(0) 429*da0073e9SAndroid Build Coastguard Worker shape = x.shape 430*da0073e9SAndroid Build Coastguard Worker x = torch.cos(torch.arange(0, sz).float().view(shape)) 431*da0073e9SAndroid Build Coastguard Worker p.data.copy_(x) 432*da0073e9SAndroid Build Coastguard Worker 433*da0073e9SAndroid Build Coastguard Worker if training: 434*da0073e9SAndroid Build Coastguard Worker model = model.train() 435*da0073e9SAndroid Build Coastguard Worker else: 436*da0073e9SAndroid Build Coastguard Worker model = model.eval() 437*da0073e9SAndroid Build Coastguard Worker x = torch.arange(0, 16).reshape(2, 2, 4).to(torch.get_default_dtype()).to(device) 438*da0073e9SAndroid Build Coastguard Worker src_mask = torch.Tensor([[0, 1], [0, 0]]).to(torch.bool).to(device) 439*da0073e9SAndroid Build Coastguard Worker 440*da0073e9SAndroid Build Coastguard Worker if with_no_grad: 441*da0073e9SAndroid Build Coastguard Worker cm = torch.no_grad() 442*da0073e9SAndroid Build Coastguard Worker else: 443*da0073e9SAndroid Build Coastguard Worker cm = contextlib.nullcontext() 444*da0073e9SAndroid Build Coastguard Worker with cm: 445*da0073e9SAndroid Build Coastguard Worker result = model(x, mask=src_mask) 446*da0073e9SAndroid Build Coastguard Worker 447*da0073e9SAndroid Build Coastguard Worker ref_output = torch.Tensor([[[2.420306205749512, 0.017629241570830, -0.607857942581177, -0.085519507527351], 448*da0073e9SAndroid Build Coastguard Worker [2.420306205749512, 0.017629241570830, -0.607857942581177, -0.085519507527351]], 449*da0073e9SAndroid Build Coastguard Worker [[2.419836044311523, 0.017548924311996, -0.608187675476074, -0.085347734391689], 450*da0073e9SAndroid Build Coastguard Worker [2.419836044311523, 0.017548924311996, -0.608187675476074, -0.085347734391689]]] 451*da0073e9SAndroid Build Coastguard Worker ).to(device) 452*da0073e9SAndroid Build Coastguard Worker self.assertEqual(tuple(result.shape), tuple(ref_output.shape)) 453*da0073e9SAndroid Build Coastguard Worker torch.testing.assert_close(result, ref_output, rtol=1e-7, atol=1e-5) 454*da0073e9SAndroid Build Coastguard Worker 455*da0073e9SAndroid Build Coastguard Worker @parametrize("batch_first", [True, False]) 456*da0073e9SAndroid Build Coastguard Worker @parametrize("training", [True, False]) 457*da0073e9SAndroid Build Coastguard Worker @parametrize("enable_nested_tensor", [True, False]) 458*da0073e9SAndroid Build Coastguard Worker def test_transformerencoder(self, batch_first, training, enable_nested_tensor, device): 459*da0073e9SAndroid Build Coastguard Worker def get_a_test_layer(activation, batch_first=False): 460*da0073e9SAndroid Build Coastguard Worker d_model = 4 461*da0073e9SAndroid Build Coastguard Worker nhead = 2 462*da0073e9SAndroid Build Coastguard Worker dim_feedforward = 16 463*da0073e9SAndroid Build Coastguard Worker dropout = 0.0 464*da0073e9SAndroid Build Coastguard Worker 465*da0073e9SAndroid Build Coastguard Worker layer = nn.TransformerEncoderLayer( 466*da0073e9SAndroid Build Coastguard Worker d_model, 467*da0073e9SAndroid Build Coastguard Worker nhead, 468*da0073e9SAndroid Build Coastguard Worker dim_feedforward=dim_feedforward, 469*da0073e9SAndroid Build Coastguard Worker dropout=dropout, 470*da0073e9SAndroid Build Coastguard Worker activation=activation, 471*da0073e9SAndroid Build Coastguard Worker batch_first=batch_first, 472*da0073e9SAndroid Build Coastguard Worker ).to(device) 473*da0073e9SAndroid Build Coastguard Worker 474*da0073e9SAndroid Build Coastguard Worker with torch.no_grad(): 475*da0073e9SAndroid Build Coastguard Worker # set constant weights of the model 476*da0073e9SAndroid Build Coastguard Worker for idx, p in enumerate(layer.parameters()): 477*da0073e9SAndroid Build Coastguard Worker x = p.data 478*da0073e9SAndroid Build Coastguard Worker sz = x.view(-1).size(0) 479*da0073e9SAndroid Build Coastguard Worker shape = x.shape 480*da0073e9SAndroid Build Coastguard Worker x = torch.cos(torch.arange(0, sz).float().view(shape)) 481*da0073e9SAndroid Build Coastguard Worker p.data.copy_(x) 482*da0073e9SAndroid Build Coastguard Worker 483*da0073e9SAndroid Build Coastguard Worker return layer 484*da0073e9SAndroid Build Coastguard Worker 485*da0073e9SAndroid Build Coastguard Worker # this is a deterministic test for TransformerEncoder 486*da0073e9SAndroid Build Coastguard Worker activation = F.relu 487*da0073e9SAndroid Build Coastguard Worker 488*da0073e9SAndroid Build Coastguard Worker def _test(batch_first, training, enable_nested_tensor): 489*da0073e9SAndroid Build Coastguard Worker def perm_fn(x): 490*da0073e9SAndroid Build Coastguard Worker return x.transpose(1, 0) if batch_first else x 491*da0073e9SAndroid Build Coastguard Worker 492*da0073e9SAndroid Build Coastguard Worker encoder_layer = get_a_test_layer(activation=activation, 493*da0073e9SAndroid Build Coastguard Worker batch_first=batch_first) 494*da0073e9SAndroid Build Coastguard Worker 495*da0073e9SAndroid Build Coastguard Worker model = nn.TransformerEncoder( 496*da0073e9SAndroid Build Coastguard Worker encoder_layer, 1, enable_nested_tensor=enable_nested_tensor 497*da0073e9SAndroid Build Coastguard Worker ).to(device) 498*da0073e9SAndroid Build Coastguard Worker 499*da0073e9SAndroid Build Coastguard Worker if not training: 500*da0073e9SAndroid Build Coastguard Worker model = model.eval() 501*da0073e9SAndroid Build Coastguard Worker 502*da0073e9SAndroid Build Coastguard Worker # deterministic input 503*da0073e9SAndroid Build Coastguard Worker encoder_input = perm_fn(torch.tensor([[[0.7462, 0.6653, 0.5679, 0.4891], 504*da0073e9SAndroid Build Coastguard Worker [0.5387, 0.1655, 0.3565, 0.0471]], 505*da0073e9SAndroid Build Coastguard Worker [[0.8335, 0.2799, 0.5031, 0.2947], 506*da0073e9SAndroid Build Coastguard Worker [0.1402, 0.0318, 0.7636, 0.1346]], 507*da0073e9SAndroid Build Coastguard Worker [[0.6333, 0.9344, 0.1376, 0.9938], 508*da0073e9SAndroid Build Coastguard Worker [0.8924, 0.2872, 0.6692, 0.2944]], 509*da0073e9SAndroid Build Coastguard Worker [[0.9897, 0.6915, 0.3154, 0.1733], 510*da0073e9SAndroid Build Coastguard Worker [0.8645, 0.3513, 0.3064, 0.0767]], 511*da0073e9SAndroid Build Coastguard Worker [[0.8117, 0.2366, 0.4838, 0.7881], 512*da0073e9SAndroid Build Coastguard Worker [0.3718, 0.4945, 0.9511, 0.0864]]] 513*da0073e9SAndroid Build Coastguard Worker )).to(device) 514*da0073e9SAndroid Build Coastguard Worker result = model(encoder_input) 515*da0073e9SAndroid Build Coastguard Worker ref_output = perm_fn(torch.tensor([[[2.428589, 0.020835, -0.602055, -0.085249], 516*da0073e9SAndroid Build Coastguard Worker [2.427987, 0.021213, -0.602496, -0.084103]], 517*da0073e9SAndroid Build Coastguard Worker [[2.424689, 0.019155, -0.604793, -0.085672], 518*da0073e9SAndroid Build Coastguard Worker [2.413863, 0.022211, -0.612486, -0.072490]], 519*da0073e9SAndroid Build Coastguard Worker [[2.433774, 0.021598, -0.598343, -0.087548], 520*da0073e9SAndroid Build Coastguard Worker [2.425104, 0.019748, -0.604515, -0.084839]], 521*da0073e9SAndroid Build Coastguard Worker [[2.436185, 0.022682, -0.596625, -0.087261], 522*da0073e9SAndroid Build Coastguard Worker [2.433556, 0.021891, -0.598509, -0.086832]], 523*da0073e9SAndroid Build Coastguard Worker [[2.416246, 0.017512, -0.610712, -0.082961], 524*da0073e9SAndroid Build Coastguard Worker [2.422901, 0.024187, -0.606178, -0.074929]]] 525*da0073e9SAndroid Build Coastguard Worker )).to(device) 526*da0073e9SAndroid Build Coastguard Worker self.assertEqual(tuple(result.shape), tuple(ref_output.shape)) 527*da0073e9SAndroid Build Coastguard Worker torch.testing.assert_close(result, ref_output, rtol=1e-7, atol=1e-5) 528*da0073e9SAndroid Build Coastguard Worker 529*da0073e9SAndroid Build Coastguard Worker # all 0 src_mask 530*da0073e9SAndroid Build Coastguard Worker src_mask = torch.zeros([5, 5]).to(device) == 1 531*da0073e9SAndroid Build Coastguard Worker result = model(encoder_input, mask=src_mask) 532*da0073e9SAndroid Build Coastguard Worker self.assertEqual(tuple(result.shape), tuple(ref_output.shape)) 533*da0073e9SAndroid Build Coastguard Worker torch.testing.assert_close(result, ref_output, rtol=1e-7, atol=1e-5) 534*da0073e9SAndroid Build Coastguard Worker 535*da0073e9SAndroid Build Coastguard Worker # all 0 536*da0073e9SAndroid Build Coastguard Worker mask = torch.zeros([2, 5]).to(device) == 1 537*da0073e9SAndroid Build Coastguard Worker result = model(encoder_input, src_key_padding_mask=mask) 538*da0073e9SAndroid Build Coastguard Worker self.assertEqual(tuple(result.shape), tuple(ref_output.shape)) 539*da0073e9SAndroid Build Coastguard Worker torch.testing.assert_close(result, ref_output, rtol=1e-7, atol=1e-5) 540*da0073e9SAndroid Build Coastguard Worker 541*da0073e9SAndroid Build Coastguard Worker mask[0, 1] = 1 542*da0073e9SAndroid Build Coastguard Worker mask[1, 3] = 1 543*da0073e9SAndroid Build Coastguard Worker mask[1, 4] = 1 544*da0073e9SAndroid Build Coastguard Worker result = model(encoder_input, src_key_padding_mask=mask) 545*da0073e9SAndroid Build Coastguard Worker ref_output = perm_fn(torch.tensor([[[2.429026, 0.020793, -0.601741, -0.085642], 546*da0073e9SAndroid Build Coastguard Worker [2.428811, 0.021445, -0.601912, -0.084252]], 547*da0073e9SAndroid Build Coastguard Worker [[2.425009, 0.019155, -0.604566, -0.085899], 548*da0073e9SAndroid Build Coastguard Worker [2.415408, 0.02249, -0.611415, -0.073]], 549*da0073e9SAndroid Build Coastguard Worker [[2.434199, 0.021682, -0.598039, -0.087699], 550*da0073e9SAndroid Build Coastguard Worker [2.42598, 0.019941, -0.603896, -0.085091]], 551*da0073e9SAndroid Build Coastguard Worker [[2.436457, 0.022736, -0.59643, -0.08736], 552*da0073e9SAndroid Build Coastguard Worker [2.434021, 0.022093, -0.598179, -0.08679]], 553*da0073e9SAndroid Build Coastguard Worker [[2.416531, 0.017498, -0.610513, -0.083181], 554*da0073e9SAndroid Build Coastguard Worker [2.4242, 0.024653, -0.605266, -0.074959]]] 555*da0073e9SAndroid Build Coastguard Worker )).to(device) 556*da0073e9SAndroid Build Coastguard Worker self.assertEqual(tuple(result.shape), tuple(ref_output.shape)) 557*da0073e9SAndroid Build Coastguard Worker torch.testing.assert_close(result, ref_output, rtol=1e-7, atol=1e-5) 558*da0073e9SAndroid Build Coastguard Worker 559*da0073e9SAndroid Build Coastguard Worker # test case 2, multiple layers no norm 560*da0073e9SAndroid Build Coastguard Worker model = nn.TransformerEncoder(encoder_layer, 2, enable_nested_tensor=enable_nested_tensor).to(device) 561*da0073e9SAndroid Build Coastguard Worker if not training: 562*da0073e9SAndroid Build Coastguard Worker model = model.eval() 563*da0073e9SAndroid Build Coastguard Worker result = model(encoder_input, src_key_padding_mask=mask) 564*da0073e9SAndroid Build Coastguard Worker ref_output = perm_fn(torch.tensor([[[2.419051, 0.017446, -0.608738, -0.085003], 565*da0073e9SAndroid Build Coastguard Worker [2.419102, 0.017452, -0.608703, -0.085026]], 566*da0073e9SAndroid Build Coastguard Worker [[2.419043, 0.017445, -0.608744, -0.084999], 567*da0073e9SAndroid Build Coastguard Worker [2.419052, 0.017446, -0.608738, -0.085004]], 568*da0073e9SAndroid Build Coastguard Worker [[2.419067, 0.017448, -0.608727, -0.085010], 569*da0073e9SAndroid Build Coastguard Worker [2.419098, 0.017452, -0.608706, -0.085024]], 570*da0073e9SAndroid Build Coastguard Worker [[2.419072, 0.017449, -0.608724, -0.085012], 571*da0073e9SAndroid Build Coastguard Worker [2.419119, 0.017455, -0.608691, -0.085034]], 572*da0073e9SAndroid Build Coastguard Worker [[2.419019, 0.017442, -0.608761, -0.084989], 573*da0073e9SAndroid Build Coastguard Worker [2.419075, 0.017449, -0.608722, -0.085014]]] 574*da0073e9SAndroid Build Coastguard Worker )).to(device) 575*da0073e9SAndroid Build Coastguard Worker self.assertEqual(tuple(result.shape), tuple(ref_output.shape)) 576*da0073e9SAndroid Build Coastguard Worker torch.testing.assert_close(result, ref_output, rtol=1e-7, atol=1e-5) 577*da0073e9SAndroid Build Coastguard Worker 578*da0073e9SAndroid Build Coastguard Worker model = nn.TransformerEncoder(encoder_layer, 6, enable_nested_tensor=enable_nested_tensor).to(device) 579*da0073e9SAndroid Build Coastguard Worker if not training: 580*da0073e9SAndroid Build Coastguard Worker model = model.eval() 581*da0073e9SAndroid Build Coastguard Worker result = model(encoder_input, src_key_padding_mask=mask) 582*da0073e9SAndroid Build Coastguard Worker ref_output = perm_fn(torch.tensor([[[2.419101, 0.017453, -0.608703, -0.085025], 583*da0073e9SAndroid Build Coastguard Worker [2.419101, 0.017453, -0.608704, -0.085025]], 584*da0073e9SAndroid Build Coastguard Worker [[2.419101, 0.017453, -0.608703, -0.085025], 585*da0073e9SAndroid Build Coastguard Worker [2.419101, 0.017453, -0.608704, -0.085025]], 586*da0073e9SAndroid Build Coastguard Worker [[2.419101, 0.017453, -0.608703, -0.085025], 587*da0073e9SAndroid Build Coastguard Worker [2.419101, 0.017453, -0.608704, -0.085025]], 588*da0073e9SAndroid Build Coastguard Worker [[2.419101, 0.017453, -0.608703, -0.085025], 589*da0073e9SAndroid Build Coastguard Worker [2.419101, 0.017453, -0.608704, -0.085025]], 590*da0073e9SAndroid Build Coastguard Worker [[2.419101, 0.017453, -0.608703, -0.085025], 591*da0073e9SAndroid Build Coastguard Worker [2.419101, 0.017453, -0.608704, -0.085025]]] 592*da0073e9SAndroid Build Coastguard Worker )).to(device) 593*da0073e9SAndroid Build Coastguard Worker self.assertEqual(tuple(result.shape), tuple(ref_output.shape)) 594*da0073e9SAndroid Build Coastguard Worker torch.testing.assert_close(result, ref_output, rtol=1e-7, atol=1e-5) 595*da0073e9SAndroid Build Coastguard Worker 596*da0073e9SAndroid Build Coastguard Worker # test case 3, multiple layers with norm 597*da0073e9SAndroid Build Coastguard Worker # d_model = 4 598*da0073e9SAndroid Build Coastguard Worker norm = nn.LayerNorm(4) 599*da0073e9SAndroid Build Coastguard Worker model = nn.TransformerEncoder(encoder_layer, 2, norm=norm, 600*da0073e9SAndroid Build Coastguard Worker enable_nested_tensor=enable_nested_tensor).to(device) 601*da0073e9SAndroid Build Coastguard Worker if not training: 602*da0073e9SAndroid Build Coastguard Worker model = model.eval() 603*da0073e9SAndroid Build Coastguard Worker result = model(encoder_input, src_key_padding_mask=mask) 604*da0073e9SAndroid Build Coastguard Worker ref_output = perm_fn(torch.tensor([[[1.695949, -0.357635, -0.893077, -0.445238], 605*da0073e9SAndroid Build Coastguard Worker [1.695955, -0.357639, -0.893050, -0.445266]], 606*da0073e9SAndroid Build Coastguard Worker [[1.695948, -0.357634, -0.893082, -0.445233], 607*da0073e9SAndroid Build Coastguard Worker [1.695950, -0.357635, -0.893077, -0.445238]], 608*da0073e9SAndroid Build Coastguard Worker [[1.695951, -0.357636, -0.893069, -0.445246], 609*da0073e9SAndroid Build Coastguard Worker [1.695955, -0.357639, -0.893052, -0.445264]], 610*da0073e9SAndroid Build Coastguard Worker [[1.695952, -0.357636, -0.893066, -0.445249], 611*da0073e9SAndroid Build Coastguard Worker [1.695957, -0.357641, -0.893041, -0.445276]], 612*da0073e9SAndroid Build Coastguard Worker [[1.695946, -0.357632, -0.893095, -0.445220], 613*da0073e9SAndroid Build Coastguard Worker [1.695952, -0.357637, -0.893065, -0.445251]]] 614*da0073e9SAndroid Build Coastguard Worker )).to(device) 615*da0073e9SAndroid Build Coastguard Worker self.assertEqual(tuple(result.shape), tuple(ref_output.shape)) 616*da0073e9SAndroid Build Coastguard Worker torch.testing.assert_close(result, ref_output, rtol=1e-7, atol=1e-5) 617*da0073e9SAndroid Build Coastguard Worker 618*da0073e9SAndroid Build Coastguard Worker model = nn.TransformerEncoder(encoder_layer, 6, norm=norm, 619*da0073e9SAndroid Build Coastguard Worker enable_nested_tensor=enable_nested_tensor).to(device) 620*da0073e9SAndroid Build Coastguard Worker if not training: 621*da0073e9SAndroid Build Coastguard Worker model = model.eval() 622*da0073e9SAndroid Build Coastguard Worker result = model(encoder_input, src_key_padding_mask=mask) 623*da0073e9SAndroid Build Coastguard Worker ref_output = perm_fn(torch.tensor([[[1.695955, -0.357639, -0.893051, -0.445265], 624*da0073e9SAndroid Build Coastguard Worker [1.695955, -0.357639, -0.893051, -0.445265]], 625*da0073e9SAndroid Build Coastguard Worker [[1.695955, -0.357639, -0.893051, -0.445265], 626*da0073e9SAndroid Build Coastguard Worker [1.695955, -0.357639, -0.893051, -0.445265]], 627*da0073e9SAndroid Build Coastguard Worker [[1.695955, -0.357639, -0.893051, -0.445265], 628*da0073e9SAndroid Build Coastguard Worker [1.695955, -0.357639, -0.893051, -0.445265]], 629*da0073e9SAndroid Build Coastguard Worker [[1.695955, -0.357639, -0.893051, -0.445265], 630*da0073e9SAndroid Build Coastguard Worker [1.695955, -0.357639, -0.893051, -0.445265]], 631*da0073e9SAndroid Build Coastguard Worker [[1.695955, -0.357639, -0.893051, -0.445265], 632*da0073e9SAndroid Build Coastguard Worker [1.695955, -0.357639, -0.893051, -0.445265]]] 633*da0073e9SAndroid Build Coastguard Worker )).to(device) 634*da0073e9SAndroid Build Coastguard Worker self.assertEqual(tuple(result.shape), tuple(ref_output.shape)) 635*da0073e9SAndroid Build Coastguard Worker torch.testing.assert_close(result, ref_output, rtol=1e-7, atol=1e-5) 636*da0073e9SAndroid Build Coastguard Worker 637*da0073e9SAndroid Build Coastguard Worker # TODO: remove set default dtype to double by making ref_output more precise. 638*da0073e9SAndroid Build Coastguard Worker # Added because this test was copied from test_nn.py, which has default 639*da0073e9SAndroid Build Coastguard Worker # dtype double. If default dtype is float, tests will say tensors not close because 640*da0073e9SAndroid Build Coastguard Worker # ref output precision too low 641*da0073e9SAndroid Build Coastguard Worker with set_default_dtype(torch.double): 642*da0073e9SAndroid Build Coastguard Worker if training: 643*da0073e9SAndroid Build Coastguard Worker cm = contextlib.nullcontext() 644*da0073e9SAndroid Build Coastguard Worker else: 645*da0073e9SAndroid Build Coastguard Worker cm = torch.no_grad() # transformer fast path requires no grad 646*da0073e9SAndroid Build Coastguard Worker with cm: 647*da0073e9SAndroid Build Coastguard Worker _test(batch_first, training, enable_nested_tensor) 648*da0073e9SAndroid Build Coastguard Worker 649*da0073e9SAndroid Build Coastguard Worker @unittest.skipIf(sys.version_info < (3, 11), "not supported on pre-3.11 Python") 650*da0073e9SAndroid Build Coastguard Worker def test_encoder_padding_and_src_mask_bool(self): 651*da0073e9SAndroid Build Coastguard Worker encoder_layer = nn.TransformerEncoderLayer( 652*da0073e9SAndroid Build Coastguard Worker d_model=16, 653*da0073e9SAndroid Build Coastguard Worker nhead=2, 654*da0073e9SAndroid Build Coastguard Worker dim_feedforward=32, 655*da0073e9SAndroid Build Coastguard Worker dropout=0.1, 656*da0073e9SAndroid Build Coastguard Worker activation='relu', 657*da0073e9SAndroid Build Coastguard Worker batch_first=True, 658*da0073e9SAndroid Build Coastguard Worker ) 659*da0073e9SAndroid Build Coastguard Worker encoder_norm = nn.LayerNorm(16) 660*da0073e9SAndroid Build Coastguard Worker encoder = nn.TransformerEncoder( 661*da0073e9SAndroid Build Coastguard Worker encoder_layer, 2, encoder_norm 662*da0073e9SAndroid Build Coastguard Worker ) 663*da0073e9SAndroid Build Coastguard Worker 664*da0073e9SAndroid Build Coastguard Worker inputs = torch.randn(2, 3, 16) 665*da0073e9SAndroid Build Coastguard Worker 666*da0073e9SAndroid Build Coastguard Worker src_mask = torch.ones(3, 3, dtype=torch.bool).triu_(diagonal=1) 667*da0073e9SAndroid Build Coastguard Worker input_seq_len = torch.tensor([3, 2]) 668*da0073e9SAndroid Build Coastguard Worker padding_mask = ( 669*da0073e9SAndroid Build Coastguard Worker torch.arange(3)[None, :].cpu() >= input_seq_len[:, None] 670*da0073e9SAndroid Build Coastguard Worker ) 671*da0073e9SAndroid Build Coastguard Worker 672*da0073e9SAndroid Build Coastguard Worker with (self.assertNoLogs(None) if not TEST_WITH_TORCHDYNAMO else contextlib.nullcontext()): 673*da0073e9SAndroid Build Coastguard Worker encoder( 674*da0073e9SAndroid Build Coastguard Worker inputs, 675*da0073e9SAndroid Build Coastguard Worker mask=src_mask, 676*da0073e9SAndroid Build Coastguard Worker src_key_padding_mask=padding_mask, 677*da0073e9SAndroid Build Coastguard Worker ) 678*da0073e9SAndroid Build Coastguard Worker 679*da0073e9SAndroid Build Coastguard Worker @unittest.skipIf(sys.version_info < (3, 11), "not supported on pre-3.11 Python") 680*da0073e9SAndroid Build Coastguard Worker def test_decoder_padding_and_src_mask_bool(self): 681*da0073e9SAndroid Build Coastguard Worker 682*da0073e9SAndroid Build Coastguard Worker def transformer_decoder(inputs, input_seq_len, memory): 683*da0073e9SAndroid Build Coastguard Worker decoder_layer = nn.TransformerDecoderLayer( 684*da0073e9SAndroid Build Coastguard Worker d_model=16, 685*da0073e9SAndroid Build Coastguard Worker nhead=2, 686*da0073e9SAndroid Build Coastguard Worker dim_feedforward=32, 687*da0073e9SAndroid Build Coastguard Worker dropout=0.1, 688*da0073e9SAndroid Build Coastguard Worker activation='relu', 689*da0073e9SAndroid Build Coastguard Worker batch_first=True, 690*da0073e9SAndroid Build Coastguard Worker ) 691*da0073e9SAndroid Build Coastguard Worker decoder_norm = nn.LayerNorm(16) 692*da0073e9SAndroid Build Coastguard Worker decoder = nn.TransformerDecoder( 693*da0073e9SAndroid Build Coastguard Worker decoder_layer, 2, decoder_norm 694*da0073e9SAndroid Build Coastguard Worker ) 695*da0073e9SAndroid Build Coastguard Worker 696*da0073e9SAndroid Build Coastguard Worker src_mask = torch.ones( 697*da0073e9SAndroid Build Coastguard Worker inputs.shape[1], inputs.shape[1], dtype=torch.bool 698*da0073e9SAndroid Build Coastguard Worker ).triu_(diagonal=1) 699*da0073e9SAndroid Build Coastguard Worker padding_mask = ( 700*da0073e9SAndroid Build Coastguard Worker torch.arange(inputs.shape[1])[None, :].cpu() 701*da0073e9SAndroid Build Coastguard Worker >= input_seq_len[:, None] 702*da0073e9SAndroid Build Coastguard Worker ) 703*da0073e9SAndroid Build Coastguard Worker 704*da0073e9SAndroid Build Coastguard Worker return decoder( 705*da0073e9SAndroid Build Coastguard Worker inputs, 706*da0073e9SAndroid Build Coastguard Worker memory, 707*da0073e9SAndroid Build Coastguard Worker tgt_mask=src_mask, 708*da0073e9SAndroid Build Coastguard Worker tgt_key_padding_mask=padding_mask, 709*da0073e9SAndroid Build Coastguard Worker memory_key_padding_mask=padding_mask, 710*da0073e9SAndroid Build Coastguard Worker ) 711*da0073e9SAndroid Build Coastguard Worker 712*da0073e9SAndroid Build Coastguard Worker inputs = torch.randn(2, 3, 16) 713*da0073e9SAndroid Build Coastguard Worker memory = torch.randn(2, 3, 16) 714*da0073e9SAndroid Build Coastguard Worker input_seq_len = torch.tensor([3, 2]) 715*da0073e9SAndroid Build Coastguard Worker 716*da0073e9SAndroid Build Coastguard Worker with self.assertNoLogs(None): 717*da0073e9SAndroid Build Coastguard Worker transformer_decoder(inputs, input_seq_len, memory) 718*da0073e9SAndroid Build Coastguard Worker 719*da0073e9SAndroid Build Coastguard Worker def test_encoder_is_causal(self): 720*da0073e9SAndroid Build Coastguard Worker 721*da0073e9SAndroid Build Coastguard Worker d_model = 3 722*da0073e9SAndroid Build Coastguard Worker layer = torch.nn.TransformerEncoderLayer(d_model, 1, 6, batch_first=True) 723*da0073e9SAndroid Build Coastguard Worker layer.eval() 724*da0073e9SAndroid Build Coastguard Worker x = torch.randn(1, 5, d_model) 725*da0073e9SAndroid Build Coastguard Worker unmasked_output = layer(x) 726*da0073e9SAndroid Build Coastguard Worker mask = torch.nn.Transformer.generate_square_subsequent_mask(x.size(1)) 727*da0073e9SAndroid Build Coastguard Worker is_causal_output = layer(x, src_mask=mask, is_causal=True) 728*da0073e9SAndroid Build Coastguard Worker masked_output = layer(x, src_mask=mask) 729*da0073e9SAndroid Build Coastguard Worker 730*da0073e9SAndroid Build Coastguard Worker self.assertEqual(masked_output, is_causal_output) 731*da0073e9SAndroid Build Coastguard Worker 732*da0073e9SAndroid Build Coastguard Worker @onlyCUDA 733*da0073e9SAndroid Build Coastguard Worker @parametrize("nb_heads", [1, 8]) 734*da0073e9SAndroid Build Coastguard Worker @parametrize("bias", [True, False]) 735*da0073e9SAndroid Build Coastguard Worker def test_mha_native_args(self, nb_heads, bias): 736*da0073e9SAndroid Build Coastguard Worker 737*da0073e9SAndroid Build Coastguard Worker B, L, F = 8, 100, 128 738*da0073e9SAndroid Build Coastguard Worker batch_first = True 739*da0073e9SAndroid Build Coastguard Worker fast_path = True 740*da0073e9SAndroid Build Coastguard Worker use_pad_mask = (bias % 2) == 1 741*da0073e9SAndroid Build Coastguard Worker 742*da0073e9SAndroid Build Coastguard Worker mha = nn.MultiheadAttention( 743*da0073e9SAndroid Build Coastguard Worker embed_dim=F, 744*da0073e9SAndroid Build Coastguard Worker num_heads=nb_heads, 745*da0073e9SAndroid Build Coastguard Worker batch_first=batch_first, 746*da0073e9SAndroid Build Coastguard Worker bias=bias 747*da0073e9SAndroid Build Coastguard Worker ).cuda() 748*da0073e9SAndroid Build Coastguard Worker mha.eval() 749*da0073e9SAndroid Build Coastguard Worker 750*da0073e9SAndroid Build Coastguard Worker ctx = torch.no_grad if fast_path else contextlib.nullcontext 751*da0073e9SAndroid Build Coastguard Worker with ctx(): 752*da0073e9SAndroid Build Coastguard Worker x = torch.randn(B, L, F).cuda() 753*da0073e9SAndroid Build Coastguard Worker if not batch_first: 754*da0073e9SAndroid Build Coastguard Worker x = x.transpose(0, 1) 755*da0073e9SAndroid Build Coastguard Worker 756*da0073e9SAndroid Build Coastguard Worker pad_mask = None 757*da0073e9SAndroid Build Coastguard Worker if use_pad_mask: 758*da0073e9SAndroid Build Coastguard Worker pad_mask = torch.zeros((B, L), dtype=torch.bool).cuda() 759*da0073e9SAndroid Build Coastguard Worker 760*da0073e9SAndroid Build Coastguard Worker mha(query=x, key=x, value=x, key_padding_mask=pad_mask) 761*da0073e9SAndroid Build Coastguard Worker 762*da0073e9SAndroid Build Coastguard Worker def test_kpm_mask_trailing_column_with_nested_tensor(self, device): 763*da0073e9SAndroid Build Coastguard Worker encoder_layer = nn.TransformerEncoderLayer( 764*da0073e9SAndroid Build Coastguard Worker d_model=256, 765*da0073e9SAndroid Build Coastguard Worker nhead=4, 766*da0073e9SAndroid Build Coastguard Worker dim_feedforward=512, 767*da0073e9SAndroid Build Coastguard Worker activation='gelu', 768*da0073e9SAndroid Build Coastguard Worker norm_first=False, 769*da0073e9SAndroid Build Coastguard Worker batch_first=False, 770*da0073e9SAndroid Build Coastguard Worker ) 771*da0073e9SAndroid Build Coastguard Worker transformer_encoder = nn.TransformerEncoder(encoder_layer, num_layers=3, enable_nested_tensor=True).to(device) 772*da0073e9SAndroid Build Coastguard Worker 773*da0073e9SAndroid Build Coastguard Worker x = torch.randn(10, 6, 256).to(device) 774*da0073e9SAndroid Build Coastguard Worker mask = torch.ones(6, 10) 775*da0073e9SAndroid Build Coastguard Worker mask[0, :] = 0 # here I masked 5 columns instead of just one 776*da0073e9SAndroid Build Coastguard Worker mask = mask.bool().to(device) 777*da0073e9SAndroid Build Coastguard Worker out = transformer_encoder(src=x, src_key_padding_mask=mask) 778*da0073e9SAndroid Build Coastguard Worker self.assertEqual(out.shape[1], 6) 779*da0073e9SAndroid Build Coastguard Worker 780*da0073e9SAndroid Build Coastguard Worker # CPU unit test has_torch_functions in test environment, 781*da0073e9SAndroid Build Coastguard Worker # preventing successful completion 782*da0073e9SAndroid Build Coastguard Worker @onlyCUDA 783*da0073e9SAndroid Build Coastguard Worker def test_with_nested_tensor_input(self, device): 784*da0073e9SAndroid Build Coastguard Worker encoder_layer = nn.TransformerEncoderLayer( 785*da0073e9SAndroid Build Coastguard Worker d_model=256, 786*da0073e9SAndroid Build Coastguard Worker nhead=4, 787*da0073e9SAndroid Build Coastguard Worker dim_feedforward=512, 788*da0073e9SAndroid Build Coastguard Worker activation='gelu', 789*da0073e9SAndroid Build Coastguard Worker norm_first=False, 790*da0073e9SAndroid Build Coastguard Worker batch_first=True, 791*da0073e9SAndroid Build Coastguard Worker ) 792*da0073e9SAndroid Build Coastguard Worker transformer_encoder = nn.TransformerEncoder(encoder_layer, num_layers=3, enable_nested_tensor=True).to(device) 793*da0073e9SAndroid Build Coastguard Worker 794*da0073e9SAndroid Build Coastguard Worker transformer_encoder.eval() 795*da0073e9SAndroid Build Coastguard Worker with torch.no_grad(): 796*da0073e9SAndroid Build Coastguard Worker x = torch.randn(6, 10, 256).to(device) 797*da0073e9SAndroid Build Coastguard Worker mask = torch.ones(6, 10) 798*da0073e9SAndroid Build Coastguard Worker mask[0, 0:] = 0 # here I masked 5 columns instead of just one 799*da0073e9SAndroid Build Coastguard Worker mask[2, 2:] = 0 # here I masked 5 columns instead of just one 800*da0073e9SAndroid Build Coastguard Worker mask[4, 4:] = 0 # here I masked 5 columns instead of just one 801*da0073e9SAndroid Build Coastguard Worker mask[5, 8:] = 0 # here I masked 5 columns instead of just one 802*da0073e9SAndroid Build Coastguard Worker mask = mask.bool().to(device) 803*da0073e9SAndroid Build Coastguard Worker x = torch._nested_tensor_from_mask(x, mask.logical_not(), mask_check=False) 804*da0073e9SAndroid Build Coastguard Worker out = transformer_encoder(src=x, src_key_padding_mask=None) 805*da0073e9SAndroid Build Coastguard Worker 806*da0073e9SAndroid Build Coastguard Worker self.assertEqual(out.is_nested, True) 807*da0073e9SAndroid Build Coastguard Worker 808*da0073e9SAndroid Build Coastguard Worker 809*da0073e9SAndroid Build Coastguard Worker 810*da0073e9SAndroid Build Coastguard Worker def test_script_encoder_subclass(self, device): 811*da0073e9SAndroid Build Coastguard Worker class MyCustomLayer(nn.TransformerEncoderLayer): 812*da0073e9SAndroid Build Coastguard Worker pass 813*da0073e9SAndroid Build Coastguard Worker 814*da0073e9SAndroid Build Coastguard Worker encoder = nn.TransformerEncoder( 815*da0073e9SAndroid Build Coastguard Worker MyCustomLayer(d_model=256, nhead=8), num_layers=6 816*da0073e9SAndroid Build Coastguard Worker ).to(device=device) 817*da0073e9SAndroid Build Coastguard Worker torch.jit.script(encoder) 818*da0073e9SAndroid Build Coastguard Worker 819*da0073e9SAndroid Build Coastguard Worker # brazenly adapted from test_transformerencoderlayer_src_mask to test execution of 820*da0073e9SAndroid Build Coastguard Worker # torchscripted transformerencoderlayer subclass 821*da0073e9SAndroid Build Coastguard Worker def test_transformerencoderlayer_subclass(self, device): 822*da0073e9SAndroid Build Coastguard Worker class MyCustomLayer(nn.TransformerEncoderLayer): 823*da0073e9SAndroid Build Coastguard Worker pass 824*da0073e9SAndroid Build Coastguard Worker 825*da0073e9SAndroid Build Coastguard Worker nhead = 4 826*da0073e9SAndroid Build Coastguard Worker batch_size = 2 827*da0073e9SAndroid Build Coastguard Worker seqlen = 4 828*da0073e9SAndroid Build Coastguard Worker d_model = 8 829*da0073e9SAndroid Build Coastguard Worker dim_feedforward = 32 830*da0073e9SAndroid Build Coastguard Worker 831*da0073e9SAndroid Build Coastguard Worker model = MyCustomLayer( 832*da0073e9SAndroid Build Coastguard Worker d_model=d_model, 833*da0073e9SAndroid Build Coastguard Worker nhead=nhead, 834*da0073e9SAndroid Build Coastguard Worker dim_feedforward=dim_feedforward, 835*da0073e9SAndroid Build Coastguard Worker batch_first=True).to(device) 836*da0073e9SAndroid Build Coastguard Worker script_model = torch.jit.script(model) 837*da0073e9SAndroid Build Coastguard Worker 838*da0073e9SAndroid Build Coastguard Worker src = torch.rand(batch_size, seqlen, d_model).to(device) # bs, seqlen, d_model 839*da0073e9SAndroid Build Coastguard Worker src_mask = torch.zeros(seqlen, seqlen).to(torch.bool).to(device) 840*da0073e9SAndroid Build Coastguard Worker 841*da0073e9SAndroid Build Coastguard Worker torch.manual_seed(42) 842*da0073e9SAndroid Build Coastguard Worker result = model(src, src_mask=src_mask) 843*da0073e9SAndroid Build Coastguard Worker torch.manual_seed(42) 844*da0073e9SAndroid Build Coastguard Worker scripted_result = script_model(src, src_mask=src_mask) 845*da0073e9SAndroid Build Coastguard Worker self.assertEqual(result, scripted_result) 846*da0073e9SAndroid Build Coastguard Worker 847*da0073e9SAndroid Build Coastguard Worker model.eval() 848*da0073e9SAndroid Build Coastguard Worker script_model = torch.jit.script(model) 849*da0073e9SAndroid Build Coastguard Worker 850*da0073e9SAndroid Build Coastguard Worker with torch.no_grad(): 851*da0073e9SAndroid Build Coastguard Worker result = model(src, src_mask=src_mask) 852*da0073e9SAndroid Build Coastguard Worker scripted_result = script_model(src, src_mask=src_mask) 853*da0073e9SAndroid Build Coastguard Worker self.assertEqual(result, scripted_result) 854*da0073e9SAndroid Build Coastguard Worker 855*da0073e9SAndroid Build Coastguard Worker 856*da0073e9SAndroid Build Coastguard Worker def test_transformerencoderlayer_subclass_model(self, device): 857*da0073e9SAndroid Build Coastguard Worker class MyCustomLayer(nn.TransformerEncoderLayer): 858*da0073e9SAndroid Build Coastguard Worker pass 859*da0073e9SAndroid Build Coastguard Worker 860*da0073e9SAndroid Build Coastguard Worker nhead = 4 861*da0073e9SAndroid Build Coastguard Worker batch_size = 2 862*da0073e9SAndroid Build Coastguard Worker seqlen = 4 863*da0073e9SAndroid Build Coastguard Worker d_model = 8 864*da0073e9SAndroid Build Coastguard Worker dim_feedforward = 32 865*da0073e9SAndroid Build Coastguard Worker 866*da0073e9SAndroid Build Coastguard Worker layer = MyCustomLayer( 867*da0073e9SAndroid Build Coastguard Worker d_model=d_model, 868*da0073e9SAndroid Build Coastguard Worker nhead=nhead, 869*da0073e9SAndroid Build Coastguard Worker dim_feedforward=dim_feedforward, 870*da0073e9SAndroid Build Coastguard Worker batch_first=True) 871*da0073e9SAndroid Build Coastguard Worker model = nn.TransformerEncoder( 872*da0073e9SAndroid Build Coastguard Worker layer, num_layers=6 873*da0073e9SAndroid Build Coastguard Worker ).to(device=device) 874*da0073e9SAndroid Build Coastguard Worker script_model = torch.jit.script(model) 875*da0073e9SAndroid Build Coastguard Worker 876*da0073e9SAndroid Build Coastguard Worker src = torch.rand(batch_size, seqlen, d_model).to(device) # bs, seqlen, d_model 877*da0073e9SAndroid Build Coastguard Worker src_mask = torch.zeros(seqlen, seqlen).to(torch.bool).to(device) 878*da0073e9SAndroid Build Coastguard Worker 879*da0073e9SAndroid Build Coastguard Worker torch.manual_seed(42) 880*da0073e9SAndroid Build Coastguard Worker result = model(src, mask=src_mask) 881*da0073e9SAndroid Build Coastguard Worker torch.manual_seed(42) 882*da0073e9SAndroid Build Coastguard Worker scripted_result = script_model(src, mask=src_mask) 883*da0073e9SAndroid Build Coastguard Worker self.assertEqual(result, scripted_result) 884*da0073e9SAndroid Build Coastguard Worker 885*da0073e9SAndroid Build Coastguard Worker model.eval() 886*da0073e9SAndroid Build Coastguard Worker script_model = torch.jit.script(model) 887*da0073e9SAndroid Build Coastguard Worker 888*da0073e9SAndroid Build Coastguard Worker with torch.no_grad(): 889*da0073e9SAndroid Build Coastguard Worker result = model(src, mask=src_mask) 890*da0073e9SAndroid Build Coastguard Worker scripted_result = script_model(src, mask=src_mask) 891*da0073e9SAndroid Build Coastguard Worker self.assertEqual(result, scripted_result) 892*da0073e9SAndroid Build Coastguard Worker 893*da0073e9SAndroid Build Coastguard Worker 894*da0073e9SAndroid Build Coastguard Worker @onlyCUDA 895*da0073e9SAndroid Build Coastguard Worker @unittest.skipIf(not TEST_FAIRSEQ, "Fairseq not found") 896*da0073e9SAndroid Build Coastguard Worker def test_decoder_only_layer(self): 897*da0073e9SAndroid Build Coastguard Worker DEFAULT_PADDING_IDX = 0 898*da0073e9SAndroid Build Coastguard Worker 899*da0073e9SAndroid Build Coastguard Worker class FairseqDecoder(torch.nn.Module): 900*da0073e9SAndroid Build Coastguard Worker def __init__( 901*da0073e9SAndroid Build Coastguard Worker self, 902*da0073e9SAndroid Build Coastguard Worker embed_dim, 903*da0073e9SAndroid Build Coastguard Worker attention_heads, 904*da0073e9SAndroid Build Coastguard Worker ffn_embed_dim, 905*da0073e9SAndroid Build Coastguard Worker num_layers, 906*da0073e9SAndroid Build Coastguard Worker embedding_layer, # torch.nn.Embedding. Must have a padding_idx field 907*da0073e9SAndroid Build Coastguard Worker dropout=0, 908*da0073e9SAndroid Build Coastguard Worker normalize_before=False, 909*da0073e9SAndroid Build Coastguard Worker torch_encoder=None, # torch encoder that you can map weights from 910*da0073e9SAndroid Build Coastguard Worker activation="relu", 911*da0073e9SAndroid Build Coastguard Worker ): 912*da0073e9SAndroid Build Coastguard Worker super().__init__() 913*da0073e9SAndroid Build Coastguard Worker 914*da0073e9SAndroid Build Coastguard Worker cfg = fairseq_transformer.TransformerConfig() 915*da0073e9SAndroid Build Coastguard Worker cfg.decoder.embed_dim = embed_dim 916*da0073e9SAndroid Build Coastguard Worker cfg.decoder.output_dim = embed_dim 917*da0073e9SAndroid Build Coastguard Worker cfg.decoder.attention_heads = attention_heads 918*da0073e9SAndroid Build Coastguard Worker cfg.decoder.ffn_embed_dim = ffn_embed_dim 919*da0073e9SAndroid Build Coastguard Worker cfg.dropout = dropout 920*da0073e9SAndroid Build Coastguard Worker cfg.decoder.normalize_before = normalize_before 921*da0073e9SAndroid Build Coastguard Worker cfg.decoder.layers = num_layers 922*da0073e9SAndroid Build Coastguard Worker # make embedding behavior same as other encoders 923*da0073e9SAndroid Build Coastguard Worker cfg.no_token_positional_embeddings = True 924*da0073e9SAndroid Build Coastguard Worker cfg.no_scale_embedding = True 925*da0073e9SAndroid Build Coastguard Worker cfg.activation_fn = activation 926*da0073e9SAndroid Build Coastguard Worker 927*da0073e9SAndroid Build Coastguard Worker dictionary = {} # TODO: verify what this is 928*da0073e9SAndroid Build Coastguard Worker 929*da0073e9SAndroid Build Coastguard Worker self.decoder = fairseq_transformer.TransformerDecoder( 930*da0073e9SAndroid Build Coastguard Worker cfg, 931*da0073e9SAndroid Build Coastguard Worker dictionary, 932*da0073e9SAndroid Build Coastguard Worker embedding_layer, 933*da0073e9SAndroid Build Coastguard Worker no_encoder_attn=True, 934*da0073e9SAndroid Build Coastguard Worker output_projection=None, 935*da0073e9SAndroid Build Coastguard Worker ) 936*da0073e9SAndroid Build Coastguard Worker 937*da0073e9SAndroid Build Coastguard Worker if torch_encoder is not None: 938*da0073e9SAndroid Build Coastguard Worker self.decoder = torch_to_fairseq(torch_encoder, self.decoder) # noqa: F821 939*da0073e9SAndroid Build Coastguard Worker self.decoder = self.decoder.eval().cuda().half() 940*da0073e9SAndroid Build Coastguard Worker 941*da0073e9SAndroid Build Coastguard Worker def forward( 942*da0073e9SAndroid Build Coastguard Worker self, 943*da0073e9SAndroid Build Coastguard Worker tokens, 944*da0073e9SAndroid Build Coastguard Worker src_lengths=None, 945*da0073e9SAndroid Build Coastguard Worker with_triangle_mask=False, 946*da0073e9SAndroid Build Coastguard Worker incremental_state=None, 947*da0073e9SAndroid Build Coastguard Worker ): 948*da0073e9SAndroid Build Coastguard Worker return self.decoder( 949*da0073e9SAndroid Build Coastguard Worker prev_output_tokens=tokens, 950*da0073e9SAndroid Build Coastguard Worker encoder_out=None, 951*da0073e9SAndroid Build Coastguard Worker incremental_state=incremental_state, 952*da0073e9SAndroid Build Coastguard Worker features_only=True, 953*da0073e9SAndroid Build Coastguard Worker full_context_alignment=not with_triangle_mask, 954*da0073e9SAndroid Build Coastguard Worker alignment_layer=None, 955*da0073e9SAndroid Build Coastguard Worker alignment_heads=None, 956*da0073e9SAndroid Build Coastguard Worker src_lengths=src_lengths, 957*da0073e9SAndroid Build Coastguard Worker return_all_hiddens=False, 958*da0073e9SAndroid Build Coastguard Worker )[0] 959*da0073e9SAndroid Build Coastguard Worker 960*da0073e9SAndroid Build Coastguard Worker @parametrize("input_dim,attn_mask_dim,is_causal", 961*da0073e9SAndroid Build Coastguard Worker [(3, None, False), (3, 2, False), (3, 2, True), (3, 3, False), (3, 3, True), 962*da0073e9SAndroid Build Coastguard Worker (4, None, False), (4, 2, False), (4, 2, True), (4, 4, False), (4, 4, True)], 963*da0073e9SAndroid Build Coastguard Worker name_fn=lambda input_dim, attn_dim, is_causal: ( 964*da0073e9SAndroid Build Coastguard Worker f"{input_dim}D_input_dim_" + ( 965*da0073e9SAndroid Build Coastguard Worker f"{attn_dim}D_{'causal_' if is_causal else ''}attn_mask" 966*da0073e9SAndroid Build Coastguard Worker if attn_dim is not None else "no_attn_mask"))) 967*da0073e9SAndroid Build Coastguard Worker @parametrize("dropout_p", [0.0, 0.2, 0.5]) 968*da0073e9SAndroid Build Coastguard Worker @sdpa_kernel(backends=[SDPBackend.MATH]) 969*da0073e9SAndroid Build Coastguard Worker def test_scaled_dot_product_attention(self, device, input_dim, attn_mask_dim, is_causal, dropout_p): 970*da0073e9SAndroid Build Coastguard Worker def sdp_ref( 971*da0073e9SAndroid Build Coastguard Worker q, 972*da0073e9SAndroid Build Coastguard Worker k, 973*da0073e9SAndroid Build Coastguard Worker v, 974*da0073e9SAndroid Build Coastguard Worker attn_mask=None, 975*da0073e9SAndroid Build Coastguard Worker dropout_p=0.0): 976*da0073e9SAndroid Build Coastguard Worker E = q.size(-1) 977*da0073e9SAndroid Build Coastguard Worker q = q / math.sqrt(E) 978*da0073e9SAndroid Build Coastguard Worker # (B, Nt, E) x (B, E, Ns) -> (B, Nt, Ns) 979*da0073e9SAndroid Build Coastguard Worker if attn_mask is not None: 980*da0073e9SAndroid Build Coastguard Worker attn = torch.baddbmm(attn_mask, q, k.transpose(-2, -1)) 981*da0073e9SAndroid Build Coastguard Worker else: 982*da0073e9SAndroid Build Coastguard Worker attn = torch.bmm(q, k.transpose(-2, -1)) 983*da0073e9SAndroid Build Coastguard Worker 984*da0073e9SAndroid Build Coastguard Worker attn = torch.nn.functional.softmax(attn, dim=-1) 985*da0073e9SAndroid Build Coastguard Worker if dropout_p > 0.0: 986*da0073e9SAndroid Build Coastguard Worker attn = torch.nn.functional.dropout(attn, p=dropout_p) 987*da0073e9SAndroid Build Coastguard Worker # (B, Nt, Ns) x (B, Ns, E) -> (B, Nt, E) 988*da0073e9SAndroid Build Coastguard Worker output = torch.bmm(attn, v) 989*da0073e9SAndroid Build Coastguard Worker return output 990*da0073e9SAndroid Build Coastguard Worker # TODO: Support cross-device / dtype testing properly when instantiate_device_type_tests() is used. 991*da0073e9SAndroid Build Coastguard Worker dtypes = [torch.double, torch.float] 992*da0073e9SAndroid Build Coastguard Worker for dtype in dtypes: 993*da0073e9SAndroid Build Coastguard Worker 994*da0073e9SAndroid Build Coastguard Worker def rand_tensor(*shape): 995*da0073e9SAndroid Build Coastguard Worker return torch.randn(shape, device=device, dtype=dtype) 996*da0073e9SAndroid Build Coastguard Worker 997*da0073e9SAndroid Build Coastguard Worker # This test compares python and C++ implementations of SDP. 998*da0073e9SAndroid Build Coastguard Worker N, N_prime, L, S, E = 5, 2, 4, 3, 6 999*da0073e9SAndroid Build Coastguard Worker if input_dim == 3: 1000*da0073e9SAndroid Build Coastguard Worker query = rand_tensor(N, L, E) 1001*da0073e9SAndroid Build Coastguard Worker key = rand_tensor(N, S, E) 1002*da0073e9SAndroid Build Coastguard Worker value = rand_tensor(N, S, E) 1003*da0073e9SAndroid Build Coastguard Worker elif input_dim == 4: 1004*da0073e9SAndroid Build Coastguard Worker query = rand_tensor(N, N_prime, L, E) 1005*da0073e9SAndroid Build Coastguard Worker key = rand_tensor(N, N_prime, S, E) 1006*da0073e9SAndroid Build Coastguard Worker value = rand_tensor(N, N_prime, S, E) 1007*da0073e9SAndroid Build Coastguard Worker else: 1008*da0073e9SAndroid Build Coastguard Worker self.fail(f'Invalid input_dim {input_dim} encountered in SDP test') 1009*da0073e9SAndroid Build Coastguard Worker 1010*da0073e9SAndroid Build Coastguard Worker attn_mask = None 1011*da0073e9SAndroid Build Coastguard Worker if attn_mask_dim is not None: 1012*da0073e9SAndroid Build Coastguard Worker assert attn_mask_dim in [2, input_dim] 1013*da0073e9SAndroid Build Coastguard Worker mask_size = (L, S) if attn_mask_dim == 2 else ((N, L, S) if input_dim == 3 else (N, N_prime, L, S)) 1014*da0073e9SAndroid Build Coastguard Worker attn_mask = (torch.ones(mask_size, device=device, dtype=torch.bool).tril() if is_causal 1015*da0073e9SAndroid Build Coastguard Worker else torch.randint(0, 2, size=mask_size, device=device, dtype=torch.bool)) 1016*da0073e9SAndroid Build Coastguard Worker 1017*da0073e9SAndroid Build Coastguard Worker with freeze_rng_state(): 1018*da0073e9SAndroid Build Coastguard Worker # Python impl only supports float mask and 3D inputs. 1019*da0073e9SAndroid Build Coastguard Worker attn_mask_float = attn_mask 1020*da0073e9SAndroid Build Coastguard Worker if attn_mask_float is not None: 1021*da0073e9SAndroid Build Coastguard Worker attn_mask_float = torch.zeros_like(attn_mask, dtype=query.dtype) 1022*da0073e9SAndroid Build Coastguard Worker attn_mask_float.masked_fill_(attn_mask.logical_not(), float("-inf")) 1023*da0073e9SAndroid Build Coastguard Worker q, k, v = query.view(-1, L, E), key.view(-1, S, E), value.view(-1, S, E) 1024*da0073e9SAndroid Build Coastguard Worker a = attn_mask_float 1025*da0073e9SAndroid Build Coastguard Worker if a is not None and attn_mask_dim > 3: 1026*da0073e9SAndroid Build Coastguard Worker a = a.view(-1, L, S) 1027*da0073e9SAndroid Build Coastguard Worker expected = sdp_ref(q, k, v, attn_mask=a, dropout_p=dropout_p) 1028*da0073e9SAndroid Build Coastguard Worker if input_dim > 3: 1029*da0073e9SAndroid Build Coastguard Worker expected = expected.view(-1, N_prime, L, E) 1030*da0073e9SAndroid Build Coastguard Worker 1031*da0073e9SAndroid Build Coastguard Worker with freeze_rng_state(): 1032*da0073e9SAndroid Build Coastguard Worker if is_causal: 1033*da0073e9SAndroid Build Coastguard Worker # NB: Don't pass attn_mask here 1034*da0073e9SAndroid Build Coastguard Worker actual = torch.nn.functional.scaled_dot_product_attention( 1035*da0073e9SAndroid Build Coastguard Worker query, key, value, None, dropout_p, is_causal) 1036*da0073e9SAndroid Build Coastguard Worker 1037*da0073e9SAndroid Build Coastguard Worker # Error case: both explicit attn_mask and is_causal are set 1038*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(RuntimeError, 1039*da0073e9SAndroid Build Coastguard Worker "Explicit attn_mask should not be set when is_causal=True"): 1040*da0073e9SAndroid Build Coastguard Worker torch.nn.functional.scaled_dot_product_attention( 1041*da0073e9SAndroid Build Coastguard Worker query, key, value, attn_mask, dropout_p, is_causal) 1042*da0073e9SAndroid Build Coastguard Worker else: 1043*da0073e9SAndroid Build Coastguard Worker actual = torch.nn.functional.scaled_dot_product_attention( 1044*da0073e9SAndroid Build Coastguard Worker query, key, value, attn_mask, dropout_p, is_causal) 1045*da0073e9SAndroid Build Coastguard Worker 1046*da0073e9SAndroid Build Coastguard Worker self.assertEqual(actual, expected) 1047*da0073e9SAndroid Build Coastguard Worker 1048*da0073e9SAndroid Build Coastguard Worker if attn_mask_dim is None: 1049*da0073e9SAndroid Build Coastguard Worker q = q.double().clone() 1050*da0073e9SAndroid Build Coastguard Worker k = k.double().clone() 1051*da0073e9SAndroid Build Coastguard Worker v = v.double().clone() 1052*da0073e9SAndroid Build Coastguard Worker q.requires_grad_() 1053*da0073e9SAndroid Build Coastguard Worker k.requires_grad_() 1054*da0073e9SAndroid Build Coastguard Worker v.requires_grad_() 1055*da0073e9SAndroid Build Coastguard Worker 1056*da0073e9SAndroid Build Coastguard Worker assert gradcheck(lambda *args, **kwargs: wrapper_set_seed(sdp_ref, *args, **kwargs), 1057*da0073e9SAndroid Build Coastguard Worker (q, k, v, attn_mask, dropout_p)) 1058*da0073e9SAndroid Build Coastguard Worker assert gradcheck(lambda *args, **kwargs: 1059*da0073e9SAndroid Build Coastguard Worker wrapper_set_seed(torch.nn.functional.scaled_dot_product_attention, *args, **kwargs), 1060*da0073e9SAndroid Build Coastguard Worker (q, k, v, attn_mask, dropout_p)) 1061*da0073e9SAndroid Build Coastguard Worker 1062*da0073e9SAndroid Build Coastguard Worker def test_incompatible_mask(self, device): 1063*da0073e9SAndroid Build Coastguard Worker def ones_tensor(*shape): 1064*da0073e9SAndroid Build Coastguard Worker return torch.ones(shape, dtype=torch.float32) 1065*da0073e9SAndroid Build Coastguard Worker S, L, E, H = 1, 2, 4, 1 1066*da0073e9SAndroid Build Coastguard Worker qkv = ones_tensor(S, L, E) 1067*da0073e9SAndroid Build Coastguard Worker 1068*da0073e9SAndroid Build Coastguard Worker mha = nn.MultiheadAttention(E, H) 1069*da0073e9SAndroid Build Coastguard Worker mha.in_proj_weight = Parameter(torch.ones((E * 3, E))) 1070*da0073e9SAndroid Build Coastguard Worker mha.out_proj.weight = Parameter(torch.ones((E, E))) 1071*da0073e9SAndroid Build Coastguard Worker qkv = qkv.to(float) 1072*da0073e9SAndroid Build Coastguard Worker kpm = ones_tensor(S, L) * float("-inf") 1073*da0073e9SAndroid Build Coastguard Worker am = ones_tensor(L, L).to(bool) 1074*da0073e9SAndroid Build Coastguard Worker 1075*da0073e9SAndroid Build Coastguard Worker def func(): 1076*da0073e9SAndroid Build Coastguard Worker return mha(qkv, qkv, qkv, need_weights=False, key_padding_mask=kpm, attn_mask=am) 1077*da0073e9SAndroid Build Coastguard Worker 1078*da0073e9SAndroid Build Coastguard Worker self.assertRaises(RuntimeError, func) 1079*da0073e9SAndroid Build Coastguard Worker 1080*da0073e9SAndroid Build Coastguard Worker @unittest.skipIf(TEST_WITH_CROSSREF, 'Fastpath not available with crossref') 1081*da0073e9SAndroid Build Coastguard Worker @torch.no_grad() 1082*da0073e9SAndroid Build Coastguard Worker def test_mask_check_fastpath(self): 1083*da0073e9SAndroid Build Coastguard Worker """ 1084*da0073e9SAndroid Build Coastguard Worker Test that fastpath is executed independently of the masks that are passed. 1085*da0073e9SAndroid Build Coastguard Worker If the passed key padding mask is left aligned or mask_check=False, test that nested tensors are used 1086*da0073e9SAndroid Build Coastguard Worker (sparsity fastpath), otherwise use fastpath with traditional tensors. 1087*da0073e9SAndroid Build Coastguard Worker Also test that fast path is executed with both key padding mask and attention mask passed at the same time. 1088*da0073e9SAndroid Build Coastguard Worker """ 1089*da0073e9SAndroid Build Coastguard Worker 1090*da0073e9SAndroid Build Coastguard Worker x = torch.Tensor([[[1, 2], [3, 4], [5, 6]]]).to(torch.float) 1091*da0073e9SAndroid Build Coastguard Worker 1092*da0073e9SAndroid Build Coastguard Worker def _test_fastpath(model, key_padding_mask, mock_return_value, attn_mask=None, nested_tensors=True): 1093*da0073e9SAndroid Build Coastguard Worker with patch('torch._transformer_encoder_layer_fwd') as fastpath_mock: 1094*da0073e9SAndroid Build Coastguard Worker fastpath_mock.return_value = mock_return_value 1095*da0073e9SAndroid Build Coastguard Worker model(x, src_key_padding_mask=key_padding_mask, mask=attn_mask) 1096*da0073e9SAndroid Build Coastguard Worker 1097*da0073e9SAndroid Build Coastguard Worker # If mock was called, fastpath was taken 1098*da0073e9SAndroid Build Coastguard Worker self.assertTrue(fastpath_mock.called) 1099*da0073e9SAndroid Build Coastguard Worker 1100*da0073e9SAndroid Build Coastguard Worker # If mock was called with nested tensors, sparsity fastpath was taken 1101*da0073e9SAndroid Build Coastguard Worker for call_args, _ in fastpath_mock.call_args_list: 1102*da0073e9SAndroid Build Coastguard Worker self.assertEqual(call_args[0].is_nested, nested_tensors) 1103*da0073e9SAndroid Build Coastguard Worker 1104*da0073e9SAndroid Build Coastguard Worker encoder_layer = torch.nn.TransformerEncoderLayer(d_model=2, nhead=2, dim_feedforward=8, batch_first=True) 1105*da0073e9SAndroid Build Coastguard Worker 1106*da0073e9SAndroid Build Coastguard Worker model = torch.nn.TransformerEncoder(encoder_layer, num_layers=2, enable_nested_tensor=True, mask_check=True) 1107*da0073e9SAndroid Build Coastguard Worker model.eval() 1108*da0073e9SAndroid Build Coastguard Worker 1109*da0073e9SAndroid Build Coastguard Worker aligned_key_padding_mask = torch.Tensor([[0, 0, 1]]).to(torch.bool) 1110*da0073e9SAndroid Build Coastguard Worker not_aligned_key_padding_mask = torch.Tensor([[1, 0, 1]]).to(torch.bool) 1111*da0073e9SAndroid Build Coastguard Worker attn_mask = torch.Tensor([[1, 0, 1], [0, 1, 0], [1, 0, 1]]).to(torch.bool) 1112*da0073e9SAndroid Build Coastguard Worker nested_tensor_return_value = torch.nested.nested_tensor([torch.ones((2, 2), dtype=torch.float)]) 1113*da0073e9SAndroid Build Coastguard Worker tensor_return_value = torch.ones((1, 3, 2), dtype=torch.float) 1114*da0073e9SAndroid Build Coastguard Worker 1115*da0073e9SAndroid Build Coastguard Worker # Left aligned mask results in sparsity fastpath 1116*da0073e9SAndroid Build Coastguard Worker _test_fastpath(model, aligned_key_padding_mask, nested_tensor_return_value, nested_tensors=True) 1117*da0073e9SAndroid Build Coastguard Worker 1118*da0073e9SAndroid Build Coastguard Worker # Not aligned mask results in fastpath 1119*da0073e9SAndroid Build Coastguard Worker _test_fastpath(model, not_aligned_key_padding_mask, tensor_return_value, nested_tensors=False) 1120*da0073e9SAndroid Build Coastguard Worker 1121*da0073e9SAndroid Build Coastguard Worker model = torch.nn.TransformerEncoder(encoder_layer, num_layers=2, enable_nested_tensor=False, mask_check=True) 1122*da0073e9SAndroid Build Coastguard Worker model.eval() 1123*da0073e9SAndroid Build Coastguard Worker 1124*da0073e9SAndroid Build Coastguard Worker # If nested tensor disabled, fastpath is always taken 1125*da0073e9SAndroid Build Coastguard Worker _test_fastpath(model, aligned_key_padding_mask, tensor_return_value, nested_tensors=False) 1126*da0073e9SAndroid Build Coastguard Worker _test_fastpath(model, not_aligned_key_padding_mask, tensor_return_value, nested_tensors=False) 1127*da0073e9SAndroid Build Coastguard Worker # Fast path is taken if both attention mask and key padding mask are present 1128*da0073e9SAndroid Build Coastguard Worker _test_fastpath(model, aligned_key_padding_mask, tensor_return_value, attn_mask=attn_mask, nested_tensors=False) 1129*da0073e9SAndroid Build Coastguard Worker 1130*da0073e9SAndroid Build Coastguard Worker model = torch.nn.TransformerEncoder(encoder_layer, num_layers=2, enable_nested_tensor=True, mask_check=False) 1131*da0073e9SAndroid Build Coastguard Worker model.eval() 1132*da0073e9SAndroid Build Coastguard Worker 1133*da0073e9SAndroid Build Coastguard Worker # Mask check disabled results in sparisty fastpath, independently of the mask 1134*da0073e9SAndroid Build Coastguard Worker _test_fastpath(model, aligned_key_padding_mask, nested_tensor_return_value, nested_tensors=True) 1135*da0073e9SAndroid Build Coastguard Worker _test_fastpath(model, not_aligned_key_padding_mask, nested_tensor_return_value, nested_tensors=True) 1136*da0073e9SAndroid Build Coastguard Worker 1137*da0073e9SAndroid Build Coastguard Worker # Test failing MHA when bias was NoneType 1138*da0073e9SAndroid Build Coastguard Worker def test_bias_is_none(self): 1139*da0073e9SAndroid Build Coastguard Worker x = torch.rand((1, 5, 10)) 1140*da0073e9SAndroid Build Coastguard Worker model = torch.nn.modules.activation.MultiheadAttention(10, 1, bias=False, batch_first=True) 1141*da0073e9SAndroid Build Coastguard Worker model.eval() 1142*da0073e9SAndroid Build Coastguard Worker model(x, x, x) 1143*da0073e9SAndroid Build Coastguard Worker # completes without error 1144*da0073e9SAndroid Build Coastguard Worker 1145*da0073e9SAndroid Build Coastguard Worker def test_transformer_bias_is_none(self, device): 1146*da0073e9SAndroid Build Coastguard Worker batch_size = 2 1147*da0073e9SAndroid Build Coastguard Worker seqlen = 3 1148*da0073e9SAndroid Build Coastguard Worker d_model = 8 1149*da0073e9SAndroid Build Coastguard Worker nhead = 4 1150*da0073e9SAndroid Build Coastguard Worker 1151*da0073e9SAndroid Build Coastguard Worker encoder_layer = torch.nn.TransformerEncoderLayer(d_model, nhead, bias=False, batch_first=True, device=device) 1152*da0073e9SAndroid Build Coastguard Worker encoder_layer.eval() 1153*da0073e9SAndroid Build Coastguard Worker x = torch.randn(batch_size, seqlen, d_model, device=device) 1154*da0073e9SAndroid Build Coastguard Worker # runs without error 1155*da0073e9SAndroid Build Coastguard Worker encoder_layer(x) 1156*da0073e9SAndroid Build Coastguard Worker 1157*da0073e9SAndroid Build Coastguard Worker with self.assertWarnsRegex(UserWarning, "encoder_layer.self_attn was passed bias=False"): 1158*da0073e9SAndroid Build Coastguard Worker encoder = torch.nn.TransformerEncoder(encoder_layer, num_layers=1).eval() 1159*da0073e9SAndroid Build Coastguard Worker encoder(x) 1160*da0073e9SAndroid Build Coastguard Worker 1161*da0073e9SAndroid Build Coastguard Worker with self.assertWarnsRegex(UserWarning, "self_attn was passed bias=False"): 1162*da0073e9SAndroid Build Coastguard Worker transformer = torch.nn.Transformer( 1163*da0073e9SAndroid Build Coastguard Worker d_model=d_model, nhead=nhead, bias=False, batch_first=True, device=device 1164*da0073e9SAndroid Build Coastguard Worker ).eval() 1165*da0073e9SAndroid Build Coastguard Worker transformer(x, x) 1166*da0073e9SAndroid Build Coastguard Worker 1167*da0073e9SAndroid Build Coastguard Worker def test_train_with_is_causal(self, device): 1168*da0073e9SAndroid Build Coastguard Worker # training with is_causal 1169*da0073e9SAndroid Build Coastguard Worker S, L, E, H = 1, 2, 2, 1 1170*da0073e9SAndroid Build Coastguard Worker layer = nn.TransformerEncoderLayer( 1171*da0073e9SAndroid Build Coastguard Worker d_model=2, 1172*da0073e9SAndroid Build Coastguard Worker dim_feedforward=4, 1173*da0073e9SAndroid Build Coastguard Worker nhead=H, 1174*da0073e9SAndroid Build Coastguard Worker batch_first=True, 1175*da0073e9SAndroid Build Coastguard Worker activation="gelu", 1176*da0073e9SAndroid Build Coastguard Worker dropout=0, 1177*da0073e9SAndroid Build Coastguard Worker ) 1178*da0073e9SAndroid Build Coastguard Worker criterion = nn.MSELoss() 1179*da0073e9SAndroid Build Coastguard Worker encoder = nn.TransformerEncoder(layer, 2).to(device) 1180*da0073e9SAndroid Build Coastguard Worker optimizer = optim.SGD(encoder.parameters(), lr=0.1, momentum=0.9) 1181*da0073e9SAndroid Build Coastguard Worker encoder.train() 1182*da0073e9SAndroid Build Coastguard Worker 1183*da0073e9SAndroid Build Coastguard Worker encoder.train() 1184*da0073e9SAndroid Build Coastguard Worker optimizer.zero_grad() 1185*da0073e9SAndroid Build Coastguard Worker inputs = torch.randn(S, L, E).to(device) 1186*da0073e9SAndroid Build Coastguard Worker mask = torch.nn.Transformer.generate_square_subsequent_mask( 1187*da0073e9SAndroid Build Coastguard Worker inputs.size(1), device=device 1188*da0073e9SAndroid Build Coastguard Worker ) 1189*da0073e9SAndroid Build Coastguard Worker 1190*da0073e9SAndroid Build Coastguard Worker outputs = encoder(inputs, mask=mask, is_causal=True) 1191*da0073e9SAndroid Build Coastguard Worker 1192*da0073e9SAndroid Build Coastguard Worker loss = criterion(outputs[:, 0:2, :], inputs[:, 0:2, :]) 1193*da0073e9SAndroid Build Coastguard Worker loss.backward() 1194*da0073e9SAndroid Build Coastguard Worker optimizer.step() 1195*da0073e9SAndroid Build Coastguard Worker 1196*da0073e9SAndroid Build Coastguard Worker # inference with is_causal 1197*da0073e9SAndroid Build Coastguard Worker t_qvk = torch.randn((S, L, E), device=device, dtype=torch.float32) 1198*da0073e9SAndroid Build Coastguard Worker mha = nn.MultiheadAttention(E, H).to(device) 1199*da0073e9SAndroid Build Coastguard Worker mask = torch.nn.Transformer.generate_square_subsequent_mask( 1200*da0073e9SAndroid Build Coastguard Worker S, device=device 1201*da0073e9SAndroid Build Coastguard Worker ) 1202*da0073e9SAndroid Build Coastguard Worker 1203*da0073e9SAndroid Build Coastguard Worker attn_out, _ = mha(t_qvk, t_qvk, t_qvk, attn_mask=mask, is_causal=True) 1204*da0073e9SAndroid Build Coastguard Worker 1205*da0073e9SAndroid Build Coastguard Worker # Can't give only is_causal 1206*da0073e9SAndroid Build Coastguard Worker attn_mask = torch.randint(0, 2, size=(L, L), device=device, dtype=torch.bool) 1207*da0073e9SAndroid Build Coastguard Worker with self.assertRaises(RuntimeError): 1208*da0073e9SAndroid Build Coastguard Worker _ = mha(t_qvk, t_qvk, t_qvk, is_causal=True) 1209*da0073e9SAndroid Build Coastguard Worker 1210*da0073e9SAndroid Build Coastguard Worker # # Passing a causal mask sets is_causal to 1 1211*da0073e9SAndroid Build Coastguard Worker causal_mask = torch.triu( 1212*da0073e9SAndroid Build Coastguard Worker torch.ones(L, L, device=inputs.device) * float('-inf'), diagonal=1 1213*da0073e9SAndroid Build Coastguard Worker ).to(torch.bool) 1214*da0073e9SAndroid Build Coastguard Worker 1215*da0073e9SAndroid Build Coastguard Worker mock_layer = MagicMock(torch.nn.MultiheadAttention(E, H), return_value=inputs) 1216*da0073e9SAndroid Build Coastguard Worker encoder.layers[1] = mock_layer 1217*da0073e9SAndroid Build Coastguard Worker outputs = encoder(inputs, mask=causal_mask) 1218*da0073e9SAndroid Build Coastguard Worker mock_layer.assert_called_with(ANY, src_mask=ANY, is_causal=True, src_key_padding_mask=ANY) 1219*da0073e9SAndroid Build Coastguard Worker 1220*da0073e9SAndroid Build Coastguard Worker # check expected numerical values with all kernels 1221*da0073e9SAndroid Build Coastguard Worker self.is_causal_kernels([SDPBackend.MATH], device) 1222*da0073e9SAndroid Build Coastguard Worker 1223*da0073e9SAndroid Build Coastguard Worker def is_causal_kernels(self, kernels, device): 1224*da0073e9SAndroid Build Coastguard Worker def ones_tensor(*shape): 1225*da0073e9SAndroid Build Coastguard Worker return torch.ones(shape, device=device, dtype=torch.float32).to(device) 1226*da0073e9SAndroid Build Coastguard Worker S, L, E, H = 1, 2, 4, 1 1227*da0073e9SAndroid Build Coastguard Worker qkv = ones_tensor(S, L, E) 1228*da0073e9SAndroid Build Coastguard Worker 1229*da0073e9SAndroid Build Coastguard Worker mha = nn.MultiheadAttention(E, H).to(device) 1230*da0073e9SAndroid Build Coastguard Worker mha.in_proj_weight = Parameter(torch.ones((E * 3, E), device=device)) 1231*da0073e9SAndroid Build Coastguard Worker mha.out_proj.weight = Parameter(torch.ones((E, E), device=device)) 1232*da0073e9SAndroid Build Coastguard Worker expected = torch.ones(size=(S, L, E)).to(device) * 16 1233*da0073e9SAndroid Build Coastguard Worker mask = torch.nn.Transformer.generate_square_subsequent_mask( 1234*da0073e9SAndroid Build Coastguard Worker qkv.size(1), device=device 1235*da0073e9SAndroid Build Coastguard Worker ) 1236*da0073e9SAndroid Build Coastguard Worker 1237*da0073e9SAndroid Build Coastguard Worker for kernel in kernels: 1238*da0073e9SAndroid Build Coastguard Worker with sdpa_kernel(backends=[kernel]): 1239*da0073e9SAndroid Build Coastguard Worker actual, _ = mha(qkv, qkv, qkv, attn_mask=mask, need_weights=False, is_causal=True) 1240*da0073e9SAndroid Build Coastguard Worker self.assertTrue(torch.equal(actual, expected)) 1241*da0073e9SAndroid Build Coastguard Worker 1242*da0073e9SAndroid Build Coastguard Worker if kernel != SDPBackend.MATH: 1243*da0073e9SAndroid Build Coastguard Worker # fails with embedding size not multiple of 4 1244*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(RuntimeError, "No available kernel"): 1245*da0073e9SAndroid Build Coastguard Worker qkv_f, mha_f = ones_tensor(S, L, 2), nn.MultiheadAttention(2, H).to(device) 1246*da0073e9SAndroid Build Coastguard Worker mask = torch.nn.Transformer.generate_square_subsequent_mask( 1247*da0073e9SAndroid Build Coastguard Worker qkv_f.size(1), device=device 1248*da0073e9SAndroid Build Coastguard Worker ) 1249*da0073e9SAndroid Build Coastguard Worker _ = mha_f(qkv_f, qkv_f, qkv_f, attn_mask=mask, need_weights=False, is_causal=True) 1250*da0073e9SAndroid Build Coastguard Worker torch.cuda.synchronize() 1251*da0073e9SAndroid Build Coastguard Worker 1252*da0073e9SAndroid Build Coastguard Worker @skipIfRocm # Missing EFFICIENT_ATTENTION 1253*da0073e9SAndroid Build Coastguard Worker @unittest.skipIf( 1254*da0073e9SAndroid Build Coastguard Worker not PLATFORM_SUPPORTS_FLASH_ATTENTION, "Platform does not supposrt fused SDPA or pre-SM80 hardware" 1255*da0073e9SAndroid Build Coastguard Worker ) 1256*da0073e9SAndroid Build Coastguard Worker def test_is_causal_gpu(self): 1257*da0073e9SAndroid Build Coastguard Worker device = 'cuda' 1258*da0073e9SAndroid Build Coastguard Worker self.is_causal_kernels([SDPBackend.MATH, SDPBackend.EFFICIENT_ATTENTION], device) 1259*da0073e9SAndroid Build Coastguard Worker 1260*da0073e9SAndroid Build Coastguard Worker def test_script_mha_in_proj_weight_none(self): 1261*da0073e9SAndroid Build Coastguard Worker mha = torch.nn.MultiheadAttention( 1262*da0073e9SAndroid Build Coastguard Worker embed_dim=128, num_heads=8, kdim=256, vdim=256 1263*da0073e9SAndroid Build Coastguard Worker ).eval() 1264*da0073e9SAndroid Build Coastguard Worker 1265*da0073e9SAndroid Build Coastguard Worker torch.jit.script(mha) 1266*da0073e9SAndroid Build Coastguard Worker 1267*da0073e9SAndroid Build Coastguard Worker @unittest.skipIf(TEST_WITH_CROSSREF, 'Fastpath not available with crossref') 1268*da0073e9SAndroid Build Coastguard Worker @torch.no_grad() 1269*da0073e9SAndroid Build Coastguard Worker def test_disable_fastpath(self, device): 1270*da0073e9SAndroid Build Coastguard Worker def _test_te_fastpath_called(model, args, kwargs=None, return_value=None, is_called=True): 1271*da0073e9SAndroid Build Coastguard Worker if kwargs is None: 1272*da0073e9SAndroid Build Coastguard Worker kwargs = {} 1273*da0073e9SAndroid Build Coastguard Worker with patch('torch._transformer_encoder_layer_fwd') as fastpath_mock: 1274*da0073e9SAndroid Build Coastguard Worker fastpath_mock.return_value = return_value 1275*da0073e9SAndroid Build Coastguard Worker output = model(*args, **kwargs) 1276*da0073e9SAndroid Build Coastguard Worker self.assertTrue(fastpath_mock.called == is_called) 1277*da0073e9SAndroid Build Coastguard Worker 1278*da0073e9SAndroid Build Coastguard Worker def _test_mha_fastpath_called(model, args, kwargs=None, return_value=None, is_called=True): 1279*da0073e9SAndroid Build Coastguard Worker if kwargs is None: 1280*da0073e9SAndroid Build Coastguard Worker kwargs = {} 1281*da0073e9SAndroid Build Coastguard Worker with patch('torch._native_multi_head_attention') as fastpath_mock: 1282*da0073e9SAndroid Build Coastguard Worker fastpath_mock.return_value = return_value 1283*da0073e9SAndroid Build Coastguard Worker output = model(*args, **kwargs) 1284*da0073e9SAndroid Build Coastguard Worker self.assertTrue(fastpath_mock.called == is_called) 1285*da0073e9SAndroid Build Coastguard Worker 1286*da0073e9SAndroid Build Coastguard Worker inp = torch.tensor([[[1, 2], [3, 4], [5, 6]]], dtype=torch.float32, device=device) 1287*da0073e9SAndroid Build Coastguard Worker aligned_key_padding_mask = torch.tensor([[0, 0, 1]], dtype=torch.bool, device=device) 1288*da0073e9SAndroid Build Coastguard Worker src_key_padding_mask = torch.tensor([[1, 0, 1]], dtype=torch.bool, device=device) 1289*da0073e9SAndroid Build Coastguard Worker attn_mask = torch.tensor([[1, 0, 1], [0, 1, 0], [1, 0, 1]], dtype=torch.bool, device=device) 1290*da0073e9SAndroid Build Coastguard Worker te_return_value = torch.ones((1, 3, 2), dtype=torch.float32) 1291*da0073e9SAndroid Build Coastguard Worker 1292*da0073e9SAndroid Build Coastguard Worker encoder_layer = torch.nn.TransformerEncoderLayer(d_model=2, nhead=2, dim_feedforward=8, batch_first=True) 1293*da0073e9SAndroid Build Coastguard Worker te = torch.nn.TransformerEncoder(encoder_layer, num_layers=2, enable_nested_tensor=True, mask_check=True) 1294*da0073e9SAndroid Build Coastguard Worker te = te.to(device).eval() 1295*da0073e9SAndroid Build Coastguard Worker 1296*da0073e9SAndroid Build Coastguard Worker t = torch.nn.Transformer(d_model=2, nhead=2, batch_first=True, device=device).eval() 1297*da0073e9SAndroid Build Coastguard Worker src = torch.tensor([[[0, 1], [2, 3], [4, 5]]], dtype=torch.float32, device=device) 1298*da0073e9SAndroid Build Coastguard Worker tgt = torch.tensor([[[0, 1], [2, 3], [4, 5], [6, 7]]], dtype=torch.float32, device=device) 1299*da0073e9SAndroid Build Coastguard Worker t_return_value = torch.ones((1, 3, 2), dtype=torch.float32, device=device) 1300*da0073e9SAndroid Build Coastguard Worker 1301*da0073e9SAndroid Build Coastguard Worker mha = nn.MultiheadAttention(2, 2, batch_first=True, device=device).eval() 1302*da0073e9SAndroid Build Coastguard Worker q = torch.tensor([[[0, 1], [2, 3]]], dtype=torch.float32, device=device) 1303*da0073e9SAndroid Build Coastguard Worker mha_return_value = torch.ones((1, 3, 2), dtype=torch.float32, device=device) 1304*da0073e9SAndroid Build Coastguard Worker 1305*da0073e9SAndroid Build Coastguard Worker _test_te_fastpath_called( 1306*da0073e9SAndroid Build Coastguard Worker te, (inp,), kwargs={'src_key_padding_mask': src_key_padding_mask}, 1307*da0073e9SAndroid Build Coastguard Worker return_value=te_return_value, is_called=True 1308*da0073e9SAndroid Build Coastguard Worker ) 1309*da0073e9SAndroid Build Coastguard Worker _test_te_fastpath_called(t, (src, tgt), return_value=t_return_value, is_called=True) 1310*da0073e9SAndroid Build Coastguard Worker _test_mha_fastpath_called(mha, (q, q, q,), return_value=mha_return_value, is_called=True) 1311*da0073e9SAndroid Build Coastguard Worker 1312*da0073e9SAndroid Build Coastguard Worker torch.backends.mha.set_fastpath_enabled(False) 1313*da0073e9SAndroid Build Coastguard Worker _test_te_fastpath_called( 1314*da0073e9SAndroid Build Coastguard Worker te, (inp,), kwargs={'src_key_padding_mask': src_key_padding_mask}, 1315*da0073e9SAndroid Build Coastguard Worker return_value=te_return_value, is_called=False 1316*da0073e9SAndroid Build Coastguard Worker ) 1317*da0073e9SAndroid Build Coastguard Worker _test_te_fastpath_called(t, (src, tgt), return_value=t_return_value, is_called=False) 1318*da0073e9SAndroid Build Coastguard Worker _test_mha_fastpath_called(mha, (q, q, q,), return_value=mha_return_value, is_called=False) 1319*da0073e9SAndroid Build Coastguard Worker 1320*da0073e9SAndroid Build Coastguard Worker torch.backends.mha.set_fastpath_enabled(True) 1321*da0073e9SAndroid Build Coastguard Worker _test_te_fastpath_called( 1322*da0073e9SAndroid Build Coastguard Worker te, (inp,), kwargs={'src_key_padding_mask': src_key_padding_mask}, 1323*da0073e9SAndroid Build Coastguard Worker return_value=te_return_value, is_called=True 1324*da0073e9SAndroid Build Coastguard Worker ) 1325*da0073e9SAndroid Build Coastguard Worker _test_te_fastpath_called(t, (src, tgt), return_value=t_return_value, is_called=True) 1326*da0073e9SAndroid Build Coastguard Worker _test_mha_fastpath_called(mha, (q, q, q,), return_value=mha_return_value, is_called=True) 1327*da0073e9SAndroid Build Coastguard Worker 1328*da0073e9SAndroid Build Coastguard Worker 1329*da0073e9SAndroid Build Coastguard Workerclass TestSDPAFailureModes(NNTestCase): 1330*da0073e9SAndroid Build Coastguard Worker """ Used to test the failure modes of scaled_dot_product_attention 1331*da0073e9SAndroid Build Coastguard Worker """ 1332*da0073e9SAndroid Build Coastguard Worker _do_cuda_memory_leak_check = True 1333*da0073e9SAndroid Build Coastguard Worker _do_cuda_non_default_stream = True 1334*da0073e9SAndroid Build Coastguard Worker 1335*da0073e9SAndroid Build Coastguard Worker @onlyCUDA 1336*da0073e9SAndroid Build Coastguard Worker @unittest.skipIf( 1337*da0073e9SAndroid Build Coastguard Worker not PLATFORM_SUPPORTS_FLASH_ATTENTION or not isSM8XDevice, 1338*da0073e9SAndroid Build Coastguard Worker "Does not support fused SDPA or not SM86+ hardware", 1339*da0073e9SAndroid Build Coastguard Worker ) 1340*da0073e9SAndroid Build Coastguard Worker @parametrize("head_dim", [193, 204, 256]) 1341*da0073e9SAndroid Build Coastguard Worker @parametrize("dropout_p", [0.0, 0.2]) 1342*da0073e9SAndroid Build Coastguard Worker def test_flash_backward_failure_sm86plus(self, device, head_dim: int, dropout_p: float): 1343*da0073e9SAndroid Build Coastguard Worker dtype = torch.float16 1344*da0073e9SAndroid Build Coastguard Worker make_tensor = partial(torch.rand, device=device, dtype=dtype) 1345*da0073e9SAndroid Build Coastguard Worker # See check_requires_grad_and_head_dim_gt192_constraints_on_sm86_89 in 1346*da0073e9SAndroid Build Coastguard Worker # pytorch/aten/src/ATen/native/transformers/cuda/sdp_utils.h 1347*da0073e9SAndroid Build Coastguard Worker size = (2, 2, 4, head_dim) 1348*da0073e9SAndroid Build Coastguard Worker q, k, v = make_tensor(size), make_tensor(size), make_tensor(size) 1349*da0073e9SAndroid Build Coastguard Worker 1350*da0073e9SAndroid Build Coastguard Worker with sdpa_kernel(backends=[SDPBackend.MATH]): 1351*da0073e9SAndroid Build Coastguard Worker math_ref = torch.nn.functional.scaled_dot_product_attention(q, k, v, None, 0.0, False) 1352*da0073e9SAndroid Build Coastguard Worker 1353*da0073e9SAndroid Build Coastguard Worker with sdpa_kernel(backends=[SDPBackend.FLASH_ATTENTION]): 1354*da0073e9SAndroid Build Coastguard Worker # Should not fail because inputs don't require grad 1355*da0073e9SAndroid Build Coastguard Worker flash_ref = torch.nn.functional.scaled_dot_product_attention(q, k, v, None, 0.0, False) 1356*da0073e9SAndroid Build Coastguard Worker 1357*da0073e9SAndroid Build Coastguard Worker self.assertEqual(math_ref, flash_ref, atol=1e-3, rtol=1e-3) 1358*da0073e9SAndroid Build Coastguard Worker 1359*da0073e9SAndroid Build Coastguard Worker # Should fail because inputs require grad 1360*da0073e9SAndroid Build Coastguard Worker q = make_tensor(size, requires_grad=True) 1361*da0073e9SAndroid Build Coastguard Worker k = make_tensor(size, requires_grad=True) 1362*da0073e9SAndroid Build Coastguard Worker v = make_tensor(size, requires_grad=True) 1363*da0073e9SAndroid Build Coastguard Worker if 192 < head_dim <= 224 or (head_dim > 224 and dropout_p != 0.0): 1364*da0073e9SAndroid Build Coastguard Worker self.assertRaises( 1365*da0073e9SAndroid Build Coastguard Worker RuntimeError, 1366*da0073e9SAndroid Build Coastguard Worker lambda: torch.nn.functional.scaled_dot_product_attention( 1367*da0073e9SAndroid Build Coastguard Worker q, k, v, None, dropout_p, False 1368*da0073e9SAndroid Build Coastguard Worker ), 1369*da0073e9SAndroid Build Coastguard Worker ) 1370*da0073e9SAndroid Build Coastguard Worker else: 1371*da0073e9SAndroid Build Coastguard Worker flash_ref = torch.nn.functional.scaled_dot_product_attention(q, k, v, None, dropout_p, False) 1372*da0073e9SAndroid Build Coastguard Worker 1373*da0073e9SAndroid Build Coastguard Worker @onlyCUDA 1374*da0073e9SAndroid Build Coastguard Worker def test_dispatch_fails_no_backend(self, device): 1375*da0073e9SAndroid Build Coastguard Worker dtype = torch.float16 1376*da0073e9SAndroid Build Coastguard Worker with sdpa_kernel(backends=[SDPBackend.ERROR]): 1377*da0073e9SAndroid Build Coastguard Worker size = (2, 3, 4) 1378*da0073e9SAndroid Build Coastguard Worker q = torch.randn(size, device=device, dtype=dtype) 1379*da0073e9SAndroid Build Coastguard Worker k = torch.randn(size, device=device, dtype=dtype) 1380*da0073e9SAndroid Build Coastguard Worker v = torch.randn(size, device=device, dtype=dtype) 1381*da0073e9SAndroid Build Coastguard Worker self.assertRaisesRegex(RuntimeError, "No viable backend for scaled_dot_product_attention was found.", 1382*da0073e9SAndroid Build Coastguard Worker lambda: torch._fused_sdp_choice(q, k, v)) 1383*da0073e9SAndroid Build Coastguard Worker self.assertRaisesRegex(RuntimeError, "No viable backend for scaled_dot_product_attention was found.", 1384*da0073e9SAndroid Build Coastguard Worker lambda: torch.nn.functional.scaled_dot_product_attention(q, k, v)) 1385*da0073e9SAndroid Build Coastguard Worker 1386*da0073e9SAndroid Build Coastguard Worker @onlyCUDA 1387*da0073e9SAndroid Build Coastguard Worker @unittest.skipIf(not PLATFORM_SUPPORTS_FUSED_ATTENTION, "Does not support fused scaled dot product attention") 1388*da0073e9SAndroid Build Coastguard Worker @parametrize( 1389*da0073e9SAndroid Build Coastguard Worker "kernel", 1390*da0073e9SAndroid Build Coastguard Worker PLATFORM_SPECIFIC_SDPA, 1391*da0073e9SAndroid Build Coastguard Worker ) 1392*da0073e9SAndroid Build Coastguard Worker def test_invalid_fused_inputs_dim_3(self, device, kernel: SDPBackend): 1393*da0073e9SAndroid Build Coastguard Worker with sdpa_kernel(backends=[kernel]): 1394*da0073e9SAndroid Build Coastguard Worker # Dim is not 4 1395*da0073e9SAndroid Build Coastguard Worker size = (2, 3, 8) 1396*da0073e9SAndroid Build Coastguard Worker dtype = torch.float16 1397*da0073e9SAndroid Build Coastguard Worker q = torch.randn(size, device=device, dtype=dtype) 1398*da0073e9SAndroid Build Coastguard Worker k = torch.randn(size, device=device, dtype=dtype) 1399*da0073e9SAndroid Build Coastguard Worker v = torch.randn(size, device=device, dtype=dtype) 1400*da0073e9SAndroid Build Coastguard Worker with self.assertWarnsRegex(UserWarning, "Both fused kernels requires query, key and value to be 4 dimensional"): 1401*da0073e9SAndroid Build Coastguard Worker self.assertRaises(RuntimeError, lambda: torch.nn.functional.scaled_dot_product_attention( 1402*da0073e9SAndroid Build Coastguard Worker q, k, v, None, 0.0, False)) 1403*da0073e9SAndroid Build Coastguard Worker 1404*da0073e9SAndroid Build Coastguard Worker @onlyCUDA 1405*da0073e9SAndroid Build Coastguard Worker @unittest.skipIf(not PLATFORM_SUPPORTS_FUSED_ATTENTION, "Does not support fused scaled dot product attention") 1406*da0073e9SAndroid Build Coastguard Worker @parametrize( 1407*da0073e9SAndroid Build Coastguard Worker "kernel", 1408*da0073e9SAndroid Build Coastguard Worker PLATFORM_SPECIFIC_SDPA, 1409*da0073e9SAndroid Build Coastguard Worker ) 1410*da0073e9SAndroid Build Coastguard Worker def test_invalid_fused_inputs_broadcast(self, device, kernel: SDPBackend): 1411*da0073e9SAndroid Build Coastguard Worker with sdpa_kernel(backends=[kernel]): 1412*da0073e9SAndroid Build Coastguard Worker # Fused Kernels don't support broadcasting for dense inputs 1413*da0073e9SAndroid Build Coastguard Worker dtype = torch.float16 1414*da0073e9SAndroid Build Coastguard Worker size = (2, 4, 3, 8) 1415*da0073e9SAndroid Build Coastguard Worker size_broadcast = (1, 4, 3, 8) 1416*da0073e9SAndroid Build Coastguard Worker q = torch.randn(size_broadcast, device=device, dtype=dtype) 1417*da0073e9SAndroid Build Coastguard Worker k = torch.randn(size, device=device, dtype=dtype) 1418*da0073e9SAndroid Build Coastguard Worker v = torch.randn(size, device=device, dtype=dtype) 1419*da0073e9SAndroid Build Coastguard Worker self.assertRaises(RuntimeError, lambda: torch.nn.functional.scaled_dot_product_attention( 1420*da0073e9SAndroid Build Coastguard Worker q, k, v, None, 0.0, False)) 1421*da0073e9SAndroid Build Coastguard Worker 1422*da0073e9SAndroid Build Coastguard Worker @onlyCUDA 1423*da0073e9SAndroid Build Coastguard Worker @unittest.skipIf(not PLATFORM_SUPPORTS_FUSED_ATTENTION, "Does not support fused scaled dot product attention") 1424*da0073e9SAndroid Build Coastguard Worker @parametrize("kernel", PLATFORM_SPECIFIC_SDPA) 1425*da0073e9SAndroid Build Coastguard Worker def test_invalid_sequence_lengths(self, device, kernel: SDPBackend): 1426*da0073e9SAndroid Build Coastguard Worker with sdpa_kernel(backends=[kernel]): 1427*da0073e9SAndroid Build Coastguard Worker # Passing in a q,k,v with 0 length sequences will error 1428*da0073e9SAndroid Build Coastguard Worker dtype = torch.float16 1429*da0073e9SAndroid Build Coastguard Worker make_tensor = partial(torch.rand, device=device, dtype=dtype) 1430*da0073e9SAndroid Build Coastguard Worker size = SdpaShape(2, 2, 0, 8) 1431*da0073e9SAndroid Build Coastguard Worker q, k, v = make_tensor(size), make_tensor(size), make_tensor(size) 1432*da0073e9SAndroid Build Coastguard Worker with self.assertWarnsRegex(UserWarning, "Both fused kernels do not support zero seq_len_q or seq_len_kv."): 1433*da0073e9SAndroid Build Coastguard Worker self.assertRaises(RuntimeError, lambda: torch.nn.functional.scaled_dot_product_attention( 1434*da0073e9SAndroid Build Coastguard Worker q, k, v, None, 0.0, False)) 1435*da0073e9SAndroid Build Coastguard Worker 1436*da0073e9SAndroid Build Coastguard Worker @onlyCUDA 1437*da0073e9SAndroid Build Coastguard Worker @unittest.skipIf(not PLATFORM_SUPPORTS_FUSED_ATTENTION, "Does not support fused scaled dot product attention") 1438*da0073e9SAndroid Build Coastguard Worker @parametrize("kernel", PLATFORM_SPECIFIC_SDPA) 1439*da0073e9SAndroid Build Coastguard Worker def test_invalid_last_dim_stride(self, device, kernel: SDPBackend): 1440*da0073e9SAndroid Build Coastguard Worker with sdpa_kernel(backends=[kernel]): 1441*da0073e9SAndroid Build Coastguard Worker # Passing in a q,k,v with last dim stride not equal to 1 will error 1442*da0073e9SAndroid Build Coastguard Worker dtype = torch.float16 1443*da0073e9SAndroid Build Coastguard Worker make_tensor = partial(torch.rand, device=device, dtype=dtype) 1444*da0073e9SAndroid Build Coastguard Worker size = SdpaShape(2, 2, 8, 8) 1445*da0073e9SAndroid Build Coastguard Worker q, k, v = make_tensor(size), make_tensor(size), make_tensor(size) 1446*da0073e9SAndroid Build Coastguard Worker q.as_strided_(size, [2, 2, 2, 2]) 1447*da0073e9SAndroid Build Coastguard Worker with self.assertWarnsRegex(UserWarning, "Both fused kernels require the last dimension of the input to have stride 1."): 1448*da0073e9SAndroid Build Coastguard Worker self.assertRaises(RuntimeError, lambda: torch.nn.functional.scaled_dot_product_attention( 1449*da0073e9SAndroid Build Coastguard Worker q, k, v, None, 0.0, False)) 1450*da0073e9SAndroid Build Coastguard Worker 1451*da0073e9SAndroid Build Coastguard Worker @onlyCUDA 1452*da0073e9SAndroid Build Coastguard Worker @unittest.skipIf(not PLATFORM_SUPPORTS_FLASH_ATTENTION, "Does not flash_attention fused scaled dot product attention") 1453*da0073e9SAndroid Build Coastguard Worker @parametrize("kernel", [SDPBackend.FLASH_ATTENTION, SDPBackend.EFFICIENT_ATTENTION]) 1454*da0073e9SAndroid Build Coastguard Worker def test_invalid_fused_inputs_head_dim(self, device, kernel: SDPBackend): 1455*da0073e9SAndroid Build Coastguard Worker with sdpa_kernel(backends=[kernel]): 1456*da0073e9SAndroid Build Coastguard Worker # The embed dim per head is not divisible by 8 for flash attention 1457*da0073e9SAndroid Build Coastguard Worker dtype = torch.float16 1458*da0073e9SAndroid Build Coastguard Worker make_tensor = partial(torch.rand, device=device, dtype=dtype) 1459*da0073e9SAndroid Build Coastguard Worker size = SdpaShape(2, 2, 3, 9) if kernel == SDPBackend.EFFICIENT_ATTENTION else SdpaShape(2, 2, 3, 257) 1460*da0073e9SAndroid Build Coastguard Worker if TEST_WITH_ROCM: # On ROCM, FA and EA share the backend GPU kernels 1461*da0073e9SAndroid Build Coastguard Worker size = SdpaShape(2, 2, 3, 257) 1462*da0073e9SAndroid Build Coastguard Worker q, k, v = make_tensor(size), make_tensor(size), make_tensor(size) 1463*da0073e9SAndroid Build Coastguard Worker self.assertRaises(RuntimeError, lambda: torch.nn.functional.scaled_dot_product_attention( 1464*da0073e9SAndroid Build Coastguard Worker q, k, v, None, 0.0, False)) 1465*da0073e9SAndroid Build Coastguard Worker 1466*da0073e9SAndroid Build Coastguard Worker @onlyCUDA 1467*da0073e9SAndroid Build Coastguard Worker @unittest.skipIf(not PLATFORM_SUPPORTS_FUSED_ATTENTION, "Does not support fused scaled dot product attention") 1468*da0073e9SAndroid Build Coastguard Worker @parametrize( 1469*da0073e9SAndroid Build Coastguard Worker "kernel", 1470*da0073e9SAndroid Build Coastguard Worker PLATFORM_SPECIFIC_SDPA, 1471*da0073e9SAndroid Build Coastguard Worker ) 1472*da0073e9SAndroid Build Coastguard Worker def test_invalid_fused_inputs_invalid_dtype(self, device, kernel: SDPBackend): 1473*da0073e9SAndroid Build Coastguard Worker with sdpa_kernel(backends=[kernel]): 1474*da0073e9SAndroid Build Coastguard Worker # Invalid dtype for both Flash Attention and Mem Efficient Attention 1475*da0073e9SAndroid Build Coastguard Worker size = SdpaShape(2, 2, 3, 16) 1476*da0073e9SAndroid Build Coastguard Worker make_tensor = partial(torch.rand, device=device, dtype=torch.float64) 1477*da0073e9SAndroid Build Coastguard Worker q, k, v = make_tensor(size), make_tensor(size), make_tensor(size) 1478*da0073e9SAndroid Build Coastguard Worker self.assertRaises(RuntimeError, lambda: torch.nn.functional.scaled_dot_product_attention( 1479*da0073e9SAndroid Build Coastguard Worker q, k, v, None, 0.0, False)) 1480*da0073e9SAndroid Build Coastguard Worker 1481*da0073e9SAndroid Build Coastguard Worker @onlyCUDA 1482*da0073e9SAndroid Build Coastguard Worker @unittest.skipIf(not PLATFORM_SUPPORTS_FLASH_ATTENTION, "Does not support flash attention") 1483*da0073e9SAndroid Build Coastguard Worker @parametrize("kernel", [SDPBackend.FLASH_ATTENTION]) 1484*da0073e9SAndroid Build Coastguard Worker def test_invalid_fused_inputs_attn_mask_present(self, device, kernel: SDPBackend): 1485*da0073e9SAndroid Build Coastguard Worker with sdpa_kernel(backends=[kernel]): 1486*da0073e9SAndroid Build Coastguard Worker # Failures for unsupported SDP args 1487*da0073e9SAndroid Build Coastguard Worker size = SdpaShape(2, 2, 3, 16) 1488*da0073e9SAndroid Build Coastguard Worker make_tensor = partial(torch.rand, size, device=device, dtype=torch.float16) 1489*da0073e9SAndroid Build Coastguard Worker q, k, v = make_tensor(), make_tensor(), make_tensor() 1490*da0073e9SAndroid Build Coastguard Worker # Non-None attention mask 1491*da0073e9SAndroid Build Coastguard Worker mask = torch.ones((2, 2, 3, 3), device=device, dtype=q.dtype) 1492*da0073e9SAndroid Build Coastguard Worker self.assertRaises(RuntimeError, lambda: torch.nn.functional.scaled_dot_product_attention( 1493*da0073e9SAndroid Build Coastguard Worker q, k, v, mask, 0.0, False)) 1494*da0073e9SAndroid Build Coastguard Worker 1495*da0073e9SAndroid Build Coastguard Worker @onlyCUDA 1496*da0073e9SAndroid Build Coastguard Worker @unittest.skipIf(not PLATFORM_SUPPORTS_FLASH_ATTENTION, "Does not support fused SDPA or pre-SM80 hardware") 1497*da0073e9SAndroid Build Coastguard Worker def test_unaligned_tensors(self, device): 1498*da0073e9SAndroid Build Coastguard Worker # The alignment is depdent on arch so we specifiy SM80OrLater 1499*da0073e9SAndroid Build Coastguard Worker dtype = torch.float16 1500*da0073e9SAndroid Build Coastguard Worker size = SdpaShape(2, 2, 8, 5) 1501*da0073e9SAndroid Build Coastguard Worker make_tensor = partial(torch.rand, size, device=device, dtype=dtype) 1502*da0073e9SAndroid Build Coastguard Worker q, k, v = make_tensor(), make_tensor(), make_tensor() 1503*da0073e9SAndroid Build Coastguard Worker with sdpa_kernel(backends=[SDPBackend.EFFICIENT_ATTENTION]): 1504*da0073e9SAndroid Build Coastguard Worker ctxmgr = self.assertRaises(RuntimeError) if not TEST_WITH_ROCM else contextlib.nullcontext() 1505*da0073e9SAndroid Build Coastguard Worker with ctxmgr: 1506*da0073e9SAndroid Build Coastguard Worker torch.nn.functional.scaled_dot_product_attention(q, k, v, None, 0.0, False) 1507*da0073e9SAndroid Build Coastguard Worker 1508*da0073e9SAndroid Build Coastguard Worker @onlyCUDA 1509*da0073e9SAndroid Build Coastguard Worker @unittest.skipIf(not PLATFORM_SUPPORTS_FLASH_ATTENTION, "Does not support fused SDPA or pre-SM80 hardware") 1510*da0073e9SAndroid Build Coastguard Worker def test_flash_fail_fp32(self, device): 1511*da0073e9SAndroid Build Coastguard Worker dtype = torch.float 1512*da0073e9SAndroid Build Coastguard Worker size = SdpaShape(16, 16, 32, 32) 1513*da0073e9SAndroid Build Coastguard Worker make_tensor = partial(torch.rand, size, device=device, dtype=dtype) 1514*da0073e9SAndroid Build Coastguard Worker q, k, v = make_tensor(), make_tensor(), make_tensor() 1515*da0073e9SAndroid Build Coastguard Worker with sdpa_kernel(backends=[SDPBackend.FLASH_ATTENTION]): 1516*da0073e9SAndroid Build Coastguard Worker with self.assertWarnsRegex(UserWarning, "Expected query, key and value to all be of dtype: {Half, BFloat16}"): 1517*da0073e9SAndroid Build Coastguard Worker self.assertRaises(RuntimeError, lambda: torch.nn.functional.scaled_dot_product_attention( 1518*da0073e9SAndroid Build Coastguard Worker q, k, v, None, 0.0, False)) 1519*da0073e9SAndroid Build Coastguard Worker 1520*da0073e9SAndroid Build Coastguard Worker @onlyCUDA 1521*da0073e9SAndroid Build Coastguard Worker @unittest.skipIf(not PLATFORM_SUPPORTS_FLASH_ATTENTION, "Does not support SDPA or pre-SM80 hardware") 1522*da0073e9SAndroid Build Coastguard Worker def test_flash_autocast_fp32_float16(self, device): 1523*da0073e9SAndroid Build Coastguard Worker dtype = torch.float 1524*da0073e9SAndroid Build Coastguard Worker size = SdpaShape(16, 16, 32, 32) 1525*da0073e9SAndroid Build Coastguard Worker make_tensor = partial(torch.rand, size, device=device, dtype=dtype) 1526*da0073e9SAndroid Build Coastguard Worker q, k, v = make_tensor(), make_tensor(), make_tensor() 1527*da0073e9SAndroid Build Coastguard Worker with torch.autocast(device_type='cuda', dtype=torch.float16): 1528*da0073e9SAndroid Build Coastguard Worker with sdpa_kernel(backends=[SDPBackend.FLASH_ATTENTION]): 1529*da0073e9SAndroid Build Coastguard Worker _ = torch.nn.functional.scaled_dot_product_attention( 1530*da0073e9SAndroid Build Coastguard Worker q, k, v, None, 0.0, False) 1531*da0073e9SAndroid Build Coastguard Worker 1532*da0073e9SAndroid Build Coastguard Worker @onlyCUDA 1533*da0073e9SAndroid Build Coastguard Worker @unittest.skipIf(not PLATFORM_SUPPORTS_FLASH_ATTENTION, "Does not support SDPA or pre-SM80 hardware") 1534*da0073e9SAndroid Build Coastguard Worker def test_flash_autocast_fp32_bfloat16(self, device): 1535*da0073e9SAndroid Build Coastguard Worker dtype = torch.float 1536*da0073e9SAndroid Build Coastguard Worker size = SdpaShape(16, 16, 32, 32) 1537*da0073e9SAndroid Build Coastguard Worker make_tensor = partial(torch.rand, size, device=device, dtype=dtype) 1538*da0073e9SAndroid Build Coastguard Worker q, k, v = make_tensor(), make_tensor(), make_tensor() 1539*da0073e9SAndroid Build Coastguard Worker with torch.autocast(device_type='cuda', dtype=torch.bfloat16): 1540*da0073e9SAndroid Build Coastguard Worker with sdpa_kernel(backends=[SDPBackend.FLASH_ATTENTION]): 1541*da0073e9SAndroid Build Coastguard Worker _ = torch.nn.functional.scaled_dot_product_attention( 1542*da0073e9SAndroid Build Coastguard Worker q, k, v, None, 0.0, False) 1543*da0073e9SAndroid Build Coastguard Worker 1544*da0073e9SAndroid Build Coastguard Worker # Note: do not truncate the list according to platforms. These tests should always raise errors. 1545*da0073e9SAndroid Build Coastguard Worker @parametrize("kernel", [SDPBackend.MATH, SDPBackend.FLASH_ATTENTION, SDPBackend.EFFICIENT_ATTENTION]) 1546*da0073e9SAndroid Build Coastguard Worker def test_invalid_inputs_different_datatypes(self, device, kernel: SDPBackend): 1547*da0073e9SAndroid Build Coastguard Worker with sdpa_kernel(backends=[kernel]): 1548*da0073e9SAndroid Build Coastguard Worker # Different datatypes 1549*da0073e9SAndroid Build Coastguard Worker shape = (1, 4, 8, 16) 1550*da0073e9SAndroid Build Coastguard Worker query = torch.randn(shape, dtype=torch.float32, device=device) 1551*da0073e9SAndroid Build Coastguard Worker key = torch.randn(shape, dtype=torch.float16, device=device) 1552*da0073e9SAndroid Build Coastguard Worker value = torch.randn(shape, dtype=torch.float16, device=device) 1553*da0073e9SAndroid Build Coastguard Worker self.assertRaises(RuntimeError, lambda: F.scaled_dot_product_attention(query, key, value)) 1554*da0073e9SAndroid Build Coastguard Worker 1555*da0073e9SAndroid Build Coastguard Worker @onlyCUDA 1556*da0073e9SAndroid Build Coastguard Worker @parametrize("kernel", [SDPBackend.MATH, SDPBackend.FLASH_ATTENTION, SDPBackend.EFFICIENT_ATTENTION]) 1557*da0073e9SAndroid Build Coastguard Worker def test_invalid_inputs_different_devices(self, device, kernel: SDPBackend): 1558*da0073e9SAndroid Build Coastguard Worker # Different devices 1559*da0073e9SAndroid Build Coastguard Worker shape = (1, 4, 8, 16) 1560*da0073e9SAndroid Build Coastguard Worker query = torch.randn(shape, dtype=torch.float32, device=device) 1561*da0073e9SAndroid Build Coastguard Worker key = torch.randn(shape, dtype=torch.float16, device='cpu') 1562*da0073e9SAndroid Build Coastguard Worker value = torch.randn(shape, dtype=torch.float16, device='cpu') 1563*da0073e9SAndroid Build Coastguard Worker self.assertRaises(RuntimeError, lambda: F.scaled_dot_product_attention(query, key, value)) 1564*da0073e9SAndroid Build Coastguard Worker 1565*da0073e9SAndroid Build Coastguard Worker @parametrize("kernel", [SDPBackend.MATH, SDPBackend.FLASH_ATTENTION, SDPBackend.EFFICIENT_ATTENTION]) 1566*da0073e9SAndroid Build Coastguard Worker def test_invalid_inputs_1_dimensional_inputs(self, device, kernel: SDPBackend): 1567*da0073e9SAndroid Build Coastguard Worker with sdpa_kernel(backends=[kernel]): 1568*da0073e9SAndroid Build Coastguard Worker # 1 dimensional input 1569*da0073e9SAndroid Build Coastguard Worker shape = (1, 4) 1570*da0073e9SAndroid Build Coastguard Worker query = torch.randn(4, dtype=torch.float16, device=device) 1571*da0073e9SAndroid Build Coastguard Worker key = torch.randn(shape, dtype=torch.float16, device=device) 1572*da0073e9SAndroid Build Coastguard Worker value = torch.randn(shape, dtype=torch.float16, device=device) 1573*da0073e9SAndroid Build Coastguard Worker self.assertRaises(RuntimeError, lambda: F.scaled_dot_product_attention(query, key, value)) 1574*da0073e9SAndroid Build Coastguard Worker 1575*da0073e9SAndroid Build Coastguard Worker @onlyCUDA 1576*da0073e9SAndroid Build Coastguard Worker @skipIfRocm # Missing EFFICIENT_ATTENTION 1577*da0073e9SAndroid Build Coastguard Worker @unittest.skipIf(not PLATFORM_SUPPORTS_MEM_EFF_ATTENTION, "Fused SDPA was not built for this system") 1578*da0073e9SAndroid Build Coastguard Worker def test_fused_kernels_nested_broadcasting_error_cases(self, device): 1579*da0073e9SAndroid Build Coastguard Worker # one of k,v needs to be broadcasted and other has non consistent seq_len dim 1580*da0073e9SAndroid Build Coastguard Worker rand_nested_tensor = partial(rand_sdpa_tensor, type="nested", device=device, dtype=torch.float32) 1581*da0073e9SAndroid Build Coastguard Worker batch, num_heads, head_dim = 32, 8, 64 1582*da0073e9SAndroid Build Coastguard Worker seq_lens_q = torch.randint(low=1, high=32, size=(batch,)).tolist() 1583*da0073e9SAndroid Build Coastguard Worker seq_lens_v = torch.randint(low=1, high=32, size=(batch,)).tolist() 1584*da0073e9SAndroid Build Coastguard Worker 1585*da0073e9SAndroid Build Coastguard Worker q_shape = SdpaShape(batch, num_heads, seq_lens_q, head_dim) 1586*da0073e9SAndroid Build Coastguard Worker k_shape = SdpaShape(1, num_heads, 1, head_dim) 1587*da0073e9SAndroid Build Coastguard Worker v_shape = SdpaShape(batch, num_heads, seq_lens_v, head_dim) 1588*da0073e9SAndroid Build Coastguard Worker 1589*da0073e9SAndroid Build Coastguard Worker query = rand_nested_tensor(q_shape).transpose(1, 2) 1590*da0073e9SAndroid Build Coastguard Worker key = rand_nested_tensor(k_shape).transpose(1, 2) 1591*da0073e9SAndroid Build Coastguard Worker value = rand_nested_tensor(v_shape).transpose(1, 2) 1592*da0073e9SAndroid Build Coastguard Worker 1593*da0073e9SAndroid Build Coastguard Worker with sdpa_kernel(backends=[SDPBackend.EFFICIENT_ATTENTION]): 1594*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(RuntimeError, "No available kernel"): 1595*da0073e9SAndroid Build Coastguard Worker torch.nn.functional.scaled_dot_product_attention( 1596*da0073e9SAndroid Build Coastguard Worker query, key, value, attn_mask=None, dropout_p=0.0, is_causal=False) 1597*da0073e9SAndroid Build Coastguard Worker 1598*da0073e9SAndroid Build Coastguard Worker @onlyCUDA 1599*da0073e9SAndroid Build Coastguard Worker @unittest.skipIf(not PLATFORM_SUPPORTS_FLASH_ATTENTION, "Fused SDPA was not built for this system") 1600*da0073e9SAndroid Build Coastguard Worker def test_nested_fails_on_padding_head_dim(self, device): 1601*da0073e9SAndroid Build Coastguard Worker dtype = torch.bfloat16 1602*da0073e9SAndroid Build Coastguard Worker seq_len_list = [2, 4, 5, 6, 7] 1603*da0073e9SAndroid Build Coastguard Worker shape = SdpaShape(5, 8, seq_len_list, 57) 1604*da0073e9SAndroid Build Coastguard Worker make_tensor = partial(rand_sdpa_tensor, shape=shape, type="nested", device=device, dtype=dtype) 1605*da0073e9SAndroid Build Coastguard Worker q, k, v = make_tensor(), make_tensor(), make_tensor() 1606*da0073e9SAndroid Build Coastguard Worker with sdpa_kernel(backends=[SDPBackend.FLASH_ATTENTION]): 1607*da0073e9SAndroid Build Coastguard Worker with self.assertWarnsRegex(UserWarning, "For NestedTensor inputs, Flash attention requires"): 1608*da0073e9SAndroid Build Coastguard Worker self.assertRaises(RuntimeError, lambda: torch.nn.functional.scaled_dot_product_attention( 1609*da0073e9SAndroid Build Coastguard Worker q, k, v, None, 0.0, False)) 1610*da0073e9SAndroid Build Coastguard Worker 1611*da0073e9SAndroid Build Coastguard Worker @onlyCUDA 1612*da0073e9SAndroid Build Coastguard Worker @unittest.skipIf(not PLATFORM_SUPPORTS_FUSED_ATTENTION or not isLessThanSM80Device, 1613*da0073e9SAndroid Build Coastguard Worker "Current platform does not support fused SDPA or is an SM80+ device.") 1614*da0073e9SAndroid Build Coastguard Worker def test_mem_efficient_fail_bfloat16_less_than_sm80(self, device): 1615*da0073e9SAndroid Build Coastguard Worker dtype = torch.bfloat16 1616*da0073e9SAndroid Build Coastguard Worker size = SdpaShape(16, 16, 32, 32) 1617*da0073e9SAndroid Build Coastguard Worker make_tensor = partial(torch.rand, size, device=device, dtype=dtype) 1618*da0073e9SAndroid Build Coastguard Worker q, k, v = make_tensor(), make_tensor(), make_tensor() 1619*da0073e9SAndroid Build Coastguard Worker with sdpa_kernel(backends=[SDPBackend.EFFICIENT_ATTENTION]): 1620*da0073e9SAndroid Build Coastguard Worker with self.assertWarnsRegex(UserWarning, "Expected query, key and value to all be of dtype: {Half, Float}"): 1621*da0073e9SAndroid Build Coastguard Worker self.assertRaises(RuntimeError, lambda: torch.nn.functional.scaled_dot_product_attention( 1622*da0073e9SAndroid Build Coastguard Worker q, k, v, None, 0.0, False)) 1623*da0073e9SAndroid Build Coastguard Worker 1624*da0073e9SAndroid Build Coastguard Worker @onlyCUDA 1625*da0073e9SAndroid Build Coastguard Worker @unittest.skipIf(not PLATFORM_SUPPORTS_FLASH_ATTENTION, "Does not support flash attention") 1626*da0073e9SAndroid Build Coastguard Worker def test_flash_atteention_large_bf16_nan_values(self, device): 1627*da0073e9SAndroid Build Coastguard Worker query = torch.full((1, 1, 1, 64), 133120.0, dtype=torch.bfloat16, device="cuda") 1628*da0073e9SAndroid Build Coastguard Worker key = torch.full((1, 1, 1, 64), 133120.0, dtype=torch.bfloat16, device="cuda") 1629*da0073e9SAndroid Build Coastguard Worker value = torch.full((1, 1, 1, 64), 133120.0, dtype=torch.bfloat16, device="cuda") 1630*da0073e9SAndroid Build Coastguard Worker 1631*da0073e9SAndroid Build Coastguard Worker with sdpa_kernel(SDPBackend.FLASH_ATTENTION): 1632*da0073e9SAndroid Build Coastguard Worker out = torch.nn.functional.scaled_dot_product_attention(query, key, value) 1633*da0073e9SAndroid Build Coastguard Worker 1634*da0073e9SAndroid Build Coastguard Worker self.assertFalse(torch.isnan(out).any(), "Output should not contain NaNs!") 1635*da0073e9SAndroid Build Coastguard Worker 1636*da0073e9SAndroid Build Coastguard Worker @onlyCUDA 1637*da0073e9SAndroid Build Coastguard Worker @unittest.skipIf(not PLATFORM_SUPPORTS_FUSED_ATTENTION, "Fused SDPA was not built for this system") 1638*da0073e9SAndroid Build Coastguard Worker @parametrize("fused_kernel", [SDPBackend.FLASH_ATTENTION, SDPBackend.EFFICIENT_ATTENTION] if 1639*da0073e9SAndroid Build Coastguard Worker PLATFORM_SUPPORTS_FLASH_ATTENTION else [SDPBackend.EFFICIENT_ATTENTION]) 1640*da0073e9SAndroid Build Coastguard Worker def test_fused_kernels_seq_len_0_inputs(self, device, fused_kernel): 1641*da0073e9SAndroid Build Coastguard Worker rand_nested_tensor = partial(rand_sdpa_tensor, type="nested", device=device, dtype=torch.float16) 1642*da0073e9SAndroid Build Coastguard Worker batch, num_heads, head_dim = 32, 16, 64 1643*da0073e9SAndroid Build Coastguard Worker seq_lens = torch.randint(low=1, high=32, size=(batch,)) 1644*da0073e9SAndroid Build Coastguard Worker # make sure some seq_lens are 0 1645*da0073e9SAndroid Build Coastguard Worker num_zeros = 10 1646*da0073e9SAndroid Build Coastguard Worker indices = torch.randint(low=0, high=batch, size=(num_zeros,)) 1647*da0073e9SAndroid Build Coastguard Worker seq_lens.scatter_(0, indices, 0) 1648*da0073e9SAndroid Build Coastguard Worker 1649*da0073e9SAndroid Build Coastguard Worker shape = SdpaShape(batch, num_heads, seq_lens.tolist(), head_dim) 1650*da0073e9SAndroid Build Coastguard Worker query = rand_nested_tensor(shape) 1651*da0073e9SAndroid Build Coastguard Worker key = rand_nested_tensor(shape) 1652*da0073e9SAndroid Build Coastguard Worker value = rand_nested_tensor(shape) 1653*da0073e9SAndroid Build Coastguard Worker 1654*da0073e9SAndroid Build Coastguard Worker query = query.transpose(1, 2) 1655*da0073e9SAndroid Build Coastguard Worker key = key.transpose(1, 2) 1656*da0073e9SAndroid Build Coastguard Worker value = value.transpose(1, 2) 1657*da0073e9SAndroid Build Coastguard Worker 1658*da0073e9SAndroid Build Coastguard Worker with sdpa_kernel(backends=[fused_kernel]): 1659*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(RuntimeError, "No available kernel"): 1660*da0073e9SAndroid Build Coastguard Worker torch.nn.functional.scaled_dot_product_attention( 1661*da0073e9SAndroid Build Coastguard Worker query, key, value, attn_mask=None, dropout_p=0.0, is_causal=False) 1662*da0073e9SAndroid Build Coastguard Worker 1663*da0073e9SAndroid Build Coastguard Worker @onlyCUDA 1664*da0073e9SAndroid Build Coastguard Worker @unittest.skipIf(not PLATFORM_SUPPORTS_FLASH_ATTENTION, "Fused SDPA was not built for this system") 1665*da0073e9SAndroid Build Coastguard Worker def test_fused_kernels_nested_broadcasting_requires_grad_failure(self, device): 1666*da0073e9SAndroid Build Coastguard Worker rand_nested_tensor = partial(rand_sdpa_tensor, type="nested", device=device, dtype=torch.float16, requires_grad=True) 1667*da0073e9SAndroid Build Coastguard Worker batch, num_heads, head_dim, head_dim_v = 32, 16, 64, 64 1668*da0073e9SAndroid Build Coastguard Worker seq_lens = torch.randint(low=1, high=32, size=(batch,)).tolist() 1669*da0073e9SAndroid Build Coastguard Worker q_shape = SdpaShape(1, num_heads, 1, head_dim) 1670*da0073e9SAndroid Build Coastguard Worker k_shape = SdpaShape(batch, num_heads, seq_lens, head_dim) 1671*da0073e9SAndroid Build Coastguard Worker v_shape = SdpaShape(batch, 1, seq_lens, head_dim_v) 1672*da0073e9SAndroid Build Coastguard Worker 1673*da0073e9SAndroid Build Coastguard Worker # create a dense query 1674*da0073e9SAndroid Build Coastguard Worker query = torch.randn(q_shape, device=device, dtype=torch.float16, requires_grad=True) 1675*da0073e9SAndroid Build Coastguard Worker key = rand_nested_tensor(k_shape) 1676*da0073e9SAndroid Build Coastguard Worker value = rand_nested_tensor(v_shape) 1677*da0073e9SAndroid Build Coastguard Worker 1678*da0073e9SAndroid Build Coastguard Worker query = query.transpose(1, 2) 1679*da0073e9SAndroid Build Coastguard Worker key = key.transpose(1, 2) 1680*da0073e9SAndroid Build Coastguard Worker value = value.transpose(1, 2) 1681*da0073e9SAndroid Build Coastguard Worker 1682*da0073e9SAndroid Build Coastguard Worker with sdpa_kernel(backends=[SDPBackend.FLASH_ATTENTION]): 1683*da0073e9SAndroid Build Coastguard Worker with self.assertWarnsRegex(UserWarning, "Both fused kernels do not support training with broadcasted NT inputs"): 1684*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(RuntimeError, "No available kernel"): 1685*da0073e9SAndroid Build Coastguard Worker out = torch.nn.functional.scaled_dot_product_attention( 1686*da0073e9SAndroid Build Coastguard Worker query, key, value, attn_mask=None, dropout_p=0.0, is_causal=False) 1687*da0073e9SAndroid Build Coastguard Worker 1688*da0073e9SAndroid Build Coastguard Worker @onlyCUDA 1689*da0073e9SAndroid Build Coastguard Worker @unittest.skipIf(not PLATFORM_SUPPORTS_FLASH_ATTENTION, "Does not support flash attention") 1690*da0073e9SAndroid Build Coastguard Worker def test_flash_attention_fail_with_non_square_causal_attention(self, device): 1691*da0073e9SAndroid Build Coastguard Worker dtype = torch.bfloat16 1692*da0073e9SAndroid Build Coastguard Worker q_shape = SdpaShape(1, 1, 8, 16) 1693*da0073e9SAndroid Build Coastguard Worker kv_shape = SdpaShape(1, 1, 12, 16) 1694*da0073e9SAndroid Build Coastguard Worker make_q = partial(torch.rand, q_shape, device=device, dtype=dtype) 1695*da0073e9SAndroid Build Coastguard Worker make_kv = partial(torch.rand, kv_shape, device=device, dtype=dtype) 1696*da0073e9SAndroid Build Coastguard Worker q, k, v = make_q(), make_kv(), make_kv() 1697*da0073e9SAndroid Build Coastguard Worker warning_str = "Flash attention does not support the is_causal flag when seqlen_q != seqlen_k." 1698*da0073e9SAndroid Build Coastguard Worker with sdpa_kernel(backends=[SDPBackend.FLASH_ATTENTION]): 1699*da0073e9SAndroid Build Coastguard Worker with self.assertWarnsRegex(UserWarning, warning_str): 1700*da0073e9SAndroid Build Coastguard Worker self.assertRaises(RuntimeError, lambda: torch.nn.functional.scaled_dot_product_attention( 1701*da0073e9SAndroid Build Coastguard Worker q, k, v, None, 0.0, is_causal=True)) 1702*da0073e9SAndroid Build Coastguard Worker 1703*da0073e9SAndroid Build Coastguard Workerdef _get_block_size_n(device, head_dim, is_dropout, is_causal): 1704*da0073e9SAndroid Build Coastguard Worker # This should match the block sizes in the CUDA kernel 1705*da0073e9SAndroid Build Coastguard Worker assert head_dim <= 256 1706*da0073e9SAndroid Build Coastguard Worker major, minor = torch.cuda.get_device_capability(device) 1707*da0073e9SAndroid Build Coastguard Worker is_sm8x = major == 8 and minor > 0 # Only include sm86 and sm89, exclude sm80 (A100) 1708*da0073e9SAndroid Build Coastguard Worker is_sm80 = major == 8 and minor == 0 1709*da0073e9SAndroid Build Coastguard Worker is_sm90 = major == 9 and minor == 0 1710*da0073e9SAndroid Build Coastguard Worker if head_dim <= 32: 1711*da0073e9SAndroid Build Coastguard Worker return 128 1712*da0073e9SAndroid Build Coastguard Worker if head_dim <= 64: 1713*da0073e9SAndroid Build Coastguard Worker return 128 if not is_dropout else 64 1714*da0073e9SAndroid Build Coastguard Worker elif head_dim <= 96: 1715*da0073e9SAndroid Build Coastguard Worker return 64 1716*da0073e9SAndroid Build Coastguard Worker elif head_dim <= 128: 1717*da0073e9SAndroid Build Coastguard Worker if is_sm8x: 1718*da0073e9SAndroid Build Coastguard Worker return 64 if (not is_dropout and is_causal) else 32 1719*da0073e9SAndroid Build Coastguard Worker else: 1720*da0073e9SAndroid Build Coastguard Worker return 64 if not is_dropout else 32 1721*da0073e9SAndroid Build Coastguard Worker elif head_dim <= 160: 1722*da0073e9SAndroid Build Coastguard Worker if is_sm8x: 1723*da0073e9SAndroid Build Coastguard Worker return 64 1724*da0073e9SAndroid Build Coastguard Worker else: 1725*da0073e9SAndroid Build Coastguard Worker return 32 1726*da0073e9SAndroid Build Coastguard Worker elif head_dim <= 192: 1727*da0073e9SAndroid Build Coastguard Worker return 64 1728*da0073e9SAndroid Build Coastguard Worker elif head_dim <= 224: 1729*da0073e9SAndroid Build Coastguard Worker return 64 1730*da0073e9SAndroid Build Coastguard Worker elif head_dim <= 256: 1731*da0073e9SAndroid Build Coastguard Worker return 64 1732*da0073e9SAndroid Build Coastguard Worker 1733*da0073e9SAndroid Build Coastguard Worker 1734*da0073e9SAndroid Build Coastguard Workerdef pad_last_dim(input_tensor, alignment_size, slice: bool = False): 1735*da0073e9SAndroid Build Coastguard Worker last_dim_size = input_tensor.size(-1) 1736*da0073e9SAndroid Build Coastguard Worker if (last_dim_size % alignment_size == 0): 1737*da0073e9SAndroid Build Coastguard Worker return input_tensor, last_dim_size 1738*da0073e9SAndroid Build Coastguard Worker pad_count = alignment_size - (last_dim_size % alignment_size) 1739*da0073e9SAndroid Build Coastguard Worker padded_tensor = F.pad(input_tensor, (0, pad_count)) 1740*da0073e9SAndroid Build Coastguard Worker if slice: 1741*da0073e9SAndroid Build Coastguard Worker return padded_tensor[..., :last_dim_size], last_dim_size 1742*da0073e9SAndroid Build Coastguard Worker return padded_tensor, last_dim_size 1743*da0073e9SAndroid Build Coastguard Worker 1744*da0073e9SAndroid Build Coastguard Worker 1745*da0073e9SAndroid Build Coastguard Workerclass TestSDPA(NNTestCase): 1746*da0073e9SAndroid Build Coastguard Worker """ Used to test generic functionality of scaled_dot_product_attention 1747*da0073e9SAndroid Build Coastguard Worker Summary: 1748*da0073e9SAndroid Build Coastguard Worker If you are adding a new test to this class, make sure that it runs 1749*da0073e9SAndroid Build Coastguard Worker for both cpu and cuda. If you're test is only applicable to cuda, 1750*da0073e9SAndroid Build Coastguard Worker add it to TestSDPACudaOnly. 1751*da0073e9SAndroid Build Coastguard Worker """ 1752*da0073e9SAndroid Build Coastguard Worker @parametrize("contiguous_inputs", [True, False]) 1753*da0073e9SAndroid Build Coastguard Worker def test_sdp_math_gradcheck(self, device, contiguous_inputs: bool): 1754*da0073e9SAndroid Build Coastguard Worker 1755*da0073e9SAndroid Build Coastguard Worker batch_size, seq_len, num_heads, head_dim = 4, 4, 2, 16 1756*da0073e9SAndroid Build Coastguard Worker shape = SdpaShape(batch_size, num_heads, seq_len, head_dim) 1757*da0073e9SAndroid Build Coastguard Worker make_tensor = partial(rand_sdpa_tensor, type="dense", device=device, 1758*da0073e9SAndroid Build Coastguard Worker dtype=torch.float64, requires_grad=True, packed=True) 1759*da0073e9SAndroid Build Coastguard Worker 1760*da0073e9SAndroid Build Coastguard Worker qkv = make_tensor(shape) 1761*da0073e9SAndroid Build Coastguard Worker query, key, value = qkv.chunk(3, dim=-1) 1762*da0073e9SAndroid Build Coastguard Worker 1763*da0073e9SAndroid Build Coastguard Worker query = query.view(batch_size, -1, num_heads, head_dim).transpose(1, 2) 1764*da0073e9SAndroid Build Coastguard Worker key = key.view(batch_size, -1, num_heads, head_dim).transpose(1, 2) 1765*da0073e9SAndroid Build Coastguard Worker value = value.view(batch_size, -1, num_heads, head_dim).transpose(1, 2) 1766*da0073e9SAndroid Build Coastguard Worker 1767*da0073e9SAndroid Build Coastguard Worker if contiguous_inputs: 1768*da0073e9SAndroid Build Coastguard Worker query = query.contiguous() 1769*da0073e9SAndroid Build Coastguard Worker key = key.contiguous() 1770*da0073e9SAndroid Build Coastguard Worker value = value.contiguous() 1771*da0073e9SAndroid Build Coastguard Worker 1772*da0073e9SAndroid Build Coastguard Worker with sdpa_kernel(backends=[SDPBackend.MATH]): 1773*da0073e9SAndroid Build Coastguard Worker assert gradcheck(lambda *args, **kwargs: 1774*da0073e9SAndroid Build Coastguard Worker wrapper_set_seed(torch.nn.functional.scaled_dot_product_attention, *args, **kwargs), 1775*da0073e9SAndroid Build Coastguard Worker (query, key, value, None, 0.0, False) 1776*da0073e9SAndroid Build Coastguard Worker ) 1777*da0073e9SAndroid Build Coastguard Worker 1778*da0073e9SAndroid Build Coastguard Worker @onlyCPU 1779*da0073e9SAndroid Build Coastguard Worker @parametrize("type", ["dense", "nested"]) 1780*da0073e9SAndroid Build Coastguard Worker @parametrize("dropout", [0.0, 0.7]) 1781*da0073e9SAndroid Build Coastguard Worker @parametrize("dtype", [torch.float64, torch.float32, torch.bfloat16, torch.half]) 1782*da0073e9SAndroid Build Coastguard Worker def test_fused_sdp_choice_cpu(self, device, type: str, dropout: float, dtype: torch.dtype): 1783*da0073e9SAndroid Build Coastguard Worker # Test that cpu and nestedtensor cpu return MATH backend 1784*da0073e9SAndroid Build Coastguard Worker make_tensor = partial(rand_sdpa_tensor, type=type, device=device, dtype=dtype) 1785*da0073e9SAndroid Build Coastguard Worker size = SdpaShape(2, 8, 128, 64) 1786*da0073e9SAndroid Build Coastguard Worker q, k, v = make_tensor(size), make_tensor(size), make_tensor(size) 1787*da0073e9SAndroid Build Coastguard Worker if type == "nested" \ 1788*da0073e9SAndroid Build Coastguard Worker or dropout > 0.0 \ 1789*da0073e9SAndroid Build Coastguard Worker or dtype not in [torch.float32, torch.float64, torch.bfloat16, torch.float16]: 1790*da0073e9SAndroid Build Coastguard Worker assert torch._fused_sdp_choice(q, k, v, dropout_p=dropout) == SDPBackend.MATH.value 1791*da0073e9SAndroid Build Coastguard Worker else: 1792*da0073e9SAndroid Build Coastguard Worker assert torch._fused_sdp_choice(q, k, v, dropout_p=dropout) == SDPBackend.FLASH_ATTENTION.value 1793*da0073e9SAndroid Build Coastguard Worker 1794*da0073e9SAndroid Build Coastguard Worker @onlyCPU 1795*da0073e9SAndroid Build Coastguard Worker @parametrize("fused_kernel", [SDPBackend.FLASH_ATTENTION]) 1796*da0073e9SAndroid Build Coastguard Worker @parametrize("dtype", [torch.float64, torch.float32, torch.bfloat16, torch.float16]) 1797*da0073e9SAndroid Build Coastguard Worker @parametrize("batch_size", [2, 12]) 1798*da0073e9SAndroid Build Coastguard Worker @parametrize("seq_len", [267, 1030]) 1799*da0073e9SAndroid Build Coastguard Worker @parametrize("n_head", [1, 3]) 1800*da0073e9SAndroid Build Coastguard Worker @parametrize("head_dim", [8, 16]) 1801*da0073e9SAndroid Build Coastguard Worker @parametrize("causal", [True, False]) 1802*da0073e9SAndroid Build Coastguard Worker @parametrize("train", [True, False]) 1803*da0073e9SAndroid Build Coastguard Worker def test_scaled_dot_product_fused_attention_vs_math_cpu( 1804*da0073e9SAndroid Build Coastguard Worker self, 1805*da0073e9SAndroid Build Coastguard Worker device, 1806*da0073e9SAndroid Build Coastguard Worker fused_kernel, 1807*da0073e9SAndroid Build Coastguard Worker dtype, 1808*da0073e9SAndroid Build Coastguard Worker batch_size, 1809*da0073e9SAndroid Build Coastguard Worker seq_len, 1810*da0073e9SAndroid Build Coastguard Worker n_head, 1811*da0073e9SAndroid Build Coastguard Worker head_dim, 1812*da0073e9SAndroid Build Coastguard Worker causal, 1813*da0073e9SAndroid Build Coastguard Worker train, 1814*da0073e9SAndroid Build Coastguard Worker ): 1815*da0073e9SAndroid Build Coastguard Worker atol = 1e-5 1816*da0073e9SAndroid Build Coastguard Worker rtol = 5e-6 1817*da0073e9SAndroid Build Coastguard Worker if dtype is torch.bfloat16: 1818*da0073e9SAndroid Build Coastguard Worker atol = 5e-2 1819*da0073e9SAndroid Build Coastguard Worker rtol = 5e-2 1820*da0073e9SAndroid Build Coastguard Worker if dtype is torch.float16: 1821*da0073e9SAndroid Build Coastguard Worker atol = 1e-2 1822*da0073e9SAndroid Build Coastguard Worker rtol = 1e-2 1823*da0073e9SAndroid Build Coastguard Worker 1824*da0073e9SAndroid Build Coastguard Worker n_embd = n_head * head_dim 1825*da0073e9SAndroid Build Coastguard Worker make_tensor = partial(rand_sdpa_tensor, type="dense", device=device, dtype=dtype, packed=True, requires_grad=False) 1826*da0073e9SAndroid Build Coastguard Worker shape = SdpaShape(batch_size, n_head, seq_len, head_dim) 1827*da0073e9SAndroid Build Coastguard Worker x = make_tensor(shape) 1828*da0073e9SAndroid Build Coastguard Worker x2 = x.clone() 1829*da0073e9SAndroid Build Coastguard Worker 1830*da0073e9SAndroid Build Coastguard Worker if train: 1831*da0073e9SAndroid Build Coastguard Worker x.requires_grad_(True) 1832*da0073e9SAndroid Build Coastguard Worker x2.requires_grad_(True) 1833*da0073e9SAndroid Build Coastguard Worker 1834*da0073e9SAndroid Build Coastguard Worker q, k, v = x.split(n_embd, dim=2) 1835*da0073e9SAndroid Build Coastguard Worker q2, k2, v2 = x2.split(n_embd, dim=2) 1836*da0073e9SAndroid Build Coastguard Worker 1837*da0073e9SAndroid Build Coastguard Worker if dtype in [torch.bfloat16, torch.float16]: 1838*da0073e9SAndroid Build Coastguard Worker q2 = q2.float() 1839*da0073e9SAndroid Build Coastguard Worker k2 = k2.float() 1840*da0073e9SAndroid Build Coastguard Worker v2 = v2.float() 1841*da0073e9SAndroid Build Coastguard Worker 1842*da0073e9SAndroid Build Coastguard Worker # (B, nh, T, hs) 1843*da0073e9SAndroid Build Coastguard Worker k = k.view(batch_size, seq_len, n_head, head_dim).transpose(1, 2) 1844*da0073e9SAndroid Build Coastguard Worker q = q.view(batch_size, seq_len, n_head, head_dim).transpose(1, 2) 1845*da0073e9SAndroid Build Coastguard Worker v = v.view(batch_size, seq_len, n_head, head_dim).transpose(1, 2) 1846*da0073e9SAndroid Build Coastguard Worker k2 = k2.view(batch_size, seq_len, n_head, head_dim).transpose(1, 2) 1847*da0073e9SAndroid Build Coastguard Worker q2 = q2.view(batch_size, seq_len, n_head, head_dim).transpose(1, 2) 1848*da0073e9SAndroid Build Coastguard Worker v2 = v2.view(batch_size, seq_len, n_head, head_dim).transpose(1, 2) 1849*da0073e9SAndroid Build Coastguard Worker 1850*da0073e9SAndroid Build Coastguard Worker with sdpa_kernel(backends=[fused_kernel]): 1851*da0073e9SAndroid Build Coastguard Worker actual = torch.nn.functional.scaled_dot_product_attention( 1852*da0073e9SAndroid Build Coastguard Worker q, k, v, attn_mask=None, dropout_p=0.0, is_causal=causal) 1853*da0073e9SAndroid Build Coastguard Worker with sdpa_kernel(backends=[SDPBackend.MATH]): 1854*da0073e9SAndroid Build Coastguard Worker math_ref = torch.nn.functional.scaled_dot_product_attention( 1855*da0073e9SAndroid Build Coastguard Worker q2, k2, v2, attn_mask=None, dropout_p=0.0, is_causal=causal) 1856*da0073e9SAndroid Build Coastguard Worker 1857*da0073e9SAndroid Build Coastguard Worker if dtype in [torch.bfloat16, torch.float16]: 1858*da0073e9SAndroid Build Coastguard Worker math_ref = math_ref.to(dtype) 1859*da0073e9SAndroid Build Coastguard Worker 1860*da0073e9SAndroid Build Coastguard Worker self.assertEqual(actual, math_ref, atol=atol, rtol=rtol) 1861*da0073e9SAndroid Build Coastguard Worker 1862*da0073e9SAndroid Build Coastguard Worker if train: 1863*da0073e9SAndroid Build Coastguard Worker actual.sum().backward() 1864*da0073e9SAndroid Build Coastguard Worker math_ref.sum().backward() 1865*da0073e9SAndroid Build Coastguard Worker 1866*da0073e9SAndroid Build Coastguard Worker grad_x, grad_x2 = x.grad, x2.grad 1867*da0073e9SAndroid Build Coastguard Worker grad_q_actual, grad_k_actual, grad_v_actual = grad_x.split(n_embd, dim=2) 1868*da0073e9SAndroid Build Coastguard Worker grad_q_ref, grad_k_ref, grad_v_ref = grad_x2.split(n_embd, dim=2) 1869*da0073e9SAndroid Build Coastguard Worker 1870*da0073e9SAndroid Build Coastguard Worker self.assertEqual(grad_q_actual, grad_q_ref, atol=atol, rtol=rtol) 1871*da0073e9SAndroid Build Coastguard Worker self.assertEqual(grad_k_actual, grad_k_ref, atol=atol, rtol=rtol) 1872*da0073e9SAndroid Build Coastguard Worker self.assertEqual(grad_v_actual, grad_v_ref, atol=atol, rtol=rtol) 1873*da0073e9SAndroid Build Coastguard Worker 1874*da0073e9SAndroid Build Coastguard Worker @onlyCPU 1875*da0073e9SAndroid Build Coastguard Worker @parametrize("fused_kernel", [SDPBackend.FLASH_ATTENTION]) 1876*da0073e9SAndroid Build Coastguard Worker @parametrize("dtype", [torch.float64, torch.float32, torch.bfloat16, torch.float16]) 1877*da0073e9SAndroid Build Coastguard Worker @parametrize("batch_size", [2, 12]) 1878*da0073e9SAndroid Build Coastguard Worker @parametrize("q_seq_len", [267, 1030]) 1879*da0073e9SAndroid Build Coastguard Worker @parametrize("kv_seq_len", [514, 1179]) 1880*da0073e9SAndroid Build Coastguard Worker @parametrize("n_head", [1, 3]) 1881*da0073e9SAndroid Build Coastguard Worker @parametrize("head_dim", [8, 16]) 1882*da0073e9SAndroid Build Coastguard Worker @parametrize("mask_dim", [2, 4]) 1883*da0073e9SAndroid Build Coastguard Worker @parametrize("bool_mask", [0, 1]) 1884*da0073e9SAndroid Build Coastguard Worker @parametrize("train", [True, False]) 1885*da0073e9SAndroid Build Coastguard Worker def test_scaled_dot_product_fused_attention_mask_vs_math_cpu( 1886*da0073e9SAndroid Build Coastguard Worker self, 1887*da0073e9SAndroid Build Coastguard Worker device, 1888*da0073e9SAndroid Build Coastguard Worker fused_kernel, 1889*da0073e9SAndroid Build Coastguard Worker dtype, 1890*da0073e9SAndroid Build Coastguard Worker batch_size, 1891*da0073e9SAndroid Build Coastguard Worker q_seq_len, 1892*da0073e9SAndroid Build Coastguard Worker kv_seq_len, 1893*da0073e9SAndroid Build Coastguard Worker n_head, 1894*da0073e9SAndroid Build Coastguard Worker head_dim, 1895*da0073e9SAndroid Build Coastguard Worker mask_dim, 1896*da0073e9SAndroid Build Coastguard Worker bool_mask, 1897*da0073e9SAndroid Build Coastguard Worker train, 1898*da0073e9SAndroid Build Coastguard Worker ): 1899*da0073e9SAndroid Build Coastguard Worker tol = Tolerances(1e-5, 5e-6) 1900*da0073e9SAndroid Build Coastguard Worker if dtype is torch.bfloat16: 1901*da0073e9SAndroid Build Coastguard Worker tol = Tolerances(5e-2, 5e-2) 1902*da0073e9SAndroid Build Coastguard Worker if dtype is torch.float16: 1903*da0073e9SAndroid Build Coastguard Worker tol = Tolerances(1e-2, 1e-2) 1904*da0073e9SAndroid Build Coastguard Worker 1905*da0073e9SAndroid Build Coastguard Worker make_tensor = partial(rand_sdpa_tensor, type="dense", device=device, dtype=dtype, requires_grad=False) 1906*da0073e9SAndroid Build Coastguard Worker q_shape = SdpaShape(batch_size, n_head, q_seq_len, head_dim) 1907*da0073e9SAndroid Build Coastguard Worker kv_shape = SdpaShape(batch_size, n_head, kv_seq_len, head_dim) 1908*da0073e9SAndroid Build Coastguard Worker q = make_tensor(q_shape) 1909*da0073e9SAndroid Build Coastguard Worker k = make_tensor(kv_shape) 1910*da0073e9SAndroid Build Coastguard Worker v = make_tensor(kv_shape) 1911*da0073e9SAndroid Build Coastguard Worker q2, k2, v2 = q.clone(), k.clone(), v.clone() 1912*da0073e9SAndroid Build Coastguard Worker 1913*da0073e9SAndroid Build Coastguard Worker if train: 1914*da0073e9SAndroid Build Coastguard Worker q.requires_grad_(True) 1915*da0073e9SAndroid Build Coastguard Worker k.requires_grad_(True) 1916*da0073e9SAndroid Build Coastguard Worker v.requires_grad_(True) 1917*da0073e9SAndroid Build Coastguard Worker q2.requires_grad_(True) 1918*da0073e9SAndroid Build Coastguard Worker k2.requires_grad_(True) 1919*da0073e9SAndroid Build Coastguard Worker v2.requires_grad_(True) 1920*da0073e9SAndroid Build Coastguard Worker 1921*da0073e9SAndroid Build Coastguard Worker if dtype in [torch.bfloat16, torch.float16]: 1922*da0073e9SAndroid Build Coastguard Worker q2, k2, v2 = q2.float(), k2.float(), v2.float() 1923*da0073e9SAndroid Build Coastguard Worker # (B, nh, T, hs) 1924*da0073e9SAndroid Build Coastguard Worker q = q.view(batch_size, q_seq_len, n_head, head_dim).transpose(1, 2) 1925*da0073e9SAndroid Build Coastguard Worker k = k.view(batch_size, kv_seq_len, n_head, head_dim).transpose(1, 2) 1926*da0073e9SAndroid Build Coastguard Worker v = v.view(batch_size, kv_seq_len, n_head, head_dim).transpose(1, 2) 1927*da0073e9SAndroid Build Coastguard Worker if mask_dim == 4: 1928*da0073e9SAndroid Build Coastguard Worker mask_shape = (batch_size, n_head, q_seq_len, kv_seq_len) 1929*da0073e9SAndroid Build Coastguard Worker else: 1930*da0073e9SAndroid Build Coastguard Worker mask_shape = (q_seq_len, kv_seq_len) 1931*da0073e9SAndroid Build Coastguard Worker if bool_mask: 1932*da0073e9SAndroid Build Coastguard Worker attn_mask = torch.randint(0, 2, size=mask_shape, dtype=torch.bool, device=device) 1933*da0073e9SAndroid Build Coastguard Worker else: 1934*da0073e9SAndroid Build Coastguard Worker attn_mask = torch.randn(mask_shape, dtype=dtype, device=device) 1935*da0073e9SAndroid Build Coastguard Worker q2 = q2.view(batch_size, q_seq_len, n_head, head_dim).transpose(1, 2) 1936*da0073e9SAndroid Build Coastguard Worker k2 = k2.view(batch_size, kv_seq_len, n_head, head_dim).transpose(1, 2) 1937*da0073e9SAndroid Build Coastguard Worker v2 = v2.view(batch_size, kv_seq_len, n_head, head_dim).transpose(1, 2) 1938*da0073e9SAndroid Build Coastguard Worker 1939*da0073e9SAndroid Build Coastguard Worker with sdpa_kernel(backends=[fused_kernel]): 1940*da0073e9SAndroid Build Coastguard Worker actual = torch.nn.functional.scaled_dot_product_attention( 1941*da0073e9SAndroid Build Coastguard Worker q, k, v, attn_mask=attn_mask, dropout_p=0.0, is_causal=False) 1942*da0073e9SAndroid Build Coastguard Worker with sdpa_kernel(backends=[SDPBackend.MATH]): 1943*da0073e9SAndroid Build Coastguard Worker if not bool_mask and dtype in [torch.bfloat16, torch.float16]: 1944*da0073e9SAndroid Build Coastguard Worker attn_mask = attn_mask.float() 1945*da0073e9SAndroid Build Coastguard Worker math_ref = torch.nn.functional.scaled_dot_product_attention( 1946*da0073e9SAndroid Build Coastguard Worker q2, k2, v2, attn_mask=attn_mask, dropout_p=0.0, is_causal=False) 1947*da0073e9SAndroid Build Coastguard Worker 1948*da0073e9SAndroid Build Coastguard Worker if dtype in [torch.bfloat16, torch.float16]: 1949*da0073e9SAndroid Build Coastguard Worker math_ref = math_ref.to(dtype) 1950*da0073e9SAndroid Build Coastguard Worker 1951*da0073e9SAndroid Build Coastguard Worker self.assertEqual(actual, math_ref, atol=tol.atol, rtol=tol.rtol) 1952*da0073e9SAndroid Build Coastguard Worker 1953*da0073e9SAndroid Build Coastguard Worker if train: 1954*da0073e9SAndroid Build Coastguard Worker actual.sum().backward() 1955*da0073e9SAndroid Build Coastguard Worker math_ref.sum().backward() 1956*da0073e9SAndroid Build Coastguard Worker 1957*da0073e9SAndroid Build Coastguard Worker grad_q_actual, grad_k_actual, grad_v_actual = q.grad, k.grad, v.grad 1958*da0073e9SAndroid Build Coastguard Worker grad_q_ref, grad_k_ref, grad_v_ref = q2.grad, k2.grad, v2.grad 1959*da0073e9SAndroid Build Coastguard Worker 1960*da0073e9SAndroid Build Coastguard Worker self.assertEqual(grad_q_actual, grad_q_ref, atol=tol.atol, rtol=tol.rtol) 1961*da0073e9SAndroid Build Coastguard Worker self.assertEqual(grad_k_actual, grad_k_ref, atol=tol.atol, rtol=tol.rtol) 1962*da0073e9SAndroid Build Coastguard Worker self.assertEqual(grad_v_actual, grad_v_ref, atol=tol.atol, rtol=tol.rtol) 1963*da0073e9SAndroid Build Coastguard Worker 1964*da0073e9SAndroid Build Coastguard Worker @onlyCPU 1965*da0073e9SAndroid Build Coastguard Worker def test_scaled_dot_product_fused_attention_with_inf(self, device): 1966*da0073e9SAndroid Build Coastguard Worker # https://github.com/pytorch/pytorch/issues/127055. 1967*da0073e9SAndroid Build Coastguard Worker full = torch.full((600, 600), float("-inf"), device=device) 1968*da0073e9SAndroid Build Coastguard Worker mask = torch.triu(full, diagonal=1) + torch.tril(full, diagonal=-10) 1969*da0073e9SAndroid Build Coastguard Worker make_tensor = partial(rand_sdpa_tensor, type="dense", device=device, dtype=torch.float32, requires_grad=False) 1970*da0073e9SAndroid Build Coastguard Worker input_shape = SdpaShape(1, 600, 2, 8) 1971*da0073e9SAndroid Build Coastguard Worker q = make_tensor(input_shape) 1972*da0073e9SAndroid Build Coastguard Worker k = make_tensor(input_shape) 1973*da0073e9SAndroid Build Coastguard Worker v = make_tensor(input_shape) 1974*da0073e9SAndroid Build Coastguard Worker with sdpa_kernel(backends=[SDPBackend.MATH]): 1975*da0073e9SAndroid Build Coastguard Worker math_ref = torch.nn.functional.scaled_dot_product_attention(q, k, v, attn_mask=mask) 1976*da0073e9SAndroid Build Coastguard Worker with sdpa_kernel(backends=[SDPBackend.FLASH_ATTENTION]): 1977*da0073e9SAndroid Build Coastguard Worker actual = torch.nn.functional.scaled_dot_product_attention(q, k, v, attn_mask=mask) 1978*da0073e9SAndroid Build Coastguard Worker self.assertEqual(math_ref, actual) 1979*da0073e9SAndroid Build Coastguard Worker 1980*da0073e9SAndroid Build Coastguard Worker @parametrize("kernel", [SDPBackend.MATH]) 1981*da0073e9SAndroid Build Coastguard Worker def test_scaled_dot_product_attention_math_with_negative_scale(self, device, kernel: SDPBackend): 1982*da0073e9SAndroid Build Coastguard Worker # https://github.com/pytorch/pytorch/issues/105190. 1983*da0073e9SAndroid Build Coastguard Worker def ref(x): 1984*da0073e9SAndroid Build Coastguard Worker v1 = torch.matmul(x, x.transpose(-1, -2)) 1985*da0073e9SAndroid Build Coastguard Worker v2 = v1 / -0.0001 1986*da0073e9SAndroid Build Coastguard Worker v3 = v2.softmax(dim=-1) 1987*da0073e9SAndroid Build Coastguard Worker v4 = torch.matmul(v3, x) 1988*da0073e9SAndroid Build Coastguard Worker return v4 1989*da0073e9SAndroid Build Coastguard Worker 1990*da0073e9SAndroid Build Coastguard Worker x = torch.randn(1, 3, 64, 64, device=device) 1991*da0073e9SAndroid Build Coastguard Worker ref_result = ref(x) 1992*da0073e9SAndroid Build Coastguard Worker with sdpa_kernel(backends=[kernel]): 1993*da0073e9SAndroid Build Coastguard Worker sdp_math = torch.nn.functional.scaled_dot_product_attention(x, x, x, scale=-1.0 / 0.0001) 1994*da0073e9SAndroid Build Coastguard Worker self.assertEqual(ref_result, sdp_math) 1995*da0073e9SAndroid Build Coastguard Worker 1996*da0073e9SAndroid Build Coastguard Workerclass TestSDPACudaOnly(NNTestCase): 1997*da0073e9SAndroid Build Coastguard Worker """ Used to test CUDA only functionality of scaled_dot_product_attention 1998*da0073e9SAndroid Build Coastguard Worker Quarks: 1999*da0073e9SAndroid Build Coastguard Worker There is some trickiness with this function. Its runtime behavior 2000*da0073e9SAndroid Build Coastguard Worker is dependent on the CUDA architecture you are testing it on. See 2001*da0073e9SAndroid Build Coastguard Worker `PLATFORM_SUPPORTS_FUSED_ATTENTION` at the top of the file. 2002*da0073e9SAndroid Build Coastguard Worker Summary: 2003*da0073e9SAndroid Build Coastguard Worker Math: always supported 2004*da0073e9SAndroid Build Coastguard Worker FlashAttention: Supported on sm80 or newer hardware 2005*da0073e9SAndroid Build Coastguard Worker MemEfficientAttention: Supported on sm50 or newer hardware 2006*da0073e9SAndroid Build Coastguard Worker """ 2007*da0073e9SAndroid Build Coastguard Worker _do_cuda_memory_leak_check = True 2008*da0073e9SAndroid Build Coastguard Worker _do_cuda_non_default_stream = True 2009*da0073e9SAndroid Build Coastguard Worker 2010*da0073e9SAndroid Build Coastguard Worker # TODO USED FOR TESTING THE SCORES, e.g. testing ALIBI we don't need this now 2011*da0073e9SAndroid Build Coastguard Worker def normalize_flash_attn_S( 2012*da0073e9SAndroid Build Coastguard Worker self, 2013*da0073e9SAndroid Build Coastguard Worker attn_unnorm, 2014*da0073e9SAndroid Build Coastguard Worker q, 2015*da0073e9SAndroid Build Coastguard Worker k, 2016*da0073e9SAndroid Build Coastguard Worker v, 2017*da0073e9SAndroid Build Coastguard Worker query_padding_mask=None, 2018*da0073e9SAndroid Build Coastguard Worker key_padding_mask=None, 2019*da0073e9SAndroid Build Coastguard Worker attn_bias=None, 2020*da0073e9SAndroid Build Coastguard Worker is_dropout=False, 2021*da0073e9SAndroid Build Coastguard Worker causal=False, 2022*da0073e9SAndroid Build Coastguard Worker window_size=(-1, -1), # -1 means infinite window size 2023*da0073e9SAndroid Build Coastguard Worker scale=None, 2024*da0073e9SAndroid Build Coastguard Worker ): 2025*da0073e9SAndroid Build Coastguard Worker """ 2026*da0073e9SAndroid Build Coastguard Worker Arguments: 2027*da0073e9SAndroid Build Coastguard Worker q: (batch_size, seqlen_q, nheads, head_dim) 2028*da0073e9SAndroid Build Coastguard Worker k, v: (batch_size, seqlen_k, nheads, head_dim) 2029*da0073e9SAndroid Build Coastguard Worker key_padding_mask: (batch_size, seqlen_q) 2030*da0073e9SAndroid Build Coastguard Worker attn_bias: broadcastable to (batch_size, nheads, seqlen_q, seqlen_k) 2031*da0073e9SAndroid Build Coastguard Worker Output: 2032*da0073e9SAndroid Build Coastguard Worker softmax_lse: (batch_size, nheads, seqlen_q) 2033*da0073e9SAndroid Build Coastguard Worker softmax_max: (batch_size, nheads, seqlen_q) 2034*da0073e9SAndroid Build Coastguard Worker """ 2035*da0073e9SAndroid Build Coastguard Worker q = q.transpose(1, 2) 2036*da0073e9SAndroid Build Coastguard Worker k = k.transpose(1, 2) 2037*da0073e9SAndroid Build Coastguard Worker v = v.transpose(1, 2) 2038*da0073e9SAndroid Build Coastguard Worker if causal: 2039*da0073e9SAndroid Build Coastguard Worker window_size = (window_size[0], 0) 2040*da0073e9SAndroid Build Coastguard Worker q, k, v = q.float(), k.float(), v.float() 2041*da0073e9SAndroid Build Coastguard Worker _, seqlen_q, _, head_dim = q.shape 2042*da0073e9SAndroid Build Coastguard Worker seqlen_k = k.shape[1] 2043*da0073e9SAndroid Build Coastguard Worker b = q.shape[0] 2044*da0073e9SAndroid Build Coastguard Worker from torch.nn.attention.bias import _calculate_scale 2045*da0073e9SAndroid Build Coastguard Worker scale = _calculate_scale(head_dim, scale) 2046*da0073e9SAndroid Build Coastguard Worker scores = torch.matmul(q.transpose(1, 2) * scale, k.permute(0, 2, 3, 1)) 2047*da0073e9SAndroid Build Coastguard Worker if key_padding_mask is not None: 2048*da0073e9SAndroid Build Coastguard Worker scores.masked_fill_(~key_padding_mask.view(b, 1, 1, -1), float("-inf")) 2049*da0073e9SAndroid Build Coastguard Worker if window_size[0] >= 0 or window_size[1] >= 0: 2050*da0073e9SAndroid Build Coastguard Worker local_mask = self.construct_local_mask( 2051*da0073e9SAndroid Build Coastguard Worker seqlen_q, 2052*da0073e9SAndroid Build Coastguard Worker seqlen_k, 2053*da0073e9SAndroid Build Coastguard Worker window_size, 2054*da0073e9SAndroid Build Coastguard Worker query_padding_mask, 2055*da0073e9SAndroid Build Coastguard Worker key_padding_mask, 2056*da0073e9SAndroid Build Coastguard Worker q.device, 2057*da0073e9SAndroid Build Coastguard Worker ) 2058*da0073e9SAndroid Build Coastguard Worker scores.masked_fill_(local_mask, float("-inf")) 2059*da0073e9SAndroid Build Coastguard Worker if attn_bias is not None: 2060*da0073e9SAndroid Build Coastguard Worker scores = scores + attn_bias.to(dtype=scores.dtype) 2061*da0073e9SAndroid Build Coastguard Worker block_size_n = _get_block_size_n(scores.device, head_dim, is_dropout, causal) 2062*da0073e9SAndroid Build Coastguard Worker scores_block = scores.split(block_size_n, dim=-1) 2063*da0073e9SAndroid Build Coastguard Worker lse_block = torch.stack([torch.logsumexp(s, dim=-1) for s in scores_block], dim=-1) 2064*da0073e9SAndroid Build Coastguard Worker lse = torch.logsumexp(lse_block, dim=-1) 2065*da0073e9SAndroid Build Coastguard Worker # lse could be -inf (i.e. all values in scores are -inf), and we want to set those to inf 2066*da0073e9SAndroid Build Coastguard Worker # so that when we do torch.exp(m - lse), we get 0.0 instead of NaN. 2067*da0073e9SAndroid Build Coastguard Worker lse[lse == float("-inf")] = float("inf") 2068*da0073e9SAndroid Build Coastguard Worker scores_max_block = torch.stack([torch.amax(s, dim=-1) for s in scores_block], dim=-1) 2069*da0073e9SAndroid Build Coastguard Worker cummax_block = torch.cummax(scores_max_block.flip(-1), dim=-1).values.flip(-1).unbind(dim=-1) 2070*da0073e9SAndroid Build Coastguard Worker attn_unnorm_block = attn_unnorm.split(block_size_n, dim=-1) 2071*da0073e9SAndroid Build Coastguard Worker attn_norm = torch.cat( 2072*da0073e9SAndroid Build Coastguard Worker [ 2073*da0073e9SAndroid Build Coastguard Worker a * (torch.exp(m - lse)).unsqueeze(-1) 2074*da0073e9SAndroid Build Coastguard Worker for a, m in zip(attn_unnorm_block, cummax_block) 2075*da0073e9SAndroid Build Coastguard Worker ], 2076*da0073e9SAndroid Build Coastguard Worker dim=-1, 2077*da0073e9SAndroid Build Coastguard Worker ) 2078*da0073e9SAndroid Build Coastguard Worker if query_padding_mask is not None: 2079*da0073e9SAndroid Build Coastguard Worker attn_norm.masked_fill_(~query_padding_mask.view(b, 1, -1, 1), 0.0) 2080*da0073e9SAndroid Build Coastguard Worker # attn_norm.masked_fill_(rearrange(~query_padding_mask, "b s -> b 1 s 1"), 0.0) 2081*da0073e9SAndroid Build Coastguard Worker return attn_norm.to(dtype=attn_unnorm.dtype) 2082*da0073e9SAndroid Build Coastguard Worker 2083*da0073e9SAndroid Build Coastguard Worker def construct_local_mask(self, seqlen_q, seqlen_k, window_size, query_padding_mask, key_padding_mask, device): 2084*da0073e9SAndroid Build Coastguard Worker # row_idx = rearrange(torch.arange(seqlen_q, device=device, dtype=torch.long), "s -> s 1") 2085*da0073e9SAndroid Build Coastguard Worker row_idx = torch.arange(seqlen_q, device=device, dtype=torch.long).view(-1, 1) 2086*da0073e9SAndroid Build Coastguard Worker col_idx = torch.arange(seqlen_k, device=device, dtype=torch.long) 2087*da0073e9SAndroid Build Coastguard Worker sk = ( 2088*da0073e9SAndroid Build Coastguard Worker seqlen_k 2089*da0073e9SAndroid Build Coastguard Worker if key_padding_mask is None 2090*da0073e9SAndroid Build Coastguard Worker else key_padding_mask.sum(-1).view(-1, 1, 1, 1) 2091*da0073e9SAndroid Build Coastguard Worker # else rearrange(key_padding_mask.sum(-1), "b -> b 1 1 1") 2092*da0073e9SAndroid Build Coastguard Worker ) 2093*da0073e9SAndroid Build Coastguard Worker sq = ( 2094*da0073e9SAndroid Build Coastguard Worker seqlen_q 2095*da0073e9SAndroid Build Coastguard Worker if query_padding_mask is None 2096*da0073e9SAndroid Build Coastguard Worker else query_padding_mask.sum(-1).view(-1, 1, 1, 1) 2097*da0073e9SAndroid Build Coastguard Worker # else rearrange(query_padding_mask.sum(-1), "b -> b 1 1 1") 2098*da0073e9SAndroid Build Coastguard Worker ) 2099*da0073e9SAndroid Build Coastguard Worker if window_size[0] < 0: 2100*da0073e9SAndroid Build Coastguard Worker return col_idx > row_idx + sk - sq + window_size[1] 2101*da0073e9SAndroid Build Coastguard Worker else: 2102*da0073e9SAndroid Build Coastguard Worker sk = torch.full_like(col_idx, seqlen_k) if key_padding_mask is None else sk 2103*da0073e9SAndroid Build Coastguard Worker return torch.logical_or( 2104*da0073e9SAndroid Build Coastguard Worker col_idx > torch.minimum(row_idx + sk - sq + window_size[1], sk), 2105*da0073e9SAndroid Build Coastguard Worker col_idx < row_idx + sk - sq - window_size[0], 2106*da0073e9SAndroid Build Coastguard Worker ) 2107*da0073e9SAndroid Build Coastguard Worker 2108*da0073e9SAndroid Build Coastguard Worker def convert_flash_attn_S_to_softmax( 2109*da0073e9SAndroid Build Coastguard Worker self, 2110*da0073e9SAndroid Build Coastguard Worker S, 2111*da0073e9SAndroid Build Coastguard Worker seqlen_q, 2112*da0073e9SAndroid Build Coastguard Worker seqlen_k, 2113*da0073e9SAndroid Build Coastguard Worker query_padding_mask, 2114*da0073e9SAndroid Build Coastguard Worker key_padding_mask, 2115*da0073e9SAndroid Build Coastguard Worker causal=False, 2116*da0073e9SAndroid Build Coastguard Worker window_size=(-1, -1), # -1 means infinite window size 2117*da0073e9SAndroid Build Coastguard Worker ): 2118*da0073e9SAndroid Build Coastguard Worker """FlashAttention stores the S matrix in a different way. 2119*da0073e9SAndroid Build Coastguard Worker Arguments: 2120*da0073e9SAndroid Build Coastguard Worker S: (batch_size, nheads, seqlen_q, seqlen_k) 2121*da0073e9SAndroid Build Coastguard Worker query_padding_mask: (batch_size, seqlen_q) 2122*da0073e9SAndroid Build Coastguard Worker key_padding_mask: (batch_size, seqlen_k) 2123*da0073e9SAndroid Build Coastguard Worker """ 2124*da0073e9SAndroid Build Coastguard Worker if TEST_WITH_ROCM: 2125*da0073e9SAndroid Build Coastguard Worker return S 2126*da0073e9SAndroid Build Coastguard Worker b = S.shape[0] 2127*da0073e9SAndroid Build Coastguard Worker 2128*da0073e9SAndroid Build Coastguard Worker if causal: 2129*da0073e9SAndroid Build Coastguard Worker window_size = (window_size[0], 0) 2130*da0073e9SAndroid Build Coastguard Worker seqlen_q_rounded, seqlen_k_rounded = S.shape[-2:] 2131*da0073e9SAndroid Build Coastguard Worker S_converted = S 2132*da0073e9SAndroid Build Coastguard Worker if window_size[0] >= 0 or window_size[1] >= 0: 2133*da0073e9SAndroid Build Coastguard Worker local_mask = self.construct_local_mask( 2134*da0073e9SAndroid Build Coastguard Worker seqlen_q, 2135*da0073e9SAndroid Build Coastguard Worker seqlen_k, 2136*da0073e9SAndroid Build Coastguard Worker window_size, 2137*da0073e9SAndroid Build Coastguard Worker query_padding_mask, 2138*da0073e9SAndroid Build Coastguard Worker key_padding_mask, 2139*da0073e9SAndroid Build Coastguard Worker S.device, 2140*da0073e9SAndroid Build Coastguard Worker ) 2141*da0073e9SAndroid Build Coastguard Worker local_mask = F.pad( 2142*da0073e9SAndroid Build Coastguard Worker local_mask, 2143*da0073e9SAndroid Build Coastguard Worker (0, seqlen_k_rounded - seqlen_k, 0, seqlen_q_rounded - seqlen_q), 2144*da0073e9SAndroid Build Coastguard Worker value=True, 2145*da0073e9SAndroid Build Coastguard Worker ) 2146*da0073e9SAndroid Build Coastguard Worker S_converted = S_converted.masked_fill(local_mask, 0.0) 2147*da0073e9SAndroid Build Coastguard Worker 2148*da0073e9SAndroid Build Coastguard Worker # Need to zero out things not in attention_mask in case S was initialized with random values 2149*da0073e9SAndroid Build Coastguard Worker # and some of those values aren't overwritten. 2150*da0073e9SAndroid Build Coastguard Worker seqlen_q_og = ( 2151*da0073e9SAndroid Build Coastguard Worker query_padding_mask.shape[-1] if query_padding_mask is not None else seqlen_q_rounded 2152*da0073e9SAndroid Build Coastguard Worker ) 2153*da0073e9SAndroid Build Coastguard Worker if query_padding_mask is not None: 2154*da0073e9SAndroid Build Coastguard Worker query_padding_mask = F.pad(query_padding_mask, (0, seqlen_q_rounded - seqlen_q_og)) 2155*da0073e9SAndroid Build Coastguard Worker # S_converted = S_converted.masked_fill(rearrange(~query_padding_mask, "b s -> b 1 s 1"), 0.0) 2156*da0073e9SAndroid Build Coastguard Worker S_converted = S_converted.masked_fill(~query_padding_mask.view(b, 1, -1, 1), 0.0) 2157*da0073e9SAndroid Build Coastguard Worker seqlen_k_og = key_padding_mask.shape[-1] if key_padding_mask is not None else seqlen_k 2158*da0073e9SAndroid Build Coastguard Worker if key_padding_mask is not None: 2159*da0073e9SAndroid Build Coastguard Worker key_padding_mask = F.pad(key_padding_mask, (0, seqlen_k_rounded - seqlen_k_og)) 2160*da0073e9SAndroid Build Coastguard Worker S_converted = S_converted.masked_fill(~key_padding_mask.view(b, 1, 1, -1), 0.0) 2161*da0073e9SAndroid Build Coastguard Worker # S_converted = S_converted.masked_fill(rearrange(~key_padding_mask, "b s -> b 1 1 s"), 0.0) 2162*da0073e9SAndroid Build Coastguard Worker S_converted = F.pad(S_converted, (0, 0, 0, seqlen_q_og - seqlen_q_rounded)) 2163*da0073e9SAndroid Build Coastguard Worker S_converted = F.pad(S_converted, (0, seqlen_k_og - seqlen_k_rounded)) 2164*da0073e9SAndroid Build Coastguard Worker return S_converted[:, :, :seqlen_q, :seqlen_k] 2165*da0073e9SAndroid Build Coastguard Worker 2166*da0073e9SAndroid Build Coastguard Worker @skipIfRocm # No cuDNN Attention 2167*da0073e9SAndroid Build Coastguard Worker @unittest.skipIf(not PLATFORM_SUPPORTS_CUDNN_ATTENTION, "cuDNN Attention is not supported on this system") 2168*da0073e9SAndroid Build Coastguard Worker def test_cudnn_attention_different_dk_dv(self, device): 2169*da0073e9SAndroid Build Coastguard Worker dtype = torch.bfloat16 2170*da0073e9SAndroid Build Coastguard Worker make_tensor = partial(torch.rand, device=device, dtype=dtype, requires_grad=True) 2171*da0073e9SAndroid Build Coastguard Worker batch, num_heads, head_dim_k, head_dim_v = 32, 16, 128, 64 2172*da0073e9SAndroid Build Coastguard Worker seq_len = 640 2173*da0073e9SAndroid Build Coastguard Worker q_shape = SdpaShape(batch, num_heads, seq_len, head_dim_k) 2174*da0073e9SAndroid Build Coastguard Worker k_shape = SdpaShape(batch, num_heads, seq_len, head_dim_k) 2175*da0073e9SAndroid Build Coastguard Worker v_shape = SdpaShape(batch, num_heads, seq_len, head_dim_v) 2176*da0073e9SAndroid Build Coastguard Worker query, key, value = make_tensor(q_shape), make_tensor(k_shape), make_tensor(v_shape) 2177*da0073e9SAndroid Build Coastguard Worker 2178*da0073e9SAndroid Build Coastguard Worker with sdpa_kernel(backends=[SDPBackend.CUDNN_ATTENTION]): 2179*da0073e9SAndroid Build Coastguard Worker actual = torch.nn.functional.scaled_dot_product_attention( 2180*da0073e9SAndroid Build Coastguard Worker query, key, value, attn_mask=None, dropout_p=0.0, is_causal=False) 2181*da0073e9SAndroid Build Coastguard Worker with sdpa_kernel(backends=[SDPBackend.MATH]): 2182*da0073e9SAndroid Build Coastguard Worker math_ref = torch.nn.functional.scaled_dot_product_attention( 2183*da0073e9SAndroid Build Coastguard Worker query.contiguous().to(torch.float32), 2184*da0073e9SAndroid Build Coastguard Worker key.contiguous().to(torch.float32), 2185*da0073e9SAndroid Build Coastguard Worker value.contiguous().to(torch.float32), 2186*da0073e9SAndroid Build Coastguard Worker attn_mask=None, dropout_p=0.0, is_causal=False) 2187*da0073e9SAndroid Build Coastguard Worker 2188*da0073e9SAndroid Build Coastguard Worker self.assertEqual(actual.contiguous(), math_ref.contiguous().to(dtype), atol=1e-3, rtol=1e-2) 2189*da0073e9SAndroid Build Coastguard Worker 2190*da0073e9SAndroid Build Coastguard Worker 2191*da0073e9SAndroid Build Coastguard Worker @unittest.skipIf(not PLATFORM_SUPPORTS_MEM_EFF_ATTENTION, "Fused SDPA was not built for this system") 2192*da0073e9SAndroid Build Coastguard Worker @parametrize("mask_dim", [1, 2, 3, 4]) 2193*da0073e9SAndroid Build Coastguard Worker def test_mem_efficient_attention_mask_variants(self, device, mask_dim: List[int]): 2194*da0073e9SAndroid Build Coastguard Worker dtype = torch.float16 2195*da0073e9SAndroid Build Coastguard Worker make_tensor = partial(torch.rand, device=device, dtype=dtype, requires_grad=True) 2196*da0073e9SAndroid Build Coastguard Worker batch, num_heads, head_dim = 8, 8, 64 2197*da0073e9SAndroid Build Coastguard Worker seq_len_q, seq_len_kv = 64, 32 2198*da0073e9SAndroid Build Coastguard Worker query = make_tensor(SdpaShape(batch, num_heads, seq_len_q, head_dim)) 2199*da0073e9SAndroid Build Coastguard Worker kv_shape = SdpaShape(batch, num_heads, seq_len_kv, head_dim) 2200*da0073e9SAndroid Build Coastguard Worker key, value = make_tensor(kv_shape), make_tensor(kv_shape) 2201*da0073e9SAndroid Build Coastguard Worker 2202*da0073e9SAndroid Build Coastguard Worker if mask_dim == 1: 2203*da0073e9SAndroid Build Coastguard Worker mask = torch.randn((seq_len_kv,), device=device, dtype=dtype) 2204*da0073e9SAndroid Build Coastguard Worker elif mask_dim == 2: 2205*da0073e9SAndroid Build Coastguard Worker mask = torch.randn((seq_len_q, seq_len_kv), device=device, dtype=dtype) 2206*da0073e9SAndroid Build Coastguard Worker elif mask_dim == 3: 2207*da0073e9SAndroid Build Coastguard Worker mask = torch.randn((num_heads, seq_len_q, seq_len_kv), device=device, dtype=dtype) 2208*da0073e9SAndroid Build Coastguard Worker elif mask_dim == 4: 2209*da0073e9SAndroid Build Coastguard Worker mask = torch.randn((batch, num_heads, seq_len_q, seq_len_kv), device=device, dtype=dtype) 2210*da0073e9SAndroid Build Coastguard Worker with sdpa_kernel(backends=[SDPBackend.EFFICIENT_ATTENTION]): 2211*da0073e9SAndroid Build Coastguard Worker out = F.scaled_dot_product_attention(query, key, value, mask) 2212*da0073e9SAndroid Build Coastguard Worker out.sum().backward() 2213*da0073e9SAndroid Build Coastguard Worker 2214*da0073e9SAndroid Build Coastguard Worker @unittest.skipIf(not PLATFORM_SUPPORTS_MEM_EFF_ATTENTION, "Fused SDPA was not built for this system") 2215*da0073e9SAndroid Build Coastguard Worker @parametrize("dtype", [torch.float, torch.float16]) 2216*da0073e9SAndroid Build Coastguard Worker def test_mem_eff_attention_pad_mask(self, device, dtype): 2217*da0073e9SAndroid Build Coastguard Worker make_tensor = partial(torch.rand, device=device, dtype=dtype, requires_grad=True) 2218*da0073e9SAndroid Build Coastguard Worker batch, num_heads, head_dim = 8, 8, 64 2219*da0073e9SAndroid Build Coastguard Worker seq_len_q, seq_len_kv = 64, 15 2220*da0073e9SAndroid Build Coastguard Worker query = make_tensor(SdpaShape(batch, num_heads, seq_len_q, head_dim)) 2221*da0073e9SAndroid Build Coastguard Worker kv_shape = SdpaShape(batch, num_heads, seq_len_kv, head_dim) 2222*da0073e9SAndroid Build Coastguard Worker key, value = make_tensor(kv_shape), make_tensor(kv_shape) 2223*da0073e9SAndroid Build Coastguard Worker mask = torch.randn((batch, num_heads, seq_len_q, seq_len_kv), device=device, dtype=dtype) 2224*da0073e9SAndroid Build Coastguard Worker with sdpa_kernel(backends=[SDPBackend.EFFICIENT_ATTENTION]): 2225*da0073e9SAndroid Build Coastguard Worker out = F.scaled_dot_product_attention(query, key, value, mask) 2226*da0073e9SAndroid Build Coastguard Worker out.sum().backward() 2227*da0073e9SAndroid Build Coastguard Worker 2228*da0073e9SAndroid Build Coastguard Worker @unittest.skipIf(not PLATFORM_SUPPORTS_MEM_EFF_ATTENTION, "Fused SDPA was not built for this system") 2229*da0073e9SAndroid Build Coastguard Worker @parametrize("dtype", [torch.float, torch.float16]) 2230*da0073e9SAndroid Build Coastguard Worker def test_mem_eff_attention_non_contiguous_mask(self, device, dtype): 2231*da0073e9SAndroid Build Coastguard Worker make_tensor = partial(torch.rand, device=device, dtype=dtype, requires_grad=True) 2232*da0073e9SAndroid Build Coastguard Worker batch, num_heads, head_dim = 8, 8, 64 2233*da0073e9SAndroid Build Coastguard Worker seq_len_q, seq_len_kv = 64, 16 2234*da0073e9SAndroid Build Coastguard Worker query = make_tensor(SdpaShape(batch, num_heads, seq_len_q, head_dim)) 2235*da0073e9SAndroid Build Coastguard Worker kv_shape = SdpaShape(batch, num_heads, seq_len_kv, head_dim) 2236*da0073e9SAndroid Build Coastguard Worker key, value = make_tensor(kv_shape), make_tensor(kv_shape) 2237*da0073e9SAndroid Build Coastguard Worker mask = torch.randn((batch, num_heads, seq_len_q, seq_len_kv), device=device, dtype=dtype) 2238*da0073e9SAndroid Build Coastguard Worker mask = torch.as_strided(mask, (batch, num_heads, seq_len_q, seq_len_kv), (0, 0, 0, 1)) 2239*da0073e9SAndroid Build Coastguard Worker with sdpa_kernel(backends=[SDPBackend.EFFICIENT_ATTENTION]): 2240*da0073e9SAndroid Build Coastguard Worker out = F.scaled_dot_product_attention(query, key, value, mask) 2241*da0073e9SAndroid Build Coastguard Worker out.sum().backward() 2242*da0073e9SAndroid Build Coastguard Worker 2243*da0073e9SAndroid Build Coastguard Worker @unittest.skipIf(not PLATFORM_SUPPORTS_MEM_EFF_ATTENTION, "Fused SDPA was not built for this system") 2244*da0073e9SAndroid Build Coastguard Worker @parametrize("dtype", [torch.float, torch.float16]) 2245*da0073e9SAndroid Build Coastguard Worker def test_mem_eff_attention_long_sequence_mask(self, device, dtype): 2246*da0073e9SAndroid Build Coastguard Worker if torch.cuda.get_device_properties('cuda').total_memory < 80 * 2**30: 2247*da0073e9SAndroid Build Coastguard Worker unittest.skip("This test requires substatnial GPU memory.") 2248*da0073e9SAndroid Build Coastguard Worker return 2249*da0073e9SAndroid Build Coastguard Worker make_tensor = partial(torch.rand, device=device, dtype=dtype, requires_grad=True) 2250*da0073e9SAndroid Build Coastguard Worker batch, num_heads, head_dim = 1, 32, 64 2251*da0073e9SAndroid Build Coastguard Worker seq_len_q, seq_len_kv = 8192, 8192 2252*da0073e9SAndroid Build Coastguard Worker query = make_tensor(SdpaShape(batch, num_heads, seq_len_q, head_dim)) 2253*da0073e9SAndroid Build Coastguard Worker kv_shape = SdpaShape(batch, num_heads, seq_len_kv, head_dim) 2254*da0073e9SAndroid Build Coastguard Worker key, value = make_tensor(kv_shape), make_tensor(kv_shape) 2255*da0073e9SAndroid Build Coastguard Worker mask = torch.randn((batch, num_heads, seq_len_q, seq_len_kv), device=device, dtype=dtype) 2256*da0073e9SAndroid Build Coastguard Worker with sdpa_kernel(backends=[SDPBackend.EFFICIENT_ATTENTION]): 2257*da0073e9SAndroid Build Coastguard Worker out = F.scaled_dot_product_attention(query, key, value, mask) 2258*da0073e9SAndroid Build Coastguard Worker out.sum().backward() 2259*da0073e9SAndroid Build Coastguard Worker 2260*da0073e9SAndroid Build Coastguard Worker @unittest.skipIf(not PLATFORM_SUPPORTS_MEM_EFF_ATTENTION, "Fused SDPA was not built for this system") 2261*da0073e9SAndroid Build Coastguard Worker def test_mem_eff_attention_non_contig_mask_bug(self, device): 2262*da0073e9SAndroid Build Coastguard Worker # Without the fix this produces `AssertionError: assert 0.07352933287620544 < 1e-07` 2263*da0073e9SAndroid Build Coastguard Worker # Shapes taken from repro 2264*da0073e9SAndroid Build Coastguard Worker query_size = (3, 16, 1, 128) 2265*da0073e9SAndroid Build Coastguard Worker query_strides = (2304, 128, 2048, 1) 2266*da0073e9SAndroid Build Coastguard Worker key_size = (3, 16, 14, 128) 2267*da0073e9SAndroid Build Coastguard Worker key_strides = (3584, 0, 256, 1) 2268*da0073e9SAndroid Build Coastguard Worker value_size = (3, 16, 14, 128) 2269*da0073e9SAndroid Build Coastguard Worker value_strides = (3584, 0, 256, 1) 2270*da0073e9SAndroid Build Coastguard Worker attention_mask_size = (3, 1, 1, 14) 2271*da0073e9SAndroid Build Coastguard Worker attn_mask_strides = (14, 14, 14, 1) 2272*da0073e9SAndroid Build Coastguard Worker 2273*da0073e9SAndroid Build Coastguard Worker # Calculate the number of elements needed for each tensor 2274*da0073e9SAndroid Build Coastguard Worker query_num_elements = max(size * stride for size, stride in zip(query_size, query_strides)) 2275*da0073e9SAndroid Build Coastguard Worker key_num_elements = max(size * stride for size, stride in zip(key_size, key_strides)) 2276*da0073e9SAndroid Build Coastguard Worker value_num_elements = max(size * stride for size, stride in zip(value_size, value_strides)) 2277*da0073e9SAndroid Build Coastguard Worker attention_mask_num_elements = max(size * stride for size, stride in zip(attention_mask_size, attn_mask_strides)) 2278*da0073e9SAndroid Build Coastguard Worker 2279*da0073e9SAndroid Build Coastguard Worker # Create the tensors with the specified sizes and strides 2280*da0073e9SAndroid Build Coastguard Worker query = torch.randn(query_num_elements, device=device).as_strided(query_size, query_strides) 2281*da0073e9SAndroid Build Coastguard Worker key = torch.randn(key_num_elements, device=device).as_strided(key_size, key_strides) 2282*da0073e9SAndroid Build Coastguard Worker value = torch.randn(value_num_elements, device=device).as_strided(value_size, value_strides) 2283*da0073e9SAndroid Build Coastguard Worker bias = torch.randn(attention_mask_num_elements, device=device).as_strided(attention_mask_size, attn_mask_strides) 2284*da0073e9SAndroid Build Coastguard Worker 2285*da0073e9SAndroid Build Coastguard Worker with sdpa_kernel(backends=[SDPBackend.EFFICIENT_ATTENTION]): 2286*da0073e9SAndroid Build Coastguard Worker out = F.scaled_dot_product_attention(query, key, value, bias) 2287*da0073e9SAndroid Build Coastguard Worker out_contig = F.scaled_dot_product_attention(query, key, value, bias.contiguous()) 2288*da0073e9SAndroid Build Coastguard Worker 2289*da0073e9SAndroid Build Coastguard Worker max_diff = (out - out_contig).abs().mean() 2290*da0073e9SAndroid Build Coastguard Worker self.assertTrue(max_diff.item() < 1e-7) 2291*da0073e9SAndroid Build Coastguard Worker 2292*da0073e9SAndroid Build Coastguard Worker @unittest.skipIf(not PLATFORM_SUPPORTS_FLASH_ATTENTION, "Fused SDPA was not built for this system") 2293*da0073e9SAndroid Build Coastguard Worker def test_singelton_head_dim_stride_ne_1(self, device): 2294*da0073e9SAndroid Build Coastguard Worker query = torch.tensor([[[[1, 2]]]], dtype=torch.float16, device=device) 2295*da0073e9SAndroid Build Coastguard Worker query = query.transpose(-1, -2) 2296*da0073e9SAndroid Build Coastguard Worker key = torch.tensor([[[[1]]]], dtype=torch.float16, device=device) 2297*da0073e9SAndroid Build Coastguard Worker value = torch.tensor([[[[1]]]], dtype=torch.float16, device=device) 2298*da0073e9SAndroid Build Coastguard Worker 2299*da0073e9SAndroid Build Coastguard Worker with torch.backends.cuda.sdp_kernel(enable_math=False, enable_flash=True, enable_mem_efficient=False): 2300*da0073e9SAndroid Build Coastguard Worker scaled_dot_product_attention(query, key, value) 2301*da0073e9SAndroid Build Coastguard Worker 2302*da0073e9SAndroid Build Coastguard Worker @unittest.skipIf(not PLATFORM_SUPPORTS_MEM_EFF_ATTENTION, "Fused SDPA was not built for this system") 2303*da0073e9SAndroid Build Coastguard Worker @parametrize("type", ["dense", "nested"]) 2304*da0073e9SAndroid Build Coastguard Worker @parametrize("is_contiguous", [True, False]) 2305*da0073e9SAndroid Build Coastguard Worker def test_scaled_dot_product_attention_fused_kernels_packed(self, device, type: str, is_contiguous: bool): 2306*da0073e9SAndroid Build Coastguard Worker if TEST_WITH_ROCM and type == 'nested': 2307*da0073e9SAndroid Build Coastguard Worker self.skipTest("ROCM does not support efficient attention on nested tensors, for now") 2308*da0073e9SAndroid Build Coastguard Worker make_tensor = partial(rand_sdpa_tensor, type=type, device=device, dtype=torch.float16, packed=True) 2309*da0073e9SAndroid Build Coastguard Worker 2310*da0073e9SAndroid Build Coastguard Worker batch_size, seq_len, num_heads, head_dim = 32, 64, 16, 64 2311*da0073e9SAndroid Build Coastguard Worker shape = SdpaShape(batch_size, num_heads, seq_len, head_dim) 2312*da0073e9SAndroid Build Coastguard Worker 2313*da0073e9SAndroid Build Coastguard Worker # Test Packed 2314*da0073e9SAndroid Build Coastguard Worker qkv = make_tensor(shape) 2315*da0073e9SAndroid Build Coastguard Worker query, key, value = qkv.chunk(3, dim=-1) 2316*da0073e9SAndroid Build Coastguard Worker 2317*da0073e9SAndroid Build Coastguard Worker query = query.view(batch_size, -1, num_heads, head_dim).transpose(1, 2) 2318*da0073e9SAndroid Build Coastguard Worker value = value.view(batch_size, -1, num_heads, head_dim).transpose(1, 2) 2319*da0073e9SAndroid Build Coastguard Worker key = key.view(batch_size, -1, num_heads, head_dim).transpose(1, 2) 2320*da0073e9SAndroid Build Coastguard Worker 2321*da0073e9SAndroid Build Coastguard Worker if is_contiguous: 2322*da0073e9SAndroid Build Coastguard Worker query = query.contiguous() 2323*da0073e9SAndroid Build Coastguard Worker key = key.contiguous() 2324*da0073e9SAndroid Build Coastguard Worker value = value.contiguous() 2325*da0073e9SAndroid Build Coastguard Worker 2326*da0073e9SAndroid Build Coastguard Worker with sdpa_kernel(backends=[SDPBackend.EFFICIENT_ATTENTION]): 2327*da0073e9SAndroid Build Coastguard Worker actual = torch.nn.functional.scaled_dot_product_attention( 2328*da0073e9SAndroid Build Coastguard Worker query, key, value, attn_mask=None, dropout_p=0.0, is_causal=False) 2329*da0073e9SAndroid Build Coastguard Worker with sdpa_kernel(backends=[SDPBackend.MATH]): 2330*da0073e9SAndroid Build Coastguard Worker math_ref = torch.nn.functional.scaled_dot_product_attention( 2331*da0073e9SAndroid Build Coastguard Worker query.contiguous(), key.contiguous(), value.contiguous(), 2332*da0073e9SAndroid Build Coastguard Worker attn_mask=None, dropout_p=0.0, is_causal=False) 2333*da0073e9SAndroid Build Coastguard Worker 2334*da0073e9SAndroid Build Coastguard Worker self.assertEqual(actual.contiguous(), math_ref.contiguous(), atol=2e-3, rtol=1e-2) 2335*da0073e9SAndroid Build Coastguard Worker 2336*da0073e9SAndroid Build Coastguard Worker @skipIfRocm # Missing nested and EFFICIENT_ATTENTION 2337*da0073e9SAndroid Build Coastguard Worker @unittest.skipIf(not PLATFORM_SUPPORTS_FUSED_ATTENTION, "Fused SDPA was not built for this system") 2338*da0073e9SAndroid Build Coastguard Worker @parametrize("type", ["dense", "nested"]) 2339*da0073e9SAndroid Build Coastguard Worker @parametrize("fused_kernel", [SDPBackend.FLASH_ATTENTION, SDPBackend.EFFICIENT_ATTENTION] if 2340*da0073e9SAndroid Build Coastguard Worker PLATFORM_SUPPORTS_FLASH_ATTENTION else [SDPBackend.EFFICIENT_ATTENTION]) 2341*da0073e9SAndroid Build Coastguard Worker def test_scaled_dot_product_attention_fused_kernels_packed_accuracy(self, device, type: str, fused_kernel: str): 2342*da0073e9SAndroid Build Coastguard Worker def rand_nt(shape): 2343*da0073e9SAndroid Build Coastguard Worker batch, seq_len, num_heads, head_dim = shape 2344*da0073e9SAndroid Build Coastguard Worker tensors = [6 * torch.rand((seq_len, 3 * num_heads * head_dim), device=device, dtype=torch.float32) - 3 2345*da0073e9SAndroid Build Coastguard Worker for _ in range(batch)] 2346*da0073e9SAndroid Build Coastguard Worker return (torch.nested.nested_tensor(tensors, device=device, dtype=torch.float32), 2347*da0073e9SAndroid Build Coastguard Worker torch.nested.nested_tensor(tensors, device=device, dtype=torch.float16)) 2348*da0073e9SAndroid Build Coastguard Worker 2349*da0073e9SAndroid Build Coastguard Worker def rand_tensor(shape): 2350*da0073e9SAndroid Build Coastguard Worker batch, seq_len, num_heads, head_dim = shape 2351*da0073e9SAndroid Build Coastguard Worker tensor = 6 * torch.rand((batch, seq_len, 3 * num_heads * head_dim), device=device, dtype=torch.float32) - 3 2352*da0073e9SAndroid Build Coastguard Worker return tensor, tensor.to(dtype=torch.float16) 2353*da0073e9SAndroid Build Coastguard Worker 2354*da0073e9SAndroid Build Coastguard Worker batch_size, seq_len, num_heads, head_dim = 16, 8, 4, 64 2355*da0073e9SAndroid Build Coastguard Worker shape = (batch_size, seq_len, num_heads, head_dim) 2356*da0073e9SAndroid Build Coastguard Worker 2357*da0073e9SAndroid Build Coastguard Worker # Test Packed 2358*da0073e9SAndroid Build Coastguard Worker qkv, qkv_low_precision = rand_tensor(shape) if type == "dense" else rand_nt(shape) 2359*da0073e9SAndroid Build Coastguard Worker query, key, value = qkv.chunk(3, dim=-1) 2360*da0073e9SAndroid Build Coastguard Worker query_lp, key_lp, value_lp = qkv_low_precision.chunk(3, dim=-1) 2361*da0073e9SAndroid Build Coastguard Worker 2362*da0073e9SAndroid Build Coastguard Worker query = query.view(batch_size, -1, num_heads, head_dim).transpose(1, 2) 2363*da0073e9SAndroid Build Coastguard Worker key = key.view(batch_size, -1, num_heads, head_dim).transpose(1, 2) 2364*da0073e9SAndroid Build Coastguard Worker value = value.view(batch_size, -1, num_heads, head_dim).transpose(1, 2) 2365*da0073e9SAndroid Build Coastguard Worker 2366*da0073e9SAndroid Build Coastguard Worker query_lp = query_lp.view(batch_size, -1, num_heads, head_dim).transpose(1, 2) 2367*da0073e9SAndroid Build Coastguard Worker key_lp = key_lp.view(batch_size, -1, num_heads, head_dim).transpose(1, 2) 2368*da0073e9SAndroid Build Coastguard Worker value_lp = value_lp.view(batch_size, -1, num_heads, head_dim).transpose(1, 2) 2369*da0073e9SAndroid Build Coastguard Worker 2370*da0073e9SAndroid Build Coastguard Worker with sdpa_kernel(backends=[fused_kernel]): 2371*da0073e9SAndroid Build Coastguard Worker actual = torch.nn.functional.scaled_dot_product_attention( 2372*da0073e9SAndroid Build Coastguard Worker query_lp, key_lp, value_lp, attn_mask=None, dropout_p=0.0, is_causal=False) 2373*da0073e9SAndroid Build Coastguard Worker 2374*da0073e9SAndroid Build Coastguard Worker with sdpa_kernel(backends=[SDPBackend.MATH]): 2375*da0073e9SAndroid Build Coastguard Worker math_ref_lp = torch.nn.functional.scaled_dot_product_attention( 2376*da0073e9SAndroid Build Coastguard Worker query_lp.contiguous(), key_lp.contiguous(), value_lp.contiguous(), 2377*da0073e9SAndroid Build Coastguard Worker attn_mask=None, dropout_p=0.0, is_causal=False) 2378*da0073e9SAndroid Build Coastguard Worker 2379*da0073e9SAndroid Build Coastguard Worker math_query = query.contiguous() 2380*da0073e9SAndroid Build Coastguard Worker math_key = key.contiguous() 2381*da0073e9SAndroid Build Coastguard Worker math_value = value.contiguous() 2382*da0073e9SAndroid Build Coastguard Worker 2383*da0073e9SAndroid Build Coastguard Worker math_ref = torch.nn.functional.scaled_dot_product_attention( 2384*da0073e9SAndroid Build Coastguard Worker math_query, math_key, math_value, attn_mask=None, dropout_p=0.0, is_causal=False) 2385*da0073e9SAndroid Build Coastguard Worker 2386*da0073e9SAndroid Build Coastguard Worker actual_test = actual 2387*da0073e9SAndroid Build Coastguard Worker math_ref_test = math_ref 2388*da0073e9SAndroid Build Coastguard Worker math_ref_lp_test = math_ref_lp 2389*da0073e9SAndroid Build Coastguard Worker 2390*da0073e9SAndroid Build Coastguard Worker if actual_test.is_nested: 2391*da0073e9SAndroid Build Coastguard Worker actual_test = torch.nested.to_padded_tensor(actual_test.contiguous(), padding=0.0) 2392*da0073e9SAndroid Build Coastguard Worker math_ref_test = torch.nested.to_padded_tensor(math_ref_test, padding=0.0) 2393*da0073e9SAndroid Build Coastguard Worker math_ref_lp_test = torch.nested.to_padded_tensor(math_ref_lp_test, padding=0.0) 2394*da0073e9SAndroid Build Coastguard Worker 2395*da0073e9SAndroid Build Coastguard Worker actual_test = actual_test.to(dtype=torch.float32).contiguous() 2396*da0073e9SAndroid Build Coastguard Worker math_ref_test = math_ref_test.to(dtype=torch.float32).contiguous() 2397*da0073e9SAndroid Build Coastguard Worker math_ref_lp_test = math_ref_lp_test.to(dtype=torch.float32).contiguous() 2398*da0073e9SAndroid Build Coastguard Worker 2399*da0073e9SAndroid Build Coastguard Worker self.assertEqual(math_ref_test, math_ref_lp_test, atol=7e-3, rtol=7e-3) 2400*da0073e9SAndroid Build Coastguard Worker self.assertEqual(actual_test, math_ref_test, atol=5e-3, rtol=5e-3) 2401*da0073e9SAndroid Build Coastguard Worker 2402*da0073e9SAndroid Build Coastguard Worker @unittest.skipIf(not PLATFORM_SUPPORTS_MEM_EFF_ATTENTION, "Efficient Attention was not built for this system") 2403*da0073e9SAndroid Build Coastguard Worker @parametrize("contiguous_inputs", [True, False]) 2404*da0073e9SAndroid Build Coastguard Worker @parametrize("is_causal", [True, False]) 2405*da0073e9SAndroid Build Coastguard Worker def test_sdp_mem_efficient_grad_against_math(self, device, contiguous_inputs: bool, is_causal: bool): 2406*da0073e9SAndroid Build Coastguard Worker batch_size, seq_len, num_heads, head_dim = 4, 4, 2, 16 2407*da0073e9SAndroid Build Coastguard Worker make_tensor = partial(rand_sdpa_tensor, type="dense", device=device, 2408*da0073e9SAndroid Build Coastguard Worker dtype=torch.float64, requires_grad=True, packed=True) 2409*da0073e9SAndroid Build Coastguard Worker 2410*da0073e9SAndroid Build Coastguard Worker qkv = make_tensor(SdpaShape(batch_size, num_heads, seq_len, head_dim)) 2411*da0073e9SAndroid Build Coastguard Worker qkv_lp = qkv.detach().clone().to(torch.float32).requires_grad_() 2412*da0073e9SAndroid Build Coastguard Worker 2413*da0073e9SAndroid Build Coastguard Worker query, key, value = qkv.chunk(3, dim=-1) 2414*da0073e9SAndroid Build Coastguard Worker query_lp, key_lp, value_lp = qkv_lp.chunk(3, dim=-1) 2415*da0073e9SAndroid Build Coastguard Worker 2416*da0073e9SAndroid Build Coastguard Worker query = query.view(batch_size, -1, num_heads, head_dim).transpose(1, 2) 2417*da0073e9SAndroid Build Coastguard Worker key = key.view(batch_size, -1, num_heads, head_dim).transpose(1, 2) 2418*da0073e9SAndroid Build Coastguard Worker value = value.view(batch_size, -1, num_heads, head_dim).transpose(1, 2) 2419*da0073e9SAndroid Build Coastguard Worker 2420*da0073e9SAndroid Build Coastguard Worker query_lp = query_lp.view(batch_size, -1, num_heads, head_dim).transpose(1, 2) 2421*da0073e9SAndroid Build Coastguard Worker key_lp = key_lp.view(batch_size, -1, num_heads, head_dim).transpose(1, 2) 2422*da0073e9SAndroid Build Coastguard Worker value_lp = value_lp.view(batch_size, -1, num_heads, head_dim).transpose(1, 2) 2423*da0073e9SAndroid Build Coastguard Worker 2424*da0073e9SAndroid Build Coastguard Worker if contiguous_inputs: 2425*da0073e9SAndroid Build Coastguard Worker query = query.contiguous() 2426*da0073e9SAndroid Build Coastguard Worker key = key.contiguous() 2427*da0073e9SAndroid Build Coastguard Worker value = value.contiguous() 2428*da0073e9SAndroid Build Coastguard Worker 2429*da0073e9SAndroid Build Coastguard Worker query_lp = query_lp.contiguous() 2430*da0073e9SAndroid Build Coastguard Worker key_lp = key_lp.contiguous() 2431*da0073e9SAndroid Build Coastguard Worker value_lp = value_lp.contiguous() 2432*da0073e9SAndroid Build Coastguard Worker 2433*da0073e9SAndroid Build Coastguard Worker with sdpa_kernel(backends=[SDPBackend.MATH]): 2434*da0073e9SAndroid Build Coastguard Worker out = torch.nn.functional.scaled_dot_product_attention(query, key, value, None, 0.0, is_causal) 2435*da0073e9SAndroid Build Coastguard Worker 2436*da0073e9SAndroid Build Coastguard Worker with sdpa_kernel(backends=[SDPBackend.EFFICIENT_ATTENTION]): 2437*da0073e9SAndroid Build Coastguard Worker out_lp = torch.nn.functional.scaled_dot_product_attention( 2438*da0073e9SAndroid Build Coastguard Worker query_lp, key_lp, value_lp, None, 0.0, is_causal) 2439*da0073e9SAndroid Build Coastguard Worker 2440*da0073e9SAndroid Build Coastguard Worker rand_upward = torch.rand_like(out) 2441*da0073e9SAndroid Build Coastguard Worker rand_upward_lp = rand_upward.to(torch.float32) 2442*da0073e9SAndroid Build Coastguard Worker 2443*da0073e9SAndroid Build Coastguard Worker out.backward(rand_upward) 2444*da0073e9SAndroid Build Coastguard Worker out_lp.backward(rand_upward_lp) 2445*da0073e9SAndroid Build Coastguard Worker 2446*da0073e9SAndroid Build Coastguard Worker # Cast up and compare 2447*da0073e9SAndroid Build Coastguard Worker self.assertEqual(qkv.grad, qkv_lp.grad.to(torch.float64), atol=1e-5, rtol=1e-5) 2448*da0073e9SAndroid Build Coastguard Worker 2449*da0073e9SAndroid Build Coastguard Worker @unittest.skipIf(not PLATFORM_SUPPORTS_FLASH_ATTENTION, "Flash Attention was not built for this system") 2450*da0073e9SAndroid Build Coastguard Worker @parametrize("contiguous_inputs", [True, False]) 2451*da0073e9SAndroid Build Coastguard Worker @parametrize("is_causal", [True, False]) 2452*da0073e9SAndroid Build Coastguard Worker @parametrize("dtype", [torch.float16, torch.bfloat16]) 2453*da0073e9SAndroid Build Coastguard Worker def test_sdp_flash_attention_grad_against_math(self, device, contiguous_inputs: bool, is_causal: bool, dtype: torch.dtype): 2454*da0073e9SAndroid Build Coastguard Worker batch_size, seq_len, num_heads, head_dim = 4, 4, 2, 16 2455*da0073e9SAndroid Build Coastguard Worker make_tensor = partial(rand_sdpa_tensor, type="dense", device=device, 2456*da0073e9SAndroid Build Coastguard Worker dtype=torch.float64, requires_grad=True, packed=True) 2457*da0073e9SAndroid Build Coastguard Worker 2458*da0073e9SAndroid Build Coastguard Worker qkv = make_tensor(SdpaShape(batch_size, num_heads, seq_len, head_dim)) 2459*da0073e9SAndroid Build Coastguard Worker qkv_lp = qkv.detach().clone().to(dtype).requires_grad_() 2460*da0073e9SAndroid Build Coastguard Worker 2461*da0073e9SAndroid Build Coastguard Worker query, key, value = qkv.chunk(3, dim=-1) 2462*da0073e9SAndroid Build Coastguard Worker query_lp, key_lp, value_lp = qkv_lp.chunk(3, dim=-1) 2463*da0073e9SAndroid Build Coastguard Worker 2464*da0073e9SAndroid Build Coastguard Worker query = query.view(batch_size, -1, num_heads, head_dim).transpose(1, 2) 2465*da0073e9SAndroid Build Coastguard Worker key = key.view(batch_size, -1, num_heads, head_dim).transpose(1, 2) 2466*da0073e9SAndroid Build Coastguard Worker value = value.view(batch_size, -1, num_heads, head_dim).transpose(1, 2) 2467*da0073e9SAndroid Build Coastguard Worker 2468*da0073e9SAndroid Build Coastguard Worker query_lp = query_lp.view(batch_size, -1, num_heads, head_dim).transpose(1, 2) 2469*da0073e9SAndroid Build Coastguard Worker key_lp = key_lp.view(batch_size, -1, num_heads, head_dim).transpose(1, 2) 2470*da0073e9SAndroid Build Coastguard Worker value_lp = value_lp.view(batch_size, -1, num_heads, head_dim).transpose(1, 2) 2471*da0073e9SAndroid Build Coastguard Worker 2472*da0073e9SAndroid Build Coastguard Worker if contiguous_inputs: 2473*da0073e9SAndroid Build Coastguard Worker query = query.contiguous() 2474*da0073e9SAndroid Build Coastguard Worker key = key.contiguous() 2475*da0073e9SAndroid Build Coastguard Worker value = value.contiguous() 2476*da0073e9SAndroid Build Coastguard Worker 2477*da0073e9SAndroid Build Coastguard Worker query_lp = query_lp.contiguous() 2478*da0073e9SAndroid Build Coastguard Worker key_lp = key_lp.contiguous() 2479*da0073e9SAndroid Build Coastguard Worker value_lp = value_lp.contiguous() 2480*da0073e9SAndroid Build Coastguard Worker 2481*da0073e9SAndroid Build Coastguard Worker with sdpa_kernel(backends=[SDPBackend.MATH]): 2482*da0073e9SAndroid Build Coastguard Worker out = torch.nn.functional.scaled_dot_product_attention(query, key, value, None, 0.0, is_causal) 2483*da0073e9SAndroid Build Coastguard Worker 2484*da0073e9SAndroid Build Coastguard Worker with sdpa_kernel(backends=[SDPBackend.FLASH_ATTENTION]): 2485*da0073e9SAndroid Build Coastguard Worker out_lp = torch.nn.functional.scaled_dot_product_attention( 2486*da0073e9SAndroid Build Coastguard Worker query_lp, key_lp, value_lp, None, 0.0, is_causal) 2487*da0073e9SAndroid Build Coastguard Worker 2488*da0073e9SAndroid Build Coastguard Worker rand_upward = torch.rand_like(out) 2489*da0073e9SAndroid Build Coastguard Worker rand_upward_lp = rand_upward.to(dtype) 2490*da0073e9SAndroid Build Coastguard Worker 2491*da0073e9SAndroid Build Coastguard Worker out.backward(rand_upward) 2492*da0073e9SAndroid Build Coastguard Worker out_lp.backward(rand_upward_lp) 2493*da0073e9SAndroid Build Coastguard Worker 2494*da0073e9SAndroid Build Coastguard Worker # Cast up and compare 2495*da0073e9SAndroid Build Coastguard Worker # Since we are doing the compute on fp16 we have to bump the tolerance 2496*da0073e9SAndroid Build Coastguard Worker # Bump down the tolearnce for blfoat16 2497*da0073e9SAndroid Build Coastguard Worker atol = 7e-4 if dtype == torch.float16 else 7e-3 2498*da0073e9SAndroid Build Coastguard Worker rtol = 7e-4 if dtype == torch.float16 else 7e-3 2499*da0073e9SAndroid Build Coastguard Worker if TEST_WITH_ROCM: 2500*da0073e9SAndroid Build Coastguard Worker atol = 9e-4 if dtype == torch.float16 else 9e-3 2501*da0073e9SAndroid Build Coastguard Worker self.assertEqual(qkv.grad, qkv_lp.grad.to(torch.float64), atol=atol, rtol=rtol) 2502*da0073e9SAndroid Build Coastguard Worker 2503*da0073e9SAndroid Build Coastguard Worker @skipIfRocm # Missing nested and EFFICIENT_ATTENTION 2504*da0073e9SAndroid Build Coastguard Worker @unittest.skipIf(not PLATFORM_SUPPORTS_FUSED_ATTENTION, "Platform does not support fused SDPA") 2505*da0073e9SAndroid Build Coastguard Worker @parametrize("type", ["dense", "nested"]) 2506*da0073e9SAndroid Build Coastguard Worker def test_fused_sdp_choice(self, device, type: str): 2507*da0073e9SAndroid Build Coastguard Worker batch_size, seq_len, num_heads, head_dim = 2, 128, 8, 64 2508*da0073e9SAndroid Build Coastguard Worker shape = SdpaShape(batch_size, num_heads, seq_len, head_dim) 2509*da0073e9SAndroid Build Coastguard Worker make_tensor = partial(rand_sdpa_tensor, device=device, dtype=torch.float16, packed=True, requires_grad=True) 2510*da0073e9SAndroid Build Coastguard Worker 2511*da0073e9SAndroid Build Coastguard Worker qkv = make_tensor(shape, type=type) 2512*da0073e9SAndroid Build Coastguard Worker query, key, value = qkv.chunk(3, dim=-1) 2513*da0073e9SAndroid Build Coastguard Worker 2514*da0073e9SAndroid Build Coastguard Worker query = query.view(batch_size, -1, num_heads, head_dim).transpose(1, 2) 2515*da0073e9SAndroid Build Coastguard Worker value = value.view(batch_size, -1, num_heads, head_dim).transpose(1, 2) 2516*da0073e9SAndroid Build Coastguard Worker key = key.view(batch_size, -1, num_heads, head_dim).transpose(1, 2) 2517*da0073e9SAndroid Build Coastguard Worker 2518*da0073e9SAndroid Build Coastguard Worker if PLATFORM_SUPPORTS_FLASH_ATTENTION: 2519*da0073e9SAndroid Build Coastguard Worker assert torch._fused_sdp_choice(query, key, value) == SDPBackend.FLASH_ATTENTION.value 2520*da0073e9SAndroid Build Coastguard Worker else: 2521*da0073e9SAndroid Build Coastguard Worker assert torch._fused_sdp_choice(query, key, value) == SDPBackend.EFFICIENT_ATTENTION.value 2522*da0073e9SAndroid Build Coastguard Worker 2523*da0073e9SAndroid Build Coastguard Worker # Change dtype to float32 so that efficient attention should get chosen 2524*da0073e9SAndroid Build Coastguard Worker make_tensor = partial(rand_sdpa_tensor, device=device, dtype=torch.float32, packed=True) 2525*da0073e9SAndroid Build Coastguard Worker 2526*da0073e9SAndroid Build Coastguard Worker qkv = make_tensor(shape, type=type) 2527*da0073e9SAndroid Build Coastguard Worker query, key, value = qkv.chunk(3, dim=-1) 2528*da0073e9SAndroid Build Coastguard Worker 2529*da0073e9SAndroid Build Coastguard Worker query = query.view(batch_size, -1, num_heads, head_dim).transpose(1, 2) 2530*da0073e9SAndroid Build Coastguard Worker value = value.view(batch_size, -1, num_heads, head_dim).transpose(1, 2) 2531*da0073e9SAndroid Build Coastguard Worker key = key.view(batch_size, -1, num_heads, head_dim).transpose(1, 2) 2532*da0073e9SAndroid Build Coastguard Worker 2533*da0073e9SAndroid Build Coastguard Worker assert torch._fused_sdp_choice(query, key, value) == SDPBackend.EFFICIENT_ATTENTION.value 2534*da0073e9SAndroid Build Coastguard Worker 2535*da0073e9SAndroid Build Coastguard Worker @skipIfRocm # Missing triton.float32 ("triton" prefix is to locate skipped UTs), and deterministic algo 2536*da0073e9SAndroid Build Coastguard Worker @unittest.skipIf(not PLATFORM_SUPPORTS_MEM_EFF_ATTENTION, "Platform does not support fused SDPA") 2537*da0073e9SAndroid Build Coastguard Worker @parametrize("warn_only", [True, False]) 2538*da0073e9SAndroid Build Coastguard Worker def test_sdp_choice_with_determinism(self, device, warn_only): 2539*da0073e9SAndroid Build Coastguard Worker batch_size, seq_len, num_heads, head_dim = 1, 64, 8, 64 2540*da0073e9SAndroid Build Coastguard Worker shape = SdpaShape(batch_size, num_heads, seq_len, head_dim) 2541*da0073e9SAndroid Build Coastguard Worker make_tensor = partial(rand_sdpa_tensor, type="dense", device=device, dtype=torch.float32, packed=False) 2542*da0073e9SAndroid Build Coastguard Worker query, key, value = make_tensor(shape), make_tensor(shape), make_tensor(shape) 2543*da0073e9SAndroid Build Coastguard Worker 2544*da0073e9SAndroid Build Coastguard Worker with use_deterministic_algorithims(True, warn_only=warn_only): 2545*da0073e9SAndroid Build Coastguard Worker with sdpa_kernel(backends=[SDPBackend.EFFICIENT_ATTENTION, SDPBackend.MATH]): 2546*da0073e9SAndroid Build Coastguard Worker assert torch._fused_sdp_choice(query, key, value) == SDPBackend.EFFICIENT_ATTENTION.value 2547*da0073e9SAndroid Build Coastguard Worker 2548*da0073e9SAndroid Build Coastguard Worker @skipIfRocm # Missing deterministic algo 2549*da0073e9SAndroid Build Coastguard Worker @unittest.skipIf(not PLATFORM_SUPPORTS_FUSED_ATTENTION, "Fused SDPA was not built for this system") 2550*da0073e9SAndroid Build Coastguard Worker @parametrize("fused_kernel", PLATFORM_SPECIFIC_SDPA) 2551*da0073e9SAndroid Build Coastguard Worker @parametrize("warn_only", [True, False]) 2552*da0073e9SAndroid Build Coastguard Worker def test_fused_backwards_throws_determinism_warning(self, device, warn_only, fused_kernel): 2553*da0073e9SAndroid Build Coastguard Worker batch_size, seq_len, num_heads, head_dim = 1, 64, 8, 64 2554*da0073e9SAndroid Build Coastguard Worker shape = SdpaShape(batch_size, num_heads, seq_len, head_dim) 2555*da0073e9SAndroid Build Coastguard Worker make_tensor = partial(rand_sdpa_tensor, type="dense", device=device, dtype=torch.float16, packed=False, requires_grad=True) 2556*da0073e9SAndroid Build Coastguard Worker query, key, value = make_tensor(shape), make_tensor(shape), make_tensor(shape) 2557*da0073e9SAndroid Build Coastguard Worker 2558*da0073e9SAndroid Build Coastguard Worker kernel_name = "Memory Efficient attention" if fused_kernel == SDPBackend.EFFICIENT_ATTENTION else "Flash Attention" 2559*da0073e9SAndroid Build Coastguard Worker warning_context = ( 2560*da0073e9SAndroid Build Coastguard Worker self.assertWarnsRegex( 2561*da0073e9SAndroid Build Coastguard Worker UserWarning, 2562*da0073e9SAndroid Build Coastguard Worker f"{kernel_name} defaults to a non-deterministic algorithm.", 2563*da0073e9SAndroid Build Coastguard Worker ) 2564*da0073e9SAndroid Build Coastguard Worker if warn_only 2565*da0073e9SAndroid Build Coastguard Worker else contextlib.nullcontext() 2566*da0073e9SAndroid Build Coastguard Worker ) 2567*da0073e9SAndroid Build Coastguard Worker with use_deterministic_algorithims(True, warn_only=warn_only): 2568*da0073e9SAndroid Build Coastguard Worker with sdpa_kernel(backends=[fused_kernel]): 2569*da0073e9SAndroid Build Coastguard Worker with warning_context: 2570*da0073e9SAndroid Build Coastguard Worker torch.nn.functional.scaled_dot_product_attention(query, key, value).sum().backward() 2571*da0073e9SAndroid Build Coastguard Worker 2572*da0073e9SAndroid Build Coastguard Worker @unittest.skip("This test is not behaving deterministaclly non-deterministaclly on CI/CD") 2573*da0073e9SAndroid Build Coastguard Worker @unittest.skipIf(not PLATFORM_SUPPORTS_FLASH_ATTENTION, "Platform does not support fused SDPA") 2574*da0073e9SAndroid Build Coastguard Worker def test_mem_eff_backwards_determinism(self, device): 2575*da0073e9SAndroid Build Coastguard Worker # Need big seq_len to ensure that num_splits > 1 2576*da0073e9SAndroid Build Coastguard Worker dtype = torch.float32 2577*da0073e9SAndroid Build Coastguard Worker batch_size, seq_len, n_heads, head_dim = 1, 1024, 8, 64 2578*da0073e9SAndroid Build Coastguard Worker query = torch.rand(batch_size, n_heads, seq_len, head_dim, 2579*da0073e9SAndroid Build Coastguard Worker device=device, dtype=dtype, requires_grad=True) 2580*da0073e9SAndroid Build Coastguard Worker key = torch.rand(batch_size, n_heads, seq_len, head_dim, device=device, 2581*da0073e9SAndroid Build Coastguard Worker dtype=dtype, requires_grad=True) 2582*da0073e9SAndroid Build Coastguard Worker value = torch.rand(batch_size, n_heads, seq_len, head_dim, 2583*da0073e9SAndroid Build Coastguard Worker device=device, dtype=dtype, requires_grad=True) 2584*da0073e9SAndroid Build Coastguard Worker 2585*da0073e9SAndroid Build Coastguard Worker with sdpa_kernel(backends=[SDPBackend.EFFICIENT_ATTENTION]): 2586*da0073e9SAndroid Build Coastguard Worker # Run once to establish baseline 2587*da0073e9SAndroid Build Coastguard Worker out = F.scaled_dot_product_attention(query, key, value) 2588*da0073e9SAndroid Build Coastguard Worker upward_grad = torch.rand_like(out) 2589*da0073e9SAndroid Build Coastguard Worker out.backward(upward_grad) 2590*da0073e9SAndroid Build Coastguard Worker intial_query_grad = query.grad 2591*da0073e9SAndroid Build Coastguard Worker 2592*da0073e9SAndroid Build Coastguard Worker # Re-run the op with the same upward grad and check that the backward is 2593*da0073e9SAndroid Build Coastguard Worker # not deterministic 2594*da0073e9SAndroid Build Coastguard Worker diff_anwser_once = False 2595*da0073e9SAndroid Build Coastguard Worker for _ in range(100): 2596*da0073e9SAndroid Build Coastguard Worker query.grad = None 2597*da0073e9SAndroid Build Coastguard Worker out = F.scaled_dot_product_attention(query, key, value) 2598*da0073e9SAndroid Build Coastguard Worker out.backward(upward_grad) 2599*da0073e9SAndroid Build Coastguard Worker if not torch.equal(intial_query_grad, query.grad): 2600*da0073e9SAndroid Build Coastguard Worker diff_anwser_once = True 2601*da0073e9SAndroid Build Coastguard Worker break 2602*da0073e9SAndroid Build Coastguard Worker self.assertTrue(diff_anwser_once) 2603*da0073e9SAndroid Build Coastguard Worker 2604*da0073e9SAndroid Build Coastguard Worker with use_deterministic_algorithims(True, warn_only=False): 2605*da0073e9SAndroid Build Coastguard Worker query.grad = None 2606*da0073e9SAndroid Build Coastguard Worker out = F.scaled_dot_product_attention(query, key, value) 2607*da0073e9SAndroid Build Coastguard Worker upward_grad = torch.rand_like(out) 2608*da0073e9SAndroid Build Coastguard Worker out.backward(upward_grad) 2609*da0073e9SAndroid Build Coastguard Worker intial_query_grad = query.grad 2610*da0073e9SAndroid Build Coastguard Worker 2611*da0073e9SAndroid Build Coastguard Worker # Re-run the op with the same upward grad and check that the backward is 2612*da0073e9SAndroid Build Coastguard Worker # deterministic now that we have enforced it 2613*da0073e9SAndroid Build Coastguard Worker diff_anwser_once = False 2614*da0073e9SAndroid Build Coastguard Worker for _ in range(100): 2615*da0073e9SAndroid Build Coastguard Worker query.grad = None 2616*da0073e9SAndroid Build Coastguard Worker out = F.scaled_dot_product_attention(query, key, value) 2617*da0073e9SAndroid Build Coastguard Worker out.backward(upward_grad) 2618*da0073e9SAndroid Build Coastguard Worker if not torch.equal(intial_query_grad, query.grad): 2619*da0073e9SAndroid Build Coastguard Worker diff_anwser_once = True 2620*da0073e9SAndroid Build Coastguard Worker break 2621*da0073e9SAndroid Build Coastguard Worker self.assertFalse(diff_anwser_once) 2622*da0073e9SAndroid Build Coastguard Worker 2623*da0073e9SAndroid Build Coastguard Worker # verified passing successfully on H100 2624*da0073e9SAndroid Build Coastguard Worker @unittest.skipIf(not PLATFORM_SUPPORTS_MEM_EFF_ATTENTION, "Does not support SDPA") 2625*da0073e9SAndroid Build Coastguard Worker @unittest.skipIf(IS_JETSON, "causing sigkill on Jetson") 2626*da0073e9SAndroid Build Coastguard Worker @parametrize("batch_size", [1, 8]) 2627*da0073e9SAndroid Build Coastguard Worker @parametrize("seq_len_q", [4, 8, 64, 128, 256, 512, 1024, 2048] if MEM_EFF_CAPABILITY_MATCHES_SM80 2628*da0073e9SAndroid Build Coastguard Worker else [4, 8, 64, 128, 256, 512]) 2629*da0073e9SAndroid Build Coastguard Worker @parametrize("seq_len_k", [4, 8, 64, 128, 256, 512, 1024, 2048] if MEM_EFF_CAPABILITY_MATCHES_SM80 2630*da0073e9SAndroid Build Coastguard Worker else [4, 8, 64, 128, 256, 512]) 2631*da0073e9SAndroid Build Coastguard Worker @parametrize("head_dim", [8, 16, 32, 64, 72, 96, 128] if MEM_EFF_CAPABILITY_MATCHES_SM80 2632*da0073e9SAndroid Build Coastguard Worker else [8, 16, 32, 64]) 2633*da0073e9SAndroid Build Coastguard Worker @parametrize("is_causal", [False, True]) 2634*da0073e9SAndroid Build Coastguard Worker @parametrize("dropout_p", [0.0, 0.22]) 2635*da0073e9SAndroid Build Coastguard Worker @parametrize("dtype", [torch.float16, torch.bfloat16, torch.float32] if MEM_EFF_CAPABILITY_MATCHES_SM80 2636*da0073e9SAndroid Build Coastguard Worker else [torch.float16, torch.float32]) 2637*da0073e9SAndroid Build Coastguard Worker @parametrize("scale", [None, "l1"]) 2638*da0073e9SAndroid Build Coastguard Worker def test_mem_efficient_attention_vs_math_ref_grads(self, device, batch_size: int, seq_len_q: int, seq_len_k: int, 2639*da0073e9SAndroid Build Coastguard Worker head_dim: int, is_causal: bool, dropout_p: float, dtype: torch.dtype, 2640*da0073e9SAndroid Build Coastguard Worker scale: str): 2641*da0073e9SAndroid Build Coastguard Worker def _get_mem_eff_drop_mask(batch_size, n_heads, q_len, kv_len, p, seed, offset, device=device): 2642*da0073e9SAndroid Build Coastguard Worker mask = torch.empty((batch_size, n_heads, q_len, kv_len), device=device, dtype=torch.float32) 2643*da0073e9SAndroid Build Coastguard Worker rand_uniform = torch._fill_mem_eff_dropout_mask_(mask, p, seed, offset) 2644*da0073e9SAndroid Build Coastguard Worker mask = (rand_uniform > p).to(torch.float32) 2645*da0073e9SAndroid Build Coastguard Worker return mask 2646*da0073e9SAndroid Build Coastguard Worker if max(seq_len_q, seq_len_k) >= 2048 and torch.cuda.get_device_properties('cuda').total_memory < 40 * 2**30: 2647*da0073e9SAndroid Build Coastguard Worker unittest.skip("Reference implementation OOM") 2648*da0073e9SAndroid Build Coastguard Worker return 2649*da0073e9SAndroid Build Coastguard Worker if TEST_WITH_ROCM and seq_len_q * seq_len_k * head_dim * batch_size > 1024 * 1024 * 128: 2650*da0073e9SAndroid Build Coastguard Worker torch.cuda.empty_cache() # Prevent memory fragmentation 2651*da0073e9SAndroid Build Coastguard Worker seed = 42 2652*da0073e9SAndroid Build Coastguard Worker scale = scale if scale is None else (1 / head_dim) 2653*da0073e9SAndroid Build Coastguard Worker n_heads = 4 2654*da0073e9SAndroid Build Coastguard Worker query = torch.rand(batch_size, n_heads, seq_len_q, head_dim, 2655*da0073e9SAndroid Build Coastguard Worker device=device, dtype=dtype, requires_grad=True) 2656*da0073e9SAndroid Build Coastguard Worker key = torch.rand(batch_size, n_heads, seq_len_k, head_dim, device=device, 2657*da0073e9SAndroid Build Coastguard Worker dtype=dtype, requires_grad=True) 2658*da0073e9SAndroid Build Coastguard Worker value = torch.rand(batch_size, n_heads, seq_len_k, head_dim, 2659*da0073e9SAndroid Build Coastguard Worker device=device, dtype=dtype, requires_grad=True) 2660*da0073e9SAndroid Build Coastguard Worker 2661*da0073e9SAndroid Build Coastguard Worker # Run the math kernel on low precision references 2662*da0073e9SAndroid Build Coastguard Worker query_ref_lp, key_ref_lp, value_ref_lp = query_key_value_clones(query, key, value, dtype=dtype) 2663*da0073e9SAndroid Build Coastguard Worker 2664*da0073e9SAndroid Build Coastguard Worker higher_precision_dtype = torch.float64 if dtype == torch.float32 else torch.float32 2665*da0073e9SAndroid Build Coastguard Worker query_ref, key_ref, value_ref = query_key_value_clones(query, key, value, dtype=higher_precision_dtype) 2666*da0073e9SAndroid Build Coastguard Worker 2667*da0073e9SAndroid Build Coastguard Worker # Create real output 2668*da0073e9SAndroid Build Coastguard Worker with sdpa_kernel(backends=[SDPBackend.EFFICIENT_ATTENTION]): 2669*da0073e9SAndroid Build Coastguard Worker # Set the seed and run the kernel 2670*da0073e9SAndroid Build Coastguard Worker torch.manual_seed(seed) 2671*da0073e9SAndroid Build Coastguard Worker out = F.scaled_dot_product_attention(query, key, value, dropout_p=dropout_p, is_causal=is_causal, scale=scale) 2672*da0073e9SAndroid Build Coastguard Worker 2673*da0073e9SAndroid Build Coastguard Worker if dropout_p == 0.0: 2674*da0073e9SAndroid Build Coastguard Worker with sdpa_kernel(backends=[SDPBackend.MATH]): 2675*da0073e9SAndroid Build Coastguard Worker # High Precision Math Reference 2676*da0073e9SAndroid Build Coastguard Worker out_ref = F.scaled_dot_product_attention(query_ref, key_ref, value_ref, 2677*da0073e9SAndroid Build Coastguard Worker dropout_p=dropout_p, is_causal=is_causal, scale=scale) 2678*da0073e9SAndroid Build Coastguard Worker # Low Precision Math Reference 2679*da0073e9SAndroid Build Coastguard Worker out_lp_ref = F.scaled_dot_product_attention(query_ref_lp, key_ref_lp, value_ref_lp, 2680*da0073e9SAndroid Build Coastguard Worker dropout_p=dropout_p, is_causal=is_causal, scale=scale) 2681*da0073e9SAndroid Build Coastguard Worker else: 2682*da0073e9SAndroid Build Coastguard Worker if seq_len_q > 1024: 2683*da0073e9SAndroid Build Coastguard Worker self.skipTest("Will call _fill_mem_eff_dropout_mask with too many threads!") 2684*da0073e9SAndroid Build Coastguard Worker # Create the dropout_mask 2685*da0073e9SAndroid Build Coastguard Worker torch.manual_seed(seed) 2686*da0073e9SAndroid Build Coastguard Worker dropout_mask = _get_mem_eff_drop_mask(batch_size, n_heads, seq_len_q, seq_len_k, dropout_p, seed, 0, device=device) 2687*da0073e9SAndroid Build Coastguard Worker # High Precision Math Reference 2688*da0073e9SAndroid Build Coastguard Worker out_ref = torch.ops.aten._scaled_dot_product_attention_math( 2689*da0073e9SAndroid Build Coastguard Worker query_ref, key_ref, value_ref, dropout_p=dropout_p, is_causal=is_causal, scale=scale, dropout_mask=dropout_mask)[0] 2690*da0073e9SAndroid Build Coastguard Worker # Low Precision Math Reference 2691*da0073e9SAndroid Build Coastguard Worker out_lp_ref = torch.ops.aten._scaled_dot_product_attention_math( 2692*da0073e9SAndroid Build Coastguard Worker query_ref_lp, key_ref_lp, value_ref_lp, dropout_p=dropout_p, is_causal=is_causal, scale=scale, 2693*da0073e9SAndroid Build Coastguard Worker dropout_mask=dropout_mask)[0] 2694*da0073e9SAndroid Build Coastguard Worker 2695*da0073e9SAndroid Build Coastguard Worker upstream_grad = torch.rand_like(out, requires_grad=False) 2696*da0073e9SAndroid Build Coastguard Worker 2697*da0073e9SAndroid Build Coastguard Worker out.backward(upstream_grad) 2698*da0073e9SAndroid Build Coastguard Worker out_ref.backward(upstream_grad.to(out_ref.dtype)) 2699*da0073e9SAndroid Build Coastguard Worker out_lp_ref.backward(upstream_grad.to(out_lp_ref.dtype)) 2700*da0073e9SAndroid Build Coastguard Worker 2701*da0073e9SAndroid Build Coastguard Worker # [Note] Fused Tolerances 2702*da0073e9SAndroid Build Coastguard Worker # Establish the numerical error between the "true" high precision math output 2703*da0073e9SAndroid Build Coastguard Worker # and the low precision math reference. We use this reference for the atol 2704*da0073e9SAndroid Build Coastguard Worker # And we use the default rtol for the low precision type. 2705*da0073e9SAndroid Build Coastguard Worker # We then provide a fudge factor for gradients respectively to account 2706*da0073e9SAndroid Build Coastguard Worker # for the use of the fused kernel rather than the eager implemntation. 2707*da0073e9SAndroid Build Coastguard Worker output_ref_atol, output_ref_rtol = get_tolerances(out_ref, out_lp_ref) 2708*da0073e9SAndroid Build Coastguard Worker 2709*da0073e9SAndroid Build Coastguard Worker # Fudge Factor when dropout is enabled 2710*da0073e9SAndroid Build Coastguard Worker dropout_fudge_factor = 1.0 if dropout_p == 0.0 else 2.0 2711*da0073e9SAndroid Build Coastguard Worker 2712*da0073e9SAndroid Build Coastguard Worker query_fudge_factor = dropout_fudge_factor 2713*da0073e9SAndroid Build Coastguard Worker grad_q_ref_atol, grad_q_ref_rtol = get_tolerances(query_ref.grad, query_ref_lp.grad, query_fudge_factor) 2714*da0073e9SAndroid Build Coastguard Worker 2715*da0073e9SAndroid Build Coastguard Worker # TODO: Investigate why grad_k needs larger tolerances 2716*da0073e9SAndroid Build Coastguard Worker key_fudge_factor = 8 * dropout_fudge_factor 2717*da0073e9SAndroid Build Coastguard Worker grad_k_ref_atol, grad_k_ref_rtol = get_tolerances(key_ref.grad, key_ref_lp.grad, key_fudge_factor) 2718*da0073e9SAndroid Build Coastguard Worker 2719*da0073e9SAndroid Build Coastguard Worker value_fudge_factor = 7 if not SM80OrLater and dtype == torch.float16 else 1.0 2720*da0073e9SAndroid Build Coastguard Worker if TEST_WITH_ROCM: 2721*da0073e9SAndroid Build Coastguard Worker value_fudge_factor = max(2.0, value_fudge_factor) 2722*da0073e9SAndroid Build Coastguard Worker grad_v_ref_atol, grad_v_ref_rtol = get_tolerances(value_ref.grad, value_ref_lp.grad, value_fudge_factor) 2723*da0073e9SAndroid Build Coastguard Worker 2724*da0073e9SAndroid Build Coastguard Worker self.assertEqual(out, out_ref.to(out.dtype), atol=output_ref_atol, rtol=output_ref_rtol) 2725*da0073e9SAndroid Build Coastguard Worker self.assertEqual(query.grad, query_ref.grad.to(query.grad.dtype), 2726*da0073e9SAndroid Build Coastguard Worker atol=grad_q_ref_atol, rtol=grad_q_ref_rtol) 2727*da0073e9SAndroid Build Coastguard Worker self.assertEqual(key.grad, key_ref.grad.to(key.grad.dtype), 2728*da0073e9SAndroid Build Coastguard Worker atol=grad_k_ref_atol, rtol=grad_k_ref_rtol) 2729*da0073e9SAndroid Build Coastguard Worker self.assertEqual(value.grad, value_ref.grad.to(value.grad.dtype), 2730*da0073e9SAndroid Build Coastguard Worker atol=grad_v_ref_atol, rtol=grad_v_ref_rtol) 2731*da0073e9SAndroid Build Coastguard Worker 2732*da0073e9SAndroid Build Coastguard Worker 2733*da0073e9SAndroid Build Coastguard Worker @unittest.skipIf(not PLATFORM_SUPPORTS_MEM_EFF_ATTENTION, "Does not support SDPA") 2734*da0073e9SAndroid Build Coastguard Worker @unittest.skipIf(IS_JETSON, "causing sigkill on Jetson") 2735*da0073e9SAndroid Build Coastguard Worker @parametrize("batch_size", [1, 8]) 2736*da0073e9SAndroid Build Coastguard Worker @parametrize("seq_len_q", [4, 8, 64, 128, 256, 312, 512, 1024, 2048] if MEM_EFF_CAPABILITY_MATCHES_SM80 2737*da0073e9SAndroid Build Coastguard Worker else [4, 8, 64, 128, 152, 256, 512]) 2738*da0073e9SAndroid Build Coastguard Worker @parametrize("seq_len_k", [4, 8, 64, 65, 128, 256, 408, 512, 1024, 2048] if MEM_EFF_CAPABILITY_MATCHES_SM80 2739*da0073e9SAndroid Build Coastguard Worker else [4, 8, 37, 64, 128, 256, 512]) 2740*da0073e9SAndroid Build Coastguard Worker @parametrize("head_dim", [8, 16, 32, 64, 72, 96, 128] if MEM_EFF_CAPABILITY_MATCHES_SM80 2741*da0073e9SAndroid Build Coastguard Worker else [8, 16, 32, 64]) 2742*da0073e9SAndroid Build Coastguard Worker @parametrize("is_causal", [False]) 2743*da0073e9SAndroid Build Coastguard Worker @parametrize("dropout_p", [0.0, 0.22]) 2744*da0073e9SAndroid Build Coastguard Worker @parametrize("dtype", [torch.float16, torch.bfloat16, torch.float32] if MEM_EFF_CAPABILITY_MATCHES_SM80 2745*da0073e9SAndroid Build Coastguard Worker else [torch.float16, torch.float32]) 2746*da0073e9SAndroid Build Coastguard Worker @parametrize("scale", [None, "l1"]) 2747*da0073e9SAndroid Build Coastguard Worker def test_mem_efficient_attention_attn_mask_vs_math_ref_grads(self, device, batch_size: int, seq_len_q: int, 2748*da0073e9SAndroid Build Coastguard Worker seq_len_k: int, head_dim: int, is_causal: bool, 2749*da0073e9SAndroid Build Coastguard Worker dropout_p: float, dtype: torch.dtype, 2750*da0073e9SAndroid Build Coastguard Worker scale: str): 2751*da0073e9SAndroid Build Coastguard Worker def _get_mem_eff_drop_mask(batch_size, n_heads, q_len, kv_len, p, seed, offset, device=device): 2752*da0073e9SAndroid Build Coastguard Worker mask = torch.empty((batch_size, n_heads, q_len, kv_len), device=device, dtype=torch.float32) 2753*da0073e9SAndroid Build Coastguard Worker rand_uniform = torch._fill_mem_eff_dropout_mask_(mask, p, seed, offset) 2754*da0073e9SAndroid Build Coastguard Worker mask = (rand_uniform > p).to(torch.float32) 2755*da0073e9SAndroid Build Coastguard Worker return mask 2756*da0073e9SAndroid Build Coastguard Worker if max(seq_len_q, seq_len_k) >= 2048 and torch.cuda.get_device_properties('cuda').total_memory < 40 * 2**30: 2757*da0073e9SAndroid Build Coastguard Worker unittest.skip("Reference implementation OOM") 2758*da0073e9SAndroid Build Coastguard Worker return 2759*da0073e9SAndroid Build Coastguard Worker if TEST_WITH_ROCM and dtype == torch.float32: 2760*da0073e9SAndroid Build Coastguard Worker unittest.skip("Skip fp32 attn_mask gradients on ROCM, for now.") 2761*da0073e9SAndroid Build Coastguard Worker return 2762*da0073e9SAndroid Build Coastguard Worker if TEST_WITH_ROCM and seq_len_q * seq_len_k * head_dim * batch_size > 1024 * 1024 * 128: 2763*da0073e9SAndroid Build Coastguard Worker torch.cuda.empty_cache() # Prevent memory fragmentation 2764*da0073e9SAndroid Build Coastguard Worker seed = 42 2765*da0073e9SAndroid Build Coastguard Worker scale = scale if scale is None else (1 / head_dim) 2766*da0073e9SAndroid Build Coastguard Worker n_heads = 4 2767*da0073e9SAndroid Build Coastguard Worker query = torch.rand(batch_size, n_heads, seq_len_q, head_dim, 2768*da0073e9SAndroid Build Coastguard Worker device=device, dtype=dtype, requires_grad=True) 2769*da0073e9SAndroid Build Coastguard Worker key = torch.rand(batch_size, n_heads, seq_len_k, head_dim, device=device, 2770*da0073e9SAndroid Build Coastguard Worker dtype=dtype, requires_grad=True) 2771*da0073e9SAndroid Build Coastguard Worker value = torch.rand(batch_size, n_heads, seq_len_k, head_dim, 2772*da0073e9SAndroid Build Coastguard Worker device=device, dtype=dtype, requires_grad=True) 2773*da0073e9SAndroid Build Coastguard Worker 2774*da0073e9SAndroid Build Coastguard Worker attn_mask = torch.rand(seq_len_q, seq_len_k, device=device, dtype=dtype, requires_grad=True) 2775*da0073e9SAndroid Build Coastguard Worker 2776*da0073e9SAndroid Build Coastguard Worker # Run the math kernel on low precision references 2777*da0073e9SAndroid Build Coastguard Worker query_ref_lp, key_ref_lp, value_ref_lp = query_key_value_clones(query, key, value, dtype=dtype) 2778*da0073e9SAndroid Build Coastguard Worker attn_mask_ref_lp = attn_mask.detach().to(dtype).requires_grad_(True) 2779*da0073e9SAndroid Build Coastguard Worker 2780*da0073e9SAndroid Build Coastguard Worker higher_precision_dtype = torch.float64 if dtype == torch.float32 else torch.float32 2781*da0073e9SAndroid Build Coastguard Worker query_ref, key_ref, value_ref = query_key_value_clones(query, key, value, dtype=higher_precision_dtype) 2782*da0073e9SAndroid Build Coastguard Worker attn_mask_ref = attn_mask.detach().to(higher_precision_dtype).requires_grad_(True) 2783*da0073e9SAndroid Build Coastguard Worker 2784*da0073e9SAndroid Build Coastguard Worker # Create real output 2785*da0073e9SAndroid Build Coastguard Worker with sdpa_kernel(backends=[SDPBackend.EFFICIENT_ATTENTION]): 2786*da0073e9SAndroid Build Coastguard Worker # Set the seed and run the kernel 2787*da0073e9SAndroid Build Coastguard Worker torch.manual_seed(seed) 2788*da0073e9SAndroid Build Coastguard Worker out = F.scaled_dot_product_attention(query, key, value, attn_mask, dropout_p=dropout_p, 2789*da0073e9SAndroid Build Coastguard Worker is_causal=is_causal, scale=scale) 2790*da0073e9SAndroid Build Coastguard Worker 2791*da0073e9SAndroid Build Coastguard Worker if dropout_p == 0.0: 2792*da0073e9SAndroid Build Coastguard Worker with sdpa_kernel(backends=[SDPBackend.MATH]): 2793*da0073e9SAndroid Build Coastguard Worker # High Precision Math Reference 2794*da0073e9SAndroid Build Coastguard Worker out_ref = F.scaled_dot_product_attention(query_ref, key_ref, value_ref, attn_mask_ref, 2795*da0073e9SAndroid Build Coastguard Worker dropout_p=dropout_p, is_causal=is_causal, scale=scale) 2796*da0073e9SAndroid Build Coastguard Worker # Low Precision Math Reference 2797*da0073e9SAndroid Build Coastguard Worker out_lp_ref = F.scaled_dot_product_attention(query_ref_lp, key_ref_lp, value_ref_lp, attn_mask_ref_lp, 2798*da0073e9SAndroid Build Coastguard Worker dropout_p=dropout_p, is_causal=is_causal, scale=scale) 2799*da0073e9SAndroid Build Coastguard Worker else: 2800*da0073e9SAndroid Build Coastguard Worker if seq_len_q > 1024: 2801*da0073e9SAndroid Build Coastguard Worker self.skipTest("Will call _fill_mem_eff_dropout_mask with too many threads!") 2802*da0073e9SAndroid Build Coastguard Worker # Create the dropout_mask 2803*da0073e9SAndroid Build Coastguard Worker torch.manual_seed(seed) 2804*da0073e9SAndroid Build Coastguard Worker dropout_mask = _get_mem_eff_drop_mask(batch_size, n_heads, seq_len_q, 2805*da0073e9SAndroid Build Coastguard Worker seq_len_k, dropout_p, seed, 0, device=device) 2806*da0073e9SAndroid Build Coastguard Worker # High Precision Math Reference 2807*da0073e9SAndroid Build Coastguard Worker out_ref = torch.ops.aten._scaled_dot_product_attention_math( 2808*da0073e9SAndroid Build Coastguard Worker query_ref, key_ref, value_ref, attn_mask_ref, dropout_p=dropout_p, is_causal=is_causal, 2809*da0073e9SAndroid Build Coastguard Worker scale=scale, dropout_mask=dropout_mask)[0] 2810*da0073e9SAndroid Build Coastguard Worker # Low Precision Math Reference 2811*da0073e9SAndroid Build Coastguard Worker out_lp_ref = torch.ops.aten._scaled_dot_product_attention_math( 2812*da0073e9SAndroid Build Coastguard Worker query_ref_lp, key_ref_lp, value_ref_lp, attn_mask_ref_lp, 2813*da0073e9SAndroid Build Coastguard Worker dropout_p=dropout_p, is_causal=is_causal, scale=scale, 2814*da0073e9SAndroid Build Coastguard Worker dropout_mask=dropout_mask)[0] 2815*da0073e9SAndroid Build Coastguard Worker 2816*da0073e9SAndroid Build Coastguard Worker upstream_grad = torch.rand_like(out, requires_grad=False) 2817*da0073e9SAndroid Build Coastguard Worker 2818*da0073e9SAndroid Build Coastguard Worker out.backward(upstream_grad) 2819*da0073e9SAndroid Build Coastguard Worker out_ref.backward(upstream_grad.to(out_ref.dtype)) 2820*da0073e9SAndroid Build Coastguard Worker out_lp_ref.backward(upstream_grad.to(out_lp_ref.dtype)) 2821*da0073e9SAndroid Build Coastguard Worker 2822*da0073e9SAndroid Build Coastguard Worker # [Note] Fused Tolerances 2823*da0073e9SAndroid Build Coastguard Worker # Establish the numerical error between the "true" high precision math output 2824*da0073e9SAndroid Build Coastguard Worker # and the low precision math reference. We use this reference for the atol 2825*da0073e9SAndroid Build Coastguard Worker # And we use the default rtol for the low precision type. 2826*da0073e9SAndroid Build Coastguard Worker # We then provide a fudge factor for gradients respectively to account 2827*da0073e9SAndroid Build Coastguard Worker # for the use of the fused kernel rather than the eager implemntation. 2828*da0073e9SAndroid Build Coastguard Worker output_ref_atol, output_ref_rtol = get_tolerances(out_ref, out_lp_ref) 2829*da0073e9SAndroid Build Coastguard Worker 2830*da0073e9SAndroid Build Coastguard Worker # Fudge Factor when dropout is enabled 2831*da0073e9SAndroid Build Coastguard Worker dropout_fudge_factor = 1.0 if dropout_p == 0.0 else 1.75 2832*da0073e9SAndroid Build Coastguard Worker mask_fudge_factor = 1.0 if attn_mask is None else 1.5 2833*da0073e9SAndroid Build Coastguard Worker 2834*da0073e9SAndroid Build Coastguard Worker query_fudge_factor = dropout_fudge_factor 2835*da0073e9SAndroid Build Coastguard Worker grad_q_ref_atol, grad_q_ref_rtol = get_tolerances(query_ref.grad, query_ref_lp.grad, query_fudge_factor) 2836*da0073e9SAndroid Build Coastguard Worker 2837*da0073e9SAndroid Build Coastguard Worker # TODO: Investigate why grad_k needs larger tolerances 2838*da0073e9SAndroid Build Coastguard Worker key_fudge_factor = 8 * dropout_fudge_factor * mask_fudge_factor 2839*da0073e9SAndroid Build Coastguard Worker grad_k_ref_atol, grad_k_ref_rtol = get_tolerances(key_ref.grad, key_ref_lp.grad, key_fudge_factor) 2840*da0073e9SAndroid Build Coastguard Worker 2841*da0073e9SAndroid Build Coastguard Worker value_fudge_factor = 7 if not SM80OrLater and dtype == torch.float16 else 1.0 2842*da0073e9SAndroid Build Coastguard Worker if TEST_WITH_ROCM: 2843*da0073e9SAndroid Build Coastguard Worker value_fudge_factor = max(2.0, value_fudge_factor) 2844*da0073e9SAndroid Build Coastguard Worker grad_v_ref_atol, grad_v_ref_rtol = get_tolerances(value_ref.grad, value_ref_lp.grad, value_fudge_factor) 2845*da0073e9SAndroid Build Coastguard Worker 2846*da0073e9SAndroid Build Coastguard Worker mask_fudge_factor = 12 if attn_mask.numel() > 512 else 22 2847*da0073e9SAndroid Build Coastguard Worker grad_attn_mask_atol, grad_attn_mask_rtol = get_tolerances( 2848*da0073e9SAndroid Build Coastguard Worker attn_mask_ref.grad, attn_mask_ref_lp.grad, mask_fudge_factor) 2849*da0073e9SAndroid Build Coastguard Worker 2850*da0073e9SAndroid Build Coastguard Worker self.assertEqual(out, out_ref.to(out.dtype), atol=output_ref_atol, rtol=output_ref_rtol) 2851*da0073e9SAndroid Build Coastguard Worker self.assertEqual(query.grad, query_ref.grad.to(query.grad.dtype), 2852*da0073e9SAndroid Build Coastguard Worker atol=grad_q_ref_atol, rtol=grad_q_ref_rtol) 2853*da0073e9SAndroid Build Coastguard Worker self.assertEqual(key.grad, key_ref.grad.to(key.grad.dtype), 2854*da0073e9SAndroid Build Coastguard Worker atol=grad_k_ref_atol, rtol=grad_k_ref_rtol) 2855*da0073e9SAndroid Build Coastguard Worker self.assertEqual(value.grad, value_ref.grad.to(value.grad.dtype), 2856*da0073e9SAndroid Build Coastguard Worker atol=grad_v_ref_atol, rtol=grad_v_ref_rtol) 2857*da0073e9SAndroid Build Coastguard Worker 2858*da0073e9SAndroid Build Coastguard Worker self.assertEqual(attn_mask.grad, attn_mask_ref.grad.to(attn_mask.grad.dtype), 2859*da0073e9SAndroid Build Coastguard Worker atol=grad_attn_mask_atol, rtol=grad_attn_mask_rtol) 2860*da0073e9SAndroid Build Coastguard Worker 2861*da0073e9SAndroid Build Coastguard Worker @unittest.skipIf(not PLATFORM_SUPPORTS_FLASH_ATTENTION, "Does not support SDPA or pre-SM80 hardware") 2862*da0073e9SAndroid Build Coastguard Worker @unittest.skipIf(IS_JETSON, "causing sigkill on Jetson") 2863*da0073e9SAndroid Build Coastguard Worker @parametrize("batch_size", [1, 8]) 2864*da0073e9SAndroid Build Coastguard Worker @parametrize("seq_len_q", [4, 8, 64, 143, 256, 512, 1024, 2048]) 2865*da0073e9SAndroid Build Coastguard Worker @parametrize("seq_len_k", [4, 8, 64, 128, 256, 587, 1024, 2048]) 2866*da0073e9SAndroid Build Coastguard Worker @parametrize("head_dim", [8, 16, 21, 32, 64, 72, 96, 128, 160, 192, 203, 256]) 2867*da0073e9SAndroid Build Coastguard Worker @parametrize("is_causal", [True, False]) 2868*da0073e9SAndroid Build Coastguard Worker @parametrize("dropout_p", [0.0, 0.22, 0.48]) 2869*da0073e9SAndroid Build Coastguard Worker @parametrize("dtype", [torch.float16, torch.bfloat16]) 2870*da0073e9SAndroid Build Coastguard Worker @parametrize("scale", [None, "l1"]) 2871*da0073e9SAndroid Build Coastguard Worker def test_flash_attention_vs_math_ref_grads(self, device, batch_size: int, seq_len_q: int, seq_len_k: int, 2872*da0073e9SAndroid Build Coastguard Worker head_dim: int, is_causal: bool, dropout_p: float, dtype: torch.dtype, 2873*da0073e9SAndroid Build Coastguard Worker scale: str): 2874*da0073e9SAndroid Build Coastguard Worker if isSM8XDevice and head_dim in range(193, 256 + 1): 2875*da0073e9SAndroid Build Coastguard Worker self.skipTest("Flash attention on sm86, sm87, and sm89 for headdim > 192 currently disabled") 2876*da0073e9SAndroid Build Coastguard Worker if is_causal and seq_len_q != seq_len_k: 2877*da0073e9SAndroid Build Coastguard Worker self.skipTest("Flash V2 does not accept is_casual when seq_len_q != seq_len_k") 2878*da0073e9SAndroid Build Coastguard Worker if TEST_WITH_ROCM and seq_len_q >= 1024 and seq_len_k >= 1024 and batch_size > 1: 2879*da0073e9SAndroid Build Coastguard Worker torch.cuda.empty_cache() # Prevent memory fragmentation 2880*da0073e9SAndroid Build Coastguard Worker 2881*da0073e9SAndroid Build Coastguard Worker scale = scale if scale is None else (1 / head_dim) 2882*da0073e9SAndroid Build Coastguard Worker n_heads = 4 2883*da0073e9SAndroid Build Coastguard Worker query = torch.rand(batch_size, n_heads, seq_len_q, head_dim, 2884*da0073e9SAndroid Build Coastguard Worker device=device, dtype=dtype, requires_grad=True) 2885*da0073e9SAndroid Build Coastguard Worker key = torch.rand(batch_size, n_heads, seq_len_k, head_dim, device=device, 2886*da0073e9SAndroid Build Coastguard Worker dtype=dtype, requires_grad=True) 2887*da0073e9SAndroid Build Coastguard Worker value = torch.rand(batch_size, n_heads, seq_len_k, head_dim, 2888*da0073e9SAndroid Build Coastguard Worker device=device, dtype=dtype, requires_grad=True) 2889*da0073e9SAndroid Build Coastguard Worker 2890*da0073e9SAndroid Build Coastguard Worker # Run the math kernel on low precision references 2891*da0073e9SAndroid Build Coastguard Worker query_ref_lp, key_ref_lp, value_ref_lp = query_key_value_clones(query, key, value, dtype=dtype) 2892*da0073e9SAndroid Build Coastguard Worker 2893*da0073e9SAndroid Build Coastguard Worker higher_precision_dtype = torch.float64 if dtype == torch.float32 else torch.float32 2894*da0073e9SAndroid Build Coastguard Worker query_ref, key_ref, value_ref = query_key_value_clones(query, key, value, dtype=higher_precision_dtype) 2895*da0073e9SAndroid Build Coastguard Worker 2896*da0073e9SAndroid Build Coastguard Worker is_dropout = dropout_p > 0.0 2897*da0073e9SAndroid Build Coastguard Worker 2898*da0073e9SAndroid Build Coastguard Worker if not is_dropout: 2899*da0073e9SAndroid Build Coastguard Worker with sdpa_kernel(backends=[SDPBackend.FLASH_ATTENTION]): 2900*da0073e9SAndroid Build Coastguard Worker out = F.scaled_dot_product_attention(query, key, value, dropout_p=dropout_p, is_causal=is_causal, scale=scale) 2901*da0073e9SAndroid Build Coastguard Worker with sdpa_kernel(backends=[SDPBackend.MATH]): 2902*da0073e9SAndroid Build Coastguard Worker # High Precision Math Reference 2903*da0073e9SAndroid Build Coastguard Worker out_ref = F.scaled_dot_product_attention( 2904*da0073e9SAndroid Build Coastguard Worker query_ref, key_ref, value_ref, is_causal=is_causal, scale=scale) 2905*da0073e9SAndroid Build Coastguard Worker # Low Precision Math Reference 2906*da0073e9SAndroid Build Coastguard Worker out_lp_ref = F.scaled_dot_product_attention( 2907*da0073e9SAndroid Build Coastguard Worker query_ref_lp, key_ref_lp, value_ref_lp, is_causal=is_causal, scale=scale) 2908*da0073e9SAndroid Build Coastguard Worker else: 2909*da0073e9SAndroid Build Coastguard Worker # Problem: We pad sizes in the composite region of the top level SDPA. But we need the 2910*da0073e9SAndroid Build Coastguard Worker # Debug mask when have dropout. So I am going to manualy pad up here when testing dropout 2911*da0073e9SAndroid Build Coastguard Worker q_padded, q_og_size = pad_last_dim(query, 8) 2912*da0073e9SAndroid Build Coastguard Worker k_padded, k_og_size = pad_last_dim(key, 8) 2913*da0073e9SAndroid Build Coastguard Worker v_padded, v_og_size = pad_last_dim(value, 8) 2914*da0073e9SAndroid Build Coastguard Worker # scale needs to be calculated on the og head_size 2915*da0073e9SAndroid Build Coastguard Worker if scale is None: 2916*da0073e9SAndroid Build Coastguard Worker scale = 1 / math.sqrt(q_og_size) 2917*da0073e9SAndroid Build Coastguard Worker output_tuple = torch.ops.aten._scaled_dot_product_flash_attention( 2918*da0073e9SAndroid Build Coastguard Worker q_padded, k_padded, v_padded, dropout_p=dropout_p, is_causal=is_causal, scale=scale, return_debug_mask=is_dropout) 2919*da0073e9SAndroid Build Coastguard Worker out = output_tuple[0] 2920*da0073e9SAndroid Build Coastguard Worker out = out[..., :v_og_size] 2921*da0073e9SAndroid Build Coastguard Worker # Build dropout_mask 2922*da0073e9SAndroid Build Coastguard Worker dbug_mask = output_tuple[-1] 2923*da0073e9SAndroid Build Coastguard Worker query_padding_mask = torch.ones( 2924*da0073e9SAndroid Build Coastguard Worker batch_size, seq_len_q, device=device, dtype=torch.bool) 2925*da0073e9SAndroid Build Coastguard Worker key_padding_mask = torch.ones( 2926*da0073e9SAndroid Build Coastguard Worker batch_size, seq_len_k, device=device, dtype=torch.bool) 2927*da0073e9SAndroid Build Coastguard Worker 2928*da0073e9SAndroid Build Coastguard Worker softmax_mask = self.convert_flash_attn_S_to_softmax( 2929*da0073e9SAndroid Build Coastguard Worker dbug_mask, seq_len_q, seq_len_k, query_padding_mask, key_padding_mask, 2930*da0073e9SAndroid Build Coastguard Worker causal=is_causal)[:, :, :seq_len_q, :seq_len_k] 2931*da0073e9SAndroid Build Coastguard Worker dropout_mask = softmax_mask >= 0 2932*da0073e9SAndroid Build Coastguard Worker # attn_unnorm = softmax_mask.abs() 2933*da0073e9SAndroid Build Coastguard Worker # attn = self.normalize_flash_attn_S(attn_unnorm, q_padded, 2934*da0073e9SAndroid Build Coastguard Worker # k_padded, v_padded, query_padding_mask, 2935*da0073e9SAndroid Build Coastguard Worker # key_padding_mask, None, True, is_causal, scale=scale) 2936*da0073e9SAndroid Build Coastguard Worker 2937*da0073e9SAndroid Build Coastguard Worker # High Precision Math Reference 2938*da0073e9SAndroid Build Coastguard Worker out_ref = torch.ops.aten._scaled_dot_product_attention_math( 2939*da0073e9SAndroid Build Coastguard Worker query_ref, key_ref, value_ref, dropout_p=dropout_p, is_causal=is_causal, scale=scale, dropout_mask=dropout_mask)[0] 2940*da0073e9SAndroid Build Coastguard Worker # Low Precision Math Reference 2941*da0073e9SAndroid Build Coastguard Worker out_lp_ref = torch.ops.aten._scaled_dot_product_attention_math( 2942*da0073e9SAndroid Build Coastguard Worker query_ref_lp, key_ref_lp, value_ref_lp, dropout_p=dropout_p, is_causal=is_causal, scale=scale, 2943*da0073e9SAndroid Build Coastguard Worker dropout_mask=dropout_mask)[0] 2944*da0073e9SAndroid Build Coastguard Worker 2945*da0073e9SAndroid Build Coastguard Worker upstream_grad = torch.rand_like(out, requires_grad=False) 2946*da0073e9SAndroid Build Coastguard Worker 2947*da0073e9SAndroid Build Coastguard Worker # backward for flash attention on sm86, sm87, and sm89 for headdim >= 193 currently disabled 2948*da0073e9SAndroid Build Coastguard Worker if isSM8XDevice and head_dim in range(193, 256): 2949*da0073e9SAndroid Build Coastguard Worker self.assertRaises(RuntimeError, lambda: out.backward(upstream_grad)) 2950*da0073e9SAndroid Build Coastguard Worker return 2951*da0073e9SAndroid Build Coastguard Worker out.backward(upstream_grad) 2952*da0073e9SAndroid Build Coastguard Worker out_ref.backward(upstream_grad.to(out_ref.dtype)) 2953*da0073e9SAndroid Build Coastguard Worker out_lp_ref.backward(upstream_grad.to(out_lp_ref.dtype)) 2954*da0073e9SAndroid Build Coastguard Worker 2955*da0073e9SAndroid Build Coastguard Worker # See [Note] Fused Tolerances above 2956*da0073e9SAndroid Build Coastguard Worker output_fudge_factor = 3 if head_dim % 8 != 0 or TEST_WITH_ROCM else 1 2957*da0073e9SAndroid Build Coastguard Worker output_ref_atol, output_ref_rtol = get_tolerances(out_ref, out_lp_ref, output_fudge_factor) 2958*da0073e9SAndroid Build Coastguard Worker 2959*da0073e9SAndroid Build Coastguard Worker # TODO: Investigate why grad_q needs larger tolerances 2960*da0073e9SAndroid Build Coastguard Worker query_fudge_factor = 4 2961*da0073e9SAndroid Build Coastguard Worker grad_q_ref_atol, grad_q_ref_rtol = get_tolerances(query_ref.grad, query_ref_lp.grad, query_fudge_factor) 2962*da0073e9SAndroid Build Coastguard Worker 2963*da0073e9SAndroid Build Coastguard Worker key_fudge_factor = 2 2964*da0073e9SAndroid Build Coastguard Worker grad_k_ref_atol, grad_k_ref_rtol = get_tolerances(key_ref.grad, key_ref_lp.grad, key_fudge_factor) 2965*da0073e9SAndroid Build Coastguard Worker 2966*da0073e9SAndroid Build Coastguard Worker value_fudge_factor = 2 2967*da0073e9SAndroid Build Coastguard Worker grad_v_ref_atol, grad_v_ref_rtol = get_tolerances(value_ref.grad, value_ref_lp.grad, value_fudge_factor) 2968*da0073e9SAndroid Build Coastguard Worker 2969*da0073e9SAndroid Build Coastguard Worker self.assertEqual(out, out_ref.to(out.dtype), atol=output_ref_atol, rtol=output_ref_rtol) 2970*da0073e9SAndroid Build Coastguard Worker self.assertEqual(query.grad, query_ref.grad.to(query.grad.dtype), 2971*da0073e9SAndroid Build Coastguard Worker atol=grad_q_ref_atol, rtol=grad_q_ref_rtol) 2972*da0073e9SAndroid Build Coastguard Worker self.assertEqual(key.grad, key_ref.grad.to(key.grad.dtype), 2973*da0073e9SAndroid Build Coastguard Worker atol=grad_k_ref_atol, rtol=grad_k_ref_rtol) 2974*da0073e9SAndroid Build Coastguard Worker self.assertEqual(value.grad, value_ref.grad.to(value.grad.dtype), 2975*da0073e9SAndroid Build Coastguard Worker atol=grad_v_ref_atol, rtol=grad_v_ref_rtol) 2976*da0073e9SAndroid Build Coastguard Worker 2977*da0073e9SAndroid Build Coastguard Worker @skipIfRocm # FIXME: "capturing stream has unjoined work" 2978*da0073e9SAndroid Build Coastguard Worker @unittest.skipIf(not PLATFORM_SUPPORTS_FLASH_ATTENTION, "Does not support SDPA or pre-SM80 hardware") 2979*da0073e9SAndroid Build Coastguard Worker @parametrize("batch_size", [1, 8]) 2980*da0073e9SAndroid Build Coastguard Worker @parametrize("seq_len_q", [256, 512, 1024]) 2981*da0073e9SAndroid Build Coastguard Worker @parametrize("seq_len_k", [256, 512, 1024]) 2982*da0073e9SAndroid Build Coastguard Worker @parametrize("head_dim", [32, 64]) 2983*da0073e9SAndroid Build Coastguard Worker @parametrize("is_causal", [True, False]) 2984*da0073e9SAndroid Build Coastguard Worker @parametrize("dropout_p", [0.0, 0.22]) 2985*da0073e9SAndroid Build Coastguard Worker @parametrize("dtype", [torch.float16,]) 2986*da0073e9SAndroid Build Coastguard Worker @parametrize("scale", [None, "l1"]) 2987*da0073e9SAndroid Build Coastguard Worker @parametrize("fused_kernel", PLATFORM_SPECIFIC_SDPA) 2988*da0073e9SAndroid Build Coastguard Worker def test_fused_attention_vs_math_ref_grads_cudagraph(self, device, batch_size: int, seq_len_q: int, seq_len_k: int, 2989*da0073e9SAndroid Build Coastguard Worker head_dim: int, 2990*da0073e9SAndroid Build Coastguard Worker is_causal: bool, 2991*da0073e9SAndroid Build Coastguard Worker dropout_p: float, 2992*da0073e9SAndroid Build Coastguard Worker dtype: torch.dtype, 2993*da0073e9SAndroid Build Coastguard Worker scale: str, 2994*da0073e9SAndroid Build Coastguard Worker fused_kernel: SDPBackend): 2995*da0073e9SAndroid Build Coastguard Worker def _get_mem_eff_drop_mask(batch_size, n_heads, q_len, kv_len, dropout_p, seed, offset, device=device): 2996*da0073e9SAndroid Build Coastguard Worker mask = torch.empty((batch_size, n_heads, q_len, kv_len), device=device, dtype=torch.float32) 2997*da0073e9SAndroid Build Coastguard Worker rand_uniform = torch._fill_mem_eff_dropout_mask_(mask, dropout_p, seed, offset) 2998*da0073e9SAndroid Build Coastguard Worker mask = (rand_uniform > dropout_p).to(torch.float32) 2999*da0073e9SAndroid Build Coastguard Worker return mask 3000*da0073e9SAndroid Build Coastguard Worker 3001*da0073e9SAndroid Build Coastguard Worker def get_dropout_mask(output, fused_kernel, batch_size, n_heads, q_len, kv_len, dropout_p, device=device): 3002*da0073e9SAndroid Build Coastguard Worker if fused_kernel == SDPBackend.EFFICIENT_ATTENTION: 3003*da0073e9SAndroid Build Coastguard Worker output_seed, output_offset = output_tuple[2], output_tuple[3] 3004*da0073e9SAndroid Build Coastguard Worker output_seed = output_seed.item() 3005*da0073e9SAndroid Build Coastguard Worker output_offset = output_offset.item() 3006*da0073e9SAndroid Build Coastguard Worker return _get_mem_eff_drop_mask(batch_size, n_heads, q_len, kv_len, 3007*da0073e9SAndroid Build Coastguard Worker dropout_p, output_seed, output_offset, device=device) 3008*da0073e9SAndroid Build Coastguard Worker else: 3009*da0073e9SAndroid Build Coastguard Worker # Build dropout_mask 3010*da0073e9SAndroid Build Coastguard Worker dbug_mask = output_tuple[-1] 3011*da0073e9SAndroid Build Coastguard Worker query_padding_mask = torch.ones( 3012*da0073e9SAndroid Build Coastguard Worker batch_size, seq_len_q, device=device, dtype=torch.bool) 3013*da0073e9SAndroid Build Coastguard Worker key_padding_mask = torch.ones( 3014*da0073e9SAndroid Build Coastguard Worker batch_size, seq_len_k, device=device, dtype=torch.bool) 3015*da0073e9SAndroid Build Coastguard Worker 3016*da0073e9SAndroid Build Coastguard Worker softmax_mask = self.convert_flash_attn_S_to_softmax( 3017*da0073e9SAndroid Build Coastguard Worker dbug_mask, seq_len_q, seq_len_k, query_padding_mask, key_padding_mask, 3018*da0073e9SAndroid Build Coastguard Worker causal=is_causal)[:, :, :seq_len_q, :seq_len_k] 3019*da0073e9SAndroid Build Coastguard Worker dropout_mask = softmax_mask >= 0 3020*da0073e9SAndroid Build Coastguard Worker return dropout_mask 3021*da0073e9SAndroid Build Coastguard Worker 3022*da0073e9SAndroid Build Coastguard Worker if fused_kernel == SDPBackend.FLASH_ATTENTION and is_causal and seq_len_q != seq_len_k: 3023*da0073e9SAndroid Build Coastguard Worker self.skipTest("Flash V2 does not accept is_casual when seq_len_q != seq_len_k") 3024*da0073e9SAndroid Build Coastguard Worker 3025*da0073e9SAndroid Build Coastguard Worker seed = 42 3026*da0073e9SAndroid Build Coastguard Worker scale = scale if scale is None else (1 / head_dim) 3027*da0073e9SAndroid Build Coastguard Worker n_heads = 4 3028*da0073e9SAndroid Build Coastguard Worker query = torch.rand(batch_size, n_heads, seq_len_q, head_dim, 3029*da0073e9SAndroid Build Coastguard Worker device=device, dtype=dtype, requires_grad=True) 3030*da0073e9SAndroid Build Coastguard Worker key = torch.rand(batch_size, n_heads, seq_len_k, head_dim, device=device, 3031*da0073e9SAndroid Build Coastguard Worker dtype=dtype, requires_grad=True) 3032*da0073e9SAndroid Build Coastguard Worker value = torch.rand(batch_size, n_heads, seq_len_k, head_dim, 3033*da0073e9SAndroid Build Coastguard Worker device=device, dtype=dtype, requires_grad=True) 3034*da0073e9SAndroid Build Coastguard Worker 3035*da0073e9SAndroid Build Coastguard Worker fused_op = (torch.ops.aten._scaled_dot_product_efficient_attention 3036*da0073e9SAndroid Build Coastguard Worker if fused_kernel == SDPBackend.EFFICIENT_ATTENTION else torch.ops.aten._scaled_dot_product_flash_attention) 3037*da0073e9SAndroid Build Coastguard Worker # Run the math kernel on low precision references 3038*da0073e9SAndroid Build Coastguard Worker query_ref_lp, key_ref_lp, value_ref_lp = query_key_value_clones(query, key, value, dtype=dtype) 3039*da0073e9SAndroid Build Coastguard Worker 3040*da0073e9SAndroid Build Coastguard Worker higher_precision_dtype = torch.float64 if dtype == torch.float32 else torch.float32 3041*da0073e9SAndroid Build Coastguard Worker query_ref, key_ref, value_ref = query_key_value_clones(query, key, value, dtype=higher_precision_dtype) 3042*da0073e9SAndroid Build Coastguard Worker 3043*da0073e9SAndroid Build Coastguard Worker # warmup 3044*da0073e9SAndroid Build Coastguard Worker s = torch.cuda.Stream() 3045*da0073e9SAndroid Build Coastguard Worker s.wait_stream(torch.cuda.current_stream()) 3046*da0073e9SAndroid Build Coastguard Worker # Set the global seed before capture 3047*da0073e9SAndroid Build Coastguard Worker torch.manual_seed(seed) 3048*da0073e9SAndroid Build Coastguard Worker kwargs = {"dropout_p": dropout_p, "is_causal": is_causal, "scale": scale} 3049*da0073e9SAndroid Build Coastguard Worker if fused_kernel == SDPBackend.EFFICIENT_ATTENTION: 3050*da0073e9SAndroid Build Coastguard Worker kwargs["compute_log_sumexp"] = True 3051*da0073e9SAndroid Build Coastguard Worker kwargs["attn_bias"] = None 3052*da0073e9SAndroid Build Coastguard Worker if fused_kernel == SDPBackend.FLASH_ATTENTION: 3053*da0073e9SAndroid Build Coastguard Worker kwargs['return_debug_mask'] = dropout_p > 0.0 3054*da0073e9SAndroid Build Coastguard Worker with torch.cuda.stream(s): 3055*da0073e9SAndroid Build Coastguard Worker # Create real output 3056*da0073e9SAndroid Build Coastguard Worker output_tuple = fused_op(query, key, value, **kwargs) 3057*da0073e9SAndroid Build Coastguard Worker 3058*da0073e9SAndroid Build Coastguard Worker torch.cuda.current_stream().wait_stream(s) 3059*da0073e9SAndroid Build Coastguard Worker out = output_tuple[0] 3060*da0073e9SAndroid Build Coastguard Worker upstream_grad = torch.rand_like(out, requires_grad=False) 3061*da0073e9SAndroid Build Coastguard Worker s.wait_stream(torch.cuda.current_stream()) 3062*da0073e9SAndroid Build Coastguard Worker with torch.cuda.stream(s): 3063*da0073e9SAndroid Build Coastguard Worker out.backward(upstream_grad) 3064*da0073e9SAndroid Build Coastguard Worker for x in (query, key, value): 3065*da0073e9SAndroid Build Coastguard Worker x.grad = None 3066*da0073e9SAndroid Build Coastguard Worker g = torch.cuda.CUDAGraph() 3067*da0073e9SAndroid Build Coastguard Worker # Create real output 3068*da0073e9SAndroid Build Coastguard Worker with torch.cuda.graph(g): 3069*da0073e9SAndroid Build Coastguard Worker tmp = torch.rand_like(query, device=query.device) # test non-zero intragraph offset 3070*da0073e9SAndroid Build Coastguard Worker # Create real output 3071*da0073e9SAndroid Build Coastguard Worker output_tuple = fused_op(query, key, value, **kwargs) 3072*da0073e9SAndroid Build Coastguard Worker assert all(not isinstance(o, torch.Tensor) or o.is_cuda for o in output_tuple) 3073*da0073e9SAndroid Build Coastguard Worker g.replay() 3074*da0073e9SAndroid Build Coastguard Worker out_first = output_tuple[0].clone() 3075*da0073e9SAndroid Build Coastguard Worker g.replay() 3076*da0073e9SAndroid Build Coastguard Worker out = output_tuple[0] 3077*da0073e9SAndroid Build Coastguard Worker if dropout_p == 0.0: 3078*da0073e9SAndroid Build Coastguard Worker self.assertEqual(out_first, out, atol=0, rtol=0) 3079*da0073e9SAndroid Build Coastguard Worker else: 3080*da0073e9SAndroid Build Coastguard Worker # replays produce different results 3081*da0073e9SAndroid Build Coastguard Worker self.assertNotEqual(out_first, out) 3082*da0073e9SAndroid Build Coastguard Worker 3083*da0073e9SAndroid Build Coastguard Worker with sdpa_kernel(backends=[SDPBackend.MATH]): 3084*da0073e9SAndroid Build Coastguard Worker if dropout_p == 0.0: 3085*da0073e9SAndroid Build Coastguard Worker # High Precision Math Reference 3086*da0073e9SAndroid Build Coastguard Worker out_ref = F.scaled_dot_product_attention(query_ref, key_ref, value_ref, 3087*da0073e9SAndroid Build Coastguard Worker dropout_p=dropout_p, is_causal=is_causal, scale=scale) 3088*da0073e9SAndroid Build Coastguard Worker # Low Precision Math Reference 3089*da0073e9SAndroid Build Coastguard Worker out_lp_ref = F.scaled_dot_product_attention(query_ref_lp, key_ref_lp, value_ref_lp, 3090*da0073e9SAndroid Build Coastguard Worker dropout_p=dropout_p, is_causal=is_causal, scale=scale) 3091*da0073e9SAndroid Build Coastguard Worker else: 3092*da0073e9SAndroid Build Coastguard Worker # Create the dropout_mask 3093*da0073e9SAndroid Build Coastguard Worker dropout_mask = get_dropout_mask(output_tuple, fused_kernel, batch_size, 3094*da0073e9SAndroid Build Coastguard Worker n_heads, seq_len_q, seq_len_k, dropout_p, device) 3095*da0073e9SAndroid Build Coastguard Worker # High Precision Math Reference 3096*da0073e9SAndroid Build Coastguard Worker out_ref = torch.ops.aten._scaled_dot_product_attention_math( 3097*da0073e9SAndroid Build Coastguard Worker query_ref, key_ref, value_ref, dropout_p=dropout_p, is_causal=is_causal, 3098*da0073e9SAndroid Build Coastguard Worker scale=scale, dropout_mask=dropout_mask)[0] 3099*da0073e9SAndroid Build Coastguard Worker # Low Precision Math Reference 3100*da0073e9SAndroid Build Coastguard Worker out_lp_ref = torch.ops.aten._scaled_dot_product_attention_math( 3101*da0073e9SAndroid Build Coastguard Worker query_ref_lp, key_ref_lp, value_ref_lp, dropout_p=dropout_p, is_causal=is_causal, scale=scale, 3102*da0073e9SAndroid Build Coastguard Worker dropout_mask=dropout_mask)[0] 3103*da0073e9SAndroid Build Coastguard Worker 3104*da0073e9SAndroid Build Coastguard Worker 3105*da0073e9SAndroid Build Coastguard Worker g1 = torch.cuda.CUDAGraph() 3106*da0073e9SAndroid Build Coastguard Worker with torch.cuda.graph(g1): 3107*da0073e9SAndroid Build Coastguard Worker out.backward(upstream_grad) 3108*da0073e9SAndroid Build Coastguard Worker g1.replay() 3109*da0073e9SAndroid Build Coastguard Worker out_ref.backward(upstream_grad.to(out_ref.dtype)) 3110*da0073e9SAndroid Build Coastguard Worker out_lp_ref.backward(upstream_grad.to(out_lp_ref.dtype)) 3111*da0073e9SAndroid Build Coastguard Worker 3112*da0073e9SAndroid Build Coastguard Worker # [Note] Fused Tolerances 3113*da0073e9SAndroid Build Coastguard Worker # Establish the numerical error between the "true" high precision math output 3114*da0073e9SAndroid Build Coastguard Worker # and the low precision math reference. We use this reference for the atol 3115*da0073e9SAndroid Build Coastguard Worker # And we use the default rtol for the low precision type. 3116*da0073e9SAndroid Build Coastguard Worker # We then provide a fudge factor for gradients respectively to account 3117*da0073e9SAndroid Build Coastguard Worker # for the use of the fused kernel rather than the eager implemntation. 3118*da0073e9SAndroid Build Coastguard Worker output_ref_atol, output_ref_rtol = get_tolerances(out_ref, out_lp_ref) 3119*da0073e9SAndroid Build Coastguard Worker 3120*da0073e9SAndroid Build Coastguard Worker # Fudge Factor when dropout is enabled 3121*da0073e9SAndroid Build Coastguard Worker dropout_fudge_factor = 1.0 if dropout_p == 0.0 else 1.5 3122*da0073e9SAndroid Build Coastguard Worker 3123*da0073e9SAndroid Build Coastguard Worker query_fudge_factor = dropout_fudge_factor 3124*da0073e9SAndroid Build Coastguard Worker grad_q_ref_atol, grad_q_ref_rtol = get_tolerances(query_ref.grad, query_ref_lp.grad, query_fudge_factor) 3125*da0073e9SAndroid Build Coastguard Worker 3126*da0073e9SAndroid Build Coastguard Worker # TODO: Investigate why grad_k needs larger tolerances 3127*da0073e9SAndroid Build Coastguard Worker key_fudge_factor = 8 * dropout_fudge_factor 3128*da0073e9SAndroid Build Coastguard Worker grad_k_ref_atol, grad_k_ref_rtol = get_tolerances(key_ref.grad, key_ref_lp.grad, key_fudge_factor) 3129*da0073e9SAndroid Build Coastguard Worker 3130*da0073e9SAndroid Build Coastguard Worker value_fudge_factor = 7 if not SM80OrLater and dtype == torch.float16 else 1.0 3131*da0073e9SAndroid Build Coastguard Worker grad_v_ref_atol, grad_v_ref_rtol = get_tolerances(value_ref.grad, value_ref_lp.grad, value_fudge_factor) 3132*da0073e9SAndroid Build Coastguard Worker 3133*da0073e9SAndroid Build Coastguard Worker self.assertEqual(out, out_ref.to(out.dtype), atol=output_ref_atol, rtol=output_ref_rtol) 3134*da0073e9SAndroid Build Coastguard Worker self.assertEqual(query.grad, query_ref.grad.to(query.grad.dtype), 3135*da0073e9SAndroid Build Coastguard Worker atol=grad_q_ref_atol, rtol=grad_q_ref_rtol) 3136*da0073e9SAndroid Build Coastguard Worker self.assertEqual(key.grad, key_ref.grad.to(key.grad.dtype), 3137*da0073e9SAndroid Build Coastguard Worker atol=grad_k_ref_atol, rtol=grad_k_ref_rtol) 3138*da0073e9SAndroid Build Coastguard Worker self.assertEqual(value.grad, value_ref.grad.to(value.grad.dtype), 3139*da0073e9SAndroid Build Coastguard Worker atol=grad_v_ref_atol, rtol=grad_v_ref_rtol) 3140*da0073e9SAndroid Build Coastguard Worker 3141*da0073e9SAndroid Build Coastguard Worker @skipIfRocm # Nested Tensor 3142*da0073e9SAndroid Build Coastguard Worker @unittest.skipIf(not PLATFORM_SUPPORTS_FUSED_ATTENTION, "Fused SDPA was not built for this system") 3143*da0073e9SAndroid Build Coastguard Worker @parametrize("fused_kernel", [SDPBackend.FLASH_ATTENTION, SDPBackend.EFFICIENT_ATTENTION] if 3144*da0073e9SAndroid Build Coastguard Worker PLATFORM_SUPPORTS_FLASH_ATTENTION else [SDPBackend.EFFICIENT_ATTENTION]) 3145*da0073e9SAndroid Build Coastguard Worker def test_fused_kernels_seq_len_1_inputs(self, device, fused_kernel): 3146*da0073e9SAndroid Build Coastguard Worker rand_nested_tensor = partial(rand_sdpa_tensor, type="nested", device=device, dtype=torch.float16) 3147*da0073e9SAndroid Build Coastguard Worker batch, num_heads, head_dim = 32, 16, 64 3148*da0073e9SAndroid Build Coastguard Worker seq_lens = torch.randint(low=1, high=32, size=(batch,)) 3149*da0073e9SAndroid Build Coastguard Worker # make sure some seq_lens are 1 3150*da0073e9SAndroid Build Coastguard Worker num_ones = 10 3151*da0073e9SAndroid Build Coastguard Worker indices = torch.randint(low=0, high=batch, size=(num_ones,)) 3152*da0073e9SAndroid Build Coastguard Worker seq_lens.scatter_(0, indices, 1) 3153*da0073e9SAndroid Build Coastguard Worker 3154*da0073e9SAndroid Build Coastguard Worker shape = SdpaShape(batch, num_heads, seq_lens.tolist(), head_dim) 3155*da0073e9SAndroid Build Coastguard Worker query = rand_nested_tensor(shape) 3156*da0073e9SAndroid Build Coastguard Worker key = rand_nested_tensor(shape) 3157*da0073e9SAndroid Build Coastguard Worker value = rand_nested_tensor(shape) 3158*da0073e9SAndroid Build Coastguard Worker 3159*da0073e9SAndroid Build Coastguard Worker query = query.transpose(1, 2) 3160*da0073e9SAndroid Build Coastguard Worker key = key.transpose(1, 2) 3161*da0073e9SAndroid Build Coastguard Worker value = value.transpose(1, 2) 3162*da0073e9SAndroid Build Coastguard Worker 3163*da0073e9SAndroid Build Coastguard Worker with sdpa_kernel(backends=[fused_kernel]): 3164*da0073e9SAndroid Build Coastguard Worker actual = torch.nn.functional.scaled_dot_product_attention( 3165*da0073e9SAndroid Build Coastguard Worker query, key, value, attn_mask=None, dropout_p=0.0, is_causal=False) 3166*da0073e9SAndroid Build Coastguard Worker with sdpa_kernel(backends=[SDPBackend.MATH]): 3167*da0073e9SAndroid Build Coastguard Worker math_ref = torch.nn.functional.scaled_dot_product_attention( 3168*da0073e9SAndroid Build Coastguard Worker query.contiguous().to(torch.float32), 3169*da0073e9SAndroid Build Coastguard Worker key.contiguous().to(torch.float32), 3170*da0073e9SAndroid Build Coastguard Worker value.contiguous().to(torch.float32), 3171*da0073e9SAndroid Build Coastguard Worker attn_mask=None, dropout_p=0.0, is_causal=False) 3172*da0073e9SAndroid Build Coastguard Worker 3173*da0073e9SAndroid Build Coastguard Worker self.assertEqual(actual.contiguous(), math_ref.contiguous().to(torch.float16), atol=1e-3, rtol=1e-2) 3174*da0073e9SAndroid Build Coastguard Worker 3175*da0073e9SAndroid Build Coastguard Worker 3176*da0073e9SAndroid Build Coastguard Worker @skipIfRocm # Nested tensor 3177*da0073e9SAndroid Build Coastguard Worker @unittest.skipIf(not PLATFORM_SUPPORTS_FUSED_ATTENTION, "Fused SDPA was not built for this system") 3178*da0073e9SAndroid Build Coastguard Worker @parametrize("kernel", [SDPBackend.FLASH_ATTENTION, SDPBackend.EFFICIENT_ATTENTION] if 3179*da0073e9SAndroid Build Coastguard Worker PLATFORM_SUPPORTS_FLASH_ATTENTION else [SDPBackend.EFFICIENT_ATTENTION]) 3180*da0073e9SAndroid Build Coastguard Worker @parametrize("expand_q_batch", [True, False]) 3181*da0073e9SAndroid Build Coastguard Worker @parametrize("expand_k_batch", [True, False]) 3182*da0073e9SAndroid Build Coastguard Worker @parametrize("expand_v_batch", [True, False]) 3183*da0073e9SAndroid Build Coastguard Worker @parametrize("expand_q_num_heads", [True, False]) 3184*da0073e9SAndroid Build Coastguard Worker @parametrize("expand_k_num_heads", [True, False]) 3185*da0073e9SAndroid Build Coastguard Worker @parametrize("expand_v_num_heads", [True, False]) 3186*da0073e9SAndroid Build Coastguard Worker def test_fused_kernels_nested_broadcasting( 3187*da0073e9SAndroid Build Coastguard Worker self, 3188*da0073e9SAndroid Build Coastguard Worker device, 3189*da0073e9SAndroid Build Coastguard Worker kernel, 3190*da0073e9SAndroid Build Coastguard Worker expand_q_batch, 3191*da0073e9SAndroid Build Coastguard Worker expand_k_batch, 3192*da0073e9SAndroid Build Coastguard Worker expand_v_batch, 3193*da0073e9SAndroid Build Coastguard Worker expand_q_num_heads, 3194*da0073e9SAndroid Build Coastguard Worker expand_k_num_heads, 3195*da0073e9SAndroid Build Coastguard Worker expand_v_num_heads, 3196*da0073e9SAndroid Build Coastguard Worker ): 3197*da0073e9SAndroid Build Coastguard Worker is_efficient = kernel == SDPBackend.EFFICIENT_ATTENTION 3198*da0073e9SAndroid Build Coastguard Worker dtype = torch.float32 if is_efficient else torch.float16 3199*da0073e9SAndroid Build Coastguard Worker rand_nested_tensor = partial(rand_sdpa_tensor, type="nested", device=device, dtype=dtype) 3200*da0073e9SAndroid Build Coastguard Worker batch, num_heads, head_dim = 32, 8, 64 3201*da0073e9SAndroid Build Coastguard Worker head_dim_v = 32 if is_efficient else head_dim 3202*da0073e9SAndroid Build Coastguard Worker seq_lens_q = (torch.randint(low=1, high=5, size=(1,)).item() 3203*da0073e9SAndroid Build Coastguard Worker if expand_q_batch 3204*da0073e9SAndroid Build Coastguard Worker else torch.randint(low=1, high=32, size=(batch,)).tolist()) 3205*da0073e9SAndroid Build Coastguard Worker seq_lens_kv = (torch.randint(low=1, high=5, size=(1,)).item() 3206*da0073e9SAndroid Build Coastguard Worker if (expand_k_batch or expand_v_batch) 3207*da0073e9SAndroid Build Coastguard Worker else torch.randint(low=1, high=32, size=(batch,)).tolist()) 3208*da0073e9SAndroid Build Coastguard Worker 3209*da0073e9SAndroid Build Coastguard Worker batch_q = 1 if expand_q_batch else batch 3210*da0073e9SAndroid Build Coastguard Worker batch_k = 1 if expand_k_batch else batch 3211*da0073e9SAndroid Build Coastguard Worker batch_v = 1 if expand_v_batch else batch 3212*da0073e9SAndroid Build Coastguard Worker 3213*da0073e9SAndroid Build Coastguard Worker # handle case where all batch_sizes are 1 3214*da0073e9SAndroid Build Coastguard Worker batch = max(batch_q, batch_k, batch_v) 3215*da0073e9SAndroid Build Coastguard Worker 3216*da0073e9SAndroid Build Coastguard Worker num_heads_q = 1 if expand_q_num_heads else num_heads 3217*da0073e9SAndroid Build Coastguard Worker num_heads_k = 1 if expand_k_num_heads else num_heads 3218*da0073e9SAndroid Build Coastguard Worker num_heads_v = 1 if expand_v_num_heads else num_heads 3219*da0073e9SAndroid Build Coastguard Worker 3220*da0073e9SAndroid Build Coastguard Worker # handle case where all num_heads are 1 3221*da0073e9SAndroid Build Coastguard Worker num_heads = max(num_heads_q, num_heads_k, num_heads_v) 3222*da0073e9SAndroid Build Coastguard Worker 3223*da0073e9SAndroid Build Coastguard Worker q_shape = SdpaShape(batch_q, num_heads_q, seq_lens_q, head_dim) 3224*da0073e9SAndroid Build Coastguard Worker k_shape = SdpaShape(batch_k, num_heads_k, seq_lens_kv, head_dim) 3225*da0073e9SAndroid Build Coastguard Worker v_shape = SdpaShape(batch_v, num_heads_v, seq_lens_kv, head_dim_v) 3226*da0073e9SAndroid Build Coastguard Worker 3227*da0073e9SAndroid Build Coastguard Worker query = rand_nested_tensor(q_shape) 3228*da0073e9SAndroid Build Coastguard Worker key = rand_nested_tensor(k_shape) 3229*da0073e9SAndroid Build Coastguard Worker value = rand_nested_tensor(v_shape) 3230*da0073e9SAndroid Build Coastguard Worker 3231*da0073e9SAndroid Build Coastguard Worker def _broadcast(t, batch_broadcasted, num_heads_broadcasted): 3232*da0073e9SAndroid Build Coastguard Worker if batch_broadcasted and num_heads_broadcasted: 3233*da0073e9SAndroid Build Coastguard Worker # (1, seq_len, 1, head_dim) -> (batch, seq_len, num_heads, head_dim) 3234*da0073e9SAndroid Build Coastguard Worker result = torch.nested.nested_tensor( 3235*da0073e9SAndroid Build Coastguard Worker [t[0].expand(-1, num_heads, t.size(-1)) for _ in range(batch)], dtype=torch.float32) 3236*da0073e9SAndroid Build Coastguard Worker elif batch_broadcasted: 3237*da0073e9SAndroid Build Coastguard Worker # (1, seq_len, num_heads, head_dim) -> (batch, seq_len, num_heads, head_dim) 3238*da0073e9SAndroid Build Coastguard Worker result = torch.nested.nested_tensor([t[0] for _ in range(batch)], dtype=torch.float32) 3239*da0073e9SAndroid Build Coastguard Worker elif num_heads_broadcasted: 3240*da0073e9SAndroid Build Coastguard Worker # (batch, seq_len, 1, head_dim) -> (batch, seq_len, num_heads, head_dim) 3241*da0073e9SAndroid Build Coastguard Worker result = torch.nested.nested_tensor([x.expand(-1, num_heads, t.size(-1)) 3242*da0073e9SAndroid Build Coastguard Worker for x in t.unbind()], dtype=torch.float32) 3243*da0073e9SAndroid Build Coastguard Worker else: 3244*da0073e9SAndroid Build Coastguard Worker result = t.to(torch.float32) 3245*da0073e9SAndroid Build Coastguard Worker return result 3246*da0073e9SAndroid Build Coastguard Worker 3247*da0073e9SAndroid Build Coastguard Worker query_expanded = _broadcast(query, expand_q_batch, expand_q_num_heads).transpose(1, 2) 3248*da0073e9SAndroid Build Coastguard Worker key_expanded = _broadcast(key, expand_k_batch, expand_k_num_heads).transpose(1, 2) 3249*da0073e9SAndroid Build Coastguard Worker value_expanded = _broadcast(value, expand_v_batch, expand_v_num_heads).transpose(1, 2) 3250*da0073e9SAndroid Build Coastguard Worker 3251*da0073e9SAndroid Build Coastguard Worker query = query.transpose(1, 2) 3252*da0073e9SAndroid Build Coastguard Worker key = key.transpose(1, 2) 3253*da0073e9SAndroid Build Coastguard Worker value = value.transpose(1, 2) 3254*da0073e9SAndroid Build Coastguard Worker 3255*da0073e9SAndroid Build Coastguard Worker with sdpa_kernel(backends=[kernel]): 3256*da0073e9SAndroid Build Coastguard Worker actual = torch.nn.functional.scaled_dot_product_attention( 3257*da0073e9SAndroid Build Coastguard Worker query, key, value, attn_mask=None, dropout_p=0.0, is_causal=False) 3258*da0073e9SAndroid Build Coastguard Worker with sdpa_kernel(backends=[SDPBackend.MATH]): 3259*da0073e9SAndroid Build Coastguard Worker math_ref = torch.nn.functional.scaled_dot_product_attention( 3260*da0073e9SAndroid Build Coastguard Worker query_expanded.contiguous(), key_expanded.contiguous(), value_expanded.contiguous(), 3261*da0073e9SAndroid Build Coastguard Worker attn_mask=None, dropout_p=0.0, is_causal=False) 3262*da0073e9SAndroid Build Coastguard Worker 3263*da0073e9SAndroid Build Coastguard Worker self.assertEqual(actual.contiguous(), math_ref.contiguous().to(dtype), atol=1e-3, rtol=1e-2) 3264*da0073e9SAndroid Build Coastguard Worker 3265*da0073e9SAndroid Build Coastguard Worker @skipIfRocm # Nested tensor 3266*da0073e9SAndroid Build Coastguard Worker @unittest.skipIf(not PLATFORM_SUPPORTS_MEM_EFF_ATTENTION, "Fused SDPA was not built for this system") 3267*da0073e9SAndroid Build Coastguard Worker def test_fused_kernels_nested_broadcasting_query_dense(self, device): 3268*da0073e9SAndroid Build Coastguard Worker rand_nested_tensor = partial(rand_sdpa_tensor, type="nested", device=device, dtype=torch.float32) 3269*da0073e9SAndroid Build Coastguard Worker batch, num_heads, head_dim, head_dim_v = 32, 16, 64, 96 3270*da0073e9SAndroid Build Coastguard Worker seq_lens = torch.randint(low=1, high=32, size=(batch,)).tolist() 3271*da0073e9SAndroid Build Coastguard Worker q_shape = (1, 1, num_heads, head_dim) 3272*da0073e9SAndroid Build Coastguard Worker k_shape = SdpaShape(batch, num_heads, seq_lens, head_dim) 3273*da0073e9SAndroid Build Coastguard Worker v_shape = SdpaShape(batch, 1, seq_lens, head_dim_v) 3274*da0073e9SAndroid Build Coastguard Worker 3275*da0073e9SAndroid Build Coastguard Worker # create a dense query 3276*da0073e9SAndroid Build Coastguard Worker query = torch.randn(q_shape, device=device, dtype=torch.float32) 3277*da0073e9SAndroid Build Coastguard Worker key = rand_nested_tensor(k_shape) 3278*da0073e9SAndroid Build Coastguard Worker value = rand_nested_tensor(v_shape) 3279*da0073e9SAndroid Build Coastguard Worker 3280*da0073e9SAndroid Build Coastguard Worker # (1, 1, num_heads, head_dim) -> (batch, 1, num_heads, head_dim) 3281*da0073e9SAndroid Build Coastguard Worker query_expanded = torch.nested.nested_tensor([query.squeeze(0) for _ in range(batch)]).transpose(1, 2) 3282*da0073e9SAndroid Build Coastguard Worker # (batch, seq_lens, 1, head_dim) -> (batch, seq_lens, num_heads, head_dim) 3283*da0073e9SAndroid Build Coastguard Worker value_expanded = torch.nested.nested_tensor( 3284*da0073e9SAndroid Build Coastguard Worker [t.expand(-1, num_heads, head_dim_v) for t in value.unbind()]).transpose(1, 2) 3285*da0073e9SAndroid Build Coastguard Worker 3286*da0073e9SAndroid Build Coastguard Worker query = query.transpose(1, 2) 3287*da0073e9SAndroid Build Coastguard Worker key = key.transpose(1, 2) 3288*da0073e9SAndroid Build Coastguard Worker value = value.transpose(1, 2) 3289*da0073e9SAndroid Build Coastguard Worker 3290*da0073e9SAndroid Build Coastguard Worker with sdpa_kernel(backends=[SDPBackend.EFFICIENT_ATTENTION]): 3291*da0073e9SAndroid Build Coastguard Worker actual = torch.nn.functional.scaled_dot_product_attention( 3292*da0073e9SAndroid Build Coastguard Worker query, key, value, attn_mask=None, dropout_p=0.0, is_causal=False) 3293*da0073e9SAndroid Build Coastguard Worker with sdpa_kernel(backends=[SDPBackend.MATH]): 3294*da0073e9SAndroid Build Coastguard Worker math_ref = torch.nn.functional.scaled_dot_product_attention( 3295*da0073e9SAndroid Build Coastguard Worker query_expanded.contiguous(), key.contiguous(), value_expanded.contiguous(), 3296*da0073e9SAndroid Build Coastguard Worker attn_mask=None, dropout_p=0.0, is_causal=False) 3297*da0073e9SAndroid Build Coastguard Worker 3298*da0073e9SAndroid Build Coastguard Worker self.assertEqual(actual.contiguous(), math_ref.contiguous(), atol=1e-3, rtol=1e-2) 3299*da0073e9SAndroid Build Coastguard Worker 3300*da0073e9SAndroid Build Coastguard Worker @onlyCUDA 3301*da0073e9SAndroid Build Coastguard Worker @skipIfRocm # Nested tensor 3302*da0073e9SAndroid Build Coastguard Worker @unittest.skipIf(not PLATFORM_SUPPORTS_FLASH_ATTENTION, "Does not support SDPA or pre-SM80 hardware") 3303*da0073e9SAndroid Build Coastguard Worker @parametrize("batch_size", [8, 32]) 3304*da0073e9SAndroid Build Coastguard Worker @parametrize("max_seq_len_q", [32, 256]) 3305*da0073e9SAndroid Build Coastguard Worker @parametrize("max_seq_len_kv", [32, 256]) 3306*da0073e9SAndroid Build Coastguard Worker @parametrize("head_dim", [8, 64]) 3307*da0073e9SAndroid Build Coastguard Worker @parametrize("dropout_p", [0.0, 0.1]) 3308*da0073e9SAndroid Build Coastguard Worker @parametrize("dtype", [torch.float16]) 3309*da0073e9SAndroid Build Coastguard Worker @parametrize("scale", [None, "l1"]) 3310*da0073e9SAndroid Build Coastguard Worker @parametrize("is_causal", [True, False]) 3311*da0073e9SAndroid Build Coastguard Worker def test_flash_attention_vs_math_ref_grads_nestedtensor(self, device, batch_size: int, max_seq_len_q: int, max_seq_len_kv: int, 3312*da0073e9SAndroid Build Coastguard Worker head_dim: int, dropout_p: float, dtype: torch.dtype, 3313*da0073e9SAndroid Build Coastguard Worker scale: str, is_causal: bool): 3314*da0073e9SAndroid Build Coastguard Worker if is_causal: 3315*da0073e9SAndroid Build Coastguard Worker # TODO we should support this 3316*da0073e9SAndroid Build Coastguard Worker self.assertRaisesRegex(RuntimeError, "Nested tensors for query / key are not supported when is_causal=True") 3317*da0073e9SAndroid Build Coastguard Worker return 3318*da0073e9SAndroid Build Coastguard Worker scale = scale if scale is None else (1 / head_dim) 3319*da0073e9SAndroid Build Coastguard Worker n_heads = 4 3320*da0073e9SAndroid Build Coastguard Worker seq_lens_q = torch.randint(low=1, high=max_seq_len_q, size=(batch_size,)) 3321*da0073e9SAndroid Build Coastguard Worker # Set one entry to max length 3322*da0073e9SAndroid Build Coastguard Worker seq_lens_q[torch.randint(0, batch_size, size=(1,))] = max_seq_len_q 3323*da0073e9SAndroid Build Coastguard Worker seq_lens_kv = torch.randint(low=1, high=max_seq_len_kv, size=(batch_size,)) 3324*da0073e9SAndroid Build Coastguard Worker seq_lens_kv[torch.randint(0, batch_size, size=(1,))] = max_seq_len_kv 3325*da0073e9SAndroid Build Coastguard Worker 3326*da0073e9SAndroid Build Coastguard Worker def rand_nt(sequence_list, num_heads, head_dim): 3327*da0073e9SAndroid Build Coastguard Worker tensors = [torch.rand((num_heads, seq_len, head_dim)) for seq_len in sequence_list] 3328*da0073e9SAndroid Build Coastguard Worker return torch.nested.nested_tensor(tensors, requires_grad=True, device=device, dtype=dtype) 3329*da0073e9SAndroid Build Coastguard Worker 3330*da0073e9SAndroid Build Coastguard Worker query = rand_nt(seq_lens_q, n_heads, head_dim) 3331*da0073e9SAndroid Build Coastguard Worker key = rand_nt(seq_lens_kv, n_heads, head_dim) 3332*da0073e9SAndroid Build Coastguard Worker value = rand_nt(seq_lens_kv, n_heads, head_dim) 3333*da0073e9SAndroid Build Coastguard Worker 3334*da0073e9SAndroid Build Coastguard Worker # Run the math kernel on low precision references 3335*da0073e9SAndroid Build Coastguard Worker query_ref_lp = query.clone().detach().requires_grad_(True) 3336*da0073e9SAndroid Build Coastguard Worker key_ref_lp = key.clone().detach().requires_grad_(True) 3337*da0073e9SAndroid Build Coastguard Worker value_ref_lp = value.clone().detach().requires_grad_(True) 3338*da0073e9SAndroid Build Coastguard Worker 3339*da0073e9SAndroid Build Coastguard Worker query_ref = query.clone().detach().to(torch.float32).requires_grad_(True) 3340*da0073e9SAndroid Build Coastguard Worker key_ref = key.clone().detach().to(torch.float32).requires_grad_(True) 3341*da0073e9SAndroid Build Coastguard Worker value_ref = value.clone().detach().to(torch.float32).requires_grad_(True) 3342*da0073e9SAndroid Build Coastguard Worker 3343*da0073e9SAndroid Build Coastguard Worker is_dropout = dropout_p > 0.0 3344*da0073e9SAndroid Build Coastguard Worker 3345*da0073e9SAndroid Build Coastguard Worker if not is_dropout: 3346*da0073e9SAndroid Build Coastguard Worker with sdpa_kernel(backends=[SDPBackend.FLASH_ATTENTION]): 3347*da0073e9SAndroid Build Coastguard Worker out = F.scaled_dot_product_attention(query, key, value, dropout_p=dropout_p, is_causal=is_causal, scale=scale) 3348*da0073e9SAndroid Build Coastguard Worker with sdpa_kernel(backends=[SDPBackend.MATH]): 3349*da0073e9SAndroid Build Coastguard Worker # High Precision Math Reference 3350*da0073e9SAndroid Build Coastguard Worker out_ref = F.scaled_dot_product_attention( 3351*da0073e9SAndroid Build Coastguard Worker query_ref, key_ref, value_ref, is_causal=is_causal, scale=scale) 3352*da0073e9SAndroid Build Coastguard Worker # Low Precision Math Reference 3353*da0073e9SAndroid Build Coastguard Worker out_lp_ref = F.scaled_dot_product_attention( 3354*da0073e9SAndroid Build Coastguard Worker query_ref_lp, key_ref_lp, value_ref_lp, is_causal=is_causal, scale=scale) 3355*da0073e9SAndroid Build Coastguard Worker else: 3356*da0073e9SAndroid Build Coastguard Worker # Create real output 3357*da0073e9SAndroid Build Coastguard Worker output_tuple = torch.ops.aten._scaled_dot_product_flash_attention( 3358*da0073e9SAndroid Build Coastguard Worker query, key, value, dropout_p=dropout_p, is_causal=is_causal, 3359*da0073e9SAndroid Build Coastguard Worker scale=scale, return_debug_mask=is_dropout) 3360*da0073e9SAndroid Build Coastguard Worker out = output_tuple[0] 3361*da0073e9SAndroid Build Coastguard Worker dbug_mask = output_tuple[-1] 3362*da0073e9SAndroid Build Coastguard Worker 3363*da0073e9SAndroid Build Coastguard Worker query_padding_mask = torch.arange(max_seq_len_q).unsqueeze(0).expand( 3364*da0073e9SAndroid Build Coastguard Worker batch_size, max_seq_len_q 3365*da0073e9SAndroid Build Coastguard Worker ) < seq_lens_q.unsqueeze(-1) 3366*da0073e9SAndroid Build Coastguard Worker query_padding_mask = query_padding_mask.to("cuda") 3367*da0073e9SAndroid Build Coastguard Worker 3368*da0073e9SAndroid Build Coastguard Worker key_padding_mask = torch.arange(max_seq_len_kv).unsqueeze(0).expand( 3369*da0073e9SAndroid Build Coastguard Worker batch_size, max_seq_len_kv 3370*da0073e9SAndroid Build Coastguard Worker ) < seq_lens_kv.unsqueeze(-1) 3371*da0073e9SAndroid Build Coastguard Worker key_padding_mask = key_padding_mask.to("cuda") 3372*da0073e9SAndroid Build Coastguard Worker 3373*da0073e9SAndroid Build Coastguard Worker softmax_mask = self.convert_flash_attn_S_to_softmax( 3374*da0073e9SAndroid Build Coastguard Worker dbug_mask, max_seq_len_q, max_seq_len_kv, query_padding_mask, key_padding_mask, causal=is_causal) 3375*da0073e9SAndroid Build Coastguard Worker dropout_mask = softmax_mask >= 0 3376*da0073e9SAndroid Build Coastguard Worker nt_stack = [] 3377*da0073e9SAndroid Build Coastguard Worker for tensor_component in range(batch_size): 3378*da0073e9SAndroid Build Coastguard Worker batch_stack = [] 3379*da0073e9SAndroid Build Coastguard Worker for head in range(n_heads): 3380*da0073e9SAndroid Build Coastguard Worker batch_stack.append(dropout_mask[tensor_component, head, 3381*da0073e9SAndroid Build Coastguard Worker 0:seq_lens_q[tensor_component], 3382*da0073e9SAndroid Build Coastguard Worker 0:seq_lens_kv[tensor_component]].unsqueeze(0)) 3383*da0073e9SAndroid Build Coastguard Worker nt_stack.append(torch.cat(batch_stack)) 3384*da0073e9SAndroid Build Coastguard Worker nested_dropout_mask = torch.nested.nested_tensor(nt_stack) 3385*da0073e9SAndroid Build Coastguard Worker # High Precision Math Reference 3386*da0073e9SAndroid Build Coastguard Worker out_ref = torch.ops.aten._scaled_dot_product_attention_math( 3387*da0073e9SAndroid Build Coastguard Worker query_ref, key_ref, value_ref, dropout_p=dropout_p, 3388*da0073e9SAndroid Build Coastguard Worker is_causal=is_causal, scale=scale, dropout_mask=nested_dropout_mask)[0] 3389*da0073e9SAndroid Build Coastguard Worker # Low Precision Math Reference 3390*da0073e9SAndroid Build Coastguard Worker out_lp_ref = torch.ops.aten._scaled_dot_product_attention_math( 3391*da0073e9SAndroid Build Coastguard Worker query_ref_lp, key_ref_lp, value_ref_lp, dropout_p=dropout_p, is_causal=is_causal, scale=scale, 3392*da0073e9SAndroid Build Coastguard Worker dropout_mask=nested_dropout_mask)[0] 3393*da0073e9SAndroid Build Coastguard Worker 3394*da0073e9SAndroid Build Coastguard Worker upstream_grad = out.detach().clone().contiguous() 3395*da0073e9SAndroid Build Coastguard Worker 3396*da0073e9SAndroid Build Coastguard Worker out.backward(upstream_grad) 3397*da0073e9SAndroid Build Coastguard Worker out_ref.backward(upstream_grad.to(out_ref.dtype)) 3398*da0073e9SAndroid Build Coastguard Worker out_lp_ref.backward(upstream_grad.to(out_lp_ref.dtype)) 3399*da0073e9SAndroid Build Coastguard Worker 3400*da0073e9SAndroid Build Coastguard Worker # See [Note] Fused Tolerances above 3401*da0073e9SAndroid Build Coastguard Worker output_ref_atol, output_ref_rtol = calculate_nt_tolerances(out_ref, out_lp_ref, out.dtype) 3402*da0073e9SAndroid Build Coastguard Worker grad_q_ref_atol, grad_q_ref_rtol = calculate_nt_tolerances(query_ref.grad, query_ref_lp.grad, 3403*da0073e9SAndroid Build Coastguard Worker query.grad.dtype, fudge_factor=4) 3404*da0073e9SAndroid Build Coastguard Worker grad_k_ref_atol, grad_k_ref_rtol = calculate_nt_tolerances(key_ref.grad, key_ref_lp.grad, key.grad.dtype) 3405*da0073e9SAndroid Build Coastguard Worker grad_v_ref_atol, grad_v_ref_rtol = calculate_nt_tolerances(value_ref.grad, value_ref_lp.grad, value.grad.dtype) 3406*da0073e9SAndroid Build Coastguard Worker 3407*da0073e9SAndroid Build Coastguard Worker self.assertEqual(out, out_ref.to(out.dtype), atol=output_ref_atol, rtol=output_ref_rtol) 3408*da0073e9SAndroid Build Coastguard Worker self.assertEqual(query.grad, query_ref.grad.to(query.grad.dtype), 3409*da0073e9SAndroid Build Coastguard Worker atol=grad_q_ref_atol, rtol=grad_q_ref_rtol) 3410*da0073e9SAndroid Build Coastguard Worker self.assertEqual(key.grad.contiguous(), key_ref.grad.contiguous().to(key.grad.dtype), 3411*da0073e9SAndroid Build Coastguard Worker atol=grad_k_ref_atol, rtol=grad_k_ref_rtol) 3412*da0073e9SAndroid Build Coastguard Worker self.assertEqual(value.grad, value_ref.grad.to(value.grad.dtype), 3413*da0073e9SAndroid Build Coastguard Worker atol=grad_v_ref_atol, rtol=grad_v_ref_rtol) 3414*da0073e9SAndroid Build Coastguard Worker 3415*da0073e9SAndroid Build Coastguard Workerclass TestAttnBias(NNTestCase): 3416*da0073e9SAndroid Build Coastguard Worker 3417*da0073e9SAndroid Build Coastguard Worker def run_test( 3418*da0073e9SAndroid Build Coastguard Worker self, 3419*da0073e9SAndroid Build Coastguard Worker device, 3420*da0073e9SAndroid Build Coastguard Worker make_q, 3421*da0073e9SAndroid Build Coastguard Worker make_kv, 3422*da0073e9SAndroid Build Coastguard Worker attn_bias=None, 3423*da0073e9SAndroid Build Coastguard Worker forw_tolerances: Optional[Tolerances] = None, 3424*da0073e9SAndroid Build Coastguard Worker grad_tolerances: Optional[Tolerances] = None, 3425*da0073e9SAndroid Build Coastguard Worker backend=None, 3426*da0073e9SAndroid Build Coastguard Worker ): 3427*da0073e9SAndroid Build Coastguard Worker if backend is not None: 3428*da0073e9SAndroid Build Coastguard Worker torch._dynamo.reset() 3429*da0073e9SAndroid Build Coastguard Worker 3430*da0073e9SAndroid Build Coastguard Worker query, key, value = make_q(), make_kv(), make_kv() 3431*da0073e9SAndroid Build Coastguard Worker query_prototype, key_prototype, value_prototype = query_key_value_clones(query, key, value) 3432*da0073e9SAndroid Build Coastguard Worker 3433*da0073e9SAndroid Build Coastguard Worker realized = attn_bias._materialize(device) if attn_bias is not None else None 3434*da0073e9SAndroid Build Coastguard Worker pytorch_output = scaled_dot_product_attention( 3435*da0073e9SAndroid Build Coastguard Worker query, key, value, attn_mask=realized, dropout_p=0.0, is_causal=False 3436*da0073e9SAndroid Build Coastguard Worker ) 3437*da0073e9SAndroid Build Coastguard Worker 3438*da0073e9SAndroid Build Coastguard Worker sdpa_op = ( 3439*da0073e9SAndroid Build Coastguard Worker torch.compile(scaled_dot_product_attention, backend=backend) 3440*da0073e9SAndroid Build Coastguard Worker if backend is not None 3441*da0073e9SAndroid Build Coastguard Worker else scaled_dot_product_attention 3442*da0073e9SAndroid Build Coastguard Worker ) 3443*da0073e9SAndroid Build Coastguard Worker sdpa_output = sdpa_op( 3444*da0073e9SAndroid Build Coastguard Worker query_prototype, 3445*da0073e9SAndroid Build Coastguard Worker key_prototype, 3446*da0073e9SAndroid Build Coastguard Worker value_prototype, 3447*da0073e9SAndroid Build Coastguard Worker attn_mask=attn_bias, 3448*da0073e9SAndroid Build Coastguard Worker dropout_p=0.0, 3449*da0073e9SAndroid Build Coastguard Worker is_causal=False, 3450*da0073e9SAndroid Build Coastguard Worker scale=None, 3451*da0073e9SAndroid Build Coastguard Worker ) 3452*da0073e9SAndroid Build Coastguard Worker 3453*da0073e9SAndroid Build Coastguard Worker dOut = torch.randn_like(pytorch_output) 3454*da0073e9SAndroid Build Coastguard Worker pytorch_output.backward(dOut) 3455*da0073e9SAndroid Build Coastguard Worker sdpa_output.backward(dOut) 3456*da0073e9SAndroid Build Coastguard Worker 3457*da0073e9SAndroid Build Coastguard Worker # Use default assert_close tolerances for dtypes 3458*da0073e9SAndroid Build Coastguard Worker if forw_tolerances is None: 3459*da0073e9SAndroid Build Coastguard Worker forw_tolerances = Tolerances(atol=None, rtol=None) 3460*da0073e9SAndroid Build Coastguard Worker if grad_tolerances is None: 3461*da0073e9SAndroid Build Coastguard Worker grad_tolerances = Tolerances(atol=None, rtol=None) 3462*da0073e9SAndroid Build Coastguard Worker 3463*da0073e9SAndroid Build Coastguard Worker torch.testing.assert_close(pytorch_output, sdpa_output, rtol=forw_tolerances.rtol, atol=forw_tolerances.atol) 3464*da0073e9SAndroid Build Coastguard Worker torch.testing.assert_close(query.grad, query_prototype.grad, rtol=grad_tolerances.rtol, atol=grad_tolerances.atol) 3465*da0073e9SAndroid Build Coastguard Worker torch.testing.assert_close(key.grad, key_prototype.grad, rtol=grad_tolerances.rtol, atol=grad_tolerances.atol) 3466*da0073e9SAndroid Build Coastguard Worker torch.testing.assert_close(value.grad, value_prototype.grad, rtol=grad_tolerances.rtol, atol=grad_tolerances.atol) 3467*da0073e9SAndroid Build Coastguard Worker 3468*da0073e9SAndroid Build Coastguard Worker @skipIfRocm # No support for the second variant for now 3469*da0073e9SAndroid Build Coastguard Worker @parametrize("causal_variant", [CausalVariant.UPPER_LEFT, CausalVariant.LOWER_RIGHT]) 3470*da0073e9SAndroid Build Coastguard Worker @parametrize( 3471*da0073e9SAndroid Build Coastguard Worker "shape", 3472*da0073e9SAndroid Build Coastguard Worker [(16, 16, 128, 128, 16), (16, 16, 128, 256, 32), (16, 16, 256, 128, 32), (1, 1, 23, 56, 15)], 3473*da0073e9SAndroid Build Coastguard Worker ) 3474*da0073e9SAndroid Build Coastguard Worker def test_causal_variants(self, device, causal_variant: CausalVariant, shape: List[Tuple[int]]): 3475*da0073e9SAndroid Build Coastguard Worker make_tensor = partial( 3476*da0073e9SAndroid Build Coastguard Worker torch.rand, device=device, dtype=torch.float16, requires_grad=True 3477*da0073e9SAndroid Build Coastguard Worker ) 3478*da0073e9SAndroid Build Coastguard Worker 3479*da0073e9SAndroid Build Coastguard Worker bsz, num_heads, seq_len_q, seq_len_kv, head_dim = shape 3480*da0073e9SAndroid Build Coastguard Worker make_q_tensor = partial(make_tensor, SdpaShape(bsz, num_heads, seq_len_q, head_dim)) 3481*da0073e9SAndroid Build Coastguard Worker make_kv_tensor = partial(make_tensor, SdpaShape(bsz, num_heads, seq_len_kv, head_dim)) 3482*da0073e9SAndroid Build Coastguard Worker if causal_variant == CausalVariant.LOWER_RIGHT and seq_len_q > seq_len_kv: 3483*da0073e9SAndroid Build Coastguard Worker self.skipTest( 3484*da0073e9SAndroid Build Coastguard Worker "Lower right causal mask will produce NaNs in the output when seq_len_q > seq_len_kv!" 3485*da0073e9SAndroid Build Coastguard Worker ) 3486*da0073e9SAndroid Build Coastguard Worker 3487*da0073e9SAndroid Build Coastguard Worker forw_tol = Tolerances(1e-3, 1e-3) 3488*da0073e9SAndroid Build Coastguard Worker grad_tol = Tolerances(5e-3, 5e-3) 3489*da0073e9SAndroid Build Coastguard Worker 3490*da0073e9SAndroid Build Coastguard Worker if causal_variant == CausalVariant.UPPER_LEFT: 3491*da0073e9SAndroid Build Coastguard Worker attn_bias = causal_upper_left(seq_len_q, seq_len_kv) 3492*da0073e9SAndroid Build Coastguard Worker else: 3493*da0073e9SAndroid Build Coastguard Worker attn_bias = causal_lower_right(seq_len_q, seq_len_kv) 3494*da0073e9SAndroid Build Coastguard Worker 3495*da0073e9SAndroid Build Coastguard Worker self.run_test(device, make_q_tensor, make_kv_tensor, attn_bias, forw_tol, grad_tol, backend=None) 3496*da0073e9SAndroid Build Coastguard Worker 3497*da0073e9SAndroid Build Coastguard Worker @skipIfRocm # CausalVariant 3498*da0073e9SAndroid Build Coastguard Worker @parametrize("causal_variant", [CausalVariant.UPPER_LEFT, CausalVariant.LOWER_RIGHT]) 3499*da0073e9SAndroid Build Coastguard Worker @parametrize( 3500*da0073e9SAndroid Build Coastguard Worker "shape", 3501*da0073e9SAndroid Build Coastguard Worker [(16, 16, 128, 128, 16), (16, 16, 128, 256, 32), (16, 16, 256, 128, 32), (1, 1, 23, 56, 15)], 3502*da0073e9SAndroid Build Coastguard Worker ) 3503*da0073e9SAndroid Build Coastguard Worker @unittest.skipIf(IS_WINDOWS, "torch.compile is not supported on windows") 3504*da0073e9SAndroid Build Coastguard Worker @skipIfTorchDynamo("This function already calls torch.compile.") 3505*da0073e9SAndroid Build Coastguard Worker def test_causal_variants_compile(self, device, causal_variant: CausalVariant, shape: List[Tuple[int]]): 3506*da0073e9SAndroid Build Coastguard Worker cnts = CompileCounterWithBackend("aot_eager") 3507*da0073e9SAndroid Build Coastguard Worker make_tensor = partial( 3508*da0073e9SAndroid Build Coastguard Worker torch.rand, device=device, dtype=torch.float16, requires_grad=True 3509*da0073e9SAndroid Build Coastguard Worker ) 3510*da0073e9SAndroid Build Coastguard Worker 3511*da0073e9SAndroid Build Coastguard Worker bsz, num_heads, seq_len_q, seq_len_kv, head_dim = shape 3512*da0073e9SAndroid Build Coastguard Worker make_q_tensor = partial(make_tensor, SdpaShape(bsz, num_heads, seq_len_q, head_dim)) 3513*da0073e9SAndroid Build Coastguard Worker make_kv_tensor = partial(make_tensor, SdpaShape(bsz, num_heads, seq_len_kv, head_dim)) 3514*da0073e9SAndroid Build Coastguard Worker if causal_variant == CausalVariant.LOWER_RIGHT and seq_len_q > seq_len_kv: 3515*da0073e9SAndroid Build Coastguard Worker self.skipTest( 3516*da0073e9SAndroid Build Coastguard Worker "Lower right causal mask will produce NaNs in the output when seq_len_q > seq_len_kv!" 3517*da0073e9SAndroid Build Coastguard Worker ) 3518*da0073e9SAndroid Build Coastguard Worker forw_tol = Tolerances(1e-3, 1e-3) 3519*da0073e9SAndroid Build Coastguard Worker grad_tol = Tolerances(5e-3, 5e-3) 3520*da0073e9SAndroid Build Coastguard Worker 3521*da0073e9SAndroid Build Coastguard Worker if causal_variant == CausalVariant.UPPER_LEFT: 3522*da0073e9SAndroid Build Coastguard Worker attn_bias = causal_upper_left(seq_len_q, seq_len_kv) 3523*da0073e9SAndroid Build Coastguard Worker else: 3524*da0073e9SAndroid Build Coastguard Worker attn_bias = causal_lower_right(seq_len_q, seq_len_kv) 3525*da0073e9SAndroid Build Coastguard Worker 3526*da0073e9SAndroid Build Coastguard Worker self.run_test(device, make_q_tensor, make_kv_tensor, attn_bias, forw_tol, grad_tol, backend=cnts) 3527*da0073e9SAndroid Build Coastguard Worker self.assertEqual(cnts.frame_count, 1, "Compiled graph should have 1 frame!") 3528*da0073e9SAndroid Build Coastguard Worker 3529*da0073e9SAndroid Build Coastguard Worker @parametrize("shape", [(16, 16, 128, 128, 16), (16, 16, 128, 256, 32), (16, 16, 256, 128, 32), (1, 1, 23, 56, 15)]) 3530*da0073e9SAndroid Build Coastguard Worker def test_is_causal_equals_upper_left(self, device, shape: List[Tuple[int]]): 3531*da0073e9SAndroid Build Coastguard Worker make_tensor = partial( 3532*da0073e9SAndroid Build Coastguard Worker torch.rand, device=device, dtype=torch.float16, requires_grad=True 3533*da0073e9SAndroid Build Coastguard Worker ) 3534*da0073e9SAndroid Build Coastguard Worker 3535*da0073e9SAndroid Build Coastguard Worker bsz, num_heads, seq_len_q, seq_len_kv, head_dim = shape 3536*da0073e9SAndroid Build Coastguard Worker make_q_tensor = partial(make_tensor, SdpaShape(bsz, num_heads, seq_len_q, head_dim)) 3537*da0073e9SAndroid Build Coastguard Worker make_kv_tensor = partial(make_tensor, SdpaShape(bsz, num_heads, seq_len_kv, head_dim)) 3538*da0073e9SAndroid Build Coastguard Worker 3539*da0073e9SAndroid Build Coastguard Worker forw_tol = Tolerances(1e-3, 1e-3) 3540*da0073e9SAndroid Build Coastguard Worker grad_tol = Tolerances(5e-3, 5e-3) 3541*da0073e9SAndroid Build Coastguard Worker 3542*da0073e9SAndroid Build Coastguard Worker query = make_q_tensor() 3543*da0073e9SAndroid Build Coastguard Worker key = make_kv_tensor() 3544*da0073e9SAndroid Build Coastguard Worker value = make_kv_tensor() 3545*da0073e9SAndroid Build Coastguard Worker attn_bias = causal_upper_left(seq_len_q, seq_len_kv) 3546*da0073e9SAndroid Build Coastguard Worker 3547*da0073e9SAndroid Build Coastguard Worker out_attn_bias = scaled_dot_product_attention(query, key, value, attn_mask=attn_bias, dropout_p=0.0) 3548*da0073e9SAndroid Build Coastguard Worker out_is_causal = scaled_dot_product_attention(query, key, value, is_causal=True, dropout_p=0.0) 3549*da0073e9SAndroid Build Coastguard Worker torch.testing.assert_close(out_attn_bias, out_is_causal, rtol=forw_tol.rtol, atol=forw_tol.atol) 3550*da0073e9SAndroid Build Coastguard Worker 3551*da0073e9SAndroid Build Coastguard Worker def test_is_causal_and_mask_fails(self, device): 3552*da0073e9SAndroid Build Coastguard Worker make_tensor = partial( 3553*da0073e9SAndroid Build Coastguard Worker torch.rand, device=device, dtype=torch.float16, requires_grad=True 3554*da0073e9SAndroid Build Coastguard Worker ) 3555*da0073e9SAndroid Build Coastguard Worker make_q_tensor = partial(make_tensor, SdpaShape(16, 16, 128, 16)) 3556*da0073e9SAndroid Build Coastguard Worker make_kv_tensor = partial(make_tensor, SdpaShape(16, 16, 128, 16)) 3557*da0073e9SAndroid Build Coastguard Worker 3558*da0073e9SAndroid Build Coastguard Worker query = make_q_tensor() 3559*da0073e9SAndroid Build Coastguard Worker key = make_kv_tensor() 3560*da0073e9SAndroid Build Coastguard Worker value = make_kv_tensor() 3561*da0073e9SAndroid Build Coastguard Worker attn_bias = causal_upper_left(128, 128) 3562*da0073e9SAndroid Build Coastguard Worker 3563*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(ValueError, "CausalBias should not be used with causal=True"): 3564*da0073e9SAndroid Build Coastguard Worker scaled_dot_product_attention(query, key, value, attn_mask=attn_bias, is_causal=True, dropout_p=0.0) 3565*da0073e9SAndroid Build Coastguard Worker 3566*da0073e9SAndroid Build Coastguard Workerif NOTEST_CPU: 3567*da0073e9SAndroid Build Coastguard Worker device_types = ("cuda", ) 3568*da0073e9SAndroid Build Coastguard Workerelse: 3569*da0073e9SAndroid Build Coastguard Worker device_types = ("cpu", "cuda") 3570*da0073e9SAndroid Build Coastguard Worker 3571*da0073e9SAndroid Build Coastguard Workerinstantiate_device_type_tests(TestTransformers, globals(), only_for=device_types) 3572*da0073e9SAndroid Build Coastguard Workerinstantiate_device_type_tests(TestSDPAFailureModes, globals(), only_for=device_types) 3573*da0073e9SAndroid Build Coastguard Workerinstantiate_device_type_tests(TestSDPA, globals(), only_for=device_types) 3574*da0073e9SAndroid Build Coastguard Workerinstantiate_device_type_tests(TestSDPACudaOnly, globals(), only_for=("cuda")) 3575*da0073e9SAndroid Build Coastguard Workerinstantiate_device_type_tests(TestAttnBias, globals(), only_for=device_types) 3576*da0073e9SAndroid Build Coastguard Worker 3577*da0073e9SAndroid Build Coastguard Workerif __name__ == '__main__': 3578*da0073e9SAndroid Build Coastguard Worker run_tests() 3579