xref: /aosp_15_r20/external/pytorch/test/package/test_dependency_api.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# Owner(s): ["oncall: package/deploy"]
2
3import importlib
4from io import BytesIO
5from sys import version_info
6from textwrap import dedent
7from unittest import skipIf
8
9import torch.nn
10from torch.package import EmptyMatchError, Importer, PackageExporter, PackageImporter
11from torch.package.package_exporter import PackagingError
12from torch.testing._internal.common_utils import IS_WINDOWS, run_tests
13
14
15try:
16    from .common import PackageTestCase
17except ImportError:
18    # Support the case where we run this file directly.
19    from common import PackageTestCase
20
21
22class TestDependencyAPI(PackageTestCase):
23    """Dependency management API tests.
24    - mock()
25    - extern()
26    - deny()
27    """
28
29    def test_extern(self):
30        buffer = BytesIO()
31        with PackageExporter(buffer) as he:
32            he.extern(["package_a.subpackage", "module_a"])
33            he.save_source_string("foo", "import package_a.subpackage; import module_a")
34        buffer.seek(0)
35        hi = PackageImporter(buffer)
36        import module_a
37        import package_a.subpackage
38
39        module_a_im = hi.import_module("module_a")
40        hi.import_module("package_a.subpackage")
41        package_a_im = hi.import_module("package_a")
42
43        self.assertIs(module_a, module_a_im)
44        self.assertIsNot(package_a, package_a_im)
45        self.assertIs(package_a.subpackage, package_a_im.subpackage)
46
47    def test_extern_glob(self):
48        buffer = BytesIO()
49        with PackageExporter(buffer) as he:
50            he.extern(["package_a.*", "module_*"])
51            he.save_module("package_a")
52            he.save_source_string(
53                "test_module",
54                dedent(
55                    """\
56                    import package_a.subpackage
57                    import module_a
58                    """
59                ),
60            )
61        buffer.seek(0)
62        hi = PackageImporter(buffer)
63        import module_a
64        import package_a.subpackage
65
66        module_a_im = hi.import_module("module_a")
67        hi.import_module("package_a.subpackage")
68        package_a_im = hi.import_module("package_a")
69
70        self.assertIs(module_a, module_a_im)
71        self.assertIsNot(package_a, package_a_im)
72        self.assertIs(package_a.subpackage, package_a_im.subpackage)
73
74    def test_extern_glob_allow_empty(self):
75        """
76        Test that an error is thrown when a extern glob is specified with allow_empty=True
77        and no matching module is required during packaging.
78        """
79        import package_a.subpackage  # noqa: F401
80
81        buffer = BytesIO()
82        with self.assertRaisesRegex(EmptyMatchError, r"did not match any modules"):
83            with PackageExporter(buffer) as exporter:
84                exporter.extern(include=["package_b.*"], allow_empty=False)
85                exporter.save_module("package_a.subpackage")
86
87    def test_deny(self):
88        """
89        Test marking packages as "deny" during export.
90        """
91        buffer = BytesIO()
92
93        with self.assertRaisesRegex(PackagingError, "denied"):
94            with PackageExporter(buffer) as exporter:
95                exporter.deny(["package_a.subpackage", "module_a"])
96                exporter.save_source_string("foo", "import package_a.subpackage")
97
98    def test_deny_glob(self):
99        """
100        Test marking packages as "deny" using globs instead of package names.
101        """
102        buffer = BytesIO()
103        with self.assertRaises(PackagingError):
104            with PackageExporter(buffer) as exporter:
105                exporter.deny(["package_a.*", "module_*"])
106                exporter.save_source_string(
107                    "test_module",
108                    dedent(
109                        """\
110                        import package_a.subpackage
111                        import module_a
112                        """
113                    ),
114                )
115
116    @skipIf(version_info < (3, 7), "mock uses __getattr__ a 3.7 feature")
117    def test_mock(self):
118        buffer = BytesIO()
119        with PackageExporter(buffer) as he:
120            he.mock(["package_a.subpackage", "module_a"])
121            # Import something that dependso n package_a.subpackage
122            he.save_source_string("foo", "import package_a.subpackage")
123        buffer.seek(0)
124        hi = PackageImporter(buffer)
125        import package_a.subpackage
126
127        _ = package_a.subpackage
128        import module_a
129
130        _ = module_a
131
132        m = hi.import_module("package_a.subpackage")
133        r = m.result
134        with self.assertRaisesRegex(NotImplementedError, "was mocked out"):
135            r()
136
137    @skipIf(version_info < (3, 7), "mock uses __getattr__ a 3.7 feature")
138    def test_mock_glob(self):
139        buffer = BytesIO()
140        with PackageExporter(buffer) as he:
141            he.mock(["package_a.*", "module*"])
142            he.save_module("package_a")
143            he.save_source_string(
144                "test_module",
145                dedent(
146                    """\
147                    import package_a.subpackage
148                    import module_a
149                    """
150                ),
151            )
152        buffer.seek(0)
153        hi = PackageImporter(buffer)
154        import package_a.subpackage
155
156        _ = package_a.subpackage
157        import module_a
158
159        _ = module_a
160
161        m = hi.import_module("package_a.subpackage")
162        r = m.result
163        with self.assertRaisesRegex(NotImplementedError, "was mocked out"):
164            r()
165
166    def test_mock_glob_allow_empty(self):
167        """
168        Test that an error is thrown when a mock glob is specified with allow_empty=True
169        and no matching module is required during packaging.
170        """
171        import package_a.subpackage  # noqa: F401
172
173        buffer = BytesIO()
174        with self.assertRaisesRegex(EmptyMatchError, r"did not match any modules"):
175            with PackageExporter(buffer) as exporter:
176                exporter.mock(include=["package_b.*"], allow_empty=False)
177                exporter.save_module("package_a.subpackage")
178
179    @skipIf(version_info < (3, 7), "mock uses __getattr__ a 3.7 feature")
180    def test_pickle_mocked(self):
181        import package_a.subpackage
182
183        obj = package_a.subpackage.PackageASubpackageObject()
184        obj2 = package_a.PackageAObject(obj)
185
186        buffer = BytesIO()
187        with self.assertRaises(PackagingError):
188            with PackageExporter(buffer) as he:
189                he.mock(include="package_a.subpackage")
190                he.intern("**")
191                he.save_pickle("obj", "obj.pkl", obj2)
192
193    @skipIf(version_info < (3, 7), "mock uses __getattr__ a 3.7 feature")
194    def test_pickle_mocked_all(self):
195        import package_a.subpackage
196
197        obj = package_a.subpackage.PackageASubpackageObject()
198        obj2 = package_a.PackageAObject(obj)
199
200        buffer = BytesIO()
201        with PackageExporter(buffer) as he:
202            he.intern(include="package_a.**")
203            he.mock("**")
204            he.save_pickle("obj", "obj.pkl", obj2)
205
206    def test_allow_empty_with_error(self):
207        """If an error occurs during packaging, it should not be shadowed by the allow_empty error."""
208        buffer = BytesIO()
209        with self.assertRaises(ModuleNotFoundError):
210            with PackageExporter(buffer) as pe:
211                # Even though we did not extern a module that matches this
212                # pattern, we want to show the save_module error, not the allow_empty error.
213
214                pe.extern("foo", allow_empty=False)
215                pe.save_module("aodoifjodisfj")  # will error
216
217                # we never get here, so technically the allow_empty check
218                # should raise an error. However, the error above is more
219                # informative to what's actually going wrong with packaging.
220                pe.save_source_string("bar", "import foo\n")
221
222    def test_implicit_intern(self):
223        """The save_module APIs should implicitly intern the module being saved."""
224        import package_a  # noqa: F401
225
226        buffer = BytesIO()
227        with PackageExporter(buffer) as he:
228            he.save_module("package_a")
229
230    def test_intern_error(self):
231        """Failure to handle all dependencies should lead to an error."""
232        import package_a.subpackage
233
234        obj = package_a.subpackage.PackageASubpackageObject()
235        obj2 = package_a.PackageAObject(obj)
236
237        buffer = BytesIO()
238
239        with self.assertRaises(PackagingError) as e:
240            with PackageExporter(buffer) as he:
241                he.save_pickle("obj", "obj.pkl", obj2)
242
243        self.assertEqual(
244            str(e.exception),
245            dedent(
246                """
247                * Module did not match against any action pattern. Extern, mock, or intern it.
248                    package_a
249                    package_a.subpackage
250
251                Set debug=True when invoking PackageExporter for a visualization of where broken modules are coming from!
252                """
253            ),
254        )
255
256        # Interning all dependencies should work
257        with PackageExporter(buffer) as he:
258            he.intern(["package_a", "package_a.subpackage"])
259            he.save_pickle("obj", "obj.pkl", obj2)
260
261    @skipIf(IS_WINDOWS, "extension modules have a different file extension on windows")
262    def test_broken_dependency(self):
263        """A unpackageable dependency should raise a PackagingError."""
264
265        def create_module(name):
266            spec = importlib.machinery.ModuleSpec(name, self, is_package=False)  # type: ignore[arg-type]
267            module = importlib.util.module_from_spec(spec)
268            ns = module.__dict__
269            ns["__spec__"] = spec
270            ns["__loader__"] = self
271            ns["__file__"] = f"{name}.so"
272            ns["__cached__"] = None
273            return module
274
275        class BrokenImporter(Importer):
276            def __init__(self) -> None:
277                self.modules = {
278                    "foo": create_module("foo"),
279                    "bar": create_module("bar"),
280                }
281
282            def import_module(self, module_name):
283                return self.modules[module_name]
284
285        buffer = BytesIO()
286
287        with self.assertRaises(PackagingError) as e:
288            with PackageExporter(buffer, importer=BrokenImporter()) as exporter:
289                exporter.intern(["foo", "bar"])
290                exporter.save_source_string("my_module", "import foo; import bar")
291
292        self.assertEqual(
293            str(e.exception),
294            dedent(
295                """
296                * Module is a C extension module. torch.package supports Python modules only.
297                    foo
298                    bar
299
300                Set debug=True when invoking PackageExporter for a visualization of where broken modules are coming from!
301                """
302            ),
303        )
304
305    def test_invalid_import(self):
306        """An incorrectly-formed import should raise a PackagingError."""
307        buffer = BytesIO()
308        with self.assertRaises(PackagingError) as e:
309            with PackageExporter(buffer) as exporter:
310                # This import will fail to load.
311                exporter.save_source_string("foo", "from ........ import lol")
312
313        self.assertEqual(
314            str(e.exception),
315            dedent(
316                """
317                * Dependency resolution failed.
318                    foo
319                      Context: attempted relative import beyond top-level package
320
321                Set debug=True when invoking PackageExporter for a visualization of where broken modules are coming from!
322                """
323            ),
324        )
325
326    @skipIf(version_info < (3, 7), "mock uses __getattr__ a 3.7 feature")
327    def test_repackage_mocked_module(self):
328        """Re-packaging a package that contains a mocked module should work correctly."""
329        buffer = BytesIO()
330        with PackageExporter(buffer) as exporter:
331            exporter.mock("package_a")
332            exporter.save_source_string("foo", "import package_a")
333
334        buffer.seek(0)
335        importer = PackageImporter(buffer)
336        foo = importer.import_module("foo")
337
338        # "package_a" should be mocked out.
339        with self.assertRaises(NotImplementedError):
340            foo.package_a.get_something()
341
342        # Re-package the model, but intern the previously-mocked module and mock
343        # everything else.
344        buffer2 = BytesIO()
345        with PackageExporter(buffer2, importer=importer) as exporter:
346            exporter.intern("package_a")
347            exporter.mock("**")
348            exporter.save_source_string("foo", "import package_a")
349
350        buffer2.seek(0)
351        importer2 = PackageImporter(buffer2)
352        foo2 = importer2.import_module("foo")
353
354        # "package_a" should still be mocked out.
355        with self.assertRaises(NotImplementedError):
356            foo2.package_a.get_something()
357
358    def test_externing_c_extension(self):
359        """Externing c extensions modules should allow us to still access them especially those found in torch._C."""
360
361        buffer = BytesIO()
362        # The C extension module in question is F.gelu which comes from torch._C._nn
363        model = torch.nn.TransformerEncoderLayer(
364            d_model=64,
365            nhead=2,
366            dim_feedforward=64,
367            dropout=1.0,
368            batch_first=True,
369            activation="gelu",
370            norm_first=True,
371        )
372        with PackageExporter(buffer) as e:
373            e.extern("torch.**")
374            e.intern("**")
375
376            e.save_pickle("model", "model.pkl", model)
377        buffer.seek(0)
378        imp = PackageImporter(buffer)
379        imp.load_pickle("model", "model.pkl")
380
381
382if __name__ == "__main__":
383    run_tests()
384