1# mypy: allow-untyped-defs 2from __future__ import annotations 3 4 5__all__ = [ 6 "DiagnosticOptions", 7 "ExportOptions", 8 "ONNXProgram", 9 "ONNXRuntimeOptions", 10 "InvalidExportOptionsError", 11 "OnnxRegistry", 12 "UnsatisfiedDependencyError", 13 "dynamo_export", 14 "enable_fake_mode", 15] 16 17 18import abc 19import contextlib 20import dataclasses 21import logging 22import os 23import tempfile 24import warnings 25from collections import defaultdict 26from typing import Any, Callable, Final, Mapping, Sequence, TYPE_CHECKING, TypeVar 27from typing_extensions import Self 28 29import torch 30import torch._ops 31import torch.utils._pytree as pytree 32from torch.onnx import errors 33from torch.onnx._internal import io_adapter 34from torch.onnx._internal.diagnostics import infra 35from torch.onnx._internal.fx import ( 36 decomposition_table, 37 patcher as patcher, 38 registration, 39 serialization as fx_serialization, 40) 41 42 43# We can only import onnx from this module in a type-checking context to ensure that 44# 'import torch.onnx' continues to work without having 'onnx' installed. We fully 45# 'import onnx' inside of dynamo_export (by way of _assert_dependencies). 46if TYPE_CHECKING: 47 import io 48 49 import onnx 50 51 import onnxruntime 52 import onnxscript 53 54 from torch._subclasses import fake_tensor 55 from torch.onnx._internal.fx import diagnostics 56 57_DEFAULT_OPSET_VERSION: Final[int] = 18 58"""The default ONNX opset version the exporter will use if one is not specified explicitly 59through :class:`ExportOptions`. This should NEVER be accessed outside of this module! Users 60should reference :attr:`ExportOptions.opset_version`.""" 61 62_PYTORCH_GITHUB_ISSUES_URL = "https://github.com/pytorch/pytorch/issues" 63"""The URL to the PyTorch GitHub issues page.""" 64 65_DEFAULT_FAILED_EXPORT_SARIF_LOG_PATH = "report_dynamo_export.sarif" 66"""The default path to write the SARIF log to if the export fails.""" 67 68_PROTOBUF_SIZE_MAX_LIMIT = 2 * 1024 * 1024 * 1024 69"""The maximum size of a Protobuf file in bytes. This is used to determine whether to 70serialize the model with external data or not.""" 71 72log = logging.getLogger(__name__) 73 74 75DiagnosticOptions = infra.DiagnosticOptions 76 77 78@dataclasses.dataclass 79class ONNXFakeContext: 80 """A dataclass used to store context for model export using FakeTensor. 81 82 This dataclass stores the FakeTensorMode instance used to convert 83 real tensors and model parameters into fake tensors. This :attr:`ONNXFakeContext.fake_mode` is 84 reused internally during tracing of a :class:`torch.nn.Module` into a FX :class:`GraphModule`. 85 """ 86 87 fake_mode: fake_tensor.FakeTensorMode 88 """The fake tensor mode used for tracing model using fake tensors and parameters.""" 89 90 state_dict_paths: tuple[str | io.BytesIO | dict[str, Any]] | None = None 91 """List of paths of files that contain the model :meth:`state_dict`""" 92 93 94class OnnxRegistry: 95 """Registry for ONNX functions. 96 97 The registry maintains a mapping from qualified names to symbolic functions under a 98 fixed opset version. It supports registering custom onnx-script functions and for 99 dispatcher to dispatch calls to the appropriate function. 100 101 """ 102 103 def __init__(self) -> None: 104 """Initializes the registry""" 105 106 # NOTE: _registry is the registry maps OpNameto a list of ONNXFunctions. It is important 107 # not to directly modify this variable. Instead, access to it should be done through 108 # the public methods: register_custom_op, get_ops, and is_registered_op. 109 self._registry: dict[registration.OpName, list[registration.ONNXFunction]] = ( 110 defaultdict(list) 111 ) 112 113 # opset_version is unused for now, since torchlib only supports opset18. 114 # TODO: get opset version from torchlib 115 self._opset_version = _DEFAULT_OPSET_VERSION 116 warnings.warn( 117 f"torch.onnx.dynamo_export only implements opset version {self._opset_version} for now. If you need to use a " 118 "different opset version, please register them with register_custom_op." 119 ) 120 121 self._initiate_registry_from_torchlib() 122 123 @property 124 def opset_version(self) -> int: 125 """The ONNX opset version the exporter should target. Defaults to the latest 126 supported ONNX opset version: 18. The default version will increment over time as 127 ONNX continues to evolve.""" 128 129 return self._opset_version 130 131 def _initiate_registry_from_torchlib(self) -> None: 132 """Populates the registry with ATen functions from torchlib. 133 134 Args: 135 torchlib_registry: The torchlib registry to use for populating the registry. 136 """ 137 import onnxscript._framework_apis.torch_2_5 as onnxscript_apis 138 139 for meta in onnxscript_apis.get_torchlib_ops(): 140 internal_name_instance = registration.OpName.from_qualified_name( 141 meta.qualified_name 142 ) 143 symbolic_function = registration.ONNXFunction( 144 onnx_function=meta.function, # type: ignore[arg-type] 145 op_full_name=internal_name_instance.qualified_name(), 146 is_custom=False, 147 is_complex=meta.is_complex, 148 ) 149 self._register(internal_name_instance, symbolic_function) 150 151 def _register( 152 self, 153 internal_qualified_name: registration.OpName, 154 symbolic_function: registration.ONNXFunction, 155 ) -> None: 156 """Registers a ONNXFunction to an operator. 157 158 Args: 159 internal_qualified_name: The qualified name of the operator to register: OpName. 160 symbolic_function: The ONNXFunction to register. 161 """ 162 self._registry[internal_qualified_name].append(symbolic_function) 163 164 def register_op( 165 self, 166 function: onnxscript.OnnxFunction | onnxscript.TracedOnnxFunction, 167 namespace: str, 168 op_name: str, 169 overload: str | None = None, 170 is_complex: bool = False, 171 ) -> None: 172 """Registers a custom operator: torch.ops.<namespace>.<op_name>.<overload>. 173 174 Args: 175 function: The onnx-sctip function to register. 176 namespace: The namespace of the operator to register. 177 op_name: The name of the operator to register. 178 overload: The overload of the operator to register. If it's default overload, 179 leave it to None. 180 is_complex: Whether the function is a function that handles complex valued inputs. 181 182 Raises: 183 ValueError: If the name is not in the form of 'namespace::op'. 184 """ 185 internal_name_instance = registration.OpName.from_name_parts( 186 namespace=namespace, op_name=op_name, overload=overload 187 ) 188 symbolic_function = registration.ONNXFunction( 189 onnx_function=function, 190 op_full_name=internal_name_instance.qualified_name(), 191 is_custom=True, 192 is_complex=is_complex, 193 ) 194 self._register(internal_name_instance, symbolic_function) 195 196 def get_op_functions( 197 self, namespace: str, op_name: str, overload: str | None = None 198 ) -> list[registration.ONNXFunction] | None: 199 """Returns a list of ONNXFunctions for the given op: torch.ops.<namespace>.<op_name>.<overload>. 200 201 The list is ordered by the time of registration. The custom operators should be 202 in the second half of the list. 203 204 Args: 205 namespace: The namespace of the operator to get. 206 op_name: The name of the operator to get. 207 overload: The overload of the operator to get. If it's default overload, 208 leave it to None. 209 Returns: 210 A list of ONNXFunctions corresponding to the given name, or None if 211 the name is not in the registry. 212 """ 213 internal_name_instance = registration.OpName.from_name_parts( 214 namespace=namespace, op_name=op_name, overload=overload 215 ) 216 return self._registry.get(internal_name_instance) 217 218 def is_registered_op( 219 self, namespace: str, op_name: str, overload: str | None = None 220 ) -> bool: 221 """Returns whether the given op is registered: torch.ops.<namespace>.<op_name>.<overload>. 222 223 Args: 224 namespace: The namespace of the operator to check. 225 op_name: The name of the operator to check. 226 overload: The overload of the operator to check. If it's default overload, 227 leave it to None. 228 229 Returns: 230 True if the given op is registered, otherwise False. 231 """ 232 functions = self.get_op_functions( 233 namespace=namespace, op_name=op_name, overload=overload 234 ) 235 return functions is not None 236 237 def _all_registered_ops(self) -> set[str]: 238 """Returns the set of all registered function names.""" 239 return { 240 op_name_class.qualified_name() for op_name_class in self._registry.keys() 241 } 242 243 244class ExportOptions: 245 """Options to influence the TorchDynamo ONNX exporter. 246 247 Attributes: 248 dynamic_shapes: Shape information hint for input/output tensors. 249 When ``None``, the exporter determines the most compatible setting. 250 When ``True``, all input shapes are considered dynamic. 251 When ``False``, all input shapes are considered static. 252 diagnostic_options: The diagnostic options for the exporter. 253 fake_context: The fake context used for symbolic tracing. 254 onnx_registry: The ONNX registry used to register ATen operators to ONNX functions. 255 """ 256 257 dynamic_shapes: bool | None = None 258 """Shape information hint for input/output tensors. 259 260 - ``None``: the exporter determines the most compatible setting. 261 - ``True``: all input shapes are considered dynamic. 262 - ``False``: all input shapes are considered static. 263 """ 264 265 diagnostic_options: DiagnosticOptions 266 """The diagnostic options for the exporter.""" 267 268 fake_context: ONNXFakeContext | None = None 269 """The fake context used for symbolic tracing.""" 270 271 onnx_registry: OnnxRegistry | None = None 272 """The ONNX registry used to register ATen operators to ONNX functions.""" 273 274 def __init__( 275 self, 276 *, 277 dynamic_shapes: bool | None = None, 278 fake_context: ONNXFakeContext | None = None, 279 onnx_registry: OnnxRegistry | None = None, 280 diagnostic_options: DiagnosticOptions | None = None, 281 ): 282 self.dynamic_shapes = dynamic_shapes 283 self.fake_context = fake_context 284 self.onnx_registry = onnx_registry 285 self.diagnostic_options = diagnostic_options or DiagnosticOptions() 286 287 288class ResolvedExportOptions(ExportOptions): 289 """Consolidates :class:`ExportOptions` with default values. 290 All unspecified options from :class:`ExportOptions` are assigned a default value. 291 This is an internal class and its API may be changed at any time without notice. 292 """ 293 294 # Public attributes MUST be redefined below without ``Optional[]`` from ``ExportOptions`` 295 dynamic_shapes: bool 296 diagnostic_options: DiagnosticOptions 297 fake_context: ONNXFakeContext 298 onnx_registry: OnnxRegistry 299 300 # Private only attributes 301 decomposition_table: dict[torch._ops.OpOverload, Callable] 302 """A dictionary that maps operators to their decomposition functions.""" 303 304 onnxfunction_dispatcher: ( 305 torch.onnx._internal.fx.onnxfunction_dispatcher.OnnxFunctionDispatcher 306 ) 307 """The ONNX dispatcher used to dispatch ATen operators to ONNX functions.""" 308 309 fx_tracer: FXGraphExtractor 310 """The FXGraphExtractor instance used to extract the FX graph from the model.""" 311 312 diagnostic_context: diagnostics.DiagnosticContext 313 """The diagnostics context for the export. Responsible for recording diagnostics, 314 logging diagnostics, and generating the SARIF log.""" 315 316 def __init__( 317 self, 318 options: ExportOptions | ResolvedExportOptions, 319 model: torch.nn.Module | Callable | None = None, # type: ignore[name-defined] 320 ): 321 from torch.onnx._internal.fx import ( # TODO: Prevent circular dep 322 diagnostics, 323 dynamo_graph_extractor, 324 ) 325 326 if isinstance(options, ResolvedExportOptions): 327 self.dynamic_shapes = options.dynamic_shapes 328 self.diagnostic_options = options.diagnostic_options 329 self.fake_context = options.fake_context 330 self.fx_tracer = options.fx_tracer 331 self.onnx_registry = options.onnx_registry 332 self.onnxfunction_dispatcher = options.onnxfunction_dispatcher 333 self.decomposition_table = options.decomposition_table 334 self.diagnostic_context = options.diagnostic_context 335 else: 336 T = TypeVar("T") 337 338 def resolve(value: T | None, fallback: T | Callable[[], T]) -> T: 339 if value is not None: 340 return value 341 if callable(fallback): 342 return fallback() 343 return fallback 344 345 self.dynamic_shapes = resolve(options.dynamic_shapes, False) 346 347 self.diagnostic_options = resolve( 348 options.diagnostic_options, DiagnosticOptions() 349 ) 350 351 self.fx_tracer = dynamo_graph_extractor.DynamoExport() 352 353 self.fake_context = resolve(options.fake_context, None) # type: ignore[arg-type] 354 self.diagnostic_context = diagnostics.DiagnosticContext( 355 "torch.onnx.dynamo_export", 356 torch.__version__, 357 self.diagnostic_options, 358 ) 359 360 self.onnx_registry = resolve(options.onnx_registry, OnnxRegistry()) 361 self.decomposition_table = ( 362 decomposition_table.create_onnx_friendly_decomposition_table( # type: ignore[assignment] 363 self.onnx_registry 364 ) 365 ) 366 367 from torch.onnx._internal.fx import onnxfunction_dispatcher 368 369 self.onnxfunction_dispatcher = ( 370 onnxfunction_dispatcher.OnnxFunctionDispatcher( 371 self.onnx_registry, 372 self.diagnostic_context, 373 ) 374 ) 375 376 for key in dir(options): 377 if not key.startswith("_"): # skip private attributes 378 assert hasattr(self, key), f"Unresolved option '{key}'" 379 380 381@contextlib.contextmanager 382def enable_fake_mode(): 383 """Enable fake mode for the duration of the context. 384 385 Internally it instantiates a :class:`torch._subclasses.fake_tensor.FakeTensorMode` context manager 386 that converts user input and model parameters into :class:`torch._subclasses.fake_tensor.FakeTensor`. 387 388 A :class:`torch._subclasses.fake_tensor.FakeTensor` 389 is a :class:`torch.Tensor` with the ability to run PyTorch code without having to 390 actually do computation through tensors allocated on a ``meta`` device. Because 391 there is no actual data being allocated on the device, this API allows for 392 exporting large models without the actual memory footprint needed for executing it. 393 394 It is highly recommended to enable fake mode when exporting models that 395 are too large to fit into memory. 396 397 Returns: 398 A :class:`ONNXFakeContext` object. 399 400 Example:: 401 402 # xdoctest: +REQUIRES(env:TORCH_DOCTEST_ONNX) 403 >>> import torch 404 >>> import torch.onnx 405 >>> class MyModel(torch.nn.Module): # Dummy model 406 ... def __init__(self) -> None: 407 ... super().__init__() 408 ... self.linear = torch.nn.Linear(2, 2) 409 ... def forward(self, x): 410 ... out = self.linear(x) 411 ... return out 412 >>> with torch.onnx.enable_fake_mode() as fake_context: 413 ... my_nn_module = MyModel() 414 ... arg1 = torch.randn(2, 2, 2) # positional input 1 415 >>> export_options = torch.onnx.ExportOptions(fake_context=fake_context) 416 >>> onnx_program = torch.onnx.export(my_nn_module, (arg1,), dynamo=True) 417 >>> onnx_program.apply_weights(MyModel().state_dict()) 418 >>> # Saving model WITHOUT initializers 419 >>> onnx_program.save( 420 ... "my_model_without_initializers.onnx", 421 ... include_initializers=False, 422 ... keep_initializers_as_inputs=True, 423 ... ) 424 >>> # Saving model WITH initializers 425 >>> onnx_program.save("my_model_with_initializers.onnx") 426 427 .. warning:: 428 This API is experimental and is *NOT* backward-compatible. 429 430 """ 431 from torch._subclasses import fake_tensor 432 from torch.fx.experimental.symbolic_shapes import ShapeEnv 433 434 # This overrides the internal `FakeTensorMode` instance created by `torch._dynamo.export`[1]. 435 # It is a good idea to keep them in sync (constructor args) to maintain the same default behavior 436 # [1] `torch/_dynamo/output_graph.py::InstructionTranslator::OutputGraph.__init__` 437 # Mixed fake/real tensors are only allowed when `torch.onnx.dynamo_export` is not called within `FakeTensorMode` 438 # This is needed because models can create new parameters during `forward(self, *args, **kwargs)` run 439 fake_mode = fake_tensor.FakeTensorMode( 440 allow_non_fake_inputs=not torch._guards.detect_fake_mode(), 441 shape_env=ShapeEnv( 442 allow_scalar_outputs=False, allow_dynamic_output_shape_ops=False 443 ), 444 ) 445 # The patcher is needed for when user calls `fake_model.load_state_dict(...)` within fake mode 446 patcher_context = patcher.ONNXTorchPatcher() 447 fake_context = ONNXFakeContext(fake_mode=fake_mode) 448 with fake_mode, patcher_context: 449 yield fake_context 450 fake_context.state_dict_paths = tuple( 451 patcher_context.paths, 452 ) # type: ignore[assignment] 453 454 455class ONNXRuntimeOptions: 456 """Options to influence the execution of the ONNX model through ONNX Runtime. 457 458 Attributes: 459 session_options: ONNX Runtime session options. 460 execution_providers: ONNX Runtime execution providers to use during model execution. 461 execution_provider_options: ONNX Runtime execution provider options. 462 """ 463 464 session_options: Sequence[onnxruntime.SessionOptions] | None = None 465 """ONNX Runtime session options.""" 466 467 execution_providers: Sequence[str | tuple[str, dict[Any, Any]]] | None = None 468 """ONNX Runtime execution providers to use during model execution.""" 469 470 execution_provider_options: Sequence[dict[Any, Any]] | None = None 471 """ONNX Runtime execution provider options.""" 472 473 def __init__( 474 self, 475 *, 476 session_options: Sequence[onnxruntime.SessionOptions] | None = None, 477 execution_providers: Sequence[str | tuple[str, dict[Any, Any]]] | None = None, 478 execution_provider_options: Sequence[dict[Any, Any]] | None = None, 479 ): 480 self.session_options = session_options 481 self.execution_providers = execution_providers 482 self.execution_provider_options = execution_provider_options 483 484 485class ONNXProgram: 486 """An in-memory representation of a PyTorch model that has been exported to ONNX. 487 488 Args: 489 model_proto: The exported ONNX model as an :py:obj:`onnx.ModelProto`. 490 input_adapter: The input adapter used to convert PyTorch inputs into ONNX inputs. 491 output_adapter: The output adapter used to convert PyTorch outputs into ONNX outputs. 492 diagnostic_context: Context object for the SARIF diagnostic system responsible for logging errors and metadata. 493 fake_context: The fake context used for symbolic tracing. 494 export_exception: The exception that occurred during export, if any. 495 """ 496 497 _model_proto: Final[onnx.ModelProto] # type: ignore[name-defined, misc] 498 _input_adapter: Final[io_adapter.InputAdapter] # type: ignore[misc] 499 _output_adapter: Final[io_adapter.OutputAdapter] # type: ignore[misc] 500 _diagnostic_context: Final[diagnostics.DiagnosticContext] # type: ignore[misc] 501 _fake_context: Final[ONNXFakeContext | None] # type: ignore[misc] 502 _export_exception: Final[Exception | None] # type: ignore[misc] 503 _model_torch: Final[ # type: ignore[misc] 504 torch.nn.Module | Callable | None 505 ] 506 507 def __init__( 508 self, 509 model_proto: onnx.ModelProto, # type: ignore[name-defined] 510 input_adapter: io_adapter.InputAdapter, 511 output_adapter: io_adapter.OutputAdapter, 512 diagnostic_context: diagnostics.DiagnosticContext, 513 *, 514 fake_context: ONNXFakeContext | None = None, 515 export_exception: Exception | None = None, 516 model_torch: torch.nn.Module | Callable | None = None, 517 ): 518 self._model_proto = model_proto 519 self._model_torch = model_torch 520 self._input_adapter = input_adapter 521 self._output_adapter = output_adapter 522 self._diagnostic_context = diagnostic_context 523 self._fake_context = fake_context 524 self._export_exception = export_exception 525 self._state_dict: dict[str, torch.Tensor] = {} 526 527 def __call__( 528 self, 529 *args: Any, 530 model_with_state_dict: torch.nn.Module | Callable | None = None, 531 options: ONNXRuntimeOptions | None = None, 532 **kwargs: Any, 533 ) -> Any: 534 """Runs the ONNX model using ONNX Runtime 535 536 Args: 537 args: The positional inputs to the model. 538 kwargs: The keyword inputs to the model. 539 model_with_state_dict: The PyTorch model to fetch state from. 540 Required when :func:`enable_fake_mode` is used to extract real initializers as needed by the ONNX graph. 541 options: The options to use for running the model with ONNX Runtime. 542 543 Returns: 544 The model output as computed by ONNX Runtime 545 """ 546 547 # TODO: If ONNX used absolute paths on the initializers external data files, 548 # users could call ONNXProgram.save and use ONNXProgram.__call__ without the internal save below 549 with contextlib.ExitStack() as stack: 550 # model specified by the user has precedence, when specified 551 model_with_state_dict = model_with_state_dict or self._model_torch 552 553 if self.fake_context: 554 tmpdir_path = stack.enter_context(tempfile.TemporaryDirectory()) 555 warnings.warn( 556 "Cannot run model directly from `ONNXProgram` because" 557 " the model was exported using `enable_fake_mode`." 558 " The model will be serialized to disk using a temporary folder ({tmpdir_path})" 559 " to populate the model with initializers before being execution." 560 ) 561 # TODO: Revisit the need of `model_with_state_dict` being a real model and not just its state 562 onnx_model = os.path.join(tmpdir_path, "model.onnx") 563 if isinstance(model_with_state_dict, torch.nn.Module): 564 model_state = model_with_state_dict.state_dict() 565 else: 566 model_state = self._state_dict 567 self.save( 568 onnx_model, 569 model_state=model_state, 570 ) 571 else: 572 onnx_model = self.model_proto.SerializeToString() # type: ignore[assignment] 573 574 import onnxruntime # type: ignore[import] 575 576 onnx_input = self.adapt_torch_inputs_to_onnx( 577 *args, model_with_state_dict=model_with_state_dict, **kwargs 578 ) 579 options = options or ONNXRuntimeOptions() 580 providers = ( 581 options.execution_providers or onnxruntime.get_available_providers() 582 ) 583 ort_session = onnxruntime.InferenceSession(onnx_model, providers=providers) 584 585 onnxruntime_input = { 586 k.name: v.numpy(force=True) # type: ignore[union-attr] 587 for k, v in zip(ort_session.get_inputs(), onnx_input) 588 } 589 590 return ort_session.run(None, onnxruntime_input) 591 592 @property 593 def model_proto(self) -> onnx.ModelProto: # type: ignore[name-defined] 594 """The exported ONNX model as an :py:obj:`onnx.ModelProto`.""" 595 596 if self._export_exception is not None: 597 raise self._export_exception 598 return self._model_proto 599 600 @property 601 def diagnostic_context(self) -> diagnostics.DiagnosticContext: 602 """The diagnostic context associated with the export.""" 603 604 return self._diagnostic_context 605 606 @property 607 def fake_context(self) -> ONNXFakeContext | None: 608 """The fake context associated with the export.""" 609 610 return self._fake_context 611 612 def adapt_torch_inputs_to_onnx( 613 self, 614 *model_args, 615 model_with_state_dict: torch.nn.Module | Callable | None = None, 616 **model_kwargs, 617 ) -> Sequence[torch.Tensor | int | float | bool | torch.dtype]: 618 """Converts the PyTorch model inputs to exported ONNX model inputs format. 619 620 Due to design differences, input/output format between PyTorch model and exported 621 ONNX model are often not the same. E.g., None is allowed for PyTorch model, but are 622 not supported by ONNX. Nested constructs of tensors are allowed for PyTorch model, 623 but only flattened tensors are supported by ONNX, etc. 624 625 The actual adapting steps are associated with each individual export. It 626 depends on the PyTorch model, the particular set of model_args and model_kwargs 627 used for the export, and export options. 628 629 This method replays the adapting steps recorded during export. 630 631 Args: 632 model_args: The PyTorch model inputs. 633 model_with_state_dict: The PyTorch model to get extra state from. 634 If not specified, the model used during export is used. 635 Required when :func:`enable_fake_mode` is used to extract real initializers as needed by the ONNX graph. 636 model_kwargs: The PyTorch model keyword inputs. 637 638 Returns: 639 A sequence of tensors converted from PyTorch model inputs. 640 641 Example:: 642 643 # xdoctest: +REQUIRES(env:TORCH_DOCTEST_ONNX) 644 >>> import torch 645 >>> import torch.onnx 646 >>> from typing import Dict, Tuple 647 >>> def func_nested_input( 648 ... x_dict: Dict[str, torch.Tensor], 649 ... y_tuple: Tuple[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]] 650 ... ): 651 ... if "a" in x_dict: 652 ... x = x_dict["a"] 653 ... elif "b" in x_dict: 654 ... x = x_dict["b"] 655 ... else: 656 ... x = torch.randn(3) 657 ... 658 ... y1, (y2, y3) = y_tuple 659 ... 660 ... return x + y1 + y2 + y3 661 >>> x_dict = {"a": torch.tensor(1.)} 662 >>> y_tuple = (torch.tensor(2.), (torch.tensor(3.), torch.tensor(4.))) 663 >>> onnx_program = torch.onnx.dynamo_export(func_nested_input, x_dict, y_tuple) 664 >>> print(x_dict, y_tuple) 665 {'a': tensor(1.)} (tensor(2.), (tensor(3.), tensor(4.))) 666 >>> print(onnx_program.adapt_torch_inputs_to_onnx(x_dict, y_tuple, model_with_state_dict=func_nested_input)) 667 (tensor(1.), tensor(2.), tensor(3.), tensor(4.)) 668 669 .. warning:: 670 This API is experimental and is *NOT* backward-compatible. 671 672 """ 673 # model specified by the user has precedence, when specified 674 model_with_state_dict = model_with_state_dict or self._model_torch 675 assert ( 676 model_with_state_dict is not None 677 ), "model_with_state_dict must be specified." 678 return self._input_adapter.apply( # type: ignore[return-value] 679 *model_args, model=model_with_state_dict, **model_kwargs 680 ) 681 682 def adapt_torch_outputs_to_onnx( 683 self, 684 model_outputs: Any, 685 model_with_state_dict: torch.nn.Module | Callable | None = None, 686 ) -> Sequence[torch.Tensor | int | float | bool]: 687 """Converts the PyTorch model outputs to exported ONNX model outputs format. 688 689 Due to design differences, input/output format between PyTorch model and exported 690 ONNX model are often not the same. E.g., None is allowed for PyTorch model, but are 691 not supported by ONNX. Nested constructs of tensors are allowed for PyTorch model, 692 but only flattened tensors are supported by ONNX, etc. 693 694 The actual adapting steps are associated with each individual export. It 695 depends on the PyTorch model, the particular set of model_args and model_kwargs 696 used for the export, and export options. 697 698 This method replays the adapting steps recorded during export. 699 700 Args: 701 model_outputs: The PyTorch model outputs. 702 model_with_state_dict: The PyTorch model to get extra state from. 703 If not specified, the model used during export is used. 704 Required when :func:`enable_fake_mode` is used to extract real initializers as needed by the ONNX graph. 705 706 Returns: 707 PyTorch model outputs in exported ONNX model outputs format. 708 709 Example:: 710 711 # xdoctest: +REQUIRES(env:TORCH_DOCTEST_ONNX) 712 >>> import torch 713 >>> import torch.onnx 714 >>> def func_returning_tuples(x, y, z): 715 ... x = x + y 716 ... y = y + z 717 ... z = x + y 718 ... return (x, (y, z)) 719 >>> x = torch.tensor(1.) 720 >>> y = torch.tensor(2.) 721 >>> z = torch.tensor(3.) 722 >>> onnx_program = torch.onnx.dynamo_export(func_returning_tuples, x, y, z) 723 >>> pt_output = func_returning_tuples(x, y, z) 724 >>> print(pt_output) 725 (tensor(3.), (tensor(5.), tensor(8.))) 726 >>> print(onnx_program.adapt_torch_outputs_to_onnx(pt_output, model_with_state_dict=func_returning_tuples)) 727 [tensor(3.), tensor(5.), tensor(8.)] 728 729 .. warning:: 730 This API is experimental and is *NOT* backward-compatible. 731 732 """ 733 # model specified by the user has precedence, when specified 734 model_with_state_dict = model_with_state_dict or self._model_torch 735 assert ( 736 model_with_state_dict is not None 737 ), "model_with_state_dict must be specified." 738 return self._output_adapter.apply(model_outputs, model=model_with_state_dict) # type: ignore[return-value] 739 740 def apply_weights(self, state_dict: dict[str, torch.Tensor]) -> None: 741 """Apply the weights from the specified state dict to the ONNX model. 742 Args: 743 state_dict: The state dict containing the weights to apply to the ONNX model. 744 """ 745 self._state_dict = state_dict 746 747 def save( 748 self, 749 destination: str | io.BufferedIOBase, 750 *, 751 include_initializers: bool = True, 752 model_state: dict[str, Any] | str | None = None, 753 ) -> None: 754 """Saves the in-memory ONNX model to ``destination`` using specified ``serializer``. 755 756 Args: 757 destination: The destination to save the ONNX model. It can be either a string or a file-like object. 758 When used with ``model_state``, it must be a string with a full path to the destination. 759 If `destination` is a string, besides saving the ONNX model into a file, model weights are also stored 760 in separate files in the same directory as the ONNX model. E.g. for `destination="/path/model.onnx"`, 761 the initializers are saved in "/path/" folder along with "onnx.model". 762 include_initializers: Whether to include initializers in the ONNX graph as external data. 763 Cannot be combined with `model_state_dict`. 764 model_state: The state_dict of the PyTorch model containing all weights on it. 765 It can be either a string with the path to a checkpoint or a dictionary with the actual model state. 766 The supported file formats are the same as those supported by `torch.load` and `safetensors.safe_open`. 767 Required when :func:`enable_fake_mode` is used but real initializers are needed on the ONNX graph. 768 """ 769 import onnx 770 771 assert ( 772 include_initializers is True or model_state is None 773 ), "Cannot specify both `include_initializers=False` and `model_state`." 774 775 if self._state_dict and model_state is None: 776 model_state = self._state_dict 777 778 # Add initializers when symbolic tracing is enabled 779 _model_state_files: list[str | io.BytesIO | dict[str, Any]] = [] 780 if include_initializers: 781 if model_state is not None: 782 assert isinstance( 783 model_state, (dict, str) 784 ), "model_state must be a path to the model's state_dict or the actual state_dict" 785 # NOTE: For dict, there can be performance penalty or high memory usage that might lead to OOM 786 # if the dict wasn't loaded with torch.load(..., mmap=True, map_location="cpu") 787 _model_state_files.append(model_state) 788 elif self._fake_context and self._fake_context.state_dict_paths: 789 # Load state from previous model.load_state_dict() call within enable_fake_mode() context 790 for path in self._fake_context.state_dict_paths: 791 if path in _model_state_files: 792 # ignore duplicate 793 continue 794 if os.path.exists(path): # type: ignore[arg-type] 795 _model_state_files.append(path) 796 else: 797 # self.model_proto.graph.initializer.clear() not available in older protobuf versions 798 initializer_count = len(self.model_proto.graph.initializer) 799 for _ in range(initializer_count): 800 del self.model_proto.graph.initializer[0] 801 802 if _model_state_files: 803 if not isinstance(destination, str): 804 raise RuntimeError( 805 "`destination` must be a string with a path when `model_state` is specified." 806 ) 807 destination_path, destination_filename = os.path.split(destination) 808 destination_path = destination_path or os.getcwd() 809 onnx_model_location = destination_filename 810 811 # TODO: Should this be part of the serializer? 812 fx_serialization.save_model_with_external_data( 813 destination_path, 814 onnx_model_location, 815 "", # When initializers >2GB, must be in the same folder as the model 816 tuple(_model_state_files), 817 self.model_proto, 818 ) 819 else: 820 if isinstance(destination, str): 821 with open(destination, "wb") as f: 822 if self.model_proto.ByteSize() < _PROTOBUF_SIZE_MAX_LIMIT: 823 onnx.save_model(self.model_proto, destination) # type: ignore[attr-defined] 824 else: 825 # ValueError: Message onnx.ModelProto exceeds maximum protobuf size of 2GB 826 # Fallback to serializing the model with external data. 827 onnx.save_model( # type: ignore[attr-defined] 828 self.model_proto, 829 destination, 830 save_as_external_data=True, 831 all_tensors_to_one_file=True, 832 ) 833 else: 834 try: 835 destination.write(self.model_proto.SerializeToString()) 836 except ValueError as exc: 837 raise ValueError( 838 "'destination' should be provided as a path-like string when saving a model larger than 2GB. " 839 "External tensor data will be saved alongside the model on disk." 840 ) from exc 841 842 def save_diagnostics(self, destination: str) -> None: 843 """Saves the export diagnostics as a SARIF log to the specified destination path. 844 845 Args: 846 destination: The destination to save the diagnostics SARIF log. 847 It must have a `.sarif` extension. 848 849 Raises: 850 ValueError: If the destination path does not end with `.sarif` extension. 851 """ 852 if not destination.endswith(".sarif"): 853 message = f"'destination' must have a .sarif extension, got {destination}" 854 log.fatal(message) 855 raise ValueError(message) 856 857 self.diagnostic_context.dump(destination) 858 859 @classmethod 860 def _from_failure( 861 cls, 862 export_exception: Exception, 863 diagnostic_context: diagnostics.DiagnosticContext, 864 ) -> Self: 865 """ 866 Creates an instance of :class:`ONNXProgram` when the export process encounters a failure. 867 868 In case of a failed export, this method is used to encapsulate the exception 869 and associated diagnostic context within an :class:`ONNXProgram` instance for 870 easier handling and debugging. 871 872 Args: 873 export_exception: The exception raised during the export process. 874 diagnostic_context: The context associated with diagnostics during export. 875 876 Returns: 877 An instance of :class:`ONNXProgram` representing the failed ONNX program. 878 """ 879 # Defer `import onnx` out of `import torch` path 880 # https://github.com/pytorch/pytorch/issues/103764 881 import onnx 882 883 return cls( 884 onnx.ModelProto(), # type: ignore[attr-defined] 885 io_adapter.InputAdapter(), 886 io_adapter.OutputAdapter(), 887 diagnostic_context, 888 export_exception=export_exception, 889 ) 890 891 892class FXGraphExtractor(abc.ABC): 893 """Abstract interface for FX graph extractor engines. 894 This class isolates FX extraction logic from the rest of the export logic. 895 That allows a single ONNX exporter that can leverage different FX graphs.""" 896 897 def __init__(self) -> None: 898 super().__init__() 899 self.input_adapter: io_adapter.InputAdapter = io_adapter.InputAdapter() 900 self.output_adapter: io_adapter.OutputAdapter = io_adapter.OutputAdapter() 901 902 @abc.abstractmethod 903 def generate_fx( 904 self, 905 options: ResolvedExportOptions, 906 model: torch.nn.Module | Callable, 907 model_args: Sequence[Any], 908 model_kwargs: Mapping[str, Any], 909 ) -> torch.fx.GraphModule: 910 """Analyzes user ``model`` and generates a FX graph. 911 Args: 912 options: The export options. 913 model: The user model. 914 model_args: The model's positional input arguments. 915 model_kwargs: The model's keyword input arguments. 916 Returns: 917 The generated FX Graph. 918 """ 919 ... 920 921 # TODO: Design the passes API 922 @abc.abstractmethod 923 def pre_export_passes( 924 self, 925 options: ResolvedExportOptions, 926 original_model: torch.nn.Module | Callable, 927 fx_module: torch.fx.GraphModule, 928 fx_module_args: Sequence[Any], 929 ): 930 """Applies pre-export passes to the FX graph. 931 932 Pre-export passes are FX-to-FX graph transformations that make the graph 933 more palatable for the FX-to-ONNX conversion. 934 For example, it can be used to flatten model input/output, add explicit 935 casts to the graph, replace/decompose operators, functionalize the graph, etc. 936 """ 937 ... 938 939 940class Exporter: 941 def __init__( 942 self, 943 options: ResolvedExportOptions, 944 model: torch.nn.Module | Callable, 945 model_args: Sequence[Any], 946 model_kwargs: Mapping[str, Any], 947 ): 948 self.options = options 949 assert self.options is not None 950 951 self.model = model 952 self.model_args = model_args 953 self.model_kwargs = model_kwargs 954 955 # TODO: https://github.com/pytorch/pytorch/issues/107714 956 # NOTE: FXSymbolicTracer would fail in this assert, as it does not use `enable_fake_mode` 957 from torch.onnx._internal.fx import fx_symbolic_graph_extractor 958 959 if not isinstance( 960 self.options.fx_tracer, fx_symbolic_graph_extractor.FXSymbolicTracer 961 ): 962 self._assert_fake_tensor_mode() 963 964 def export(self) -> ONNXProgram: 965 from torch.export._trace import ( # TODO: Prevent circular dependency 966 DEFAULT_EXPORT_DYNAMO_CONFIG, 967 ) 968 969 # TODO: Defer `import onnxscript` out of `import torch` path 970 # https://github.com/pytorch/pytorch/issues/103764 971 from torch.onnx._internal.fx import decomposition_skip 972 973 with self.options.diagnostic_context, decomposition_skip.enable_decomposition_skips( 974 self.options 975 ), torch._dynamo.config.patch(dataclasses.asdict(DEFAULT_EXPORT_DYNAMO_CONFIG)): 976 graph_module = self.options.fx_tracer.generate_fx( 977 self.options, self.model, self.model_args, self.model_kwargs 978 ) 979 # TODO: Defer `import onnxscript` out of `import torch` path 980 # https://github.com/pytorch/pytorch/issues/103764 981 from torch.onnx._internal.fx import fx_onnx_interpreter 982 983 fx_interpreter = fx_onnx_interpreter.FxOnnxInterpreter( 984 diagnostic_context=self.options.diagnostic_context 985 ) 986 onnxscript_graph = fx_interpreter.run( 987 fx_graph_module=graph_module, 988 onnxfunction_dispatcher=self.options.onnxfunction_dispatcher, 989 ) 990 991 # NOTE: Filter out the initializers with fake tensors when it's fake_mode exporting. 992 # Otherwise, the ONNX exporter will fail: RuntimeError: basic_string::_M_construct null 993 # not valid. 994 # Concrete data is expected to be filled for those initializers later during `ONNXProgram.save`. 995 if self.options.fake_context is not None: 996 initializers_with_real_tensors: dict[str, torch.Tensor] = {} 997 for ( 998 initializer_name, 999 initializer, 1000 ) in onnxscript_graph.initializers.items(): 1001 if not isinstance(initializer, torch._subclasses.FakeTensor): 1002 initializers_with_real_tensors[initializer_name] = initializer 1003 onnxscript_graph.initializers = initializers_with_real_tensors 1004 1005 # Export TorchScript graph to ONNX ModelProto. 1006 onnx_model = onnxscript_graph.to_model_proto( 1007 self.options.onnx_registry.opset_version, 1008 ) 1009 1010 try: 1011 from onnxscript import optimizer 1012 1013 onnx_model = optimizer.optimize(onnx_model) 1014 except ImportError: 1015 warnings.warn( 1016 "ONNXScript optimizer is not available. Skipping optimization. " 1017 "Please `pip install onnxscript -U` to enable post-export optimization." 1018 ) 1019 except Exception as e: 1020 warnings.warn( 1021 "ONNXScript optimizer failed. Skipping optimization. " 1022 "\n\nPLEASE REPORT A BUG AT https://github.com/microsoft/onnxscript/issues " 1023 f"\n\nDetail:\n{e}" 1024 ) 1025 1026 return torch.onnx.ONNXProgram( 1027 onnx_model, 1028 self.options.fx_tracer.input_adapter, 1029 self.options.fx_tracer.output_adapter, 1030 self.options.diagnostic_context, 1031 fake_context=self.options.fake_context, 1032 model_torch=self.model, 1033 ) 1034 1035 def _assert_fake_tensor_mode(self): 1036 """Asserts that the model and its input do not contain fake tensors.""" 1037 1038 # Case 1: Model with fake inputs/weights and without enabling fake mode 1039 has_any_fake_tensor = pytree.tree_any( 1040 lambda x: isinstance(x, torch._subclasses.FakeTensor), 1041 (self.model_args, self.model_kwargs), 1042 ) 1043 has_any_fake_param_or_buffer = False 1044 if isinstance(self.model, torch.nn.Module): 1045 has_any_fake_param_or_buffer = pytree.tree_any( 1046 lambda x: isinstance(x, torch._subclasses.FakeTensor), 1047 (self.model.parameters(), self.model.buffers()), 1048 ) 1049 if ( 1050 has_any_fake_tensor or has_any_fake_param_or_buffer 1051 ) and not self.options.fake_context: 1052 raise RuntimeError( 1053 "Cannot export a model with fake inputs/weights without enabling fake mode.", 1054 ) 1055 # Case 2: Model with non fake inputs/weights and enabled fake mode 1056 has_any_non_fake_tensors = pytree.tree_any( 1057 lambda x: isinstance(x, torch.Tensor) 1058 and not isinstance(x, torch._subclasses.FakeTensor), 1059 (self.model_args, self.model_kwargs), 1060 ) 1061 has_any_non_fake_param_or_buffer = False 1062 if isinstance(self.model, torch.nn.Module): 1063 has_any_non_fake_param_or_buffer = pytree.tree_any( 1064 lambda x: isinstance(x, torch.Tensor) 1065 and not isinstance(x, torch._subclasses.FakeTensor), 1066 (self.model.parameters(), self.model.buffers()), 1067 ) 1068 if ( 1069 has_any_non_fake_tensors or has_any_non_fake_param_or_buffer 1070 ) and self.options.fake_context: 1071 raise RuntimeError( 1072 "Cannot export a model with non fake inputs/weights and enabled fake mode.", 1073 ) 1074 1075 1076class UnsatisfiedDependencyError(RuntimeError): 1077 """Raised when an ONNX exporter dependency cannot be satisfied.""" 1078 1079 def __init__(self, package_name: str, message: str): 1080 super().__init__(message) 1081 self.package_name = package_name 1082 1083 1084class InvalidExportOptionsError(RuntimeError): 1085 """Raised when user specified an invalid value for the :class:`ExportOptions`.""" 1086 1087 1088def _assert_dependencies(export_options: ResolvedExportOptions): 1089 opset_version = export_options.onnx_registry.opset_version 1090 1091 def missing_package(package_name: str, exc_info: logging._ExcInfoType): 1092 message = ( 1093 f"Please install the `{package_name}` package " 1094 f"(e.g. `python -m pip install {package_name}`)." 1095 ) 1096 log.fatal(message, exc_info=exc_info) 1097 return UnsatisfiedDependencyError(package_name, message) 1098 1099 def missing_opset(package_name: str): 1100 message = ( 1101 f"The installed `{package_name}` does not support the specified ONNX opset " 1102 f"version {opset_version}. Install a newer `{package_name}` package or " 1103 f"specify an older opset version." 1104 ) 1105 log.fatal(message) 1106 return UnsatisfiedDependencyError(package_name, message) 1107 1108 try: 1109 import onnx 1110 except ImportError as e: 1111 raise missing_package("onnx", e) from e 1112 1113 if onnx.defs.onnx_opset_version() < opset_version: 1114 raise missing_opset("onnx") 1115 1116 try: 1117 # PyTorch runs lintrunner in CI without onnxscript installed 1118 import onnxscript # type: ignore[import] 1119 except ImportError as e: 1120 raise missing_package("onnxscript", e) from e 1121 1122 if not isinstance( 1123 onnxscript.onnx_opset.all_opsets[("", opset_version)], 1124 onnxscript.values.Opset, 1125 ): 1126 raise missing_opset("onnxscript") 1127 1128 1129def dynamo_export( 1130 model: torch.nn.Module | Callable, 1131 /, 1132 *model_args, 1133 export_options: ExportOptions | None = None, 1134 **model_kwargs, 1135) -> ONNXProgram | Any: 1136 """Export a torch.nn.Module to an ONNX graph. 1137 1138 Args: 1139 model: The PyTorch model to be exported to ONNX. 1140 model_args: Positional inputs to ``model``. 1141 model_kwargs: Keyword inputs to ``model``. 1142 export_options: Options to influence the export to ONNX. 1143 1144 Returns: 1145 An in-memory representation of the exported ONNX model. 1146 1147 **Example 1 - Simplest export** 1148 :: 1149 1150 class MyModel(torch.nn.Module): 1151 def __init__(self) -> None: 1152 super().__init__() 1153 self.linear = torch.nn.Linear(2, 2) 1154 1155 def forward(self, x, bias=None): 1156 out = self.linear(x) 1157 out = out + bias 1158 return out 1159 1160 1161 model = MyModel() 1162 kwargs = {"bias": 3.0} 1163 args = (torch.randn(2, 2, 2),) 1164 onnx_program = torch.onnx.dynamo_export(model, *args, **kwargs).save( 1165 "my_simple_model.onnx" 1166 ) 1167 1168 **Example 2 - Exporting with dynamic shapes** 1169 :: 1170 1171 # The previous model can be exported with dynamic shapes 1172 export_options = torch.onnx.ExportOptions(dynamic_shapes=True) 1173 onnx_program = torch.onnx.dynamo_export( 1174 model, *args, **kwargs, export_options=export_options 1175 ) 1176 onnx_program.save("my_dynamic_model.onnx") 1177 1178 1179 By printing input dynamic dimensions we can see the input shape is no longer (2,2,2) 1180 :: 1181 1182 >>> print(onnx_program.model_proto.graph.input[0]) 1183 name: "arg0" 1184 type { 1185 tensor_type { 1186 elem_type: 1 1187 shape { 1188 dim { 1189 dim_param: "arg0_dim_0" 1190 } 1191 dim { 1192 dim_param: "arg0_dim_1" 1193 } 1194 dim { 1195 dim_param: "arg0_dim_2" 1196 } 1197 } 1198 } 1199 } 1200 """ 1201 1202 if export_options is not None: 1203 resolved_export_options = ( 1204 export_options 1205 if isinstance(export_options, ResolvedExportOptions) 1206 else ResolvedExportOptions(export_options, model=model) 1207 ) 1208 else: 1209 resolved_export_options = ResolvedExportOptions(ExportOptions(), model=model) 1210 1211 _assert_dependencies(resolved_export_options) 1212 1213 try: 1214 from torch._dynamo import config as _dynamo_config 1215 1216 with _dynamo_config.patch(do_not_emit_runtime_asserts=True): 1217 return Exporter( 1218 options=resolved_export_options, 1219 model=model, 1220 model_args=model_args, 1221 model_kwargs=model_kwargs, 1222 ).export() 1223 except Exception as e: 1224 sarif_report_path = _DEFAULT_FAILED_EXPORT_SARIF_LOG_PATH 1225 resolved_export_options.diagnostic_context.dump(sarif_report_path) 1226 message = ( 1227 f"Failed to export the model to ONNX. Generating SARIF report at '{sarif_report_path}'. " 1228 "SARIF is a standard format for the output of static analysis tools. " 1229 "SARIF logs can be loaded in VS Code SARIF viewer extension, " 1230 "or SARIF web viewer (https://microsoft.github.io/sarif-web-component/). " 1231 f"Please report a bug on PyTorch Github: {_PYTORCH_GITHUB_ISSUES_URL}" 1232 ) 1233 raise errors.OnnxExporterError(message) from e 1234 1235 1236def common_pre_export_passes( 1237 options: ResolvedExportOptions, 1238 original_model: torch.nn.Module | Callable, 1239 fx_module: torch.fx.GraphModule, 1240 fx_module_args: Sequence[Any], 1241): 1242 # TODO: Import here to prevent circular dependency 1243 from torch.onnx._internal.fx import analysis, passes 1244 1245 diagnostic_context = options.diagnostic_context 1246 1247 # Apply decomposition table to the input graph. 1248 module = passes.Decompose( 1249 diagnostic_context, 1250 fx_module, 1251 options.decomposition_table, 1252 enable_dynamic_axes=options.dynamic_shapes, 1253 allow_fake_constant=options.fake_context is not None, 1254 ).run(*fx_module_args) 1255 1256 # ONNX does not support views and mutations. 1257 # Functionalize to get a semantically equivalent graph without mutations. 1258 module = passes.Functionalize( 1259 diagnostic_context, 1260 module, 1261 enable_dynamic_axes=options.dynamic_shapes, 1262 allow_fake_constant=options.fake_context is not None, 1263 ).run(*fx_module_args) 1264 1265 # Input mutations are detected and distilled after `Functionalize` pass. 1266 # Remove them since ONNX inference does not need them. 1267 module = passes.RemoveInputMutation(diagnostic_context, module).run(*fx_module_args) 1268 1269 # ONNX does not support concept of (implicit) type promotion. 1270 # Insert type casts explicitly where needed. 1271 module = passes.InsertTypePromotion(diagnostic_context, module).run() 1272 1273 analysis.UnsupportedFxNodesAnalysis( 1274 diagnostic_context, module, options.onnxfunction_dispatcher 1275 ).analyze(infra.levels.ERROR) 1276 1277 if isinstance(original_model, torch.nn.Module): 1278 module = passes.RestoreParameterAndBufferNames( 1279 diagnostic_context, module, original_model 1280 ).run() 1281 1282 # This operation should be invoked as the last pre export pass. 1283 # See [NOTE: Modularize pass ordering] 1284 module = passes.Modularize(diagnostic_context, module).run() 1285 1286 # ONNX does not support None inputs. During graph building, all None inputs 1287 # are removed. Here we register this step to input adapter. 1288 options.fx_tracer.input_adapter.append_step(io_adapter.RemoveNoneInputStep()) 1289 1290 # NOTE: temp workaround for https://github.com/pytorch/pytorch/issues/99534 1291 # Dynamo doesn't support non-tensor inputs. 1292 options.fx_tracer.input_adapter.append_step(io_adapter.RemoveNonTensorInputStep()) 1293 1294 # ONNX does not support complex inputs. During graph building, all complex inputs 1295 # are converted to real representation inputs. Here we register this step to 1296 # input/output adapter. 1297 options.fx_tracer.input_adapter.append_step( 1298 io_adapter.ConvertComplexToRealRepresentationInputStep() 1299 ) 1300 1301 # ONNX can't represent collection types (e.g., dictionary, tuple of tuple of 1302 # tensor, etc), we flatten the collection and register each element as output. 1303 options.fx_tracer.output_adapter.append_step(io_adapter.FlattenOutputStep()) 1304 1305 # Output post-processing steps should happen after `FlattenOutputStep`. 1306 options.fx_tracer.output_adapter.append_step( 1307 io_adapter.ConvertComplexToRealRepresentationOutputStep() 1308 ) 1309 1310 return module 1311