xref: /aosp_15_r20/external/pytorch/test/jit/test_types.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# Owner(s): ["oncall: jit"]
2
3import inspect
4import os
5import sys
6from collections import namedtuple
7from textwrap import dedent
8from typing import Dict, Iterator, List, Optional, Tuple
9
10import torch
11import torch.testing._internal.jit_utils
12from jit.test_module_interface import TestModuleInterface  # noqa: F401
13from torch.testing import FileCheck
14from torch.testing._internal.jit_utils import JitTestCase
15
16
17# Make the helper files in test/ importable
18pytorch_test_dir = os.path.dirname(os.path.dirname(os.path.realpath(__file__)))
19sys.path.append(pytorch_test_dir)
20
21if __name__ == "__main__":
22    raise RuntimeError(
23        "This test file is not meant to be run directly, use:\n\n"
24        "\tpython test/test_jit.py TESTNAME\n\n"
25        "instead."
26    )
27
28
29class TestTypesAndAnnotation(JitTestCase):
30    def test_pep585_type(self):
31        # TODO add test to use PEP585 type annotation for return type after py3.9
32        # see: https://www.python.org/dev/peps/pep-0585/#id5
33        def fn(x: torch.Tensor) -> Tuple[Tuple[torch.Tensor], Dict[str, int]]:
34            xl: list[tuple[torch.Tensor]] = []
35            xd: dict[str, int] = {}
36            xl.append((x,))
37            xd["foo"] = 1
38            return xl.pop(), xd
39
40        self.checkScript(fn, [torch.randn(2, 2)])
41
42        x = torch.randn(2, 2)
43        expected = fn(x)
44        scripted = torch.jit.script(fn)(x)
45
46        self.assertEqual(expected, scripted)
47
48    def test_types_as_values(self):
49        def fn(m: torch.Tensor) -> torch.device:
50            return m.device
51
52        self.checkScript(fn, [torch.randn(2, 2)])
53
54        GG = namedtuple("GG", ["f", "g"])
55
56        class Foo(torch.nn.Module):
57            @torch.jit.ignore
58            def foo(self, x: torch.Tensor, z: torch.Tensor) -> Tuple[GG, GG]:
59                return GG(x, z), GG(x, z)
60
61            def forward(self, x, z):
62                return self.foo(x, z)
63
64        foo = torch.jit.script(Foo())
65        y = foo(torch.randn(2, 2), torch.randn(2, 2))
66
67        class Foo(torch.nn.Module):
68            @torch.jit.ignore
69            def foo(self, x, z) -> Tuple[GG, GG]:
70                return GG(x, z)
71
72            def forward(self, x, z):
73                return self.foo(x, z)
74
75        foo = torch.jit.script(Foo())
76        y = foo(torch.randn(2, 2), torch.randn(2, 2))
77
78    def test_ignore_with_types(self):
79        @torch.jit.ignore
80        def fn(x: Dict[str, Optional[torch.Tensor]]):
81            return x + 10
82
83        class M(torch.nn.Module):
84            def forward(
85                self, in_batch: Dict[str, Optional[torch.Tensor]]
86            ) -> torch.Tensor:
87                self.dropout_modality(in_batch)
88                fn(in_batch)
89                return torch.tensor(1)
90
91            @torch.jit.ignore
92            def dropout_modality(
93                self, in_batch: Dict[str, Optional[torch.Tensor]]
94            ) -> Dict[str, Optional[torch.Tensor]]:
95                return in_batch
96
97        sm = torch.jit.script(M())
98        FileCheck().check("dropout_modality").check("in_batch").run(str(sm.graph))
99
100    def test_python_callable(self):
101        class MyPythonClass:
102            @torch.jit.ignore
103            def __call__(self, *args) -> str:
104                return str(type(args[0]))
105
106        the_class = MyPythonClass()
107
108        @torch.jit.script
109        def fn(x):
110            return the_class(x)
111
112        # This doesn't involve the string frontend, so don't use checkScript
113        x = torch.ones(2)
114        self.assertEqual(fn(x), the_class(x))
115
116    def test_bad_types(self):
117        @torch.jit.ignore
118        def fn(my_arg):
119            return my_arg + 10
120
121        with self.assertRaisesRegex(RuntimeError, "argument 'my_arg'"):
122
123            @torch.jit.script
124            def other_fn(x):
125                return fn("2")
126
127    def test_type_annotate_py3(self):
128        def fn():
129            a: List[int] = []
130            b: torch.Tensor = torch.ones(2, 2)
131            c: Optional[torch.Tensor] = None
132            d: Optional[torch.Tensor] = torch.ones(3, 4)
133            for _ in range(10):
134                a.append(4)
135                c = torch.ones(2, 2)
136                d = None
137            return a, b, c, d
138
139        self.checkScript(fn, ())
140
141        def wrong_type():
142            wrong: List[int] = [0.5]
143            return wrong
144
145        with self.assertRaisesRegex(
146            RuntimeError,
147            "List type annotation"
148            r" `List\[int\]` did not match the "
149            "types of the given list elements",
150        ):
151            torch.jit.script(wrong_type)
152
153    def test_optional_no_element_type_annotation(self):
154        """
155        Test that using an optional with no contained types produces an error.
156        """
157
158        def fn_with_comment(x: torch.Tensor) -> Optional:
159            return (x, x)
160
161        def annotated_fn(x: torch.Tensor) -> Optional:
162            return (x, x)
163
164        with self.assertRaisesRegex(
165            RuntimeError, r"Attempted to use Optional without a contained type"
166        ):
167            cu = torch.jit.CompilationUnit()
168            cu.define(dedent(inspect.getsource(fn_with_comment)))
169
170        with self.assertRaisesRegex(
171            RuntimeError, r"Attempted to use Optional without a contained type"
172        ):
173            cu = torch.jit.CompilationUnit()
174            cu.define(dedent(inspect.getsource(annotated_fn)))
175
176        with self.assertRaisesRegex(
177            RuntimeError, r"Attempted to use Optional without a contained type"
178        ):
179            torch.jit.script(fn_with_comment)
180
181        with self.assertRaisesRegex(
182            RuntimeError, r"Attempted to use Optional without a contained type"
183        ):
184            torch.jit.script(annotated_fn)
185
186    def test_tuple_no_element_type_annotation(self):
187        """
188        Test that using a tuple with no contained types produces an error.
189        """
190
191        def fn_with_comment(x: torch.Tensor) -> Tuple:
192            return (x, x)
193
194        def annotated_fn(x: torch.Tensor) -> Tuple:
195            return (x, x)
196
197        with self.assertRaisesRegex(
198            RuntimeError, r"Attempted to use Tuple without a contained type"
199        ):
200            cu = torch.jit.CompilationUnit()
201            cu.define(dedent(inspect.getsource(fn_with_comment)))
202
203        with self.assertRaisesRegex(
204            RuntimeError, r"Attempted to use Tuple without a contained type"
205        ):
206            cu = torch.jit.CompilationUnit()
207            cu.define(dedent(inspect.getsource(annotated_fn)))
208
209        with self.assertRaisesRegex(
210            RuntimeError, r"Attempted to use Tuple without a contained type"
211        ):
212            torch.jit.script(fn_with_comment)
213
214        with self.assertRaisesRegex(
215            RuntimeError, r"Attempted to use Tuple without a contained type"
216        ):
217            torch.jit.script(annotated_fn)
218
219    def test_ignoring_module_attributes(self):
220        """
221        Test that module attributes can be ignored.
222        """
223
224        class Sub(torch.nn.Module):
225            def forward(self, a: int) -> int:
226                return sum([a])
227
228        class ModuleWithIgnoredAttr(torch.nn.Module):
229            __jit_ignored_attributes__ = ["a", "sub"]
230
231            def __init__(self, a: int, b: int):
232                super().__init__()
233                self.a = a
234                self.b = b
235                self.sub = Sub()
236
237            def forward(self) -> int:
238                return self.b
239
240            @torch.jit.ignore
241            def ignored_fn(self) -> int:
242                return self.sub.forward(self.a)
243
244        mod = ModuleWithIgnoredAttr(1, 4)
245        scripted_mod = torch.jit.script(mod)
246        self.assertEqual(scripted_mod(), 4)
247        self.assertEqual(scripted_mod.ignored_fn(), 1)
248
249        # Test the error message for ignored attributes.
250        class ModuleUsesIgnoredAttr(torch.nn.Module):
251            __jit_ignored_attributes__ = ["a", "sub"]
252
253            def __init__(self, a: int):
254                super().__init__()
255                self.a = a
256                self.sub = Sub()
257
258            def forward(self) -> int:
259                return self.sub(self.b)
260
261        mod = ModuleUsesIgnoredAttr(1)
262
263        with self.assertRaisesRegexWithHighlight(
264            RuntimeError, r"attribute was ignored during compilation", "self.sub"
265        ):
266            scripted_mod = torch.jit.script(mod)
267
268    def test_ignoring_fn_with_nonscriptable_types(self):
269        class CFX:
270            def __init__(self, a: List[torch.Tensor]) -> None:
271                self.a = a
272
273            def forward(self, x: torch.Tensor) -> torch.Tensor:
274                return torch.sin(x)
275
276            @torch.jit._drop
277            def __iter__(self) -> Iterator[torch.Tensor]:
278                return iter(self.a)
279
280            @torch.jit._drop
281            def __fx_create_arg__(
282                self, tracer: torch.fx.Tracer
283            ) -> torch.fx.node.Argument:
284                # torch.fx classes are not scriptable
285                return tracer.create_node(
286                    "call_function",
287                    CFX,
288                    args=(tracer.create_arg(self.features),),
289                    kwargs={},
290                )
291
292        torch.jit.script(CFX)
293
294    def test_unimported_type_resolution(self):
295        # verify fallback from the python resolver to the c++ resolver
296
297        @torch.jit.script
298        def fn(x):
299            # type: (number) -> number
300            return x + 1
301
302        FileCheck().check("Scalar").run(fn.graph)
303
304    def test_parser_bug(self):
305        def parser_bug(o: Optional[torch.Tensor]):
306            pass
307
308    def test_mismatched_annotation(self):
309        with self.assertRaisesRegex(RuntimeError, "annotated with type"):
310
311            @torch.jit.script
312            def foo():
313                x: str = 4
314                return x
315
316    def test_reannotate(self):
317        with self.assertRaisesRegex(RuntimeError, "declare and annotate"):
318
319            @torch.jit.script
320            def foo():
321                x = 5
322                if 1 == 1:
323                    x: Optional[int] = 7
324
325    def test_annotate_outside_init(self):
326        msg = "annotations on instance attributes must be declared in __init__"
327        highlight = "self.x: int"
328
329        # Simple case
330        with self.assertRaisesRegexWithHighlight(ValueError, msg, highlight):
331
332            @torch.jit.script
333            class BadModule:
334                def __init__(self, x: int):
335                    self.x = x
336
337                def set(self, val: int):
338                    self.x: int = val
339
340        # Type annotation in a loop
341        with self.assertRaisesRegexWithHighlight(ValueError, msg, highlight):
342
343            @torch.jit.script
344            class BadModuleLoop:
345                def __init__(self, x: int):
346                    self.x = x
347
348                def set(self, val: int):
349                    for i in range(3):
350                        self.x: int = val
351
352        # Type annotation in __init__, should not fail
353        @torch.jit.script
354        class GoodModule:
355            def __init__(self, x: int):
356                self.x: int = x
357
358            def set(self, val: int):
359                self.x = val
360
361    def test_inferred_type_error_message(self):
362        inferred_type = torch._C.InferredType("ErrorReason")
363
364        with self.assertRaisesRegex(
365            RuntimeError,
366            "Tried to get the type from an InferredType but the type is null.",
367        ):
368            t = inferred_type.type()
369
370        with self.assertRaisesRegex(RuntimeError, "ErrorReason"):
371            t = inferred_type.type()
372