xref: /aosp_15_r20/external/executorch/exir/passes/insert_write_back_for_buffers_pass.py (revision 523fa7a60841cd1ecfb9cc4201f1ca8b03ed023a)
1# Copyright (c) Meta Platforms, Inc. and affiliates.
2# All rights reserved.
3#
4# This source code is licensed under the BSD-style license found in the
5# LICENSE file in the root directory of this source tree.
6
7from typing import Dict, List, Optional, Tuple
8
9import torch
10
11from torch.export.exported_program import (
12    ExportedProgram,
13    ExportGraphSignature,
14    InputKind,
15    OutputKind,
16    OutputSpec,
17)
18from torch.export.graph_signature import TensorArgument
19from torch.utils import _pytree as pytree
20
21
22def _insert_copy(
23    gm: torch.fx.GraphModule,
24    mutated_outputs: List[Optional[str]],
25    input_name_to_node: Dict[str, torch.fx.Node],
26):
27    """
28    Find the all the buffers and inputs that were mutated and insert copy_
29    operators to reflect mutations.
30    """
31    output_node = None
32    for node in gm.graph.nodes:
33        if node.op == "output":
34            output_node = node
35            break
36    assert output_node is not None
37    outputs = pytree.tree_flatten(output_node.args)[0]
38    assert len(outputs) == len(mutated_outputs)
39
40    user_output_nodes = []
41    buffer_output_nodes = []
42    for return_node, mutated_node_name in zip(outputs, mutated_outputs):
43        # User output, leave alone
44        if mutated_node_name is None:
45            user_output_nodes.append(return_node)
46            continue
47
48        # Mutable buffer grab the node
49        if mutated_node_name in input_name_to_node:
50            mutated_node = input_name_to_node[mutated_node_name]
51        else:
52            raise RuntimeError(
53                f"Could not find {mutated_node_name} in either buffer or input nodes"
54            )
55
56        # insert copy
57        with gm.graph.inserting_before(output_node):
58            buffer_output = gm.graph.call_function(
59                torch.ops.aten.copy_.default, (mutated_node, return_node)
60            )
61            # add output of copy to graph outputs
62            buffer_output_nodes.append(buffer_output)
63
64    with gm.graph.inserting_before(output_node):
65        buffer_output_nodes.extend(user_output_nodes)
66        # Remove old outputs
67        new_output = gm.graph.output(tuple(buffer_output_nodes))
68        output_node.replace_all_uses_with(new_output)
69        gm.graph.erase_node(output_node)
70    return buffer_output_nodes
71
72
73def insert_write_back_for_buffers_pass(
74    ep: ExportedProgram,
75) -> Tuple[torch.fx.GraphModule, ExportGraphSignature]:
76    gm: torch.fx.GraphModule = ep.graph_module
77    lifted_inputs: List[Optional[str]] = []
78    for in_spec in ep.graph_signature.input_specs:
79        if in_spec.kind in (
80            InputKind.BUFFER,
81            InputKind.CONSTANT_TENSOR,
82            InputKind.PARAMETER,
83            InputKind.CUSTOM_OBJ,
84        ):
85            lifted_inputs.append(in_spec.target)
86        elif in_spec.kind is InputKind.USER_INPUT and isinstance(
87            in_spec.arg, TensorArgument
88        ):
89            lifted_inputs.append(in_spec.arg.name)
90        else:
91            lifted_inputs.append(None)
92
93    input_name_to_node: Dict[str, torch.fx.Node] = {}
94
95    placeholder_nodes = [node for node in gm.graph.nodes if node.op == "placeholder"]
96    assert len(lifted_inputs) == len(placeholder_nodes)
97    # Grab the all the non user inputs
98    for input_node, lifted_node in zip(placeholder_nodes, lifted_inputs):
99        if lifted_node is not None:
100            input_name_to_node[lifted_node] = input_node
101
102    # Grab the mutable buffer nodes in the outputs,
103    mutated_outputs: List[Optional[str]] = [
104        (
105            out_spec.target
106            if out_spec.kind
107            in (OutputKind.BUFFER_MUTATION, OutputKind.USER_INPUT_MUTATION)
108            and out_spec.arg.name
109            not in {
110                val.name for val in input_name_to_node.values()
111            }  # if the output arg is the input value then all operations on it are in-place so theres no need to add a copy_ node
112            else None
113        )
114        for out_spec in ep.graph_signature.output_specs
115    ]
116
117    # insert the copy ops and update the outputs
118    buffer_output_nodes = _insert_copy(gm, mutated_outputs, input_name_to_node)
119    gm.graph.lint()
120    gm.graph.eliminate_dead_code()
121    gm.recompile()
122
123    # patch the output signature to point to the new updated outputs
124    new_output_specs: List[OutputSpec] = []
125    i = 0
126    for output_spec in ep.graph_signature.output_specs:
127        if output_spec.kind in (
128            OutputKind.BUFFER_MUTATION,
129            OutputKind.USER_INPUT_MUTATION,
130        ):
131            output_spec.arg.name = buffer_output_nodes[i].name
132            i += 1
133        new_output_specs.append(output_spec)
134
135    signature = ExportGraphSignature(
136        input_specs=ep.graph_signature.input_specs,
137        output_specs=new_output_specs,
138    )
139
140    return gm, signature
141