xref: /aosp_15_r20/external/pytorch/test/jit/test_isinstance.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1*da0073e9SAndroid Build Coastguard Worker# Owner(s): ["oncall: jit"]
2*da0073e9SAndroid Build Coastguard Worker
3*da0073e9SAndroid Build Coastguard Workerimport os
4*da0073e9SAndroid Build Coastguard Workerimport sys
5*da0073e9SAndroid Build Coastguard Workerimport warnings
6*da0073e9SAndroid Build Coastguard Workerfrom typing import Any, Dict, List, Optional, Tuple
7*da0073e9SAndroid Build Coastguard Worker
8*da0073e9SAndroid Build Coastguard Workerimport torch
9*da0073e9SAndroid Build Coastguard Worker
10*da0073e9SAndroid Build Coastguard Worker
11*da0073e9SAndroid Build Coastguard Worker# Make the helper files in test/ importable
12*da0073e9SAndroid Build Coastguard Workerpytorch_test_dir = os.path.dirname(os.path.dirname(os.path.realpath(__file__)))
13*da0073e9SAndroid Build Coastguard Workersys.path.append(pytorch_test_dir)
14*da0073e9SAndroid Build Coastguard Workerfrom torch.testing._internal.jit_utils import JitTestCase
15*da0073e9SAndroid Build Coastguard Worker
16*da0073e9SAndroid Build Coastguard Worker
17*da0073e9SAndroid Build Coastguard Workerif __name__ == "__main__":
18*da0073e9SAndroid Build Coastguard Worker    raise RuntimeError(
19*da0073e9SAndroid Build Coastguard Worker        "This test file is not meant to be run directly, use:\n\n"
20*da0073e9SAndroid Build Coastguard Worker        "\tpython test/test_jit.py TESTNAME\n\n"
21*da0073e9SAndroid Build Coastguard Worker        "instead."
22*da0073e9SAndroid Build Coastguard Worker    )
23*da0073e9SAndroid Build Coastguard Worker
24*da0073e9SAndroid Build Coastguard Worker
25*da0073e9SAndroid Build Coastguard Worker# Tests for torch.jit.isinstance
26*da0073e9SAndroid Build Coastguard Workerclass TestIsinstance(JitTestCase):
27*da0073e9SAndroid Build Coastguard Worker    def test_int(self):
28*da0073e9SAndroid Build Coastguard Worker        def int_test(x: Any):
29*da0073e9SAndroid Build Coastguard Worker            assert torch.jit.isinstance(x, int)
30*da0073e9SAndroid Build Coastguard Worker            assert not torch.jit.isinstance(x, float)
31*da0073e9SAndroid Build Coastguard Worker
32*da0073e9SAndroid Build Coastguard Worker        x = 1
33*da0073e9SAndroid Build Coastguard Worker        self.checkScript(int_test, (x,))
34*da0073e9SAndroid Build Coastguard Worker
35*da0073e9SAndroid Build Coastguard Worker    def test_float(self):
36*da0073e9SAndroid Build Coastguard Worker        def float_test(x: Any):
37*da0073e9SAndroid Build Coastguard Worker            assert torch.jit.isinstance(x, float)
38*da0073e9SAndroid Build Coastguard Worker            assert not torch.jit.isinstance(x, int)
39*da0073e9SAndroid Build Coastguard Worker
40*da0073e9SAndroid Build Coastguard Worker        x = 1.0
41*da0073e9SAndroid Build Coastguard Worker        self.checkScript(float_test, (x,))
42*da0073e9SAndroid Build Coastguard Worker
43*da0073e9SAndroid Build Coastguard Worker    def test_bool(self):
44*da0073e9SAndroid Build Coastguard Worker        def bool_test(x: Any):
45*da0073e9SAndroid Build Coastguard Worker            assert torch.jit.isinstance(x, bool)
46*da0073e9SAndroid Build Coastguard Worker            assert not torch.jit.isinstance(x, float)
47*da0073e9SAndroid Build Coastguard Worker
48*da0073e9SAndroid Build Coastguard Worker        x = False
49*da0073e9SAndroid Build Coastguard Worker        self.checkScript(bool_test, (x,))
50*da0073e9SAndroid Build Coastguard Worker
51*da0073e9SAndroid Build Coastguard Worker    def test_list(self):
52*da0073e9SAndroid Build Coastguard Worker        def list_str_test(x: Any):
53*da0073e9SAndroid Build Coastguard Worker            assert torch.jit.isinstance(x, List[str])
54*da0073e9SAndroid Build Coastguard Worker            assert not torch.jit.isinstance(x, List[int])
55*da0073e9SAndroid Build Coastguard Worker            assert not torch.jit.isinstance(x, Tuple[int])
56*da0073e9SAndroid Build Coastguard Worker
57*da0073e9SAndroid Build Coastguard Worker        x = ["1", "2", "3"]
58*da0073e9SAndroid Build Coastguard Worker        self.checkScript(list_str_test, (x,))
59*da0073e9SAndroid Build Coastguard Worker
60*da0073e9SAndroid Build Coastguard Worker    def test_list_tensor(self):
61*da0073e9SAndroid Build Coastguard Worker        def list_tensor_test(x: Any):
62*da0073e9SAndroid Build Coastguard Worker            assert torch.jit.isinstance(x, List[torch.Tensor])
63*da0073e9SAndroid Build Coastguard Worker            assert not torch.jit.isinstance(x, Tuple[int])
64*da0073e9SAndroid Build Coastguard Worker
65*da0073e9SAndroid Build Coastguard Worker        x = [torch.tensor([1]), torch.tensor([2]), torch.tensor([3])]
66*da0073e9SAndroid Build Coastguard Worker        self.checkScript(list_tensor_test, (x,))
67*da0073e9SAndroid Build Coastguard Worker
68*da0073e9SAndroid Build Coastguard Worker    def test_dict(self):
69*da0073e9SAndroid Build Coastguard Worker        def dict_str_int_test(x: Any):
70*da0073e9SAndroid Build Coastguard Worker            assert torch.jit.isinstance(x, Dict[str, int])
71*da0073e9SAndroid Build Coastguard Worker            assert not torch.jit.isinstance(x, Dict[int, str])
72*da0073e9SAndroid Build Coastguard Worker            assert not torch.jit.isinstance(x, Dict[str, str])
73*da0073e9SAndroid Build Coastguard Worker
74*da0073e9SAndroid Build Coastguard Worker        x = {"a": 1, "b": 2}
75*da0073e9SAndroid Build Coastguard Worker        self.checkScript(dict_str_int_test, (x,))
76*da0073e9SAndroid Build Coastguard Worker
77*da0073e9SAndroid Build Coastguard Worker    def test_dict_tensor(self):
78*da0073e9SAndroid Build Coastguard Worker        def dict_int_tensor_test(x: Any):
79*da0073e9SAndroid Build Coastguard Worker            assert torch.jit.isinstance(x, Dict[int, torch.Tensor])
80*da0073e9SAndroid Build Coastguard Worker
81*da0073e9SAndroid Build Coastguard Worker        x = {2: torch.tensor([2])}
82*da0073e9SAndroid Build Coastguard Worker        self.checkScript(dict_int_tensor_test, (x,))
83*da0073e9SAndroid Build Coastguard Worker
84*da0073e9SAndroid Build Coastguard Worker    def test_tuple(self):
85*da0073e9SAndroid Build Coastguard Worker        def tuple_test(x: Any):
86*da0073e9SAndroid Build Coastguard Worker            assert torch.jit.isinstance(x, Tuple[str, int, str])
87*da0073e9SAndroid Build Coastguard Worker            assert not torch.jit.isinstance(x, Tuple[int, str, str])
88*da0073e9SAndroid Build Coastguard Worker            assert not torch.jit.isinstance(x, Tuple[str])
89*da0073e9SAndroid Build Coastguard Worker
90*da0073e9SAndroid Build Coastguard Worker        x = ("a", 1, "b")
91*da0073e9SAndroid Build Coastguard Worker        self.checkScript(tuple_test, (x,))
92*da0073e9SAndroid Build Coastguard Worker
93*da0073e9SAndroid Build Coastguard Worker    def test_tuple_tensor(self):
94*da0073e9SAndroid Build Coastguard Worker        def tuple_tensor_test(x: Any):
95*da0073e9SAndroid Build Coastguard Worker            assert torch.jit.isinstance(x, Tuple[torch.Tensor, torch.Tensor])
96*da0073e9SAndroid Build Coastguard Worker
97*da0073e9SAndroid Build Coastguard Worker        x = (torch.tensor([1]), torch.tensor([[2], [3]]))
98*da0073e9SAndroid Build Coastguard Worker        self.checkScript(tuple_tensor_test, (x,))
99*da0073e9SAndroid Build Coastguard Worker
100*da0073e9SAndroid Build Coastguard Worker    def test_optional(self):
101*da0073e9SAndroid Build Coastguard Worker        def optional_test(x: Any):
102*da0073e9SAndroid Build Coastguard Worker            assert torch.jit.isinstance(x, Optional[torch.Tensor])
103*da0073e9SAndroid Build Coastguard Worker            assert not torch.jit.isinstance(x, Optional[str])
104*da0073e9SAndroid Build Coastguard Worker
105*da0073e9SAndroid Build Coastguard Worker        x = torch.ones(3, 3)
106*da0073e9SAndroid Build Coastguard Worker        self.checkScript(optional_test, (x,))
107*da0073e9SAndroid Build Coastguard Worker
108*da0073e9SAndroid Build Coastguard Worker    def test_optional_none(self):
109*da0073e9SAndroid Build Coastguard Worker        def optional_test_none(x: Any):
110*da0073e9SAndroid Build Coastguard Worker            assert torch.jit.isinstance(x, Optional[torch.Tensor])
111*da0073e9SAndroid Build Coastguard Worker            # assert torch.jit.isinstance(x, Optional[str])
112*da0073e9SAndroid Build Coastguard Worker            # TODO: above line in eager will evaluate to True while in
113*da0073e9SAndroid Build Coastguard Worker            #       the TS interpreter will evaluate to False as the
114*da0073e9SAndroid Build Coastguard Worker            #       first torch.jit.isinstance refines the 'None' type
115*da0073e9SAndroid Build Coastguard Worker
116*da0073e9SAndroid Build Coastguard Worker        x = None
117*da0073e9SAndroid Build Coastguard Worker        self.checkScript(optional_test_none, (x,))
118*da0073e9SAndroid Build Coastguard Worker
119*da0073e9SAndroid Build Coastguard Worker    def test_list_nested(self):
120*da0073e9SAndroid Build Coastguard Worker        def list_nested(x: Any):
121*da0073e9SAndroid Build Coastguard Worker            assert torch.jit.isinstance(x, List[Dict[str, int]])
122*da0073e9SAndroid Build Coastguard Worker            assert not torch.jit.isinstance(x, List[List[str]])
123*da0073e9SAndroid Build Coastguard Worker
124*da0073e9SAndroid Build Coastguard Worker        x = [{"a": 1, "b": 2}, {"aa": 11, "bb": 22}]
125*da0073e9SAndroid Build Coastguard Worker        self.checkScript(list_nested, (x,))
126*da0073e9SAndroid Build Coastguard Worker
127*da0073e9SAndroid Build Coastguard Worker    def test_dict_nested(self):
128*da0073e9SAndroid Build Coastguard Worker        def dict_nested(x: Any):
129*da0073e9SAndroid Build Coastguard Worker            assert torch.jit.isinstance(x, Dict[str, Tuple[str, str, str]])
130*da0073e9SAndroid Build Coastguard Worker            assert not torch.jit.isinstance(x, Dict[str, Tuple[int, int, int]])
131*da0073e9SAndroid Build Coastguard Worker
132*da0073e9SAndroid Build Coastguard Worker        x = {"a": ("aa", "aa", "aa"), "b": ("bb", "bb", "bb")}
133*da0073e9SAndroid Build Coastguard Worker        self.checkScript(dict_nested, (x,))
134*da0073e9SAndroid Build Coastguard Worker
135*da0073e9SAndroid Build Coastguard Worker    def test_tuple_nested(self):
136*da0073e9SAndroid Build Coastguard Worker        def tuple_nested(x: Any):
137*da0073e9SAndroid Build Coastguard Worker            assert torch.jit.isinstance(
138*da0073e9SAndroid Build Coastguard Worker                x, Tuple[Dict[str, Tuple[str, str, str]], List[bool], Optional[str]]
139*da0073e9SAndroid Build Coastguard Worker            )
140*da0073e9SAndroid Build Coastguard Worker            assert not torch.jit.isinstance(x, Dict[str, Tuple[int, int, int]])
141*da0073e9SAndroid Build Coastguard Worker            assert not torch.jit.isinstance(x, Tuple[str])
142*da0073e9SAndroid Build Coastguard Worker            assert not torch.jit.isinstance(x, Tuple[List[bool], List[str], List[int]])
143*da0073e9SAndroid Build Coastguard Worker
144*da0073e9SAndroid Build Coastguard Worker        x = (
145*da0073e9SAndroid Build Coastguard Worker            {"a": ("aa", "aa", "aa"), "b": ("bb", "bb", "bb")},
146*da0073e9SAndroid Build Coastguard Worker            [True, False, True],
147*da0073e9SAndroid Build Coastguard Worker            None,
148*da0073e9SAndroid Build Coastguard Worker        )
149*da0073e9SAndroid Build Coastguard Worker        self.checkScript(tuple_nested, (x,))
150*da0073e9SAndroid Build Coastguard Worker
151*da0073e9SAndroid Build Coastguard Worker    def test_optional_nested(self):
152*da0073e9SAndroid Build Coastguard Worker        def optional_nested(x: Any):
153*da0073e9SAndroid Build Coastguard Worker            assert torch.jit.isinstance(x, Optional[List[str]])
154*da0073e9SAndroid Build Coastguard Worker
155*da0073e9SAndroid Build Coastguard Worker        x = ["a", "b", "c"]
156*da0073e9SAndroid Build Coastguard Worker        self.checkScript(optional_nested, (x,))
157*da0073e9SAndroid Build Coastguard Worker
158*da0073e9SAndroid Build Coastguard Worker    def test_list_tensor_type_true(self):
159*da0073e9SAndroid Build Coastguard Worker        def list_tensor_type_true(x: Any):
160*da0073e9SAndroid Build Coastguard Worker            assert torch.jit.isinstance(x, List[torch.Tensor])
161*da0073e9SAndroid Build Coastguard Worker
162*da0073e9SAndroid Build Coastguard Worker        x = [torch.rand(3, 3), torch.rand(4, 3)]
163*da0073e9SAndroid Build Coastguard Worker        self.checkScript(list_tensor_type_true, (x,))
164*da0073e9SAndroid Build Coastguard Worker
165*da0073e9SAndroid Build Coastguard Worker    def test_tensor_type_false(self):
166*da0073e9SAndroid Build Coastguard Worker        def list_tensor_type_false(x: Any):
167*da0073e9SAndroid Build Coastguard Worker            assert not torch.jit.isinstance(x, List[torch.Tensor])
168*da0073e9SAndroid Build Coastguard Worker
169*da0073e9SAndroid Build Coastguard Worker        x = [1, 2, 3]
170*da0073e9SAndroid Build Coastguard Worker        self.checkScript(list_tensor_type_false, (x,))
171*da0073e9SAndroid Build Coastguard Worker
172*da0073e9SAndroid Build Coastguard Worker    def test_in_if(self):
173*da0073e9SAndroid Build Coastguard Worker        def list_in_if(x: Any):
174*da0073e9SAndroid Build Coastguard Worker            if torch.jit.isinstance(x, List[int]):
175*da0073e9SAndroid Build Coastguard Worker                assert True
176*da0073e9SAndroid Build Coastguard Worker            if torch.jit.isinstance(x, List[str]):
177*da0073e9SAndroid Build Coastguard Worker                assert not True
178*da0073e9SAndroid Build Coastguard Worker
179*da0073e9SAndroid Build Coastguard Worker        x = [1, 2, 3]
180*da0073e9SAndroid Build Coastguard Worker        self.checkScript(list_in_if, (x,))
181*da0073e9SAndroid Build Coastguard Worker
182*da0073e9SAndroid Build Coastguard Worker    def test_if_else(self):
183*da0073e9SAndroid Build Coastguard Worker        def list_in_if_else(x: Any):
184*da0073e9SAndroid Build Coastguard Worker            if torch.jit.isinstance(x, Tuple[str, str, str]):
185*da0073e9SAndroid Build Coastguard Worker                assert True
186*da0073e9SAndroid Build Coastguard Worker            else:
187*da0073e9SAndroid Build Coastguard Worker                assert not True
188*da0073e9SAndroid Build Coastguard Worker
189*da0073e9SAndroid Build Coastguard Worker        x = ("a", "b", "c")
190*da0073e9SAndroid Build Coastguard Worker        self.checkScript(list_in_if_else, (x,))
191*da0073e9SAndroid Build Coastguard Worker
192*da0073e9SAndroid Build Coastguard Worker    def test_in_while_loop(self):
193*da0073e9SAndroid Build Coastguard Worker        def list_in_while_loop(x: Any):
194*da0073e9SAndroid Build Coastguard Worker            count = 0
195*da0073e9SAndroid Build Coastguard Worker            while torch.jit.isinstance(x, List[Dict[str, int]]) and count <= 0:
196*da0073e9SAndroid Build Coastguard Worker                count = count + 1
197*da0073e9SAndroid Build Coastguard Worker            assert count == 1
198*da0073e9SAndroid Build Coastguard Worker
199*da0073e9SAndroid Build Coastguard Worker        x = [{"a": 1, "b": 2}, {"aa": 11, "bb": 22}]
200*da0073e9SAndroid Build Coastguard Worker        self.checkScript(list_in_while_loop, (x,))
201*da0073e9SAndroid Build Coastguard Worker
202*da0073e9SAndroid Build Coastguard Worker    def test_type_refinement(self):
203*da0073e9SAndroid Build Coastguard Worker        def type_refinement(obj: Any):
204*da0073e9SAndroid Build Coastguard Worker            hit = False
205*da0073e9SAndroid Build Coastguard Worker            if torch.jit.isinstance(obj, List[torch.Tensor]):
206*da0073e9SAndroid Build Coastguard Worker                hit = not hit
207*da0073e9SAndroid Build Coastguard Worker                for el in obj:
208*da0073e9SAndroid Build Coastguard Worker                    # perform some tensor operation
209*da0073e9SAndroid Build Coastguard Worker                    y = el.clamp(0, 0.5)
210*da0073e9SAndroid Build Coastguard Worker            if torch.jit.isinstance(obj, Dict[str, str]):
211*da0073e9SAndroid Build Coastguard Worker                hit = not hit
212*da0073e9SAndroid Build Coastguard Worker                str_cat = ""
213*da0073e9SAndroid Build Coastguard Worker                for val in obj.values():
214*da0073e9SAndroid Build Coastguard Worker                    str_cat = str_cat + val
215*da0073e9SAndroid Build Coastguard Worker                assert "111222" == str_cat
216*da0073e9SAndroid Build Coastguard Worker            assert hit
217*da0073e9SAndroid Build Coastguard Worker
218*da0073e9SAndroid Build Coastguard Worker        x = [torch.rand(3, 3), torch.rand(4, 3)]
219*da0073e9SAndroid Build Coastguard Worker        self.checkScript(type_refinement, (x,))
220*da0073e9SAndroid Build Coastguard Worker        x = {"1": "111", "2": "222"}
221*da0073e9SAndroid Build Coastguard Worker        self.checkScript(type_refinement, (x,))
222*da0073e9SAndroid Build Coastguard Worker
223*da0073e9SAndroid Build Coastguard Worker    def test_list_no_contained_type(self):
224*da0073e9SAndroid Build Coastguard Worker        def list_no_contained_type(x: Any):
225*da0073e9SAndroid Build Coastguard Worker            assert torch.jit.isinstance(x, List)
226*da0073e9SAndroid Build Coastguard Worker
227*da0073e9SAndroid Build Coastguard Worker        x = ["1", "2", "3"]
228*da0073e9SAndroid Build Coastguard Worker
229*da0073e9SAndroid Build Coastguard Worker        err_msg = (
230*da0073e9SAndroid Build Coastguard Worker            "Attempted to use List without a contained type. "
231*da0073e9SAndroid Build Coastguard Worker            r"Please add a contained type, e.g. List\[int\]"
232*da0073e9SAndroid Build Coastguard Worker        )
233*da0073e9SAndroid Build Coastguard Worker
234*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(
235*da0073e9SAndroid Build Coastguard Worker            RuntimeError,
236*da0073e9SAndroid Build Coastguard Worker            err_msg,
237*da0073e9SAndroid Build Coastguard Worker        ):
238*da0073e9SAndroid Build Coastguard Worker            torch.jit.script(list_no_contained_type)
239*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(
240*da0073e9SAndroid Build Coastguard Worker            RuntimeError,
241*da0073e9SAndroid Build Coastguard Worker            err_msg,
242*da0073e9SAndroid Build Coastguard Worker        ):
243*da0073e9SAndroid Build Coastguard Worker            list_no_contained_type(x)
244*da0073e9SAndroid Build Coastguard Worker
245*da0073e9SAndroid Build Coastguard Worker    def test_tuple_no_contained_type(self):
246*da0073e9SAndroid Build Coastguard Worker        def tuple_no_contained_type(x: Any):
247*da0073e9SAndroid Build Coastguard Worker            assert torch.jit.isinstance(x, Tuple)
248*da0073e9SAndroid Build Coastguard Worker
249*da0073e9SAndroid Build Coastguard Worker        x = ("1", "2", "3")
250*da0073e9SAndroid Build Coastguard Worker
251*da0073e9SAndroid Build Coastguard Worker        err_msg = (
252*da0073e9SAndroid Build Coastguard Worker            "Attempted to use Tuple without a contained type. "
253*da0073e9SAndroid Build Coastguard Worker            r"Please add a contained type, e.g. Tuple\[int\]"
254*da0073e9SAndroid Build Coastguard Worker        )
255*da0073e9SAndroid Build Coastguard Worker
256*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(
257*da0073e9SAndroid Build Coastguard Worker            RuntimeError,
258*da0073e9SAndroid Build Coastguard Worker            err_msg,
259*da0073e9SAndroid Build Coastguard Worker        ):
260*da0073e9SAndroid Build Coastguard Worker            torch.jit.script(tuple_no_contained_type)
261*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(
262*da0073e9SAndroid Build Coastguard Worker            RuntimeError,
263*da0073e9SAndroid Build Coastguard Worker            err_msg,
264*da0073e9SAndroid Build Coastguard Worker        ):
265*da0073e9SAndroid Build Coastguard Worker            tuple_no_contained_type(x)
266*da0073e9SAndroid Build Coastguard Worker
267*da0073e9SAndroid Build Coastguard Worker    def test_optional_no_contained_type(self):
268*da0073e9SAndroid Build Coastguard Worker        def optional_no_contained_type(x: Any):
269*da0073e9SAndroid Build Coastguard Worker            assert torch.jit.isinstance(x, Optional)
270*da0073e9SAndroid Build Coastguard Worker
271*da0073e9SAndroid Build Coastguard Worker        x = ("1", "2", "3")
272*da0073e9SAndroid Build Coastguard Worker
273*da0073e9SAndroid Build Coastguard Worker        err_msg = (
274*da0073e9SAndroid Build Coastguard Worker            "Attempted to use Optional without a contained type. "
275*da0073e9SAndroid Build Coastguard Worker            r"Please add a contained type, e.g. Optional\[int\]"
276*da0073e9SAndroid Build Coastguard Worker        )
277*da0073e9SAndroid Build Coastguard Worker
278*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(
279*da0073e9SAndroid Build Coastguard Worker            RuntimeError,
280*da0073e9SAndroid Build Coastguard Worker            err_msg,
281*da0073e9SAndroid Build Coastguard Worker        ):
282*da0073e9SAndroid Build Coastguard Worker            torch.jit.script(optional_no_contained_type)
283*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(
284*da0073e9SAndroid Build Coastguard Worker            RuntimeError,
285*da0073e9SAndroid Build Coastguard Worker            err_msg,
286*da0073e9SAndroid Build Coastguard Worker        ):
287*da0073e9SAndroid Build Coastguard Worker            optional_no_contained_type(x)
288*da0073e9SAndroid Build Coastguard Worker
289*da0073e9SAndroid Build Coastguard Worker    def test_dict_no_contained_type(self):
290*da0073e9SAndroid Build Coastguard Worker        def dict_no_contained_type(x: Any):
291*da0073e9SAndroid Build Coastguard Worker            assert torch.jit.isinstance(x, Dict)
292*da0073e9SAndroid Build Coastguard Worker
293*da0073e9SAndroid Build Coastguard Worker        x = {"a": "aa"}
294*da0073e9SAndroid Build Coastguard Worker
295*da0073e9SAndroid Build Coastguard Worker        err_msg = (
296*da0073e9SAndroid Build Coastguard Worker            "Attempted to use Dict without contained types. "
297*da0073e9SAndroid Build Coastguard Worker            r"Please add contained type, e.g. Dict\[int, int\]"
298*da0073e9SAndroid Build Coastguard Worker        )
299*da0073e9SAndroid Build Coastguard Worker
300*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(
301*da0073e9SAndroid Build Coastguard Worker            RuntimeError,
302*da0073e9SAndroid Build Coastguard Worker            err_msg,
303*da0073e9SAndroid Build Coastguard Worker        ):
304*da0073e9SAndroid Build Coastguard Worker            torch.jit.script(dict_no_contained_type)
305*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(
306*da0073e9SAndroid Build Coastguard Worker            RuntimeError,
307*da0073e9SAndroid Build Coastguard Worker            err_msg,
308*da0073e9SAndroid Build Coastguard Worker        ):
309*da0073e9SAndroid Build Coastguard Worker            dict_no_contained_type(x)
310*da0073e9SAndroid Build Coastguard Worker
311*da0073e9SAndroid Build Coastguard Worker    def test_tuple_rhs(self):
312*da0073e9SAndroid Build Coastguard Worker        def fn(x: Any):
313*da0073e9SAndroid Build Coastguard Worker            assert torch.jit.isinstance(x, (int, List[str]))
314*da0073e9SAndroid Build Coastguard Worker            assert not torch.jit.isinstance(x, (List[float], Tuple[int, str]))
315*da0073e9SAndroid Build Coastguard Worker            assert not torch.jit.isinstance(x, (List[float], str))
316*da0073e9SAndroid Build Coastguard Worker
317*da0073e9SAndroid Build Coastguard Worker        self.checkScript(fn, (2,))
318*da0073e9SAndroid Build Coastguard Worker        self.checkScript(fn, (["foo", "bar", "baz"],))
319*da0073e9SAndroid Build Coastguard Worker
320*da0073e9SAndroid Build Coastguard Worker    def test_nontuple_container_rhs_throws_in_eager(self):
321*da0073e9SAndroid Build Coastguard Worker        def fn1(x: Any):
322*da0073e9SAndroid Build Coastguard Worker            assert torch.jit.isinstance(x, [int, List[str]])
323*da0073e9SAndroid Build Coastguard Worker
324*da0073e9SAndroid Build Coastguard Worker        def fn2(x: Any):
325*da0073e9SAndroid Build Coastguard Worker            assert not torch.jit.isinstance(x, {List[str], Tuple[int, str]})
326*da0073e9SAndroid Build Coastguard Worker
327*da0073e9SAndroid Build Coastguard Worker        err_highlight = "must be a type or a tuple of types"
328*da0073e9SAndroid Build Coastguard Worker
329*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(RuntimeError, err_highlight):
330*da0073e9SAndroid Build Coastguard Worker            fn1(2)
331*da0073e9SAndroid Build Coastguard Worker
332*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(RuntimeError, err_highlight):
333*da0073e9SAndroid Build Coastguard Worker            fn2(2)
334*da0073e9SAndroid Build Coastguard Worker
335*da0073e9SAndroid Build Coastguard Worker    def test_empty_container_throws_warning_in_eager(self):
336*da0073e9SAndroid Build Coastguard Worker        def fn(x: Any):
337*da0073e9SAndroid Build Coastguard Worker            torch.jit.isinstance(x, List[int])
338*da0073e9SAndroid Build Coastguard Worker
339*da0073e9SAndroid Build Coastguard Worker        with warnings.catch_warnings(record=True) as w:
340*da0073e9SAndroid Build Coastguard Worker            x: List[int] = []
341*da0073e9SAndroid Build Coastguard Worker            fn(x)
342*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(len(w), 1)
343*da0073e9SAndroid Build Coastguard Worker
344*da0073e9SAndroid Build Coastguard Worker        with warnings.catch_warnings(record=True) as w:
345*da0073e9SAndroid Build Coastguard Worker            x: int = 2
346*da0073e9SAndroid Build Coastguard Worker            fn(x)
347*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(len(w), 0)
348*da0073e9SAndroid Build Coastguard Worker
349*da0073e9SAndroid Build Coastguard Worker    def test_empty_container_special_cases(self):
350*da0073e9SAndroid Build Coastguard Worker        # Should not throw "Boolean value of Tensor with no values is
351*da0073e9SAndroid Build Coastguard Worker        # ambiguous" error
352*da0073e9SAndroid Build Coastguard Worker        torch._jit_internal.check_empty_containers(torch.Tensor([]))
353*da0073e9SAndroid Build Coastguard Worker
354*da0073e9SAndroid Build Coastguard Worker        # Should not throw "Boolean value of Tensor with more than
355*da0073e9SAndroid Build Coastguard Worker        # one value is ambiguous" error
356*da0073e9SAndroid Build Coastguard Worker        torch._jit_internal.check_empty_containers(torch.rand(2, 3))
357