1# Copyright (c) Meta Platforms, Inc. and affiliates. 2# All rights reserved. 3# 4# This source code is licensed under the BSD-style license found in the 5# LICENSE file in the root directory of this source tree. 6 7import copy 8from typing import Any 9 10import torch 11import torch._export 12from torch.fx.graph import _PyTreeCodeGen, _PyTreeInfo 13from torch.utils import _pytree as pytree 14 15 16Val = Any 17 18 19def _unlift(gm, inp_pos_to_param_buffer_name, in_spec, out_spec, state_dict): 20 count = 0 21 # Step 1: make lifted params as get_attr 22 for node in gm.graph.nodes: 23 if node.op == "placeholder": 24 if count in inp_pos_to_param_buffer_name: 25 with gm.graph.inserting_after(node): 26 getattr_node = gm.graph.get_attr( 27 inp_pos_to_param_buffer_name[count] 28 ) 29 node.replace_all_uses_with(getattr_node) 30 metadata = node.meta 31 gm.graph.erase_node(node) 32 getattr_node.meta = metadata 33 count += 1 34 35 # Step 2: Fix the input/output of the graph now that we deleted 36 # some args. 37 gm.graph.lint() 38 names = [f"arg_{i}" for i in range(len(in_spec.children_specs))] 39 gm.graph._codegen = _PyTreeCodeGen( 40 _PyTreeInfo( 41 names, 42 in_spec, 43 out_spec, 44 ) 45 ) 46 gm.recompile() 47 48 # Step 3: Find state references in HigherOrderOps and recursively 49 # fix them. 50 for node in gm.graph.nodes: 51 if node.op == "call_function" and node.target == torch.ops.cond: 52 pred, true_graph, false_graph, operands = node.args 53 true_gm = getattr(gm, true_graph.name) 54 false_gm = getattr(gm, false_graph.name) 55 inp_pos_to_param_buffer_name_for_submod = {} 56 real_operands = [] 57 for ix, operand in enumerate(operands): 58 if operand.target in inp_pos_to_param_buffer_name.values(): 59 inp_pos_to_param_buffer_name_for_submod[ix] = operand.target 60 true_gm.register_buffer(operand.target, state_dict[operand.target]) 61 false_gm.register_buffer(operand.target, state_dict[operand.target]) 62 else: 63 real_operands.append(operand) 64 node.args = (pred, true_graph, false_graph, real_operands) 65 66 _, in_spec = pytree.tree_flatten(real_operands) 67 68 _unlift( 69 true_gm, 70 inp_pos_to_param_buffer_name_for_submod, 71 in_spec, 72 None, 73 state_dict, 74 ) 75 _unlift( 76 false_gm, 77 inp_pos_to_param_buffer_name_for_submod, 78 in_spec, 79 None, 80 state_dict, 81 ) 82 if node.op == "call_function" and node.target.__name__ == "map_impl": 83 body_graph, num_mapped, *operands = node.args 84 body_gm = getattr(gm, body_graph.name) 85 inp_pos_to_buffer_name_for_submod = {} 86 real_operands = [] 87 for ix, operand in enumerate(operands): 88 if operand.target in inp_pos_to_param_buffer_name.values(): 89 inp_pos_to_buffer_name_for_submod[ix] = operand.target 90 body_gm.register_buffer(operand.target, state_dict[operand.target]) 91 else: 92 real_operands.append(operand) 93 node.args = (body_graph, num_mapped, *real_operands) 94 95 _, in_spec = pytree.tree_flatten(real_operands) 96 97 _unlift( 98 body_gm, inp_pos_to_buffer_name_for_submod, in_spec, None, state_dict 99 ) 100 gm.graph.lint() 101 gm.graph.eliminate_dead_code() 102 gm.recompile() 103 return gm 104 105 106def unlift_exported_program_lifted_states( 107 ep: torch.export.exported_program.ExportedProgram, 108): 109 new_gm = copy.deepcopy(ep.graph_module) 110 111 # TODO Fix the period in params/buffers names later 112 # maybe a pass to replace graph signature with fixed names 113 param_buffer_name_to_corrected_name = {} 114 115 for name, stuff in ep.state_dict.items(): 116 if name in ep.graph_signature.buffers: 117 if "." in name: 118 new_gm.register_buffer(name.replace(".", "_"), stuff) 119 param_buffer_name_to_corrected_name[name] = name.replace(".", "_") 120 else: 121 new_gm.register_buffer(name, stuff) 122 elif name in ep.graph_signature.parameters: 123 if "." in name: 124 new_gm.register_parameter(name.replace(".", "_"), stuff) 125 param_buffer_name_to_corrected_name[name] = name.replace(".", "_") 126 else: 127 new_gm.register_parameter(name, stuff) 128 else: 129 raise AssertionError("encountered not registered param/buffer") 130 131 count = 0 132 inp_pos_to_param_buffer_name = {} 133 for node in new_gm.graph.nodes: 134 if node.op == "placeholder": 135 if node.name in ep.graph_signature.inputs_to_buffers: 136 buffer_name = ep.graph_signature.inputs_to_buffers[node.name] 137 if buffer_name in param_buffer_name_to_corrected_name: 138 inp_pos_to_param_buffer_name[count] = ( 139 param_buffer_name_to_corrected_name[buffer_name] 140 ) 141 else: 142 inp_pos_to_param_buffer_name[count] = buffer_name 143 if node.name in ep.graph_signature.inputs_to_parameters: 144 param_name = ep.graph_signature.inputs_to_parameters[node.name] 145 if param_name in param_buffer_name_to_corrected_name: 146 inp_pos_to_param_buffer_name[count] = ( 147 param_buffer_name_to_corrected_name[param_name] 148 ) 149 else: 150 inp_pos_to_param_buffer_name[count] = param_name 151 count += 1 152 new_gm = _unlift( 153 new_gm, 154 inp_pos_to_param_buffer_name, 155 ep.call_spec.in_spec, 156 ep.call_spec.out_spec, 157 ep.state_dict, 158 ) 159 return new_gm 160