1# Owner(s): ["oncall: package/deploy"] 2 3from io import BytesIO 4 5from torch.package import PackageExporter, PackageImporter, sys_importer 6from torch.testing._internal.common_utils import run_tests 7 8 9try: 10 from .common import PackageTestCase 11except ImportError: 12 # Support the case where we run this file directly. 13 from common import PackageTestCase 14 15 16class TestRepackage(PackageTestCase): 17 """Tests for repackaging.""" 18 19 def test_repackage_import_indirectly_via_parent_module(self): 20 from package_d.imports_directly import ImportsDirectlyFromSubSubPackage 21 from package_d.imports_indirectly import ImportsIndirectlyFromSubPackage 22 23 model_a = ImportsDirectlyFromSubSubPackage() 24 buffer = BytesIO() 25 with PackageExporter(buffer) as pe: 26 pe.intern("**") 27 pe.save_pickle("default", "model.py", model_a) 28 29 buffer.seek(0) 30 pi = PackageImporter(buffer) 31 loaded_model = pi.load_pickle("default", "model.py") 32 33 model_b = ImportsIndirectlyFromSubPackage() 34 buffer = BytesIO() 35 with PackageExporter( 36 buffer, 37 importer=( 38 pi, 39 sys_importer, 40 ), 41 ) as pe: 42 pe.intern("**") 43 pe.save_pickle("default", "model_b.py", model_b) 44 45 46if __name__ == "__main__": 47 run_tests() 48