xref: /aosp_15_r20/external/pytorch/torch/onnx/_internal/_exporter_legacy.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# mypy: allow-untyped-defs
2from __future__ import annotations
3
4
5__all__ = [
6    "DiagnosticOptions",
7    "ExportOptions",
8    "ONNXProgram",
9    "ONNXRuntimeOptions",
10    "InvalidExportOptionsError",
11    "OnnxRegistry",
12    "UnsatisfiedDependencyError",
13    "dynamo_export",
14    "enable_fake_mode",
15]
16
17
18import abc
19import contextlib
20import dataclasses
21import logging
22import os
23import tempfile
24import warnings
25from collections import defaultdict
26from typing import Any, Callable, Final, Mapping, Sequence, TYPE_CHECKING, TypeVar
27from typing_extensions import Self
28
29import torch
30import torch._ops
31import torch.utils._pytree as pytree
32from torch.onnx import errors
33from torch.onnx._internal import io_adapter
34from torch.onnx._internal.diagnostics import infra
35from torch.onnx._internal.fx import (
36    decomposition_table,
37    patcher as patcher,
38    registration,
39    serialization as fx_serialization,
40)
41
42
43# We can only import onnx from this module in a type-checking context to ensure that
44# 'import torch.onnx' continues to work without having 'onnx' installed. We fully
45# 'import onnx' inside of dynamo_export (by way of _assert_dependencies).
46if TYPE_CHECKING:
47    import io
48
49    import onnx
50
51    import onnxruntime
52    import onnxscript
53
54    from torch._subclasses import fake_tensor
55    from torch.onnx._internal.fx import diagnostics
56
57_DEFAULT_OPSET_VERSION: Final[int] = 18
58"""The default ONNX opset version the exporter will use if one is not specified explicitly
59through :class:`ExportOptions`. This should NEVER be accessed outside of this module! Users
60should reference :attr:`ExportOptions.opset_version`."""
61
62_PYTORCH_GITHUB_ISSUES_URL = "https://github.com/pytorch/pytorch/issues"
63"""The URL to the PyTorch GitHub issues page."""
64
65_DEFAULT_FAILED_EXPORT_SARIF_LOG_PATH = "report_dynamo_export.sarif"
66"""The default path to write the SARIF log to if the export fails."""
67
68_PROTOBUF_SIZE_MAX_LIMIT = 2 * 1024 * 1024 * 1024
69"""The maximum size of a Protobuf file in bytes. This is used to determine whether to
70serialize the model with external data or not."""
71
72log = logging.getLogger(__name__)
73
74
75DiagnosticOptions = infra.DiagnosticOptions
76
77
78@dataclasses.dataclass
79class ONNXFakeContext:
80    """A dataclass used to store context for model export using FakeTensor.
81
82    This dataclass stores the FakeTensorMode instance used to convert
83    real tensors and model parameters into fake tensors. This :attr:`ONNXFakeContext.fake_mode` is
84    reused internally during tracing of a :class:`torch.nn.Module` into a FX :class:`GraphModule`.
85    """
86
87    fake_mode: fake_tensor.FakeTensorMode
88    """The fake tensor mode used for tracing model using fake tensors and parameters."""
89
90    state_dict_paths: tuple[str | io.BytesIO | dict[str, Any]] | None = None
91    """List of paths of files that contain the model :meth:`state_dict`"""
92
93
94class OnnxRegistry:
95    """Registry for ONNX functions.
96
97    The registry maintains a mapping from qualified names to symbolic functions under a
98    fixed opset version. It supports registering custom onnx-script functions and for
99    dispatcher to dispatch calls to the appropriate function.
100
101    """
102
103    def __init__(self) -> None:
104        """Initializes the registry"""
105
106        # NOTE: _registry is the registry maps OpNameto a list of ONNXFunctions. It is important
107        # not to directly modify this variable. Instead, access to it should be done through
108        # the public methods: register_custom_op, get_ops, and is_registered_op.
109        self._registry: dict[registration.OpName, list[registration.ONNXFunction]] = (
110            defaultdict(list)
111        )
112
113        # opset_version is unused for now, since torchlib only supports opset18.
114        # TODO: get opset version from torchlib
115        self._opset_version = _DEFAULT_OPSET_VERSION
116        warnings.warn(
117            f"torch.onnx.dynamo_export only implements opset version {self._opset_version} for now. If you need to use a "
118            "different opset version, please register them with register_custom_op."
119        )
120
121        self._initiate_registry_from_torchlib()
122
123    @property
124    def opset_version(self) -> int:
125        """The ONNX opset version the exporter should target. Defaults to the latest
126        supported ONNX opset version: 18. The default version will increment over time as
127        ONNX continues to evolve."""
128
129        return self._opset_version
130
131    def _initiate_registry_from_torchlib(self) -> None:
132        """Populates the registry with ATen functions from torchlib.
133
134        Args:
135            torchlib_registry: The torchlib registry to use for populating the registry.
136        """
137        import onnxscript._framework_apis.torch_2_5 as onnxscript_apis
138
139        for meta in onnxscript_apis.get_torchlib_ops():
140            internal_name_instance = registration.OpName.from_qualified_name(
141                meta.qualified_name
142            )
143            symbolic_function = registration.ONNXFunction(
144                onnx_function=meta.function,  # type: ignore[arg-type]
145                op_full_name=internal_name_instance.qualified_name(),
146                is_custom=False,
147                is_complex=meta.is_complex,
148            )
149            self._register(internal_name_instance, symbolic_function)
150
151    def _register(
152        self,
153        internal_qualified_name: registration.OpName,
154        symbolic_function: registration.ONNXFunction,
155    ) -> None:
156        """Registers a ONNXFunction to an operator.
157
158        Args:
159            internal_qualified_name: The qualified name of the operator to register: OpName.
160            symbolic_function: The ONNXFunction to register.
161        """
162        self._registry[internal_qualified_name].append(symbolic_function)
163
164    def register_op(
165        self,
166        function: onnxscript.OnnxFunction | onnxscript.TracedOnnxFunction,
167        namespace: str,
168        op_name: str,
169        overload: str | None = None,
170        is_complex: bool = False,
171    ) -> None:
172        """Registers a custom operator: torch.ops.<namespace>.<op_name>.<overload>.
173
174        Args:
175            function: The onnx-sctip function to register.
176            namespace: The namespace of the operator to register.
177            op_name: The name of the operator to register.
178            overload: The overload of the operator to register. If it's default overload,
179                leave it to None.
180            is_complex: Whether the function is a function that handles complex valued inputs.
181
182        Raises:
183            ValueError: If the name is not in the form of 'namespace::op'.
184        """
185        internal_name_instance = registration.OpName.from_name_parts(
186            namespace=namespace, op_name=op_name, overload=overload
187        )
188        symbolic_function = registration.ONNXFunction(
189            onnx_function=function,
190            op_full_name=internal_name_instance.qualified_name(),
191            is_custom=True,
192            is_complex=is_complex,
193        )
194        self._register(internal_name_instance, symbolic_function)
195
196    def get_op_functions(
197        self, namespace: str, op_name: str, overload: str | None = None
198    ) -> list[registration.ONNXFunction] | None:
199        """Returns a list of ONNXFunctions for the given op: torch.ops.<namespace>.<op_name>.<overload>.
200
201        The list is ordered by the time of registration. The custom operators should be
202        in the second half of the list.
203
204        Args:
205            namespace: The namespace of the operator to get.
206            op_name: The name of the operator to get.
207            overload: The overload of the operator to get. If it's default overload,
208                leave it to None.
209        Returns:
210            A list of ONNXFunctions corresponding to the given name, or None if
211            the name is not in the registry.
212        """
213        internal_name_instance = registration.OpName.from_name_parts(
214            namespace=namespace, op_name=op_name, overload=overload
215        )
216        return self._registry.get(internal_name_instance)
217
218    def is_registered_op(
219        self, namespace: str, op_name: str, overload: str | None = None
220    ) -> bool:
221        """Returns whether the given op is registered: torch.ops.<namespace>.<op_name>.<overload>.
222
223        Args:
224            namespace: The namespace of the operator to check.
225            op_name: The name of the operator to check.
226            overload: The overload of the operator to check. If it's default overload,
227                leave it to None.
228
229        Returns:
230            True if the given op is registered, otherwise False.
231        """
232        functions = self.get_op_functions(
233            namespace=namespace, op_name=op_name, overload=overload
234        )
235        return functions is not None
236
237    def _all_registered_ops(self) -> set[str]:
238        """Returns the set of all registered function names."""
239        return {
240            op_name_class.qualified_name() for op_name_class in self._registry.keys()
241        }
242
243
244class ExportOptions:
245    """Options to influence the TorchDynamo ONNX exporter.
246
247    Attributes:
248        dynamic_shapes: Shape information hint for input/output tensors.
249            When ``None``, the exporter determines the most compatible setting.
250            When ``True``, all input shapes are considered dynamic.
251            When ``False``, all input shapes are considered static.
252        diagnostic_options: The diagnostic options for the exporter.
253        fake_context: The fake context used for symbolic tracing.
254        onnx_registry: The ONNX registry used to register ATen operators to ONNX functions.
255    """
256
257    dynamic_shapes: bool | None = None
258    """Shape information hint for input/output tensors.
259
260    - ``None``: the exporter determines the most compatible setting.
261    - ``True``: all input shapes are considered dynamic.
262    - ``False``: all input shapes are considered static.
263    """
264
265    diagnostic_options: DiagnosticOptions
266    """The diagnostic options for the exporter."""
267
268    fake_context: ONNXFakeContext | None = None
269    """The fake context used for symbolic tracing."""
270
271    onnx_registry: OnnxRegistry | None = None
272    """The ONNX registry used to register ATen operators to ONNX functions."""
273
274    def __init__(
275        self,
276        *,
277        dynamic_shapes: bool | None = None,
278        fake_context: ONNXFakeContext | None = None,
279        onnx_registry: OnnxRegistry | None = None,
280        diagnostic_options: DiagnosticOptions | None = None,
281    ):
282        self.dynamic_shapes = dynamic_shapes
283        self.fake_context = fake_context
284        self.onnx_registry = onnx_registry
285        self.diagnostic_options = diagnostic_options or DiagnosticOptions()
286
287
288class ResolvedExportOptions(ExportOptions):
289    """Consolidates :class:`ExportOptions` with default values.
290    All unspecified options from :class:`ExportOptions` are assigned a default value.
291    This is an internal class and its API may be changed at any time without notice.
292    """
293
294    # Public attributes MUST be redefined below without ``Optional[]`` from ``ExportOptions``
295    dynamic_shapes: bool
296    diagnostic_options: DiagnosticOptions
297    fake_context: ONNXFakeContext
298    onnx_registry: OnnxRegistry
299
300    # Private only attributes
301    decomposition_table: dict[torch._ops.OpOverload, Callable]
302    """A dictionary that maps operators to their decomposition functions."""
303
304    onnxfunction_dispatcher: (
305        torch.onnx._internal.fx.onnxfunction_dispatcher.OnnxFunctionDispatcher
306    )
307    """The ONNX dispatcher used to dispatch ATen operators to ONNX functions."""
308
309    fx_tracer: FXGraphExtractor
310    """The FXGraphExtractor instance used to extract the FX graph from the model."""
311
312    diagnostic_context: diagnostics.DiagnosticContext
313    """The diagnostics context for the export. Responsible for recording diagnostics,
314    logging diagnostics, and generating the SARIF log."""
315
316    def __init__(
317        self,
318        options: ExportOptions | ResolvedExportOptions,
319        model: torch.nn.Module | Callable | None = None,  # type: ignore[name-defined]
320    ):
321        from torch.onnx._internal.fx import (  # TODO: Prevent circular dep
322            diagnostics,
323            dynamo_graph_extractor,
324        )
325
326        if isinstance(options, ResolvedExportOptions):
327            self.dynamic_shapes = options.dynamic_shapes
328            self.diagnostic_options = options.diagnostic_options
329            self.fake_context = options.fake_context
330            self.fx_tracer = options.fx_tracer
331            self.onnx_registry = options.onnx_registry
332            self.onnxfunction_dispatcher = options.onnxfunction_dispatcher
333            self.decomposition_table = options.decomposition_table
334            self.diagnostic_context = options.diagnostic_context
335        else:
336            T = TypeVar("T")
337
338            def resolve(value: T | None, fallback: T | Callable[[], T]) -> T:
339                if value is not None:
340                    return value
341                if callable(fallback):
342                    return fallback()
343                return fallback
344
345            self.dynamic_shapes = resolve(options.dynamic_shapes, False)
346
347            self.diagnostic_options = resolve(
348                options.diagnostic_options, DiagnosticOptions()
349            )
350
351            self.fx_tracer = dynamo_graph_extractor.DynamoExport()
352
353            self.fake_context = resolve(options.fake_context, None)  # type: ignore[arg-type]
354            self.diagnostic_context = diagnostics.DiagnosticContext(
355                "torch.onnx.dynamo_export",
356                torch.__version__,
357                self.diagnostic_options,
358            )
359
360            self.onnx_registry = resolve(options.onnx_registry, OnnxRegistry())
361            self.decomposition_table = (
362                decomposition_table.create_onnx_friendly_decomposition_table(  # type: ignore[assignment]
363                    self.onnx_registry
364                )
365            )
366
367            from torch.onnx._internal.fx import onnxfunction_dispatcher
368
369            self.onnxfunction_dispatcher = (
370                onnxfunction_dispatcher.OnnxFunctionDispatcher(
371                    self.onnx_registry,
372                    self.diagnostic_context,
373                )
374            )
375
376            for key in dir(options):
377                if not key.startswith("_"):  # skip private attributes
378                    assert hasattr(self, key), f"Unresolved option '{key}'"
379
380
381@contextlib.contextmanager
382def enable_fake_mode():
383    """Enable fake mode for the duration of the context.
384
385    Internally it instantiates a :class:`torch._subclasses.fake_tensor.FakeTensorMode` context manager
386    that converts user input and model parameters into :class:`torch._subclasses.fake_tensor.FakeTensor`.
387
388    A :class:`torch._subclasses.fake_tensor.FakeTensor`
389    is a :class:`torch.Tensor` with the ability to run PyTorch code without having to
390    actually do computation through tensors allocated on a ``meta`` device. Because
391    there is no actual data being allocated on the device, this API allows for
392    exporting large models without the actual memory footprint needed for executing it.
393
394    It is highly recommended to enable fake mode when exporting models that
395    are too large to fit into memory.
396
397    Returns:
398        A :class:`ONNXFakeContext` object.
399
400    Example::
401
402        # xdoctest: +REQUIRES(env:TORCH_DOCTEST_ONNX)
403        >>> import torch
404        >>> import torch.onnx
405        >>> class MyModel(torch.nn.Module):  # Dummy model
406        ...     def __init__(self) -> None:
407        ...         super().__init__()
408        ...         self.linear = torch.nn.Linear(2, 2)
409        ...     def forward(self, x):
410        ...         out = self.linear(x)
411        ...         return out
412        >>> with torch.onnx.enable_fake_mode() as fake_context:
413        ...     my_nn_module = MyModel()
414        ...     arg1 = torch.randn(2, 2, 2)  # positional input 1
415        >>> export_options = torch.onnx.ExportOptions(fake_context=fake_context)
416        >>> onnx_program = torch.onnx.export(my_nn_module, (arg1,), dynamo=True)
417        >>> onnx_program.apply_weights(MyModel().state_dict())
418        >>> # Saving model WITHOUT initializers
419        >>> onnx_program.save(
420        ...     "my_model_without_initializers.onnx",
421        ...     include_initializers=False,
422        ...     keep_initializers_as_inputs=True,
423        ... )
424        >>> # Saving model WITH initializers
425        >>> onnx_program.save("my_model_with_initializers.onnx")
426
427    .. warning::
428        This API is experimental and is *NOT* backward-compatible.
429
430    """
431    from torch._subclasses import fake_tensor
432    from torch.fx.experimental.symbolic_shapes import ShapeEnv
433
434    # This overrides the internal `FakeTensorMode` instance created by `torch._dynamo.export`[1].
435    # It is a good idea to keep them in sync (constructor args) to maintain the same default behavior
436    # [1] `torch/_dynamo/output_graph.py::InstructionTranslator::OutputGraph.__init__`
437    # Mixed fake/real tensors are only allowed when `torch.onnx.dynamo_export` is not called within `FakeTensorMode`
438    # This is needed because models can create new parameters during `forward(self, *args, **kwargs)` run
439    fake_mode = fake_tensor.FakeTensorMode(
440        allow_non_fake_inputs=not torch._guards.detect_fake_mode(),
441        shape_env=ShapeEnv(
442            allow_scalar_outputs=False, allow_dynamic_output_shape_ops=False
443        ),
444    )
445    # The patcher is needed for when user calls `fake_model.load_state_dict(...)` within fake mode
446    patcher_context = patcher.ONNXTorchPatcher()
447    fake_context = ONNXFakeContext(fake_mode=fake_mode)
448    with fake_mode, patcher_context:
449        yield fake_context
450    fake_context.state_dict_paths = tuple(
451        patcher_context.paths,
452    )  # type: ignore[assignment]
453
454
455class ONNXRuntimeOptions:
456    """Options to influence the execution of the ONNX model through ONNX Runtime.
457
458    Attributes:
459        session_options: ONNX Runtime session options.
460        execution_providers: ONNX Runtime execution providers to use during model execution.
461        execution_provider_options: ONNX Runtime execution provider options.
462    """
463
464    session_options: Sequence[onnxruntime.SessionOptions] | None = None
465    """ONNX Runtime session options."""
466
467    execution_providers: Sequence[str | tuple[str, dict[Any, Any]]] | None = None
468    """ONNX Runtime execution providers to use during model execution."""
469
470    execution_provider_options: Sequence[dict[Any, Any]] | None = None
471    """ONNX Runtime execution provider options."""
472
473    def __init__(
474        self,
475        *,
476        session_options: Sequence[onnxruntime.SessionOptions] | None = None,
477        execution_providers: Sequence[str | tuple[str, dict[Any, Any]]] | None = None,
478        execution_provider_options: Sequence[dict[Any, Any]] | None = None,
479    ):
480        self.session_options = session_options
481        self.execution_providers = execution_providers
482        self.execution_provider_options = execution_provider_options
483
484
485class ONNXProgram:
486    """An in-memory representation of a PyTorch model that has been exported to ONNX.
487
488    Args:
489        model_proto: The exported ONNX model as an :py:obj:`onnx.ModelProto`.
490        input_adapter: The input adapter used to convert PyTorch inputs into ONNX inputs.
491        output_adapter: The output adapter used to convert PyTorch outputs into ONNX outputs.
492        diagnostic_context: Context object for the SARIF diagnostic system responsible for logging errors and metadata.
493        fake_context: The fake context used for symbolic tracing.
494        export_exception: The exception that occurred during export, if any.
495    """
496
497    _model_proto: Final[onnx.ModelProto]  # type: ignore[name-defined, misc]
498    _input_adapter: Final[io_adapter.InputAdapter]  # type: ignore[misc]
499    _output_adapter: Final[io_adapter.OutputAdapter]  # type: ignore[misc]
500    _diagnostic_context: Final[diagnostics.DiagnosticContext]  # type: ignore[misc]
501    _fake_context: Final[ONNXFakeContext | None]  # type: ignore[misc]
502    _export_exception: Final[Exception | None]  # type: ignore[misc]
503    _model_torch: Final[  # type: ignore[misc]
504        torch.nn.Module | Callable | None
505    ]
506
507    def __init__(
508        self,
509        model_proto: onnx.ModelProto,  # type: ignore[name-defined]
510        input_adapter: io_adapter.InputAdapter,
511        output_adapter: io_adapter.OutputAdapter,
512        diagnostic_context: diagnostics.DiagnosticContext,
513        *,
514        fake_context: ONNXFakeContext | None = None,
515        export_exception: Exception | None = None,
516        model_torch: torch.nn.Module | Callable | None = None,
517    ):
518        self._model_proto = model_proto
519        self._model_torch = model_torch
520        self._input_adapter = input_adapter
521        self._output_adapter = output_adapter
522        self._diagnostic_context = diagnostic_context
523        self._fake_context = fake_context
524        self._export_exception = export_exception
525        self._state_dict: dict[str, torch.Tensor] = {}
526
527    def __call__(
528        self,
529        *args: Any,
530        model_with_state_dict: torch.nn.Module | Callable | None = None,
531        options: ONNXRuntimeOptions | None = None,
532        **kwargs: Any,
533    ) -> Any:
534        """Runs the ONNX model using ONNX Runtime
535
536        Args:
537            args: The positional inputs to the model.
538            kwargs: The keyword inputs to the model.
539            model_with_state_dict: The PyTorch model to fetch state from.
540                Required when :func:`enable_fake_mode` is used to extract real initializers as needed by the ONNX graph.
541            options: The options to use for running the model with ONNX Runtime.
542
543        Returns:
544            The model output as computed by ONNX Runtime
545        """
546
547        # TODO: If ONNX used absolute paths on the initializers external data files,
548        # users could call ONNXProgram.save and use ONNXProgram.__call__ without the internal save below
549        with contextlib.ExitStack() as stack:
550            # model specified by the user has precedence, when specified
551            model_with_state_dict = model_with_state_dict or self._model_torch
552
553            if self.fake_context:
554                tmpdir_path = stack.enter_context(tempfile.TemporaryDirectory())
555                warnings.warn(
556                    "Cannot run model directly from `ONNXProgram` because"
557                    " the model was exported using `enable_fake_mode`."
558                    " The model will be serialized to disk using a temporary folder ({tmpdir_path})"
559                    " to populate the model with initializers before being execution."
560                )
561                # TODO: Revisit the need of `model_with_state_dict` being a real model and not just its state
562                onnx_model = os.path.join(tmpdir_path, "model.onnx")
563                if isinstance(model_with_state_dict, torch.nn.Module):
564                    model_state = model_with_state_dict.state_dict()
565                else:
566                    model_state = self._state_dict
567                self.save(
568                    onnx_model,
569                    model_state=model_state,
570                )
571            else:
572                onnx_model = self.model_proto.SerializeToString()  # type: ignore[assignment]
573
574            import onnxruntime  # type: ignore[import]
575
576            onnx_input = self.adapt_torch_inputs_to_onnx(
577                *args, model_with_state_dict=model_with_state_dict, **kwargs
578            )
579            options = options or ONNXRuntimeOptions()
580            providers = (
581                options.execution_providers or onnxruntime.get_available_providers()
582            )
583            ort_session = onnxruntime.InferenceSession(onnx_model, providers=providers)
584
585            onnxruntime_input = {
586                k.name: v.numpy(force=True)  # type: ignore[union-attr]
587                for k, v in zip(ort_session.get_inputs(), onnx_input)
588            }
589
590            return ort_session.run(None, onnxruntime_input)
591
592    @property
593    def model_proto(self) -> onnx.ModelProto:  # type: ignore[name-defined]
594        """The exported ONNX model as an :py:obj:`onnx.ModelProto`."""
595
596        if self._export_exception is not None:
597            raise self._export_exception
598        return self._model_proto
599
600    @property
601    def diagnostic_context(self) -> diagnostics.DiagnosticContext:
602        """The diagnostic context associated with the export."""
603
604        return self._diagnostic_context
605
606    @property
607    def fake_context(self) -> ONNXFakeContext | None:
608        """The fake context associated with the export."""
609
610        return self._fake_context
611
612    def adapt_torch_inputs_to_onnx(
613        self,
614        *model_args,
615        model_with_state_dict: torch.nn.Module | Callable | None = None,
616        **model_kwargs,
617    ) -> Sequence[torch.Tensor | int | float | bool | torch.dtype]:
618        """Converts the PyTorch model inputs to exported ONNX model inputs format.
619
620        Due to design differences, input/output format between PyTorch model and exported
621        ONNX model are often not the same. E.g., None is allowed for PyTorch model, but are
622        not supported by ONNX. Nested constructs of tensors are allowed for PyTorch model,
623        but only flattened tensors are supported by ONNX, etc.
624
625        The actual adapting steps are associated with each individual export. It
626        depends on the PyTorch model, the particular set of model_args and model_kwargs
627        used for the export, and export options.
628
629        This method replays the adapting steps recorded during export.
630
631        Args:
632            model_args: The PyTorch model inputs.
633            model_with_state_dict: The PyTorch model to get extra state from.
634                If not specified, the model used during export is used.
635                Required when :func:`enable_fake_mode` is used to extract real initializers as needed by the ONNX graph.
636            model_kwargs: The PyTorch model keyword inputs.
637
638        Returns:
639            A sequence of tensors converted from PyTorch model inputs.
640
641        Example::
642
643            # xdoctest: +REQUIRES(env:TORCH_DOCTEST_ONNX)
644            >>> import torch
645            >>> import torch.onnx
646            >>> from typing import Dict, Tuple
647            >>> def func_nested_input(
648            ...     x_dict: Dict[str, torch.Tensor],
649            ...     y_tuple: Tuple[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]
650            ... ):
651            ...     if "a" in x_dict:
652            ...         x = x_dict["a"]
653            ...     elif "b" in x_dict:
654            ...         x = x_dict["b"]
655            ...     else:
656            ...         x = torch.randn(3)
657            ...
658            ...     y1, (y2, y3) = y_tuple
659            ...
660            ...     return x + y1 + y2 + y3
661            >>> x_dict = {"a": torch.tensor(1.)}
662            >>> y_tuple = (torch.tensor(2.), (torch.tensor(3.), torch.tensor(4.)))
663            >>> onnx_program = torch.onnx.dynamo_export(func_nested_input, x_dict, y_tuple)
664            >>> print(x_dict, y_tuple)
665            {'a': tensor(1.)} (tensor(2.), (tensor(3.), tensor(4.)))
666            >>> print(onnx_program.adapt_torch_inputs_to_onnx(x_dict, y_tuple, model_with_state_dict=func_nested_input))
667            (tensor(1.), tensor(2.), tensor(3.), tensor(4.))
668
669        .. warning::
670            This API is experimental and is *NOT* backward-compatible.
671
672        """
673        # model specified by the user has precedence, when specified
674        model_with_state_dict = model_with_state_dict or self._model_torch
675        assert (
676            model_with_state_dict is not None
677        ), "model_with_state_dict must be specified."
678        return self._input_adapter.apply(  # type: ignore[return-value]
679            *model_args, model=model_with_state_dict, **model_kwargs
680        )
681
682    def adapt_torch_outputs_to_onnx(
683        self,
684        model_outputs: Any,
685        model_with_state_dict: torch.nn.Module | Callable | None = None,
686    ) -> Sequence[torch.Tensor | int | float | bool]:
687        """Converts the PyTorch model outputs to exported ONNX model outputs format.
688
689        Due to design differences, input/output format between PyTorch model and exported
690        ONNX model are often not the same. E.g., None is allowed for PyTorch model, but are
691        not supported by ONNX. Nested constructs of tensors are allowed for PyTorch model,
692        but only flattened tensors are supported by ONNX, etc.
693
694        The actual adapting steps are associated with each individual export. It
695        depends on the PyTorch model, the particular set of model_args and model_kwargs
696        used for the export, and export options.
697
698        This method replays the adapting steps recorded during export.
699
700        Args:
701            model_outputs: The PyTorch model outputs.
702            model_with_state_dict: The PyTorch model to get extra state from.
703                If not specified, the model used during export is used.
704                Required when :func:`enable_fake_mode` is used to extract real initializers as needed by the ONNX graph.
705
706        Returns:
707            PyTorch model outputs in exported ONNX model outputs format.
708
709        Example::
710
711            # xdoctest: +REQUIRES(env:TORCH_DOCTEST_ONNX)
712            >>> import torch
713            >>> import torch.onnx
714            >>> def func_returning_tuples(x, y, z):
715            ...     x = x + y
716            ...     y = y + z
717            ...     z = x + y
718            ...     return (x, (y, z))
719            >>> x = torch.tensor(1.)
720            >>> y = torch.tensor(2.)
721            >>> z = torch.tensor(3.)
722            >>> onnx_program = torch.onnx.dynamo_export(func_returning_tuples, x, y, z)
723            >>> pt_output = func_returning_tuples(x, y, z)
724            >>> print(pt_output)
725            (tensor(3.), (tensor(5.), tensor(8.)))
726            >>> print(onnx_program.adapt_torch_outputs_to_onnx(pt_output, model_with_state_dict=func_returning_tuples))
727            [tensor(3.), tensor(5.), tensor(8.)]
728
729        .. warning::
730            This API is experimental and is *NOT* backward-compatible.
731
732        """
733        # model specified by the user has precedence, when specified
734        model_with_state_dict = model_with_state_dict or self._model_torch
735        assert (
736            model_with_state_dict is not None
737        ), "model_with_state_dict must be specified."
738        return self._output_adapter.apply(model_outputs, model=model_with_state_dict)  # type: ignore[return-value]
739
740    def apply_weights(self, state_dict: dict[str, torch.Tensor]) -> None:
741        """Apply the weights from the specified state dict to the ONNX model.
742        Args:
743            state_dict: The state dict containing the weights to apply to the ONNX model.
744        """
745        self._state_dict = state_dict
746
747    def save(
748        self,
749        destination: str | io.BufferedIOBase,
750        *,
751        include_initializers: bool = True,
752        model_state: dict[str, Any] | str | None = None,
753    ) -> None:
754        """Saves the in-memory ONNX model to ``destination`` using specified ``serializer``.
755
756        Args:
757            destination: The destination to save the ONNX model. It can be either a string or a file-like object.
758                When used with ``model_state``, it must be a string with a full path to the destination.
759                If `destination` is a string, besides saving the ONNX model into a file, model weights are also stored
760                in separate files in the same directory as the ONNX model. E.g. for `destination="/path/model.onnx"`,
761                the initializers are saved in "/path/" folder along with "onnx.model".
762            include_initializers: Whether to include initializers in the ONNX graph as external data.
763                Cannot be combined with `model_state_dict`.
764            model_state: The state_dict of the PyTorch model containing all weights on it.
765                It can be either a string with the path to a checkpoint or a dictionary with the actual model state.
766                The supported file formats are the same as those supported by `torch.load` and `safetensors.safe_open`.
767                Required when :func:`enable_fake_mode` is used but real initializers are needed on the ONNX graph.
768        """
769        import onnx
770
771        assert (
772            include_initializers is True or model_state is None
773        ), "Cannot specify both `include_initializers=False` and `model_state`."
774
775        if self._state_dict and model_state is None:
776            model_state = self._state_dict
777
778        # Add initializers when symbolic tracing is enabled
779        _model_state_files: list[str | io.BytesIO | dict[str, Any]] = []
780        if include_initializers:
781            if model_state is not None:
782                assert isinstance(
783                    model_state, (dict, str)
784                ), "model_state must be a path to the model's state_dict or the actual state_dict"
785                # NOTE: For dict, there can be performance penalty or high memory usage that might lead to OOM
786                #       if the dict wasn't loaded with torch.load(..., mmap=True, map_location="cpu")
787                _model_state_files.append(model_state)
788            elif self._fake_context and self._fake_context.state_dict_paths:
789                # Load state from previous model.load_state_dict() call within enable_fake_mode() context
790                for path in self._fake_context.state_dict_paths:
791                    if path in _model_state_files:
792                        # ignore duplicate
793                        continue
794                    if os.path.exists(path):  # type: ignore[arg-type]
795                        _model_state_files.append(path)
796        else:
797            # self.model_proto.graph.initializer.clear() not available in older protobuf versions
798            initializer_count = len(self.model_proto.graph.initializer)
799            for _ in range(initializer_count):
800                del self.model_proto.graph.initializer[0]
801
802        if _model_state_files:
803            if not isinstance(destination, str):
804                raise RuntimeError(
805                    "`destination` must be a string with a path when `model_state` is specified."
806                )
807            destination_path, destination_filename = os.path.split(destination)
808            destination_path = destination_path or os.getcwd()
809            onnx_model_location = destination_filename
810
811            # TODO: Should this be part of the serializer?
812            fx_serialization.save_model_with_external_data(
813                destination_path,
814                onnx_model_location,
815                "",  # When initializers >2GB, must be in the same folder as the model
816                tuple(_model_state_files),
817                self.model_proto,
818            )
819        else:
820            if isinstance(destination, str):
821                with open(destination, "wb") as f:
822                    if self.model_proto.ByteSize() < _PROTOBUF_SIZE_MAX_LIMIT:
823                        onnx.save_model(self.model_proto, destination)  # type: ignore[attr-defined]
824                    else:
825                        # ValueError: Message onnx.ModelProto exceeds maximum protobuf size of 2GB
826                        # Fallback to serializing the model with external data.
827                        onnx.save_model(  # type: ignore[attr-defined]
828                            self.model_proto,
829                            destination,
830                            save_as_external_data=True,
831                            all_tensors_to_one_file=True,
832                        )
833            else:
834                try:
835                    destination.write(self.model_proto.SerializeToString())
836                except ValueError as exc:
837                    raise ValueError(
838                        "'destination' should be provided as a path-like string when saving a model larger than 2GB. "
839                        "External tensor data will be saved alongside the model on disk."
840                    ) from exc
841
842    def save_diagnostics(self, destination: str) -> None:
843        """Saves the export diagnostics as a SARIF log to the specified destination path.
844
845        Args:
846            destination: The destination to save the diagnostics SARIF log.
847                It must have a `.sarif` extension.
848
849        Raises:
850            ValueError: If the destination path does not end with `.sarif` extension.
851        """
852        if not destination.endswith(".sarif"):
853            message = f"'destination' must have a .sarif extension, got {destination}"
854            log.fatal(message)
855            raise ValueError(message)
856
857        self.diagnostic_context.dump(destination)
858
859    @classmethod
860    def _from_failure(
861        cls,
862        export_exception: Exception,
863        diagnostic_context: diagnostics.DiagnosticContext,
864    ) -> Self:
865        """
866        Creates an instance of :class:`ONNXProgram` when the export process encounters a failure.
867
868        In case of a failed export, this method is used to encapsulate the exception
869        and associated diagnostic context within an :class:`ONNXProgram` instance for
870        easier handling and debugging.
871
872        Args:
873            export_exception: The exception raised during the export process.
874            diagnostic_context: The context associated with diagnostics during export.
875
876        Returns:
877            An instance of :class:`ONNXProgram` representing the failed ONNX program.
878        """
879        # Defer `import onnx` out of `import torch` path
880        # https://github.com/pytorch/pytorch/issues/103764
881        import onnx
882
883        return cls(
884            onnx.ModelProto(),  # type: ignore[attr-defined]
885            io_adapter.InputAdapter(),
886            io_adapter.OutputAdapter(),
887            diagnostic_context,
888            export_exception=export_exception,
889        )
890
891
892class FXGraphExtractor(abc.ABC):
893    """Abstract interface for FX graph extractor engines.
894    This class isolates FX extraction logic from the rest of the export logic.
895    That allows a single ONNX exporter that can leverage different FX graphs."""
896
897    def __init__(self) -> None:
898        super().__init__()
899        self.input_adapter: io_adapter.InputAdapter = io_adapter.InputAdapter()
900        self.output_adapter: io_adapter.OutputAdapter = io_adapter.OutputAdapter()
901
902    @abc.abstractmethod
903    def generate_fx(
904        self,
905        options: ResolvedExportOptions,
906        model: torch.nn.Module | Callable,
907        model_args: Sequence[Any],
908        model_kwargs: Mapping[str, Any],
909    ) -> torch.fx.GraphModule:
910        """Analyzes user ``model`` and generates a FX graph.
911        Args:
912            options: The export options.
913            model: The user model.
914            model_args: The model's positional input arguments.
915            model_kwargs: The model's keyword input arguments.
916        Returns:
917            The generated FX Graph.
918        """
919        ...
920
921    # TODO: Design the passes API
922    @abc.abstractmethod
923    def pre_export_passes(
924        self,
925        options: ResolvedExportOptions,
926        original_model: torch.nn.Module | Callable,
927        fx_module: torch.fx.GraphModule,
928        fx_module_args: Sequence[Any],
929    ):
930        """Applies pre-export passes to the FX graph.
931
932        Pre-export passes are FX-to-FX graph transformations that make the graph
933        more palatable for the FX-to-ONNX conversion.
934        For example, it can be used to flatten model input/output, add explicit
935        casts to the graph, replace/decompose operators, functionalize the graph, etc.
936        """
937        ...
938
939
940class Exporter:
941    def __init__(
942        self,
943        options: ResolvedExportOptions,
944        model: torch.nn.Module | Callable,
945        model_args: Sequence[Any],
946        model_kwargs: Mapping[str, Any],
947    ):
948        self.options = options
949        assert self.options is not None
950
951        self.model = model
952        self.model_args = model_args
953        self.model_kwargs = model_kwargs
954
955        # TODO: https://github.com/pytorch/pytorch/issues/107714
956        # NOTE: FXSymbolicTracer would fail in this assert, as it does not use `enable_fake_mode`
957        from torch.onnx._internal.fx import fx_symbolic_graph_extractor
958
959        if not isinstance(
960            self.options.fx_tracer, fx_symbolic_graph_extractor.FXSymbolicTracer
961        ):
962            self._assert_fake_tensor_mode()
963
964    def export(self) -> ONNXProgram:
965        from torch.export._trace import (  # TODO: Prevent circular dependency
966            DEFAULT_EXPORT_DYNAMO_CONFIG,
967        )
968
969        # TODO: Defer `import onnxscript` out of `import torch` path
970        # https://github.com/pytorch/pytorch/issues/103764
971        from torch.onnx._internal.fx import decomposition_skip
972
973        with self.options.diagnostic_context, decomposition_skip.enable_decomposition_skips(
974            self.options
975        ), torch._dynamo.config.patch(dataclasses.asdict(DEFAULT_EXPORT_DYNAMO_CONFIG)):
976            graph_module = self.options.fx_tracer.generate_fx(
977                self.options, self.model, self.model_args, self.model_kwargs
978            )
979            # TODO: Defer `import onnxscript` out of `import torch` path
980            # https://github.com/pytorch/pytorch/issues/103764
981            from torch.onnx._internal.fx import fx_onnx_interpreter
982
983            fx_interpreter = fx_onnx_interpreter.FxOnnxInterpreter(
984                diagnostic_context=self.options.diagnostic_context
985            )
986            onnxscript_graph = fx_interpreter.run(
987                fx_graph_module=graph_module,
988                onnxfunction_dispatcher=self.options.onnxfunction_dispatcher,
989            )
990
991            # NOTE: Filter out the initializers with fake tensors when it's fake_mode exporting.
992            # Otherwise, the ONNX exporter will fail: RuntimeError: basic_string::_M_construct null
993            # not valid.
994            # Concrete data is expected to be filled for those initializers later during `ONNXProgram.save`.
995            if self.options.fake_context is not None:
996                initializers_with_real_tensors: dict[str, torch.Tensor] = {}
997                for (
998                    initializer_name,
999                    initializer,
1000                ) in onnxscript_graph.initializers.items():
1001                    if not isinstance(initializer, torch._subclasses.FakeTensor):
1002                        initializers_with_real_tensors[initializer_name] = initializer
1003                onnxscript_graph.initializers = initializers_with_real_tensors
1004
1005            # Export TorchScript graph to ONNX ModelProto.
1006            onnx_model = onnxscript_graph.to_model_proto(
1007                self.options.onnx_registry.opset_version,
1008            )
1009
1010            try:
1011                from onnxscript import optimizer
1012
1013                onnx_model = optimizer.optimize(onnx_model)
1014            except ImportError:
1015                warnings.warn(
1016                    "ONNXScript optimizer is not available. Skipping optimization. "
1017                    "Please `pip install onnxscript -U` to enable post-export optimization."
1018                )
1019            except Exception as e:
1020                warnings.warn(
1021                    "ONNXScript optimizer failed. Skipping optimization. "
1022                    "\n\nPLEASE REPORT A BUG AT https://github.com/microsoft/onnxscript/issues "
1023                    f"\n\nDetail:\n{e}"
1024                )
1025
1026            return torch.onnx.ONNXProgram(
1027                onnx_model,
1028                self.options.fx_tracer.input_adapter,
1029                self.options.fx_tracer.output_adapter,
1030                self.options.diagnostic_context,
1031                fake_context=self.options.fake_context,
1032                model_torch=self.model,
1033            )
1034
1035    def _assert_fake_tensor_mode(self):
1036        """Asserts that the model and its input do not contain fake tensors."""
1037
1038        # Case 1: Model with fake inputs/weights and without enabling fake mode
1039        has_any_fake_tensor = pytree.tree_any(
1040            lambda x: isinstance(x, torch._subclasses.FakeTensor),
1041            (self.model_args, self.model_kwargs),
1042        )
1043        has_any_fake_param_or_buffer = False
1044        if isinstance(self.model, torch.nn.Module):
1045            has_any_fake_param_or_buffer = pytree.tree_any(
1046                lambda x: isinstance(x, torch._subclasses.FakeTensor),
1047                (self.model.parameters(), self.model.buffers()),
1048            )
1049        if (
1050            has_any_fake_tensor or has_any_fake_param_or_buffer
1051        ) and not self.options.fake_context:
1052            raise RuntimeError(
1053                "Cannot export a model with fake inputs/weights without enabling fake mode.",
1054            )
1055        # Case 2: Model with non fake inputs/weights and enabled fake mode
1056        has_any_non_fake_tensors = pytree.tree_any(
1057            lambda x: isinstance(x, torch.Tensor)
1058            and not isinstance(x, torch._subclasses.FakeTensor),
1059            (self.model_args, self.model_kwargs),
1060        )
1061        has_any_non_fake_param_or_buffer = False
1062        if isinstance(self.model, torch.nn.Module):
1063            has_any_non_fake_param_or_buffer = pytree.tree_any(
1064                lambda x: isinstance(x, torch.Tensor)
1065                and not isinstance(x, torch._subclasses.FakeTensor),
1066                (self.model.parameters(), self.model.buffers()),
1067            )
1068        if (
1069            has_any_non_fake_tensors or has_any_non_fake_param_or_buffer
1070        ) and self.options.fake_context:
1071            raise RuntimeError(
1072                "Cannot export a model with non fake inputs/weights and enabled fake mode.",
1073            )
1074
1075
1076class UnsatisfiedDependencyError(RuntimeError):
1077    """Raised when an ONNX exporter dependency cannot be satisfied."""
1078
1079    def __init__(self, package_name: str, message: str):
1080        super().__init__(message)
1081        self.package_name = package_name
1082
1083
1084class InvalidExportOptionsError(RuntimeError):
1085    """Raised when user specified an invalid value for the :class:`ExportOptions`."""
1086
1087
1088def _assert_dependencies(export_options: ResolvedExportOptions):
1089    opset_version = export_options.onnx_registry.opset_version
1090
1091    def missing_package(package_name: str, exc_info: logging._ExcInfoType):
1092        message = (
1093            f"Please install the `{package_name}` package "
1094            f"(e.g. `python -m pip install {package_name}`)."
1095        )
1096        log.fatal(message, exc_info=exc_info)
1097        return UnsatisfiedDependencyError(package_name, message)
1098
1099    def missing_opset(package_name: str):
1100        message = (
1101            f"The installed `{package_name}` does not support the specified ONNX opset "
1102            f"version {opset_version}. Install a newer `{package_name}` package or "
1103            f"specify an older opset version."
1104        )
1105        log.fatal(message)
1106        return UnsatisfiedDependencyError(package_name, message)
1107
1108    try:
1109        import onnx
1110    except ImportError as e:
1111        raise missing_package("onnx", e) from e
1112
1113    if onnx.defs.onnx_opset_version() < opset_version:
1114        raise missing_opset("onnx")
1115
1116    try:
1117        # PyTorch runs lintrunner in CI without onnxscript installed
1118        import onnxscript  # type: ignore[import]
1119    except ImportError as e:
1120        raise missing_package("onnxscript", e) from e
1121
1122    if not isinstance(
1123        onnxscript.onnx_opset.all_opsets[("", opset_version)],
1124        onnxscript.values.Opset,
1125    ):
1126        raise missing_opset("onnxscript")
1127
1128
1129def dynamo_export(
1130    model: torch.nn.Module | Callable,
1131    /,
1132    *model_args,
1133    export_options: ExportOptions | None = None,
1134    **model_kwargs,
1135) -> ONNXProgram | Any:
1136    """Export a torch.nn.Module to an ONNX graph.
1137
1138    Args:
1139        model: The PyTorch model to be exported to ONNX.
1140        model_args: Positional inputs to ``model``.
1141        model_kwargs: Keyword inputs to ``model``.
1142        export_options: Options to influence the export to ONNX.
1143
1144    Returns:
1145        An in-memory representation of the exported ONNX model.
1146
1147    **Example 1 - Simplest export**
1148    ::
1149
1150        class MyModel(torch.nn.Module):
1151            def __init__(self) -> None:
1152                super().__init__()
1153                self.linear = torch.nn.Linear(2, 2)
1154
1155            def forward(self, x, bias=None):
1156                out = self.linear(x)
1157                out = out + bias
1158                return out
1159
1160
1161        model = MyModel()
1162        kwargs = {"bias": 3.0}
1163        args = (torch.randn(2, 2, 2),)
1164        onnx_program = torch.onnx.dynamo_export(model, *args, **kwargs).save(
1165            "my_simple_model.onnx"
1166        )
1167
1168    **Example 2 - Exporting with dynamic shapes**
1169    ::
1170
1171        # The previous model can be exported with dynamic shapes
1172        export_options = torch.onnx.ExportOptions(dynamic_shapes=True)
1173        onnx_program = torch.onnx.dynamo_export(
1174            model, *args, **kwargs, export_options=export_options
1175        )
1176        onnx_program.save("my_dynamic_model.onnx")
1177
1178
1179    By printing input dynamic dimensions we can see the input shape is no longer (2,2,2)
1180    ::
1181
1182        >>> print(onnx_program.model_proto.graph.input[0])
1183        name: "arg0"
1184        type {
1185          tensor_type {
1186            elem_type: 1
1187            shape {
1188              dim {
1189                dim_param: "arg0_dim_0"
1190              }
1191              dim {
1192                dim_param: "arg0_dim_1"
1193              }
1194              dim {
1195                dim_param: "arg0_dim_2"
1196              }
1197            }
1198          }
1199        }
1200    """
1201
1202    if export_options is not None:
1203        resolved_export_options = (
1204            export_options
1205            if isinstance(export_options, ResolvedExportOptions)
1206            else ResolvedExportOptions(export_options, model=model)
1207        )
1208    else:
1209        resolved_export_options = ResolvedExportOptions(ExportOptions(), model=model)
1210
1211    _assert_dependencies(resolved_export_options)
1212
1213    try:
1214        from torch._dynamo import config as _dynamo_config
1215
1216        with _dynamo_config.patch(do_not_emit_runtime_asserts=True):
1217            return Exporter(
1218                options=resolved_export_options,
1219                model=model,
1220                model_args=model_args,
1221                model_kwargs=model_kwargs,
1222            ).export()
1223    except Exception as e:
1224        sarif_report_path = _DEFAULT_FAILED_EXPORT_SARIF_LOG_PATH
1225        resolved_export_options.diagnostic_context.dump(sarif_report_path)
1226        message = (
1227            f"Failed to export the model to ONNX. Generating SARIF report at '{sarif_report_path}'. "
1228            "SARIF is a standard format for the output of static analysis tools. "
1229            "SARIF logs can be loaded in VS Code SARIF viewer extension, "
1230            "or SARIF web viewer (https://microsoft.github.io/sarif-web-component/). "
1231            f"Please report a bug on PyTorch Github: {_PYTORCH_GITHUB_ISSUES_URL}"
1232        )
1233        raise errors.OnnxExporterError(message) from e
1234
1235
1236def common_pre_export_passes(
1237    options: ResolvedExportOptions,
1238    original_model: torch.nn.Module | Callable,
1239    fx_module: torch.fx.GraphModule,
1240    fx_module_args: Sequence[Any],
1241):
1242    # TODO: Import here to prevent circular dependency
1243    from torch.onnx._internal.fx import analysis, passes
1244
1245    diagnostic_context = options.diagnostic_context
1246
1247    # Apply decomposition table to the input graph.
1248    module = passes.Decompose(
1249        diagnostic_context,
1250        fx_module,
1251        options.decomposition_table,
1252        enable_dynamic_axes=options.dynamic_shapes,
1253        allow_fake_constant=options.fake_context is not None,
1254    ).run(*fx_module_args)
1255
1256    # ONNX does not support views and mutations.
1257    # Functionalize to get a semantically equivalent graph without mutations.
1258    module = passes.Functionalize(
1259        diagnostic_context,
1260        module,
1261        enable_dynamic_axes=options.dynamic_shapes,
1262        allow_fake_constant=options.fake_context is not None,
1263    ).run(*fx_module_args)
1264
1265    # Input mutations are detected and distilled after `Functionalize` pass.
1266    # Remove them since ONNX inference does not need them.
1267    module = passes.RemoveInputMutation(diagnostic_context, module).run(*fx_module_args)
1268
1269    # ONNX does not support concept of (implicit) type promotion.
1270    # Insert type casts explicitly where needed.
1271    module = passes.InsertTypePromotion(diagnostic_context, module).run()
1272
1273    analysis.UnsupportedFxNodesAnalysis(
1274        diagnostic_context, module, options.onnxfunction_dispatcher
1275    ).analyze(infra.levels.ERROR)
1276
1277    if isinstance(original_model, torch.nn.Module):
1278        module = passes.RestoreParameterAndBufferNames(
1279            diagnostic_context, module, original_model
1280        ).run()
1281
1282    # This operation should be invoked as the last pre export pass.
1283    # See [NOTE: Modularize pass ordering]
1284    module = passes.Modularize(diagnostic_context, module).run()
1285
1286    # ONNX does not support None inputs. During graph building, all None inputs
1287    # are removed. Here we register this step to input adapter.
1288    options.fx_tracer.input_adapter.append_step(io_adapter.RemoveNoneInputStep())
1289
1290    # NOTE: temp workaround for https://github.com/pytorch/pytorch/issues/99534
1291    # Dynamo doesn't support non-tensor inputs.
1292    options.fx_tracer.input_adapter.append_step(io_adapter.RemoveNonTensorInputStep())
1293
1294    # ONNX does not support complex inputs. During graph building, all complex inputs
1295    # are converted to real representation inputs. Here we register this step to
1296    # input/output adapter.
1297    options.fx_tracer.input_adapter.append_step(
1298        io_adapter.ConvertComplexToRealRepresentationInputStep()
1299    )
1300
1301    # ONNX can't represent collection types (e.g., dictionary, tuple of tuple of
1302    # tensor, etc), we flatten the collection and register each element as output.
1303    options.fx_tracer.output_adapter.append_step(io_adapter.FlattenOutputStep())
1304
1305    # Output post-processing steps should happen after `FlattenOutputStep`.
1306    options.fx_tracer.output_adapter.append_step(
1307        io_adapter.ConvertComplexToRealRepresentationOutputStep()
1308    )
1309
1310    return module
1311