1# Owner(s): ["module: fx.passes"] 2 3from dataclasses import dataclass 4import operator 5import logging 6import sys 7 8import torch 9from torch.fx._symbolic_trace import symbolic_trace 10 11from torch.fx.passes.infra.partitioner import CapabilityBasedPartitioner 12from torch.fx.passes.operator_support import OperatorSupport 13from torch.fx.passes.utils.fuser_utils import fuse_by_partitions 14from torch.fx.passes.utils.matcher_utils import SubgraphMatcher 15 16from torch.testing._internal.common_utils import run_tests, parametrize, instantiate_parametrized_tests 17from torch.testing._internal.jit_utils import JitTestCase 18 19logging.basicConfig(level=logging.WARNING) 20logger = logging.getLogger(__name__) 21 22class TestModule(torch.nn.Module): 23 def __init__(self) -> None: 24 super().__init__() 25 self.linear = torch.nn.Linear(4, 4) 26 self.linear2 = torch.nn.Linear(4, 4) 27 self.param = torch.nn.Parameter(torch.rand(4, 4)) 28 29 def forward(self, a, b, c): 30 add = a + b 31 32 linear_1 = self.linear(add) 33 34 add_1 = add + c 35 add_2 = add_1 + self.param 36 add_3 = add_1 + linear_1 37 add_4 = add_2 + add_3 38 39 linear_2 = self.linear2(add_4) 40 41 add_5 = linear_2 + add_4 42 add_6 = add_5 + a 43 relu = add_6.relu() 44 45 return add_4, add_6, relu 46 47class TestDeepModule(torch.nn.Module): 48 def __init__(self) -> None: 49 super().__init__() 50 self.linear = torch.nn.Linear(4, 4) 51 52 def forward(self, a, b, c): 53 o = a + b 54 o = o + 1.0 55 56 # testing to avoid DFS uses in passes. Since Python has max recursion depth. 57 for _ in range(sys.getrecursionlimit() + 1): 58 o = o - c 59 60 return o 61 62 63class TestPartitionFunctions: 64 @staticmethod 65 def forward1(a, b, c): 66 add = a + b 67 add_1 = add + b 68 add_2 = add_1 + c 69 relu_1 = add_2.relu() 70 add_3 = add_1 + add_2 71 add_4 = add_1 + relu_1 + add_3 72 relu_2 = add_4.relu() 73 add_5 = relu_2 + add_4 74 add_6 = add_5 + add_4 75 return add_4, add_6 76 77 @staticmethod 78 def forward2(a, b, _): 79 add = a + b 80 add_1 = add + b 81 relu_1 = add_1.relu() # blocked by this 82 add_3 = add_1 + relu_1 83 add_4 = add_1 + add_3 84 return add_4, add_1 85 86 @staticmethod 87 def forward3(a, b, c): 88 add = a + b 89 add_1 = a + c 90 add_2 = b + c 91 return add, add_1, add_2 92 93 @staticmethod 94 def forward4(a, b, c): 95 add = a + b 96 add_1 = a + c 97 add_2 = b + c 98 return torch.where(add > 0, add_1, add_2) 99 100 @staticmethod 101 def forward5(a, b, c): 102 # add should be fused right branch, as left branch is not supported 103 add = a + 1 104 # left branch 105 relu = add.relu() 106 # right branch 107 add_1 = add + 2 108 return relu, add_1 109 110 @staticmethod 111 def forward6(a, b, c): 112 # add should have its own partition, as neither branchs are supported 113 add = a + 1 114 # left branch 115 relu = add.relu() 116 # right branch 117 relu_1 = add.relu() 118 return relu, relu_1 119 120 @staticmethod 121 def forward7(a, b, c): 122 # both branches are supported, all adds should be fused together 123 add = a + 1 124 # left branch 125 add_1 = add + 2 126 # right branch is larger 127 add_2 = add + 1 128 add_3 = add_2 + 1 129 return add_3, add_1 130 131 @staticmethod 132 def forward8(a, b, c): 133 # both branches are in the same partition, add should join the same partition 134 add = a + 1 135 # left branch 136 add_1 = add + 2 137 # right branch 138 add_2 = add + 1 139 # left and right branch merges 140 add_3 = add_2 + add_1 141 142 return add_3 143 144 @staticmethod 145 def forward9(a, b, c): 146 add = a + 1 147 # branch 1 148 add_1 = add + 1 149 # branch 2 150 add_2 = add + 1 151 # branch_3 152 add_3 = add + 1 153 out = torch.stack([add_1, add_2, add_3]) 154 return out 155 156 @staticmethod 157 def forward10(a, b, c): 158 add = a + 1 159 # branch 1 160 add_1 = add + 1 161 # branch 2 162 add_2 = add + 1 163 # branch 3: depends on branch 2 164 add_3 = add + add_2 165 out = torch.stack([add_1, add_2, add_3]) 166 return out 167 168 @staticmethod 169 def forward11(a, b, c): 170 add = a + 1 171 # branch 1 172 add_1 = add.relu() 173 # branch 2 depends on branch 1 174 add_2 = add + add_1 175 # branch 3 176 add_3 = add.relu() 177 out = torch.stack([add_1, add_2, add_3]) 178 return out 179 180 @staticmethod 181 def forward12(a, b, c): 182 b0 = a + 1.0 183 c0 = a + 1.5 184 x0 = b0.relu() 185 x1 = c0.relu() 186 b1 = b0 + x1 187 c1 = c0 + 1.2 188 # c2 has dependency on x0 & b0, when we merge {c0, c1, c2} 189 # this dependency should be updated to the fusion group and reflected 190 # on the decision to not fuse b0 & b1, which forms a cyclic dependency in 191 # the new graph 192 c2 = x0 + c0 193 return b1, c2 194 195 @staticmethod 196 def forward13(a, b, c): 197 a0, a1, a2, a3 = a.split(1, 0) 198 b1 = a0 + b 199 c1 = a1 + c 200 return b1 + c1 201 202 @staticmethod 203 def forward14(a, b, c): 204 a0, a1 = torch.ops.aten.std_mean(a) 205 out = a0 + 1.0 206 return out 207 208 @staticmethod 209 def forward15(a, b, c): 210 a0 = torch.ops.aten.view(a, [2, 2]) 211 a1 = torch.ops.aten.permute(a0, [1, 0]) 212 a2 = a1 + 1.0 213 a3 = torch.ops.aten.permute(a2, [1, 0]) 214 a4 = a3 + 1.0 215 a5 = torch.ops.aten.permute(a4, [1, 0]) 216 return torch.ops.aten.permute(a5, [1, 0]) 217 218 @staticmethod 219 def forward16(a, b, c): 220 a0 = a - 1.0 221 a1 = torch.ops.aten.view(a0, [2, 2]) 222 a2 = torch.ops.aten.permute(a1, [1, 0]) 223 a3 = a2 + 1.0 224 a4 = torch.ops.aten.permute(a3, [1, 0]) 225 a5 = a4 + 1.0 226 a6 = torch.ops.aten.permute(a5, [1, 0]) 227 a7 = torch.ops.aten.permute(a6, [1, 0]) 228 return a7 - 1.0 229 230 @staticmethod 231 def forward17(a, b, c, d, e, f): 232 a0 = a + b 233 a1 = c + d 234 a2 = e + f 235 return a0, a1, a2 236 237 @staticmethod 238 def forward18(a, b, c): 239 a0, a1 = torch.ops.aten.var_mean(a) 240 return a0 241 242# A mock OperatorSupport class, where only operator.add is supported 243class MockOperatorSupport(OperatorSupport): 244 def is_node_supported(self, submodules, node: torch.fx.Node) -> bool: 245 return (node.op == "call_function" and 246 node.target in {operator.add, operator.getitem, 247 torch.ops.aten.view, 248 torch.ops.aten.permute, 249 torch.ops.aten.std_mean}) 250 251@instantiate_parametrized_tests 252class TestFXGraphPasses(JitTestCase): 253 254 @parametrize("fn, expected_partition, bookend_non_compute_pass", [ 255 (TestPartitionFunctions.forward1, [["add_7", "add_6"], ["add_5", "add_4", "add_3"], ["add_2", "add_1", "add"]], False), 256 (TestPartitionFunctions.forward2, [["add_3", "add_2"], ["add_1", "add"]], False), 257 258 # 1 horizontal fusion with common producer 259 (TestPartitionFunctions.forward3, [["add_2", "add_1", "add"]], False), 260 (TestPartitionFunctions.forward4, [["add_2", "add_1", "add"]], False), 261 262 # 2 branches cases 263 (TestPartitionFunctions.forward5, [["add_1", "add"]], False), 264 (TestPartitionFunctions.forward6, [["add"]], False), 265 (TestPartitionFunctions.forward7, [["add_3", "add_2", "add", "add_1"]], False), 266 (TestPartitionFunctions.forward8, [["add_3", "add_2", "add", "add_1"]], False), 267 268 # 3 branch cases 269 (TestPartitionFunctions.forward9, [['add_3', 'add_2', 'add_1', 'add']], False), 270 (TestPartitionFunctions.forward10, [['add_3', 'add_2', 'add', 'add_1']], False), 271 (TestPartitionFunctions.forward11, [['add_1'], ['add']], False), 272 273 # 4 not necessarily the only partition, just to verify that there's no cyclic dependency after partition 274 (TestPartitionFunctions.forward12, [["add_2", "add_3", "add_4"], ["add", "add_1"]], False), 275 276 # 5 getitem special case 277 (TestPartitionFunctions.forward13, [["add_2", "add_1", "add"]], False), 278 (TestPartitionFunctions.forward14, [["add", "std_mean", "getitem", "getitem_1"]], False), 279 280 # 6 bookend non_compute pass 281 (TestPartitionFunctions.forward15, [["permute_1", "add_1", "add"]], True), 282 (TestPartitionFunctions.forward15, [['add_1', 'add', 'permute_1', 'view', 'permute_2', 'permute_3', 'permute']], False), 283 (TestPartitionFunctions.forward16, [["permute_1", "add_1", "add"]], True), 284 (TestPartitionFunctions.forward16, [['add_1', 'add', 'permute_1', 'view', 'permute_2', 'permute_3', 'permute']], False), 285 # should be empty partition, not a partiton with empty nodes 286 (TestPartitionFunctions.forward18, [], False), 287 ]) 288 def test_partitioner(self, fn, expected_partition, bookend_non_compute_pass): 289 traced = symbolic_trace(fn) 290 291 non_compute_ops = [] 292 if bookend_non_compute_pass: 293 non_compute_ops = ["torch.ops.aten.view", "torch.ops.aten.permute"] 294 295 supported_ops = MockOperatorSupport() 296 partitioner = CapabilityBasedPartitioner(traced, 297 supported_ops, 298 allows_single_node_partition=True, 299 non_compute_ops=non_compute_ops) 300 partitions = partitioner.propose_partitions() 301 if bookend_non_compute_pass: 302 partitioner.remove_bookend_non_compute_ops(partitions) 303 304 partitions_name = [[node.name for node in partition.nodes] for partition in partitions] 305 assert len(partitions_name) == len(expected_partition) 306 for i in range(len(partitions_name)): 307 assert set(partitions_name[i]) == set(expected_partition[i]) 308 309 fused_graph = partitioner.fuse_partitions(partitions) 310 311 a, b, c = torch.rand(4), torch.rand(4), torch.rand(4) 312 313 expected = fn(a, b, c) 314 result = fused_graph(a, b, c) 315 torch.testing.assert_close(expected, result) 316 317 @parametrize("fn, expected_partition", [ 318 (TestPartitionFunctions.forward17, [['add', 'add_1', 'add_2']]), 319 ]) 320 def test_partitioner_independent_output(self, fn, expected_partition): 321 traced = symbolic_trace(fn) 322 323 supported_ops = MockOperatorSupport() 324 partitioner = CapabilityBasedPartitioner(traced, 325 supported_ops, 326 allows_single_node_partition=True) 327 partitions = partitioner.propose_partitions() 328 partitions_name = [[node.name for node in partition.nodes] for partition in partitions] 329 assert len(partitions_name) == len(expected_partition) 330 for i in range(len(partitions_name)): 331 assert set(partitions_name[i]) == set(expected_partition[i]) 332 333 fused_graph = partitioner.fuse_partitions(partitions) 334 335 a, b, c, d, e, f = torch.rand(4), torch.rand(4), torch.rand(4), torch.rand(4), torch.rand(4), torch.rand(4) 336 337 expected = fn(a, b, c, d, e, f) 338 result = fused_graph(a, b, c, d, e, f) 339 torch.testing.assert_close(expected, result) 340 341 @parametrize("partition", [ 342 [['add', 'add_1'], ['add_5', 'add_6']], 343 [['add', 'add_1', 'add_2']], # vertical fusion 344 [['add_2', 'add_3']], # horizontal fusion 345 [['add_3', 'add_4']], 346 [['add_6', 'add_5']], # arbitray node order 347 [['add_4', 'add_1', 'add_3', 'add_2']], # arbitray node order 348 [['add_5', 'add_6'], ['add_1', 'add_2', 'add_3', 'add_4']], # arbitray partition order 349 [['add_5', 'linear2']], # includes call_function + call_module node 350 [['add_6', 'relu']], # includes call_function + call_module node 351 [['param', 'add_2']], # includes get_attr + call_module nodes 352 [['param', 'add_1', 'linear']], # includes get_attr + call_function + call_module nodes 353 [["add", "linear", "add_1", "param", "add_2", "add_3", "add_4", "linear2", "add_5", "add_6", "relu"]], # full graph 354 ]) 355 def test_fuser_util(self, partition): 356 m = TestModule() 357 gm = symbolic_trace(m) 358 359 nodes_by_name = {node.name : node for node in gm.graph.nodes} 360 361 partitions = [] 362 for node_names in partition: 363 partitions.append([nodes_by_name[name] for name in node_names]) 364 365 fused_graph = fuse_by_partitions(gm, partitions) 366 367 a, b, c = torch.rand(4), torch.rand(4), torch.rand(4) 368 369 expected = m(a, b, c) 370 result = fused_graph(a, b, c) 371 372 torch.testing.assert_close(expected, result) 373 374 @parametrize("partition", [ 375 [['add', 'add_1'], ['add_1', 'add_5', 'add_6']], # add_1 exists in multiple partitions 376 [['add', 'add_1', 'add_3']], # invalid partition: circular dependency 377 [['add_4', 'add_5']], # invalid partition: circular dependency 378 [['relu', 'add_5']], # invalid partition: circular dependency 379 ]) 380 def test_fuser_util_xfail(self, partition): 381 m = TestModule() 382 gm = symbolic_trace(m) 383 384 nodes_by_name = {node.name : node for node in gm.graph.nodes} 385 386 partitions = [] 387 for node_names in partition: 388 partitions.append([nodes_by_name[name] for name in node_names]) 389 390 with self.assertRaises(Exception): 391 fuse_by_partitions(gm, partitions) 392 393 def test_fuser_pass_deep_model(self): 394 m = TestDeepModule() 395 traced = symbolic_trace(m) 396 397 supported_ops = MockOperatorSupport() 398 partitioner = CapabilityBasedPartitioner(traced, 399 supported_ops, 400 allows_single_node_partition=True) 401 partitions = partitioner.propose_partitions() 402 403@dataclass 404class TestCase: 405 match_output: bool 406 match_placeholder: bool 407 num_matches: int 408 remove_overlapping_matches: bool = True 409 410class SingleNodePattern: 411 @staticmethod 412 def forward(x): 413 val = torch.neg(x) 414 return torch.add(val, val) 415 416 @staticmethod 417 def pattern(a): 418 return torch.neg(a) 419 420 test_cases = [ 421 # match_output, match_placeholder, num_matches 422 TestCase(False, False, 1), 423 TestCase(True, False, 0), 424 TestCase(False, True, 1), 425 TestCase(True, True, 0) 426 ] 427class SimplePattern: 428 @staticmethod 429 def forward(x, w1, w2): 430 m1 = torch.cat([w1, w2]).sum() 431 m2 = torch.cat([w2, w1]).sum() 432 m3 = torch.cat([m1, m2]).sum() 433 return x + torch.max(m1) + torch.max(m2) + m3 434 435 @staticmethod 436 def pattern(a, b): 437 return torch.cat([a, b]).sum() 438 439 test_cases = [ 440 # match_output, match_placeholder, num_matches 441 TestCase(False, False, 3), 442 TestCase(True, False, 0), 443 TestCase(False, True, 2), 444 TestCase(True, True, 0) 445 ] 446 447class SimpleFullGraphMatching: 448 @staticmethod 449 def forward(x): 450 a = torch.neg(x) 451 return torch.add(a, a) 452 453 @staticmethod 454 def pattern(x): 455 a = torch.neg(x) 456 return torch.add(a, a) 457 458 test_cases = [ 459 # match_output, match_placeholder, num_matches 460 TestCase(False, False, 1), 461 TestCase(True, False, 1), 462 TestCase(False, True, 1), 463 TestCase(True, True, 1) 464 ] 465 466class DiamondShapePatternTestCase: 467 @staticmethod 468 def forward(x): 469 a = torch.neg(x) 470 471 a = a.relu() 472 left = a.sigmoid() 473 right = a.relu() 474 out = left + right 475 476 return out 477 478 @staticmethod 479 def pattern(a): 480 a = a.relu() 481 left = a.sigmoid() 482 right = a.relu() 483 out = left + right 484 return out 485 486 test_cases = [ 487 # match_output, match_placeholder, num_matches 488 TestCase(False, False, 1), 489 TestCase(True, False, 1), 490 TestCase(False, True, 0), 491 TestCase(True, True, 0) 492 ] 493 494class NonFullyContainedMatches: 495 @staticmethod 496 def forward(x, w1, w2, b1, b2): 497 # fully contained matched subgraph 498 m1 = torch.cat([w1, w2]) 499 m2 = torch.cat([x, b2]) 500 t0 = torch.addmm(b1, m1, m2.t()) 501 t0_sum = torch.sum(t0) # use of t0 is not leaking 502 503 # leaking matched subgraph, m3 is leaked 504 m3 = torch.cat([w1, w2]) 505 m4 = torch.cat([x, b2]) 506 t1 = torch.addmm(b1, m3, m4.t()) 507 m3_sum = torch.sum(m3) 508 509 return t0_sum, m3_sum 510 511 @staticmethod 512 def pattern(x, w1, w2, b1, b2): 513 m1 = torch.cat([w1, w2]) 514 m2 = torch.cat([x, b2]) 515 return torch.addmm(b1, m1, m2.t()) 516 517 test_cases = [ 518 # match_output, match_placeholder, num_matches 519 TestCase(False, False, 1), 520 521 TestCase(True, False, 0), 522 523 TestCase(False, True, 1), # leaked used of placeholder is not leaking 524 ] 525 526class ChainRepeatedPattern: 527 @staticmethod 528 def forward(x): 529 x = torch.sigmoid(x) 530 x = torch.sigmoid(x) 531 x = torch.sigmoid(x) 532 return torch.sigmoid(x) 533 534 @staticmethod 535 def pattern(x): 536 return torch.sigmoid(torch.sigmoid(x)) 537 538 test_cases = [ 539 # match_output, match_placeholder, num_matches 540 TestCase(False, False, 3, remove_overlapping_matches=False), 541 TestCase(False, False, 2, remove_overlapping_matches=True), 542 TestCase(True, False, 1), 543 TestCase(False, True, 1), 544 TestCase(True, True, 0) 545 ] 546 547class QuantizationModel: 548 @staticmethod 549 def forward(x): 550 x += 3 551 x = x.dequantize() 552 x = torch.sigmoid(x) 553 x = x.to(torch.float16) 554 return x 555 556 @staticmethod 557 def pattern(x): 558 x = x.dequantize() 559 x = torch.sigmoid(x) 560 x = x.to(torch.float16) 561 return x 562 563 test_cases = [ 564 # match_output, match_placeholder, num_matches 565 TestCase(False, False, 1), 566 TestCase(True, False, 1), 567 TestCase(False, True, 0), 568 TestCase(True, True, 0) 569 ] 570 571class MultipleOutputsWithDependency: 572 @staticmethod 573 def forward(x): 574 y = x.relu() 575 z = y.sigmoid() 576 return z, y 577 578 @staticmethod 579 def pattern(a): 580 b = a.relu() 581 c = b.sigmoid() 582 return b, c # outputs have data dependency 583 584 test_cases = [ 585 # match_output, match_placeholder, num_matches 586 TestCase(False, False, 1), 587 TestCase(True, False, 0), 588 TestCase(False, True, 1), 589 TestCase(True, True, 0) 590 ] 591 592class MultipleOutputsWithoutDependency: 593 @staticmethod 594 def forward(x): 595 x = x + 1 596 597 # target subgraph to match 598 x = x.relu() 599 z = x.sum() 600 y = x.sigmoid() 601 602 out = y.sigmoid() + z.sum() 603 return out 604 605 @staticmethod 606 def pattern(a): 607 a = a.relu() 608 b = a.sigmoid() 609 c = a.sum() 610 return b, c 611 612 test_cases = [ 613 # match_output, match_placeholder, num_matches 614 TestCase(False, False, 1), 615 TestCase(True, False, 0), 616 TestCase(False, True, 0), 617 TestCase(True, True, 0) 618 ] 619 620class MultipleOutputsMultipleOverlappingMatches: 621 @staticmethod 622 def forward(x): 623 x = x + 1 624 625 # target subgraph to match 626 x = x.relu() 627 z = x.sum() 628 z1 = x.sum() 629 y = x.sigmoid() 630 y1 = x.sigmoid() 631 632 return z + z1 + y + y1 633 634 @staticmethod 635 def pattern(a): 636 a = a.relu() 637 b = a.sigmoid() 638 c = a.sum() 639 return a, b, c 640 641 test_cases = [ 642 # match_output, match_placeholder, num_matches 643 TestCase(False, False, 4, remove_overlapping_matches=False), 644 TestCase(False, False, 1, remove_overlapping_matches=True), 645 ] 646 647class MultipleOutputsMultipleNonOverlappingMatches: 648 @staticmethod 649 def forward(x): 650 x = x + 1 651 652 # target subgraph to match 653 x = x.relu() 654 z = x.sum() 655 y = x.sigmoid() 656 657 x = x.relu() 658 z1 = x.sum() 659 y1 = x.sigmoid() 660 661 return z + z1 + y + y1 662 663 @staticmethod 664 def pattern(a): 665 a = a.relu() 666 b = a.sigmoid() 667 c = a.sum() 668 return b, c 669 670 test_cases = [ 671 # match_output, match_placeholder, num_matches 672 TestCase(False, False, 1), 673 ] 674 675class MultipleOutputsIdenticalAnchor: 676 @staticmethod 677 def forward(x): 678 x = x + 1 679 680 # target subgraph to match 681 x = x.relu() 682 y = x.sigmoid() 683 y1 = x.sigmoid() 684 685 return y, y1 686 687 @staticmethod 688 def pattern(a): 689 a = a.relu() 690 b = a.sigmoid() 691 b1 = a.sigmoid() 692 return b, b1 693 694 test_cases = [ 695 # match_output, match_placeholder, num_matches 696 # (False, False, 2), # FIXME: currently still matches to 2, should fix to 1 697 TestCase(True, False, 1), 698 TestCase(False, True, 0), 699 ] 700 701 702class MultipleOutputsHorizontalPattern: 703 @staticmethod 704 def forward(x): 705 x = x + 1 706 707 # target subgraph to match 708 y1 = x.relu() 709 y2 = x.sigmoid() 710 711 return y1, y2 712 713 @staticmethod 714 def pattern(a): 715 b1 = a.relu() 716 b2 = a.sigmoid() 717 718 return b1, b2 719 720 test_cases = [ 721 # match_output, match_placeholder, num_matches 722 TestCase(False, False, 1), 723 TestCase(True, False, 1), 724 TestCase(False, True, 0), 725 TestCase(True, True, 0) 726 ] 727 728class MultiOutputWithWithInvalidMatches: 729 @staticmethod 730 def forward(x): 731 res0 = torch.nn.functional.linear(x, torch.rand(3, 3)) 732 res1 = torch.sigmoid(res0) 733 res2 = res0 * res1 734 res3 = torch.sum(res2, dim=1) 735 return res3 736 737 @staticmethod 738 def pattern(a, b, c): 739 lin_res = torch.nn.functional.linear(a, b) 740 mul_res = lin_res * c 741 return lin_res, mul_res 742 743 test_cases = [ 744 # match_output, match_placeholder, num_matches 745 TestCase(False, False, 0), 746 TestCase(True, False, 0), 747 TestCase(False, True, 0), 748 ] 749 750class QuantizationFp8Pattern: 751 @classmethod 752 def setup(cls): 753 cls.quantization = torch.library.Library("fp8_quantization", "DEF") # noqa: TOR901 754 cls.quantization.define("quantize_per_tensor_affine_fp8(Tensor self, int dtype, float scale) -> Tensor") 755 cls.quantization.define("dequantize_per_tensor_affine_fp8(Tensor self, int dtype, float scale) -> Tensor") 756 757 @classmethod 758 def tearDown(cls): 759 del cls.quantization 760 761 @staticmethod 762 def forward(self, arg0_1, arg1_1): 763 qt = torch.ops.fp8_quantization 764 _scale_0 = self._scale_0 765 quantize_per_tensor_affine_fp8 = qt.quantize_per_tensor_affine_fp8(arg0_1, 0, _scale_0) 766 dequantize_per_tensor_affine_fp8 = qt.dequantize_per_tensor_affine_fp8(quantize_per_tensor_affine_fp8, 0, _scale_0) 767 _scale_1 = self._scale_0 768 quantize_per_tensor_affine_fp8_1 = qt.quantize_per_tensor_affine_fp8(arg1_1, 0, _scale_1) 769 dequantize_per_tensor_affine_fp8_1 = qt.dequantize_per_tensor_affine_fp8(quantize_per_tensor_affine_fp8_1, 0, _scale_1) 770 add = torch.ops.aten.add.Tensor(dequantize_per_tensor_affine_fp8, dequantize_per_tensor_affine_fp8_1) 771 _scale_2 = self._scale_0 772 quantize_per_tensor_affine_fp8_2 = qt.quantize_per_tensor_affine_fp8(add, 0, _scale_2) 773 dequantize_per_tensor_affine_fp8_2 = qt.dequantize_per_tensor_affine_fp8(quantize_per_tensor_affine_fp8_2, 0, _scale_2) 774 return dequantize_per_tensor_affine_fp8_2 775 776 @staticmethod 777 def pattern(a, a_dtype, a_scale, b, b_dtype, b_scale, out_scale): 778 qt = torch.ops.fp8_quantization 779 a = qt.dequantize_per_tensor_affine_fp8(a, a_dtype, a_scale) 780 b = qt.dequantize_per_tensor_affine_fp8(b, b_dtype, b_scale) 781 output = torch.ops.aten.add.Tensor(a, b) 782 783 qt.dequantize_per_tensor_affine_fp8 784 785 output = qt.quantize_per_tensor_affine_fp8(output, a_dtype, out_scale) 786 return output 787 788 test_cases = [ 789 # match_output, match_placeholder, num_matches 790 TestCase(False, False, 1), 791 ] 792 793class NoAnchorFound: 794 # This test case is for pattern where no matching anchor is found in the target graph 795 # `anchor` is the starting point of the pattern matching, it's usually the boundary returning nodes 796 @staticmethod 797 def forward(x): 798 x = x + 1 799 return x 800 801 @staticmethod 802 def pattern(a): 803 b1 = a.relu() 804 return b1 805 806 test_cases = [ 807 # match_output, match_placeholder, num_matches 808 TestCase(False, False, 0), 809 TestCase(True, False, 0), 810 TestCase(False, True, 0), 811 TestCase(True, True, 0) 812 ] 813 814@instantiate_parametrized_tests 815class TestFXMatcherUtils(JitTestCase): 816 817 @parametrize("test_model", [ 818 SingleNodePattern, 819 SimplePattern, 820 SimpleFullGraphMatching, 821 DiamondShapePatternTestCase, 822 NonFullyContainedMatches, 823 ChainRepeatedPattern, 824 QuantizationModel, 825 MultipleOutputsWithDependency, 826 MultipleOutputsWithoutDependency, 827 MultipleOutputsMultipleOverlappingMatches, 828 MultipleOutputsMultipleNonOverlappingMatches, 829 MultipleOutputsIdenticalAnchor, 830 MultipleOutputsHorizontalPattern, 831 MultiOutputWithWithInvalidMatches, 832 QuantizationFp8Pattern, 833 NoAnchorFound, 834 ]) 835 def test_subgraph_matcher(self, test_model): 836 837 setup = getattr(test_model, "setup", None) 838 if callable(setup): 839 setup() 840 841 traced = symbolic_trace(test_model.forward) 842 pattern_traced = symbolic_trace(test_model.pattern) 843 844 for test_case in test_model.test_cases: 845 846 matcher = SubgraphMatcher(pattern_traced.graph, 847 match_output=test_case.match_output, 848 match_placeholder=test_case.match_placeholder, 849 remove_overlapping_matches=test_case.remove_overlapping_matches) 850 matches = matcher.match(traced.graph) 851 852 assert len(matches) == test_case.num_matches 853 854 for match in matches: 855 for node in pattern_traced.graph.nodes: 856 if not test_case.match_placeholder and node.op == "placeholder": 857 continue 858 if not test_case.match_output and node.op == "output": 859 continue 860 assert node in match.nodes_map 861 862 tearDown = getattr(test_model, "tearDown", None) 863 if callable(setup): 864 tearDown() 865 866 867if __name__ == "__main__": 868 run_tests() 869