xref: /aosp_15_r20/external/pytorch/torch/onnx/_internal/onnxruntime.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# mypy: allow-untyped-defs
2import dataclasses
3import importlib
4import logging
5import os
6from typing import (
7    Any,
8    Callable,
9    Dict,
10    Final,
11    List,
12    Mapping,
13    Optional,
14    Sequence,
15    Set,
16    Tuple,
17    TYPE_CHECKING,
18    Union,
19)
20from typing_extensions import TypeAlias
21
22import torch
23import torch._C
24import torch._ops
25import torch._prims.executor
26import torch.fx
27from torch._subclasses.fake_tensor import FakeTensor
28from torch.fx._compatibility import compatibility
29from torch.fx.passes.fake_tensor_prop import FakeTensorProp
30from torch.fx.passes.operator_support import OperatorSupport
31from torch.fx.passes.tools_common import CALLABLE_NODE_OPS
32from torch.utils import _pytree
33
34
35if TYPE_CHECKING:
36    import onnx
37    import onnxruntime
38    from onnxruntime.capi import _pybind_state as ORTC
39
40    import torch.onnx
41    import torch.onnx._internal
42    import torch.onnx._internal._exporter_legacy
43    import torch.onnx._internal.diagnostics
44    import torch.onnx._internal.fx.decomposition_table
45    import torch.onnx._internal.fx.passes  # noqa: TCH004
46
47
48_SUPPORT_ONNXRT: Optional[bool] = None
49
50__all__ = [
51    "is_onnxrt_backend_supported",
52    "torch_compile_backend",
53    "OrtExecutionProvider",
54    "OrtBackendOptions",
55    "OrtBackend",
56]
57
58
59def is_onnxrt_backend_supported() -> bool:
60    """Returns ``True`` if ONNX Runtime dependencies are installed and usable
61    to support TorchDynamo backend integration; ``False`` otherwise.
62
63    Example::
64
65        # xdoctest: +REQUIRES(env:TORCH_DOCTEST_ONNX)
66        >>> import torch
67        >>> if torch.onnx.is_onnxrt_backend_supported():
68        ...     @torch.compile(backend="onnxrt")
69        ...     def f(x):
70        ...             return x * x
71        ...     print(f(torch.randn(10)))
72        ... else:
73        ...     print("pip install onnx onnxscript onnxruntime")
74        ...
75    """
76    global _SUPPORT_ONNXRT
77
78    if _SUPPORT_ONNXRT is None:
79        # `onnxruntime` might import a lot of other runtime packages,
80        # e.g. apex, deepspeed, transformers.
81        # So lazy-importing onnxruntime to avoid possible circular import.
82        try:
83            importlib.import_module("onnxruntime")
84            importlib.import_module("onnxruntime.capi._pybind_state")
85
86            # This is not use directly in DORT but needed by underlying exporter,
87            # so we still need to check if it exists.
88            importlib.import_module("onnxscript")
89
90            import torch.onnx  # noqa: F401
91            import torch.onnx._internal  # noqa: F401
92            import torch.onnx._internal._exporter_legacy  # noqa: F401
93            import torch.onnx._internal.diagnostics  # noqa: F401
94            from torch.onnx._internal.fx import (  # noqa: F401
95                decomposition_table,
96                fx_onnx_interpreter,
97                passes,
98                type_utils,
99            )
100
101            _SUPPORT_ONNXRT = True
102        except ImportError:
103            _SUPPORT_ONNXRT = False
104
105    return _SUPPORT_ONNXRT
106
107
108_dumped_onnx_model: Dict[str, int] = {}
109
110
111def _dump_onnx_model(
112    model_string: bytes, graph_module: Optional[torch.fx.GraphModule] = None
113) -> str:
114    """Stores the onnx model into a file.
115    The name is "{ONNXRT_DUMP_PATH}{N}.onnx"
116    where *N* is the number of files already stored with
117    this prefix.
118    If graph_module is not None, the graph is stored as a string with
119    the same filename except the extension (.txt).
120    """
121    prefix = os.environ.get("ONNXRT_DUMP_PATH", None)
122    if not prefix:
123        return ""
124    n = _dumped_onnx_model.get(prefix, -1) + 1
125    filename = f"{prefix}{n}.onnx"
126    with open(filename, "wb") as f:
127        f.write(model_string)
128    _dumped_onnx_model[prefix] = n
129    if graph_module is not None:
130        filename_txt = f"{prefix}{n}.txt"
131        with open(filename_txt, "w", encoding="utf-8") as f:
132            f.write(str(graph_module.graph))
133    return filename
134
135
136def _infer_default_eps() -> Sequence[str]:
137    # TODO: select a good default based on the capabilities of the host
138    # e.g. DML on Windows, etc.
139    return ["CPUExecutionProvider"]
140
141
142def _nvtx_range_push(name: str):
143    """If PyTorch is installed with CUDA support, this starts NVTX range.
144
145    Check torch.cuda.nvtx.range_push's document for more details.
146    """
147    if torch.cuda.is_available():
148        torch.cuda.nvtx.range_push(name)
149
150
151def _nvtx_range_pop():
152    """If PyTorch is installed with CUDA support, this terminates NVTX range.
153
154    Check torch.cuda.nvtx.range_pop's document for more details.
155    """
156    if torch.cuda.is_available():
157        torch.cuda.nvtx.range_pop()
158
159
160def _get_ort_device_type(device_type: str):
161    from onnxruntime.capi import _pybind_state as ORTC
162
163    if device_type == "cuda":
164        return ORTC.OrtDevice.cuda()
165    if device_type == "cpu":
166        return ORTC.OrtDevice.cpu()
167    # ort pytorch device is mapped to NPU OrtDevice type
168    if device_type == "maia":
169        return ORTC.OrtDevice.npu()
170    raise ValueError("Unsupported device type: " + device_type)
171
172
173logger = logging.getLogger(__name__)
174# Uncomment the following lines to print out development info.
175# logging.basicConfig(level=logging.WARNING)
176# logger.setLevel(logging.WARNING)
177
178
179class OrtOperatorSupport(OperatorSupport):
180    """Operator support for ONNXRuntime backend.
181
182    It has two-level of support decision. One is via support_dict and the other one
183    is via extra_support_dict. The logic of using support_dict is implemented in
184    OrtOperatorSupport and extra_support_dict is used by OperatorSupport.is_node_supported.
185    """
186
187    def __init__(self, support_dict: Set[Any], extra_support_dict: Dict[str, Any]):
188        # Use extra_support_dict[op_name] = None to indicate
189        # we support op_name with all input types. Otherwise,
190        # see support_dict (type: SupportDict) in operator_support.py
191        # for specifying supported types.
192        super().__init__(extra_support_dict)
193        self._onnx_support_dict = support_dict
194
195    def is_node_supported(
196        self, submodules: Mapping[str, torch.nn.Module], node: torch.fx.Node
197    ) -> bool:
198        # OperatorSupport.is_node_supported returns True for non-callable nodes.
199        # Since ORT can't execute them, we return False here to override the base
200        # behavior.
201        if node.op not in CALLABLE_NODE_OPS:
202            return False
203        # This is the and the only place to decide if aten op is supported.
204        if node.op == "call_function" and node.target in self._onnx_support_dict:
205            logger.info(
206                "support_dict supports node.target: %s (type: %s)",
207                node.target,
208                type(node.target),
209            )
210            return True
211        # If node.target is not in support_dict, we still want to check if torch.jit.script
212        # can convert it to ONNX equivalence. Let's use base mechanism to do this.
213        # See extra_support_dict  for supported ops.
214        if super().is_node_supported(submodules, node):
215            logger.info(
216                "extra_support_dict supports node.target: %s (type: %s)",
217                node.target,
218                type(node.target),
219            )
220            return True
221        logger.warning(
222            "support_dict and extra_support_dict don't support node.target: %s (type: %s)",
223            node.target,
224            type(node.target),
225        )
226        return False
227
228
229def _move_placeholder_to_front(graph_module: torch.fx.GraphModule) -> None:
230    """
231    In torch.fx.Graph, placeholder is a special assignment node. If it's not
232    executed in the beginning, it could overwrite values computed by upstream
233    nodes.
234    """
235
236    graph = graph_module.graph
237    placeholders = []
238    first_not_placeholder = None
239    for node in graph.nodes:
240        if node.op == "placeholder":
241            placeholders.append(node)
242        if first_not_placeholder is None and node.op != "placeholder":
243            first_not_placeholder = node
244    if first_not_placeholder is None:
245        return
246    for placeholder in placeholders:
247        first_not_placeholder.prepend(placeholder)
248
249
250def _infer_ep_from_device(*args) -> Tuple[str, ...]:
251    """Return the first valid device (i.e., GPU or CPU) in argument list."""
252    eps = []
253    for arg in args:
254        if hasattr(arg, "device"):
255            device = arg.device
256            if device.type == "cuda":
257                eps.append("CUDAExecutionProvider")
258            elif device.type == "cpu":
259                eps.append("CPUExecutionProvider")
260    return tuple(eps)
261
262
263def _extract_graph_module_inputs(graph_module: torch.fx.GraphModule) -> Tuple[Any, ...]:
264    placeholders = []
265    for node in graph_module.graph.nodes:
266        if node.op == "placeholder":
267            if hasattr(node, "meta") and "val" in node.meta:
268                assert isinstance(node.meta["val"], torch.Tensor)
269            placeholders.append(node)
270    return tuple(placeholders)
271
272
273def _extract_graph_module_outputs(graph_module: torch.fx.GraphModule) -> Any:
274    """Collect "val" fields from outputs metadata in this torch.fx.GraphModule."""
275    for node in graph_module.graph.nodes:
276        if node.op == "output":
277            # Output node is unique. Let's retrieve output values from
278            # this node's input list. And then just return.
279            return node.args[0]
280    raise ValueError("No output node found in this torch.fx.GraphModule.")
281
282
283def _infer_ep_from_graph_module(graph_module: torch.fx.GraphModule) -> Tuple[str, ...]:
284    """Return the all valid devices (i.e., GPU or CPU) among outputs of this torch.fx.GraphModule."""
285    flattened_output_args, _ = _pytree.tree_flatten(
286        _extract_graph_module_outputs(graph_module)
287    )
288    # Output arguments with example value (type: torch.Tensor) in the `graph_module`.
289    selected_output_args = [
290        output_arg.meta["val"]
291        for output_arg in flattened_output_args
292        # output_arg must have tensor for its device information.
293        # Otherwise, skip it.
294        if (hasattr(output_arg, "meta") and "val" in output_arg.meta)
295    ]
296    return _infer_ep_from_device(*selected_output_args)
297
298
299def _sort_eps(eps: Tuple[str, ...]) -> Tuple[str, ...]:
300    """Sort execution providers in eps based on pre-set priority."""
301
302    def get_execution_provider_priority(ep: str) -> int:
303        if ep == "CPUExecutionProvider":
304            # Lowest priority.
305            return 2
306        if ep == "CUDAExecutionProvider":
307            # Higher priority than CPU but lower than
308            # other specialized EPs.
309            return 1
310        # Highest priority.
311        return 0
312
313    unique_eps = set(eps)
314    return tuple(sorted(unique_eps, key=get_execution_provider_priority, reverse=True))
315
316
317def _get_onnx_devices(
318    values: Tuple[
319        Union[
320            torch.Tensor, torch.SymInt, int, torch.SymFloat, float, torch.SymBool, bool
321        ],
322        ...,
323    ],
324) -> Tuple["ORTC.OrtDevice", ...]:
325    from onnxruntime.capi import _pybind_state as ORTC
326
327    def _device_id_or_zero(device_id: int) -> int:
328        return device_id or 0
329
330    def _map_tensor_or_sym_to_device(
331        value: Union[
332            torch.Tensor, torch.SymInt, int, torch.SymFloat, float, torch.SymBool, bool
333        ],
334    ) -> int:
335        if isinstance(value, torch.Tensor):
336            return ORTC.OrtDevice(
337                _get_ort_device_type(value.device.type),
338                ORTC.OrtDevice.default_memory(),
339                _device_id_or_zero(value.device.index),
340            )
341        elif isinstance(
342            value, (torch.SymInt, int, torch.SymFloat, float, torch.SymBool, bool)
343        ):
344            return ORTC.OrtDevice(
345                _get_ort_device_type("cpu"), ORTC.OrtDevice.default_memory(), 0
346            )
347        else:
348            raise ValueError("Unsupported value type: " + str(type(value)))
349
350    if len(values) > 0:
351        ort_devices = tuple(_map_tensor_or_sym_to_device(value) for value in values)
352        return ort_devices
353    else:
354        return (_map_tensor_or_sym_to_device(1),)
355
356
357def _get_ortvalues_from_torch_tensors(
358    tensors: Tuple[torch.Tensor, ...], devices: Tuple["ORTC.OrtDevice", ...]
359) -> Tuple[torch.Tensor, ...]:
360    from onnxruntime.capi import _pybind_state as ORTC
361
362    from torch.onnx._internal.fx.type_utils import _TORCH_DTYPE_TO_NUMPY_DTYPE
363
364    ortvalues = ORTC.OrtValueVector()
365    ortvalues.reserve(len(tensors))
366    dtypes = []
367    shapes = []
368    data_ptrs = []
369
370    for tensor in tensors:
371        dtypes.append(_TORCH_DTYPE_TO_NUMPY_DTYPE[tensor.dtype])
372        shapes.append(tensor.size())
373        data_ptrs.append(tensor.data_ptr())
374    ortvalues.push_back_batch(tensors, data_ptrs, dtypes, shapes, devices)
375    return ortvalues
376
377
378def _to_real_tensor(tensor: FakeTensor) -> torch.Tensor:
379    if tensor.is_sparse:
380        raise ValueError("sparse tensor is not yet supported.")
381    out = torch.empty(tensor.size(), dtype=tensor.dtype, device=tensor.device)
382    return out
383
384
385def _adjust_scalar_from_fx_to_onnx(
386    dynamo_value: Union[
387        torch.Tensor,
388        int,
389        float,
390        bool,
391    ],
392    value_info: "onnx.ValueInfoProto",  # type: ignore[name-defined]
393) -> torch.Tensor:
394    """Helper function to wrap PyTorch variables as torch.Tensor"""
395    if (
396        isinstance(dynamo_value, torch.Tensor)
397        and len(value_info.type.tensor_type.shape.dim) == 0
398        and dynamo_value.shape == (1,)
399    ):
400        # ONNX expect a scalar with empty shape.
401        # In contrast, PyTorch usually allows implicit
402        # conversion between shape=() and shape=(1,).
403        #
404        # Below, PyTorch's shape (1,) is reshaped to ().
405        return torch.squeeze(dynamo_value)
406    elif isinstance(dynamo_value, int):
407        return torch.tensor(dynamo_value, dtype=torch.int64)
408    elif isinstance(dynamo_value, float):
409        return torch.tensor(dynamo_value, dtype=torch.float32)
410    elif isinstance(dynamo_value, bool):
411        return torch.tensor(dynamo_value, dtype=torch.bool)
412    else:
413        assert isinstance(dynamo_value, torch.Tensor)
414        return dynamo_value.contiguous()
415
416
417def _adjust_scalar_from_onnx_to_fx(
418    tensor: torch.Tensor,
419    prim_value: Union[
420        torch.Tensor,
421        torch.SymInt,
422        int,
423        torch.SymFloat,
424        float,
425        torch.SymBool,
426        bool,
427    ],
428) -> Union[
429    torch.Tensor,
430    int,
431    float,
432    bool,
433]:
434    """Helper function to wrap ORT-produced torch.Tensor as PyTorch variables"""
435    assert isinstance(tensor, torch.Tensor), "ORT's output must be tensor."
436    if isinstance(
437        prim_value,
438        (torch.SymInt, int, torch.SymFloat, float, torch.SymBool, bool),
439    ):
440        # Convert tensor back to scalar to match Dynamo's expectation.
441        return tensor.item()
442    return tensor
443
444
445def _run_onnx_session_with_ortvaluevector(
446    sess: "onnxruntime.InferenceSession",
447    input_names: Tuple[str, ...],
448    inputs: Tuple[torch.Tensor, ...],
449    input_devices: Tuple["ORTC.OrtDevice", ...],
450    output_names: Tuple[str, ...],
451    outputs: Tuple[torch.Tensor, ...],
452    output_devices: Tuple["ORTC.OrtDevice", ...],
453    preallocate_output: bool,
454    input_value_infos: Tuple["onnx.ValueInfoProto", ...],  # type: ignore[name-defined]
455    normalized_prim_outputs: Tuple[
456        Union[
457            torch.Tensor, torch.SymInt, int, torch.SymFloat, float, torch.SymBool, bool
458        ],
459        ...,
460    ],
461) -> Tuple[Union[torch.Tensor, int, float, bool], ...]:
462    import onnxruntime
463    from onnxruntime.capi import _pybind_state as ORTC
464
465    _nvtx_range_push("contiguous")
466    inputs = tuple(
467        _adjust_scalar_from_fx_to_onnx(arg, value_info)
468        for arg, value_info in zip(inputs, input_value_infos)
469    )
470    _nvtx_range_pop()
471
472    _nvtx_range_push("push_back_batch")
473    ort_inputs = _get_ortvalues_from_torch_tensors(inputs, input_devices)
474
475    # preallocate output pytorch Tensors and use the buffers affined to the torch device for the output ortvalue.
476    # Because the output ortvalue is not allocated and owned by ort, it does not need to convert the output ortvalue
477    # to torch Tensor transferring the ownership.
478    if preallocate_output:
479        pth_outputs = tuple(
480            _to_real_tensor(t) if isinstance(t, FakeTensor) else t for t in outputs
481        )
482        ort_outputs = _get_ortvalues_from_torch_tensors(pth_outputs, output_devices)
483    else:
484        ort_outputs = ORTC.OrtValueVector()
485    _nvtx_range_pop()
486
487    _nvtx_range_push("run_with_ortvaluevector")
488    run_options = onnxruntime.RunOptions()
489    run_options.add_run_config_entry("disable_synchronize_execution_providers", "1")
490    sess.run_with_ortvaluevector(
491        run_options, input_names, ort_inputs, output_names, ort_outputs, output_devices
492    )
493    _nvtx_range_pop()
494
495    # Post-processing step:
496    #  wrap ORT's outputs to the schema represented by
497    #  `prim_output` (obtained by running the original
498    #  torch.fx.GraphModule).
499    if preallocate_output:
500        # Profile the ORT-to-PyTorch type cast below
501        _nvtx_range_push("after run_with_ortvaluevector")
502        # Outputs are stored on pre-allocated torch.Tensors' memory,
503        # so this case doesn't need to convert ORTValue to torch.Tensor.
504        pth_outputs = tuple(
505            _adjust_scalar_from_onnx_to_fx(onnx_output, prim_output)  # type: ignore[misc]
506            for onnx_output, prim_output in zip(pth_outputs, normalized_prim_outputs)
507        )
508        _nvtx_range_pop()
509        return pth_outputs
510    else:
511        # Profile the two ORT-to-PyTorch type casts below
512        _nvtx_range_push("after run_with_ortvaluevector")
513        # Map ORTValue to torch.Tensor.
514        pth_outputs = onnxruntime.training.ortmodule._utils._ortvalues_to_torch_tensor(
515            ort_outputs
516        )
517        # Change some torch.Tensor to int, float, bool.
518        pth_outputs = tuple(
519            _adjust_scalar_from_onnx_to_fx(onnx_output, prim_output)  # type: ignore[misc]
520            for onnx_output, prim_output in zip(pth_outputs, normalized_prim_outputs)
521        )
522        _nvtx_range_pop()
523        return pth_outputs
524
525
526def _run_onnx_session_with_fetch(
527    sess: "onnxruntime.InferenceSession",
528    input_names: Tuple[str, ...],
529    inputs: Tuple[torch.Tensor, ...],
530    input_devices: Tuple["ORTC.OrtDevice", ...],
531    output_names: Tuple[str, ...],
532    outputs: Tuple[torch.Tensor, ...],
533    output_devices: Tuple["ORTC.OrtDevice", ...],
534    preallocate_output: bool,
535    input_value_infos: Tuple["onnx.ValueInfoProto", ...],  # type: ignore[name-defined]
536    normalized_prim_outputs: Tuple[
537        Union[
538            torch.Tensor, torch.SymInt, int, torch.SymFloat, float, torch.SymBool, bool
539        ],
540        ...,
541    ],
542) -> Tuple[Union[torch.Tensor, int, float, bool], ...]:
543    import onnxruntime
544
545    inputs = tuple(
546        _adjust_scalar_from_fx_to_onnx(arg, value_info)
547        for arg, value_info in zip(inputs, input_value_infos)
548    )
549    feed = {
550        name: onnxruntime.OrtValue.ortvalue_from_numpy(tensor.cpu().numpy())
551        for name, tensor in zip(input_names, inputs)
552    }
553    ort_outputs = sess.run(output_names, feed)
554    pth_outputs = tuple(
555        _adjust_scalar_from_onnx_to_fx(
556            torch.from_numpy(value),
557            prim_output,
558        )
559        for value, prim_output in zip(ort_outputs, normalized_prim_outputs)
560    )
561    return pth_outputs
562
563
564class OrtExecutionInfoPerSession:
565    """Information required to execute torch.fx.GraphModule using onnxruntime.InferenceSession"""
566
567    def __init__(
568        self,
569        session: "onnxruntime.InferenceSession",
570        input_names: Tuple[str, ...],
571        input_value_infos: Tuple["onnx.ValueInfoProto", ...],  # type: ignore[name-defined]
572        output_names: Tuple[str, ...],
573        output_value_infos: Tuple["onnx.ValueInfoProto", ...],  # type: ignore[name-defined]
574        input_devices: Tuple["ORTC.OrtDevice", ...],
575        output_devices: Tuple["ORTC.OrtDevice", ...],
576        example_outputs: Union[Tuple[torch.Tensor, ...], torch.Tensor],
577    ):
578        # Carrier of ONNX model and its executor.
579        self.session: onnxruntime.InferenceSession = session
580        # For the ONNX model stored in self.session, self.input_names[i] is the
581        # name of the i-th positional input.
582        self.input_names: Tuple[str, ...] = input_names
583        # self.input_name[i]'s type information is stored in self.input_value_infos[i].
584        self.input_value_infos: Tuple[onnx.ValueInfoProto, ...] = input_value_infos  # type: ignore[name-defined]
585        # Similar to self.input_names, but for outputs.
586        self.output_names: Tuple[str, ...] = output_names
587        # Similar to self.input_value_infos but for outputs.
588        self.output_value_infos: Tuple[onnx.ValueInfoProto, ...] = output_value_infos  # type: ignore[name-defined]
589        # For the ONNX model stored in self.session, self.input_devices[i] is the
590        # i-th positional input's device.
591        self.input_devices: Tuple[ORTC.OrtDevice, ...] = input_devices
592        # Similar to self.input_devices, but for outputs.
593        self.output_devices: Tuple[ORTC.OrtDevice, ...] = output_devices
594        # This is the outputs of executing the original torch.fx.GraphModule with example inputs
595        # (i.e., args passed into OrtBackend._ort_acclerated_call).
596        self.example_outputs: Union[Tuple[torch.Tensor, ...], torch.Tensor] = (
597            example_outputs
598        )
599
600    def is_supported(self, *args):
601        from torch.onnx._internal.fx.type_utils import (
602            _TORCH_DTYPE_TO_ONNX_TENSOR_ELEMENT_TYPE,
603            from_python_type_to_onnx_tensor_element_type,
604        )
605
606        # Compare the args and the input schema in ONNX model and
607        # return the first match.
608        if len(args) != len(self.input_value_infos):
609            return False
610        for arg, value_info in zip(args, self.input_value_infos):
611            if not isinstance(arg, (torch.Tensor, float, int)):
612                return False
613
614            # Check Python scalars such as int, float, and bool.
615            if isinstance(arg, (int, float, bool)):
616                # Map, e.g., float to onnx.TensorProto.FLOAT.
617                onnx_dtype = from_python_type_to_onnx_tensor_element_type(type(arg))
618                if onnx_dtype != value_info.type.tensor_type.elem_type:
619                    return False
620                if len(value_info.type.tensor_type.shape.dim) != 0:
621                    return False
622                continue
623
624            # Check tensor.
625            onnx_dtype = _TORCH_DTYPE_TO_ONNX_TENSOR_ELEMENT_TYPE[arg.dtype]
626            if onnx_dtype != value_info.type.tensor_type.elem_type:
627                return False
628            for dim, onnx_dim in zip(arg.shape, value_info.type.tensor_type.shape.dim):
629                if isinstance(dim, int) and (
630                    onnx_dim.dim_value == dim or onnx_dim.dim_param
631                ):
632                    continue
633                elif isinstance(dim, torch.SymInt) and onnx_dim.dim_param:
634                    continue
635                else:
636                    return False
637        return True
638
639
640@dataclasses.dataclass
641class OrtExecutionInfoForAllGraphModules:
642    def __init__(self) -> None:
643        # All sessions (and their related information) created by exporting the same GraphModule
644        # with different inputs.
645        self.execution_info_per_graph_module: Dict[
646            torch.fx.GraphModule, List[OrtExecutionInfoPerSession]
647        ] = {}
648
649    def search_reusable_session_execution_info(
650        self, graph_module: torch.fx.GraphModule, *args
651    ):
652        if graph_module not in self.execution_info_per_graph_module:
653            return None
654        # All execution information for ONNX models exported from the same `graph_module`
655        # with different inputs.
656        candidates = self.execution_info_per_graph_module[graph_module]
657
658        for candidate in candidates:
659            if candidate.is_supported(*args):
660                # Returns the first session that accepts this input schema.
661                return candidate
662        # No reusable session found.
663        return None
664
665    def cache_session_execution_info(
666        self, graph_module: torch.fx.GraphModule, info: OrtExecutionInfoPerSession
667    ):
668        if graph_module not in self.execution_info_per_graph_module:
669            self.execution_info_per_graph_module[graph_module] = [info]
670        else:
671            self.execution_info_per_graph_module[graph_module].append(info)
672
673
674OrtExecutionProvider: TypeAlias = Union[str, Tuple[str, Mapping[str, Any]]]
675"""Either the name of an ONNX Runtime execution provider as a string or
676a 2-tuple of the name and a dictionary of execution provider options.
677
678Examples::
679
680    >>> "CPUExecutionProvider"
681
682    >>> ("CUDAExecutionProvider", {"device_id": 3})
683
684"""
685
686
687@dataclasses.dataclass(frozen=True)
688@compatibility(is_backward_compatible=False)
689class OrtBackendOptions:
690    """Options for constructing an ``OrtBackend``, the ONNX Runtime
691    backend (``"onnxrt"``) for ``torch.compile``.
692
693    Example::
694
695        >>> @torch.compile(
696        ...     backend="onnxrt",
697        ...     options=torch.onnx._OrtBackendOptions(...),
698        ... )
699        ... def ort_function(x):
700        ...     return x ** x
701    """
702
703    preferred_execution_providers: Optional[Sequence[OrtExecutionProvider]] = None
704    """An optional sequence of execution providers to be prioritized ahead of any
705    execution providers that may be inferred (see ``infer_execution_providers``).
706    """
707
708    infer_execution_providers: bool = True
709    """Whether to infer an execution provider from ``torch.device`` bound to inputs or found in the graph."""
710
711    default_execution_providers: Optional[Sequence[OrtExecutionProvider]] = None
712    """The default fallback execution providers. If not specified, one will be
713    be selected based on the host environment (most likely ``"CPUExecutionProvider"``).
714    """
715
716    # preallocate_output allows for allocating output torch Tensor buffers and feeding them to InferenceSession
717    # in order to avoid internal allocation of output buffers in InferenceSession.
718    # If output ortvalue returned from InferenceSession is allocated internally,
719    # it needs to be converted to torch Tensor for return, and the torch Tensor should hold the ownership.
720    # When a custom torch device is used with a custom aten allocator, the conversion from ortvalue to torch Tensor
721    # should be supported, which is currently done through dlpack. Note that dlpack might not support a custom torch device.
722    # It can be avoided by allowing for preallocation for output buffers allocated by a custom aten allocator,
723    # and use the preallocated output buffers for InferenceSession not holding any ownership for them.
724    # TODO(wschin): Make it to inference session level flag.
725    # See https://github.com/pytorch/pytorch/issues/106869.
726    preallocate_output: bool = False
727    """If ``True``, allocate memory for ONNX Runtime's outputs on the PyTorch side."""
728
729    use_aot_autograd: bool = True
730    """Whether to wrap the ``OrtBackend`` with TorchDynamo's aot_autograd backend
731    to support training (i.e., backward graphs are also sent to ``OrtBackend``).
732
733    Symbolic execution is used to capture the forward pass and backward passes as a single graph.
734    Then, a selected graph partition algorithm (``min_cut_rematerialization_partition``) is used
735    to split the entire graph into forward sub-graph and backward sub-graph. Finally, both
736    sub-graphs are compiled by ``OrtBackend``.
737    """
738
739    export_options: Optional["torch.onnx.ExportOptions"] = None
740    """Options for the TorchDynamo-based ONNX exporter used by the ``OrtBackend``."""
741
742    ort_session_options: Optional["onnxruntime.SessionOptions"] = None
743    """Options for the ``onnxruntime.InferenceSession`` used by the ``OrtBackend``."""
744
745    pre_ort_model_transforms: Optional[  # type: ignore[name-defined]
746        Sequence[Callable[["onnx.ModelProto"], None]]
747    ] = None
748    """A list of graph transforms to be applied to the ONNX model before it
749    is fed to ONNXRuntime's InferenceSession."""
750
751
752@compatibility(is_backward_compatible=False)
753class OrtBackend:
754    """A backend compiles (sub-)graphs in torch.fx.GraphModule to onnxruntime.InferenceSession calls.
755
756    The compiler entry point is OrtBackend.compile, which
757        1. partitions the original graph into supported sub-graphs (type: torch.fx.GraphModule) and unsupported
758           sub-graphs.
759        2. For each supported sub-graph, it replaces its _wrapped_call function with _ort_accelerated_call.
760        3. Inside _ort_accelerated_call, it creates onnxruntime.InferenceSession and calls it to execute the sub-graph.
761    """
762
763    def __init__(self, options: Optional[OrtBackendOptions] = None):
764        from onnxruntime.capi import _pybind_state as ORTC
765
766        import torch.onnx
767        import torch.onnx._internal._exporter_legacy
768        import torch.onnx._internal.fx.decomposition_table
769
770        self._options: Final = OrtBackendOptions() if options is None else options
771
772        # options.export_options contains information shared between exporter and DORT.
773        # For example, they should use the same decomposition table when
774        #  1. capturing FX graph in torch.compile (see how we create aot_ort in register_backend.py)
775        #  2. call exporter's API to convert `torch.fx.GraphModule` to ONNX model
776        #     (see onnxfunction_dispatcher passed to FxOnnxInterpreter.run below).
777        #
778        # Convert user-facing option to internal option used by ONNX exporter
779        # to access required information.
780        # Some useful fields:
781        # - Decomposition table for decomposing FX operators in exporter is
782        #   self._resolved_onnx_exporter_options.decomposition_table.
783        # - self._resolved_onnx_exporter_options.onnx_registry records what
784        #   aten/prim ops are supported by exporter and their exporters (type: callable).
785        self._resolved_onnx_exporter_options = (
786            torch.onnx._internal._exporter_legacy.ResolvedExportOptions(
787                torch.onnx.ExportOptions()
788                if self._options.export_options is None
789                else self._options.export_options
790            )
791        )
792
793        #  Given DORT's computation flow:
794        #   1. OrtOperatorSupport uses support_dict and extra_support_dict to select operators
795        #      and send them to DORT.
796        #   2. Then, DORT exports the selected sub-graphs into ONNX.
797        #   3. Finally DORT calls ORT to do the computation.
798        #  OrtOperatorSupport and create_onnx_friendly_decomposition_table(...)
799        #  must use the same support_dict. If the support_dict here contains something not
800        #  supported by exporter, exporter will fails in step 2 since the selected graphs may
801        #  contains unsupported operators such as aten::_who_you_are.
802        #  This restriction is automatically done since DORT and exporter shares the same
803        #  self._resolved_onnx_exporter_options.
804        support_dict = torch.onnx._internal.fx.decomposition_table._create_onnx_supports_op_overload_table(
805            self._resolved_onnx_exporter_options.onnx_registry
806        )
807
808        extra_support_dict: Dict[str, Any] = {
809            "getattr": None,
810            # To send operator.getitem to ORT, add the corresponding string
811            # recognized by PyTorch's OperatorSupport class.
812            "_operator.getitem": None,
813            # To send operator.mul to ORT, add the corresponding string
814            # recognized by PyTorch's OperatorSupport class.
815            "_operator.mul": None,
816            "_operator.add": None,
817            "_operator.sub": None,
818        }
819
820        self._supported_ops = OrtOperatorSupport(support_dict, extra_support_dict)
821        # TODO(wschin): this is a naive implementation of cache without proper guard
822        # See https://github.com/pytorch/pytorch/issues/106868.
823        self._partitioner_cache: Dict[torch.fx.GraphModule, torch.fx.GraphModule] = {}
824        # Conceptually, this filed is a 2-layer dictionary
825        #   GraphModule 0
826        #     ONNX Model 0 (with ORT InferenceSession and related information. type: OrtExecutionInfoPerSession)
827        #     ONNX Model 1
828        #     ...
829        #   GraphModule 1
830        #     ONNX Model 2 (with ORT InferenceSession and related information. type: OrtExecutionInfoPerSession)
831        #     ONNX Model 3
832        #     ...
833        #   ...
834        # , which caches all previous compilation result so that we can reuse them.
835        # ONNX Model 0 and 1 are exported from the same GraphModule 0 but with different inputs
836        # (e.g., tensors with different ranks). GraphModule 0 and GraphModule 1 are different
837        # graphs captured by Dynamo and sent to OrtBackend.compile.
838        self._all_ort_execution_info = OrtExecutionInfoForAllGraphModules()
839
840        self._assert_allclose_to_baseline = False
841
842        self.execution_count = 0
843
844        # Function which invokes ORT do to the real computation.
845        self.run = (
846            _run_onnx_session_with_ortvaluevector
847            if hasattr(ORTC.OrtValueVector, "push_back_batch")
848            else _run_onnx_session_with_fetch
849        )
850
851    def _select_eps(
852        self, graph_module: torch.fx.GraphModule, *args
853    ) -> Sequence[Tuple[str, Mapping[str, Any]]]:
854        inferred_eps: Tuple[str, ...] = ()
855        if self._options.infer_execution_providers:
856            if eps_from_args := _infer_ep_from_device(*args):
857                # If user feeds CUDA tensor as input argument,
858                # we want to use CUDA EP.
859                # Thus, `eps_from_args` (deduced from input arguments)
860                # has highest priority.
861                inferred_eps = eps_from_args
862            elif eps_from_graph_module := _infer_ep_from_graph_module(graph_module):
863                # If there is no EP in input arguments, we deduce EP from
864                # graph_module's outputs. Those outputs may come from
865                # FakeTensorProp or Dynamo's built-in symbolic shape inference.
866                inferred_eps = eps_from_graph_module
867
868        selected_eps = []
869
870        for ep in (
871            *(self._options.preferred_execution_providers or []),
872            *_sort_eps(inferred_eps),
873            *(self._options.default_execution_providers or _infer_default_eps()),
874        ):
875            if isinstance(ep, str):
876                ep = (ep, {})
877            elif isinstance(ep, tuple) and ep[1] is None:
878                ep = (ep[0], {})
879            if ep is not None and ep not in selected_eps:
880                selected_eps.append(ep)
881
882        return selected_eps
883
884    def _ort_acclerated_call(self, graph_module: torch.fx.GraphModule, *args, **kwargs):
885        """This function replaces GraphModule._wrapped_call in compiled model.
886
887        The _wrapped_call is the underlying implementation of forward method. Replacing
888        it means we delegate the computation to _ort_acclerated_call and therefore
889        onnxruntime.InferenceSession.
890        """
891        import onnxruntime
892
893        from torch.onnx._internal.fx import fx_onnx_interpreter, passes
894
895        cached_execution_info_per_session = (
896            self._all_ort_execution_info.search_reusable_session_execution_info(
897                graph_module, *args
898            )
899        )
900        if cached_execution_info_per_session:
901            onnx_session = cached_execution_info_per_session.session
902            input_names = cached_execution_info_per_session.input_names
903            output_names = cached_execution_info_per_session.output_names
904            input_value_infos = cached_execution_info_per_session.input_value_infos
905            output_value_infos = cached_execution_info_per_session.output_value_infos
906            input_devices = cached_execution_info_per_session.input_devices
907            output_devices = cached_execution_info_per_session.output_devices
908            prim_outputs = cached_execution_info_per_session.example_outputs
909        else:
910            # It's first time seeing such as graph. Let's make a new session
911            # (type: onnxruntime.InferenceSession) for it.
912
913            graph_module = passes.MovePlaceholderToFront(
914                self._resolved_onnx_exporter_options.diagnostic_context,
915                graph_module,
916            ).run()
917            # Generate reference outputs. They are used to indicate output
918            # tensors' types and devices when calling ORT.
919            #
920            # WARNING: The downstream code should not change prim_outputs and
921            # this backend should always produces output with schema identical to prim_outputs'.
922
923            if self._resolved_onnx_exporter_options.dynamic_shapes:
924                # No pre-allocation when dynamic shape is enabled.
925                self.preallocate_output = False
926                extracted_outputs = _extract_graph_module_outputs(graph_module)
927
928                def maybe_map_to_meta_val(value):
929                    if hasattr(value, "meta") and "val" in value.meta:
930                        # Select outputs with "val" information. Without "val",
931                        # it's not possible access output_arg.meta["val"].device.
932                        return value.meta["val"]
933                    else:
934                        return value
935
936                prim_outputs = _pytree.tree_map(
937                    maybe_map_to_meta_val, extracted_outputs
938                )
939            else:
940                try:
941                    prim_outputs = FakeTensorProp(graph_module).propagate(
942                        *args, **kwargs
943                    )
944                except Exception:
945                    logger.warning("FakeTensorProb failed for %s", graph_module)
946                    # When FakeTensorProp fails, it is not possible to preallocate output buffers
947                    # because the output shapes are not inferred.
948                    self.preallocate_output = False
949
950                    # rethrow FakeTensorProb failure because it is not yet currently handled.
951                    raise
952
953            # Create the object to iterate through the nodes in graph one-by-one
954            # and calls the corresponding ONNX exporter for each node.
955            fx_interpreter = fx_onnx_interpreter.FxOnnxInterpreter(
956                diagnostic_context=self._resolved_onnx_exporter_options.diagnostic_context
957            )
958            # Cast FX variables if they will result schema-mismatch when searching
959            # for ONNX operator. E.g., add(double_tensor, int_tensor) is fine in PyTorch,
960            # but ONNX expects add(double_tensor, double_tensor).
961            graph_module = passes.InsertTypePromotion(
962                self._resolved_onnx_exporter_options.diagnostic_context, graph_module
963            ).run()
964            # Start the per-node exporting process. It's conceptually a for loop
965            # scanning through the nodes in the graph.
966            exported = fx_interpreter.run(
967                fx_graph_module=graph_module,
968                onnxfunction_dispatcher=self._resolved_onnx_exporter_options.onnxfunction_dispatcher,
969            )
970            # Convert the exported result to ONNX ModelProto.
971            onnx_model = exported.to_model_proto(
972                opset_version=self._resolved_onnx_exporter_options.onnx_registry.opset_version,
973            )
974
975            try:
976                from onnxscript import optimizer  # type: ignore[import]
977                from onnxscript.rewriter import (  # type: ignore[import]
978                    onnxruntime as ort_rewriter,
979                )
980
981                onnx_model = optimizer.optimize(onnx_model)
982                onnx_model = ort_rewriter.rewrite(onnx_model)
983            except ImportError:
984                logger.warning(
985                    "ONNXScript optimizer is not available. Skipping optimization. "
986                    "Please `pip install onnxscript -U` to enable post-export optimization."
987                )
988
989            # Modify ONNX model using pre-registered graph transforms.
990            # They are in-place modifications for avoiding unnecessary
991            # copy of ONNX initializers.
992            if self._options.pre_ort_model_transforms:
993                for transform in self._options.pre_ort_model_transforms:
994                    transform(onnx_model)
995
996            onnx_model_bytes = onnx_model.SerializeToString()
997            if os.environ.get("ONNXRT_DUMP_PATH", None):
998                # If not empty, environment variable ONNXRT_DUMP_PATH defined the path
999                # where generated onnx files should be stored.
1000                # This module keeps a global variables keeping track of the
1001                # stored models.
1002                # If ONNXRT_DUMP_PATH="dumped/dumped_model_"
1003                # The first file name will be 'dumped/dumped_model_0.onnx'.
1004                # For every dumped model, a text file 'dumped/dumped_model_0.txt'
1005                # is created as well to contain the string representing the graph_module.
1006                _dump_onnx_model(onnx_model_bytes, graph_module=graph_module)
1007
1008            # Initialize a ORT session to execute this ONNX model.
1009            # Note that TorchDynamo assumes all inputs/outputs are on the
1010            # same device, but it's subject to change (very likely with
1011            # dynamic shape support), so we add execution providers
1012            # based on the logic in _select_eps: (explicitly preferred EPs,
1013            # EPs inferred from inputs or graph, and the fallback default EP)/
1014            #
1015            # TODO(wschin): enable external allocators.
1016            # See https://github.com/pytorch/pytorch/issues/106867
1017            onnx_session = onnxruntime.InferenceSession(
1018                path_or_bytes=onnx_model_bytes,
1019                sess_options=self._options.ort_session_options,
1020                providers=self._select_eps(graph_module, *args),
1021            )
1022
1023            # Cache ORT session. It's reused for the same "graph_module".
1024            # Generate ONNX model and extract its input and output names.
1025            input_names = tuple(input.name for input in onnx_model.graph.input)
1026            output_names = tuple(output.name for output in onnx_model.graph.output)
1027            input_devices = _get_onnx_devices(args)
1028            # Cache devices for inputs and outputs. They are used to invoke
1029            # ORT session. Output devices indicate where (e.g., GPU or CPU)
1030            # to store outputs
1031            if isinstance(prim_outputs, tuple):
1032                output_devices = _get_onnx_devices(prim_outputs)
1033            else:
1034                output_devices = _get_onnx_devices((prim_outputs,))
1035
1036            input_value_infos = tuple(input for input in onnx_model.graph.input)
1037            output_value_infos = tuple(output for output in onnx_model.graph.output)
1038
1039            execution_info_per_session = OrtExecutionInfoPerSession(
1040                session=onnx_session,
1041                input_names=input_names,
1042                input_value_infos=input_value_infos,
1043                output_names=output_names,
1044                output_value_infos=output_value_infos,
1045                input_devices=input_devices,
1046                output_devices=output_devices,
1047                example_outputs=prim_outputs,
1048            )
1049
1050            self._all_ort_execution_info.cache_session_execution_info(
1051                graph_module, execution_info_per_session
1052            )
1053
1054        self.execution_count += 1
1055
1056        # ORT always returns a tuple of outputs. If the original output is a tensor,
1057        # ORT output's first element must be extracted and returned. Otherwise, type
1058        # mismatch may happen in downstream computation.
1059        is_single_tensor_output = isinstance(prim_outputs, torch.Tensor)
1060        normalized_prim_outputs = (
1061            (prim_outputs,) if is_single_tensor_output else prim_outputs
1062        )
1063        assert isinstance(normalized_prim_outputs, tuple)
1064        assert all(
1065            isinstance(elem, (torch.Tensor, torch.SymInt, int))
1066            for elem in normalized_prim_outputs
1067        )
1068
1069        _nvtx_range_push("run_onnx_session_with_ortvaluevector")
1070        onnx_outputs = self.run(
1071            onnx_session,
1072            input_names,
1073            args,
1074            input_devices,
1075            output_names,
1076            normalized_prim_outputs,
1077            output_devices,
1078            self._options.preallocate_output,
1079            input_value_infos,
1080            normalized_prim_outputs,
1081        )
1082        _nvtx_range_pop()
1083
1084        if self._assert_allclose_to_baseline:
1085            # Compute baseline.
1086            baseline_outputs = torch._prims.executor.execute(
1087                graph_module, *args, executor="aten"
1088            )
1089            normalized_baseline_ouptuts = (
1090                (baseline_outputs,) if is_single_tensor_output else baseline_outputs
1091            )
1092            # Ensure every output tensor is close to the corresponding baseline.
1093            for onnx_output, baseline_output in zip(
1094                onnx_outputs, normalized_baseline_ouptuts
1095            ):
1096                torch.testing.assert_close(onnx_output, baseline_output)
1097        return onnx_outputs[0] if is_single_tensor_output else onnx_outputs
1098
1099    def compile(self, graph_module: torch.fx.GraphModule, args) -> torch.fx.GraphModule:
1100        # Deferred import since CapabilityBasedPartitioner is not decorated with
1101        # @compatibility; importing it at the module level will result in the test
1102        # failing: pytest test/test_fx.py -k test_public_api_surface
1103        # because this module is imported into torch.onnx.
1104        from torch.fx.passes.infra.partitioner import CapabilityBasedPartitioner
1105
1106        # FX graph based partitioning based on ONNX supported ops.
1107        # Given a graph module
1108        #  GraphModule0
1109        #   node_0
1110        #   node_1
1111        #   node_2
1112        #   node_3
1113        #   node_4
1114        # If only node_2 is not supported by ONNX, this graph module will be partitioned into
1115        #  GraphModule0
1116        #   GraphModule1
1117        #    node_0
1118        #    node_1
1119        #   node_2
1120        #   GraphModule2
1121        #    node_3
1122        #    node_4
1123        # by calling CapabilityBasedPartitioner.partition_and_fuse.
1124        # Then, GraphModule1's and GraphModule2's forward method (GraphModule._wrapped_call)
1125        # will be replaced by OrtBackend._ort_accelerated_call to delegate computation to ORT.
1126        if graph_module in self._partitioner_cache:
1127            partitioned_prim_graph_module = self._partitioner_cache[graph_module]
1128        else:
1129            prim_graph_module = graph_module
1130            partitioner = CapabilityBasedPartitioner(
1131                prim_graph_module,
1132                self._supported_ops,
1133                allows_single_node_partition=True,
1134            )
1135            partitioned_prim_graph_module = partitioner.partition_and_fuse()
1136            self._partitioner_cache[graph_module] = partitioned_prim_graph_module
1137
1138            # Overriding fused_module's __call__() function with ort_acclerated_call()
1139            # This loop goes through all graph partitions (each of them is an ONNX-representable graph)
1140            # and override their _wrapped_call function with _ort_accelerated_call.
1141            # Inside _ort_accelerated_call, the partition's graph is exported into ONNX and executed by ORT.
1142            for node in partitioned_prim_graph_module.graph.nodes:
1143                # TODO(wschin): use a better way to identify fused submodule
1144                # See https://github.com/pytorch/pytorch/issues/106872.
1145                if node.op == "call_module" and "fused_" in node.name:
1146                    fused_module = getattr(partitioned_prim_graph_module, node.name)
1147                    # self.ort_acclerated_call is responsible for exporting graph to ONNX,
1148                    # creating ORT session, and running ORT session.
1149                    fused_module._wrapped_call = self._ort_acclerated_call
1150
1151        return partitioned_prim_graph_module
1152
1153    def __call__(
1154        self, graph_module: torch.fx.GraphModule, args
1155    ) -> torch.fx.GraphModule:
1156        """If ``OrtBackendOptions.use_aot_autograd`` is ``True``, the `auto_autograd` compiler
1157        will be invoked, wrapping this ``OrtBackend`` instance's ``compile`` method. Otherwise,
1158        the ``compile`` method is invoked directly."""
1159        if self._options.use_aot_autograd:
1160            from functorch.compile import min_cut_rematerialization_partition
1161            from torch._dynamo.backends.common import aot_autograd
1162
1163            return aot_autograd(
1164                fw_compiler=self.compile,
1165                partition_fn=min_cut_rematerialization_partition,
1166                decompositions=self._resolved_onnx_exporter_options.decomposition_table,
1167            )(graph_module, args)
1168
1169        return self.compile(graph_module, args)
1170
1171    __instance_cache_max_count: Final = 8
1172    __instance_cache: Final[List["OrtBackend"]] = []
1173
1174    @staticmethod
1175    def get_cached_instance_for_options(
1176        options: Optional[Union[OrtBackendOptions, Mapping[str, Any]]] = None,
1177    ) -> "OrtBackend":
1178        """Returns a possibly cached instance of an ``OrtBackend``. If an existing
1179        backend was created previously through this function with the same options,
1180        it will be returned. Otherwise a new backend will be created, cached, and
1181        returned.
1182
1183        Note: if ``options`` sets ``ort_session_options``, a new ``OrtBackend``
1184        will always be returned, since ``onnxruntime.SessionOptions`` cannot
1185        participate in caching."""
1186
1187        def reusable(a: OrtBackendOptions, b: OrtBackendOptions):
1188            if (
1189                a.preferred_execution_providers != b.preferred_execution_providers
1190                or a.infer_execution_providers != b.infer_execution_providers
1191                or a.default_execution_providers != b.default_execution_providers
1192                or a.preallocate_output != b.preallocate_output
1193                or a.use_aot_autograd != b.use_aot_autograd
1194                or a.pre_ort_model_transforms != b.pre_ort_model_transforms
1195            ):
1196                return False
1197
1198            # onnxruntime.SessionOptions is a pybind11 object, cannot be pickled,
1199            # and holds too much potential state to reasonably check manually;
1200            # ort_session_options is provided at all, the backend does not participate
1201            # in caching.
1202            if a.ort_session_options is not None or b.ort_session_options is not None:
1203                return False
1204
1205            if a.export_options is b.export_options:
1206                return True
1207
1208            # Similarly, some objects in ExportOptions are too stateful to use for
1209            # caching. We should revisit this.
1210            if a.export_options is not None and b.export_options is not None:
1211                return (
1212                    a.export_options.dynamic_shapes == b.export_options.dynamic_shapes
1213                    and a.export_options.diagnostic_options
1214                    == b.export_options.diagnostic_options
1215                    and a.export_options.onnx_registry is b.export_options.onnx_registry
1216                    and a.export_options.fake_context is b.export_options.fake_context
1217                )
1218
1219            # We can't account for how the two option sets may differ, so it's not safe to reuse.
1220            return False
1221
1222        if not isinstance(options, OrtBackendOptions):
1223            options = OrtBackendOptions(**(options or {}))
1224
1225        backend = next(
1226            (b for b in OrtBackend.__instance_cache if reusable(b._options, options)),
1227            None,
1228        )
1229
1230        if backend is None:
1231            assert (
1232                len(OrtBackend.__instance_cache) < OrtBackend.__instance_cache_max_count
1233            ), (
1234                f"No more than {OrtBackend.__instance_cache_max_count} instances of "
1235                f"{OrtBackend} allowed. Please instantiate `{OrtBackend}` explicitly "
1236                "to pass to `torch.compile`. "
1237                "See https://github.com/pytorch/pytorch/pull/107973#discussion_r1306144795 "
1238                "for discussion."
1239            )
1240            OrtBackend.__instance_cache.append(backend := OrtBackend(options))
1241
1242        return backend
1243
1244    @staticmethod
1245    def clear_cached_instances():
1246        OrtBackend.__instance_cache.clear()
1247
1248    @staticmethod
1249    def get_cached_instances():
1250        return tuple(OrtBackend.__instance_cache)
1251
1252
1253@compatibility(is_backward_compatible=False)
1254def torch_compile_backend(
1255    graph_module: torch.fx.GraphModule,
1256    args,
1257    *,
1258    options: Optional[Union[OrtBackendOptions, Mapping[str, Any]]] = None,
1259):
1260    return OrtBackend.get_cached_instance_for_options(options)(graph_module, args)
1261