1# Owner(s): ["module: onnx"] 2"""Unit tests for the internal registration wrapper module.""" 3 4from __future__ import annotations 5 6import operator 7from typing import TypeVar, Union 8 9import onnxscript # type: ignore[import] 10from onnxscript import BFLOAT16, DOUBLE, FLOAT, FLOAT16 # type: ignore[import] 11from onnxscript.onnx_opset import opset15 as op # type: ignore[import] 12 13import torch 14import torch.fx 15from torch.onnx._internal.fx import diagnostics, onnxfunction_dispatcher, registration 16from torch.testing._internal import common_utils 17 18 19# TODO: this can only be global. https://github.com/microsoft/onnxscript/issues/805 20TCustomFloat = TypeVar("TCustomFloat", bound=Union[FLOAT16, FLOAT, DOUBLE, BFLOAT16]) 21 22 23class TestRegistration(common_utils.TestCase): 24 def setUp(self) -> None: 25 self.registry = torch.onnx.OnnxRegistry() 26 self.custom_domain = onnxscript.values.Opset(domain="custom", version=1) 27 28 def tearDown(self) -> None: 29 internal_name_instance = registration.OpName.from_name_parts( 30 namespace="test", op_name="test_op" 31 ) 32 self.registry._registry.pop(internal_name_instance, None) 33 34 def test_register_custom_op_registers_custom_function(self): 35 self.assertFalse(self.registry.is_registered_op("test", "test_op", "default")) 36 37 @onnxscript.script(self.custom_domain) 38 def custom_add(x, y): 39 return op.Add(x, y) 40 41 self.registry.register_op(custom_add, "test", "test_op", "default") 42 self.assertTrue(self.registry.is_registered_op("test", "test_op", "default")) 43 44 # Test on get_ops 45 function_group = self.registry.get_op_functions("test", "test_op", "default") 46 self.assertIsNotNone(function_group) 47 self.assertEqual({func.onnx_function for func in function_group}, {custom_add}) # type: ignore[arg-type] 48 49 def test_custom_onnx_symbolic_joins_existing_function(self): 50 self.assertFalse(self.registry.is_registered_op("test", "test_op")) 51 52 @onnxscript.script(self.custom_domain) 53 def test_original(x, y): 54 return op.Add(x, y) 55 56 # default has to be specified, as we are not using the registration.OpName 57 internal_name_instance = registration.OpName.from_name_parts( 58 namespace="test", op_name="test_op", overload="default" 59 ) 60 symbolic_fn = registration.ONNXFunction( 61 test_original, op_full_name=internal_name_instance.qualified_name() 62 ) 63 self.registry._register(internal_name_instance, symbolic_fn) 64 self.assertTrue(self.registry.is_registered_op("test", "test_op")) 65 66 @onnxscript.script(self.custom_domain) 67 def test_custom(x, y): 68 return op.Add(x, y) 69 70 self.registry.register_op(test_custom, "test", "test_op") 71 72 function_group = self.registry.get_op_functions("test", "test_op") 73 assert function_group is not None 74 # The order does matter (list) 75 self.assertEqual( 76 [func.onnx_function for func in function_group], 77 [test_original, test_custom], 78 ) 79 80 81@common_utils.instantiate_parametrized_tests 82class TestDispatcher(common_utils.TestCase): 83 def setUp(self): 84 self.registry = torch.onnx.OnnxRegistry() 85 self.diagnostic_context = diagnostics.DiagnosticContext( 86 "torch.onnx.dynamo_export", torch.__version__ 87 ) 88 self.dispatcher = onnxfunction_dispatcher.OnnxFunctionDispatcher( 89 self.registry, self.diagnostic_context 90 ) 91 92 @common_utils.parametrize( 93 "node, expected_name", 94 [ 95 common_utils.subtest( 96 ( 97 torch.fx.Node( 98 graph=torch.fx.Graph(), 99 name="aten::add.Tensor", 100 op="call_function", 101 target=torch.ops.aten.add.Tensor, # type: ignore[attr-defined] 102 args=(torch.tensor(3), torch.tensor(4)), 103 kwargs={}, 104 ), 105 ("aten", "add", "Tensor"), 106 ), 107 name="get_Opoverload_name", 108 ), 109 common_utils.subtest( 110 ( 111 torch.fx.Node( 112 graph=torch.fx.Graph(), 113 name="aten::sym_size", 114 op="call_function", 115 target=torch.ops.aten.sym_size, 116 args=(), 117 kwargs={}, 118 ), 119 ("aten", "sym_size", None), 120 ), 121 name="get_Opoverloadpacket_name", 122 ), 123 common_utils.subtest( 124 ( 125 torch.fx.Node( 126 graph=torch.fx.Graph(), 127 name="builtin_add", 128 op="call_function", 129 target=operator.add, 130 args=(1, 2), 131 kwargs={}, 132 ), 133 ("_operator", "add", None), 134 ), 135 name="get_builtin_op_name", 136 ), 137 ], 138 ) 139 def test_get_aten_name_on_supported_fx_node( 140 self, node: torch.fx.Node, expected_name: str 141 ): 142 expected_name_class = registration.OpName.from_name_parts(*expected_name) 143 self.assertEqual( 144 self.dispatcher._get_aten_name(node, self.diagnostic_context), 145 expected_name_class, 146 ) 147 148 @common_utils.parametrize( 149 "node", 150 [ 151 common_utils.subtest( 152 torch.fx.Node( 153 graph=torch.fx.Graph(), 154 name="aten::add", 155 op="call_function", 156 target=torch.ops.aten.add, 157 args=(), 158 kwargs={}, 159 ), 160 name="unsupported_Opoverloadpacket_name", 161 ), 162 common_utils.subtest( 163 torch.fx.Node( 164 graph=torch.fx.Graph(), 165 name="builtin_add", 166 op="call_function", 167 target=operator.add, 168 args=("A", "B"), 169 kwargs={}, 170 ), 171 name="unsupported_input_dtypes_for_builtin_op", 172 ), 173 common_utils.subtest( 174 torch.fx.Node( 175 graph=torch.fx.Graph(), 176 name="aten::made_up_node", 177 op="call_function", 178 target=lambda: None, 179 args=(), 180 kwargs={}, 181 ), 182 name="unsupported_target_function", 183 ), 184 ], 185 ) 186 def test_get_aten_name_on_unsupported_fx_node(self, node: torch.fx.Node): 187 with self.assertRaises(RuntimeError): 188 self.dispatcher._get_aten_name(node, self.diagnostic_context) 189 190 def test_get_function_overloads_gives_overload_fall_back_default(self): 191 # Test fall back to default op name 192 node_overload = torch.fx.Node( 193 graph=torch.fx.Graph(), 194 name="aten::add.Tensor", 195 op="call_function", 196 target=torch.ops.aten.add.Tensor, # type: ignore[attr-defined] 197 args=(torch.tensor(3), torch.tensor(4)), 198 kwargs={}, 199 ) 200 node_overloadpacket = torch.fx.Node( 201 graph=torch.fx.Graph(), 202 name="aten::add", 203 op="call_function", 204 target=torch.ops.aten.add.Tensor, # type: ignore[attr-defined] 205 args=(), 206 kwargs={}, 207 ) 208 209 self.assertEqual( 210 self.dispatcher.get_function_overloads( 211 node_overload, self.diagnostic_context 212 ), 213 self.dispatcher.get_function_overloads( 214 node_overloadpacket, 215 self.diagnostic_context, 216 ), 217 ) 218 219 # Non-registered op 220 unsupported_op_node = torch.fx.Node( 221 graph=torch.fx.Graph(), 222 name="aten::made_up_node", 223 op="call_function", 224 target=lambda: None, 225 args=(), 226 kwargs={}, 227 ) 228 with self.assertRaises(RuntimeError): 229 self.dispatcher.get_function_overloads( 230 unsupported_op_node, 231 self.diagnostic_context, 232 ) 233 234 @common_utils.parametrize( 235 "node", 236 [ 237 common_utils.subtest( 238 torch.fx.Node( 239 graph=torch.fx.Graph(), 240 name="aten::add.Tensor", 241 op="call_function", 242 target=torch.ops.aten.add.Tensor, # type: ignore[attr-defined] 243 args=(torch.tensor(3.0), torch.tensor(4.0)), 244 kwargs={}, 245 ), 246 name="nearest_match", 247 ), 248 common_utils.subtest( 249 torch.fx.Node( 250 graph=torch.fx.Graph(), 251 name="aten::add.Tensor", 252 op="call_function", 253 target=torch.ops.aten.add.Tensor, # type: ignore[attr-defined] 254 args=(torch.tensor(3.0), torch.tensor(4.0)), 255 kwargs={"alpha": 1}, 256 ), 257 name="perfect_match_with_kwargs", 258 ), 259 ], 260 ) 261 def test_find_the_perfect_or_nearest_match_onnxfunction_gives_custom_ops_precedence( 262 self, node 263 ): 264 custom_domain = onnxscript.values.Opset(domain="custom", version=1) 265 266 @onnxscript.script(custom_domain) 267 def test_custom_op( 268 x: TCustomFloat, y: TCustomFloat, alpha: int = 1 269 ) -> TCustomFloat: 270 return op.Add(x, y) 271 272 @onnxscript.script(custom_domain) 273 def test_default_op( 274 x: TCustomFloat, y: TCustomFloat, alpha: int = 1 275 ) -> TCustomFloat: 276 return op.Add(x, y) 277 278 op_full_name = "test::test_op" 279 280 custom_overloads = [ 281 registration.ONNXFunction( 282 test_custom_op, op_full_name=op_full_name, is_custom=True 283 ) 284 ] 285 function_overloads = [ 286 registration.ONNXFunction(test_default_op, op_full_name=op_full_name) 287 ] + custom_overloads 288 289 symbolic_fn = self.dispatcher._find_the_perfect_or_nearest_match_onnxfunction( 290 node, 291 function_overloads, 292 node.args, 293 node.kwargs, 294 self.diagnostic_context, 295 ) 296 self.assertEqual(symbolic_fn, test_custom_op) 297 298 @common_utils.parametrize( 299 "node", 300 [ 301 common_utils.subtest( 302 torch.fx.Node( 303 graph=torch.fx.Graph(), 304 name="aten::add.Tensor", 305 op="call_function", 306 target=torch.ops.aten.add.Tensor, # type: ignore[attr-defined] 307 args=(torch.tensor(3.0), torch.tensor(4.0)), 308 kwargs={"attr": None}, 309 ), 310 name="perfect_match_with_ignoring_none_attribute", 311 ), 312 common_utils.subtest( 313 torch.fx.Node( 314 graph=torch.fx.Graph(), 315 name="aten::add.Tensor", 316 op="call_function", 317 target=torch.ops.aten.add.Tensor, # type: ignore[attr-defined] 318 args=(torch.tensor(3.0), torch.tensor(4.0)), 319 kwargs={"unrelated": None}, 320 ), 321 name="perfect_match_with_ignoring_unrelated_none_attribute", 322 ), 323 ], 324 ) 325 def test_find_the_perfect_or_nearest_match_onnxfunction_ignores_attribute_with_none( 326 self, node 327 ): 328 custom_domain = onnxscript.values.Opset(domain="custom", version=1) 329 330 @onnxscript.script(custom_domain) 331 def test_op_attribute( 332 x: TCustomFloat, y: TCustomFloat, attr: int 333 ) -> TCustomFloat: 334 return op.Add(x, y) 335 336 @onnxscript.script(custom_domain) 337 def test_op(x: TCustomFloat, y: TCustomFloat) -> TCustomFloat: 338 return op.Add(x, y) 339 340 op_full_name = "test::test_op" 341 342 function_overloads = [ 343 registration.ONNXFunction(test_op_attribute, op_full_name=op_full_name), 344 registration.ONNXFunction(test_op, op_full_name=op_full_name), 345 ] 346 347 symbolic_fn = self.dispatcher._find_the_perfect_or_nearest_match_onnxfunction( 348 node, 349 function_overloads, 350 node.args, 351 node.kwargs, 352 self.diagnostic_context, 353 ) 354 self.assertEqual(symbolic_fn, test_op) 355 356 @common_utils.parametrize( 357 "node", 358 [ 359 common_utils.subtest( 360 torch.fx.Node( 361 graph=torch.fx.Graph(), 362 name="aten::add.Tensor", 363 op="call_function", 364 target=torch.ops.aten.add.Tensor, # type: ignore[attr-defined] 365 args=(torch.tensor(3.0), torch.tensor(4.0)), 366 kwargs={}, 367 ), 368 name="nearest_match", 369 ), 370 common_utils.subtest( 371 torch.fx.Node( 372 graph=torch.fx.Graph(), 373 name="aten::add.Tensor", 374 op="call_function", 375 target=torch.ops.aten.add.Tensor, # type: ignore[attr-defined] 376 args=(torch.tensor(3.0), torch.tensor(4.0)), 377 kwargs={"alpha": 1}, 378 ), 379 name="perfect_match_with_kwargs", 380 ), 381 ], 382 ) 383 def test_find_the_perfect_or_nearest_match_onnxfunction_gives_tie_breaks_to_registered_order( 384 self, node 385 ): 386 custom_domain = onnxscript.values.Opset(domain="custom", version=1) 387 388 @onnxscript.script(custom_domain) 389 def test_second_custom_op( 390 x: TCustomFloat, y: TCustomFloat, alpha: int = 1 391 ) -> TCustomFloat: 392 return op.Add(x, y) 393 394 @onnxscript.script(custom_domain) 395 def test_third_custom_op( 396 x: TCustomFloat, y: TCustomFloat, alpha: int = 1 397 ) -> TCustomFloat: 398 return op.Add(x, y) 399 400 @onnxscript.script(custom_domain) 401 def test_first_custom_op( 402 x: TCustomFloat, y: TCustomFloat, alpha: int = 1 403 ) -> TCustomFloat: 404 return op.Add(x, y) 405 406 op_full_name = "aten::add" 407 408 function_overloads = [ 409 registration.ONNXFunction( 410 test_first_custom_op, op_full_name=op_full_name, is_custom=True 411 ), 412 registration.ONNXFunction( 413 test_second_custom_op, op_full_name=op_full_name, is_custom=True 414 ), 415 registration.ONNXFunction( 416 test_third_custom_op, op_full_name=op_full_name, is_custom=True 417 ), 418 ] 419 420 symbolic_fn = self.dispatcher._find_the_perfect_or_nearest_match_onnxfunction( 421 node, 422 function_overloads, 423 node.args, 424 node.kwargs, 425 self.diagnostic_context, 426 ) 427 self.assertEqual(symbolic_fn, test_third_custom_op) 428 429 430if __name__ == "__main__": 431 common_utils.run_tests() 432