1# Owner(s): ["module: fx"] 2 3import operator 4 5import torch 6import torch.fx 7from torch.fx.experimental import const_fold 8from torch.fx.passes.shape_prop import _extract_tensor_metadata, ShapeProp 9from torch.testing._internal.common_utils import TestCase 10 11 12class TestConstFold(TestCase): 13 def _get_attr(self, node): 14 mod = node.graph.owning_module 15 target = str(node.target) 16 target_atoms = target.split(".") 17 curr_obj = mod 18 for i, atom in enumerate(target_atoms): 19 if not hasattr(curr_obj, atom): 20 raise RuntimeError( 21 f"Node referenced nonexistent target '{'.'.join(target_atoms[:i])}'; " 22 f" original whole target: '{target}'" 23 ) 24 curr_obj = getattr(curr_obj, atom) 25 return curr_obj 26 27 def _verify_const_fold_mod(self, mod_folded: const_fold.FoldedGraphModule): 28 self.assertTrue(mod_folded.const_subgraph_module is not None) 29 30 # Check that we don't have the const or non-const fold graphs in the gm, and 31 # that we do have the const folded get_attr. 32 found_folded_attrs = False 33 for n in mod_folded.graph.nodes: 34 if n.op == "get_attr" and n.target.startswith("_FX_CONST_FOLDED_ATTRS"): 35 found_folded_attrs = True 36 elif n.op == "call_module": 37 self.assertTrue(n.target not in {"submod_0", "submod_1"}) 38 self.assertTrue(found_folded_attrs) 39 40 def test_const_fold_basic_one_attr_no_name_collision(self): 41 r""" 42 Perform constant folding conversion, from original mod to split constant folding 43 module with two split subgraphs, where there's a single attr to fold and 44 a single output attr result to replace. 45 46 attr1 attr1 47 | | | | 48 x add add 49 \ / | 50 sub y output (becomes attr add_1) 51 \ / ==> -------+------- (const/base subgraph split) 52 mul attr2 x / (input from previous subgraph 53 \ / \ / is attr) 54 add sub y 55 | \ / 56 output mul attr2 57 \ / 58 add 59 | 60 output 61 """ 62 63 class ConstFoldTestModule(torch.nn.Module): 64 def __init__(self) -> None: 65 super().__init__() 66 self.attr_1 = torch.nn.Parameter(torch.tensor([[-0.9]])) 67 self.attr_2 = torch.nn.Parameter(torch.tensor([[17.1]])) 68 69 def forward(self, x, y): 70 a = self.attr_1 + self.attr_1 71 x = x - a 72 return x * y + self.attr_2 73 74 mod = ConstFoldTestModule() 75 mod_folded: const_fold.FoldedGraphModule = const_fold.split_const_subgraphs(mod) 76 self._verify_const_fold_mod(mod_folded) 77 78 # Now run both folded and non-folded to check results equal. 79 in_x, in_y = torch.tensor([[-0.45]]), torch.tensor([0.9]) 80 base_result = mod(in_x, in_y) 81 fold_result = mod_folded(in_x, in_y) 82 self.assertTrue(torch.equal(fold_result, base_result)) 83 84 def test_const_fold_basic_one_attr_name_collision(self): 85 r""" 86 Perform constant folding conversion, from original mod to split constant folding 87 module with two split subgraphs, where there's a single attr to fold and 88 a single output attr result to replace. Name the attrs such that they will 89 collide by name with folded attrs. 90 91 add_1 add_1 92 | | | | 93 x add add 94 \ / | 95 sub y output (becomes attr add_1) 96 \ / ==> -------+------- (const/base subgraph split) 97 mul add_2 x / (input from previous subgraph 98 \ / \ / is attr) 99 add sub y 100 | \ / 101 output mul add_2 102 \ / 103 add 104 | 105 output 106 """ 107 108 class ConstFoldTestModule(torch.nn.Module): 109 def __init__(self) -> None: 110 super().__init__() 111 # Note: Named as such to result in name collision. 112 self.add_1__CF = torch.nn.Parameter(torch.tensor([[1.0]])) 113 self.add_2__CF = torch.nn.Parameter(torch.tensor([[17.1]])) 114 115 def forward(self, x, y): 116 a = self.add_1__CF + self.add_1__CF 117 x = x - a 118 return x * y + self.add_2__CF 119 120 mod = ConstFoldTestModule() 121 mod_folded: const_fold.FoldedGraphModule = const_fold.split_const_subgraphs(mod) 122 self._verify_const_fold_mod(mod_folded) 123 124 # Now run both folded and non-folded to check results equal. 125 in_x, in_y = torch.tensor([[5.0]]), torch.tensor([4.0]) 126 base_result = mod(in_x, in_y) 127 fold_result = mod_folded(in_x, in_y) 128 self.assertTrue(torch.equal(fold_result, base_result)) 129 130 def test_const_fold_basic_placeholder_reordered(self): 131 """ 132 Test code path where placeholder comes after normal op node in FX 133 """ 134 135 class ConstFoldTestModule(torch.nn.Module): 136 def forward(self, x, y): 137 return x * 2 + y 138 139 mod = ConstFoldTestModule() 140 mod = torch.fx.symbolic_trace(mod) 141 yy = None 142 for n in mod.graph.nodes: 143 if n.op == "placeholder" and n.target == "y": 144 yy = n 145 elif yy is not None and n.op == "call_function": 146 yy.prepend(n) 147 break 148 149 mod_folded: const_fold.FoldedGraphModule = const_fold.split_const_subgraphs(mod) 150 151 self.assertTrue(mod_folded.const_subgraph_module is None) 152 # Now run both folded and non-folded to check results equal. 153 in_x = torch.tensor([[-0.45]]) 154 in_y = torch.tensor([[0.45]]) 155 base_result = mod(in_x, in_y) 156 fold_result = mod_folded(in_x, in_y) 157 self.assertTrue(torch.equal(fold_result, base_result)) 158 159 def test_const_fold_noop(self): 160 r""" 161 Check that a graph with no constant folding is handled correctly. 162 163 x attr1 164 \ / 165 sub 166 | 167 output 168 """ 169 170 class ConstFoldTestModule(torch.nn.Module): 171 def __init__(self) -> None: 172 super().__init__() 173 self.attr1 = torch.nn.Parameter(torch.tensor([[-0.9]])) 174 175 def forward(self, x): 176 return x - self.attr1 177 178 mod = ConstFoldTestModule() 179 mod_folded: const_fold.FoldedGraphModule = const_fold.split_const_subgraphs(mod) 180 181 # Check that the folded graph module is None, since there was no folding to do. 182 self.assertTrue(mod_folded.const_subgraph_module is None) 183 184 # Now run both folded and non-folded to check results equal. 185 in_x = torch.tensor([[-0.45]]) 186 base_result = mod(in_x) 187 fold_result = mod_folded(in_x) 188 self.assertTrue(torch.equal(fold_result, base_result)) 189 190 def test_const_fold_basic_two_attr_three_input(self): 191 r""" 192 Perform constant folding conversion, from original mod to split constant 193 folding module with two split subgraphs, where there are two attrs to 194 fold into a single output, and there are three placeholder inputs. 195 196 attr1 attr2 attr1 attr2 197 \ / \ / 198 x add add 199 \ / | 200 sub y output (becomes attr add_1) 201 \ / ==> -------+------- (const/base subgraph split) 202 mul z x / (input from previous subgraph 203 \ / \ / is attr) 204 div sub y 205 | \ / 206 output mul z 207 \ / 208 div 209 | 210 output 211 """ 212 213 class ConstFoldTestModule(torch.nn.Module): 214 def __init__(self) -> None: 215 super().__init__() 216 self.attr1 = torch.nn.Parameter(torch.tensor([[-0.9]])) 217 self.attr1 = torch.nn.Parameter(torch.tensor([[1.32]])) 218 219 def forward(self, x, y, z): 220 a = self.attr1 + self.attr1 221 sub = x - a 222 mul = sub * y 223 return mul / z 224 225 mod = ConstFoldTestModule() 226 mod_folded: const_fold.FoldedGraphModule = const_fold.split_const_subgraphs(mod) 227 self._verify_const_fold_mod(mod_folded) 228 229 # Now run both folded and non-folded to check results equal. 230 in_x, in_y, in_z = ( 231 torch.tensor([[-0.45]]), 232 torch.tensor([0.9]), 233 torch.tensor([1.1]), 234 ) 235 base_result = mod(in_x, in_y, in_z) 236 fold_result = mod_folded(in_x, in_y, in_z) 237 self.assertTrue(torch.equal(fold_result, base_result)) 238 239 def test_const_fold_basic_two_attr(self): 240 r""" 241 Perform constant folding conversion, from original mod to split constant 242 folding module with two split subgraphs, where there are two attrs to 243 fold into a single output. 244 245 attr1 attr2 attr1 attr2 246 \ / \ / 247 x add add (becomes attr add_1) 248 \ / ==> -------+------- (const/base subgraph split) 249 sub x | (input from previous subgraph is attr) 250 | \ / 251 output sub 252 | 253 output 254 """ 255 256 class ConstFoldTestModule(torch.nn.Module): 257 def __init__(self) -> None: 258 super().__init__() 259 self.attr1 = torch.nn.Parameter(torch.randn(2, 3)) 260 self.attr2 = torch.nn.Parameter(torch.randn(2, 3)) 261 262 def forward(self, x): 263 y = self.attr1 + self.attr2 264 return x + y 265 266 mod = ConstFoldTestModule() 267 mod_folded: const_fold.FoldedGraphModule = const_fold.split_const_subgraphs(mod) 268 self._verify_const_fold_mod(mod_folded) 269 270 # Now run both folded and non-folded to check results equal. 271 in_x = torch.randn(2, 3) 272 fold_result = mod_folded(in_x) 273 base_result = mod(in_x) 274 self.assertTrue(torch.equal(fold_result, base_result)) 275 276 def test_const_fold_multi_const_folded_attrs(self): 277 r""" 278 Perform constant folding conversion, from original mod to split constant 279 folding module with two split subgraphs, where there are two attrs to 280 fold into two new attrs. 281 282 attr1 attr2 attr1 attr2 283 / \ | / \ | 284 permute | sum permute | sum 285 \ / / \ / | 286 x add y / add | 287 \ / \ / | | 288 sub add output output (become attrs add_1 and mul_1) 289 \ / ==> --------+-------+------ (const/base subgraph split) 290 \ / x | y | (inputs from previous subgraph 291 add \ / \ / are attrs) 292 | sub add 293 linear \ / 294 | add 295 sigmoid | 296 | linear 297 output | 298 sigmoid 299 | 300 output 301 """ 302 303 class ConstFoldTestModule(torch.nn.Module): 304 def __init__(self) -> None: 305 super().__init__() 306 self.attr1 = torch.nn.Parameter(torch.randn(4, 4)) 307 self.attr2 = torch.nn.Parameter(torch.randn(4, 4)) 308 self.lin = torch.nn.Linear(4, 4) 309 310 def forward(self, x, y): 311 a = self.attr1 + self.attr1.permute(1, 0) 312 x = x - a 313 amax = torch.sum(self.attr2, dim=1) 314 y = y + amax 315 return torch.sigmoid(self.lin(x + y)) 316 317 mod = ConstFoldTestModule() 318 mod_folded: const_fold.FoldedGraphModule = const_fold.split_const_subgraphs(mod) 319 self._verify_const_fold_mod(mod_folded) 320 321 # Now run both folded and non-folded to check results equal. 322 in_x, in_y = torch.randn(4, 4), torch.randn(4) 323 fold_result = mod_folded(in_x, in_y) 324 base_result = mod(in_x, in_y) 325 self.assertTrue(torch.equal(fold_result, base_result)) 326 327 def test_const_fold_submod_hierarchy(self): 328 r""" 329 Perform constant folding conversion, from original mod to split constant folding 330 module where one of the folded attrs comes from a submod deeper in the hierarchy 331 of the base module. 332 """ 333 334 class TracedThroughModule(torch.nn.Module): 335 def __init__(self) -> None: 336 super().__init__() 337 self.internal_attr = torch.nn.Parameter(torch.randn(2, 3)) 338 339 def forward(self): 340 return self.internal_attr 341 342 class ConstFoldTestModule(torch.nn.Module): 343 def __init__(self) -> None: 344 super().__init__() 345 self.my_mod = TracedThroughModule() 346 self.attr = torch.nn.Parameter(torch.randn(2, 3)) 347 348 def forward(self, x): 349 return self.attr + self.my_mod() + x 350 351 mod = ConstFoldTestModule() 352 mod_folded: const_fold.FoldedGraphModule = const_fold.split_const_subgraphs(mod) 353 self._verify_const_fold_mod(mod_folded) 354 355 # Now run both folded and non-folded to check results equal. 356 in_x = torch.randn(2, 3) 357 fold_result = mod_folded(in_x) 358 base_result = mod(in_x) 359 self.assertTrue(torch.equal(fold_result, base_result)) 360 361 def test_retain_node_meta(self): 362 r""" 363 Perform constant folding conversion, and validate that node meta is retained. 364 """ 365 366 class ConstFoldTestModule(torch.nn.Module): 367 def __init__(self) -> None: 368 super().__init__() 369 self.attr = torch.nn.Parameter(torch.randn(2, 3)) 370 371 def forward(self, x): 372 a = self.attr + self.attr 373 return x - a 374 375 mod = ConstFoldTestModule() 376 gm = torch.fx.symbolic_trace(mod) 377 378 # Add a count for each node to check after we const fold. 379 for idx, node in enumerate(gm.graph.nodes): 380 if node.op != "output": 381 node.meta["meta_idx"] = idx 382 383 # Pre-folding: 384 # idx 0: placeholder 385 # idx 1: get_attr (will no longer be used, hence removed) 386 # idx 2: add (will be folded into a get_attr) 387 # idx 3: sub 388 389 gm_folded: const_fold.FoldedGraphModule = const_fold.split_const_subgraphs(gm) 390 self._verify_const_fold_mod(gm_folded) 391 392 # Post-folding: 393 # idx 0: placeholder 394 # idx 2: get_attr (replaced original add; original get_attr was removed) 395 # idx 3: sub 396 397 # Check the expected indices are still here. 398 for node in gm_folded.graph.nodes: 399 if node.op == "placeholder": 400 self.assertEqual(node.meta["meta_idx"], 0) 401 elif node.op == "get_attr": 402 self.assertEqual(node.meta["meta_idx"], 2) 403 elif node.op == "call_function" and node.target == operator.sub: 404 self.assertEqual(node.meta["meta_idx"], 3) 405 else: 406 self.assertEqual(node.op, "output") 407 408 # Now run both folded and non-folded to check results equal. 409 in_x = torch.randn(2, 3) 410 fold_result = gm_folded(in_x) 411 base_result = mod(in_x) 412 self.assertTrue(torch.equal(fold_result, base_result)) 413 414 def test_const_fold_has_inlined_call_module_node(self): 415 class ConstFoldTestModule(torch.nn.Module): 416 def __init__(self) -> None: 417 super().__init__() 418 self.attr = torch.nn.Parameter(torch.randn(2, 3)) 419 self.mod = torch.nn.Identity() 420 self.mod.relu = torch.nn.ReLU() 421 422 def forward(self, x): 423 a = self.attr + self.attr 424 return self.mod.relu(x - a) 425 426 mod = ConstFoldTestModule() 427 gm_folded = const_fold.split_const_subgraphs(mod) 428 429 # Now run both folded and non-folded to check results equal. 430 in_x = torch.randn(2, 3) 431 fold_result = gm_folded(in_x) 432 base_result = mod(in_x) 433 self.assertTrue(torch.equal(fold_result, base_result)) 434 435 def test_const_fold_module_attr(self): 436 class ConstFoldTestModule(torch.nn.Module): 437 def __init__(self) -> None: 438 super().__init__() 439 self.const = torch.nn.Parameter(torch.randn(2, 3)) 440 self.mod = torch.nn.Identity() 441 self.mod.attr = torch.nn.Parameter(torch.randn(2, 3)) 442 443 def forward(self, x): 444 a = self.const + self.mod.attr 445 x = x + a 446 return x + self.mod.attr 447 448 mod = ConstFoldTestModule() 449 gm_folded = const_fold.split_const_subgraphs(mod) 450 451 # Now run both folded and non-folded to check results equal. 452 in_x = torch.randn(2, 3) 453 fold_result = gm_folded(in_x) 454 base_result = mod(in_x) 455 self.assertTrue(torch.equal(fold_result, base_result)) 456 457 def test_const_fold_unused_placeholder(self): 458 class ConstFoldTestModule(torch.nn.Module): 459 def __init__(self) -> None: 460 super().__init__() 461 self.const = torch.nn.Parameter(torch.randn(2, 3)) 462 463 def forward(self, x, y, z): 464 a = self.const + self.const 465 return y + a 466 467 mod = ConstFoldTestModule() 468 gm_folded = const_fold.split_const_subgraphs(mod) 469 470 # Now run both folded and non-folded to check results equal. 471 in_x = torch.randn(2, 3) 472 fold_result = gm_folded(in_x, in_x, in_x) 473 base_result = mod(in_x, in_x, in_x) 474 self.assertTrue(torch.equal(fold_result, base_result)) 475 476 def test_dict_output(self): 477 class ConstFoldTestModule(torch.nn.Module): 478 def __init__(self) -> None: 479 super().__init__() 480 self.const = torch.nn.Parameter(torch.randn(2, 3)) 481 482 def forward(self, x): 483 a = self.const + self.const 484 return {"result": x + a} 485 486 mod = ConstFoldTestModule() 487 gm_folded = const_fold.split_const_subgraphs(mod) 488 489 # Now run both folded and non-folded to check results equal. 490 in_x = torch.randn(2, 3) 491 fold_result = gm_folded(in_x) 492 base_result = mod(in_x) 493 self.assertTrue(torch.equal(fold_result["result"], base_result["result"])) 494 495 def test_two_outputs(self): 496 class ConstFoldTestModule(torch.nn.Module): 497 def __init__(self) -> None: 498 super().__init__() 499 self.const = torch.nn.Parameter(torch.randn(2, 3)) 500 501 def forward(self, x): 502 a = self.const + self.const 503 return x, x + a 504 505 mod = ConstFoldTestModule() 506 gm_folded = const_fold.split_const_subgraphs(mod) 507 508 # Now run both folded and non-folded to check results equal. 509 in_x = torch.randn(2, 3) 510 fold_result = gm_folded(in_x) 511 base_result = mod(in_x) 512 self.assertTrue(torch.equal(fold_result[0], base_result[0])) 513 self.assertTrue(torch.equal(fold_result[1], base_result[1])) 514 515 def test_three_outputs(self): 516 class ConstFoldTestModule(torch.nn.Module): 517 def __init__(self) -> None: 518 super().__init__() 519 self.const = torch.nn.Parameter(torch.randn(2, 3)) 520 521 def forward(self, x): 522 a = self.const + self.const 523 return x, x + a, x + a 524 525 mod = ConstFoldTestModule() 526 gm_folded = const_fold.split_const_subgraphs(mod) 527 528 # Now run both folded and non-folded to check results equal. 529 in_x = torch.randn(2, 3) 530 fold_result = gm_folded(in_x) 531 base_result = mod(in_x) 532 self.assertTrue(torch.equal(fold_result[0], base_result[0])) 533 self.assertTrue(torch.equal(fold_result[1], base_result[1])) 534 self.assertTrue(torch.equal(fold_result[2], base_result[2])) 535 536 def test_check_inline_non_const(self): 537 r""" 538 Perform constant folding conversion and check that the non-const module is inlined 539 correctly. 540 """ 541 542 class ConstFoldTestModule(torch.nn.Module): 543 def __init__(self) -> None: 544 super().__init__() 545 self.attr = torch.nn.Parameter(torch.randn(2, 3)) 546 547 def forward(self, x): 548 a = self.attr + self.attr 549 return (x - a * x) / 2 550 551 mod = ConstFoldTestModule() 552 gm = torch.fx.symbolic_trace(mod) 553 554 gm_folded: const_fold.FoldedGraphModule = const_fold.split_const_subgraphs(gm) 555 self._verify_const_fold_mod(gm_folded) 556 557 # Check there are no call modules, because they've been inlined or extracted for 558 # const folding. 559 for node in gm_folded.graph.nodes: 560 self.assertNotEqual(node.op, "call_module") 561 562 # Now run both folded and non-folded to check results equal. 563 in_x = torch.randn(2, 3) 564 fold_result = gm_folded(in_x) 565 base_result = mod(in_x) 566 self.assertTrue(torch.equal(fold_result, base_result)) 567 568 def test_check_inline_non_const_mult_return(self): 569 r""" 570 Perform constant folding conversion and check that the non-const module is inlined 571 correctly. 572 """ 573 574 class ConstFoldTestModule(torch.nn.Module): 575 def __init__(self) -> None: 576 super().__init__() 577 self.attr = torch.nn.Parameter(torch.randn(2, 3)) 578 579 def forward(self, x): 580 a = self.attr + self.attr 581 return x - a, x / 2 582 583 mod = ConstFoldTestModule() 584 gm = torch.fx.symbolic_trace(mod) 585 586 gm_folded: const_fold.FoldedGraphModule = const_fold.split_const_subgraphs(gm) 587 self._verify_const_fold_mod(gm_folded) 588 589 # Check there are no call modules, because they've been inlined or extracted for 590 # const folding. 591 for node in gm_folded.graph.nodes: 592 self.assertNotEqual(node.op, "call_module") 593 594 # Now run both folded and non-folded to check results equal. 595 in_x = torch.randn(2, 3) 596 fold_result = gm_folded(in_x) 597 base_result = mod(in_x) 598 self.assertTrue(torch.equal(fold_result[0], base_result[0])) 599 self.assertTrue(torch.equal(fold_result[1], base_result[1])) 600 601 def test_check_skip_folding_quant_dequant_pattern(self): 602 r""" 603 Set up skip_folding_quant_dequant function to skip quant/dequant pattern. 604 This example shows how to use skip_folding_node_fn. 605 """ 606 607 class ConstFoldTestModule(torch.nn.Module): 608 def __init__(self) -> None: 609 super().__init__() 610 self.weight = torch.nn.Parameter(torch.randn(4, 4)) 611 self.bias = torch.nn.Parameter(torch.randn(4)) 612 self.relu = torch.nn.ReLU() 613 614 def forward(self, x): 615 quant_weight = torch.quantize_per_tensor( 616 self.weight, 0.5, 3, torch.quint8 617 ) 618 dequant_weight = torch.dequantize(quant_weight) 619 output = torch.nn.functional.linear(x, dequant_weight, self.bias) 620 return self.relu(output) 621 622 mod = ConstFoldTestModule() 623 in_x = torch.randn(2, 4) 624 gm = torch.fx.symbolic_trace(mod) 625 626 def skip_folding_quant_dequant(node: torch.fx.Node): 627 if node.target != torch.quantize_per_tensor: 628 return False 629 # If quantize_per_node -> dequantize, then skip folding. 630 for user in node.users: 631 if user.target == torch.dequantize: 632 return True 633 return False 634 635 gm_folded: const_fold.FoldedGraphModule = const_fold.split_const_subgraphs( 636 gm, skip_folding_node_fn=skip_folding_quant_dequant 637 ) 638 639 # Check that the folded graph module is None, since there was no folding to do. 640 self.assertTrue(gm_folded.const_subgraph_module is None) 641 642 # Now run both folded and non-folded to check results equal. 643 fold_result = gm_folded(in_x) 644 base_result = mod(in_x) 645 self.assertTrue(torch.equal(fold_result, base_result)) 646 647 def test_fold_module(self): 648 r""" 649 Perform constant folding with a call_module node. 650 """ 651 652 class ConstFoldTestModule(torch.nn.Module): 653 def __init__(self) -> None: 654 super().__init__() 655 self.lin_input = torch.nn.Parameter(torch.randn(4, 4)) 656 self.lin = torch.nn.Linear(4, 4) 657 658 def forward(self, x): 659 return self.lin(self.lin_input) + x 660 661 mod = ConstFoldTestModule() 662 mod_folded: const_fold.FoldedGraphModule = const_fold.split_const_subgraphs(mod) 663 self._verify_const_fold_mod(mod_folded) 664 665 # Now run both folded and non-folded to check results equal. 666 inp = torch.randn(4, 4) 667 self.assertTrue(torch.equal(mod_folded(inp), mod(inp))) 668 669 def test_const_fold_tensor_meta(self): 670 self._test_const_fold_tensor_meta(True) 671 self._test_const_fold_tensor_meta(False) 672 673 def _test_const_fold_tensor_meta(self, requires_grad): 674 """ 675 Verify tensor_meta is handled correctly. 676 """ 677 678 class ConstFoldTestModule(torch.nn.Module): 679 def __init__(self) -> None: 680 super().__init__() 681 self.attr_1 = torch.nn.Parameter(torch.tensor([[-0.9]]), requires_grad) 682 self.attr_2 = torch.nn.Parameter(torch.tensor([[17.1]]), requires_grad) 683 684 def forward(self, x, y): 685 a = self.attr_1 + self.attr_1 686 x = x - a 687 return x * y + self.attr_2 688 689 mod = ConstFoldTestModule() 690 gm = torch.fx.symbolic_trace(mod) 691 in_x, in_y = torch.tensor([[-0.45]]), torch.tensor([0.9]) 692 ShapeProp(gm).propagate(in_x, in_y) 693 mod_folded: const_fold.FoldedGraphModule = const_fold.split_const_subgraphs( 694 gm, device_for_folded_attrs="cpu" 695 ) 696 self._verify_const_fold_mod(mod_folded) 697 698 mod_folded.run_folding() 699 700 for n in mod_folded.graph.nodes: 701 if n.op == "get_attr": 702 attr = self._get_attr(n) 703 self.assertEqual(_extract_tensor_metadata(attr), n.meta["tensor_meta"]) 704 705 # Now run both folded and non-folded to check results equal. 706 base_result = mod(in_x, in_y) 707 fold_result = mod_folded(in_x, in_y) 708 self.assertTrue(torch.equal(fold_result, base_result)) 709