1# mypy: allow-untyped-defs 2import copy 3from dataclasses import dataclass, field 4from typing import Dict, List, Optional, Tuple, Type, Union 5 6import torch.fx 7from torch.fx._compatibility import compatibility 8from torch.fx.graph import map_arg 9from torch.fx.passes.utils import HolderModule, lift_subgraph_as_module 10 11from .tools_common import NodeList 12 13__all__ = ["getattr_recursive", "setattr_recursive", "Component", "split_by_tags"] 14 15 16@compatibility(is_backward_compatible=False) 17def getattr_recursive(obj, name): 18 for layer in name.split("."): 19 if hasattr(obj, layer): 20 obj = getattr(obj, layer) 21 else: 22 return None 23 return obj 24 25 26@compatibility(is_backward_compatible=False) 27def setattr_recursive(obj, attr, value): 28 if "." not in attr: 29 setattr(obj, attr, value) 30 else: 31 layer = attr.split(".") 32 setattr_recursive(getattr(obj, layer[0]), ".".join(layer[1:]), value) 33 34 35@compatibility(is_backward_compatible=False) 36@dataclass 37class Component: 38 """ 39 A component serves as a container for a subgraph we want to create afterwards. 40 """ 41 42 graph: torch.fx.Graph 43 order: int 44 name: str 45 46 # Stores the placeholder nodes in `graph`. 47 input_placeholders: List = field(default_factory=list) 48 49 # Store the nodes in original graph that are placeholder in `graph`. 50 orig_inputs: List = field(default_factory=list) 51 52 # Store the nodes in original graph that are outputs in `graph`. 53 orig_outputs: List = field(default_factory=list) 54 55 # Mapping from get_attr node in original graph to get_attr node in `graph`. 56 getattr_maps: Dict[torch.fx.Node, torch.fx.Node] = field(default_factory=dict) 57 constructor_args: List[str] = field(default_factory=list) 58 gm: Optional[torch.fx.GraphModule] = None 59 60 61@compatibility(is_backward_compatible=False) 62def split_by_tags( 63 gm: torch.fx.GraphModule, 64 tags: List[str], 65 return_fqn_mapping: bool = False, 66 return_tuple: bool = False, 67 GraphModuleCls: Type[torch.fx.GraphModule] = torch.fx.GraphModule, 68) -> Union[torch.fx.GraphModule, Tuple[torch.fx.GraphModule, Dict[str, str]]]: 69 """ 70 Splits a GraphModule using tags on its graph nodes. We honor the order of 71 tags. For example, we have tags = ["a", "b", "c"], the function will create 72 the initial submodules in the order of "a", "b", "c". 73 74 To set a tag: 75 gm.graph.nodes[idx].tag = "mytag" 76 77 This will result in all nodes with the same tag being extracted and placed in their 78 own submodule. For placeholder, output and get_attr node, the tag is ignored. placeholder 79 and output nodes are created when needed while get_attr nodes get copied to submodules 80 where they are used. 81 82 Given the following module def: 83 84 class SimpleModule(torch.nn.Module): 85 def __init__(self) -> None: 86 super().__init__() 87 self.linear1 = torch.nn.Linear(...) 88 self.linear2 = torch.nn.Linear(...) 89 self.linear3 = torch.nn.Linear(...) 90 91 def forward(self, in1, in2): 92 r1 = self.linear1(in1) 93 r2 = self.linear2(in2) 94 r3 = torch.cat([r1, r2]) 95 return self.linear3(r3) 96 97 Marking the node corresponding to in1 with the tag sc.REQUEST_ONLY.lower() results in the following split: 98 99 ro: 100 def forward(self, in1): 101 self = self.root 102 linear1 = self.linear1(in1) 103 return linear1 104 105 main: 106 def forward(self, in2, linear1): 107 self = self.root 108 linear2 = self.linear2(in2) 109 cat_1 = torch.cat([linear1, linear2]) 110 linear3 = self.linear3(cat_1) 111 return linear3 112 113 main: 114 def forward(self, in1, in2): 115 self = self.root 116 ro_0 = self.ro_0(in1) 117 main_1 = self.main_1(in2, ro_0) 118 return main_1 119 120 Returns: 121 split_gm: torch fx graph after split 122 orig_to_split_fqn_mapping: a map between the original fqn and the fqn 123 after split for call_module and get_attr. 124 """ 125 126 def flatten(x: torch.fx.node.Argument) -> NodeList: 127 """ 128 Stores nodes in x to a list and returns the list. 129 """ 130 r: NodeList = [] 131 map_arg(x, r.append) 132 return r 133 134 # Mapping from node in original module to node in created submodule. 135 node_remapping: Dict[torch.fx.Node, torch.fx.Node] = {} 136 137 # Mapping from node in original module or created submodules to 138 # corresponding component. 139 node_to_component: Dict[torch.fx.Node, Component] = {} 140 141 # Mapping from tag to the corresponding component. 142 tag_to_component: Dict[str, Component] = {} 143 144 # Stores all components. 145 all_components: List[Component] = [] 146 147 # Stores nodes that will be used in main graph. 148 used_in_main: Dict[torch.fx.Node, None] = {} 149 150 # Main graph after split. 151 main_g = torch.fx.Graph() 152 153 # Mapping from node in original module to node in main graph after split. 154 main_remapping: Dict[torch.fx.Node, torch.fx.Node] = {} 155 156 # Output node of original module. 157 output_node: Optional[torch.fx.Node] = None 158 159 # Create a component for each tag, we don't expect to create other components afterwards. 160 for tag in tags: 161 comp = Component(torch.fx.Graph(), len(all_components), f"{tag}") 162 all_components.append(comp) 163 tag_to_component[tag] = comp 164 165 # Traverse the nodes in original graph and take care of them. 166 for node in gm.graph.nodes: 167 if node.op == "output": 168 if output_node is not None: 169 raise RuntimeError("Multiple output nodes in graph!") 170 output_node = node 171 continue 172 173 # Placeholders in the original graph get copied to main graph. 174 if node.op == "placeholder": 175 main_remapping[node] = main_g.placeholder(node.name, type_expr=node.type) 176 main_remapping[node].meta = copy.copy(node.meta) 177 continue 178 179 # Get_attr nodes are ignored because we are not tagging them. 180 # Instead, we copy them directly to the submodules use them afterwards. 181 if node.op == "get_attr": 182 continue 183 184 # Now we process callable nodes which are nodes with op of call_module, 185 # call_function or call_method. Every callable nodes should be tagged. 186 assert hasattr(node, "tag") 187 188 upstream_components = [ 189 node_to_component[x] 190 for x in flatten(node.args) + flatten(node.kwargs) 191 if x.op not in {"placeholder", "get_attr"} 192 ] 193 194 comp = tag_to_component[node.tag] 195 node_to_component[node] = comp 196 197 # Max order of upperstream components. 198 mx = max((c.order for c in upstream_components), default=0) 199 200 # Expect the component for `node` has higher order then its upstream components. 201 assert comp.order >= mx 202 203 # Map a input of `node` to nodes in the component's graph. 204 def remap_func(x): 205 # If input is a get_attr node, copy it to current component's graph. 206 # Returns the get_attr node in current component's graph. 207 if x.op == "get_attr": 208 if x not in comp.getattr_maps: 209 comp.getattr_maps[x] = comp.graph.get_attr( 210 x.target, type_expr=x.type 211 ) 212 return comp.getattr_maps[x] 213 214 # If input is not a placeholder, it should have been put into a component 215 # already. If it's the current component then we return the corresponding 216 # node in the component. 217 if x.op != "placeholder" and node_to_component[x] == comp: 218 return node_remapping[x] 219 220 # If input is a placeholder or it's in other components, we want to make it 221 # as a placeholder in current component's graph. 222 if x not in comp.orig_inputs: 223 comp.orig_inputs.append(x) 224 placeholder = comp.graph.placeholder(x.name, type_expr=x.type) 225 placeholder.meta = copy.copy(x.meta) 226 comp.input_placeholders.append(placeholder) 227 used_in_main[x] = None 228 229 return comp.input_placeholders[comp.orig_inputs.index(x)] 230 231 n = comp.graph.node_copy(node, remap_func) 232 n.tag = node.tag # type: ignore[attr-defined] 233 node_remapping[node] = n 234 node_to_component[n] = comp 235 236 if output_node is None: 237 raise RuntimeError("Graph had no output node!") 238 239 for x in flatten(output_node.args[0]): 240 if x.op == "get_attr": 241 # We don't need components mapping for nodes of type "get_attr" 242 # that are consumed by the output. Only need to make sure we create 243 # corresponding counterparts in the resulting graph. 244 main_remapping[x] = main_g.get_attr(x.name, type_expr=x.type) 245 else: 246 # All component results consumed by the output node should be 247 # marked as "used in main". 248 used_in_main[x] = None 249 250 # If a node is used in main graph then we mark it as an output in the component 251 # it belongs to. 252 for n in used_in_main: 253 if n.op != "placeholder": 254 node_to_component[n].orig_outputs.append(n) 255 256 # Now we create a graphmodule for each component. 257 orig_to_split_fqn_mapping: Dict[str, str] = {} 258 for comp in all_components: 259 outs = tuple(map(node_remapping.__getitem__, comp.orig_outputs)) 260 261 if return_tuple: 262 comp.graph.output(outs) 263 else: 264 # Take care of the args of FX output node. If there's a single 265 # output then the output node args is like (output_single), else 266 # if there're multiple outputs then the output node args is like 267 # ((output_0, output_1, ...)). 268 comp.graph.output(outs[0] if len(outs) == 1 else outs) 269 270 comp.gm, comp_orig_to_split_fqn_mapping = lift_subgraph_as_module( 271 gm, subgraph=comp.graph, comp_name=comp.name 272 ) 273 orig_to_split_fqn_mapping.update(comp_orig_to_split_fqn_mapping) 274 275 # Create a call_module node in main graph. 276 main_node = main_g.call_module( 277 comp.name, 278 args=tuple(map(main_remapping.__getitem__, comp.orig_inputs)), 279 kwargs=None, 280 ) 281 282 if len(outs) == 1 and not return_tuple: 283 main_remapping[comp.orig_outputs[0]] = main_node 284 else: 285 for i, o in enumerate(comp.orig_outputs): 286 # Use Proxy to record getitem access. 287 main_remapping[o] = torch.fx.Proxy(main_node)[i].node # type: ignore[index] 288 289 main_g.output(map_arg(output_node.args[0], main_remapping.__getitem__)) 290 main_root = HolderModule({comp.name: comp.gm for comp in all_components}) 291 main_g._codegen = gm.graph._codegen 292 293 # If the output nodes consumes get_attr directly in the original graph, 294 # then we need to make sure get_attr is copied to the new graph. 295 for x in flatten(output_node.args[0]): 296 if x.op == "get_attr": 297 setattr(main_root, x.name, getattr_recursive(gm, x.target)) # type: ignore[arg-type] 298 299 result_gm = GraphModuleCls(main_root, main_g) 300 if return_fqn_mapping: 301 return result_gm, orig_to_split_fqn_mapping 302 303 return result_gm 304