1# Owner(s): ["module: fx"] 2 3import os 4import sys 5 6import torch 7from torch.fx import subgraph_rewriter, symbolic_trace 8from torch.fx.annotate import annotate 9 10# Make the helper files in test/ importable 11from torch.fx.experimental.rewriter import RewritingTracer 12 13 14pytorch_test_dir = os.path.dirname(os.path.dirname(os.path.realpath(__file__))) 15sys.path.append(pytorch_test_dir) 16from torch.testing._internal.jit_utils import JitTestCase 17 18 19if __name__ == "__main__": 20 raise RuntimeError( 21 "This test file is not meant to be run directly, use:\n\n" 22 "\tpython test/test_fx.py TESTNAME\n\n" 23 "instead." 24 ) 25 26 27@torch.fx.wrap 28def wrapped_gemm_bias_mul(a, b, bias): 29 lin_res = torch.nn.functional.linear(a, b, bias=bias) 30 mul_res = lin_res * a 31 return lin_res, mul_res 32 33 34@torch.fx.wrap 35def wrapped_gemm_bias_mul_with_c(a, b, bias, c): 36 lin_res = torch.nn.functional.linear(a, b, bias=bias) 37 mul_res = lin_res * c 38 return lin_res, mul_res 39 40 41class TestSubgraphRewriter(JitTestCase): 42 def test_subgraph_rewriter_preserves_logic(self): 43 class M(torch.nn.Module): 44 def forward(self, x): 45 val = torch.neg(x) + torch.relu(x) 46 return torch.add(val, val) 47 48 def pattern(x): 49 return torch.neg(x) + torch.relu(x) 50 51 def comparison(x): 52 val = torch.neg(x) + torch.relu(x) 53 return torch.add(val, val) 54 55 traced = symbolic_trace(M()) 56 comparison_fn = symbolic_trace(comparison) 57 58 x = torch.rand(1, 3) 59 60 # Replace `pattern` with the same pattern (shouldn't change 61 # the underlying logic) 62 subgraph_rewriter.replace_pattern(traced, pattern, pattern) 63 64 traced.graph.lint() 65 66 ref_output = comparison_fn(x) 67 test_output = traced.forward(x) 68 self.assertEqual(ref_output, test_output) 69 70 def test_subgraph_rewriter_with_oneliner_pattern(self): 71 class M(torch.nn.Module): 72 def forward(self, x): 73 val = torch.neg(x) 74 return torch.add(val, val) 75 76 def pattern(x): 77 return torch.neg(x) 78 79 def replacement(x): 80 return torch.relu(x) 81 82 def comparison(x): 83 val = torch.relu(x) 84 return torch.add(val, val) 85 86 traced = symbolic_trace(M()) 87 comparison_fn = symbolic_trace(comparison) 88 89 x = torch.rand(1, 3) 90 91 subgraph_rewriter.replace_pattern(traced, pattern, replacement) 92 93 traced.graph.lint() 94 95 ref_output = comparison_fn(x) 96 test_output = traced.forward(x) 97 self.assertEqual(ref_output, test_output) 98 99 def test_subgraph_rewriter_with_trivial_replacement(self): 100 class M(torch.nn.Module): 101 def forward(self, x): 102 val = torch.neg(x) 103 val = torch.add(val, val) 104 return torch.add(val, val) 105 106 def pattern(x): 107 return torch.add(x, x) 108 109 def replacement(x): 110 return x 111 112 def comparison(x): 113 return torch.neg(x) 114 115 traced = symbolic_trace(M()) 116 comparison_fn = symbolic_trace(comparison) 117 118 x = torch.randn(1, 5) 119 120 matches = subgraph_rewriter.replace_pattern_with_filters( 121 traced, pattern, replacement, [] 122 ) 123 124 traced.graph.lint() 125 126 ref_output = comparison_fn(x) 127 test_output = traced.forward(x) 128 no_replacements = len(matches) == 2 and len(matches[1].replacements) == 0 129 self.assertEqual(ref_output, test_output) 130 self.assertTrue(no_replacements) 131 132 def test_subgraph_rewriter_single_pattern_match(self): 133 class M(torch.nn.Module): 134 def forward(self, x): 135 val = torch.neg(x) + torch.relu(x) 136 return torch.add(val, val) 137 138 def pattern(x): 139 return torch.neg(x) + torch.relu(x) 140 141 def replacement(x): 142 return torch.relu(x) 143 144 def comparison(x): 145 val = torch.relu(x) 146 return torch.add(val, val) 147 148 traced = symbolic_trace(M()) 149 comparison_fn = symbolic_trace(comparison) 150 151 x = torch.rand(1, 3) 152 153 subgraph_rewriter.replace_pattern(traced, pattern, replacement) 154 155 traced.graph.lint() 156 157 ref_output = comparison_fn(x) 158 test_output = traced.forward(x) 159 self.assertEqual(ref_output, test_output) 160 161 def test_subgraph_rewriter_multiple_pattern_match(self): 162 class M(torch.nn.Module): 163 def forward(self, x, w1, w2): 164 m1 = torch.cat([w1, w2]).sum() 165 m2 = torch.cat([w1, w2]).sum() 166 return x + torch.max(m1) + torch.max(m2) 167 168 def pattern(w1, w2): 169 return torch.cat([w1, w2]).sum() 170 171 def replacement(w1, w2): 172 return torch.stack([w1, w2]) 173 174 def comparison(x, w1, w2): 175 m1 = torch.stack([w1, w2]) 176 m2 = torch.stack([w1, w2]) 177 return x + torch.max(m1) + torch.max(m2) 178 179 traced = symbolic_trace(M()) 180 comparison_fn = symbolic_trace(comparison) 181 182 x = torch.rand(1, 3) 183 w1 = torch.rand(1, 3) 184 w2 = torch.rand(1, 3) 185 186 subgraph_rewriter.replace_pattern(traced, pattern, replacement) 187 188 traced.graph.lint() 189 190 ref_outs = comparison_fn(x, w1, w2) 191 test_outs = traced.forward(x, w1, w2) 192 self.assertEqual(ref_outs, test_outs) 193 194 def test_subgraph_rewriter_graph_argument_order(self): 195 class M(torch.nn.Module): 196 def forward(self, x, y): 197 return torch.mm(x, y) 198 199 def pattern(x, y): 200 return torch.mm(x, y) 201 202 def comparison(x, y): 203 return torch.mm(x, y) 204 205 traced = symbolic_trace(M()) 206 comparison_fn = symbolic_trace(comparison) 207 208 x = torch.randn(3, 4) 209 y = torch.randn(4, 5) 210 211 subgraph_rewriter.replace_pattern(traced, pattern, pattern) 212 213 traced.graph.lint() 214 215 ref_outs = comparison_fn(x, y) 216 test_outs = traced.forward(x, y) 217 self.assertEqual(ref_outs, test_outs) 218 219 def test_subgraph_rewriter_correct_output_replacement(self): 220 class M(torch.nn.Module): 221 def forward(self, x, y): 222 val = torch.neg(y) + torch.relu(x) 223 return torch.add(val, val) 224 225 def pattern(x): 226 return torch.relu(x) 227 228 def replacement(x): 229 return torch.neg(x) 230 231 def comparison(x, y): 232 val = torch.neg(y) + torch.neg(x) 233 return torch.add(val, val) 234 235 traced = symbolic_trace(M()) 236 comparison_fn = symbolic_trace(comparison) 237 238 x = torch.randn(4, 4) 239 y = torch.randn(4, 4) 240 241 subgraph_rewriter.replace_pattern(traced, pattern, replacement) 242 243 traced.graph.lint() 244 245 ref_outs = comparison_fn(x, y) 246 test_outs = traced.forward(x, y) 247 self.assertEqual(ref_outs, test_outs) 248 249 def test_subgraph_rewriter_traced_as_callable(self): 250 class M(torch.nn.Module): 251 def forward(self, x): 252 val = torch.neg(x) + torch.relu(x) 253 return torch.add(val, val) 254 255 class Pattern(torch.nn.Module): 256 def forward(self, x): 257 return torch.neg(x) + torch.relu(x) 258 259 class Replacement(torch.nn.Module): 260 def forward(self, x): 261 return torch.sigmoid(x) 262 263 def comparison(x): 264 val = torch.sigmoid(x) 265 return torch.add(val, val) 266 267 traced = symbolic_trace(M()) 268 traced_pattern = symbolic_trace(Pattern()) 269 traced_replacement = symbolic_trace(Replacement()) 270 comparison_fn = symbolic_trace(comparison) 271 272 x = torch.randn(3, 4) 273 274 subgraph_rewriter.replace_pattern(traced, traced_pattern, traced_replacement) 275 276 traced.graph.lint() 277 278 ref_outs = comparison_fn(x) 279 test_outs = traced.forward(x) 280 self.assertEqual(ref_outs, test_outs) 281 282 def test_subgraph_rewriter_pattern_is_entire_graph(self): 283 class M(torch.nn.Module): 284 def forward(self, x): 285 a = torch.neg(x) 286 return torch.add(a, a) 287 288 def pattern(x): 289 a = torch.neg(x) 290 return torch.add(a, a) 291 292 def replacement(x): 293 a = torch.sigmoid(x) 294 return torch.cat([a, a]) 295 296 traced = symbolic_trace(M()) 297 comparison_fn = symbolic_trace(replacement) 298 299 x = torch.randn(3, 4) 300 301 subgraph_rewriter.replace_pattern(traced, pattern, replacement) 302 303 traced.graph.lint() 304 305 ref_outs = comparison_fn(x) 306 test_outs = traced.forward(x) 307 self.assertEqual(ref_outs, test_outs) 308 309 def test_subgraph_rewriter_pattern_output_pattern_node_can_have_users_that_are_not_matched( 310 self, 311 ): 312 class M(torch.nn.Module): 313 def forward(self, x): 314 y = torch.relu(x) 315 return torch.neg(y) - y 316 317 def pattern(x): 318 return torch.relu(x) 319 320 def replacement(x): 321 return torch.sigmoid(x) 322 323 def comparison(x): 324 y = torch.sigmoid(x) 325 return torch.neg(y) - y 326 327 traced = symbolic_trace(M()) 328 comparison_fn = symbolic_trace(comparison) 329 330 x = torch.randn(3, 4) 331 332 subgraph_rewriter.replace_pattern(traced, pattern, replacement) 333 334 traced.graph.lint() 335 336 ref_outs = comparison_fn(x) 337 test_outs = traced.forward(x) 338 self.assertEqual(ref_outs, test_outs) 339 340 def test_subgraph_rewriter_internal_pattern_nodes_cannot_have_users_that_are_not_matched( 341 self, 342 ): 343 class M(torch.nn.Module): 344 def forward(self, x, w1, w2, b1, b2): 345 m0 = torch.cat([w1, w2]) 346 m1 = torch.cat([w1, w2]) 347 m2 = torch.cat([x, b2]) 348 t0 = torch.addmm(b1, m1, m2.t()) 349 t1 = torch.sum(w1, 1) 350 t2 = torch.addmm(b1, m1, m2.t()) 351 return torch.sum(t1), torch.sum(t2) 352 353 def pattern(x, w1, w2, b1, b2): 354 m1 = torch.cat([w1, w2]) 355 m2 = torch.cat([x, b2]) 356 return torch.addmm(b1, m1, m2.t()) 357 358 def replacement(x, w1, w2, b1, b2): 359 return torch.cat([x, w1, w2]) 360 361 traced = symbolic_trace(M()) 362 363 # Result should be [] since no matches can be found 364 res = subgraph_rewriter.replace_pattern(traced, pattern, replacement) 365 366 traced.graph.lint() 367 368 self.assertEqual(res, []) 369 370 def test_subgraph_rewriter_placeholder_matching(self): 371 """ 372 This tests that a placeholder Node can be matched to a Node with 373 a different number of input Nodes. In the example below, the 374 original traced Module looks like this: 375 376 opcode target args kwargs 377 ------------- ---------------------------------------------------------- ------------------------ -------- 378 placeholder x () {} 379 call_function <built-in function add> (x, 3) {} 380 call_method dequantize (add,) {} 381 call_function <built-in method sigmoid of type object at 0x7f7c1f440fe0> (dequantize,) {} 382 call_method to (sigmoid, torch.float16) {} 383 output output (to,) {} 384 385 while the pattern we want to match looks like this: 386 387 opcode target args kwargs 388 ------------- ---------------------------------------------------------- ------------------------ -------- 389 placeholder x () {} 390 call_method dequantize (x,) {} 391 call_function <built-in method sigmoid of type object at 0x7f7c1f440fe0> (dequantize,) {} 392 call_method to (sigmoid, torch.float16) {} 393 output output (to,) {} 394 395 Here, we want to be able to match the original graph's 396 `call_function.add` Node with the pattern graph's 397 `placeholder.x` Node. 398 399 Credit to Jerry Zhang (GitHub: jerryzh168) for this test case 400 """ 401 402 class M(torch.nn.Module): 403 def __init__(self) -> None: 404 super().__init__() 405 self.dtype = torch.float16 406 407 def forward(self, x): 408 x += 3 409 x = x.dequantize() 410 x = torch.sigmoid(x) 411 dtype = self.dtype 412 x = x.to(dtype) 413 return x 414 415 def pattern(x): 416 x = x.dequantize() 417 x = torch.sigmoid(x) 418 x = x.to(torch.float16) 419 return x 420 421 def replacement(x): 422 return x 423 424 def comparison(x): 425 return x + 3 426 427 traced = symbolic_trace(M()) 428 comparison_fn = symbolic_trace(comparison) 429 430 x = torch.randn(3, 4) 431 432 subgraph_rewriter.replace_pattern(traced, pattern, replacement) 433 434 traced.graph.lint() 435 436 ref_outs = comparison_fn(x) 437 test_outs = traced.forward(x) 438 self.assertEqual(ref_outs, test_outs) 439 440 def test_subgraph_rewriter_replaces_referenced_submodules(self): 441 class M(torch.nn.Module): 442 def __init__(self) -> None: 443 super().__init__() 444 self.sigmoid = torch.nn.Sigmoid() 445 self.submod = torch.nn.ReLU() 446 447 def forward(self, x): 448 x = x + 1 449 return self.submod(self.sigmoid(x)) 450 451 class Pattern(torch.nn.Module): 452 def __init__(self) -> None: 453 super().__init__() 454 self.sigmoid = torch.nn.Sigmoid() 455 self.submod = torch.nn.ReLU() 456 457 def forward(self, x): 458 return self.submod(self.sigmoid(x)) 459 460 class Replacement(torch.nn.Module): 461 def __init__(self) -> None: 462 super().__init__() 463 self.tanh = torch.nn.Tanh() 464 self.submod = torch.nn.ReLU() 465 466 def forward(self, x): 467 return self.submod(self.tanh(x)) 468 469 class Comparison(torch.nn.Module): 470 def __init__(self) -> None: 471 super().__init__() 472 self.tanh = torch.nn.Tanh() 473 self.submod = torch.nn.ReLU() 474 475 def forward(self, x): 476 x = x + 1 477 return self.submod(self.tanh(x)) 478 479 traced = symbolic_trace(M()) 480 comparison = Comparison() 481 482 x = torch.randn(3, 4) 483 484 subgraph_rewriter.replace_pattern(traced, Pattern(), Replacement()) 485 486 traced.graph.lint() 487 488 ref_outs = comparison(x) 489 test_outs = traced.forward(x) 490 self.assertEqual(ref_outs, test_outs) 491 492 traced.get_submodule("tanh") 493 with self.assertRaisesRegex(AttributeError, "has no attribute"): 494 traced.get_submodule("sigmoid") 495 496 submod = traced.get_submodule("submod") 497 self.assertEqual(type(submod), torch.nn.ReLU) 498 499 def test_subgraph_rewriter_annotations_int(self): 500 class M1(torch.nn.Module): 501 def forward(self, x): 502 y: int = x 503 return torch.add(x, y) 504 505 class M2(torch.nn.Module): 506 def forward(self, x): 507 y = annotate(x, int) 508 return torch.add(x, y) 509 510 ast_rewriter = RewritingTracer() 511 graph = ast_rewriter.trace(M1()) 512 513 module = M2() 514 symbolic_traced: torch.fx.GraphModule = symbolic_trace(module) 515 for n, m in zip(symbolic_traced.graph.nodes, graph.nodes): 516 if n.op == "placeholder": 517 assert n.type == int 518 assert m.type == int 519 520 def test_subgraph_rewriter_replace_consecutive_submodules(self): 521 def f(x): 522 x = torch.sigmoid(x) 523 x = torch.sigmoid(x) 524 return torch.sigmoid(x) 525 526 def pattern(x): 527 return torch.sigmoid(x) 528 529 def replacement(x): 530 return torch.exp(x) 531 532 def comparison(x): 533 x = torch.exp(x) 534 x = torch.exp(x) 535 return torch.exp(x) 536 537 traced = symbolic_trace(f) 538 comparison_fn = symbolic_trace(comparison) 539 540 x = torch.randn(3, 4) 541 542 subgraph_rewriter.replace_pattern(traced, pattern, replacement) 543 544 traced.graph.lint() 545 546 ref_outs = comparison_fn(x) 547 test_outs = traced.forward(x) 548 self.assertEqual(ref_outs, test_outs) 549 550 def test_subgraph_rewriter_with_overlapping_matches(self): 551 def f(x): 552 x = torch.sigmoid(x) 553 x = torch.sigmoid(x) 554 x = torch.sigmoid(x) 555 return torch.sigmoid(x) 556 557 def pattern(x): 558 x = torch.sigmoid(x) 559 x = torch.sigmoid(x) 560 return x 561 562 def replacement(x): 563 return torch.neg(x) 564 565 def comparison(x): 566 x = torch.neg(x) 567 return torch.neg(x) 568 569 traced = symbolic_trace(f) 570 comparison_fn = symbolic_trace(comparison) 571 572 x = torch.randn(3, 4) 573 574 subgraph_rewriter.replace_pattern(traced, pattern, replacement) 575 576 traced.graph.lint() 577 578 ref_outs = comparison_fn(x) 579 test_outs = traced.forward(x) 580 self.assertEqual(ref_outs, test_outs) 581 582 def test_subgraph_rewriter_replace_with_multiple_outputs(self): 583 def f(x): 584 y = torch.sigmoid(x) 585 z = torch.relu(x) 586 return y + z 587 588 def pattern(a): 589 b = torch.sigmoid(a) 590 c = torch.relu(a) 591 return b, c 592 593 def replacement(x): 594 return torch.exp(x), torch.abs(x) 595 596 def comparison(x): 597 y = torch.exp(x) 598 z = torch.abs(x) 599 return y + z 600 601 traced = symbolic_trace(f) 602 comparison_fn = symbolic_trace(comparison) 603 604 x = torch.randn(3, 4) 605 606 subgraph_rewriter.replace_pattern(traced, pattern, replacement) 607 608 traced.graph.lint() 609 610 ref_outs = comparison_fn(x) 611 test_outs = traced.forward(x) 612 self.assertEqual(ref_outs, test_outs) 613 614 def test_subgraph_rewriter_replace_with_duplicated_outputs(self): 615 def f(x1, x2): 616 x = x1 - x2 617 y = torch.sigmoid(x) 618 z = torch.relu(x) 619 return y + z 620 621 def pattern(a1, a2): 622 a = a1 - a2 623 b = torch.sigmoid(a) 624 c = torch.relu(a) 625 return b, c, a 626 627 def replacement(x1, x2): 628 y1 = torch.exp(x1) 629 y2 = torch.abs(x2) 630 return y2, y2, y1 631 632 def comparison(x1, x2): 633 y2 = torch.abs(x2) 634 return y2 + y2 635 636 traced = symbolic_trace(f) 637 comparison_fn = symbolic_trace(comparison) 638 639 x1 = torch.randn(3, 4) 640 x2 = torch.randn(3, 4) 641 642 subgraph_rewriter.replace_pattern(traced, pattern, replacement) 643 644 traced.graph.lint() 645 646 ref_outs = comparison_fn(x1, x2) 647 test_outs = traced.forward(x1, x2) 648 self.assertEqual(ref_outs, test_outs) 649 650 def test_subgraph_rewriter_with_unused_args(self): 651 class M(torch.nn.Module): 652 def forward(self, x, y, z): 653 return x + y 654 655 def pattern(x, y): 656 return x + y 657 658 def replacement(x, y): 659 return x - y 660 661 def comparison(x1, x2, x3): 662 return x1 - x2 663 664 traced = symbolic_trace(M()) 665 comparison_fn = symbolic_trace(comparison) 666 667 x1 = torch.randn(3, 4) 668 x2 = torch.randn(3, 4) 669 x3 = torch.randn(3, 4) 670 671 subgraph_rewriter.replace_pattern(traced, pattern, replacement) 672 673 traced.graph.lint() 674 placeholder_nodes = [n for n in traced.graph.nodes if n.op == "placeholder"] 675 assert len(placeholder_nodes) == 3 676 677 ref_outs = comparison_fn(x1, x2, x3) 678 test_outs = traced.forward(x1, x2, x3) 679 self.assertEqual(ref_outs, test_outs) 680 681 def test_subgraph_rewriter_call_method(self): 682 class M(torch.nn.Module): 683 def forward(self, x): 684 x = x.dequantize() 685 x = x.sigmoid() 686 x = x.to(torch.float16) 687 return x 688 689 def pattern(x): 690 x = x.dequantize() 691 x = x.sigmoid() 692 x = x.to(torch.float16) 693 return x 694 695 def replacement(x): 696 return x 697 698 traced = symbolic_trace(M()) 699 comparison_fn = symbolic_trace(replacement) 700 701 x1 = torch.randn(3, 4) 702 703 subgraph_rewriter.replace_pattern(traced, pattern, replacement) 704 705 traced.graph.lint() 706 707 ref_outs = comparison_fn(x1) 708 test_outs = traced.forward(x1) 709 self.assertEqual(ref_outs, test_outs) 710 711 def test_subgraph_rewriter_nodes_with_kwargs(self): 712 class M(torch.nn.Module): 713 def __init__(self) -> None: 714 super().__init__() 715 self.w0 = torch.nn.Parameter(torch.empty([128, 128])) 716 self.b0 = torch.nn.Parameter(torch.empty([128])) 717 718 def forward(self, in0): 719 lin_res = torch.nn.functional.linear(in0, self.w0, bias=self.b0) 720 mul_res = in0 * lin_res 721 sum_res = mul_res + in0 722 return sum_res 723 724 def pattern(a, b, bias): 725 lin_res = torch.nn.functional.linear(a, b, bias=bias) 726 mul_res = a * lin_res 727 return lin_res, mul_res 728 729 def replacement(a, b, bias): 730 lin_res, mul_res = wrapped_gemm_bias_mul(a, b, bias) 731 return lin_res, mul_res 732 733 traced = symbolic_trace(M()) 734 matches = subgraph_rewriter.replace_pattern(traced, pattern, replacement) 735 736 self.assertEqual(len(matches), 1) 737 738 found_repalcement_node = False 739 for node in traced.graph.nodes: 740 if node.target == wrapped_gemm_bias_mul: 741 found_repalcement_node = True 742 break 743 744 self.assertTrue(found_repalcement_node) 745 746 def test_subgraph_rewriter_local_revert(self): 747 # Following model will have 3 anchors as the matching candidate with the given pattern 748 # Anchor 1 and 3 is a real match, but anchor 2 is not. 749 # The subgraph rewriter should be able to revert the changes made while matching anchor 2. 750 # Final match with anchor 3 should be successful. 751 752 class M(torch.nn.Module): 753 def __init__(self) -> None: 754 super().__init__() 755 self.w0 = torch.nn.Parameter(torch.empty([128, 128])) 756 self.b0 = torch.nn.Parameter(torch.empty([128])) 757 self.w1 = torch.nn.Parameter(torch.empty([128, 128])) 758 self.b1 = torch.nn.Parameter(torch.empty([128])) 759 self.w2 = torch.nn.Parameter(torch.empty([128, 128])) 760 self.b2 = torch.nn.Parameter(torch.empty([128])) 761 self.w3 = torch.nn.Parameter(torch.empty([128, 128])) 762 self.b3 = torch.nn.Parameter(torch.empty([128])) 763 self.w4 = torch.nn.Parameter(torch.empty([128, 128])) 764 self.b4 = torch.nn.Parameter(torch.empty([128])) 765 766 def forward(self, in0, in1): 767 lin_res_1 = torch.nn.functional.linear(in1, self.w0, bias=self.b0) 768 lin_res_2 = torch.nn.functional.linear(lin_res_1, self.w1, bias=self.b1) 769 # potential match at anchor 1 770 mul_res_1 = in1 * lin_res_2 771 sum_res_1 = mul_res_1 + in1 772 lin_res_3 = torch.nn.functional.linear(sum_res_1, self.w2, bias=self.b2) 773 sigmoid_res_1 = torch.sigmoid(lin_res_3) 774 # potential match at anchor 2 775 mul_res_2 = lin_res_3 * sigmoid_res_1 776 lin_res_4 = torch.nn.functional.linear(in0, self.w3, bias=self.b3) 777 lin_res_5 = torch.nn.functional.linear(lin_res_4, self.w4, bias=self.b4) 778 # potential match at anchor 3 779 mul_res_3 = in0 * lin_res_5 780 sum_res_2 = mul_res_3 + in0 781 cat_res = torch.cat( 782 [mul_res_2, sum_res_2], 783 dim=1, 784 ) 785 return cat_res 786 787 def gemm_bias_mul_pattern_with_c(a, b, bias, c): 788 lin_res = torch.nn.functional.linear(a, b, bias=bias) 789 mul_res = c * lin_res 790 return lin_res, mul_res 791 792 def gemm_bias_mul_replacement_with_c(a, b, bias, c): 793 lin_res, mul_res = wrapped_gemm_bias_mul_with_c(a, b, bias, c) 794 return lin_res, mul_res 795 796 traced = symbolic_trace(M()) 797 matches = subgraph_rewriter.replace_pattern( 798 traced, gemm_bias_mul_pattern_with_c, gemm_bias_mul_replacement_with_c 799 ) 800 801 self.assertEqual(len(matches), 2) 802 803 repalcement_node_found = 0 804 for node in traced.graph.nodes: 805 if node.target == wrapped_gemm_bias_mul_with_c: 806 repalcement_node_found += 1 807 808 self.assertEqual(repalcement_node_found, 2) 809 810 def test_replace_pattern_with_filters(self): 811 class M(torch.nn.Module): 812 def forward(self, x, scale, zero_point): 813 # Match, second input to add is a scalar 814 x = x.dequantize() 815 x = torch.add(x, 2) 816 x = x.relu() 817 x = torch.quantize_per_tensor(x, scale, zero_point, torch.quint8) 818 819 y = x + 1 820 # NOT a match, second input to add is NOT a scalar 821 x = x.dequantize() 822 x = torch.add(x, y) 823 x = x.relu() 824 x = torch.quantize_per_tensor(x, scale, zero_point, torch.quint8) 825 826 return x 827 828 def BinaryOpScalarReLUPattern(x, num, scale, zero_point): 829 x = x.dequantize() 830 x = torch.add(x, num) 831 x = x.relu() 832 x = torch.quantize_per_tensor(x, scale, zero_point, torch.quint8) 833 return x 834 835 def BinaryOpScalarReLUReplacement(x, num, scale, zero_point): 836 x = torch.mul(x, num) 837 return x 838 839 def second_input_is_scalar(match, original_graph, pattern_graph): 840 """check the node that's matched to the second input of the pattern graph 841 is a scalar number 842 """ 843 input_idx = 0 844 for node in pattern_graph.nodes: 845 if node.op == "placeholder": 846 if input_idx == 1: 847 num_node = node 848 input_idx += 1 849 return isinstance(match.nodes_map[num_node], (int, float)) 850 851 def check_replacement_nodes(self, traced, matches): 852 replacement_nodes_in_graph = [ 853 node for node in traced.graph.nodes if node.target == torch.mul 854 ] 855 replacement_nodes_in_res = [r for m in matches for r in m.replacements] 856 self.assertEqual( 857 len(replacement_nodes_in_graph), len(replacement_nodes_in_res) 858 ) 859 self.assertEqual(replacement_nodes_in_graph, replacement_nodes_in_res) 860 return len(replacement_nodes_in_graph) 861 862 # match without filter, should find 2 match 863 traced = symbolic_trace(M()) 864 matches = subgraph_rewriter.replace_pattern_with_filters( 865 traced, BinaryOpScalarReLUPattern, BinaryOpScalarReLUReplacement, None 866 ) 867 self.assertEqual(len(matches), 2) 868 self.assertEqual(check_replacement_nodes(self, traced, matches), 2) 869 870 # match with filter, should find 1 match 871 traced = symbolic_trace(M()) 872 matches = subgraph_rewriter.replace_pattern_with_filters( 873 traced, 874 BinaryOpScalarReLUPattern, 875 BinaryOpScalarReLUReplacement, 876 [second_input_is_scalar], 877 ) 878 self.assertEqual(len(matches), 1) 879 self.assertEqual(check_replacement_nodes(self, traced, matches), 1) 880 881 def test_matching_pattern_with_list_type_arg(self): 882 class M(torch.nn.Module): 883 def forward(self, x): 884 return torch.ops.aten._reshape_alias_copy.default(x, [1, 2], [3, 4]) 885 886 def pattern(x, arg0, arg1): 887 return torch.ops.aten._reshape_alias_copy.default(x, arg0, arg1) 888 889 def replacement(x, arg0, arg1): 890 return torch.ops.aten._reshape_alias_copy.default(x, arg1, arg0) 891 892 traced = symbolic_trace(M()) 893 matches = subgraph_rewriter.replace_pattern(traced, pattern, replacement) 894 895 self.assertEqual(len(matches), 1) 896 897 self.assertExpectedInline( 898 traced.code.strip(), 899 """\ 900def forward(self, x): 901 _reshape_alias_copy_default_1 = torch.ops.aten._reshape_alias_copy.default(x, [3, 4], [1, 2]); x = None 902 return _reshape_alias_copy_default_1""", 903 ) # noqa: B950 904 905 def test_replacement_with_attrs(self): 906 class M(torch.nn.Module): 907 def __init__(self) -> None: 908 super().__init__() 909 self.a = torch.tensor([1]) 910 self.b = torch.tensor([2]) 911 912 def forward(self, x): 913 return x + self.a - self.b 914 915 class Pattern(torch.nn.Module): 916 def __init__(self) -> None: 917 super().__init__() 918 self.a = torch.tensor([1]) 919 920 def forward(self, x): 921 return x + self.a 922 923 class Replacement(torch.nn.Module): 924 def __init__(self) -> None: 925 super().__init__() 926 self.c = torch.tensor([3]) 927 928 def forward(self, x): 929 return x - self.c 930 931 traced = symbolic_trace(M()) 932 matches = subgraph_rewriter.replace_pattern(traced, Pattern(), Replacement()) 933 self.assertEqual(len(matches), 1) 934 935 def test_matching_variable_arguments(self): 936 class M(torch.nn.Module): 937 def forward(self, x): 938 return torch.ops.aten.max_pool2d_with_indices.default( 939 x, [2, 2], stride=[2, 2] 940 ) 941 942 def pattern(x, kernel_size, stride): 943 # default padding is [0, 0] 944 return torch.ops.aten.max_pool2d_with_indices.default( 945 x, kernel_size, stride, padding=[0, 0] 946 ) 947 948 traced = symbolic_trace(M()) 949 matches = subgraph_rewriter.replace_pattern(traced, pattern, pattern) 950 951 self.assertEqual(len(matches), 1) 952 953 def test_replaced_nodes(self): 954 class M(torch.nn.Module): 955 def forward(self, x, y): 956 return torch.add(x, y) 957 958 def pattern(x, y): 959 return torch.add(x, y) 960 961 def replacement(x, y): 962 return torch.sub(torch.mul(x, y), y) 963 964 traced = symbolic_trace(M()) 965 matches = subgraph_rewriter.replace_pattern_with_filters( 966 traced, pattern, replacement 967 ) 968 969 def check_replacement_nodes(self, traced, matches): 970 replacement_nodes_in_graph = [ 971 node 972 for node in traced.graph.nodes 973 if node.target in {torch.sub, torch.mul} 974 ] 975 replacement_nodes_in_res = [r for m in matches for r in m.replacements] 976 self.assertEqual( 977 len(replacement_nodes_in_graph), len(replacement_nodes_in_res) 978 ) 979 self.assertEqual(replacement_nodes_in_graph, replacement_nodes_in_res) 980 return len(replacement_nodes_in_graph) 981 982 self.assertEqual(check_replacement_nodes(self, traced, matches), 2) 983