1*da0073e9SAndroid Build Coastguard Worker# Owner(s): ["module: unknown"] 2*da0073e9SAndroid Build Coastguard Worker 3*da0073e9SAndroid Build Coastguard Workerfrom typing import Optional, List 4*da0073e9SAndroid Build Coastguard Workerimport torch 5*da0073e9SAndroid Build Coastguard Workerfrom torch.testing._internal.common_utils import TestCase, run_tests, skipIfTorchDynamo 6*da0073e9SAndroid Build Coastguard Worker 7*da0073e9SAndroid Build Coastguard Worker# End-to-end tests of features in native_functions.yaml 8*da0073e9SAndroid Build Coastguard Worker 9*da0073e9SAndroid Build Coastguard Worker 10*da0073e9SAndroid Build Coastguard Workerclass FloatListWrapperModule(torch.nn.Module): 11*da0073e9SAndroid Build Coastguard Worker def forward(self, values, incr: Optional[List[float]]): 12*da0073e9SAndroid Build Coastguard Worker return torch._C._nn._test_optional_floatlist(values, incr) 13*da0073e9SAndroid Build Coastguard Worker 14*da0073e9SAndroid Build Coastguard Worker 15*da0073e9SAndroid Build Coastguard Workerclass IntListWrapperModule(torch.nn.Module): 16*da0073e9SAndroid Build Coastguard Worker def forward(self, values, incr: Optional[List[int]]): 17*da0073e9SAndroid Build Coastguard Worker return torch._C._nn._test_optional_intlist(values, incr) 18*da0073e9SAndroid Build Coastguard Worker 19*da0073e9SAndroid Build Coastguard Worker 20*da0073e9SAndroid Build Coastguard Workerclass TestNativeFunctions(TestCase): 21*da0073e9SAndroid Build Coastguard Worker 22*da0073e9SAndroid Build Coastguard Worker def _lists_with_str(self): 23*da0073e9SAndroid Build Coastguard Worker return [ 24*da0073e9SAndroid Build Coastguard Worker ("foo",), 25*da0073e9SAndroid Build Coastguard Worker (2, "foo"), 26*da0073e9SAndroid Build Coastguard Worker ("foo", 3), 27*da0073e9SAndroid Build Coastguard Worker ["foo"], 28*da0073e9SAndroid Build Coastguard Worker [2, "foo"], 29*da0073e9SAndroid Build Coastguard Worker ["foo", 3], 30*da0073e9SAndroid Build Coastguard Worker "foo", 31*da0073e9SAndroid Build Coastguard Worker ] 32*da0073e9SAndroid Build Coastguard Worker 33*da0073e9SAndroid Build Coastguard Worker def _test_raises_str_typeerror(self, fn): 34*da0073e9SAndroid Build Coastguard Worker for arg in self._lists_with_str(): 35*da0073e9SAndroid Build Coastguard Worker self.assertRaisesRegex(TypeError, "str", lambda: fn(arg)) 36*da0073e9SAndroid Build Coastguard Worker try: 37*da0073e9SAndroid Build Coastguard Worker fn(arg) 38*da0073e9SAndroid Build Coastguard Worker except TypeError as e: 39*da0073e9SAndroid Build Coastguard Worker print(e) 40*da0073e9SAndroid Build Coastguard Worker 41*da0073e9SAndroid Build Coastguard Worker def test_symintlist_error(self): 42*da0073e9SAndroid Build Coastguard Worker x = torch.randn(1) 43*da0073e9SAndroid Build Coastguard Worker self._test_raises_str_typeerror(lambda arg: torch._C._nn.pad(x, arg)) 44*da0073e9SAndroid Build Coastguard Worker 45*da0073e9SAndroid Build Coastguard Worker def test_vararg_symintlist_error(self): 46*da0073e9SAndroid Build Coastguard Worker self._test_raises_str_typeerror(lambda arg: torch.rand(arg)) 47*da0073e9SAndroid Build Coastguard Worker self._test_raises_str_typeerror(lambda arg: torch.rand(*arg)) 48*da0073e9SAndroid Build Coastguard Worker 49*da0073e9SAndroid Build Coastguard Worker def test_symintlist_error_with_overload_but_is_unique(self): 50*da0073e9SAndroid Build Coastguard Worker x = torch.randn(1) 51*da0073e9SAndroid Build Coastguard Worker y = torch.randn(1) 52*da0073e9SAndroid Build Coastguard Worker self._test_raises_str_typeerror(lambda arg: x.set_(y, 0, arg)) 53*da0073e9SAndroid Build Coastguard Worker 54*da0073e9SAndroid Build Coastguard Worker def test_symintlist_error_with_overload(self): 55*da0073e9SAndroid Build Coastguard Worker x = torch.randn(1) 56*da0073e9SAndroid Build Coastguard Worker self._test_raises_str_typeerror(lambda arg: x.view(arg)) 57*da0073e9SAndroid Build Coastguard Worker 58*da0073e9SAndroid Build Coastguard Worker def test_intlist_error_with_overload(self): 59*da0073e9SAndroid Build Coastguard Worker x = torch.randn(1) 60*da0073e9SAndroid Build Coastguard Worker self._test_raises_str_typeerror(lambda arg: torch._C._nn.pad(x, arg)) 61*da0073e9SAndroid Build Coastguard Worker 62*da0073e9SAndroid Build Coastguard Worker # 63*da0073e9SAndroid Build Coastguard Worker # optional float list 64*da0073e9SAndroid Build Coastguard Worker # 65*da0073e9SAndroid Build Coastguard Worker 66*da0073e9SAndroid Build Coastguard Worker def do_test_optional_floatlist_with_module(self, module): 67*da0073e9SAndroid Build Coastguard Worker values = torch.tensor([1.5, 2.5], dtype=torch.float) 68*da0073e9SAndroid Build Coastguard Worker 69*da0073e9SAndroid Build Coastguard Worker returned = module(values, None) 70*da0073e9SAndroid Build Coastguard Worker self.assertEqual(values, returned) 71*da0073e9SAndroid Build Coastguard Worker # Make sure that it's an alias, indicating that the operator saw a nullopt. 72*da0073e9SAndroid Build Coastguard Worker values[0] = 3.5 73*da0073e9SAndroid Build Coastguard Worker self.assertEqual(values, returned) 74*da0073e9SAndroid Build Coastguard Worker 75*da0073e9SAndroid Build Coastguard Worker returned = module(values, [5.1, 4.1]) 76*da0073e9SAndroid Build Coastguard Worker self.assertEqual(values, torch.tensor([3.5, 2.5], dtype=torch.float)) 77*da0073e9SAndroid Build Coastguard Worker self.assertEqual(returned, torch.tensor([8.6, 6.6], dtype=torch.float)) 78*da0073e9SAndroid Build Coastguard Worker 79*da0073e9SAndroid Build Coastguard Worker def trace_optional_floatlist(self, const): 80*da0073e9SAndroid Build Coastguard Worker def wrapper(values): 81*da0073e9SAndroid Build Coastguard Worker return torch._C._nn._test_optional_floatlist(values, const) 82*da0073e9SAndroid Build Coastguard Worker return torch.jit.trace(wrapper, torch.tensor([1.5, 2.5], dtype=torch.float)) 83*da0073e9SAndroid Build Coastguard Worker 84*da0073e9SAndroid Build Coastguard Worker @skipIfTorchDynamo("Not a suitable test for TorchDynamo") 85*da0073e9SAndroid Build Coastguard Worker def test_optional_floatlist(self): 86*da0073e9SAndroid Build Coastguard Worker self.do_test_optional_floatlist_with_module(FloatListWrapperModule()) 87*da0073e9SAndroid Build Coastguard Worker self.do_test_optional_floatlist_with_module(torch.jit.script(FloatListWrapperModule())) 88*da0073e9SAndroid Build Coastguard Worker 89*da0073e9SAndroid Build Coastguard Worker traced_none = self.trace_optional_floatlist(None) 90*da0073e9SAndroid Build Coastguard Worker traced_list = self.trace_optional_floatlist([5.1, 4.1]) 91*da0073e9SAndroid Build Coastguard Worker 92*da0073e9SAndroid Build Coastguard Worker # Not really a module, just lets us use our two traced functions to handle 93*da0073e9SAndroid Build Coastguard Worker # the specific cases of passing None and [5.1, 4.1]. 94*da0073e9SAndroid Build Coastguard Worker def fake_module(values, const): 95*da0073e9SAndroid Build Coastguard Worker if const is None: 96*da0073e9SAndroid Build Coastguard Worker return traced_none(values) 97*da0073e9SAndroid Build Coastguard Worker if const == [5.1, 4.1]: 98*da0073e9SAndroid Build Coastguard Worker return traced_list(values) 99*da0073e9SAndroid Build Coastguard Worker raise Exception("Invalid argument") # noqa: TRY002 100*da0073e9SAndroid Build Coastguard Worker 101*da0073e9SAndroid Build Coastguard Worker self.do_test_optional_floatlist_with_module(fake_module) 102*da0073e9SAndroid Build Coastguard Worker 103*da0073e9SAndroid Build Coastguard Worker def test_optional_floatlist_invalid(self): 104*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(TypeError, "must be tuple of floats, not list"): 105*da0073e9SAndroid Build Coastguard Worker FloatListWrapperModule()(torch.zeros(1), ["hi"]) 106*da0073e9SAndroid Build Coastguard Worker 107*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(RuntimeError, "value of type .* instead found type"): 108*da0073e9SAndroid Build Coastguard Worker torch.jit.script(FloatListWrapperModule())(torch.zeros(1), ["hi"]) 109*da0073e9SAndroid Build Coastguard Worker 110*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(TypeError, "must be .* Tensor"): 111*da0073e9SAndroid Build Coastguard Worker FloatListWrapperModule()(torch.zeros(1), torch.zeros(1)) 112*da0073e9SAndroid Build Coastguard Worker 113*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(RuntimeError, "value of type .* instead found type"): 114*da0073e9SAndroid Build Coastguard Worker torch.jit.script(FloatListWrapperModule())(torch.zeros(1), torch.zeros(1)) 115*da0073e9SAndroid Build Coastguard Worker 116*da0073e9SAndroid Build Coastguard Worker # 117*da0073e9SAndroid Build Coastguard Worker # optional int list 118*da0073e9SAndroid Build Coastguard Worker # 119*da0073e9SAndroid Build Coastguard Worker 120*da0073e9SAndroid Build Coastguard Worker def do_test_optional_intlist_with_module(self, module): 121*da0073e9SAndroid Build Coastguard Worker values = torch.tensor([1, 2], dtype=torch.int) 122*da0073e9SAndroid Build Coastguard Worker 123*da0073e9SAndroid Build Coastguard Worker returned = module(values, None) 124*da0073e9SAndroid Build Coastguard Worker self.assertEqual(values, returned) 125*da0073e9SAndroid Build Coastguard Worker # Make sure that it's an alias, indicating that the operator saw a nullopt. 126*da0073e9SAndroid Build Coastguard Worker values[0] = 3 127*da0073e9SAndroid Build Coastguard Worker self.assertEqual(values, returned) 128*da0073e9SAndroid Build Coastguard Worker 129*da0073e9SAndroid Build Coastguard Worker returned = module(values, [5, 4]) 130*da0073e9SAndroid Build Coastguard Worker self.assertEqual(values, torch.tensor([3, 2], dtype=torch.int)) 131*da0073e9SAndroid Build Coastguard Worker self.assertEqual(returned, torch.tensor([8, 6], dtype=torch.int)) 132*da0073e9SAndroid Build Coastguard Worker 133*da0073e9SAndroid Build Coastguard Worker def trace_optional_intlist(self, const): 134*da0073e9SAndroid Build Coastguard Worker def wrapper(values): 135*da0073e9SAndroid Build Coastguard Worker return torch._C._nn._test_optional_intlist(values, const) 136*da0073e9SAndroid Build Coastguard Worker return torch.jit.trace(wrapper, torch.tensor([1, 2], dtype=torch.int)) 137*da0073e9SAndroid Build Coastguard Worker 138*da0073e9SAndroid Build Coastguard Worker @skipIfTorchDynamo("Not a suitable test for TorchDynamo") 139*da0073e9SAndroid Build Coastguard Worker def test_optional_intlist(self): 140*da0073e9SAndroid Build Coastguard Worker self.do_test_optional_intlist_with_module(IntListWrapperModule()) 141*da0073e9SAndroid Build Coastguard Worker self.do_test_optional_intlist_with_module(torch.jit.script(IntListWrapperModule())) 142*da0073e9SAndroid Build Coastguard Worker 143*da0073e9SAndroid Build Coastguard Worker traced_none = self.trace_optional_intlist(None) 144*da0073e9SAndroid Build Coastguard Worker traced_list = self.trace_optional_intlist([5, 4]) 145*da0073e9SAndroid Build Coastguard Worker 146*da0073e9SAndroid Build Coastguard Worker # Not really a module, just lets us use our two traced functions to handle 147*da0073e9SAndroid Build Coastguard Worker # the specific cases of passing None and [5, 4]. 148*da0073e9SAndroid Build Coastguard Worker def fake_module(values, const): 149*da0073e9SAndroid Build Coastguard Worker if const is None: 150*da0073e9SAndroid Build Coastguard Worker return traced_none(values) 151*da0073e9SAndroid Build Coastguard Worker if const == [5, 4]: 152*da0073e9SAndroid Build Coastguard Worker return traced_list(values) 153*da0073e9SAndroid Build Coastguard Worker raise Exception("Invalid argument") # noqa: TRY002 154*da0073e9SAndroid Build Coastguard Worker 155*da0073e9SAndroid Build Coastguard Worker self.do_test_optional_intlist_with_module(fake_module) 156*da0073e9SAndroid Build Coastguard Worker 157*da0073e9SAndroid Build Coastguard Worker def test_optional_intlist_invalid(self): 158*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(TypeError, "must be .* but found"): 159*da0073e9SAndroid Build Coastguard Worker IntListWrapperModule()(torch.zeros(1), [0.5]) 160*da0073e9SAndroid Build Coastguard Worker 161*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(RuntimeError, "value of type .* instead found type"): 162*da0073e9SAndroid Build Coastguard Worker torch.jit.script(IntListWrapperModule())(torch.zeros(1), [0.5]) 163*da0073e9SAndroid Build Coastguard Worker 164*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(TypeError, "must be .* Tensor"): 165*da0073e9SAndroid Build Coastguard Worker IntListWrapperModule()(torch.zeros(1), torch.zeros(1)) 166*da0073e9SAndroid Build Coastguard Worker 167*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(RuntimeError, "value of type .* instead found type"): 168*da0073e9SAndroid Build Coastguard Worker torch.jit.script(IntListWrapperModule())(torch.zeros(1), torch.zeros(1)) 169*da0073e9SAndroid Build Coastguard Worker 170*da0073e9SAndroid Build Coastguard Worker # 171*da0073e9SAndroid Build Coastguard Worker # optional filled int list 172*da0073e9SAndroid Build Coastguard Worker # 173*da0073e9SAndroid Build Coastguard Worker 174*da0073e9SAndroid Build Coastguard Worker def do_test_optional_filled_intlist_with_module(self, module): 175*da0073e9SAndroid Build Coastguard Worker values = torch.tensor([1, 2], dtype=torch.int) 176*da0073e9SAndroid Build Coastguard Worker 177*da0073e9SAndroid Build Coastguard Worker returned = module(values, None) 178*da0073e9SAndroid Build Coastguard Worker self.assertEqual(values, returned) 179*da0073e9SAndroid Build Coastguard Worker # Make sure that it's an alias, indicating that the operator saw a nullopt. 180*da0073e9SAndroid Build Coastguard Worker values[0] = 3 181*da0073e9SAndroid Build Coastguard Worker self.assertEqual(values, returned) 182*da0073e9SAndroid Build Coastguard Worker 183*da0073e9SAndroid Build Coastguard Worker returned = module(values, 10) 184*da0073e9SAndroid Build Coastguard Worker self.assertEqual(values, torch.tensor([3, 2], dtype=torch.int)) 185*da0073e9SAndroid Build Coastguard Worker self.assertEqual(returned, torch.tensor([13, 12], dtype=torch.int)) 186*da0073e9SAndroid Build Coastguard Worker 187*da0073e9SAndroid Build Coastguard Worker def trace_optional_filled_intlist(self, const): 188*da0073e9SAndroid Build Coastguard Worker def wrapper(values): 189*da0073e9SAndroid Build Coastguard Worker return torch._C._nn._test_optional_filled_intlist(values, const) 190*da0073e9SAndroid Build Coastguard Worker return torch.jit.trace(wrapper, torch.tensor([1, 2], dtype=torch.int)) 191*da0073e9SAndroid Build Coastguard Worker 192*da0073e9SAndroid Build Coastguard Worker @skipIfTorchDynamo("Not a suitable test for TorchDynamo") 193*da0073e9SAndroid Build Coastguard Worker def test_optional_filled_intlist(self): 194*da0073e9SAndroid Build Coastguard Worker 195*da0073e9SAndroid Build Coastguard Worker def f(n: int): 196*da0073e9SAndroid Build Coastguard Worker x = torch._C._nn._test_optional_filled_intlist(torch.tensor([1, 1], dtype=torch.int), (n, n)) 197*da0073e9SAndroid Build Coastguard Worker y = torch._C._nn._test_optional_filled_intlist(torch.tensor([1, 1], dtype=torch.int), n) 198*da0073e9SAndroid Build Coastguard Worker return x, y 199*da0073e9SAndroid Build Coastguard Worker 200*da0073e9SAndroid Build Coastguard Worker # eager 201*da0073e9SAndroid Build Coastguard Worker returned = f(10) 202*da0073e9SAndroid Build Coastguard Worker self.assertEqual(returned[0], returned[1]) 203*da0073e9SAndroid Build Coastguard Worker 204*da0073e9SAndroid Build Coastguard Worker # scripted 205*da0073e9SAndroid Build Coastguard Worker s = torch.jit.script(f) 206*da0073e9SAndroid Build Coastguard Worker returned = s(10) 207*da0073e9SAndroid Build Coastguard Worker self.assertEqual(returned[0], returned[1]) 208*da0073e9SAndroid Build Coastguard Worker 209*da0073e9SAndroid Build Coastguard Worker # traced 210*da0073e9SAndroid Build Coastguard Worker traced_none = self.trace_optional_filled_intlist(None) 211*da0073e9SAndroid Build Coastguard Worker traced_int = self.trace_optional_filled_intlist(10) 212*da0073e9SAndroid Build Coastguard Worker 213*da0073e9SAndroid Build Coastguard Worker # Not really a module, just lets us use our two traced functions to handle 214*da0073e9SAndroid Build Coastguard Worker # the specific cases of passing None and 10. 215*da0073e9SAndroid Build Coastguard Worker def fake_module(values, const): 216*da0073e9SAndroid Build Coastguard Worker if const is None: 217*da0073e9SAndroid Build Coastguard Worker return traced_none(values) 218*da0073e9SAndroid Build Coastguard Worker if const == 10: 219*da0073e9SAndroid Build Coastguard Worker return traced_int(values) 220*da0073e9SAndroid Build Coastguard Worker raise Exception("Invalid argument") # noqa: TRY002 221*da0073e9SAndroid Build Coastguard Worker 222*da0073e9SAndroid Build Coastguard Worker self.do_test_optional_filled_intlist_with_module(fake_module) 223*da0073e9SAndroid Build Coastguard Worker 224*da0073e9SAndroid Build Coastguard Worker def test_string_defaults(self): 225*da0073e9SAndroid Build Coastguard Worker dummy = torch.rand(1) 226*da0073e9SAndroid Build Coastguard Worker fn = torch._C._nn._test_string_default 227*da0073e9SAndroid Build Coastguard Worker fn(dummy) 228*da0073e9SAndroid Build Coastguard Worker 229*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(RuntimeError, "A"): 230*da0073e9SAndroid Build Coastguard Worker fn(dummy, a="") 231*da0073e9SAndroid Build Coastguard Worker 232*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(RuntimeError, "B"): 233*da0073e9SAndroid Build Coastguard Worker fn(dummy, b="") 234*da0073e9SAndroid Build Coastguard Worker 235*da0073e9SAndroid Build Coastguard Worker def f(x): 236*da0073e9SAndroid Build Coastguard Worker torch._C._nn._test_string_default(x) 237*da0073e9SAndroid Build Coastguard Worker scripted_fn = torch.jit.script(f) 238*da0073e9SAndroid Build Coastguard Worker scripted_fn(dummy) 239*da0073e9SAndroid Build Coastguard Worker 240*da0073e9SAndroid Build Coastguard Worker 241*da0073e9SAndroid Build Coastguard Workerif __name__ == '__main__': 242*da0073e9SAndroid Build Coastguard Worker run_tests() 243