xref: /aosp_15_r20/external/pytorch/test/jit/test_save_load.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# Owner(s): ["oncall: jit"]
2
3import io
4import os
5import sys
6from pathlib import Path
7from typing import NamedTuple, Optional
8
9import torch
10from torch import Tensor
11from torch.testing._internal.common_utils import skipIfTorchDynamo, TemporaryFileName
12
13
14# Make the helper files in test/ importable
15pytorch_test_dir = os.path.dirname(os.path.dirname(os.path.realpath(__file__)))
16sys.path.append(pytorch_test_dir)
17from torch.testing._internal.jit_utils import clear_class_registry, JitTestCase
18
19
20if __name__ == "__main__":
21    raise RuntimeError(
22        "This test file is not meant to be run directly, use:\n\n"
23        "\tpython test/test_jit.py TESTNAME\n\n"
24        "instead."
25    )
26
27
28class TestSaveLoad(JitTestCase):
29    def test_different_modules(self):
30        """
31        Exercise the situation where we have the same qualified name
32        in two different CompilationUnits on save/load.
33        """
34
35        class Foo(torch.nn.Module):
36            def __init__(self) -> None:
37                super().__init__()
38                self.foo = torch.nn.Linear(2, 2)
39                self.bar = torch.nn.Linear(2, 2)
40
41            def forward(self, x):
42                x = self.foo(x)
43                x = self.bar(x)
44                return x
45
46        first_script_module = torch.jit.script(Foo())
47        first_saved_module = io.BytesIO()
48        torch.jit.save(first_script_module, first_saved_module)
49        first_saved_module.seek(0)
50
51        clear_class_registry()
52
53        class Foo(torch.nn.Module):
54            def __init__(self) -> None:
55                super().__init__()
56                self.foo = torch.nn.Linear(2, 2)
57
58            def forward(self, x):
59                x = self.foo(x)
60                return x
61
62        second_script_module = torch.jit.script(Foo())
63        second_saved_module = io.BytesIO()
64        torch.jit.save(torch.jit.script(Foo()), second_saved_module)
65        second_saved_module.seek(0)
66
67        clear_class_registry()
68
69        self.assertEqual(
70            first_script_module._c.qualified_name,
71            second_script_module._c.qualified_name,
72        )
73
74        class ContainsBoth(torch.nn.Module):
75            def __init__(self) -> None:
76                super().__init__()
77                self.add_module("second", torch.jit.load(second_saved_module))
78                self.add_module("first", torch.jit.load(first_saved_module))
79
80            def forward(self, x):
81                x = self.first(x)
82                x = self.second(x)
83                return x
84
85        sm = torch.jit.script(ContainsBoth())
86        contains_both = io.BytesIO()
87        torch.jit.save(sm, contains_both)
88        contains_both.seek(0)
89        sm = torch.jit.load(contains_both)
90
91    def test_different_functions(self):
92        """
93        Exercise the situation where we have the same qualified name
94        in two different CompilationUnits on save/load.
95        """
96
97        def lol(x):
98            return x
99
100        class Foo(torch.nn.Module):
101            def forward(self, x):
102                return lol(x)
103
104        first_script_module = torch.jit.script(Foo())
105        first_saved_module = io.BytesIO()
106        torch.jit.save(first_script_module, first_saved_module)
107        first_saved_module.seek(0)
108
109        clear_class_registry()
110
111        def lol(x):  # noqa: F811
112            return "hello"
113
114        class Foo(torch.nn.Module):
115            def forward(self, x):
116                return lol(x)
117
118        second_script_module = torch.jit.script(Foo())
119        second_saved_module = io.BytesIO()
120        torch.jit.save(torch.jit.script(Foo()), second_saved_module)
121        second_saved_module.seek(0)
122
123        clear_class_registry()
124
125        self.assertEqual(
126            first_script_module._c.qualified_name,
127            second_script_module._c.qualified_name,
128        )
129
130        class ContainsBoth(torch.nn.Module):
131            def __init__(self) -> None:
132                super().__init__()
133                self.add_module("second", torch.jit.load(second_saved_module))
134                self.add_module("first", torch.jit.load(first_saved_module))
135
136            def forward(self, x):
137                x = self.first(x)
138                x = self.second(x)
139                return x
140
141        sm = torch.jit.script(ContainsBoth())
142        contains_both = io.BytesIO()
143        torch.jit.save(sm, contains_both)
144        contains_both.seek(0)
145        sm = torch.jit.load(contains_both)
146
147    def test_different_interfaces(self):
148        """
149        Exercise the situation where we have the same qualified name
150        in two different CompilationUnits on save/load.
151        """
152
153        @torch.jit.interface
154        class MyInterface:
155            def bar(self, x: Tensor) -> Tensor:
156                pass
157
158        @torch.jit.script
159        class ImplementInterface:
160            def __init__(self) -> None:
161                pass
162
163            def bar(self, x):
164                return x
165
166        class Foo(torch.nn.Module):
167            __annotations__ = {"interface": MyInterface}
168
169            def __init__(self) -> None:
170                super().__init__()
171                self.interface = ImplementInterface()
172
173            def forward(self, x):
174                return self.interface.bar(x)
175
176        first_script_module = torch.jit.script(Foo())
177        first_saved_module = io.BytesIO()
178        torch.jit.save(first_script_module, first_saved_module)
179        first_saved_module.seek(0)
180
181        clear_class_registry()
182
183        @torch.jit.interface
184        class MyInterface:
185            def not_bar(self, x: Tensor) -> Tensor:
186                pass
187
188        @torch.jit.script  # noqa: F811
189        class ImplementInterface:  # noqa: F811
190            def __init__(self) -> None:
191                pass
192
193            def not_bar(self, x):
194                return x
195
196        class Foo(torch.nn.Module):
197            __annotations__ = {"interface": MyInterface}
198
199            def __init__(self) -> None:
200                super().__init__()
201                self.interface = ImplementInterface()
202
203            def forward(self, x):
204                return self.interface.not_bar(x)
205
206        second_script_module = torch.jit.script(Foo())
207        second_saved_module = io.BytesIO()
208        torch.jit.save(torch.jit.script(Foo()), second_saved_module)
209        second_saved_module.seek(0)
210
211        clear_class_registry()
212
213        self.assertEqual(
214            first_script_module._c.qualified_name,
215            second_script_module._c.qualified_name,
216        )
217
218        class ContainsBoth(torch.nn.Module):
219            def __init__(self) -> None:
220                super().__init__()
221                self.add_module("second", torch.jit.load(second_saved_module))
222                self.add_module("first", torch.jit.load(first_saved_module))
223
224            def forward(self, x):
225                x = self.first(x)
226                x = self.second(x)
227                return x
228
229        sm = torch.jit.script(ContainsBoth())
230        contains_both = io.BytesIO()
231        torch.jit.save(sm, contains_both)
232        contains_both.seek(0)
233        sm = torch.jit.load(contains_both)
234
235    def test_many_collisions(self):
236        class MyCoolNamedTuple(NamedTuple):
237            a: int
238
239        @torch.jit.interface
240        class MyInterface:
241            def bar(self, x: Tensor) -> Tensor:
242                pass
243
244        @torch.jit.script
245        class ImplementInterface:
246            def __init__(self) -> None:
247                pass
248
249            def bar(self, x):
250                return x
251
252        def lol(x):
253            return x
254
255        class Foo(torch.nn.Module):
256            interface: MyInterface
257
258            def __init__(self) -> None:
259                super().__init__()
260                self.foo = torch.nn.Linear(2, 2)
261                self.bar = torch.nn.Linear(2, 2)
262                self.interface = ImplementInterface()
263
264            def forward(self, x):
265                x = self.foo(x)
266                x = self.bar(x)
267                x = lol(x)
268                x = self.interface.bar(x)
269
270                return x, MyCoolNamedTuple(a=5)
271
272        first_script_module = torch.jit.script(Foo())
273        first_saved_module = io.BytesIO()
274        torch.jit.save(first_script_module, first_saved_module)
275        first_saved_module.seek(0)
276
277        clear_class_registry()
278
279        @torch.jit.interface
280        class MyInterface:
281            def not_bar(self, x: Tensor) -> Tensor:
282                pass
283
284        @torch.jit.script  # noqa: F811
285        class ImplementInterface:  # noqa: F811
286            def __init__(self) -> None:
287                pass
288
289            def not_bar(self, x):
290                return x
291
292        def lol(x):  # noqa: F811
293            return "asdofij"
294
295        class MyCoolNamedTuple(NamedTuple):  # noqa: F811
296            a: str
297
298        class Foo(torch.nn.Module):
299            interface: MyInterface
300
301            def __init__(self) -> None:
302                super().__init__()
303                self.foo = torch.nn.Linear(2, 2)
304                self.interface = ImplementInterface()
305
306            def forward(self, x):
307                x = self.foo(x)
308                self.interface.not_bar(x)
309                x = lol(x)
310                return x, MyCoolNamedTuple(a="hello")
311
312        second_script_module = torch.jit.script(Foo())
313        second_saved_module = io.BytesIO()
314        torch.jit.save(second_script_module, second_saved_module)
315        second_saved_module.seek(0)
316
317        clear_class_registry()
318
319        self.assertEqual(
320            first_script_module._c.qualified_name,
321            second_script_module._c.qualified_name,
322        )
323
324        class ContainsBoth(torch.nn.Module):
325            def __init__(self) -> None:
326                super().__init__()
327                self.add_module("second", torch.jit.load(second_saved_module))
328                self.add_module("first", torch.jit.load(first_saved_module))
329
330            def forward(self, x):
331                x, named_tuple_1 = self.first(x)
332                x, named_tuple_2 = self.second(x)
333                return len(x + named_tuple_2.a) + named_tuple_1.a
334
335        sm = torch.jit.script(ContainsBoth())
336        contains_both = io.BytesIO()
337        torch.jit.save(sm, contains_both)
338        contains_both.seek(0)
339        sm = torch.jit.load(contains_both)
340
341    def test_save_load_with_extra_files(self):
342        class MyMod(torch.jit.ScriptModule):
343            @torch.jit.script_method
344            def forward(self, a):
345                return a
346
347        # specifically test binary data
348        value = b"bar\x00\xffbaz"
349
350        expected_extra_files = {}
351        expected_extra_files["foo"] = value
352        # verify that str to bytes conversion also works
353        expected_extra_files["foo2"] = "bar"
354        m = MyMod()
355
356        # Save to file.
357        with TemporaryFileName() as fname:
358            m.save(fname, _extra_files=expected_extra_files)
359            # values don't matter
360            extra_files = {"foo": "", "foo2": None}
361            torch.jit.load(fname, _extra_files=extra_files)
362            self.assertEqual(value, extra_files["foo"])
363            # results come back always as bytes
364            self.assertEqual(b"bar", extra_files["foo2"])
365
366            # Use torch.jit API
367            torch.jit.save(m, fname, _extra_files=expected_extra_files)
368            extra_files["foo"] = ""
369            torch.jit.load(fname, _extra_files=extra_files)
370            self.assertEqual(value, extra_files["foo"])
371
372        # Save to buffer.
373        buffer = io.BytesIO(m.save_to_buffer(_extra_files=expected_extra_files))
374        extra_files = {"foo": ""}
375        torch.jit.load(buffer, _extra_files=extra_files)
376        self.assertEqual(value, extra_files["foo"])
377
378        # Use torch.jit API
379        buffer = io.BytesIO()
380        torch.jit.save(m, buffer, _extra_files=expected_extra_files)
381        buffer.seek(0)
382        extra_files = {"foo": ""}
383        torch.jit.load(buffer, _extra_files=extra_files)
384        self.assertEqual(value, extra_files["foo"])
385
386        # Non-existent file 'bar'
387        with self.assertRaises(RuntimeError):
388            extra_files["bar"] = ""
389            torch.jit.load(buffer, _extra_files=extra_files)
390
391    def test_save_load_using_pathlib(self):
392        class MyMod(torch.jit.ScriptModule):
393            @torch.jit.script_method
394            def forward(self, a):
395                return 2 * a
396
397        m = MyMod()
398
399        # Save then load.
400        with TemporaryFileName() as fname:
401            path = Path(fname)
402            m.save(path)
403            m2 = torch.jit.load(path)
404
405        x = torch.tensor([1.0, 2.0, 3.0, 4.0])
406        self.assertTrue(torch.equal(m(x), m2(x)))
407
408    def test_save_nonexit_file(self):
409        class Foo(torch.nn.Module):
410            def forward(self, x):
411                return 2 * x
412
413        script_module = torch.jit.script(Foo())
414        with self.assertRaises(RuntimeError):
415            script_module.save("NonExist/path/test.pt")
416
417    def test_save_namedtuple_input_only(self):
418        """
419        Even if a NamedTuple is only used as an input argument, saving and
420        loading should work correctly.
421        """
422        global FooTuple  # see [local resolution in python]
423
424        class FooTuple(NamedTuple):
425            a: int
426
427        class MyModule(torch.nn.Module):
428            def forward(self, x: FooTuple) -> torch.Tensor:
429                return torch.tensor(3)
430
431        m_loaded = self.getExportImportCopy(torch.jit.script(MyModule()))
432        output = m_loaded(FooTuple(a=5))
433        self.assertEqual(output, torch.tensor(3))
434
435    def test_save_namedtuple_input_only_forwardref(self):
436        """
437        Even if a NamedTuple is only used as an input argument, saving and
438        loading should work correctly.
439        """
440        global FooTuple  # see [local resolution in python]
441
442        class FooTuple(NamedTuple):
443            a: "int"
444
445        class MyModule(torch.nn.Module):
446            def forward(self, x: FooTuple) -> torch.Tensor:
447                return torch.tensor(3)
448
449        m_loaded = self.getExportImportCopy(torch.jit.script(MyModule()))
450        output = m_loaded(FooTuple(a=5))
451        self.assertEqual(output, torch.tensor(3))
452
453    def test_save_namedtuple_output_only(self):
454        """
455        Even if a NamedTuple is only used as an output argument, saving and
456        loading should work correctly.
457        """
458        global FooTuple  # see [local resolution in python]
459
460        class FooTuple(NamedTuple):
461            a: int
462
463        class MyModule(torch.nn.Module):
464            def forward(self) -> Optional[FooTuple]:
465                return None
466
467        m_loaded = self.getExportImportCopy(torch.jit.script(MyModule()))
468        output = m_loaded()
469        self.assertEqual(output, None)
470
471    def test_save_load_params_buffers_submodules(self):
472        """
473        Check that parameters, buffers, and submodules are the same after loading.
474        """
475
476        class Submodule(torch.nn.Module):
477            pass
478
479        class TestModule(torch.nn.Module):
480            def __init__(self) -> None:
481                super().__init__()
482                self.add_module("submodule_a", Submodule())
483                self.register_parameter(
484                    "parameter_a", torch.nn.Parameter(torch.randn(4))
485                )
486                self.buffer = torch.nn.Buffer(torch.randn(4))
487                self.t = torch.rand(4)  # not buffer
488
489                self.parameter_b = torch.nn.Parameter(torch.randn(4))
490                self.submodule_b = Submodule()
491                self.buffer_b = torch.nn.Buffer(torch.randn(4))
492
493        m = TestModule()
494        m_loaded = self.getExportImportCopy(torch.jit.script(m))
495
496        # Check submodules.
497        self.assertEqual(
498            len(list(m.named_modules())), len(list(m_loaded.named_modules()))
499        )
500        for m_s, loaded_s in zip(m.named_modules(), m_loaded.named_modules()):
501            m_name, _ = m_s
502            loaded_name, _ = loaded_s
503            self.assertEqual(m_name, loaded_name)
504
505        # Check parameters.
506        self.assertEqual(len(list(m.parameters())), len(list(m_loaded.parameters())))
507        for m_p, loaded_p in zip(m.parameters(), m_loaded.parameters()):
508            self.assertEqual(m_p, loaded_p)
509
510        # Check buffers.
511        self.assertEqual(
512            len(list(m.named_buffers())), len(list(m_loaded.named_buffers()))
513        )
514        for m_b, loaded_b in zip(m.named_buffers(), m_loaded.named_buffers()):
515            m_name, m_buffer = m_b
516            loaded_name, loaded_buffer = loaded_b
517            self.assertEqual(m_name, loaded_name)
518            self.assertEqual(m_buffer, loaded_buffer)
519
520    def test_save_load_meta_tensors(self):
521        """
522        Check that parameters, buffers, and submodules are the same after loading
523        for a module with parameters and buffers that are meta tensors
524        """
525
526        class Foo(torch.nn.Module):
527            def __init__(self) -> None:
528                super().__init__()
529                self.foo = torch.nn.Linear(2, 3, device="meta")
530                self.bar = torch.nn.Linear(3, 4)
531                self.buffer = torch.nn.Buffer(torch.randn(4, device="meta"))
532
533            def forward(self, x):
534                x = self.foo(x)
535                x = self.bar(x)
536                return x
537
538        m = Foo()
539        m_loaded = self.getExportImportCopy(torch.jit.script(m))
540        # Check submodules.
541        self.assertEqual(
542            len(list(m.named_modules())), len(list(m_loaded.named_modules()))
543        )
544        self.assertEqual(
545            {name for name, _ in m.named_modules()},
546            {name for name, _ in m_loaded.named_modules()},
547        )
548        # Check parameters.
549        m_params = dict(m.named_parameters())
550        m_loaded_params = dict(m_loaded.named_parameters())
551        self.assertEqual(len(m_params), len(m_loaded_params))
552        self.assertEqual(m_params, m_loaded_params)
553        # Check buffers.
554        m_buffers = dict(m.named_buffers())
555        m_loaded_buffers = dict(m_loaded.named_buffers())
556        self.assertEqual(len(m_buffers), len(m_loaded_buffers))
557        self.assertEqual(m_buffers, m_loaded_buffers)
558        # Check params and buffers that are/are not meta tensors
559        self.assertTrue(m_params["foo.weight"].is_meta)
560        self.assertTrue(m_loaded_params["foo.weight"].is_meta)
561        self.assertTrue(m_params["foo.bias"].is_meta)
562        self.assertTrue(m_loaded_params["foo.bias"].is_meta)
563        self.assertFalse(m_params["bar.weight"].is_meta)
564        self.assertFalse(m_loaded_params["bar.weight"].is_meta)
565        self.assertFalse(m_params["bar.bias"].is_meta)
566        self.assertFalse(m_loaded_params["bar.bias"].is_meta)
567        self.assertTrue(m_buffers["buffer"].is_meta)
568        self.assertTrue(m_loaded_buffers["buffer"].is_meta)
569
570    def test_save_load_meta_tensors_to_device(self):
571        """
572        Check that when loading a module with meta tensors to device, the meta tensors
573        stay on meta, but non-meta tensors are set to the indicated device.
574        """
575
576        class Foo(torch.nn.Module):
577            def __init__(self) -> None:
578                super().__init__()
579                self.foo = torch.nn.Linear(2, 3, device="meta")
580                self.bar = torch.nn.Linear(3, 4)
581
582            def forward(self, x):
583                x = self.foo(x)
584                x = self.bar(x)
585                return x
586
587        m = Foo()
588
589        m_loaded = self.getExportImportCopy(torch.jit.script(m), map_location="cpu")
590        # Check submodules.
591        self.assertEqual(
592            len(list(m.named_modules())), len(list(m_loaded.named_modules()))
593        )
594        self.assertEqual(
595            {name for name, _ in m.named_modules()},
596            {name for name, _ in m_loaded.named_modules()},
597        )
598        # Check parameters.
599        m_params = dict(m.named_parameters())
600        m_loaded_params = dict(m_loaded.named_parameters())
601        self.assertEqual(len(m_params), len(m_loaded_params))
602        self.assertEqual(m_params, m_loaded_params)
603        # Check params and buffers that are/are not meta tensors
604        self.assertTrue(m_params["foo.weight"].is_meta)
605        self.assertTrue(m_loaded_params["foo.weight"].is_meta)
606        self.assertTrue(m_params["foo.bias"].is_meta)
607        self.assertTrue(m_loaded_params["foo.bias"].is_meta)
608        self.assertTrue(m_params["bar.weight"].is_cpu)
609        self.assertTrue(m_loaded_params["bar.weight"].is_cpu)
610        self.assertTrue(m_params["bar.bias"].is_cpu)
611        self.assertTrue(m_loaded_params["bar.bias"].is_cpu)
612
613    def test_save_load_with_saved_traced_inputs(self):
614        """
615        Check that saving and loading with traced inputs works as expected
616        """
617
618        class Module(torch.nn.Module):
619            def __init__(self) -> None:
620                super().__init__()
621
622            def forward(self, x):
623                return torch.ones(1)
624
625        def get_loaded_inputs(inputs):
626            traced_module = torch.jit.trace(module, input1)
627            traced_inputs = list(traced_module.graph.inputs())
628            with TemporaryFileName() as fname:
629                path = Path(fname)
630                traced_module.save(path)
631                print(traced_module.graph)
632                loaded_module = torch.jit.load(path, _restore_shapes=True)
633                print(loaded_module.graph)
634                return traced_inputs, list(loaded_module.graph.inputs())
635
636        module = Module()
637        input_tensor = torch.rand(1, 3, 24, 24)
638        # Validate that with no input specified the traced inputs are stored
639        traced_module = torch.jit.trace(module, input_tensor)
640        traced_inputs = list(traced_module.graph.inputs())
641        self.assertEqual(
642            traced_module._c._retrieve_traced_inputs()["forward"], [input_tensor]
643        )
644        with TemporaryFileName() as fname:
645            path = Path(fname)
646            traced_module.save(path)
647            loaded_module = torch.jit.load(path, _restore_shapes=True)
648            loaded_inputs = list(loaded_module.graph.inputs())
649            self.assertEqual(traced_inputs[1].type(), loaded_inputs[1].type())
650            self.assertEqual(
651                traced_inputs[1].type().sizes(), loaded_inputs[1].type().sizes()
652            )
653            # Validate that if no shapes are requested previous functionality remains
654            loaded_module = torch.jit.load(path)
655            loaded_inputs = list(loaded_module.graph.inputs())
656            self.assertEqual(loaded_inputs[1].type().sizes(), None)
657
658        # Validate that inputs aren't saved when requested not to
659        traced_module = torch.jit.trace(module, input_tensor, _store_inputs=False)
660        traced_inputs = list(traced_module.graph.inputs())
661        self.assertEqual(len(traced_module._c._retrieve_traced_inputs()), 0)
662
663        with TemporaryFileName() as fname:
664            path = Path(fname)
665            traced_module.save(path)
666            loaded_module = torch.jit.load(path, _restore_shapes=True)
667            loaded_inputs = list(loaded_module.graph.inputs())
668            self.assertEqual(loaded_inputs[1].type().sizes(), None)
669            # Validate that if no shapes are requested previous functionality remains
670            loaded_module = torch.jit.load(path)
671            loaded_inputs = list(loaded_module.graph.inputs())
672            self.assertEqual(loaded_inputs[1].type().sizes(), None)
673
674        # Validate that complex inputs work
675        # Testing dict of list with empty tensors
676        input1 = {
677            "1000": (
678                torch.tensor([0]),
679                torch.tensor([], dtype=torch.int64),
680                torch.tensor([]),
681            )
682        }
683        traced_inputs, loaded_inputs = get_loaded_inputs(input1)
684        self.assertEqual(traced_inputs[1].type(), loaded_inputs[1].type())
685
686        # Testing dict of list
687        input2 = {
688            "1000": (
689                torch.tensor([0]),
690                torch.tensor([1500000, 1500004], dtype=torch.int64),
691                torch.tensor([2.0, 3.0]),
692            )
693        }
694        traced_inputs, loaded_inputs = get_loaded_inputs(input2)
695        self.assertEqual(traced_inputs[1].type(), loaded_inputs[1].type())
696
697        # Testing list
698        input3 = [
699            torch.tensor([0]),
700            torch.tensor([1500000, 1500004], dtype=torch.int64),
701            torch.tensor([2.0, 3.0]),
702        ]
703
704        traced_inputs, loaded_inputs = get_loaded_inputs(input3)
705        self.assertEqual(traced_inputs[1].type(), loaded_inputs[1].type())
706
707        # Testing list of dict of list
708        input4 = [
709            {
710                "1000": (
711                    torch.tensor([0]),
712                    torch.tensor([1500000, 1500004], dtype=torch.int64),
713                    torch.tensor([2.0, 3.0]),
714                )
715            }
716        ]
717
718        traced_inputs, loaded_inputs = get_loaded_inputs(input4)
719        self.assertEqual(traced_inputs[1].type(), loaded_inputs[1].type())
720
721    @skipIfTorchDynamo("too slow")
722    def test_save_load_large_string_attribute(self):
723        """
724        Check if the model with string > 4GB can be loaded.
725        """
726        import psutil
727
728        if psutil.virtual_memory().available < 60 * 1024 * 1024 * 1024:
729            # Profiled the test execution, and got this number to be safe to run the test
730            self.skipTest(
731                "Doesn't have enough memory to run test_save_load_large_string_attribute"
732            )
733
734        class Model(torch.nn.Module):
735            def __init__(self) -> None:
736                super().__init__()
737                self.x = "x" * (2**32 + 1)
738
739            def forward(self, i) -> int:
740                return len(self.x) + i.numel()
741
742        inp = torch.ones(0)
743        ts = torch.jit.script(Model())
744        ts_output = ts(inp)
745
746        b = io.BytesIO(ts.save_to_buffer())
747        del ts
748
749        loaded_ts = torch.jit.load(b)
750        del b
751        loaded_output = loaded_ts(inp)
752        self.assertEqual(ts_output, loaded_output)
753
754
755def script_module_to_buffer(script_module):
756    module_buffer = io.BytesIO(
757        script_module._save_to_buffer_for_lite_interpreter(_use_flatbuffer=True)
758    )
759    module_buffer.seek(0)
760    return module_buffer
761
762
763class TestSaveLoadFlatbuffer(JitTestCase):
764    def test_different_modules(self):
765        """
766        Exercise the situation where we have the same qualified name
767        in two different CompilationUnits on save/load.
768        """
769
770        class Foo(torch.nn.Module):
771            def __init__(self) -> None:
772                super().__init__()
773                self.foo = torch.nn.Linear(2, 2)
774                self.bar = torch.nn.Linear(2, 2)
775
776            def forward(self, x):
777                x = self.foo(x)
778                x = self.bar(x)
779                return x
780
781        first_script_module = torch.jit.script(Foo())
782        first_saved_module = script_module_to_buffer(first_script_module)
783
784        clear_class_registry()
785
786        class Foo(torch.nn.Module):
787            def __init__(self) -> None:
788                super().__init__()
789                self.foo = torch.nn.Linear(2, 2)
790
791            def forward(self, x):
792                x = self.foo(x)
793                return x
794
795        second_script_module = torch.jit.script(Foo())
796        second_saved_module = script_module_to_buffer(second_script_module)
797
798        clear_class_registry()
799
800        self.assertEqual(
801            first_script_module._c.qualified_name,
802            second_script_module._c.qualified_name,
803        )
804
805        class ContainsBoth(torch.nn.Module):
806            def __init__(self) -> None:
807                super().__init__()
808                self.add_module("second", torch.jit.load(second_saved_module))
809                self.add_module("first", torch.jit.load(first_saved_module))
810
811            def forward(self, x):
812                x = self.first(x)
813                x = self.second(x)
814                return x
815
816        sm = torch.jit.script(ContainsBoth())
817        contains_both = script_module_to_buffer(sm)
818        sm = torch.jit.load(contains_both)
819
820    def test_different_functions(self):
821        """
822        Exercise the situation where we have the same qualified name
823        in two different CompilationUnits on save/load.
824        """
825
826        def lol(x):
827            return x
828
829        class Foo(torch.nn.Module):
830            def forward(self, x):
831                return lol(x)
832
833        first_script_module = torch.jit.script(Foo())
834        first_saved_module = script_module_to_buffer(first_script_module)
835        clear_class_registry()
836
837        def lol(x):  # noqa: F811
838            return "hello"
839
840        class Foo(torch.nn.Module):
841            def forward(self, x):
842                return lol(x)
843
844        second_script_module = torch.jit.script(Foo())
845        second_saved_module = script_module_to_buffer(second_script_module)
846
847        clear_class_registry()
848
849        self.assertEqual(
850            first_script_module._c.qualified_name,
851            second_script_module._c.qualified_name,
852        )
853
854        class ContainsBoth(torch.nn.Module):
855            def __init__(self) -> None:
856                super().__init__()
857                self.add_module("second", torch.jit.load(second_saved_module))
858                self.add_module("first", torch.jit.load(first_saved_module))
859
860            def forward(self, x):
861                x = self.first(x)
862                x = self.second(x)
863                return x
864
865        sm = torch.jit.script(ContainsBoth())
866        contains_both = script_module_to_buffer(sm)
867        sm = torch.jit.load(contains_both)
868
869    def test_different_interfaces(self):
870        """
871        Exercise the situation where we have the same qualified name
872        in two different CompilationUnits on save/load.
873        """
874
875        @torch.jit.interface
876        class MyInterface:
877            def bar(self, x: Tensor) -> Tensor:
878                pass
879
880        @torch.jit.script
881        class ImplementInterface:
882            def __init__(self) -> None:
883                pass
884
885            def bar(self, x):
886                return x
887
888        class Foo(torch.nn.Module):
889            __annotations__ = {"interface": MyInterface}
890
891            def __init__(self) -> None:
892                super().__init__()
893                self.interface = ImplementInterface()
894
895            def forward(self, x):
896                return self.interface.bar(x)
897
898        first_script_module = torch.jit.script(Foo())
899        first_saved_module = script_module_to_buffer(first_script_module)
900        clear_class_registry()
901
902        @torch.jit.interface
903        class MyInterface:
904            def not_bar(self, x: Tensor) -> Tensor:
905                pass
906
907        @torch.jit.script  # noqa: F811
908        class ImplementInterface:  # noqa: F811
909            def __init__(self) -> None:
910                pass
911
912            def not_bar(self, x):
913                return x
914
915        class Foo(torch.nn.Module):
916            __annotations__ = {"interface": MyInterface}
917
918            def __init__(self) -> None:
919                super().__init__()
920                self.interface = ImplementInterface()
921
922            def forward(self, x):
923                return self.interface.not_bar(x)
924
925        second_script_module = torch.jit.script(Foo())
926        second_saved_module = script_module_to_buffer(second_script_module)
927
928        clear_class_registry()
929
930        self.assertEqual(
931            first_script_module._c.qualified_name,
932            second_script_module._c.qualified_name,
933        )
934
935        class ContainsBoth(torch.nn.Module):
936            def __init__(self) -> None:
937                super().__init__()
938                self.add_module("second", torch.jit.load(second_saved_module))
939                self.add_module("first", torch.jit.load(first_saved_module))
940
941            def forward(self, x):
942                x = self.first(x)
943                x = self.second(x)
944                return x
945
946        sm = torch.jit.script(ContainsBoth())
947        contains_both = script_module_to_buffer(sm)
948        sm = torch.jit.load(contains_both)
949
950    def test_many_collisions(self):
951        class MyCoolNamedTuple(NamedTuple):
952            a: int
953
954        @torch.jit.interface
955        class MyInterface:
956            def bar(self, x: Tensor) -> Tensor:
957                pass
958
959        @torch.jit.script
960        class ImplementInterface:
961            def __init__(self) -> None:
962                pass
963
964            def bar(self, x):
965                return x
966
967        def lol(x):
968            return x
969
970        class Foo(torch.nn.Module):
971            interface: MyInterface
972
973            def __init__(self) -> None:
974                super().__init__()
975                self.foo = torch.nn.Linear(2, 2)
976                self.bar = torch.nn.Linear(2, 2)
977                self.interface = ImplementInterface()
978
979            def forward(self, x):
980                x = self.foo(x)
981                x = self.bar(x)
982                x = lol(x)
983                x = self.interface.bar(x)
984
985                return x, MyCoolNamedTuple(a=5)
986
987        first_script_module = torch.jit.script(Foo())
988        first_saved_module = script_module_to_buffer(first_script_module)
989
990        clear_class_registry()
991
992        @torch.jit.interface
993        class MyInterface:
994            def not_bar(self, x: Tensor) -> Tensor:
995                pass
996
997        @torch.jit.script  # noqa: F811
998        class ImplementInterface:  # noqa: F811
999            def __init__(self) -> None:
1000                pass
1001
1002            def not_bar(self, x):
1003                return x
1004
1005        def lol(x):  # noqa: F811
1006            return "asdofij"
1007
1008        class MyCoolNamedTuple(NamedTuple):  # noqa: F811
1009            a: str
1010
1011        class Foo(torch.nn.Module):
1012            interface: MyInterface
1013
1014            def __init__(self) -> None:
1015                super().__init__()
1016                self.foo = torch.nn.Linear(2, 2)
1017                self.interface = ImplementInterface()
1018
1019            def forward(self, x):
1020                x = self.foo(x)
1021                self.interface.not_bar(x)
1022                x = lol(x)
1023                return x, MyCoolNamedTuple(a="hello")
1024
1025        second_script_module = torch.jit.script(Foo())
1026        second_saved_module = script_module_to_buffer(second_script_module)
1027
1028        clear_class_registry()
1029
1030        self.assertEqual(
1031            first_script_module._c.qualified_name,
1032            second_script_module._c.qualified_name,
1033        )
1034
1035        class ContainsBoth(torch.nn.Module):
1036            def __init__(self) -> None:
1037                super().__init__()
1038                self.add_module("second", torch.jit.load(second_saved_module))
1039                self.add_module("first", torch.jit.load(first_saved_module))
1040
1041            def forward(self, x):
1042                x, named_tuple_1 = self.first(x)
1043                x, named_tuple_2 = self.second(x)
1044                return len(x + named_tuple_2.a) + named_tuple_1.a
1045
1046        sm = torch.jit.script(ContainsBoth())
1047        contains_both = script_module_to_buffer(sm)
1048        sm = torch.jit.load(contains_both)
1049
1050    def test_save_load_using_pathlib(self):
1051        class MyMod(torch.jit.ScriptModule):
1052            @torch.jit.script_method
1053            def forward(self, a):
1054                return 2 * a
1055
1056        m = MyMod()
1057
1058        # Save then load.
1059        with TemporaryFileName() as fname:
1060            path = Path(fname)
1061            torch.jit.save_jit_module_to_flatbuffer(m, path)
1062            m2 = torch.jit.load(path)
1063
1064        x = torch.tensor([1.0, 2.0, 3.0, 4.0])
1065        self.assertTrue(torch.equal(m(x), m2(x)))
1066
1067    def test_save_namedtuple_input_only(self):
1068        """
1069        Even if a NamedTuple is only used as an input argument, saving and
1070        loading should work correctly.
1071        """
1072        global FooTuple  # see [local resolution in python]
1073
1074        class FooTuple(NamedTuple):
1075            a: int
1076
1077        class MyModule(torch.nn.Module):
1078            def forward(self, x: FooTuple) -> torch.Tensor:
1079                return torch.tensor(3)
1080
1081        m_loaded = self.getExportImportCopy(torch.jit.script(MyModule()))
1082        output = m_loaded(FooTuple(a=5))
1083        self.assertEqual(output, torch.tensor(3))
1084
1085    def test_save_namedtuple_output_only(self):
1086        """
1087        Even if a NamedTuple is only used as an output argument, saving and
1088        loading should work correctly.
1089        """
1090        global FooTuple  # see [local resolution in python]
1091
1092        class FooTuple(NamedTuple):
1093            a: int
1094
1095        class MyModule(torch.nn.Module):
1096            def forward(self) -> Optional[FooTuple]:
1097                return None
1098
1099        m_loaded = self.getExportImportCopy(torch.jit.script(MyModule()))
1100        output = m_loaded()
1101        self.assertEqual(output, None)
1102
1103    def test_module_info_flatbuffer(self):
1104        class Foo(torch.nn.Module):
1105            def __init__(self) -> None:
1106                super().__init__()
1107                self.foo = torch.nn.Linear(2, 2)
1108                self.bar = torch.nn.Linear(2, 2)
1109
1110            def forward(self, x):
1111                x = self.foo(x)
1112                x = self.bar(x)
1113                return x
1114
1115        first_script_module = torch.jit.script(Foo())
1116        first_saved_module = io.BytesIO()
1117        torch.jit.save_jit_module_to_flatbuffer(first_script_module, first_saved_module)
1118        first_saved_module.seek(0)
1119        ff_info = torch.jit._serialization.get_flatbuffer_module_info(
1120            first_saved_module
1121        )
1122        self.assertEqual(ff_info["bytecode_version"], 9)
1123        self.assertEqual(ff_info["operator_version"], 1)
1124        self.assertEqual(ff_info["type_names"], set())
1125        self.assertEqual(ff_info["opname_to_num_args"], {"aten::linear": 3})
1126
1127        self.assertEqual(len(ff_info["function_names"]), 1)
1128        self.assertTrue(next(iter(ff_info["function_names"])).endswith("forward"))
1129
1130    def test_save_load_params_buffers_submodules(self):
1131        """
1132        Check that parameters, buffers, and submodules are the same after loading.
1133        """
1134
1135        class Submodule(torch.nn.Module):
1136            pass
1137
1138        class TestModule(torch.nn.Module):
1139            def __init__(self) -> None:
1140                super().__init__()
1141                self.add_module("submodule_a", Submodule())
1142                self.register_parameter(
1143                    "parameter_a", torch.nn.Parameter(torch.randn(4))
1144                )
1145                self.buffer = torch.nn.Buffer(torch.randn(4))
1146                self.t = torch.rand(4)  # not buffer
1147
1148                self.parameter_b = torch.nn.Parameter(torch.randn(4))
1149                self.submodule_b = Submodule()
1150                self.buffer_b = torch.nn.Buffer(torch.randn(4))
1151
1152        m = TestModule()
1153        m_loaded = self.getExportImportCopy(torch.jit.script(m))
1154
1155        # Check submodules.
1156        self.assertEqual(
1157            len(list(m.named_modules())), len(list(m_loaded.named_modules()))
1158        )
1159        for m_s, loaded_s in zip(m.named_modules(), m_loaded.named_modules()):
1160            m_name, _ = m_s
1161            loaded_name, _ = loaded_s
1162            self.assertEqual(m_name, loaded_name)
1163
1164        # Check parameters.
1165        self.assertEqual(len(list(m.parameters())), len(list(m_loaded.parameters())))
1166        for m_p, loaded_p in zip(m.parameters(), m_loaded.parameters()):
1167            self.assertEqual(m_p, loaded_p)
1168
1169        # Check buffers.
1170        self.assertEqual(
1171            len(list(m.named_buffers())), len(list(m_loaded.named_buffers()))
1172        )
1173        for m_b, loaded_b in zip(m.named_buffers(), m_loaded.named_buffers()):
1174            m_name, m_buffer = m_b
1175            loaded_name, loaded_buffer = loaded_b
1176            self.assertEqual(m_name, loaded_name)
1177            self.assertEqual(m_buffer, loaded_buffer)
1178
1179    def test_save_load_with_extra_files(self):
1180        """
1181        Check that parameters, buffers, and submodules are the same after loading.
1182        """
1183
1184        class Module(torch.nn.Module):
1185            def forward(self, x: Tensor):
1186                return x
1187
1188        module = Module()
1189        script_module = torch.jit.script(module)
1190
1191        extra_files = {"abc.json": b"[1,2,3]"}
1192        script_module_io = script_module._save_to_buffer_for_lite_interpreter(
1193            _extra_files=extra_files, _use_flatbuffer=True
1194        )
1195
1196        re_extra_files = {}
1197        torch._C._get_model_extra_files_from_buffer(script_module_io, re_extra_files)
1198
1199        self.assertEqual(extra_files, re_extra_files)
1200