1*da0073e9SAndroid Build Coastguard Worker# Owner(s): ["oncall: jit"] 2*da0073e9SAndroid Build Coastguard Worker 3*da0073e9SAndroid Build Coastguard Workerimport inspect 4*da0073e9SAndroid Build Coastguard Workerimport os 5*da0073e9SAndroid Build Coastguard Workerimport sys 6*da0073e9SAndroid Build Coastguard Workerimport unittest 7*da0073e9SAndroid Build Coastguard Workerfrom typing import Dict, List 8*da0073e9SAndroid Build Coastguard Worker 9*da0073e9SAndroid Build Coastguard Workerimport torch 10*da0073e9SAndroid Build Coastguard Workerfrom torch.testing import FileCheck 11*da0073e9SAndroid Build Coastguard Worker 12*da0073e9SAndroid Build Coastguard Worker 13*da0073e9SAndroid Build Coastguard Worker# Make the helper files in test/ importable 14*da0073e9SAndroid Build Coastguard Workerpytorch_test_dir = os.path.dirname(os.path.dirname(os.path.realpath(__file__))) 15*da0073e9SAndroid Build Coastguard Workersys.path.append(pytorch_test_dir) 16*da0073e9SAndroid Build Coastguard Workerfrom torch.testing._internal.jit_utils import JitTestCase, RUN_CUDA 17*da0073e9SAndroid Build Coastguard Worker 18*da0073e9SAndroid Build Coastguard Worker 19*da0073e9SAndroid Build Coastguard Workerif __name__ == "__main__": 20*da0073e9SAndroid Build Coastguard Worker raise RuntimeError( 21*da0073e9SAndroid Build Coastguard Worker "This test file is not meant to be run directly, use:\n\n" 22*da0073e9SAndroid Build Coastguard Worker "\tpython test/test_jit.py TESTNAME\n\n" 23*da0073e9SAndroid Build Coastguard Worker "instead." 24*da0073e9SAndroid Build Coastguard Worker ) 25*da0073e9SAndroid Build Coastguard Worker 26*da0073e9SAndroid Build Coastguard Worker 27*da0073e9SAndroid Build Coastguard Workerclass TestBuiltins(JitTestCase): 28*da0073e9SAndroid Build Coastguard Worker """ 29*da0073e9SAndroid Build Coastguard Worker Tests for TorchScript support of Python builtin functions. 30*da0073e9SAndroid Build Coastguard Worker """ 31*da0073e9SAndroid Build Coastguard Worker 32*da0073e9SAndroid Build Coastguard Worker def test_has_attr(self): 33*da0073e9SAndroid Build Coastguard Worker class HasA(torch.nn.Module): 34*da0073e9SAndroid Build Coastguard Worker def __init__(self) -> None: 35*da0073e9SAndroid Build Coastguard Worker super().__init__() 36*da0073e9SAndroid Build Coastguard Worker self.a = 0 37*da0073e9SAndroid Build Coastguard Worker 38*da0073e9SAndroid Build Coastguard Worker class HasB(torch.nn.Module): 39*da0073e9SAndroid Build Coastguard Worker def __init__(self) -> None: 40*da0073e9SAndroid Build Coastguard Worker super().__init__() 41*da0073e9SAndroid Build Coastguard Worker self.b = 1 42*da0073e9SAndroid Build Coastguard Worker 43*da0073e9SAndroid Build Coastguard Worker class Mod(torch.nn.Module): 44*da0073e9SAndroid Build Coastguard Worker def __init__(self) -> None: 45*da0073e9SAndroid Build Coastguard Worker super().__init__() 46*da0073e9SAndroid Build Coastguard Worker self.mods = torch.nn.ModuleList([HasA(), HasB()]) 47*da0073e9SAndroid Build Coastguard Worker 48*da0073e9SAndroid Build Coastguard Worker def forward(self): 49*da0073e9SAndroid Build Coastguard Worker # use a list to encode hasattr results 50*da0073e9SAndroid Build Coastguard Worker l = torch.jit.annotate(List[int], []) 51*da0073e9SAndroid Build Coastguard Worker for mod in self.mods: 52*da0073e9SAndroid Build Coastguard Worker l.append(int(hasattr(mod, "a"))) 53*da0073e9SAndroid Build Coastguard Worker l.append(int(hasattr(mod, "b"))) 54*da0073e9SAndroid Build Coastguard Worker # actually retrieve the attr to test static refinement 55*da0073e9SAndroid Build Coastguard Worker if hasattr(mod, "a"): 56*da0073e9SAndroid Build Coastguard Worker l.append(mod.a) 57*da0073e9SAndroid Build Coastguard Worker if hasattr(mod, "b"): 58*da0073e9SAndroid Build Coastguard Worker l.append(mod.b) 59*da0073e9SAndroid Build Coastguard Worker return l 60*da0073e9SAndroid Build Coastguard Worker 61*da0073e9SAndroid Build Coastguard Worker self.checkModule(Mod(), ()) 62*da0073e9SAndroid Build Coastguard Worker 63*da0073e9SAndroid Build Coastguard Worker def test_has_attr_invalid_args(self): 64*da0073e9SAndroid Build Coastguard Worker class Mod(torch.nn.Module): 65*da0073e9SAndroid Build Coastguard Worker def __init__(self) -> None: 66*da0073e9SAndroid Build Coastguard Worker super().__init__() 67*da0073e9SAndroid Build Coastguard Worker self.mod = torch.nn.Linear(1, 1) 68*da0073e9SAndroid Build Coastguard Worker 69*da0073e9SAndroid Build Coastguard Worker def forward(self, name): 70*da0073e9SAndroid Build Coastguard Worker # not allowed, `name` must be static. 71*da0073e9SAndroid Build Coastguard Worker return hasattr(self.mod, name) 72*da0073e9SAndroid Build Coastguard Worker 73*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegexWithHighlight(RuntimeError, "hasattr", "name"): 74*da0073e9SAndroid Build Coastguard Worker torch.jit.script(Mod()) 75*da0073e9SAndroid Build Coastguard Worker 76*da0073e9SAndroid Build Coastguard Worker class Mod(torch.nn.Module): 77*da0073e9SAndroid Build Coastguard Worker def forward(self, name): 78*da0073e9SAndroid Build Coastguard Worker # not allowed, `torch.rand` is not a class type 79*da0073e9SAndroid Build Coastguard Worker return hasattr(torch.rand(2, 3), name) 80*da0073e9SAndroid Build Coastguard Worker 81*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegexWithHighlight(RuntimeError, "hasattr", "name"): 82*da0073e9SAndroid Build Coastguard Worker torch.jit.script(Mod()) 83*da0073e9SAndroid Build Coastguard Worker 84*da0073e9SAndroid Build Coastguard Worker def test_del(self): 85*da0073e9SAndroid Build Coastguard Worker def fn(x: List[int]) -> List[int]: 86*da0073e9SAndroid Build Coastguard Worker a = x * 2 87*da0073e9SAndroid Build Coastguard Worker del a 88*da0073e9SAndroid Build Coastguard Worker return x 89*da0073e9SAndroid Build Coastguard Worker 90*da0073e9SAndroid Build Coastguard Worker self.checkScript(fn, ([1, 2, 3],)) 91*da0073e9SAndroid Build Coastguard Worker 92*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegexWithHighlight(RuntimeError, "undefined value", "a"): 93*da0073e9SAndroid Build Coastguard Worker 94*da0073e9SAndroid Build Coastguard Worker @torch.jit.script 95*da0073e9SAndroid Build Coastguard Worker def fn(x): 96*da0073e9SAndroid Build Coastguard Worker a = x**2 97*da0073e9SAndroid Build Coastguard Worker del a 98*da0073e9SAndroid Build Coastguard Worker return a # noqa: F821 99*da0073e9SAndroid Build Coastguard Worker 100*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegexWithHighlight(RuntimeError, "undefined value", "a"): 101*da0073e9SAndroid Build Coastguard Worker 102*da0073e9SAndroid Build Coastguard Worker @torch.jit.script 103*da0073e9SAndroid Build Coastguard Worker def fn(x): 104*da0073e9SAndroid Build Coastguard Worker a = x**2 105*da0073e9SAndroid Build Coastguard Worker if a: 106*da0073e9SAndroid Build Coastguard Worker del a 107*da0073e9SAndroid Build Coastguard Worker return a 108*da0073e9SAndroid Build Coastguard Worker 109*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegexWithHighlight(RuntimeError, "undefined value", "b"): 110*da0073e9SAndroid Build Coastguard Worker 111*da0073e9SAndroid Build Coastguard Worker @torch.jit.script 112*da0073e9SAndroid Build Coastguard Worker def fn(x): 113*da0073e9SAndroid Build Coastguard Worker a = x**2 114*da0073e9SAndroid Build Coastguard Worker del b # noqa: F821 115*da0073e9SAndroid Build Coastguard Worker return a 116*da0073e9SAndroid Build Coastguard Worker 117*da0073e9SAndroid Build Coastguard Worker def test_del_multiple_operands(self): 118*da0073e9SAndroid Build Coastguard Worker def fn(x: List[int]) -> List[int]: 119*da0073e9SAndroid Build Coastguard Worker a, b, c = x[0], x[1], x[2] 120*da0073e9SAndroid Build Coastguard Worker del a, b, c 121*da0073e9SAndroid Build Coastguard Worker return x 122*da0073e9SAndroid Build Coastguard Worker 123*da0073e9SAndroid Build Coastguard Worker self.checkScript(fn, ([1, 2, 3],)) 124*da0073e9SAndroid Build Coastguard Worker 125*da0073e9SAndroid Build Coastguard Worker def del_list_multiple_operands(x: List[int]) -> List[int]: 126*da0073e9SAndroid Build Coastguard Worker del x[0], x[1] 127*da0073e9SAndroid Build Coastguard Worker return x 128*da0073e9SAndroid Build Coastguard Worker 129*da0073e9SAndroid Build Coastguard Worker py_out = del_list_multiple_operands([0, 1, 2]) 130*da0073e9SAndroid Build Coastguard Worker jit_out = torch.jit.script(del_list_multiple_operands)([0, 1, 2]) 131*da0073e9SAndroid Build Coastguard Worker self.assertEqual(py_out, jit_out) 132*da0073e9SAndroid Build Coastguard Worker 133*da0073e9SAndroid Build Coastguard Worker def del_dict_multiple_operands(x: Dict[str, int]) -> Dict[str, int]: 134*da0073e9SAndroid Build Coastguard Worker del x["hi"], x["there"] 135*da0073e9SAndroid Build Coastguard Worker return x 136*da0073e9SAndroid Build Coastguard Worker 137*da0073e9SAndroid Build Coastguard Worker py_out = del_dict_multiple_operands({"hi": 5, "there": 6}) 138*da0073e9SAndroid Build Coastguard Worker jit_out = torch.jit.script(del_dict_multiple_operands)({"hi": 5, "there": 6}) 139*da0073e9SAndroid Build Coastguard Worker self.assertEqual(py_out, jit_out) 140*da0073e9SAndroid Build Coastguard Worker 141*da0073e9SAndroid Build Coastguard Worker 142*da0073e9SAndroid Build Coastguard Workerclass TestTensorBuiltins(JitTestCase): 143*da0073e9SAndroid Build Coastguard Worker def test_tensor_properties(self): 144*da0073e9SAndroid Build Coastguard Worker def should_keep(tensor, name): 145*da0073e9SAndroid Build Coastguard Worker if inspect.isroutine(getattr(tensor, name)): 146*da0073e9SAndroid Build Coastguard Worker return False 147*da0073e9SAndroid Build Coastguard Worker if name.startswith("_"): 148*da0073e9SAndroid Build Coastguard Worker return False 149*da0073e9SAndroid Build Coastguard Worker return True 150*da0073e9SAndroid Build Coastguard Worker 151*da0073e9SAndroid Build Coastguard Worker tensor = torch.arange(4, dtype=torch.float).view(2, 2) 152*da0073e9SAndroid Build Coastguard Worker keys = dir(tensor) 153*da0073e9SAndroid Build Coastguard Worker 154*da0073e9SAndroid Build Coastguard Worker # real and imag are only implemented for complex tensors. 155*da0073e9SAndroid Build Coastguard Worker self.assertRaises(RuntimeError, lambda: should_keep(tensor, "imag")) 156*da0073e9SAndroid Build Coastguard Worker keys.remove("imag") 157*da0073e9SAndroid Build Coastguard Worker 158*da0073e9SAndroid Build Coastguard Worker properties = [p for p in keys if should_keep(tensor, p)] 159*da0073e9SAndroid Build Coastguard Worker 160*da0073e9SAndroid Build Coastguard Worker code_template = """ 161*da0073e9SAndroid Build Coastguard Worker def fn(x): 162*da0073e9SAndroid Build Coastguard Worker return x.{} 163*da0073e9SAndroid Build Coastguard Worker """ 164*da0073e9SAndroid Build Coastguard Worker 165*da0073e9SAndroid Build Coastguard Worker EQUALITY_MISMATCH = { 166*da0073e9SAndroid Build Coastguard Worker # TorchScript doesn't have real enums so they return an int instead 167*da0073e9SAndroid Build Coastguard Worker # of the actual value 168*da0073e9SAndroid Build Coastguard Worker "dtype", 169*da0073e9SAndroid Build Coastguard Worker "layout", 170*da0073e9SAndroid Build Coastguard Worker } 171*da0073e9SAndroid Build Coastguard Worker MISSING_PROPERTIES = { 172*da0073e9SAndroid Build Coastguard Worker "grad_fn", 173*da0073e9SAndroid Build Coastguard Worker # This is an undocumented property so it's not included 174*da0073e9SAndroid Build Coastguard Worker "output_nr", 175*da0073e9SAndroid Build Coastguard Worker # This has a longer implementation, maybe not worth copying to 176*da0073e9SAndroid Build Coastguard Worker # TorchScript if named tensors don't work there anyways 177*da0073e9SAndroid Build Coastguard Worker "names", 178*da0073e9SAndroid Build Coastguard Worker } 179*da0073e9SAndroid Build Coastguard Worker 180*da0073e9SAndroid Build Coastguard Worker for p in properties: 181*da0073e9SAndroid Build Coastguard Worker if p in MISSING_PROPERTIES: 182*da0073e9SAndroid Build Coastguard Worker continue 183*da0073e9SAndroid Build Coastguard Worker code = code_template.format(p) 184*da0073e9SAndroid Build Coastguard Worker cu = torch.jit.CompilationUnit() 185*da0073e9SAndroid Build Coastguard Worker cu.define(code) 186*da0073e9SAndroid Build Coastguard Worker if p in EQUALITY_MISMATCH: 187*da0073e9SAndroid Build Coastguard Worker continue 188*da0073e9SAndroid Build Coastguard Worker self.assertEqual(getattr(tensor, p), cu.fn(tensor)) 189*da0073e9SAndroid Build Coastguard Worker 190*da0073e9SAndroid Build Coastguard Worker def test_tensor_subscript_assign(self): 191*da0073e9SAndroid Build Coastguard Worker def fn1(x): 192*da0073e9SAndroid Build Coastguard Worker a = torch.zeros_like(x, dtype=torch.uint8) 193*da0073e9SAndroid Build Coastguard Worker a[torch.tensor(0)] = torch.tensor(2, dtype=torch.uint8) 194*da0073e9SAndroid Build Coastguard Worker return a 195*da0073e9SAndroid Build Coastguard Worker 196*da0073e9SAndroid Build Coastguard Worker def fn2(x): 197*da0073e9SAndroid Build Coastguard Worker a = torch.zeros_like(x, dtype=torch.uint8) 198*da0073e9SAndroid Build Coastguard Worker a[0] = 2 199*da0073e9SAndroid Build Coastguard Worker return a 200*da0073e9SAndroid Build Coastguard Worker 201*da0073e9SAndroid Build Coastguard Worker def fn3(x): 202*da0073e9SAndroid Build Coastguard Worker a = torch.zeros_like(x, dtype=torch.uint8) 203*da0073e9SAndroid Build Coastguard Worker a[torch.tensor(0)] = 2 204*da0073e9SAndroid Build Coastguard Worker return a 205*da0073e9SAndroid Build Coastguard Worker 206*da0073e9SAndroid Build Coastguard Worker def fn4(x): 207*da0073e9SAndroid Build Coastguard Worker a = torch.zeros_like(x, dtype=torch.uint8) 208*da0073e9SAndroid Build Coastguard Worker a[0] = torch.tensor(2, dtype=torch.uint8) 209*da0073e9SAndroid Build Coastguard Worker return a 210*da0073e9SAndroid Build Coastguard Worker 211*da0073e9SAndroid Build Coastguard Worker def fn5(x): 212*da0073e9SAndroid Build Coastguard Worker a = torch.zeros_like(x, dtype=torch.float32) 213*da0073e9SAndroid Build Coastguard Worker a[torch.tensor(0)] = 2 214*da0073e9SAndroid Build Coastguard Worker return a 215*da0073e9SAndroid Build Coastguard Worker 216*da0073e9SAndroid Build Coastguard Worker for fn in (fn1, fn2, fn3, fn4, fn5): 217*da0073e9SAndroid Build Coastguard Worker self.checkScript(fn, (torch.zeros(2, dtype=torch.uint8),)) 218*da0073e9SAndroid Build Coastguard Worker 219*da0073e9SAndroid Build Coastguard Worker @unittest.skipIf(not RUN_CUDA, "requires CUDA") 220*da0073e9SAndroid Build Coastguard Worker def test_tensor_subscript_assign_device(self): 221*da0073e9SAndroid Build Coastguard Worker def fn6(x): 222*da0073e9SAndroid Build Coastguard Worker a = torch.zeros_like(x, dtype=torch.float32, device="cuda") 223*da0073e9SAndroid Build Coastguard Worker a[torch.tensor(0)] = 2 224*da0073e9SAndroid Build Coastguard Worker return a 225*da0073e9SAndroid Build Coastguard Worker 226*da0073e9SAndroid Build Coastguard Worker self.checkScript(fn6, (torch.zeros(2, dtype=torch.float32, device="cuda"),)) 227*da0073e9SAndroid Build Coastguard Worker 228*da0073e9SAndroid Build Coastguard Worker def test_tensor_item(self): 229*da0073e9SAndroid Build Coastguard Worker def test_scalar_cast(x): 230*da0073e9SAndroid Build Coastguard Worker scalar = x.item() 231*da0073e9SAndroid Build Coastguard Worker return int(scalar), float(scalar) 232*da0073e9SAndroid Build Coastguard Worker 233*da0073e9SAndroid Build Coastguard Worker graph = torch.jit.script(test_scalar_cast).graph 234*da0073e9SAndroid Build Coastguard Worker FileCheck().check("(int, float) = prim::TupleConstruct").run(graph) 235*da0073e9SAndroid Build Coastguard Worker self.checkScript(test_scalar_cast, (torch.tensor(1.0),)) 236*da0073e9SAndroid Build Coastguard Worker self.checkScript(test_scalar_cast, (torch.tensor(1),)) 237*da0073e9SAndroid Build Coastguard Worker 238*da0073e9SAndroid Build Coastguard Worker def test_method_on_number(self): 239*da0073e9SAndroid Build Coastguard Worker def func(): 240*da0073e9SAndroid Build Coastguard Worker c = 1 241*da0073e9SAndroid Build Coastguard Worker return c.add(1) 242*da0073e9SAndroid Build Coastguard Worker 243*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(RuntimeError, "object has no attribute or method"): 244*da0073e9SAndroid Build Coastguard Worker torch.jit.script(func) 245*da0073e9SAndroid Build Coastguard Worker 246*da0073e9SAndroid Build Coastguard Worker # testing implicit conversion of tensors to scalars to match function arguments 247*da0073e9SAndroid Build Coastguard Worker def test_scalar_to_num_conversions(self): 248*da0073e9SAndroid Build Coastguard Worker @torch.jit.script 249*da0073e9SAndroid Build Coastguard Worker def multiple_defs(x): 250*da0073e9SAndroid Build Coastguard Worker c = 1 251*da0073e9SAndroid Build Coastguard Worker x = x + c 252*da0073e9SAndroid Build Coastguard Worker return x 253*da0073e9SAndroid Build Coastguard Worker 254*da0073e9SAndroid Build Coastguard Worker self.assertTrue("ImplicitTensorToNum" not in str(multiple_defs.graph)) 255*da0073e9SAndroid Build Coastguard Worker 256*da0073e9SAndroid Build Coastguard Worker @torch.jit.script 257*da0073e9SAndroid Build Coastguard Worker def tensor_to_int_script(x, tensor): 258*da0073e9SAndroid Build Coastguard Worker return x.unsqueeze(tensor) 259*da0073e9SAndroid Build Coastguard Worker 260*da0073e9SAndroid Build Coastguard Worker # location present in error message 261*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(RuntimeError, "x.unsqueeze"): 262*da0073e9SAndroid Build Coastguard Worker tensor_to_int_script(torch.tensor([2]), torch.tensor([2, 2])) 263*da0073e9SAndroid Build Coastguard Worker 264*da0073e9SAndroid Build Coastguard Worker def tensor_to_int(x, tensor): 265*da0073e9SAndroid Build Coastguard Worker return x.unsqueeze(tensor) 266*da0073e9SAndroid Build Coastguard Worker 267*da0073e9SAndroid Build Coastguard Worker @torch.jit.script 268*da0073e9SAndroid Build Coastguard Worker def tensor_to_float_script(x, tensor): 269*da0073e9SAndroid Build Coastguard Worker return x.addcmul(tensor, tensor, value=tensor) 270*da0073e9SAndroid Build Coastguard Worker 271*da0073e9SAndroid Build Coastguard Worker def tensor_to_float(x, tensor): 272*da0073e9SAndroid Build Coastguard Worker return x.addcmul(tensor, tensor, value=tensor) 273*da0073e9SAndroid Build Coastguard Worker 274*da0073e9SAndroid Build Coastguard Worker x = torch.zeros(10) 275*da0073e9SAndroid Build Coastguard Worker # float tensor, float tensor with grad, int tensor (can't set grad on int tensor) 276*da0073e9SAndroid Build Coastguard Worker tensors = [ 277*da0073e9SAndroid Build Coastguard Worker torch.tensor(1.1), 278*da0073e9SAndroid Build Coastguard Worker torch.tensor(1.1, requires_grad=True), 279*da0073e9SAndroid Build Coastguard Worker torch.tensor(0), 280*da0073e9SAndroid Build Coastguard Worker torch.tensor([2]), 281*da0073e9SAndroid Build Coastguard Worker ] 282*da0073e9SAndroid Build Coastguard Worker 283*da0073e9SAndroid Build Coastguard Worker script_funs = [tensor_to_int_script, tensor_to_float_script] 284*da0073e9SAndroid Build Coastguard Worker funs = [tensor_to_int, tensor_to_float] 285*da0073e9SAndroid Build Coastguard Worker 286*da0073e9SAndroid Build Coastguard Worker # return the result, or whether exception was thrown 287*da0073e9SAndroid Build Coastguard Worker def test_func(func, x, tensor): 288*da0073e9SAndroid Build Coastguard Worker try: 289*da0073e9SAndroid Build Coastguard Worker result = func(x, tensor) 290*da0073e9SAndroid Build Coastguard Worker except RuntimeError as e: 291*da0073e9SAndroid Build Coastguard Worker result = True 292*da0073e9SAndroid Build Coastguard Worker except TypeError as e: 293*da0073e9SAndroid Build Coastguard Worker result = True 294*da0073e9SAndroid Build Coastguard Worker return result 295*da0073e9SAndroid Build Coastguard Worker 296*da0073e9SAndroid Build Coastguard Worker # assert result or exception equal for each (function, inputs) 297*da0073e9SAndroid Build Coastguard Worker for tensor in tensors: 298*da0073e9SAndroid Build Coastguard Worker for i in range(len(script_funs)): 299*da0073e9SAndroid Build Coastguard Worker self.assertEqual( 300*da0073e9SAndroid Build Coastguard Worker test_func(script_funs[i], x, tensor), test_func(funs[i], x, tensor) 301*da0073e9SAndroid Build Coastguard Worker ) 302