xref: /aosp_15_r20/external/pytorch/test/onnx/internal/test_registraion.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# Owner(s): ["module: onnx"]
2"""Unit tests for the internal registration wrapper module."""
3
4from typing import Sequence
5
6from torch.onnx import errors
7from torch.onnx._internal import registration
8from torch.testing._internal import common_utils
9
10
11@common_utils.instantiate_parametrized_tests
12class TestGlobalHelpers(common_utils.TestCase):
13    @common_utils.parametrize(
14        "available_opsets, target, expected",
15        [
16            ((7, 8, 9, 10, 11), 16, 11),
17            ((7, 8, 9, 10, 11), 11, 11),
18            ((7, 8, 9, 10, 11), 10, 10),
19            ((7, 8, 9, 10, 11), 9, 9),
20            ((7, 8, 9, 10, 11), 8, 8),
21            ((7, 8, 9, 10, 11), 7, 7),
22            ((9, 10, 16), 16, 16),
23            ((9, 10, 16), 15, 10),
24            ((9, 10, 16), 10, 10),
25            ((9, 10, 16), 9, 9),
26            ((9, 10, 16), 8, 9),
27            ((9, 10, 16), 7, 9),
28            ((7, 9, 10, 16), 16, 16),
29            ((7, 9, 10, 16), 10, 10),
30            ((7, 9, 10, 16), 9, 9),
31            ((7, 9, 10, 16), 8, 9),
32            ((7, 9, 10, 16), 7, 7),
33            ([17], 16, None),  # New op added in 17
34            ([9], 9, 9),
35            ([9], 8, 9),
36            ([], 16, None),
37            ([], 9, None),
38            ([], 8, None),
39            # Ops registered at opset 1 found as a fallback when target >= 9
40            ([1], 16, 1),
41        ],
42    )
43    def test_dispatch_opset_version_returns_correct_version(
44        self, available_opsets: Sequence[int], target: int, expected: int
45    ):
46        actual = registration._dispatch_opset_version(target, available_opsets)
47        self.assertEqual(actual, expected)
48
49
50class TestOverrideDict(common_utils.TestCase):
51    def setUp(self):
52        self.override_dict: registration.OverrideDict[str, int] = (
53            registration.OverrideDict()
54        )
55
56    def test_get_item_returns_base_value_when_no_override(self):
57        self.override_dict.set_base("a", 42)
58        self.override_dict.set_base("b", 0)
59
60        self.assertEqual(self.override_dict["a"], 42)
61        self.assertEqual(self.override_dict["b"], 0)
62        self.assertEqual(len(self.override_dict), 2)
63
64    def test_get_item_returns_overridden_value_when_override(self):
65        self.override_dict.set_base("a", 42)
66        self.override_dict.set_base("b", 0)
67        self.override_dict.override("a", 100)
68        self.override_dict.override("c", 1)
69
70        self.assertEqual(self.override_dict["a"], 100)
71        self.assertEqual(self.override_dict["b"], 0)
72        self.assertEqual(self.override_dict["c"], 1)
73        self.assertEqual(len(self.override_dict), 3)
74
75    def test_get_item_raises_key_error_when_not_found(self):
76        self.override_dict.set_base("a", 42)
77
78        with self.assertRaises(KeyError):
79            self.override_dict["nonexistent_key"]
80
81    def test_get_returns_overridden_value_when_override(self):
82        self.override_dict.set_base("a", 42)
83        self.override_dict.set_base("b", 0)
84        self.override_dict.override("a", 100)
85        self.override_dict.override("c", 1)
86
87        self.assertEqual(self.override_dict.get("a"), 100)
88        self.assertEqual(self.override_dict.get("b"), 0)
89        self.assertEqual(self.override_dict.get("c"), 1)
90        self.assertEqual(len(self.override_dict), 3)
91
92    def test_get_returns_none_when_not_found(self):
93        self.override_dict.set_base("a", 42)
94
95        self.assertEqual(self.override_dict.get("nonexistent_key"), None)
96
97    def test_in_base_returns_true_for_base_value(self):
98        self.override_dict.set_base("a", 42)
99        self.override_dict.set_base("b", 0)
100        self.override_dict.override("a", 100)
101        self.override_dict.override("c", 1)
102
103        self.assertIn("a", self.override_dict)
104        self.assertIn("b", self.override_dict)
105        self.assertIn("c", self.override_dict)
106
107        self.assertTrue(self.override_dict.in_base("a"))
108        self.assertTrue(self.override_dict.in_base("b"))
109        self.assertFalse(self.override_dict.in_base("c"))
110        self.assertFalse(self.override_dict.in_base("nonexistent_key"))
111
112    def test_overridden_returns_true_for_overridden_value(self):
113        self.override_dict.set_base("a", 42)
114        self.override_dict.set_base("b", 0)
115        self.override_dict.override("a", 100)
116        self.override_dict.override("c", 1)
117
118        self.assertTrue(self.override_dict.overridden("a"))
119        self.assertFalse(self.override_dict.overridden("b"))
120        self.assertTrue(self.override_dict.overridden("c"))
121        self.assertFalse(self.override_dict.overridden("nonexistent_key"))
122
123    def test_remove_override_removes_overridden_value(self):
124        self.override_dict.set_base("a", 42)
125        self.override_dict.set_base("b", 0)
126        self.override_dict.override("a", 100)
127        self.override_dict.override("c", 1)
128
129        self.assertEqual(self.override_dict["a"], 100)
130        self.assertEqual(self.override_dict["c"], 1)
131
132        self.override_dict.remove_override("a")
133        self.override_dict.remove_override("c")
134        self.assertEqual(self.override_dict["a"], 42)
135        self.assertEqual(self.override_dict.get("c"), None)
136        self.assertFalse(self.override_dict.overridden("a"))
137        self.assertFalse(self.override_dict.overridden("c"))
138
139    def test_remove_override_removes_overridden_key(self):
140        self.override_dict.override("a", 100)
141        self.assertEqual(self.override_dict["a"], 100)
142        self.assertEqual(len(self.override_dict), 1)
143        self.override_dict.remove_override("a")
144        self.assertEqual(len(self.override_dict), 0)
145        self.assertNotIn("a", self.override_dict)
146
147    def test_overriden_key_precededs_base_key_regardless_of_insert_order(self):
148        self.override_dict.set_base("a", 42)
149        self.override_dict.override("a", 100)
150        self.override_dict.set_base("a", 0)
151
152        self.assertEqual(self.override_dict["a"], 100)
153        self.assertEqual(len(self.override_dict), 1)
154
155    def test_bool_is_true_when_not_empty(self):
156        if self.override_dict:
157            self.fail("OverrideDict should be false when empty")
158        self.override_dict.override("a", 1)
159        if not self.override_dict:
160            self.fail("OverrideDict should be true when not empty")
161        self.override_dict.set_base("a", 42)
162        if not self.override_dict:
163            self.fail("OverrideDict should be true when not empty")
164        self.override_dict.remove_override("a")
165        if not self.override_dict:
166            self.fail("OverrideDict should be true when not empty")
167
168
169class TestRegistrationDecorators(common_utils.TestCase):
170    def tearDown(self) -> None:
171        registration.registry._registry.pop("test::test_op", None)
172
173    def test_onnx_symbolic_registers_function(self):
174        self.assertFalse(registration.registry.is_registered_op("test::test_op", 9))
175
176        @registration.onnx_symbolic("test::test_op", opset=9)
177        def test(g, x):
178            return g.op("test", x)
179
180        self.assertTrue(registration.registry.is_registered_op("test::test_op", 9))
181        function_group = registration.registry.get_function_group("test::test_op")
182        assert function_group is not None
183        self.assertEqual(function_group.get(9), test)
184
185    def test_onnx_symbolic_registers_function_applied_decorator_when_provided(self):
186        wrapper_called = False
187
188        def decorator(func):
189            def wrapper(*args, **kwargs):
190                nonlocal wrapper_called
191                wrapper_called = True
192                return func(*args, **kwargs)
193
194            return wrapper
195
196        @registration.onnx_symbolic("test::test_op", opset=9, decorate=[decorator])
197        def test():
198            return
199
200        function_group = registration.registry.get_function_group("test::test_op")
201        assert function_group is not None
202        registered_function = function_group[9]
203        self.assertFalse(wrapper_called)
204        registered_function()
205        self.assertTrue(wrapper_called)
206
207    def test_onnx_symbolic_raises_warning_when_overriding_function(self):
208        self.assertFalse(registration.registry.is_registered_op("test::test_op", 9))
209
210        @registration.onnx_symbolic("test::test_op", opset=9)
211        def test1():
212            return
213
214        with self.assertWarnsRegex(
215            errors.OnnxExporterWarning,
216            "Symbolic function 'test::test_op' already registered",
217        ):
218
219            @registration.onnx_symbolic("test::test_op", opset=9)
220            def test2():
221                return
222
223    def test_custom_onnx_symbolic_registers_custom_function(self):
224        self.assertFalse(registration.registry.is_registered_op("test::test_op", 9))
225
226        @registration.custom_onnx_symbolic("test::test_op", opset=9)
227        def test(g, x):
228            return g.op("test", x)
229
230        self.assertTrue(registration.registry.is_registered_op("test::test_op", 9))
231        function_group = registration.registry.get_function_group("test::test_op")
232        assert function_group is not None
233        self.assertEqual(function_group.get(9), test)
234
235    def test_custom_onnx_symbolic_overrides_existing_function(self):
236        self.assertFalse(registration.registry.is_registered_op("test::test_op", 9))
237
238        @registration.onnx_symbolic("test::test_op", opset=9)
239        def test_original():
240            return "original"
241
242        self.assertTrue(registration.registry.is_registered_op("test::test_op", 9))
243
244        @registration.custom_onnx_symbolic("test::test_op", opset=9)
245        def test_custom():
246            return "custom"
247
248        function_group = registration.registry.get_function_group("test::test_op")
249        assert function_group is not None
250        self.assertEqual(function_group.get(9), test_custom)
251
252
253if __name__ == "__main__":
254    common_utils.run_tests()
255