xref: /aosp_15_r20/external/pytorch/torch/onnx/_internal/fx/passes/functionalization.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# mypy: allow-untyped-defs
2from __future__ import annotations
3
4import contextlib
5from typing import Callable
6
7import torch
8import torch._ops
9import torch.func
10import torch.fx
11from torch._subclasses import fake_tensor
12from torch.fx.experimental import proxy_tensor
13from torch.onnx._internal.fx import _pass, diagnostics
14from torch.onnx._internal.fx.passes import _utils
15from torch.utils import _pytree as pytree
16
17
18class Functionalize(_pass.Transform):
19    """Functionalize a GraphModule.
20
21    This pass utilizes ``functionalization`` utility of ``torch._functorch`` to convert
22    a GraphModule into a functional form. The two main functionalities are (copied from
23    its documentations):
24
25    * ``functionalization`` removes (intermediate) mutations and aliasing from a
26    function, while preserving the function's semantics.
27
28    * ``functionalization`` also removes mutations (and views) that were performed
29    on function inputs. However to preserve semantics, functionalize will "fix up" the
30    mutations after the transform has finished running, by detecting if any tensor inputs
31    "should have" been mutated, and copying the new data back to the inputs if necessary.
32    For example, consider::
33
34        def fn(a, b):
35            a.add_(b)
36            return a
37
38      For a call like `fn(x, y)`, the variable `x` outside is also mutated. Hence just
39      functionalizing is not enough for preserving the original semantics. A "special"
40      input mutation step needs to be inserted at the end.::
41
42        # After functionalization, without input mutation "fix up".
43        # This is not semantically the same. The variable outside the function call that
44        # was passed in as `a` is not mutated.
45        def fn(a, b):
46            new_a = a + b
47            return new_a
48
49        # Functionalization with input mutation "fix up" that preserves semantics.
50        def fn(a, b):
51            new_a = a + b
52
53            # Copying the new data back to the inputs
54            a.copy_(new_a)
55
56            return new_a
57
58    For ONNX inference, it is recommended to run ``RemoveInputMutation`` after this pass.
59    ``RemoveInputMutation`` removes the "fix up" nodes that were added by ``Functionalize``,
60    which are not needed for ONNX inference.
61    """
62
63    def __init__(
64        self,
65        diagnostic_context: diagnostics.DiagnosticContext,
66        module: torch.fx.GraphModule,
67        enable_dynamic_axes: bool,
68        allow_fake_constant: bool | None = False,
69    ):
70        super().__init__(diagnostic_context, module)
71        self.enable_dynamic_axes = enable_dynamic_axes
72        self.allow_fake_constant = allow_fake_constant
73
74    def _functionalize(self, function: Callable) -> Callable:
75        # Working around a dispatcher issue with `torch.func.functionalize` when used
76        # together with `make_fx`.
77        # Ref: https://github.com/pytorch/pytorch/issues/99774#issuecomment-1527949391
78        def wrapped(*inputs):
79            inputs_functional = pytree.tree_map_only(
80                torch.Tensor, torch._to_functional_tensor, inputs
81            )
82            torch._enable_functionalization(reapply_views=True)
83            try:
84                out = function(*inputs_functional)
85            finally:
86                torch._disable_functionalization()
87            flat_inputs = pytree.tree_leaves(inputs)
88            flat_inputs_functional = pytree.tree_leaves(inputs_functional)
89            for inpt, input_functional in zip(flat_inputs, flat_inputs_functional):
90                if isinstance(input_functional, torch.Tensor):
91                    torch._sync(input_functional)
92                    inpt_new = torch._from_functional_tensor(input_functional)
93            pytree.tree_map(torch._sync, out)
94            out_unwrapped = pytree.tree_map(torch._from_functional_tensor, out)
95            return out_unwrapped
96
97        return wrapped
98
99    def _run(self, *args) -> torch.fx.GraphModule:
100        # To preserve stack trace info after `make_fx`.
101        module = _utils.wrap_graph_module_for_node_meta_preservation(self.module)
102
103        functionalized_callable = self._functionalize(module)
104
105        # Mimic `torch._dynamo.export(aten_graph=True)` behavior in invoking `make_fx`.
106        # TODO: May need revisit for user fake mode export + dynamic shape scenario.
107        fake_mode: fake_tensor.FakeTensorMode | None = self.fake_mode
108        maybe_fake_args = self._maybe_fakefy_args(fake_mode, *args)
109        if fake_mode is not None:
110            # Using existing fake mode as context, signal `make_fx` that it does not need
111            # to create a new fake mode by passing tracing_mode as "real".
112            tracing_mode = "real"
113        else:
114            # Existing fake mode not found, signal `make_fx` to create one.
115            fake_mode = contextlib.nullcontext()  # type: ignore[assignment]
116            tracing_mode = "symbolic" if self.enable_dynamic_axes else "fake"
117
118        assert fake_mode is not None  # for mypy
119        with fake_tensor.unset_fake_temporarily(), fake_mode:
120            graph_module = proxy_tensor.make_fx(
121                functionalized_callable,
122                decomposition_table={},
123                tracing_mode=tracing_mode,
124                _allow_non_fake_inputs=True,
125                _allow_fake_constant=bool(self.allow_fake_constant),
126            )(*maybe_fake_args)
127
128        # Rename placeholder targets to match the original module's signature since
129        # We don't want to map forward(x, y, z) to forward(arg0, arg1, arg2).
130        _utils.replace_placeholder_name_and_target(graph_module, self.module)
131
132        return graph_module
133
134
135class RemoveInputMutation(_pass.Transform):
136    """Remove `aten.copy_.default` nodes that mutate module inputs.
137
138    This pass is recommended to be used after ``Functionalization`` pass.
139    ``Functionalization`` pass adds `aten.copy_.default` nodes to the graph
140    when it detects mutations to inputs. These nodes are not needed for ONNX export
141    for inference. They could be useful for training.
142    """
143
144    def _run(self, *args) -> torch.fx.GraphModule:
145        for node in reversed(self.module.graph.nodes):
146            if (
147                node.op == "call_function"
148                and node.target == torch.ops.aten.copy_.default
149                and len(node.users) == 0
150                and isinstance(node.args[0], torch.fx.Node)
151                and node.args[0].op == "placeholder"
152            ):
153                self.module.graph.erase_node(node)
154        return self.module
155