1# Owner(s): ["oncall: jit"] 2 3import inspect 4import os 5import sys 6import unittest 7from typing import Dict, List 8 9import torch 10from torch.testing import FileCheck 11 12 13# Make the helper files in test/ importable 14pytorch_test_dir = os.path.dirname(os.path.dirname(os.path.realpath(__file__))) 15sys.path.append(pytorch_test_dir) 16from torch.testing._internal.jit_utils import JitTestCase, RUN_CUDA 17 18 19if __name__ == "__main__": 20 raise RuntimeError( 21 "This test file is not meant to be run directly, use:\n\n" 22 "\tpython test/test_jit.py TESTNAME\n\n" 23 "instead." 24 ) 25 26 27class TestBuiltins(JitTestCase): 28 """ 29 Tests for TorchScript support of Python builtin functions. 30 """ 31 32 def test_has_attr(self): 33 class HasA(torch.nn.Module): 34 def __init__(self) -> None: 35 super().__init__() 36 self.a = 0 37 38 class HasB(torch.nn.Module): 39 def __init__(self) -> None: 40 super().__init__() 41 self.b = 1 42 43 class Mod(torch.nn.Module): 44 def __init__(self) -> None: 45 super().__init__() 46 self.mods = torch.nn.ModuleList([HasA(), HasB()]) 47 48 def forward(self): 49 # use a list to encode hasattr results 50 l = torch.jit.annotate(List[int], []) 51 for mod in self.mods: 52 l.append(int(hasattr(mod, "a"))) 53 l.append(int(hasattr(mod, "b"))) 54 # actually retrieve the attr to test static refinement 55 if hasattr(mod, "a"): 56 l.append(mod.a) 57 if hasattr(mod, "b"): 58 l.append(mod.b) 59 return l 60 61 self.checkModule(Mod(), ()) 62 63 def test_has_attr_invalid_args(self): 64 class Mod(torch.nn.Module): 65 def __init__(self) -> None: 66 super().__init__() 67 self.mod = torch.nn.Linear(1, 1) 68 69 def forward(self, name): 70 # not allowed, `name` must be static. 71 return hasattr(self.mod, name) 72 73 with self.assertRaisesRegexWithHighlight(RuntimeError, "hasattr", "name"): 74 torch.jit.script(Mod()) 75 76 class Mod(torch.nn.Module): 77 def forward(self, name): 78 # not allowed, `torch.rand` is not a class type 79 return hasattr(torch.rand(2, 3), name) 80 81 with self.assertRaisesRegexWithHighlight(RuntimeError, "hasattr", "name"): 82 torch.jit.script(Mod()) 83 84 def test_del(self): 85 def fn(x: List[int]) -> List[int]: 86 a = x * 2 87 del a 88 return x 89 90 self.checkScript(fn, ([1, 2, 3],)) 91 92 with self.assertRaisesRegexWithHighlight(RuntimeError, "undefined value", "a"): 93 94 @torch.jit.script 95 def fn(x): 96 a = x**2 97 del a 98 return a # noqa: F821 99 100 with self.assertRaisesRegexWithHighlight(RuntimeError, "undefined value", "a"): 101 102 @torch.jit.script 103 def fn(x): 104 a = x**2 105 if a: 106 del a 107 return a 108 109 with self.assertRaisesRegexWithHighlight(RuntimeError, "undefined value", "b"): 110 111 @torch.jit.script 112 def fn(x): 113 a = x**2 114 del b # noqa: F821 115 return a 116 117 def test_del_multiple_operands(self): 118 def fn(x: List[int]) -> List[int]: 119 a, b, c = x[0], x[1], x[2] 120 del a, b, c 121 return x 122 123 self.checkScript(fn, ([1, 2, 3],)) 124 125 def del_list_multiple_operands(x: List[int]) -> List[int]: 126 del x[0], x[1] 127 return x 128 129 py_out = del_list_multiple_operands([0, 1, 2]) 130 jit_out = torch.jit.script(del_list_multiple_operands)([0, 1, 2]) 131 self.assertEqual(py_out, jit_out) 132 133 def del_dict_multiple_operands(x: Dict[str, int]) -> Dict[str, int]: 134 del x["hi"], x["there"] 135 return x 136 137 py_out = del_dict_multiple_operands({"hi": 5, "there": 6}) 138 jit_out = torch.jit.script(del_dict_multiple_operands)({"hi": 5, "there": 6}) 139 self.assertEqual(py_out, jit_out) 140 141 142class TestTensorBuiltins(JitTestCase): 143 def test_tensor_properties(self): 144 def should_keep(tensor, name): 145 if inspect.isroutine(getattr(tensor, name)): 146 return False 147 if name.startswith("_"): 148 return False 149 return True 150 151 tensor = torch.arange(4, dtype=torch.float).view(2, 2) 152 keys = dir(tensor) 153 154 # real and imag are only implemented for complex tensors. 155 self.assertRaises(RuntimeError, lambda: should_keep(tensor, "imag")) 156 keys.remove("imag") 157 158 properties = [p for p in keys if should_keep(tensor, p)] 159 160 code_template = """ 161 def fn(x): 162 return x.{} 163 """ 164 165 EQUALITY_MISMATCH = { 166 # TorchScript doesn't have real enums so they return an int instead 167 # of the actual value 168 "dtype", 169 "layout", 170 } 171 MISSING_PROPERTIES = { 172 "grad_fn", 173 # This is an undocumented property so it's not included 174 "output_nr", 175 # This has a longer implementation, maybe not worth copying to 176 # TorchScript if named tensors don't work there anyways 177 "names", 178 } 179 180 for p in properties: 181 if p in MISSING_PROPERTIES: 182 continue 183 code = code_template.format(p) 184 cu = torch.jit.CompilationUnit() 185 cu.define(code) 186 if p in EQUALITY_MISMATCH: 187 continue 188 self.assertEqual(getattr(tensor, p), cu.fn(tensor)) 189 190 def test_tensor_subscript_assign(self): 191 def fn1(x): 192 a = torch.zeros_like(x, dtype=torch.uint8) 193 a[torch.tensor(0)] = torch.tensor(2, dtype=torch.uint8) 194 return a 195 196 def fn2(x): 197 a = torch.zeros_like(x, dtype=torch.uint8) 198 a[0] = 2 199 return a 200 201 def fn3(x): 202 a = torch.zeros_like(x, dtype=torch.uint8) 203 a[torch.tensor(0)] = 2 204 return a 205 206 def fn4(x): 207 a = torch.zeros_like(x, dtype=torch.uint8) 208 a[0] = torch.tensor(2, dtype=torch.uint8) 209 return a 210 211 def fn5(x): 212 a = torch.zeros_like(x, dtype=torch.float32) 213 a[torch.tensor(0)] = 2 214 return a 215 216 for fn in (fn1, fn2, fn3, fn4, fn5): 217 self.checkScript(fn, (torch.zeros(2, dtype=torch.uint8),)) 218 219 @unittest.skipIf(not RUN_CUDA, "requires CUDA") 220 def test_tensor_subscript_assign_device(self): 221 def fn6(x): 222 a = torch.zeros_like(x, dtype=torch.float32, device="cuda") 223 a[torch.tensor(0)] = 2 224 return a 225 226 self.checkScript(fn6, (torch.zeros(2, dtype=torch.float32, device="cuda"),)) 227 228 def test_tensor_item(self): 229 def test_scalar_cast(x): 230 scalar = x.item() 231 return int(scalar), float(scalar) 232 233 graph = torch.jit.script(test_scalar_cast).graph 234 FileCheck().check("(int, float) = prim::TupleConstruct").run(graph) 235 self.checkScript(test_scalar_cast, (torch.tensor(1.0),)) 236 self.checkScript(test_scalar_cast, (torch.tensor(1),)) 237 238 def test_method_on_number(self): 239 def func(): 240 c = 1 241 return c.add(1) 242 243 with self.assertRaisesRegex(RuntimeError, "object has no attribute or method"): 244 torch.jit.script(func) 245 246 # testing implicit conversion of tensors to scalars to match function arguments 247 def test_scalar_to_num_conversions(self): 248 @torch.jit.script 249 def multiple_defs(x): 250 c = 1 251 x = x + c 252 return x 253 254 self.assertTrue("ImplicitTensorToNum" not in str(multiple_defs.graph)) 255 256 @torch.jit.script 257 def tensor_to_int_script(x, tensor): 258 return x.unsqueeze(tensor) 259 260 # location present in error message 261 with self.assertRaisesRegex(RuntimeError, "x.unsqueeze"): 262 tensor_to_int_script(torch.tensor([2]), torch.tensor([2, 2])) 263 264 def tensor_to_int(x, tensor): 265 return x.unsqueeze(tensor) 266 267 @torch.jit.script 268 def tensor_to_float_script(x, tensor): 269 return x.addcmul(tensor, tensor, value=tensor) 270 271 def tensor_to_float(x, tensor): 272 return x.addcmul(tensor, tensor, value=tensor) 273 274 x = torch.zeros(10) 275 # float tensor, float tensor with grad, int tensor (can't set grad on int tensor) 276 tensors = [ 277 torch.tensor(1.1), 278 torch.tensor(1.1, requires_grad=True), 279 torch.tensor(0), 280 torch.tensor([2]), 281 ] 282 283 script_funs = [tensor_to_int_script, tensor_to_float_script] 284 funs = [tensor_to_int, tensor_to_float] 285 286 # return the result, or whether exception was thrown 287 def test_func(func, x, tensor): 288 try: 289 result = func(x, tensor) 290 except RuntimeError as e: 291 result = True 292 except TypeError as e: 293 result = True 294 return result 295 296 # assert result or exception equal for each (function, inputs) 297 for tensor in tensors: 298 for i in range(len(script_funs)): 299 self.assertEqual( 300 test_func(script_funs[i], x, tensor), test_func(funs[i], x, tensor) 301 ) 302