xref: /aosp_15_r20/external/pytorch/test/jit/test_isinstance.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# Owner(s): ["oncall: jit"]
2
3import os
4import sys
5import warnings
6from typing import Any, Dict, List, Optional, Tuple
7
8import torch
9
10
11# Make the helper files in test/ importable
12pytorch_test_dir = os.path.dirname(os.path.dirname(os.path.realpath(__file__)))
13sys.path.append(pytorch_test_dir)
14from torch.testing._internal.jit_utils import JitTestCase
15
16
17if __name__ == "__main__":
18    raise RuntimeError(
19        "This test file is not meant to be run directly, use:\n\n"
20        "\tpython test/test_jit.py TESTNAME\n\n"
21        "instead."
22    )
23
24
25# Tests for torch.jit.isinstance
26class TestIsinstance(JitTestCase):
27    def test_int(self):
28        def int_test(x: Any):
29            assert torch.jit.isinstance(x, int)
30            assert not torch.jit.isinstance(x, float)
31
32        x = 1
33        self.checkScript(int_test, (x,))
34
35    def test_float(self):
36        def float_test(x: Any):
37            assert torch.jit.isinstance(x, float)
38            assert not torch.jit.isinstance(x, int)
39
40        x = 1.0
41        self.checkScript(float_test, (x,))
42
43    def test_bool(self):
44        def bool_test(x: Any):
45            assert torch.jit.isinstance(x, bool)
46            assert not torch.jit.isinstance(x, float)
47
48        x = False
49        self.checkScript(bool_test, (x,))
50
51    def test_list(self):
52        def list_str_test(x: Any):
53            assert torch.jit.isinstance(x, List[str])
54            assert not torch.jit.isinstance(x, List[int])
55            assert not torch.jit.isinstance(x, Tuple[int])
56
57        x = ["1", "2", "3"]
58        self.checkScript(list_str_test, (x,))
59
60    def test_list_tensor(self):
61        def list_tensor_test(x: Any):
62            assert torch.jit.isinstance(x, List[torch.Tensor])
63            assert not torch.jit.isinstance(x, Tuple[int])
64
65        x = [torch.tensor([1]), torch.tensor([2]), torch.tensor([3])]
66        self.checkScript(list_tensor_test, (x,))
67
68    def test_dict(self):
69        def dict_str_int_test(x: Any):
70            assert torch.jit.isinstance(x, Dict[str, int])
71            assert not torch.jit.isinstance(x, Dict[int, str])
72            assert not torch.jit.isinstance(x, Dict[str, str])
73
74        x = {"a": 1, "b": 2}
75        self.checkScript(dict_str_int_test, (x,))
76
77    def test_dict_tensor(self):
78        def dict_int_tensor_test(x: Any):
79            assert torch.jit.isinstance(x, Dict[int, torch.Tensor])
80
81        x = {2: torch.tensor([2])}
82        self.checkScript(dict_int_tensor_test, (x,))
83
84    def test_tuple(self):
85        def tuple_test(x: Any):
86            assert torch.jit.isinstance(x, Tuple[str, int, str])
87            assert not torch.jit.isinstance(x, Tuple[int, str, str])
88            assert not torch.jit.isinstance(x, Tuple[str])
89
90        x = ("a", 1, "b")
91        self.checkScript(tuple_test, (x,))
92
93    def test_tuple_tensor(self):
94        def tuple_tensor_test(x: Any):
95            assert torch.jit.isinstance(x, Tuple[torch.Tensor, torch.Tensor])
96
97        x = (torch.tensor([1]), torch.tensor([[2], [3]]))
98        self.checkScript(tuple_tensor_test, (x,))
99
100    def test_optional(self):
101        def optional_test(x: Any):
102            assert torch.jit.isinstance(x, Optional[torch.Tensor])
103            assert not torch.jit.isinstance(x, Optional[str])
104
105        x = torch.ones(3, 3)
106        self.checkScript(optional_test, (x,))
107
108    def test_optional_none(self):
109        def optional_test_none(x: Any):
110            assert torch.jit.isinstance(x, Optional[torch.Tensor])
111            # assert torch.jit.isinstance(x, Optional[str])
112            # TODO: above line in eager will evaluate to True while in
113            #       the TS interpreter will evaluate to False as the
114            #       first torch.jit.isinstance refines the 'None' type
115
116        x = None
117        self.checkScript(optional_test_none, (x,))
118
119    def test_list_nested(self):
120        def list_nested(x: Any):
121            assert torch.jit.isinstance(x, List[Dict[str, int]])
122            assert not torch.jit.isinstance(x, List[List[str]])
123
124        x = [{"a": 1, "b": 2}, {"aa": 11, "bb": 22}]
125        self.checkScript(list_nested, (x,))
126
127    def test_dict_nested(self):
128        def dict_nested(x: Any):
129            assert torch.jit.isinstance(x, Dict[str, Tuple[str, str, str]])
130            assert not torch.jit.isinstance(x, Dict[str, Tuple[int, int, int]])
131
132        x = {"a": ("aa", "aa", "aa"), "b": ("bb", "bb", "bb")}
133        self.checkScript(dict_nested, (x,))
134
135    def test_tuple_nested(self):
136        def tuple_nested(x: Any):
137            assert torch.jit.isinstance(
138                x, Tuple[Dict[str, Tuple[str, str, str]], List[bool], Optional[str]]
139            )
140            assert not torch.jit.isinstance(x, Dict[str, Tuple[int, int, int]])
141            assert not torch.jit.isinstance(x, Tuple[str])
142            assert not torch.jit.isinstance(x, Tuple[List[bool], List[str], List[int]])
143
144        x = (
145            {"a": ("aa", "aa", "aa"), "b": ("bb", "bb", "bb")},
146            [True, False, True],
147            None,
148        )
149        self.checkScript(tuple_nested, (x,))
150
151    def test_optional_nested(self):
152        def optional_nested(x: Any):
153            assert torch.jit.isinstance(x, Optional[List[str]])
154
155        x = ["a", "b", "c"]
156        self.checkScript(optional_nested, (x,))
157
158    def test_list_tensor_type_true(self):
159        def list_tensor_type_true(x: Any):
160            assert torch.jit.isinstance(x, List[torch.Tensor])
161
162        x = [torch.rand(3, 3), torch.rand(4, 3)]
163        self.checkScript(list_tensor_type_true, (x,))
164
165    def test_tensor_type_false(self):
166        def list_tensor_type_false(x: Any):
167            assert not torch.jit.isinstance(x, List[torch.Tensor])
168
169        x = [1, 2, 3]
170        self.checkScript(list_tensor_type_false, (x,))
171
172    def test_in_if(self):
173        def list_in_if(x: Any):
174            if torch.jit.isinstance(x, List[int]):
175                assert True
176            if torch.jit.isinstance(x, List[str]):
177                assert not True
178
179        x = [1, 2, 3]
180        self.checkScript(list_in_if, (x,))
181
182    def test_if_else(self):
183        def list_in_if_else(x: Any):
184            if torch.jit.isinstance(x, Tuple[str, str, str]):
185                assert True
186            else:
187                assert not True
188
189        x = ("a", "b", "c")
190        self.checkScript(list_in_if_else, (x,))
191
192    def test_in_while_loop(self):
193        def list_in_while_loop(x: Any):
194            count = 0
195            while torch.jit.isinstance(x, List[Dict[str, int]]) and count <= 0:
196                count = count + 1
197            assert count == 1
198
199        x = [{"a": 1, "b": 2}, {"aa": 11, "bb": 22}]
200        self.checkScript(list_in_while_loop, (x,))
201
202    def test_type_refinement(self):
203        def type_refinement(obj: Any):
204            hit = False
205            if torch.jit.isinstance(obj, List[torch.Tensor]):
206                hit = not hit
207                for el in obj:
208                    # perform some tensor operation
209                    y = el.clamp(0, 0.5)
210            if torch.jit.isinstance(obj, Dict[str, str]):
211                hit = not hit
212                str_cat = ""
213                for val in obj.values():
214                    str_cat = str_cat + val
215                assert "111222" == str_cat
216            assert hit
217
218        x = [torch.rand(3, 3), torch.rand(4, 3)]
219        self.checkScript(type_refinement, (x,))
220        x = {"1": "111", "2": "222"}
221        self.checkScript(type_refinement, (x,))
222
223    def test_list_no_contained_type(self):
224        def list_no_contained_type(x: Any):
225            assert torch.jit.isinstance(x, List)
226
227        x = ["1", "2", "3"]
228
229        err_msg = (
230            "Attempted to use List without a contained type. "
231            r"Please add a contained type, e.g. List\[int\]"
232        )
233
234        with self.assertRaisesRegex(
235            RuntimeError,
236            err_msg,
237        ):
238            torch.jit.script(list_no_contained_type)
239        with self.assertRaisesRegex(
240            RuntimeError,
241            err_msg,
242        ):
243            list_no_contained_type(x)
244
245    def test_tuple_no_contained_type(self):
246        def tuple_no_contained_type(x: Any):
247            assert torch.jit.isinstance(x, Tuple)
248
249        x = ("1", "2", "3")
250
251        err_msg = (
252            "Attempted to use Tuple without a contained type. "
253            r"Please add a contained type, e.g. Tuple\[int\]"
254        )
255
256        with self.assertRaisesRegex(
257            RuntimeError,
258            err_msg,
259        ):
260            torch.jit.script(tuple_no_contained_type)
261        with self.assertRaisesRegex(
262            RuntimeError,
263            err_msg,
264        ):
265            tuple_no_contained_type(x)
266
267    def test_optional_no_contained_type(self):
268        def optional_no_contained_type(x: Any):
269            assert torch.jit.isinstance(x, Optional)
270
271        x = ("1", "2", "3")
272
273        err_msg = (
274            "Attempted to use Optional without a contained type. "
275            r"Please add a contained type, e.g. Optional\[int\]"
276        )
277
278        with self.assertRaisesRegex(
279            RuntimeError,
280            err_msg,
281        ):
282            torch.jit.script(optional_no_contained_type)
283        with self.assertRaisesRegex(
284            RuntimeError,
285            err_msg,
286        ):
287            optional_no_contained_type(x)
288
289    def test_dict_no_contained_type(self):
290        def dict_no_contained_type(x: Any):
291            assert torch.jit.isinstance(x, Dict)
292
293        x = {"a": "aa"}
294
295        err_msg = (
296            "Attempted to use Dict without contained types. "
297            r"Please add contained type, e.g. Dict\[int, int\]"
298        )
299
300        with self.assertRaisesRegex(
301            RuntimeError,
302            err_msg,
303        ):
304            torch.jit.script(dict_no_contained_type)
305        with self.assertRaisesRegex(
306            RuntimeError,
307            err_msg,
308        ):
309            dict_no_contained_type(x)
310
311    def test_tuple_rhs(self):
312        def fn(x: Any):
313            assert torch.jit.isinstance(x, (int, List[str]))
314            assert not torch.jit.isinstance(x, (List[float], Tuple[int, str]))
315            assert not torch.jit.isinstance(x, (List[float], str))
316
317        self.checkScript(fn, (2,))
318        self.checkScript(fn, (["foo", "bar", "baz"],))
319
320    def test_nontuple_container_rhs_throws_in_eager(self):
321        def fn1(x: Any):
322            assert torch.jit.isinstance(x, [int, List[str]])
323
324        def fn2(x: Any):
325            assert not torch.jit.isinstance(x, {List[str], Tuple[int, str]})
326
327        err_highlight = "must be a type or a tuple of types"
328
329        with self.assertRaisesRegex(RuntimeError, err_highlight):
330            fn1(2)
331
332        with self.assertRaisesRegex(RuntimeError, err_highlight):
333            fn2(2)
334
335    def test_empty_container_throws_warning_in_eager(self):
336        def fn(x: Any):
337            torch.jit.isinstance(x, List[int])
338
339        with warnings.catch_warnings(record=True) as w:
340            x: List[int] = []
341            fn(x)
342            self.assertEqual(len(w), 1)
343
344        with warnings.catch_warnings(record=True) as w:
345            x: int = 2
346            fn(x)
347            self.assertEqual(len(w), 0)
348
349    def test_empty_container_special_cases(self):
350        # Should not throw "Boolean value of Tensor with no values is
351        # ambiguous" error
352        torch._jit_internal.check_empty_containers(torch.Tensor([]))
353
354        # Should not throw "Boolean value of Tensor with more than
355        # one value is ambiguous" error
356        torch._jit_internal.check_empty_containers(torch.rand(2, 3))
357