xref: /aosp_15_r20/external/pytorch/test/test_native_functions.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1*da0073e9SAndroid Build Coastguard Worker# Owner(s): ["module: unknown"]
2*da0073e9SAndroid Build Coastguard Worker
3*da0073e9SAndroid Build Coastguard Workerfrom typing import Optional, List
4*da0073e9SAndroid Build Coastguard Workerimport torch
5*da0073e9SAndroid Build Coastguard Workerfrom torch.testing._internal.common_utils import TestCase, run_tests, skipIfTorchDynamo
6*da0073e9SAndroid Build Coastguard Worker
7*da0073e9SAndroid Build Coastguard Worker# End-to-end tests of features in native_functions.yaml
8*da0073e9SAndroid Build Coastguard Worker
9*da0073e9SAndroid Build Coastguard Worker
10*da0073e9SAndroid Build Coastguard Workerclass FloatListWrapperModule(torch.nn.Module):
11*da0073e9SAndroid Build Coastguard Worker    def forward(self, values, incr: Optional[List[float]]):
12*da0073e9SAndroid Build Coastguard Worker        return torch._C._nn._test_optional_floatlist(values, incr)
13*da0073e9SAndroid Build Coastguard Worker
14*da0073e9SAndroid Build Coastguard Worker
15*da0073e9SAndroid Build Coastguard Workerclass IntListWrapperModule(torch.nn.Module):
16*da0073e9SAndroid Build Coastguard Worker    def forward(self, values, incr: Optional[List[int]]):
17*da0073e9SAndroid Build Coastguard Worker        return torch._C._nn._test_optional_intlist(values, incr)
18*da0073e9SAndroid Build Coastguard Worker
19*da0073e9SAndroid Build Coastguard Worker
20*da0073e9SAndroid Build Coastguard Workerclass TestNativeFunctions(TestCase):
21*da0073e9SAndroid Build Coastguard Worker
22*da0073e9SAndroid Build Coastguard Worker    def _lists_with_str(self):
23*da0073e9SAndroid Build Coastguard Worker        return [
24*da0073e9SAndroid Build Coastguard Worker            ("foo",),
25*da0073e9SAndroid Build Coastguard Worker            (2, "foo"),
26*da0073e9SAndroid Build Coastguard Worker            ("foo", 3),
27*da0073e9SAndroid Build Coastguard Worker            ["foo"],
28*da0073e9SAndroid Build Coastguard Worker            [2, "foo"],
29*da0073e9SAndroid Build Coastguard Worker            ["foo", 3],
30*da0073e9SAndroid Build Coastguard Worker            "foo",
31*da0073e9SAndroid Build Coastguard Worker        ]
32*da0073e9SAndroid Build Coastguard Worker
33*da0073e9SAndroid Build Coastguard Worker    def _test_raises_str_typeerror(self, fn):
34*da0073e9SAndroid Build Coastguard Worker        for arg in self._lists_with_str():
35*da0073e9SAndroid Build Coastguard Worker            self.assertRaisesRegex(TypeError, "str", lambda: fn(arg))
36*da0073e9SAndroid Build Coastguard Worker            try:
37*da0073e9SAndroid Build Coastguard Worker                fn(arg)
38*da0073e9SAndroid Build Coastguard Worker            except TypeError as e:
39*da0073e9SAndroid Build Coastguard Worker                print(e)
40*da0073e9SAndroid Build Coastguard Worker
41*da0073e9SAndroid Build Coastguard Worker    def test_symintlist_error(self):
42*da0073e9SAndroid Build Coastguard Worker        x = torch.randn(1)
43*da0073e9SAndroid Build Coastguard Worker        self._test_raises_str_typeerror(lambda arg: torch._C._nn.pad(x, arg))
44*da0073e9SAndroid Build Coastguard Worker
45*da0073e9SAndroid Build Coastguard Worker    def test_vararg_symintlist_error(self):
46*da0073e9SAndroid Build Coastguard Worker        self._test_raises_str_typeerror(lambda arg: torch.rand(arg))
47*da0073e9SAndroid Build Coastguard Worker        self._test_raises_str_typeerror(lambda arg: torch.rand(*arg))
48*da0073e9SAndroid Build Coastguard Worker
49*da0073e9SAndroid Build Coastguard Worker    def test_symintlist_error_with_overload_but_is_unique(self):
50*da0073e9SAndroid Build Coastguard Worker        x = torch.randn(1)
51*da0073e9SAndroid Build Coastguard Worker        y = torch.randn(1)
52*da0073e9SAndroid Build Coastguard Worker        self._test_raises_str_typeerror(lambda arg: x.set_(y, 0, arg))
53*da0073e9SAndroid Build Coastguard Worker
54*da0073e9SAndroid Build Coastguard Worker    def test_symintlist_error_with_overload(self):
55*da0073e9SAndroid Build Coastguard Worker        x = torch.randn(1)
56*da0073e9SAndroid Build Coastguard Worker        self._test_raises_str_typeerror(lambda arg: x.view(arg))
57*da0073e9SAndroid Build Coastguard Worker
58*da0073e9SAndroid Build Coastguard Worker    def test_intlist_error_with_overload(self):
59*da0073e9SAndroid Build Coastguard Worker        x = torch.randn(1)
60*da0073e9SAndroid Build Coastguard Worker        self._test_raises_str_typeerror(lambda arg: torch._C._nn.pad(x, arg))
61*da0073e9SAndroid Build Coastguard Worker
62*da0073e9SAndroid Build Coastguard Worker    #
63*da0073e9SAndroid Build Coastguard Worker    # optional float list
64*da0073e9SAndroid Build Coastguard Worker    #
65*da0073e9SAndroid Build Coastguard Worker
66*da0073e9SAndroid Build Coastguard Worker    def do_test_optional_floatlist_with_module(self, module):
67*da0073e9SAndroid Build Coastguard Worker        values = torch.tensor([1.5, 2.5], dtype=torch.float)
68*da0073e9SAndroid Build Coastguard Worker
69*da0073e9SAndroid Build Coastguard Worker        returned = module(values, None)
70*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(values, returned)
71*da0073e9SAndroid Build Coastguard Worker        # Make sure that it's an alias, indicating that the operator saw a nullopt.
72*da0073e9SAndroid Build Coastguard Worker        values[0] = 3.5
73*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(values, returned)
74*da0073e9SAndroid Build Coastguard Worker
75*da0073e9SAndroid Build Coastguard Worker        returned = module(values, [5.1, 4.1])
76*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(values, torch.tensor([3.5, 2.5], dtype=torch.float))
77*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(returned, torch.tensor([8.6, 6.6], dtype=torch.float))
78*da0073e9SAndroid Build Coastguard Worker
79*da0073e9SAndroid Build Coastguard Worker    def trace_optional_floatlist(self, const):
80*da0073e9SAndroid Build Coastguard Worker        def wrapper(values):
81*da0073e9SAndroid Build Coastguard Worker            return torch._C._nn._test_optional_floatlist(values, const)
82*da0073e9SAndroid Build Coastguard Worker        return torch.jit.trace(wrapper, torch.tensor([1.5, 2.5], dtype=torch.float))
83*da0073e9SAndroid Build Coastguard Worker
84*da0073e9SAndroid Build Coastguard Worker    @skipIfTorchDynamo("Not a suitable test for TorchDynamo")
85*da0073e9SAndroid Build Coastguard Worker    def test_optional_floatlist(self):
86*da0073e9SAndroid Build Coastguard Worker        self.do_test_optional_floatlist_with_module(FloatListWrapperModule())
87*da0073e9SAndroid Build Coastguard Worker        self.do_test_optional_floatlist_with_module(torch.jit.script(FloatListWrapperModule()))
88*da0073e9SAndroid Build Coastguard Worker
89*da0073e9SAndroid Build Coastguard Worker        traced_none = self.trace_optional_floatlist(None)
90*da0073e9SAndroid Build Coastguard Worker        traced_list = self.trace_optional_floatlist([5.1, 4.1])
91*da0073e9SAndroid Build Coastguard Worker
92*da0073e9SAndroid Build Coastguard Worker        # Not really a module, just lets us use our two traced functions to handle
93*da0073e9SAndroid Build Coastguard Worker        # the specific cases of passing None and [5.1, 4.1].
94*da0073e9SAndroid Build Coastguard Worker        def fake_module(values, const):
95*da0073e9SAndroid Build Coastguard Worker            if const is None:
96*da0073e9SAndroid Build Coastguard Worker                return traced_none(values)
97*da0073e9SAndroid Build Coastguard Worker            if const == [5.1, 4.1]:
98*da0073e9SAndroid Build Coastguard Worker                return traced_list(values)
99*da0073e9SAndroid Build Coastguard Worker            raise Exception("Invalid argument")  # noqa: TRY002
100*da0073e9SAndroid Build Coastguard Worker
101*da0073e9SAndroid Build Coastguard Worker        self.do_test_optional_floatlist_with_module(fake_module)
102*da0073e9SAndroid Build Coastguard Worker
103*da0073e9SAndroid Build Coastguard Worker    def test_optional_floatlist_invalid(self):
104*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(TypeError, "must be tuple of floats, not list"):
105*da0073e9SAndroid Build Coastguard Worker            FloatListWrapperModule()(torch.zeros(1), ["hi"])
106*da0073e9SAndroid Build Coastguard Worker
107*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(RuntimeError, "value of type .* instead found type"):
108*da0073e9SAndroid Build Coastguard Worker            torch.jit.script(FloatListWrapperModule())(torch.zeros(1), ["hi"])
109*da0073e9SAndroid Build Coastguard Worker
110*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(TypeError, "must be .* Tensor"):
111*da0073e9SAndroid Build Coastguard Worker            FloatListWrapperModule()(torch.zeros(1), torch.zeros(1))
112*da0073e9SAndroid Build Coastguard Worker
113*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(RuntimeError, "value of type .* instead found type"):
114*da0073e9SAndroid Build Coastguard Worker            torch.jit.script(FloatListWrapperModule())(torch.zeros(1), torch.zeros(1))
115*da0073e9SAndroid Build Coastguard Worker
116*da0073e9SAndroid Build Coastguard Worker    #
117*da0073e9SAndroid Build Coastguard Worker    # optional int list
118*da0073e9SAndroid Build Coastguard Worker    #
119*da0073e9SAndroid Build Coastguard Worker
120*da0073e9SAndroid Build Coastguard Worker    def do_test_optional_intlist_with_module(self, module):
121*da0073e9SAndroid Build Coastguard Worker        values = torch.tensor([1, 2], dtype=torch.int)
122*da0073e9SAndroid Build Coastguard Worker
123*da0073e9SAndroid Build Coastguard Worker        returned = module(values, None)
124*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(values, returned)
125*da0073e9SAndroid Build Coastguard Worker        # Make sure that it's an alias, indicating that the operator saw a nullopt.
126*da0073e9SAndroid Build Coastguard Worker        values[0] = 3
127*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(values, returned)
128*da0073e9SAndroid Build Coastguard Worker
129*da0073e9SAndroid Build Coastguard Worker        returned = module(values, [5, 4])
130*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(values, torch.tensor([3, 2], dtype=torch.int))
131*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(returned, torch.tensor([8, 6], dtype=torch.int))
132*da0073e9SAndroid Build Coastguard Worker
133*da0073e9SAndroid Build Coastguard Worker    def trace_optional_intlist(self, const):
134*da0073e9SAndroid Build Coastguard Worker        def wrapper(values):
135*da0073e9SAndroid Build Coastguard Worker            return torch._C._nn._test_optional_intlist(values, const)
136*da0073e9SAndroid Build Coastguard Worker        return torch.jit.trace(wrapper, torch.tensor([1, 2], dtype=torch.int))
137*da0073e9SAndroid Build Coastguard Worker
138*da0073e9SAndroid Build Coastguard Worker    @skipIfTorchDynamo("Not a suitable test for TorchDynamo")
139*da0073e9SAndroid Build Coastguard Worker    def test_optional_intlist(self):
140*da0073e9SAndroid Build Coastguard Worker        self.do_test_optional_intlist_with_module(IntListWrapperModule())
141*da0073e9SAndroid Build Coastguard Worker        self.do_test_optional_intlist_with_module(torch.jit.script(IntListWrapperModule()))
142*da0073e9SAndroid Build Coastguard Worker
143*da0073e9SAndroid Build Coastguard Worker        traced_none = self.trace_optional_intlist(None)
144*da0073e9SAndroid Build Coastguard Worker        traced_list = self.trace_optional_intlist([5, 4])
145*da0073e9SAndroid Build Coastguard Worker
146*da0073e9SAndroid Build Coastguard Worker        # Not really a module, just lets us use our two traced functions to handle
147*da0073e9SAndroid Build Coastguard Worker        # the specific cases of passing None and [5, 4].
148*da0073e9SAndroid Build Coastguard Worker        def fake_module(values, const):
149*da0073e9SAndroid Build Coastguard Worker            if const is None:
150*da0073e9SAndroid Build Coastguard Worker                return traced_none(values)
151*da0073e9SAndroid Build Coastguard Worker            if const == [5, 4]:
152*da0073e9SAndroid Build Coastguard Worker                return traced_list(values)
153*da0073e9SAndroid Build Coastguard Worker            raise Exception("Invalid argument")  # noqa: TRY002
154*da0073e9SAndroid Build Coastguard Worker
155*da0073e9SAndroid Build Coastguard Worker        self.do_test_optional_intlist_with_module(fake_module)
156*da0073e9SAndroid Build Coastguard Worker
157*da0073e9SAndroid Build Coastguard Worker    def test_optional_intlist_invalid(self):
158*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(TypeError, "must be .* but found"):
159*da0073e9SAndroid Build Coastguard Worker            IntListWrapperModule()(torch.zeros(1), [0.5])
160*da0073e9SAndroid Build Coastguard Worker
161*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(RuntimeError, "value of type .* instead found type"):
162*da0073e9SAndroid Build Coastguard Worker            torch.jit.script(IntListWrapperModule())(torch.zeros(1), [0.5])
163*da0073e9SAndroid Build Coastguard Worker
164*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(TypeError, "must be .* Tensor"):
165*da0073e9SAndroid Build Coastguard Worker            IntListWrapperModule()(torch.zeros(1), torch.zeros(1))
166*da0073e9SAndroid Build Coastguard Worker
167*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(RuntimeError, "value of type .* instead found type"):
168*da0073e9SAndroid Build Coastguard Worker            torch.jit.script(IntListWrapperModule())(torch.zeros(1), torch.zeros(1))
169*da0073e9SAndroid Build Coastguard Worker
170*da0073e9SAndroid Build Coastguard Worker    #
171*da0073e9SAndroid Build Coastguard Worker    # optional filled int list
172*da0073e9SAndroid Build Coastguard Worker    #
173*da0073e9SAndroid Build Coastguard Worker
174*da0073e9SAndroid Build Coastguard Worker    def do_test_optional_filled_intlist_with_module(self, module):
175*da0073e9SAndroid Build Coastguard Worker        values = torch.tensor([1, 2], dtype=torch.int)
176*da0073e9SAndroid Build Coastguard Worker
177*da0073e9SAndroid Build Coastguard Worker        returned = module(values, None)
178*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(values, returned)
179*da0073e9SAndroid Build Coastguard Worker        # Make sure that it's an alias, indicating that the operator saw a nullopt.
180*da0073e9SAndroid Build Coastguard Worker        values[0] = 3
181*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(values, returned)
182*da0073e9SAndroid Build Coastguard Worker
183*da0073e9SAndroid Build Coastguard Worker        returned = module(values, 10)
184*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(values, torch.tensor([3, 2], dtype=torch.int))
185*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(returned, torch.tensor([13, 12], dtype=torch.int))
186*da0073e9SAndroid Build Coastguard Worker
187*da0073e9SAndroid Build Coastguard Worker    def trace_optional_filled_intlist(self, const):
188*da0073e9SAndroid Build Coastguard Worker        def wrapper(values):
189*da0073e9SAndroid Build Coastguard Worker            return torch._C._nn._test_optional_filled_intlist(values, const)
190*da0073e9SAndroid Build Coastguard Worker        return torch.jit.trace(wrapper, torch.tensor([1, 2], dtype=torch.int))
191*da0073e9SAndroid Build Coastguard Worker
192*da0073e9SAndroid Build Coastguard Worker    @skipIfTorchDynamo("Not a suitable test for TorchDynamo")
193*da0073e9SAndroid Build Coastguard Worker    def test_optional_filled_intlist(self):
194*da0073e9SAndroid Build Coastguard Worker
195*da0073e9SAndroid Build Coastguard Worker        def f(n: int):
196*da0073e9SAndroid Build Coastguard Worker            x = torch._C._nn._test_optional_filled_intlist(torch.tensor([1, 1], dtype=torch.int), (n, n))
197*da0073e9SAndroid Build Coastguard Worker            y = torch._C._nn._test_optional_filled_intlist(torch.tensor([1, 1], dtype=torch.int), n)
198*da0073e9SAndroid Build Coastguard Worker            return x, y
199*da0073e9SAndroid Build Coastguard Worker
200*da0073e9SAndroid Build Coastguard Worker        # eager
201*da0073e9SAndroid Build Coastguard Worker        returned = f(10)
202*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(returned[0], returned[1])
203*da0073e9SAndroid Build Coastguard Worker
204*da0073e9SAndroid Build Coastguard Worker        # scripted
205*da0073e9SAndroid Build Coastguard Worker        s = torch.jit.script(f)
206*da0073e9SAndroid Build Coastguard Worker        returned = s(10)
207*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(returned[0], returned[1])
208*da0073e9SAndroid Build Coastguard Worker
209*da0073e9SAndroid Build Coastguard Worker        # traced
210*da0073e9SAndroid Build Coastguard Worker        traced_none = self.trace_optional_filled_intlist(None)
211*da0073e9SAndroid Build Coastguard Worker        traced_int = self.trace_optional_filled_intlist(10)
212*da0073e9SAndroid Build Coastguard Worker
213*da0073e9SAndroid Build Coastguard Worker        # Not really a module, just lets us use our two traced functions to handle
214*da0073e9SAndroid Build Coastguard Worker        # the specific cases of passing None and 10.
215*da0073e9SAndroid Build Coastguard Worker        def fake_module(values, const):
216*da0073e9SAndroid Build Coastguard Worker            if const is None:
217*da0073e9SAndroid Build Coastguard Worker                return traced_none(values)
218*da0073e9SAndroid Build Coastguard Worker            if const == 10:
219*da0073e9SAndroid Build Coastguard Worker                return traced_int(values)
220*da0073e9SAndroid Build Coastguard Worker            raise Exception("Invalid argument")  # noqa: TRY002
221*da0073e9SAndroid Build Coastguard Worker
222*da0073e9SAndroid Build Coastguard Worker        self.do_test_optional_filled_intlist_with_module(fake_module)
223*da0073e9SAndroid Build Coastguard Worker
224*da0073e9SAndroid Build Coastguard Worker    def test_string_defaults(self):
225*da0073e9SAndroid Build Coastguard Worker        dummy = torch.rand(1)
226*da0073e9SAndroid Build Coastguard Worker        fn = torch._C._nn._test_string_default
227*da0073e9SAndroid Build Coastguard Worker        fn(dummy)
228*da0073e9SAndroid Build Coastguard Worker
229*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(RuntimeError, "A"):
230*da0073e9SAndroid Build Coastguard Worker            fn(dummy, a="")
231*da0073e9SAndroid Build Coastguard Worker
232*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(RuntimeError, "B"):
233*da0073e9SAndroid Build Coastguard Worker            fn(dummy, b="")
234*da0073e9SAndroid Build Coastguard Worker
235*da0073e9SAndroid Build Coastguard Worker        def f(x):
236*da0073e9SAndroid Build Coastguard Worker            torch._C._nn._test_string_default(x)
237*da0073e9SAndroid Build Coastguard Worker        scripted_fn = torch.jit.script(f)
238*da0073e9SAndroid Build Coastguard Worker        scripted_fn(dummy)
239*da0073e9SAndroid Build Coastguard Worker
240*da0073e9SAndroid Build Coastguard Worker
241*da0073e9SAndroid Build Coastguard Workerif __name__ == '__main__':
242*da0073e9SAndroid Build Coastguard Worker    run_tests()
243