1# mypy: ignore-errors 2 3import logging 4import traceback 5from dataclasses import dataclass, field 6from typing import Any, List, Optional 7from unittest import mock 8 9import torch 10from torch import fx 11from torch._dynamo.output_graph import GraphCompileReason 12from torch._dynamo.utils import deepcopy_to_fake_tensor, detect_fake_mode 13from torch._logging import trace_structured 14from torch.fx.node import Node 15 16 17# Regular log messages should go through 'log'. 18# ddp_graph_log is a separate artifact logger reserved for dumping graphs. 19# See docs/source/logging.rst for more info. 20log = logging.getLogger(__name__) 21ddp_graph_log = torch._logging.getArtifactLogger(__name__, "ddp_graphs") 22 23 24def args_str(args): 25 # a debug helper 26 if torch.is_tensor(args): 27 return f"T[{args.shape}]" 28 elif isinstance(args, tuple): 29 return f"tuple({', '.join([args_str(x) for x in args])})" 30 elif isinstance(args, list): 31 return f"list({', '.join([args_str(x) for x in args])})" 32 else: 33 return str(args) 34 35 36@dataclass 37class Bucket: 38 size: int = 0 39 params: List[str] = field(default_factory=list) 40 nodes: List[fx.Node] = field(default_factory=list) 41 42 # param_ids is just used for unit testing 43 param_ids: List = field(default_factory=list) 44 45 # keep track of any buckets that were extended for logging purposes 46 opcount_increased_to_capture_external_output: int = 0 47 paramsize_before_opcount_increase: int = 0 48 49 50def bucket_has_external_output(bucket: Bucket) -> bool: 51 nodes_in_bucket = set() 52 # we want to iterate in reverse order, but clumsi-luckily the bucket.nodes list was already created backwards 53 # so we don't reverse it here 54 for node in bucket.nodes: 55 # assume node.op != output, since those are filtered in the original iteration 56 nodes_in_bucket.add(node) 57 for user in node.users: 58 if user not in nodes_in_bucket: 59 return True 60 return False 61 62 63def pretty_print_buckets(buckets: List[Bucket], bucket_bytes_cap: int): 64 headers = ("Index", "Size (b)", "Param Names") 65 rows = [] 66 extended_buckets = [] 67 for idx, bucket in enumerate(reversed(buckets)): 68 if len(bucket.params) > 0: 69 rows.append((idx, bucket.size, bucket.params[0])) 70 for param in bucket.params[1:]: 71 rows.append((None, None, param)) 72 if bucket.opcount_increased_to_capture_external_output > 0: 73 extended_buckets.append( 74 ( 75 idx, 76 bucket.opcount_increased_to_capture_external_output, 77 bucket.size - bucket.paramsize_before_opcount_increase, 78 ) 79 ) 80 81 if len(rows): 82 log.info( 83 "\nDDPOptimizer used bucket cap %s and created %d buckets. Enable debug logs for detailed bucket info.", 84 bucket_bytes_cap, 85 len(buckets), 86 ) 87 88 if len(extended_buckets): 89 log.warning( 90 "Some buckets were extended beyond their requested parameter capacities" 91 " in order to ensure each subgraph has an output node, required for fx graph partitioning." 92 " This can be the case when a subgraph would have only contained nodes performing inplace mutation," 93 " and returning no logical outputs. This should not be a problem, unless it results in too few graph" 94 " partitions for optimal DDP performance." 95 ) 96 97 try: 98 from tabulate import tabulate 99 100 log.debug( 101 "\nDDPOptimizer produced the following bucket assignments:\n%s", 102 tabulate(rows, headers=headers, tablefmt="simple_grid"), 103 ) 104 105 if len(extended_buckets): 106 log.warning( 107 "DDPOptimizer extended these buckets to ensure per-subgraph output nodes:\n%s", 108 tabulate( 109 extended_buckets, 110 headers=("Index", "Extra Ops", "Extra Param Size (b)"), 111 tablefmt="simple_grid", 112 ), 113 ) 114 except ImportError: 115 log.debug( 116 "Please `pip install tabulate` in order to display ddp bucket sizes and diagnostic information." 117 ) 118 else: 119 log.debug("DDPOptimizer captured no parameters and did not split this graph.") 120 121 122def has_higher_order_op(gm): 123 # Check if there is a higher order op in the graph 124 for node in gm.graph.nodes: 125 if node.op == "get_attr": 126 maybe_param = getattr(gm, node.target) 127 if isinstance(maybe_param, torch.fx.GraphModule): 128 return True 129 return False 130 131 132# compile each of the partitioned submodules using the user-provided compiler 133class SubmodCompiler(torch.fx.interpreter.Interpreter): 134 def __init__(self, module, compiler, fake_mode) -> None: 135 super().__init__(module) 136 self.compiler = compiler 137 self.fake_mode = fake_mode 138 139 def compile_submod(self, input_mod, args, kwargs): 140 """ 141 Compile the submodule, 142 using a wrapper to make sure its output is always a tuple, 143 which is required by AotAutograd based compilers 144 """ 145 assert len(kwargs) == 0, "We assume only args for these modules" 146 147 class WrapperModule(torch.nn.Module): 148 def __init__(self, submod, unwrap_singleton_tuple) -> None: 149 super().__init__() 150 self.submod = submod 151 self.unwrap_singleton_tuple = unwrap_singleton_tuple 152 153 def forward(self, *args): 154 x = self.submod(*args) 155 # TODO(whc) 156 # for some reason the isinstance check is necessary if I split one node per submod 157 # - even though I supposedly wrapped the output in a tuple in those cases, the real 158 # compiled module was still returning a tensor 159 if self.unwrap_singleton_tuple and isinstance(x, (tuple, list)): 160 return x[0] 161 return x 162 163 unwrap_singleton_tuple = False 164 for sn in input_mod.graph.nodes: 165 if sn.op == "output": 166 if not isinstance(sn.args[0], tuple): 167 unwrap_singleton_tuple = True 168 sn.args = (sn.args,) 169 170 input_mod.recompile() 171 input_mod.compile_subgraph_reason = GraphCompileReason( 172 "DDPOptimizer intentional graph-break (See Note [DDPOptimizer])." 173 " Set `torch._dynamo.config.optimize_ddp = False` to disable.", 174 [ 175 # it's close to useless to get a real stacktrace here, and quite verbose. 176 traceback.FrameSummary(__file__, 0, DDPOptimizer), 177 ], 178 ) 179 180 wrapper = WrapperModule( 181 self.compiler(input_mod, args), 182 unwrap_singleton_tuple, 183 ) 184 return wrapper 185 186 # Note: 187 # 188 # The way distributed works today around fake tensors can be somewhat confusing. 189 # Some of these codepaths are shared in both runtime, and compile time. The presence 190 # of a fake_mode, read off of fake tensor inputs, dictates how we will operate. 191 # 192 # A few things to keep in mind: 193 # 194 # 1) We invoke `compile_submod` with a real module. The output of that gets stored 195 # on the graph via `self.module.add_submodule(n.target, compiled_submod_real)`. 196 # 197 # 2) When running a call_module targeted node, if we have a fake_mode, we fakify the 198 # module we got from self.fetch_attr(n.target). Regardless of fake_mode, we then execute it. 199 # 200 # 3) Fake tensors should always be around during compile time. 201 # 202 # 4) Fake tensors should never be around at runtime. 203 # 204 # 5) We end up with a compilation mode that takes a real submodule and fake tensors, 205 # to match what aot_autograd expects. See Note: [Fake Modules and AOTAutograd] 206 def run_node(self, n: Node) -> Any: 207 args, kwargs = self.fetch_args_kwargs_from_env(n) 208 new_args = [] 209 assert self.fake_mode 210 for arg in args: 211 if isinstance(arg, torch.Tensor) and not isinstance( 212 arg, torch._subclasses.FakeTensor 213 ): 214 new_args.append(torch._dynamo.utils.to_fake_tensor(arg, self.fake_mode)) 215 else: 216 new_args.append(arg) 217 218 log.debug("run_node %s, %s got args %s", n.op, n.target, args_str(args)) 219 assert isinstance(args, tuple) 220 assert isinstance(kwargs, dict) 221 222 if n.op == "call_module": 223 real_mod = self.fetch_attr(n.target) 224 if self.fake_mode: 225 curr_submod = deepcopy_to_fake_tensor(real_mod, self.fake_mode) 226 else: 227 curr_submod = real_mod 228 229 ddp_graph_log.debug("\n---%s graph---\n%s", n.target, curr_submod.graph) 230 231 # When calling the compiler on the submod, inputs (new_args) are expected to 232 # be FakeTensors already since Dynamo would have made them FakeTensors in the 233 # non-DDP flow. However, the parameters are _not_ expected to be FakeTensors, 234 # since this wrapping happens during compilation 235 236 # Note: Returning Fake Tensors on First AOT Autograd Call 237 # 238 # Inductor will optimize strides of outputs when it deems it profitable. 239 # For instance, converting to channels last. When we split the graph here 240 # into multiple inductor compilations, we need to make sure that the 241 # output strides of one compilation is appropriately passed to the subsequent 242 # compilations. However, the mapping from inductor output to dynamo output 243 # is non-trivial due to aot_autograd's deduping, de-aliasing, mutation, re-writing, 244 # subclass handling, etc. In order to replay all this logic we set a flag such that 245 # the first invocation of inductor in aot_autograd will return Fake Tensors with 246 # appropriate strides. Then, all of aot autograd's runtime logic is replayed. 247 # This gives us the appropriately strided outputs here which will reflect runtime strides. 248 249 class FakeifyFirstAOTInvocationGuard: 250 def __init__(self) -> None: 251 self.tc = torch._guards.TracingContext.try_get() 252 assert self.tc 253 torch._guards.TracingContext.try_get().fakify_first_call = True 254 255 def __del__(self) -> None: 256 self.tc.fakify_first_call = False 257 258 # For aot_eager and other backends, tracing context is not set 259 has_tracing_context = torch._guards.TracingContext.try_get() is not None 260 if has_tracing_context: 261 g = FakeifyFirstAOTInvocationGuard() 262 263 from torch._dynamo.utils import counters 264 265 init = counters["aot_autograd"]["total"] 266 compiled_submod_real = self.compile_submod(real_mod, new_args, kwargs) 267 268 # TODO - better way of doing this? 269 # Only aot autograd handles fakifying first call 270 invoked_aot_autograd = init != counters["aot_autograd"]["total"] 271 272 # We update the original (outer) graph with a call into the compiled module 273 # instead of the uncompiled one. 274 self.module.delete_submodule(n.target) 275 n.target = "compiled_" + n.target 276 self.module.add_submodule(n.target, compiled_submod_real) 277 278 # Finally, we have to produce inputs for use compiling the next submodule, 279 # and these need to be FakeTensors, so we execute the module under fake_mode 280 # Because parameters are not fake we patch fake tensor mode to allow non fake inputs 281 with self.fake_mode, mock.patch.object( 282 self.fake_mode, "allow_non_fake_inputs", True 283 ): 284 if has_tracing_context and invoked_aot_autograd: 285 out = compiled_submod_real(*new_args, **kwargs) 286 # output should be fake or subclass 287 assert all( 288 (not isinstance(t, torch.Tensor) or type(t) is not torch.Tensor) 289 for t in (out if isinstance(out, (list, tuple)) else [out]) 290 ) 291 return out 292 else: 293 return curr_submod(*new_args, **kwargs) 294 else: 295 # placeholder or output nodes don't need to get compiled, just executed 296 return getattr(self, n.op)(n.target, new_args, kwargs) 297 298 299class DDPOptimizer: 300 """Note [DDPOptimizer] 301 DDPOptimizer applies when dynamo compiles models wrapped in DistributedDataParallel (DDP), 302 breaking the dynamo graph into chunks to compile separately, with the breaks aligning to 303 the boundaries of gradient-allreduce buckets chosen by DDP. 304 305 Background/Motivation 306 - DDP uses allreduce collectives to synchronize partial gradients computed on different workers 307 - DDP groups gradient allreduces into 'buckets' to optimize communication efficiency of all-reduce 308 - Parameters grouped into buckets are assumed to be adjacent in time, so they become ready 309 at around the same time during backward and thus can share the same allreduce efficiently 310 - Allreduces must overlap with backward compute for optimal training performance 311 - DDP schedules allreduces using 'hooks' fired from the c++ autograd engine in pytorch, which 312 operates when individual grads become 'ready' 313 - Dynamo+AOTAutograd produces a single fused graph that runs 'atomically' from the perspective of the 314 autograd engine, such that all gradients become 'ready' at the same time. Hooks fire after the whole 315 fused backward function executes, preventing any overlap of compute and communication 316 317 Algorithm 318 - DDPOptimizer starts off with an FX graph traced by dynamo which represents forward. It can traverse 319 this graph in reverse order to determine the true order that gradients will become ready during backward. 320 - Parameter sizes are counted in reverse order, up to a bucket size limit, at which point a new bucket is started 321 and a graph break introduced 322 - Each of the subgraphs is compiled by the compiler provided to dynamo by the user, and then fused back together 323 into an outer module that is returned to the user 324 325 Notes 326 - It would be better to enforce (by adding an API to DDP) that the bucket splits chosen here are used by DDP, 327 and that DDP does not need to detect or optimize bucket order by observing execution at runtime, as it does 328 in eager. 329 - If Dynamo can't capture a whole graph for the portion of the model wrapped by DDP, this algorithm will currently 330 produce splits that do not necessarily align with the buckets used by DDP. This should result in performance 331 degradation approaching the baseline case where graph-splits are not used, but not worse. 332 - If the backend compiler fails to compile a single subgraph, it will execute eagerly despite the rest of the 333 subgraphs being compiled 334 - DDP has a 'parameters_and_buffers_to_ignore' field, which DDPOptimizer attempts to honor by reading markers 335 left by DDP on individual parameters. In cases where other transformations, such as reparameterization, are 336 also used, the ignore markers could be lost. If DDPOptimizer fails to ignore a parameter ignored by DDP, 337 it is not catastrophic but could impact performance by choosing sub-optimal bucket splits. 338 - DDPOptimizer always ignores all buffers, regardless of their ignore flag, since buffers do not require gradients, 339 and therefore aren't allreduced by DDP. (They are broadcast during forward, but this is not covered by 340 DDPOptimizer) 341 342 Debugging 343 - Generally, it is easiest to debug DDPOptimizer in a single process program, using pdb. 344 - In many cases, the log messages are helpful (they show bucket size assignments)- 345 just set TORCH_LOGS env to include any of 'dynamo', 'distributed', or 'dist_ddp'. 346 - See `benchmarks/dynamo/distributed.py` for a simple harness that will run a toy model or a torchbench model 347 in a single process (or with torchrun, in multiple processes) 348 349 Args: 350 bucket_bytes_cap (int): Controls the size of buckets, in bytes, used to determine graphbreaks. Should be 351 set to match the equivalent parameter on the original DDP module. 352 353 backend_compile_fn (callable): A dynamo compiler function, to be invoked to compile each subgraph. 354 355 first_bucket_cap (int): Controls the size of the first bucket. Should match DDP's first bucket cap. DDP 356 special-cases the first bucket size since it is sometimes optimal to start a small allreduce early. 357 358 """ 359 360 def __init__( 361 self, 362 bucket_bytes_cap: int, 363 backend_compile_fn, 364 first_bucket_cap: Optional[int] = None, 365 ) -> None: 366 if first_bucket_cap is not None: 367 self.first_bucket_cap = first_bucket_cap 368 elif torch.distributed.is_available(): 369 # this constant comes from C10D lib which is not always built 370 self.first_bucket_cap = torch.distributed._DEFAULT_FIRST_BUCKET_BYTES 371 else: 372 self.first_bucket_cap = bucket_bytes_cap 373 374 self.bucket_bytes_cap = bucket_bytes_cap 375 assert ( 376 self.first_bucket_cap <= self.bucket_bytes_cap 377 ), "First bucket should be smaller/equal to other buckets to get comms warmed up ASAP" 378 379 self.backend_compile_fn = backend_compile_fn 380 381 def _ignore_parameter(self, parameter): 382 return hasattr(parameter, "_ddp_ignored") and parameter._ddp_ignored 383 384 def add_param(self, bucket, param, name): 385 bucket.size += param.untyped_storage().nbytes() 386 bucket.params.append(name) 387 bucket.param_ids.append(id(param)) 388 389 def add_module_params_to_bucket(self, mod, bucket, processed_modules, prefix): 390 processed_modules.add(mod) 391 for name, param in mod.named_parameters(): 392 if param.requires_grad and not self._ignore_parameter(param): 393 self.add_param(bucket, param, f"{prefix}_{name}") 394 395 def add_param_args(self, bucket, node): 396 for arg in node.args: 397 if not isinstance(arg, torch.fx.node.Node): 398 continue 399 if arg.op != "placeholder": 400 continue 401 param = arg.meta["example_value"] 402 if ( 403 isinstance(param, torch.nn.Parameter) 404 and param.requires_grad 405 and not self._ignore_parameter(param) 406 ): 407 self.add_param(bucket, param, arg.target) 408 409 def compile_fn(self, gm: fx.GraphModule, example_inputs: List[torch.Tensor]): 410 """ 411 Implements graph splitting, first determining a set of of buckets by counting 412 parameter sizes in reverse graph order, then invoking the user/backend compiler 413 to compile each subgraph. Finally, stiches compiled graphs into one graphmodule 414 and returns its callable. 415 """ 416 if has_higher_order_op(gm): 417 # This indicates presence of a higher order op. For now, we 418 # have no way to break the higher order op into two buckets. 419 # Allowing higher order ops in the graph also requires 420 # changes in the split_module, becuase graph splitter 421 # currently assumes that all the args of all ops are 422 # tensors, but in the case of higher order ops, it could be 423 # a graph module. As a workaround, we are shortcircuiting 424 raise NotImplementedError( 425 "DDPOptimizer backend: Found a higher order op in the graph. " 426 "This is not supported. Please turn off DDP optimizer using " 427 "torch._dynamo.config.optimize_ddp=False. Note that this can " 428 "cause performance degradation because there will be one bucket " 429 "for the entire Dynamo graph. Please refer to this issue - " 430 "https://github.com/pytorch/pytorch/issues/104674." 431 ) 432 433 # 1: compute the partition map according to DDP bucket logic 434 buckets = [Bucket()] # (size, param_names) 435 processed_modules = set() 436 for node in reversed(gm.graph.nodes): 437 if node.op in ("output", "placeholder"): 438 continue 439 440 if ( 441 buckets[0].size >= self.bucket_bytes_cap 442 or len(buckets) == 1 443 and buckets[0].size >= self.first_bucket_cap 444 ): 445 if bucket_has_external_output(buckets[0]): 446 buckets.insert(0, Bucket()) 447 else: 448 # continue building this bucket past the point of filling its parameter capacity, 449 # to increase chances it contains at least one node that is either a global output or 450 # passed as input to a subsequent graph 451 452 if buckets[0].opcount_increased_to_capture_external_output == 0: 453 buckets[0].paramsize_before_opcount_increase = buckets[0].size 454 buckets[0].opcount_increased_to_capture_external_output += 1 455 456 if node.op == "call_function": 457 self.add_param_args(buckets[0], node) 458 459 elif node.op == "call_module": 460 target_mod = gm.get_submodule(node.target) 461 if target_mod not in processed_modules: 462 self.add_module_params_to_bucket( 463 target_mod, buckets[0], processed_modules, node.target 464 ) 465 elif node.op == "call_method": 466 if isinstance(node.args[0].target, str): 467 target_mod = None 468 try: 469 target_mod = gm.get_submodule(node.args[0].target) 470 except AttributeError: 471 pass 472 if target_mod is not None and target_mod not in processed_modules: 473 self.add_module_params_to_bucket( 474 target_mod, buckets[0], processed_modules, node.target 475 ) 476 # This handles situations like tmp = torch.mm(x, self.weight.t()) 477 # t: "f32[512, 512]" = l_self_seq_2_weight.t(); l_self_seq_2_weight = None 478 # tmp: "f32[512, 512]" = torch.mm(input_2, t); input_2 = t = None 479 self.add_param_args(buckets[0], node) 480 481 elif node.op == "get_attr": 482 maybe_param = getattr(gm, node.target) 483 if ( 484 isinstance(maybe_param, torch.nn.Parameter) 485 and maybe_param.requires_grad 486 and not self._ignore_parameter(maybe_param) 487 ): 488 self.add_param(buckets[0], maybe_param, node.target) 489 490 # All nodes have to be mapped to a bucket, even if they don't have their own params 491 # Ignored params still end up in buckets, we just don't count them towards the capacity 492 buckets[0].nodes.append(node) 493 494 if len(buckets) > 1 and buckets[0].size == 0: 495 # we collected a small preamble graph with ops that don't include parameters, fuse it back 496 buckets[1].nodes.extend(buckets[0].nodes) 497 assert len(buckets[0].params) == 0, "Params should be empty if size is 0" 498 del buckets[0] 499 500 # stash buckets for testing/debugging purposes 501 self.buckets = buckets 502 pretty_print_buckets(buckets, self.bucket_bytes_cap) 503 504 if len(buckets) == 1: 505 # bypass split/fuse logic if there is only one bucket 506 return self.backend_compile_fn(gm, example_inputs) 507 508 # 2: partition the graphmodule according to bucket capacity 509 partition_map = {} 510 for idx, b in enumerate(buckets): 511 for node in b.nodes: 512 partition_map[node] = idx 513 514 split_gm = fx.passes.split_module.split_module( 515 gm, None, lambda node: partition_map[node] 516 ) 517 518 debug_str = ( 519 f"\n---orig graph---\n{gm.graph}\n" 520 + f"\n---split graph---\n{split_gm.graph}\n" 521 ) 522 for name, module in split_gm.named_modules(): 523 if "." not in name and len(name): 524 # only print the submod graphs, not their children 525 debug_str += f"\n---{name} graph---\n{module.graph}\n" 526 debug_str += "\n---------------\n" 527 ddp_graph_log.debug(debug_str) 528 529 trace_structured( 530 "optimize_ddp_split_graph", 531 payload_fn=lambda: split_gm.print_readable(print_output=False), 532 ) 533 for name, module in split_gm.named_modules(): 534 if "." not in name and len(name): 535 trace_structured( 536 "optimize_ddp_split_child", 537 lambda: {"name": name}, 538 payload_fn=lambda: module.print_readable(print_output=False), 539 ) 540 541 fake_mode = detect_fake_mode(example_inputs) 542 if fake_mode is None: 543 fake_mode = torch._subclasses.fake_tensor.FakeTensorMode() 544 545 submod_compiler = SubmodCompiler(split_gm, self.backend_compile_fn, fake_mode) 546 submod_compiler.run(*example_inputs) 547 split_gm.recompile() 548 549 ddp_graph_log.debug( 550 "\n---final graph---\n%s\n---------------\n", split_gm.graph 551 ) 552 return split_gm 553