1# mypy: allow-untyped-defs 2# Copyright (c) Meta Platforms, Inc. and affiliates 3import copy 4import logging 5import operator 6from collections import defaultdict 7from enum import Enum 8from inspect import Parameter, Signature, signature 9from types import MethodType 10from typing import Any, Callable, Dict, List, Optional, Set, Tuple, Union 11 12import torch 13import torch.fx as fx 14from torch.distributed import ProcessGroup 15from torch.export import ExportedProgram 16from torch.export.unflatten import ( 17 _assign_attr, 18 _AttrKind, 19 _sink_params, 20 InterpreterModule, 21) 22from torch.fx.node import map_aggregate 23from torch.fx.passes.split_module import split_module 24 25from ._backward import _null_coalesce_accumulate, stage_backward 26from ._unflatten import _outline_submodules 27from ._utils import PipeInfo 28from .stage import _PipelineStage 29 30 31logger = logging.getLogger(__name__) 32 33# TODO: 34# 1. investigate gradient sync for shared parameters. how does DDP do it? 35# 2. Add parameter movement to split_module 36 37 38def _find_loss_from_output_and_spec(output_val, spec_val): 39 if spec_val is False: 40 return None 41 if spec_val is True: 42 if not isinstance(output_val, fx.Node): 43 raise RuntimeError( 44 f"Loss spec must specify a dynamic value but got {output_val}" 45 ) 46 return output_val 47 48 if isinstance(spec_val, (tuple, list)): 49 if not isinstance(output_val, (tuple, list)): 50 raise RuntimeError( 51 f"Output value {output_val} must match type of loss specification " 52 f"{spec_val}" 53 ) 54 if len(output_val) != len(spec_val): 55 raise RuntimeError( 56 f"Output value {output_val} must match length of loss specification " 57 f"{spec_val}" 58 ) 59 for out, spec in zip(output_val, spec_val): 60 loss_val = _find_loss_from_output_and_spec(out, spec) 61 if loss_val is not None: 62 return loss_val 63 raise RuntimeError(f"Did not find loss value in specification {spec_val}") 64 65 if isinstance(spec_val, dict): 66 if not isinstance(output_val, dict): 67 raise RuntimeError( 68 f"Output value {output_val} must match type of loss specification " 69 f"{spec_val}" 70 ) 71 if set(output_val.keys()) != set(spec_val.keys()): 72 raise RuntimeError( 73 f"Output value {output_val} must match keys of loss specification " 74 f"{spec_val}" 75 ) 76 for k in spec_val: 77 loss_val = _find_loss_from_output_and_spec(output_val[k], spec_val[k]) 78 if loss_val is not None: 79 return loss_val 80 raise RuntimeError(f"Did not find loss value in specification {spec_val}") 81 82 raise RuntimeError(f"Unsupported type {type(spec_val)} in loss specification") 83 84 85def _find_loss_output(mod: torch.nn.Module, g: fx.Graph, output_loss_value_spec): 86 output_nodes = [n for n in g.nodes if n.op == "output"] 87 assert len(output_nodes) == 1 88 output_node = output_nodes[0] 89 output_val = output_node.args[0] 90 generated_spec: Any = None 91 92 if isinstance(mod, TrivialLossWrapper): 93 # TrivialLossWrapper is pre-defined by PiPPy. 94 # It has loss as the only output so we can safely assume the first output arg is the loss. 95 assert len(output_node.args) == 1 96 loss_node = output_val 97 generated_spec = TrivialLossWrapper.loss_spec 98 elif output_loss_value_spec is None: 99 # Use default spec, i.e. search for "loss" in output values 100 if isinstance(output_val, dict) and "loss" in output_val.keys(): 101 loss_node = output_val["loss"] 102 generated_spec = {k: k == "loss" for k in output_val} 103 else: 104 loss_node = None 105 generated_spec = None 106 else: 107 loss_node = _find_loss_from_output_and_spec(output_val, output_loss_value_spec) 108 generated_spec = output_loss_value_spec 109 110 return loss_node, output_node, generated_spec 111 112 113def _insert_stage_symbolic_backward( 114 g: fx.Graph, 115 loss_node: fx.Node, 116 output_node: fx.Node, 117): 118 # Collect metadata about tuple output values. TODO: move this to split_module or FX IR 119 tuples: Dict[fx.Node, Tuple] = {} 120 for node in reversed(g.nodes): 121 if node.op == "call_function": 122 # In the forward pass, only emit placeholder, module calls, and 123 # getitem calls. If we have a target other than getitem in this 124 # (forward-only) code, there is a bug. 125 assert node.target == operator.getitem, ( 126 "Found non-getitem call in forward pass. " 127 "Please report a bug to PiPPy" 128 ) 129 assert ( 130 len(node.args) == 2 131 ), "Found malformed getitem call. Please report a bug to PiPPy" 132 indexed_value, node_idx = tuple(node.args) 133 134 # indexed_value is a collection that we are indexing into. It could 135 # exist in the tuples map if we've processed another `getitem` 136 # already. 137 existing_list_size = ( 138 len(tuples[indexed_value]) if indexed_value in tuples else -1 139 ) 140 new_list_size = max(node_idx + 1, existing_list_size) 141 142 reconstructed_list = [None for _ in range(new_list_size)] 143 144 # Copy over existing elements if present 145 if indexed_value in tuples: 146 for i, val in enumerate(tuples[indexed_value]): 147 reconstructed_list[i] = val 148 149 # Populate value represented by this node 150 reconstructed_list[node_idx] = node 151 152 tuples[indexed_value] = tuple(reconstructed_list) 153 154 # Keep track of nodes that dominate the loss node. 155 # We will only emit backward operations for nodes that can contribute 156 # to the specified loss value. 157 live_nodes = {loss_node: None} 158 val_to_grad: Dict[fx.Node, Optional[fx.Node]] = {loss_node: None} 159 160 def assign_or_accumulate_grad(forward_node, grad_value): 161 if forward_node in val_to_grad and forward_node.op != "placeholder": 162 grad_value = g.call_function( 163 _null_coalesce_accumulate, 164 (val_to_grad[forward_node], grad_value), 165 ) 166 val_to_grad[forward_node] = grad_value 167 168 with g.inserting_before(output_node): 169 for node in reversed(g.nodes): 170 if node not in live_nodes: 171 continue 172 173 def add_to_live_nodes(n): 174 live_nodes.setdefault(n, None) 175 176 fx.node.map_arg(node.args, add_to_live_nodes) 177 fx.node.map_arg(node.kwargs, add_to_live_nodes) 178 if node.op == "call_module": 179 output_grads: Union[Tuple[Optional[fx.Node], ...], Optional[fx.Node]] 180 if node in tuples: 181 stage_output = tuples[node] 182 output_grads = tuple(val_to_grad.get(n, None) for n in tuples[node]) 183 outputs_with_grads_idxs = [ 184 i for i, n in enumerate(tuples[node]) if n in live_nodes 185 ] 186 else: 187 stage_output = (node,) 188 output_grads = val_to_grad[node] 189 outputs_with_grads_idxs = [0] 190 191 output_grads = ( 192 (output_grads,) 193 if not isinstance(output_grads, tuple) 194 else output_grads 195 ) 196 197 grad_call = g.call_function( 198 stage_backward, 199 kwargs={ 200 "stage_output": stage_output, 201 "output_grads": output_grads, 202 "input_values": list(node.all_input_nodes), 203 "outputs_with_grads_idxs": outputs_with_grads_idxs, 204 }, 205 ) 206 # Insert backward stage debug info 207 kwargs_copy = dict(grad_call.kwargs) 208 grad_call.kwargs = kwargs_copy 209 210 grad_call_proxy = fx.Proxy(grad_call) 211 grads = grad_call_proxy.node 212 213 input_nodes = list(node.all_input_nodes) 214 grads_proxy = fx.Proxy(grads) 215 for i, input_node in enumerate(input_nodes): 216 assign_or_accumulate_grad(input_node, grads_proxy[i].node) # type: ignore[index] 217 218 return g 219 220 221class PipeSequential(torch.nn.Sequential): 222 @staticmethod 223 def from_sequential(sequential_instance: torch.nn.Sequential): 224 return PipeSequential(*[copy.copy(m) for m in sequential_instance]) 225 226 def forward(self, input): 227 for i, module in enumerate(self): 228 input = module(input) 229 if i != len(self) - 1: 230 pipe_split() 231 return input 232 233 234class LossWrapper(torch.nn.Module): 235 """ 236 LossWrapper is a convenient abstract class that allows you to wrap up both 237 your model as well as its loss function and specify the connectivity between 238 the inputs, model, loss function, and output value. Example:: 239 240 class MyModelWrapper(LossWrapper): 241 def forward(self, x, targets): 242 model_out = self.module(x) 243 loss_value = self.loss_fn(model_out, targets) 244 return loss_value 245 246 The above example defines a connectivity where we expect the forward/loss/backward 247 training procedure to take two arguments (x and targets), pass x into the module 248 to get the output of the feedforward computation, pass the model output and the 249 targets value into the loss function, and get and return the loss value, which will 250 be backpropagated by PiPPy. The above class would then be instantiated like:: 251 252 model = ... # instantiate the model 253 loss_fn = torch.nn.MSELoss() # for the sake of demonstration 254 255 wrapper = MyModelWrapper(model, loss_fn) 256 pipe = Pipe.from_tracing(wrapper, ...) 257 258 """ 259 260 def __init__(self, module, loss_fn): 261 super().__init__() 262 self.module = module 263 self.loss_fn = loss_fn 264 265 def forward(self, *args, **kwargs): 266 raise NotImplementedError( 267 "This instance of LossWrapper does not have an overridden" 268 "forward(). Please implement forward() to specify the arguments, " 269 "connection between the module and loss, and loss output " 270 "value." 271 ) 272 273 274class TrivialLossWrapper(LossWrapper): 275 def forward(self, x, targets): 276 model_out = self.module(x) 277 return self.loss_fn(model_out, targets) 278 279 loss_spec = True 280 281 282# Pipe model representation 283# 284# Pipe can be thought of as an `nn.Sequential++`. That is to say: it specifies 285# a single topological ordering of pipeline "stages" that, when run in series, 286# constitutes all of the operations of the program. However, unlike `nn.Sequential`, 287# Pipe allows non-local usages of values, so long as those uses still respect 288# topological ordering. In particular: 289# 290# 1. Non-local activations. This type of usage can appear in, for example, skip 291# connections. These values will be directly transmitted from the "def" stage 292# to all stages that use them skipping intermediate stages. During autograd, 293# gradients will be propagated back through this skip connection reverse 294# to how activations propagated in the forward pass. 295# 2. Non-local parameter/module invocations. This occurs when a parameter is used 296# in a stage downstream of where it is resident. These values can be carried 297# forward similarly to (1), but in addition one might want to replicate the 298# value on multiple stages. Gradients for these shared parameters will be 299# accumulated separately on each stage, but there will be an additional 300# gradient accumulation before the optimizer step. 301 302 303# Register `_pipe_split()` as an ATen operator. This is required for Export to 304# preserve this marker in the graph. 305torch.library.define("pippy::_pipe_split", "() -> ()") 306 307 308@torch.library.impl("pippy::_pipe_split", "BackendSelect") 309def _pipe_split(): 310 return None 311 312 313@torch.library.register_fake("pippy::_pipe_split") # type: ignore[no-redef] 314def _pipe_split(): # noqa: F811 315 return None 316 317 318# Add an alias for convenience 319aten_pipe_split_alias = torch.ops.pippy._pipe_split.default 320 321# Ask Export to preserve the `_pipe_split` op. 322# See examples in pytorch/torch/fx/node.py 323fx.node._side_effectful_functions.add(aten_pipe_split_alias) 324 325 326# User facing API 327def pipe_split(): 328 """ 329 pipe_split is a special operator that is used to mark the boundary between 330 stages in a module. It is used to split the module into stages. It is a 331 no-op if your annotated module is run eagerly. 332 333 Example: 334 >>> # xdoctest: +SKIP 335 >>> def forward(self, x): 336 >>> x = torch.mm(x, self.mm_param) 337 >>> x = torch.relu(x) 338 >>> pipe_split() 339 >>> x = self.lin(x) 340 >>> return x 341 342 The above example will be split into two stages. 343 """ 344 return torch.ops.pippy._pipe_split() 345 346 347class MultiUseParameterConfig(Enum): 348 TRANSMIT = 1 349 REPLICATE = 2 350 351 352MultiUseParamSpec = Union[MultiUseParameterConfig, Dict[str, MultiUseParameterConfig]] 353 354 355class DetachExecutor(fx.Interpreter): 356 """ 357 Special interpreter to run the split_gm in testing that detaches all inputs to 358 a module invocation. This is needed so that the values at the boundary are 359 leaf modules in autograd execution. 360 """ 361 362 def __init__(self, module, garbage_collect_values=True): 363 garbage_collect_values = False 364 super().__init__(module, garbage_collect_values) 365 self.value_remap = {} 366 367 def run(self, *args, initial_env=None): 368 self.value_remap = {} 369 return super().run(*args, initial_env=initial_env) 370 371 def call_module(self, target, args, kwargs): 372 def detach_tensors(a): 373 if isinstance(a, torch.Tensor) and a.requires_grad: 374 if a not in self.value_remap: 375 new_val = a.detach().requires_grad_(True) 376 self.value_remap[a] = new_val 377 return self.value_remap[a] 378 else: 379 return a 380 381 """ 382 def dont_traverse_size(a): 383 return type(a) != torch.Size 384 """ 385 386 args = map_aggregate( 387 args, 388 detach_tensors, # dont_traverse_size 389 ) 390 kwargs = map_aggregate( 391 kwargs, 392 detach_tensors, # dont_traverse_size 393 ) 394 395 return super().call_module(target, args, kwargs) 396 397 def call_function(self, target, args, kwargs): 398 # HACK to reroute saved input tensors to point to the detach()ed version 399 if target == stage_backward: 400 kwargs = dict(kwargs) 401 kwargs["input_values"] = [ 402 self.value_remap.get(v, v) for v in kwargs["input_values"] 403 ] 404 return super().call_function(target, args, kwargs) 405 406 407class _NodeReference: 408 def __init__(self, name): 409 self.name = name 410 411 name: str 412 413 414class _LinearNodeList: 415 def __init__(self, node_list): 416 self.serialize_node_list = [] 417 for node in node_list: 418 node_args = fx.node.map_arg(node.args, lambda n: _NodeReference(n.name)) # type: ignore[arg-type,return-value] 419 node_kwargs = fx.node.map_arg(node.kwargs, lambda n: _NodeReference(n.name)) # type: ignore[arg-type,return-value] 420 serialize_node = fx.Node( 421 graph=None, # type: ignore[arg-type] 422 name=node.name, 423 op=node.op, 424 target=node.target, 425 args=node_args, # type: ignore[arg-type] 426 kwargs=node_kwargs, # type: ignore[arg-type] 427 return_type=node.type, 428 ) 429 serialize_node.meta = copy.copy(node.meta) 430 self.serialize_node_list.append(serialize_node) 431 432 def to_graph(self): 433 graph = fx.Graph() 434 435 ref_str_to_node: Dict[str, fx.Node] = {} 436 437 def ref_to_node(arg): 438 if isinstance(arg, _NodeReference): 439 return ref_str_to_node[arg.name] 440 else: 441 return arg 442 443 for node in self.serialize_node_list: 444 node_args = map_aggregate(node.args, ref_to_node) 445 node_kwargs = map_aggregate(node.kwargs, ref_to_node) 446 deser_node = graph.create_node( 447 op=node.op, 448 target=node.target, 449 args=node_args, # type: ignore[arg-type] 450 kwargs=node_kwargs, # type: ignore[arg-type] 451 name=node.name, 452 type_expr=node.type, 453 ) 454 ref_str_to_node[node.name] = deser_node 455 456 return graph 457 458 459def _direct_serialization_deserialize(body, nodes): 460 """ 461 Custom `__reduce__` method for serialization. 462 DO AS I SAY -- NOT AS I DO. This violates the principle that 463 GraphModules serialize via code export & re-tracing. We allow 464 for this here because **PIPE STAGES SHOULD NOT BE PERSISTED 465 TO DISK -- THIS IS ONLY FOR TRANSMISSION VIA RPC**. Persisting 466 these instances to disk will expose internal implementation 467 details of `fx.Graph` and related data structures and is 468 NOT advised. 469 """ 470 471 class DummyModule(torch.nn.Module): 472 def __init__(self, body): 473 super().__init__() 474 self.__dict__.update(body) 475 476 dummy = DummyModule(body) 477 478 return fx.GraphModule(dummy, nodes.to_graph()) 479 480 481def _direct_serialization_reduce(self): 482 serialization_dict = dict(self.__dict__) 483 serialization_dict.pop("_graph") 484 return ( 485 _direct_serialization_deserialize, 486 (serialization_dict, _LinearNodeList(self.graph.nodes)), 487 ) 488 489 490def _modify_graph_op_device( 491 gm: torch.fx.GraphModule, 492 new_device: torch.device, 493): 494 """ 495 Modify the device argument of all "call_function" nodes in the graph. This 496 is useful for moving the graph to a different device. In particular for 497 generator ops, like torch.ones. 498 """ 499 modified = False 500 for node in gm.graph.nodes: 501 if node.op == "call_function": 502 if "device" in node.kwargs and node.kwargs["device"] != new_device: 503 logger.debug( 504 f"Changing device of Node {node.name} from {node.kwargs['device']} to {new_device}" # noqa: G004 505 ) 506 node.update_kwarg("device", new_device) 507 modified = True 508 elif node.op == "call_module": 509 # Recursively modify "device" in submodules 510 submod = gm.get_submodule(node.target) 511 if isinstance(submod, torch.fx.GraphModule): 512 _modify_graph_op_device(submod, new_device) 513 elif isinstance(submod, InterpreterModule): 514 # If unflattening has been performed, we need to access its graph module by `.graph_module` 515 _modify_graph_op_device(submod.graph_module, new_device) 516 else: 517 logger.warning( 518 f"Skipping device modification for submodule {node.target} because it is a {type(submod)}" # noqa: G004 519 ) 520 521 if modified: 522 gm.recompile() 523 524 525class Pipe(torch.nn.Module): 526 def __init__( 527 self, 528 split_gm: fx.GraphModule, 529 num_stages: int, 530 has_loss_and_backward: bool, 531 loss_spec, 532 ): 533 # TODO: is there a way not to hard wire init? 534 torch.nn.Module.__init__(self) 535 self.split_gm: fx.GraphModule = split_gm 536 self.executor: DetachExecutor = DetachExecutor(self.split_gm) 537 self.num_stages: int = num_stages 538 self.has_loss_and_backward = has_loss_and_backward 539 self.loss_spec = loss_spec 540 541 for node in split_gm.graph.nodes: 542 assert ( 543 node.op in {"call_module", "placeholder", "output"} 544 or (node.op, node.target) == ("call_function", operator.getitem) 545 or (node.op, node.target) == ("call_method", "backward") 546 or (node.op, node.target) == ("call_function", stage_backward) 547 or (node.op, node.target) 548 == ("call_function", _null_coalesce_accumulate) 549 ), node 550 551 # Detect replicated parameters so we know that we have to do an additional allreduce 552 # before applying the optimizer 553 # 554 # Note that this also handles the case where there were multiple calls to a single 555 # module from different stages, regardless of whether that module invocation 556 # was handled by the logic above. 557 558 # Map parameter value to a dictionary that maps the user pipeline module 559 # to the local qualname within that module 560 params_to_users: Dict[torch.nn.Parameter, Dict[str, str]] = {} 561 562 for m_qualname, mod in self.split_gm.named_children(): 563 for p_qualname, param in mod.named_parameters(): 564 params_to_users.setdefault(param, {}) 565 params_to_users[param][m_qualname] = p_qualname 566 567 self.replicated_params: List[Dict[str, str]] = [ 568 use_mapping 569 for _, use_mapping in params_to_users.items() 570 if len(use_mapping) > 1 571 ] 572 573 # We must break the aliasing relationship between the replicated parameters for correct 574 # numerics in reference runs. If we do not do this, the autograd tape in separate stages 575 # will have a reference to the same tensor value and will erroneously apply gradient 576 # updates multiple times. Therefore, for each replicated parameter set, we deepcopy the 577 # values so that we have separate instances. 578 for param_mapping in self.replicated_params: 579 for submod_name, param_qualname in param_mapping.items(): 580 submod = getattr(self.split_gm, submod_name) 581 atoms = param_qualname.split(".") 582 for atom in atoms[:-1]: 583 submod = getattr(submod, atom) 584 setattr(submod, atoms[-1], copy.deepcopy(getattr(submod, atoms[-1]))) 585 586 def throw(self, *args, **kwargs): 587 raise RuntimeError( 588 "To run pipeline locally, invoke the Pipe object directly, not `split_gm`" 589 ) 590 591 self.split_gm.forward = throw 592 593 # Make submodules use custom direct-serialized GraphModule 594 i = 0 595 while True: 596 try: 597 name = f"submod_{i}" 598 submod = getattr(self.split_gm, name) 599 submod.__class__.__reduce__ = _direct_serialization_reduce 600 i += 1 601 except AttributeError: 602 break 603 604 def forward(self, *args, **kwargs): 605 executor_args = args 606 if len(kwargs) > 0: 607 parameters = [] 608 for node in self.split_gm.graph.nodes: 609 if node.op == "placeholder": 610 if node.args and len(node.args) > 0: 611 parameters.append( 612 Parameter( 613 node.target, 614 Parameter.POSITIONAL_OR_KEYWORD, 615 default=node.args[0], 616 ) 617 ) 618 else: 619 parameter_kind = Parameter.POSITIONAL_OR_KEYWORD 620 param_name = node.target 621 if node.target.startswith("**"): 622 parameter_kind = Parameter.VAR_KEYWORD # type: ignore[assignment] 623 param_name = param_name[2:] 624 elif node.target.startswith("*"): 625 parameter_kind = Parameter.VAR_POSITIONAL # type: ignore[assignment] 626 param_name = param_name[1:] 627 parameters.append(Parameter(param_name, parameter_kind)) 628 signature = Signature(parameters) 629 ba = signature.bind(*args, **kwargs) 630 ba.apply_defaults() 631 executor_args = ba.arguments.values() # type: ignore[assignment] 632 633 res = self.executor.run(*executor_args) 634 635 return res 636 637 def get_stage_module(self, stage_idx: int) -> torch.nn.Module: 638 """ 639 Return a stage module corresponding to `stage_idx` of the `pipe`. 640 """ 641 if stage_idx < 0 or stage_idx >= self.num_stages: 642 raise ValueError(f"Invalid stage index {stage_idx}!") 643 return getattr(self.split_gm, f"submod_{stage_idx}") 644 645 @staticmethod 646 def _number_and_count_forward_stages(gm: fx.GraphModule): 647 num_stages = 0 648 found_idxs: Dict[int, None] = {} 649 for node in gm.graph.nodes: 650 if node.op == "call_module" and node.target.startswith("submod_"): 651 node.meta["stage_idx"] = int(node.target[len("submod_") :]) 652 found_idxs.setdefault(node.meta["stage_idx"]) 653 num_stages += 1 654 655 # this assert will fail if a split point is inserted before the first layer, which creates empty first submodule 656 # Update: the following assert may fail against some torch versions >= 657 # 2.2.0, as: 658 # submod_0, submod_1, submod_2, ... 659 # may be named as 660 # submod_0, submod_2, submod_4, ... 661 # TODO: investigate 662 # assert all(i in found_idxs for i in range(num_stages)) 663 664 return num_stages 665 666 @staticmethod 667 def _from_traced( 668 mod: torch.nn.Module, 669 exported_program: ExportedProgram, 670 multi_use_param_spec: Optional[MultiUseParamSpec] = None, 671 output_loss_value_spec=None, 672 split_policy: Optional[ 673 Callable[[torch.fx.GraphModule], torch.fx.GraphModule] 674 ] = None, 675 ): 676 """ 677 Additionally, the ``output_loss_value_spec`` value can be specified to disambiguate 678 which value in the output of `forward` is the loss value on which PiPPy should apply 679 backpropagation. For example, if your ``forward`` returns a tuple ``(loss, model_out)``, 680 you can specify ``output_loss_value_spec=(True, False)``. Or, if your ``forward`` returns 681 a dict ``{'loss': loss_value, 'model_out': model_out}``, you can specify 682 ``output_loss_value_spec={'loss': True, 'model_out': False}`` 683 """ 684 685 traced = exported_program.module() 686 687 if split_policy is not None: 688 logger.info("Auto-splitting model") 689 traced = split_policy(traced) # type: ignore[arg-type] 690 691 logger.debug(traced.print_readable(print_output=False)) 692 693 # Deduplicate `get_attr` nodes that refer to the same parameter . Downstream code for moving 694 # parameters relies on the invariant that parameter accesses happen once. This is not necessarily 695 # the case (especially with custom tracers), so fix that up here. 696 get_attr_nodes: Dict[str, fx.Node] = {} 697 for node in traced.graph.nodes: 698 if node.op == "get_attr": 699 get_attr_nodes.setdefault(node.target, node) 700 701 if get_attr_nodes[node.target] != node: 702 node.replace_all_uses_with(get_attr_nodes[node.target]) 703 traced.graph.erase_node(node) 704 705 # avoid looking at next node by keeping track of previous pipe_split 706 prev_pipe_split_idx = -1 707 pipe_split_nodes_to_erase = set() 708 for i, node in enumerate(traced.graph.nodes): 709 if (node.op, node.target) == ("call_function", pipe_split): 710 if prev_pipe_split_idx == i - 1: 711 pipe_split_nodes_to_erase.add(node) 712 prev_pipe_split_idx = i 713 714 for node in pipe_split_nodes_to_erase: 715 traced.graph.erase_node(node) 716 717 traced.recompile() 718 719 part_idx = 0 720 721 def split_callback(n: fx.Node): 722 nonlocal part_idx 723 if (n.op, n.target) == ( 724 "call_function", 725 aten_pipe_split_alias, 726 ): 727 logger.debug(f"Found pipe_split {part_idx}") # noqa: G004 728 part_idx += 1 729 return part_idx 730 731 # TODO: what does split do with module invocations? does it move the modules 732 # into the submodules? 733 split = split_module(traced, mod, split_callback) # type: ignore[arg-type] 734 # a (custom) tracer can produce dead code like orphan get_attr nodes 735 split.graph.eliminate_dead_code() 736 737 # peephole to remove pipe_split 738 for submodule in split.modules(): 739 if isinstance(submodule, fx.GraphModule): 740 for node in submodule.graph.nodes: 741 if (node.op, node.target) == ( 742 "call_function", 743 aten_pipe_split_alias, 744 ): 745 submodule.graph.erase_node(node) 746 submodule.recompile() 747 748 for name, submodule in split.named_children(): 749 if isinstance(submodule, fx.GraphModule): 750 new_submod = _outline_submodules(submodule.graph) 751 # Replace old submod 752 split.register_module(name, new_submod) 753 754 # TODO: backport this into split_module 755 def delete_user_reference(node, user): 756 """ 757 Delete reference of `node` from `user`'s arg list. 758 Args: 759 - node: a `get_attr` node at root. 760 - user: a submodule node that uses `node`. 761 """ 762 assert len(user.kwargs) == 0 763 use_idxs = [i for i, arg in enumerate(user.args) if arg == node] 764 assert len(use_idxs) == 1 765 args_copy = list(user.args) 766 args_copy.pop(use_idxs[0]) 767 user.args = tuple(args_copy) 768 logger.debug( 769 f"Deleted {node} from user {user}, arg index = {use_idxs[0]}" # noqa: G004 770 ) 771 772 # A list of param referrals for deferred deletion. 773 # To be accumulated in `move_param_to_callee`. 774 to_delete = [] 775 776 def _recursive_getattr_with_parent(mod, fqn): 777 # Returns getattr call given a nested FQN, and the last parent 778 atoms = fqn.split(".") 779 for atom in atoms[:-1]: 780 if not hasattr(mod, atom): 781 return None, None 782 mod = getattr(mod, atom) 783 if not hasattr(mod, atoms[-1]): 784 return mod, None 785 attr = getattr(mod, atoms[-1]) 786 return mod, attr 787 788 def move_param_to_callee( 789 root, 790 callee_name, 791 param_fqn, 792 ): 793 """ 794 Move a parameter from the root module to a submodule. 795 Args: 796 root: The root module. 797 callee_name: The name of the submodule to move the parameter to. 798 param_fqn: The fully qualified name of the parameter to move. 799 """ 800 # `atoms` is a list of strings representing the path to the 801 # parameter in the original model 802 atoms = param_fqn.split(".") 803 mod_itr, param_val = _recursive_getattr_with_parent(split, param_fqn) 804 # Check whether the parameter is a buffer or a parameter 805 is_buffer = atoms[-1] in mod_itr._buffers 806 807 # Check whether the parameter is a tensor 808 assert isinstance(param_val, torch.Tensor), ( 809 f"Expected '{param_fqn}' to be {torch.Tensor} but got {type(param_val)}." 810 + ( 811 f" It might happen if module '{param_fqn}' was passed to some 'leaf function'" 812 f"(see https://pytorch.org/docs/stable/fx.html#fx.wrap). Please inspect " 813 f"usages of '{param_fqn}' in the traced graph." 814 if isinstance(param_val, torch.nn.Module) 815 else "" 816 ) 817 ) 818 819 # Get submodule 820 callee = root.get_submodule(callee_name) 821 assert not hasattr( 822 callee, param_fqn 823 ), f"Module {callee_name} already has a parameter named {param_fqn}" 824 825 # Assign the parameter to the submodule 826 if is_buffer: 827 _assign_attr( 828 param_val, 829 callee, 830 param_fqn, 831 attr_kind=_AttrKind.BUFFER, 832 persistent=True, # TODO: handle non-persistent buffer 833 ) 834 else: 835 _assign_attr( 836 param_val, 837 callee, 838 param_fqn, 839 attr_kind=_AttrKind.PARAMETER, 840 ) 841 logger.debug(f"Moved parameter {param_fqn} to {callee_name}") # noqa: G004 842 843 # Next step is to replace placeholder of submodule with a get_attr. 844 # Those placeholders are created by `split_module` inside each 845 # submodule. 846 # Update: this step is now moved to `_sink_params` because 847 # `_sink_params` can do it recursively (i.e. for modules inside 848 # submodule) 849 850 to_delete.append((mod_itr, atoms[-1])) 851 852 # Get the list of all parameters in the root module 853 attr_nodes = list(filter(lambda n: n.op == "get_attr", split.graph.nodes)) 854 for node in attr_nodes: 855 # Check whether the parameter is used in only one submodule 856 if len(node.users) > 1: 857 logger.info( 858 f"Parameter {node.target} used in multiple stages: {node.users}." # noqa: G004 859 ) 860 for user in node.users: 861 assert user.op == "call_module" 862 # Move parameter into submodule 863 move_param_to_callee( 864 split, 865 user.target, 866 node.target, 867 ) 868 869 # [aliasing] store tensor id -> list of FQNs, built from state dict 870 # Also assign non-persistent buffers 871 id_to_fqns: Dict[int, Set[str]] = defaultdict(set) 872 for fqn, tensor in mod.state_dict(keep_vars=True).items(): 873 id_to_fqns[id(tensor)].add(fqn) 874 for fqn, tensor in mod.named_buffers(): 875 id_to_fqns[id(tensor)].add(fqn) 876 877 # After moving the params to their corresponding hierarchies, we also 878 # need to move the `get_attr` nodes from the root of the graph to those 879 # hierarchies. 880 # [aliasing] use id -> fqn mapping to list out all valid FQNs 881 inputs_to_state: Dict[str, List[str]] = {} 882 for attr in attr_nodes: 883 _, tensor = _recursive_getattr_with_parent(mod, attr.target) 884 fqns = list(id_to_fqns[id(tensor)]) 885 if fqns: 886 inputs_to_state[attr.name] = fqns 887 elif attr.target in exported_program.constants: # lifted constants 888 inputs_to_state[attr.name] = [attr.target] 889 890 # [aliasing] for each submodule split, assign attributes on FQNs that may be used. 891 # We determine this based on whether or not the FQN attribute parent exists. 892 # i.e. if the last submodule exists, assign the attribute. 893 added_attributes: Dict[str, List[str]] = defaultdict(list) 894 for fqn, tensor in mod.state_dict(keep_vars=True).items(): 895 for name, submod in split.named_children(): 896 if isinstance(submod, fx.GraphModule): 897 parent, child = _recursive_getattr_with_parent(submod, fqn) 898 if ( 899 parent and child is None 900 ): # parent exists, attribute doesn't -> assign 901 added_attributes[name].append(fqn) 902 setattr(parent, fqn.split(".")[-1], tensor) 903 904 # Deferral deletion: Remove the original attributes (to params) from the 905 # root GraphModule 906 for mod_itr, last_atom in to_delete: 907 try: 908 delattr(mod_itr, last_atom) 909 except AttributeError: 910 # This is expected if the parameter is used in multiple stages 911 pass 912 913 # This is done by (1) `_sink_params` at each submodule; 914 for name, submod in split.named_children(): 915 if isinstance(submod, fx.GraphModule): 916 _sink_params(submod, inputs_to_state, []) 917 submod.graph.lint() 918 submod.recompile() 919 920 # [aliasing] This step is not super necessary, but helps reduce parameter usage/memory. 921 # After _sink_params() routine has run, clean up unused attributes that we previously added. 922 # Determine this based on the get_attr nodes - if not used, remove it. 923 for name, attributes in added_attributes.items(): 924 submod = getattr(split, name) 925 unused_attributes = set(attributes) 926 # track used attributes in the submodule, running DFS on subgraph hierarchy 927 stack = [("", submod)] # (scope, submodule) 928 while stack: 929 scope, _mod = stack.pop() 930 if isinstance(_mod, (fx.GraphModule, InterpreterModule)): 931 for node in _mod.graph.nodes: 932 if node.op == "get_attr": 933 # get_attr might get access deeper level attribute 934 fqn = scope + "." + node.target if scope else node.target 935 if fqn in unused_attributes: # used, remove it 936 unused_attributes.remove(fqn) 937 for _name, _submod in _mod.named_children(): 938 stack.append((scope + "." + _name if scope else _name, _submod)) 939 # delete unused attributes 940 for attr in unused_attributes: 941 mod_itr, atoms = submod, attr.split(".") 942 for atom in atoms[:-1]: 943 mod_itr = getattr(mod_itr, atom) 944 delattr(mod_itr, atoms[-1]) 945 946 for node in attr_nodes: 947 # And (2): remove `get_attr` node from submod's arg list 948 for user in copy.copy(node.users): 949 assert user.op == "call_module" 950 delete_user_reference(node, user) 951 # And (3): remove the `get_attr` node from the root graph. 952 split.graph.erase_node(node) 953 954 split.delete_all_unused_submodules() 955 split.graph.lint() 956 split.recompile() 957 958 num_stages = Pipe._number_and_count_forward_stages(split) 959 960 has_loss_and_backward = False 961 generated_loss_spec = output_loss_value_spec 962 963 if output_loss_value_spec is not None: 964 loss_node, output_node, generated_loss_spec = _find_loss_output( 965 mod, split.graph, output_loss_value_spec 966 ) 967 if loss_node is not None: 968 _insert_stage_symbolic_backward( 969 split.graph, 970 loss_node, 971 output_node, 972 ) 973 split.recompile() 974 has_loss_and_backward = True 975 logger.debug("Pipeline is in training mode, backward pass generated") 976 else: 977 raise RuntimeError( 978 f"Did not find any loss value according to {output_loss_value_spec=}" 979 ) 980 else: 981 logger.debug("Pipeline is in inference mode, backward pass not generated") 982 983 logger.debug("Full pipe model:\n" f"{split}") # noqa: G004 984 985 return Pipe( 986 split, 987 num_stages, 988 has_loss_and_backward, 989 generated_loss_spec, 990 ) 991 992 def print_readable(self): 993 """ 994 Print the pipe in a human-readable format. 995 This will print both the root pipe and each stage module. 996 """ 997 self.split_gm.print_readable() 998 999 @staticmethod 1000 def _trace_with_export( 1001 mod: torch.nn.Module, 1002 example_args: Tuple[Any, ...], 1003 example_kwargs: Optional[Dict[str, Any]] = None, 1004 ) -> ExportedProgram: 1005 logger.info("Tracing model ...") 1006 try: 1007 ep = torch.export.export( 1008 mod, 1009 example_args, 1010 example_kwargs, 1011 ) 1012 except Exception as e: 1013 raise RuntimeError( 1014 "It seems that we cannot capture your model as a full graph. " 1015 "Typical reasons include graph breaks, data/shape-dependent " 1016 "control flow, or missing meta kernels for custom operators. " 1017 "You can use our manual pipeline interfaces, or try to fix the " 1018 "graph breaks, see https://pytorch.org/docs/stable/export.html" 1019 ) from e 1020 1021 return ep 1022 1023 @staticmethod 1024 def from_tracing( 1025 mod: torch.nn.Module, 1026 example_args: Tuple[Any, ...], 1027 example_kwargs: Optional[Dict[str, Any]] = None, 1028 split_policy: Optional[Callable[[fx.GraphModule], fx.GraphModule]] = None, 1029 ): 1030 # If a param will be used in multiple pipeline stages, we default the strategy to REPLICATE'ing the param across 1031 # stages instead of TRANSMIT'ting it 1032 multi_use_param_spec = MultiUseParameterConfig.REPLICATE 1033 1034 # Figure out which output is loss from output_chunk_spec 1035 output_loss_value_spec: Any = None 1036 # Deprecated 1037 """ 1038 if output_chunk_spec is not None: 1039 output_loss_value_spec = map_aggregate( 1040 output_chunk_spec, lambda v: isinstance(v, _LossReducer) 1041 ) 1042 """ 1043 1044 # Trace with export 1045 exported_program = Pipe._trace_with_export( 1046 mod, 1047 example_args, 1048 example_kwargs, 1049 ) 1050 1051 pipe = Pipe._from_traced( 1052 mod, 1053 exported_program, 1054 multi_use_param_spec, 1055 output_loss_value_spec=output_loss_value_spec, 1056 split_policy=split_policy, 1057 ) 1058 1059 # Users want the first pipeline stage to accept kwargs if the original 1060 # program does. This is controlled by the `_codegen` field of the graph, 1061 # so we make a copy here. Note: we only want the input spec and not the 1062 # output spec, because the output spec is for the last stage. Maybe a 1063 # TODO? Not sure yet. 1064 split = pipe.split_gm 1065 traced = exported_program.module() 1066 submod0 = next(iter(split.children())) 1067 submod0_sign = signature(submod0.forward) 1068 model_sign = signature(traced.forward) 1069 if len(model_sign.parameters) != len(submod0_sign.parameters): 1070 # We don't change the signature of the first stage if it takes 1071 # different number of args than original model 1072 logger.info( 1073 f"Original model takes {len(model_sign.parameters)} args but the " # noqa: G004 1074 f"first pipeline stage takes {len(submod0_sign.parameters)}. " 1075 "Please provide args to respective pipeline stages." 1076 ) 1077 else: 1078 # Support kwargs for the first stage 1079 submod0.graph._codegen = copy.deepcopy(traced.graph._codegen) 1080 # `_replace` is actually not "private" or internal. based on this doc: 1081 # To prevent conflicts with field names, the method and attribute names 1082 # start with an underscore 1083 submod0.graph._codegen.pytree_info = ( 1084 submod0.graph._codegen.pytree_info._replace(out_spec=None) 1085 ) 1086 submod0.recompile() 1087 1088 return pipe 1089 1090 def __str__(self): 1091 return self.split_gm.__str__() 1092 1093 def __repr__(self): 1094 return self.split_gm.__repr__() 1095 1096 def info(self) -> PipeInfo: 1097 """ 1098 Get information about the pipe. 1099 1100 Returns 1101 ------- 1102 PipeInfo 1103 A dataclass containing information about the pipe. 1104 """ 1105 return PipeInfo( 1106 graph=self.split_gm.graph, 1107 num_stages=self.num_stages, 1108 has_loss_and_backward=self.has_loss_and_backward, 1109 ) 1110 1111 def build_stage( 1112 self, 1113 stage_index: int, 1114 device: torch.device, 1115 group: Optional[ProcessGroup] = None, 1116 ) -> _PipelineStage: 1117 """ 1118 Create a `PipelineStage` given a stage index and distributed group. 1119 The `PipelineStage` can run with `PipelineSchedule`s. 1120 """ 1121 # Find stage module 1122 stage_module = self.get_stage_module(stage_index) 1123 1124 # Move ops argument to device 1125 # Today PT2 tracer does not treat `x.device` as a symbolic device; 1126 # instead, the device of tracing time got burned into the generated 1127 # code. Here we provide a workaround for users to manually modify the 1128 # "device" kwarg of operations. Such operation may include: 1129 # `torch.ones`, `torch.zeros`, `torch.rand`, etc. 1130 if isinstance(stage_module, torch.fx.GraphModule): 1131 _modify_graph_op_device(stage_module, device) 1132 else: 1133 logger.warning( 1134 f"Expected a `torch.fx.GraphModule` but got {type(stage_module)}" # noqa: G004 1135 ) 1136 1137 # Detach pipe info 1138 # Note: be careful what's included in `pipe_info`. We don't want to keep 1139 # a reference to `Pipe` or `Pipe.split_gm` which stops python from 1140 # recycling them. When python recycles them, other stage modules (which 1141 # are irrelevant to current rank) can be automatically freed. 1142 pipe_info = self.info() 1143 return _PipelineStage(stage_module, stage_index, pipe_info, device, group) 1144 1145 1146class SplitPoint(Enum): 1147 BEGINNING = 1 1148 END = 2 1149 1150 1151# For backward compatibility, we kept the PipeSplitWrapper class because `class 1152# SplitPoint` used to be defined in this class. 1153class PipeSplitWrapper: 1154 # Create a class alias for BC 1155 SplitPoint = SplitPoint 1156 1157 1158def _split_before_forward(self, *args, **kwargs): 1159 pipe_split() 1160 return self._orig_forward(*args, **kwargs) 1161 1162 1163def _split_after_forward(self, *args, **kwargs): 1164 try: 1165 return self._orig_forward(*args, **kwargs) 1166 finally: 1167 pipe_split() 1168 1169 1170def annotate_split_points(mod: torch.nn.Module, spec: Dict[str, SplitPoint]): 1171 # TODO: make this implementation out-of-place? 1172 for qualname, split_type in spec.items(): 1173 atoms = qualname.split(".") 1174 predecessor_module = mod 1175 for i, atom in enumerate(atoms[:-1]): 1176 try: 1177 predecessor_module = getattr(predecessor_module, atom) 1178 except AttributeError as e: 1179 raise AttributeError( 1180 f"Specified target {qualname} referenced " 1181 f'nonexistent module {".".join(atoms[: i + 1])}' 1182 ) from e 1183 1184 mod_to_wrap = getattr(predecessor_module, atoms[-1]) 1185 mod_to_wrap._orig_forward = mod_to_wrap.forward 1186 if split_type == SplitPoint.BEGINNING: 1187 mod_to_wrap.forward = MethodType(_split_before_forward, mod_to_wrap) 1188 elif split_type == SplitPoint.END: 1189 mod_to_wrap.forward = MethodType(_split_after_forward, mod_to_wrap) 1190 else: 1191 raise ValueError("Unknown split point type.") 1192 1193 1194def pipeline( 1195 module: torch.nn.Module, 1196 mb_args: Tuple[Any, ...], 1197 mb_kwargs: Optional[Dict[str, Any]] = None, 1198 split_spec: Optional[Dict[str, SplitPoint]] = None, 1199 split_policy: Optional[Callable[[fx.GraphModule], fx.GraphModule]] = None, 1200) -> Pipe: 1201 """ 1202 Split a module based on a specification. 1203 1204 See `Pipe` for more details. 1205 1206 Arguments 1207 --------- 1208 module: 1209 The module to be splitted. 1210 mb_args: 1211 Example positional inputs, in micro-batch form. 1212 mb_kwargs: 1213 Example keyword inputs, in micro-batch form. (default: `None`) 1214 split_spec: 1215 A dictionary using submodule names as split marker. (default: `None`) 1216 split_policy: 1217 The policy to use for splitting the module. (default: `None`) 1218 1219 Returns 1220 ------- 1221 A pipeline representation of class `Pipe`. 1222 """ 1223 if split_spec is not None and split_policy is not None: 1224 raise ValueError( 1225 "Cannot specify both `split_spec` and `split_policy`. Please use only one of them." 1226 ) 1227 1228 if split_spec is not None: 1229 # Annotate split points in the module based on user spec 1230 annotate_split_points(module, split_spec) 1231 return Pipe.from_tracing( 1232 mod=module, 1233 example_args=mb_args, 1234 example_kwargs=mb_kwargs, 1235 ) 1236 else: 1237 # Use split policy 1238 return Pipe.from_tracing( 1239 mod=module, 1240 example_args=mb_args, 1241 example_kwargs=mb_kwargs, 1242 split_policy=split_policy, 1243 ) 1244