xref: /aosp_15_r20/external/pytorch/test/package/test_repackage.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# Owner(s): ["oncall: package/deploy"]
2
3from io import BytesIO
4
5from torch.package import PackageExporter, PackageImporter, sys_importer
6from torch.testing._internal.common_utils import run_tests
7
8
9try:
10    from .common import PackageTestCase
11except ImportError:
12    # Support the case where we run this file directly.
13    from common import PackageTestCase
14
15
16class TestRepackage(PackageTestCase):
17    """Tests for repackaging."""
18
19    def test_repackage_import_indirectly_via_parent_module(self):
20        from package_d.imports_directly import ImportsDirectlyFromSubSubPackage
21        from package_d.imports_indirectly import ImportsIndirectlyFromSubPackage
22
23        model_a = ImportsDirectlyFromSubSubPackage()
24        buffer = BytesIO()
25        with PackageExporter(buffer) as pe:
26            pe.intern("**")
27            pe.save_pickle("default", "model.py", model_a)
28
29        buffer.seek(0)
30        pi = PackageImporter(buffer)
31        loaded_model = pi.load_pickle("default", "model.py")
32
33        model_b = ImportsIndirectlyFromSubPackage()
34        buffer = BytesIO()
35        with PackageExporter(
36            buffer,
37            importer=(
38                pi,
39                sys_importer,
40            ),
41        ) as pe:
42            pe.intern("**")
43            pe.save_pickle("default", "model_b.py", model_b)
44
45
46if __name__ == "__main__":
47    run_tests()
48