1# Owner(s): ["module: inductor"] 2import functools 3import itertools 4import math 5 6import torch 7import torch._inductor.config 8import torch.utils.checkpoint 9from torch._dynamo.debug_utils import aot_graph_input_parser 10from torch._dynamo.utils import counters 11from torch._inductor.test_case import run_tests, TestCase 12from torch._inductor.utils import run_and_get_code 13from torch.testing._internal.common_cuda import ( 14 PLATFORM_SUPPORTS_FUSED_ATTENTION, 15 SM80OrLater, 16) 17from torch.testing._internal.common_utils import IS_LINUX, skipIfRocm 18from torch.testing._internal.inductor_utils import HAS_CPU, HAS_CUDA 19 20 21def checkpoint_wrapper(fn): 22 def inner(*args): 23 return torch.utils.checkpoint.checkpoint(fn, *args, use_reentrant=True) 24 25 return inner 26 27 28class TestSDPAPatternRewriterTemplate(TestCase): 29 use_static_shapes = True 30 31 def _clone_inputs(self, inputs): 32 def clone(x): 33 if not isinstance(x, torch.Tensor): 34 return x 35 return x.clone() 36 37 return [clone(x) for x in inputs] 38 39 def _check_common( 40 self, 41 dot_prod_attention, 42 args1=None, 43 contains=True, 44 atol=1e-5, 45 has_fuse_pattern=True, 46 has_dropout=False, 47 check_train=True, 48 override_check_equal=False, 49 dtype=torch.float, 50 rtol=1.3e-6, 51 ): 52 if args1 is None: 53 tensor_shape = (4, 2, 16, 32) 54 args1 = [ 55 torch.randn(tensor_shape, device=self.device, dtype=dtype), 56 torch.randn(tensor_shape, device=self.device, dtype=dtype), 57 torch.randn(tensor_shape, device=self.device, dtype=dtype), 58 ] 59 else: 60 args1 = list(args1) 61 args2 = self._clone_inputs(args1) 62 63 for training in [False, True] if check_train else [False]: 64 for x in itertools.chain(args1[:], args2[:]): 65 if isinstance(x, torch.Tensor) and x.is_floating_point(): 66 x.requires_grad = training 67 68 if not self.use_static_shapes: 69 torch._dynamo.mark_dynamic(args2[0], 0) 70 torch._dynamo.mark_dynamic(args2[1], 0) 71 torch._dynamo.mark_dynamic(args2[2], 0) 72 73 dropout_arg = [training] if has_dropout else [] 74 torch.manual_seed(1234) 75 result1 = dot_prod_attention(*(args1 + dropout_arg)) 76 77 counters.clear() 78 torch.manual_seed(1234) 79 result2, source_code = run_and_get_code( 80 torch.compile(dot_prod_attention, fullgraph=True), 81 *(args2 + dropout_arg), 82 ) 83 source_code = "\n".join(source_code) 84 if has_fuse_pattern: 85 self.assertGreaterEqual(counters["inductor"]["fuse_attention"], 1) 86 if contains: 87 # many of the patterns get re-expanded in dispatcher 88 self.assertIn( 89 "aten._scaled_dot_product", 90 source_code, 91 ) 92 93 # some tests configured with very low dropout where we still want to check equality 94 if not has_dropout or override_check_equal: 95 self.assertEqual(result1, result2, atol=atol, rtol=1.3e-6) 96 97 if training: 98 result1.sum().backward() 99 result2.sum().backward() 100 for arg1, arg2 in zip(args1, args2): 101 if ( 102 isinstance(arg1, torch.Tensor) 103 and arg1.is_floating_point() 104 and (not has_dropout or override_check_equal) 105 ): 106 self.assertEqual(arg1.grad, arg2.grad, atol=atol, rtol=rtol) 107 108 @skipIfRocm 109 def _test_sdpa_rewriter_1(self): 110 def dot_prod_attention( 111 query: torch.Tensor, key: torch.Tensor, value: torch.Tensor 112 ) -> torch.Tensor: 113 """Input tensors assumed to have shape (batch_size, n_head, seq_len, embed_dim)""" 114 return ( 115 torch.matmul(query, key.transpose(-2, -1)) 116 .div(math.sqrt(key.shape[-1])) 117 .softmax(dim=-1) 118 .matmul(value) 119 ) 120 121 for dtype in [torch.float, torch.half]: 122 atol = 0.001 123 rtol = 1.3e-6 if dtype == torch.float else 0.7 124 if self.device == "cpu" and dtype == torch.half: 125 atol = 2e-3 126 rtol = 1e-2 127 self._check_common(dot_prod_attention, dtype=dtype, atol=atol, rtol=rtol) 128 self._check_common( 129 checkpoint_wrapper(dot_prod_attention), 130 dtype=dtype, 131 atol=atol, 132 rtol=rtol, 133 ) 134 135 @skipIfRocm 136 @torch._inductor.config.patch("freezing", True) 137 def _test_sdpa_rewriter_1_freezing(self): 138 def dot_prod_attention( 139 query: torch.Tensor, key: torch.Tensor, value: torch.Tensor 140 ) -> torch.Tensor: 141 """Input tensors assumed to have shape (batch_size, n_head, seq_len, embed_dim)""" 142 return ( 143 torch.matmul(query, key.transpose(-2, -1)) 144 .div(math.sqrt(key.shape[-1])) 145 .softmax(dim=-1) 146 .matmul(value) 147 ) 148 149 for dtype in [torch.float, torch.half]: 150 atol = 0.001 151 rtol = 1.3e-6 if dtype == torch.float else 0.7 152 if self.device == "cpu" and dtype == torch.half: 153 atol = 2e-3 154 rtol = 1e-2 155 with torch.no_grad(): 156 self._check_common( 157 dot_prod_attention, 158 dtype=dtype, 159 atol=atol, 160 rtol=rtol, 161 check_train=False, 162 ) 163 164 @skipIfRocm 165 def _test_insignificant_strides(self): 166 f32 = torch.float32 167 168 # repro taken from https://github.com/pytorch/pytorch/issues/124289 169 # constant_pad_nd is a single element tensor that gets expanded 170 171 def forward( 172 permute_3: "f32[1, 32, 1, 128]", 173 permute_4: "f32[1, 32, 1, 128]", 174 permute_5: "f32[1, 32, 1, 128]", 175 permute_6: "f32[1, 1, 64]", 176 mul_2: "f32[1, 1, 1, 1]", 177 ): 178 cat = torch.ops.aten.cat.default([permute_6, permute_6], 2) 179 permute_6 = None 180 cos = torch.ops.aten.cos.default(cat) 181 sin = torch.ops.aten.sin.default(cat) 182 unsqueeze_10 = torch.ops.aten.unsqueeze.default(cos, 1) 183 cos = None 184 unsqueeze_11 = torch.ops.aten.unsqueeze.default(sin, 1) 185 sin = None 186 mul_5 = torch.ops.aten.mul.Tensor(permute_3, unsqueeze_10) 187 slice_10 = torch.ops.aten.slice.Tensor(permute_3, 3, 0, 64) 188 slice_11 = torch.ops.aten.slice.Tensor( 189 permute_3, 3, 64, 9223372036854775807 190 ) 191 permute_3 = None 192 neg = torch.ops.aten.neg.default(slice_11) 193 slice_11 = None 194 cat_1 = torch.ops.aten.cat.default([neg, slice_10], 3) 195 neg = slice_10 = None 196 mul_6 = torch.ops.aten.mul.Tensor(cat_1, unsqueeze_11) 197 cat_1 = None 198 add_1 = torch.ops.aten.add.Tensor(mul_5, mul_6) 199 mul_5 = mul_6 = None 200 mul_7 = torch.ops.aten.mul.Tensor(permute_4, unsqueeze_10) 201 unsqueeze_10 = None 202 slice_12 = torch.ops.aten.slice.Tensor(permute_4, 3, 0, 64) 203 slice_13 = torch.ops.aten.slice.Tensor( 204 permute_4, 3, 64, 9223372036854775807 205 ) 206 permute_4 = None 207 neg_1 = torch.ops.aten.neg.default(slice_13) 208 slice_13 = None 209 cat_2 = torch.ops.aten.cat.default([neg_1, slice_12], 3) 210 neg_1 = slice_12 = None 211 mul_8 = torch.ops.aten.mul.Tensor(cat_2, unsqueeze_11) 212 cat_2 = unsqueeze_11 = None 213 add_2 = torch.ops.aten.add.Tensor(mul_7, mul_8) 214 mul_7 = mul_8 = None 215 slice_14 = torch.ops.aten.slice.Tensor(mul_2, 0, 0, 9223372036854775807) 216 mul_2 = None 217 slice_15 = torch.ops.aten.slice.Tensor(slice_14, 1, 0, 9223372036854775807) 218 slice_14 = None 219 slice_16 = torch.ops.aten.slice.Tensor(slice_15, 2, 0, 9223372036854775807) 220 slice_15 = None 221 constant_pad_nd = torch.ops.aten.constant_pad_nd.default( 222 slice_16, [0, 7], 0.0 223 ) 224 slice_16 = None 225 slice_17 = torch.ops.aten.slice.Tensor(constant_pad_nd, -1, 0, 1) 226 constant_pad_nd = None 227 expand_5 = torch.ops.aten.expand.default(slice_17, [1, 32, 1, 1]) 228 _scaled_dot_product_efficient_attention = ( 229 torch.ops.aten._scaled_dot_product_efficient_attention.default( 230 add_1, add_2, permute_5, expand_5, True 231 ) 232 ) 233 return _scaled_dot_product_efficient_attention 234 235 kwargs = aot_graph_input_parser(forward, device="cuda") 236 # runs successfully 237 out_eager = forward(**kwargs) 238 out_c = torch.compile(forward)(**kwargs) 239 # dont compare philox_seed/offset 240 torch.testing.assert_close(out_eager[0:2], out_c[0:2]) 241 242 def _test_pattern_fails_with_reuse(self): 243 """ 244 This test checks that the replacement is not done 245 when an intermediate result is being used / returned downstream 246 """ 247 248 @torch.compile(fullgraph=True) 249 def dot_prod_attention( 250 query: torch.Tensor, key: torch.Tensor, value: torch.Tensor 251 ) -> torch.Tensor: 252 attn_weights = ( 253 torch.matmul(query, key.transpose(-2, -1)) 254 .div(math.sqrt(key.shape[-1])) 255 .softmax(dim=-1) 256 ) 257 return attn_weights.matmul(value), attn_weights 258 259 tensor_shape = (2, 4, 8, 16) 260 args = [ 261 torch.randn(tensor_shape, device=self.device), 262 torch.randn(tensor_shape, device=self.device), 263 torch.randn(tensor_shape, device=self.device), 264 ] 265 _, (source_code,) = run_and_get_code(dot_prod_attention, *args) 266 self.assertNotIn("aten._scaled_dot_product_efficient_attention", source_code) 267 268 @skipIfRocm 269 def _test_sdpa_rewriter_2(self): 270 def dot_prod_attention( 271 query: torch.Tensor, key: torch.Tensor, value: torch.Tensor 272 ) -> torch.Tensor: 273 return ( 274 torch.matmul(query, key.transpose(-2, -1)) 275 .mul(1.0 / math.sqrt(key.shape[-1])) 276 .softmax(dim=-1) 277 .matmul(value) 278 ) 279 280 self._check_common(dot_prod_attention) 281 self._check_common(checkpoint_wrapper(dot_prod_attention)) 282 283 @skipIfRocm # AssertionError: expected size 4==4, stride 32==64 at dim=0 284 def _test_sdpa_rewriter_3(self): 285 def dot_prod_attention( 286 query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, training: bool 287 ) -> torch.Tensor: 288 return torch.nn.functional.dropout( 289 torch.matmul(query, key.transpose(-2, -1)).div(3.0).softmax(dim=-1), 290 p=0.4, 291 training=training, 292 inplace=False, 293 ).matmul(value) 294 295 self._check_common(dot_prod_attention, contains=False, has_dropout=True) 296 self._check_common( 297 checkpoint_wrapper(dot_prod_attention), contains=False, has_dropout=True 298 ) 299 300 @skipIfRocm # AssertionError: expected size 4==4, stride 32==64 at dim=0 301 def _test_sdpa_rewriter_4(self): 302 def dot_prod_attention( 303 query: torch.Tensor, 304 key: torch.Tensor, 305 value: torch.Tensor, 306 training: bool, 307 ) -> torch.Tensor: 308 return torch.nn.functional.dropout( 309 torch.matmul(query, key.transpose(-2, -1)).mul(0.4).softmax(dim=-1), 310 p=0.2, 311 inplace=False, 312 training=training, 313 ).matmul(value) 314 315 self._check_common(dot_prod_attention, contains=False, has_dropout=True) 316 self._check_common( 317 checkpoint_wrapper(dot_prod_attention), contains=False, has_dropout=True 318 ) 319 320 def _test_sdpa_rewriter_5(self): 321 def sfdp_pattern_5_v1(query, key, value): 322 attn_mask = torch.ones( 323 query.size(-2), key.size(-2), dtype=torch.bool, device=query.device 324 ).tril(diagonal=0) 325 attn_mask = attn_mask.masked_fill( 326 torch.logical_not(attn_mask), -float("inf") 327 ) 328 attn_weight = torch.softmax( 329 (query @ key.transpose(-2, -1) / math.sqrt(query.size(-1))) + attn_mask, 330 dim=-1, 331 ) 332 return attn_weight @ value 333 334 def sfdp_pattern_5_v2(query, key, value): 335 # https://github.com/pytorch/pytorch/issues/100318. 336 attn_mask = torch.zeros( 337 query.size(-2), key.size(-2), dtype=torch.bool, device=query.device 338 ).bool() 339 attn_weight = torch.softmax( 340 (query @ key.transpose(-2, -1) / math.sqrt(query.size(-1))) + attn_mask, 341 dim=-1, 342 ) 343 return attn_weight @ value 344 345 self._check_common(sfdp_pattern_5_v1, contains=False) 346 self._check_common(checkpoint_wrapper(sfdp_pattern_5_v1), contains=False) 347 self._check_common(sfdp_pattern_5_v2, contains=False) 348 self._check_common(checkpoint_wrapper(sfdp_pattern_5_v2), contains=False) 349 350 @skipIfRocm 351 def _test_sdpa_rewriter_6(self): 352 def sfdp_pattern_6(query, key, value, training): 353 attn_mask = torch.ones( 354 query.size(-2), key.size(-2), dtype=torch.bool, device=query.device 355 ).tril(diagonal=0) 356 attn_mask = attn_mask.masked_fill( 357 torch.logical_not(attn_mask), -float("inf") 358 ) 359 attn_weight = torch.softmax( 360 (query @ key.transpose(-2, -1) / math.sqrt(query.size(-1))) + attn_mask, 361 dim=-1, 362 ) 363 attn_weight = torch.nn.functional.dropout(attn_weight, 0.5, training) 364 return attn_weight @ value 365 366 self._check_common(sfdp_pattern_6, contains=False, has_dropout=True) 367 self._check_common( 368 checkpoint_wrapper(sfdp_pattern_6), contains=False, has_dropout=True 369 ) 370 371 @skipIfRocm 372 def _test_sdpa_rewriter_7(self): 373 def sfdp_pattern_7(query, key, value, training): 374 q = query.permute(0, 2, 1, 3) 375 k = key.permute(0, 2, 1, 3) 376 v = value.permute(0, 2, 1, 3) 377 div = q @ k.transpose(-2, -1) / math.sqrt(q.size(-1)) 378 div = div.to(torch.float32) 379 attn_weight = torch.softmax(div, dim=-1) 380 # Set to False 381 attn_weight = torch.dropout(attn_weight, 0.00000000001, training) 382 attn_weight = attn_weight.to(torch.float16) 383 return attn_weight @ v 384 385 args = ( 386 torch.randn((2, 8, 4, 16), device=self.device, dtype=torch.half), 387 torch.randn((2, 8, 4, 16), device=self.device, dtype=torch.half), 388 torch.randn((2, 8, 4, 16), device=self.device, dtype=torch.half), 389 ) 390 self._check_common( 391 sfdp_pattern_7, 392 args, 393 contains=SM80OrLater, 394 has_dropout=True, 395 override_check_equal=True, 396 atol=2e-3, 397 ) 398 399 args = ( 400 torch.randn((2, 8, 4, 16), device="cuda", dtype=torch.half), 401 torch.randn((2, 8, 4, 16), device="cuda", dtype=torch.half), 402 torch.randn((2, 8, 4, 16), device="cuda", dtype=torch.half), 403 ) 404 self._check_common( 405 checkpoint_wrapper(sfdp_pattern_7), 406 args, 407 contains=SM80OrLater, 408 has_dropout=True, 409 override_check_equal=True, 410 atol=2e-3, 411 ) 412 413 @skipIfRocm 414 def _test_sdpa_rewriter_8(self): 415 def sfdp_pattern_8(query, key, value): 416 q = query.permute(0, 2, 1, 3) 417 k = key.permute(0, 2, 1, 3) 418 v = value.permute(0, 2, 1, 3) 419 div = q @ k.transpose(-2, -1) / math.sqrt(q.size(-1)) 420 div = div.to(torch.float32) 421 attn_weight = torch.softmax(div, dim=-1) 422 attn_weight = attn_weight.to(torch.float16) 423 return attn_weight @ v 424 425 args = ( 426 torch.randn((2, 8, 4, 16), device=self.device, dtype=torch.half), 427 torch.randn((2, 8, 4, 16), device=self.device, dtype=torch.half), 428 torch.randn((2, 8, 4, 16), device=self.device, dtype=torch.half), 429 ) 430 self._check_common(sfdp_pattern_8, args, atol=2e-3) 431 432 args = ( 433 torch.randn((2, 8, 4, 16), device="cuda", dtype=torch.half), 434 torch.randn((2, 8, 4, 16), device="cuda", dtype=torch.half), 435 torch.randn((2, 8, 4, 16), device="cuda", dtype=torch.half), 436 ) 437 self._check_common(checkpoint_wrapper(sfdp_pattern_8), args, atol=2e-3) 438 439 @skipIfRocm 440 def _test_sdpa_rewriter_9(self): 441 def sfdp_pattern_9(query, key, value, training): 442 q = query.permute(0, 2, 1, 3) 443 k = key.permute(0, 2, 1, 3) 444 v = value.permute(0, 2, 1, 3) 445 q = q / math.sqrt(q.size(-1)) 446 div = q @ k.transpose(-2, -1) 447 div = div.to(torch.float32) 448 attn_weight = torch.softmax(div, dim=-1) 449 # very low dropout to make test pass 450 attn_weight = torch.dropout(attn_weight, 0.00000000001, training) 451 attn_weight = attn_weight.to(torch.float16) 452 return attn_weight @ v 453 454 args = ( 455 torch.randn((2, 8, 4, 16), device=self.device, dtype=torch.half), 456 torch.randn((2, 8, 4, 16), device=self.device, dtype=torch.half), 457 torch.randn((2, 8, 4, 16), device=self.device, dtype=torch.half), 458 ) 459 self._check_common( 460 sfdp_pattern_9, 461 args, 462 contains=SM80OrLater, 463 has_dropout=True, 464 override_check_equal=True, 465 atol=2e-3, 466 ) 467 args = ( 468 torch.randn((2, 8, 4, 16), device="cuda", dtype=torch.half), 469 torch.randn((2, 8, 4, 16), device="cuda", dtype=torch.half), 470 torch.randn((2, 8, 4, 16), device="cuda", dtype=torch.half), 471 ) 472 self._check_common( 473 checkpoint_wrapper(sfdp_pattern_9), 474 args, 475 contains=SM80OrLater, 476 has_dropout=True, 477 override_check_equal=True, 478 atol=2e-3, 479 ) 480 481 @skipIfRocm 482 def _test_sdpa_rewriter_10(self): 483 def sfdp_pattern_10(query, key, value): 484 q = query.permute(0, 2, 1, 3) 485 k = key.permute(0, 2, 1, 3) 486 v = value.permute(0, 2, 1, 3) 487 q = q / math.sqrt(q.size(-1)) 488 div = q @ k.transpose(-2, -1) 489 div = div.to(torch.float32) 490 attn_weight = torch.softmax(div, dim=-1) 491 attn_weight = attn_weight.to(torch.float16) 492 return attn_weight @ v 493 494 args = ( 495 torch.randn((2, 8, 4, 16), device=self.device, dtype=torch.half), 496 torch.randn((2, 8, 4, 16), device=self.device, dtype=torch.half), 497 torch.randn((2, 8, 4, 16), device=self.device, dtype=torch.half), 498 ) 499 self._check_common(sfdp_pattern_10, args, atol=2e-3) 500 501 args = ( 502 torch.randn((2, 8, 4, 16), device="cuda", dtype=torch.half), 503 torch.randn((2, 8, 4, 16), device="cuda", dtype=torch.half), 504 torch.randn((2, 8, 4, 16), device="cuda", dtype=torch.half), 505 ) 506 self._check_common(checkpoint_wrapper(sfdp_pattern_10), args, atol=2e-3) 507 508 def _test_pattern_fails_with_tensor_factor(self): 509 # https://github.com/pytorch/pytorch/issues/99124 510 class Model(torch.nn.Module): 511 def __init__(self, is_inv_factor): 512 super().__init__() 513 self.is_inv_factor = is_inv_factor 514 515 def forward(self, query, key, value, scale_factor) -> torch.Tensor: 516 # Dividing by scale_factor makes scale_factor gradients very 517 # unstable 518 scale_factor = scale_factor.detach() 519 y = torch.matmul(query, key.transpose(-2, -1)) 520 if self.is_inv_factor: 521 y = y.div(scale_factor) 522 else: 523 y = y.mul(scale_factor) 524 return y.softmax(dim=-1).matmul(value) 525 526 tensor_shape = (2, 4, 4, 4) 527 for is_inv_factor in [True, False]: 528 args = [ 529 torch.randn(tensor_shape, device=self.device), 530 torch.randn(tensor_shape, device=self.device), 531 torch.randn(tensor_shape, device=self.device), 532 torch.randn((4, 1, 1), device=self.device), 533 ] 534 model = Model(is_inv_factor).eval() 535 # The training path has an accuracy gap compared with eager mode. 536 self._check_common( 537 model, args1=args, contains=False, atol=1e-3, has_fuse_pattern=False 538 ) 539 540 def _test_pattern_fails_with_unsupported_mask(self): 541 if not self.use_static_shapes: 542 self.skipTest("Causes shape specialization. TODO: investigate") 543 544 # https://github.com/pytorch/pytorch/issues/100315 545 class Model(torch.nn.Module): 546 def __init__( 547 self, 548 ): 549 super().__init__() 550 551 def forward(self, query, key, value, attn_mask) -> torch.Tensor: 552 attn_weight = torch.softmax( 553 query @ key.transpose(-2, -1) / math.sqrt(query.size(-1)) 554 + attn_mask, 555 dim=-1, 556 ) 557 return attn_weight @ value 558 559 tensor_shape = (2, 4, 4, 4) 560 561 upsupported_masks = [ 562 torch.randn((2, 4, 4, 4), device=self.device).to(dtype=torch.int), 563 2.0, 564 ] 565 for atte_mask in upsupported_masks: 566 args = [ 567 torch.randn(tensor_shape, device=self.device), 568 torch.randn(tensor_shape, device=self.device), 569 torch.randn(tensor_shape, device=self.device), 570 atte_mask, 571 ] 572 model = Model().eval() 573 # The training path has an accuracy gap compared with eager mode. 574 self._check_common( 575 model, args1=args, contains=False, atol=1e-4, has_fuse_pattern=False 576 ) 577 578 @skipIfRocm 579 def _test_sdpa_rewriter_11(self): 580 def dot_prod_attention( 581 query: torch.Tensor, key: torch.Tensor, value: torch.Tensor 582 ) -> torch.Tensor: 583 """Input tensors assumed to have shape (batch_size, seq_len, n_head, embed_dim)""" 584 q = query.transpose(1, 2) 585 k = key.transpose(1, 2) 586 v = value.transpose(1, 2) 587 return ( 588 torch.matmul(q, k.transpose(-2, -1)) 589 .div(math.sqrt(key.shape[-1])) 590 .softmax(dim=-1) 591 .matmul(v) 592 ) 593 594 self._check_common(dot_prod_attention) 595 596 def _test_sdpa_rewriter_12(self): 597 def dot_prod_attention( 598 query: torch.Tensor, 599 key: torch.Tensor, 600 value: torch.Tensor, 601 training: bool, 602 ) -> torch.Tensor: 603 """Input tensors assumed to have shape (batch_size, seq_len, n_head, embed_dim)""" 604 q = query.transpose(1, 2) 605 k = key.transpose(1, 2) 606 v = value.transpose(1, 2) 607 return torch.nn.functional.dropout( 608 torch.matmul(q, k.transpose(-2, -1)) 609 .div(math.sqrt(key.shape[-1])) 610 .softmax(dim=-1) 611 .matmul(v), 612 p=0.4, 613 training=training, 614 inplace=False, 615 ) 616 617 self._check_common(dot_prod_attention, contains=False, has_dropout=True) 618 619 @skipIfRocm 620 def _test_sdpa_prev_13(self): 621 def dot_prod_attention( 622 query: torch.Tensor, key: torch.Tensor, value: torch.Tensor 623 ) -> torch.Tensor: 624 """Input tensors assumed to have shape (batch_size, n_head, seq_len, embed_dim)""" 625 return ( 626 torch.matmul(query, key.transpose(-2, -1)) 627 .div(math.sqrt(key.shape[-1])) 628 .softmax(dim=-1) 629 .clone() 630 .matmul(value) 631 ) 632 633 self._check_common(dot_prod_attention, check_train=False) 634 self._check_common(checkpoint_wrapper(dot_prod_attention), check_train=False) 635 636 @skipIfRocm 637 def _test_sdpa_prev_14(self): 638 def dot_prod_attention( 639 query: torch.Tensor, key: torch.Tensor, value: torch.Tensor 640 ) -> torch.Tensor: 641 return ( 642 torch.matmul(query, key.transpose(-2, -1)) 643 .mul(1.0 / math.sqrt(key.shape[-1])) 644 .softmax(dim=-1) 645 .clone() 646 .matmul(value) 647 ) 648 649 self._check_common(dot_prod_attention, check_train=False) 650 self._check_common(checkpoint_wrapper(dot_prod_attention), check_train=False) 651 652 @skipIfRocm 653 def _test_sdpa_prev_15(self): 654 def dot_prod_attention( 655 query: torch.Tensor, key: torch.Tensor, value: torch.Tensor 656 ) -> torch.Tensor: 657 """Input tensors assumed to have shape (batch_size, seq_len, n_head, embed_dim)""" 658 q = query.transpose(1, 2) 659 k = key.transpose(1, 2) 660 v = value.transpose(1, 2) 661 return ( 662 torch.matmul(q, k.transpose(-2, -1)) 663 .div(math.sqrt(key.shape[-1])) 664 .softmax(dim=-1) 665 .clone() 666 .matmul(v) 667 ) 668 669 self._check_common(dot_prod_attention, check_train=False) 670 671 @skipIfRocm 672 def _test_sdpa_rewriter_13(self, dtype): 673 def dot_prod_attention( 674 query: torch.Tensor, 675 key: torch.Tensor, 676 value: torch.Tensor, 677 training: bool, 678 ) -> torch.Tensor: 679 """Input tensors assumed to have shape (batch_size, seq_len, n_head, embed_dim)""" 680 attn_weight = torch.bmm(query, key.transpose(1, 2)).softmax(dim=-1) 681 attn_weight = torch.nn.functional.dropout( 682 attn_weight, p=0.5, training=training 683 ) 684 return torch.bmm(attn_weight, value) 685 686 tensor_shape = (4, 8, 16) 687 args = [ 688 torch.randn(tensor_shape, device=self.device, dtype=dtype), 689 torch.randn(tensor_shape, device=self.device, dtype=dtype), 690 torch.randn(tensor_shape, device=self.device, dtype=dtype), 691 ] 692 693 self._check_common( 694 dot_prod_attention, 695 check_train=False, 696 args1=args, 697 has_dropout=True, 698 override_check_equal=True, 699 atol=1e-2, 700 rtol=1e-2, 701 ) 702 703 @skipIfRocm 704 def _test_sdpa_rewriter_14(self): 705 def dot_prod_attention( 706 query: torch.Tensor, key: torch.Tensor, value: torch.Tensor 707 ) -> torch.Tensor: 708 """Input tensors assumed to have shape (batch_size, seq_len, n_head, embed_dim)""" 709 attn_mask = torch.ones( 710 query.size(1), key.size(1), dtype=torch.bool, device=query.device 711 ).tril(diagonal=0) 712 attn_mask = attn_mask.masked_fill( 713 torch.logical_not(attn_mask), -float("inf") 714 ) 715 q = query.permute(0, 2, 1, 3) 716 k = key.permute(0, 2, 1, 3) 717 v = value.permute(0, 2, 1, 3) 718 return ( 719 (torch.matmul(q, k.transpose(-2, -1)).div(3.0) + attn_mask) 720 .softmax(dim=-1) 721 .matmul(v) 722 ) 723 724 self._check_common(dot_prod_attention) 725 726 @skipIfRocm 727 def _test_sdpa_rewriter_15(self): 728 def dot_prod_attention( 729 query: torch.Tensor, key: torch.Tensor, value: torch.Tensor 730 ) -> torch.Tensor: 731 """Input tensors assumed to have shape (batch_size, seq_len, n_head, embed_dim)""" 732 q = query.transpose(1, 2) 733 k = key.transpose(1, 2) 734 v = value.transpose(1, 2) 735 bs = q.size(0) 736 k_len = k.size(-2) 737 attn_mask = torch.ones( 738 bs, k_len, dtype=torch.bool, device=query.device 739 ).tril(diagonal=0) 740 scores = torch.matmul(q, k.transpose(-2, -1)) / 3.0 741 attn_mask = (attn_mask == 0).view((bs, 1, 1, k_len)).expand_as(scores) 742 scores = scores.masked_fill(attn_mask, -float("inf")) 743 weights = torch.nn.functional.softmax(scores, dim=-1) 744 return torch.matmul(weights, v) 745 746 self._check_common(dot_prod_attention, check_train=False) 747 748 @skipIfRocm 749 def _test_sdpa_rewriter_16(self): 750 def dot_prod_attention( 751 query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, training 752 ) -> torch.Tensor: 753 """Input tensors assumed to have shape (batch_size, seq_len, n_head, embed_dim)""" 754 attn_mask = torch.ones( 755 query.size(1), key.size(1), dtype=torch.bool, device=query.device 756 ).tril(diagonal=0) 757 attn_mask = attn_mask.masked_fill( 758 torch.logical_not(attn_mask), -float("inf") 759 ) 760 q = query.permute(0, 2, 1, 3) 761 k = key.permute(0, 2, 1, 3) 762 v = value.permute(0, 2, 1, 3) 763 return torch.nn.functional.dropout( 764 (torch.matmul(q, k.transpose(-2, -1)).div(3.0) + attn_mask).softmax( 765 dim=-1 766 ), 767 p=0.4, 768 training=training, 769 inplace=False, 770 ).matmul(v) 771 772 self._check_common(dot_prod_attention, contains=False, has_dropout=True) 773 774 # also check batch_size=1 because the graph is slightly different 775 tensor_shape = (1, 2, 16, 32) 776 args = [ 777 torch.randn(tensor_shape, device=self.device), 778 torch.randn(tensor_shape, device=self.device), 779 torch.randn(tensor_shape, device=self.device), 780 ] 781 self._check_common( 782 dot_prod_attention, args1=args, contains=False, has_dropout=True 783 ) 784 785 @skipIfRocm 786 def _test_sdpa_rewriter_16_fp32_mask(self): 787 def dot_prod_attention( 788 query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, training 789 ) -> torch.Tensor: 790 """Input tensors assumed to have shape (batch_size, seq_len, n_head, embed_dim)""" 791 attn_mask = torch.randn( 792 query.size(1), key.size(1), dtype=torch.float, device=query.device 793 ).tril(diagonal=0) 794 q = query.permute(0, 2, 1, 3) 795 k = key.permute(0, 2, 1, 3) 796 v = value.permute(0, 2, 1, 3) 797 return torch.nn.functional.dropout( 798 (torch.matmul(q, k.transpose(-2, -1)).div(3.0) + attn_mask).softmax( 799 dim=-1 800 ), 801 p=0.4, 802 training=training, 803 inplace=False, 804 ).matmul(v) 805 806 self._check_common(dot_prod_attention, contains=False, has_dropout=True) 807 808 # also check batch_size=1 because the graph is slightly different 809 tensor_shape = (1, 2, 16, 32) 810 args = [ 811 torch.randn(tensor_shape, device=self.device), 812 torch.randn(tensor_shape, device=self.device), 813 torch.randn(tensor_shape, device=self.device), 814 ] 815 self._check_common( 816 dot_prod_attention, args1=args, contains=False, has_dropout=True 817 ) 818 819 @skipIfRocm 820 def _test_sdpa_rewriter_17(self): 821 def dot_prod_attention( 822 query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, training 823 ) -> torch.Tensor: 824 """Input tensors assumed to have shape (batch_size, seq_len, n_head, embed_dim)""" 825 q = query.transpose(1, 2) 826 k = key.transpose(1, 2) 827 v = value.transpose(1, 2) 828 bs = q.size(0) 829 k_len = k.size(-2) 830 attn_mask = torch.ones( 831 bs, k_len, dtype=torch.bool, device=query.device 832 ).tril(diagonal=0) 833 scores = torch.matmul(q, k.transpose(-2, -1)) / 3.0 834 attn_mask = (attn_mask == 0).view((bs, 1, 1, k_len)).expand_as(scores) 835 scores = scores.masked_fill(attn_mask, -float("inf")) 836 weights = torch.nn.functional.softmax(scores, dim=-1) 837 weights = torch.nn.functional.dropout( 838 weights, 839 p=0.4, 840 training=training, 841 inplace=False, 842 ) 843 return torch.matmul(weights, v) 844 845 self._check_common(dot_prod_attention, check_train=False, has_dropout=True) 846 847 @skipIfRocm 848 def _test_sdpa_rewriter_18(self): 849 def dot_prod_attention( 850 query: torch.Tensor, 851 key: torch.Tensor, 852 value: torch.Tensor, 853 causal_mask: torch.Tensor, 854 ) -> torch.Tensor: 855 # for hf_GPT2 with dropout 856 query = query.permute([0, 2, 1, 3]) 857 key = key.permute([0, 2, 1, 3]) 858 value = value.permute([0, 2, 1, 3]) 859 attn_weights = torch.matmul(query, key.permute(0, 1, 3, 2)) 860 inv_scale = torch.full( 861 (), math.sqrt(value.size(-1)), dtype=query.dtype, device=query.device 862 ) 863 attn_weights = attn_weights.div(inv_scale) 864 causal_mask_value = torch.full( 865 (), torch.finfo(query.dtype).min, dtype=query.dtype, device=query.device 866 ) 867 attn_weights = torch.where(causal_mask, attn_weights, causal_mask_value) 868 return ( 869 ( 870 torch.nn.functional.dropout( 871 attn_weights.softmax(dim=-1), 0.0 872 ).matmul(value) 873 ), 874 key.permute([0, 2, 1, 3]), 875 value.permute([0, 2, 1, 3]), 876 ) 877 878 tensor_shape = (4, 2, 16, 32) 879 causal_mask = torch.ones(2, 2, dtype=torch.bool, device=self.device).tril( 880 diagonal=0 881 ) 882 args = [ 883 torch.randn(tensor_shape, device=self.device), 884 torch.randn(tensor_shape, device=self.device), 885 torch.randn(tensor_shape, device=self.device), 886 causal_mask, 887 ] 888 self._check_common( 889 dot_prod_attention, 890 args1=args, 891 contains=False, 892 has_dropout=False, 893 check_train=False, 894 ) 895 896 # also check batch_size=1 because the graph is slightly different 897 tensor_shape = (1, 2, 16, 32) 898 args = [ 899 torch.randn(tensor_shape, device=self.device), 900 torch.randn(tensor_shape, device=self.device), 901 torch.randn(tensor_shape, device=self.device), 902 causal_mask, 903 ] 904 self._check_common( 905 dot_prod_attention, 906 args1=args, 907 contains=False, 908 has_dropout=False, 909 check_train=False, 910 ) 911 912 @skipIfRocm 913 def _test_sdpa_rewriter_19(self): 914 def dot_prod_attention( 915 query: torch.Tensor, 916 key: torch.Tensor, 917 value: torch.Tensor, 918 causal_mask: torch.Tensor, 919 attn_mask: torch.Tensor, 920 training, 921 ) -> torch.Tensor: 922 attn_weights = torch.matmul(query, key.permute(0, 1, 3, 2)) 923 inv_scale = torch.full( 924 (), 925 math.sqrt(value.size(-1)), 926 dtype=attn_weights.dtype, 927 device=attn_weights.device, 928 ) 929 attn_weights = attn_weights.div(inv_scale) 930 causal_mask_value = torch.full( 931 (), torch.finfo(query.dtype).min, dtype=query.dtype, device=query.device 932 ) 933 attn_weights = torch.where(causal_mask, attn_weights, causal_mask_value) 934 attn_weights = attn_weights + attn_mask 935 attn_weights = attn_weights.softmax(dim=-1).type(value.dtype) 936 return torch.nn.functional.dropout( 937 attn_weights, 938 p=0.4, 939 training=training, 940 inplace=False, 941 ).matmul(value) 942 943 tensor_shape = (4, 2, 16, 32) 944 causal_mask = torch.ones(16, 16, dtype=torch.bool, device=self.device).tril( 945 diagonal=0 946 ) 947 attn_mask = torch.randn((16, 16), dtype=torch.float, device=self.device) 948 args = [ 949 torch.randn(tensor_shape, device=self.device), 950 torch.randn(tensor_shape, device=self.device), 951 torch.randn(tensor_shape, device=self.device), 952 causal_mask, 953 attn_mask, 954 ] 955 self._check_common( 956 dot_prod_attention, 957 args1=args, 958 contains=False, 959 has_dropout=True, 960 check_train=False, 961 ) 962 963 964if HAS_CUDA and PLATFORM_SUPPORTS_FUSED_ATTENTION: 965 966 class SDPAPatternRewriterCudaTests(TestSDPAPatternRewriterTemplate): 967 device = "cuda" 968 test_sdpa_rewriter_1_cuda = ( 969 TestSDPAPatternRewriterTemplate._test_sdpa_rewriter_1 970 ) 971 test_sdpa_rewriter_1_freezing = ( 972 TestSDPAPatternRewriterTemplate._test_sdpa_rewriter_1_freezing 973 ) 974 test_insignificant_strides = ( 975 TestSDPAPatternRewriterTemplate._test_insignificant_strides 976 ) 977 test_pattern_fails_with_reuse_cuda = ( 978 TestSDPAPatternRewriterTemplate._test_pattern_fails_with_reuse 979 ) 980 test_sdpa_rewriter_2_cuda = ( 981 TestSDPAPatternRewriterTemplate._test_sdpa_rewriter_2 982 ) 983 test_sdpa_rewriter_3_cuda = ( 984 TestSDPAPatternRewriterTemplate._test_sdpa_rewriter_3 985 ) 986 test_sdpa_rewriter_4_cuda = ( 987 TestSDPAPatternRewriterTemplate._test_sdpa_rewriter_4 988 ) 989 test_sdpa_rewriter_5_cuda = ( 990 TestSDPAPatternRewriterTemplate._test_sdpa_rewriter_5 991 ) 992 test_sdpa_rewriter_6_cuda = ( 993 TestSDPAPatternRewriterTemplate._test_sdpa_rewriter_6 994 ) 995 test_sdpa_rewriter_7_cuda = ( 996 TestSDPAPatternRewriterTemplate._test_sdpa_rewriter_7 997 ) 998 test_sdpa_rewriter_8_cuda = ( 999 TestSDPAPatternRewriterTemplate._test_sdpa_rewriter_8 1000 ) 1001 test_sdpa_rewriter_9_cuda = ( 1002 TestSDPAPatternRewriterTemplate._test_sdpa_rewriter_9 1003 ) 1004 test_sdpa_rewriter_10_cuda = ( 1005 TestSDPAPatternRewriterTemplate._test_sdpa_rewriter_10 1006 ) 1007 test_pattern_fails_with_tensor_factor_cuda = ( 1008 TestSDPAPatternRewriterTemplate._test_pattern_fails_with_tensor_factor 1009 ) 1010 test_pattern_fails_with_unsupported_mask_cuda = ( 1011 TestSDPAPatternRewriterTemplate._test_pattern_fails_with_unsupported_mask 1012 ) 1013 test_sdpa_rewriter_11_cuda = ( 1014 TestSDPAPatternRewriterTemplate._test_sdpa_rewriter_11 1015 ) 1016 test_sdpa_rewriter_12_cuda = ( 1017 TestSDPAPatternRewriterTemplate._test_sdpa_rewriter_12 1018 ) 1019 test_sdpa_prev_13_cuda = TestSDPAPatternRewriterTemplate._test_sdpa_prev_13 1020 test_sdpa_prev_14_cuda = TestSDPAPatternRewriterTemplate._test_sdpa_prev_14 1021 test_sdpa_prev_15_cuda = TestSDPAPatternRewriterTemplate._test_sdpa_prev_15 1022 test_sdpa_rewriter_13_cuda = functools.partialmethod( 1023 TestSDPAPatternRewriterTemplate._test_sdpa_rewriter_13, dtype=torch.half 1024 ) 1025 test_sdpa_rewriter_14_cuda = functools.partialmethod( 1026 TestSDPAPatternRewriterTemplate._test_sdpa_rewriter_14 1027 ) 1028 test_sdpa_rewriter_15_cuda = functools.partialmethod( 1029 TestSDPAPatternRewriterTemplate._test_sdpa_rewriter_15 1030 ) 1031 test_sdpa_rewriter_17_cuda = functools.partialmethod( 1032 TestSDPAPatternRewriterTemplate._test_sdpa_rewriter_17 1033 ) 1034 test_sdpa_rewriter_19_cuda = functools.partialmethod( 1035 TestSDPAPatternRewriterTemplate._test_sdpa_rewriter_19 1036 ) 1037 1038 class SDPAPatternRewriterCudaDynamicTests(SDPAPatternRewriterCudaTests): 1039 use_static_shapes = False 1040 1041 1042if HAS_CPU: 1043 1044 class SDPAPatternRewriterCpuTests(TestSDPAPatternRewriterTemplate): 1045 device = "cpu" 1046 test_sdpa_rewriter_1_cpu = TestSDPAPatternRewriterTemplate._test_sdpa_rewriter_1 1047 test_pattern_fails_with_reuse_cpu = ( 1048 TestSDPAPatternRewriterTemplate._test_pattern_fails_with_reuse 1049 ) 1050 test_sdpa_rewriter_2_cpu = TestSDPAPatternRewriterTemplate._test_sdpa_rewriter_2 1051 test_sdpa_rewriter_5_cpu = TestSDPAPatternRewriterTemplate._test_sdpa_rewriter_5 1052 test_pattern_fails_with_tensor_factor_cpu = ( 1053 TestSDPAPatternRewriterTemplate._test_pattern_fails_with_tensor_factor 1054 ) 1055 test_pattern_fails_with_unsupported_mask_cpu = ( 1056 TestSDPAPatternRewriterTemplate._test_pattern_fails_with_unsupported_mask 1057 ) 1058 test_sdpa_rewriter_11_cpu = ( 1059 TestSDPAPatternRewriterTemplate._test_sdpa_rewriter_11 1060 ) 1061 test_sdpa_rewriter_12_cpu = ( 1062 TestSDPAPatternRewriterTemplate._test_sdpa_rewriter_12 1063 ) 1064 test_sdpa_prev_13_cpu = TestSDPAPatternRewriterTemplate._test_sdpa_prev_13 1065 test_sdpa_prev_14_cpu = TestSDPAPatternRewriterTemplate._test_sdpa_prev_14 1066 test_sdpa_prev_15_cpu = TestSDPAPatternRewriterTemplate._test_sdpa_prev_15 1067 test_sdpa_rewriter_13_cpu = functools.partialmethod( 1068 TestSDPAPatternRewriterTemplate._test_sdpa_rewriter_13, dtype=torch.float32 1069 ) 1070 test_sdpa_rewriter_14_cpu = functools.partialmethod( 1071 TestSDPAPatternRewriterTemplate._test_sdpa_rewriter_14 1072 ) 1073 test_sdpa_rewriter_15_cpu = functools.partialmethod( 1074 TestSDPAPatternRewriterTemplate._test_sdpa_rewriter_15 1075 ) 1076 test_sdpa_rewriter_16_cpu = functools.partialmethod( 1077 TestSDPAPatternRewriterTemplate._test_sdpa_rewriter_16 1078 ) 1079 test_sdpa_rewriter_16_fp32_mask_cpu = functools.partialmethod( 1080 TestSDPAPatternRewriterTemplate._test_sdpa_rewriter_16_fp32_mask 1081 ) 1082 test_sdpa_rewriter_17_cpu = functools.partialmethod( 1083 TestSDPAPatternRewriterTemplate._test_sdpa_rewriter_17 1084 ) 1085 test_sdpa_rewriter_18_cpu = functools.partialmethod( 1086 TestSDPAPatternRewriterTemplate._test_sdpa_rewriter_18 1087 ) 1088 test_sdpa_rewriter_19_cpu = functools.partialmethod( 1089 TestSDPAPatternRewriterTemplate._test_sdpa_rewriter_19 1090 ) 1091 1092 class SDPAPatternRewriterCpuDynamicTests(SDPAPatternRewriterCpuTests): 1093 use_static_shapes = False 1094 1095 1096if __name__ == "__main__": 1097 if IS_LINUX: 1098 run_tests() 1099