xref: /aosp_15_r20/external/pytorch/benchmarks/dynamo/distributed.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1import argparse
2import logging
3import os
4from functools import partial
5
6import torch
7import torch._dynamo as dynamo
8import torch.utils._pytree as pytree
9from torch._dynamo.testing import reduce_to_scalar_loss
10from torch.nn.parallel import DistributedDataParallel as DDP
11from torch.profiler import profile, ProfilerActivity, record_function
12
13
14try:
15    from .common import timed
16    from .dist_util import apply_fsdp, cleanup, get_model, model_iter_fn, setup
17except ImportError:
18    from common import timed
19    from dist_util import apply_fsdp, cleanup, get_model, model_iter_fn, setup
20
21log = logging.getLogger(__name__)
22
23
24def torchviz_model(args, model, inputs, rank):
25    from torchviz import make_dot
26
27    outputs = model(*inputs)
28    loss = reduce_to_scalar_loss(outputs)
29    parameter_names = dict(model.named_parameters())
30    dot = make_dot(loss, params=parameter_names, show_attrs=True, show_saved=True)
31    if rank == 0:
32        dot.render("torchviz.dot")
33
34
35def profile_model(args, model, inputs, rank):
36    with profile(activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA]) as prof:
37        for i in range(args.repeat):
38            with record_function("Forward"):
39                outputs = model(*inputs)
40                loss = reduce_to_scalar_loss(outputs)
41            with record_function("Backward"):
42                loss.backward()
43    if rank == 0:
44        prof.export_chrome_trace(args.trace_file)
45
46
47def run_model(args, model, inputs, key):
48    rank = int(os.getenv("RANK", 0))
49    world_size = int(os.getenv("WORLD_SIZE", 1))
50    # result_q = []
51
52    setup(rank, world_size)
53    if args.device == "cuda":
54        # needed for FSDP
55        torch.cuda.set_device(rank)
56
57    dev_rank = f"{args.device}:{rank}"
58    model = model.to(dev_rank)
59
60    def move_tensor(maybe_tensor):
61        if torch.is_tensor(maybe_tensor):
62            return maybe_tensor.to(dev_rank)
63        return maybe_tensor
64
65    inputs = pytree.tree_map(move_tensor, inputs)
66
67    if args.fsdp:
68        model = apply_fsdp(
69            args,
70            model,
71            use_checkpointing=args.fsdp_checkpoint,
72            use_wrap_policy=args.fsdp_wrap,
73        )
74    elif args.ddp:
75        model = DDP(model)
76
77    if args.verbose:
78        print(model)
79
80    if args.dynamo:
81        dynamo.reset()
82        if args.verbose:
83            dynamo.config.verbose = True
84            dynamo.config.log_level = logging.DEBUG
85        if args.dynamo_no_optimize_ddp:
86            dynamo.config.optimize_ddp = False
87        if args.dynamo == "inductor" and args.fsdp:
88            torch._inductor.config.triton.cudagraphs = False
89            log.warning("disabling inductor cudagraphs for compatibility with FSDP")
90
91        def print_compile(gm, ex):
92            print(
93                f"print_compile:\n{str(gm.graph)}\n-----------------------------------------"
94            )
95            return gm
96
97        dynamo_ctx = dynamo.optimize(
98            print_compile if args.dynamo == "print" else args.dynamo
99        )
100        model = dynamo_ctx(model)
101
102    # warmup
103    _ = timed(model, model_iter_fn, inputs, times=3, return_result=False)
104    t_total = timed(
105        model, model_iter_fn, inputs, times=args.repeat, return_result=False
106    )
107    if args.torchviz:
108        torchviz_model(args, model, inputs, rank)
109    if args.profile:
110        profile_model(args, model, inputs, rank)
111
112    cleanup()
113    return t_total
114
115
116if __name__ == "__main__":
117    parser = argparse.ArgumentParser()
118    parser.add_argument("--device", default="cuda")
119    parser.add_argument(
120        "--dynamo",
121        default=None,
122        help="if set to a str, uses dynamo[str] backend. else, eager",
123    )
124    parser.add_argument("--verbose", action="store_true")
125    parser.add_argument("--batch-size", "--batch_size", default=None)
126    parser.add_argument(
127        "--torchviz", action="store_true", help="Dump autograd graph with torchviz"
128    )
129    parser.add_argument("--profile", action="store_true", help="Run the profiler")
130    parser.add_argument(
131        "--trace-file", "--trace_file", default="profile.json", help="Run the profiler"
132    )
133    parser.add_argument("--repeat", default=10, help="Repeats for timing run")
134    parser.add_argument(
135        "--dynamo-no-optimize-ddp",
136        "--dynamo_no_optimize_ddp",
137        action="store_true",
138        help="Disable dynamo's ddp optimizer (enabled by default)",
139    )
140    parser.add_argument(
141        "--fsdp-checkpoint",
142        "--fsdp_checkpoint",
143        action="store_true",
144        help="Use gradient checkpointing via model-specific policy",
145    )
146    parser.add_argument(
147        "--fsdp-wrap",
148        "--fsdp_wrap",
149        action="store_true",
150        help="Apply fsdp to submodules via model-specific policy",
151    )
152
153    dist_arg = parser.add_mutually_exclusive_group()
154    dist_arg.add_argument("--ddp", action="store_true")
155    dist_arg.add_argument("--fsdp", action="store_true")
156
157    model_arg = parser.add_mutually_exclusive_group(required=True)
158    model_arg.add_argument(
159        "--torchbench-model",
160        "--torchbench_model",
161        help="name of torchbench model, e.g. hf_Bert",
162    )
163    model_arg.add_argument(
164        "--toy-model", "--toy_model", action="store_true", help="use toy model instead"
165    )
166    args = parser.parse_args()
167
168    model_name = args.torchbench_model
169    if args.toy_model:
170        model_name = "ToyModel"
171    model, inputs = get_model(args)
172
173    fn = partial(run_model, args, model, inputs)
174
175    world_size = os.getenv("WORLD_SIZE", 1)
176    t_total = fn(f"{model_name}_{world_size}")
177    print(f"mean latency {t_total / args.repeat} across {args.repeat} runs")
178