1*da0073e9SAndroid Build Coastguard Worker# mypy: allow-untyped-defs 2*da0073e9SAndroid Build Coastguard Workerr""" 3*da0073e9SAndroid Build Coastguard WorkerModule ``torch.distributed.launch``. 4*da0073e9SAndroid Build Coastguard Worker 5*da0073e9SAndroid Build Coastguard Worker``torch.distributed.launch`` is a module that spawns up multiple distributed 6*da0073e9SAndroid Build Coastguard Workertraining processes on each of the training nodes. 7*da0073e9SAndroid Build Coastguard Worker 8*da0073e9SAndroid Build Coastguard Worker.. warning:: 9*da0073e9SAndroid Build Coastguard Worker 10*da0073e9SAndroid Build Coastguard Worker This module is going to be deprecated in favor of :ref:`torchrun <launcher-api>`. 11*da0073e9SAndroid Build Coastguard Worker 12*da0073e9SAndroid Build Coastguard WorkerThe utility can be used for single-node distributed training, in which one or 13*da0073e9SAndroid Build Coastguard Workermore processes per node will be spawned. The utility can be used for either 14*da0073e9SAndroid Build Coastguard WorkerCPU training or GPU training. If the utility is used for GPU training, 15*da0073e9SAndroid Build Coastguard Workereach distributed process will be operating on a single GPU. This can achieve 16*da0073e9SAndroid Build Coastguard Workerwell-improved single-node training performance. It can also be used in 17*da0073e9SAndroid Build Coastguard Workermulti-node distributed training, by spawning up multiple processes on each node 18*da0073e9SAndroid Build Coastguard Workerfor well-improved multi-node distributed training performance as well. 19*da0073e9SAndroid Build Coastguard WorkerThis will especially be beneficial for systems with multiple Infiniband 20*da0073e9SAndroid Build Coastguard Workerinterfaces that have direct-GPU support, since all of them can be utilized for 21*da0073e9SAndroid Build Coastguard Workeraggregated communication bandwidth. 22*da0073e9SAndroid Build Coastguard Worker 23*da0073e9SAndroid Build Coastguard WorkerIn both cases of single-node distributed training or multi-node distributed 24*da0073e9SAndroid Build Coastguard Workertraining, this utility will launch the given number of processes per node 25*da0073e9SAndroid Build Coastguard Worker(``--nproc-per-node``). If used for GPU training, this number needs to be less 26*da0073e9SAndroid Build Coastguard Workeror equal to the number of GPUs on the current system (``nproc_per_node``), 27*da0073e9SAndroid Build Coastguard Workerand each process will be operating on a single GPU from *GPU 0 to 28*da0073e9SAndroid Build Coastguard WorkerGPU (nproc_per_node - 1)*. 29*da0073e9SAndroid Build Coastguard Worker 30*da0073e9SAndroid Build Coastguard Worker**How to use this module:** 31*da0073e9SAndroid Build Coastguard Worker 32*da0073e9SAndroid Build Coastguard Worker1. Single-Node multi-process distributed training 33*da0073e9SAndroid Build Coastguard Worker 34*da0073e9SAndroid Build Coastguard Worker:: 35*da0073e9SAndroid Build Coastguard Worker 36*da0073e9SAndroid Build Coastguard Worker python -m torch.distributed.launch --nproc-per-node=NUM_GPUS_YOU_HAVE 37*da0073e9SAndroid Build Coastguard Worker YOUR_TRAINING_SCRIPT.py (--arg1 --arg2 --arg3 and all other 38*da0073e9SAndroid Build Coastguard Worker arguments of your training script) 39*da0073e9SAndroid Build Coastguard Worker 40*da0073e9SAndroid Build Coastguard Worker2. Multi-Node multi-process distributed training: (e.g. two nodes) 41*da0073e9SAndroid Build Coastguard Worker 42*da0073e9SAndroid Build Coastguard Worker 43*da0073e9SAndroid Build Coastguard WorkerNode 1: *(IP: 192.168.1.1, and has a free port: 1234)* 44*da0073e9SAndroid Build Coastguard Worker 45*da0073e9SAndroid Build Coastguard Worker:: 46*da0073e9SAndroid Build Coastguard Worker 47*da0073e9SAndroid Build Coastguard Worker python -m torch.distributed.launch --nproc-per-node=NUM_GPUS_YOU_HAVE 48*da0073e9SAndroid Build Coastguard Worker --nnodes=2 --node-rank=0 --master-addr="192.168.1.1" 49*da0073e9SAndroid Build Coastguard Worker --master-port=1234 YOUR_TRAINING_SCRIPT.py (--arg1 --arg2 --arg3 50*da0073e9SAndroid Build Coastguard Worker and all other arguments of your training script) 51*da0073e9SAndroid Build Coastguard Worker 52*da0073e9SAndroid Build Coastguard WorkerNode 2: 53*da0073e9SAndroid Build Coastguard Worker 54*da0073e9SAndroid Build Coastguard Worker:: 55*da0073e9SAndroid Build Coastguard Worker 56*da0073e9SAndroid Build Coastguard Worker python -m torch.distributed.launch --nproc-per-node=NUM_GPUS_YOU_HAVE 57*da0073e9SAndroid Build Coastguard Worker --nnodes=2 --node-rank=1 --master-addr="192.168.1.1" 58*da0073e9SAndroid Build Coastguard Worker --master-port=1234 YOUR_TRAINING_SCRIPT.py (--arg1 --arg2 --arg3 59*da0073e9SAndroid Build Coastguard Worker and all other arguments of your training script) 60*da0073e9SAndroid Build Coastguard Worker 61*da0073e9SAndroid Build Coastguard Worker3. To look up what optional arguments this module offers: 62*da0073e9SAndroid Build Coastguard Worker 63*da0073e9SAndroid Build Coastguard Worker:: 64*da0073e9SAndroid Build Coastguard Worker 65*da0073e9SAndroid Build Coastguard Worker python -m torch.distributed.launch --help 66*da0073e9SAndroid Build Coastguard Worker 67*da0073e9SAndroid Build Coastguard Worker 68*da0073e9SAndroid Build Coastguard Worker**Important Notices:** 69*da0073e9SAndroid Build Coastguard Worker 70*da0073e9SAndroid Build Coastguard Worker1. This utility and multi-process distributed (single-node or 71*da0073e9SAndroid Build Coastguard Workermulti-node) GPU training currently only achieves the best performance using 72*da0073e9SAndroid Build Coastguard Workerthe NCCL distributed backend. Thus NCCL backend is the recommended backend to 73*da0073e9SAndroid Build Coastguard Workeruse for GPU training. 74*da0073e9SAndroid Build Coastguard Worker 75*da0073e9SAndroid Build Coastguard Worker2. In your training program, you must parse the command-line argument: 76*da0073e9SAndroid Build Coastguard Worker``--local-rank=LOCAL_PROCESS_RANK``, which will be provided by this module. 77*da0073e9SAndroid Build Coastguard WorkerIf your training program uses GPUs, you should ensure that your code only 78*da0073e9SAndroid Build Coastguard Workerruns on the GPU device of LOCAL_PROCESS_RANK. This can be done by: 79*da0073e9SAndroid Build Coastguard Worker 80*da0073e9SAndroid Build Coastguard WorkerParsing the local_rank argument 81*da0073e9SAndroid Build Coastguard Worker 82*da0073e9SAndroid Build Coastguard Worker:: 83*da0073e9SAndroid Build Coastguard Worker 84*da0073e9SAndroid Build Coastguard Worker >>> # xdoctest: +SKIP 85*da0073e9SAndroid Build Coastguard Worker >>> import argparse 86*da0073e9SAndroid Build Coastguard Worker >>> parser = argparse.ArgumentParser() 87*da0073e9SAndroid Build Coastguard Worker >>> parser.add_argument("--local-rank", "--local_rank", type=int) 88*da0073e9SAndroid Build Coastguard Worker >>> args = parser.parse_args() 89*da0073e9SAndroid Build Coastguard Worker 90*da0073e9SAndroid Build Coastguard WorkerSet your device to local rank using either 91*da0073e9SAndroid Build Coastguard Worker 92*da0073e9SAndroid Build Coastguard Worker:: 93*da0073e9SAndroid Build Coastguard Worker 94*da0073e9SAndroid Build Coastguard Worker >>> torch.cuda.set_device(args.local_rank) # before your code runs 95*da0073e9SAndroid Build Coastguard Worker 96*da0073e9SAndroid Build Coastguard Workeror 97*da0073e9SAndroid Build Coastguard Worker 98*da0073e9SAndroid Build Coastguard Worker:: 99*da0073e9SAndroid Build Coastguard Worker 100*da0073e9SAndroid Build Coastguard Worker >>> with torch.cuda.device(args.local_rank): 101*da0073e9SAndroid Build Coastguard Worker >>> # your code to run 102*da0073e9SAndroid Build Coastguard Worker >>> ... 103*da0073e9SAndroid Build Coastguard Worker 104*da0073e9SAndroid Build Coastguard Worker.. versionchanged:: 2.0.0 105*da0073e9SAndroid Build Coastguard Worker 106*da0073e9SAndroid Build Coastguard Worker The launcher will passes the ``--local-rank=<rank>`` argument to your script. 107*da0073e9SAndroid Build Coastguard Worker From PyTorch 2.0.0 onwards, the dashed ``--local-rank`` is preferred over the 108*da0073e9SAndroid Build Coastguard Worker previously used underscored ``--local_rank``. 109*da0073e9SAndroid Build Coastguard Worker 110*da0073e9SAndroid Build Coastguard Worker For backward compatibility, it may be necessary for users to handle both 111*da0073e9SAndroid Build Coastguard Worker cases in their argument parsing code. This means including both ``"--local-rank"`` 112*da0073e9SAndroid Build Coastguard Worker and ``"--local_rank"`` in the argument parser. If only ``"--local_rank"`` is 113*da0073e9SAndroid Build Coastguard Worker provided, the launcher will trigger an error: "error: unrecognized arguments: 114*da0073e9SAndroid Build Coastguard Worker --local-rank=<rank>". For training code that only supports PyTorch 2.0.0+, 115*da0073e9SAndroid Build Coastguard Worker including ``"--local-rank"`` should be sufficient. 116*da0073e9SAndroid Build Coastguard Worker 117*da0073e9SAndroid Build Coastguard Worker3. In your training program, you are supposed to call the following function 118*da0073e9SAndroid Build Coastguard Workerat the beginning to start the distributed backend. It is strongly recommended 119*da0073e9SAndroid Build Coastguard Workerthat ``init_method=env://``. Other init methods (e.g. ``tcp://``) may work, 120*da0073e9SAndroid Build Coastguard Workerbut ``env://`` is the one that is officially supported by this module. 121*da0073e9SAndroid Build Coastguard Worker 122*da0073e9SAndroid Build Coastguard Worker:: 123*da0073e9SAndroid Build Coastguard Worker 124*da0073e9SAndroid Build Coastguard Worker >>> torch.distributed.init_process_group(backend='YOUR BACKEND', 125*da0073e9SAndroid Build Coastguard Worker >>> init_method='env://') 126*da0073e9SAndroid Build Coastguard Worker 127*da0073e9SAndroid Build Coastguard Worker4. In your training program, you can either use regular distributed functions 128*da0073e9SAndroid Build Coastguard Workeror use :func:`torch.nn.parallel.DistributedDataParallel` module. If your 129*da0073e9SAndroid Build Coastguard Workertraining program uses GPUs for training and you would like to use 130*da0073e9SAndroid Build Coastguard Worker:func:`torch.nn.parallel.DistributedDataParallel` module, 131*da0073e9SAndroid Build Coastguard Workerhere is how to configure it. 132*da0073e9SAndroid Build Coastguard Worker 133*da0073e9SAndroid Build Coastguard Worker:: 134*da0073e9SAndroid Build Coastguard Worker 135*da0073e9SAndroid Build Coastguard Worker >>> model = torch.nn.parallel.DistributedDataParallel(model, 136*da0073e9SAndroid Build Coastguard Worker >>> device_ids=[args.local_rank], 137*da0073e9SAndroid Build Coastguard Worker >>> output_device=args.local_rank) 138*da0073e9SAndroid Build Coastguard Worker 139*da0073e9SAndroid Build Coastguard WorkerPlease ensure that ``device_ids`` argument is set to be the only GPU device id 140*da0073e9SAndroid Build Coastguard Workerthat your code will be operating on. This is generally the local rank of the 141*da0073e9SAndroid Build Coastguard Workerprocess. In other words, the ``device_ids`` needs to be ``[args.local_rank]``, 142*da0073e9SAndroid Build Coastguard Workerand ``output_device`` needs to be ``args.local_rank`` in order to use this 143*da0073e9SAndroid Build Coastguard Workerutility 144*da0073e9SAndroid Build Coastguard Worker 145*da0073e9SAndroid Build Coastguard Worker5. Another way to pass ``local_rank`` to the subprocesses via environment variable 146*da0073e9SAndroid Build Coastguard Worker``LOCAL_RANK``. This behavior is enabled when you launch the script with 147*da0073e9SAndroid Build Coastguard Worker``--use-env=True``. You must adjust the subprocess example above to replace 148*da0073e9SAndroid Build Coastguard Worker``args.local_rank`` with ``os.environ['LOCAL_RANK']``; the launcher 149*da0073e9SAndroid Build Coastguard Workerwill not pass ``--local-rank`` when you specify this flag. 150*da0073e9SAndroid Build Coastguard Worker 151*da0073e9SAndroid Build Coastguard Worker.. warning:: 152*da0073e9SAndroid Build Coastguard Worker 153*da0073e9SAndroid Build Coastguard Worker ``local_rank`` is NOT globally unique: it is only unique per process 154*da0073e9SAndroid Build Coastguard Worker on a machine. Thus, don't use it to decide if you should, e.g., 155*da0073e9SAndroid Build Coastguard Worker write to a networked filesystem. See 156*da0073e9SAndroid Build Coastguard Worker https://github.com/pytorch/pytorch/issues/12042 for an example of 157*da0073e9SAndroid Build Coastguard Worker how things can go wrong if you don't do this correctly. 158*da0073e9SAndroid Build Coastguard Worker 159*da0073e9SAndroid Build Coastguard Worker 160*da0073e9SAndroid Build Coastguard Worker 161*da0073e9SAndroid Build Coastguard Worker""" 162*da0073e9SAndroid Build Coastguard Worker 163*da0073e9SAndroid Build Coastguard Workerfrom typing_extensions import deprecated as _deprecated 164*da0073e9SAndroid Build Coastguard Worker 165*da0073e9SAndroid Build Coastguard Workerfrom torch.distributed.run import get_args_parser, run 166*da0073e9SAndroid Build Coastguard Worker 167*da0073e9SAndroid Build Coastguard Worker 168*da0073e9SAndroid Build Coastguard Workerdef parse_args(args): 169*da0073e9SAndroid Build Coastguard Worker parser = get_args_parser() 170*da0073e9SAndroid Build Coastguard Worker parser.add_argument( 171*da0073e9SAndroid Build Coastguard Worker "--use-env", 172*da0073e9SAndroid Build Coastguard Worker "--use_env", 173*da0073e9SAndroid Build Coastguard Worker default=False, 174*da0073e9SAndroid Build Coastguard Worker action="store_true", 175*da0073e9SAndroid Build Coastguard Worker help="Use environment variable to pass " 176*da0073e9SAndroid Build Coastguard Worker "'local rank'. For legacy reasons, the default value is False. " 177*da0073e9SAndroid Build Coastguard Worker "If set to True, the script will not pass " 178*da0073e9SAndroid Build Coastguard Worker "--local-rank as argument, and will instead set LOCAL_RANK.", 179*da0073e9SAndroid Build Coastguard Worker ) 180*da0073e9SAndroid Build Coastguard Worker return parser.parse_args(args) 181*da0073e9SAndroid Build Coastguard Worker 182*da0073e9SAndroid Build Coastguard Worker 183*da0073e9SAndroid Build Coastguard Workerdef launch(args): 184*da0073e9SAndroid Build Coastguard Worker if args.no_python and not args.use_env: 185*da0073e9SAndroid Build Coastguard Worker raise ValueError( 186*da0073e9SAndroid Build Coastguard Worker "When using the '--no-python' flag," 187*da0073e9SAndroid Build Coastguard Worker " you must also set the '--use-env' flag." 188*da0073e9SAndroid Build Coastguard Worker ) 189*da0073e9SAndroid Build Coastguard Worker run(args) 190*da0073e9SAndroid Build Coastguard Worker 191*da0073e9SAndroid Build Coastguard Worker 192*da0073e9SAndroid Build Coastguard Worker@_deprecated( 193*da0073e9SAndroid Build Coastguard Worker "The module torch.distributed.launch is deprecated\n" 194*da0073e9SAndroid Build Coastguard Worker "and will be removed in future. Use torchrun.\n" 195*da0073e9SAndroid Build Coastguard Worker "Note that --use-env is set by default in torchrun.\n" 196*da0073e9SAndroid Build Coastguard Worker "If your script expects `--local-rank` argument to be set, please\n" 197*da0073e9SAndroid Build Coastguard Worker "change it to read from `os.environ['LOCAL_RANK']` instead. See \n" 198*da0073e9SAndroid Build Coastguard Worker "https://pytorch.org/docs/stable/distributed.html#launch-utility for \n" 199*da0073e9SAndroid Build Coastguard Worker "further instructions\n", 200*da0073e9SAndroid Build Coastguard Worker category=FutureWarning, 201*da0073e9SAndroid Build Coastguard Worker) 202*da0073e9SAndroid Build Coastguard Workerdef main(args=None): 203*da0073e9SAndroid Build Coastguard Worker args = parse_args(args) 204*da0073e9SAndroid Build Coastguard Worker launch(args) 205*da0073e9SAndroid Build Coastguard Worker 206*da0073e9SAndroid Build Coastguard Worker 207*da0073e9SAndroid Build Coastguard Workerif __name__ == "__main__": 208*da0073e9SAndroid Build Coastguard Worker main() 209