1*da0073e9SAndroid Build Coastguard Workerimport argparse 2*da0073e9SAndroid Build Coastguard Workerimport logging 3*da0073e9SAndroid Build Coastguard Workerimport os 4*da0073e9SAndroid Build Coastguard Workerfrom functools import partial 5*da0073e9SAndroid Build Coastguard Worker 6*da0073e9SAndroid Build Coastguard Workerimport torch 7*da0073e9SAndroid Build Coastguard Workerimport torch._dynamo as dynamo 8*da0073e9SAndroid Build Coastguard Workerimport torch.utils._pytree as pytree 9*da0073e9SAndroid Build Coastguard Workerfrom torch._dynamo.testing import reduce_to_scalar_loss 10*da0073e9SAndroid Build Coastguard Workerfrom torch.nn.parallel import DistributedDataParallel as DDP 11*da0073e9SAndroid Build Coastguard Workerfrom torch.profiler import profile, ProfilerActivity, record_function 12*da0073e9SAndroid Build Coastguard Worker 13*da0073e9SAndroid Build Coastguard Worker 14*da0073e9SAndroid Build Coastguard Workertry: 15*da0073e9SAndroid Build Coastguard Worker from .common import timed 16*da0073e9SAndroid Build Coastguard Worker from .dist_util import apply_fsdp, cleanup, get_model, model_iter_fn, setup 17*da0073e9SAndroid Build Coastguard Workerexcept ImportError: 18*da0073e9SAndroid Build Coastguard Worker from common import timed 19*da0073e9SAndroid Build Coastguard Worker from dist_util import apply_fsdp, cleanup, get_model, model_iter_fn, setup 20*da0073e9SAndroid Build Coastguard Worker 21*da0073e9SAndroid Build Coastguard Workerlog = logging.getLogger(__name__) 22*da0073e9SAndroid Build Coastguard Worker 23*da0073e9SAndroid Build Coastguard Worker 24*da0073e9SAndroid Build Coastguard Workerdef torchviz_model(args, model, inputs, rank): 25*da0073e9SAndroid Build Coastguard Worker from torchviz import make_dot 26*da0073e9SAndroid Build Coastguard Worker 27*da0073e9SAndroid Build Coastguard Worker outputs = model(*inputs) 28*da0073e9SAndroid Build Coastguard Worker loss = reduce_to_scalar_loss(outputs) 29*da0073e9SAndroid Build Coastguard Worker parameter_names = dict(model.named_parameters()) 30*da0073e9SAndroid Build Coastguard Worker dot = make_dot(loss, params=parameter_names, show_attrs=True, show_saved=True) 31*da0073e9SAndroid Build Coastguard Worker if rank == 0: 32*da0073e9SAndroid Build Coastguard Worker dot.render("torchviz.dot") 33*da0073e9SAndroid Build Coastguard Worker 34*da0073e9SAndroid Build Coastguard Worker 35*da0073e9SAndroid Build Coastguard Workerdef profile_model(args, model, inputs, rank): 36*da0073e9SAndroid Build Coastguard Worker with profile(activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA]) as prof: 37*da0073e9SAndroid Build Coastguard Worker for i in range(args.repeat): 38*da0073e9SAndroid Build Coastguard Worker with record_function("Forward"): 39*da0073e9SAndroid Build Coastguard Worker outputs = model(*inputs) 40*da0073e9SAndroid Build Coastguard Worker loss = reduce_to_scalar_loss(outputs) 41*da0073e9SAndroid Build Coastguard Worker with record_function("Backward"): 42*da0073e9SAndroid Build Coastguard Worker loss.backward() 43*da0073e9SAndroid Build Coastguard Worker if rank == 0: 44*da0073e9SAndroid Build Coastguard Worker prof.export_chrome_trace(args.trace_file) 45*da0073e9SAndroid Build Coastguard Worker 46*da0073e9SAndroid Build Coastguard Worker 47*da0073e9SAndroid Build Coastguard Workerdef run_model(args, model, inputs, key): 48*da0073e9SAndroid Build Coastguard Worker rank = int(os.getenv("RANK", 0)) 49*da0073e9SAndroid Build Coastguard Worker world_size = int(os.getenv("WORLD_SIZE", 1)) 50*da0073e9SAndroid Build Coastguard Worker # result_q = [] 51*da0073e9SAndroid Build Coastguard Worker 52*da0073e9SAndroid Build Coastguard Worker setup(rank, world_size) 53*da0073e9SAndroid Build Coastguard Worker if args.device == "cuda": 54*da0073e9SAndroid Build Coastguard Worker # needed for FSDP 55*da0073e9SAndroid Build Coastguard Worker torch.cuda.set_device(rank) 56*da0073e9SAndroid Build Coastguard Worker 57*da0073e9SAndroid Build Coastguard Worker dev_rank = f"{args.device}:{rank}" 58*da0073e9SAndroid Build Coastguard Worker model = model.to(dev_rank) 59*da0073e9SAndroid Build Coastguard Worker 60*da0073e9SAndroid Build Coastguard Worker def move_tensor(maybe_tensor): 61*da0073e9SAndroid Build Coastguard Worker if torch.is_tensor(maybe_tensor): 62*da0073e9SAndroid Build Coastguard Worker return maybe_tensor.to(dev_rank) 63*da0073e9SAndroid Build Coastguard Worker return maybe_tensor 64*da0073e9SAndroid Build Coastguard Worker 65*da0073e9SAndroid Build Coastguard Worker inputs = pytree.tree_map(move_tensor, inputs) 66*da0073e9SAndroid Build Coastguard Worker 67*da0073e9SAndroid Build Coastguard Worker if args.fsdp: 68*da0073e9SAndroid Build Coastguard Worker model = apply_fsdp( 69*da0073e9SAndroid Build Coastguard Worker args, 70*da0073e9SAndroid Build Coastguard Worker model, 71*da0073e9SAndroid Build Coastguard Worker use_checkpointing=args.fsdp_checkpoint, 72*da0073e9SAndroid Build Coastguard Worker use_wrap_policy=args.fsdp_wrap, 73*da0073e9SAndroid Build Coastguard Worker ) 74*da0073e9SAndroid Build Coastguard Worker elif args.ddp: 75*da0073e9SAndroid Build Coastguard Worker model = DDP(model) 76*da0073e9SAndroid Build Coastguard Worker 77*da0073e9SAndroid Build Coastguard Worker if args.verbose: 78*da0073e9SAndroid Build Coastguard Worker print(model) 79*da0073e9SAndroid Build Coastguard Worker 80*da0073e9SAndroid Build Coastguard Worker if args.dynamo: 81*da0073e9SAndroid Build Coastguard Worker dynamo.reset() 82*da0073e9SAndroid Build Coastguard Worker if args.verbose: 83*da0073e9SAndroid Build Coastguard Worker dynamo.config.verbose = True 84*da0073e9SAndroid Build Coastguard Worker dynamo.config.log_level = logging.DEBUG 85*da0073e9SAndroid Build Coastguard Worker if args.dynamo_no_optimize_ddp: 86*da0073e9SAndroid Build Coastguard Worker dynamo.config.optimize_ddp = False 87*da0073e9SAndroid Build Coastguard Worker if args.dynamo == "inductor" and args.fsdp: 88*da0073e9SAndroid Build Coastguard Worker torch._inductor.config.triton.cudagraphs = False 89*da0073e9SAndroid Build Coastguard Worker log.warning("disabling inductor cudagraphs for compatibility with FSDP") 90*da0073e9SAndroid Build Coastguard Worker 91*da0073e9SAndroid Build Coastguard Worker def print_compile(gm, ex): 92*da0073e9SAndroid Build Coastguard Worker print( 93*da0073e9SAndroid Build Coastguard Worker f"print_compile:\n{str(gm.graph)}\n-----------------------------------------" 94*da0073e9SAndroid Build Coastguard Worker ) 95*da0073e9SAndroid Build Coastguard Worker return gm 96*da0073e9SAndroid Build Coastguard Worker 97*da0073e9SAndroid Build Coastguard Worker dynamo_ctx = dynamo.optimize( 98*da0073e9SAndroid Build Coastguard Worker print_compile if args.dynamo == "print" else args.dynamo 99*da0073e9SAndroid Build Coastguard Worker ) 100*da0073e9SAndroid Build Coastguard Worker model = dynamo_ctx(model) 101*da0073e9SAndroid Build Coastguard Worker 102*da0073e9SAndroid Build Coastguard Worker # warmup 103*da0073e9SAndroid Build Coastguard Worker _ = timed(model, model_iter_fn, inputs, times=3, return_result=False) 104*da0073e9SAndroid Build Coastguard Worker t_total = timed( 105*da0073e9SAndroid Build Coastguard Worker model, model_iter_fn, inputs, times=args.repeat, return_result=False 106*da0073e9SAndroid Build Coastguard Worker ) 107*da0073e9SAndroid Build Coastguard Worker if args.torchviz: 108*da0073e9SAndroid Build Coastguard Worker torchviz_model(args, model, inputs, rank) 109*da0073e9SAndroid Build Coastguard Worker if args.profile: 110*da0073e9SAndroid Build Coastguard Worker profile_model(args, model, inputs, rank) 111*da0073e9SAndroid Build Coastguard Worker 112*da0073e9SAndroid Build Coastguard Worker cleanup() 113*da0073e9SAndroid Build Coastguard Worker return t_total 114*da0073e9SAndroid Build Coastguard Worker 115*da0073e9SAndroid Build Coastguard Worker 116*da0073e9SAndroid Build Coastguard Workerif __name__ == "__main__": 117*da0073e9SAndroid Build Coastguard Worker parser = argparse.ArgumentParser() 118*da0073e9SAndroid Build Coastguard Worker parser.add_argument("--device", default="cuda") 119*da0073e9SAndroid Build Coastguard Worker parser.add_argument( 120*da0073e9SAndroid Build Coastguard Worker "--dynamo", 121*da0073e9SAndroid Build Coastguard Worker default=None, 122*da0073e9SAndroid Build Coastguard Worker help="if set to a str, uses dynamo[str] backend. else, eager", 123*da0073e9SAndroid Build Coastguard Worker ) 124*da0073e9SAndroid Build Coastguard Worker parser.add_argument("--verbose", action="store_true") 125*da0073e9SAndroid Build Coastguard Worker parser.add_argument("--batch-size", "--batch_size", default=None) 126*da0073e9SAndroid Build Coastguard Worker parser.add_argument( 127*da0073e9SAndroid Build Coastguard Worker "--torchviz", action="store_true", help="Dump autograd graph with torchviz" 128*da0073e9SAndroid Build Coastguard Worker ) 129*da0073e9SAndroid Build Coastguard Worker parser.add_argument("--profile", action="store_true", help="Run the profiler") 130*da0073e9SAndroid Build Coastguard Worker parser.add_argument( 131*da0073e9SAndroid Build Coastguard Worker "--trace-file", "--trace_file", default="profile.json", help="Run the profiler" 132*da0073e9SAndroid Build Coastguard Worker ) 133*da0073e9SAndroid Build Coastguard Worker parser.add_argument("--repeat", default=10, help="Repeats for timing run") 134*da0073e9SAndroid Build Coastguard Worker parser.add_argument( 135*da0073e9SAndroid Build Coastguard Worker "--dynamo-no-optimize-ddp", 136*da0073e9SAndroid Build Coastguard Worker "--dynamo_no_optimize_ddp", 137*da0073e9SAndroid Build Coastguard Worker action="store_true", 138*da0073e9SAndroid Build Coastguard Worker help="Disable dynamo's ddp optimizer (enabled by default)", 139*da0073e9SAndroid Build Coastguard Worker ) 140*da0073e9SAndroid Build Coastguard Worker parser.add_argument( 141*da0073e9SAndroid Build Coastguard Worker "--fsdp-checkpoint", 142*da0073e9SAndroid Build Coastguard Worker "--fsdp_checkpoint", 143*da0073e9SAndroid Build Coastguard Worker action="store_true", 144*da0073e9SAndroid Build Coastguard Worker help="Use gradient checkpointing via model-specific policy", 145*da0073e9SAndroid Build Coastguard Worker ) 146*da0073e9SAndroid Build Coastguard Worker parser.add_argument( 147*da0073e9SAndroid Build Coastguard Worker "--fsdp-wrap", 148*da0073e9SAndroid Build Coastguard Worker "--fsdp_wrap", 149*da0073e9SAndroid Build Coastguard Worker action="store_true", 150*da0073e9SAndroid Build Coastguard Worker help="Apply fsdp to submodules via model-specific policy", 151*da0073e9SAndroid Build Coastguard Worker ) 152*da0073e9SAndroid Build Coastguard Worker 153*da0073e9SAndroid Build Coastguard Worker dist_arg = parser.add_mutually_exclusive_group() 154*da0073e9SAndroid Build Coastguard Worker dist_arg.add_argument("--ddp", action="store_true") 155*da0073e9SAndroid Build Coastguard Worker dist_arg.add_argument("--fsdp", action="store_true") 156*da0073e9SAndroid Build Coastguard Worker 157*da0073e9SAndroid Build Coastguard Worker model_arg = parser.add_mutually_exclusive_group(required=True) 158*da0073e9SAndroid Build Coastguard Worker model_arg.add_argument( 159*da0073e9SAndroid Build Coastguard Worker "--torchbench-model", 160*da0073e9SAndroid Build Coastguard Worker "--torchbench_model", 161*da0073e9SAndroid Build Coastguard Worker help="name of torchbench model, e.g. hf_Bert", 162*da0073e9SAndroid Build Coastguard Worker ) 163*da0073e9SAndroid Build Coastguard Worker model_arg.add_argument( 164*da0073e9SAndroid Build Coastguard Worker "--toy-model", "--toy_model", action="store_true", help="use toy model instead" 165*da0073e9SAndroid Build Coastguard Worker ) 166*da0073e9SAndroid Build Coastguard Worker args = parser.parse_args() 167*da0073e9SAndroid Build Coastguard Worker 168*da0073e9SAndroid Build Coastguard Worker model_name = args.torchbench_model 169*da0073e9SAndroid Build Coastguard Worker if args.toy_model: 170*da0073e9SAndroid Build Coastguard Worker model_name = "ToyModel" 171*da0073e9SAndroid Build Coastguard Worker model, inputs = get_model(args) 172*da0073e9SAndroid Build Coastguard Worker 173*da0073e9SAndroid Build Coastguard Worker fn = partial(run_model, args, model, inputs) 174*da0073e9SAndroid Build Coastguard Worker 175*da0073e9SAndroid Build Coastguard Worker world_size = os.getenv("WORLD_SIZE", 1) 176*da0073e9SAndroid Build Coastguard Worker t_total = fn(f"{model_name}_{world_size}") 177*da0073e9SAndroid Build Coastguard Worker print(f"mean latency {t_total / args.repeat} across {args.repeat} runs") 178