xref: /aosp_15_r20/external/pytorch/torch/distributed/launch.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1*da0073e9SAndroid Build Coastguard Worker# mypy: allow-untyped-defs
2*da0073e9SAndroid Build Coastguard Workerr"""
3*da0073e9SAndroid Build Coastguard WorkerModule ``torch.distributed.launch``.
4*da0073e9SAndroid Build Coastguard Worker
5*da0073e9SAndroid Build Coastguard Worker``torch.distributed.launch`` is a module that spawns up multiple distributed
6*da0073e9SAndroid Build Coastguard Workertraining processes on each of the training nodes.
7*da0073e9SAndroid Build Coastguard Worker
8*da0073e9SAndroid Build Coastguard Worker.. warning::
9*da0073e9SAndroid Build Coastguard Worker
10*da0073e9SAndroid Build Coastguard Worker    This module is going to be deprecated in favor of :ref:`torchrun <launcher-api>`.
11*da0073e9SAndroid Build Coastguard Worker
12*da0073e9SAndroid Build Coastguard WorkerThe utility can be used for single-node distributed training, in which one or
13*da0073e9SAndroid Build Coastguard Workermore processes per node will be spawned. The utility can be used for either
14*da0073e9SAndroid Build Coastguard WorkerCPU training or GPU training. If the utility is used for GPU training,
15*da0073e9SAndroid Build Coastguard Workereach distributed process will be operating on a single GPU. This can achieve
16*da0073e9SAndroid Build Coastguard Workerwell-improved single-node training performance. It can also be used in
17*da0073e9SAndroid Build Coastguard Workermulti-node distributed training, by spawning up multiple processes on each node
18*da0073e9SAndroid Build Coastguard Workerfor well-improved multi-node distributed training performance as well.
19*da0073e9SAndroid Build Coastguard WorkerThis will especially be beneficial for systems with multiple Infiniband
20*da0073e9SAndroid Build Coastguard Workerinterfaces that have direct-GPU support, since all of them can be utilized for
21*da0073e9SAndroid Build Coastguard Workeraggregated communication bandwidth.
22*da0073e9SAndroid Build Coastguard Worker
23*da0073e9SAndroid Build Coastguard WorkerIn both cases of single-node distributed training or multi-node distributed
24*da0073e9SAndroid Build Coastguard Workertraining, this utility will launch the given number of processes per node
25*da0073e9SAndroid Build Coastguard Worker(``--nproc-per-node``). If used for GPU training, this number needs to be less
26*da0073e9SAndroid Build Coastguard Workeror equal to the number of GPUs on the current system (``nproc_per_node``),
27*da0073e9SAndroid Build Coastguard Workerand each process will be operating on a single GPU from *GPU 0 to
28*da0073e9SAndroid Build Coastguard WorkerGPU (nproc_per_node - 1)*.
29*da0073e9SAndroid Build Coastguard Worker
30*da0073e9SAndroid Build Coastguard Worker**How to use this module:**
31*da0073e9SAndroid Build Coastguard Worker
32*da0073e9SAndroid Build Coastguard Worker1. Single-Node multi-process distributed training
33*da0073e9SAndroid Build Coastguard Worker
34*da0073e9SAndroid Build Coastguard Worker::
35*da0073e9SAndroid Build Coastguard Worker
36*da0073e9SAndroid Build Coastguard Worker    python -m torch.distributed.launch --nproc-per-node=NUM_GPUS_YOU_HAVE
37*da0073e9SAndroid Build Coastguard Worker               YOUR_TRAINING_SCRIPT.py (--arg1 --arg2 --arg3 and all other
38*da0073e9SAndroid Build Coastguard Worker               arguments of your training script)
39*da0073e9SAndroid Build Coastguard Worker
40*da0073e9SAndroid Build Coastguard Worker2. Multi-Node multi-process distributed training: (e.g. two nodes)
41*da0073e9SAndroid Build Coastguard Worker
42*da0073e9SAndroid Build Coastguard Worker
43*da0073e9SAndroid Build Coastguard WorkerNode 1: *(IP: 192.168.1.1, and has a free port: 1234)*
44*da0073e9SAndroid Build Coastguard Worker
45*da0073e9SAndroid Build Coastguard Worker::
46*da0073e9SAndroid Build Coastguard Worker
47*da0073e9SAndroid Build Coastguard Worker    python -m torch.distributed.launch --nproc-per-node=NUM_GPUS_YOU_HAVE
48*da0073e9SAndroid Build Coastguard Worker               --nnodes=2 --node-rank=0 --master-addr="192.168.1.1"
49*da0073e9SAndroid Build Coastguard Worker               --master-port=1234 YOUR_TRAINING_SCRIPT.py (--arg1 --arg2 --arg3
50*da0073e9SAndroid Build Coastguard Worker               and all other arguments of your training script)
51*da0073e9SAndroid Build Coastguard Worker
52*da0073e9SAndroid Build Coastguard WorkerNode 2:
53*da0073e9SAndroid Build Coastguard Worker
54*da0073e9SAndroid Build Coastguard Worker::
55*da0073e9SAndroid Build Coastguard Worker
56*da0073e9SAndroid Build Coastguard Worker    python -m torch.distributed.launch --nproc-per-node=NUM_GPUS_YOU_HAVE
57*da0073e9SAndroid Build Coastguard Worker               --nnodes=2 --node-rank=1 --master-addr="192.168.1.1"
58*da0073e9SAndroid Build Coastguard Worker               --master-port=1234 YOUR_TRAINING_SCRIPT.py (--arg1 --arg2 --arg3
59*da0073e9SAndroid Build Coastguard Worker               and all other arguments of your training script)
60*da0073e9SAndroid Build Coastguard Worker
61*da0073e9SAndroid Build Coastguard Worker3. To look up what optional arguments this module offers:
62*da0073e9SAndroid Build Coastguard Worker
63*da0073e9SAndroid Build Coastguard Worker::
64*da0073e9SAndroid Build Coastguard Worker
65*da0073e9SAndroid Build Coastguard Worker    python -m torch.distributed.launch --help
66*da0073e9SAndroid Build Coastguard Worker
67*da0073e9SAndroid Build Coastguard Worker
68*da0073e9SAndroid Build Coastguard Worker**Important Notices:**
69*da0073e9SAndroid Build Coastguard Worker
70*da0073e9SAndroid Build Coastguard Worker1. This utility and multi-process distributed (single-node or
71*da0073e9SAndroid Build Coastguard Workermulti-node) GPU training currently only achieves the best performance using
72*da0073e9SAndroid Build Coastguard Workerthe NCCL distributed backend. Thus NCCL backend is the recommended backend to
73*da0073e9SAndroid Build Coastguard Workeruse for GPU training.
74*da0073e9SAndroid Build Coastguard Worker
75*da0073e9SAndroid Build Coastguard Worker2. In your training program, you must parse the command-line argument:
76*da0073e9SAndroid Build Coastguard Worker``--local-rank=LOCAL_PROCESS_RANK``, which will be provided by this module.
77*da0073e9SAndroid Build Coastguard WorkerIf your training program uses GPUs, you should ensure that your code only
78*da0073e9SAndroid Build Coastguard Workerruns on the GPU device of LOCAL_PROCESS_RANK. This can be done by:
79*da0073e9SAndroid Build Coastguard Worker
80*da0073e9SAndroid Build Coastguard WorkerParsing the local_rank argument
81*da0073e9SAndroid Build Coastguard Worker
82*da0073e9SAndroid Build Coastguard Worker::
83*da0073e9SAndroid Build Coastguard Worker
84*da0073e9SAndroid Build Coastguard Worker    >>> # xdoctest: +SKIP
85*da0073e9SAndroid Build Coastguard Worker    >>> import argparse
86*da0073e9SAndroid Build Coastguard Worker    >>> parser = argparse.ArgumentParser()
87*da0073e9SAndroid Build Coastguard Worker    >>> parser.add_argument("--local-rank", "--local_rank", type=int)
88*da0073e9SAndroid Build Coastguard Worker    >>> args = parser.parse_args()
89*da0073e9SAndroid Build Coastguard Worker
90*da0073e9SAndroid Build Coastguard WorkerSet your device to local rank using either
91*da0073e9SAndroid Build Coastguard Worker
92*da0073e9SAndroid Build Coastguard Worker::
93*da0073e9SAndroid Build Coastguard Worker
94*da0073e9SAndroid Build Coastguard Worker    >>> torch.cuda.set_device(args.local_rank)  # before your code runs
95*da0073e9SAndroid Build Coastguard Worker
96*da0073e9SAndroid Build Coastguard Workeror
97*da0073e9SAndroid Build Coastguard Worker
98*da0073e9SAndroid Build Coastguard Worker::
99*da0073e9SAndroid Build Coastguard Worker
100*da0073e9SAndroid Build Coastguard Worker    >>> with torch.cuda.device(args.local_rank):
101*da0073e9SAndroid Build Coastguard Worker    >>>    # your code to run
102*da0073e9SAndroid Build Coastguard Worker    >>>    ...
103*da0073e9SAndroid Build Coastguard Worker
104*da0073e9SAndroid Build Coastguard Worker.. versionchanged:: 2.0.0
105*da0073e9SAndroid Build Coastguard Worker
106*da0073e9SAndroid Build Coastguard Worker    The launcher will passes the ``--local-rank=<rank>`` argument to your script.
107*da0073e9SAndroid Build Coastguard Worker    From PyTorch 2.0.0 onwards, the dashed ``--local-rank`` is preferred over the
108*da0073e9SAndroid Build Coastguard Worker    previously used underscored ``--local_rank``.
109*da0073e9SAndroid Build Coastguard Worker
110*da0073e9SAndroid Build Coastguard Worker    For backward compatibility, it may be necessary for users to handle both
111*da0073e9SAndroid Build Coastguard Worker    cases in their argument parsing code. This means including both ``"--local-rank"``
112*da0073e9SAndroid Build Coastguard Worker    and ``"--local_rank"`` in the argument parser. If only ``"--local_rank"`` is
113*da0073e9SAndroid Build Coastguard Worker    provided, the launcher will trigger an error: "error: unrecognized arguments:
114*da0073e9SAndroid Build Coastguard Worker    --local-rank=<rank>". For training code that only supports PyTorch 2.0.0+,
115*da0073e9SAndroid Build Coastguard Worker    including ``"--local-rank"`` should be sufficient.
116*da0073e9SAndroid Build Coastguard Worker
117*da0073e9SAndroid Build Coastguard Worker3. In your training program, you are supposed to call the following function
118*da0073e9SAndroid Build Coastguard Workerat the beginning to start the distributed backend. It is strongly recommended
119*da0073e9SAndroid Build Coastguard Workerthat ``init_method=env://``. Other init methods (e.g. ``tcp://``) may work,
120*da0073e9SAndroid Build Coastguard Workerbut ``env://`` is the one that is officially supported by this module.
121*da0073e9SAndroid Build Coastguard Worker
122*da0073e9SAndroid Build Coastguard Worker::
123*da0073e9SAndroid Build Coastguard Worker
124*da0073e9SAndroid Build Coastguard Worker    >>> torch.distributed.init_process_group(backend='YOUR BACKEND',
125*da0073e9SAndroid Build Coastguard Worker    >>>                                      init_method='env://')
126*da0073e9SAndroid Build Coastguard Worker
127*da0073e9SAndroid Build Coastguard Worker4. In your training program, you can either use regular distributed functions
128*da0073e9SAndroid Build Coastguard Workeror use :func:`torch.nn.parallel.DistributedDataParallel` module. If your
129*da0073e9SAndroid Build Coastguard Workertraining program uses GPUs for training and you would like to use
130*da0073e9SAndroid Build Coastguard Worker:func:`torch.nn.parallel.DistributedDataParallel` module,
131*da0073e9SAndroid Build Coastguard Workerhere is how to configure it.
132*da0073e9SAndroid Build Coastguard Worker
133*da0073e9SAndroid Build Coastguard Worker::
134*da0073e9SAndroid Build Coastguard Worker
135*da0073e9SAndroid Build Coastguard Worker    >>> model = torch.nn.parallel.DistributedDataParallel(model,
136*da0073e9SAndroid Build Coastguard Worker    >>>                                                   device_ids=[args.local_rank],
137*da0073e9SAndroid Build Coastguard Worker    >>>                                                   output_device=args.local_rank)
138*da0073e9SAndroid Build Coastguard Worker
139*da0073e9SAndroid Build Coastguard WorkerPlease ensure that ``device_ids`` argument is set to be the only GPU device id
140*da0073e9SAndroid Build Coastguard Workerthat your code will be operating on. This is generally the local rank of the
141*da0073e9SAndroid Build Coastguard Workerprocess. In other words, the ``device_ids`` needs to be ``[args.local_rank]``,
142*da0073e9SAndroid Build Coastguard Workerand ``output_device`` needs to be ``args.local_rank`` in order to use this
143*da0073e9SAndroid Build Coastguard Workerutility
144*da0073e9SAndroid Build Coastguard Worker
145*da0073e9SAndroid Build Coastguard Worker5. Another way to pass ``local_rank`` to the subprocesses via environment variable
146*da0073e9SAndroid Build Coastguard Worker``LOCAL_RANK``. This behavior is enabled when you launch the script with
147*da0073e9SAndroid Build Coastguard Worker``--use-env=True``. You must adjust the subprocess example above to replace
148*da0073e9SAndroid Build Coastguard Worker``args.local_rank`` with ``os.environ['LOCAL_RANK']``; the launcher
149*da0073e9SAndroid Build Coastguard Workerwill not pass ``--local-rank`` when you specify this flag.
150*da0073e9SAndroid Build Coastguard Worker
151*da0073e9SAndroid Build Coastguard Worker.. warning::
152*da0073e9SAndroid Build Coastguard Worker
153*da0073e9SAndroid Build Coastguard Worker    ``local_rank`` is NOT globally unique: it is only unique per process
154*da0073e9SAndroid Build Coastguard Worker    on a machine.  Thus, don't use it to decide if you should, e.g.,
155*da0073e9SAndroid Build Coastguard Worker    write to a networked filesystem.  See
156*da0073e9SAndroid Build Coastguard Worker    https://github.com/pytorch/pytorch/issues/12042 for an example of
157*da0073e9SAndroid Build Coastguard Worker    how things can go wrong if you don't do this correctly.
158*da0073e9SAndroid Build Coastguard Worker
159*da0073e9SAndroid Build Coastguard Worker
160*da0073e9SAndroid Build Coastguard Worker
161*da0073e9SAndroid Build Coastguard Worker"""
162*da0073e9SAndroid Build Coastguard Worker
163*da0073e9SAndroid Build Coastguard Workerfrom typing_extensions import deprecated as _deprecated
164*da0073e9SAndroid Build Coastguard Worker
165*da0073e9SAndroid Build Coastguard Workerfrom torch.distributed.run import get_args_parser, run
166*da0073e9SAndroid Build Coastguard Worker
167*da0073e9SAndroid Build Coastguard Worker
168*da0073e9SAndroid Build Coastguard Workerdef parse_args(args):
169*da0073e9SAndroid Build Coastguard Worker    parser = get_args_parser()
170*da0073e9SAndroid Build Coastguard Worker    parser.add_argument(
171*da0073e9SAndroid Build Coastguard Worker        "--use-env",
172*da0073e9SAndroid Build Coastguard Worker        "--use_env",
173*da0073e9SAndroid Build Coastguard Worker        default=False,
174*da0073e9SAndroid Build Coastguard Worker        action="store_true",
175*da0073e9SAndroid Build Coastguard Worker        help="Use environment variable to pass "
176*da0073e9SAndroid Build Coastguard Worker        "'local rank'. For legacy reasons, the default value is False. "
177*da0073e9SAndroid Build Coastguard Worker        "If set to True, the script will not pass "
178*da0073e9SAndroid Build Coastguard Worker        "--local-rank as argument, and will instead set LOCAL_RANK.",
179*da0073e9SAndroid Build Coastguard Worker    )
180*da0073e9SAndroid Build Coastguard Worker    return parser.parse_args(args)
181*da0073e9SAndroid Build Coastguard Worker
182*da0073e9SAndroid Build Coastguard Worker
183*da0073e9SAndroid Build Coastguard Workerdef launch(args):
184*da0073e9SAndroid Build Coastguard Worker    if args.no_python and not args.use_env:
185*da0073e9SAndroid Build Coastguard Worker        raise ValueError(
186*da0073e9SAndroid Build Coastguard Worker            "When using the '--no-python' flag,"
187*da0073e9SAndroid Build Coastguard Worker            " you must also set the '--use-env' flag."
188*da0073e9SAndroid Build Coastguard Worker        )
189*da0073e9SAndroid Build Coastguard Worker    run(args)
190*da0073e9SAndroid Build Coastguard Worker
191*da0073e9SAndroid Build Coastguard Worker
192*da0073e9SAndroid Build Coastguard Worker@_deprecated(
193*da0073e9SAndroid Build Coastguard Worker    "The module torch.distributed.launch is deprecated\n"
194*da0073e9SAndroid Build Coastguard Worker    "and will be removed in future. Use torchrun.\n"
195*da0073e9SAndroid Build Coastguard Worker    "Note that --use-env is set by default in torchrun.\n"
196*da0073e9SAndroid Build Coastguard Worker    "If your script expects `--local-rank` argument to be set, please\n"
197*da0073e9SAndroid Build Coastguard Worker    "change it to read from `os.environ['LOCAL_RANK']` instead. See \n"
198*da0073e9SAndroid Build Coastguard Worker    "https://pytorch.org/docs/stable/distributed.html#launch-utility for \n"
199*da0073e9SAndroid Build Coastguard Worker    "further instructions\n",
200*da0073e9SAndroid Build Coastguard Worker    category=FutureWarning,
201*da0073e9SAndroid Build Coastguard Worker)
202*da0073e9SAndroid Build Coastguard Workerdef main(args=None):
203*da0073e9SAndroid Build Coastguard Worker    args = parse_args(args)
204*da0073e9SAndroid Build Coastguard Worker    launch(args)
205*da0073e9SAndroid Build Coastguard Worker
206*da0073e9SAndroid Build Coastguard Worker
207*da0073e9SAndroid Build Coastguard Workerif __name__ == "__main__":
208*da0073e9SAndroid Build Coastguard Worker    main()
209