xref: /aosp_15_r20/external/pytorch/torch/export/_remove_effect_tokens_pass.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# mypy: allow-untyped-defs
2import operator
3from typing import List
4
5import torch
6from torch._higher_order_ops.effects import _get_schema, with_effects
7
8from .exported_program import ExportedProgram
9from .graph_signature import (
10    CustomObjArgument,
11    InputKind,
12    InputSpec,
13    OutputKind,
14    OutputSpec,
15    TokenArgument,
16)
17
18
19def _remove_effect_tokens_from_graph_helper(
20    ep, num_tokens, input_token_names, output_token_names
21):
22    inputs_to_lifted_custom_objs = ep.graph_signature.inputs_to_lifted_custom_objs
23
24    output_node = None
25    with_effect_nodes: List[torch.fx.Node] = []
26
27    # Output node need to check its args agianst output_token_names (collected from output_spec)
28    # Therefore, we only need to find the top-levele output node
29    output_node = next(reversed(ep.graph_module.graph.find_nodes(op="output")))
30    for module in ep.graph_module.modules():
31        if not isinstance(module, torch.fx.GraphModule):
32            continue
33
34        for node in module.graph.nodes:
35            if not (node.op == "call_function" and node.target is with_effects):
36                continue
37
38            with_effect_nodes.append(node)
39
40    # Remove tokens from outputs
41    assert output_node is not None
42    output_args = output_node.args[0]
43    assert len(output_args) >= num_tokens
44    out_token_nodes = output_args[:num_tokens]
45    output_node.args = (tuple(output_args[num_tokens:]),)
46    for out_token in out_token_nodes:
47        assert out_token.name in output_token_names
48        out_token.users.clear()
49        ep.graph.erase_node(out_token)
50
51    # Replace with_effects(token, func, args) with just func(args)
52    for node in reversed(with_effect_nodes):
53        func = node.args[1]
54        assert isinstance(func, (torch._ops.OpOverload, torch._ops.HigherOrderOperator))
55
56        if func == torch.ops.higher_order.call_torchbind:
57            custom_obj_meta = node.args[2].meta["val"]
58            assert isinstance(custom_obj_meta, CustomObjArgument)
59            if custom_obj_meta.fake_val:
60                custom_obj = custom_obj_meta.fake_val
61            elif node.args[2].name in inputs_to_lifted_custom_objs:
62                custom_obj = ep.constants[
63                    inputs_to_lifted_custom_objs[node.args[2].name]
64                ]
65            else:
66                raise RuntimeError(f"Unable to find custom obj for node {node}")
67            schema = _get_schema(func, (custom_obj,) + node.args[3:])
68        else:
69            schema = _get_schema(func, node.args[2:])
70
71        with ep.graph.inserting_before(node):
72            new_node = ep.graph.call_function(func, node.args[2:], node.kwargs)
73        for k, v in node.meta.items():
74            new_node.meta[k] = v
75
76        node.replace_all_uses_with(new_node)
77
78        # Update user getitem nodes
79        for user in list(new_node.users.keys()):
80            assert user.target == operator.getitem
81            # getitem(with_effects, 0) == token
82            if user.args[1] == 0:
83                ep.graph.erase_node(user)
84
85        if len(schema.returns) == 1:
86            # If the function has 1 return then it will just directly return the
87            # result -- we don't need a getitem. So we can replace all the
88            # getitem(with_effects, 1) with just the note itself.
89            for user in list(new_node.users.keys()):
90                assert user.args[1] == 1
91                user.replace_all_uses_with(new_node)
92
93            new_node.meta["val"] = node.meta["val"][1]
94        elif len(schema.returns) > 1:
95            # If the function has more than 1 return then since we got rid of
96            # the 1st return value (the token), we need to bump all the other
97            # getitem calls by 1 down
98            for user in list(new_node.users.keys()):
99                assert user.args[1] >= 1
100                user.args = (user.args[0], user.args[1] - 1)
101
102            new_node.meta["val"] = node.meta["val"][1:]
103        else:
104            assert len(schema.returns) == 0
105            assert len(new_node.users) == 0
106            new_node.meta["val"] = None
107
108        ep.graph.erase_node(node)
109
110    # Remove tokens from inputs
111    placeholders = [node for node in ep.graph.nodes if node.op == "placeholder"]
112    assert len(placeholders) >= num_tokens
113    inp_token_nodes = placeholders[:num_tokens]
114    for inp_token in inp_token_nodes:
115        assert inp_token.name in input_token_names
116        ep.graph.erase_node(inp_token)
117
118    ep.graph.eliminate_dead_code()
119
120
121def _remove_effect_tokens(ep: ExportedProgram) -> ExportedProgram:
122    """
123    Removes the existance of tokens from the exported program, including:
124    - Removes the input and output tokens
125    - Replaces with_effects(token, func, args) with just func(args)
126
127    This function does an inplace modification on the given ExportedProgram.
128    """
129    num_tokens: int = 0
130    input_token_names: List[str] = []
131    new_input_specs: List[InputSpec] = []
132    for inp in ep.graph_signature.input_specs:
133        if inp.kind == InputKind.TOKEN:
134            num_tokens += 1
135            assert isinstance(inp.arg, TokenArgument)
136            input_token_names.append(inp.arg.name)
137        else:
138            new_input_specs.append(inp)
139
140    num_out_tokens: int = 0
141    new_output_specs: List[OutputSpec] = []
142    output_token_names: List[OutputSpec] = []
143    for out in ep.graph_signature.output_specs:
144        if out.kind == OutputKind.TOKEN:
145            num_out_tokens += 1
146            output_token_names.append(out.arg.name)
147        else:
148            new_output_specs.append(out)
149
150    # Update graph signature
151    ep.graph_signature.input_specs = new_input_specs
152    ep.graph_signature.output_specs = new_output_specs
153
154    assert num_tokens == num_out_tokens
155
156    with ep.graph_module._set_replace_hook(ep.graph_signature.get_replace_hook()):
157        _remove_effect_tokens_from_graph_helper(
158            ep, num_tokens, input_token_names, output_token_names
159        )
160
161    return ep
162