1# Copyright (c) Meta Platforms, Inc. and affiliates. 2# All rights reserved. 3# 4# This source code is licensed under the BSD-style license found in the 5# LICENSE file in the root directory of this source tree. 6 7from typing import Dict, List, Optional, Tuple 8 9import torch 10 11from torch.export.exported_program import ( 12 ExportedProgram, 13 ExportGraphSignature, 14 InputKind, 15 OutputKind, 16 OutputSpec, 17) 18from torch.export.graph_signature import TensorArgument 19from torch.utils import _pytree as pytree 20 21 22def _insert_copy( 23 gm: torch.fx.GraphModule, 24 mutated_outputs: List[Optional[str]], 25 input_name_to_node: Dict[str, torch.fx.Node], 26): 27 """ 28 Find the all the buffers and inputs that were mutated and insert copy_ 29 operators to reflect mutations. 30 """ 31 output_node = None 32 for node in gm.graph.nodes: 33 if node.op == "output": 34 output_node = node 35 break 36 assert output_node is not None 37 outputs = pytree.tree_flatten(output_node.args)[0] 38 assert len(outputs) == len(mutated_outputs) 39 40 user_output_nodes = [] 41 buffer_output_nodes = [] 42 for return_node, mutated_node_name in zip(outputs, mutated_outputs): 43 # User output, leave alone 44 if mutated_node_name is None: 45 user_output_nodes.append(return_node) 46 continue 47 48 # Mutable buffer grab the node 49 if mutated_node_name in input_name_to_node: 50 mutated_node = input_name_to_node[mutated_node_name] 51 else: 52 raise RuntimeError( 53 f"Could not find {mutated_node_name} in either buffer or input nodes" 54 ) 55 56 # insert copy 57 with gm.graph.inserting_before(output_node): 58 buffer_output = gm.graph.call_function( 59 torch.ops.aten.copy_.default, (mutated_node, return_node) 60 ) 61 # add output of copy to graph outputs 62 buffer_output_nodes.append(buffer_output) 63 64 with gm.graph.inserting_before(output_node): 65 buffer_output_nodes.extend(user_output_nodes) 66 # Remove old outputs 67 new_output = gm.graph.output(tuple(buffer_output_nodes)) 68 output_node.replace_all_uses_with(new_output) 69 gm.graph.erase_node(output_node) 70 return buffer_output_nodes 71 72 73def insert_write_back_for_buffers_pass( 74 ep: ExportedProgram, 75) -> Tuple[torch.fx.GraphModule, ExportGraphSignature]: 76 gm: torch.fx.GraphModule = ep.graph_module 77 lifted_inputs: List[Optional[str]] = [] 78 for in_spec in ep.graph_signature.input_specs: 79 if in_spec.kind in ( 80 InputKind.BUFFER, 81 InputKind.CONSTANT_TENSOR, 82 InputKind.PARAMETER, 83 InputKind.CUSTOM_OBJ, 84 ): 85 lifted_inputs.append(in_spec.target) 86 elif in_spec.kind is InputKind.USER_INPUT and isinstance( 87 in_spec.arg, TensorArgument 88 ): 89 lifted_inputs.append(in_spec.arg.name) 90 else: 91 lifted_inputs.append(None) 92 93 input_name_to_node: Dict[str, torch.fx.Node] = {} 94 95 placeholder_nodes = [node for node in gm.graph.nodes if node.op == "placeholder"] 96 assert len(lifted_inputs) == len(placeholder_nodes) 97 # Grab the all the non user inputs 98 for input_node, lifted_node in zip(placeholder_nodes, lifted_inputs): 99 if lifted_node is not None: 100 input_name_to_node[lifted_node] = input_node 101 102 # Grab the mutable buffer nodes in the outputs, 103 mutated_outputs: List[Optional[str]] = [ 104 ( 105 out_spec.target 106 if out_spec.kind 107 in (OutputKind.BUFFER_MUTATION, OutputKind.USER_INPUT_MUTATION) 108 and out_spec.arg.name 109 not in { 110 val.name for val in input_name_to_node.values() 111 } # if the output arg is the input value then all operations on it are in-place so theres no need to add a copy_ node 112 else None 113 ) 114 for out_spec in ep.graph_signature.output_specs 115 ] 116 117 # insert the copy ops and update the outputs 118 buffer_output_nodes = _insert_copy(gm, mutated_outputs, input_name_to_node) 119 gm.graph.lint() 120 gm.graph.eliminate_dead_code() 121 gm.recompile() 122 123 # patch the output signature to point to the new updated outputs 124 new_output_specs: List[OutputSpec] = [] 125 i = 0 126 for output_spec in ep.graph_signature.output_specs: 127 if output_spec.kind in ( 128 OutputKind.BUFFER_MUTATION, 129 OutputKind.USER_INPUT_MUTATION, 130 ): 131 output_spec.arg.name = buffer_output_nodes[i].name 132 i += 1 133 new_output_specs.append(output_spec) 134 135 signature = ExportGraphSignature( 136 input_specs=ep.graph_signature.input_specs, 137 output_specs=new_output_specs, 138 ) 139 140 return gm, signature 141