xref: /aosp_15_r20/external/pytorch/torch/distributed/run.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1*da0073e9SAndroid Build Coastguard Worker#!/usr/bin/env python3
2*da0073e9SAndroid Build Coastguard Worker# mypy: allow-untyped-defs
3*da0073e9SAndroid Build Coastguard Worker
4*da0073e9SAndroid Build Coastguard Worker# Copyright (c) Facebook, Inc. and its affiliates.
5*da0073e9SAndroid Build Coastguard Worker# All rights reserved.
6*da0073e9SAndroid Build Coastguard Worker#
7*da0073e9SAndroid Build Coastguard Worker# This source code is licensed under the BSD-style license found in the
8*da0073e9SAndroid Build Coastguard Worker# LICENSE file in the root directory of this source tree.
9*da0073e9SAndroid Build Coastguard Worker
10*da0073e9SAndroid Build Coastguard Worker"""
11*da0073e9SAndroid Build Coastguard WorkerSuperset of ``torch.distributed.launch``.
12*da0073e9SAndroid Build Coastguard Worker
13*da0073e9SAndroid Build Coastguard Worker``torchrun`` provides a superset of the functionality as ``torch.distributed.launch``
14*da0073e9SAndroid Build Coastguard Workerwith the following additional functionalities:
15*da0073e9SAndroid Build Coastguard Worker
16*da0073e9SAndroid Build Coastguard Worker1. Worker failures are handled gracefully by restarting all workers.
17*da0073e9SAndroid Build Coastguard Worker
18*da0073e9SAndroid Build Coastguard Worker2. Worker ``RANK`` and ``WORLD_SIZE`` are assigned automatically.
19*da0073e9SAndroid Build Coastguard Worker
20*da0073e9SAndroid Build Coastguard Worker3. Number of nodes is allowed to change between minimum and maximum sizes (elasticity).
21*da0073e9SAndroid Build Coastguard Worker
22*da0073e9SAndroid Build Coastguard Worker.. note:: ``torchrun`` is a python
23*da0073e9SAndroid Build Coastguard Worker          `console script <https://packaging.python.org/en/latest/specifications/entry-points/#use-for-scripts>`_
24*da0073e9SAndroid Build Coastguard Worker          to the main module
25*da0073e9SAndroid Build Coastguard Worker          `torch.distributed.run <https://github.com/pytorch/pytorch/blob/master/torch/distributed/run.py>`_
26*da0073e9SAndroid Build Coastguard Worker          declared in the ``entry_points`` configuration in
27*da0073e9SAndroid Build Coastguard Worker          `setup.py <https://github.com/pytorch/pytorch/blob/master/setup.py>`_.
28*da0073e9SAndroid Build Coastguard Worker          It is equivalent to invoking ``python -m torch.distributed.run``.
29*da0073e9SAndroid Build Coastguard Worker
30*da0073e9SAndroid Build Coastguard Worker
31*da0073e9SAndroid Build Coastguard WorkerTransitioning from torch.distributed.launch to torchrun
32*da0073e9SAndroid Build Coastguard Worker~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
33*da0073e9SAndroid Build Coastguard Worker
34*da0073e9SAndroid Build Coastguard Worker
35*da0073e9SAndroid Build Coastguard Worker``torchrun`` supports the same arguments as ``torch.distributed.launch`` **except**
36*da0073e9SAndroid Build Coastguard Workerfor ``--use-env`` which is now deprecated. To migrate from ``torch.distributed.launch``
37*da0073e9SAndroid Build Coastguard Workerto ``torchrun`` follow these steps:
38*da0073e9SAndroid Build Coastguard Worker
39*da0073e9SAndroid Build Coastguard Worker1.  If your training script is already reading ``local_rank`` from the ``LOCAL_RANK`` environment variable.
40*da0073e9SAndroid Build Coastguard Worker    Then you need simply omit the ``--use-env`` flag, e.g.:
41*da0073e9SAndroid Build Coastguard Worker
42*da0073e9SAndroid Build Coastguard Worker    +--------------------------------------------------------------------+--------------------------------------------+
43*da0073e9SAndroid Build Coastguard Worker    |         ``torch.distributed.launch``                               |                ``torchrun``                |
44*da0073e9SAndroid Build Coastguard Worker    +====================================================================+============================================+
45*da0073e9SAndroid Build Coastguard Worker    |                                                                    |                                            |
46*da0073e9SAndroid Build Coastguard Worker    | .. code-block:: shell-session                                      | .. code-block:: shell-session              |
47*da0073e9SAndroid Build Coastguard Worker    |                                                                    |                                            |
48*da0073e9SAndroid Build Coastguard Worker    |    $ python -m torch.distributed.launch --use-env train_script.py  |    $ torchrun train_script.py              |
49*da0073e9SAndroid Build Coastguard Worker    |                                                                    |                                            |
50*da0073e9SAndroid Build Coastguard Worker    +--------------------------------------------------------------------+--------------------------------------------+
51*da0073e9SAndroid Build Coastguard Worker
52*da0073e9SAndroid Build Coastguard Worker2.  If your training script reads local rank from a ``--local-rank`` cmd argument.
53*da0073e9SAndroid Build Coastguard Worker    Change your training script to read from the ``LOCAL_RANK`` environment variable as
54*da0073e9SAndroid Build Coastguard Worker    demonstrated by the following code snippet:
55*da0073e9SAndroid Build Coastguard Worker
56*da0073e9SAndroid Build Coastguard Worker    +-------------------------------------------------------+----------------------------------------------------+
57*da0073e9SAndroid Build Coastguard Worker    |         ``torch.distributed.launch``                  |                    ``torchrun``                    |
58*da0073e9SAndroid Build Coastguard Worker    +=======================================================+====================================================+
59*da0073e9SAndroid Build Coastguard Worker    |                                                       |                                                    |
60*da0073e9SAndroid Build Coastguard Worker    | .. code-block:: python                                | .. code-block:: python                             |
61*da0073e9SAndroid Build Coastguard Worker    |                                                       |                                                    |
62*da0073e9SAndroid Build Coastguard Worker    |                                                       |                                                    |
63*da0073e9SAndroid Build Coastguard Worker    |    import argparse                                    |     import os                                      |
64*da0073e9SAndroid Build Coastguard Worker    |    parser = argparse.ArgumentParser()                 |     local_rank = int(os.environ["LOCAL_RANK"])     |
65*da0073e9SAndroid Build Coastguard Worker    |    parser.add_argument("--local-rank", type=int)      |                                                    |
66*da0073e9SAndroid Build Coastguard Worker    |    args = parser.parse_args()                         |                                                    |
67*da0073e9SAndroid Build Coastguard Worker    |                                                       |                                                    |
68*da0073e9SAndroid Build Coastguard Worker    |    local_rank = args.local_rank                       |                                                    |
69*da0073e9SAndroid Build Coastguard Worker    |                                                       |                                                    |
70*da0073e9SAndroid Build Coastguard Worker    +-------------------------------------------------------+----------------------------------------------------+
71*da0073e9SAndroid Build Coastguard Worker
72*da0073e9SAndroid Build Coastguard Worker.. versionchanged:: 2.0.0
73*da0073e9SAndroid Build Coastguard Worker
74*da0073e9SAndroid Build Coastguard Worker    The launcher will pass the ``--local-rank=<rank>`` argument to your script.
75*da0073e9SAndroid Build Coastguard Worker    From PyTorch 2.0.0 onwards, the dashed ``--local-rank`` is preferred over the
76*da0073e9SAndroid Build Coastguard Worker    previously used underscored ``--local_rank``.
77*da0073e9SAndroid Build Coastguard Worker
78*da0073e9SAndroid Build Coastguard Worker    For backward compatibility, it may be necessary for users to handle both
79*da0073e9SAndroid Build Coastguard Worker    cases in their argument parsing code. This means including both ``"--local-rank"``
80*da0073e9SAndroid Build Coastguard Worker    and ``"--local_rank"`` in the argument parser. If only ``"--local_rank"`` is
81*da0073e9SAndroid Build Coastguard Worker    provided, the launcher will trigger an error: "error: unrecognized arguments:
82*da0073e9SAndroid Build Coastguard Worker    --local-rank=<rank>". For training code that only supports PyTorch 2.0.0+,
83*da0073e9SAndroid Build Coastguard Worker    including ``"--local-rank"`` should be sufficient.
84*da0073e9SAndroid Build Coastguard Worker
85*da0073e9SAndroid Build Coastguard Worker    ::
86*da0073e9SAndroid Build Coastguard Worker
87*da0073e9SAndroid Build Coastguard Worker        >>> # xdoctest: +SKIP
88*da0073e9SAndroid Build Coastguard Worker        >>> import argparse
89*da0073e9SAndroid Build Coastguard Worker        >>> parser = argparse.ArgumentParser()
90*da0073e9SAndroid Build Coastguard Worker        >>> parser.add_argument("--local-rank", "--local_rank", type=int)
91*da0073e9SAndroid Build Coastguard Worker        >>> args = parser.parse_args()
92*da0073e9SAndroid Build Coastguard Worker
93*da0073e9SAndroid Build Coastguard WorkerThe aformentioned changes suffice to migrate from ``torch.distributed.launch`` to ``torchrun``.
94*da0073e9SAndroid Build Coastguard WorkerTo take advantage of new features such as elasticity, fault-tolerance, and error reporting of ``torchrun``
95*da0073e9SAndroid Build Coastguard Workerplease refer to:
96*da0073e9SAndroid Build Coastguard Worker
97*da0073e9SAndroid Build Coastguard Worker* :ref:`elastic_train_script` for more information on authoring training scripts that are ``torchrun`` compliant.
98*da0073e9SAndroid Build Coastguard Worker* the rest of this page for more information on the features of ``torchrun``.
99*da0073e9SAndroid Build Coastguard Worker
100*da0073e9SAndroid Build Coastguard Worker
101*da0073e9SAndroid Build Coastguard WorkerUsage
102*da0073e9SAndroid Build Coastguard Worker--------
103*da0073e9SAndroid Build Coastguard Worker
104*da0073e9SAndroid Build Coastguard WorkerSingle-node multi-worker
105*da0073e9SAndroid Build Coastguard Worker++++++++++++++++++++++++++++++
106*da0073e9SAndroid Build Coastguard Worker
107*da0073e9SAndroid Build Coastguard Worker::
108*da0073e9SAndroid Build Coastguard Worker
109*da0073e9SAndroid Build Coastguard Worker    torchrun
110*da0073e9SAndroid Build Coastguard Worker        --standalone
111*da0073e9SAndroid Build Coastguard Worker        --nnodes=1
112*da0073e9SAndroid Build Coastguard Worker        --nproc-per-node=$NUM_TRAINERS
113*da0073e9SAndroid Build Coastguard Worker        YOUR_TRAINING_SCRIPT.py (--arg1 ... train script args...)
114*da0073e9SAndroid Build Coastguard Worker
115*da0073e9SAndroid Build Coastguard WorkerStacked single-node multi-worker
116*da0073e9SAndroid Build Coastguard Worker+++++++++++++++++++++++++++++++++++
117*da0073e9SAndroid Build Coastguard Worker
118*da0073e9SAndroid Build Coastguard WorkerTo run multiple instances (separate jobs) of single-node, multi-worker on the
119*da0073e9SAndroid Build Coastguard Workersame host, we need to make sure that each instance (job) is
120*da0073e9SAndroid Build Coastguard Workersetup on different ports to avoid port conflicts (or worse, two jobs being merged
121*da0073e9SAndroid Build Coastguard Workeras a single job). To do this you have to run with ``--rdzv-backend=c10d``
122*da0073e9SAndroid Build Coastguard Workerand specify a different port by setting ``--rdzv-endpoint=localhost:$PORT_k``.
123*da0073e9SAndroid Build Coastguard WorkerFor ``--nodes=1``, its often convenient to let ``torchrun`` pick a free random
124*da0073e9SAndroid Build Coastguard Workerport automatically instead of manually assigning different ports for each run.
125*da0073e9SAndroid Build Coastguard Worker
126*da0073e9SAndroid Build Coastguard Worker::
127*da0073e9SAndroid Build Coastguard Worker
128*da0073e9SAndroid Build Coastguard Worker    torchrun
129*da0073e9SAndroid Build Coastguard Worker        --rdzv-backend=c10d
130*da0073e9SAndroid Build Coastguard Worker        --rdzv-endpoint=localhost:0
131*da0073e9SAndroid Build Coastguard Worker        --nnodes=1
132*da0073e9SAndroid Build Coastguard Worker        --nproc-per-node=$NUM_TRAINERS
133*da0073e9SAndroid Build Coastguard Worker        YOUR_TRAINING_SCRIPT.py (--arg1 ... train script args...)
134*da0073e9SAndroid Build Coastguard Worker
135*da0073e9SAndroid Build Coastguard Worker
136*da0073e9SAndroid Build Coastguard WorkerFault tolerant (fixed sized number of workers, no elasticity, tolerates 3 failures)
137*da0073e9SAndroid Build Coastguard Worker++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
138*da0073e9SAndroid Build Coastguard Worker
139*da0073e9SAndroid Build Coastguard Worker::
140*da0073e9SAndroid Build Coastguard Worker
141*da0073e9SAndroid Build Coastguard Worker    torchrun
142*da0073e9SAndroid Build Coastguard Worker        --nnodes=$NUM_NODES
143*da0073e9SAndroid Build Coastguard Worker        --nproc-per-node=$NUM_TRAINERS
144*da0073e9SAndroid Build Coastguard Worker        --max-restarts=3
145*da0073e9SAndroid Build Coastguard Worker        --rdzv-id=$JOB_ID
146*da0073e9SAndroid Build Coastguard Worker        --rdzv-backend=c10d
147*da0073e9SAndroid Build Coastguard Worker        --rdzv-endpoint=$HOST_NODE_ADDR
148*da0073e9SAndroid Build Coastguard Worker        YOUR_TRAINING_SCRIPT.py (--arg1 ... train script args...)
149*da0073e9SAndroid Build Coastguard Worker
150*da0073e9SAndroid Build Coastguard Worker``HOST_NODE_ADDR``, in form <host>[:<port>] (e.g. node1.example.com:29400), specifies the node and
151*da0073e9SAndroid Build Coastguard Workerthe port on which the C10d rendezvous backend should be instantiated and hosted. It can be any
152*da0073e9SAndroid Build Coastguard Workernode in your training cluster, but ideally you should pick a node that has a high bandwidth.
153*da0073e9SAndroid Build Coastguard Worker
154*da0073e9SAndroid Build Coastguard Worker.. note::
155*da0073e9SAndroid Build Coastguard Worker   If no port number is specified ``HOST_NODE_ADDR`` defaults to 29400.
156*da0073e9SAndroid Build Coastguard Worker
157*da0073e9SAndroid Build Coastguard WorkerElastic (``min=1``, ``max=4``, tolerates up to 3 membership changes or failures)
158*da0073e9SAndroid Build Coastguard Worker+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
159*da0073e9SAndroid Build Coastguard Worker
160*da0073e9SAndroid Build Coastguard Worker::
161*da0073e9SAndroid Build Coastguard Worker
162*da0073e9SAndroid Build Coastguard Worker    torchrun
163*da0073e9SAndroid Build Coastguard Worker        --nnodes=1:4
164*da0073e9SAndroid Build Coastguard Worker        --nproc-per-node=$NUM_TRAINERS
165*da0073e9SAndroid Build Coastguard Worker        --max-restarts=3
166*da0073e9SAndroid Build Coastguard Worker        --rdzv-id=$JOB_ID
167*da0073e9SAndroid Build Coastguard Worker        --rdzv-backend=c10d
168*da0073e9SAndroid Build Coastguard Worker        --rdzv-endpoint=$HOST_NODE_ADDR
169*da0073e9SAndroid Build Coastguard Worker        YOUR_TRAINING_SCRIPT.py (--arg1 ... train script args...)
170*da0073e9SAndroid Build Coastguard Worker
171*da0073e9SAndroid Build Coastguard Worker``HOST_NODE_ADDR``, in form <host>[:<port>] (e.g. node1.example.com:29400), specifies the node and
172*da0073e9SAndroid Build Coastguard Workerthe port on which the C10d rendezvous backend should be instantiated and hosted. It can be any
173*da0073e9SAndroid Build Coastguard Workernode in your training cluster, but ideally you should pick a node that has a high bandwidth.
174*da0073e9SAndroid Build Coastguard Worker
175*da0073e9SAndroid Build Coastguard Worker.. note::
176*da0073e9SAndroid Build Coastguard Worker   If no port number is specified ``HOST_NODE_ADDR`` defaults to 29400.
177*da0073e9SAndroid Build Coastguard Worker
178*da0073e9SAndroid Build Coastguard WorkerNote on rendezvous backend
179*da0073e9SAndroid Build Coastguard Worker------------------------------
180*da0073e9SAndroid Build Coastguard Worker
181*da0073e9SAndroid Build Coastguard WorkerFor multi-node training you need to specify:
182*da0073e9SAndroid Build Coastguard Worker
183*da0073e9SAndroid Build Coastguard Worker1. ``--rdzv-id``: A unique job id (shared by all nodes participating in the job)
184*da0073e9SAndroid Build Coastguard Worker2. ``--rdzv-backend``: An implementation of
185*da0073e9SAndroid Build Coastguard Worker   :py:class:`torch.distributed.elastic.rendezvous.RendezvousHandler`
186*da0073e9SAndroid Build Coastguard Worker3. ``--rdzv-endpoint``: The endpoint where the rendezvous backend is running; usually in form
187*da0073e9SAndroid Build Coastguard Worker   ``host:port``.
188*da0073e9SAndroid Build Coastguard Worker
189*da0073e9SAndroid Build Coastguard WorkerCurrently ``c10d`` (recommended), ``etcd-v2``, and ``etcd`` (legacy)  rendezvous backends are
190*da0073e9SAndroid Build Coastguard Workersupported out of the box. To use ``etcd-v2`` or ``etcd``, setup an etcd server with the ``v2`` api
191*da0073e9SAndroid Build Coastguard Workerenabled (e.g. ``--enable-v2``).
192*da0073e9SAndroid Build Coastguard Worker
193*da0073e9SAndroid Build Coastguard Worker.. warning::
194*da0073e9SAndroid Build Coastguard Worker   ``etcd-v2`` and ``etcd`` rendezvous use etcd API v2. You MUST enable the v2 API on the etcd
195*da0073e9SAndroid Build Coastguard Worker   server. Our tests use etcd v3.4.3.
196*da0073e9SAndroid Build Coastguard Worker
197*da0073e9SAndroid Build Coastguard Worker.. warning::
198*da0073e9SAndroid Build Coastguard Worker   For etcd-based rendezvous we recommend using ``etcd-v2`` over ``etcd`` which is functionally
199*da0073e9SAndroid Build Coastguard Worker   equivalent, but uses a revised implementation. ``etcd`` is in maintenance mode and will be
200*da0073e9SAndroid Build Coastguard Worker   removed in a future version.
201*da0073e9SAndroid Build Coastguard Worker
202*da0073e9SAndroid Build Coastguard WorkerDefinitions
203*da0073e9SAndroid Build Coastguard Worker--------------
204*da0073e9SAndroid Build Coastguard Worker
205*da0073e9SAndroid Build Coastguard Worker1. ``Node`` - A physical instance or a container; maps to the unit that the job manager works with.
206*da0073e9SAndroid Build Coastguard Worker
207*da0073e9SAndroid Build Coastguard Worker2. ``Worker`` - A worker in the context of distributed training.
208*da0073e9SAndroid Build Coastguard Worker
209*da0073e9SAndroid Build Coastguard Worker3. ``WorkerGroup`` - The set of workers that execute the same function (e.g. trainers).
210*da0073e9SAndroid Build Coastguard Worker
211*da0073e9SAndroid Build Coastguard Worker4. ``LocalWorkerGroup`` - A subset of the workers in the worker group running on the same node.
212*da0073e9SAndroid Build Coastguard Worker
213*da0073e9SAndroid Build Coastguard Worker5. ``RANK`` - The rank of the worker within a worker group.
214*da0073e9SAndroid Build Coastguard Worker
215*da0073e9SAndroid Build Coastguard Worker6. ``WORLD_SIZE`` - The total number of workers in a worker group.
216*da0073e9SAndroid Build Coastguard Worker
217*da0073e9SAndroid Build Coastguard Worker7. ``LOCAL_RANK`` - The rank of the worker within a local worker group.
218*da0073e9SAndroid Build Coastguard Worker
219*da0073e9SAndroid Build Coastguard Worker8. ``LOCAL_WORLD_SIZE`` - The size of the local worker group.
220*da0073e9SAndroid Build Coastguard Worker
221*da0073e9SAndroid Build Coastguard Worker9. ``rdzv_id`` - A user-defined id that uniquely identifies the worker group for a job. This id is
222*da0073e9SAndroid Build Coastguard Worker   used by each node to join as a member of a particular worker group.
223*da0073e9SAndroid Build Coastguard Worker
224*da0073e9SAndroid Build Coastguard Worker9. ``rdzv_backend`` - The backend of the rendezvous (e.g. ``c10d``). This is typically a strongly
225*da0073e9SAndroid Build Coastguard Worker   consistent key-value store.
226*da0073e9SAndroid Build Coastguard Worker
227*da0073e9SAndroid Build Coastguard Worker10. ``rdzv_endpoint`` - The rendezvous backend endpoint; usually in form ``<host>:<port>``.
228*da0073e9SAndroid Build Coastguard Worker
229*da0073e9SAndroid Build Coastguard WorkerA ``Node`` runs ``LOCAL_WORLD_SIZE`` workers which comprise a ``LocalWorkerGroup``. The union of
230*da0073e9SAndroid Build Coastguard Workerall ``LocalWorkerGroups`` in the nodes in the job comprise the ``WorkerGroup``.
231*da0073e9SAndroid Build Coastguard Worker
232*da0073e9SAndroid Build Coastguard WorkerEnvironment Variables
233*da0073e9SAndroid Build Coastguard Worker----------------------
234*da0073e9SAndroid Build Coastguard Worker
235*da0073e9SAndroid Build Coastguard WorkerThe following environment variables are made available to you in your script:
236*da0073e9SAndroid Build Coastguard Worker
237*da0073e9SAndroid Build Coastguard Worker1. ``LOCAL_RANK`` -  The local rank.
238*da0073e9SAndroid Build Coastguard Worker
239*da0073e9SAndroid Build Coastguard Worker2. ``RANK`` -  The global rank.
240*da0073e9SAndroid Build Coastguard Worker
241*da0073e9SAndroid Build Coastguard Worker3. ``GROUP_RANK`` - The rank of the worker group. A number between 0 and ``max_nnodes``. When
242*da0073e9SAndroid Build Coastguard Worker   running a single worker group per node, this is the rank of the node.
243*da0073e9SAndroid Build Coastguard Worker
244*da0073e9SAndroid Build Coastguard Worker4. ``ROLE_RANK`` -  The rank of the worker across all the workers that have the same role. The role
245*da0073e9SAndroid Build Coastguard Worker   of the worker is specified in the ``WorkerSpec``.
246*da0073e9SAndroid Build Coastguard Worker
247*da0073e9SAndroid Build Coastguard Worker5. ``LOCAL_WORLD_SIZE`` - The local world size (e.g. number of workers running locally); equals to
248*da0073e9SAndroid Build Coastguard Worker   ``--nproc-per-node`` specified on ``torchrun``.
249*da0073e9SAndroid Build Coastguard Worker
250*da0073e9SAndroid Build Coastguard Worker6. ``WORLD_SIZE`` - The world size (total number of workers in the job).
251*da0073e9SAndroid Build Coastguard Worker
252*da0073e9SAndroid Build Coastguard Worker7. ``ROLE_WORLD_SIZE`` - The total number of workers that was launched with the same role specified
253*da0073e9SAndroid Build Coastguard Worker   in ``WorkerSpec``.
254*da0073e9SAndroid Build Coastguard Worker
255*da0073e9SAndroid Build Coastguard Worker8. ``MASTER_ADDR`` - The FQDN of the host that is running worker with rank 0; used to initialize
256*da0073e9SAndroid Build Coastguard Worker   the Torch Distributed backend.
257*da0073e9SAndroid Build Coastguard Worker
258*da0073e9SAndroid Build Coastguard Worker9. ``MASTER_PORT`` - The port on the ``MASTER_ADDR`` that can be used to host the C10d TCP store.
259*da0073e9SAndroid Build Coastguard Worker
260*da0073e9SAndroid Build Coastguard Worker10. ``TORCHELASTIC_RESTART_COUNT`` - The number of worker group restarts so far.
261*da0073e9SAndroid Build Coastguard Worker
262*da0073e9SAndroid Build Coastguard Worker11. ``TORCHELASTIC_MAX_RESTARTS`` - The configured maximum number of restarts.
263*da0073e9SAndroid Build Coastguard Worker
264*da0073e9SAndroid Build Coastguard Worker12. ``TORCHELASTIC_RUN_ID`` - Equal to the rendezvous ``run_id`` (e.g. unique job id).
265*da0073e9SAndroid Build Coastguard Worker
266*da0073e9SAndroid Build Coastguard Worker13. ``PYTHON_EXEC`` - System executable override. If provided, the python user script will
267*da0073e9SAndroid Build Coastguard Worker    use the value of ``PYTHON_EXEC`` as executable. The `sys.executable` is used by default.
268*da0073e9SAndroid Build Coastguard Worker
269*da0073e9SAndroid Build Coastguard WorkerDeployment
270*da0073e9SAndroid Build Coastguard Worker------------
271*da0073e9SAndroid Build Coastguard Worker
272*da0073e9SAndroid Build Coastguard Worker1. (Not needed for the C10d backend) Start the rendezvous backend server and get the endpoint (to be
273*da0073e9SAndroid Build Coastguard Worker   passed as ``--rdzv-endpoint`` to the launcher script)
274*da0073e9SAndroid Build Coastguard Worker
275*da0073e9SAndroid Build Coastguard Worker2. Single-node multi-worker: Start the launcher on the host to start the agent process which
276*da0073e9SAndroid Build Coastguard Worker   creates and monitors a local worker group.
277*da0073e9SAndroid Build Coastguard Worker
278*da0073e9SAndroid Build Coastguard Worker3. Multi-node multi-worker: Start the launcher with the same arguments on all the nodes
279*da0073e9SAndroid Build Coastguard Worker   participating in training.
280*da0073e9SAndroid Build Coastguard Worker
281*da0073e9SAndroid Build Coastguard WorkerWhen using a job/cluster manager the entry point command to the multi-node job should be this
282*da0073e9SAndroid Build Coastguard Workerlauncher.
283*da0073e9SAndroid Build Coastguard Worker
284*da0073e9SAndroid Build Coastguard WorkerFailure Modes
285*da0073e9SAndroid Build Coastguard Worker---------------
286*da0073e9SAndroid Build Coastguard Worker
287*da0073e9SAndroid Build Coastguard Worker1. Worker failure: For a training job with ``n`` workers, if ``k<=n`` workers fail all workers
288*da0073e9SAndroid Build Coastguard Worker   are stopped and restarted up to ``max_restarts``.
289*da0073e9SAndroid Build Coastguard Worker
290*da0073e9SAndroid Build Coastguard Worker2. Agent failure: An agent failure results in a local worker group failure. It is up to the job
291*da0073e9SAndroid Build Coastguard Worker   manager to fail the entire job (gang semantics) or attempt to replace the node. Both behaviors
292*da0073e9SAndroid Build Coastguard Worker   are supported by the agent.
293*da0073e9SAndroid Build Coastguard Worker
294*da0073e9SAndroid Build Coastguard Worker3. Node failure: Same as agent failure.
295*da0073e9SAndroid Build Coastguard Worker
296*da0073e9SAndroid Build Coastguard WorkerMembership Changes
297*da0073e9SAndroid Build Coastguard Worker--------------------
298*da0073e9SAndroid Build Coastguard Worker
299*da0073e9SAndroid Build Coastguard Worker1. Node departure (scale-down): The agent is notified of the departure, all existing workers are
300*da0073e9SAndroid Build Coastguard Worker   stopped, a new ``WorkerGroup`` is formed, and all workers are started with a new ``RANK`` and
301*da0073e9SAndroid Build Coastguard Worker   ``WORLD_SIZE``.
302*da0073e9SAndroid Build Coastguard Worker
303*da0073e9SAndroid Build Coastguard Worker2. Node arrival (scale-up): The new node is admitted to the job, all existing workers are stopped,
304*da0073e9SAndroid Build Coastguard Worker   a new ``WorkerGroup`` is formed, and all workers are started with a new ``RANK`` and
305*da0073e9SAndroid Build Coastguard Worker   ``WORLD_SIZE``.
306*da0073e9SAndroid Build Coastguard Worker
307*da0073e9SAndroid Build Coastguard WorkerImportant Notices
308*da0073e9SAndroid Build Coastguard Worker--------------------
309*da0073e9SAndroid Build Coastguard Worker
310*da0073e9SAndroid Build Coastguard Worker1. This utility and multi-process distributed (single-node or
311*da0073e9SAndroid Build Coastguard Worker   multi-node) GPU training currently only achieves the best performance using
312*da0073e9SAndroid Build Coastguard Worker   the NCCL distributed backend. Thus NCCL backend is the recommended backend to
313*da0073e9SAndroid Build Coastguard Worker   use for GPU training.
314*da0073e9SAndroid Build Coastguard Worker
315*da0073e9SAndroid Build Coastguard Worker2. The environment variables necessary to initialize a Torch process group are provided to you by
316*da0073e9SAndroid Build Coastguard Worker   this module, no need for you to pass ``RANK`` manually.  To initialize a process group in your
317*da0073e9SAndroid Build Coastguard Worker   training script, simply run:
318*da0073e9SAndroid Build Coastguard Worker
319*da0073e9SAndroid Build Coastguard Worker::
320*da0073e9SAndroid Build Coastguard Worker
321*da0073e9SAndroid Build Coastguard Worker >>> # xdoctest: +SKIP("stub")
322*da0073e9SAndroid Build Coastguard Worker >>> import torch.distributed as dist
323*da0073e9SAndroid Build Coastguard Worker >>> dist.init_process_group(backend="gloo|nccl")
324*da0073e9SAndroid Build Coastguard Worker
325*da0073e9SAndroid Build Coastguard Worker3. In your training program, you can either use regular distributed functions
326*da0073e9SAndroid Build Coastguard Worker   or use :func:`torch.nn.parallel.DistributedDataParallel` module. If your
327*da0073e9SAndroid Build Coastguard Worker   training program uses GPUs for training and you would like to use
328*da0073e9SAndroid Build Coastguard Worker   :func:`torch.nn.parallel.DistributedDataParallel` module,
329*da0073e9SAndroid Build Coastguard Worker   here is how to configure it.
330*da0073e9SAndroid Build Coastguard Worker
331*da0073e9SAndroid Build Coastguard Worker::
332*da0073e9SAndroid Build Coastguard Worker
333*da0073e9SAndroid Build Coastguard Worker    local_rank = int(os.environ["LOCAL_RANK"])
334*da0073e9SAndroid Build Coastguard Worker    model = torch.nn.parallel.DistributedDataParallel(model,
335*da0073e9SAndroid Build Coastguard Worker                                                      device_ids=[local_rank],
336*da0073e9SAndroid Build Coastguard Worker                                                      output_device=local_rank)
337*da0073e9SAndroid Build Coastguard Worker
338*da0073e9SAndroid Build Coastguard WorkerPlease ensure that ``device_ids`` argument is set to be the only GPU device id
339*da0073e9SAndroid Build Coastguard Workerthat your code will be operating on. This is generally the local rank of the
340*da0073e9SAndroid Build Coastguard Workerprocess. In other words, the ``device_ids`` needs to be ``[int(os.environ("LOCAL_RANK"))]``,
341*da0073e9SAndroid Build Coastguard Workerand ``output_device`` needs to be ``int(os.environ("LOCAL_RANK"))`` in order to use this
342*da0073e9SAndroid Build Coastguard Workerutility
343*da0073e9SAndroid Build Coastguard Worker
344*da0073e9SAndroid Build Coastguard Worker
345*da0073e9SAndroid Build Coastguard Worker4. On failures or membership changes ALL surviving workers are killed immediately. Make sure to
346*da0073e9SAndroid Build Coastguard Worker   checkpoint your progress. The frequency of checkpoints should depend on your job's tolerance
347*da0073e9SAndroid Build Coastguard Worker   for lost work.
348*da0073e9SAndroid Build Coastguard Worker
349*da0073e9SAndroid Build Coastguard Worker5. This module only supports homogeneous ``LOCAL_WORLD_SIZE``. That is, it is assumed that all
350*da0073e9SAndroid Build Coastguard Worker   nodes run the same number of local workers (per role).
351*da0073e9SAndroid Build Coastguard Worker
352*da0073e9SAndroid Build Coastguard Worker6. ``RANK`` is NOT stable. Between restarts, the local workers on a node can be assigned a
353*da0073e9SAndroid Build Coastguard Worker   different range of ranks than before. NEVER hard code any assumptions about the stable-ness of
354*da0073e9SAndroid Build Coastguard Worker   ranks or some correlation between ``RANK`` and ``LOCAL_RANK``.
355*da0073e9SAndroid Build Coastguard Worker
356*da0073e9SAndroid Build Coastguard Worker7. When using elasticity (``min_size!=max_size``) DO NOT hard code assumptions about
357*da0073e9SAndroid Build Coastguard Worker   ``WORLD_SIZE`` as the world size can change as nodes are allowed to leave and join.
358*da0073e9SAndroid Build Coastguard Worker
359*da0073e9SAndroid Build Coastguard Worker8. It is recommended for your script to have the following structure:
360*da0073e9SAndroid Build Coastguard Worker
361*da0073e9SAndroid Build Coastguard Worker::
362*da0073e9SAndroid Build Coastguard Worker
363*da0073e9SAndroid Build Coastguard Worker  def main():
364*da0073e9SAndroid Build Coastguard Worker    load_checkpoint(checkpoint_path)
365*da0073e9SAndroid Build Coastguard Worker    initialize()
366*da0073e9SAndroid Build Coastguard Worker    train()
367*da0073e9SAndroid Build Coastguard Worker
368*da0073e9SAndroid Build Coastguard Worker  def train():
369*da0073e9SAndroid Build Coastguard Worker    for batch in iter(dataset):
370*da0073e9SAndroid Build Coastguard Worker      train_step(batch)
371*da0073e9SAndroid Build Coastguard Worker
372*da0073e9SAndroid Build Coastguard Worker      if should_checkpoint:
373*da0073e9SAndroid Build Coastguard Worker        save_checkpoint(checkpoint_path)
374*da0073e9SAndroid Build Coastguard Worker
375*da0073e9SAndroid Build Coastguard Worker9. (Recommended) On worker errors, this tool will summarize the details of the error
376*da0073e9SAndroid Build Coastguard Worker   (e.g. time, rank, host, pid, traceback, etc). On each node, the first error (by timestamp)
377*da0073e9SAndroid Build Coastguard Worker   is heuristically reported as the "Root Cause" error. To get tracebacks as part of this
378*da0073e9SAndroid Build Coastguard Worker   error summary print out, you must decorate your main entrypoint function in your
379*da0073e9SAndroid Build Coastguard Worker   training script as shown in the example below. If not decorated, then the summary
380*da0073e9SAndroid Build Coastguard Worker   will not include the traceback of the exception and will only contain the exitcode.
381*da0073e9SAndroid Build Coastguard Worker   For details on torchelastic error handling see: https://pytorch.org/docs/stable/elastic/errors.html
382*da0073e9SAndroid Build Coastguard Worker
383*da0073e9SAndroid Build Coastguard Worker::
384*da0073e9SAndroid Build Coastguard Worker
385*da0073e9SAndroid Build Coastguard Worker  from torch.distributed.elastic.multiprocessing.errors import record
386*da0073e9SAndroid Build Coastguard Worker
387*da0073e9SAndroid Build Coastguard Worker  @record
388*da0073e9SAndroid Build Coastguard Worker  def main():
389*da0073e9SAndroid Build Coastguard Worker      # do train
390*da0073e9SAndroid Build Coastguard Worker      pass
391*da0073e9SAndroid Build Coastguard Worker
392*da0073e9SAndroid Build Coastguard Worker  if __name__ == "__main__":
393*da0073e9SAndroid Build Coastguard Worker      main()
394*da0073e9SAndroid Build Coastguard Worker
395*da0073e9SAndroid Build Coastguard Worker"""
396*da0073e9SAndroid Build Coastguard Workerimport logging
397*da0073e9SAndroid Build Coastguard Workerimport os
398*da0073e9SAndroid Build Coastguard Workerimport sys
399*da0073e9SAndroid Build Coastguard Workerimport uuid
400*da0073e9SAndroid Build Coastguard Workerfrom argparse import ArgumentParser, REMAINDER
401*da0073e9SAndroid Build Coastguard Workerfrom importlib import metadata
402*da0073e9SAndroid Build Coastguard Workerfrom typing import Callable, List, Optional, Set, Tuple, Type, Union
403*da0073e9SAndroid Build Coastguard Worker
404*da0073e9SAndroid Build Coastguard Workerimport torch
405*da0073e9SAndroid Build Coastguard Workerfrom torch.distributed.argparse_util import check_env, env
406*da0073e9SAndroid Build Coastguard Workerfrom torch.distributed.elastic.multiprocessing import DefaultLogsSpecs, LogsSpecs, Std
407*da0073e9SAndroid Build Coastguard Workerfrom torch.distributed.elastic.multiprocessing.errors import record
408*da0073e9SAndroid Build Coastguard Workerfrom torch.distributed.elastic.rendezvous.utils import _parse_rendezvous_config
409*da0073e9SAndroid Build Coastguard Workerfrom torch.distributed.elastic.utils import macros
410*da0073e9SAndroid Build Coastguard Workerfrom torch.distributed.elastic.utils.logging import get_logger
411*da0073e9SAndroid Build Coastguard Workerfrom torch.distributed.launcher.api import elastic_launch, LaunchConfig
412*da0073e9SAndroid Build Coastguard Workerfrom torch.utils.backend_registration import _get_custom_mod_func
413*da0073e9SAndroid Build Coastguard Worker
414*da0073e9SAndroid Build Coastguard Worker
415*da0073e9SAndroid Build Coastguard Workerlogger = get_logger(__name__)
416*da0073e9SAndroid Build Coastguard Worker
417*da0073e9SAndroid Build Coastguard Worker
418*da0073e9SAndroid Build Coastguard Workerdef get_args_parser() -> ArgumentParser:
419*da0073e9SAndroid Build Coastguard Worker    """Parse the command line options."""
420*da0073e9SAndroid Build Coastguard Worker    parser = ArgumentParser(description="Torch Distributed Elastic Training Launcher")
421*da0073e9SAndroid Build Coastguard Worker
422*da0073e9SAndroid Build Coastguard Worker    #
423*da0073e9SAndroid Build Coastguard Worker    # Worker/node size related arguments.
424*da0073e9SAndroid Build Coastguard Worker    #
425*da0073e9SAndroid Build Coastguard Worker
426*da0073e9SAndroid Build Coastguard Worker    parser.add_argument(
427*da0073e9SAndroid Build Coastguard Worker        "--nnodes",
428*da0073e9SAndroid Build Coastguard Worker        action=env,
429*da0073e9SAndroid Build Coastguard Worker        type=str,
430*da0073e9SAndroid Build Coastguard Worker        default="1:1",
431*da0073e9SAndroid Build Coastguard Worker        help="Number of nodes, or the range of nodes in form <minimum_nodes>:<maximum_nodes>.",
432*da0073e9SAndroid Build Coastguard Worker    )
433*da0073e9SAndroid Build Coastguard Worker    parser.add_argument(
434*da0073e9SAndroid Build Coastguard Worker        "--nproc-per-node",
435*da0073e9SAndroid Build Coastguard Worker        "--nproc_per_node",
436*da0073e9SAndroid Build Coastguard Worker        action=env,
437*da0073e9SAndroid Build Coastguard Worker        type=str,
438*da0073e9SAndroid Build Coastguard Worker        default="1",
439*da0073e9SAndroid Build Coastguard Worker        help="Number of workers per node; supported values: [auto, cpu, gpu, int].",
440*da0073e9SAndroid Build Coastguard Worker    )
441*da0073e9SAndroid Build Coastguard Worker
442*da0073e9SAndroid Build Coastguard Worker    #
443*da0073e9SAndroid Build Coastguard Worker    # Rendezvous related arguments
444*da0073e9SAndroid Build Coastguard Worker    #
445*da0073e9SAndroid Build Coastguard Worker
446*da0073e9SAndroid Build Coastguard Worker    parser.add_argument(
447*da0073e9SAndroid Build Coastguard Worker        "--rdzv-backend",
448*da0073e9SAndroid Build Coastguard Worker        "--rdzv_backend",
449*da0073e9SAndroid Build Coastguard Worker        action=env,
450*da0073e9SAndroid Build Coastguard Worker        type=str,
451*da0073e9SAndroid Build Coastguard Worker        default="static",
452*da0073e9SAndroid Build Coastguard Worker        help="Rendezvous backend.",
453*da0073e9SAndroid Build Coastguard Worker    )
454*da0073e9SAndroid Build Coastguard Worker    parser.add_argument(
455*da0073e9SAndroid Build Coastguard Worker        "--rdzv-endpoint",
456*da0073e9SAndroid Build Coastguard Worker        "--rdzv_endpoint",
457*da0073e9SAndroid Build Coastguard Worker        action=env,
458*da0073e9SAndroid Build Coastguard Worker        type=str,
459*da0073e9SAndroid Build Coastguard Worker        default="",
460*da0073e9SAndroid Build Coastguard Worker        help="Rendezvous backend endpoint; usually in form <host>:<port>.",
461*da0073e9SAndroid Build Coastguard Worker    )
462*da0073e9SAndroid Build Coastguard Worker    parser.add_argument(
463*da0073e9SAndroid Build Coastguard Worker        "--rdzv-id",
464*da0073e9SAndroid Build Coastguard Worker        "--rdzv_id",
465*da0073e9SAndroid Build Coastguard Worker        action=env,
466*da0073e9SAndroid Build Coastguard Worker        type=str,
467*da0073e9SAndroid Build Coastguard Worker        default="none",
468*da0073e9SAndroid Build Coastguard Worker        help="User-defined group id.",
469*da0073e9SAndroid Build Coastguard Worker    )
470*da0073e9SAndroid Build Coastguard Worker    parser.add_argument(
471*da0073e9SAndroid Build Coastguard Worker        "--rdzv-conf",
472*da0073e9SAndroid Build Coastguard Worker        "--rdzv_conf",
473*da0073e9SAndroid Build Coastguard Worker        action=env,
474*da0073e9SAndroid Build Coastguard Worker        type=str,
475*da0073e9SAndroid Build Coastguard Worker        default="",
476*da0073e9SAndroid Build Coastguard Worker        help="Additional rendezvous configuration (<key1>=<value1>,<key2>=<value2>,...).",
477*da0073e9SAndroid Build Coastguard Worker    )
478*da0073e9SAndroid Build Coastguard Worker    parser.add_argument(
479*da0073e9SAndroid Build Coastguard Worker        "--standalone",
480*da0073e9SAndroid Build Coastguard Worker        action=check_env,
481*da0073e9SAndroid Build Coastguard Worker        help="Start a local standalone rendezvous backend that is represented by a C10d TCP store "
482*da0073e9SAndroid Build Coastguard Worker        "on a free port. Useful when launching single-node, multi-worker job. If specified "
483*da0073e9SAndroid Build Coastguard Worker        "--rdzv-backend, --rdzv-endpoint, --rdzv-id are auto-assigned and any explicitly set values "
484*da0073e9SAndroid Build Coastguard Worker        "are ignored.",
485*da0073e9SAndroid Build Coastguard Worker    )
486*da0073e9SAndroid Build Coastguard Worker
487*da0073e9SAndroid Build Coastguard Worker    #
488*da0073e9SAndroid Build Coastguard Worker    # User-code launch related arguments.
489*da0073e9SAndroid Build Coastguard Worker    #
490*da0073e9SAndroid Build Coastguard Worker
491*da0073e9SAndroid Build Coastguard Worker    parser.add_argument(
492*da0073e9SAndroid Build Coastguard Worker        "--max-restarts",
493*da0073e9SAndroid Build Coastguard Worker        "--max_restarts",
494*da0073e9SAndroid Build Coastguard Worker        action=env,
495*da0073e9SAndroid Build Coastguard Worker        type=int,
496*da0073e9SAndroid Build Coastguard Worker        default=0,
497*da0073e9SAndroid Build Coastguard Worker        help="Maximum number of worker group restarts before failing.",
498*da0073e9SAndroid Build Coastguard Worker    )
499*da0073e9SAndroid Build Coastguard Worker    parser.add_argument(
500*da0073e9SAndroid Build Coastguard Worker        "--monitor-interval",
501*da0073e9SAndroid Build Coastguard Worker        "--monitor_interval",
502*da0073e9SAndroid Build Coastguard Worker        action=env,
503*da0073e9SAndroid Build Coastguard Worker        type=float,
504*da0073e9SAndroid Build Coastguard Worker        default=0.1,
505*da0073e9SAndroid Build Coastguard Worker        help="Interval, in seconds, to monitor the state of workers.",
506*da0073e9SAndroid Build Coastguard Worker    )
507*da0073e9SAndroid Build Coastguard Worker    parser.add_argument(
508*da0073e9SAndroid Build Coastguard Worker        "--start-method",
509*da0073e9SAndroid Build Coastguard Worker        "--start_method",
510*da0073e9SAndroid Build Coastguard Worker        action=env,
511*da0073e9SAndroid Build Coastguard Worker        type=str,
512*da0073e9SAndroid Build Coastguard Worker        default="spawn",
513*da0073e9SAndroid Build Coastguard Worker        choices=["spawn", "fork", "forkserver"],
514*da0073e9SAndroid Build Coastguard Worker        help="Multiprocessing start method to use when creating workers.",
515*da0073e9SAndroid Build Coastguard Worker    )
516*da0073e9SAndroid Build Coastguard Worker    parser.add_argument(
517*da0073e9SAndroid Build Coastguard Worker        "--role",
518*da0073e9SAndroid Build Coastguard Worker        action=env,
519*da0073e9SAndroid Build Coastguard Worker        type=str,
520*da0073e9SAndroid Build Coastguard Worker        default="default",
521*da0073e9SAndroid Build Coastguard Worker        help="User-defined role for the workers.",
522*da0073e9SAndroid Build Coastguard Worker    )
523*da0073e9SAndroid Build Coastguard Worker    parser.add_argument(
524*da0073e9SAndroid Build Coastguard Worker        "-m",
525*da0073e9SAndroid Build Coastguard Worker        "--module",
526*da0073e9SAndroid Build Coastguard Worker        action=check_env,
527*da0073e9SAndroid Build Coastguard Worker        help="Change each process to interpret the launch script as a Python module, executing "
528*da0073e9SAndroid Build Coastguard Worker        "with the same behavior as 'python -m'.",
529*da0073e9SAndroid Build Coastguard Worker    )
530*da0073e9SAndroid Build Coastguard Worker    parser.add_argument(
531*da0073e9SAndroid Build Coastguard Worker        "--no-python",
532*da0073e9SAndroid Build Coastguard Worker        "--no_python",
533*da0073e9SAndroid Build Coastguard Worker        action=check_env,
534*da0073e9SAndroid Build Coastguard Worker        help="Skip prepending the training script with 'python' - just execute it directly. Useful "
535*da0073e9SAndroid Build Coastguard Worker        "when the script is not a Python script.",
536*da0073e9SAndroid Build Coastguard Worker    )
537*da0073e9SAndroid Build Coastguard Worker
538*da0073e9SAndroid Build Coastguard Worker    parser.add_argument(
539*da0073e9SAndroid Build Coastguard Worker        "--run-path",
540*da0073e9SAndroid Build Coastguard Worker        "--run_path",
541*da0073e9SAndroid Build Coastguard Worker        action=check_env,
542*da0073e9SAndroid Build Coastguard Worker        help="Run the training script with runpy.run_path in the same interpreter."
543*da0073e9SAndroid Build Coastguard Worker        " Script must be provided as an abs path (e.g. /abs/path/script.py)."
544*da0073e9SAndroid Build Coastguard Worker        " Takes precedence over --no-python.",
545*da0073e9SAndroid Build Coastguard Worker    )
546*da0073e9SAndroid Build Coastguard Worker    parser.add_argument(
547*da0073e9SAndroid Build Coastguard Worker        "--log-dir",
548*da0073e9SAndroid Build Coastguard Worker        "--log_dir",
549*da0073e9SAndroid Build Coastguard Worker        action=env,
550*da0073e9SAndroid Build Coastguard Worker        type=str,
551*da0073e9SAndroid Build Coastguard Worker        default=None,
552*da0073e9SAndroid Build Coastguard Worker        help="Base directory to use for log files (e.g. /var/log/torch/elastic). The same "
553*da0073e9SAndroid Build Coastguard Worker        "directory is re-used for multiple runs (a unique job-level sub-directory is created with "
554*da0073e9SAndroid Build Coastguard Worker        "rdzv_id as the prefix).",
555*da0073e9SAndroid Build Coastguard Worker    )
556*da0073e9SAndroid Build Coastguard Worker    parser.add_argument(
557*da0073e9SAndroid Build Coastguard Worker        "-r",
558*da0073e9SAndroid Build Coastguard Worker        "--redirects",
559*da0073e9SAndroid Build Coastguard Worker        action=env,
560*da0073e9SAndroid Build Coastguard Worker        type=str,
561*da0073e9SAndroid Build Coastguard Worker        default="0",
562*da0073e9SAndroid Build Coastguard Worker        help="Redirect std streams into a log file in the log directory (e.g. [-r 3] redirects "
563*da0073e9SAndroid Build Coastguard Worker        "both stdout+stderr for all workers, [-r 0:1,1:2] redirects stdout for local rank 0 and "
564*da0073e9SAndroid Build Coastguard Worker        "stderr for local rank 1).",
565*da0073e9SAndroid Build Coastguard Worker    )
566*da0073e9SAndroid Build Coastguard Worker    parser.add_argument(
567*da0073e9SAndroid Build Coastguard Worker        "-t",
568*da0073e9SAndroid Build Coastguard Worker        "--tee",
569*da0073e9SAndroid Build Coastguard Worker        action=env,
570*da0073e9SAndroid Build Coastguard Worker        type=str,
571*da0073e9SAndroid Build Coastguard Worker        default="0",
572*da0073e9SAndroid Build Coastguard Worker        help="Tee std streams into a log file and also to console (see --redirects for format).",
573*da0073e9SAndroid Build Coastguard Worker    )
574*da0073e9SAndroid Build Coastguard Worker
575*da0073e9SAndroid Build Coastguard Worker    parser.add_argument(
576*da0073e9SAndroid Build Coastguard Worker        "--local-ranks-filter",
577*da0073e9SAndroid Build Coastguard Worker        "--local_ranks_filter",
578*da0073e9SAndroid Build Coastguard Worker        action=env,
579*da0073e9SAndroid Build Coastguard Worker        type=str,
580*da0073e9SAndroid Build Coastguard Worker        default="",
581*da0073e9SAndroid Build Coastguard Worker        help="Only show logs from specified ranks in console (e.g. [--local_ranks_filter=0,1,2] will "
582*da0073e9SAndroid Build Coastguard Worker        "only show logs from rank 0, 1 and 2). This will only apply to stdout and stderr, not to"
583*da0073e9SAndroid Build Coastguard Worker        "log files saved via --redirect or --tee",
584*da0073e9SAndroid Build Coastguard Worker    )
585*da0073e9SAndroid Build Coastguard Worker
586*da0073e9SAndroid Build Coastguard Worker    #
587*da0073e9SAndroid Build Coastguard Worker    # Backwards compatible parameters with caffe2.distributed.launch.
588*da0073e9SAndroid Build Coastguard Worker    #
589*da0073e9SAndroid Build Coastguard Worker
590*da0073e9SAndroid Build Coastguard Worker    parser.add_argument(
591*da0073e9SAndroid Build Coastguard Worker        "--node-rank",
592*da0073e9SAndroid Build Coastguard Worker        "--node_rank",
593*da0073e9SAndroid Build Coastguard Worker        type=int,
594*da0073e9SAndroid Build Coastguard Worker        action=env,
595*da0073e9SAndroid Build Coastguard Worker        default=0,
596*da0073e9SAndroid Build Coastguard Worker        help="Rank of the node for multi-node distributed training.",
597*da0073e9SAndroid Build Coastguard Worker    )
598*da0073e9SAndroid Build Coastguard Worker    parser.add_argument(
599*da0073e9SAndroid Build Coastguard Worker        "--master-addr",
600*da0073e9SAndroid Build Coastguard Worker        "--master_addr",
601*da0073e9SAndroid Build Coastguard Worker        default="127.0.0.1",
602*da0073e9SAndroid Build Coastguard Worker        type=str,
603*da0073e9SAndroid Build Coastguard Worker        action=env,
604*da0073e9SAndroid Build Coastguard Worker        help="Address of the master node (rank 0) that only used for static rendezvous. It should "
605*da0073e9SAndroid Build Coastguard Worker        "be either the IP address or the hostname of rank 0. For single node multi-proc training "
606*da0073e9SAndroid Build Coastguard Worker        "the --master-addr can simply be 127.0.0.1; IPv6 should have the pattern "
607*da0073e9SAndroid Build Coastguard Worker        "`[0:0:0:0:0:0:0:1]`.",
608*da0073e9SAndroid Build Coastguard Worker    )
609*da0073e9SAndroid Build Coastguard Worker    parser.add_argument(
610*da0073e9SAndroid Build Coastguard Worker        "--master-port",
611*da0073e9SAndroid Build Coastguard Worker        "--master_port",
612*da0073e9SAndroid Build Coastguard Worker        default=29500,
613*da0073e9SAndroid Build Coastguard Worker        type=int,
614*da0073e9SAndroid Build Coastguard Worker        action=env,
615*da0073e9SAndroid Build Coastguard Worker        help="Port on the master node (rank 0) to be used for communication during distributed "
616*da0073e9SAndroid Build Coastguard Worker        "training. It is only used for static rendezvous.",
617*da0073e9SAndroid Build Coastguard Worker    )
618*da0073e9SAndroid Build Coastguard Worker    parser.add_argument(
619*da0073e9SAndroid Build Coastguard Worker        "--local-addr",
620*da0073e9SAndroid Build Coastguard Worker        "--local_addr",
621*da0073e9SAndroid Build Coastguard Worker        default=None,
622*da0073e9SAndroid Build Coastguard Worker        type=str,
623*da0073e9SAndroid Build Coastguard Worker        action=env,
624*da0073e9SAndroid Build Coastguard Worker        help="Address of the local node. If specified, will use the given address for connection. "
625*da0073e9SAndroid Build Coastguard Worker        "Else, will look up the local node address instead. Else, it will be default to local "
626*da0073e9SAndroid Build Coastguard Worker        "machine's FQDN.",
627*da0073e9SAndroid Build Coastguard Worker    )
628*da0073e9SAndroid Build Coastguard Worker
629*da0073e9SAndroid Build Coastguard Worker    parser.add_argument(
630*da0073e9SAndroid Build Coastguard Worker        "--logs-specs",
631*da0073e9SAndroid Build Coastguard Worker        "--logs_specs",
632*da0073e9SAndroid Build Coastguard Worker        default=None,
633*da0073e9SAndroid Build Coastguard Worker        type=str,
634*da0073e9SAndroid Build Coastguard Worker        help="torchrun.logs_specs group entrypoint name, value must be type of LogsSpecs. "
635*da0073e9SAndroid Build Coastguard Worker        "Can be used to override custom logging behavior.",
636*da0073e9SAndroid Build Coastguard Worker    )
637*da0073e9SAndroid Build Coastguard Worker
638*da0073e9SAndroid Build Coastguard Worker    #
639*da0073e9SAndroid Build Coastguard Worker    # Positional arguments.
640*da0073e9SAndroid Build Coastguard Worker    #
641*da0073e9SAndroid Build Coastguard Worker
642*da0073e9SAndroid Build Coastguard Worker    parser.add_argument(
643*da0073e9SAndroid Build Coastguard Worker        "training_script",
644*da0073e9SAndroid Build Coastguard Worker        type=str,
645*da0073e9SAndroid Build Coastguard Worker        help="Full path to the (single GPU) training program/script to be launched in parallel, "
646*da0073e9SAndroid Build Coastguard Worker        "followed by all the arguments for the training script.",
647*da0073e9SAndroid Build Coastguard Worker    )
648*da0073e9SAndroid Build Coastguard Worker
649*da0073e9SAndroid Build Coastguard Worker    # Rest from the training program.
650*da0073e9SAndroid Build Coastguard Worker    parser.add_argument("training_script_args", nargs=REMAINDER)
651*da0073e9SAndroid Build Coastguard Worker
652*da0073e9SAndroid Build Coastguard Worker    return parser
653*da0073e9SAndroid Build Coastguard Worker
654*da0073e9SAndroid Build Coastguard Worker
655*da0073e9SAndroid Build Coastguard Workerdef parse_args(args):
656*da0073e9SAndroid Build Coastguard Worker    parser = get_args_parser()
657*da0073e9SAndroid Build Coastguard Worker    return parser.parse_args(args)
658*da0073e9SAndroid Build Coastguard Worker
659*da0073e9SAndroid Build Coastguard Worker
660*da0073e9SAndroid Build Coastguard Workerdef parse_min_max_nnodes(nnodes: str):
661*da0073e9SAndroid Build Coastguard Worker    arr = nnodes.split(":")
662*da0073e9SAndroid Build Coastguard Worker
663*da0073e9SAndroid Build Coastguard Worker    if len(arr) == 1:
664*da0073e9SAndroid Build Coastguard Worker        min_nodes = max_nodes = int(arr[0])
665*da0073e9SAndroid Build Coastguard Worker    elif len(arr) == 2:
666*da0073e9SAndroid Build Coastguard Worker        min_nodes = int(arr[0])
667*da0073e9SAndroid Build Coastguard Worker        max_nodes = int(arr[1])
668*da0073e9SAndroid Build Coastguard Worker    else:
669*da0073e9SAndroid Build Coastguard Worker        raise RuntimeError(f'nnodes={nnodes} is not in "MIN:MAX" format')  # noqa: E231
670*da0073e9SAndroid Build Coastguard Worker
671*da0073e9SAndroid Build Coastguard Worker    return min_nodes, max_nodes
672*da0073e9SAndroid Build Coastguard Worker
673*da0073e9SAndroid Build Coastguard Worker
674*da0073e9SAndroid Build Coastguard Workerdef determine_local_world_size(nproc_per_node: str):
675*da0073e9SAndroid Build Coastguard Worker    try:
676*da0073e9SAndroid Build Coastguard Worker        logging.info("Using nproc_per_node=%s.", nproc_per_node)
677*da0073e9SAndroid Build Coastguard Worker        return int(nproc_per_node)
678*da0073e9SAndroid Build Coastguard Worker    except ValueError as e:
679*da0073e9SAndroid Build Coastguard Worker        if nproc_per_node == "cpu":
680*da0073e9SAndroid Build Coastguard Worker            num_proc = os.cpu_count()
681*da0073e9SAndroid Build Coastguard Worker            device_type = "cpu"
682*da0073e9SAndroid Build Coastguard Worker        elif nproc_per_node == "gpu":
683*da0073e9SAndroid Build Coastguard Worker            if not torch.cuda.is_available():
684*da0073e9SAndroid Build Coastguard Worker                raise ValueError("Cuda is not available.") from e
685*da0073e9SAndroid Build Coastguard Worker            device_type = "gpu"
686*da0073e9SAndroid Build Coastguard Worker            num_proc = torch.cuda.device_count()
687*da0073e9SAndroid Build Coastguard Worker        elif nproc_per_node == torch._C._get_privateuse1_backend_name():
688*da0073e9SAndroid Build Coastguard Worker            if not _get_custom_mod_func("is_available")():
689*da0073e9SAndroid Build Coastguard Worker                raise ValueError(f"{nproc_per_node} is not available.") from e
690*da0073e9SAndroid Build Coastguard Worker            device_type = nproc_per_node
691*da0073e9SAndroid Build Coastguard Worker            num_proc = _get_custom_mod_func("device_count")()
692*da0073e9SAndroid Build Coastguard Worker        elif nproc_per_node == "auto":
693*da0073e9SAndroid Build Coastguard Worker            if torch.cuda.is_available():
694*da0073e9SAndroid Build Coastguard Worker                num_proc = torch.cuda.device_count()
695*da0073e9SAndroid Build Coastguard Worker                device_type = "gpu"
696*da0073e9SAndroid Build Coastguard Worker            elif (
697*da0073e9SAndroid Build Coastguard Worker                hasattr(torch, torch._C._get_privateuse1_backend_name())
698*da0073e9SAndroid Build Coastguard Worker                and _get_custom_mod_func("is_available")()
699*da0073e9SAndroid Build Coastguard Worker            ):
700*da0073e9SAndroid Build Coastguard Worker                num_proc = _get_custom_mod_func("device_count")()
701*da0073e9SAndroid Build Coastguard Worker                device_type = torch._C._get_privateuse1_backend_name()
702*da0073e9SAndroid Build Coastguard Worker            else:
703*da0073e9SAndroid Build Coastguard Worker                num_proc = os.cpu_count()
704*da0073e9SAndroid Build Coastguard Worker                device_type = "cpu"
705*da0073e9SAndroid Build Coastguard Worker        else:
706*da0073e9SAndroid Build Coastguard Worker            raise ValueError(
707*da0073e9SAndroid Build Coastguard Worker                f"Unsupported nproc_per_node value: {nproc_per_node}"
708*da0073e9SAndroid Build Coastguard Worker            ) from e
709*da0073e9SAndroid Build Coastguard Worker
710*da0073e9SAndroid Build Coastguard Worker        logger.info(
711*da0073e9SAndroid Build Coastguard Worker            "Using nproc_per_node=%s, setting nproc_per_node to %s since the instance has %s %s",
712*da0073e9SAndroid Build Coastguard Worker            nproc_per_node,
713*da0073e9SAndroid Build Coastguard Worker            num_proc,
714*da0073e9SAndroid Build Coastguard Worker            num_proc,
715*da0073e9SAndroid Build Coastguard Worker            device_type,
716*da0073e9SAndroid Build Coastguard Worker        )
717*da0073e9SAndroid Build Coastguard Worker        return num_proc
718*da0073e9SAndroid Build Coastguard Worker
719*da0073e9SAndroid Build Coastguard Worker
720*da0073e9SAndroid Build Coastguard Workerdef get_rdzv_endpoint(args):
721*da0073e9SAndroid Build Coastguard Worker    if args.rdzv_backend == "static" and not args.rdzv_endpoint:
722*da0073e9SAndroid Build Coastguard Worker        return f"{args.master_addr}:{args.master_port}"  # noqa: E231
723*da0073e9SAndroid Build Coastguard Worker    return args.rdzv_endpoint
724*da0073e9SAndroid Build Coastguard Worker
725*da0073e9SAndroid Build Coastguard Worker
726*da0073e9SAndroid Build Coastguard Workerdef get_use_env(args) -> bool:
727*da0073e9SAndroid Build Coastguard Worker    """
728*da0073e9SAndroid Build Coastguard Worker    Retrieve ``use_env`` from the args.
729*da0073e9SAndroid Build Coastguard Worker
730*da0073e9SAndroid Build Coastguard Worker    ``use_env`` is a legacy argument, if ``use_env`` is False, the
731*da0073e9SAndroid Build Coastguard Worker    ``--node-rank`` argument will be transferred to all worker processes.
732*da0073e9SAndroid Build Coastguard Worker    ``use_env`` is only used by the ``torch.distributed.launch`` and will
733*da0073e9SAndroid Build Coastguard Worker    be deprecated in future releases.
734*da0073e9SAndroid Build Coastguard Worker    """
735*da0073e9SAndroid Build Coastguard Worker    if not hasattr(args, "use_env"):
736*da0073e9SAndroid Build Coastguard Worker        return True
737*da0073e9SAndroid Build Coastguard Worker    return args.use_env
738*da0073e9SAndroid Build Coastguard Worker
739*da0073e9SAndroid Build Coastguard Worker
740*da0073e9SAndroid Build Coastguard Workerdef _get_logs_specs_class(logs_specs_name: Optional[str]) -> Type[LogsSpecs]:
741*da0073e9SAndroid Build Coastguard Worker    """
742*da0073e9SAndroid Build Coastguard Worker    Attemps to load `torchrun.logs_spec` entrypoint with key of `logs_specs_name` param.
743*da0073e9SAndroid Build Coastguard Worker    Provides plugin mechanism to provide custom implementation of LogsSpecs.
744*da0073e9SAndroid Build Coastguard Worker
745*da0073e9SAndroid Build Coastguard Worker    Returns `DefaultLogsSpecs` when logs_spec_name is None.
746*da0073e9SAndroid Build Coastguard Worker    Raises ValueError when entrypoint for `logs_spec_name` can't be found in entrypoints.
747*da0073e9SAndroid Build Coastguard Worker    """
748*da0073e9SAndroid Build Coastguard Worker    logs_specs_cls = None
749*da0073e9SAndroid Build Coastguard Worker    if logs_specs_name is not None:
750*da0073e9SAndroid Build Coastguard Worker        eps = metadata.entry_points()
751*da0073e9SAndroid Build Coastguard Worker        if hasattr(eps, "select"):  # >= 3.10
752*da0073e9SAndroid Build Coastguard Worker            group = eps.select(group="torchrun.logs_specs")
753*da0073e9SAndroid Build Coastguard Worker            if group.select(name=logs_specs_name):
754*da0073e9SAndroid Build Coastguard Worker                logs_specs_cls = group[logs_specs_name].load()
755*da0073e9SAndroid Build Coastguard Worker
756*da0073e9SAndroid Build Coastguard Worker        elif specs := eps.get("torchrun.logs_specs"):  # < 3.10
757*da0073e9SAndroid Build Coastguard Worker            if entrypoint_list := [ep for ep in specs if ep.name == logs_specs_name]:
758*da0073e9SAndroid Build Coastguard Worker                logs_specs_cls = entrypoint_list[0].load()
759*da0073e9SAndroid Build Coastguard Worker
760*da0073e9SAndroid Build Coastguard Worker        if logs_specs_cls is None:
761*da0073e9SAndroid Build Coastguard Worker            raise ValueError(
762*da0073e9SAndroid Build Coastguard Worker                f"Could not find entrypoint under 'torchrun.logs_specs[{logs_specs_name}]' key"
763*da0073e9SAndroid Build Coastguard Worker            )
764*da0073e9SAndroid Build Coastguard Worker
765*da0073e9SAndroid Build Coastguard Worker        logging.info(
766*da0073e9SAndroid Build Coastguard Worker            "Using logs_spec '%s' mapped to %s", logs_specs_name, str(logs_specs_cls)
767*da0073e9SAndroid Build Coastguard Worker        )
768*da0073e9SAndroid Build Coastguard Worker    else:
769*da0073e9SAndroid Build Coastguard Worker        logs_specs_cls = DefaultLogsSpecs
770*da0073e9SAndroid Build Coastguard Worker
771*da0073e9SAndroid Build Coastguard Worker    return logs_specs_cls
772*da0073e9SAndroid Build Coastguard Worker
773*da0073e9SAndroid Build Coastguard Worker
774*da0073e9SAndroid Build Coastguard Workerdef config_from_args(args) -> Tuple[LaunchConfig, Union[Callable, str], List[str]]:
775*da0073e9SAndroid Build Coastguard Worker    # If ``args`` not passed, defaults to ``sys.argv[:1]``
776*da0073e9SAndroid Build Coastguard Worker    min_nodes, max_nodes = parse_min_max_nnodes(args.nnodes)
777*da0073e9SAndroid Build Coastguard Worker    assert 0 < min_nodes <= max_nodes
778*da0073e9SAndroid Build Coastguard Worker    assert args.max_restarts >= 0
779*da0073e9SAndroid Build Coastguard Worker
780*da0073e9SAndroid Build Coastguard Worker    if (
781*da0073e9SAndroid Build Coastguard Worker        hasattr(args, "master_addr")
782*da0073e9SAndroid Build Coastguard Worker        and args.rdzv_backend != "static"
783*da0073e9SAndroid Build Coastguard Worker        and not args.rdzv_endpoint
784*da0073e9SAndroid Build Coastguard Worker    ):
785*da0073e9SAndroid Build Coastguard Worker        logger.warning(
786*da0073e9SAndroid Build Coastguard Worker            "master_addr is only used for static rdzv_backend and when rdzv_endpoint "
787*da0073e9SAndroid Build Coastguard Worker            "is not specified."
788*da0073e9SAndroid Build Coastguard Worker        )
789*da0073e9SAndroid Build Coastguard Worker
790*da0073e9SAndroid Build Coastguard Worker    nproc_per_node = determine_local_world_size(args.nproc_per_node)
791*da0073e9SAndroid Build Coastguard Worker    if "OMP_NUM_THREADS" not in os.environ and nproc_per_node > 1:
792*da0073e9SAndroid Build Coastguard Worker        omp_num_threads = 1
793*da0073e9SAndroid Build Coastguard Worker        logger.warning(
794*da0073e9SAndroid Build Coastguard Worker            "\n*****************************************\n"
795*da0073e9SAndroid Build Coastguard Worker            "Setting OMP_NUM_THREADS environment variable for each process to be "
796*da0073e9SAndroid Build Coastguard Worker            "%s in default, to avoid your system being overloaded, "
797*da0073e9SAndroid Build Coastguard Worker            "please further tune the variable for optimal performance in "
798*da0073e9SAndroid Build Coastguard Worker            "your application as needed. \n"
799*da0073e9SAndroid Build Coastguard Worker            "*****************************************",
800*da0073e9SAndroid Build Coastguard Worker            omp_num_threads,
801*da0073e9SAndroid Build Coastguard Worker        )
802*da0073e9SAndroid Build Coastguard Worker        # This env variable will be passed down to the subprocesses
803*da0073e9SAndroid Build Coastguard Worker        os.environ["OMP_NUM_THREADS"] = str(omp_num_threads)
804*da0073e9SAndroid Build Coastguard Worker
805*da0073e9SAndroid Build Coastguard Worker    log_line_prefix_template = os.getenv("TORCHELASTIC_LOG_LINE_PREFIX_TEMPLATE")
806*da0073e9SAndroid Build Coastguard Worker
807*da0073e9SAndroid Build Coastguard Worker    rdzv_configs = _parse_rendezvous_config(args.rdzv_conf)
808*da0073e9SAndroid Build Coastguard Worker
809*da0073e9SAndroid Build Coastguard Worker    if args.rdzv_backend == "static":
810*da0073e9SAndroid Build Coastguard Worker        rdzv_configs["rank"] = args.node_rank
811*da0073e9SAndroid Build Coastguard Worker
812*da0073e9SAndroid Build Coastguard Worker    rdzv_endpoint = get_rdzv_endpoint(args)
813*da0073e9SAndroid Build Coastguard Worker
814*da0073e9SAndroid Build Coastguard Worker    ranks: Optional[Set[int]] = None
815*da0073e9SAndroid Build Coastguard Worker    if args.local_ranks_filter:
816*da0073e9SAndroid Build Coastguard Worker        try:
817*da0073e9SAndroid Build Coastguard Worker            ranks = set(map(int, args.local_ranks_filter.split(",")))
818*da0073e9SAndroid Build Coastguard Worker            assert ranks
819*da0073e9SAndroid Build Coastguard Worker        except Exception as e:
820*da0073e9SAndroid Build Coastguard Worker            raise ValueError(
821*da0073e9SAndroid Build Coastguard Worker                "--local_ranks_filter must be a comma-separated list of integers e.g. --local_ranks_filter=0,1,2"
822*da0073e9SAndroid Build Coastguard Worker            ) from e
823*da0073e9SAndroid Build Coastguard Worker
824*da0073e9SAndroid Build Coastguard Worker    logs_specs_cls: Type[LogsSpecs] = _get_logs_specs_class(args.logs_specs)
825*da0073e9SAndroid Build Coastguard Worker    logs_specs = logs_specs_cls(
826*da0073e9SAndroid Build Coastguard Worker        log_dir=args.log_dir,
827*da0073e9SAndroid Build Coastguard Worker        redirects=Std.from_str(args.redirects),
828*da0073e9SAndroid Build Coastguard Worker        tee=Std.from_str(args.tee),
829*da0073e9SAndroid Build Coastguard Worker        local_ranks_filter=ranks,
830*da0073e9SAndroid Build Coastguard Worker    )
831*da0073e9SAndroid Build Coastguard Worker
832*da0073e9SAndroid Build Coastguard Worker    config = LaunchConfig(
833*da0073e9SAndroid Build Coastguard Worker        min_nodes=min_nodes,
834*da0073e9SAndroid Build Coastguard Worker        max_nodes=max_nodes,
835*da0073e9SAndroid Build Coastguard Worker        nproc_per_node=nproc_per_node,
836*da0073e9SAndroid Build Coastguard Worker        run_id=args.rdzv_id,
837*da0073e9SAndroid Build Coastguard Worker        role=args.role,
838*da0073e9SAndroid Build Coastguard Worker        rdzv_endpoint=rdzv_endpoint,
839*da0073e9SAndroid Build Coastguard Worker        rdzv_backend=args.rdzv_backend,
840*da0073e9SAndroid Build Coastguard Worker        rdzv_configs=rdzv_configs,
841*da0073e9SAndroid Build Coastguard Worker        max_restarts=args.max_restarts,
842*da0073e9SAndroid Build Coastguard Worker        monitor_interval=args.monitor_interval,
843*da0073e9SAndroid Build Coastguard Worker        start_method=args.start_method,
844*da0073e9SAndroid Build Coastguard Worker        log_line_prefix_template=log_line_prefix_template,
845*da0073e9SAndroid Build Coastguard Worker        local_addr=args.local_addr,
846*da0073e9SAndroid Build Coastguard Worker        logs_specs=logs_specs,
847*da0073e9SAndroid Build Coastguard Worker    )
848*da0073e9SAndroid Build Coastguard Worker
849*da0073e9SAndroid Build Coastguard Worker    with_python = not args.no_python
850*da0073e9SAndroid Build Coastguard Worker    cmd: Union[Callable, str]
851*da0073e9SAndroid Build Coastguard Worker    cmd_args = []
852*da0073e9SAndroid Build Coastguard Worker    use_env = get_use_env(args)
853*da0073e9SAndroid Build Coastguard Worker    if args.run_path:
854*da0073e9SAndroid Build Coastguard Worker        cmd = run_script_path
855*da0073e9SAndroid Build Coastguard Worker        cmd_args.append(args.training_script)
856*da0073e9SAndroid Build Coastguard Worker    else:
857*da0073e9SAndroid Build Coastguard Worker        if with_python:
858*da0073e9SAndroid Build Coastguard Worker            cmd = os.getenv("PYTHON_EXEC", sys.executable)
859*da0073e9SAndroid Build Coastguard Worker            cmd_args.append("-u")
860*da0073e9SAndroid Build Coastguard Worker            if args.module:
861*da0073e9SAndroid Build Coastguard Worker                cmd_args.append("-m")
862*da0073e9SAndroid Build Coastguard Worker            cmd_args.append(args.training_script)
863*da0073e9SAndroid Build Coastguard Worker        else:
864*da0073e9SAndroid Build Coastguard Worker            if args.module:
865*da0073e9SAndroid Build Coastguard Worker                raise ValueError(
866*da0073e9SAndroid Build Coastguard Worker                    "Don't use both the '--no-python' flag"
867*da0073e9SAndroid Build Coastguard Worker                    " and the '--module' flag at the same time."
868*da0073e9SAndroid Build Coastguard Worker                )
869*da0073e9SAndroid Build Coastguard Worker            cmd = args.training_script
870*da0073e9SAndroid Build Coastguard Worker    if not use_env:
871*da0073e9SAndroid Build Coastguard Worker        cmd_args.append(f"--local-rank={macros.local_rank}")
872*da0073e9SAndroid Build Coastguard Worker    cmd_args.extend(args.training_script_args)
873*da0073e9SAndroid Build Coastguard Worker
874*da0073e9SAndroid Build Coastguard Worker    return config, cmd, cmd_args
875*da0073e9SAndroid Build Coastguard Worker
876*da0073e9SAndroid Build Coastguard Worker
877*da0073e9SAndroid Build Coastguard Workerdef run_script_path(training_script: str, *training_script_args: str):
878*da0073e9SAndroid Build Coastguard Worker    """
879*da0073e9SAndroid Build Coastguard Worker    Run the provided `training_script` from within this interpreter.
880*da0073e9SAndroid Build Coastguard Worker
881*da0073e9SAndroid Build Coastguard Worker    Usage: `script_as_function("/abs/path/to/script.py", "--arg1", "val1")`
882*da0073e9SAndroid Build Coastguard Worker    """
883*da0073e9SAndroid Build Coastguard Worker    import runpy
884*da0073e9SAndroid Build Coastguard Worker    import sys
885*da0073e9SAndroid Build Coastguard Worker
886*da0073e9SAndroid Build Coastguard Worker    sys.argv = [training_script] + [*training_script_args]
887*da0073e9SAndroid Build Coastguard Worker    runpy.run_path(sys.argv[0], run_name="__main__")
888*da0073e9SAndroid Build Coastguard Worker
889*da0073e9SAndroid Build Coastguard Worker
890*da0073e9SAndroid Build Coastguard Workerdef run(args):
891*da0073e9SAndroid Build Coastguard Worker    torch.multiprocessing._set_thread_name("pt_elastic")
892*da0073e9SAndroid Build Coastguard Worker
893*da0073e9SAndroid Build Coastguard Worker    if args.standalone:
894*da0073e9SAndroid Build Coastguard Worker        args.rdzv_backend = "c10d"
895*da0073e9SAndroid Build Coastguard Worker        args.rdzv_endpoint = "localhost:0"
896*da0073e9SAndroid Build Coastguard Worker        args.rdzv_id = str(uuid.uuid4())
897*da0073e9SAndroid Build Coastguard Worker        logger.info(
898*da0073e9SAndroid Build Coastguard Worker            "\n**************************************\n"
899*da0073e9SAndroid Build Coastguard Worker            "Rendezvous info:\n"
900*da0073e9SAndroid Build Coastguard Worker            "--rdzv-backend=%s "
901*da0073e9SAndroid Build Coastguard Worker            "--rdzv-endpoint=%s "
902*da0073e9SAndroid Build Coastguard Worker            "--rdzv-id=%s\n"
903*da0073e9SAndroid Build Coastguard Worker            "**************************************\n",
904*da0073e9SAndroid Build Coastguard Worker            args.rdzv_backend,
905*da0073e9SAndroid Build Coastguard Worker            args.rdzv_endpoint,
906*da0073e9SAndroid Build Coastguard Worker            args.rdzv_id,
907*da0073e9SAndroid Build Coastguard Worker        )
908*da0073e9SAndroid Build Coastguard Worker
909*da0073e9SAndroid Build Coastguard Worker    config, cmd, cmd_args = config_from_args(args)
910*da0073e9SAndroid Build Coastguard Worker    elastic_launch(
911*da0073e9SAndroid Build Coastguard Worker        config=config,
912*da0073e9SAndroid Build Coastguard Worker        entrypoint=cmd,
913*da0073e9SAndroid Build Coastguard Worker    )(*cmd_args)
914*da0073e9SAndroid Build Coastguard Worker
915*da0073e9SAndroid Build Coastguard Worker
916*da0073e9SAndroid Build Coastguard Worker@record
917*da0073e9SAndroid Build Coastguard Workerdef main(args=None):
918*da0073e9SAndroid Build Coastguard Worker    args = parse_args(args)
919*da0073e9SAndroid Build Coastguard Worker    run(args)
920*da0073e9SAndroid Build Coastguard Worker
921*da0073e9SAndroid Build Coastguard Worker
922*da0073e9SAndroid Build Coastguard Workerif __name__ == "__main__":
923*da0073e9SAndroid Build Coastguard Worker    main()
924