1# Owner(s): ["module: unknown"] 2 3from typing import Optional, List 4import torch 5from torch.testing._internal.common_utils import TestCase, run_tests, skipIfTorchDynamo 6 7# End-to-end tests of features in native_functions.yaml 8 9 10class FloatListWrapperModule(torch.nn.Module): 11 def forward(self, values, incr: Optional[List[float]]): 12 return torch._C._nn._test_optional_floatlist(values, incr) 13 14 15class IntListWrapperModule(torch.nn.Module): 16 def forward(self, values, incr: Optional[List[int]]): 17 return torch._C._nn._test_optional_intlist(values, incr) 18 19 20class TestNativeFunctions(TestCase): 21 22 def _lists_with_str(self): 23 return [ 24 ("foo",), 25 (2, "foo"), 26 ("foo", 3), 27 ["foo"], 28 [2, "foo"], 29 ["foo", 3], 30 "foo", 31 ] 32 33 def _test_raises_str_typeerror(self, fn): 34 for arg in self._lists_with_str(): 35 self.assertRaisesRegex(TypeError, "str", lambda: fn(arg)) 36 try: 37 fn(arg) 38 except TypeError as e: 39 print(e) 40 41 def test_symintlist_error(self): 42 x = torch.randn(1) 43 self._test_raises_str_typeerror(lambda arg: torch._C._nn.pad(x, arg)) 44 45 def test_vararg_symintlist_error(self): 46 self._test_raises_str_typeerror(lambda arg: torch.rand(arg)) 47 self._test_raises_str_typeerror(lambda arg: torch.rand(*arg)) 48 49 def test_symintlist_error_with_overload_but_is_unique(self): 50 x = torch.randn(1) 51 y = torch.randn(1) 52 self._test_raises_str_typeerror(lambda arg: x.set_(y, 0, arg)) 53 54 def test_symintlist_error_with_overload(self): 55 x = torch.randn(1) 56 self._test_raises_str_typeerror(lambda arg: x.view(arg)) 57 58 def test_intlist_error_with_overload(self): 59 x = torch.randn(1) 60 self._test_raises_str_typeerror(lambda arg: torch._C._nn.pad(x, arg)) 61 62 # 63 # optional float list 64 # 65 66 def do_test_optional_floatlist_with_module(self, module): 67 values = torch.tensor([1.5, 2.5], dtype=torch.float) 68 69 returned = module(values, None) 70 self.assertEqual(values, returned) 71 # Make sure that it's an alias, indicating that the operator saw a nullopt. 72 values[0] = 3.5 73 self.assertEqual(values, returned) 74 75 returned = module(values, [5.1, 4.1]) 76 self.assertEqual(values, torch.tensor([3.5, 2.5], dtype=torch.float)) 77 self.assertEqual(returned, torch.tensor([8.6, 6.6], dtype=torch.float)) 78 79 def trace_optional_floatlist(self, const): 80 def wrapper(values): 81 return torch._C._nn._test_optional_floatlist(values, const) 82 return torch.jit.trace(wrapper, torch.tensor([1.5, 2.5], dtype=torch.float)) 83 84 @skipIfTorchDynamo("Not a suitable test for TorchDynamo") 85 def test_optional_floatlist(self): 86 self.do_test_optional_floatlist_with_module(FloatListWrapperModule()) 87 self.do_test_optional_floatlist_with_module(torch.jit.script(FloatListWrapperModule())) 88 89 traced_none = self.trace_optional_floatlist(None) 90 traced_list = self.trace_optional_floatlist([5.1, 4.1]) 91 92 # Not really a module, just lets us use our two traced functions to handle 93 # the specific cases of passing None and [5.1, 4.1]. 94 def fake_module(values, const): 95 if const is None: 96 return traced_none(values) 97 if const == [5.1, 4.1]: 98 return traced_list(values) 99 raise Exception("Invalid argument") # noqa: TRY002 100 101 self.do_test_optional_floatlist_with_module(fake_module) 102 103 def test_optional_floatlist_invalid(self): 104 with self.assertRaisesRegex(TypeError, "must be tuple of floats, not list"): 105 FloatListWrapperModule()(torch.zeros(1), ["hi"]) 106 107 with self.assertRaisesRegex(RuntimeError, "value of type .* instead found type"): 108 torch.jit.script(FloatListWrapperModule())(torch.zeros(1), ["hi"]) 109 110 with self.assertRaisesRegex(TypeError, "must be .* Tensor"): 111 FloatListWrapperModule()(torch.zeros(1), torch.zeros(1)) 112 113 with self.assertRaisesRegex(RuntimeError, "value of type .* instead found type"): 114 torch.jit.script(FloatListWrapperModule())(torch.zeros(1), torch.zeros(1)) 115 116 # 117 # optional int list 118 # 119 120 def do_test_optional_intlist_with_module(self, module): 121 values = torch.tensor([1, 2], dtype=torch.int) 122 123 returned = module(values, None) 124 self.assertEqual(values, returned) 125 # Make sure that it's an alias, indicating that the operator saw a nullopt. 126 values[0] = 3 127 self.assertEqual(values, returned) 128 129 returned = module(values, [5, 4]) 130 self.assertEqual(values, torch.tensor([3, 2], dtype=torch.int)) 131 self.assertEqual(returned, torch.tensor([8, 6], dtype=torch.int)) 132 133 def trace_optional_intlist(self, const): 134 def wrapper(values): 135 return torch._C._nn._test_optional_intlist(values, const) 136 return torch.jit.trace(wrapper, torch.tensor([1, 2], dtype=torch.int)) 137 138 @skipIfTorchDynamo("Not a suitable test for TorchDynamo") 139 def test_optional_intlist(self): 140 self.do_test_optional_intlist_with_module(IntListWrapperModule()) 141 self.do_test_optional_intlist_with_module(torch.jit.script(IntListWrapperModule())) 142 143 traced_none = self.trace_optional_intlist(None) 144 traced_list = self.trace_optional_intlist([5, 4]) 145 146 # Not really a module, just lets us use our two traced functions to handle 147 # the specific cases of passing None and [5, 4]. 148 def fake_module(values, const): 149 if const is None: 150 return traced_none(values) 151 if const == [5, 4]: 152 return traced_list(values) 153 raise Exception("Invalid argument") # noqa: TRY002 154 155 self.do_test_optional_intlist_with_module(fake_module) 156 157 def test_optional_intlist_invalid(self): 158 with self.assertRaisesRegex(TypeError, "must be .* but found"): 159 IntListWrapperModule()(torch.zeros(1), [0.5]) 160 161 with self.assertRaisesRegex(RuntimeError, "value of type .* instead found type"): 162 torch.jit.script(IntListWrapperModule())(torch.zeros(1), [0.5]) 163 164 with self.assertRaisesRegex(TypeError, "must be .* Tensor"): 165 IntListWrapperModule()(torch.zeros(1), torch.zeros(1)) 166 167 with self.assertRaisesRegex(RuntimeError, "value of type .* instead found type"): 168 torch.jit.script(IntListWrapperModule())(torch.zeros(1), torch.zeros(1)) 169 170 # 171 # optional filled int list 172 # 173 174 def do_test_optional_filled_intlist_with_module(self, module): 175 values = torch.tensor([1, 2], dtype=torch.int) 176 177 returned = module(values, None) 178 self.assertEqual(values, returned) 179 # Make sure that it's an alias, indicating that the operator saw a nullopt. 180 values[0] = 3 181 self.assertEqual(values, returned) 182 183 returned = module(values, 10) 184 self.assertEqual(values, torch.tensor([3, 2], dtype=torch.int)) 185 self.assertEqual(returned, torch.tensor([13, 12], dtype=torch.int)) 186 187 def trace_optional_filled_intlist(self, const): 188 def wrapper(values): 189 return torch._C._nn._test_optional_filled_intlist(values, const) 190 return torch.jit.trace(wrapper, torch.tensor([1, 2], dtype=torch.int)) 191 192 @skipIfTorchDynamo("Not a suitable test for TorchDynamo") 193 def test_optional_filled_intlist(self): 194 195 def f(n: int): 196 x = torch._C._nn._test_optional_filled_intlist(torch.tensor([1, 1], dtype=torch.int), (n, n)) 197 y = torch._C._nn._test_optional_filled_intlist(torch.tensor([1, 1], dtype=torch.int), n) 198 return x, y 199 200 # eager 201 returned = f(10) 202 self.assertEqual(returned[0], returned[1]) 203 204 # scripted 205 s = torch.jit.script(f) 206 returned = s(10) 207 self.assertEqual(returned[0], returned[1]) 208 209 # traced 210 traced_none = self.trace_optional_filled_intlist(None) 211 traced_int = self.trace_optional_filled_intlist(10) 212 213 # Not really a module, just lets us use our two traced functions to handle 214 # the specific cases of passing None and 10. 215 def fake_module(values, const): 216 if const is None: 217 return traced_none(values) 218 if const == 10: 219 return traced_int(values) 220 raise Exception("Invalid argument") # noqa: TRY002 221 222 self.do_test_optional_filled_intlist_with_module(fake_module) 223 224 def test_string_defaults(self): 225 dummy = torch.rand(1) 226 fn = torch._C._nn._test_string_default 227 fn(dummy) 228 229 with self.assertRaisesRegex(RuntimeError, "A"): 230 fn(dummy, a="") 231 232 with self.assertRaisesRegex(RuntimeError, "B"): 233 fn(dummy, b="") 234 235 def f(x): 236 torch._C._nn._test_string_default(x) 237 scripted_fn = torch.jit.script(f) 238 scripted_fn(dummy) 239 240 241if __name__ == '__main__': 242 run_tests() 243