xref: /aosp_15_r20/external/pytorch/test/package/test_repackage.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1*da0073e9SAndroid Build Coastguard Worker# Owner(s): ["oncall: package/deploy"]
2*da0073e9SAndroid Build Coastguard Worker
3*da0073e9SAndroid Build Coastguard Workerfrom io import BytesIO
4*da0073e9SAndroid Build Coastguard Worker
5*da0073e9SAndroid Build Coastguard Workerfrom torch.package import PackageExporter, PackageImporter, sys_importer
6*da0073e9SAndroid Build Coastguard Workerfrom torch.testing._internal.common_utils import run_tests
7*da0073e9SAndroid Build Coastguard Worker
8*da0073e9SAndroid Build Coastguard Worker
9*da0073e9SAndroid Build Coastguard Workertry:
10*da0073e9SAndroid Build Coastguard Worker    from .common import PackageTestCase
11*da0073e9SAndroid Build Coastguard Workerexcept ImportError:
12*da0073e9SAndroid Build Coastguard Worker    # Support the case where we run this file directly.
13*da0073e9SAndroid Build Coastguard Worker    from common import PackageTestCase
14*da0073e9SAndroid Build Coastguard Worker
15*da0073e9SAndroid Build Coastguard Worker
16*da0073e9SAndroid Build Coastguard Workerclass TestRepackage(PackageTestCase):
17*da0073e9SAndroid Build Coastguard Worker    """Tests for repackaging."""
18*da0073e9SAndroid Build Coastguard Worker
19*da0073e9SAndroid Build Coastguard Worker    def test_repackage_import_indirectly_via_parent_module(self):
20*da0073e9SAndroid Build Coastguard Worker        from package_d.imports_directly import ImportsDirectlyFromSubSubPackage
21*da0073e9SAndroid Build Coastguard Worker        from package_d.imports_indirectly import ImportsIndirectlyFromSubPackage
22*da0073e9SAndroid Build Coastguard Worker
23*da0073e9SAndroid Build Coastguard Worker        model_a = ImportsDirectlyFromSubSubPackage()
24*da0073e9SAndroid Build Coastguard Worker        buffer = BytesIO()
25*da0073e9SAndroid Build Coastguard Worker        with PackageExporter(buffer) as pe:
26*da0073e9SAndroid Build Coastguard Worker            pe.intern("**")
27*da0073e9SAndroid Build Coastguard Worker            pe.save_pickle("default", "model.py", model_a)
28*da0073e9SAndroid Build Coastguard Worker
29*da0073e9SAndroid Build Coastguard Worker        buffer.seek(0)
30*da0073e9SAndroid Build Coastguard Worker        pi = PackageImporter(buffer)
31*da0073e9SAndroid Build Coastguard Worker        loaded_model = pi.load_pickle("default", "model.py")
32*da0073e9SAndroid Build Coastguard Worker
33*da0073e9SAndroid Build Coastguard Worker        model_b = ImportsIndirectlyFromSubPackage()
34*da0073e9SAndroid Build Coastguard Worker        buffer = BytesIO()
35*da0073e9SAndroid Build Coastguard Worker        with PackageExporter(
36*da0073e9SAndroid Build Coastguard Worker            buffer,
37*da0073e9SAndroid Build Coastguard Worker            importer=(
38*da0073e9SAndroid Build Coastguard Worker                pi,
39*da0073e9SAndroid Build Coastguard Worker                sys_importer,
40*da0073e9SAndroid Build Coastguard Worker            ),
41*da0073e9SAndroid Build Coastguard Worker        ) as pe:
42*da0073e9SAndroid Build Coastguard Worker            pe.intern("**")
43*da0073e9SAndroid Build Coastguard Worker            pe.save_pickle("default", "model_b.py", model_b)
44*da0073e9SAndroid Build Coastguard Worker
45*da0073e9SAndroid Build Coastguard Worker
46*da0073e9SAndroid Build Coastguard Workerif __name__ == "__main__":
47*da0073e9SAndroid Build Coastguard Worker    run_tests()
48