1# Owner(s): ["oncall: jit"] 2 3import inspect 4import os 5import sys 6from collections import namedtuple 7from textwrap import dedent 8from typing import Dict, Iterator, List, Optional, Tuple 9 10import torch 11import torch.testing._internal.jit_utils 12from jit.test_module_interface import TestModuleInterface # noqa: F401 13from torch.testing import FileCheck 14from torch.testing._internal.jit_utils import JitTestCase 15 16 17# Make the helper files in test/ importable 18pytorch_test_dir = os.path.dirname(os.path.dirname(os.path.realpath(__file__))) 19sys.path.append(pytorch_test_dir) 20 21if __name__ == "__main__": 22 raise RuntimeError( 23 "This test file is not meant to be run directly, use:\n\n" 24 "\tpython test/test_jit.py TESTNAME\n\n" 25 "instead." 26 ) 27 28 29class TestTypesAndAnnotation(JitTestCase): 30 def test_pep585_type(self): 31 # TODO add test to use PEP585 type annotation for return type after py3.9 32 # see: https://www.python.org/dev/peps/pep-0585/#id5 33 def fn(x: torch.Tensor) -> Tuple[Tuple[torch.Tensor], Dict[str, int]]: 34 xl: list[tuple[torch.Tensor]] = [] 35 xd: dict[str, int] = {} 36 xl.append((x,)) 37 xd["foo"] = 1 38 return xl.pop(), xd 39 40 self.checkScript(fn, [torch.randn(2, 2)]) 41 42 x = torch.randn(2, 2) 43 expected = fn(x) 44 scripted = torch.jit.script(fn)(x) 45 46 self.assertEqual(expected, scripted) 47 48 def test_types_as_values(self): 49 def fn(m: torch.Tensor) -> torch.device: 50 return m.device 51 52 self.checkScript(fn, [torch.randn(2, 2)]) 53 54 GG = namedtuple("GG", ["f", "g"]) 55 56 class Foo(torch.nn.Module): 57 @torch.jit.ignore 58 def foo(self, x: torch.Tensor, z: torch.Tensor) -> Tuple[GG, GG]: 59 return GG(x, z), GG(x, z) 60 61 def forward(self, x, z): 62 return self.foo(x, z) 63 64 foo = torch.jit.script(Foo()) 65 y = foo(torch.randn(2, 2), torch.randn(2, 2)) 66 67 class Foo(torch.nn.Module): 68 @torch.jit.ignore 69 def foo(self, x, z) -> Tuple[GG, GG]: 70 return GG(x, z) 71 72 def forward(self, x, z): 73 return self.foo(x, z) 74 75 foo = torch.jit.script(Foo()) 76 y = foo(torch.randn(2, 2), torch.randn(2, 2)) 77 78 def test_ignore_with_types(self): 79 @torch.jit.ignore 80 def fn(x: Dict[str, Optional[torch.Tensor]]): 81 return x + 10 82 83 class M(torch.nn.Module): 84 def forward( 85 self, in_batch: Dict[str, Optional[torch.Tensor]] 86 ) -> torch.Tensor: 87 self.dropout_modality(in_batch) 88 fn(in_batch) 89 return torch.tensor(1) 90 91 @torch.jit.ignore 92 def dropout_modality( 93 self, in_batch: Dict[str, Optional[torch.Tensor]] 94 ) -> Dict[str, Optional[torch.Tensor]]: 95 return in_batch 96 97 sm = torch.jit.script(M()) 98 FileCheck().check("dropout_modality").check("in_batch").run(str(sm.graph)) 99 100 def test_python_callable(self): 101 class MyPythonClass: 102 @torch.jit.ignore 103 def __call__(self, *args) -> str: 104 return str(type(args[0])) 105 106 the_class = MyPythonClass() 107 108 @torch.jit.script 109 def fn(x): 110 return the_class(x) 111 112 # This doesn't involve the string frontend, so don't use checkScript 113 x = torch.ones(2) 114 self.assertEqual(fn(x), the_class(x)) 115 116 def test_bad_types(self): 117 @torch.jit.ignore 118 def fn(my_arg): 119 return my_arg + 10 120 121 with self.assertRaisesRegex(RuntimeError, "argument 'my_arg'"): 122 123 @torch.jit.script 124 def other_fn(x): 125 return fn("2") 126 127 def test_type_annotate_py3(self): 128 def fn(): 129 a: List[int] = [] 130 b: torch.Tensor = torch.ones(2, 2) 131 c: Optional[torch.Tensor] = None 132 d: Optional[torch.Tensor] = torch.ones(3, 4) 133 for _ in range(10): 134 a.append(4) 135 c = torch.ones(2, 2) 136 d = None 137 return a, b, c, d 138 139 self.checkScript(fn, ()) 140 141 def wrong_type(): 142 wrong: List[int] = [0.5] 143 return wrong 144 145 with self.assertRaisesRegex( 146 RuntimeError, 147 "List type annotation" 148 r" `List\[int\]` did not match the " 149 "types of the given list elements", 150 ): 151 torch.jit.script(wrong_type) 152 153 def test_optional_no_element_type_annotation(self): 154 """ 155 Test that using an optional with no contained types produces an error. 156 """ 157 158 def fn_with_comment(x: torch.Tensor) -> Optional: 159 return (x, x) 160 161 def annotated_fn(x: torch.Tensor) -> Optional: 162 return (x, x) 163 164 with self.assertRaisesRegex( 165 RuntimeError, r"Attempted to use Optional without a contained type" 166 ): 167 cu = torch.jit.CompilationUnit() 168 cu.define(dedent(inspect.getsource(fn_with_comment))) 169 170 with self.assertRaisesRegex( 171 RuntimeError, r"Attempted to use Optional without a contained type" 172 ): 173 cu = torch.jit.CompilationUnit() 174 cu.define(dedent(inspect.getsource(annotated_fn))) 175 176 with self.assertRaisesRegex( 177 RuntimeError, r"Attempted to use Optional without a contained type" 178 ): 179 torch.jit.script(fn_with_comment) 180 181 with self.assertRaisesRegex( 182 RuntimeError, r"Attempted to use Optional without a contained type" 183 ): 184 torch.jit.script(annotated_fn) 185 186 def test_tuple_no_element_type_annotation(self): 187 """ 188 Test that using a tuple with no contained types produces an error. 189 """ 190 191 def fn_with_comment(x: torch.Tensor) -> Tuple: 192 return (x, x) 193 194 def annotated_fn(x: torch.Tensor) -> Tuple: 195 return (x, x) 196 197 with self.assertRaisesRegex( 198 RuntimeError, r"Attempted to use Tuple without a contained type" 199 ): 200 cu = torch.jit.CompilationUnit() 201 cu.define(dedent(inspect.getsource(fn_with_comment))) 202 203 with self.assertRaisesRegex( 204 RuntimeError, r"Attempted to use Tuple without a contained type" 205 ): 206 cu = torch.jit.CompilationUnit() 207 cu.define(dedent(inspect.getsource(annotated_fn))) 208 209 with self.assertRaisesRegex( 210 RuntimeError, r"Attempted to use Tuple without a contained type" 211 ): 212 torch.jit.script(fn_with_comment) 213 214 with self.assertRaisesRegex( 215 RuntimeError, r"Attempted to use Tuple without a contained type" 216 ): 217 torch.jit.script(annotated_fn) 218 219 def test_ignoring_module_attributes(self): 220 """ 221 Test that module attributes can be ignored. 222 """ 223 224 class Sub(torch.nn.Module): 225 def forward(self, a: int) -> int: 226 return sum([a]) 227 228 class ModuleWithIgnoredAttr(torch.nn.Module): 229 __jit_ignored_attributes__ = ["a", "sub"] 230 231 def __init__(self, a: int, b: int): 232 super().__init__() 233 self.a = a 234 self.b = b 235 self.sub = Sub() 236 237 def forward(self) -> int: 238 return self.b 239 240 @torch.jit.ignore 241 def ignored_fn(self) -> int: 242 return self.sub.forward(self.a) 243 244 mod = ModuleWithIgnoredAttr(1, 4) 245 scripted_mod = torch.jit.script(mod) 246 self.assertEqual(scripted_mod(), 4) 247 self.assertEqual(scripted_mod.ignored_fn(), 1) 248 249 # Test the error message for ignored attributes. 250 class ModuleUsesIgnoredAttr(torch.nn.Module): 251 __jit_ignored_attributes__ = ["a", "sub"] 252 253 def __init__(self, a: int): 254 super().__init__() 255 self.a = a 256 self.sub = Sub() 257 258 def forward(self) -> int: 259 return self.sub(self.b) 260 261 mod = ModuleUsesIgnoredAttr(1) 262 263 with self.assertRaisesRegexWithHighlight( 264 RuntimeError, r"attribute was ignored during compilation", "self.sub" 265 ): 266 scripted_mod = torch.jit.script(mod) 267 268 def test_ignoring_fn_with_nonscriptable_types(self): 269 class CFX: 270 def __init__(self, a: List[torch.Tensor]) -> None: 271 self.a = a 272 273 def forward(self, x: torch.Tensor) -> torch.Tensor: 274 return torch.sin(x) 275 276 @torch.jit._drop 277 def __iter__(self) -> Iterator[torch.Tensor]: 278 return iter(self.a) 279 280 @torch.jit._drop 281 def __fx_create_arg__( 282 self, tracer: torch.fx.Tracer 283 ) -> torch.fx.node.Argument: 284 # torch.fx classes are not scriptable 285 return tracer.create_node( 286 "call_function", 287 CFX, 288 args=(tracer.create_arg(self.features),), 289 kwargs={}, 290 ) 291 292 torch.jit.script(CFX) 293 294 def test_unimported_type_resolution(self): 295 # verify fallback from the python resolver to the c++ resolver 296 297 @torch.jit.script 298 def fn(x): 299 # type: (number) -> number 300 return x + 1 301 302 FileCheck().check("Scalar").run(fn.graph) 303 304 def test_parser_bug(self): 305 def parser_bug(o: Optional[torch.Tensor]): 306 pass 307 308 def test_mismatched_annotation(self): 309 with self.assertRaisesRegex(RuntimeError, "annotated with type"): 310 311 @torch.jit.script 312 def foo(): 313 x: str = 4 314 return x 315 316 def test_reannotate(self): 317 with self.assertRaisesRegex(RuntimeError, "declare and annotate"): 318 319 @torch.jit.script 320 def foo(): 321 x = 5 322 if 1 == 1: 323 x: Optional[int] = 7 324 325 def test_annotate_outside_init(self): 326 msg = "annotations on instance attributes must be declared in __init__" 327 highlight = "self.x: int" 328 329 # Simple case 330 with self.assertRaisesRegexWithHighlight(ValueError, msg, highlight): 331 332 @torch.jit.script 333 class BadModule: 334 def __init__(self, x: int): 335 self.x = x 336 337 def set(self, val: int): 338 self.x: int = val 339 340 # Type annotation in a loop 341 with self.assertRaisesRegexWithHighlight(ValueError, msg, highlight): 342 343 @torch.jit.script 344 class BadModuleLoop: 345 def __init__(self, x: int): 346 self.x = x 347 348 def set(self, val: int): 349 for i in range(3): 350 self.x: int = val 351 352 # Type annotation in __init__, should not fail 353 @torch.jit.script 354 class GoodModule: 355 def __init__(self, x: int): 356 self.x: int = x 357 358 def set(self, val: int): 359 self.x = val 360 361 def test_inferred_type_error_message(self): 362 inferred_type = torch._C.InferredType("ErrorReason") 363 364 with self.assertRaisesRegex( 365 RuntimeError, 366 "Tried to get the type from an InferredType but the type is null.", 367 ): 368 t = inferred_type.type() 369 370 with self.assertRaisesRegex(RuntimeError, "ErrorReason"): 371 t = inferred_type.type() 372