xref: /aosp_15_r20/external/pytorch/test/test_native_functions.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# Owner(s): ["module: unknown"]
2
3from typing import Optional, List
4import torch
5from torch.testing._internal.common_utils import TestCase, run_tests, skipIfTorchDynamo
6
7# End-to-end tests of features in native_functions.yaml
8
9
10class FloatListWrapperModule(torch.nn.Module):
11    def forward(self, values, incr: Optional[List[float]]):
12        return torch._C._nn._test_optional_floatlist(values, incr)
13
14
15class IntListWrapperModule(torch.nn.Module):
16    def forward(self, values, incr: Optional[List[int]]):
17        return torch._C._nn._test_optional_intlist(values, incr)
18
19
20class TestNativeFunctions(TestCase):
21
22    def _lists_with_str(self):
23        return [
24            ("foo",),
25            (2, "foo"),
26            ("foo", 3),
27            ["foo"],
28            [2, "foo"],
29            ["foo", 3],
30            "foo",
31        ]
32
33    def _test_raises_str_typeerror(self, fn):
34        for arg in self._lists_with_str():
35            self.assertRaisesRegex(TypeError, "str", lambda: fn(arg))
36            try:
37                fn(arg)
38            except TypeError as e:
39                print(e)
40
41    def test_symintlist_error(self):
42        x = torch.randn(1)
43        self._test_raises_str_typeerror(lambda arg: torch._C._nn.pad(x, arg))
44
45    def test_vararg_symintlist_error(self):
46        self._test_raises_str_typeerror(lambda arg: torch.rand(arg))
47        self._test_raises_str_typeerror(lambda arg: torch.rand(*arg))
48
49    def test_symintlist_error_with_overload_but_is_unique(self):
50        x = torch.randn(1)
51        y = torch.randn(1)
52        self._test_raises_str_typeerror(lambda arg: x.set_(y, 0, arg))
53
54    def test_symintlist_error_with_overload(self):
55        x = torch.randn(1)
56        self._test_raises_str_typeerror(lambda arg: x.view(arg))
57
58    def test_intlist_error_with_overload(self):
59        x = torch.randn(1)
60        self._test_raises_str_typeerror(lambda arg: torch._C._nn.pad(x, arg))
61
62    #
63    # optional float list
64    #
65
66    def do_test_optional_floatlist_with_module(self, module):
67        values = torch.tensor([1.5, 2.5], dtype=torch.float)
68
69        returned = module(values, None)
70        self.assertEqual(values, returned)
71        # Make sure that it's an alias, indicating that the operator saw a nullopt.
72        values[0] = 3.5
73        self.assertEqual(values, returned)
74
75        returned = module(values, [5.1, 4.1])
76        self.assertEqual(values, torch.tensor([3.5, 2.5], dtype=torch.float))
77        self.assertEqual(returned, torch.tensor([8.6, 6.6], dtype=torch.float))
78
79    def trace_optional_floatlist(self, const):
80        def wrapper(values):
81            return torch._C._nn._test_optional_floatlist(values, const)
82        return torch.jit.trace(wrapper, torch.tensor([1.5, 2.5], dtype=torch.float))
83
84    @skipIfTorchDynamo("Not a suitable test for TorchDynamo")
85    def test_optional_floatlist(self):
86        self.do_test_optional_floatlist_with_module(FloatListWrapperModule())
87        self.do_test_optional_floatlist_with_module(torch.jit.script(FloatListWrapperModule()))
88
89        traced_none = self.trace_optional_floatlist(None)
90        traced_list = self.trace_optional_floatlist([5.1, 4.1])
91
92        # Not really a module, just lets us use our two traced functions to handle
93        # the specific cases of passing None and [5.1, 4.1].
94        def fake_module(values, const):
95            if const is None:
96                return traced_none(values)
97            if const == [5.1, 4.1]:
98                return traced_list(values)
99            raise Exception("Invalid argument")  # noqa: TRY002
100
101        self.do_test_optional_floatlist_with_module(fake_module)
102
103    def test_optional_floatlist_invalid(self):
104        with self.assertRaisesRegex(TypeError, "must be tuple of floats, not list"):
105            FloatListWrapperModule()(torch.zeros(1), ["hi"])
106
107        with self.assertRaisesRegex(RuntimeError, "value of type .* instead found type"):
108            torch.jit.script(FloatListWrapperModule())(torch.zeros(1), ["hi"])
109
110        with self.assertRaisesRegex(TypeError, "must be .* Tensor"):
111            FloatListWrapperModule()(torch.zeros(1), torch.zeros(1))
112
113        with self.assertRaisesRegex(RuntimeError, "value of type .* instead found type"):
114            torch.jit.script(FloatListWrapperModule())(torch.zeros(1), torch.zeros(1))
115
116    #
117    # optional int list
118    #
119
120    def do_test_optional_intlist_with_module(self, module):
121        values = torch.tensor([1, 2], dtype=torch.int)
122
123        returned = module(values, None)
124        self.assertEqual(values, returned)
125        # Make sure that it's an alias, indicating that the operator saw a nullopt.
126        values[0] = 3
127        self.assertEqual(values, returned)
128
129        returned = module(values, [5, 4])
130        self.assertEqual(values, torch.tensor([3, 2], dtype=torch.int))
131        self.assertEqual(returned, torch.tensor([8, 6], dtype=torch.int))
132
133    def trace_optional_intlist(self, const):
134        def wrapper(values):
135            return torch._C._nn._test_optional_intlist(values, const)
136        return torch.jit.trace(wrapper, torch.tensor([1, 2], dtype=torch.int))
137
138    @skipIfTorchDynamo("Not a suitable test for TorchDynamo")
139    def test_optional_intlist(self):
140        self.do_test_optional_intlist_with_module(IntListWrapperModule())
141        self.do_test_optional_intlist_with_module(torch.jit.script(IntListWrapperModule()))
142
143        traced_none = self.trace_optional_intlist(None)
144        traced_list = self.trace_optional_intlist([5, 4])
145
146        # Not really a module, just lets us use our two traced functions to handle
147        # the specific cases of passing None and [5, 4].
148        def fake_module(values, const):
149            if const is None:
150                return traced_none(values)
151            if const == [5, 4]:
152                return traced_list(values)
153            raise Exception("Invalid argument")  # noqa: TRY002
154
155        self.do_test_optional_intlist_with_module(fake_module)
156
157    def test_optional_intlist_invalid(self):
158        with self.assertRaisesRegex(TypeError, "must be .* but found"):
159            IntListWrapperModule()(torch.zeros(1), [0.5])
160
161        with self.assertRaisesRegex(RuntimeError, "value of type .* instead found type"):
162            torch.jit.script(IntListWrapperModule())(torch.zeros(1), [0.5])
163
164        with self.assertRaisesRegex(TypeError, "must be .* Tensor"):
165            IntListWrapperModule()(torch.zeros(1), torch.zeros(1))
166
167        with self.assertRaisesRegex(RuntimeError, "value of type .* instead found type"):
168            torch.jit.script(IntListWrapperModule())(torch.zeros(1), torch.zeros(1))
169
170    #
171    # optional filled int list
172    #
173
174    def do_test_optional_filled_intlist_with_module(self, module):
175        values = torch.tensor([1, 2], dtype=torch.int)
176
177        returned = module(values, None)
178        self.assertEqual(values, returned)
179        # Make sure that it's an alias, indicating that the operator saw a nullopt.
180        values[0] = 3
181        self.assertEqual(values, returned)
182
183        returned = module(values, 10)
184        self.assertEqual(values, torch.tensor([3, 2], dtype=torch.int))
185        self.assertEqual(returned, torch.tensor([13, 12], dtype=torch.int))
186
187    def trace_optional_filled_intlist(self, const):
188        def wrapper(values):
189            return torch._C._nn._test_optional_filled_intlist(values, const)
190        return torch.jit.trace(wrapper, torch.tensor([1, 2], dtype=torch.int))
191
192    @skipIfTorchDynamo("Not a suitable test for TorchDynamo")
193    def test_optional_filled_intlist(self):
194
195        def f(n: int):
196            x = torch._C._nn._test_optional_filled_intlist(torch.tensor([1, 1], dtype=torch.int), (n, n))
197            y = torch._C._nn._test_optional_filled_intlist(torch.tensor([1, 1], dtype=torch.int), n)
198            return x, y
199
200        # eager
201        returned = f(10)
202        self.assertEqual(returned[0], returned[1])
203
204        # scripted
205        s = torch.jit.script(f)
206        returned = s(10)
207        self.assertEqual(returned[0], returned[1])
208
209        # traced
210        traced_none = self.trace_optional_filled_intlist(None)
211        traced_int = self.trace_optional_filled_intlist(10)
212
213        # Not really a module, just lets us use our two traced functions to handle
214        # the specific cases of passing None and 10.
215        def fake_module(values, const):
216            if const is None:
217                return traced_none(values)
218            if const == 10:
219                return traced_int(values)
220            raise Exception("Invalid argument")  # noqa: TRY002
221
222        self.do_test_optional_filled_intlist_with_module(fake_module)
223
224    def test_string_defaults(self):
225        dummy = torch.rand(1)
226        fn = torch._C._nn._test_string_default
227        fn(dummy)
228
229        with self.assertRaisesRegex(RuntimeError, "A"):
230            fn(dummy, a="")
231
232        with self.assertRaisesRegex(RuntimeError, "B"):
233            fn(dummy, b="")
234
235        def f(x):
236            torch._C._nn._test_string_default(x)
237        scripted_fn = torch.jit.script(f)
238        scripted_fn(dummy)
239
240
241if __name__ == '__main__':
242    run_tests()
243