xref: /aosp_15_r20/external/pytorch/test/jit/test_builtins.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# Owner(s): ["oncall: jit"]
2
3import inspect
4import os
5import sys
6import unittest
7from typing import Dict, List
8
9import torch
10from torch.testing import FileCheck
11
12
13# Make the helper files in test/ importable
14pytorch_test_dir = os.path.dirname(os.path.dirname(os.path.realpath(__file__)))
15sys.path.append(pytorch_test_dir)
16from torch.testing._internal.jit_utils import JitTestCase, RUN_CUDA
17
18
19if __name__ == "__main__":
20    raise RuntimeError(
21        "This test file is not meant to be run directly, use:\n\n"
22        "\tpython test/test_jit.py TESTNAME\n\n"
23        "instead."
24    )
25
26
27class TestBuiltins(JitTestCase):
28    """
29    Tests for TorchScript support of Python builtin functions.
30    """
31
32    def test_has_attr(self):
33        class HasA(torch.nn.Module):
34            def __init__(self) -> None:
35                super().__init__()
36                self.a = 0
37
38        class HasB(torch.nn.Module):
39            def __init__(self) -> None:
40                super().__init__()
41                self.b = 1
42
43        class Mod(torch.nn.Module):
44            def __init__(self) -> None:
45                super().__init__()
46                self.mods = torch.nn.ModuleList([HasA(), HasB()])
47
48            def forward(self):
49                # use a list to encode hasattr results
50                l = torch.jit.annotate(List[int], [])
51                for mod in self.mods:
52                    l.append(int(hasattr(mod, "a")))
53                    l.append(int(hasattr(mod, "b")))
54                    # actually retrieve the attr to test static refinement
55                    if hasattr(mod, "a"):
56                        l.append(mod.a)
57                    if hasattr(mod, "b"):
58                        l.append(mod.b)
59                return l
60
61        self.checkModule(Mod(), ())
62
63    def test_has_attr_invalid_args(self):
64        class Mod(torch.nn.Module):
65            def __init__(self) -> None:
66                super().__init__()
67                self.mod = torch.nn.Linear(1, 1)
68
69            def forward(self, name):
70                # not allowed, `name` must be static.
71                return hasattr(self.mod, name)
72
73        with self.assertRaisesRegexWithHighlight(RuntimeError, "hasattr", "name"):
74            torch.jit.script(Mod())
75
76        class Mod(torch.nn.Module):
77            def forward(self, name):
78                # not allowed, `torch.rand` is not a class type
79                return hasattr(torch.rand(2, 3), name)
80
81        with self.assertRaisesRegexWithHighlight(RuntimeError, "hasattr", "name"):
82            torch.jit.script(Mod())
83
84    def test_del(self):
85        def fn(x: List[int]) -> List[int]:
86            a = x * 2
87            del a
88            return x
89
90        self.checkScript(fn, ([1, 2, 3],))
91
92        with self.assertRaisesRegexWithHighlight(RuntimeError, "undefined value", "a"):
93
94            @torch.jit.script
95            def fn(x):
96                a = x**2
97                del a
98                return a  # noqa: F821
99
100        with self.assertRaisesRegexWithHighlight(RuntimeError, "undefined value", "a"):
101
102            @torch.jit.script
103            def fn(x):
104                a = x**2
105                if a:
106                    del a
107                return a
108
109        with self.assertRaisesRegexWithHighlight(RuntimeError, "undefined value", "b"):
110
111            @torch.jit.script
112            def fn(x):
113                a = x**2
114                del b  # noqa: F821
115                return a
116
117    def test_del_multiple_operands(self):
118        def fn(x: List[int]) -> List[int]:
119            a, b, c = x[0], x[1], x[2]
120            del a, b, c
121            return x
122
123        self.checkScript(fn, ([1, 2, 3],))
124
125        def del_list_multiple_operands(x: List[int]) -> List[int]:
126            del x[0], x[1]
127            return x
128
129        py_out = del_list_multiple_operands([0, 1, 2])
130        jit_out = torch.jit.script(del_list_multiple_operands)([0, 1, 2])
131        self.assertEqual(py_out, jit_out)
132
133        def del_dict_multiple_operands(x: Dict[str, int]) -> Dict[str, int]:
134            del x["hi"], x["there"]
135            return x
136
137        py_out = del_dict_multiple_operands({"hi": 5, "there": 6})
138        jit_out = torch.jit.script(del_dict_multiple_operands)({"hi": 5, "there": 6})
139        self.assertEqual(py_out, jit_out)
140
141
142class TestTensorBuiltins(JitTestCase):
143    def test_tensor_properties(self):
144        def should_keep(tensor, name):
145            if inspect.isroutine(getattr(tensor, name)):
146                return False
147            if name.startswith("_"):
148                return False
149            return True
150
151        tensor = torch.arange(4, dtype=torch.float).view(2, 2)
152        keys = dir(tensor)
153
154        # real and imag are only implemented for complex tensors.
155        self.assertRaises(RuntimeError, lambda: should_keep(tensor, "imag"))
156        keys.remove("imag")
157
158        properties = [p for p in keys if should_keep(tensor, p)]
159
160        code_template = """
161        def fn(x):
162            return x.{}
163        """
164
165        EQUALITY_MISMATCH = {
166            # TorchScript doesn't have real enums so they return an int instead
167            # of the actual value
168            "dtype",
169            "layout",
170        }
171        MISSING_PROPERTIES = {
172            "grad_fn",
173            # This is an undocumented property so it's not included
174            "output_nr",
175            # This has a longer implementation, maybe not worth copying to
176            # TorchScript if named tensors don't work there anyways
177            "names",
178        }
179
180        for p in properties:
181            if p in MISSING_PROPERTIES:
182                continue
183            code = code_template.format(p)
184            cu = torch.jit.CompilationUnit()
185            cu.define(code)
186            if p in EQUALITY_MISMATCH:
187                continue
188            self.assertEqual(getattr(tensor, p), cu.fn(tensor))
189
190    def test_tensor_subscript_assign(self):
191        def fn1(x):
192            a = torch.zeros_like(x, dtype=torch.uint8)
193            a[torch.tensor(0)] = torch.tensor(2, dtype=torch.uint8)
194            return a
195
196        def fn2(x):
197            a = torch.zeros_like(x, dtype=torch.uint8)
198            a[0] = 2
199            return a
200
201        def fn3(x):
202            a = torch.zeros_like(x, dtype=torch.uint8)
203            a[torch.tensor(0)] = 2
204            return a
205
206        def fn4(x):
207            a = torch.zeros_like(x, dtype=torch.uint8)
208            a[0] = torch.tensor(2, dtype=torch.uint8)
209            return a
210
211        def fn5(x):
212            a = torch.zeros_like(x, dtype=torch.float32)
213            a[torch.tensor(0)] = 2
214            return a
215
216        for fn in (fn1, fn2, fn3, fn4, fn5):
217            self.checkScript(fn, (torch.zeros(2, dtype=torch.uint8),))
218
219    @unittest.skipIf(not RUN_CUDA, "requires CUDA")
220    def test_tensor_subscript_assign_device(self):
221        def fn6(x):
222            a = torch.zeros_like(x, dtype=torch.float32, device="cuda")
223            a[torch.tensor(0)] = 2
224            return a
225
226        self.checkScript(fn6, (torch.zeros(2, dtype=torch.float32, device="cuda"),))
227
228    def test_tensor_item(self):
229        def test_scalar_cast(x):
230            scalar = x.item()
231            return int(scalar), float(scalar)
232
233        graph = torch.jit.script(test_scalar_cast).graph
234        FileCheck().check("(int, float) = prim::TupleConstruct").run(graph)
235        self.checkScript(test_scalar_cast, (torch.tensor(1.0),))
236        self.checkScript(test_scalar_cast, (torch.tensor(1),))
237
238    def test_method_on_number(self):
239        def func():
240            c = 1
241            return c.add(1)
242
243        with self.assertRaisesRegex(RuntimeError, "object has no attribute or method"):
244            torch.jit.script(func)
245
246    # testing implicit conversion of tensors to scalars to match function arguments
247    def test_scalar_to_num_conversions(self):
248        @torch.jit.script
249        def multiple_defs(x):
250            c = 1
251            x = x + c
252            return x
253
254        self.assertTrue("ImplicitTensorToNum" not in str(multiple_defs.graph))
255
256        @torch.jit.script
257        def tensor_to_int_script(x, tensor):
258            return x.unsqueeze(tensor)
259
260        # location present in error message
261        with self.assertRaisesRegex(RuntimeError, "x.unsqueeze"):
262            tensor_to_int_script(torch.tensor([2]), torch.tensor([2, 2]))
263
264        def tensor_to_int(x, tensor):
265            return x.unsqueeze(tensor)
266
267        @torch.jit.script
268        def tensor_to_float_script(x, tensor):
269            return x.addcmul(tensor, tensor, value=tensor)
270
271        def tensor_to_float(x, tensor):
272            return x.addcmul(tensor, tensor, value=tensor)
273
274        x = torch.zeros(10)
275        # float tensor, float tensor with grad, int tensor (can't set grad on int tensor)
276        tensors = [
277            torch.tensor(1.1),
278            torch.tensor(1.1, requires_grad=True),
279            torch.tensor(0),
280            torch.tensor([2]),
281        ]
282
283        script_funs = [tensor_to_int_script, tensor_to_float_script]
284        funs = [tensor_to_int, tensor_to_float]
285
286        # return the result, or whether exception was thrown
287        def test_func(func, x, tensor):
288            try:
289                result = func(x, tensor)
290            except RuntimeError as e:
291                result = True
292            except TypeError as e:
293                result = True
294            return result
295
296        # assert result or exception equal for each (function, inputs)
297        for tensor in tensors:
298            for i in range(len(script_funs)):
299                self.assertEqual(
300                    test_func(script_funs[i], x, tensor), test_func(funs[i], x, tensor)
301                )
302