1# mypy: allow-untyped-defs 2from __future__ import annotations 3 4import inspect 5import logging 6import operator 7import re 8from typing import Callable, Sequence 9 10import onnxscript # type: ignore[import] 11from onnxscript.function_libs.torch_lib import ( # type: ignore[import] 12 graph_building as onnxscript_graph_building, 13) 14 15import torch 16import torch.fx 17from torch.onnx import _type_utils as jit_type_utils 18from torch.onnx._internal.fx import ( 19 _pass, 20 diagnostics, 21 onnxfunction_dispatcher, 22 type_utils as fx_type_utils, 23) 24from torch.utils import _pytree 25 26 27def _fx_node_to_onnx_message_formatter( 28 fn: Callable, 29 self, 30 node: torch.fx.Node, 31 *args, 32 **kwargs, 33) -> str: 34 return f"FX Node: {node.op}:{node.target}[name={node.name}]. " 35 36 37def _fx_graph_to_onnx_message_formatter( 38 fn: Callable, 39 self, 40 fx_graph_module: torch.fx.GraphModule, 41 *args, 42 **kwargs, 43) -> str: 44 return f"FX Graph: {fx_graph_module._get_name()}. " 45 46 47def _location_from_fx_stack_trace( 48 node_stack_trace: str, 49) -> diagnostics.infra.Location | None: 50 """Extract location from FX node stack trace. 51 52 TODO(bowbao): Create fx utils module and move this function there. 53 54 Args: 55 node_stack_trace: The stack trace of the FX node. Example: 56 57 File "path/file.py", line 311, in <function> 58 <code> 59 | File "path/file2.py", line 389, in <function> 60 <code> 61 62 Returns: 63 location: The location of the FX node. 64 """ 65 if "File" not in node_stack_trace: 66 return None 67 68 lines = node_stack_trace.strip().split("\n") 69 idx = 0 70 while idx < len(lines) and "File" not in lines[idx]: 71 idx += 1 72 if idx + 1 >= len(lines): 73 return None 74 75 pattern = re.compile(r"^File \"(.+)\", line (\d+), in (.+)$") 76 matches = pattern.match(lines[idx].strip()) 77 if matches: 78 uri = matches.group(1) 79 line_number = int(matches.group(2)) 80 snippet = lines[idx + 1].strip() 81 return diagnostics.infra.Location(uri=uri, line=line_number, snippet=snippet) 82 return None 83 84 85def _retrieve_or_adapt_input_to_graph_set( 86 fx_node_arg: fx_type_utils.Argument, 87 fx_name_to_onnxscript_value: dict[ 88 str, 89 onnxscript_graph_building.TorchScriptTensor 90 | tuple[onnxscript_graph_building.TorchScriptTensor, ...], 91 ], 92 tracer: onnxscript_graph_building.TorchScriptTracingEvaluator, 93): 94 """Map FX value to TorchScript value. 95 96 When creating TorchScript graph from FX graph, we need a mapping from FX variable 97 to TorchScript variable. This function maps FX variable, fx_node_arg, to torch.jit.Value. 98 """ 99 100 onnx_tensor = fx_node_arg 101 if isinstance(onnx_tensor, torch.fx.Node): 102 # 1. fx_node_arg is a torch.fx.Node, which means 103 # fx_node_arg stands for the output of that torch.fx.Node. 104 # 2. fx_node_arg (variable in torch.fx.Graph) is be mapped to 105 # torch.jit.Value, fx_name_to_onnxscript_value[fx_node_arg.name], 106 # in TorchScript graph. 107 return fx_name_to_onnxscript_value[onnx_tensor.name] 108 elif isinstance(onnx_tensor, (tuple, list)) and any( 109 isinstance(node, torch.fx.Node) 110 and fx_type_utils.is_torch_symbolic_type(node.meta.get("val")) 111 for node in onnx_tensor 112 ): 113 # This intends to handle dynamic axes. for example, if the input size of op.Expand 114 # is dynamic, each dimension would be variable (i.e., sym variable in Pytorch 115 # FX graph. Note that sym variable is mapped to tensor in ONNX Script world) 116 # calculated by other operators. 117 sequence_mixed_elements: list[ 118 onnxscript_graph_building.TorchScriptTensor 119 | tuple[onnxscript_graph_building.TorchScriptTensor, ...] 120 | list[int] 121 ] = [] 122 # onnx_tensor contains a list of scalars which could be one of 123 # - tensor with empty shape, 124 # - tensor with tensor with shape (1,), 125 # - torch.SymInt, 126 # - int 127 # - ... 128 # They should all be promoted to tensor with shape (1,) 129 # in order to call ONNX's Concat. 130 for tensor in onnx_tensor: 131 # Prepare `tensor` as input of ONNX's Concat. 132 133 if isinstance( 134 tensor, torch.fx.Node 135 ) and fx_type_utils.is_torch_symbolic_type(tensor.meta.get("val")): 136 # In this case, tensor is a torch.SymInt from Dynamo's perspective. 137 # It might be mapped to tensor with shape () or (1,) in ONNX. 138 element_value = fx_name_to_onnxscript_value[tensor.name] 139 if isinstance( 140 element_value, onnxscript_graph_building.TorchScriptTensor 141 ): 142 # All elements sequence_mixed_elements will be send to onnx's Concat 143 # as inputs. Therefore, they are required to have the same rank. 144 # Since tensors with rank=0 (i.e., scalar) cannot be concated, all 145 # scalars are promoted to tensors with shape (1,). 146 with onnxscript.evaluator.default_as(tracer): 147 element_value = onnxscript.opset18.Reshape(element_value, [1]) # type: ignore[arg-type, type-var] 148 sequence_mixed_elements.append(element_value) 149 elif isinstance(tensor, int): 150 # NOTE: op.Concat doesn't support scalar, so we need to wrap it with 151 # dim, and onnx-script will promote it to tensor(int64) 152 sequence_mixed_elements.append([tensor]) 153 else: 154 raise RuntimeError( 155 f"Unsupported type in sequence_mixed_elements: {type(tensor)}" 156 ) 157 # Concat all the elements in the sequence. 158 # shapes are mapped to tensors in ONNX graph (TorchScriptGraph), 159 # so list of sym_ints is concatenated to a tensor before calling ONNX op. 160 161 # For example: 162 # inputs: [[2], [4], fx.Node(SymIntA), [1], fx.Node(SymIntB)] 163 # outputs: op.Concat([op.Constant(2), op.Constant(4), TorchScriptTensor(A), op.Constant(1), TorchScriptTensor(B)]) 164 165 # onnx-script auto wraps python number with op.Constants, 166 # so we don't need to specifically process them. 167 with onnxscript.evaluator.default_as(tracer): 168 output = onnxscript.opset18.Concat(*sequence_mixed_elements, axis=0) # type: ignore[type-var] 169 output.dtype = torch.int64 # type: ignore[union-attr] 170 output.shape = [len(sequence_mixed_elements)] # type: ignore[union-attr] 171 return output 172 elif isinstance(onnx_tensor, (tuple, list)) and all( 173 isinstance(node, torch.fx.Node) or node is None for node in onnx_tensor 174 ): 175 sequence_elements: list[ 176 onnxscript_graph_building.TorchScriptTensor 177 | None 178 | tuple[onnxscript_graph_building.TorchScriptTensor, ...] 179 ] = [] 180 for tensor in onnx_tensor: 181 sequence_elements.append( 182 fx_name_to_onnxscript_value[tensor.name] if tensor is not None else None 183 ) 184 return sequence_elements 185 if isinstance(onnx_tensor, torch.dtype): 186 onnx_tensor = int( # type: ignore[call-overload] 187 jit_type_utils.JitScalarType.from_dtype(onnx_tensor).onnx_type() 188 ) 189 # NOTE: if device is specified in kwargs (not consumed), it's free to ignored. But 190 # if it's in args, we need to set it to string for dispatcher to match schema. 191 if isinstance(onnx_tensor, torch.device): 192 # torch.device is not supported by onnxscript (no op). We turn it into 193 # a string. 194 return str(onnx_tensor) 195 # all other cases, we do nothing. 196 return onnx_tensor 197 198 199def filter_incompatible_and_dtype_convert_kwargs(kwargs): 200 """Filter out kwargs that are not supported by onnxscript.""" 201 filtered = {} 202 for key, value in kwargs.items(): 203 if key in { 204 "layout", 205 "device", 206 "requires_grad", 207 "pin_memory", 208 "memory_format", 209 "implicit", 210 }: 211 continue 212 if key == "dtype": 213 if value is None: 214 # We omit if dtype is not provided, because onnxscript handles the 215 # default case. 216 continue 217 else: 218 value = int(jit_type_utils.JitScalarType.from_dtype(value).onnx_type()) # type: ignore[call-overload] 219 filtered[key] = value 220 return filtered 221 222 223def _fill_tensor_shape_type( 224 onnxscript_values: onnxscript_graph_building.TorchScriptTensor 225 | tuple[onnxscript_graph_building.TorchScriptTensor, ...], 226 name: str, 227 expected_values: fx_type_utils.META_VALUE_TYPE 228 | list[fx_type_utils.META_VALUE_TYPE] 229 | tuple[fx_type_utils.META_VALUE_TYPE | None, ...], 230): 231 """Fill the meta information of onnxscript_values with that from the fx FakeTensor.""" 232 233 if isinstance(expected_values, (list, tuple)) and not isinstance( 234 onnxscript_values, (list, tuple) 235 ): 236 # ex: aten::split - in onnx_dtype: seq(tensor) 237 # onnxscript_values is a single tensor, but expected_values is a list of tensors. 238 return 239 240 flat_onnxscript_values, _ = _pytree.tree_flatten(onnxscript_values) 241 flat_expected_values, _ = _pytree.tree_flatten(expected_values) 242 for i, (onnxscript_value, expected_value) in enumerate( 243 zip(flat_onnxscript_values, flat_expected_values) 244 ): 245 if expected_value is None: 246 # There is no shape/type from None. 247 # NOTE: according to https://github.com/pytorch/pytorch/blob/main/torch/_meta_registrations.py, 248 # None could be a valid value for return type, so we need to handle it. 249 # e.g. the function: meta__scaled_dot_product_flash() in cpu mode. 250 continue 251 elif fx_type_utils.is_torch_symbolic_type(expected_value): 252 # aten::sym_size output is a int, not a tensor, which stands 253 # for the size of one dim. We treat it as 1-D tensor. 254 onnxscript_value.dtype = fx_type_utils.from_sym_value_to_torch_dtype( 255 expected_value 256 ) 257 onnxscript_value.shape = torch.Size([1]) 258 elif isinstance(expected_value, (int, float, bool)): 259 onnxscript_value.dtype = fx_type_utils.from_scalar_type_to_torch_dtype( 260 type(expected_value) 261 ) 262 onnxscript_value.shape = torch.Size([]) 263 elif isinstance(expected_value, complex): 264 # From complex scalar to real representation 265 onnxscript_value_to_torch_dtype = ( 266 fx_type_utils.from_scalar_type_to_torch_dtype(type(expected_value)) 267 ) 268 onnxscript_value.dtype = ( 269 fx_type_utils.from_complex_to_float(onnxscript_value_to_torch_dtype) 270 if onnxscript_value_to_torch_dtype is not None 271 else None 272 ) 273 onnxscript_value.shape = torch.Size([2]) 274 elif fx_type_utils.is_torch_complex_dtype(expected_value.dtype): 275 # Like torch.view_as_real, we flatten complex tensors to real tensors with 276 # additional last dimension of 2 277 onnxscript_value.shape = torch.Size((*expected_value.size(), 2)) 278 # complex64 -> float32, complex128 -> float64, etc. 279 onnxscript_value.dtype = fx_type_utils.from_complex_to_float( 280 expected_value.dtype 281 ) 282 # Dispatcher needs to know the value is complex 283 onnxscript_value.is_complex = True 284 else: 285 # We set node output sizes to be dynamic to continue the model conversion, 286 # and inputs are also set to be dynamic in add_input(). 287 onnxscript_value.shape = expected_value.size() 288 onnxscript_value.dtype = expected_value.dtype 289 290 # naming 291 if i > 0: 292 onnxscript_value.name = f"{name}_{i}" 293 else: 294 onnxscript_value.name = name 295 296 297def _fill_in_default_kwargs( 298 node: torch.fx.Node, 299) -> tuple[list[fx_type_utils.Argument], dict[str, fx_type_utils.Argument]]: 300 """Find and Fill in the not provided kwargs with default values.""" 301 302 # TODO: aten::sym_size has overload, but fx graph is using 303 # overloadpacket for some reasons. 304 # https://github.com/pytorch/pytorch/issues/97201 305 # We manually assigned overload for aten::sym_size. 306 if hasattr(node.target, "_schema"): 307 node_schema = node.target._schema # type: ignore[union-attr] 308 else: 309 node_schema = torch.ops.aten.sym_size.int._schema # type: ignore[union-attr] 310 311 # This function assumes the order of arguments in FX op is the 312 # same as the order of arguments in TorchScript op. 313 complete_args: list[fx_type_utils.Argument] = [] 314 complete_kwargs: dict[str, fx_type_utils.Argument] = {} 315 316 if inspect.isbuiltin(node.target): 317 complete_args = list(node.args) 318 else: 319 for i, expected_arg in enumerate(node_schema.arguments): 320 if i < len(node.args): 321 complete_args.append(node.args[i]) 322 elif expected_arg.name in node.kwargs: 323 complete_kwargs[expected_arg.name] = node.kwargs[expected_arg.name] 324 else: 325 # Get default from schema. 326 complete_kwargs[expected_arg.name] = expected_arg.default_value 327 328 return complete_args, complete_kwargs 329 330 331def _wrap_fx_args_as_onnxscript_args( 332 complete_args: list[fx_type_utils.Argument], 333 complete_kwargs: dict[str, fx_type_utils.Argument], 334 fx_name_to_onnxscript_value: dict[ 335 str, 336 onnxscript_graph_building.TorchScriptTensor 337 | tuple[onnxscript_graph_building.TorchScriptTensor, ...], 338 ], 339 tracer: onnxscript_graph_building.TorchScriptTracingEvaluator, 340) -> tuple[ 341 Sequence[ 342 onnxscript_graph_building.TorchScriptTensor 343 | str 344 | int 345 | float 346 | bool 347 | list 348 | complex 349 | None 350 ], 351 dict[str, fx_type_utils.Argument], 352]: 353 """Map all FX arguments of a node to arguments in TorchScript graph.""" 354 355 onnxscript_args = tuple( 356 _retrieve_or_adapt_input_to_graph_set(arg, fx_name_to_onnxscript_value, tracer) 357 for arg in complete_args 358 ) 359 onnxscript_kwargs = filter_incompatible_and_dtype_convert_kwargs(complete_kwargs) 360 361 return onnxscript_args, onnxscript_kwargs 362 363 364class FxOnnxInterpreter: 365 """Stateless class to process FX graph Nodes and translate them into their ONNX counterparts. 366 367 All FX nodes described by [FX Graph](https://pytorch.org/docs/stable/fx.html#torch.fx.Graph) are supported. 368 Similarly to [FX Interpreter pattern](https://pytorch.org/docs/stable/fx.html#torch.fx.Interpreter), each FX node 369 must be implemented on its own method in this class. 370 371 Each operator's implementation returns either an `onnxscript.OnnxFunction` or 372 `onnxscript.TracedOnnxFunction` instance based on the dispatch algorithm. They can 373 also raise RuntimeError: If there are no overloaded functions available for the given FX node. 374 375 TODO: Convert methods to @staticmethod when the diagnostic system supports it 376 DO NOT ADD NEW ATTRIBUTES TO THIS CLASS! 377 """ 378 379 def __init__( 380 self, 381 diagnostic_context: diagnostics.DiagnosticContext, 382 ): 383 # THIS SHOULD BE THE ONLY STATE IN THIS CLASS (constraint from diagnosticS API) 384 # TODO: Diagnostics API should be revised to get rid of this attribute. 385 # DO NOT add other class-level attributes. 386 self.diagnostic_context = diagnostic_context 387 388 @diagnostics.diagnose_call( 389 diagnostics.rules.fx_node_to_onnx, 390 diagnostic_message_formatter=_fx_node_to_onnx_message_formatter, 391 ) 392 def run_node( 393 self, 394 node, 395 fx_graph_module: torch.fx.GraphModule, 396 onnxfunction_dispatcher: onnxfunction_dispatcher.OnnxFunctionDispatcher, 397 onnxscript_graph: onnxscript_graph_building.TorchScriptGraph, 398 onnxscript_tracer: onnxscript_graph_building.TorchScriptTracingEvaluator, 399 fx_name_to_onnxscript_value: dict[ 400 str, 401 onnxscript_graph_building.TorchScriptTensor 402 | tuple[onnxscript_graph_building.TorchScriptTensor, ...], 403 ], 404 ): 405 """Execute a single FX node to produce its ONNX counterpart. 406 407 Args: 408 node: The FX node to be translated. 409 fx_graph_module: The FX graph module containing the node. 410 onnxfunction_dispatcher: The dispatcher to find the best matched ONNX op. 411 onnxscript_graph: The ONNX graph to be populated. 412 onnxscript_tracer: The tracer to trace the ONNX graph. 413 fx_name_to_onnxscript_value: The mapping from FX node name to ONNX Script value. 414 415 Raises: 416 RuntimeError: When a node.op is not supported. 417 """ 418 # Record stack trace of node in diagnostic. 419 node_stack_trace = node.stack_trace 420 if node_stack_trace: 421 diagnostic = self.diagnostic_context.inflight_diagnostic( 422 rule=diagnostics.rules.fx_node_to_onnx 423 ) 424 with diagnostic.log_section(logging.INFO, "PyTorch source information"): 425 diagnostic.info("```\n%s\n```", node_stack_trace) 426 location = _location_from_fx_stack_trace(node_stack_trace) 427 if location is not None: 428 diagnostic.with_location(location) 429 430 if node.op == "placeholder": 431 self.placeholder(node, onnxscript_graph, fx_name_to_onnxscript_value) 432 elif node.op == "get_attr": 433 self.get_attr( 434 node, 435 onnxscript_graph, 436 fx_name_to_onnxscript_value, 437 fx_graph_module, 438 ) 439 elif node.op == "call_function": 440 self.call_function( 441 node, 442 onnxscript_tracer, 443 fx_name_to_onnxscript_value, 444 onnxfunction_dispatcher, 445 fx_graph_module, 446 ) 447 elif node.op == "call_method": 448 self.call_method(node) 449 elif node.op == "call_module": 450 self.call_module( 451 node, 452 onnxscript_graph, 453 fx_name_to_onnxscript_value, 454 onnxscript_tracer, 455 fx_graph_module, 456 onnxfunction_dispatcher, 457 ) 458 elif node.op == "output": 459 self.output(node, onnxscript_graph, fx_name_to_onnxscript_value) 460 else: 461 raise RuntimeError(f"Found node type not defined in torch.fx: {node.op}") 462 463 @diagnostics.diagnose_call( 464 diagnostics.rules.fx_graph_to_onnx, 465 diagnostic_message_formatter=_fx_graph_to_onnx_message_formatter, 466 ) 467 def run( 468 self, 469 fx_graph_module: torch.fx.GraphModule, 470 onnxfunction_dispatcher: onnxfunction_dispatcher.OnnxFunctionDispatcher, 471 parent_onnxscript_graph: onnxscript_graph_building.TorchScriptGraph 472 | None = None, 473 ) -> onnxscript_graph_building.TorchScriptGraph: 474 """Analyze all FX nodes and trigger their ONNX translation. 475 476 Args: 477 fx_graph_module: FX graph module to be translated. 478 onnxfunction_dispatcher: ONNX function dispatcher. 479 parent_onnxscript_graph: The parent TorchScript graph. Must be provided if 480 `fx_graph_module` is a submodule. If not provided, 481 `fx_graph_module` is assumed to be the root module. 482 """ 483 diagnostic = self.diagnostic_context.inflight_diagnostic() 484 with diagnostic.log_section(logging.DEBUG, "FX Graph:"): 485 diagnostic.debug( 486 "```\n%s\n```", 487 diagnostics.LazyString(fx_graph_module.print_readable, False), 488 ) 489 490 if parent_onnxscript_graph is not None: 491 # If parent_onnxscript_graph is provided, we assume fx_graph_module is a 492 # submodule representing a forward call of an nn.Module. 493 # Compose package and version where the nn.Module is defined as domain name 494 # for the local function. 495 496 onnx_meta: _pass.GraphModuleOnnxMeta | None = fx_graph_module.meta.get( 497 "onnx" 498 ) 499 if onnx_meta is None: 500 raise RuntimeError( 501 f"ONNX meta is not found in submodule {fx_graph_module._get_name()}. " 502 f"Only submodules produced by `Modularize` pass is supported in ONNX export." 503 ) 504 505 onnx_domain = onnx_meta.package_info.to_onnx_domain_string() 506 else: 507 # Leave as default domain name for the root module. 508 onnx_domain = None 509 510 onnxscript_graph = onnxscript_graph_building.TorchScriptGraph( 511 parent_onnxscript_graph, domain_name=onnx_domain 512 ) 513 onnxscript_tracer = onnxscript_graph_building.TorchScriptTracingEvaluator( 514 onnxscript_graph 515 ) 516 # In the following loop, a TorchScript graph is created to 517 # represent the input FX graph with ONNX symbols (e.g., onnx::add). 518 # To connect the values to nodes in the TorchScript graph, we maintain 519 # fx_name_to_onnxscript_value. Basically, we want to translate 520 # fx_tensor_x (type: torch.fx.Node) -> fx_node_1 -> fx_tensor_y (type: torch.fx.Node) 521 # to 522 # fx_name_to_onnxscript_value[fx_tensor_x.name] -> onnx_node_1 -> fx_name_to_onnxscript_value[fx_tensor_y.name] 523 fx_name_to_onnxscript_value: dict[ 524 str, 525 onnxscript_graph_building.TorchScriptTensor 526 | tuple[onnxscript_graph_building.TorchScriptTensor, ...], 527 ] = {} 528 529 # TODO: Fix FakeTensorMode limitation asap 530 # We want to pass list of ints and floats to TorchScript graph correctly 531 # in _export_fx_to_ts, so we must disable FakeTensorMode. Otherwise, graph may 532 # receive FakeTensor and results runtime error. In addition, TorchScript-based 533 # ONNX exporter used in _ts_graph_to_onnx_model_in_protobuf is not compatible 534 # with FakeTensorMode. 535 with torch.utils._mode_utils.no_dispatch(): 536 for node in fx_graph_module.graph.nodes: 537 self.run_node( 538 node, 539 fx_graph_module, 540 onnxfunction_dispatcher, 541 onnxscript_graph, 542 onnxscript_tracer, 543 fx_name_to_onnxscript_value, 544 ) 545 546 with diagnostic.log_section(logging.DEBUG, "ONNX Graph:"): 547 diagnostic.debug("```\n%s\n```", onnxscript_graph.torch_graph) # type: ignore[attr-defined] 548 549 return onnxscript_graph 550 551 def placeholder( 552 self, 553 node: torch.fx.Node, 554 onnxscript_graph: onnxscript_graph_building.TorchScriptGraph, 555 fx_name_to_onnxscript_value: dict[ 556 str, 557 onnxscript_graph_building.TorchScriptTensor 558 | tuple[onnxscript_graph_building.TorchScriptTensor, ...], 559 ], 560 ): 561 # Input of graph. 562 # The node.meta["val"] is generated by FakeTensorProp. 563 # NOTE: add_input() intends to create nodes with shape/type 564 fake_tensor = node.meta.get("val", None) 565 # NOTE: During the tracing, when inputs are constants, they are represented 566 # by nodes with node.meta['val'] being None (nn.Module to dynamo_export) 567 # or nodes with node.meta['val'] being a builtin value (ExportedProgram to dynamo_export). 568 # Nonethless, the nodes are not consumed by others, so we don't need to 569 # create a TorchScriptTensor for them. 570 if fake_tensor is None or isinstance(fake_tensor, (int, float, bool, str)): 571 output = onnxscript_graph.add_input( 572 input_name=None, 573 ) 574 elif isinstance(fake_tensor, torch.Tensor): 575 # NOTE: ONNX doesn't support tensor of complex64/complex128, so we 576 # convert them to float32/float64 with real representation. 577 if fx_type_utils.is_torch_complex_dtype(fake_tensor.dtype): 578 fake_tensor = torch.view_as_real(fake_tensor.resolve_conj()) 579 output = onnxscript_graph.add_input( 580 input_name=node.name, 581 shape=fake_tensor.shape, 582 dtype=fake_tensor.dtype, 583 ) 584 585 elif fx_type_utils.is_torch_symbolic_type(fake_tensor): 586 output = onnxscript_graph.add_input( 587 input_name=node.name, 588 shape=torch.Size([]), 589 dtype=fx_type_utils.from_sym_value_to_torch_dtype(fake_tensor), 590 ) 591 else: 592 raise RuntimeError( 593 f"Unsupported type(node.meta['val']) for placeholder: {type(fake_tensor)}" 594 ) 595 assert ( 596 output is not None 597 ), f"Node creates None with target={node.target} and name={node.name}" 598 599 assert isinstance(output, onnxscript_graph_building.TorchScriptTensor) 600 assert isinstance(output, onnxscript.tensor.Tensor) 601 602 fx_name_to_onnxscript_value[node.name] = output 603 604 def call_function( 605 self, 606 node: torch.fx.Node, 607 onnxscript_tracer: onnxscript_graph_building.TorchScriptTracingEvaluator, 608 fx_name_to_onnxscript_value: dict[ 609 str, 610 onnxscript_graph_building.TorchScriptTensor 611 | tuple[onnxscript_graph_building.TorchScriptTensor, ...], 612 ], 613 onnxfunction_dispatcher: onnxfunction_dispatcher.OnnxFunctionDispatcher, 614 fx_graph_module: torch.fx.GraphModule, 615 ): 616 # aten ops and other stateless functions. 617 if node.target == operator.getitem and isinstance( 618 fx_name_to_onnxscript_value[node.args[0].name], # type: ignore[union-attr,index] 619 tuple, 620 ): 621 onnx_tensor_tuple = fx_name_to_onnxscript_value[node.args[0].name] # type: ignore[union-attr,index] 622 index = node.args[1] 623 value = onnx_tensor_tuple[index] # type: ignore[index] 624 assert ( 625 value is not None 626 ), f"Node creates None with target={node.target} and name={node.name}" 627 assert isinstance( 628 value, (onnxscript_graph_building.TorchScriptTensor, tuple) 629 ), type(value) 630 631 fx_name_to_onnxscript_value[node.name] = value 632 return 633 634 # Map FX inputs to ONNX inputs and fill optional inputs with default values. 635 # torch_args and torch_kwargs are for op-level validation 636 fx_args, fx_kwargs = _fill_in_default_kwargs(node) 637 638 onnx_args, onnx_kwargs = _wrap_fx_args_as_onnxscript_args( 639 fx_args, 640 fx_kwargs, 641 fx_name_to_onnxscript_value, 642 onnxscript_tracer, 643 ) 644 # Dispatch to ONNX op through OpShema. The input argument dtypes are compared to 645 # function signature in OpSchema, and find the best matched overload. 646 symbolic_fn = onnxfunction_dispatcher.dispatch( 647 node=node, 648 onnx_args=onnx_args, # type: ignore[arg-type] 649 onnx_kwargs=onnx_kwargs, 650 diagnostic_context=self.diagnostic_context, 651 ) 652 with onnxscript.evaluator.default_as(onnxscript_tracer): 653 output: ( 654 onnxscript_graph_building.TorchScriptTensor 655 | tuple[onnxscript_graph_building.TorchScriptTensor, ...] 656 ) = symbolic_fn(*onnx_args, **onnx_kwargs) 657 assert ( 658 output is not None 659 ), f"Node creates None with target={node.target}, name={node.name}, args={onnx_args}, kwargs={onnx_kwargs}" 660 # Assign type and shape from fx graph. 661 _fill_tensor_shape_type(output, node.name, node.meta["val"]) 662 # One fx node could produce multiple outputs (e.g., tuple of tensors); in 663 # that case, v is a tuple of TorchScriptTensors. 664 assert isinstance( 665 output, (onnxscript_graph_building.TorchScriptTensor, tuple) 666 ), type(output) 667 fx_name_to_onnxscript_value[node.name] = output 668 669 def output( 670 self, 671 node: torch.fx.Node, 672 onnxscript_graph: onnxscript_graph_building.TorchScriptGraph, 673 fx_name_to_onnxscript_value: dict[ 674 str, 675 onnxscript_graph_building.TorchScriptTensor 676 | tuple[onnxscript_graph_building.TorchScriptTensor, ...], 677 ], 678 ): 679 if isinstance(node.args[0], torch.fx.Node): 680 onnx_tensor_or_tensor_tuple = fx_name_to_onnxscript_value[node.args[0].name] 681 onnxscript_graph.register_outputs(onnx_tensor_or_tensor_tuple) 682 else: 683 # ONNX can't represent collection types (e.g., dictionary, tuple of tuple of 684 # tensor, etc), we flatten the collection and register each element as output. 685 flat_args, _ = _pytree.tree_flatten(node.args[0]) 686 for arg in flat_args: 687 assert isinstance( 688 arg, torch.fx.Node 689 ), f"arg must be a torch.fx.Node, not {type(arg)}" 690 onnx_tensor_or_tensor_tuple = fx_name_to_onnxscript_value[arg.name] 691 onnxscript_graph.register_outputs(onnx_tensor_or_tensor_tuple) 692 693 def call_method(self, node: torch.fx.Node): 694 # TODO(wechi): Support call_method. 695 raise RuntimeError("call_method is not supported yet.") 696 697 def call_module( 698 self, 699 node: torch.fx.Node, 700 parent_onnxscript_graph: onnxscript_graph_building.TorchScriptGraph, 701 fx_name_to_onnxscript_value: dict[ 702 str, 703 onnxscript_graph_building.TorchScriptTensor 704 | tuple[onnxscript_graph_building.TorchScriptTensor, ...], 705 ], 706 tracer: onnxscript_graph_building.TorchScriptTracingEvaluator, 707 root_fx_graph_module: torch.fx.GraphModule, 708 onnxfunction_dispatcher: onnxfunction_dispatcher.OnnxFunctionDispatcher, 709 ) -> None: 710 """Export a fx.GraphModule submodule to ONNXScript graph. 711 712 The export process specifically targets `call_module` nodes that are created by 713 the exporter's `Modularize` pass. Each `call_module` node has an associated fx.GraphModule 714 by `node.target` underneath the root fx.GraphModule. These `call_module` nodes are exported as ONNX 715 function nodes. The related `sub_module` is then exported as an ONNX model local function, 716 which is represented by another `TorchScriptGraph`. This `TorchScriptGraph` sets the current 717 `onnxscript_graph` as its parent. 718 719 Args: 720 node: The call_module node in the FX graph that represents the submodule call. 721 parent_onnxscript_graph: The parent ONNXScript graph to which the ONNX function and 722 function node belong. 723 fx_name_to_onnxscript_value: The mapping from FX node name to ONNXScript value. 724 tracer: The tracer used to trace the ONNXScript graph. 725 root_fx_graph_module: The root FX module. 726 onnxfunction_dispatcher: The dispatcher. 727 """ 728 assert isinstance( 729 node.target, str 730 ), f"node.target must be a str, not {type(node.target)} for node {node}." 731 732 sub_module = root_fx_graph_module.get_submodule(node.target) 733 734 assert isinstance( 735 sub_module, torch.fx.GraphModule 736 ), f"sub_module must be a torch.fx.GraphModule, not {type(sub_module)} for node {node}." 737 738 sub_onnxscript_graph = self.run( 739 sub_module, onnxfunction_dispatcher, parent_onnxscript_graph 740 ) 741 742 onnx_args, _ = _wrap_fx_args_as_onnxscript_args( 743 list(node.args), {}, fx_name_to_onnxscript_value, tracer 744 ) 745 746 # TODO: We may want to consider other naming styles. The goal is to be stable and 747 # unique such that it can be easily identified in case of kernel substitution. 748 # Example for current style is combination of qualified module class name and 749 # module attribute name: `torch_nn_modules_conv_Conv2d_conv1`. 750 # Other naming styles such as qualified module class name made unique can also 751 # be considered. 752 unique_module_name = f"{sub_module._get_name()}_{node.target}" 753 754 outputs: ( 755 onnxscript_graph_building.TorchScriptTensor 756 | tuple[onnxscript_graph_building.TorchScriptTensor, ...] 757 ) = parent_onnxscript_graph.add_module_call( # type: ignore[assignment] 758 unique_module_name, sub_onnxscript_graph, onnx_args 759 ) 760 761 assert isinstance( 762 outputs, (onnxscript_graph_building.TorchScriptTensor, tuple) 763 ), f"Unexpected outputs type {type(outputs)} for node {node}." 764 765 _fill_tensor_shape_type(outputs, node.name, node.meta["val"]) 766 fx_name_to_onnxscript_value[node.name] = outputs 767 768 # Skip op_level_validation for call_module. Subgraph nodes are validated individually. 769 770 def get_attr( 771 self, 772 node: torch.fx.Node, 773 onnxscript_graph: onnxscript_graph_building.TorchScriptGraph, 774 fx_name_to_onnxscript_value: dict[ 775 str, 776 onnxscript_graph_building.TorchScriptTensor 777 | tuple[onnxscript_graph_building.TorchScriptTensor, ...], 778 ], 779 fx_graph_module: torch.fx.GraphModule, 780 ): 781 assert isinstance(node.target, str), f"node.target {node.target} is not a str." 782 attr_tensor = getattr(fx_graph_module, node.target) 783 assert isinstance(attr_tensor, torch.Tensor), f"{attr_tensor} is not a tensor." 784 785 # Parameter/buffer name cannot contain "." 786 # Revert from "/" to restore namespace formatting. 787 input_ = onnxscript_graph.add_initializer( 788 name=node.target.replace("/", "."), 789 value=attr_tensor, 790 ) 791 792 assert isinstance(input_, onnxscript_graph_building.TorchScriptTensor) 793 assert isinstance(input_, onnxscript.tensor.Tensor) 794 fx_name_to_onnxscript_value[node.name] = input_ 795