# Owner(s): ["module: onnx"] """Unit tests for the internal registration wrapper module.""" from typing import Sequence from torch.onnx import errors from torch.onnx._internal import registration from torch.testing._internal import common_utils @common_utils.instantiate_parametrized_tests class TestGlobalHelpers(common_utils.TestCase): @common_utils.parametrize( "available_opsets, target, expected", [ ((7, 8, 9, 10, 11), 16, 11), ((7, 8, 9, 10, 11), 11, 11), ((7, 8, 9, 10, 11), 10, 10), ((7, 8, 9, 10, 11), 9, 9), ((7, 8, 9, 10, 11), 8, 8), ((7, 8, 9, 10, 11), 7, 7), ((9, 10, 16), 16, 16), ((9, 10, 16), 15, 10), ((9, 10, 16), 10, 10), ((9, 10, 16), 9, 9), ((9, 10, 16), 8, 9), ((9, 10, 16), 7, 9), ((7, 9, 10, 16), 16, 16), ((7, 9, 10, 16), 10, 10), ((7, 9, 10, 16), 9, 9), ((7, 9, 10, 16), 8, 9), ((7, 9, 10, 16), 7, 7), ([17], 16, None), # New op added in 17 ([9], 9, 9), ([9], 8, 9), ([], 16, None), ([], 9, None), ([], 8, None), # Ops registered at opset 1 found as a fallback when target >= 9 ([1], 16, 1), ], ) def test_dispatch_opset_version_returns_correct_version( self, available_opsets: Sequence[int], target: int, expected: int ): actual = registration._dispatch_opset_version(target, available_opsets) self.assertEqual(actual, expected) class TestOverrideDict(common_utils.TestCase): def setUp(self): self.override_dict: registration.OverrideDict[str, int] = ( registration.OverrideDict() ) def test_get_item_returns_base_value_when_no_override(self): self.override_dict.set_base("a", 42) self.override_dict.set_base("b", 0) self.assertEqual(self.override_dict["a"], 42) self.assertEqual(self.override_dict["b"], 0) self.assertEqual(len(self.override_dict), 2) def test_get_item_returns_overridden_value_when_override(self): self.override_dict.set_base("a", 42) self.override_dict.set_base("b", 0) self.override_dict.override("a", 100) self.override_dict.override("c", 1) self.assertEqual(self.override_dict["a"], 100) self.assertEqual(self.override_dict["b"], 0) self.assertEqual(self.override_dict["c"], 1) self.assertEqual(len(self.override_dict), 3) def test_get_item_raises_key_error_when_not_found(self): self.override_dict.set_base("a", 42) with self.assertRaises(KeyError): self.override_dict["nonexistent_key"] def test_get_returns_overridden_value_when_override(self): self.override_dict.set_base("a", 42) self.override_dict.set_base("b", 0) self.override_dict.override("a", 100) self.override_dict.override("c", 1) self.assertEqual(self.override_dict.get("a"), 100) self.assertEqual(self.override_dict.get("b"), 0) self.assertEqual(self.override_dict.get("c"), 1) self.assertEqual(len(self.override_dict), 3) def test_get_returns_none_when_not_found(self): self.override_dict.set_base("a", 42) self.assertEqual(self.override_dict.get("nonexistent_key"), None) def test_in_base_returns_true_for_base_value(self): self.override_dict.set_base("a", 42) self.override_dict.set_base("b", 0) self.override_dict.override("a", 100) self.override_dict.override("c", 1) self.assertIn("a", self.override_dict) self.assertIn("b", self.override_dict) self.assertIn("c", self.override_dict) self.assertTrue(self.override_dict.in_base("a")) self.assertTrue(self.override_dict.in_base("b")) self.assertFalse(self.override_dict.in_base("c")) self.assertFalse(self.override_dict.in_base("nonexistent_key")) def test_overridden_returns_true_for_overridden_value(self): self.override_dict.set_base("a", 42) self.override_dict.set_base("b", 0) self.override_dict.override("a", 100) self.override_dict.override("c", 1) self.assertTrue(self.override_dict.overridden("a")) self.assertFalse(self.override_dict.overridden("b")) self.assertTrue(self.override_dict.overridden("c")) self.assertFalse(self.override_dict.overridden("nonexistent_key")) def test_remove_override_removes_overridden_value(self): self.override_dict.set_base("a", 42) self.override_dict.set_base("b", 0) self.override_dict.override("a", 100) self.override_dict.override("c", 1) self.assertEqual(self.override_dict["a"], 100) self.assertEqual(self.override_dict["c"], 1) self.override_dict.remove_override("a") self.override_dict.remove_override("c") self.assertEqual(self.override_dict["a"], 42) self.assertEqual(self.override_dict.get("c"), None) self.assertFalse(self.override_dict.overridden("a")) self.assertFalse(self.override_dict.overridden("c")) def test_remove_override_removes_overridden_key(self): self.override_dict.override("a", 100) self.assertEqual(self.override_dict["a"], 100) self.assertEqual(len(self.override_dict), 1) self.override_dict.remove_override("a") self.assertEqual(len(self.override_dict), 0) self.assertNotIn("a", self.override_dict) def test_overriden_key_precededs_base_key_regardless_of_insert_order(self): self.override_dict.set_base("a", 42) self.override_dict.override("a", 100) self.override_dict.set_base("a", 0) self.assertEqual(self.override_dict["a"], 100) self.assertEqual(len(self.override_dict), 1) def test_bool_is_true_when_not_empty(self): if self.override_dict: self.fail("OverrideDict should be false when empty") self.override_dict.override("a", 1) if not self.override_dict: self.fail("OverrideDict should be true when not empty") self.override_dict.set_base("a", 42) if not self.override_dict: self.fail("OverrideDict should be true when not empty") self.override_dict.remove_override("a") if not self.override_dict: self.fail("OverrideDict should be true when not empty") class TestRegistrationDecorators(common_utils.TestCase): def tearDown(self) -> None: registration.registry._registry.pop("test::test_op", None) def test_onnx_symbolic_registers_function(self): self.assertFalse(registration.registry.is_registered_op("test::test_op", 9)) @registration.onnx_symbolic("test::test_op", opset=9) def test(g, x): return g.op("test", x) self.assertTrue(registration.registry.is_registered_op("test::test_op", 9)) function_group = registration.registry.get_function_group("test::test_op") assert function_group is not None self.assertEqual(function_group.get(9), test) def test_onnx_symbolic_registers_function_applied_decorator_when_provided(self): wrapper_called = False def decorator(func): def wrapper(*args, **kwargs): nonlocal wrapper_called wrapper_called = True return func(*args, **kwargs) return wrapper @registration.onnx_symbolic("test::test_op", opset=9, decorate=[decorator]) def test(): return function_group = registration.registry.get_function_group("test::test_op") assert function_group is not None registered_function = function_group[9] self.assertFalse(wrapper_called) registered_function() self.assertTrue(wrapper_called) def test_onnx_symbolic_raises_warning_when_overriding_function(self): self.assertFalse(registration.registry.is_registered_op("test::test_op", 9)) @registration.onnx_symbolic("test::test_op", opset=9) def test1(): return with self.assertWarnsRegex( errors.OnnxExporterWarning, "Symbolic function 'test::test_op' already registered", ): @registration.onnx_symbolic("test::test_op", opset=9) def test2(): return def test_custom_onnx_symbolic_registers_custom_function(self): self.assertFalse(registration.registry.is_registered_op("test::test_op", 9)) @registration.custom_onnx_symbolic("test::test_op", opset=9) def test(g, x): return g.op("test", x) self.assertTrue(registration.registry.is_registered_op("test::test_op", 9)) function_group = registration.registry.get_function_group("test::test_op") assert function_group is not None self.assertEqual(function_group.get(9), test) def test_custom_onnx_symbolic_overrides_existing_function(self): self.assertFalse(registration.registry.is_registered_op("test::test_op", 9)) @registration.onnx_symbolic("test::test_op", opset=9) def test_original(): return "original" self.assertTrue(registration.registry.is_registered_op("test::test_op", 9)) @registration.custom_onnx_symbolic("test::test_op", opset=9) def test_custom(): return "custom" function_group = registration.registry.get_function_group("test::test_op") assert function_group is not None self.assertEqual(function_group.get(9), test_custom) if __name__ == "__main__": common_utils.run_tests()