1# mypy: allow-untyped-decorators 2# mypy: allow-untyped-defs 3import contextlib 4import copy 5import dataclasses 6import functools 7import operator 8import types 9import warnings 10from collections import namedtuple 11from contextlib import contextmanager 12from typing import ( 13 Any, 14 Callable, 15 Dict, 16 final, 17 Iterator, 18 List, 19 Optional, 20 Tuple, 21 Type, 22 TYPE_CHECKING, 23 Union, 24) 25 26from torch._higher_order_ops.utils import autograd_not_implemented 27from torch._library.fake_class_registry import FakeScriptObject 28from torch.fx._utils import first_call_function_nn_module_stack 29from torch.fx.graph import _PyTreeCodeGen, _PyTreeInfo 30from torch.fx.immutable_collections import immutable_dict, immutable_list 31from torch.fx.passes.runtime_assert import insert_deferred_runtime_asserts 32 33 34if TYPE_CHECKING: 35 # Import the following modules during type checking to enable code intelligence features, 36 # such as auto-completion in tools like pylance, even when these modules are not explicitly 37 # imported in user code. 38 39 import sympy 40 41 from torch.utils._sympy.value_ranges import ValueRanges 42 43import torch 44import torch.utils._pytree as pytree 45from torch._export.utils import ( 46 _collect_and_set_constant_attrs, 47 _collect_param_buffer_metadata, 48 _detect_fake_mode_from_gm, 49 _name_hoo_subgraph_placeholders, 50 _overwrite_signature_for_non_persistent_buffers, 51 _populate_param_buffer_metadata_to_new_gm, 52 _rename_without_collisions, 53) 54from torch._export.verifier import Verifier 55from torch._guards import detect_fake_mode 56from torch._subclasses.fake_tensor import unset_fake_temporarily 57from torch._subclasses.functional_tensor import FunctionalTensor 58from torch.export._tree_utils import is_equivalent, reorder_kwargs 59from torch.fx._compatibility import compatibility 60from torch.fx.passes.infra.pass_base import PassResult 61from torch.fx.passes.infra.pass_manager import PassManager 62 63from .graph_signature import ( # noqa: F401 64 ArgumentSpec, 65 ConstantArgument, 66 CustomObjArgument, 67 ExportGraphSignature, 68 InputKind, 69 InputSpec, 70 OutputKind, 71 OutputSpec, 72 SymIntArgument, 73 TensorArgument, 74 TokenArgument, 75) 76 77 78__all__ = [ 79 "ExportedProgram", 80 "ModuleCallEntry", 81 "ModuleCallSignature", 82] 83 84 85PassType = Callable[[torch.fx.GraphModule], Optional[PassResult]] 86 87 88@dataclasses.dataclass 89class ModuleCallSignature: 90 inputs: List[ArgumentSpec] 91 outputs: List[ArgumentSpec] 92 in_spec: pytree.TreeSpec 93 out_spec: pytree.TreeSpec 94 95 def replace_all_uses_with(self, original_node, new_node): 96 for i in self.inputs: 97 if i.name == original_node.name: 98 i.name = new_node.name 99 for o in self.outputs: 100 if o.name == original_node.name: 101 o.name = new_node.name 102 103 104@dataclasses.dataclass 105class ModuleCallEntry: 106 fqn: str 107 signature: Optional[ModuleCallSignature] = None 108 109 110def _disable_prexisiting_fake_mode(fn): 111 @functools.wraps(fn) 112 def wrapper(*args, **kwargs): 113 with unset_fake_temporarily(): 114 return fn(*args, **kwargs) 115 116 return wrapper 117 118 119def _fx_collection_equivalence_fn( 120 spec1_type: Optional[type], 121 spec1_context: pytree.Context, 122 spec2_type: Optional[type], 123 spec2_context: pytree.Context, 124) -> bool: 125 """Treat containers and their immutable variants as the same type. Otherwise 126 compare as normal. 127 """ 128 if spec1_type is None or spec2_type is None: 129 return spec1_type is spec2_type and spec1_context == spec2_context 130 131 if issubclass(spec1_type, (dict, immutable_dict)) and issubclass( 132 spec2_type, (dict, immutable_dict) 133 ): 134 return spec1_context == spec2_context 135 136 if issubclass(spec1_type, (list, immutable_list)) and issubclass( 137 spec2_type, (list, immutable_list) 138 ): 139 return spec1_context == spec2_context 140 141 return spec1_type is spec2_type and spec1_context == spec2_context 142 143 144def _register_cia_to_meta(*args, **kwargs): 145 kernel = kwargs["kernel"] 146 del kwargs["kernel"] 147 148 assert torch._C._dispatch_has_kernel_for_dispatch_key( 149 kernel.name(), torch._C.DispatchKey.CompositeImplicitAutograd 150 ) 151 152 return kernel._op_dk( 153 torch._C.DispatchKey.CompositeImplicitAutograd, *args, **kwargs 154 ) 155 156 157# This list is compiled from DispatchKey.cpp. 158# The idea is that we use these keys to override 159# CIA decomp in export 160_AUTOGRAD_ALIAS_BACKEND_KEYS_TO_OVERRIDE = [ 161 torch._C.DispatchKey.AutogradCPU, 162 torch._C.DispatchKey.AutogradCUDA, 163 torch._C.DispatchKey.AutogradMeta, 164 torch._C.DispatchKey.AutogradXLA, 165 torch._C.DispatchKey.AutogradLazy, 166 torch._C.DispatchKey.AutogradIPU, 167 torch._C.DispatchKey.AutogradXPU, 168 torch._C.DispatchKey.AutogradMPS, 169 torch._C.DispatchKey.AutogradHPU, 170 torch._C.DispatchKey.AutogradPrivateUse1, 171 torch._C.DispatchKey.AutogradPrivateUse2, 172 torch._C.DispatchKey.AutogradPrivateUse3, 173] 174 175 176@contextmanager 177def _override_composite_implicit_decomp(ops_to_preserve, decomp_table, safe=True): 178 # This function overrides CompositeImplicitAutograd decomp for 179 # functional composite ops that user specified. Ideally we want to not-decompose 180 # ALL composite ops but today's C++ functinalization relies on 181 # the fact that it is working with the opset after decomp is run. 182 # Hence we can only do it for functional ops. One caveat is that 183 # there are some composite ops that lie about their schema (claimed to be 184 # functional but not really aka dropout), for these cases, we just decompose. 185 186 # When safe=False, we will assume that ops_to_preserve can be mutating/aliasing 187 # and their usual decompositions need to be shadowed rather than overridden. 188 # Thus we will avoid asserting that they are valid to preserve, and will not 189 # replace their CompositeImplicitAutograd kernels with NotImplemented. 190 # The only current users of this mode are variants of aten::to that we will 191 # replace with aten::_to_copy in FunctionalTensorMode.__torch_dispatch__. 192 193 saved_tables = {} 194 patched_ops = set() 195 removed_decomps = {} 196 for op_overload in ops_to_preserve: 197 # Our strategy for deciding if we can preserve CIA is following: 198 # 1. The op should be known statically that it is functional 199 # 2. If it is maybe aliasing, we decompose because we must know if an op 200 # is mutating or aliasing. 201 # TODO (tmanlaibaatar) make this utility function and share it with functional_tensor 202 # decomp part. (https://github.com/pytorch/pytorch/issues/129431) 203 def assert_valid_to_preserve(op_overload): 204 if op_overload in FunctionalTensor.maybe_aliasing_or_mutating_ops: 205 raise RuntimeError( 206 f"We can't detect {op_overload} as a functional op statically, so we can't preserve it" 207 ) 208 if op_overload in FunctionalTensor.metadata_fns: 209 raise RuntimeError( 210 f"{op_overload} is a metadata query function, " 211 "it will be preserved implicitly in our tracing system. " 212 "Please file an issue on github if you see otherwise" 213 ) 214 215 alias_info = len( 216 [i for i in op_overload._schema.arguments if i.alias_info is not None] 217 ) 218 219 is_mutating_or_aliasing = alias_info != 0 or op_overload._schema.is_mutable 220 221 if is_mutating_or_aliasing: 222 raise RuntimeError( 223 f"{op_overload} is a mutating/aliasing op, we can't preserve it as is" 224 ) 225 226 if not torch._C._dispatch_has_kernel(op_overload.name()): 227 raise RuntimeError( 228 f"{op_overload} is a TorchScript op, we can't preserve it as is" 229 ) 230 231 return True 232 233 if safe: 234 # If we didn't error, it means we can go ahead 235 assert_valid_to_preserve(op_overload) 236 237 saved_tables[op_overload] = op_overload.py_kernels.copy() 238 patched_ops.add(op_overload) 239 240 for override_dispatch_key in _AUTOGRAD_ALIAS_BACKEND_KEYS_TO_OVERRIDE: 241 if override_dispatch_key not in op_overload.py_kernels: 242 # TODO (tmanlaibaatar)https://github.com/pytorch/pytorch/issues/129430 243 op_overload.py_impl(override_dispatch_key)( 244 autograd_not_implemented(op_overload, deferred_error=True) 245 ) 246 if torch._C.DispatchKey.CompositeImplicitAutograd in op_overload.py_kernels: 247 del op_overload.py_kernels[torch._C.DispatchKey.CompositeImplicitAutograd] 248 249 if safe: 250 251 def _(*args, **kwargs): 252 return NotImplemented 253 254 op_overload.py_impl(torch._C.DispatchKey.CompositeImplicitAutograd)(_) 255 256 # For fake tensor prop, we do want to register meta kernel directly 257 if torch._C.DispatchKey.Meta not in op_overload.py_kernels: 258 op_overload.py_impl(torch._C.DispatchKey.Meta)( 259 functools.partial(_register_cia_to_meta, kernel=op_overload) 260 ) 261 262 if op_overload in decomp_table: 263 removed_decomps[op_overload] = decomp_table[op_overload] 264 del decomp_table[op_overload] 265 266 try: 267 yield 268 finally: 269 for op in patched_ops: 270 op.py_kernels.clear() 271 op.py_kernels.update(saved_tables[op]) 272 op._dispatch_cache.clear() 273 274 for op, decomp in removed_decomps.items(): 275 decomp_table[op] = decomp 276 277 278@contextmanager 279def _override_decomp_aten_to_variants(): 280 # Preserve variants of aten::to understanding that they are mutating/aliasing 281 # and their CompositeImplicitAutograd kernels will not become NotImplemented. 282 # We will later replace them with aten._to_copy when functionalizing. 283 with _override_composite_implicit_decomp( 284 (torch.ops.aten.to.dtype_layout, torch.ops.aten.to.dtype), 285 {}, 286 safe=False, 287 ): 288 yield 289 290 291def _decompose_and_get_gm_with_new_signature_constants( 292 ep, 293 *, 294 decomp_table: Dict[torch._ops.OperatorBase, Callable], 295 _preserve_ops: Tuple[torch._ops.OpOverload], 296 joint_loss_index: Optional[int], 297): 298 from torch._functorch.aot_autograd import aot_export_module 299 from torch._subclasses.fake_tensor import FakeTensorMode 300 from torch.export._trace import ( 301 _export_to_aten_ir, 302 _fakify_params_buffers, 303 _ignore_backend_decomps, 304 _verify_nn_module_stack, 305 _verify_placeholder_names, 306 _verify_stack_trace, 307 ) 308 from torch.fx.experimental.symbolic_shapes import ShapeEnv 309 310 # TODO Merge this path with inference IR decomp, but it will require some additional work 311 # so I will leave it for now. T200307782 312 if ep.verifier.dialect == "TRAINING": 313 mod = ep.module() 314 315 fake_args = [] 316 for node in mod.graph.nodes: 317 if node.op == "placeholder": 318 fake_args.append(node.meta["val"]) 319 320 fake_args_unwrapped = pytree.tree_unflatten(fake_args, mod._in_spec) 321 fake_mode = _detect_fake_mode_from_gm(mod) 322 if fake_mode is None: 323 fake_mode = FakeTensorMode(shape_env=ShapeEnv(), export=True) 324 325 # Fix the graph output signature to be tuple if scalar 326 out_spec = mod._out_spec 327 328 orig_arg_names = mod.graph._codegen.pytree_info.orig_args # type: ignore[attr-defined] 329 330 # aot_export expect the return type to always be a tuple. 331 if out_spec.type not in (list, tuple): 332 out_spec = pytree.TreeSpec(tuple, None, [out_spec]) 333 334 mod.graph._codegen = _PyTreeCodeGen( 335 _PyTreeInfo( 336 orig_arg_names, 337 mod._in_spec, 338 out_spec, 339 ) 340 ) 341 342 mod.recompile() 343 344 # the exported module will store constants & non-persistent buffers such that 345 # retracing treats them as persistent buffers, so we inform the constants lifting pass 346 # and overwrite the new graph signature using the previous program. 347 constant_attrs = _collect_and_set_constant_attrs( 348 ep.graph_signature, ep.constants, mod 349 ) 350 351 # get params & buffers after excluding constants 352 fake_params_buffers = _fakify_params_buffers(fake_mode, mod) 353 354 params_buffers_to_node_meta = _collect_param_buffer_metadata(mod) 355 356 with _ignore_backend_decomps(), ( 357 fake_mode 358 ), _override_decomp_aten_to_variants(), _override_composite_implicit_decomp( 359 _preserve_ops, 360 decomp_table, 361 ): 362 aten_export_artifact = _export_to_aten_ir( 363 mod, 364 # this requires empty kwargs, but not in pytree.flattened format 365 ( 366 *fake_args_unwrapped[0], 367 *fake_args_unwrapped[1].values(), 368 ), 369 {}, 370 fake_params_buffers, 371 constant_attrs, 372 decomp_table=decomp_table, 373 _check_autograd_state=False, 374 ) 375 376 gm = aten_export_artifact.gm 377 new_graph_signature = aten_export_artifact.sig 378 379 _populate_param_buffer_metadata_to_new_gm( 380 params_buffers_to_node_meta, gm, new_graph_signature 381 ) 382 383 # overwrite signature for non-persistent buffers 384 new_graph_signature = _overwrite_signature_for_non_persistent_buffers( 385 ep.graph_signature, new_graph_signature 386 ) 387 388 _verify_nn_module_stack(gm) 389 _verify_stack_trace(gm) 390 _verify_placeholder_names(gm, new_graph_signature) 391 392 return _remove_unneccessary_copy_op_pass(gm, new_graph_signature) 393 394 old_placeholders = [ 395 node for node in ep.graph_module.graph.nodes if node.op == "placeholder" 396 ] 397 fake_args = [node.meta["val"] for node in old_placeholders] 398 399 buffers_to_remove = [name for name, _ in ep.graph_module.named_buffers()] 400 for name in buffers_to_remove: 401 delattr(ep.graph_module, name) 402 403 # TODO(zhxhchen17) Return the new graph_signature directly. 404 fake_mode = detect_fake_mode(fake_args) 405 fake_mode = contextlib.nullcontext() if fake_mode is None else fake_mode 406 with _ignore_backend_decomps(), fake_mode, _override_composite_implicit_decomp( 407 _preserve_ops, 408 decomp_table, 409 ): 410 gm, graph_signature = aot_export_module( 411 ep.graph_module, 412 fake_args, 413 decompositions=decomp_table, 414 trace_joint=True if joint_loss_index is not None else False, 415 output_loss_index=joint_loss_index 416 if joint_loss_index is not None 417 else None, 418 ) 419 420 # Update the signatures with the new placeholder names in case they 421 # changed when calling aot_export 422 def update_arg(old_arg, new_ph): 423 if isinstance(old_arg, ConstantArgument): 424 return old_arg 425 elif isinstance(old_arg, TensorArgument): 426 return TensorArgument(name=new_ph.name) 427 elif isinstance(old_arg, SymIntArgument): 428 return SymIntArgument(name=new_ph.name) 429 raise RuntimeError(f"Type of old_arg not supported: {type(old_arg)}") 430 431 new_placeholders = [node for node in gm.graph.nodes if node.op == "placeholder"] 432 new_outputs = list(gm.graph.nodes)[-1].args[0] 433 434 # rename the placeholders 435 assert len(new_placeholders) == len(old_placeholders) 436 for old_ph, new_ph in zip(old_placeholders, new_placeholders): 437 new_ph.name = new_ph.target = old_ph.name 438 439 # handle name collisions with newly decomposed graph nodes 440 name_map = {ph.name: ph.name for ph in new_placeholders} 441 for node in gm.graph.nodes: 442 if node.op == "placeholder": 443 continue 444 node.name = _rename_without_collisions(name_map, node.name, node.name) 445 446 # propagate names to higher order op subgraphs 447 _name_hoo_subgraph_placeholders(gm) 448 449 # Run this pass before creating input/output specs, since size-related CSE/DCE might affect output signature. 450 # Overwrite output specs afterwards. 451 from torch._export.passes._node_metadata_hook import ( 452 _node_metadata_hook, 453 _set_node_metadata_hook, 454 ) 455 from torch._functorch._aot_autograd.input_output_analysis import _graph_output_names 456 457 if not torch._dynamo.config.do_not_emit_runtime_asserts: 458 stack_trace = ( 459 'File "torch/fx/passes/runtime_assert.py", line 24, ' 460 "in insert_deferred_runtime_asserts" 461 ) 462 shape_env = _get_shape_env(gm) 463 if shape_env is not None: 464 with _set_node_metadata_hook( 465 gm, functools.partial(_node_metadata_hook, stack_trace=stack_trace) 466 ): 467 insert_deferred_runtime_asserts( 468 gm, 469 shape_env, 470 f"exported program: {first_call_function_nn_module_stack(gm.graph)}", 471 export=True, 472 ) 473 474 # update output specs 475 gm.recompile() 476 for i, name in enumerate(_graph_output_names(gm)): 477 if isinstance(new_outputs[i], torch.fx.Node): 478 new_outputs[i].name = name 479 480 # To match the output target with correct input for input mutations 481 # need to find the old to new placeholder map 482 old_new_placeholder_map = { 483 spec.arg.name: new_placeholders[i].name 484 for i, spec in enumerate(ep.graph_signature.input_specs) 485 if not isinstance(spec.arg, ConstantArgument) 486 } 487 488 input_specs = [ 489 InputSpec( 490 spec.kind, 491 update_arg(spec.arg, new_placeholders[i]), 492 spec.target, 493 spec.persistent, 494 ) 495 for i, spec in enumerate(ep.graph_signature.input_specs) 496 ] 497 output_specs = [ 498 OutputSpec( 499 spec.kind, 500 update_arg(spec.arg, new_outputs[i]), 501 old_new_placeholder_map.get(spec.target, spec.target), 502 ) 503 for i, spec in enumerate(ep.graph_signature.output_specs) 504 ] 505 506 if joint_loss_index is not None: 507 assert graph_signature.backward_signature is not None 508 gradients = graph_signature.backward_signature.gradients_to_user_inputs 509 assert len(graph_signature.user_inputs) == len(ep.graph_signature.input_specs) 510 specs = { 511 graph_signature.user_inputs[i]: spec 512 for i, spec in enumerate(ep.graph_signature.input_specs) 513 if isinstance(spec.arg, TensorArgument) 514 } 515 for i, node in enumerate(new_outputs[len(output_specs) :]): 516 source = gradients[node.name] 517 spec = specs[source] # type: ignore[index] 518 if spec.kind == InputKind.PARAMETER: 519 kind = OutputKind.GRADIENT_TO_PARAMETER 520 target = spec.target 521 elif spec.kind == InputKind.USER_INPUT: 522 kind = OutputKind.GRADIENT_TO_USER_INPUT 523 target = source 524 else: 525 raise AssertionError(f"Unknown input kind: {spec.kind}") 526 output_specs.append( 527 OutputSpec( 528 kind, 529 TensorArgument(name=node.name), 530 target, 531 ) 532 ) 533 534 assert len(new_placeholders) == len(old_placeholders) 535 536 new_graph_signature = ExportGraphSignature( 537 input_specs=input_specs, output_specs=output_specs 538 ) 539 # NOTE: aot_export adds symint metadata for placeholders with int 540 # values; since these become specialized, we replace such metadata with 541 # the original values. 542 # Also, set the param/buffer metadata back to the placeholders. 543 for old_node, new_node in zip(old_placeholders, new_placeholders): 544 if not isinstance(old_node.meta["val"], torch.Tensor): 545 new_node.meta["val"] = old_node.meta["val"] 546 547 if ( 548 new_node.target in new_graph_signature.inputs_to_parameters 549 or new_node.target in new_graph_signature.inputs_to_buffers 550 ): 551 for k, v in old_node.meta.items(): 552 new_node.meta[k] = v 553 return gm, new_graph_signature 554 555 556def _remove_unneccessary_copy_op_pass( 557 gm: torch.fx.GraphModule, new_graph_signature: ExportGraphSignature 558) -> Tuple[torch.fx.GraphModule, ExportGraphSignature]: 559 """ 560 Removes redundant copy_ node that was introduced due to mutated buffer. 561 """ 562 with gm._set_replace_hook(new_graph_signature.get_replace_hook()): 563 for node in gm.graph.nodes: 564 if node.op == "output": 565 args, _ = pytree.tree_flatten(node.args) 566 for out in args: 567 if ( 568 isinstance(out, torch.fx.Node) 569 and out.name in new_graph_signature.buffers_to_mutate 570 ): 571 if ( 572 out.op == "call_function" 573 and out.target == torch.ops.aten.copy.default 574 ): 575 out.replace_all_uses_with(out.args[1]) # type: ignore[arg-type] 576 gm.graph.erase_node(out) 577 gm.recompile() 578 return gm, new_graph_signature 579 580 581def _common_getitem_elimination_pass( 582 gm: torch.fx.GraphModule, graph_signature, module_call_graph 583): 584 with gm._set_replace_hook(graph_signature.get_replace_hook()): 585 for module in gm.modules(): 586 if not isinstance(module, torch.fx.GraphModule): 587 continue 588 589 node_id: Dict[torch.fx.Node, str] = {} 590 getitems: Dict[str, torch.fx.Node] = {} 591 for node in list(module.graph.nodes): 592 if node.op == "call_function" and node.target == operator.getitem: 593 source, idx = node.args 594 new_id = f"{node_id[source]}.{idx}" 595 if new_id in getitems: 596 node.replace_all_uses_with(getitems[new_id]) 597 for entry in module_call_graph: 598 if entry.signature is not None: 599 entry.signature.replace_all_uses_with( 600 node, getitems[new_id] 601 ) 602 module.graph.erase_node(node) 603 else: 604 getitems[new_id] = node 605 node_id[node] = new_id 606 else: 607 node_id[node] = node.name 608 609 610def _decompose_exported_program( 611 ep, 612 *, 613 decomp_table: Dict[torch._ops.OperatorBase, Callable], 614 _preserve_ops: Tuple[torch._ops.OpOverload], 615 joint_loss_index: Optional[int], 616): 617 gm, new_graph_signature = _decompose_and_get_gm_with_new_signature_constants( 618 ep, 619 decomp_table=decomp_table, 620 _preserve_ops=_preserve_ops, 621 joint_loss_index=joint_loss_index, 622 ) 623 624 # TODO unfortunately preserving graph-level metadata is not 625 # working well with aot_export. So we manually copy it. 626 # (The node-level meta is addressed above.) 627 gm.meta.update(ep.graph_module.meta) 628 629 new_range_constraints = _get_updated_range_constraints( 630 gm, 631 ep.range_constraints, 632 ) 633 634 exported_program = ExportedProgram( 635 root=gm, 636 graph=gm.graph, 637 graph_signature=new_graph_signature, 638 state_dict=ep.state_dict, 639 range_constraints=new_range_constraints, 640 module_call_graph=copy.deepcopy(ep.module_call_graph), 641 example_inputs=ep.example_inputs, 642 constants=ep.constants, 643 ) 644 return exported_program 645 646 647class ExportedProgram: 648 """ 649 Package of a program from :func:`export`. It contains 650 an :class:`torch.fx.Graph` that represents Tensor computation, a state_dict containing 651 tensor values of all lifted parameters and buffers, and various metadata. 652 653 You can call an ExportedProgram like the original callable traced by 654 :func:`export` with the same calling convention. 655 656 To perform transformations on the graph, use ``.module`` property to access 657 an :class:`torch.fx.GraphModule`. You can then use 658 `FX transformation <https://pytorch.org/docs/stable/fx.html#writing-transformations>`_ 659 to rewrite the graph. Afterwards, you can simply use :func:`export` 660 again to construct a correct ExportedProgram. 661 """ 662 663 def __init__( 664 self, 665 root: Union[torch.nn.Module, Dict[str, Any]], 666 graph: torch.fx.Graph, 667 graph_signature: ExportGraphSignature, 668 state_dict: Dict[str, Union[torch.Tensor, torch.nn.Parameter]], 669 range_constraints: "Dict[sympy.Symbol, Any]", 670 module_call_graph: List[ModuleCallEntry], 671 example_inputs: Optional[Tuple[Tuple[Any, ...], Dict[str, Any]]] = None, 672 constants: Optional[ 673 Dict[str, Union[torch.Tensor, FakeScriptObject, torch._C.ScriptObject]] 674 ] = None, 675 *, 676 verifiers: Optional[List[Type[Verifier]]] = None, 677 ): 678 # Remove codegen related things from the graph. It should just be a flat graph. 679 graph._codegen = torch.fx.graph.CodeGen() 680 self._graph_module = _create_graph_module_for_export(root, graph) 681 if isinstance(root, torch.fx.GraphModule): 682 self._graph_module.meta.update(root.meta) 683 684 _common_getitem_elimination_pass( 685 self._graph_module, graph_signature, module_call_graph 686 ) 687 self._graph_signature: ExportGraphSignature = graph_signature 688 self._state_dict: Dict[str, Any] = state_dict 689 self._range_constraints: Dict[sympy.Symbol, ValueRanges] = range_constraints 690 assert module_call_graph is not None 691 self._module_call_graph: List[ModuleCallEntry] = module_call_graph 692 self._example_inputs = example_inputs 693 694 self._constants = constants or {} 695 696 verifiers = verifiers or [Verifier] 697 assert all(issubclass(v, Verifier) for v in verifiers) 698 self._verifiers = verifiers 699 # Validate should be always the last step of the constructor. 700 self.validate() 701 702 @property 703 @compatibility(is_backward_compatible=False) 704 def graph_module(self): 705 return self._graph_module 706 707 @property 708 @compatibility(is_backward_compatible=False) 709 def graph(self): 710 return self.graph_module.graph 711 712 @property 713 @compatibility(is_backward_compatible=False) 714 def graph_signature(self): 715 return self._graph_signature 716 717 @property 718 @compatibility(is_backward_compatible=False) 719 def state_dict(self): 720 return self._state_dict 721 722 @compatibility(is_backward_compatible=False) 723 def parameters(self) -> Iterator[torch.nn.Parameter]: 724 """ 725 Returns an iterator over original module's parameters. 726 """ 727 for _, param in self.named_parameters(): 728 yield param 729 730 @compatibility(is_backward_compatible=False) 731 def named_parameters(self) -> Iterator[Tuple[str, torch.nn.Parameter]]: 732 """ 733 Returns an iterator over original module parameters, yielding 734 both the name of the parameter as well as the parameter itself. 735 """ 736 for param_name in self.graph_signature.parameters: 737 yield param_name, self.state_dict[param_name] 738 739 @compatibility(is_backward_compatible=False) 740 def buffers(self) -> Iterator[torch.Tensor]: 741 """ 742 Returns an iterator over original module buffers. 743 """ 744 for _, buf in self.named_buffers(): 745 yield buf 746 747 @compatibility(is_backward_compatible=False) 748 def named_buffers(self) -> Iterator[Tuple[str, torch.Tensor]]: 749 """ 750 Returns an iterator over original module buffers, yielding 751 both the name of the buffer as well as the buffer itself. 752 """ 753 non_persistent_buffers = set(self.graph_signature.non_persistent_buffers) 754 for buffer_name in self.graph_signature.buffers: 755 if buffer_name in non_persistent_buffers: 756 yield buffer_name, self.constants[buffer_name] 757 else: 758 yield buffer_name, self.state_dict[buffer_name] 759 760 @property 761 @compatibility(is_backward_compatible=False) 762 def range_constraints(self): 763 return self._range_constraints 764 765 @property 766 @compatibility(is_backward_compatible=False) 767 def module_call_graph(self): 768 return self._module_call_graph 769 770 @property 771 @compatibility(is_backward_compatible=False) 772 def example_inputs(self): 773 return self._example_inputs 774 775 @property 776 @compatibility(is_backward_compatible=False) 777 def call_spec(self): 778 CallSpec = namedtuple("CallSpec", ["in_spec", "out_spec"]) 779 780 if len(self.module_call_graph) == 0: 781 return CallSpec(in_spec=None, out_spec=None) 782 assert self.module_call_graph[0].fqn == "" 783 return CallSpec( 784 in_spec=self.module_call_graph[0].signature.in_spec, 785 out_spec=self.module_call_graph[0].signature.out_spec, 786 ) 787 788 @property 789 @compatibility(is_backward_compatible=False) 790 def verifier(self) -> Any: 791 return self._verifiers[0] 792 793 @property 794 @compatibility(is_backward_compatible=False) 795 def dialect(self) -> str: 796 assert self._verifiers is not None 797 return self._verifiers[0].dialect 798 799 @property 800 @compatibility(is_backward_compatible=False) 801 def verifiers(self): 802 return self._verifiers 803 804 @property 805 @compatibility(is_backward_compatible=False) 806 def tensor_constants(self): 807 return self._constants 808 809 @property 810 @compatibility(is_backward_compatible=False) 811 def constants(self): 812 return self._constants 813 814 def _get_flat_args_with_check(self, args, kwargs): 815 """Flatten args, kwargs using pytree, then, check specs. 816 817 Args: 818 args: List[Any] original args passed to __call__ 819 kwargs: Dict[str, Any] original kwargs passed to __call 820 821 Returns: 822 A tuple of (flat_args, received_spec) 823 flat_args is flattend args / kwargs 824 received_spec is the pytree spec produced while flattening the 825 tuple (args, kwargs) 826 """ 827 in_spec = self.call_spec.in_spec 828 if in_spec is not None: 829 kwargs = reorder_kwargs(kwargs, in_spec) 830 flat_args_with_path, received_spec = pytree.tree_flatten_with_path( 831 (args, kwargs) 832 ) # type: ignore[possibly-undefined] 833 self._check_input_constraints(flat_args_with_path) 834 flat_args = tuple(x[1] for x in flat_args_with_path) 835 return flat_args, received_spec 836 837 def _graph_module_flat_inputs(self, args: Any, kwargs: Any) -> Any: 838 """Transform args, kwargs of __call__ to args for graph_module. 839 840 self.graph_module takes stuff from state dict as inputs. 841 The invariant is for ep: ExportedProgram is 842 ep(args, kwargs) == 843 ep.postprocess(ep.graph_module(ep.graph_module_flat_inputs(args, kwargs))) 844 """ 845 846 in_spec = self.call_spec.in_spec 847 flat_args, received_spec = self._get_flat_args_with_check(args, kwargs) 848 if in_spec is not None and not is_equivalent( 849 received_spec, in_spec, _fx_collection_equivalence_fn 850 ): 851 raise ValueError( 852 "Trying to flatten user inputs with exported input tree spec: \n" 853 f"{in_spec}\n" 854 "but actually got inputs with tree spec of: \n" 855 f"{received_spec}" 856 ) 857 858 additional_inputs = [] 859 for input_ in self.graph_signature.input_specs: 860 if input_.kind == InputKind.USER_INPUT: 861 continue 862 elif input_.kind in ( 863 InputKind.PARAMETER, 864 InputKind.BUFFER, 865 ): 866 if input_.persistent is False: 867 # This is a non-persistent buffer, grab it from our 868 # constants instead of the state dict. 869 additional_inputs.append(self.constants[input_.target]) 870 else: 871 additional_inputs.append(self.state_dict[input_.target]) 872 elif input_.kind in ( 873 InputKind.CONSTANT_TENSOR, 874 InputKind.CUSTOM_OBJ, 875 ): 876 additional_inputs.append(self.constants[input_.target]) 877 additional_inputs = tuple(additional_inputs) 878 879 # NOTE: calling convention is first params, then buffers, then args as user supplied them. 880 # See: torch/_functorch/aot_autograd.py#L1034 881 return additional_inputs + flat_args 882 883 def __call__(self, *args: Any, **kwargs: Any) -> Any: 884 raise RuntimeError( 885 "Unable to call ExportedProgram directly. " 886 "You should use `exported_program.module()` instead." 887 ) 888 889 def _postprocess_graph_module_outputs(self, res, orig_args, orig_kwargs): 890 """Process potential mutations to the input. 891 892 Because self.graph_module is functional, so mutations has to be written 893 back after execution of graph_module. 894 """ 895 import torch._export.error as error 896 897 flat_args, _ = self._get_flat_args_with_check(orig_args, orig_kwargs) 898 if self.call_spec.out_spec is not None: 899 buffer_mutation = self.graph_signature.buffers_to_mutate 900 user_input_mutation = self.graph_signature.user_inputs_to_mutate 901 num_mutated = len(buffer_mutation) + len(user_input_mutation) 902 mutated_values = res[:num_mutated] 903 904 # Exclude dependency token from final result. 905 assertion_dep_token = self.graph_signature.assertion_dep_token 906 if assertion_dep_token is not None: 907 assertion_dep_token_index = next(iter(assertion_dep_token.keys())) 908 res = res[:assertion_dep_token_index] 909 910 res = res[num_mutated:] 911 try: 912 res = pytree.tree_unflatten(res, self.call_spec.out_spec) 913 except Exception: 914 _, received_spec = pytree.tree_flatten(res) 915 raise error.InternalError( # noqa: B904 916 "Trying to flatten user outputs with exported output tree spec: \n" 917 f"{self.call_spec.out_spec}\n" 918 "but actually got outputs with tree spec of: \n" 919 f"{received_spec}" 920 ) 921 finally: 922 user_inputs = [ 923 spec 924 for spec in self.graph_signature.input_specs 925 if spec.kind == InputKind.USER_INPUT 926 ] 927 for i, value in enumerate(mutated_values): 928 output_spec = self.graph_signature.output_specs[i] 929 if output_spec.kind == OutputKind.BUFFER_MUTATION: 930 assert output_spec.target is not None 931 self.state_dict[output_spec.target] = value 932 elif output_spec.kind == OutputKind.USER_INPUT_MUTATION: 933 assert output_spec.target is not None 934 index = next( 935 i 936 for i, spec in enumerate(user_inputs) 937 if spec.arg.name == output_spec.target 938 ) 939 flat_args[index].copy_(value) 940 else: 941 raise AssertionError(f"Unexpected kind: {output_spec.kind}") 942 return res 943 944 def __str__(self) -> str: 945 graph_module = self.graph_module.print_readable( 946 print_output=False, colored=False 947 ).replace("\n", "\n ") 948 string = ( 949 "ExportedProgram:\n" 950 f" {graph_module}\n" 951 f"Graph signature: {self.graph_signature}\n" 952 f"Range constraints: {self.range_constraints}\n" 953 ) 954 return string 955 956 def module(self) -> torch.nn.Module: 957 """ 958 Returns a self contained GraphModule with all the parameters/buffers inlined. 959 """ 960 from ._unlift import _unlift_exported_program_lifted_states 961 962 module = _unlift_exported_program_lifted_states(self) 963 964 def _train(self, mode: bool = True): 965 raise NotImplementedError("Calling train() is not supported yet.") 966 967 def _eval(self, mode: bool = True): 968 raise NotImplementedError("Calling eval() is not supported yet.") 969 970 module.train = types.MethodType(_train, module) # type: ignore[method-assign] 971 module.eval = types.MethodType(_eval, module) # type: ignore[method-assign] 972 return module 973 974 def _num_lifted_params_buffers(self): 975 return next( 976 ( 977 i 978 for i, s in enumerate(self._graph_signature.input_specs) 979 if s.kind == InputKind.USER_INPUT 980 ), 981 len(self._graph_signature.input_specs), 982 ) 983 984 @_disable_prexisiting_fake_mode 985 def run_decompositions( 986 self, 987 decomp_table: Optional[Dict[torch._ops.OperatorBase, Callable]] = None, 988 _preserve_ops: Tuple[torch._ops.OpOverload, ...] = (), 989 ) -> "ExportedProgram": 990 """ 991 Run a set of decompositions on the exported program and returns a new 992 exported program. By default we will run the Core ATen decompositions to 993 get operators in the 994 `Core ATen Operator Set <https://pytorch.org/docs/stable/torch.compiler_ir.html>`_. 995 996 For now, we do not decompose joint graphs. 997 """ 998 from torch._decomp import core_aten_decompositions 999 1000 if decomp_table is None: 1001 decomp_table = core_aten_decompositions() 1002 1003 return _decompose_exported_program( 1004 self, 1005 decomp_table=decomp_table, 1006 _preserve_ops=_preserve_ops, # type: ignore[arg-type] 1007 joint_loss_index=None, 1008 ) 1009 1010 def _transform_do_not_use(self, *passes: PassType) -> "ExportedProgram": 1011 pm = PassManager(list(passes)) 1012 # Since we abstractly run the passes, we need to disable backend decomp here 1013 # again. 1014 from torch.export._trace import _ignore_backend_decomps 1015 1016 with _ignore_backend_decomps(): 1017 res = pm(self.graph_module) 1018 transformed_gm = res.graph_module if res is not None else self.graph_module 1019 assert transformed_gm is not None 1020 1021 if transformed_gm is self.graph_module and not res.modified: 1022 return self 1023 1024 # TODO(zhxchen17) Remove this. 1025 def _get_updated_graph_signature( 1026 old_signature: ExportGraphSignature, 1027 new_gm: torch.fx.GraphModule, 1028 ) -> ExportGraphSignature: 1029 """ 1030 Update the graph signature's user_input/user_outputs. 1031 """ 1032 new_input_specs = [] 1033 for i, node in enumerate(new_gm.graph.nodes): 1034 if node.op != "placeholder": 1035 break 1036 1037 assert i < len( 1038 old_signature.input_specs 1039 ), "Number of inputs changed after transformation" 1040 old_input_spec = old_signature.input_specs[i] 1041 arg = ( 1042 old_input_spec.arg 1043 if isinstance( 1044 old_input_spec.arg, (ConstantArgument, CustomObjArgument) 1045 ) 1046 else type(old_input_spec.arg)(node.name) 1047 ) 1048 new_input_specs.append( 1049 InputSpec( 1050 old_input_spec.kind, 1051 arg, 1052 old_input_spec.target, 1053 old_input_spec.persistent, 1054 ) 1055 ) 1056 1057 output_node = list(new_gm.graph.nodes)[-1] 1058 assert output_node.op == "output" 1059 1060 new_output_specs = [] 1061 for i, node in enumerate(output_node.args[0]): 1062 assert i < len( 1063 old_signature.output_specs 1064 ), "Number of outputs changed after transformation" 1065 old_output_spec = old_signature.output_specs[i] 1066 arg = ( 1067 old_output_spec.arg 1068 if isinstance( 1069 old_output_spec.arg, (ConstantArgument, CustomObjArgument) 1070 ) 1071 else type(old_output_spec.arg)(node.name) 1072 ) 1073 new_output_specs.append( 1074 OutputSpec(old_output_spec.kind, arg, old_output_spec.target) 1075 ) 1076 1077 new_signature = ExportGraphSignature( 1078 input_specs=new_input_specs, output_specs=new_output_specs 1079 ) 1080 return new_signature 1081 1082 transformed_ep = ExportedProgram( 1083 root=transformed_gm, 1084 graph=transformed_gm.graph, 1085 graph_signature=_get_updated_graph_signature( 1086 self.graph_signature, transformed_gm 1087 ), 1088 state_dict=self.state_dict, 1089 range_constraints=_get_updated_range_constraints( 1090 transformed_gm, 1091 self.range_constraints, 1092 ), 1093 module_call_graph=copy.deepcopy(self._module_call_graph), 1094 example_inputs=self.example_inputs, 1095 constants=self.constants, 1096 verifiers=self.verifiers, 1097 ) 1098 transformed_ep.graph_module.meta.update(self.graph_module.meta) 1099 transformed_ep.graph_module.meta.update(res.graph_module.meta) 1100 return transformed_ep 1101 1102 def _check_input_constraints(self, flat_args_with_path): 1103 from torch._export.utils import _check_input_constraints_for_graph 1104 1105 placeholders = [p for p in self.graph.nodes if p.op == "placeholder"] 1106 input_placeholders = [ 1107 p 1108 for p, s in zip(placeholders, self.graph_signature.input_specs) 1109 if s.kind == InputKind.USER_INPUT 1110 ] 1111 _check_input_constraints_for_graph( 1112 input_placeholders, flat_args_with_path, self.range_constraints 1113 ) 1114 1115 @compatibility(is_backward_compatible=False) 1116 def validate(self): 1117 self._validate() 1118 1119 # TODO: remove this 1120 @final 1121 def _validate(self): 1122 assert ( 1123 len(self.verifiers) > 0 1124 ), "ExportedProgram must have at least one verifier." 1125 for v in self.verifiers: 1126 v().check(self) 1127 1128 # TODO(zhxchen17) Formalize this. 1129 def _update( 1130 self, graph_module, graph_signature, *, state_dict=None, verifiers=None 1131 ) -> "ExportedProgram": 1132 return ExportedProgram( 1133 root=graph_module, 1134 graph=graph_module.graph, 1135 graph_signature=graph_signature, 1136 state_dict=state_dict if state_dict is not None else self.state_dict, 1137 range_constraints=copy.deepcopy(self.range_constraints), 1138 module_call_graph=copy.deepcopy(self._module_call_graph), 1139 example_inputs=self.example_inputs, 1140 constants=self.constants, 1141 verifiers=verifiers if verifiers is not None else self.verifiers, 1142 ) 1143 1144 1145def _get_shape_env(gm): 1146 vals = [ 1147 node.meta["val"] 1148 for node in gm.graph.nodes 1149 if node.meta.get("val", None) is not None 1150 ] 1151 from torch._guards import detect_fake_mode 1152 1153 fake_mode = detect_fake_mode(vals) 1154 if fake_mode is not None: 1155 return fake_mode.shape_env 1156 for v in vals: 1157 if isinstance(v, torch.SymInt): 1158 return v.node.shape_env 1159 1160 1161def _get_updated_range_constraints( 1162 gm: torch.fx.GraphModule, 1163 old_range_constraints: "Optional[Dict[sympy.Symbol, Any]]" = None, 1164) -> "Dict[sympy.Symbol, Any]": 1165 assert old_range_constraints is not None 1166 1167 shape_env = _get_shape_env(gm) 1168 if shape_env is None: 1169 return {} 1170 1171 range_constraints = copy.copy(old_range_constraints) 1172 range_constraints = { 1173 k: v for k, v in range_constraints.items() if k not in shape_env.replacements 1174 } 1175 # Only when we have an unbacked symint, and it's used as constructor inputs, 1176 # runtime_var_to_range will make a difference compated to var_to_range. 1177 # e.g. [2, oo) -> [0, oo) 1178 for k, v in shape_env.var_to_range.items(): 1179 if k not in shape_env.replacements and k not in range_constraints: 1180 range_constraints[k] = v 1181 return range_constraints 1182 1183 1184def _create_graph_module_for_export(root, graph): 1185 try: 1186 gm = torch.fx.GraphModule(root, graph) 1187 except SyntaxError: 1188 # If custom objects stored in memory are being used in the graph, 1189 # the generated python code will result in a syntax error on the custom 1190 # object, since it is unable to parse the in-memory object. However 1191 # we can still run the graph eagerly through torch.fx.Interpreter, 1192 # so we will bypass this error. 1193 warnings.warn( 1194 "Unable to execute the generated python source code from " 1195 "the graph. The graph module will no longer be directly callable, " 1196 "but you can still run the ExportedProgram, and if needed, you can " 1197 "run the graph module eagerly using torch.fx.Interpreter." 1198 ) 1199 gm = torch.fx.GraphModule(root, torch.fx.Graph()) 1200 gm._graph = graph 1201 1202 return gm 1203