1# Owner(s): ["oncall: jit"] 2 3import io 4import os 5import sys 6from pathlib import Path 7from typing import NamedTuple, Optional 8 9import torch 10from torch import Tensor 11from torch.testing._internal.common_utils import skipIfTorchDynamo, TemporaryFileName 12 13 14# Make the helper files in test/ importable 15pytorch_test_dir = os.path.dirname(os.path.dirname(os.path.realpath(__file__))) 16sys.path.append(pytorch_test_dir) 17from torch.testing._internal.jit_utils import clear_class_registry, JitTestCase 18 19 20if __name__ == "__main__": 21 raise RuntimeError( 22 "This test file is not meant to be run directly, use:\n\n" 23 "\tpython test/test_jit.py TESTNAME\n\n" 24 "instead." 25 ) 26 27 28class TestSaveLoad(JitTestCase): 29 def test_different_modules(self): 30 """ 31 Exercise the situation where we have the same qualified name 32 in two different CompilationUnits on save/load. 33 """ 34 35 class Foo(torch.nn.Module): 36 def __init__(self) -> None: 37 super().__init__() 38 self.foo = torch.nn.Linear(2, 2) 39 self.bar = torch.nn.Linear(2, 2) 40 41 def forward(self, x): 42 x = self.foo(x) 43 x = self.bar(x) 44 return x 45 46 first_script_module = torch.jit.script(Foo()) 47 first_saved_module = io.BytesIO() 48 torch.jit.save(first_script_module, first_saved_module) 49 first_saved_module.seek(0) 50 51 clear_class_registry() 52 53 class Foo(torch.nn.Module): 54 def __init__(self) -> None: 55 super().__init__() 56 self.foo = torch.nn.Linear(2, 2) 57 58 def forward(self, x): 59 x = self.foo(x) 60 return x 61 62 second_script_module = torch.jit.script(Foo()) 63 second_saved_module = io.BytesIO() 64 torch.jit.save(torch.jit.script(Foo()), second_saved_module) 65 second_saved_module.seek(0) 66 67 clear_class_registry() 68 69 self.assertEqual( 70 first_script_module._c.qualified_name, 71 second_script_module._c.qualified_name, 72 ) 73 74 class ContainsBoth(torch.nn.Module): 75 def __init__(self) -> None: 76 super().__init__() 77 self.add_module("second", torch.jit.load(second_saved_module)) 78 self.add_module("first", torch.jit.load(first_saved_module)) 79 80 def forward(self, x): 81 x = self.first(x) 82 x = self.second(x) 83 return x 84 85 sm = torch.jit.script(ContainsBoth()) 86 contains_both = io.BytesIO() 87 torch.jit.save(sm, contains_both) 88 contains_both.seek(0) 89 sm = torch.jit.load(contains_both) 90 91 def test_different_functions(self): 92 """ 93 Exercise the situation where we have the same qualified name 94 in two different CompilationUnits on save/load. 95 """ 96 97 def lol(x): 98 return x 99 100 class Foo(torch.nn.Module): 101 def forward(self, x): 102 return lol(x) 103 104 first_script_module = torch.jit.script(Foo()) 105 first_saved_module = io.BytesIO() 106 torch.jit.save(first_script_module, first_saved_module) 107 first_saved_module.seek(0) 108 109 clear_class_registry() 110 111 def lol(x): # noqa: F811 112 return "hello" 113 114 class Foo(torch.nn.Module): 115 def forward(self, x): 116 return lol(x) 117 118 second_script_module = torch.jit.script(Foo()) 119 second_saved_module = io.BytesIO() 120 torch.jit.save(torch.jit.script(Foo()), second_saved_module) 121 second_saved_module.seek(0) 122 123 clear_class_registry() 124 125 self.assertEqual( 126 first_script_module._c.qualified_name, 127 second_script_module._c.qualified_name, 128 ) 129 130 class ContainsBoth(torch.nn.Module): 131 def __init__(self) -> None: 132 super().__init__() 133 self.add_module("second", torch.jit.load(second_saved_module)) 134 self.add_module("first", torch.jit.load(first_saved_module)) 135 136 def forward(self, x): 137 x = self.first(x) 138 x = self.second(x) 139 return x 140 141 sm = torch.jit.script(ContainsBoth()) 142 contains_both = io.BytesIO() 143 torch.jit.save(sm, contains_both) 144 contains_both.seek(0) 145 sm = torch.jit.load(contains_both) 146 147 def test_different_interfaces(self): 148 """ 149 Exercise the situation where we have the same qualified name 150 in two different CompilationUnits on save/load. 151 """ 152 153 @torch.jit.interface 154 class MyInterface: 155 def bar(self, x: Tensor) -> Tensor: 156 pass 157 158 @torch.jit.script 159 class ImplementInterface: 160 def __init__(self) -> None: 161 pass 162 163 def bar(self, x): 164 return x 165 166 class Foo(torch.nn.Module): 167 __annotations__ = {"interface": MyInterface} 168 169 def __init__(self) -> None: 170 super().__init__() 171 self.interface = ImplementInterface() 172 173 def forward(self, x): 174 return self.interface.bar(x) 175 176 first_script_module = torch.jit.script(Foo()) 177 first_saved_module = io.BytesIO() 178 torch.jit.save(first_script_module, first_saved_module) 179 first_saved_module.seek(0) 180 181 clear_class_registry() 182 183 @torch.jit.interface 184 class MyInterface: 185 def not_bar(self, x: Tensor) -> Tensor: 186 pass 187 188 @torch.jit.script # noqa: F811 189 class ImplementInterface: # noqa: F811 190 def __init__(self) -> None: 191 pass 192 193 def not_bar(self, x): 194 return x 195 196 class Foo(torch.nn.Module): 197 __annotations__ = {"interface": MyInterface} 198 199 def __init__(self) -> None: 200 super().__init__() 201 self.interface = ImplementInterface() 202 203 def forward(self, x): 204 return self.interface.not_bar(x) 205 206 second_script_module = torch.jit.script(Foo()) 207 second_saved_module = io.BytesIO() 208 torch.jit.save(torch.jit.script(Foo()), second_saved_module) 209 second_saved_module.seek(0) 210 211 clear_class_registry() 212 213 self.assertEqual( 214 first_script_module._c.qualified_name, 215 second_script_module._c.qualified_name, 216 ) 217 218 class ContainsBoth(torch.nn.Module): 219 def __init__(self) -> None: 220 super().__init__() 221 self.add_module("second", torch.jit.load(second_saved_module)) 222 self.add_module("first", torch.jit.load(first_saved_module)) 223 224 def forward(self, x): 225 x = self.first(x) 226 x = self.second(x) 227 return x 228 229 sm = torch.jit.script(ContainsBoth()) 230 contains_both = io.BytesIO() 231 torch.jit.save(sm, contains_both) 232 contains_both.seek(0) 233 sm = torch.jit.load(contains_both) 234 235 def test_many_collisions(self): 236 class MyCoolNamedTuple(NamedTuple): 237 a: int 238 239 @torch.jit.interface 240 class MyInterface: 241 def bar(self, x: Tensor) -> Tensor: 242 pass 243 244 @torch.jit.script 245 class ImplementInterface: 246 def __init__(self) -> None: 247 pass 248 249 def bar(self, x): 250 return x 251 252 def lol(x): 253 return x 254 255 class Foo(torch.nn.Module): 256 interface: MyInterface 257 258 def __init__(self) -> None: 259 super().__init__() 260 self.foo = torch.nn.Linear(2, 2) 261 self.bar = torch.nn.Linear(2, 2) 262 self.interface = ImplementInterface() 263 264 def forward(self, x): 265 x = self.foo(x) 266 x = self.bar(x) 267 x = lol(x) 268 x = self.interface.bar(x) 269 270 return x, MyCoolNamedTuple(a=5) 271 272 first_script_module = torch.jit.script(Foo()) 273 first_saved_module = io.BytesIO() 274 torch.jit.save(first_script_module, first_saved_module) 275 first_saved_module.seek(0) 276 277 clear_class_registry() 278 279 @torch.jit.interface 280 class MyInterface: 281 def not_bar(self, x: Tensor) -> Tensor: 282 pass 283 284 @torch.jit.script # noqa: F811 285 class ImplementInterface: # noqa: F811 286 def __init__(self) -> None: 287 pass 288 289 def not_bar(self, x): 290 return x 291 292 def lol(x): # noqa: F811 293 return "asdofij" 294 295 class MyCoolNamedTuple(NamedTuple): # noqa: F811 296 a: str 297 298 class Foo(torch.nn.Module): 299 interface: MyInterface 300 301 def __init__(self) -> None: 302 super().__init__() 303 self.foo = torch.nn.Linear(2, 2) 304 self.interface = ImplementInterface() 305 306 def forward(self, x): 307 x = self.foo(x) 308 self.interface.not_bar(x) 309 x = lol(x) 310 return x, MyCoolNamedTuple(a="hello") 311 312 second_script_module = torch.jit.script(Foo()) 313 second_saved_module = io.BytesIO() 314 torch.jit.save(second_script_module, second_saved_module) 315 second_saved_module.seek(0) 316 317 clear_class_registry() 318 319 self.assertEqual( 320 first_script_module._c.qualified_name, 321 second_script_module._c.qualified_name, 322 ) 323 324 class ContainsBoth(torch.nn.Module): 325 def __init__(self) -> None: 326 super().__init__() 327 self.add_module("second", torch.jit.load(second_saved_module)) 328 self.add_module("first", torch.jit.load(first_saved_module)) 329 330 def forward(self, x): 331 x, named_tuple_1 = self.first(x) 332 x, named_tuple_2 = self.second(x) 333 return len(x + named_tuple_2.a) + named_tuple_1.a 334 335 sm = torch.jit.script(ContainsBoth()) 336 contains_both = io.BytesIO() 337 torch.jit.save(sm, contains_both) 338 contains_both.seek(0) 339 sm = torch.jit.load(contains_both) 340 341 def test_save_load_with_extra_files(self): 342 class MyMod(torch.jit.ScriptModule): 343 @torch.jit.script_method 344 def forward(self, a): 345 return a 346 347 # specifically test binary data 348 value = b"bar\x00\xffbaz" 349 350 expected_extra_files = {} 351 expected_extra_files["foo"] = value 352 # verify that str to bytes conversion also works 353 expected_extra_files["foo2"] = "bar" 354 m = MyMod() 355 356 # Save to file. 357 with TemporaryFileName() as fname: 358 m.save(fname, _extra_files=expected_extra_files) 359 # values don't matter 360 extra_files = {"foo": "", "foo2": None} 361 torch.jit.load(fname, _extra_files=extra_files) 362 self.assertEqual(value, extra_files["foo"]) 363 # results come back always as bytes 364 self.assertEqual(b"bar", extra_files["foo2"]) 365 366 # Use torch.jit API 367 torch.jit.save(m, fname, _extra_files=expected_extra_files) 368 extra_files["foo"] = "" 369 torch.jit.load(fname, _extra_files=extra_files) 370 self.assertEqual(value, extra_files["foo"]) 371 372 # Save to buffer. 373 buffer = io.BytesIO(m.save_to_buffer(_extra_files=expected_extra_files)) 374 extra_files = {"foo": ""} 375 torch.jit.load(buffer, _extra_files=extra_files) 376 self.assertEqual(value, extra_files["foo"]) 377 378 # Use torch.jit API 379 buffer = io.BytesIO() 380 torch.jit.save(m, buffer, _extra_files=expected_extra_files) 381 buffer.seek(0) 382 extra_files = {"foo": ""} 383 torch.jit.load(buffer, _extra_files=extra_files) 384 self.assertEqual(value, extra_files["foo"]) 385 386 # Non-existent file 'bar' 387 with self.assertRaises(RuntimeError): 388 extra_files["bar"] = "" 389 torch.jit.load(buffer, _extra_files=extra_files) 390 391 def test_save_load_using_pathlib(self): 392 class MyMod(torch.jit.ScriptModule): 393 @torch.jit.script_method 394 def forward(self, a): 395 return 2 * a 396 397 m = MyMod() 398 399 # Save then load. 400 with TemporaryFileName() as fname: 401 path = Path(fname) 402 m.save(path) 403 m2 = torch.jit.load(path) 404 405 x = torch.tensor([1.0, 2.0, 3.0, 4.0]) 406 self.assertTrue(torch.equal(m(x), m2(x))) 407 408 def test_save_nonexit_file(self): 409 class Foo(torch.nn.Module): 410 def forward(self, x): 411 return 2 * x 412 413 script_module = torch.jit.script(Foo()) 414 with self.assertRaises(RuntimeError): 415 script_module.save("NonExist/path/test.pt") 416 417 def test_save_namedtuple_input_only(self): 418 """ 419 Even if a NamedTuple is only used as an input argument, saving and 420 loading should work correctly. 421 """ 422 global FooTuple # see [local resolution in python] 423 424 class FooTuple(NamedTuple): 425 a: int 426 427 class MyModule(torch.nn.Module): 428 def forward(self, x: FooTuple) -> torch.Tensor: 429 return torch.tensor(3) 430 431 m_loaded = self.getExportImportCopy(torch.jit.script(MyModule())) 432 output = m_loaded(FooTuple(a=5)) 433 self.assertEqual(output, torch.tensor(3)) 434 435 def test_save_namedtuple_input_only_forwardref(self): 436 """ 437 Even if a NamedTuple is only used as an input argument, saving and 438 loading should work correctly. 439 """ 440 global FooTuple # see [local resolution in python] 441 442 class FooTuple(NamedTuple): 443 a: "int" 444 445 class MyModule(torch.nn.Module): 446 def forward(self, x: FooTuple) -> torch.Tensor: 447 return torch.tensor(3) 448 449 m_loaded = self.getExportImportCopy(torch.jit.script(MyModule())) 450 output = m_loaded(FooTuple(a=5)) 451 self.assertEqual(output, torch.tensor(3)) 452 453 def test_save_namedtuple_output_only(self): 454 """ 455 Even if a NamedTuple is only used as an output argument, saving and 456 loading should work correctly. 457 """ 458 global FooTuple # see [local resolution in python] 459 460 class FooTuple(NamedTuple): 461 a: int 462 463 class MyModule(torch.nn.Module): 464 def forward(self) -> Optional[FooTuple]: 465 return None 466 467 m_loaded = self.getExportImportCopy(torch.jit.script(MyModule())) 468 output = m_loaded() 469 self.assertEqual(output, None) 470 471 def test_save_load_params_buffers_submodules(self): 472 """ 473 Check that parameters, buffers, and submodules are the same after loading. 474 """ 475 476 class Submodule(torch.nn.Module): 477 pass 478 479 class TestModule(torch.nn.Module): 480 def __init__(self) -> None: 481 super().__init__() 482 self.add_module("submodule_a", Submodule()) 483 self.register_parameter( 484 "parameter_a", torch.nn.Parameter(torch.randn(4)) 485 ) 486 self.buffer = torch.nn.Buffer(torch.randn(4)) 487 self.t = torch.rand(4) # not buffer 488 489 self.parameter_b = torch.nn.Parameter(torch.randn(4)) 490 self.submodule_b = Submodule() 491 self.buffer_b = torch.nn.Buffer(torch.randn(4)) 492 493 m = TestModule() 494 m_loaded = self.getExportImportCopy(torch.jit.script(m)) 495 496 # Check submodules. 497 self.assertEqual( 498 len(list(m.named_modules())), len(list(m_loaded.named_modules())) 499 ) 500 for m_s, loaded_s in zip(m.named_modules(), m_loaded.named_modules()): 501 m_name, _ = m_s 502 loaded_name, _ = loaded_s 503 self.assertEqual(m_name, loaded_name) 504 505 # Check parameters. 506 self.assertEqual(len(list(m.parameters())), len(list(m_loaded.parameters()))) 507 for m_p, loaded_p in zip(m.parameters(), m_loaded.parameters()): 508 self.assertEqual(m_p, loaded_p) 509 510 # Check buffers. 511 self.assertEqual( 512 len(list(m.named_buffers())), len(list(m_loaded.named_buffers())) 513 ) 514 for m_b, loaded_b in zip(m.named_buffers(), m_loaded.named_buffers()): 515 m_name, m_buffer = m_b 516 loaded_name, loaded_buffer = loaded_b 517 self.assertEqual(m_name, loaded_name) 518 self.assertEqual(m_buffer, loaded_buffer) 519 520 def test_save_load_meta_tensors(self): 521 """ 522 Check that parameters, buffers, and submodules are the same after loading 523 for a module with parameters and buffers that are meta tensors 524 """ 525 526 class Foo(torch.nn.Module): 527 def __init__(self) -> None: 528 super().__init__() 529 self.foo = torch.nn.Linear(2, 3, device="meta") 530 self.bar = torch.nn.Linear(3, 4) 531 self.buffer = torch.nn.Buffer(torch.randn(4, device="meta")) 532 533 def forward(self, x): 534 x = self.foo(x) 535 x = self.bar(x) 536 return x 537 538 m = Foo() 539 m_loaded = self.getExportImportCopy(torch.jit.script(m)) 540 # Check submodules. 541 self.assertEqual( 542 len(list(m.named_modules())), len(list(m_loaded.named_modules())) 543 ) 544 self.assertEqual( 545 {name for name, _ in m.named_modules()}, 546 {name for name, _ in m_loaded.named_modules()}, 547 ) 548 # Check parameters. 549 m_params = dict(m.named_parameters()) 550 m_loaded_params = dict(m_loaded.named_parameters()) 551 self.assertEqual(len(m_params), len(m_loaded_params)) 552 self.assertEqual(m_params, m_loaded_params) 553 # Check buffers. 554 m_buffers = dict(m.named_buffers()) 555 m_loaded_buffers = dict(m_loaded.named_buffers()) 556 self.assertEqual(len(m_buffers), len(m_loaded_buffers)) 557 self.assertEqual(m_buffers, m_loaded_buffers) 558 # Check params and buffers that are/are not meta tensors 559 self.assertTrue(m_params["foo.weight"].is_meta) 560 self.assertTrue(m_loaded_params["foo.weight"].is_meta) 561 self.assertTrue(m_params["foo.bias"].is_meta) 562 self.assertTrue(m_loaded_params["foo.bias"].is_meta) 563 self.assertFalse(m_params["bar.weight"].is_meta) 564 self.assertFalse(m_loaded_params["bar.weight"].is_meta) 565 self.assertFalse(m_params["bar.bias"].is_meta) 566 self.assertFalse(m_loaded_params["bar.bias"].is_meta) 567 self.assertTrue(m_buffers["buffer"].is_meta) 568 self.assertTrue(m_loaded_buffers["buffer"].is_meta) 569 570 def test_save_load_meta_tensors_to_device(self): 571 """ 572 Check that when loading a module with meta tensors to device, the meta tensors 573 stay on meta, but non-meta tensors are set to the indicated device. 574 """ 575 576 class Foo(torch.nn.Module): 577 def __init__(self) -> None: 578 super().__init__() 579 self.foo = torch.nn.Linear(2, 3, device="meta") 580 self.bar = torch.nn.Linear(3, 4) 581 582 def forward(self, x): 583 x = self.foo(x) 584 x = self.bar(x) 585 return x 586 587 m = Foo() 588 589 m_loaded = self.getExportImportCopy(torch.jit.script(m), map_location="cpu") 590 # Check submodules. 591 self.assertEqual( 592 len(list(m.named_modules())), len(list(m_loaded.named_modules())) 593 ) 594 self.assertEqual( 595 {name for name, _ in m.named_modules()}, 596 {name for name, _ in m_loaded.named_modules()}, 597 ) 598 # Check parameters. 599 m_params = dict(m.named_parameters()) 600 m_loaded_params = dict(m_loaded.named_parameters()) 601 self.assertEqual(len(m_params), len(m_loaded_params)) 602 self.assertEqual(m_params, m_loaded_params) 603 # Check params and buffers that are/are not meta tensors 604 self.assertTrue(m_params["foo.weight"].is_meta) 605 self.assertTrue(m_loaded_params["foo.weight"].is_meta) 606 self.assertTrue(m_params["foo.bias"].is_meta) 607 self.assertTrue(m_loaded_params["foo.bias"].is_meta) 608 self.assertTrue(m_params["bar.weight"].is_cpu) 609 self.assertTrue(m_loaded_params["bar.weight"].is_cpu) 610 self.assertTrue(m_params["bar.bias"].is_cpu) 611 self.assertTrue(m_loaded_params["bar.bias"].is_cpu) 612 613 def test_save_load_with_saved_traced_inputs(self): 614 """ 615 Check that saving and loading with traced inputs works as expected 616 """ 617 618 class Module(torch.nn.Module): 619 def __init__(self) -> None: 620 super().__init__() 621 622 def forward(self, x): 623 return torch.ones(1) 624 625 def get_loaded_inputs(inputs): 626 traced_module = torch.jit.trace(module, input1) 627 traced_inputs = list(traced_module.graph.inputs()) 628 with TemporaryFileName() as fname: 629 path = Path(fname) 630 traced_module.save(path) 631 print(traced_module.graph) 632 loaded_module = torch.jit.load(path, _restore_shapes=True) 633 print(loaded_module.graph) 634 return traced_inputs, list(loaded_module.graph.inputs()) 635 636 module = Module() 637 input_tensor = torch.rand(1, 3, 24, 24) 638 # Validate that with no input specified the traced inputs are stored 639 traced_module = torch.jit.trace(module, input_tensor) 640 traced_inputs = list(traced_module.graph.inputs()) 641 self.assertEqual( 642 traced_module._c._retrieve_traced_inputs()["forward"], [input_tensor] 643 ) 644 with TemporaryFileName() as fname: 645 path = Path(fname) 646 traced_module.save(path) 647 loaded_module = torch.jit.load(path, _restore_shapes=True) 648 loaded_inputs = list(loaded_module.graph.inputs()) 649 self.assertEqual(traced_inputs[1].type(), loaded_inputs[1].type()) 650 self.assertEqual( 651 traced_inputs[1].type().sizes(), loaded_inputs[1].type().sizes() 652 ) 653 # Validate that if no shapes are requested previous functionality remains 654 loaded_module = torch.jit.load(path) 655 loaded_inputs = list(loaded_module.graph.inputs()) 656 self.assertEqual(loaded_inputs[1].type().sizes(), None) 657 658 # Validate that inputs aren't saved when requested not to 659 traced_module = torch.jit.trace(module, input_tensor, _store_inputs=False) 660 traced_inputs = list(traced_module.graph.inputs()) 661 self.assertEqual(len(traced_module._c._retrieve_traced_inputs()), 0) 662 663 with TemporaryFileName() as fname: 664 path = Path(fname) 665 traced_module.save(path) 666 loaded_module = torch.jit.load(path, _restore_shapes=True) 667 loaded_inputs = list(loaded_module.graph.inputs()) 668 self.assertEqual(loaded_inputs[1].type().sizes(), None) 669 # Validate that if no shapes are requested previous functionality remains 670 loaded_module = torch.jit.load(path) 671 loaded_inputs = list(loaded_module.graph.inputs()) 672 self.assertEqual(loaded_inputs[1].type().sizes(), None) 673 674 # Validate that complex inputs work 675 # Testing dict of list with empty tensors 676 input1 = { 677 "1000": ( 678 torch.tensor([0]), 679 torch.tensor([], dtype=torch.int64), 680 torch.tensor([]), 681 ) 682 } 683 traced_inputs, loaded_inputs = get_loaded_inputs(input1) 684 self.assertEqual(traced_inputs[1].type(), loaded_inputs[1].type()) 685 686 # Testing dict of list 687 input2 = { 688 "1000": ( 689 torch.tensor([0]), 690 torch.tensor([1500000, 1500004], dtype=torch.int64), 691 torch.tensor([2.0, 3.0]), 692 ) 693 } 694 traced_inputs, loaded_inputs = get_loaded_inputs(input2) 695 self.assertEqual(traced_inputs[1].type(), loaded_inputs[1].type()) 696 697 # Testing list 698 input3 = [ 699 torch.tensor([0]), 700 torch.tensor([1500000, 1500004], dtype=torch.int64), 701 torch.tensor([2.0, 3.0]), 702 ] 703 704 traced_inputs, loaded_inputs = get_loaded_inputs(input3) 705 self.assertEqual(traced_inputs[1].type(), loaded_inputs[1].type()) 706 707 # Testing list of dict of list 708 input4 = [ 709 { 710 "1000": ( 711 torch.tensor([0]), 712 torch.tensor([1500000, 1500004], dtype=torch.int64), 713 torch.tensor([2.0, 3.0]), 714 ) 715 } 716 ] 717 718 traced_inputs, loaded_inputs = get_loaded_inputs(input4) 719 self.assertEqual(traced_inputs[1].type(), loaded_inputs[1].type()) 720 721 @skipIfTorchDynamo("too slow") 722 def test_save_load_large_string_attribute(self): 723 """ 724 Check if the model with string > 4GB can be loaded. 725 """ 726 import psutil 727 728 if psutil.virtual_memory().available < 60 * 1024 * 1024 * 1024: 729 # Profiled the test execution, and got this number to be safe to run the test 730 self.skipTest( 731 "Doesn't have enough memory to run test_save_load_large_string_attribute" 732 ) 733 734 class Model(torch.nn.Module): 735 def __init__(self) -> None: 736 super().__init__() 737 self.x = "x" * (2**32 + 1) 738 739 def forward(self, i) -> int: 740 return len(self.x) + i.numel() 741 742 inp = torch.ones(0) 743 ts = torch.jit.script(Model()) 744 ts_output = ts(inp) 745 746 b = io.BytesIO(ts.save_to_buffer()) 747 del ts 748 749 loaded_ts = torch.jit.load(b) 750 del b 751 loaded_output = loaded_ts(inp) 752 self.assertEqual(ts_output, loaded_output) 753 754 755def script_module_to_buffer(script_module): 756 module_buffer = io.BytesIO( 757 script_module._save_to_buffer_for_lite_interpreter(_use_flatbuffer=True) 758 ) 759 module_buffer.seek(0) 760 return module_buffer 761 762 763class TestSaveLoadFlatbuffer(JitTestCase): 764 def test_different_modules(self): 765 """ 766 Exercise the situation where we have the same qualified name 767 in two different CompilationUnits on save/load. 768 """ 769 770 class Foo(torch.nn.Module): 771 def __init__(self) -> None: 772 super().__init__() 773 self.foo = torch.nn.Linear(2, 2) 774 self.bar = torch.nn.Linear(2, 2) 775 776 def forward(self, x): 777 x = self.foo(x) 778 x = self.bar(x) 779 return x 780 781 first_script_module = torch.jit.script(Foo()) 782 first_saved_module = script_module_to_buffer(first_script_module) 783 784 clear_class_registry() 785 786 class Foo(torch.nn.Module): 787 def __init__(self) -> None: 788 super().__init__() 789 self.foo = torch.nn.Linear(2, 2) 790 791 def forward(self, x): 792 x = self.foo(x) 793 return x 794 795 second_script_module = torch.jit.script(Foo()) 796 second_saved_module = script_module_to_buffer(second_script_module) 797 798 clear_class_registry() 799 800 self.assertEqual( 801 first_script_module._c.qualified_name, 802 second_script_module._c.qualified_name, 803 ) 804 805 class ContainsBoth(torch.nn.Module): 806 def __init__(self) -> None: 807 super().__init__() 808 self.add_module("second", torch.jit.load(second_saved_module)) 809 self.add_module("first", torch.jit.load(first_saved_module)) 810 811 def forward(self, x): 812 x = self.first(x) 813 x = self.second(x) 814 return x 815 816 sm = torch.jit.script(ContainsBoth()) 817 contains_both = script_module_to_buffer(sm) 818 sm = torch.jit.load(contains_both) 819 820 def test_different_functions(self): 821 """ 822 Exercise the situation where we have the same qualified name 823 in two different CompilationUnits on save/load. 824 """ 825 826 def lol(x): 827 return x 828 829 class Foo(torch.nn.Module): 830 def forward(self, x): 831 return lol(x) 832 833 first_script_module = torch.jit.script(Foo()) 834 first_saved_module = script_module_to_buffer(first_script_module) 835 clear_class_registry() 836 837 def lol(x): # noqa: F811 838 return "hello" 839 840 class Foo(torch.nn.Module): 841 def forward(self, x): 842 return lol(x) 843 844 second_script_module = torch.jit.script(Foo()) 845 second_saved_module = script_module_to_buffer(second_script_module) 846 847 clear_class_registry() 848 849 self.assertEqual( 850 first_script_module._c.qualified_name, 851 second_script_module._c.qualified_name, 852 ) 853 854 class ContainsBoth(torch.nn.Module): 855 def __init__(self) -> None: 856 super().__init__() 857 self.add_module("second", torch.jit.load(second_saved_module)) 858 self.add_module("first", torch.jit.load(first_saved_module)) 859 860 def forward(self, x): 861 x = self.first(x) 862 x = self.second(x) 863 return x 864 865 sm = torch.jit.script(ContainsBoth()) 866 contains_both = script_module_to_buffer(sm) 867 sm = torch.jit.load(contains_both) 868 869 def test_different_interfaces(self): 870 """ 871 Exercise the situation where we have the same qualified name 872 in two different CompilationUnits on save/load. 873 """ 874 875 @torch.jit.interface 876 class MyInterface: 877 def bar(self, x: Tensor) -> Tensor: 878 pass 879 880 @torch.jit.script 881 class ImplementInterface: 882 def __init__(self) -> None: 883 pass 884 885 def bar(self, x): 886 return x 887 888 class Foo(torch.nn.Module): 889 __annotations__ = {"interface": MyInterface} 890 891 def __init__(self) -> None: 892 super().__init__() 893 self.interface = ImplementInterface() 894 895 def forward(self, x): 896 return self.interface.bar(x) 897 898 first_script_module = torch.jit.script(Foo()) 899 first_saved_module = script_module_to_buffer(first_script_module) 900 clear_class_registry() 901 902 @torch.jit.interface 903 class MyInterface: 904 def not_bar(self, x: Tensor) -> Tensor: 905 pass 906 907 @torch.jit.script # noqa: F811 908 class ImplementInterface: # noqa: F811 909 def __init__(self) -> None: 910 pass 911 912 def not_bar(self, x): 913 return x 914 915 class Foo(torch.nn.Module): 916 __annotations__ = {"interface": MyInterface} 917 918 def __init__(self) -> None: 919 super().__init__() 920 self.interface = ImplementInterface() 921 922 def forward(self, x): 923 return self.interface.not_bar(x) 924 925 second_script_module = torch.jit.script(Foo()) 926 second_saved_module = script_module_to_buffer(second_script_module) 927 928 clear_class_registry() 929 930 self.assertEqual( 931 first_script_module._c.qualified_name, 932 second_script_module._c.qualified_name, 933 ) 934 935 class ContainsBoth(torch.nn.Module): 936 def __init__(self) -> None: 937 super().__init__() 938 self.add_module("second", torch.jit.load(second_saved_module)) 939 self.add_module("first", torch.jit.load(first_saved_module)) 940 941 def forward(self, x): 942 x = self.first(x) 943 x = self.second(x) 944 return x 945 946 sm = torch.jit.script(ContainsBoth()) 947 contains_both = script_module_to_buffer(sm) 948 sm = torch.jit.load(contains_both) 949 950 def test_many_collisions(self): 951 class MyCoolNamedTuple(NamedTuple): 952 a: int 953 954 @torch.jit.interface 955 class MyInterface: 956 def bar(self, x: Tensor) -> Tensor: 957 pass 958 959 @torch.jit.script 960 class ImplementInterface: 961 def __init__(self) -> None: 962 pass 963 964 def bar(self, x): 965 return x 966 967 def lol(x): 968 return x 969 970 class Foo(torch.nn.Module): 971 interface: MyInterface 972 973 def __init__(self) -> None: 974 super().__init__() 975 self.foo = torch.nn.Linear(2, 2) 976 self.bar = torch.nn.Linear(2, 2) 977 self.interface = ImplementInterface() 978 979 def forward(self, x): 980 x = self.foo(x) 981 x = self.bar(x) 982 x = lol(x) 983 x = self.interface.bar(x) 984 985 return x, MyCoolNamedTuple(a=5) 986 987 first_script_module = torch.jit.script(Foo()) 988 first_saved_module = script_module_to_buffer(first_script_module) 989 990 clear_class_registry() 991 992 @torch.jit.interface 993 class MyInterface: 994 def not_bar(self, x: Tensor) -> Tensor: 995 pass 996 997 @torch.jit.script # noqa: F811 998 class ImplementInterface: # noqa: F811 999 def __init__(self) -> None: 1000 pass 1001 1002 def not_bar(self, x): 1003 return x 1004 1005 def lol(x): # noqa: F811 1006 return "asdofij" 1007 1008 class MyCoolNamedTuple(NamedTuple): # noqa: F811 1009 a: str 1010 1011 class Foo(torch.nn.Module): 1012 interface: MyInterface 1013 1014 def __init__(self) -> None: 1015 super().__init__() 1016 self.foo = torch.nn.Linear(2, 2) 1017 self.interface = ImplementInterface() 1018 1019 def forward(self, x): 1020 x = self.foo(x) 1021 self.interface.not_bar(x) 1022 x = lol(x) 1023 return x, MyCoolNamedTuple(a="hello") 1024 1025 second_script_module = torch.jit.script(Foo()) 1026 second_saved_module = script_module_to_buffer(second_script_module) 1027 1028 clear_class_registry() 1029 1030 self.assertEqual( 1031 first_script_module._c.qualified_name, 1032 second_script_module._c.qualified_name, 1033 ) 1034 1035 class ContainsBoth(torch.nn.Module): 1036 def __init__(self) -> None: 1037 super().__init__() 1038 self.add_module("second", torch.jit.load(second_saved_module)) 1039 self.add_module("first", torch.jit.load(first_saved_module)) 1040 1041 def forward(self, x): 1042 x, named_tuple_1 = self.first(x) 1043 x, named_tuple_2 = self.second(x) 1044 return len(x + named_tuple_2.a) + named_tuple_1.a 1045 1046 sm = torch.jit.script(ContainsBoth()) 1047 contains_both = script_module_to_buffer(sm) 1048 sm = torch.jit.load(contains_both) 1049 1050 def test_save_load_using_pathlib(self): 1051 class MyMod(torch.jit.ScriptModule): 1052 @torch.jit.script_method 1053 def forward(self, a): 1054 return 2 * a 1055 1056 m = MyMod() 1057 1058 # Save then load. 1059 with TemporaryFileName() as fname: 1060 path = Path(fname) 1061 torch.jit.save_jit_module_to_flatbuffer(m, path) 1062 m2 = torch.jit.load(path) 1063 1064 x = torch.tensor([1.0, 2.0, 3.0, 4.0]) 1065 self.assertTrue(torch.equal(m(x), m2(x))) 1066 1067 def test_save_namedtuple_input_only(self): 1068 """ 1069 Even if a NamedTuple is only used as an input argument, saving and 1070 loading should work correctly. 1071 """ 1072 global FooTuple # see [local resolution in python] 1073 1074 class FooTuple(NamedTuple): 1075 a: int 1076 1077 class MyModule(torch.nn.Module): 1078 def forward(self, x: FooTuple) -> torch.Tensor: 1079 return torch.tensor(3) 1080 1081 m_loaded = self.getExportImportCopy(torch.jit.script(MyModule())) 1082 output = m_loaded(FooTuple(a=5)) 1083 self.assertEqual(output, torch.tensor(3)) 1084 1085 def test_save_namedtuple_output_only(self): 1086 """ 1087 Even if a NamedTuple is only used as an output argument, saving and 1088 loading should work correctly. 1089 """ 1090 global FooTuple # see [local resolution in python] 1091 1092 class FooTuple(NamedTuple): 1093 a: int 1094 1095 class MyModule(torch.nn.Module): 1096 def forward(self) -> Optional[FooTuple]: 1097 return None 1098 1099 m_loaded = self.getExportImportCopy(torch.jit.script(MyModule())) 1100 output = m_loaded() 1101 self.assertEqual(output, None) 1102 1103 def test_module_info_flatbuffer(self): 1104 class Foo(torch.nn.Module): 1105 def __init__(self) -> None: 1106 super().__init__() 1107 self.foo = torch.nn.Linear(2, 2) 1108 self.bar = torch.nn.Linear(2, 2) 1109 1110 def forward(self, x): 1111 x = self.foo(x) 1112 x = self.bar(x) 1113 return x 1114 1115 first_script_module = torch.jit.script(Foo()) 1116 first_saved_module = io.BytesIO() 1117 torch.jit.save_jit_module_to_flatbuffer(first_script_module, first_saved_module) 1118 first_saved_module.seek(0) 1119 ff_info = torch.jit._serialization.get_flatbuffer_module_info( 1120 first_saved_module 1121 ) 1122 self.assertEqual(ff_info["bytecode_version"], 9) 1123 self.assertEqual(ff_info["operator_version"], 1) 1124 self.assertEqual(ff_info["type_names"], set()) 1125 self.assertEqual(ff_info["opname_to_num_args"], {"aten::linear": 3}) 1126 1127 self.assertEqual(len(ff_info["function_names"]), 1) 1128 self.assertTrue(next(iter(ff_info["function_names"])).endswith("forward")) 1129 1130 def test_save_load_params_buffers_submodules(self): 1131 """ 1132 Check that parameters, buffers, and submodules are the same after loading. 1133 """ 1134 1135 class Submodule(torch.nn.Module): 1136 pass 1137 1138 class TestModule(torch.nn.Module): 1139 def __init__(self) -> None: 1140 super().__init__() 1141 self.add_module("submodule_a", Submodule()) 1142 self.register_parameter( 1143 "parameter_a", torch.nn.Parameter(torch.randn(4)) 1144 ) 1145 self.buffer = torch.nn.Buffer(torch.randn(4)) 1146 self.t = torch.rand(4) # not buffer 1147 1148 self.parameter_b = torch.nn.Parameter(torch.randn(4)) 1149 self.submodule_b = Submodule() 1150 self.buffer_b = torch.nn.Buffer(torch.randn(4)) 1151 1152 m = TestModule() 1153 m_loaded = self.getExportImportCopy(torch.jit.script(m)) 1154 1155 # Check submodules. 1156 self.assertEqual( 1157 len(list(m.named_modules())), len(list(m_loaded.named_modules())) 1158 ) 1159 for m_s, loaded_s in zip(m.named_modules(), m_loaded.named_modules()): 1160 m_name, _ = m_s 1161 loaded_name, _ = loaded_s 1162 self.assertEqual(m_name, loaded_name) 1163 1164 # Check parameters. 1165 self.assertEqual(len(list(m.parameters())), len(list(m_loaded.parameters()))) 1166 for m_p, loaded_p in zip(m.parameters(), m_loaded.parameters()): 1167 self.assertEqual(m_p, loaded_p) 1168 1169 # Check buffers. 1170 self.assertEqual( 1171 len(list(m.named_buffers())), len(list(m_loaded.named_buffers())) 1172 ) 1173 for m_b, loaded_b in zip(m.named_buffers(), m_loaded.named_buffers()): 1174 m_name, m_buffer = m_b 1175 loaded_name, loaded_buffer = loaded_b 1176 self.assertEqual(m_name, loaded_name) 1177 self.assertEqual(m_buffer, loaded_buffer) 1178 1179 def test_save_load_with_extra_files(self): 1180 """ 1181 Check that parameters, buffers, and submodules are the same after loading. 1182 """ 1183 1184 class Module(torch.nn.Module): 1185 def forward(self, x: Tensor): 1186 return x 1187 1188 module = Module() 1189 script_module = torch.jit.script(module) 1190 1191 extra_files = {"abc.json": b"[1,2,3]"} 1192 script_module_io = script_module._save_to_buffer_for_lite_interpreter( 1193 _extra_files=extra_files, _use_flatbuffer=True 1194 ) 1195 1196 re_extra_files = {} 1197 torch._C._get_model_extra_files_from_buffer(script_module_io, re_extra_files) 1198 1199 self.assertEqual(extra_files, re_extra_files) 1200