1# mypy: allow-untyped-defs 2from __future__ import annotations 3 4import contextlib 5from typing import Callable 6 7import torch 8import torch._ops 9import torch.func 10import torch.fx 11from torch._subclasses import fake_tensor 12from torch.fx.experimental import proxy_tensor 13from torch.onnx._internal.fx import _pass, diagnostics 14from torch.onnx._internal.fx.passes import _utils 15from torch.utils import _pytree as pytree 16 17 18class Functionalize(_pass.Transform): 19 """Functionalize a GraphModule. 20 21 This pass utilizes ``functionalization`` utility of ``torch._functorch`` to convert 22 a GraphModule into a functional form. The two main functionalities are (copied from 23 its documentations): 24 25 * ``functionalization`` removes (intermediate) mutations and aliasing from a 26 function, while preserving the function's semantics. 27 28 * ``functionalization`` also removes mutations (and views) that were performed 29 on function inputs. However to preserve semantics, functionalize will "fix up" the 30 mutations after the transform has finished running, by detecting if any tensor inputs 31 "should have" been mutated, and copying the new data back to the inputs if necessary. 32 For example, consider:: 33 34 def fn(a, b): 35 a.add_(b) 36 return a 37 38 For a call like `fn(x, y)`, the variable `x` outside is also mutated. Hence just 39 functionalizing is not enough for preserving the original semantics. A "special" 40 input mutation step needs to be inserted at the end.:: 41 42 # After functionalization, without input mutation "fix up". 43 # This is not semantically the same. The variable outside the function call that 44 # was passed in as `a` is not mutated. 45 def fn(a, b): 46 new_a = a + b 47 return new_a 48 49 # Functionalization with input mutation "fix up" that preserves semantics. 50 def fn(a, b): 51 new_a = a + b 52 53 # Copying the new data back to the inputs 54 a.copy_(new_a) 55 56 return new_a 57 58 For ONNX inference, it is recommended to run ``RemoveInputMutation`` after this pass. 59 ``RemoveInputMutation`` removes the "fix up" nodes that were added by ``Functionalize``, 60 which are not needed for ONNX inference. 61 """ 62 63 def __init__( 64 self, 65 diagnostic_context: diagnostics.DiagnosticContext, 66 module: torch.fx.GraphModule, 67 enable_dynamic_axes: bool, 68 allow_fake_constant: bool | None = False, 69 ): 70 super().__init__(diagnostic_context, module) 71 self.enable_dynamic_axes = enable_dynamic_axes 72 self.allow_fake_constant = allow_fake_constant 73 74 def _functionalize(self, function: Callable) -> Callable: 75 # Working around a dispatcher issue with `torch.func.functionalize` when used 76 # together with `make_fx`. 77 # Ref: https://github.com/pytorch/pytorch/issues/99774#issuecomment-1527949391 78 def wrapped(*inputs): 79 inputs_functional = pytree.tree_map_only( 80 torch.Tensor, torch._to_functional_tensor, inputs 81 ) 82 torch._enable_functionalization(reapply_views=True) 83 try: 84 out = function(*inputs_functional) 85 finally: 86 torch._disable_functionalization() 87 flat_inputs = pytree.tree_leaves(inputs) 88 flat_inputs_functional = pytree.tree_leaves(inputs_functional) 89 for inpt, input_functional in zip(flat_inputs, flat_inputs_functional): 90 if isinstance(input_functional, torch.Tensor): 91 torch._sync(input_functional) 92 inpt_new = torch._from_functional_tensor(input_functional) 93 pytree.tree_map(torch._sync, out) 94 out_unwrapped = pytree.tree_map(torch._from_functional_tensor, out) 95 return out_unwrapped 96 97 return wrapped 98 99 def _run(self, *args) -> torch.fx.GraphModule: 100 # To preserve stack trace info after `make_fx`. 101 module = _utils.wrap_graph_module_for_node_meta_preservation(self.module) 102 103 functionalized_callable = self._functionalize(module) 104 105 # Mimic `torch._dynamo.export(aten_graph=True)` behavior in invoking `make_fx`. 106 # TODO: May need revisit for user fake mode export + dynamic shape scenario. 107 fake_mode: fake_tensor.FakeTensorMode | None = self.fake_mode 108 maybe_fake_args = self._maybe_fakefy_args(fake_mode, *args) 109 if fake_mode is not None: 110 # Using existing fake mode as context, signal `make_fx` that it does not need 111 # to create a new fake mode by passing tracing_mode as "real". 112 tracing_mode = "real" 113 else: 114 # Existing fake mode not found, signal `make_fx` to create one. 115 fake_mode = contextlib.nullcontext() # type: ignore[assignment] 116 tracing_mode = "symbolic" if self.enable_dynamic_axes else "fake" 117 118 assert fake_mode is not None # for mypy 119 with fake_tensor.unset_fake_temporarily(), fake_mode: 120 graph_module = proxy_tensor.make_fx( 121 functionalized_callable, 122 decomposition_table={}, 123 tracing_mode=tracing_mode, 124 _allow_non_fake_inputs=True, 125 _allow_fake_constant=bool(self.allow_fake_constant), 126 )(*maybe_fake_args) 127 128 # Rename placeholder targets to match the original module's signature since 129 # We don't want to map forward(x, y, z) to forward(arg0, arg1, arg2). 130 _utils.replace_placeholder_name_and_target(graph_module, self.module) 131 132 return graph_module 133 134 135class RemoveInputMutation(_pass.Transform): 136 """Remove `aten.copy_.default` nodes that mutate module inputs. 137 138 This pass is recommended to be used after ``Functionalization`` pass. 139 ``Functionalization`` pass adds `aten.copy_.default` nodes to the graph 140 when it detects mutations to inputs. These nodes are not needed for ONNX export 141 for inference. They could be useful for training. 142 """ 143 144 def _run(self, *args) -> torch.fx.GraphModule: 145 for node in reversed(self.module.graph.nodes): 146 if ( 147 node.op == "call_function" 148 and node.target == torch.ops.aten.copy_.default 149 and len(node.users) == 0 150 and isinstance(node.args[0], torch.fx.Node) 151 and node.args[0].op == "placeholder" 152 ): 153 self.module.graph.erase_node(node) 154 return self.module 155