xref: /aosp_15_r20/external/pytorch/test/package/test_save_load.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# Owner(s): ["oncall: package/deploy"]
2
3import pickle
4from io import BytesIO
5from textwrap import dedent
6
7from torch.package import PackageExporter, PackageImporter, sys_importer
8from torch.testing._internal.common_utils import run_tests
9
10
11try:
12    from .common import PackageTestCase
13except ImportError:
14    # Support the case where we run this file directly.
15    from common import PackageTestCase
16
17from pathlib import Path
18
19
20packaging_directory = Path(__file__).parent
21
22
23class TestSaveLoad(PackageTestCase):
24    """Core save_* and loading API tests."""
25
26    def test_saving_source(self):
27        buffer = BytesIO()
28        with PackageExporter(buffer) as he:
29            he.save_source_file("foo", str(packaging_directory / "module_a.py"))
30            he.save_source_file("foodir", str(packaging_directory / "package_a"))
31        buffer.seek(0)
32        hi = PackageImporter(buffer)
33        foo = hi.import_module("foo")
34        s = hi.import_module("foodir.subpackage")
35        self.assertEqual(foo.result, "module_a")
36        self.assertEqual(s.result, "package_a.subpackage")
37
38    def test_saving_string(self):
39        buffer = BytesIO()
40        with PackageExporter(buffer) as he:
41            src = dedent(
42                """\
43                import math
44                the_math = math
45                """
46            )
47            he.save_source_string("my_mod", src)
48        buffer.seek(0)
49        hi = PackageImporter(buffer)
50        m = hi.import_module("math")
51        import math
52
53        self.assertIs(m, math)
54        my_mod = hi.import_module("my_mod")
55        self.assertIs(my_mod.math, math)
56
57    def test_save_module(self):
58        buffer = BytesIO()
59        with PackageExporter(buffer) as he:
60            import module_a
61            import package_a
62
63            he.save_module(module_a.__name__)
64            he.save_module(package_a.__name__)
65        buffer.seek(0)
66        hi = PackageImporter(buffer)
67        module_a_i = hi.import_module("module_a")
68        self.assertEqual(module_a_i.result, "module_a")
69        self.assertIsNot(module_a, module_a_i)
70        package_a_i = hi.import_module("package_a")
71        self.assertEqual(package_a_i.result, "package_a")
72        self.assertIsNot(package_a_i, package_a)
73
74    def test_dunder_imports(self):
75        buffer = BytesIO()
76        with PackageExporter(buffer) as he:
77            import package_b
78
79            obj = package_b.PackageBObject
80            he.intern("**")
81            he.save_pickle("res", "obj.pkl", obj)
82
83        buffer.seek(0)
84        hi = PackageImporter(buffer)
85        loaded_obj = hi.load_pickle("res", "obj.pkl")
86
87        package_b = hi.import_module("package_b")
88        self.assertEqual(package_b.result, "package_b")
89
90        math = hi.import_module("math")
91        self.assertEqual(math.__name__, "math")
92
93        xml_sub_sub_package = hi.import_module("xml.sax.xmlreader")
94        self.assertEqual(xml_sub_sub_package.__name__, "xml.sax.xmlreader")
95
96        subpackage_1 = hi.import_module("package_b.subpackage_1")
97        self.assertEqual(subpackage_1.result, "subpackage_1")
98
99        subpackage_2 = hi.import_module("package_b.subpackage_2")
100        self.assertEqual(subpackage_2.result, "subpackage_2")
101
102        subsubpackage_0 = hi.import_module("package_b.subpackage_0.subsubpackage_0")
103        self.assertEqual(subsubpackage_0.result, "subsubpackage_0")
104
105    def test_bad_dunder_imports(self):
106        """Test to ensure bad __imports__ don't cause PackageExporter to fail."""
107        buffer = BytesIO()
108        with PackageExporter(buffer) as e:
109            e.save_source_string(
110                "m", '__import__(these, unresolvable, "things", wont, crash, me)'
111            )
112
113    def test_save_module_binary(self):
114        f = BytesIO()
115        with PackageExporter(f) as he:
116            import module_a
117            import package_a
118
119            he.save_module(module_a.__name__)
120            he.save_module(package_a.__name__)
121        f.seek(0)
122        hi = PackageImporter(f)
123        module_a_i = hi.import_module("module_a")
124        self.assertEqual(module_a_i.result, "module_a")
125        self.assertIsNot(module_a, module_a_i)
126        package_a_i = hi.import_module("package_a")
127        self.assertEqual(package_a_i.result, "package_a")
128        self.assertIsNot(package_a_i, package_a)
129
130    def test_pickle(self):
131        import package_a.subpackage
132
133        obj = package_a.subpackage.PackageASubpackageObject()
134        obj2 = package_a.PackageAObject(obj)
135
136        buffer = BytesIO()
137        with PackageExporter(buffer) as he:
138            he.intern("**")
139            he.save_pickle("obj", "obj.pkl", obj2)
140        buffer.seek(0)
141        hi = PackageImporter(buffer)
142
143        # check we got dependencies
144        sp = hi.import_module("package_a.subpackage")
145        # check we didn't get other stuff
146        with self.assertRaises(ImportError):
147            hi.import_module("module_a")
148
149        obj_loaded = hi.load_pickle("obj", "obj.pkl")
150        self.assertIsNot(obj2, obj_loaded)
151        self.assertIsInstance(obj_loaded.obj, sp.PackageASubpackageObject)
152        self.assertIsNot(
153            package_a.subpackage.PackageASubpackageObject, sp.PackageASubpackageObject
154        )
155
156    def test_pickle_long_name_with_protocol_4(self):
157        import package_a.long_name
158
159        container = []
160
161        # Indirectly grab the function to avoid pasting a 256 character
162        # function into the test
163        package_a.long_name.add_function(container)
164
165        buffer = BytesIO()
166        with PackageExporter(buffer) as exporter:
167            exporter.intern("**")
168            exporter.save_pickle(
169                "container", "container.pkl", container, pickle_protocol=4
170            )
171
172        buffer.seek(0)
173        importer = PackageImporter(buffer)
174        unpickled_container = importer.load_pickle("container", "container.pkl")
175        self.assertIsNot(container, unpickled_container)
176        self.assertEqual(len(unpickled_container), 1)
177        self.assertEqual(container[0](), unpickled_container[0]())
178
179    def test_exporting_mismatched_code(self):
180        """
181        If an object with the same qualified name is loaded from different
182        packages, the user should get an error if they try to re-save the
183        object with the wrong package's source code.
184        """
185        import package_a.subpackage
186
187        obj = package_a.subpackage.PackageASubpackageObject()
188        obj2 = package_a.PackageAObject(obj)
189
190        b1 = BytesIO()
191        with PackageExporter(b1) as pe:
192            pe.intern("**")
193            pe.save_pickle("obj", "obj.pkl", obj2)
194
195        b1.seek(0)
196        importer1 = PackageImporter(b1)
197        loaded1 = importer1.load_pickle("obj", "obj.pkl")
198
199        b1.seek(0)
200        importer2 = PackageImporter(b1)
201        loaded2 = importer2.load_pickle("obj", "obj.pkl")
202
203        def make_exporter():
204            pe = PackageExporter(BytesIO(), importer=[importer1, sys_importer])
205            # Ensure that the importer finds the 'PackageAObject' defined in 'importer1' first.
206            return pe
207
208        # This should fail. The 'PackageAObject' type defined from 'importer1'
209        # is not necessarily the same 'obj2's version of 'PackageAObject'.
210        pe = make_exporter()
211        with self.assertRaises(pickle.PicklingError):
212            pe.save_pickle("obj", "obj.pkl", obj2)
213
214        # This should also fail. The 'PackageAObject' type defined from 'importer1'
215        # is not necessarily the same as the one defined from 'importer2'
216        pe = make_exporter()
217        with self.assertRaises(pickle.PicklingError):
218            pe.save_pickle("obj", "obj.pkl", loaded2)
219
220        # This should succeed. The 'PackageAObject' type defined from
221        # 'importer1' is a match for the one used by loaded1.
222        pe = make_exporter()
223        pe.save_pickle("obj", "obj.pkl", loaded1)
224
225    def test_save_imported_module(self):
226        """Saving a module that came from another PackageImporter should work."""
227        import package_a.subpackage
228
229        obj = package_a.subpackage.PackageASubpackageObject()
230        obj2 = package_a.PackageAObject(obj)
231
232        buffer = BytesIO()
233        with PackageExporter(buffer) as exporter:
234            exporter.intern("**")
235            exporter.save_pickle("model", "model.pkl", obj2)
236
237        buffer.seek(0)
238
239        importer = PackageImporter(buffer)
240        imported_obj2 = importer.load_pickle("model", "model.pkl")
241        imported_obj2_module = imported_obj2.__class__.__module__
242
243        # Should export without error.
244        buffer2 = BytesIO()
245        with PackageExporter(buffer2, importer=(importer, sys_importer)) as exporter:
246            exporter.intern("**")
247            exporter.save_module(imported_obj2_module)
248
249    def test_save_imported_module_using_package_importer(self):
250        """Exercise a corner case: re-packaging a module that uses `torch_package_importer`"""
251        import package_a.use_torch_package_importer  # noqa: F401
252
253        buffer = BytesIO()
254        with PackageExporter(buffer) as exporter:
255            exporter.intern("**")
256            exporter.save_module("package_a.use_torch_package_importer")
257
258        buffer.seek(0)
259
260        importer = PackageImporter(buffer)
261
262        # Should export without error.
263        buffer2 = BytesIO()
264        with PackageExporter(buffer2, importer=(importer, sys_importer)) as exporter:
265            exporter.intern("**")
266            exporter.save_module("package_a.use_torch_package_importer")
267
268
269if __name__ == "__main__":
270    run_tests()
271