1# mypy: allow-untyped-defs 2import dataclasses 3import importlib 4import logging 5import os 6from typing import ( 7 Any, 8 Callable, 9 Dict, 10 Final, 11 List, 12 Mapping, 13 Optional, 14 Sequence, 15 Set, 16 Tuple, 17 TYPE_CHECKING, 18 Union, 19) 20from typing_extensions import TypeAlias 21 22import torch 23import torch._C 24import torch._ops 25import torch._prims.executor 26import torch.fx 27from torch._subclasses.fake_tensor import FakeTensor 28from torch.fx._compatibility import compatibility 29from torch.fx.passes.fake_tensor_prop import FakeTensorProp 30from torch.fx.passes.operator_support import OperatorSupport 31from torch.fx.passes.tools_common import CALLABLE_NODE_OPS 32from torch.utils import _pytree 33 34 35if TYPE_CHECKING: 36 import onnx 37 import onnxruntime 38 from onnxruntime.capi import _pybind_state as ORTC 39 40 import torch.onnx 41 import torch.onnx._internal 42 import torch.onnx._internal._exporter_legacy 43 import torch.onnx._internal.diagnostics 44 import torch.onnx._internal.fx.decomposition_table 45 import torch.onnx._internal.fx.passes # noqa: TCH004 46 47 48_SUPPORT_ONNXRT: Optional[bool] = None 49 50__all__ = [ 51 "is_onnxrt_backend_supported", 52 "torch_compile_backend", 53 "OrtExecutionProvider", 54 "OrtBackendOptions", 55 "OrtBackend", 56] 57 58 59def is_onnxrt_backend_supported() -> bool: 60 """Returns ``True`` if ONNX Runtime dependencies are installed and usable 61 to support TorchDynamo backend integration; ``False`` otherwise. 62 63 Example:: 64 65 # xdoctest: +REQUIRES(env:TORCH_DOCTEST_ONNX) 66 >>> import torch 67 >>> if torch.onnx.is_onnxrt_backend_supported(): 68 ... @torch.compile(backend="onnxrt") 69 ... def f(x): 70 ... return x * x 71 ... print(f(torch.randn(10))) 72 ... else: 73 ... print("pip install onnx onnxscript onnxruntime") 74 ... 75 """ 76 global _SUPPORT_ONNXRT 77 78 if _SUPPORT_ONNXRT is None: 79 # `onnxruntime` might import a lot of other runtime packages, 80 # e.g. apex, deepspeed, transformers. 81 # So lazy-importing onnxruntime to avoid possible circular import. 82 try: 83 importlib.import_module("onnxruntime") 84 importlib.import_module("onnxruntime.capi._pybind_state") 85 86 # This is not use directly in DORT but needed by underlying exporter, 87 # so we still need to check if it exists. 88 importlib.import_module("onnxscript") 89 90 import torch.onnx # noqa: F401 91 import torch.onnx._internal # noqa: F401 92 import torch.onnx._internal._exporter_legacy # noqa: F401 93 import torch.onnx._internal.diagnostics # noqa: F401 94 from torch.onnx._internal.fx import ( # noqa: F401 95 decomposition_table, 96 fx_onnx_interpreter, 97 passes, 98 type_utils, 99 ) 100 101 _SUPPORT_ONNXRT = True 102 except ImportError: 103 _SUPPORT_ONNXRT = False 104 105 return _SUPPORT_ONNXRT 106 107 108_dumped_onnx_model: Dict[str, int] = {} 109 110 111def _dump_onnx_model( 112 model_string: bytes, graph_module: Optional[torch.fx.GraphModule] = None 113) -> str: 114 """Stores the onnx model into a file. 115 The name is "{ONNXRT_DUMP_PATH}{N}.onnx" 116 where *N* is the number of files already stored with 117 this prefix. 118 If graph_module is not None, the graph is stored as a string with 119 the same filename except the extension (.txt). 120 """ 121 prefix = os.environ.get("ONNXRT_DUMP_PATH", None) 122 if not prefix: 123 return "" 124 n = _dumped_onnx_model.get(prefix, -1) + 1 125 filename = f"{prefix}{n}.onnx" 126 with open(filename, "wb") as f: 127 f.write(model_string) 128 _dumped_onnx_model[prefix] = n 129 if graph_module is not None: 130 filename_txt = f"{prefix}{n}.txt" 131 with open(filename_txt, "w", encoding="utf-8") as f: 132 f.write(str(graph_module.graph)) 133 return filename 134 135 136def _infer_default_eps() -> Sequence[str]: 137 # TODO: select a good default based on the capabilities of the host 138 # e.g. DML on Windows, etc. 139 return ["CPUExecutionProvider"] 140 141 142def _nvtx_range_push(name: str): 143 """If PyTorch is installed with CUDA support, this starts NVTX range. 144 145 Check torch.cuda.nvtx.range_push's document for more details. 146 """ 147 if torch.cuda.is_available(): 148 torch.cuda.nvtx.range_push(name) 149 150 151def _nvtx_range_pop(): 152 """If PyTorch is installed with CUDA support, this terminates NVTX range. 153 154 Check torch.cuda.nvtx.range_pop's document for more details. 155 """ 156 if torch.cuda.is_available(): 157 torch.cuda.nvtx.range_pop() 158 159 160def _get_ort_device_type(device_type: str): 161 from onnxruntime.capi import _pybind_state as ORTC 162 163 if device_type == "cuda": 164 return ORTC.OrtDevice.cuda() 165 if device_type == "cpu": 166 return ORTC.OrtDevice.cpu() 167 # ort pytorch device is mapped to NPU OrtDevice type 168 if device_type == "maia": 169 return ORTC.OrtDevice.npu() 170 raise ValueError("Unsupported device type: " + device_type) 171 172 173logger = logging.getLogger(__name__) 174# Uncomment the following lines to print out development info. 175# logging.basicConfig(level=logging.WARNING) 176# logger.setLevel(logging.WARNING) 177 178 179class OrtOperatorSupport(OperatorSupport): 180 """Operator support for ONNXRuntime backend. 181 182 It has two-level of support decision. One is via support_dict and the other one 183 is via extra_support_dict. The logic of using support_dict is implemented in 184 OrtOperatorSupport and extra_support_dict is used by OperatorSupport.is_node_supported. 185 """ 186 187 def __init__(self, support_dict: Set[Any], extra_support_dict: Dict[str, Any]): 188 # Use extra_support_dict[op_name] = None to indicate 189 # we support op_name with all input types. Otherwise, 190 # see support_dict (type: SupportDict) in operator_support.py 191 # for specifying supported types. 192 super().__init__(extra_support_dict) 193 self._onnx_support_dict = support_dict 194 195 def is_node_supported( 196 self, submodules: Mapping[str, torch.nn.Module], node: torch.fx.Node 197 ) -> bool: 198 # OperatorSupport.is_node_supported returns True for non-callable nodes. 199 # Since ORT can't execute them, we return False here to override the base 200 # behavior. 201 if node.op not in CALLABLE_NODE_OPS: 202 return False 203 # This is the and the only place to decide if aten op is supported. 204 if node.op == "call_function" and node.target in self._onnx_support_dict: 205 logger.info( 206 "support_dict supports node.target: %s (type: %s)", 207 node.target, 208 type(node.target), 209 ) 210 return True 211 # If node.target is not in support_dict, we still want to check if torch.jit.script 212 # can convert it to ONNX equivalence. Let's use base mechanism to do this. 213 # See extra_support_dict for supported ops. 214 if super().is_node_supported(submodules, node): 215 logger.info( 216 "extra_support_dict supports node.target: %s (type: %s)", 217 node.target, 218 type(node.target), 219 ) 220 return True 221 logger.warning( 222 "support_dict and extra_support_dict don't support node.target: %s (type: %s)", 223 node.target, 224 type(node.target), 225 ) 226 return False 227 228 229def _move_placeholder_to_front(graph_module: torch.fx.GraphModule) -> None: 230 """ 231 In torch.fx.Graph, placeholder is a special assignment node. If it's not 232 executed in the beginning, it could overwrite values computed by upstream 233 nodes. 234 """ 235 236 graph = graph_module.graph 237 placeholders = [] 238 first_not_placeholder = None 239 for node in graph.nodes: 240 if node.op == "placeholder": 241 placeholders.append(node) 242 if first_not_placeholder is None and node.op != "placeholder": 243 first_not_placeholder = node 244 if first_not_placeholder is None: 245 return 246 for placeholder in placeholders: 247 first_not_placeholder.prepend(placeholder) 248 249 250def _infer_ep_from_device(*args) -> Tuple[str, ...]: 251 """Return the first valid device (i.e., GPU or CPU) in argument list.""" 252 eps = [] 253 for arg in args: 254 if hasattr(arg, "device"): 255 device = arg.device 256 if device.type == "cuda": 257 eps.append("CUDAExecutionProvider") 258 elif device.type == "cpu": 259 eps.append("CPUExecutionProvider") 260 return tuple(eps) 261 262 263def _extract_graph_module_inputs(graph_module: torch.fx.GraphModule) -> Tuple[Any, ...]: 264 placeholders = [] 265 for node in graph_module.graph.nodes: 266 if node.op == "placeholder": 267 if hasattr(node, "meta") and "val" in node.meta: 268 assert isinstance(node.meta["val"], torch.Tensor) 269 placeholders.append(node) 270 return tuple(placeholders) 271 272 273def _extract_graph_module_outputs(graph_module: torch.fx.GraphModule) -> Any: 274 """Collect "val" fields from outputs metadata in this torch.fx.GraphModule.""" 275 for node in graph_module.graph.nodes: 276 if node.op == "output": 277 # Output node is unique. Let's retrieve output values from 278 # this node's input list. And then just return. 279 return node.args[0] 280 raise ValueError("No output node found in this torch.fx.GraphModule.") 281 282 283def _infer_ep_from_graph_module(graph_module: torch.fx.GraphModule) -> Tuple[str, ...]: 284 """Return the all valid devices (i.e., GPU or CPU) among outputs of this torch.fx.GraphModule.""" 285 flattened_output_args, _ = _pytree.tree_flatten( 286 _extract_graph_module_outputs(graph_module) 287 ) 288 # Output arguments with example value (type: torch.Tensor) in the `graph_module`. 289 selected_output_args = [ 290 output_arg.meta["val"] 291 for output_arg in flattened_output_args 292 # output_arg must have tensor for its device information. 293 # Otherwise, skip it. 294 if (hasattr(output_arg, "meta") and "val" in output_arg.meta) 295 ] 296 return _infer_ep_from_device(*selected_output_args) 297 298 299def _sort_eps(eps: Tuple[str, ...]) -> Tuple[str, ...]: 300 """Sort execution providers in eps based on pre-set priority.""" 301 302 def get_execution_provider_priority(ep: str) -> int: 303 if ep == "CPUExecutionProvider": 304 # Lowest priority. 305 return 2 306 if ep == "CUDAExecutionProvider": 307 # Higher priority than CPU but lower than 308 # other specialized EPs. 309 return 1 310 # Highest priority. 311 return 0 312 313 unique_eps = set(eps) 314 return tuple(sorted(unique_eps, key=get_execution_provider_priority, reverse=True)) 315 316 317def _get_onnx_devices( 318 values: Tuple[ 319 Union[ 320 torch.Tensor, torch.SymInt, int, torch.SymFloat, float, torch.SymBool, bool 321 ], 322 ..., 323 ], 324) -> Tuple["ORTC.OrtDevice", ...]: 325 from onnxruntime.capi import _pybind_state as ORTC 326 327 def _device_id_or_zero(device_id: int) -> int: 328 return device_id or 0 329 330 def _map_tensor_or_sym_to_device( 331 value: Union[ 332 torch.Tensor, torch.SymInt, int, torch.SymFloat, float, torch.SymBool, bool 333 ], 334 ) -> int: 335 if isinstance(value, torch.Tensor): 336 return ORTC.OrtDevice( 337 _get_ort_device_type(value.device.type), 338 ORTC.OrtDevice.default_memory(), 339 _device_id_or_zero(value.device.index), 340 ) 341 elif isinstance( 342 value, (torch.SymInt, int, torch.SymFloat, float, torch.SymBool, bool) 343 ): 344 return ORTC.OrtDevice( 345 _get_ort_device_type("cpu"), ORTC.OrtDevice.default_memory(), 0 346 ) 347 else: 348 raise ValueError("Unsupported value type: " + str(type(value))) 349 350 if len(values) > 0: 351 ort_devices = tuple(_map_tensor_or_sym_to_device(value) for value in values) 352 return ort_devices 353 else: 354 return (_map_tensor_or_sym_to_device(1),) 355 356 357def _get_ortvalues_from_torch_tensors( 358 tensors: Tuple[torch.Tensor, ...], devices: Tuple["ORTC.OrtDevice", ...] 359) -> Tuple[torch.Tensor, ...]: 360 from onnxruntime.capi import _pybind_state as ORTC 361 362 from torch.onnx._internal.fx.type_utils import _TORCH_DTYPE_TO_NUMPY_DTYPE 363 364 ortvalues = ORTC.OrtValueVector() 365 ortvalues.reserve(len(tensors)) 366 dtypes = [] 367 shapes = [] 368 data_ptrs = [] 369 370 for tensor in tensors: 371 dtypes.append(_TORCH_DTYPE_TO_NUMPY_DTYPE[tensor.dtype]) 372 shapes.append(tensor.size()) 373 data_ptrs.append(tensor.data_ptr()) 374 ortvalues.push_back_batch(tensors, data_ptrs, dtypes, shapes, devices) 375 return ortvalues 376 377 378def _to_real_tensor(tensor: FakeTensor) -> torch.Tensor: 379 if tensor.is_sparse: 380 raise ValueError("sparse tensor is not yet supported.") 381 out = torch.empty(tensor.size(), dtype=tensor.dtype, device=tensor.device) 382 return out 383 384 385def _adjust_scalar_from_fx_to_onnx( 386 dynamo_value: Union[ 387 torch.Tensor, 388 int, 389 float, 390 bool, 391 ], 392 value_info: "onnx.ValueInfoProto", # type: ignore[name-defined] 393) -> torch.Tensor: 394 """Helper function to wrap PyTorch variables as torch.Tensor""" 395 if ( 396 isinstance(dynamo_value, torch.Tensor) 397 and len(value_info.type.tensor_type.shape.dim) == 0 398 and dynamo_value.shape == (1,) 399 ): 400 # ONNX expect a scalar with empty shape. 401 # In contrast, PyTorch usually allows implicit 402 # conversion between shape=() and shape=(1,). 403 # 404 # Below, PyTorch's shape (1,) is reshaped to (). 405 return torch.squeeze(dynamo_value) 406 elif isinstance(dynamo_value, int): 407 return torch.tensor(dynamo_value, dtype=torch.int64) 408 elif isinstance(dynamo_value, float): 409 return torch.tensor(dynamo_value, dtype=torch.float32) 410 elif isinstance(dynamo_value, bool): 411 return torch.tensor(dynamo_value, dtype=torch.bool) 412 else: 413 assert isinstance(dynamo_value, torch.Tensor) 414 return dynamo_value.contiguous() 415 416 417def _adjust_scalar_from_onnx_to_fx( 418 tensor: torch.Tensor, 419 prim_value: Union[ 420 torch.Tensor, 421 torch.SymInt, 422 int, 423 torch.SymFloat, 424 float, 425 torch.SymBool, 426 bool, 427 ], 428) -> Union[ 429 torch.Tensor, 430 int, 431 float, 432 bool, 433]: 434 """Helper function to wrap ORT-produced torch.Tensor as PyTorch variables""" 435 assert isinstance(tensor, torch.Tensor), "ORT's output must be tensor." 436 if isinstance( 437 prim_value, 438 (torch.SymInt, int, torch.SymFloat, float, torch.SymBool, bool), 439 ): 440 # Convert tensor back to scalar to match Dynamo's expectation. 441 return tensor.item() 442 return tensor 443 444 445def _run_onnx_session_with_ortvaluevector( 446 sess: "onnxruntime.InferenceSession", 447 input_names: Tuple[str, ...], 448 inputs: Tuple[torch.Tensor, ...], 449 input_devices: Tuple["ORTC.OrtDevice", ...], 450 output_names: Tuple[str, ...], 451 outputs: Tuple[torch.Tensor, ...], 452 output_devices: Tuple["ORTC.OrtDevice", ...], 453 preallocate_output: bool, 454 input_value_infos: Tuple["onnx.ValueInfoProto", ...], # type: ignore[name-defined] 455 normalized_prim_outputs: Tuple[ 456 Union[ 457 torch.Tensor, torch.SymInt, int, torch.SymFloat, float, torch.SymBool, bool 458 ], 459 ..., 460 ], 461) -> Tuple[Union[torch.Tensor, int, float, bool], ...]: 462 import onnxruntime 463 from onnxruntime.capi import _pybind_state as ORTC 464 465 _nvtx_range_push("contiguous") 466 inputs = tuple( 467 _adjust_scalar_from_fx_to_onnx(arg, value_info) 468 for arg, value_info in zip(inputs, input_value_infos) 469 ) 470 _nvtx_range_pop() 471 472 _nvtx_range_push("push_back_batch") 473 ort_inputs = _get_ortvalues_from_torch_tensors(inputs, input_devices) 474 475 # preallocate output pytorch Tensors and use the buffers affined to the torch device for the output ortvalue. 476 # Because the output ortvalue is not allocated and owned by ort, it does not need to convert the output ortvalue 477 # to torch Tensor transferring the ownership. 478 if preallocate_output: 479 pth_outputs = tuple( 480 _to_real_tensor(t) if isinstance(t, FakeTensor) else t for t in outputs 481 ) 482 ort_outputs = _get_ortvalues_from_torch_tensors(pth_outputs, output_devices) 483 else: 484 ort_outputs = ORTC.OrtValueVector() 485 _nvtx_range_pop() 486 487 _nvtx_range_push("run_with_ortvaluevector") 488 run_options = onnxruntime.RunOptions() 489 run_options.add_run_config_entry("disable_synchronize_execution_providers", "1") 490 sess.run_with_ortvaluevector( 491 run_options, input_names, ort_inputs, output_names, ort_outputs, output_devices 492 ) 493 _nvtx_range_pop() 494 495 # Post-processing step: 496 # wrap ORT's outputs to the schema represented by 497 # `prim_output` (obtained by running the original 498 # torch.fx.GraphModule). 499 if preallocate_output: 500 # Profile the ORT-to-PyTorch type cast below 501 _nvtx_range_push("after run_with_ortvaluevector") 502 # Outputs are stored on pre-allocated torch.Tensors' memory, 503 # so this case doesn't need to convert ORTValue to torch.Tensor. 504 pth_outputs = tuple( 505 _adjust_scalar_from_onnx_to_fx(onnx_output, prim_output) # type: ignore[misc] 506 for onnx_output, prim_output in zip(pth_outputs, normalized_prim_outputs) 507 ) 508 _nvtx_range_pop() 509 return pth_outputs 510 else: 511 # Profile the two ORT-to-PyTorch type casts below 512 _nvtx_range_push("after run_with_ortvaluevector") 513 # Map ORTValue to torch.Tensor. 514 pth_outputs = onnxruntime.training.ortmodule._utils._ortvalues_to_torch_tensor( 515 ort_outputs 516 ) 517 # Change some torch.Tensor to int, float, bool. 518 pth_outputs = tuple( 519 _adjust_scalar_from_onnx_to_fx(onnx_output, prim_output) # type: ignore[misc] 520 for onnx_output, prim_output in zip(pth_outputs, normalized_prim_outputs) 521 ) 522 _nvtx_range_pop() 523 return pth_outputs 524 525 526def _run_onnx_session_with_fetch( 527 sess: "onnxruntime.InferenceSession", 528 input_names: Tuple[str, ...], 529 inputs: Tuple[torch.Tensor, ...], 530 input_devices: Tuple["ORTC.OrtDevice", ...], 531 output_names: Tuple[str, ...], 532 outputs: Tuple[torch.Tensor, ...], 533 output_devices: Tuple["ORTC.OrtDevice", ...], 534 preallocate_output: bool, 535 input_value_infos: Tuple["onnx.ValueInfoProto", ...], # type: ignore[name-defined] 536 normalized_prim_outputs: Tuple[ 537 Union[ 538 torch.Tensor, torch.SymInt, int, torch.SymFloat, float, torch.SymBool, bool 539 ], 540 ..., 541 ], 542) -> Tuple[Union[torch.Tensor, int, float, bool], ...]: 543 import onnxruntime 544 545 inputs = tuple( 546 _adjust_scalar_from_fx_to_onnx(arg, value_info) 547 for arg, value_info in zip(inputs, input_value_infos) 548 ) 549 feed = { 550 name: onnxruntime.OrtValue.ortvalue_from_numpy(tensor.cpu().numpy()) 551 for name, tensor in zip(input_names, inputs) 552 } 553 ort_outputs = sess.run(output_names, feed) 554 pth_outputs = tuple( 555 _adjust_scalar_from_onnx_to_fx( 556 torch.from_numpy(value), 557 prim_output, 558 ) 559 for value, prim_output in zip(ort_outputs, normalized_prim_outputs) 560 ) 561 return pth_outputs 562 563 564class OrtExecutionInfoPerSession: 565 """Information required to execute torch.fx.GraphModule using onnxruntime.InferenceSession""" 566 567 def __init__( 568 self, 569 session: "onnxruntime.InferenceSession", 570 input_names: Tuple[str, ...], 571 input_value_infos: Tuple["onnx.ValueInfoProto", ...], # type: ignore[name-defined] 572 output_names: Tuple[str, ...], 573 output_value_infos: Tuple["onnx.ValueInfoProto", ...], # type: ignore[name-defined] 574 input_devices: Tuple["ORTC.OrtDevice", ...], 575 output_devices: Tuple["ORTC.OrtDevice", ...], 576 example_outputs: Union[Tuple[torch.Tensor, ...], torch.Tensor], 577 ): 578 # Carrier of ONNX model and its executor. 579 self.session: onnxruntime.InferenceSession = session 580 # For the ONNX model stored in self.session, self.input_names[i] is the 581 # name of the i-th positional input. 582 self.input_names: Tuple[str, ...] = input_names 583 # self.input_name[i]'s type information is stored in self.input_value_infos[i]. 584 self.input_value_infos: Tuple[onnx.ValueInfoProto, ...] = input_value_infos # type: ignore[name-defined] 585 # Similar to self.input_names, but for outputs. 586 self.output_names: Tuple[str, ...] = output_names 587 # Similar to self.input_value_infos but for outputs. 588 self.output_value_infos: Tuple[onnx.ValueInfoProto, ...] = output_value_infos # type: ignore[name-defined] 589 # For the ONNX model stored in self.session, self.input_devices[i] is the 590 # i-th positional input's device. 591 self.input_devices: Tuple[ORTC.OrtDevice, ...] = input_devices 592 # Similar to self.input_devices, but for outputs. 593 self.output_devices: Tuple[ORTC.OrtDevice, ...] = output_devices 594 # This is the outputs of executing the original torch.fx.GraphModule with example inputs 595 # (i.e., args passed into OrtBackend._ort_acclerated_call). 596 self.example_outputs: Union[Tuple[torch.Tensor, ...], torch.Tensor] = ( 597 example_outputs 598 ) 599 600 def is_supported(self, *args): 601 from torch.onnx._internal.fx.type_utils import ( 602 _TORCH_DTYPE_TO_ONNX_TENSOR_ELEMENT_TYPE, 603 from_python_type_to_onnx_tensor_element_type, 604 ) 605 606 # Compare the args and the input schema in ONNX model and 607 # return the first match. 608 if len(args) != len(self.input_value_infos): 609 return False 610 for arg, value_info in zip(args, self.input_value_infos): 611 if not isinstance(arg, (torch.Tensor, float, int)): 612 return False 613 614 # Check Python scalars such as int, float, and bool. 615 if isinstance(arg, (int, float, bool)): 616 # Map, e.g., float to onnx.TensorProto.FLOAT. 617 onnx_dtype = from_python_type_to_onnx_tensor_element_type(type(arg)) 618 if onnx_dtype != value_info.type.tensor_type.elem_type: 619 return False 620 if len(value_info.type.tensor_type.shape.dim) != 0: 621 return False 622 continue 623 624 # Check tensor. 625 onnx_dtype = _TORCH_DTYPE_TO_ONNX_TENSOR_ELEMENT_TYPE[arg.dtype] 626 if onnx_dtype != value_info.type.tensor_type.elem_type: 627 return False 628 for dim, onnx_dim in zip(arg.shape, value_info.type.tensor_type.shape.dim): 629 if isinstance(dim, int) and ( 630 onnx_dim.dim_value == dim or onnx_dim.dim_param 631 ): 632 continue 633 elif isinstance(dim, torch.SymInt) and onnx_dim.dim_param: 634 continue 635 else: 636 return False 637 return True 638 639 640@dataclasses.dataclass 641class OrtExecutionInfoForAllGraphModules: 642 def __init__(self) -> None: 643 # All sessions (and their related information) created by exporting the same GraphModule 644 # with different inputs. 645 self.execution_info_per_graph_module: Dict[ 646 torch.fx.GraphModule, List[OrtExecutionInfoPerSession] 647 ] = {} 648 649 def search_reusable_session_execution_info( 650 self, graph_module: torch.fx.GraphModule, *args 651 ): 652 if graph_module not in self.execution_info_per_graph_module: 653 return None 654 # All execution information for ONNX models exported from the same `graph_module` 655 # with different inputs. 656 candidates = self.execution_info_per_graph_module[graph_module] 657 658 for candidate in candidates: 659 if candidate.is_supported(*args): 660 # Returns the first session that accepts this input schema. 661 return candidate 662 # No reusable session found. 663 return None 664 665 def cache_session_execution_info( 666 self, graph_module: torch.fx.GraphModule, info: OrtExecutionInfoPerSession 667 ): 668 if graph_module not in self.execution_info_per_graph_module: 669 self.execution_info_per_graph_module[graph_module] = [info] 670 else: 671 self.execution_info_per_graph_module[graph_module].append(info) 672 673 674OrtExecutionProvider: TypeAlias = Union[str, Tuple[str, Mapping[str, Any]]] 675"""Either the name of an ONNX Runtime execution provider as a string or 676a 2-tuple of the name and a dictionary of execution provider options. 677 678Examples:: 679 680 >>> "CPUExecutionProvider" 681 682 >>> ("CUDAExecutionProvider", {"device_id": 3}) 683 684""" 685 686 687@dataclasses.dataclass(frozen=True) 688@compatibility(is_backward_compatible=False) 689class OrtBackendOptions: 690 """Options for constructing an ``OrtBackend``, the ONNX Runtime 691 backend (``"onnxrt"``) for ``torch.compile``. 692 693 Example:: 694 695 >>> @torch.compile( 696 ... backend="onnxrt", 697 ... options=torch.onnx._OrtBackendOptions(...), 698 ... ) 699 ... def ort_function(x): 700 ... return x ** x 701 """ 702 703 preferred_execution_providers: Optional[Sequence[OrtExecutionProvider]] = None 704 """An optional sequence of execution providers to be prioritized ahead of any 705 execution providers that may be inferred (see ``infer_execution_providers``). 706 """ 707 708 infer_execution_providers: bool = True 709 """Whether to infer an execution provider from ``torch.device`` bound to inputs or found in the graph.""" 710 711 default_execution_providers: Optional[Sequence[OrtExecutionProvider]] = None 712 """The default fallback execution providers. If not specified, one will be 713 be selected based on the host environment (most likely ``"CPUExecutionProvider"``). 714 """ 715 716 # preallocate_output allows for allocating output torch Tensor buffers and feeding them to InferenceSession 717 # in order to avoid internal allocation of output buffers in InferenceSession. 718 # If output ortvalue returned from InferenceSession is allocated internally, 719 # it needs to be converted to torch Tensor for return, and the torch Tensor should hold the ownership. 720 # When a custom torch device is used with a custom aten allocator, the conversion from ortvalue to torch Tensor 721 # should be supported, which is currently done through dlpack. Note that dlpack might not support a custom torch device. 722 # It can be avoided by allowing for preallocation for output buffers allocated by a custom aten allocator, 723 # and use the preallocated output buffers for InferenceSession not holding any ownership for them. 724 # TODO(wschin): Make it to inference session level flag. 725 # See https://github.com/pytorch/pytorch/issues/106869. 726 preallocate_output: bool = False 727 """If ``True``, allocate memory for ONNX Runtime's outputs on the PyTorch side.""" 728 729 use_aot_autograd: bool = True 730 """Whether to wrap the ``OrtBackend`` with TorchDynamo's aot_autograd backend 731 to support training (i.e., backward graphs are also sent to ``OrtBackend``). 732 733 Symbolic execution is used to capture the forward pass and backward passes as a single graph. 734 Then, a selected graph partition algorithm (``min_cut_rematerialization_partition``) is used 735 to split the entire graph into forward sub-graph and backward sub-graph. Finally, both 736 sub-graphs are compiled by ``OrtBackend``. 737 """ 738 739 export_options: Optional["torch.onnx.ExportOptions"] = None 740 """Options for the TorchDynamo-based ONNX exporter used by the ``OrtBackend``.""" 741 742 ort_session_options: Optional["onnxruntime.SessionOptions"] = None 743 """Options for the ``onnxruntime.InferenceSession`` used by the ``OrtBackend``.""" 744 745 pre_ort_model_transforms: Optional[ # type: ignore[name-defined] 746 Sequence[Callable[["onnx.ModelProto"], None]] 747 ] = None 748 """A list of graph transforms to be applied to the ONNX model before it 749 is fed to ONNXRuntime's InferenceSession.""" 750 751 752@compatibility(is_backward_compatible=False) 753class OrtBackend: 754 """A backend compiles (sub-)graphs in torch.fx.GraphModule to onnxruntime.InferenceSession calls. 755 756 The compiler entry point is OrtBackend.compile, which 757 1. partitions the original graph into supported sub-graphs (type: torch.fx.GraphModule) and unsupported 758 sub-graphs. 759 2. For each supported sub-graph, it replaces its _wrapped_call function with _ort_accelerated_call. 760 3. Inside _ort_accelerated_call, it creates onnxruntime.InferenceSession and calls it to execute the sub-graph. 761 """ 762 763 def __init__(self, options: Optional[OrtBackendOptions] = None): 764 from onnxruntime.capi import _pybind_state as ORTC 765 766 import torch.onnx 767 import torch.onnx._internal._exporter_legacy 768 import torch.onnx._internal.fx.decomposition_table 769 770 self._options: Final = OrtBackendOptions() if options is None else options 771 772 # options.export_options contains information shared between exporter and DORT. 773 # For example, they should use the same decomposition table when 774 # 1. capturing FX graph in torch.compile (see how we create aot_ort in register_backend.py) 775 # 2. call exporter's API to convert `torch.fx.GraphModule` to ONNX model 776 # (see onnxfunction_dispatcher passed to FxOnnxInterpreter.run below). 777 # 778 # Convert user-facing option to internal option used by ONNX exporter 779 # to access required information. 780 # Some useful fields: 781 # - Decomposition table for decomposing FX operators in exporter is 782 # self._resolved_onnx_exporter_options.decomposition_table. 783 # - self._resolved_onnx_exporter_options.onnx_registry records what 784 # aten/prim ops are supported by exporter and their exporters (type: callable). 785 self._resolved_onnx_exporter_options = ( 786 torch.onnx._internal._exporter_legacy.ResolvedExportOptions( 787 torch.onnx.ExportOptions() 788 if self._options.export_options is None 789 else self._options.export_options 790 ) 791 ) 792 793 # Given DORT's computation flow: 794 # 1. OrtOperatorSupport uses support_dict and extra_support_dict to select operators 795 # and send them to DORT. 796 # 2. Then, DORT exports the selected sub-graphs into ONNX. 797 # 3. Finally DORT calls ORT to do the computation. 798 # OrtOperatorSupport and create_onnx_friendly_decomposition_table(...) 799 # must use the same support_dict. If the support_dict here contains something not 800 # supported by exporter, exporter will fails in step 2 since the selected graphs may 801 # contains unsupported operators such as aten::_who_you_are. 802 # This restriction is automatically done since DORT and exporter shares the same 803 # self._resolved_onnx_exporter_options. 804 support_dict = torch.onnx._internal.fx.decomposition_table._create_onnx_supports_op_overload_table( 805 self._resolved_onnx_exporter_options.onnx_registry 806 ) 807 808 extra_support_dict: Dict[str, Any] = { 809 "getattr": None, 810 # To send operator.getitem to ORT, add the corresponding string 811 # recognized by PyTorch's OperatorSupport class. 812 "_operator.getitem": None, 813 # To send operator.mul to ORT, add the corresponding string 814 # recognized by PyTorch's OperatorSupport class. 815 "_operator.mul": None, 816 "_operator.add": None, 817 "_operator.sub": None, 818 } 819 820 self._supported_ops = OrtOperatorSupport(support_dict, extra_support_dict) 821 # TODO(wschin): this is a naive implementation of cache without proper guard 822 # See https://github.com/pytorch/pytorch/issues/106868. 823 self._partitioner_cache: Dict[torch.fx.GraphModule, torch.fx.GraphModule] = {} 824 # Conceptually, this filed is a 2-layer dictionary 825 # GraphModule 0 826 # ONNX Model 0 (with ORT InferenceSession and related information. type: OrtExecutionInfoPerSession) 827 # ONNX Model 1 828 # ... 829 # GraphModule 1 830 # ONNX Model 2 (with ORT InferenceSession and related information. type: OrtExecutionInfoPerSession) 831 # ONNX Model 3 832 # ... 833 # ... 834 # , which caches all previous compilation result so that we can reuse them. 835 # ONNX Model 0 and 1 are exported from the same GraphModule 0 but with different inputs 836 # (e.g., tensors with different ranks). GraphModule 0 and GraphModule 1 are different 837 # graphs captured by Dynamo and sent to OrtBackend.compile. 838 self._all_ort_execution_info = OrtExecutionInfoForAllGraphModules() 839 840 self._assert_allclose_to_baseline = False 841 842 self.execution_count = 0 843 844 # Function which invokes ORT do to the real computation. 845 self.run = ( 846 _run_onnx_session_with_ortvaluevector 847 if hasattr(ORTC.OrtValueVector, "push_back_batch") 848 else _run_onnx_session_with_fetch 849 ) 850 851 def _select_eps( 852 self, graph_module: torch.fx.GraphModule, *args 853 ) -> Sequence[Tuple[str, Mapping[str, Any]]]: 854 inferred_eps: Tuple[str, ...] = () 855 if self._options.infer_execution_providers: 856 if eps_from_args := _infer_ep_from_device(*args): 857 # If user feeds CUDA tensor as input argument, 858 # we want to use CUDA EP. 859 # Thus, `eps_from_args` (deduced from input arguments) 860 # has highest priority. 861 inferred_eps = eps_from_args 862 elif eps_from_graph_module := _infer_ep_from_graph_module(graph_module): 863 # If there is no EP in input arguments, we deduce EP from 864 # graph_module's outputs. Those outputs may come from 865 # FakeTensorProp or Dynamo's built-in symbolic shape inference. 866 inferred_eps = eps_from_graph_module 867 868 selected_eps = [] 869 870 for ep in ( 871 *(self._options.preferred_execution_providers or []), 872 *_sort_eps(inferred_eps), 873 *(self._options.default_execution_providers or _infer_default_eps()), 874 ): 875 if isinstance(ep, str): 876 ep = (ep, {}) 877 elif isinstance(ep, tuple) and ep[1] is None: 878 ep = (ep[0], {}) 879 if ep is not None and ep not in selected_eps: 880 selected_eps.append(ep) 881 882 return selected_eps 883 884 def _ort_acclerated_call(self, graph_module: torch.fx.GraphModule, *args, **kwargs): 885 """This function replaces GraphModule._wrapped_call in compiled model. 886 887 The _wrapped_call is the underlying implementation of forward method. Replacing 888 it means we delegate the computation to _ort_acclerated_call and therefore 889 onnxruntime.InferenceSession. 890 """ 891 import onnxruntime 892 893 from torch.onnx._internal.fx import fx_onnx_interpreter, passes 894 895 cached_execution_info_per_session = ( 896 self._all_ort_execution_info.search_reusable_session_execution_info( 897 graph_module, *args 898 ) 899 ) 900 if cached_execution_info_per_session: 901 onnx_session = cached_execution_info_per_session.session 902 input_names = cached_execution_info_per_session.input_names 903 output_names = cached_execution_info_per_session.output_names 904 input_value_infos = cached_execution_info_per_session.input_value_infos 905 output_value_infos = cached_execution_info_per_session.output_value_infos 906 input_devices = cached_execution_info_per_session.input_devices 907 output_devices = cached_execution_info_per_session.output_devices 908 prim_outputs = cached_execution_info_per_session.example_outputs 909 else: 910 # It's first time seeing such as graph. Let's make a new session 911 # (type: onnxruntime.InferenceSession) for it. 912 913 graph_module = passes.MovePlaceholderToFront( 914 self._resolved_onnx_exporter_options.diagnostic_context, 915 graph_module, 916 ).run() 917 # Generate reference outputs. They are used to indicate output 918 # tensors' types and devices when calling ORT. 919 # 920 # WARNING: The downstream code should not change prim_outputs and 921 # this backend should always produces output with schema identical to prim_outputs'. 922 923 if self._resolved_onnx_exporter_options.dynamic_shapes: 924 # No pre-allocation when dynamic shape is enabled. 925 self.preallocate_output = False 926 extracted_outputs = _extract_graph_module_outputs(graph_module) 927 928 def maybe_map_to_meta_val(value): 929 if hasattr(value, "meta") and "val" in value.meta: 930 # Select outputs with "val" information. Without "val", 931 # it's not possible access output_arg.meta["val"].device. 932 return value.meta["val"] 933 else: 934 return value 935 936 prim_outputs = _pytree.tree_map( 937 maybe_map_to_meta_val, extracted_outputs 938 ) 939 else: 940 try: 941 prim_outputs = FakeTensorProp(graph_module).propagate( 942 *args, **kwargs 943 ) 944 except Exception: 945 logger.warning("FakeTensorProb failed for %s", graph_module) 946 # When FakeTensorProp fails, it is not possible to preallocate output buffers 947 # because the output shapes are not inferred. 948 self.preallocate_output = False 949 950 # rethrow FakeTensorProb failure because it is not yet currently handled. 951 raise 952 953 # Create the object to iterate through the nodes in graph one-by-one 954 # and calls the corresponding ONNX exporter for each node. 955 fx_interpreter = fx_onnx_interpreter.FxOnnxInterpreter( 956 diagnostic_context=self._resolved_onnx_exporter_options.diagnostic_context 957 ) 958 # Cast FX variables if they will result schema-mismatch when searching 959 # for ONNX operator. E.g., add(double_tensor, int_tensor) is fine in PyTorch, 960 # but ONNX expects add(double_tensor, double_tensor). 961 graph_module = passes.InsertTypePromotion( 962 self._resolved_onnx_exporter_options.diagnostic_context, graph_module 963 ).run() 964 # Start the per-node exporting process. It's conceptually a for loop 965 # scanning through the nodes in the graph. 966 exported = fx_interpreter.run( 967 fx_graph_module=graph_module, 968 onnxfunction_dispatcher=self._resolved_onnx_exporter_options.onnxfunction_dispatcher, 969 ) 970 # Convert the exported result to ONNX ModelProto. 971 onnx_model = exported.to_model_proto( 972 opset_version=self._resolved_onnx_exporter_options.onnx_registry.opset_version, 973 ) 974 975 try: 976 from onnxscript import optimizer # type: ignore[import] 977 from onnxscript.rewriter import ( # type: ignore[import] 978 onnxruntime as ort_rewriter, 979 ) 980 981 onnx_model = optimizer.optimize(onnx_model) 982 onnx_model = ort_rewriter.rewrite(onnx_model) 983 except ImportError: 984 logger.warning( 985 "ONNXScript optimizer is not available. Skipping optimization. " 986 "Please `pip install onnxscript -U` to enable post-export optimization." 987 ) 988 989 # Modify ONNX model using pre-registered graph transforms. 990 # They are in-place modifications for avoiding unnecessary 991 # copy of ONNX initializers. 992 if self._options.pre_ort_model_transforms: 993 for transform in self._options.pre_ort_model_transforms: 994 transform(onnx_model) 995 996 onnx_model_bytes = onnx_model.SerializeToString() 997 if os.environ.get("ONNXRT_DUMP_PATH", None): 998 # If not empty, environment variable ONNXRT_DUMP_PATH defined the path 999 # where generated onnx files should be stored. 1000 # This module keeps a global variables keeping track of the 1001 # stored models. 1002 # If ONNXRT_DUMP_PATH="dumped/dumped_model_" 1003 # The first file name will be 'dumped/dumped_model_0.onnx'. 1004 # For every dumped model, a text file 'dumped/dumped_model_0.txt' 1005 # is created as well to contain the string representing the graph_module. 1006 _dump_onnx_model(onnx_model_bytes, graph_module=graph_module) 1007 1008 # Initialize a ORT session to execute this ONNX model. 1009 # Note that TorchDynamo assumes all inputs/outputs are on the 1010 # same device, but it's subject to change (very likely with 1011 # dynamic shape support), so we add execution providers 1012 # based on the logic in _select_eps: (explicitly preferred EPs, 1013 # EPs inferred from inputs or graph, and the fallback default EP)/ 1014 # 1015 # TODO(wschin): enable external allocators. 1016 # See https://github.com/pytorch/pytorch/issues/106867 1017 onnx_session = onnxruntime.InferenceSession( 1018 path_or_bytes=onnx_model_bytes, 1019 sess_options=self._options.ort_session_options, 1020 providers=self._select_eps(graph_module, *args), 1021 ) 1022 1023 # Cache ORT session. It's reused for the same "graph_module". 1024 # Generate ONNX model and extract its input and output names. 1025 input_names = tuple(input.name for input in onnx_model.graph.input) 1026 output_names = tuple(output.name for output in onnx_model.graph.output) 1027 input_devices = _get_onnx_devices(args) 1028 # Cache devices for inputs and outputs. They are used to invoke 1029 # ORT session. Output devices indicate where (e.g., GPU or CPU) 1030 # to store outputs 1031 if isinstance(prim_outputs, tuple): 1032 output_devices = _get_onnx_devices(prim_outputs) 1033 else: 1034 output_devices = _get_onnx_devices((prim_outputs,)) 1035 1036 input_value_infos = tuple(input for input in onnx_model.graph.input) 1037 output_value_infos = tuple(output for output in onnx_model.graph.output) 1038 1039 execution_info_per_session = OrtExecutionInfoPerSession( 1040 session=onnx_session, 1041 input_names=input_names, 1042 input_value_infos=input_value_infos, 1043 output_names=output_names, 1044 output_value_infos=output_value_infos, 1045 input_devices=input_devices, 1046 output_devices=output_devices, 1047 example_outputs=prim_outputs, 1048 ) 1049 1050 self._all_ort_execution_info.cache_session_execution_info( 1051 graph_module, execution_info_per_session 1052 ) 1053 1054 self.execution_count += 1 1055 1056 # ORT always returns a tuple of outputs. If the original output is a tensor, 1057 # ORT output's first element must be extracted and returned. Otherwise, type 1058 # mismatch may happen in downstream computation. 1059 is_single_tensor_output = isinstance(prim_outputs, torch.Tensor) 1060 normalized_prim_outputs = ( 1061 (prim_outputs,) if is_single_tensor_output else prim_outputs 1062 ) 1063 assert isinstance(normalized_prim_outputs, tuple) 1064 assert all( 1065 isinstance(elem, (torch.Tensor, torch.SymInt, int)) 1066 for elem in normalized_prim_outputs 1067 ) 1068 1069 _nvtx_range_push("run_onnx_session_with_ortvaluevector") 1070 onnx_outputs = self.run( 1071 onnx_session, 1072 input_names, 1073 args, 1074 input_devices, 1075 output_names, 1076 normalized_prim_outputs, 1077 output_devices, 1078 self._options.preallocate_output, 1079 input_value_infos, 1080 normalized_prim_outputs, 1081 ) 1082 _nvtx_range_pop() 1083 1084 if self._assert_allclose_to_baseline: 1085 # Compute baseline. 1086 baseline_outputs = torch._prims.executor.execute( 1087 graph_module, *args, executor="aten" 1088 ) 1089 normalized_baseline_ouptuts = ( 1090 (baseline_outputs,) if is_single_tensor_output else baseline_outputs 1091 ) 1092 # Ensure every output tensor is close to the corresponding baseline. 1093 for onnx_output, baseline_output in zip( 1094 onnx_outputs, normalized_baseline_ouptuts 1095 ): 1096 torch.testing.assert_close(onnx_output, baseline_output) 1097 return onnx_outputs[0] if is_single_tensor_output else onnx_outputs 1098 1099 def compile(self, graph_module: torch.fx.GraphModule, args) -> torch.fx.GraphModule: 1100 # Deferred import since CapabilityBasedPartitioner is not decorated with 1101 # @compatibility; importing it at the module level will result in the test 1102 # failing: pytest test/test_fx.py -k test_public_api_surface 1103 # because this module is imported into torch.onnx. 1104 from torch.fx.passes.infra.partitioner import CapabilityBasedPartitioner 1105 1106 # FX graph based partitioning based on ONNX supported ops. 1107 # Given a graph module 1108 # GraphModule0 1109 # node_0 1110 # node_1 1111 # node_2 1112 # node_3 1113 # node_4 1114 # If only node_2 is not supported by ONNX, this graph module will be partitioned into 1115 # GraphModule0 1116 # GraphModule1 1117 # node_0 1118 # node_1 1119 # node_2 1120 # GraphModule2 1121 # node_3 1122 # node_4 1123 # by calling CapabilityBasedPartitioner.partition_and_fuse. 1124 # Then, GraphModule1's and GraphModule2's forward method (GraphModule._wrapped_call) 1125 # will be replaced by OrtBackend._ort_accelerated_call to delegate computation to ORT. 1126 if graph_module in self._partitioner_cache: 1127 partitioned_prim_graph_module = self._partitioner_cache[graph_module] 1128 else: 1129 prim_graph_module = graph_module 1130 partitioner = CapabilityBasedPartitioner( 1131 prim_graph_module, 1132 self._supported_ops, 1133 allows_single_node_partition=True, 1134 ) 1135 partitioned_prim_graph_module = partitioner.partition_and_fuse() 1136 self._partitioner_cache[graph_module] = partitioned_prim_graph_module 1137 1138 # Overriding fused_module's __call__() function with ort_acclerated_call() 1139 # This loop goes through all graph partitions (each of them is an ONNX-representable graph) 1140 # and override their _wrapped_call function with _ort_accelerated_call. 1141 # Inside _ort_accelerated_call, the partition's graph is exported into ONNX and executed by ORT. 1142 for node in partitioned_prim_graph_module.graph.nodes: 1143 # TODO(wschin): use a better way to identify fused submodule 1144 # See https://github.com/pytorch/pytorch/issues/106872. 1145 if node.op == "call_module" and "fused_" in node.name: 1146 fused_module = getattr(partitioned_prim_graph_module, node.name) 1147 # self.ort_acclerated_call is responsible for exporting graph to ONNX, 1148 # creating ORT session, and running ORT session. 1149 fused_module._wrapped_call = self._ort_acclerated_call 1150 1151 return partitioned_prim_graph_module 1152 1153 def __call__( 1154 self, graph_module: torch.fx.GraphModule, args 1155 ) -> torch.fx.GraphModule: 1156 """If ``OrtBackendOptions.use_aot_autograd`` is ``True``, the `auto_autograd` compiler 1157 will be invoked, wrapping this ``OrtBackend`` instance's ``compile`` method. Otherwise, 1158 the ``compile`` method is invoked directly.""" 1159 if self._options.use_aot_autograd: 1160 from functorch.compile import min_cut_rematerialization_partition 1161 from torch._dynamo.backends.common import aot_autograd 1162 1163 return aot_autograd( 1164 fw_compiler=self.compile, 1165 partition_fn=min_cut_rematerialization_partition, 1166 decompositions=self._resolved_onnx_exporter_options.decomposition_table, 1167 )(graph_module, args) 1168 1169 return self.compile(graph_module, args) 1170 1171 __instance_cache_max_count: Final = 8 1172 __instance_cache: Final[List["OrtBackend"]] = [] 1173 1174 @staticmethod 1175 def get_cached_instance_for_options( 1176 options: Optional[Union[OrtBackendOptions, Mapping[str, Any]]] = None, 1177 ) -> "OrtBackend": 1178 """Returns a possibly cached instance of an ``OrtBackend``. If an existing 1179 backend was created previously through this function with the same options, 1180 it will be returned. Otherwise a new backend will be created, cached, and 1181 returned. 1182 1183 Note: if ``options`` sets ``ort_session_options``, a new ``OrtBackend`` 1184 will always be returned, since ``onnxruntime.SessionOptions`` cannot 1185 participate in caching.""" 1186 1187 def reusable(a: OrtBackendOptions, b: OrtBackendOptions): 1188 if ( 1189 a.preferred_execution_providers != b.preferred_execution_providers 1190 or a.infer_execution_providers != b.infer_execution_providers 1191 or a.default_execution_providers != b.default_execution_providers 1192 or a.preallocate_output != b.preallocate_output 1193 or a.use_aot_autograd != b.use_aot_autograd 1194 or a.pre_ort_model_transforms != b.pre_ort_model_transforms 1195 ): 1196 return False 1197 1198 # onnxruntime.SessionOptions is a pybind11 object, cannot be pickled, 1199 # and holds too much potential state to reasonably check manually; 1200 # ort_session_options is provided at all, the backend does not participate 1201 # in caching. 1202 if a.ort_session_options is not None or b.ort_session_options is not None: 1203 return False 1204 1205 if a.export_options is b.export_options: 1206 return True 1207 1208 # Similarly, some objects in ExportOptions are too stateful to use for 1209 # caching. We should revisit this. 1210 if a.export_options is not None and b.export_options is not None: 1211 return ( 1212 a.export_options.dynamic_shapes == b.export_options.dynamic_shapes 1213 and a.export_options.diagnostic_options 1214 == b.export_options.diagnostic_options 1215 and a.export_options.onnx_registry is b.export_options.onnx_registry 1216 and a.export_options.fake_context is b.export_options.fake_context 1217 ) 1218 1219 # We can't account for how the two option sets may differ, so it's not safe to reuse. 1220 return False 1221 1222 if not isinstance(options, OrtBackendOptions): 1223 options = OrtBackendOptions(**(options or {})) 1224 1225 backend = next( 1226 (b for b in OrtBackend.__instance_cache if reusable(b._options, options)), 1227 None, 1228 ) 1229 1230 if backend is None: 1231 assert ( 1232 len(OrtBackend.__instance_cache) < OrtBackend.__instance_cache_max_count 1233 ), ( 1234 f"No more than {OrtBackend.__instance_cache_max_count} instances of " 1235 f"{OrtBackend} allowed. Please instantiate `{OrtBackend}` explicitly " 1236 "to pass to `torch.compile`. " 1237 "See https://github.com/pytorch/pytorch/pull/107973#discussion_r1306144795 " 1238 "for discussion." 1239 ) 1240 OrtBackend.__instance_cache.append(backend := OrtBackend(options)) 1241 1242 return backend 1243 1244 @staticmethod 1245 def clear_cached_instances(): 1246 OrtBackend.__instance_cache.clear() 1247 1248 @staticmethod 1249 def get_cached_instances(): 1250 return tuple(OrtBackend.__instance_cache) 1251 1252 1253@compatibility(is_backward_compatible=False) 1254def torch_compile_backend( 1255 graph_module: torch.fx.GraphModule, 1256 args, 1257 *, 1258 options: Optional[Union[OrtBackendOptions, Mapping[str, Any]]] = None, 1259): 1260 return OrtBackend.get_cached_instance_for_options(options)(graph_module, args) 1261