xref: /aosp_15_r20/external/executorch/exir/capture/_unlift.py (revision 523fa7a60841cd1ecfb9cc4201f1ca8b03ed023a)
1# Copyright (c) Meta Platforms, Inc. and affiliates.
2# All rights reserved.
3#
4# This source code is licensed under the BSD-style license found in the
5# LICENSE file in the root directory of this source tree.
6
7import copy
8from typing import Any
9
10import torch
11import torch._export
12from torch.fx.graph import _PyTreeCodeGen, _PyTreeInfo
13from torch.utils import _pytree as pytree
14
15
16Val = Any
17
18
19def _unlift(gm, inp_pos_to_param_buffer_name, in_spec, out_spec, state_dict):
20    count = 0
21    # Step 1: make lifted params as get_attr
22    for node in gm.graph.nodes:
23        if node.op == "placeholder":
24            if count in inp_pos_to_param_buffer_name:
25                with gm.graph.inserting_after(node):
26                    getattr_node = gm.graph.get_attr(
27                        inp_pos_to_param_buffer_name[count]
28                    )
29                    node.replace_all_uses_with(getattr_node)
30                    metadata = node.meta
31                    gm.graph.erase_node(node)
32                    getattr_node.meta = metadata
33            count += 1
34
35    # Step 2: Fix the input/output of the graph now that we deleted
36    # some args.
37    gm.graph.lint()
38    names = [f"arg_{i}" for i in range(len(in_spec.children_specs))]
39    gm.graph._codegen = _PyTreeCodeGen(
40        _PyTreeInfo(
41            names,
42            in_spec,
43            out_spec,
44        )
45    )
46    gm.recompile()
47
48    # Step 3: Find state references in HigherOrderOps and recursively
49    # fix them.
50    for node in gm.graph.nodes:
51        if node.op == "call_function" and node.target == torch.ops.cond:
52            pred, true_graph, false_graph, operands = node.args
53            true_gm = getattr(gm, true_graph.name)
54            false_gm = getattr(gm, false_graph.name)
55            inp_pos_to_param_buffer_name_for_submod = {}
56            real_operands = []
57            for ix, operand in enumerate(operands):
58                if operand.target in inp_pos_to_param_buffer_name.values():
59                    inp_pos_to_param_buffer_name_for_submod[ix] = operand.target
60                    true_gm.register_buffer(operand.target, state_dict[operand.target])
61                    false_gm.register_buffer(operand.target, state_dict[operand.target])
62                else:
63                    real_operands.append(operand)
64            node.args = (pred, true_graph, false_graph, real_operands)
65
66            _, in_spec = pytree.tree_flatten(real_operands)
67
68            _unlift(
69                true_gm,
70                inp_pos_to_param_buffer_name_for_submod,
71                in_spec,
72                None,
73                state_dict,
74            )
75            _unlift(
76                false_gm,
77                inp_pos_to_param_buffer_name_for_submod,
78                in_spec,
79                None,
80                state_dict,
81            )
82        if node.op == "call_function" and node.target.__name__ == "map_impl":
83            body_graph, num_mapped, *operands = node.args
84            body_gm = getattr(gm, body_graph.name)
85            inp_pos_to_buffer_name_for_submod = {}
86            real_operands = []
87            for ix, operand in enumerate(operands):
88                if operand.target in inp_pos_to_param_buffer_name.values():
89                    inp_pos_to_buffer_name_for_submod[ix] = operand.target
90                    body_gm.register_buffer(operand.target, state_dict[operand.target])
91                else:
92                    real_operands.append(operand)
93            node.args = (body_graph, num_mapped, *real_operands)
94
95            _, in_spec = pytree.tree_flatten(real_operands)
96
97            _unlift(
98                body_gm, inp_pos_to_buffer_name_for_submod, in_spec, None, state_dict
99            )
100    gm.graph.lint()
101    gm.graph.eliminate_dead_code()
102    gm.recompile()
103    return gm
104
105
106def unlift_exported_program_lifted_states(
107    ep: torch.export.exported_program.ExportedProgram,
108):
109    new_gm = copy.deepcopy(ep.graph_module)
110
111    # TODO Fix the period in params/buffers names later
112    # maybe a pass to replace graph signature with fixed names
113    param_buffer_name_to_corrected_name = {}
114
115    for name, stuff in ep.state_dict.items():
116        if name in ep.graph_signature.buffers:
117            if "." in name:
118                new_gm.register_buffer(name.replace(".", "_"), stuff)
119                param_buffer_name_to_corrected_name[name] = name.replace(".", "_")
120            else:
121                new_gm.register_buffer(name, stuff)
122        elif name in ep.graph_signature.parameters:
123            if "." in name:
124                new_gm.register_parameter(name.replace(".", "_"), stuff)
125                param_buffer_name_to_corrected_name[name] = name.replace(".", "_")
126            else:
127                new_gm.register_parameter(name, stuff)
128        else:
129            raise AssertionError("encountered not registered param/buffer")
130
131    count = 0
132    inp_pos_to_param_buffer_name = {}
133    for node in new_gm.graph.nodes:
134        if node.op == "placeholder":
135            if node.name in ep.graph_signature.inputs_to_buffers:
136                buffer_name = ep.graph_signature.inputs_to_buffers[node.name]
137                if buffer_name in param_buffer_name_to_corrected_name:
138                    inp_pos_to_param_buffer_name[count] = (
139                        param_buffer_name_to_corrected_name[buffer_name]
140                    )
141                else:
142                    inp_pos_to_param_buffer_name[count] = buffer_name
143            if node.name in ep.graph_signature.inputs_to_parameters:
144                param_name = ep.graph_signature.inputs_to_parameters[node.name]
145                if param_name in param_buffer_name_to_corrected_name:
146                    inp_pos_to_param_buffer_name[count] = (
147                        param_buffer_name_to_corrected_name[param_name]
148                    )
149                else:
150                    inp_pos_to_param_buffer_name[count] = param_name
151            count += 1
152    new_gm = _unlift(
153        new_gm,
154        inp_pos_to_param_buffer_name,
155        ep.call_spec.in_spec,
156        ep.call_spec.out_spec,
157        ep.state_dict,
158    )
159    return new_gm
160