1# Owner(s): ["oncall: jit"] 2 3import os 4import sys 5import warnings 6from typing import Any, Dict, List, Optional, Tuple 7 8import torch 9 10 11# Make the helper files in test/ importable 12pytorch_test_dir = os.path.dirname(os.path.dirname(os.path.realpath(__file__))) 13sys.path.append(pytorch_test_dir) 14from torch.testing._internal.jit_utils import JitTestCase 15 16 17if __name__ == "__main__": 18 raise RuntimeError( 19 "This test file is not meant to be run directly, use:\n\n" 20 "\tpython test/test_jit.py TESTNAME\n\n" 21 "instead." 22 ) 23 24 25# Tests for torch.jit.isinstance 26class TestIsinstance(JitTestCase): 27 def test_int(self): 28 def int_test(x: Any): 29 assert torch.jit.isinstance(x, int) 30 assert not torch.jit.isinstance(x, float) 31 32 x = 1 33 self.checkScript(int_test, (x,)) 34 35 def test_float(self): 36 def float_test(x: Any): 37 assert torch.jit.isinstance(x, float) 38 assert not torch.jit.isinstance(x, int) 39 40 x = 1.0 41 self.checkScript(float_test, (x,)) 42 43 def test_bool(self): 44 def bool_test(x: Any): 45 assert torch.jit.isinstance(x, bool) 46 assert not torch.jit.isinstance(x, float) 47 48 x = False 49 self.checkScript(bool_test, (x,)) 50 51 def test_list(self): 52 def list_str_test(x: Any): 53 assert torch.jit.isinstance(x, List[str]) 54 assert not torch.jit.isinstance(x, List[int]) 55 assert not torch.jit.isinstance(x, Tuple[int]) 56 57 x = ["1", "2", "3"] 58 self.checkScript(list_str_test, (x,)) 59 60 def test_list_tensor(self): 61 def list_tensor_test(x: Any): 62 assert torch.jit.isinstance(x, List[torch.Tensor]) 63 assert not torch.jit.isinstance(x, Tuple[int]) 64 65 x = [torch.tensor([1]), torch.tensor([2]), torch.tensor([3])] 66 self.checkScript(list_tensor_test, (x,)) 67 68 def test_dict(self): 69 def dict_str_int_test(x: Any): 70 assert torch.jit.isinstance(x, Dict[str, int]) 71 assert not torch.jit.isinstance(x, Dict[int, str]) 72 assert not torch.jit.isinstance(x, Dict[str, str]) 73 74 x = {"a": 1, "b": 2} 75 self.checkScript(dict_str_int_test, (x,)) 76 77 def test_dict_tensor(self): 78 def dict_int_tensor_test(x: Any): 79 assert torch.jit.isinstance(x, Dict[int, torch.Tensor]) 80 81 x = {2: torch.tensor([2])} 82 self.checkScript(dict_int_tensor_test, (x,)) 83 84 def test_tuple(self): 85 def tuple_test(x: Any): 86 assert torch.jit.isinstance(x, Tuple[str, int, str]) 87 assert not torch.jit.isinstance(x, Tuple[int, str, str]) 88 assert not torch.jit.isinstance(x, Tuple[str]) 89 90 x = ("a", 1, "b") 91 self.checkScript(tuple_test, (x,)) 92 93 def test_tuple_tensor(self): 94 def tuple_tensor_test(x: Any): 95 assert torch.jit.isinstance(x, Tuple[torch.Tensor, torch.Tensor]) 96 97 x = (torch.tensor([1]), torch.tensor([[2], [3]])) 98 self.checkScript(tuple_tensor_test, (x,)) 99 100 def test_optional(self): 101 def optional_test(x: Any): 102 assert torch.jit.isinstance(x, Optional[torch.Tensor]) 103 assert not torch.jit.isinstance(x, Optional[str]) 104 105 x = torch.ones(3, 3) 106 self.checkScript(optional_test, (x,)) 107 108 def test_optional_none(self): 109 def optional_test_none(x: Any): 110 assert torch.jit.isinstance(x, Optional[torch.Tensor]) 111 # assert torch.jit.isinstance(x, Optional[str]) 112 # TODO: above line in eager will evaluate to True while in 113 # the TS interpreter will evaluate to False as the 114 # first torch.jit.isinstance refines the 'None' type 115 116 x = None 117 self.checkScript(optional_test_none, (x,)) 118 119 def test_list_nested(self): 120 def list_nested(x: Any): 121 assert torch.jit.isinstance(x, List[Dict[str, int]]) 122 assert not torch.jit.isinstance(x, List[List[str]]) 123 124 x = [{"a": 1, "b": 2}, {"aa": 11, "bb": 22}] 125 self.checkScript(list_nested, (x,)) 126 127 def test_dict_nested(self): 128 def dict_nested(x: Any): 129 assert torch.jit.isinstance(x, Dict[str, Tuple[str, str, str]]) 130 assert not torch.jit.isinstance(x, Dict[str, Tuple[int, int, int]]) 131 132 x = {"a": ("aa", "aa", "aa"), "b": ("bb", "bb", "bb")} 133 self.checkScript(dict_nested, (x,)) 134 135 def test_tuple_nested(self): 136 def tuple_nested(x: Any): 137 assert torch.jit.isinstance( 138 x, Tuple[Dict[str, Tuple[str, str, str]], List[bool], Optional[str]] 139 ) 140 assert not torch.jit.isinstance(x, Dict[str, Tuple[int, int, int]]) 141 assert not torch.jit.isinstance(x, Tuple[str]) 142 assert not torch.jit.isinstance(x, Tuple[List[bool], List[str], List[int]]) 143 144 x = ( 145 {"a": ("aa", "aa", "aa"), "b": ("bb", "bb", "bb")}, 146 [True, False, True], 147 None, 148 ) 149 self.checkScript(tuple_nested, (x,)) 150 151 def test_optional_nested(self): 152 def optional_nested(x: Any): 153 assert torch.jit.isinstance(x, Optional[List[str]]) 154 155 x = ["a", "b", "c"] 156 self.checkScript(optional_nested, (x,)) 157 158 def test_list_tensor_type_true(self): 159 def list_tensor_type_true(x: Any): 160 assert torch.jit.isinstance(x, List[torch.Tensor]) 161 162 x = [torch.rand(3, 3), torch.rand(4, 3)] 163 self.checkScript(list_tensor_type_true, (x,)) 164 165 def test_tensor_type_false(self): 166 def list_tensor_type_false(x: Any): 167 assert not torch.jit.isinstance(x, List[torch.Tensor]) 168 169 x = [1, 2, 3] 170 self.checkScript(list_tensor_type_false, (x,)) 171 172 def test_in_if(self): 173 def list_in_if(x: Any): 174 if torch.jit.isinstance(x, List[int]): 175 assert True 176 if torch.jit.isinstance(x, List[str]): 177 assert not True 178 179 x = [1, 2, 3] 180 self.checkScript(list_in_if, (x,)) 181 182 def test_if_else(self): 183 def list_in_if_else(x: Any): 184 if torch.jit.isinstance(x, Tuple[str, str, str]): 185 assert True 186 else: 187 assert not True 188 189 x = ("a", "b", "c") 190 self.checkScript(list_in_if_else, (x,)) 191 192 def test_in_while_loop(self): 193 def list_in_while_loop(x: Any): 194 count = 0 195 while torch.jit.isinstance(x, List[Dict[str, int]]) and count <= 0: 196 count = count + 1 197 assert count == 1 198 199 x = [{"a": 1, "b": 2}, {"aa": 11, "bb": 22}] 200 self.checkScript(list_in_while_loop, (x,)) 201 202 def test_type_refinement(self): 203 def type_refinement(obj: Any): 204 hit = False 205 if torch.jit.isinstance(obj, List[torch.Tensor]): 206 hit = not hit 207 for el in obj: 208 # perform some tensor operation 209 y = el.clamp(0, 0.5) 210 if torch.jit.isinstance(obj, Dict[str, str]): 211 hit = not hit 212 str_cat = "" 213 for val in obj.values(): 214 str_cat = str_cat + val 215 assert "111222" == str_cat 216 assert hit 217 218 x = [torch.rand(3, 3), torch.rand(4, 3)] 219 self.checkScript(type_refinement, (x,)) 220 x = {"1": "111", "2": "222"} 221 self.checkScript(type_refinement, (x,)) 222 223 def test_list_no_contained_type(self): 224 def list_no_contained_type(x: Any): 225 assert torch.jit.isinstance(x, List) 226 227 x = ["1", "2", "3"] 228 229 err_msg = ( 230 "Attempted to use List without a contained type. " 231 r"Please add a contained type, e.g. List\[int\]" 232 ) 233 234 with self.assertRaisesRegex( 235 RuntimeError, 236 err_msg, 237 ): 238 torch.jit.script(list_no_contained_type) 239 with self.assertRaisesRegex( 240 RuntimeError, 241 err_msg, 242 ): 243 list_no_contained_type(x) 244 245 def test_tuple_no_contained_type(self): 246 def tuple_no_contained_type(x: Any): 247 assert torch.jit.isinstance(x, Tuple) 248 249 x = ("1", "2", "3") 250 251 err_msg = ( 252 "Attempted to use Tuple without a contained type. " 253 r"Please add a contained type, e.g. Tuple\[int\]" 254 ) 255 256 with self.assertRaisesRegex( 257 RuntimeError, 258 err_msg, 259 ): 260 torch.jit.script(tuple_no_contained_type) 261 with self.assertRaisesRegex( 262 RuntimeError, 263 err_msg, 264 ): 265 tuple_no_contained_type(x) 266 267 def test_optional_no_contained_type(self): 268 def optional_no_contained_type(x: Any): 269 assert torch.jit.isinstance(x, Optional) 270 271 x = ("1", "2", "3") 272 273 err_msg = ( 274 "Attempted to use Optional without a contained type. " 275 r"Please add a contained type, e.g. Optional\[int\]" 276 ) 277 278 with self.assertRaisesRegex( 279 RuntimeError, 280 err_msg, 281 ): 282 torch.jit.script(optional_no_contained_type) 283 with self.assertRaisesRegex( 284 RuntimeError, 285 err_msg, 286 ): 287 optional_no_contained_type(x) 288 289 def test_dict_no_contained_type(self): 290 def dict_no_contained_type(x: Any): 291 assert torch.jit.isinstance(x, Dict) 292 293 x = {"a": "aa"} 294 295 err_msg = ( 296 "Attempted to use Dict without contained types. " 297 r"Please add contained type, e.g. Dict\[int, int\]" 298 ) 299 300 with self.assertRaisesRegex( 301 RuntimeError, 302 err_msg, 303 ): 304 torch.jit.script(dict_no_contained_type) 305 with self.assertRaisesRegex( 306 RuntimeError, 307 err_msg, 308 ): 309 dict_no_contained_type(x) 310 311 def test_tuple_rhs(self): 312 def fn(x: Any): 313 assert torch.jit.isinstance(x, (int, List[str])) 314 assert not torch.jit.isinstance(x, (List[float], Tuple[int, str])) 315 assert not torch.jit.isinstance(x, (List[float], str)) 316 317 self.checkScript(fn, (2,)) 318 self.checkScript(fn, (["foo", "bar", "baz"],)) 319 320 def test_nontuple_container_rhs_throws_in_eager(self): 321 def fn1(x: Any): 322 assert torch.jit.isinstance(x, [int, List[str]]) 323 324 def fn2(x: Any): 325 assert not torch.jit.isinstance(x, {List[str], Tuple[int, str]}) 326 327 err_highlight = "must be a type or a tuple of types" 328 329 with self.assertRaisesRegex(RuntimeError, err_highlight): 330 fn1(2) 331 332 with self.assertRaisesRegex(RuntimeError, err_highlight): 333 fn2(2) 334 335 def test_empty_container_throws_warning_in_eager(self): 336 def fn(x: Any): 337 torch.jit.isinstance(x, List[int]) 338 339 with warnings.catch_warnings(record=True) as w: 340 x: List[int] = [] 341 fn(x) 342 self.assertEqual(len(w), 1) 343 344 with warnings.catch_warnings(record=True) as w: 345 x: int = 2 346 fn(x) 347 self.assertEqual(len(w), 0) 348 349 def test_empty_container_special_cases(self): 350 # Should not throw "Boolean value of Tensor with no values is 351 # ambiguous" error 352 torch._jit_internal.check_empty_containers(torch.Tensor([])) 353 354 # Should not throw "Boolean value of Tensor with more than 355 # one value is ambiguous" error 356 torch._jit_internal.check_empty_containers(torch.rand(2, 3)) 357