xref: /aosp_15_r20/external/pytorch/benchmarks/dynamo/distributed.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1*da0073e9SAndroid Build Coastguard Workerimport argparse
2*da0073e9SAndroid Build Coastguard Workerimport logging
3*da0073e9SAndroid Build Coastguard Workerimport os
4*da0073e9SAndroid Build Coastguard Workerfrom functools import partial
5*da0073e9SAndroid Build Coastguard Worker
6*da0073e9SAndroid Build Coastguard Workerimport torch
7*da0073e9SAndroid Build Coastguard Workerimport torch._dynamo as dynamo
8*da0073e9SAndroid Build Coastguard Workerimport torch.utils._pytree as pytree
9*da0073e9SAndroid Build Coastguard Workerfrom torch._dynamo.testing import reduce_to_scalar_loss
10*da0073e9SAndroid Build Coastguard Workerfrom torch.nn.parallel import DistributedDataParallel as DDP
11*da0073e9SAndroid Build Coastguard Workerfrom torch.profiler import profile, ProfilerActivity, record_function
12*da0073e9SAndroid Build Coastguard Worker
13*da0073e9SAndroid Build Coastguard Worker
14*da0073e9SAndroid Build Coastguard Workertry:
15*da0073e9SAndroid Build Coastguard Worker    from .common import timed
16*da0073e9SAndroid Build Coastguard Worker    from .dist_util import apply_fsdp, cleanup, get_model, model_iter_fn, setup
17*da0073e9SAndroid Build Coastguard Workerexcept ImportError:
18*da0073e9SAndroid Build Coastguard Worker    from common import timed
19*da0073e9SAndroid Build Coastguard Worker    from dist_util import apply_fsdp, cleanup, get_model, model_iter_fn, setup
20*da0073e9SAndroid Build Coastguard Worker
21*da0073e9SAndroid Build Coastguard Workerlog = logging.getLogger(__name__)
22*da0073e9SAndroid Build Coastguard Worker
23*da0073e9SAndroid Build Coastguard Worker
24*da0073e9SAndroid Build Coastguard Workerdef torchviz_model(args, model, inputs, rank):
25*da0073e9SAndroid Build Coastguard Worker    from torchviz import make_dot
26*da0073e9SAndroid Build Coastguard Worker
27*da0073e9SAndroid Build Coastguard Worker    outputs = model(*inputs)
28*da0073e9SAndroid Build Coastguard Worker    loss = reduce_to_scalar_loss(outputs)
29*da0073e9SAndroid Build Coastguard Worker    parameter_names = dict(model.named_parameters())
30*da0073e9SAndroid Build Coastguard Worker    dot = make_dot(loss, params=parameter_names, show_attrs=True, show_saved=True)
31*da0073e9SAndroid Build Coastguard Worker    if rank == 0:
32*da0073e9SAndroid Build Coastguard Worker        dot.render("torchviz.dot")
33*da0073e9SAndroid Build Coastguard Worker
34*da0073e9SAndroid Build Coastguard Worker
35*da0073e9SAndroid Build Coastguard Workerdef profile_model(args, model, inputs, rank):
36*da0073e9SAndroid Build Coastguard Worker    with profile(activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA]) as prof:
37*da0073e9SAndroid Build Coastguard Worker        for i in range(args.repeat):
38*da0073e9SAndroid Build Coastguard Worker            with record_function("Forward"):
39*da0073e9SAndroid Build Coastguard Worker                outputs = model(*inputs)
40*da0073e9SAndroid Build Coastguard Worker                loss = reduce_to_scalar_loss(outputs)
41*da0073e9SAndroid Build Coastguard Worker            with record_function("Backward"):
42*da0073e9SAndroid Build Coastguard Worker                loss.backward()
43*da0073e9SAndroid Build Coastguard Worker    if rank == 0:
44*da0073e9SAndroid Build Coastguard Worker        prof.export_chrome_trace(args.trace_file)
45*da0073e9SAndroid Build Coastguard Worker
46*da0073e9SAndroid Build Coastguard Worker
47*da0073e9SAndroid Build Coastguard Workerdef run_model(args, model, inputs, key):
48*da0073e9SAndroid Build Coastguard Worker    rank = int(os.getenv("RANK", 0))
49*da0073e9SAndroid Build Coastguard Worker    world_size = int(os.getenv("WORLD_SIZE", 1))
50*da0073e9SAndroid Build Coastguard Worker    # result_q = []
51*da0073e9SAndroid Build Coastguard Worker
52*da0073e9SAndroid Build Coastguard Worker    setup(rank, world_size)
53*da0073e9SAndroid Build Coastguard Worker    if args.device == "cuda":
54*da0073e9SAndroid Build Coastguard Worker        # needed for FSDP
55*da0073e9SAndroid Build Coastguard Worker        torch.cuda.set_device(rank)
56*da0073e9SAndroid Build Coastguard Worker
57*da0073e9SAndroid Build Coastguard Worker    dev_rank = f"{args.device}:{rank}"
58*da0073e9SAndroid Build Coastguard Worker    model = model.to(dev_rank)
59*da0073e9SAndroid Build Coastguard Worker
60*da0073e9SAndroid Build Coastguard Worker    def move_tensor(maybe_tensor):
61*da0073e9SAndroid Build Coastguard Worker        if torch.is_tensor(maybe_tensor):
62*da0073e9SAndroid Build Coastguard Worker            return maybe_tensor.to(dev_rank)
63*da0073e9SAndroid Build Coastguard Worker        return maybe_tensor
64*da0073e9SAndroid Build Coastguard Worker
65*da0073e9SAndroid Build Coastguard Worker    inputs = pytree.tree_map(move_tensor, inputs)
66*da0073e9SAndroid Build Coastguard Worker
67*da0073e9SAndroid Build Coastguard Worker    if args.fsdp:
68*da0073e9SAndroid Build Coastguard Worker        model = apply_fsdp(
69*da0073e9SAndroid Build Coastguard Worker            args,
70*da0073e9SAndroid Build Coastguard Worker            model,
71*da0073e9SAndroid Build Coastguard Worker            use_checkpointing=args.fsdp_checkpoint,
72*da0073e9SAndroid Build Coastguard Worker            use_wrap_policy=args.fsdp_wrap,
73*da0073e9SAndroid Build Coastguard Worker        )
74*da0073e9SAndroid Build Coastguard Worker    elif args.ddp:
75*da0073e9SAndroid Build Coastguard Worker        model = DDP(model)
76*da0073e9SAndroid Build Coastguard Worker
77*da0073e9SAndroid Build Coastguard Worker    if args.verbose:
78*da0073e9SAndroid Build Coastguard Worker        print(model)
79*da0073e9SAndroid Build Coastguard Worker
80*da0073e9SAndroid Build Coastguard Worker    if args.dynamo:
81*da0073e9SAndroid Build Coastguard Worker        dynamo.reset()
82*da0073e9SAndroid Build Coastguard Worker        if args.verbose:
83*da0073e9SAndroid Build Coastguard Worker            dynamo.config.verbose = True
84*da0073e9SAndroid Build Coastguard Worker            dynamo.config.log_level = logging.DEBUG
85*da0073e9SAndroid Build Coastguard Worker        if args.dynamo_no_optimize_ddp:
86*da0073e9SAndroid Build Coastguard Worker            dynamo.config.optimize_ddp = False
87*da0073e9SAndroid Build Coastguard Worker        if args.dynamo == "inductor" and args.fsdp:
88*da0073e9SAndroid Build Coastguard Worker            torch._inductor.config.triton.cudagraphs = False
89*da0073e9SAndroid Build Coastguard Worker            log.warning("disabling inductor cudagraphs for compatibility with FSDP")
90*da0073e9SAndroid Build Coastguard Worker
91*da0073e9SAndroid Build Coastguard Worker        def print_compile(gm, ex):
92*da0073e9SAndroid Build Coastguard Worker            print(
93*da0073e9SAndroid Build Coastguard Worker                f"print_compile:\n{str(gm.graph)}\n-----------------------------------------"
94*da0073e9SAndroid Build Coastguard Worker            )
95*da0073e9SAndroid Build Coastguard Worker            return gm
96*da0073e9SAndroid Build Coastguard Worker
97*da0073e9SAndroid Build Coastguard Worker        dynamo_ctx = dynamo.optimize(
98*da0073e9SAndroid Build Coastguard Worker            print_compile if args.dynamo == "print" else args.dynamo
99*da0073e9SAndroid Build Coastguard Worker        )
100*da0073e9SAndroid Build Coastguard Worker        model = dynamo_ctx(model)
101*da0073e9SAndroid Build Coastguard Worker
102*da0073e9SAndroid Build Coastguard Worker    # warmup
103*da0073e9SAndroid Build Coastguard Worker    _ = timed(model, model_iter_fn, inputs, times=3, return_result=False)
104*da0073e9SAndroid Build Coastguard Worker    t_total = timed(
105*da0073e9SAndroid Build Coastguard Worker        model, model_iter_fn, inputs, times=args.repeat, return_result=False
106*da0073e9SAndroid Build Coastguard Worker    )
107*da0073e9SAndroid Build Coastguard Worker    if args.torchviz:
108*da0073e9SAndroid Build Coastguard Worker        torchviz_model(args, model, inputs, rank)
109*da0073e9SAndroid Build Coastguard Worker    if args.profile:
110*da0073e9SAndroid Build Coastguard Worker        profile_model(args, model, inputs, rank)
111*da0073e9SAndroid Build Coastguard Worker
112*da0073e9SAndroid Build Coastguard Worker    cleanup()
113*da0073e9SAndroid Build Coastguard Worker    return t_total
114*da0073e9SAndroid Build Coastguard Worker
115*da0073e9SAndroid Build Coastguard Worker
116*da0073e9SAndroid Build Coastguard Workerif __name__ == "__main__":
117*da0073e9SAndroid Build Coastguard Worker    parser = argparse.ArgumentParser()
118*da0073e9SAndroid Build Coastguard Worker    parser.add_argument("--device", default="cuda")
119*da0073e9SAndroid Build Coastguard Worker    parser.add_argument(
120*da0073e9SAndroid Build Coastguard Worker        "--dynamo",
121*da0073e9SAndroid Build Coastguard Worker        default=None,
122*da0073e9SAndroid Build Coastguard Worker        help="if set to a str, uses dynamo[str] backend. else, eager",
123*da0073e9SAndroid Build Coastguard Worker    )
124*da0073e9SAndroid Build Coastguard Worker    parser.add_argument("--verbose", action="store_true")
125*da0073e9SAndroid Build Coastguard Worker    parser.add_argument("--batch-size", "--batch_size", default=None)
126*da0073e9SAndroid Build Coastguard Worker    parser.add_argument(
127*da0073e9SAndroid Build Coastguard Worker        "--torchviz", action="store_true", help="Dump autograd graph with torchviz"
128*da0073e9SAndroid Build Coastguard Worker    )
129*da0073e9SAndroid Build Coastguard Worker    parser.add_argument("--profile", action="store_true", help="Run the profiler")
130*da0073e9SAndroid Build Coastguard Worker    parser.add_argument(
131*da0073e9SAndroid Build Coastguard Worker        "--trace-file", "--trace_file", default="profile.json", help="Run the profiler"
132*da0073e9SAndroid Build Coastguard Worker    )
133*da0073e9SAndroid Build Coastguard Worker    parser.add_argument("--repeat", default=10, help="Repeats for timing run")
134*da0073e9SAndroid Build Coastguard Worker    parser.add_argument(
135*da0073e9SAndroid Build Coastguard Worker        "--dynamo-no-optimize-ddp",
136*da0073e9SAndroid Build Coastguard Worker        "--dynamo_no_optimize_ddp",
137*da0073e9SAndroid Build Coastguard Worker        action="store_true",
138*da0073e9SAndroid Build Coastguard Worker        help="Disable dynamo's ddp optimizer (enabled by default)",
139*da0073e9SAndroid Build Coastguard Worker    )
140*da0073e9SAndroid Build Coastguard Worker    parser.add_argument(
141*da0073e9SAndroid Build Coastguard Worker        "--fsdp-checkpoint",
142*da0073e9SAndroid Build Coastguard Worker        "--fsdp_checkpoint",
143*da0073e9SAndroid Build Coastguard Worker        action="store_true",
144*da0073e9SAndroid Build Coastguard Worker        help="Use gradient checkpointing via model-specific policy",
145*da0073e9SAndroid Build Coastguard Worker    )
146*da0073e9SAndroid Build Coastguard Worker    parser.add_argument(
147*da0073e9SAndroid Build Coastguard Worker        "--fsdp-wrap",
148*da0073e9SAndroid Build Coastguard Worker        "--fsdp_wrap",
149*da0073e9SAndroid Build Coastguard Worker        action="store_true",
150*da0073e9SAndroid Build Coastguard Worker        help="Apply fsdp to submodules via model-specific policy",
151*da0073e9SAndroid Build Coastguard Worker    )
152*da0073e9SAndroid Build Coastguard Worker
153*da0073e9SAndroid Build Coastguard Worker    dist_arg = parser.add_mutually_exclusive_group()
154*da0073e9SAndroid Build Coastguard Worker    dist_arg.add_argument("--ddp", action="store_true")
155*da0073e9SAndroid Build Coastguard Worker    dist_arg.add_argument("--fsdp", action="store_true")
156*da0073e9SAndroid Build Coastguard Worker
157*da0073e9SAndroid Build Coastguard Worker    model_arg = parser.add_mutually_exclusive_group(required=True)
158*da0073e9SAndroid Build Coastguard Worker    model_arg.add_argument(
159*da0073e9SAndroid Build Coastguard Worker        "--torchbench-model",
160*da0073e9SAndroid Build Coastguard Worker        "--torchbench_model",
161*da0073e9SAndroid Build Coastguard Worker        help="name of torchbench model, e.g. hf_Bert",
162*da0073e9SAndroid Build Coastguard Worker    )
163*da0073e9SAndroid Build Coastguard Worker    model_arg.add_argument(
164*da0073e9SAndroid Build Coastguard Worker        "--toy-model", "--toy_model", action="store_true", help="use toy model instead"
165*da0073e9SAndroid Build Coastguard Worker    )
166*da0073e9SAndroid Build Coastguard Worker    args = parser.parse_args()
167*da0073e9SAndroid Build Coastguard Worker
168*da0073e9SAndroid Build Coastguard Worker    model_name = args.torchbench_model
169*da0073e9SAndroid Build Coastguard Worker    if args.toy_model:
170*da0073e9SAndroid Build Coastguard Worker        model_name = "ToyModel"
171*da0073e9SAndroid Build Coastguard Worker    model, inputs = get_model(args)
172*da0073e9SAndroid Build Coastguard Worker
173*da0073e9SAndroid Build Coastguard Worker    fn = partial(run_model, args, model, inputs)
174*da0073e9SAndroid Build Coastguard Worker
175*da0073e9SAndroid Build Coastguard Worker    world_size = os.getenv("WORLD_SIZE", 1)
176*da0073e9SAndroid Build Coastguard Worker    t_total = fn(f"{model_name}_{world_size}")
177*da0073e9SAndroid Build Coastguard Worker    print(f"mean latency {t_total / args.repeat} across {args.repeat} runs")
178