xref: /aosp_15_r20/external/pytorch/torch/_dynamo/backends/distributed.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# mypy: ignore-errors
2
3import logging
4import traceback
5from dataclasses import dataclass, field
6from typing import Any, List, Optional
7from unittest import mock
8
9import torch
10from torch import fx
11from torch._dynamo.output_graph import GraphCompileReason
12from torch._dynamo.utils import deepcopy_to_fake_tensor, detect_fake_mode
13from torch._logging import trace_structured
14from torch.fx.node import Node
15
16
17# Regular log messages should go through 'log'.
18# ddp_graph_log is a separate artifact logger reserved for dumping graphs.
19# See docs/source/logging.rst for more info.
20log = logging.getLogger(__name__)
21ddp_graph_log = torch._logging.getArtifactLogger(__name__, "ddp_graphs")
22
23
24def args_str(args):
25    # a debug helper
26    if torch.is_tensor(args):
27        return f"T[{args.shape}]"
28    elif isinstance(args, tuple):
29        return f"tuple({', '.join([args_str(x) for x in args])})"
30    elif isinstance(args, list):
31        return f"list({', '.join([args_str(x) for x in args])})"
32    else:
33        return str(args)
34
35
36@dataclass
37class Bucket:
38    size: int = 0
39    params: List[str] = field(default_factory=list)
40    nodes: List[fx.Node] = field(default_factory=list)
41
42    # param_ids is just used for unit testing
43    param_ids: List = field(default_factory=list)
44
45    # keep track of any buckets that were extended for logging purposes
46    opcount_increased_to_capture_external_output: int = 0
47    paramsize_before_opcount_increase: int = 0
48
49
50def bucket_has_external_output(bucket: Bucket) -> bool:
51    nodes_in_bucket = set()
52    # we want to iterate in reverse order, but clumsi-luckily the bucket.nodes list was already created backwards
53    # so we don't reverse it here
54    for node in bucket.nodes:
55        # assume node.op != output, since those are filtered in the original iteration
56        nodes_in_bucket.add(node)
57        for user in node.users:
58            if user not in nodes_in_bucket:
59                return True
60    return False
61
62
63def pretty_print_buckets(buckets: List[Bucket], bucket_bytes_cap: int):
64    headers = ("Index", "Size (b)", "Param Names")
65    rows = []
66    extended_buckets = []
67    for idx, bucket in enumerate(reversed(buckets)):
68        if len(bucket.params) > 0:
69            rows.append((idx, bucket.size, bucket.params[0]))
70            for param in bucket.params[1:]:
71                rows.append((None, None, param))
72        if bucket.opcount_increased_to_capture_external_output > 0:
73            extended_buckets.append(
74                (
75                    idx,
76                    bucket.opcount_increased_to_capture_external_output,
77                    bucket.size - bucket.paramsize_before_opcount_increase,
78                )
79            )
80
81    if len(rows):
82        log.info(
83            "\nDDPOptimizer used bucket cap %s and created %d buckets. Enable debug logs for detailed bucket info.",
84            bucket_bytes_cap,
85            len(buckets),
86        )
87
88        if len(extended_buckets):
89            log.warning(
90                "Some buckets were extended beyond their requested parameter capacities"
91                " in order to ensure each subgraph has an output node, required for fx graph partitioning."
92                " This can be the case when a subgraph would have only contained nodes performing inplace mutation,"
93                " and returning no logical outputs. This should not be a problem, unless it results in too few graph"
94                " partitions for optimal DDP performance."
95            )
96
97        try:
98            from tabulate import tabulate
99
100            log.debug(
101                "\nDDPOptimizer produced the following bucket assignments:\n%s",
102                tabulate(rows, headers=headers, tablefmt="simple_grid"),
103            )
104
105            if len(extended_buckets):
106                log.warning(
107                    "DDPOptimizer extended these buckets to ensure per-subgraph output nodes:\n%s",
108                    tabulate(
109                        extended_buckets,
110                        headers=("Index", "Extra Ops", "Extra Param Size (b)"),
111                        tablefmt="simple_grid",
112                    ),
113                )
114        except ImportError:
115            log.debug(
116                "Please `pip install tabulate` in order to display ddp bucket sizes and diagnostic information."
117            )
118    else:
119        log.debug("DDPOptimizer captured no parameters and did not split this graph.")
120
121
122def has_higher_order_op(gm):
123    # Check if there is a higher order op in the graph
124    for node in gm.graph.nodes:
125        if node.op == "get_attr":
126            maybe_param = getattr(gm, node.target)
127            if isinstance(maybe_param, torch.fx.GraphModule):
128                return True
129    return False
130
131
132# compile each of the partitioned submodules using the user-provided compiler
133class SubmodCompiler(torch.fx.interpreter.Interpreter):
134    def __init__(self, module, compiler, fake_mode) -> None:
135        super().__init__(module)
136        self.compiler = compiler
137        self.fake_mode = fake_mode
138
139    def compile_submod(self, input_mod, args, kwargs):
140        """
141        Compile the submodule,
142        using a wrapper to make sure its output is always a tuple,
143        which is required by AotAutograd based compilers
144        """
145        assert len(kwargs) == 0, "We assume only args for these modules"
146
147        class WrapperModule(torch.nn.Module):
148            def __init__(self, submod, unwrap_singleton_tuple) -> None:
149                super().__init__()
150                self.submod = submod
151                self.unwrap_singleton_tuple = unwrap_singleton_tuple
152
153            def forward(self, *args):
154                x = self.submod(*args)
155                # TODO(whc)
156                # for some reason the isinstance check is necessary if I split one node per submod
157                # - even though I supposedly wrapped the output in a tuple in those cases, the real
158                # compiled module was still returning a tensor
159                if self.unwrap_singleton_tuple and isinstance(x, (tuple, list)):
160                    return x[0]
161                return x
162
163        unwrap_singleton_tuple = False
164        for sn in input_mod.graph.nodes:
165            if sn.op == "output":
166                if not isinstance(sn.args[0], tuple):
167                    unwrap_singleton_tuple = True
168                    sn.args = (sn.args,)
169
170        input_mod.recompile()
171        input_mod.compile_subgraph_reason = GraphCompileReason(
172            "DDPOptimizer intentional graph-break (See Note [DDPOptimizer])."
173            " Set `torch._dynamo.config.optimize_ddp = False` to disable.",
174            [
175                # it's close to useless to get a real stacktrace here, and quite verbose.
176                traceback.FrameSummary(__file__, 0, DDPOptimizer),
177            ],
178        )
179
180        wrapper = WrapperModule(
181            self.compiler(input_mod, args),
182            unwrap_singleton_tuple,
183        )
184        return wrapper
185
186    # Note:
187    #
188    # The way distributed works today around fake tensors can be somewhat confusing.
189    # Some of these codepaths are shared in both runtime, and compile time. The presence
190    # of a fake_mode, read off of fake tensor inputs, dictates how we will operate.
191    #
192    # A few things to keep in mind:
193    #
194    # 1) We invoke `compile_submod` with a real module. The output of that gets stored
195    # on the graph via `self.module.add_submodule(n.target, compiled_submod_real)`.
196    #
197    # 2) When running a call_module targeted node, if we have a fake_mode, we fakify the
198    # module we got from self.fetch_attr(n.target). Regardless of fake_mode, we then execute it.
199    #
200    # 3) Fake tensors should always be around during compile time.
201    #
202    # 4) Fake tensors should never be around at runtime.
203    #
204    # 5) We end up with a compilation mode that takes a real submodule and fake tensors,
205    # to match what aot_autograd expects. See Note: [Fake Modules and AOTAutograd]
206    def run_node(self, n: Node) -> Any:
207        args, kwargs = self.fetch_args_kwargs_from_env(n)
208        new_args = []
209        assert self.fake_mode
210        for arg in args:
211            if isinstance(arg, torch.Tensor) and not isinstance(
212                arg, torch._subclasses.FakeTensor
213            ):
214                new_args.append(torch._dynamo.utils.to_fake_tensor(arg, self.fake_mode))
215            else:
216                new_args.append(arg)
217
218        log.debug("run_node %s, %s got args %s", n.op, n.target, args_str(args))
219        assert isinstance(args, tuple)
220        assert isinstance(kwargs, dict)
221
222        if n.op == "call_module":
223            real_mod = self.fetch_attr(n.target)
224            if self.fake_mode:
225                curr_submod = deepcopy_to_fake_tensor(real_mod, self.fake_mode)
226            else:
227                curr_submod = real_mod
228
229            ddp_graph_log.debug("\n---%s graph---\n%s", n.target, curr_submod.graph)
230
231            # When calling the compiler on the submod, inputs (new_args) are expected to
232            # be FakeTensors already since Dynamo would have made them FakeTensors in the
233            # non-DDP flow.  However, the parameters are _not_ expected to be FakeTensors,
234            # since this wrapping happens during compilation
235
236            # Note: Returning Fake Tensors on First AOT Autograd Call
237            #
238            # Inductor will optimize strides of outputs when it deems it profitable.
239            # For instance, converting to channels last. When we split the graph here
240            # into multiple inductor compilations, we need to make sure that the
241            # output strides of one compilation is appropriately passed to the subsequent
242            # compilations. However, the mapping from inductor output to dynamo output
243            # is non-trivial due to aot_autograd's deduping, de-aliasing, mutation, re-writing,
244            # subclass handling, etc. In order to replay all this logic we set a flag such that
245            # the first invocation of inductor in aot_autograd will return Fake Tensors with
246            # appropriate strides. Then, all of aot autograd's runtime logic is replayed.
247            # This gives us the appropriately strided outputs here which will reflect runtime strides.
248
249            class FakeifyFirstAOTInvocationGuard:
250                def __init__(self) -> None:
251                    self.tc = torch._guards.TracingContext.try_get()
252                    assert self.tc
253                    torch._guards.TracingContext.try_get().fakify_first_call = True
254
255                def __del__(self) -> None:
256                    self.tc.fakify_first_call = False
257
258            # For aot_eager and other backends, tracing context is not set
259            has_tracing_context = torch._guards.TracingContext.try_get() is not None
260            if has_tracing_context:
261                g = FakeifyFirstAOTInvocationGuard()
262
263            from torch._dynamo.utils import counters
264
265            init = counters["aot_autograd"]["total"]
266            compiled_submod_real = self.compile_submod(real_mod, new_args, kwargs)
267
268            # TODO - better way of doing this?
269            # Only aot autograd handles fakifying first call
270            invoked_aot_autograd = init != counters["aot_autograd"]["total"]
271
272            # We update the original (outer) graph with a call into the compiled module
273            # instead of the uncompiled one.
274            self.module.delete_submodule(n.target)
275            n.target = "compiled_" + n.target
276            self.module.add_submodule(n.target, compiled_submod_real)
277
278            # Finally, we have to produce inputs for use compiling the next submodule,
279            # and these need to be FakeTensors, so we execute the module under fake_mode
280            # Because parameters are not fake we patch fake tensor mode to allow non fake inputs
281            with self.fake_mode, mock.patch.object(
282                self.fake_mode, "allow_non_fake_inputs", True
283            ):
284                if has_tracing_context and invoked_aot_autograd:
285                    out = compiled_submod_real(*new_args, **kwargs)
286                    # output should be fake or subclass
287                    assert all(
288                        (not isinstance(t, torch.Tensor) or type(t) is not torch.Tensor)
289                        for t in (out if isinstance(out, (list, tuple)) else [out])
290                    )
291                    return out
292                else:
293                    return curr_submod(*new_args, **kwargs)
294        else:
295            # placeholder or output nodes don't need to get compiled, just executed
296            return getattr(self, n.op)(n.target, new_args, kwargs)
297
298
299class DDPOptimizer:
300    """Note [DDPOptimizer]
301    DDPOptimizer applies when dynamo compiles models wrapped in DistributedDataParallel (DDP),
302    breaking the dynamo graph into chunks to compile separately, with the breaks aligning to
303    the boundaries of gradient-allreduce buckets chosen by DDP.
304
305    Background/Motivation
306     - DDP uses allreduce collectives to synchronize partial gradients computed on different workers
307     - DDP groups gradient allreduces into 'buckets' to optimize communication efficiency of all-reduce
308     - Parameters grouped into buckets are assumed to be adjacent in time, so they become ready
309       at around the same time during backward and thus can share the same allreduce efficiently
310     - Allreduces must overlap with backward compute for optimal training performance
311     - DDP schedules allreduces using 'hooks' fired from the c++ autograd engine in pytorch, which
312       operates when individual grads become 'ready'
313     - Dynamo+AOTAutograd produces a single fused graph that runs 'atomically' from the perspective of the
314       autograd engine, such that all gradients become 'ready' at the same time.  Hooks fire after the whole
315       fused backward function executes, preventing any overlap of compute and communication
316
317    Algorithm
318     - DDPOptimizer starts off with an FX graph traced by dynamo which represents forward.  It can traverse
319       this graph in reverse order to determine the true order that gradients will become ready during backward.
320     - Parameter sizes are counted in reverse order, up to a bucket size limit, at which point a new bucket is started
321       and a graph break introduced
322     - Each of the subgraphs is compiled by the compiler provided to dynamo by the user, and then fused back together
323       into an outer module that is returned to the user
324
325    Notes
326     - It would be better to enforce (by adding an API to DDP) that the bucket splits chosen here are used by DDP,
327       and that DDP does not need to detect or optimize bucket order by observing execution at runtime, as it does
328       in eager.
329     - If Dynamo can't capture a whole graph for the portion of the model wrapped by DDP, this algorithm will currently
330       produce splits that do not necessarily align with the buckets used by DDP.  This should result in performance
331       degradation approaching the baseline case where graph-splits are not used, but not worse.
332     - If the backend compiler fails to compile a single subgraph, it will execute eagerly despite the rest of the
333       subgraphs being compiled
334     - DDP has a 'parameters_and_buffers_to_ignore' field, which DDPOptimizer attempts to honor by reading markers
335       left by DDP on individual parameters.  In cases where other transformations, such as reparameterization, are
336       also used, the ignore markers could be lost.  If DDPOptimizer fails to ignore a parameter ignored by DDP,
337       it is not catastrophic but could impact performance by choosing sub-optimal bucket splits.
338     - DDPOptimizer always ignores all buffers, regardless of their ignore flag, since buffers do not require gradients,
339       and therefore aren't allreduced by DDP.  (They are broadcast during forward, but this is not covered by
340       DDPOptimizer)
341
342    Debugging
343     - Generally, it is easiest to debug DDPOptimizer in a single process program, using pdb.
344     - In many cases, the log messages are helpful (they show bucket size assignments)-
345       just set TORCH_LOGS env to include any of 'dynamo', 'distributed', or 'dist_ddp'.
346     - See `benchmarks/dynamo/distributed.py` for a simple harness that will run a toy model or a torchbench model
347       in a single process (or with torchrun, in multiple processes)
348
349    Args:
350        bucket_bytes_cap (int): Controls the size of buckets, in bytes, used to determine graphbreaks.  Should be
351            set to match the equivalent parameter on the original DDP module.
352
353        backend_compile_fn (callable): A dynamo compiler function, to be invoked to compile each subgraph.
354
355        first_bucket_cap (int): Controls the size of the first bucket.  Should match DDP's first bucket cap.  DDP
356            special-cases the first bucket size since it is sometimes optimal to start a small allreduce early.
357
358    """
359
360    def __init__(
361        self,
362        bucket_bytes_cap: int,
363        backend_compile_fn,
364        first_bucket_cap: Optional[int] = None,
365    ) -> None:
366        if first_bucket_cap is not None:
367            self.first_bucket_cap = first_bucket_cap
368        elif torch.distributed.is_available():
369            # this constant comes from C10D lib which is not always built
370            self.first_bucket_cap = torch.distributed._DEFAULT_FIRST_BUCKET_BYTES
371        else:
372            self.first_bucket_cap = bucket_bytes_cap
373
374        self.bucket_bytes_cap = bucket_bytes_cap
375        assert (
376            self.first_bucket_cap <= self.bucket_bytes_cap
377        ), "First bucket should be smaller/equal to other buckets to get comms warmed up ASAP"
378
379        self.backend_compile_fn = backend_compile_fn
380
381    def _ignore_parameter(self, parameter):
382        return hasattr(parameter, "_ddp_ignored") and parameter._ddp_ignored
383
384    def add_param(self, bucket, param, name):
385        bucket.size += param.untyped_storage().nbytes()
386        bucket.params.append(name)
387        bucket.param_ids.append(id(param))
388
389    def add_module_params_to_bucket(self, mod, bucket, processed_modules, prefix):
390        processed_modules.add(mod)
391        for name, param in mod.named_parameters():
392            if param.requires_grad and not self._ignore_parameter(param):
393                self.add_param(bucket, param, f"{prefix}_{name}")
394
395    def add_param_args(self, bucket, node):
396        for arg in node.args:
397            if not isinstance(arg, torch.fx.node.Node):
398                continue
399            if arg.op != "placeholder":
400                continue
401            param = arg.meta["example_value"]
402            if (
403                isinstance(param, torch.nn.Parameter)
404                and param.requires_grad
405                and not self._ignore_parameter(param)
406            ):
407                self.add_param(bucket, param, arg.target)
408
409    def compile_fn(self, gm: fx.GraphModule, example_inputs: List[torch.Tensor]):
410        """
411        Implements graph splitting, first determining a set of of buckets by counting
412        parameter sizes in reverse graph order, then invoking the user/backend compiler
413        to compile each subgraph. Finally, stiches compiled graphs into one graphmodule
414        and returns its callable.
415        """
416        if has_higher_order_op(gm):
417            # This indicates presence of a higher order op. For now, we
418            # have no way to break the higher order op into two buckets.
419            # Allowing higher order ops in the graph also requires
420            # changes in the split_module, becuase graph splitter
421            # currently assumes that all the args of all ops are
422            # tensors, but in the case of higher order ops, it could be
423            # a graph module. As a workaround, we are shortcircuiting
424            raise NotImplementedError(
425                "DDPOptimizer backend: Found a higher order op in the graph. "
426                "This is not supported. Please turn off DDP optimizer using "
427                "torch._dynamo.config.optimize_ddp=False. Note that this can "
428                "cause performance degradation because there will be one bucket "
429                "for the entire Dynamo graph. Please refer to this issue - "
430                "https://github.com/pytorch/pytorch/issues/104674."
431            )
432
433        # 1: compute the partition map according to DDP bucket logic
434        buckets = [Bucket()]  # (size, param_names)
435        processed_modules = set()
436        for node in reversed(gm.graph.nodes):
437            if node.op in ("output", "placeholder"):
438                continue
439
440            if (
441                buckets[0].size >= self.bucket_bytes_cap
442                or len(buckets) == 1
443                and buckets[0].size >= self.first_bucket_cap
444            ):
445                if bucket_has_external_output(buckets[0]):
446                    buckets.insert(0, Bucket())
447                else:
448                    # continue building this bucket past the point of filling its parameter capacity,
449                    # to increase chances it contains at least one node that is either a global output or
450                    # passed as input to a subsequent graph
451
452                    if buckets[0].opcount_increased_to_capture_external_output == 0:
453                        buckets[0].paramsize_before_opcount_increase = buckets[0].size
454                    buckets[0].opcount_increased_to_capture_external_output += 1
455
456            if node.op == "call_function":
457                self.add_param_args(buckets[0], node)
458
459            elif node.op == "call_module":
460                target_mod = gm.get_submodule(node.target)
461                if target_mod not in processed_modules:
462                    self.add_module_params_to_bucket(
463                        target_mod, buckets[0], processed_modules, node.target
464                    )
465            elif node.op == "call_method":
466                if isinstance(node.args[0].target, str):
467                    target_mod = None
468                    try:
469                        target_mod = gm.get_submodule(node.args[0].target)
470                    except AttributeError:
471                        pass
472                    if target_mod is not None and target_mod not in processed_modules:
473                        self.add_module_params_to_bucket(
474                            target_mod, buckets[0], processed_modules, node.target
475                        )
476                    # This handles situations like  tmp = torch.mm(x, self.weight.t())
477                    # t: "f32[512, 512]" = l_self_seq_2_weight.t();  l_self_seq_2_weight = None
478                    # tmp: "f32[512, 512]" = torch.mm(input_2, t);  input_2 = t = None
479                    self.add_param_args(buckets[0], node)
480
481            elif node.op == "get_attr":
482                maybe_param = getattr(gm, node.target)
483                if (
484                    isinstance(maybe_param, torch.nn.Parameter)
485                    and maybe_param.requires_grad
486                    and not self._ignore_parameter(maybe_param)
487                ):
488                    self.add_param(buckets[0], maybe_param, node.target)
489
490            # All nodes have to be mapped to a bucket, even if they don't have their own params
491            # Ignored params still end up in buckets, we just don't count them towards the capacity
492            buckets[0].nodes.append(node)
493
494        if len(buckets) > 1 and buckets[0].size == 0:
495            # we collected a small preamble graph with ops that don't include parameters, fuse it back
496            buckets[1].nodes.extend(buckets[0].nodes)
497            assert len(buckets[0].params) == 0, "Params should be empty if size is 0"
498            del buckets[0]
499
500        # stash buckets for testing/debugging purposes
501        self.buckets = buckets
502        pretty_print_buckets(buckets, self.bucket_bytes_cap)
503
504        if len(buckets) == 1:
505            # bypass split/fuse logic if there is only one bucket
506            return self.backend_compile_fn(gm, example_inputs)
507
508        # 2: partition the graphmodule according to bucket capacity
509        partition_map = {}
510        for idx, b in enumerate(buckets):
511            for node in b.nodes:
512                partition_map[node] = idx
513
514        split_gm = fx.passes.split_module.split_module(
515            gm, None, lambda node: partition_map[node]
516        )
517
518        debug_str = (
519            f"\n---orig graph---\n{gm.graph}\n"
520            + f"\n---split graph---\n{split_gm.graph}\n"
521        )
522        for name, module in split_gm.named_modules():
523            if "." not in name and len(name):
524                # only print the submod graphs, not their children
525                debug_str += f"\n---{name} graph---\n{module.graph}\n"
526        debug_str += "\n---------------\n"
527        ddp_graph_log.debug(debug_str)
528
529        trace_structured(
530            "optimize_ddp_split_graph",
531            payload_fn=lambda: split_gm.print_readable(print_output=False),
532        )
533        for name, module in split_gm.named_modules():
534            if "." not in name and len(name):
535                trace_structured(
536                    "optimize_ddp_split_child",
537                    lambda: {"name": name},
538                    payload_fn=lambda: module.print_readable(print_output=False),
539                )
540
541        fake_mode = detect_fake_mode(example_inputs)
542        if fake_mode is None:
543            fake_mode = torch._subclasses.fake_tensor.FakeTensorMode()
544
545        submod_compiler = SubmodCompiler(split_gm, self.backend_compile_fn, fake_mode)
546        submod_compiler.run(*example_inputs)
547        split_gm.recompile()
548
549        ddp_graph_log.debug(
550            "\n---final graph---\n%s\n---------------\n", split_gm.graph
551        )
552        return split_gm
553