1*da0073e9SAndroid Build Coastguard Worker# Owner(s): ["oncall: jit"] 2*da0073e9SAndroid Build Coastguard Worker 3*da0073e9SAndroid Build Coastguard Workerimport os 4*da0073e9SAndroid Build Coastguard Workerimport sys 5*da0073e9SAndroid Build Coastguard Workerimport warnings 6*da0073e9SAndroid Build Coastguard Workerfrom typing import Any, Dict, List, Optional, Tuple 7*da0073e9SAndroid Build Coastguard Worker 8*da0073e9SAndroid Build Coastguard Workerimport torch 9*da0073e9SAndroid Build Coastguard Worker 10*da0073e9SAndroid Build Coastguard Worker 11*da0073e9SAndroid Build Coastguard Worker# Make the helper files in test/ importable 12*da0073e9SAndroid Build Coastguard Workerpytorch_test_dir = os.path.dirname(os.path.dirname(os.path.realpath(__file__))) 13*da0073e9SAndroid Build Coastguard Workersys.path.append(pytorch_test_dir) 14*da0073e9SAndroid Build Coastguard Workerfrom torch.testing._internal.jit_utils import JitTestCase 15*da0073e9SAndroid Build Coastguard Worker 16*da0073e9SAndroid Build Coastguard Worker 17*da0073e9SAndroid Build Coastguard Workerif __name__ == "__main__": 18*da0073e9SAndroid Build Coastguard Worker raise RuntimeError( 19*da0073e9SAndroid Build Coastguard Worker "This test file is not meant to be run directly, use:\n\n" 20*da0073e9SAndroid Build Coastguard Worker "\tpython test/test_jit.py TESTNAME\n\n" 21*da0073e9SAndroid Build Coastguard Worker "instead." 22*da0073e9SAndroid Build Coastguard Worker ) 23*da0073e9SAndroid Build Coastguard Worker 24*da0073e9SAndroid Build Coastguard Worker 25*da0073e9SAndroid Build Coastguard Worker# Tests for torch.jit.isinstance 26*da0073e9SAndroid Build Coastguard Workerclass TestIsinstance(JitTestCase): 27*da0073e9SAndroid Build Coastguard Worker def test_int(self): 28*da0073e9SAndroid Build Coastguard Worker def int_test(x: Any): 29*da0073e9SAndroid Build Coastguard Worker assert torch.jit.isinstance(x, int) 30*da0073e9SAndroid Build Coastguard Worker assert not torch.jit.isinstance(x, float) 31*da0073e9SAndroid Build Coastguard Worker 32*da0073e9SAndroid Build Coastguard Worker x = 1 33*da0073e9SAndroid Build Coastguard Worker self.checkScript(int_test, (x,)) 34*da0073e9SAndroid Build Coastguard Worker 35*da0073e9SAndroid Build Coastguard Worker def test_float(self): 36*da0073e9SAndroid Build Coastguard Worker def float_test(x: Any): 37*da0073e9SAndroid Build Coastguard Worker assert torch.jit.isinstance(x, float) 38*da0073e9SAndroid Build Coastguard Worker assert not torch.jit.isinstance(x, int) 39*da0073e9SAndroid Build Coastguard Worker 40*da0073e9SAndroid Build Coastguard Worker x = 1.0 41*da0073e9SAndroid Build Coastguard Worker self.checkScript(float_test, (x,)) 42*da0073e9SAndroid Build Coastguard Worker 43*da0073e9SAndroid Build Coastguard Worker def test_bool(self): 44*da0073e9SAndroid Build Coastguard Worker def bool_test(x: Any): 45*da0073e9SAndroid Build Coastguard Worker assert torch.jit.isinstance(x, bool) 46*da0073e9SAndroid Build Coastguard Worker assert not torch.jit.isinstance(x, float) 47*da0073e9SAndroid Build Coastguard Worker 48*da0073e9SAndroid Build Coastguard Worker x = False 49*da0073e9SAndroid Build Coastguard Worker self.checkScript(bool_test, (x,)) 50*da0073e9SAndroid Build Coastguard Worker 51*da0073e9SAndroid Build Coastguard Worker def test_list(self): 52*da0073e9SAndroid Build Coastguard Worker def list_str_test(x: Any): 53*da0073e9SAndroid Build Coastguard Worker assert torch.jit.isinstance(x, List[str]) 54*da0073e9SAndroid Build Coastguard Worker assert not torch.jit.isinstance(x, List[int]) 55*da0073e9SAndroid Build Coastguard Worker assert not torch.jit.isinstance(x, Tuple[int]) 56*da0073e9SAndroid Build Coastguard Worker 57*da0073e9SAndroid Build Coastguard Worker x = ["1", "2", "3"] 58*da0073e9SAndroid Build Coastguard Worker self.checkScript(list_str_test, (x,)) 59*da0073e9SAndroid Build Coastguard Worker 60*da0073e9SAndroid Build Coastguard Worker def test_list_tensor(self): 61*da0073e9SAndroid Build Coastguard Worker def list_tensor_test(x: Any): 62*da0073e9SAndroid Build Coastguard Worker assert torch.jit.isinstance(x, List[torch.Tensor]) 63*da0073e9SAndroid Build Coastguard Worker assert not torch.jit.isinstance(x, Tuple[int]) 64*da0073e9SAndroid Build Coastguard Worker 65*da0073e9SAndroid Build Coastguard Worker x = [torch.tensor([1]), torch.tensor([2]), torch.tensor([3])] 66*da0073e9SAndroid Build Coastguard Worker self.checkScript(list_tensor_test, (x,)) 67*da0073e9SAndroid Build Coastguard Worker 68*da0073e9SAndroid Build Coastguard Worker def test_dict(self): 69*da0073e9SAndroid Build Coastguard Worker def dict_str_int_test(x: Any): 70*da0073e9SAndroid Build Coastguard Worker assert torch.jit.isinstance(x, Dict[str, int]) 71*da0073e9SAndroid Build Coastguard Worker assert not torch.jit.isinstance(x, Dict[int, str]) 72*da0073e9SAndroid Build Coastguard Worker assert not torch.jit.isinstance(x, Dict[str, str]) 73*da0073e9SAndroid Build Coastguard Worker 74*da0073e9SAndroid Build Coastguard Worker x = {"a": 1, "b": 2} 75*da0073e9SAndroid Build Coastguard Worker self.checkScript(dict_str_int_test, (x,)) 76*da0073e9SAndroid Build Coastguard Worker 77*da0073e9SAndroid Build Coastguard Worker def test_dict_tensor(self): 78*da0073e9SAndroid Build Coastguard Worker def dict_int_tensor_test(x: Any): 79*da0073e9SAndroid Build Coastguard Worker assert torch.jit.isinstance(x, Dict[int, torch.Tensor]) 80*da0073e9SAndroid Build Coastguard Worker 81*da0073e9SAndroid Build Coastguard Worker x = {2: torch.tensor([2])} 82*da0073e9SAndroid Build Coastguard Worker self.checkScript(dict_int_tensor_test, (x,)) 83*da0073e9SAndroid Build Coastguard Worker 84*da0073e9SAndroid Build Coastguard Worker def test_tuple(self): 85*da0073e9SAndroid Build Coastguard Worker def tuple_test(x: Any): 86*da0073e9SAndroid Build Coastguard Worker assert torch.jit.isinstance(x, Tuple[str, int, str]) 87*da0073e9SAndroid Build Coastguard Worker assert not torch.jit.isinstance(x, Tuple[int, str, str]) 88*da0073e9SAndroid Build Coastguard Worker assert not torch.jit.isinstance(x, Tuple[str]) 89*da0073e9SAndroid Build Coastguard Worker 90*da0073e9SAndroid Build Coastguard Worker x = ("a", 1, "b") 91*da0073e9SAndroid Build Coastguard Worker self.checkScript(tuple_test, (x,)) 92*da0073e9SAndroid Build Coastguard Worker 93*da0073e9SAndroid Build Coastguard Worker def test_tuple_tensor(self): 94*da0073e9SAndroid Build Coastguard Worker def tuple_tensor_test(x: Any): 95*da0073e9SAndroid Build Coastguard Worker assert torch.jit.isinstance(x, Tuple[torch.Tensor, torch.Tensor]) 96*da0073e9SAndroid Build Coastguard Worker 97*da0073e9SAndroid Build Coastguard Worker x = (torch.tensor([1]), torch.tensor([[2], [3]])) 98*da0073e9SAndroid Build Coastguard Worker self.checkScript(tuple_tensor_test, (x,)) 99*da0073e9SAndroid Build Coastguard Worker 100*da0073e9SAndroid Build Coastguard Worker def test_optional(self): 101*da0073e9SAndroid Build Coastguard Worker def optional_test(x: Any): 102*da0073e9SAndroid Build Coastguard Worker assert torch.jit.isinstance(x, Optional[torch.Tensor]) 103*da0073e9SAndroid Build Coastguard Worker assert not torch.jit.isinstance(x, Optional[str]) 104*da0073e9SAndroid Build Coastguard Worker 105*da0073e9SAndroid Build Coastguard Worker x = torch.ones(3, 3) 106*da0073e9SAndroid Build Coastguard Worker self.checkScript(optional_test, (x,)) 107*da0073e9SAndroid Build Coastguard Worker 108*da0073e9SAndroid Build Coastguard Worker def test_optional_none(self): 109*da0073e9SAndroid Build Coastguard Worker def optional_test_none(x: Any): 110*da0073e9SAndroid Build Coastguard Worker assert torch.jit.isinstance(x, Optional[torch.Tensor]) 111*da0073e9SAndroid Build Coastguard Worker # assert torch.jit.isinstance(x, Optional[str]) 112*da0073e9SAndroid Build Coastguard Worker # TODO: above line in eager will evaluate to True while in 113*da0073e9SAndroid Build Coastguard Worker # the TS interpreter will evaluate to False as the 114*da0073e9SAndroid Build Coastguard Worker # first torch.jit.isinstance refines the 'None' type 115*da0073e9SAndroid Build Coastguard Worker 116*da0073e9SAndroid Build Coastguard Worker x = None 117*da0073e9SAndroid Build Coastguard Worker self.checkScript(optional_test_none, (x,)) 118*da0073e9SAndroid Build Coastguard Worker 119*da0073e9SAndroid Build Coastguard Worker def test_list_nested(self): 120*da0073e9SAndroid Build Coastguard Worker def list_nested(x: Any): 121*da0073e9SAndroid Build Coastguard Worker assert torch.jit.isinstance(x, List[Dict[str, int]]) 122*da0073e9SAndroid Build Coastguard Worker assert not torch.jit.isinstance(x, List[List[str]]) 123*da0073e9SAndroid Build Coastguard Worker 124*da0073e9SAndroid Build Coastguard Worker x = [{"a": 1, "b": 2}, {"aa": 11, "bb": 22}] 125*da0073e9SAndroid Build Coastguard Worker self.checkScript(list_nested, (x,)) 126*da0073e9SAndroid Build Coastguard Worker 127*da0073e9SAndroid Build Coastguard Worker def test_dict_nested(self): 128*da0073e9SAndroid Build Coastguard Worker def dict_nested(x: Any): 129*da0073e9SAndroid Build Coastguard Worker assert torch.jit.isinstance(x, Dict[str, Tuple[str, str, str]]) 130*da0073e9SAndroid Build Coastguard Worker assert not torch.jit.isinstance(x, Dict[str, Tuple[int, int, int]]) 131*da0073e9SAndroid Build Coastguard Worker 132*da0073e9SAndroid Build Coastguard Worker x = {"a": ("aa", "aa", "aa"), "b": ("bb", "bb", "bb")} 133*da0073e9SAndroid Build Coastguard Worker self.checkScript(dict_nested, (x,)) 134*da0073e9SAndroid Build Coastguard Worker 135*da0073e9SAndroid Build Coastguard Worker def test_tuple_nested(self): 136*da0073e9SAndroid Build Coastguard Worker def tuple_nested(x: Any): 137*da0073e9SAndroid Build Coastguard Worker assert torch.jit.isinstance( 138*da0073e9SAndroid Build Coastguard Worker x, Tuple[Dict[str, Tuple[str, str, str]], List[bool], Optional[str]] 139*da0073e9SAndroid Build Coastguard Worker ) 140*da0073e9SAndroid Build Coastguard Worker assert not torch.jit.isinstance(x, Dict[str, Tuple[int, int, int]]) 141*da0073e9SAndroid Build Coastguard Worker assert not torch.jit.isinstance(x, Tuple[str]) 142*da0073e9SAndroid Build Coastguard Worker assert not torch.jit.isinstance(x, Tuple[List[bool], List[str], List[int]]) 143*da0073e9SAndroid Build Coastguard Worker 144*da0073e9SAndroid Build Coastguard Worker x = ( 145*da0073e9SAndroid Build Coastguard Worker {"a": ("aa", "aa", "aa"), "b": ("bb", "bb", "bb")}, 146*da0073e9SAndroid Build Coastguard Worker [True, False, True], 147*da0073e9SAndroid Build Coastguard Worker None, 148*da0073e9SAndroid Build Coastguard Worker ) 149*da0073e9SAndroid Build Coastguard Worker self.checkScript(tuple_nested, (x,)) 150*da0073e9SAndroid Build Coastguard Worker 151*da0073e9SAndroid Build Coastguard Worker def test_optional_nested(self): 152*da0073e9SAndroid Build Coastguard Worker def optional_nested(x: Any): 153*da0073e9SAndroid Build Coastguard Worker assert torch.jit.isinstance(x, Optional[List[str]]) 154*da0073e9SAndroid Build Coastguard Worker 155*da0073e9SAndroid Build Coastguard Worker x = ["a", "b", "c"] 156*da0073e9SAndroid Build Coastguard Worker self.checkScript(optional_nested, (x,)) 157*da0073e9SAndroid Build Coastguard Worker 158*da0073e9SAndroid Build Coastguard Worker def test_list_tensor_type_true(self): 159*da0073e9SAndroid Build Coastguard Worker def list_tensor_type_true(x: Any): 160*da0073e9SAndroid Build Coastguard Worker assert torch.jit.isinstance(x, List[torch.Tensor]) 161*da0073e9SAndroid Build Coastguard Worker 162*da0073e9SAndroid Build Coastguard Worker x = [torch.rand(3, 3), torch.rand(4, 3)] 163*da0073e9SAndroid Build Coastguard Worker self.checkScript(list_tensor_type_true, (x,)) 164*da0073e9SAndroid Build Coastguard Worker 165*da0073e9SAndroid Build Coastguard Worker def test_tensor_type_false(self): 166*da0073e9SAndroid Build Coastguard Worker def list_tensor_type_false(x: Any): 167*da0073e9SAndroid Build Coastguard Worker assert not torch.jit.isinstance(x, List[torch.Tensor]) 168*da0073e9SAndroid Build Coastguard Worker 169*da0073e9SAndroid Build Coastguard Worker x = [1, 2, 3] 170*da0073e9SAndroid Build Coastguard Worker self.checkScript(list_tensor_type_false, (x,)) 171*da0073e9SAndroid Build Coastguard Worker 172*da0073e9SAndroid Build Coastguard Worker def test_in_if(self): 173*da0073e9SAndroid Build Coastguard Worker def list_in_if(x: Any): 174*da0073e9SAndroid Build Coastguard Worker if torch.jit.isinstance(x, List[int]): 175*da0073e9SAndroid Build Coastguard Worker assert True 176*da0073e9SAndroid Build Coastguard Worker if torch.jit.isinstance(x, List[str]): 177*da0073e9SAndroid Build Coastguard Worker assert not True 178*da0073e9SAndroid Build Coastguard Worker 179*da0073e9SAndroid Build Coastguard Worker x = [1, 2, 3] 180*da0073e9SAndroid Build Coastguard Worker self.checkScript(list_in_if, (x,)) 181*da0073e9SAndroid Build Coastguard Worker 182*da0073e9SAndroid Build Coastguard Worker def test_if_else(self): 183*da0073e9SAndroid Build Coastguard Worker def list_in_if_else(x: Any): 184*da0073e9SAndroid Build Coastguard Worker if torch.jit.isinstance(x, Tuple[str, str, str]): 185*da0073e9SAndroid Build Coastguard Worker assert True 186*da0073e9SAndroid Build Coastguard Worker else: 187*da0073e9SAndroid Build Coastguard Worker assert not True 188*da0073e9SAndroid Build Coastguard Worker 189*da0073e9SAndroid Build Coastguard Worker x = ("a", "b", "c") 190*da0073e9SAndroid Build Coastguard Worker self.checkScript(list_in_if_else, (x,)) 191*da0073e9SAndroid Build Coastguard Worker 192*da0073e9SAndroid Build Coastguard Worker def test_in_while_loop(self): 193*da0073e9SAndroid Build Coastguard Worker def list_in_while_loop(x: Any): 194*da0073e9SAndroid Build Coastguard Worker count = 0 195*da0073e9SAndroid Build Coastguard Worker while torch.jit.isinstance(x, List[Dict[str, int]]) and count <= 0: 196*da0073e9SAndroid Build Coastguard Worker count = count + 1 197*da0073e9SAndroid Build Coastguard Worker assert count == 1 198*da0073e9SAndroid Build Coastguard Worker 199*da0073e9SAndroid Build Coastguard Worker x = [{"a": 1, "b": 2}, {"aa": 11, "bb": 22}] 200*da0073e9SAndroid Build Coastguard Worker self.checkScript(list_in_while_loop, (x,)) 201*da0073e9SAndroid Build Coastguard Worker 202*da0073e9SAndroid Build Coastguard Worker def test_type_refinement(self): 203*da0073e9SAndroid Build Coastguard Worker def type_refinement(obj: Any): 204*da0073e9SAndroid Build Coastguard Worker hit = False 205*da0073e9SAndroid Build Coastguard Worker if torch.jit.isinstance(obj, List[torch.Tensor]): 206*da0073e9SAndroid Build Coastguard Worker hit = not hit 207*da0073e9SAndroid Build Coastguard Worker for el in obj: 208*da0073e9SAndroid Build Coastguard Worker # perform some tensor operation 209*da0073e9SAndroid Build Coastguard Worker y = el.clamp(0, 0.5) 210*da0073e9SAndroid Build Coastguard Worker if torch.jit.isinstance(obj, Dict[str, str]): 211*da0073e9SAndroid Build Coastguard Worker hit = not hit 212*da0073e9SAndroid Build Coastguard Worker str_cat = "" 213*da0073e9SAndroid Build Coastguard Worker for val in obj.values(): 214*da0073e9SAndroid Build Coastguard Worker str_cat = str_cat + val 215*da0073e9SAndroid Build Coastguard Worker assert "111222" == str_cat 216*da0073e9SAndroid Build Coastguard Worker assert hit 217*da0073e9SAndroid Build Coastguard Worker 218*da0073e9SAndroid Build Coastguard Worker x = [torch.rand(3, 3), torch.rand(4, 3)] 219*da0073e9SAndroid Build Coastguard Worker self.checkScript(type_refinement, (x,)) 220*da0073e9SAndroid Build Coastguard Worker x = {"1": "111", "2": "222"} 221*da0073e9SAndroid Build Coastguard Worker self.checkScript(type_refinement, (x,)) 222*da0073e9SAndroid Build Coastguard Worker 223*da0073e9SAndroid Build Coastguard Worker def test_list_no_contained_type(self): 224*da0073e9SAndroid Build Coastguard Worker def list_no_contained_type(x: Any): 225*da0073e9SAndroid Build Coastguard Worker assert torch.jit.isinstance(x, List) 226*da0073e9SAndroid Build Coastguard Worker 227*da0073e9SAndroid Build Coastguard Worker x = ["1", "2", "3"] 228*da0073e9SAndroid Build Coastguard Worker 229*da0073e9SAndroid Build Coastguard Worker err_msg = ( 230*da0073e9SAndroid Build Coastguard Worker "Attempted to use List without a contained type. " 231*da0073e9SAndroid Build Coastguard Worker r"Please add a contained type, e.g. List\[int\]" 232*da0073e9SAndroid Build Coastguard Worker ) 233*da0073e9SAndroid Build Coastguard Worker 234*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex( 235*da0073e9SAndroid Build Coastguard Worker RuntimeError, 236*da0073e9SAndroid Build Coastguard Worker err_msg, 237*da0073e9SAndroid Build Coastguard Worker ): 238*da0073e9SAndroid Build Coastguard Worker torch.jit.script(list_no_contained_type) 239*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex( 240*da0073e9SAndroid Build Coastguard Worker RuntimeError, 241*da0073e9SAndroid Build Coastguard Worker err_msg, 242*da0073e9SAndroid Build Coastguard Worker ): 243*da0073e9SAndroid Build Coastguard Worker list_no_contained_type(x) 244*da0073e9SAndroid Build Coastguard Worker 245*da0073e9SAndroid Build Coastguard Worker def test_tuple_no_contained_type(self): 246*da0073e9SAndroid Build Coastguard Worker def tuple_no_contained_type(x: Any): 247*da0073e9SAndroid Build Coastguard Worker assert torch.jit.isinstance(x, Tuple) 248*da0073e9SAndroid Build Coastguard Worker 249*da0073e9SAndroid Build Coastguard Worker x = ("1", "2", "3") 250*da0073e9SAndroid Build Coastguard Worker 251*da0073e9SAndroid Build Coastguard Worker err_msg = ( 252*da0073e9SAndroid Build Coastguard Worker "Attempted to use Tuple without a contained type. " 253*da0073e9SAndroid Build Coastguard Worker r"Please add a contained type, e.g. Tuple\[int\]" 254*da0073e9SAndroid Build Coastguard Worker ) 255*da0073e9SAndroid Build Coastguard Worker 256*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex( 257*da0073e9SAndroid Build Coastguard Worker RuntimeError, 258*da0073e9SAndroid Build Coastguard Worker err_msg, 259*da0073e9SAndroid Build Coastguard Worker ): 260*da0073e9SAndroid Build Coastguard Worker torch.jit.script(tuple_no_contained_type) 261*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex( 262*da0073e9SAndroid Build Coastguard Worker RuntimeError, 263*da0073e9SAndroid Build Coastguard Worker err_msg, 264*da0073e9SAndroid Build Coastguard Worker ): 265*da0073e9SAndroid Build Coastguard Worker tuple_no_contained_type(x) 266*da0073e9SAndroid Build Coastguard Worker 267*da0073e9SAndroid Build Coastguard Worker def test_optional_no_contained_type(self): 268*da0073e9SAndroid Build Coastguard Worker def optional_no_contained_type(x: Any): 269*da0073e9SAndroid Build Coastguard Worker assert torch.jit.isinstance(x, Optional) 270*da0073e9SAndroid Build Coastguard Worker 271*da0073e9SAndroid Build Coastguard Worker x = ("1", "2", "3") 272*da0073e9SAndroid Build Coastguard Worker 273*da0073e9SAndroid Build Coastguard Worker err_msg = ( 274*da0073e9SAndroid Build Coastguard Worker "Attempted to use Optional without a contained type. " 275*da0073e9SAndroid Build Coastguard Worker r"Please add a contained type, e.g. Optional\[int\]" 276*da0073e9SAndroid Build Coastguard Worker ) 277*da0073e9SAndroid Build Coastguard Worker 278*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex( 279*da0073e9SAndroid Build Coastguard Worker RuntimeError, 280*da0073e9SAndroid Build Coastguard Worker err_msg, 281*da0073e9SAndroid Build Coastguard Worker ): 282*da0073e9SAndroid Build Coastguard Worker torch.jit.script(optional_no_contained_type) 283*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex( 284*da0073e9SAndroid Build Coastguard Worker RuntimeError, 285*da0073e9SAndroid Build Coastguard Worker err_msg, 286*da0073e9SAndroid Build Coastguard Worker ): 287*da0073e9SAndroid Build Coastguard Worker optional_no_contained_type(x) 288*da0073e9SAndroid Build Coastguard Worker 289*da0073e9SAndroid Build Coastguard Worker def test_dict_no_contained_type(self): 290*da0073e9SAndroid Build Coastguard Worker def dict_no_contained_type(x: Any): 291*da0073e9SAndroid Build Coastguard Worker assert torch.jit.isinstance(x, Dict) 292*da0073e9SAndroid Build Coastguard Worker 293*da0073e9SAndroid Build Coastguard Worker x = {"a": "aa"} 294*da0073e9SAndroid Build Coastguard Worker 295*da0073e9SAndroid Build Coastguard Worker err_msg = ( 296*da0073e9SAndroid Build Coastguard Worker "Attempted to use Dict without contained types. " 297*da0073e9SAndroid Build Coastguard Worker r"Please add contained type, e.g. Dict\[int, int\]" 298*da0073e9SAndroid Build Coastguard Worker ) 299*da0073e9SAndroid Build Coastguard Worker 300*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex( 301*da0073e9SAndroid Build Coastguard Worker RuntimeError, 302*da0073e9SAndroid Build Coastguard Worker err_msg, 303*da0073e9SAndroid Build Coastguard Worker ): 304*da0073e9SAndroid Build Coastguard Worker torch.jit.script(dict_no_contained_type) 305*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex( 306*da0073e9SAndroid Build Coastguard Worker RuntimeError, 307*da0073e9SAndroid Build Coastguard Worker err_msg, 308*da0073e9SAndroid Build Coastguard Worker ): 309*da0073e9SAndroid Build Coastguard Worker dict_no_contained_type(x) 310*da0073e9SAndroid Build Coastguard Worker 311*da0073e9SAndroid Build Coastguard Worker def test_tuple_rhs(self): 312*da0073e9SAndroid Build Coastguard Worker def fn(x: Any): 313*da0073e9SAndroid Build Coastguard Worker assert torch.jit.isinstance(x, (int, List[str])) 314*da0073e9SAndroid Build Coastguard Worker assert not torch.jit.isinstance(x, (List[float], Tuple[int, str])) 315*da0073e9SAndroid Build Coastguard Worker assert not torch.jit.isinstance(x, (List[float], str)) 316*da0073e9SAndroid Build Coastguard Worker 317*da0073e9SAndroid Build Coastguard Worker self.checkScript(fn, (2,)) 318*da0073e9SAndroid Build Coastguard Worker self.checkScript(fn, (["foo", "bar", "baz"],)) 319*da0073e9SAndroid Build Coastguard Worker 320*da0073e9SAndroid Build Coastguard Worker def test_nontuple_container_rhs_throws_in_eager(self): 321*da0073e9SAndroid Build Coastguard Worker def fn1(x: Any): 322*da0073e9SAndroid Build Coastguard Worker assert torch.jit.isinstance(x, [int, List[str]]) 323*da0073e9SAndroid Build Coastguard Worker 324*da0073e9SAndroid Build Coastguard Worker def fn2(x: Any): 325*da0073e9SAndroid Build Coastguard Worker assert not torch.jit.isinstance(x, {List[str], Tuple[int, str]}) 326*da0073e9SAndroid Build Coastguard Worker 327*da0073e9SAndroid Build Coastguard Worker err_highlight = "must be a type or a tuple of types" 328*da0073e9SAndroid Build Coastguard Worker 329*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(RuntimeError, err_highlight): 330*da0073e9SAndroid Build Coastguard Worker fn1(2) 331*da0073e9SAndroid Build Coastguard Worker 332*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(RuntimeError, err_highlight): 333*da0073e9SAndroid Build Coastguard Worker fn2(2) 334*da0073e9SAndroid Build Coastguard Worker 335*da0073e9SAndroid Build Coastguard Worker def test_empty_container_throws_warning_in_eager(self): 336*da0073e9SAndroid Build Coastguard Worker def fn(x: Any): 337*da0073e9SAndroid Build Coastguard Worker torch.jit.isinstance(x, List[int]) 338*da0073e9SAndroid Build Coastguard Worker 339*da0073e9SAndroid Build Coastguard Worker with warnings.catch_warnings(record=True) as w: 340*da0073e9SAndroid Build Coastguard Worker x: List[int] = [] 341*da0073e9SAndroid Build Coastguard Worker fn(x) 342*da0073e9SAndroid Build Coastguard Worker self.assertEqual(len(w), 1) 343*da0073e9SAndroid Build Coastguard Worker 344*da0073e9SAndroid Build Coastguard Worker with warnings.catch_warnings(record=True) as w: 345*da0073e9SAndroid Build Coastguard Worker x: int = 2 346*da0073e9SAndroid Build Coastguard Worker fn(x) 347*da0073e9SAndroid Build Coastguard Worker self.assertEqual(len(w), 0) 348*da0073e9SAndroid Build Coastguard Worker 349*da0073e9SAndroid Build Coastguard Worker def test_empty_container_special_cases(self): 350*da0073e9SAndroid Build Coastguard Worker # Should not throw "Boolean value of Tensor with no values is 351*da0073e9SAndroid Build Coastguard Worker # ambiguous" error 352*da0073e9SAndroid Build Coastguard Worker torch._jit_internal.check_empty_containers(torch.Tensor([])) 353*da0073e9SAndroid Build Coastguard Worker 354*da0073e9SAndroid Build Coastguard Worker # Should not throw "Boolean value of Tensor with more than 355*da0073e9SAndroid Build Coastguard Worker # one value is ambiguous" error 356*da0073e9SAndroid Build Coastguard Worker torch._jit_internal.check_empty_containers(torch.rand(2, 3)) 357