1# Owner(s): ["module: onnx"] 2"""Unit tests for the internal registration wrapper module.""" 3 4from typing import Sequence 5 6from torch.onnx import errors 7from torch.onnx._internal import registration 8from torch.testing._internal import common_utils 9 10 11@common_utils.instantiate_parametrized_tests 12class TestGlobalHelpers(common_utils.TestCase): 13 @common_utils.parametrize( 14 "available_opsets, target, expected", 15 [ 16 ((7, 8, 9, 10, 11), 16, 11), 17 ((7, 8, 9, 10, 11), 11, 11), 18 ((7, 8, 9, 10, 11), 10, 10), 19 ((7, 8, 9, 10, 11), 9, 9), 20 ((7, 8, 9, 10, 11), 8, 8), 21 ((7, 8, 9, 10, 11), 7, 7), 22 ((9, 10, 16), 16, 16), 23 ((9, 10, 16), 15, 10), 24 ((9, 10, 16), 10, 10), 25 ((9, 10, 16), 9, 9), 26 ((9, 10, 16), 8, 9), 27 ((9, 10, 16), 7, 9), 28 ((7, 9, 10, 16), 16, 16), 29 ((7, 9, 10, 16), 10, 10), 30 ((7, 9, 10, 16), 9, 9), 31 ((7, 9, 10, 16), 8, 9), 32 ((7, 9, 10, 16), 7, 7), 33 ([17], 16, None), # New op added in 17 34 ([9], 9, 9), 35 ([9], 8, 9), 36 ([], 16, None), 37 ([], 9, None), 38 ([], 8, None), 39 # Ops registered at opset 1 found as a fallback when target >= 9 40 ([1], 16, 1), 41 ], 42 ) 43 def test_dispatch_opset_version_returns_correct_version( 44 self, available_opsets: Sequence[int], target: int, expected: int 45 ): 46 actual = registration._dispatch_opset_version(target, available_opsets) 47 self.assertEqual(actual, expected) 48 49 50class TestOverrideDict(common_utils.TestCase): 51 def setUp(self): 52 self.override_dict: registration.OverrideDict[str, int] = ( 53 registration.OverrideDict() 54 ) 55 56 def test_get_item_returns_base_value_when_no_override(self): 57 self.override_dict.set_base("a", 42) 58 self.override_dict.set_base("b", 0) 59 60 self.assertEqual(self.override_dict["a"], 42) 61 self.assertEqual(self.override_dict["b"], 0) 62 self.assertEqual(len(self.override_dict), 2) 63 64 def test_get_item_returns_overridden_value_when_override(self): 65 self.override_dict.set_base("a", 42) 66 self.override_dict.set_base("b", 0) 67 self.override_dict.override("a", 100) 68 self.override_dict.override("c", 1) 69 70 self.assertEqual(self.override_dict["a"], 100) 71 self.assertEqual(self.override_dict["b"], 0) 72 self.assertEqual(self.override_dict["c"], 1) 73 self.assertEqual(len(self.override_dict), 3) 74 75 def test_get_item_raises_key_error_when_not_found(self): 76 self.override_dict.set_base("a", 42) 77 78 with self.assertRaises(KeyError): 79 self.override_dict["nonexistent_key"] 80 81 def test_get_returns_overridden_value_when_override(self): 82 self.override_dict.set_base("a", 42) 83 self.override_dict.set_base("b", 0) 84 self.override_dict.override("a", 100) 85 self.override_dict.override("c", 1) 86 87 self.assertEqual(self.override_dict.get("a"), 100) 88 self.assertEqual(self.override_dict.get("b"), 0) 89 self.assertEqual(self.override_dict.get("c"), 1) 90 self.assertEqual(len(self.override_dict), 3) 91 92 def test_get_returns_none_when_not_found(self): 93 self.override_dict.set_base("a", 42) 94 95 self.assertEqual(self.override_dict.get("nonexistent_key"), None) 96 97 def test_in_base_returns_true_for_base_value(self): 98 self.override_dict.set_base("a", 42) 99 self.override_dict.set_base("b", 0) 100 self.override_dict.override("a", 100) 101 self.override_dict.override("c", 1) 102 103 self.assertIn("a", self.override_dict) 104 self.assertIn("b", self.override_dict) 105 self.assertIn("c", self.override_dict) 106 107 self.assertTrue(self.override_dict.in_base("a")) 108 self.assertTrue(self.override_dict.in_base("b")) 109 self.assertFalse(self.override_dict.in_base("c")) 110 self.assertFalse(self.override_dict.in_base("nonexistent_key")) 111 112 def test_overridden_returns_true_for_overridden_value(self): 113 self.override_dict.set_base("a", 42) 114 self.override_dict.set_base("b", 0) 115 self.override_dict.override("a", 100) 116 self.override_dict.override("c", 1) 117 118 self.assertTrue(self.override_dict.overridden("a")) 119 self.assertFalse(self.override_dict.overridden("b")) 120 self.assertTrue(self.override_dict.overridden("c")) 121 self.assertFalse(self.override_dict.overridden("nonexistent_key")) 122 123 def test_remove_override_removes_overridden_value(self): 124 self.override_dict.set_base("a", 42) 125 self.override_dict.set_base("b", 0) 126 self.override_dict.override("a", 100) 127 self.override_dict.override("c", 1) 128 129 self.assertEqual(self.override_dict["a"], 100) 130 self.assertEqual(self.override_dict["c"], 1) 131 132 self.override_dict.remove_override("a") 133 self.override_dict.remove_override("c") 134 self.assertEqual(self.override_dict["a"], 42) 135 self.assertEqual(self.override_dict.get("c"), None) 136 self.assertFalse(self.override_dict.overridden("a")) 137 self.assertFalse(self.override_dict.overridden("c")) 138 139 def test_remove_override_removes_overridden_key(self): 140 self.override_dict.override("a", 100) 141 self.assertEqual(self.override_dict["a"], 100) 142 self.assertEqual(len(self.override_dict), 1) 143 self.override_dict.remove_override("a") 144 self.assertEqual(len(self.override_dict), 0) 145 self.assertNotIn("a", self.override_dict) 146 147 def test_overriden_key_precededs_base_key_regardless_of_insert_order(self): 148 self.override_dict.set_base("a", 42) 149 self.override_dict.override("a", 100) 150 self.override_dict.set_base("a", 0) 151 152 self.assertEqual(self.override_dict["a"], 100) 153 self.assertEqual(len(self.override_dict), 1) 154 155 def test_bool_is_true_when_not_empty(self): 156 if self.override_dict: 157 self.fail("OverrideDict should be false when empty") 158 self.override_dict.override("a", 1) 159 if not self.override_dict: 160 self.fail("OverrideDict should be true when not empty") 161 self.override_dict.set_base("a", 42) 162 if not self.override_dict: 163 self.fail("OverrideDict should be true when not empty") 164 self.override_dict.remove_override("a") 165 if not self.override_dict: 166 self.fail("OverrideDict should be true when not empty") 167 168 169class TestRegistrationDecorators(common_utils.TestCase): 170 def tearDown(self) -> None: 171 registration.registry._registry.pop("test::test_op", None) 172 173 def test_onnx_symbolic_registers_function(self): 174 self.assertFalse(registration.registry.is_registered_op("test::test_op", 9)) 175 176 @registration.onnx_symbolic("test::test_op", opset=9) 177 def test(g, x): 178 return g.op("test", x) 179 180 self.assertTrue(registration.registry.is_registered_op("test::test_op", 9)) 181 function_group = registration.registry.get_function_group("test::test_op") 182 assert function_group is not None 183 self.assertEqual(function_group.get(9), test) 184 185 def test_onnx_symbolic_registers_function_applied_decorator_when_provided(self): 186 wrapper_called = False 187 188 def decorator(func): 189 def wrapper(*args, **kwargs): 190 nonlocal wrapper_called 191 wrapper_called = True 192 return func(*args, **kwargs) 193 194 return wrapper 195 196 @registration.onnx_symbolic("test::test_op", opset=9, decorate=[decorator]) 197 def test(): 198 return 199 200 function_group = registration.registry.get_function_group("test::test_op") 201 assert function_group is not None 202 registered_function = function_group[9] 203 self.assertFalse(wrapper_called) 204 registered_function() 205 self.assertTrue(wrapper_called) 206 207 def test_onnx_symbolic_raises_warning_when_overriding_function(self): 208 self.assertFalse(registration.registry.is_registered_op("test::test_op", 9)) 209 210 @registration.onnx_symbolic("test::test_op", opset=9) 211 def test1(): 212 return 213 214 with self.assertWarnsRegex( 215 errors.OnnxExporterWarning, 216 "Symbolic function 'test::test_op' already registered", 217 ): 218 219 @registration.onnx_symbolic("test::test_op", opset=9) 220 def test2(): 221 return 222 223 def test_custom_onnx_symbolic_registers_custom_function(self): 224 self.assertFalse(registration.registry.is_registered_op("test::test_op", 9)) 225 226 @registration.custom_onnx_symbolic("test::test_op", opset=9) 227 def test(g, x): 228 return g.op("test", x) 229 230 self.assertTrue(registration.registry.is_registered_op("test::test_op", 9)) 231 function_group = registration.registry.get_function_group("test::test_op") 232 assert function_group is not None 233 self.assertEqual(function_group.get(9), test) 234 235 def test_custom_onnx_symbolic_overrides_existing_function(self): 236 self.assertFalse(registration.registry.is_registered_op("test::test_op", 9)) 237 238 @registration.onnx_symbolic("test::test_op", opset=9) 239 def test_original(): 240 return "original" 241 242 self.assertTrue(registration.registry.is_registered_op("test::test_op", 9)) 243 244 @registration.custom_onnx_symbolic("test::test_op", opset=9) 245 def test_custom(): 246 return "custom" 247 248 function_group = registration.registry.get_function_group("test::test_op") 249 assert function_group is not None 250 self.assertEqual(function_group.get(9), test_custom) 251 252 253if __name__ == "__main__": 254 common_utils.run_tests() 255