1# Owner(s): ["oncall: jit"] 2 3import io 4import unittest 5from itertools import product 6from typing import Any 7 8import torch 9import torch.nn as nn 10import torch.nn.functional as F 11from torch.jit._recursive import wrap_cpp_module 12from torch.testing import FileCheck 13from torch.testing._internal.common_cuda import TEST_CUDA, TEST_CUDNN 14from torch.testing._internal.common_quantization import skipIfNoFBGEMM 15from torch.testing._internal.common_quantized import override_quantized_engine 16from torch.testing._internal.common_utils import ( 17 set_default_dtype, 18 skipCUDAMemoryLeakCheckIf, 19 skipIfTorchDynamo, 20 TEST_WITH_ROCM, 21) 22from torch.testing._internal.jit_utils import JitTestCase 23from torch.utils import mkldnn as mkldnn_utils 24 25 26try: 27 import torchvision 28 29 HAS_TORCHVISION = True 30except ImportError: 31 HAS_TORCHVISION = False 32skipIfNoTorchVision = unittest.skipIf(not HAS_TORCHVISION, "no torchvision") 33 34if __name__ == "__main__": 35 raise RuntimeError( 36 "This test file is not meant to be run directly, use:\n\n" 37 "\tpython test/test_jit.py TESTNAME\n\n" 38 "instead." 39 ) 40 41TEST_ROCM = torch.cuda.is_available() and torch.version.hip is not None 42 43 44def removeExceptions(graph): 45 for n in graph.findAllNodes("prim::RaiseException"): 46 n.destroy() 47 48 49class TestFreezing(JitTestCase): 50 def test_freeze_module(self): 51 class M(nn.Module): 52 def __init__(self) -> None: 53 super().__init__() 54 self.a = 1 # folded 55 self.b = 1.2 # folded 56 self.c = "hello" # folded 57 self.c2 = "hi\xA1" # not folded 58 self.d = [1, 1] # folded 59 self.e = [1.0, 1.1] # folded 60 self.f = ["hello", "world"] # folded 61 self.f2 = [(1, "Over \u0e55\u0e57 57")] 62 self.g = ( 63 [1, 2], 64 3.2, 65 "4.4", 66 torch.tensor([5.5], requires_grad=True), 67 ) # folded 68 self.h = {"layer": [torch.tensor([7.7], requires_grad=True)]} 69 self.h2 = {"layer\xB1": [torch.tensor([8.8], requires_grad=True)]} 70 self.t = torch.tensor([1.2, 2.4], requires_grad=True) # folded 71 self.ts = [ 72 torch.tensor([1.0, 2.0], requires_grad=True), 73 torch.tensor([3.0, 4.0], requires_grad=True), 74 ] # folded 75 self.tt = [[torch.tensor([3.3, 2.3], requires_grad=True), None]] 76 77 def forward(self, x): 78 return ( 79 str(self.a) 80 + str(self.b) 81 + self.c 82 + self.c2 83 + str(self.d) 84 + str(self.e) 85 + str(self.f) 86 + str(self.f2) 87 + str(self.g) 88 + str(self.h) 89 + str(self.h2) 90 + str(self.t) 91 + str(self.ts) 92 + str(self.tt) 93 ) 94 95 m = torch.jit.script(M()) 96 m.eval() 97 input = torch.randn(2, 2) 98 output_s = m.forward(input) 99 m._c = torch._C._freeze_module(m._c) 100 buffer = io.BytesIO() 101 torch.jit.save(m._c, buffer) 102 buffer.seek(0) 103 m2 = torch.jit.load(buffer) 104 # Check if frozen module looks as below: 105 # module m { 106 # attributes { 107 # tt = ... 108 # } 109 # ... 110 # } 111 self.assertFalse(m2._c.hasattr("a")) 112 self.assertFalse(m2._c.hasattr("b")) 113 self.assertFalse(m2._c.hasattr("c")) 114 self.assertFalse(m2._c.hasattr("c2")) 115 self.assertFalse(m2._c.hasattr("d")) 116 self.assertFalse(m2._c.hasattr("e")) 117 self.assertFalse(m2._c.hasattr("f")) 118 self.assertFalse(m2._c.hasattr("f2")) 119 self.assertFalse(m2._c.hasattr("g")) 120 self.assertFalse(m2._c.hasattr("h")) 121 self.assertFalse(m2._c.hasattr("h2")) 122 self.assertFalse(m2._c.hasattr("t")) 123 self.assertFalse(m2._c.hasattr("ts")) 124 self.assertFalse(m2._c.hasattr("tt")) 125 output_f = m2.forward(input) 126 self.assertEqual(output_s, output_f) 127 128 def test_freeze_module_with_submodule(self): 129 class SubModule(nn.Module): 130 def __init__(self) -> None: 131 super().__init__() 132 self.a = 11 133 self.b = 2 134 135 def forward(self, x): 136 return self.a + self.b 137 138 class SubModule2(nn.Module): 139 def __init__(self) -> None: 140 super().__init__() 141 self.a = 12 142 self.b = 2 143 144 def forward(self, x): 145 self.b = 30 146 return self.a + self.b 147 148 class TestModule(nn.Module): 149 def __init__(self) -> None: 150 super().__init__() 151 self.sub1 = SubModule() 152 self.sub2 = SubModule2() 153 self.a = 3 154 self.b = 4 155 156 def forward(self, x): 157 self.b = 20 158 return self.sub1(x) + self.a + self.b + self.sub2(x) 159 160 m = torch.jit.script(TestModule()) 161 m.eval() 162 input = torch.randn(2, 2) 163 output_s = m.forward(input) 164 mf = torch.jit.freeze(m) 165 166 # Check if frozen module looks as below: 167 # module m { 168 # attributes { 169 # sub2 = ... 170 # b = 171 # } 172 # ... 173 # submodule { 174 # module m { 175 # attributes { 176 # sub2 = ... 177 # b = 178 # } 179 # ... 180 # } 181 # } 182 # } 183 mf = mf._c 184 self.assertFalse(mf.hasattr("sub1")) 185 self.assertFalse(mf.hasattr("a")) 186 self.assertTrue(mf.hasattr("b")) 187 self.assertTrue(mf.hasattr("sub2")) 188 self.assertTrue(mf.sub2.hasattr("b")) # verify b is preserved in sub2 189 self.assertFalse(mf.sub2.hasattr("a")) # verify a is removed in sub2 190 output_f = mf.forward(input) 191 self.assertEqual(output_s, output_f) 192 193 def test_freeze_module_with_fork(self): 194 class SubModule(nn.Module): 195 def __init__(self) -> None: 196 super().__init__() 197 self.a = torch.ones(20, 20) 198 self.b = torch.ones(20, 20) 199 200 def forward(self, x): 201 return self.a * self.b + x 202 203 class TestModule(nn.Module): 204 def __init__(self) -> None: 205 super().__init__() 206 self.sub = SubModule() 207 208 def forward(self, x): 209 fut = torch.jit._fork(self.sub.forward, x) 210 y_hat = self.sub(x) 211 y = torch.jit._wait(fut) 212 return y_hat + y 213 214 m = torch.jit.script(TestModule()) 215 m.eval() 216 input = torch.randn(20, 20) 217 output_s = m.forward(input) 218 mf = torch._C._freeze_module(m._c) 219 220 # Check if frozen module looks as below: 221 # module m { 222 # attributes { 223 # } 224 # ... 225 # submodule { 226 # } 227 # } 228 self.assertFalse(mf.hasattr("a")) 229 self.assertFalse(mf.hasattr("b")) 230 output_f = mf.forward(input) 231 self.assertEqual(output_s, output_f) 232 233 def test_freeze_module_with_nested_fork(self): 234 class SubModule(nn.Module): 235 def __init__(self) -> None: 236 super().__init__() 237 self.a = torch.ones(20, 20) 238 self.b = torch.ones(20, 20) 239 240 def forward(self, x): 241 return self.a * self.b + x 242 243 class SubModule2(nn.Module): 244 def __init__(self) -> None: 245 super().__init__() 246 self.sub = SubModule() 247 self.c = torch.ones(20, 20) 248 249 def forward(self, x): 250 fut = torch.jit._fork(self.sub.forward, x) 251 y_hat = self.sub(x) 252 y = torch.jit._wait(fut) 253 return y_hat + y + self.c 254 255 class TestModule(nn.Module): 256 def __init__(self) -> None: 257 super().__init__() 258 self.sub = SubModule2() 259 self.d = 1 260 261 def forward(self, x): 262 fut = torch.jit._fork(self.sub.forward, x) 263 y_hat = self.sub(x) 264 y = torch.jit._wait(fut) 265 self.d = 2 266 return y_hat * y + self.d 267 268 m = torch.jit.script(TestModule()) 269 m.eval() 270 input = torch.randn(20, 20) 271 output_s = m.forward(input) 272 mf = torch._C._freeze_module(m._c) 273 # Check if frozen module looks as below: 274 # module m { 275 # attributes { 276 # } 277 # ... 278 # submodule { 279 # } 280 # } 281 self.assertFalse(mf.hasattr("a")) 282 self.assertFalse(mf.hasattr("b")) 283 self.assertFalse(mf.hasattr("c")) 284 self.assertTrue(mf.hasattr("d")) 285 output_f = mf.forward(input) 286 self.assertEqual(output_s, output_f) 287 288 def test_freeze_module_with_fork2(self): 289 @torch.jit.script 290 def foo(x): 291 return x * 2 292 293 class TestModule(nn.Module): 294 def __init__(self) -> None: 295 super().__init__() 296 self.a = torch.ones(20, 20) 297 self.b = torch.ones(20, 20) 298 299 def forward(self, x): 300 fut = torch.jit._fork(foo, self.a) 301 y_hat = foo(self.b) 302 y = torch.jit._wait(fut) 303 return y_hat + y 304 305 m = torch.jit.script(TestModule()) 306 m.eval() 307 input = torch.randn(2, 2) 308 output_s = m.forward(input) 309 mf = torch._C._freeze_module(m._c) 310 311 # Check if frozen module looks as below: 312 # module m { 313 # attributes { 314 # self.a = ... 315 # self.b = .. 316 # } 317 # ... 318 # submodule { 319 # } 320 # } 321 # TODO: Although there are no mutation, the alias analysis 322 # conservatively assumes there is a mutation because attributes are 323 # passed to fork subgraph. both 'a' and 'b' are preserved. 324 self.assertTrue(mf.hasattr("a")) 325 self.assertFalse(mf.hasattr("b")) 326 output_f = mf.forward(input) 327 self.assertEqual(output_s, output_f) 328 329 def test_freeze_module_with_fork_calling_module_method(self): 330 @torch.jit.script 331 def foo(x, y): 332 return x * y 333 334 class TestModule(nn.Module): 335 def __init__(self) -> None: 336 super().__init__() 337 self.a = torch.ones(20, 20) 338 self.b = torch.ones(20, 20) 339 340 @torch.jit.export 341 def foo(self, x): 342 return x * self.a 343 344 @torch.jit.export 345 def bar(self, x): 346 return x * self.b 347 348 def forward(self, x): 349 fut = torch.jit._fork(self.foo, self.b) 350 y_hat = self.bar(self.a) 351 y = torch.jit._wait(fut) 352 return y_hat + y 353 354 m = torch.jit.script(TestModule()) 355 m.eval() 356 input = torch.randn(2, 2) 357 output_s = m.forward(input) 358 mf = torch._C._freeze_module(m._c) 359 # Check if frozen module looks as below: 360 # module m { 361 # attributes { 362 # self.b = .. 363 # } 364 # ... 365 # TODO: Although there are no mutation, the alias analysis 366 # conservatively assumes there is a mutation because attributes are 367 # passed to fork subgraph. 'b' is preserved. 368 self.assertFalse(mf.hasattr("a")) 369 self.assertTrue(mf.hasattr("b")) 370 output_f = mf.forward(input) 371 self.assertEqual(output_s, output_f) 372 373 def test_freeze_module_with_sharedclasstype(self): 374 class SubModule(nn.Module): 375 def __init__(self) -> None: 376 super().__init__() 377 self.a = torch.tensor([1.1]) 378 self.b = torch.tensor([2.2]) 379 380 def forward(self, x): 381 return self.a + self.b 382 383 @torch.jit.export 384 def modify_a(self, x): 385 self.a[0] += 10 386 return self.b 387 388 @torch.jit.export 389 def modify_b(self, x): 390 self.b[0] += 20 391 return self.a 392 393 class SubModule2(nn.Module): 394 def __init__(self) -> None: 395 super().__init__() 396 self.sub = SubModule() 397 self.b = torch.tensor([3.3]) 398 399 def forward(self, x): 400 y = self.sub.modify_b(x) 401 return y + self.b 402 403 class TestModule(nn.Module): 404 def __init__(self) -> None: 405 super().__init__() 406 self.sub1 = SubModule() # sub1 and sub2.sub shared same class type. 407 self.sub2 = SubModule2() 408 self.a = torch.tensor([4.4]) 409 410 def forward(self, x): 411 z = self.sub1.modify_a(x) 412 return self.sub2(x) + z + self.a 413 414 m = torch.jit.script(TestModule()) 415 m.eval() 416 input = torch.randn(2, 2) 417 output_s = m.forward(input) 418 mf = torch._C._freeze_module(m._c) 419 420 # Checking if Frozen module looks as below 421 # module mf { 422 # attributes { 423 # sub1 = ... 424 # sub2 = ... 425 # } 426 # ... 427 # submodules { 428 # module sub1 { 429 # attributes { 430 # a = ... 431 # b = ... 432 # } 433 # ... 434 # } 435 # module sub2 { 436 # attributes { 437 # sub = ... 438 # } 439 # ... 440 # submodule { 441 # module sub { 442 # attributes { 443 # a = ... 444 # b = ... 445 # } 446 # ... 447 # } 448 # } 449 # } 450 # } 451 # } 452 453 self.assertTrue(mf.hasattr("sub1")) 454 self.assertTrue(mf.sub1.hasattr("a")) 455 self.assertTrue(mf.sub1.hasattr("b")) 456 self.assertFalse(mf.hasattr("a")) 457 self.assertTrue(mf.hasattr("sub2")) 458 self.assertTrue(mf.sub2.hasattr("sub")) 459 self.assertFalse(mf.sub2.hasattr("b")) 460 self.assertTrue(mf.sub2.sub.hasattr("a")) 461 self.assertTrue(mf.sub2.sub.hasattr("b")) 462 output_f = mf.forward(input) 463 self.assertEqual(output_s, output_f) 464 465 def test_freeze_module_with_nestedaliasing(self): 466 class SubModule(nn.Module): 467 def __init__(self) -> None: 468 super().__init__() 469 self.a = torch.tensor([1.1]) 470 self.b = torch.tensor([2.2]) 471 472 def forward(self, x): 473 return self.a + self.b 474 475 @torch.jit.export 476 def modify_a(self, x): 477 self.a[0] = 10 478 return self.b 479 480 @torch.jit.export 481 def modify_b(self, x): 482 self.b[0] = 20 483 return self.a 484 485 Sub = SubModule() 486 487 class SubModule2(nn.Module): 488 def __init__(self) -> None: 489 super().__init__() 490 self.sub = Sub # aliasing 491 492 def forward(self, x): 493 return self.sub.a 494 495 class TestModule(nn.Module): 496 def __init__(self) -> None: 497 super().__init__() 498 self.sub1 = Sub # aliasing 499 self.sub2 = SubModule2() 500 501 def forward(self, x): 502 z = self.sub1.modify_a(x) 503 return self.sub2(x) + z 504 505 m = torch.jit.script(TestModule()) 506 m.eval() 507 mf = torch._C._freeze_module(m._c) 508 self.assertTrue(mf.hasattr("sub1")) 509 self.assertTrue(mf.sub1.hasattr("a")) 510 self.assertFalse(mf.sub1.hasattr("b")) 511 self.assertTrue(mf.hasattr("sub2")) 512 self.assertTrue(mf.sub2.hasattr("sub")) 513 self.assertTrue( 514 mf.sub2.sub.hasattr("a") 515 ) # Freezing detects that self.sub2.sub.a and self.sub1.a are alias 516 self.assertFalse(mf.sub2.sub.hasattr("b")) 517 input = torch.randn(2, 2) 518 output_s = m.forward(input) 519 output_f = mf.forward(input) 520 self.assertEqual(output_s, output_f) 521 522 # FIXME: JIT is not honoring aliasing. 'Sub' module is copied. As a result 523 # Eager and Script modules produce different output. 524 def test_freeze_module_with_nestedaliasingscalar(self): 525 class SubModule(nn.Module): 526 def __init__(self) -> None: 527 super().__init__() 528 self.a = 1.1 529 self.b = 2.2 530 531 def forward(self, x): 532 return self.a + self.b 533 534 @torch.jit.export 535 def modify_a(self, x): 536 self.a = 10.0 537 return self.b 538 539 @torch.jit.export 540 def modify_b(self, x): 541 self.b = 20.0 542 return self.a 543 544 Sub = SubModule() 545 546 class SubModule2(nn.Module): 547 def __init__(self) -> None: 548 super().__init__() 549 self.sub = Sub # aliasing 550 551 def forward(self, x): 552 return self.sub.a 553 554 class TestModule(nn.Module): 555 def __init__(self) -> None: 556 super().__init__() 557 self.sub1 = Sub # aliasing 558 self.sub2 = SubModule2() 559 560 def forward(self, x): 561 z = self.sub1.modify_a(x) 562 return self.sub2(x) + z 563 564 m = TestModule() 565 ms = torch.jit.script(m) 566 ms.eval() 567 mf = torch._C._freeze_module(ms._c) 568 self.assertTrue(mf.hasattr("sub1")) 569 self.assertTrue(mf.sub1.hasattr("a")) 570 self.assertFalse(mf.sub1.hasattr("b")) 571 # sub2 is fully folded becasue self.sub1 and self.sub2.sub are not alias (Scripting bug) 572 self.assertFalse(mf.hasattr("sub2")) 573 input = torch.randn(2, 2) 574 output = m.forward(input) 575 output_s = ms.forward(input) 576 output_f = mf.forward(input) 577 # Should be equal 578 self.assertNotEqual(output, output_s) 579 self.assertEqual(output_s, output_f) 580 581 def test_freeze_module_with_preserve_sub_module(self): 582 class SubModule(nn.Module): 583 def __init__(self) -> None: 584 super().__init__() 585 self.a = torch.tensor([1.1]) 586 self.b = 2.2 587 588 def forward(self, x): 589 return self.a 590 591 class TestModule(nn.Module): 592 def __init__(self) -> None: 593 super().__init__() 594 self.sub1 = SubModule() # aliasing 595 self.sub2 = SubModule() 596 597 def forward(self, x): 598 return self.sub2(x) + self.sub1(x) 599 600 m = TestModule() 601 ms = torch.jit.script(m) 602 ms.eval() 603 mf = torch._C._freeze_module(ms._c, ["sub1"]) 604 605 # Test that 'sub1' is preserved entirely and 'sub2' is completely folded 606 self.assertTrue(mf.hasattr("sub1")) 607 self.assertTrue(mf.sub1.hasattr("a")) 608 self.assertTrue(mf.sub1.hasattr("b")) 609 self.assertFalse(mf.hasattr("sub2")) 610 input = torch.randn(2, 2) 611 output_s = ms.forward(input) 612 output_f = mf.forward(input) 613 self.assertEqual(output_s, output_f) 614 615 def test_freeze_module_with_preserve_sub_module_and_mutation(self): 616 class SubModule(nn.Module): 617 def __init__(self) -> None: 618 super().__init__() 619 self.a = torch.tensor([1.1]) 620 self.b = 2.2 621 622 def forward(self, x): 623 self.a[0] = 3.3 624 return self.a 625 626 class TestModule(nn.Module): 627 def __init__(self) -> None: 628 super().__init__() 629 self.sub1 = SubModule() # aliasing 630 self.sub2 = SubModule() 631 632 def forward(self, x): 633 return self.sub2(x) + self.sub1(x) 634 635 m = TestModule() 636 ms = torch.jit.script(m) 637 ms.eval() 638 mf = torch._C._freeze_module(ms._c, ["sub1"]) 639 640 # Test that be both sub1 and sub1 are preserved and 'b' is preserved 641 # even if it is not used. To fulfill user request to preserve 'sub1' 642 self.assertTrue(mf.hasattr("sub1")) 643 self.assertTrue(mf.sub1.hasattr("a")) 644 self.assertTrue(mf.sub1.hasattr("b")) 645 self.assertTrue(mf.hasattr("sub2")) 646 self.assertTrue(mf.sub2.hasattr("a")) 647 self.assertTrue(mf.sub2.hasattr("b")) 648 input = torch.randn(2, 2) 649 output_s = ms.forward(input) 650 output_f = mf.forward(input) 651 self.assertEqual(output_s, output_f) 652 653 def test_freeze_module_with_helperfunction(self): 654 class SubModule(nn.Module): 655 def __init__(self) -> None: 656 super().__init__() 657 self.a = 11 658 self.b = 2 659 660 def forward(self, x): 661 return self.a + self.b 662 663 class TestModule(nn.Module): 664 def __init__(self) -> None: 665 super().__init__() 666 self.sub = SubModule() 667 self.a = 3 668 self.b = 4 669 670 def forward(self, x): 671 self.b = 20 672 return self._forward(x) + self.a + self.b 673 674 def _forward(self, x): 675 return self.sub(x) 676 677 m = torch.jit.script(TestModule()) 678 m.eval() 679 input = torch.randn(2, 2) 680 mf = torch._C._freeze_module(m._c) 681 self.assertFalse(mf.hasattr("sub")) 682 self.assertFalse(mf.hasattr("a")) 683 self.assertTrue(mf.hasattr("b")) 684 with self.assertRaisesRegex( 685 AttributeError, "TestModule (.*) does not have a field with name '_forward'" 686 ): 687 mf._forward(x) # noqa: F821 688 689 def test_freeze_module_with_inplace_mutable(self): 690 class FreezeMe(torch.jit.ScriptModule): 691 def __init__(self) -> None: 692 super().__init__() 693 self.a = [11, 22] 694 695 @torch.jit.script_method 696 def forward(self, x): 697 for i in range(3): 698 self.a.append(i) 699 return self.a 700 701 m = FreezeMe() 702 m.eval() 703 m_f = torch._C._freeze_module(m._c) 704 self.assertTrue(m_f.hasattr("a")) 705 m.forward(torch.tensor([3])) 706 out = m_f.forward(torch.tensor([5])) 707 expected = [11, 22, 0, 1, 2, 0, 1, 2] 708 self.assertEqual(out, expected) 709 710 # Mutable attributes 711 def test_freeze_module_with_mutable_list(self): 712 class FreezeMe(nn.Module): 713 def __init__(self) -> None: 714 super().__init__() 715 self.a = [1, 2] 716 717 def forward(self, x): 718 return self.a 719 720 m = FreezeMe() 721 m.eval() 722 m.a.append(3) 723 m_s = torch.jit.script(m) 724 v = m_s.a 725 v.append(4) 726 m_s.a = v 727 m_s.eval() 728 m_f = torch._C._freeze_module(m_s._c) 729 # Post-freezing mutating m_s.a does not affect m_f (m_f has its own copy). 730 v = m_s.a 731 v.append(5) 732 m_s.a = v 733 self.assertFalse(m_f.hasattr("a")) 734 out = m_f.forward(torch.tensor([5])) 735 expected = [1, 2, 3, 4] 736 self.assertEqual(out, expected) 737 738 def test_freeze_module_with_mutable_dict(self): 739 class FreezeMe(nn.Module): 740 def __init__(self) -> None: 741 super().__init__() 742 self.a = {"layer": "4"} 743 744 def forward(self, x): 745 return self.a 746 747 @torch.jit.export 748 def modify_a(self, x): 749 self.a["layer"] = self.a["layer"] + "1" 750 return self.a 751 752 m = FreezeMe() 753 m.eval() 754 m.a["layer2"] = "3" 755 m_s = torch.jit.script(m) 756 t = torch.tensor(5) 757 m_s.modify_a(t) 758 m_s.eval() 759 m_f = torch._C._freeze_module(m_s._c) 760 m.a["layer2"] += "2" 761 m_s.modify_a(t) 762 self.assertFalse(m_f.hasattr("a")) 763 out = m_f.forward(t) 764 expected = {"layer": "411", "layer2": "3"} 765 self.assertEqual(out, expected) 766 767 def test_freeze_module_with_mutable_tensor(self): 768 class FreezeMe(nn.Module): 769 def __init__(self) -> None: 770 super().__init__() 771 self.a = torch.tensor([1.0, 2.0, 3.0]) 772 773 def forward(self, x): 774 return self.a 775 776 m = FreezeMe() 777 m_s = torch.jit.script(m) 778 m_s.a[1] += 3.0 779 m_s.eval() 780 m_f = torch._C._freeze_module(m_s._c) 781 # Post-freezing tensor attribute mutations affect m_f. 782 # FIXME: deep copy all folded attributes so that m_f has full ownership. 783 m_s.a[0] += 5.0 784 self.assertFalse(m_f.hasattr("a")) 785 out = m_f.forward(torch.tensor([5])) 786 expected = [6.0, 5.0, 3.0] 787 self.assertEqual(out, expected) 788 789 def test_freeze_module_with_tuple(self): 790 class FreezeMe(nn.Module): 791 def __init__(self) -> None: 792 super().__init__() 793 self.a = (torch.tensor([1, 2, 3, 4, 5, 6]), "hi") 794 795 def forward(self, x): 796 if x[0] == 2.0: 797 self.a[0][0] = 10 798 return self.a[0].sum() 799 800 m = FreezeMe() 801 m_s = torch.jit.script(m) 802 m_s.eval() 803 inp = torch.tensor([2.0]) 804 expected = m_s.forward(inp) 805 m_s.a[0][0] = 1 806 m_f = torch._C._freeze_module(m_s._c) 807 self.assertFalse(m_f.hasattr("a")) 808 out = m_f.forward(inp) 809 self.assertEqual(out, expected) 810 811 def test_freeze_module_with_tensor(self): 812 class FreezeMe(nn.Module): 813 def __init__(self) -> None: 814 super().__init__() 815 self.a = torch.tensor([1, 2, 3, 4, 5, 6]) 816 817 def forward(self, x): 818 x = self.a.view(2, 3) 819 x[0][0] += 10 820 return self.a.sum() 821 822 m = FreezeMe() 823 m_s = torch.jit.script(m) 824 m_s.eval() 825 inp = torch.tensor([5]) 826 expected = m_s.forward(inp) 827 m_f = torch._C._freeze_module(m_s._c) 828 self.assertTrue(m_f.hasattr("a")) 829 m_f.a[0] -= 10 830 out = m_f.forward(inp) 831 self.assertEqual(out, expected) 832 833 def test_freeze_module_with_list(self): 834 class FreezeMe(nn.Module): 835 def __init__(self) -> None: 836 super().__init__() 837 self.a = [torch.tensor([1, 2, 3, 4, 5, 6])] 838 839 def forward(self, x): 840 self.a[0][1] += 10 841 return self.a[0].sum() 842 843 m = FreezeMe() 844 m_s = torch.jit.script(m) 845 m_s.eval() 846 inp = torch.tensor([5]) 847 expected = m_s.forward(inp) 848 m_s.a[0][1] -= 10 849 m_f = torch._C._freeze_module(m_s._c) 850 self.assertFalse(m_f.hasattr("a")) 851 out = m_f.forward(inp) 852 self.assertEqual(out, expected) 853 854 def test_freeze_module_with_aliased_tensor_attr(self): 855 class FreezeMe(nn.Module): 856 def __init__(self) -> None: 857 super().__init__() 858 self.a = torch.tensor([1, 2, 3, 4, 5, 6]) 859 self.b = self.a.view(2, 3) 860 861 def forward(self, x): 862 self.b[1] += 10 863 return self.a.sum() 864 865 m = FreezeMe() 866 m_s = torch.jit.script(m) 867 m_s.eval() 868 m_f = torch._C._freeze_module(m_s._c) 869 self.assertTrue(m_f.hasattr("a")) 870 inp = torch.tensor([5]) 871 out = m_f.forward(inp) 872 expected = torch.tensor(51) # 1+2+3+14+15+16 873 self.assertEqual(out, expected) 874 875 def test_freeze_module_with_aliased_tensor_attr2(self): 876 class FreezeMe(nn.Module): 877 def __init__(self) -> None: 878 super().__init__() 879 self.a = torch.tensor([1, 2, 3, 4, 5, 6]) 880 self.b = {"layer": ([self.a.view(2, 3), torch.tensor([10])], 20)} 881 self.c = ([self.a.view(2, 3), torch.tensor([10])], 20) 882 self.d = (self.a.view(2, 3), 20) 883 884 def forward(self, x): 885 self.d[0][0] += 10 886 return self.a.sum() 887 888 m = FreezeMe() 889 m_s = torch.jit.script(m) 890 m_s.eval() 891 inp = torch.tensor([5]) 892 expected = m_s.forward(inp) 893 with self.assertRaisesRegex( 894 RuntimeError, "module contains attributes values that overlaps" 895 ): 896 m_f = torch._C._freeze_module(m_s._c) 897 898 def test_freeze_module_with_aliased_tensor_attr3(self): 899 class FreezeMe(nn.Module): 900 def __init__(self) -> None: 901 super().__init__() 902 self.a = torch.tensor([1, 2, 3, 4, 5, 6]) 903 self.b = [self.a, torch.tensor([10])] 904 905 def forward(self, x): 906 self.a[1] += 10 907 return self.b[0].sum() 908 909 m = FreezeMe() 910 m_s = torch.jit.script(m) 911 m_s.eval() 912 inp = torch.tensor([5]) 913 expected = m_s.forward(inp) 914 m_f = torch._C._freeze_module(m_s._c) 915 self.assertTrue(m_f.hasattr("a")) 916 self.assertTrue(m_f.hasattr("b")) 917 out = m_f.forward(inp) 918 expected += 10 # account for self.a += 10. 919 self.assertEqual(out, expected) 920 921 def test_freeze_module_with_aliased_tensor_attr4(self): 922 class FreezeMe(nn.Module): 923 def __init__(self) -> None: 924 super().__init__() 925 self.a = torch.tensor([1, 2, 3, 4, 5, 6]) 926 self.b = [self.a, torch.tensor([10])] 927 928 def forward(self, x): 929 self.b[0][0] += 10 930 return self.a.sum() 931 932 m = FreezeMe() 933 m_s = torch.jit.script(m) 934 m_s.eval() 935 inp = torch.tensor([5]) 936 expected = m_s.forward(inp) 937 m_s.a[0] -= 10 938 with self.assertRaisesRegex( 939 RuntimeError, "module contains attributes values that overlaps" 940 ): 941 m_f = torch._C._freeze_module(m_s._c) 942 943 def test_freeze_module_with_overlapping_attrs(self): 944 a = torch.tensor([1, 2, 3, 4, 5, 6]) 945 946 class FreezeMe(nn.Module): 947 def __init__(self) -> None: 948 super().__init__() 949 self.b = [a.view(3, 2), torch.tensor([10])] 950 self.c = (20, a.view(2, 3)) 951 952 def forward(self, x): 953 self.b[0][0] += 10 954 return self.c[1].sum() 955 956 m = FreezeMe() 957 m_s = torch.jit.script(m) 958 m_s.eval() 959 inp = torch.tensor([5]) 960 expected = m_s.forward(inp) 961 a[0] -= 10 962 with self.assertRaisesRegex( 963 RuntimeError, "module contains attributes values that overlaps" 964 ): 965 m_f = torch._C._freeze_module(m_s._c) 966 967 def test_freeze_module_with_aliased_attr(self): 968 class FreezeMe(nn.Module): 969 def __init__(self) -> None: 970 super().__init__() 971 self.a = [1, 2, 3, 4, 5, 6] 972 self.b = self.a 973 self.c = (self.a, 10) 974 975 def forward(self, x): 976 self.b[1] += 10 977 return str(self.a) + str(self.c) 978 979 m = FreezeMe() 980 m_s = torch.jit.script(m) 981 m_s.eval() 982 m_f = torch._C._freeze_module(m_s._c) 983 # FIXME: It should be assertTrue. Currently scripting is making a copy for setting self.b (see #33034) 984 self.assertFalse(m_f.hasattr("a")) 985 self.assertFalse(m_f.hasattr("c")) 986 inp = torch.tensor([5]) 987 out = m_f.forward(inp) 988 expected = m_s.forward(inp) 989 self.assertEqual(out, expected) 990 991 # Check attribute a is preserved. Alias analysis detects that 'a' has output writers. 992 # In this example, 'a' is not mutated. However, we do not track which sub 993 # values of a composite ivalue is mutated. 994 def test_freeze_module_with_aliased_attr2(self): 995 class FreezeMe(nn.Module): 996 def __init__(self) -> None: 997 super().__init__() 998 self.a = [1, 2, 3, 4, 5, 6] 999 self.b = ([11], [10]) 1000 1001 def forward(self, x): 1002 v = self.a 1003 self.b = (v, [12]) 1004 v2 = self.b[1] 1005 v2.append(7) 1006 return str(v) + str(v2) 1007 1008 m = FreezeMe() 1009 m_s = torch.jit.script(m) 1010 m_s.eval() 1011 m_f = torch._C._freeze_module(m_s._c) 1012 self.assertTrue(m_f.hasattr("a")) 1013 inp = torch.tensor([5]) 1014 out = m_f.forward(inp) 1015 expected = m.forward(inp) 1016 self.assertEqual(out, expected) 1017 1018 def test_freeze_module_with_aliased_attr3(self): 1019 class FreezeMe(nn.Module): 1020 def __init__(self) -> None: 1021 super().__init__() 1022 self.a = [1, 2, 3, 4, 5, 6] 1023 self.b = ([11], [10]) 1024 1025 def forward(self, x): 1026 v = self.a 1027 v2 = (v, [12]) 1028 v3 = v2[0] 1029 v3.append(7) 1030 return str(self.a) 1031 1032 m = FreezeMe() 1033 m_s = torch.jit.script(m) 1034 m_s.eval() 1035 m_f = torch._C._freeze_module(m_s._c) 1036 self.assertTrue(m_f.hasattr("a")) 1037 inp = torch.tensor([5]) 1038 out = m_f.forward(inp) 1039 expected = m.forward(inp) 1040 self.assertEqual(out, expected) 1041 1042 def test_freeze_module_return_self(self): 1043 class FreezeMe(nn.Module): 1044 def __init__(self) -> None: 1045 super().__init__() 1046 self.a = torch.tensor([1.0, 2.0, 3.0]) 1047 1048 def forward(self, x): 1049 return self 1050 1051 m = FreezeMe() 1052 m_s = torch.jit.script(m) 1053 m_s.eval() 1054 with self.assertRaisesRegex( 1055 RuntimeError, "attempted to freeze a module that return itself" 1056 ): 1057 m_f = torch._C._freeze_module(m_s._c) 1058 1059 def test_freeze_module_inlining(self): 1060 @torch.jit.script # noqa: B903 1061 class Obj: # noqa: B903 1062 def __init__(self, x: int, y: int): 1063 self.x = x 1064 self.y = y 1065 1066 class Mod(nn.Module): 1067 def __init__(self) -> None: 1068 super().__init__() 1069 self.obj = Obj(2, 3) 1070 1071 def forward(self, i: int): 1072 print(self.obj) 1073 return i 1074 1075 mod = torch.jit.freeze(torch.jit.script(Mod().eval())) 1076 obj = mod.graph.findNode("prim::Constant") 1077 self.assertTrue(torch._C._jit_object_is_non_holding(obj)) 1078 1079 buffer = io.BytesIO() 1080 torch.jit.save(mod, buffer) 1081 buffer.seek(0) 1082 1083 loaded = torch.jit.load(buffer) 1084 obj = mod.graph.findNode("prim::Constant") 1085 self.assertTrue(torch._C._jit_object_is_non_holding(obj)) 1086 1087 def test_freeze_module_return_sub_module(self): 1088 class FreezeMe(nn.Module): 1089 def __init__(self) -> None: 1090 super().__init__() 1091 self.conv1 = nn.Conv2d(1, 32, 3, 1) 1092 1093 def forward(self, x): 1094 return self.conv1 1095 1096 m = FreezeMe() 1097 m_s = torch.jit.script(m) 1098 m_s.eval() 1099 m_f = torch._C._freeze_module(m_s._c) 1100 self.assertTrue(m_f.hasattr("conv1")) 1101 1102 def test_freeze_module_no_forward(self): 1103 class FreezeMe(nn.Module): 1104 def __init__(self) -> None: 1105 super().__init__() 1106 self.lin = nn.Linear(10, 1) 1107 1108 @torch.jit.export 1109 def foo(self, x): 1110 return self.lin(x) 1111 1112 m = FreezeMe() 1113 m_s = torch.jit.script(m) 1114 m_s.eval() 1115 m_f = torch._C._freeze_module(m_s._c, preservedAttrs=["foo"]) 1116 input = torch.ones(10) 1117 self.assertEqual(m_s.foo(input), m_f.foo(input)) 1118 1119 def test_freeze_no_forward(self): 1120 class FreezeMe(nn.Module): 1121 def __init__(self) -> None: 1122 super().__init__() 1123 self.lin = nn.Linear(10, 1) 1124 1125 @torch.jit.export 1126 def foo(self, x): 1127 return self.lin(x) 1128 1129 m = FreezeMe() 1130 m_s = torch.jit.script(m) 1131 m_s.eval() 1132 m_f = torch.jit.freeze(m_s, preserved_attrs=["foo"]) 1133 input = torch.ones(10) 1134 self.assertEqual(m_s.foo(input), m_f.foo(input)) 1135 1136 def test_freeze_module_in_training_mode(self): 1137 class Net(nn.Module): 1138 def __init__(self) -> None: 1139 super().__init__() 1140 self.conv1 = nn.Conv2d(1, 32, 3, 1) 1141 self.conv2 = nn.Conv2d(32, 64, 3, 1) 1142 self.dropout1 = nn.Dropout2d(0.25) 1143 self.dropout2 = nn.Dropout2d(0.5) 1144 self.fc1 = nn.Linear(9216, 128) 1145 self.fc2 = nn.Linear(128, 10) 1146 1147 def forward(self, x): 1148 x = self.conv1(x) 1149 x = nn.functional.relu(x) 1150 x = self.conv2(x) 1151 x = nn.functional.max_pool2d(x, 2) 1152 x = self.dropout1(x) 1153 x = torch.flatten(x, 1) 1154 x = self.fc1(x) 1155 x = nn.functional.relu(x) 1156 x = self.dropout2(x) 1157 x = self.fc2(x) 1158 output = nn.functional.log_softmax(x, dim=1) 1159 return output 1160 1161 model = torch.jit.script(Net()) 1162 model.train() 1163 mTrain_freezed = torch._C._freeze_module(model._c) 1164 # verify mTrain_freezed looks exactly as: 1165 # module { 1166 # attributes { 1167 # conv1 = ... 1168 # conv2 = ... 1169 # dropout1 = ... 1170 # dropout2 = ... 1171 # fc1 = ... 1172 # fc2 = ... 1173 # } 1174 # ... 1175 # submodules { 1176 # module conv1 { 1177 # attributes { 1178 # weight = ... 1179 # bias = ... 1180 # } 1181 # ... 1182 # } 1183 # module conv2 { 1184 # attributes { 1185 # weight = ... 1186 # bias = ... 1187 # } 1188 # ... 1189 # } 1190 # module dropout1 { 1191 # attributes { 1192 # training = ... 1193 # } 1194 # ... 1195 # } 1196 # module dropout2 { 1197 # attributes { 1198 # training = ... 1199 # } 1200 # ... 1201 # } 1202 # module fc1 { 1203 # attributes { 1204 # weight = ... 1205 # bias = ... 1206 # } 1207 # ... 1208 # } 1209 # module fc2 { 1210 # attributes { 1211 # weight = ... 1212 # bias = ... 1213 # } 1214 # ... 1215 # } 1216 self.assertFalse(mTrain_freezed.hasattr("training")) 1217 self.assertTrue(mTrain_freezed.hasattr("conv1")) 1218 self.assertFalse(mTrain_freezed.conv1.hasattr("training")) 1219 self.assertTrue(mTrain_freezed.conv1.hasattr("weight")) 1220 self.assertTrue(mTrain_freezed.conv1.hasattr("bias")) 1221 self.assertTrue(mTrain_freezed.hasattr("conv2")) 1222 self.assertFalse(mTrain_freezed.conv2.hasattr("training")) 1223 self.assertTrue(mTrain_freezed.conv2.hasattr("weight")) 1224 self.assertTrue(mTrain_freezed.conv2.hasattr("bias")) 1225 self.assertTrue(mTrain_freezed.hasattr("dropout1")) 1226 self.assertTrue(mTrain_freezed.dropout1.hasattr("training")) 1227 self.assertTrue(mTrain_freezed.hasattr("dropout2")) 1228 self.assertTrue(mTrain_freezed.dropout2.hasattr("training")) 1229 self.assertTrue(mTrain_freezed.hasattr("fc1")) 1230 self.assertTrue(mTrain_freezed.fc1.hasattr("weight")) 1231 self.assertTrue(mTrain_freezed.fc1.hasattr("bias")) 1232 self.assertTrue(mTrain_freezed.hasattr("fc2")) 1233 self.assertTrue(mTrain_freezed.fc2.hasattr("weight")) 1234 self.assertTrue(mTrain_freezed.fc2.hasattr("bias")) 1235 model.eval() 1236 mEval_freezed = torch._C._freeze_module(model._c) 1237 self.assertFalse(mEval_freezed.hasattr("conv1")) 1238 self.assertFalse(mEval_freezed.hasattr("conv2")) 1239 self.assertFalse(mEval_freezed.hasattr("dropout1")) 1240 self.assertFalse(mEval_freezed.hasattr("training")) 1241 self.assertFalse(mEval_freezed.hasattr("fc1")) 1242 self.assertFalse(mEval_freezed.hasattr("dropout2")) 1243 self.assertFalse(mEval_freezed.hasattr("fc2")) 1244 with self.assertRaisesRegex( 1245 AttributeError, "does not have a field with name 'state_dict'" 1246 ): 1247 print(mEval_freezed.state_dict()) 1248 buffer = io.BytesIO() 1249 torch.jit.save(mEval_freezed, buffer) 1250 buffer.seek(0) 1251 m = torch.jit.load(buffer) 1252 FileCheck().check_not("GetAttr[name=").run(m._c._get_method("forward").graph) 1253 m2 = torch._C._freeze_module(model._c, preserveParameters=True) 1254 self.assertTrue(m2.hasattr("conv1")) 1255 self.assertTrue(m2.hasattr("conv2")) 1256 self.assertFalse(m2.hasattr("dropout1")) 1257 self.assertFalse(m2.hasattr("training")) 1258 self.assertTrue(m2.hasattr("fc1")) 1259 self.assertFalse(m2.hasattr("dropout2")) 1260 self.assertTrue(m2.hasattr("fc2")) 1261 1262 def test_freeze_module_detach_gradient(self): 1263 mod = nn.Conv2d(8, 3, 4, 2, 1) 1264 self.assertTrue(mod.weight.requires_grad) 1265 smod = torch.jit.script(mod) 1266 smod.eval() 1267 fmod = torch._C._freeze_module(smod._c) 1268 self.assertTrue(mod.weight.requires_grad) 1269 self.assertTrue(smod.weight.requires_grad) 1270 self.assertFalse(fmod.hasattr("weight")) 1271 inp = torch.ones(1, 8, 32, 32) 1272 out1 = fmod.forward(inp) 1273 # FIXME: frozen module mutated from outside (original module). 1274 with torch.no_grad(): 1275 smod.weight[0, 0, 0, 0] += 100.0 1276 out2 = fmod.forward(inp) 1277 out3 = smod(inp) 1278 self.assertNotEqual(out1, out2) 1279 self.assertEqual(out2, out3) 1280 1281 def test_freeze_module_with_user_preserved_attr(self): 1282 class Module(nn.Module): 1283 def __init__(self) -> None: 1284 super().__init__() 1285 self.a = torch.tensor([1.1]) 1286 self.b = torch.tensor([2.2]) 1287 1288 def forward(self, x): 1289 return self.a + self.b 1290 1291 m = torch.jit.script(Module()) 1292 m.eval() 1293 fm = torch._C._freeze_module(m._c, ["a"]) 1294 # Attribute "a" is preserved 1295 self.assertTrue(fm.hasattr("a")) 1296 self.assertFalse(fm.hasattr("b")) 1297 1298 def test_freeze_module_with_user_preserved_method(self): 1299 class Module(nn.Module): 1300 def __init__(self) -> None: 1301 super().__init__() 1302 self.a = torch.tensor([1.1]) 1303 self.b = torch.tensor([2.2]) 1304 1305 def forward(self, x): 1306 return self.a + self.b 1307 1308 @torch.jit.export 1309 def modify_a(self, x): 1310 self.a[0] += 10 1311 return self.b 1312 1313 @torch.jit.export 1314 def modify_b(self, x): 1315 self.b[0] += 20 1316 return self.a 1317 1318 m = torch.jit.script(Module()) 1319 m.eval() 1320 fm = torch._C._freeze_module(m._c, ["modify_a"]) 1321 # Both attribute "a" and method "modify_a" are preserved 1322 self.assertTrue(fm.hasattr("a")) 1323 self.assertFalse(fm.hasattr("b")) 1324 input = torch.randn(2, 2) 1325 expected = m.forward(input) 1326 out = fm.forward(input) 1327 self.assertEqual(out, expected) 1328 1329 def test_freeze_module_with_user_preserved_method2(self): 1330 class Module(nn.Module): 1331 def __init__(self) -> None: 1332 super().__init__() 1333 self.a = torch.tensor([1.1]) 1334 self.b = torch.tensor([2.2]) 1335 1336 def forward(self, x): 1337 self.b += 10 1338 return self.a + self.b 1339 1340 @torch.jit.export 1341 def modify_a(self, x): 1342 self.a[0] += 10 1343 return self.b + self.a 1344 1345 m = torch.jit.script(Module()) 1346 m.eval() 1347 fm = torch._C._freeze_module(m._c, ["modify_a"]) 1348 FileCheck().check('prim::GetAttr[name="a"]').run(fm.forward.graph) 1349 FileCheck().check('prim::GetAttr[name="b"]').run(fm.modify_a.graph) 1350 1351 def test_freeze_module_with_user_preserved_attribute_on_submodule(self): 1352 class SubModule(nn.Module): 1353 def __init__(self) -> None: 1354 super().__init__() 1355 self.a = 1 1356 self.b = 2 1357 1358 def forward(self): 1359 return self.a + self.b 1360 1361 class Module(nn.Module): 1362 def __init__(self) -> None: 1363 super().__init__() 1364 self.sub1 = SubModule() 1365 self.sub2 = SubModule() 1366 1367 def forward(self): 1368 return self.sub1() + self.sub2() 1369 1370 m = torch.jit.script(Module()) 1371 m.eval() 1372 m = torch.jit.freeze(m, preserved_attrs=["sub1.a", "sub2.a"]) 1373 fm = m._c 1374 1375 self.assertTrue(fm.hasattr("sub1")) 1376 self.assertTrue(fm.sub1.hasattr("a")) 1377 self.assertFalse(fm.sub1.hasattr("b")) 1378 self.assertTrue(fm.hasattr("sub2")) 1379 self.assertTrue(fm.sub2.hasattr("a")) 1380 self.assertFalse(fm.sub2.hasattr("b")) 1381 self.assertEqual(m(), 6) 1382 m.sub1.a += 1 1383 self.assertEqual(m(), 7) 1384 1385 def test_freeze_module_with_user_preserved_attribute_on_unused_submodule(self): 1386 class SubModule(nn.Module): 1387 def __init__(self) -> None: 1388 super().__init__() 1389 self.a = 1 1390 self.b = 2 1391 1392 def forward(self): 1393 return self.a + self.b 1394 1395 @torch.jit.export 1396 def method_a(self): 1397 return 42 1398 1399 class Module(nn.Module): 1400 def __init__(self) -> None: 1401 super().__init__() 1402 self.sub = SubModule() 1403 1404 def forward(self): 1405 return 1 1406 1407 m = torch.jit.script(Module()) 1408 m.eval() 1409 fm = torch.jit.freeze(m, preserved_attrs=["sub.a", "sub.method_a"])._c 1410 1411 self.assertTrue(fm.hasattr("sub")) 1412 self.assertTrue(fm.sub.hasattr("a")) 1413 self.assertFalse(fm.sub.hasattr("b")) 1414 self.assertTrue(fm.sub._has_method("method_a")) 1415 1416 def test_freeze_module_with_user_preserved_method_on_submodule(self): 1417 class SubModule(nn.Module): 1418 def forward(self, x): 1419 return self.method_a(x) + self.method_b(x) 1420 1421 def method_a(self, x): 1422 return x * x 1423 1424 def method_b(self, x): 1425 return x + x 1426 1427 class Module(nn.Module): 1428 def __init__(self) -> None: 1429 super().__init__() 1430 self.sub = SubModule() 1431 1432 def forward(self, x): 1433 return self.sub(x) 1434 1435 m = torch.jit.script(Module()) 1436 m.eval() 1437 fm = torch.jit.freeze(m, preserved_attrs=["sub.method_a"])._c 1438 1439 self.assertTrue(fm.hasattr("sub")) 1440 self.assertTrue(fm.sub._has_method("method_a")) 1441 self.assertFalse(fm.sub._has_method("method_b")) 1442 1443 @skipIfNoFBGEMM 1444 def test_module_with_shared_type_instances(self): 1445 class Child(nn.Module): 1446 def __init__(self) -> None: 1447 super().__init__() 1448 self.conv1 = nn.Conv2d(1, 1, 1).to(dtype=torch.float32) 1449 1450 def forward(self, x): 1451 x = self.conv1(x) 1452 return x 1453 1454 class Parent(nn.Module): 1455 def __init__(self) -> None: 1456 super().__init__() 1457 self.quant = torch.ao.quantization.QuantStub() 1458 self.conv1 = nn.Conv2d(1, 1, 1).to(dtype=torch.float32) 1459 self.child = Child() 1460 self.child2 = Child() 1461 self.dequant = torch.ao.quantization.DeQuantStub() 1462 1463 def forward(self, x): 1464 x = self.quant(x) 1465 x = self.conv1(x) 1466 x = self.child(x) 1467 x = self.child2(x) 1468 x = self.dequant(x) 1469 return x 1470 1471 def _static_quant(model): 1472 qModel = torch.ao.quantization.QuantWrapper(model) 1473 qModel.qconfig = torch.ao.quantization.default_qconfig 1474 torch.ao.quantization.prepare(qModel, inplace=True) 1475 qModel(torch.rand(4, 1, 4, 4, dtype=torch.float32)) 1476 torch.ao.quantization.convert(qModel, inplace=True) 1477 return model 1478 1479 with override_quantized_engine("fbgemm"): 1480 data = torch.randn(4, 1, 4, 4, dtype=torch.float32) 1481 m = Parent().to(torch.float32) 1482 m = _static_quant(m) 1483 m = torch.jit.script(m) 1484 m.eval() 1485 torch._C._jit_pass_inline(m.graph) 1486 m_frozen = wrap_cpp_module(torch._C._freeze_module(m._c)) 1487 # Earlier bug resulted in _packed_params set to false. 1488 FileCheck().check_not("_packed_params = False").run( 1489 m_frozen._c.dump_to_str(True, True, False) 1490 ) 1491 1492 m_res = m(data) 1493 # It used to segfault while running frozen module. 1494 m_frozen_res = m_frozen(data) 1495 self.assertEqual(m_res, m_frozen_res) 1496 1497 def test_module_getattr_indirection(self): 1498 @torch.jit.script 1499 class ValHolder: 1500 def __init__(self, val: int): 1501 self.val: int = val 1502 1503 class Mod(nn.Module): 1504 def __init__(self) -> None: 1505 super().__init__() 1506 self.mod1 = ValHolder(1) 1507 self.mod2 = ValHolder(2) 1508 1509 def forward(self, cond: bool): 1510 if cond: 1511 mod = self.mod1 1512 else: 1513 mod = self.mod2 1514 return mod.val 1515 1516 mod = Mod() 1517 mod.eval() 1518 frozen_mod = torch.jit.freeze(torch.jit.script(mod)) 1519 mod_eager = Mod() 1520 self.assertEqual(mod_eager(True), frozen_mod(True)) 1521 self.assertEqual(mod_eager(False), frozen_mod(False)) 1522 1523 def test_freeze_module_with_non_static_module_container_index(self): 1524 """ 1525 Test that Modules containing non-static ModuleDict or ModuleList 1526 indexing cannot be frozen. 1527 """ 1528 1529 @torch.jit.interface 1530 class ModuleInterface(torch.nn.Module): 1531 def forward(self, inp: Any) -> Any: 1532 pass 1533 1534 class ImplementsInterface(torch.nn.Module): 1535 def forward(self, inp: Any) -> Any: 1536 if isinstance(inp, torch.Tensor): 1537 return torch.max(inp, dim=0) 1538 1539 return inp 1540 1541 class ModWithDict(torch.nn.Module): 1542 def __init__(self) -> None: 1543 super().__init__() 1544 self.d = torch.nn.ModuleDict({"module": ImplementsInterface()}) 1545 1546 def forward(self, x: torch.Tensor, key: str) -> Any: 1547 value: ModuleInterface = self.d[key] 1548 return value.forward(x) 1549 1550 m = torch.jit.script(ModWithDict()) 1551 m.eval() 1552 with self.assertRaisesRegex( 1553 RuntimeError, 1554 "Freezing modules containing prim::ModuleContainerIndex is not supported", 1555 ): 1556 mf = torch._C._freeze_module(m._c) 1557 1558 class ModWithList(torch.nn.Module): 1559 def __init__(self) -> None: 1560 super().__init__() 1561 self.l = torch.nn.ModuleList([ImplementsInterface()]) 1562 1563 def forward(self, x: torch.Tensor, idx: int) -> Any: 1564 value: ModuleInterface = self.l[idx] 1565 return value.forward(x) 1566 1567 m = torch.jit.script(ModWithList()) 1568 m.eval() 1569 with self.assertRaisesRegex( 1570 RuntimeError, 1571 "Freezing modules containing prim::ModuleContainerIndex is not supported", 1572 ): 1573 mf = torch._C._freeze_module(m._c) 1574 1575 def test_freeze_with_interface_mutable(self): 1576 @torch.jit.interface 1577 class ModuleInterface(torch.nn.Module): 1578 def forward(self, inp: torch.Tensor) -> torch.Tensor: 1579 pass 1580 1581 class ImplementsInterface(torch.nn.Module): 1582 def __init__(self) -> None: 1583 super().__init__() 1584 self.sum = torch.zeros((2, 2)) 1585 1586 def forward(self, inp: torch.Tensor) -> torch.Tensor: 1587 self.sum += inp.relu() 1588 return self.sum 1589 1590 class WrapperModule(torch.nn.Module): 1591 impl: ModuleInterface 1592 1593 def __init__(self) -> None: 1594 super().__init__() 1595 self.impl = ImplementsInterface() 1596 1597 def forward(self, x: torch.Tensor) -> torch.Tensor: 1598 return self.impl.forward(x) 1599 1600 m = torch.jit.script(WrapperModule()) 1601 m.eval() 1602 m_frozen = torch.jit.freeze(m) 1603 1604 x = torch.rand((2, 2)) 1605 1606 m_frozen(x) 1607 self.assertEqual(m_frozen.impl.sum, x.relu()) 1608 1609 def test_freeze_with_swapping_interfaces(self): 1610 @torch.jit.interface 1611 class ModuleInterface(torch.nn.Module): 1612 def forward(self, inp: torch.Tensor) -> torch.Tensor: 1613 pass 1614 1615 class Implementation1(torch.nn.Module): 1616 def forward(self, inp: torch.Tensor) -> torch.Tensor: 1617 return inp.relu() 1618 1619 class Implementation2(torch.nn.Module): 1620 def forward(self, inp: torch.Tensor) -> torch.Tensor: 1621 return inp.sin() 1622 1623 class WrapperModule(torch.nn.Module): 1624 impl: ModuleInterface 1625 1626 def __init__(self) -> None: 1627 super().__init__() 1628 self.option1 = Implementation1() 1629 self.option2 = Implementation2() 1630 self.impl = self.option1 1631 self.idx = 0 1632 1633 def forward(self, x: torch.Tensor) -> torch.Tensor: 1634 self.idx += 1 1635 if self.idx % 2 == 1: 1636 self.impl = self.option1 1637 else: 1638 self.impl = self.option2 1639 return self.impl(x) 1640 1641 m = torch.jit.script(WrapperModule()) 1642 m.eval() 1643 with self.assertRaisesRegex( 1644 RuntimeError, "Freezing does not support SetAttr on an interface type" 1645 ): 1646 m_frozen = torch.jit.freeze(m) 1647 1648 def test_freeze_recursive_interfaces(self): 1649 @torch.jit.interface 1650 class InnerInterface(torch.nn.Module): 1651 def forward(self, inp: torch.Tensor) -> torch.Tensor: 1652 pass 1653 1654 @torch.jit.interface 1655 class OuterInterface(torch.nn.Module): 1656 def forward(self, inp: torch.Tensor) -> torch.Tensor: 1657 pass 1658 1659 class InnerImpl(torch.nn.Module): 1660 def __init__(self) -> None: 1661 super().__init__() 1662 self.x = torch.ones((2, 2)) 1663 1664 def forward(self, inp): 1665 return inp.cos() * self.x 1666 1667 class OuterImpl(torch.nn.Module): 1668 inner_impl: InnerInterface 1669 1670 def __init__(self) -> None: 1671 super().__init__() 1672 self.inner_impl = InnerImpl() 1673 1674 def forward(self, inp): 1675 return inp.relu() + self.inner_impl(inp.sin()) 1676 1677 class WrapperModule(torch.nn.Module): 1678 outer_impl: OuterInterface 1679 1680 def __init__(self) -> None: 1681 super().__init__() 1682 self.outer_impl = OuterImpl() 1683 1684 def forward(self, inp): 1685 return self.outer_impl(inp) + inp 1686 1687 m = WrapperModule() 1688 x = torch.rand((2, 2)) 1689 expected = m(x) 1690 1691 m_s = torch.jit.script(m) 1692 m_s.eval() 1693 m_s = torch.jit.freeze(m_s) 1694 actual = m_s(x) 1695 1696 self.assertEqual(expected, actual) 1697 1698 def test_freeze_recursive_interfaces_with_reassignment(self): 1699 @torch.jit.interface 1700 class InnerInterface(torch.nn.Module): 1701 def forward(self, inp: torch.Tensor) -> torch.Tensor: 1702 pass 1703 1704 @torch.jit.interface 1705 class OuterInterface(torch.nn.Module): 1706 def forward(self, inp: torch.Tensor) -> torch.Tensor: 1707 pass 1708 1709 class InnerImpl1(torch.nn.Module): 1710 def __init__(self) -> None: 1711 super().__init__() 1712 self.x = torch.ones((2, 2)) 1713 1714 def forward(self, inp): 1715 return inp.cos() * self.x 1716 1717 class InnerImpl2(torch.nn.Module): 1718 def __init__(self) -> None: 1719 super().__init__() 1720 self.x = torch.ones((2, 2)) * 2 1721 1722 def forward(self, inp): 1723 return inp.sin() / self.x 1724 1725 class OuterImpl(torch.nn.Module): 1726 inner_impl: InnerInterface 1727 1728 def __init__(self) -> None: 1729 super().__init__() 1730 self.inner_impl = InnerImpl1() 1731 self.impl1 = InnerImpl1() 1732 self.impl2 = InnerImpl1() 1733 self.idx = 0 1734 1735 def forward(self, inp): 1736 self.idx += 1 1737 if self.idx % 2 == 0: 1738 self.inner_impl = self.impl1 1739 else: 1740 self.inner_impl = self.impl2 1741 return inp.relu() + self.inner_impl(inp.sin()) 1742 1743 class WrapperModule(torch.nn.Module): 1744 outer_impl: OuterInterface 1745 1746 def __init__(self) -> None: 1747 super().__init__() 1748 self.outer_impl = OuterImpl() 1749 1750 def forward(self, inp): 1751 return self.outer_impl(inp) + inp 1752 1753 m = WrapperModule() 1754 1755 m_s = torch.jit.script(m) 1756 m_s.eval() 1757 with self.assertRaisesRegex( 1758 RuntimeError, "Freezing does not support SetAttr on an interface type" 1759 ): 1760 m_s = torch.jit.freeze(m_s) 1761 1762 def test_freeze_interface_swapping_two_methods(self): 1763 @torch.jit.interface 1764 class MyInterface(torch.nn.Module): 1765 def forward(self, inp: torch.Tensor) -> torch.Tensor: 1766 pass 1767 1768 class Impl1(torch.nn.Module): 1769 def forward(self, inp): 1770 return inp.cos() 1771 1772 class Impl2(torch.nn.Module): 1773 def forward(self, inp): 1774 return inp.sin() 1775 1776 class WrapperModule1(torch.nn.Module): 1777 interface_impl: MyInterface 1778 1779 def __init__(self) -> None: 1780 super().__init__() 1781 self.interface_impl = Impl1() 1782 self.impl1 = Impl1() 1783 self.impl2 = Impl2() 1784 self.idx = 0 1785 1786 def forward(self, x): 1787 return self.interface_impl(x) 1788 1789 @torch.jit.export 1790 def other_method(self, x): 1791 self.idx += 1 1792 if self.idx % 2 == 0: 1793 self.interface_impl = self.impl1 1794 else: 1795 self.interface_impl = self.impl2 1796 return self.interface_impl(x) 1797 1798 class WrapperModule2(torch.nn.Module): 1799 interface_impl: MyInterface 1800 1801 def __init__(self) -> None: 1802 super().__init__() 1803 self.interface_impl = Impl1() 1804 self.impl1 = Impl1() 1805 self.impl2 = Impl2() 1806 self.idx = 0 1807 1808 def forward(self, x): 1809 self.idx += 1 1810 if self.idx % 2 == 0: 1811 self.interface_impl = self.impl1 1812 else: 1813 self.interface_impl = self.impl2 1814 return self.interface_impl(x) 1815 1816 @torch.jit.export 1817 def other_method(self, x): 1818 return self.interface_impl(x) 1819 1820 m1 = torch.jit.script(WrapperModule1()) 1821 m2 = torch.jit.script(WrapperModule2()) 1822 1823 m1.eval() 1824 m2.eval() 1825 1826 with self.assertRaisesRegex( 1827 RuntimeError, "Freezing does not support SetAttr on an interface type" 1828 ): 1829 torch.jit.freeze(m1, preserved_attrs=["other_method"]) 1830 1831 with self.assertRaisesRegex( 1832 RuntimeError, "Freezing does not support SetAttr on an interface type" 1833 ): 1834 torch.jit.freeze(m2, preserved_attrs=["other_method"]) 1835 1836 def test_freeze_recursive_interfaces_same_name(self): 1837 @torch.jit.interface 1838 class InnerInterface(torch.nn.Module): 1839 def forward(self, inp: torch.Tensor) -> torch.Tensor: 1840 pass 1841 1842 @torch.jit.interface 1843 class OuterInterface(torch.nn.Module): 1844 def forward(self, inp: torch.Tensor) -> torch.Tensor: 1845 pass 1846 1847 class InnerImpl(torch.nn.Module): 1848 def __init__(self) -> None: 1849 super().__init__() 1850 self.x = torch.ones((2, 2)) 1851 1852 def forward(self, inp): 1853 return inp.cos() * self.x 1854 1855 class OuterImpl(torch.nn.Module): 1856 impl: InnerInterface 1857 1858 def __init__(self) -> None: 1859 super().__init__() 1860 self.impl = InnerImpl() 1861 self.x = torch.ones((2, 2)) * 5 1862 1863 def forward(self, inp): 1864 return self.other_method(inp) 1865 1866 def other_method(self, inp): 1867 return inp.relu() + self.impl(inp.sin()) + self.x 1868 1869 class WrapperModule(torch.nn.Module): 1870 impl: OuterInterface 1871 1872 def __init__(self) -> None: 1873 super().__init__() 1874 self.impl = OuterImpl() 1875 1876 def forward(self, inp): 1877 return self.impl(inp) + inp 1878 1879 m = WrapperModule() 1880 x = torch.rand((2, 2)) 1881 expected = m(x) 1882 1883 m_s = torch.jit.script(m) 1884 m_s.eval() 1885 m_s = torch.jit.freeze(m_s) 1886 actual = m_s(x) 1887 1888 self.assertEqual(expected, actual) 1889 1890 def test_freeze_non_interface_module_swap(self): 1891 class InnerModule(torch.nn.Module): 1892 def __init__(self, x): 1893 super().__init__() 1894 self.x = x 1895 1896 def forward(self, inp: torch.Tensor) -> torch.Tensor: 1897 return inp.relu() + self.x 1898 1899 class WrapperModule(torch.nn.Module): 1900 def __init__(self) -> None: 1901 super().__init__() 1902 self.option1 = InnerModule(torch.rand((2, 2))) 1903 self.option2 = InnerModule(torch.rand((2, 2))) 1904 self.impl = self.option1 1905 self.idx = 0 1906 1907 def forward(self, x: torch.Tensor) -> torch.Tensor: 1908 self.idx += 1 1909 if self.idx % 2 == 1: 1910 self.impl = self.option1 1911 else: 1912 self.impl = self.option2 1913 return self.impl(x) 1914 1915 unfrozen = WrapperModule() 1916 m = torch.jit.script(unfrozen) 1917 m.eval() 1918 m_frozen = torch.jit.freeze(m) 1919 1920 x = torch.rand((2, 2)) 1921 expected = unfrozen(x) 1922 actual = m_frozen(x) 1923 self.assertEqual(expected, actual) 1924 1925 @unittest.expectedFailure 1926 def test_freeze_interface_within_object(self): 1927 # I don't think there's any way to create a plain python object that 1928 # contains a torch.nn.Module inside it, but just in case... I'm not 1929 # sure freezing would handle this case correctly, so marking as xfail 1930 # so that if this ever _does_ start working someone will need to 1931 # investigate to make sure this is handled correctly. 1932 class MyIface(torch.nn.Module): 1933 def forward(self, inp: torch.Tensor) -> torch.Tensor: 1934 pass 1935 1936 class MyImpl(torch.nn.Module): 1937 def forward(self, inp: torch.Tensor) -> torch.Tensor: 1938 return inp.sin() 1939 1940 class MyObject: 1941 impl: MyIface 1942 1943 def run(self, x): 1944 return self.impl(x) 1945 1946 class WrapperModule(torch.nn.Module): 1947 impl: MyObject 1948 1949 def __init__(self) -> None: 1950 super().__init__() 1951 self.impl = MyObject() 1952 self.impl.impl = MyImpl() 1953 1954 def forward(self, x: torch.Tensor) -> torch.Tensor: 1955 return self.impl(x) 1956 1957 unfrozen = WrapperModule() 1958 m = torch.jit.script(unfrozen) 1959 m.eval() 1960 m_frozen = torch.jit.freeze(m) 1961 1962 x = torch.rand((2, 2)) 1963 expected = unfrozen(x) 1964 actual = m_frozen(x) 1965 self.expectEqual(expected, actual) 1966 1967 def test_freeze_non_module_class_getattr(self): 1968 class BoxCoder: 1969 def __init__(self, bbox_xform_clip): 1970 # type: (float) -> None 1971 self.bbox_xform_clip = bbox_xform_clip 1972 1973 def decode(self, input): 1974 return input * self.bbox_xform_clip 1975 1976 class MyModule(torch.nn.Module): 1977 __annotations__ = { 1978 "box_coder": BoxCoder, 1979 } 1980 1981 def __init__(self) -> None: 1982 super().__init__() 1983 self.box_coder = BoxCoder(50.0) 1984 1985 def forward(self, input): 1986 return self.box_coder.decode(input) 1987 1988 model = MyModule() 1989 model.eval() 1990 script_model = torch.jit.freeze(torch.jit.script(model)) 1991 inp = torch.randn([4, 4]) 1992 output_eager = model(inp) 1993 self.assertEqual(model(inp), script_model(inp)) 1994 FileCheck().check_not("GetAttr").run(script_model.graph) 1995 1996 def test_freeze_module_with_tupleoutput_submodule(self): 1997 class SubModule(nn.Module): 1998 def forward(self, x): 1999 return (x + 1, x + 2) 2000 2001 class TestModule(nn.Module): 2002 def __init__(self) -> None: 2003 super().__init__() 2004 self.sub = SubModule() 2005 2006 def forward(self, x): 2007 y1, y2 = self.sub(x) 2008 return y1 + y2 2009 2010 m = torch.jit.script(TestModule()) 2011 m = m.eval() 2012 mf = torch.jit.freeze(m) 2013 inp = torch.randn(2, 2) 2014 expected = m.forward(inp) 2015 output = mf.forward(inp) 2016 # Check if prim::TupleConstruct and prim::TupleUnpack 2017 # Don't exist in frozen graph 2018 FileCheck().check_not("prim::TupleConstruct").run(mf.graph) 2019 FileCheck().check_not("prim::TupleUnpack").run(mf.graph) 2020 self.assertEqual(output, expected) 2021 2022 def test_freeze_module_with_call_method(self): 2023 class Mod(nn.Module): 2024 def __init__(self, val): 2025 super().__init__() 2026 self.param = nn.Parameter(val) 2027 2028 def forward(self, x): 2029 # this method will change during freezing 2030 return x + self.param 2031 2032 @torch.jit.export 2033 def make_prediction(self, x): 2034 y = x + x 2035 return self.forward(y) 2036 2037 param = torch.rand([2, 2]) 2038 x = torch.rand([2, 2]) 2039 2040 unscripted_mod = Mod(param) 2041 mod = torch.jit.script(unscripted_mod) 2042 mod.eval() 2043 mod = torch.jit.freeze(mod, preserved_attrs=["make_prediction"]) 2044 2045 self.assertEqual( 2046 mod.forward(x), unscripted_mod.forward(x), atol=1e-5, rtol=1e-5 2047 ) 2048 2049 2050@skipIfTorchDynamo("somehow causing hanging during python shutdown") 2051class TestFrozenOptimizations(JitTestCase): 2052 def setUp(self): 2053 super().setUp() 2054 self.default_dtype = torch.get_default_dtype() 2055 torch.set_default_dtype(torch.double) 2056 2057 def tearDown(self): 2058 torch.set_default_dtype(self.default_dtype) 2059 super().tearDown() 2060 2061 def test_conv_bn_folding(self): 2062 conv_bias = [True, False] 2063 module_pairs = [ 2064 (nn.Conv1d, nn.BatchNorm1d), 2065 (nn.Conv2d, nn.BatchNorm2d), 2066 (nn.Conv3d, nn.BatchNorm3d), 2067 ] 2068 use_tracing = [True, False] 2069 bn_running_stats = [True, False] 2070 2071 for use_bias, modules, tracing, track_stats in product( 2072 conv_bias, module_pairs, use_tracing, bn_running_stats 2073 ): 2074 2075 class ConvBN(torch.nn.Module): 2076 def __init__(self, in_channels, out_channels, **kwargs): 2077 super().__init__() 2078 self.conv = modules[0]( 2079 in_channels, out_channels, bias=use_bias, **kwargs 2080 ) 2081 self.bn = modules[1]( 2082 out_channels, eps=0.001, track_running_stats=track_stats 2083 ) 2084 2085 def forward(self, x): 2086 x = self.conv(x) 2087 return self.bn(x) 2088 2089 mod_eager = ConvBN(3, 32, kernel_size=3, stride=2).eval() 2090 inps = [4, 3, 4] 2091 if modules[0] == nn.Conv2d: 2092 inps.append(inps[-1]) 2093 if modules[0] == nn.Conv3d: 2094 inps.append(inps[-1]) 2095 inps.append(inps[-1]) 2096 2097 inp = torch.rand(inps) 2098 2099 if tracing: 2100 scripted_mod = torch.jit.trace(mod_eager, (inp)) 2101 else: 2102 scripted_mod = torch.jit.script(mod_eager) 2103 2104 self.run_pass("inline", scripted_mod.graph) 2105 self.run_pass("peephole", scripted_mod.graph) 2106 self.run_pass("constant_propagation", scripted_mod.graph) 2107 2108 FileCheck().check("conv").check("batch").run(scripted_mod.graph) 2109 # successfully no-ops with non-const inputs 2110 self.run_pass("fold_frozen_conv_bn", scripted_mod.graph) 2111 FileCheck().check("conv").check("aten::batch_norm").run(scripted_mod.graph) 2112 2113 scripted_mod = torch.jit.freeze(scripted_mod) 2114 self.run_pass("fold_frozen_conv_bn", scripted_mod.graph) 2115 if track_stats: 2116 FileCheck().check("conv").check_not("aten::batch_norm").run( 2117 scripted_mod.graph 2118 ) 2119 else: 2120 FileCheck().check("conv").check("aten::batch_norm").run( 2121 scripted_mod.graph 2122 ) 2123 2124 self.assertEqual(mod_eager(inp), scripted_mod(inp)) 2125 self.assertEqual(mod_eager(inp), scripted_mod(inp)) 2126 2127 def test_conv_bn_folding_not_forward(self): 2128 class ConvBN(torch.nn.Module): 2129 def __init__(self, in_channels, out_channels, **kwargs): 2130 super().__init__() 2131 self.conv = torch.nn.Conv2d( 2132 in_channels, out_channels, bias=True, **kwargs 2133 ) 2134 self.bn = torch.nn.BatchNorm2d(out_channels, eps=0.001) 2135 self.amt = 3.2 2136 2137 def forward(self, x): 2138 x = self.conv(x) 2139 return self.bn(x) 2140 2141 @torch.jit.export 2142 def make_prediction(self, x): 2143 return self.forward(x) + self.amt 2144 2145 mod_eager = ConvBN(3, 32, kernel_size=3, stride=2).eval() 2146 scripted_mod = torch.jit.script(mod_eager) 2147 torch._C._jit_pass_inline(scripted_mod.make_prediction.graph) 2148 FileCheck().check("conv").check("aten::batch_norm").run( 2149 scripted_mod.make_prediction.graph 2150 ) 2151 2152 # _jit_pass_optimize_frozen_graph should not be called on non-method attributes (e.g. "amt") 2153 scripted_mod = torch.jit.freeze( 2154 scripted_mod, preserved_attrs=["make_prediction", "amt"] 2155 ) 2156 FileCheck().check("conv").check_not("aten::batch_norm").run( 2157 scripted_mod.make_prediction.graph 2158 ) 2159 2160 # During freezing this creates tensors constants that are attached to the frozen graph, 2161 # which is then kept alive by the compilation unit (which causes a leak) 2162 @skipCUDAMemoryLeakCheckIf(True) 2163 @unittest.skipIf(not TEST_CUDA, "Optimization currently only run for GPU") 2164 def test_conv_bn_folding_autocast_scenario_cuda(self): 2165 # CUDA conv takes input tensors which must all be the same dtype, 2166 # which can cause issues if folding produces inputs of different dtypes. 2167 2168 class ConvBN(torch.nn.Module): 2169 def __init__(self, in_channels, out_channels, **kwargs): 2170 super().__init__() 2171 self.conv = torch.nn.Conv2d( 2172 in_channels, out_channels, bias=False, dtype=torch.half, **kwargs 2173 ) 2174 self.bn = torch.nn.BatchNorm2d( 2175 out_channels, eps=0.001, dtype=torch.float 2176 ) 2177 2178 def forward(self, x): 2179 return self.bn(self.conv(x)) 2180 2181 mod_eager = ConvBN(3, 32, kernel_size=3, stride=2).cuda().eval() 2182 scripted_mod = torch.jit.script(mod_eager) 2183 scripted_mod = torch.jit.freeze(scripted_mod) 2184 FileCheck().check("conv").check_not("aten::batch_norm").run(scripted_mod.graph) 2185 conv_node = scripted_mod.graph.findNode("aten::conv2d", True) 2186 self.assertTrue(conv_node is not None) 2187 bias_input = conv_node.namedInput("bias") 2188 self.assertTrue(bias_input is not None) 2189 self.assertTrue(bias_input.type().dtype() == torch.half) 2190 2191 x = torch.rand((3, 3, 32, 32), dtype=torch.half).cuda() 2192 2193 self.assertEqual(mod_eager(x), scripted_mod(x), atol=1e-2, rtol=1e-2) 2194 self.assertEqual(mod_eager(x), scripted_mod(x), atol=1e-2, rtol=1e-2) 2195 2196 def test_conv_add_folding(self): 2197 @torch.no_grad() 2198 def test_conv_fusion( 2199 use_bias, module, tracing, op, scalar, add_tensor, expect_success 2200 ): 2201 class ConvOp(torch.nn.Module): 2202 __constants__ = ["use_scalar"] 2203 2204 def __init__(self, in_channels, out_channels, tensor=None, **kwargs): 2205 super().__init__() 2206 self.conv = module( 2207 in_channels, out_channels, bias=use_bias, **kwargs 2208 ) 2209 self.conv2 = module( 2210 in_channels, out_channels, bias=use_bias, **kwargs 2211 ) 2212 self.use_scalar = scalar 2213 tensor_size = [1 for _ in range(self.conv.weight.ndim)] 2214 tensor_size[1] = self.conv.weight.size(0) 2215 self.tensor = ( 2216 add_tensor 2217 if add_tensor is not None 2218 else torch.rand(tensor_size) 2219 ) 2220 self.op = op 2221 2222 def forward(self, x): 2223 x = self.conv(x) 2224 if self.use_scalar: 2225 return self.op(x, 2.0) 2226 else: 2227 return self.op(x, self.tensor) 2228 2229 mod_eager = ConvOp(3, 32, kernel_size=3, stride=2).eval() 2230 2231 inps = [4, 3, 4] 2232 if module == nn.Conv2d: 2233 inps.append(inps[-1]) 2234 if module == nn.Conv3d: 2235 inps.append(inps[-1]) 2236 inps.append(inps[-1]) 2237 2238 inp = torch.rand(inps) 2239 2240 if tracing: 2241 scripted_mod = torch.jit.trace(mod_eager, (inp,)) 2242 else: 2243 scripted_mod = torch.jit.script(mod_eager) 2244 2245 self.run_pass("inline", scripted_mod.graph) 2246 op_str = "aten::" + op.__name__ 2247 2248 FileCheck().check("conv").check(op_str).run(scripted_mod.graph) 2249 # successively no-ops with non-const inputs 2250 self.run_pass("fold_frozen_conv_mul_or_div", scripted_mod.graph) 2251 self.run_pass("fold_frozen_conv_add_or_sub", scripted_mod.graph) 2252 FileCheck().check("conv").check(op_str).run(scripted_mod.graph) 2253 scripted_mod = torch.jit.freeze(scripted_mod) 2254 self.run_pass("fold_frozen_conv_mul_or_div", scripted_mod.graph) 2255 self.run_pass("fold_frozen_conv_add_or_sub", scripted_mod.graph) 2256 2257 if expect_success: 2258 FileCheck().check("conv").check_not(op_str).run(scripted_mod.graph) 2259 else: 2260 FileCheck().check("conv").check(op_str).run(scripted_mod.graph) 2261 2262 self.assertEqual(mod_eager(inp), scripted_mod(inp)) 2263 self.assertEqual(mod_eager(inp), scripted_mod(inp)) 2264 2265 conv_bias = [True, False] 2266 modules = [nn.Conv1d, nn.Conv2d, nn.Conv3d] 2267 use_tracing = [False, True] 2268 use_scalar = [False, True] 2269 ops = [torch.add, torch.sub, torch.mul, torch.div] 2270 2271 for use_bias, module, tracing, pytorch_op, scalar in product( 2272 conv_bias, modules, use_tracing, ops, use_scalar 2273 ): 2274 test_conv_fusion( 2275 use_bias, 2276 module, 2277 tracing, 2278 pytorch_op, 2279 scalar, 2280 add_tensor=None, 2281 expect_success=True, 2282 ) 2283 2284 for use_bias, pytorch_op in product(conv_bias, ops): 2285 # broadcasting add 2286 test_conv_fusion( 2287 use_bias, 2288 nn.Conv2d, 2289 False, 2290 pytorch_op, 2291 False, 2292 add_tensor=torch.rand(32, 1, 32), 2293 expect_success=False, 2294 ) 2295 2296 # broadcasting add 2297 test_conv_fusion( 2298 use_bias, 2299 nn.Conv2d, 2300 False, 2301 pytorch_op, 2302 False, 2303 add_tensor=torch.rand(1, 1), 2304 expect_success=True, 2305 ) 2306 2307 # add with different dtype 2308 test_conv_fusion( 2309 use_bias, 2310 nn.Conv2d, 2311 False, 2312 pytorch_op, 2313 False, 2314 add_tensor=torch.tensor([2]).to(torch.int), 2315 expect_success=True, 2316 ) 2317 2318 def test_conv_mul_add_bn(self): 2319 class Conv_Mul_Add_Bn(nn.Module): 2320 def __init__(self, in_channels, out_channels, **kwargs): 2321 super().__init__() 2322 self.conv = nn.Conv2d(in_channels, out_channels, **kwargs) 2323 self.bn = nn.BatchNorm2d(out_channels, eps=0.001) 2324 self.tensor1 = torch.tensor(2.2) 2325 self.tensor2 = torch.tensor(2) 2326 2327 def forward(self, x): 2328 return self.bn( 2329 torch.add(torch.mul(self.conv(x), self.tensor1), self.tensor2) 2330 ) 2331 2332 input = torch.randn(8, 3, 64, 64) 2333 model = Conv_Mul_Add_Bn(3, 32, kernel_size=3, stride=1).eval() 2334 2335 with torch.no_grad(): 2336 result = model(input) 2337 traced_model = torch.jit.trace(model, input).eval() 2338 traced_model = torch.jit.freeze(traced_model) 2339 tresult = traced_model(input) 2340 self.assertEqual(result, tresult) 2341 FileCheck().check("conv").check_not("aten::batch_norm").run( 2342 traced_model.graph 2343 ) 2344 FileCheck().check("conv").check_not("aten::add").run(traced_model.graph) 2345 2346 def test_linear_bn_folding(self): 2347 module_pairs = [ 2348 (nn.Linear, nn.BatchNorm1d), 2349 (nn.Linear, nn.BatchNorm2d), 2350 (nn.Linear, nn.BatchNorm3d), 2351 ] 2352 use_tracing = [True, False] 2353 bn_running_stats = [True, False] 2354 2355 for modules, tracing, track_stats in product( 2356 module_pairs, use_tracing, bn_running_stats 2357 ): 2358 2359 class LinearBN(torch.nn.Module): 2360 def __init__(self, in_features, out_features): 2361 super().__init__() 2362 self.linear = modules[0](in_features, out_features) 2363 self.bn = modules[1]( 2364 out_features, eps=0.001, track_running_stats=track_stats 2365 ) 2366 2367 def forward(self, x): 2368 x = self.linear(x) 2369 return self.bn(x) 2370 2371 mod_eager = LinearBN(32, 32).eval() 2372 2373 inps = [3, 32] 2374 if modules[1] == nn.BatchNorm2d: 2375 inps.append(inps[-1]) 2376 inps.append(inps[-1]) 2377 if modules[1] == nn.BatchNorm3d: 2378 inps.append(inps[-1]) 2379 inps.append(inps[-1]) 2380 inps.append(inps[-1]) 2381 2382 inp = torch.rand(inps) 2383 2384 if tracing: 2385 scripted_mod = torch.jit.trace(mod_eager, (inp)) 2386 else: 2387 scripted_mod = torch.jit.script(mod_eager) 2388 2389 self.run_pass("inline", scripted_mod.graph) 2390 self.run_pass("peephole", scripted_mod.graph) 2391 self.run_pass("constant_propagation", scripted_mod.graph) 2392 2393 FileCheck().check("linear").check("batch").run(scripted_mod.graph) 2394 # successfully no-ops with non-const inputs 2395 self.run_pass("fold_frozen_linear_bn", scripted_mod.graph) 2396 FileCheck().check("linear").check("aten::batch_norm").run( 2397 scripted_mod.graph 2398 ) 2399 2400 scripted_mod = torch.jit.freeze(scripted_mod) 2401 self.run_pass("fold_frozen_linear_bn", scripted_mod.graph) 2402 if track_stats: 2403 FileCheck().check("linear").check_not("aten::batch_norm").run( 2404 scripted_mod.graph 2405 ) 2406 else: 2407 FileCheck().check("linear").check("aten::batch_norm").run( 2408 scripted_mod.graph 2409 ) 2410 2411 self.assertEqual(mod_eager(inp), scripted_mod(inp)) 2412 self.assertEqual(mod_eager(inp), scripted_mod(inp)) 2413 2414 def test_bn_not_broadcast_with_linear(self): 2415 module_pairs = [ 2416 (nn.Linear, nn.BatchNorm1d), 2417 (nn.Linear, nn.BatchNorm2d), 2418 (nn.Linear, nn.BatchNorm3d), 2419 ] 2420 use_tracing = [True, False] 2421 linear_in = 3 2422 # (linear_out, bn_in) 2423 # case 1: linear_out < bn_in 2424 # case 2: linear_out > bn_in 2425 # case 3: linear_out != bn_in && linear_out = 1 2426 dims = [(2, 4), (4, 2), (1, 2)] 2427 2428 for modules, tracing, dim in product(module_pairs, use_tracing, dims): 2429 linear_out, bn_in = dim[0], dim[1] 2430 2431 linear = modules[0](linear_in, linear_out) 2432 bn = modules[1](bn_in) 2433 mod_eager = nn.Sequential(linear, bn).eval() 2434 2435 N, C = 3, bn_in 2436 input_shape = [N, C] 2437 if modules[1] == nn.BatchNorm1d: 2438 H = linear_in 2439 input_shape.append(H) 2440 elif modules[1] == nn.BatchNorm2d: 2441 H, W = 4, linear_in 2442 input_shape.append(H) 2443 input_shape.append(W) 2444 elif modules[1] == nn.BatchNorm3d: 2445 D, H, W = 4, 4, linear_in 2446 input_shape.append(D) 2447 input_shape.append(H) 2448 input_shape.append(W) 2449 2450 inp = torch.rand(input_shape) 2451 2452 if tracing: 2453 scripted_mod = torch.jit.trace(mod_eager, (inp)) 2454 else: 2455 scripted_mod = torch.jit.script(mod_eager) 2456 2457 self.run_pass("inline", scripted_mod.graph) 2458 self.run_pass("peephole", scripted_mod.graph) 2459 self.run_pass("constant_propagation", scripted_mod.graph) 2460 2461 FileCheck().check("linear").check("batch").run(scripted_mod.graph) 2462 self.run_pass("fold_frozen_linear_bn", scripted_mod.graph) 2463 FileCheck().check("linear").check("aten::batch_norm").run( 2464 scripted_mod.graph 2465 ) 2466 2467 frozen_mod = torch.jit.freeze(scripted_mod) 2468 self.run_pass("fold_frozen_linear_bn", frozen_mod.graph) 2469 # successfully skipped folding 2470 FileCheck().check("linear").check("aten::batch_norm").run(frozen_mod.graph) 2471 2472 self.assertEqual(mod_eager(inp), frozen_mod(inp)) 2473 self.assertEqual(mod_eager(inp), frozen_mod(inp)) 2474 2475 # successfully failed folding 2476 with self.assertRaisesRegex( 2477 AssertionError, 2478 "To fuse, linear.out_features == bn.num_features or bn.num_features == 1", 2479 ): 2480 nn.utils.fusion.fuse_linear_bn_eval(linear, bn) 2481 2482 @skipCUDAMemoryLeakCheckIf(True) 2483 @unittest.skipIf(not TEST_CUDA, "Optimization currently only run for GPU") 2484 def test_linear_bn_folding_autocast_scenario_cuda(self): 2485 module_pairs = [ 2486 (nn.Linear, nn.BatchNorm1d), 2487 (nn.Linear, nn.BatchNorm2d), 2488 (nn.Linear, nn.BatchNorm3d), 2489 ] 2490 use_tracing = [True, False] 2491 bn_running_stats = [True, False] 2492 2493 for modules, tracing, track_stats in product( 2494 module_pairs, use_tracing, bn_running_stats 2495 ): 2496 2497 class LinearBN(torch.nn.Module): 2498 def __init__(self, in_features, out_features): 2499 super().__init__() 2500 self.linear = modules[0]( 2501 in_features, out_features, bias=False, dtype=torch.half 2502 ) 2503 self.bn = modules[1](out_features, eps=0.001, dtype=torch.float) 2504 2505 def forward(self, x): 2506 x = self.linear(x) 2507 return self.bn(x) 2508 2509 mod_eager = LinearBN(32, 32).cuda().eval() 2510 2511 inps = [3, 32] 2512 if modules[1] == nn.BatchNorm2d: 2513 inps.append(inps[-1]) 2514 inps.append(inps[-1]) 2515 if modules[1] == nn.BatchNorm3d: 2516 inps.append(inps[-1]) 2517 inps.append(inps[-1]) 2518 inps.append(inps[-1]) 2519 2520 x = torch.rand(inps, dtype=torch.half).cuda() 2521 2522 if tracing: 2523 scripted_mod = torch.jit.trace(mod_eager, (x)) 2524 else: 2525 scripted_mod = torch.jit.script(mod_eager) 2526 scripted_mod = torch.jit.freeze(scripted_mod) 2527 FileCheck().check("linear").check_not("aten::batch_norm").run( 2528 scripted_mod.graph 2529 ) 2530 lin_node = scripted_mod.graph.findNode("aten::linear", True) 2531 self.assertTrue(lin_node is not None) 2532 weight_input = lin_node.namedInput("weight") 2533 bias_input = lin_node.namedInput("bias") 2534 self.assertTrue(bias_input is not None) 2535 self.assertTrue(weight_input.type().dtype() == torch.half) 2536 self.assertTrue(bias_input.type().dtype() == torch.half) 2537 2538 self.assertEqual(mod_eager(x), scripted_mod(x), atol=1e-2, rtol=1e-2) 2539 self.assertEqual(mod_eager(x), scripted_mod(x), atol=1e-2, rtol=1e-2) 2540 2541 @unittest.skipIf(not TEST_CUDA, "Optimization currently only run for GPU") 2542 def test_linear_concat(self): 2543 out_dimms = [[5, 10], [1, 5]] 2544 2545 for w1_dim, w2_dim in out_dimms: 2546 2547 class ModMultLinear(nn.Module): 2548 def __init__(self, w1_dim, w2_dim): 2549 super().__init__() 2550 self.w1 = nn.Parameter(torch.rand([w1_dim, 5])) 2551 self.b1 = nn.Parameter(torch.rand([w1_dim])) 2552 self.w2 = nn.Parameter(torch.rand([w2_dim, 5])) 2553 self.b2 = nn.Parameter(torch.rand([w2_dim])) 2554 2555 def forward(self, in_tensor1): 2556 res1 = torch._C._nn.linear(in_tensor1, self.w1, self.b1) 2557 res2 = torch._C._nn.linear(in_tensor1, self.w2, self.b2) 2558 return res1, res2 2559 2560 mod_eager = ModMultLinear(w1_dim, w2_dim).eval() 2561 2562 test_val1 = torch.rand([50, 5]) 2563 self.check_linear_optimizations(mod_eager, 2, 1, (test_val1,)) 2564 2565 @unittest.skipIf(not TEST_CUDA, "Optimization currently only run for GPU") 2566 def test_linear_concat_complex(self): 2567 """ 2568 Testing that the interleaving of multiple optimizations does not 2569 cause errors, and gets optimized as expected 2570 """ 2571 2572 class ModMultLinear(nn.Module): 2573 def __init__(self) -> None: 2574 super().__init__() 2575 w1_dim = 5 2576 w2_dim = 10 2577 self.w1 = nn.Parameter(torch.rand([w1_dim, 5])) 2578 self.b1 = nn.Parameter(torch.rand([w1_dim])) 2579 self.w2 = nn.Parameter(torch.rand([w2_dim, 5])) 2580 self.b2 = nn.Parameter(torch.rand([w2_dim])) 2581 2582 def forward(self, in_tensor1): 2583 res1 = torch._C._nn.linear(in_tensor1, self.w1, self.b1) 2584 res3 = torch._C._nn.linear(res1, self.w2, self.b2) 2585 res2 = torch._C._nn.linear(in_tensor1, self.w2, self.b2) 2586 res4 = torch._C._nn.linear(res1, self.w1, self.b1) 2587 return res2, res3, res4 2588 2589 mod_eager = ModMultLinear().eval() 2590 test_val1 = torch.rand([50, 5]) 2591 self.check_linear_optimizations(mod_eager, 4, 2, (test_val1,)) 2592 2593 @unittest.skipIf(not TEST_CUDA, "Optimization currently only run for GPU") 2594 def test_linear_concat_different_input(self): 2595 """ 2596 There should be no change to the graph due to the optimization pass 2597 due to the two input tensors being different 2598 """ 2599 2600 # Freezing requires that the graph be a module 2601 class ModMultLinear(nn.Module): 2602 def __init__(self, w1_dim, w2_dim): 2603 super().__init__() 2604 self.w1 = nn.Parameter(torch.rand([w1_dim, 5])) 2605 self.b1 = nn.Parameter(torch.rand([w1_dim])) 2606 self.w2 = nn.Parameter(torch.rand([w2_dim, 5])) 2607 self.b2 = nn.Parameter(torch.rand([w2_dim])) 2608 2609 def forward(self, in_tensor1, in_tensor2): 2610 res1 = torch._C._nn.linear(in_tensor1, self.w1, self.b1) 2611 res2 = torch._C._nn.linear(in_tensor2, self.w2, self.b2) 2612 return res1, res2 2613 2614 mod_eager = ModMultLinear(5, 5).eval() 2615 test_val1 = torch.rand([50, 5]) 2616 test_val2 = torch.rand([50, 5]) 2617 self.check_linear_optimizations(mod_eager, 2, 2, (test_val1, test_val2)) 2618 2619 @unittest.skipIf(not TEST_CUDA, "Optimization currently only run for GPU") 2620 def test_linear_multiple_blocks(self): 2621 class ModMultLinear(nn.Module): 2622 def __init__(self, w1_dim, w2_dim): 2623 super().__init__() 2624 self.w1 = nn.Parameter(torch.rand([w1_dim, 5])) 2625 self.b1 = nn.Parameter(torch.rand([w1_dim])) 2626 self.w2 = nn.Parameter(torch.rand([w2_dim, 5])) 2627 self.b2 = nn.Parameter(torch.rand([w2_dim])) 2628 2629 def forward(self, in_tensor1, in_tensor2, cond: bool): 2630 res1 = torch._C._nn.linear(in_tensor1, self.w1, self.b1) 2631 if cond: 2632 res3 = torch._C._nn.linear(in_tensor2, self.w2, self.b2) 2633 res4 = torch._C._nn.linear(in_tensor1, self.w2, self.b1) 2634 else: 2635 raise AssertionError 2636 res2 = torch._C._nn.linear(in_tensor1, self.w2, self.b1) 2637 return res1, res2, res3, res4 2638 2639 mod_eager = ModMultLinear(5, 5).eval() 2640 test_val1 = torch.rand([50, 5]) 2641 test_val2 = torch.rand([50, 5]) 2642 self.check_linear_optimizations(mod_eager, 4, 3, (test_val1, test_val2, True)) 2643 2644 def check_linear_optimizations( 2645 self, eager_mod, orig_linears, new_linears, test_vals 2646 ): 2647 for is_cuda in [False, True]: 2648 if is_cuda: 2649 mod_to_device = eager_mod.cuda() 2650 test_vals_to_device = [ 2651 t.cuda() if isinstance(t, torch.Tensor) else t for t in test_vals 2652 ] 2653 else: 2654 mod_to_device = eager_mod 2655 test_vals_to_device = test_vals 2656 2657 script_mod = torch.jit.script(mod_to_device) 2658 op_graph = script_mod.graph 2659 2660 FileCheck().check_count("aten::linear", orig_linears, exactly=True).run( 2661 op_graph 2662 ) 2663 # successively no-ops with non-const inputs 2664 self.run_pass("concat_frozen_linear", op_graph) 2665 FileCheck().check_count("aten::linear", orig_linears, exactly=True).run( 2666 op_graph 2667 ) 2668 2669 script_mod = torch.jit.freeze(script_mod) 2670 op_graph = script_mod.graph 2671 self.run_pass("concat_frozen_linear", op_graph) 2672 if is_cuda: 2673 FileCheck().check_count("aten::linear", new_linears, exactly=True).run( 2674 op_graph 2675 ) 2676 else: 2677 FileCheck().check_count("aten::linear", orig_linears, exactly=True).run( 2678 op_graph 2679 ) 2680 2681 self.assertEqual( 2682 mod_to_device(*test_vals_to_device), script_mod(*test_vals_to_device) 2683 ) 2684 2685 def test_optimize_freeze_module(self): 2686 in_channels, out_channels = 3, 32 2687 conv = torch.nn.Conv2d( 2688 in_channels, out_channels, kernel_size=3, stride=2, bias=True 2689 ) 2690 bn = torch.nn.BatchNorm2d(out_channels, eps=0.001) 2691 mod = torch.nn.Sequential(conv, bn) 2692 # set optimize to False here, by default freezing runs run_frozen_optimizations 2693 frozen_mod = torch.jit.freeze( 2694 torch.jit.script(mod.eval()), optimize_numerics=False 2695 ) 2696 # inspect frozen mod 2697 FileCheck().check("batch_norm").run(frozen_mod.graph) 2698 torch.jit.run_frozen_optimizations(frozen_mod) 2699 FileCheck().check_not("batch_norm").run(frozen_mod.graph) 2700 2701 # run_frozen_optimizations should be run 2702 frozen_mod = torch.jit.freeze(torch.jit.script(mod.eval())) 2703 FileCheck().check_not("batch_norm").run(frozen_mod.graph) 2704 2705 def test_freeze_remove_dropout(self): 2706 class Net(nn.Module): 2707 def __init__(self) -> None: 2708 super().__init__() 2709 self.dropout = nn.Dropout(0.5) 2710 2711 def forward(self, x): 2712 return self.dropout(x) 2713 2714 mod = torch.jit.script(Net()) 2715 # inspect mod 2716 torch._C._jit_pass_inline(mod.graph) 2717 FileCheck().check("aten::dropout").run(mod.graph) 2718 frozen_mod = torch.jit.freeze(mod.eval()) 2719 FileCheck().check_not("aten::dropout").run(frozen_mod.graph) 2720 2721 input = torch.randn(2) 2722 output_s = mod.forward(input) 2723 output_f = frozen_mod.forward(input) 2724 self.assertEqual(output_s, output_f) 2725 2726 def test_freeze_remove_feature_dropout(self): 2727 class Net(nn.Module): 2728 def __init__(self) -> None: 2729 super().__init__() 2730 self.dropout = nn.Dropout2d(0.5) 2731 2732 def forward(self, x): 2733 return self.dropout(x) 2734 2735 mod = torch.jit.script(Net().eval()) 2736 # inspect mod 2737 torch._C._jit_pass_inline(mod.graph) 2738 FileCheck().check("aten::feature_dropout").run(mod.graph) 2739 frozen_mod = torch.jit.freeze(mod) 2740 FileCheck().check_not("aten::feature_dropout").run(frozen_mod.graph) 2741 2742 input = torch.randn(2, 2, 1, 1) 2743 output_s = mod.forward(input) 2744 output_f = frozen_mod.forward(input) 2745 self.assertEqual(output_s, output_f) 2746 2747 @unittest.skipIf( 2748 not torch.backends.mkldnn.is_available(), "MKL-DNN build is disabled" 2749 ) 2750 def test_freeze_mkdlnn(self): 2751 conv = torch.nn.Conv2d(3, 32, kernel_size=3, stride=2).eval().float() 2752 convmkl = mkldnn_utils.to_mkldnn(conv) 2753 out = torch.jit.freeze(torch.jit.script(convmkl.eval())) 2754 inp = torch.rand([4, 3, 4, 4]).float() 2755 self.assertEqual(out(inp.to_mkldnn()).to_dense(), conv(inp)) 2756 2757 @unittest.skipIf( 2758 not torch.backends.mkldnn.is_available(), "MKL-DNN build is disabled" 2759 ) 2760 def test_conv_to_mkldnn(self): 2761 with set_default_dtype(torch.float): 2762 for module, trace in product([nn.Conv2d, nn.Conv3d], [False, True]): 2763 mod = module(3, 32, kernel_size=3, stride=2).eval() 2764 inps = [4, 3, 4] 2765 if module == nn.Conv2d: 2766 inps.append(inps[-1]) 2767 if module == nn.Conv3d: 2768 inps.append(inps[-1]) 2769 inps.append(inps[-1]) 2770 2771 inp = torch.rand(inps) 2772 if trace: 2773 scripted_mod = torch.jit.script(mod) 2774 else: 2775 scripted_mod = torch.jit.trace(mod, (inp,)) 2776 2777 self.run_pass("inline", scripted_mod.graph) 2778 2779 FileCheck().check("conv").run(scripted_mod.graph) 2780 # successfully no-ops with non-const inputs 2781 self.run_pass("convert_frozen_ops_to_mkldnn", scripted_mod.graph) 2782 FileCheck().check_not("to_mkldnn").run(scripted_mod.graph) 2783 2784 scripted_mod = torch.jit.freeze(scripted_mod) 2785 self.run_pass("convert_frozen_ops_to_mkldnn", scripted_mod.graph) 2786 FileCheck().check("to_mkldnn").check("prim::mkldnn_convolution").check( 2787 "to_dense" 2788 ).run(scripted_mod.graph) 2789 2790 self.assertEqual(mod(inp), scripted_mod(inp)) 2791 self.assertEqual(mod(inp), scripted_mod(inp)) 2792 2793 def test_linear_transpose(self): 2794 class ModLinear(torch.nn.Module): 2795 def __init__(self) -> None: 2796 super().__init__() 2797 self.bias = torch.nn.Parameter(torch.rand(30)) 2798 self.weight = torch.nn.Parameter(torch.rand([30, 20])) 2799 2800 def forward(self, x): 2801 return torch._C._nn.linear(x, self.weight, self.bias) 2802 2803 mod_eager = ModLinear().eval() 2804 test_val = torch.rand([50, 20]) 2805 self.check_linear_optimizations_2( 2806 mod_eager, 1, 0, "transpose_frozen_linear", (test_val,) 2807 ) 2808 2809 def test_linear_non_constant_weight(self): 2810 class ModLinear(torch.nn.Module): 2811 def __init__(self) -> None: 2812 super().__init__() 2813 self.bias = torch.nn.Parameter(torch.rand(30)) 2814 2815 def forward(self, x, weight): 2816 return torch._C._nn.linear(x, weight, self.bias) 2817 2818 mod_eager = ModLinear().eval() 2819 test_val = torch.rand([50, 20]) 2820 test_weight = torch.rand([30, 20]) 2821 self.check_linear_optimizations_2( 2822 mod_eager, 1, 1, "transpose_frozen_linear", (test_val, test_weight) 2823 ) 2824 2825 def check_linear_optimizations_2( 2826 self, eager_mod, orig_linears, new_linears, opt_pass, test_vals 2827 ): 2828 # TODO: merge with check_linear_optimizations once both diffs land 2829 mod_to_device = eager_mod 2830 test_vals_to_device = test_vals 2831 2832 script_mod = torch.jit.script(mod_to_device) 2833 op_graph = script_mod.graph 2834 2835 FileCheck().check_count("aten::linear", orig_linears, exactly=True).run( 2836 op_graph 2837 ) 2838 # successively no-ops with non-const inputs 2839 self.run_pass(opt_pass, op_graph) 2840 FileCheck().check_count("aten::linear", orig_linears, exactly=True).run( 2841 op_graph 2842 ) 2843 2844 script_mod = torch.jit.freeze(script_mod) 2845 op_graph = script_mod.graph 2846 self.run_pass(opt_pass, op_graph) 2847 FileCheck().check_count("aten::linear", new_linears, exactly=True).run(op_graph) 2848 2849 self.assertEqual( 2850 mod_to_device(*test_vals_to_device), script_mod(*test_vals_to_device) 2851 ) 2852 2853 @staticmethod 2854 def conv(): 2855 # Generic composable conv for testing purposes 2856 return nn.Conv2d(8, 8, 1) 2857 2858 @unittest.skipIf( 2859 not torch.backends.mkldnn.is_available(), "MKL-DNN build is disabled" 2860 ) 2861 def test_collapse_adjacent_conversions(self): 2862 with set_default_dtype(torch.float): 2863 mod = nn.Sequential(self.conv(), self.conv()).eval() 2864 scripted_mod = torch.jit.script(mod) 2865 scripted_mod = torch.jit.freeze(scripted_mod) 2866 self.run_pass("convert_frozen_ops_to_mkldnn", scripted_mod.graph) 2867 FileCheck().check("to_mkldnn").check("prim::mkldnn_convolution").check( 2868 "prim::mkldnn_convolution" 2869 ).check("to_dense").run(scripted_mod.graph) 2870 FileCheck().check_count("to_mkldnn", 1, exactly=True).run( 2871 scripted_mod.graph 2872 ) 2873 2874 inp = torch.rand([1, 8, 8, 8]) 2875 self.assertEqual(scripted_mod(inp), mod(inp)) 2876 self.assertEqual(scripted_mod(inp), mod(inp)) 2877 2878 @unittest.skipIf( 2879 not torch.backends.mkldnn.is_available(), "MKL-DNN build is disabled" 2880 ) 2881 def test_mkldnn_fuser_broadcasting(self): 2882 class Add(nn.Module): 2883 def __init__(self, tensor): 2884 super().__init__() 2885 self.tensor = tensor 2886 2887 def forward(self, x): 2888 return x + self.tensor 2889 2890 with set_default_dtype(torch.float): 2891 for add_inp in [8], [8, 8, 1]: 2892 mod = nn.Sequential(self.conv(), Add(torch.rand(add_inp))).eval() 2893 scripted_mod = torch.jit.script(mod) 2894 scripted_mod = torch.jit.freeze(scripted_mod) 2895 self.run_pass("convert_frozen_ops_to_mkldnn", scripted_mod.graph) 2896 FileCheck().check("prim::BroadcastMKLDNNTensors").run( 2897 scripted_mod.graph 2898 ) 2899 inp = torch.rand([1, 8, 8, 8]) 2900 self.assertEqual(scripted_mod(inp), mod(inp)) 2901 self.assertEqual(scripted_mod(inp), mod(inp)) 2902 2903 # for good measure, check that broadcasting does not work without this op 2904 # so we can remove the op if it ever gets supported 2905 with self.assertRaisesRegex(RuntimeError, ""): 2906 ( 2907 torch.rand([1, 8, 8, 8]).to_mkldnn() 2908 + torch.rand(add_inp).to_mkldnn() 2909 ) 2910 2911 @unittest.skipIf( 2912 not torch.backends.mkldnn.is_available(), "MKL-DNN build is disabled" 2913 ) 2914 def test_mkldnn_inplace_removal(self): 2915 class AddMul(nn.Module): 2916 def __init__(self, tensor): 2917 super().__init__() 2918 self.tensor = tensor 2919 2920 def forward(self, x): 2921 return x.add_(self.tensor).div_(self.tensor) - 4 2922 2923 with set_default_dtype(torch.float): 2924 mod = nn.Sequential(self.conv(), AddMul(torch.rand([8]))).eval() 2925 scripted_mod = torch.jit.script(mod) 2926 scripted_mod = torch.jit.freeze(scripted_mod) 2927 self.run_pass("convert_frozen_ops_to_mkldnn", scripted_mod.graph) 2928 # add gets uninplaced and reinplaced 2929 FileCheck().check("aten::to_mkldnn").check("aten::add_").check( 2930 "aten::div_" 2931 ).run(scripted_mod.graph) 2932 inp = torch.rand([1, 8, 8, 8]) 2933 self.assertEqual(scripted_mod(inp), mod(inp)) 2934 self.assertEqual(scripted_mod(inp), mod(inp)) 2935 2936 @unittest.skipIf( 2937 not torch.backends.mkldnn.is_available(), "MKL-DNN build is disabled" 2938 ) 2939 @skipIfNoTorchVision 2940 def test_maxpool_mkldnn(self): 2941 with set_default_dtype(torch.float): 2942 model = torchvision.models.resnet18() 2943 sub_model = torch.nn.Sequential( 2944 model.conv1, model.bn1, model.relu, model.maxpool 2945 ) 2946 mod = torch.jit.freeze(torch.jit.script(sub_model.eval())) 2947 ( 2948 N, 2949 C, 2950 H, 2951 W, 2952 ) = ( 2953 10, 2954 3, 2955 224, 2956 224, 2957 ) 2958 inp = torch.randn(N, C, H, W) 2959 self.run_pass("convert_frozen_ops_to_mkldnn", mod.graph) 2960 FileCheck().check("max_pool").check("to_dense").run(mod.graph) 2961 FileCheck().check_count("to_dense", 1, exactly=True).run(mod.graph) 2962 self.assertEqual(mod(inp), sub_model(inp)) 2963 2964 @unittest.skipIf(torch.backends.mkldnn.is_available(), "Testing no mkldnn") 2965 def test_conv_to_mkldnn_no_mkldnn(self): 2966 # test no error when mkldnn not available 2967 with set_default_dtype(torch.float): 2968 mod = torch.jit.script(nn.Conv2d(3, 32, kernel_size=3, stride=2).eval()) 2969 frozen = torch.jit.freeze(mod) 2970 self.run_pass("convert_frozen_ops_to_mkldnn", frozen.graph) 2971 inp = torch.rand([4, 3, 4, 4]) 2972 self.assertEqual(frozen(inp), mod(inp)) 2973 2974 @unittest.skipIf(not (TEST_CUDNN or TEST_WITH_ROCM), "requires CUDNN") 2975 def test_freeze_conv_relu_fusion(self): 2976 with set_default_dtype(torch.float): 2977 conv_bias = [True, False] 2978 conv_ops = [nn.Conv2d, nn.Conv3d] 2979 use_add_z = [True, False] 2980 use_tracing = [True, False] 2981 for use_bias, conv, add_z, tracing in product( 2982 conv_bias, conv_ops, use_add_z, use_tracing 2983 ): 2984 2985 class Net(nn.Module): 2986 def __init__(self, in_channels, out_channels, **kwargs): 2987 super().__init__() 2988 self.conv = conv( 2989 in_channels, out_channels, bias=use_bias, **kwargs 2990 ) 2991 self.relu = nn.ReLU(inplace=True) 2992 self.add_z = add_z 2993 2994 def forward(self, x): 2995 z = self.conv(x) 2996 out = self.conv(x) 2997 if self.add_z: 2998 out += z 2999 out = self.relu(out) 3000 return out 3001 3002 mod_eager = Net(3, 6, kernel_size=3, stride=2).eval().cuda() 3003 3004 inps = [5, 3, 4, 4] 3005 if conv == nn.Conv3d: 3006 inps.append(inps[-1]) 3007 inp = torch.rand(inps).cuda() 3008 3009 if tracing: 3010 scripted_mod = torch.jit.trace(mod_eager, (inp)) 3011 else: 3012 scripted_mod = torch.jit.script(mod_eager) 3013 3014 frozen_mod = torch.jit.optimize_for_inference(scripted_mod) 3015 if TEST_WITH_ROCM: 3016 if add_z: 3017 FileCheck().check("aten::miopen_convolution_add_relu").run( 3018 frozen_mod.graph 3019 ) 3020 else: 3021 FileCheck().check("aten::miopen_convolution_relu").run( 3022 frozen_mod.graph 3023 ) 3024 else: 3025 if add_z: 3026 FileCheck().check("aten::cudnn_convolution_add_relu").run( 3027 frozen_mod.graph 3028 ) 3029 else: 3030 FileCheck().check("aten::cudnn_convolution_relu").run( 3031 frozen_mod.graph 3032 ) 3033 3034 self.assertEqual(mod_eager(inp), frozen_mod(inp)) 3035 3036 @unittest.skipIf(not (TEST_CUDNN or TEST_WITH_ROCM), "requires CUDNN") 3037 def test_freeze_conv_relu_fusion_not_forward(self): 3038 with set_default_dtype(torch.float): 3039 3040 class Net(nn.Module): 3041 def __init__(self, in_channels, out_channels, **kwargs): 3042 super().__init__() 3043 self.conv = nn.Conv2d( 3044 in_channels, out_channels, bias=None, **kwargs 3045 ) 3046 self.relu = nn.ReLU(inplace=True) 3047 3048 def forward(self, x): 3049 z = self.conv(x) 3050 out = self.conv(x) 3051 out = self.relu(out) 3052 return out 3053 3054 @torch.jit.export 3055 def make_prediction(self, x): 3056 return self.forward(x) 3057 3058 mod_eager = Net(3, 6, kernel_size=3, stride=2).eval().cuda() 3059 3060 inps = [5, 3, 4, 4] 3061 inp = torch.rand(inps).cuda() 3062 3063 scripted_mod = torch.jit.script(mod_eager) 3064 3065 frozen_mod = torch.jit.freeze( 3066 scripted_mod, preserved_attrs=["make_prediction"] 3067 ) 3068 optimized_mod = torch.jit.optimize_for_inference( 3069 frozen_mod, other_methods=["make_prediction"] 3070 ) 3071 if TEST_WITH_ROCM: 3072 FileCheck().check("aten::miopen_convolution_relu").run( 3073 optimized_mod.make_prediction.graph 3074 ) 3075 else: 3076 FileCheck().check("aten::cudnn_convolution_relu").run( 3077 optimized_mod.make_prediction.graph 3078 ) 3079 3080 self.assertEqual( 3081 mod_eager.make_prediction(inp), optimized_mod.make_prediction(inp) 3082 ) 3083 3084 @unittest.skipIf( 3085 not torch.backends.mkldnn.is_available(), "MKL-DNN build is disabled" 3086 ) 3087 def test_numel_less_than_size_with_padding(self): 3088 with set_default_dtype(torch.float): 3089 3090 class MyModule(nn.Module): 3091 def __init__(self) -> None: 3092 super().__init__() 3093 self.conv1 = nn.Conv2d( 3094 1, 3095 2, 3096 kernel_size=(2, 4), 3097 stride=2, 3098 padding=2, 3099 dilation=(2, 1), 3100 ) 3101 3102 def forward(self, i0): 3103 x = self.conv1(i0) 3104 o0 = torch.max(x, i0) 3105 o1 = torch.clip(x, -1.5, 1.5) 3106 return o0, o1 3107 3108 i0 = torch.zeros((1, 1, 1, 2), dtype=torch.float32) 3109 mod = MyModule() 3110 out = mod(i0) 3111 3112 exported = torch.jit.trace(mod, [i0]) 3113 exported = torch.jit.optimize_for_inference(exported) 3114 3115 eout = exported(i0) 3116 self.assertTrue(all(torch.allclose(x, y) for x, y in zip(out, eout))) 3117 3118 @unittest.skipIf( 3119 not torch.backends.mkldnn.is_available(), "MKL-DNN build is disabled" 3120 ) 3121 def test_incompatible_perf_formats(self): 3122 with set_default_dtype(torch.float): 3123 3124 class Mod(nn.Module): 3125 def __init__(self) -> None: 3126 super().__init__() 3127 self.conv = torch.nn.Conv2d(3, 64, 3, 2) 3128 self.max_pool = torch.nn.MaxPool2d(111, 111) 3129 3130 def forward(self, x): 3131 a = self.conv(x) 3132 b = self.max_pool(a) 3133 return a + b 3134 3135 model = Mod() 3136 model.eval() 3137 mod = torch.jit.freeze(torch.jit.script(model)) 3138 ( 3139 N, 3140 C, 3141 H, 3142 W, 3143 ) = ( 3144 10, 3145 3, 3146 224, 3147 224, 3148 ) 3149 inp = torch.randn(N, C, H, W) 3150 self.run_pass("convert_frozen_ops_to_mkldnn", mod.graph) 3151 self.assertEqual(model(inp), mod(inp)) 3152 3153 @unittest.skipIf( 3154 not torch.backends.mkldnn.is_available(), "MKL-DNN build is disabled" 3155 ) 3156 def test_pool2d_batchnorm(self): 3157 with set_default_dtype(torch.float): 3158 pooling_layers = [ 3159 torch.nn.AdaptiveAvgPool2d(4), 3160 # torch.nn.AdaptiveMaxPool2d(4), # return tuples 3161 torch.nn.MaxPool2d(4), 3162 torch.nn.AvgPool2d(4), 3163 torch.nn.BatchNorm2d(64).eval(), 3164 ] 3165 3166 for pl in pooling_layers: 3167 sub_model = torch.nn.Sequential( 3168 torch.nn.Conv2d(3, 64, 2, 2), 3169 torch.nn.ReLU(), 3170 pl, 3171 torch.nn.Hardswish(), 3172 ) 3173 sub_model.eval() 3174 mod = torch.jit.freeze(torch.jit.script(sub_model)) 3175 ( 3176 N, 3177 C, 3178 H, 3179 W, 3180 ) = ( 3181 10, 3182 3, 3183 224, 3184 224, 3185 ) 3186 inp = torch.randn(N, C, H, W) 3187 # these two passes needed to remove 3188 # a size check in BatchNorm2d 3189 removeExceptions(mod.graph) 3190 self.run_pass("dce", mod.graph) 3191 self.run_pass("convert_frozen_ops_to_mkldnn", mod.graph) 3192 FileCheck().check("aten::to_dense").check_next("return").run(mod.graph) 3193 self.assertEqual(sub_model(inp), mod(inp)) 3194 3195 @unittest.skipIf( 3196 not torch.backends.mkldnn.is_available(), "MKL-DNN build is disabled" 3197 ) 3198 def test_pool3d_batchnorm(self): 3199 with set_default_dtype(torch.float): 3200 pooling_layers = [ 3201 torch.nn.MaxPool3d(4), 3202 # torch.nn.AdaptiveAvgPool3d(4), # no ideep bindings 3203 # torch.nn.AdaptiveMaxPool3d(4), # return tuples 3204 torch.nn.AvgPool3d(4), 3205 torch.nn.BatchNorm3d(64).eval(), 3206 ] 3207 3208 for pl in pooling_layers: 3209 sub_model = torch.nn.Sequential( 3210 torch.nn.Conv3d(3, 64, 2, 2), 3211 torch.nn.ReLU(), 3212 pl, 3213 torch.nn.Hardswish(), 3214 ) 3215 sub_model.eval() 3216 mod = torch.jit.freeze(torch.jit.script(sub_model)) 3217 N, C, H, W, D = 10, 3, 64, 64, 64 3218 inp = torch.randn(N, C, D, H, W) 3219 # these two passes needed to remove 3220 # a size check in BatchNorm2d 3221 removeExceptions(mod.graph) 3222 self.run_pass("dce", mod.graph) 3223 self.run_pass("convert_frozen_ops_to_mkldnn", mod.graph) 3224 FileCheck().check("aten::to_dense").check_next("return").run(mod.graph) 3225 self.assertEqual(sub_model(inp), mod(inp)) 3226 3227 @unittest.skipIf( 3228 not torch.backends.mkldnn.is_available(), "MKL-DNN build is disabled" 3229 ) 3230 @skipIfNoTorchVision 3231 def test_conv_hardswish(self): 3232 with set_default_dtype(torch.float): 3233 3234 class Clamp(torch.nn.Module): 3235 def __init__(self, min_val, max_val, **kwargs): 3236 super().__init__() 3237 self.min_val = min_val 3238 self.max_val = max_val 3239 3240 def forward(self, x): 3241 return torch.clamp(x, self.min_val, self.max_val) 3242 3243 ( 3244 N, 3245 C, 3246 H, 3247 W, 3248 ) = ( 3249 10, 3250 3, 3251 224, 3252 224, 3253 ) 3254 activations = [ 3255 torch.nn.Hardswish(), 3256 torch.nn.Hardsigmoid(), 3257 torch.nn.ReLU6(), 3258 torch.nn.Tanh(), 3259 torch.nn.Hardtanh(0.0, 6.0), 3260 torch.nn.Hardtanh(1.0, 100.0), 3261 torch.nn.Hardtanh(-100.0, -1.0), 3262 torch.nn.GELU(), 3263 Clamp(-100.0, -1.0), 3264 Clamp(1.0, 100.0), 3265 Clamp(0.0, 6.0), 3266 Clamp(-1.0, 0.0), 3267 ] 3268 3269 model = torchvision.models.resnet18() 3270 for activation in activations: 3271 sub_model = torch.nn.Sequential(model.conv1, activation) 3272 sub_model.eval() 3273 mod = torch.jit.freeze(torch.jit.script(sub_model)) 3274 inp = torch.randn(N, C, H, W) 3275 self.run_pass("convert_frozen_ops_to_mkldnn", mod.graph) 3276 FileCheck().check_count("aten::to_dense", 1, exactly=True).run( 3277 mod.graph 3278 ) 3279 self.assertEqual(sub_model(inp), mod(inp)) 3280 3281 @unittest.skipIf( 3282 not torch.backends.mkldnn.is_available(), "MKL-DNN build is disabled" 3283 ) 3284 def test_hardswish_hardsigmoid(self): 3285 with set_default_dtype(torch.float): 3286 op_map = { 3287 "prim::MKLDNNHardSwish": F.hardswish, 3288 "prim::MKLDNNHardSigmoid": F.hardsigmoid, 3289 } 3290 3291 input_sizes = ([0], [1], [3], [1, 3, 8, 8]) 3292 for mkldnn_opname, aten_op in op_map.items(): 3293 for size in input_sizes: 3294 for inplace in (True, False): 3295 inplace_str = "_" if inplace else "" 3296 inplace_tgt = "%34" if inplace else "%35" 3297 graph_str = f"""graph(%input.1 : Tensor): 3298 %33 : None = prim::Constant() 3299 %34 : Tensor = aten::to_mkldnn(%input.1, %33) 3300 %35 : Tensor = {mkldnn_opname}{inplace_str}(%34) 3301 return ({inplace_tgt}) 3302 """ 3303 g = torch._C.parse_ir(graph_str) 3304 m = self.createFunctionFromGraph(g) 3305 x = torch.rand(size) 3306 # `inplace=False` is intentional, otherwise we modify the input 3307 # and we aren't testing aten impls anyways 3308 self.assertEqual(aten_op(x, inplace=False), m(x).to_dense()) 3309 3310 @unittest.skipIf( 3311 not torch.backends.mkldnn.is_available(), "MKL-DNN build is disabled" 3312 ) 3313 def test_scalar_mul(self): 3314 with set_default_dtype(torch.float): 3315 3316 class Mod(nn.Module): 3317 def __init__(self) -> None: 3318 super().__init__() 3319 self.mod = nn.Conv2d(8, 8, 1, padding=1) 3320 3321 def forward(self, x): 3322 a1 = self.mod(x) * 4 3323 return a1 * 4 + a1 * 5.0 3324 3325 mod = Mod().eval() 3326 scripted = torch.jit.freeze(torch.jit.script(mod)) 3327 optimized = torch.jit.optimize_for_inference(scripted) 3328 inp = torch.rand([1, 8, 8, 8]) 3329 # a1 cant be inplaced for first use, can for second 3330 FileCheck().check("ScalarMul(").check("ScalarMul_").run(optimized.graph) 3331 self.assertEqual(optimized(inp), mod(inp)) 3332 3333 def test_remove_detach(self): 3334 class Mod(nn.Module): 3335 def forward(self, x): 3336 y = x.detach() 3337 return y * y 3338 3339 mod = Mod().eval() 3340 frozen_mod = torch.jit.freeze(torch.jit.script(mod)) 3341 inp = torch.randn((2, 2)) 3342 FileCheck().check_not("aten::detach").run(frozen_mod.graph) 3343 self.assertEqual(frozen_mod(inp), mod(inp)) 3344 3345 def test_remove_detach_not_applied(self): 3346 class Mod(nn.Module): 3347 def forward(self, x): 3348 y = x.detach() 3349 return x is y 3350 3351 mod = Mod().eval() 3352 frozen_mod = torch.jit.freeze(torch.jit.script(mod)) 3353 inp = torch.randn((2, 2)) 3354 FileCheck().check("aten::detach").run(frozen_mod.graph) 3355 self.assertEqual(frozen_mod(inp), mod(inp)) 3356 3357 3358@skipIfTorchDynamo("somehow causing hanging during python shutdown") 3359@unittest.skipIf(not torch.backends.mkldnn.is_available(), "MKL-DNN build is disabled") 3360class TestMKLDNNReinplacing(JitTestCase): 3361 def setUp(self): 3362 super().setUp() 3363 self.default_dtype = torch.get_default_dtype() 3364 torch.set_default_dtype(torch.float) 3365 3366 def tearDown(self): 3367 super().tearDown() 3368 torch.set_default_dtype(self.default_dtype) 3369 3370 def getConv(self): 3371 return nn.Conv2d(3, 32, kernel_size=3, stride=2).eval() 3372 3373 def getInput(self): 3374 return torch.rand([4, 3, 4, 4]) 3375 3376 def freezeAndConvert(self, mod): 3377 mod = torch.jit.freeze(torch.jit.script(mod.eval())) 3378 self.run_pass("convert_frozen_ops_to_mkldnn", mod.graph) 3379 return mod 3380 3381 def checkResults(self, mod1, mod2): 3382 inp = self.getInput() 3383 self.assertEqual(mod1(inp), mod2(inp)) 3384 3385 def test_successful(self): 3386 # simple conv-relu 3387 3388 mod_eager = nn.Sequential(self.getConv(), nn.Hardswish(), nn.ReLU()) 3389 mod = self.freezeAndConvert(mod_eager) 3390 FileCheck().check("mkldnn_convolution").check_next( 3391 "prim::MKLDNNHardSwish_" 3392 ).check_next("aten::relu_").run(mod.graph) 3393 self.checkResults(mod_eager, mod) 3394 3395 def test_merge_liveness(self): 3396 class Mod(nn.Module): 3397 def __init__(self, tensor): 3398 super().__init__() 3399 self.tensor = tensor 3400 3401 def forward(self, x): 3402 # this mul can be inplaced since x is dead after this use 3403 temporary = x * self.tensor 3404 # temporary livespan is the return node, 3405 # add can not be inplaced 3406 return temporary + temporary, temporary 3407 3408 mod_eager = nn.Sequential(self.getConv(), Mod(torch.rand([4, 32, 1, 1]))) 3409 mod = self.freezeAndConvert(mod_eager) 3410 FileCheck().check("aten::mul_").check_not("aten::add_").run(mod.graph) 3411 self.checkResults(mod_eager, mod) 3412 3413 def test_always_alive_values(self): 3414 class Mod(nn.Module): 3415 def __init__(self, tensor): 3416 super().__init__() 3417 self.tensor = tensor 3418 3419 def forward(self, x): 3420 # x can't be inplaced because its a return value, 3421 # check that the inplacing pass doesnt try to inplace 3422 # self.tensor because its always alive 3423 return x * self.tensor, x 3424 3425 mod_eager = nn.Sequential(self.getConv(), Mod(torch.rand([4, 32, 1, 1]))) 3426 mod = self.freezeAndConvert(mod_eager) 3427 FileCheck().check_not("aten::mul_").run(mod.graph) 3428 self.checkResults(mod_eager, mod) 3429 3430 conv = self.getConv() 3431 3432 class Mod(nn.Module): 3433 def __init__(self) -> None: 3434 super().__init__() 3435 self.tensor = torch.rand([4, 32, 1, 1]) 3436 self.conv = conv 3437 3438 def forward(self, x): 3439 # the shapes dont add up on this just testing a particular pattern 3440 conv_output = self.conv(x) 3441 return conv_output, self.conv(torch.add(x, x)) 3442 3443 mod = self.freezeAndConvert(Mod()) 3444 # x is an input to the graph, and so it should not be inplaced 3445 # in the torch.add(x, x) call 3446 FileCheck().check_not("aten::add_").run(mod.graph) 3447 3448 def test_switch_inputs_to_inplace(self): 3449 class Mod(nn.Module): 3450 def __init__(self, tensor): 3451 super().__init__() 3452 self.tensor = tensor 3453 3454 def forward(self, x): 3455 # self.tensor cannot be inplaced, however x can, 3456 # and bc add is commutative we can reverse inputs to add_ 3457 return self.tensor + x 3458 3459 mod_eager = nn.Sequential(self.getConv(), Mod(torch.rand([4, 32, 1, 1]))) 3460 mod = self.freezeAndConvert(mod_eager) 3461 FileCheck().check("aten::add_").run(mod.graph) 3462 self.checkResults(mod_eager, mod) 3463