1# mypy: allow-untyped-defs 2import operator 3from typing import List 4 5import torch 6from torch._higher_order_ops.effects import _get_schema, with_effects 7 8from .exported_program import ExportedProgram 9from .graph_signature import ( 10 CustomObjArgument, 11 InputKind, 12 InputSpec, 13 OutputKind, 14 OutputSpec, 15 TokenArgument, 16) 17 18 19def _remove_effect_tokens_from_graph_helper( 20 ep, num_tokens, input_token_names, output_token_names 21): 22 inputs_to_lifted_custom_objs = ep.graph_signature.inputs_to_lifted_custom_objs 23 24 output_node = None 25 with_effect_nodes: List[torch.fx.Node] = [] 26 27 # Output node need to check its args agianst output_token_names (collected from output_spec) 28 # Therefore, we only need to find the top-levele output node 29 output_node = next(reversed(ep.graph_module.graph.find_nodes(op="output"))) 30 for module in ep.graph_module.modules(): 31 if not isinstance(module, torch.fx.GraphModule): 32 continue 33 34 for node in module.graph.nodes: 35 if not (node.op == "call_function" and node.target is with_effects): 36 continue 37 38 with_effect_nodes.append(node) 39 40 # Remove tokens from outputs 41 assert output_node is not None 42 output_args = output_node.args[0] 43 assert len(output_args) >= num_tokens 44 out_token_nodes = output_args[:num_tokens] 45 output_node.args = (tuple(output_args[num_tokens:]),) 46 for out_token in out_token_nodes: 47 assert out_token.name in output_token_names 48 out_token.users.clear() 49 ep.graph.erase_node(out_token) 50 51 # Replace with_effects(token, func, args) with just func(args) 52 for node in reversed(with_effect_nodes): 53 func = node.args[1] 54 assert isinstance(func, (torch._ops.OpOverload, torch._ops.HigherOrderOperator)) 55 56 if func == torch.ops.higher_order.call_torchbind: 57 custom_obj_meta = node.args[2].meta["val"] 58 assert isinstance(custom_obj_meta, CustomObjArgument) 59 if custom_obj_meta.fake_val: 60 custom_obj = custom_obj_meta.fake_val 61 elif node.args[2].name in inputs_to_lifted_custom_objs: 62 custom_obj = ep.constants[ 63 inputs_to_lifted_custom_objs[node.args[2].name] 64 ] 65 else: 66 raise RuntimeError(f"Unable to find custom obj for node {node}") 67 schema = _get_schema(func, (custom_obj,) + node.args[3:]) 68 else: 69 schema = _get_schema(func, node.args[2:]) 70 71 with ep.graph.inserting_before(node): 72 new_node = ep.graph.call_function(func, node.args[2:], node.kwargs) 73 for k, v in node.meta.items(): 74 new_node.meta[k] = v 75 76 node.replace_all_uses_with(new_node) 77 78 # Update user getitem nodes 79 for user in list(new_node.users.keys()): 80 assert user.target == operator.getitem 81 # getitem(with_effects, 0) == token 82 if user.args[1] == 0: 83 ep.graph.erase_node(user) 84 85 if len(schema.returns) == 1: 86 # If the function has 1 return then it will just directly return the 87 # result -- we don't need a getitem. So we can replace all the 88 # getitem(with_effects, 1) with just the note itself. 89 for user in list(new_node.users.keys()): 90 assert user.args[1] == 1 91 user.replace_all_uses_with(new_node) 92 93 new_node.meta["val"] = node.meta["val"][1] 94 elif len(schema.returns) > 1: 95 # If the function has more than 1 return then since we got rid of 96 # the 1st return value (the token), we need to bump all the other 97 # getitem calls by 1 down 98 for user in list(new_node.users.keys()): 99 assert user.args[1] >= 1 100 user.args = (user.args[0], user.args[1] - 1) 101 102 new_node.meta["val"] = node.meta["val"][1:] 103 else: 104 assert len(schema.returns) == 0 105 assert len(new_node.users) == 0 106 new_node.meta["val"] = None 107 108 ep.graph.erase_node(node) 109 110 # Remove tokens from inputs 111 placeholders = [node for node in ep.graph.nodes if node.op == "placeholder"] 112 assert len(placeholders) >= num_tokens 113 inp_token_nodes = placeholders[:num_tokens] 114 for inp_token in inp_token_nodes: 115 assert inp_token.name in input_token_names 116 ep.graph.erase_node(inp_token) 117 118 ep.graph.eliminate_dead_code() 119 120 121def _remove_effect_tokens(ep: ExportedProgram) -> ExportedProgram: 122 """ 123 Removes the existance of tokens from the exported program, including: 124 - Removes the input and output tokens 125 - Replaces with_effects(token, func, args) with just func(args) 126 127 This function does an inplace modification on the given ExportedProgram. 128 """ 129 num_tokens: int = 0 130 input_token_names: List[str] = [] 131 new_input_specs: List[InputSpec] = [] 132 for inp in ep.graph_signature.input_specs: 133 if inp.kind == InputKind.TOKEN: 134 num_tokens += 1 135 assert isinstance(inp.arg, TokenArgument) 136 input_token_names.append(inp.arg.name) 137 else: 138 new_input_specs.append(inp) 139 140 num_out_tokens: int = 0 141 new_output_specs: List[OutputSpec] = [] 142 output_token_names: List[OutputSpec] = [] 143 for out in ep.graph_signature.output_specs: 144 if out.kind == OutputKind.TOKEN: 145 num_out_tokens += 1 146 output_token_names.append(out.arg.name) 147 else: 148 new_output_specs.append(out) 149 150 # Update graph signature 151 ep.graph_signature.input_specs = new_input_specs 152 ep.graph_signature.output_specs = new_output_specs 153 154 assert num_tokens == num_out_tokens 155 156 with ep.graph_module._set_replace_hook(ep.graph_signature.get_replace_hook()): 157 _remove_effect_tokens_from_graph_helper( 158 ep, num_tokens, input_token_names, output_token_names 159 ) 160 161 return ep 162