xref: /aosp_15_r20/external/pytorch/test/onnx/dynamo/test_registry_dispatcher.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# Owner(s): ["module: onnx"]
2"""Unit tests for the internal registration wrapper module."""
3
4from __future__ import annotations
5
6import operator
7from typing import TypeVar, Union
8
9import onnxscript  # type: ignore[import]
10from onnxscript import BFLOAT16, DOUBLE, FLOAT, FLOAT16  # type: ignore[import]
11from onnxscript.onnx_opset import opset15 as op  # type: ignore[import]
12
13import torch
14import torch.fx
15from torch.onnx._internal.fx import diagnostics, onnxfunction_dispatcher, registration
16from torch.testing._internal import common_utils
17
18
19# TODO: this can only be global. https://github.com/microsoft/onnxscript/issues/805
20TCustomFloat = TypeVar("TCustomFloat", bound=Union[FLOAT16, FLOAT, DOUBLE, BFLOAT16])
21
22
23class TestRegistration(common_utils.TestCase):
24    def setUp(self) -> None:
25        self.registry = torch.onnx.OnnxRegistry()
26        self.custom_domain = onnxscript.values.Opset(domain="custom", version=1)
27
28    def tearDown(self) -> None:
29        internal_name_instance = registration.OpName.from_name_parts(
30            namespace="test", op_name="test_op"
31        )
32        self.registry._registry.pop(internal_name_instance, None)
33
34    def test_register_custom_op_registers_custom_function(self):
35        self.assertFalse(self.registry.is_registered_op("test", "test_op", "default"))
36
37        @onnxscript.script(self.custom_domain)
38        def custom_add(x, y):
39            return op.Add(x, y)
40
41        self.registry.register_op(custom_add, "test", "test_op", "default")
42        self.assertTrue(self.registry.is_registered_op("test", "test_op", "default"))
43
44        # Test on get_ops
45        function_group = self.registry.get_op_functions("test", "test_op", "default")
46        self.assertIsNotNone(function_group)
47        self.assertEqual({func.onnx_function for func in function_group}, {custom_add})  # type: ignore[arg-type]
48
49    def test_custom_onnx_symbolic_joins_existing_function(self):
50        self.assertFalse(self.registry.is_registered_op("test", "test_op"))
51
52        @onnxscript.script(self.custom_domain)
53        def test_original(x, y):
54            return op.Add(x, y)
55
56        # default has to be specified, as we are not using the registration.OpName
57        internal_name_instance = registration.OpName.from_name_parts(
58            namespace="test", op_name="test_op", overload="default"
59        )
60        symbolic_fn = registration.ONNXFunction(
61            test_original, op_full_name=internal_name_instance.qualified_name()
62        )
63        self.registry._register(internal_name_instance, symbolic_fn)
64        self.assertTrue(self.registry.is_registered_op("test", "test_op"))
65
66        @onnxscript.script(self.custom_domain)
67        def test_custom(x, y):
68            return op.Add(x, y)
69
70        self.registry.register_op(test_custom, "test", "test_op")
71
72        function_group = self.registry.get_op_functions("test", "test_op")
73        assert function_group is not None
74        # The order does matter (list)
75        self.assertEqual(
76            [func.onnx_function for func in function_group],
77            [test_original, test_custom],
78        )
79
80
81@common_utils.instantiate_parametrized_tests
82class TestDispatcher(common_utils.TestCase):
83    def setUp(self):
84        self.registry = torch.onnx.OnnxRegistry()
85        self.diagnostic_context = diagnostics.DiagnosticContext(
86            "torch.onnx.dynamo_export", torch.__version__
87        )
88        self.dispatcher = onnxfunction_dispatcher.OnnxFunctionDispatcher(
89            self.registry, self.diagnostic_context
90        )
91
92    @common_utils.parametrize(
93        "node, expected_name",
94        [
95            common_utils.subtest(
96                (
97                    torch.fx.Node(
98                        graph=torch.fx.Graph(),
99                        name="aten::add.Tensor",
100                        op="call_function",
101                        target=torch.ops.aten.add.Tensor,  # type: ignore[attr-defined]
102                        args=(torch.tensor(3), torch.tensor(4)),
103                        kwargs={},
104                    ),
105                    ("aten", "add", "Tensor"),
106                ),
107                name="get_Opoverload_name",
108            ),
109            common_utils.subtest(
110                (
111                    torch.fx.Node(
112                        graph=torch.fx.Graph(),
113                        name="aten::sym_size",
114                        op="call_function",
115                        target=torch.ops.aten.sym_size,
116                        args=(),
117                        kwargs={},
118                    ),
119                    ("aten", "sym_size", None),
120                ),
121                name="get_Opoverloadpacket_name",
122            ),
123            common_utils.subtest(
124                (
125                    torch.fx.Node(
126                        graph=torch.fx.Graph(),
127                        name="builtin_add",
128                        op="call_function",
129                        target=operator.add,
130                        args=(1, 2),
131                        kwargs={},
132                    ),
133                    ("_operator", "add", None),
134                ),
135                name="get_builtin_op_name",
136            ),
137        ],
138    )
139    def test_get_aten_name_on_supported_fx_node(
140        self, node: torch.fx.Node, expected_name: str
141    ):
142        expected_name_class = registration.OpName.from_name_parts(*expected_name)
143        self.assertEqual(
144            self.dispatcher._get_aten_name(node, self.diagnostic_context),
145            expected_name_class,
146        )
147
148    @common_utils.parametrize(
149        "node",
150        [
151            common_utils.subtest(
152                torch.fx.Node(
153                    graph=torch.fx.Graph(),
154                    name="aten::add",
155                    op="call_function",
156                    target=torch.ops.aten.add,
157                    args=(),
158                    kwargs={},
159                ),
160                name="unsupported_Opoverloadpacket_name",
161            ),
162            common_utils.subtest(
163                torch.fx.Node(
164                    graph=torch.fx.Graph(),
165                    name="builtin_add",
166                    op="call_function",
167                    target=operator.add,
168                    args=("A", "B"),
169                    kwargs={},
170                ),
171                name="unsupported_input_dtypes_for_builtin_op",
172            ),
173            common_utils.subtest(
174                torch.fx.Node(
175                    graph=torch.fx.Graph(),
176                    name="aten::made_up_node",
177                    op="call_function",
178                    target=lambda: None,
179                    args=(),
180                    kwargs={},
181                ),
182                name="unsupported_target_function",
183            ),
184        ],
185    )
186    def test_get_aten_name_on_unsupported_fx_node(self, node: torch.fx.Node):
187        with self.assertRaises(RuntimeError):
188            self.dispatcher._get_aten_name(node, self.diagnostic_context)
189
190    def test_get_function_overloads_gives_overload_fall_back_default(self):
191        # Test fall back to default op name
192        node_overload = torch.fx.Node(
193            graph=torch.fx.Graph(),
194            name="aten::add.Tensor",
195            op="call_function",
196            target=torch.ops.aten.add.Tensor,  # type: ignore[attr-defined]
197            args=(torch.tensor(3), torch.tensor(4)),
198            kwargs={},
199        )
200        node_overloadpacket = torch.fx.Node(
201            graph=torch.fx.Graph(),
202            name="aten::add",
203            op="call_function",
204            target=torch.ops.aten.add.Tensor,  # type: ignore[attr-defined]
205            args=(),
206            kwargs={},
207        )
208
209        self.assertEqual(
210            self.dispatcher.get_function_overloads(
211                node_overload, self.diagnostic_context
212            ),
213            self.dispatcher.get_function_overloads(
214                node_overloadpacket,
215                self.diagnostic_context,
216            ),
217        )
218
219        # Non-registered op
220        unsupported_op_node = torch.fx.Node(
221            graph=torch.fx.Graph(),
222            name="aten::made_up_node",
223            op="call_function",
224            target=lambda: None,
225            args=(),
226            kwargs={},
227        )
228        with self.assertRaises(RuntimeError):
229            self.dispatcher.get_function_overloads(
230                unsupported_op_node,
231                self.diagnostic_context,
232            )
233
234    @common_utils.parametrize(
235        "node",
236        [
237            common_utils.subtest(
238                torch.fx.Node(
239                    graph=torch.fx.Graph(),
240                    name="aten::add.Tensor",
241                    op="call_function",
242                    target=torch.ops.aten.add.Tensor,  # type: ignore[attr-defined]
243                    args=(torch.tensor(3.0), torch.tensor(4.0)),
244                    kwargs={},
245                ),
246                name="nearest_match",
247            ),
248            common_utils.subtest(
249                torch.fx.Node(
250                    graph=torch.fx.Graph(),
251                    name="aten::add.Tensor",
252                    op="call_function",
253                    target=torch.ops.aten.add.Tensor,  # type: ignore[attr-defined]
254                    args=(torch.tensor(3.0), torch.tensor(4.0)),
255                    kwargs={"alpha": 1},
256                ),
257                name="perfect_match_with_kwargs",
258            ),
259        ],
260    )
261    def test_find_the_perfect_or_nearest_match_onnxfunction_gives_custom_ops_precedence(
262        self, node
263    ):
264        custom_domain = onnxscript.values.Opset(domain="custom", version=1)
265
266        @onnxscript.script(custom_domain)
267        def test_custom_op(
268            x: TCustomFloat, y: TCustomFloat, alpha: int = 1
269        ) -> TCustomFloat:
270            return op.Add(x, y)
271
272        @onnxscript.script(custom_domain)
273        def test_default_op(
274            x: TCustomFloat, y: TCustomFloat, alpha: int = 1
275        ) -> TCustomFloat:
276            return op.Add(x, y)
277
278        op_full_name = "test::test_op"
279
280        custom_overloads = [
281            registration.ONNXFunction(
282                test_custom_op, op_full_name=op_full_name, is_custom=True
283            )
284        ]
285        function_overloads = [
286            registration.ONNXFunction(test_default_op, op_full_name=op_full_name)
287        ] + custom_overloads
288
289        symbolic_fn = self.dispatcher._find_the_perfect_or_nearest_match_onnxfunction(
290            node,
291            function_overloads,
292            node.args,
293            node.kwargs,
294            self.diagnostic_context,
295        )
296        self.assertEqual(symbolic_fn, test_custom_op)
297
298    @common_utils.parametrize(
299        "node",
300        [
301            common_utils.subtest(
302                torch.fx.Node(
303                    graph=torch.fx.Graph(),
304                    name="aten::add.Tensor",
305                    op="call_function",
306                    target=torch.ops.aten.add.Tensor,  # type: ignore[attr-defined]
307                    args=(torch.tensor(3.0), torch.tensor(4.0)),
308                    kwargs={"attr": None},
309                ),
310                name="perfect_match_with_ignoring_none_attribute",
311            ),
312            common_utils.subtest(
313                torch.fx.Node(
314                    graph=torch.fx.Graph(),
315                    name="aten::add.Tensor",
316                    op="call_function",
317                    target=torch.ops.aten.add.Tensor,  # type: ignore[attr-defined]
318                    args=(torch.tensor(3.0), torch.tensor(4.0)),
319                    kwargs={"unrelated": None},
320                ),
321                name="perfect_match_with_ignoring_unrelated_none_attribute",
322            ),
323        ],
324    )
325    def test_find_the_perfect_or_nearest_match_onnxfunction_ignores_attribute_with_none(
326        self, node
327    ):
328        custom_domain = onnxscript.values.Opset(domain="custom", version=1)
329
330        @onnxscript.script(custom_domain)
331        def test_op_attribute(
332            x: TCustomFloat, y: TCustomFloat, attr: int
333        ) -> TCustomFloat:
334            return op.Add(x, y)
335
336        @onnxscript.script(custom_domain)
337        def test_op(x: TCustomFloat, y: TCustomFloat) -> TCustomFloat:
338            return op.Add(x, y)
339
340        op_full_name = "test::test_op"
341
342        function_overloads = [
343            registration.ONNXFunction(test_op_attribute, op_full_name=op_full_name),
344            registration.ONNXFunction(test_op, op_full_name=op_full_name),
345        ]
346
347        symbolic_fn = self.dispatcher._find_the_perfect_or_nearest_match_onnxfunction(
348            node,
349            function_overloads,
350            node.args,
351            node.kwargs,
352            self.diagnostic_context,
353        )
354        self.assertEqual(symbolic_fn, test_op)
355
356    @common_utils.parametrize(
357        "node",
358        [
359            common_utils.subtest(
360                torch.fx.Node(
361                    graph=torch.fx.Graph(),
362                    name="aten::add.Tensor",
363                    op="call_function",
364                    target=torch.ops.aten.add.Tensor,  # type: ignore[attr-defined]
365                    args=(torch.tensor(3.0), torch.tensor(4.0)),
366                    kwargs={},
367                ),
368                name="nearest_match",
369            ),
370            common_utils.subtest(
371                torch.fx.Node(
372                    graph=torch.fx.Graph(),
373                    name="aten::add.Tensor",
374                    op="call_function",
375                    target=torch.ops.aten.add.Tensor,  # type: ignore[attr-defined]
376                    args=(torch.tensor(3.0), torch.tensor(4.0)),
377                    kwargs={"alpha": 1},
378                ),
379                name="perfect_match_with_kwargs",
380            ),
381        ],
382    )
383    def test_find_the_perfect_or_nearest_match_onnxfunction_gives_tie_breaks_to_registered_order(
384        self, node
385    ):
386        custom_domain = onnxscript.values.Opset(domain="custom", version=1)
387
388        @onnxscript.script(custom_domain)
389        def test_second_custom_op(
390            x: TCustomFloat, y: TCustomFloat, alpha: int = 1
391        ) -> TCustomFloat:
392            return op.Add(x, y)
393
394        @onnxscript.script(custom_domain)
395        def test_third_custom_op(
396            x: TCustomFloat, y: TCustomFloat, alpha: int = 1
397        ) -> TCustomFloat:
398            return op.Add(x, y)
399
400        @onnxscript.script(custom_domain)
401        def test_first_custom_op(
402            x: TCustomFloat, y: TCustomFloat, alpha: int = 1
403        ) -> TCustomFloat:
404            return op.Add(x, y)
405
406        op_full_name = "aten::add"
407
408        function_overloads = [
409            registration.ONNXFunction(
410                test_first_custom_op, op_full_name=op_full_name, is_custom=True
411            ),
412            registration.ONNXFunction(
413                test_second_custom_op, op_full_name=op_full_name, is_custom=True
414            ),
415            registration.ONNXFunction(
416                test_third_custom_op, op_full_name=op_full_name, is_custom=True
417            ),
418        ]
419
420        symbolic_fn = self.dispatcher._find_the_perfect_or_nearest_match_onnxfunction(
421            node,
422            function_overloads,
423            node.args,
424            node.kwargs,
425            self.diagnostic_context,
426        )
427        self.assertEqual(symbolic_fn, test_third_custom_op)
428
429
430if __name__ == "__main__":
431    common_utils.run_tests()
432