xref: /aosp_15_r20/external/pytorch/torch/distributed/run.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1#!/usr/bin/env python3
2# mypy: allow-untyped-defs
3
4# Copyright (c) Facebook, Inc. and its affiliates.
5# All rights reserved.
6#
7# This source code is licensed under the BSD-style license found in the
8# LICENSE file in the root directory of this source tree.
9
10"""
11Superset of ``torch.distributed.launch``.
12
13``torchrun`` provides a superset of the functionality as ``torch.distributed.launch``
14with the following additional functionalities:
15
161. Worker failures are handled gracefully by restarting all workers.
17
182. Worker ``RANK`` and ``WORLD_SIZE`` are assigned automatically.
19
203. Number of nodes is allowed to change between minimum and maximum sizes (elasticity).
21
22.. note:: ``torchrun`` is a python
23          `console script <https://packaging.python.org/en/latest/specifications/entry-points/#use-for-scripts>`_
24          to the main module
25          `torch.distributed.run <https://github.com/pytorch/pytorch/blob/master/torch/distributed/run.py>`_
26          declared in the ``entry_points`` configuration in
27          `setup.py <https://github.com/pytorch/pytorch/blob/master/setup.py>`_.
28          It is equivalent to invoking ``python -m torch.distributed.run``.
29
30
31Transitioning from torch.distributed.launch to torchrun
32~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
33
34
35``torchrun`` supports the same arguments as ``torch.distributed.launch`` **except**
36for ``--use-env`` which is now deprecated. To migrate from ``torch.distributed.launch``
37to ``torchrun`` follow these steps:
38
391.  If your training script is already reading ``local_rank`` from the ``LOCAL_RANK`` environment variable.
40    Then you need simply omit the ``--use-env`` flag, e.g.:
41
42    +--------------------------------------------------------------------+--------------------------------------------+
43    |         ``torch.distributed.launch``                               |                ``torchrun``                |
44    +====================================================================+============================================+
45    |                                                                    |                                            |
46    | .. code-block:: shell-session                                      | .. code-block:: shell-session              |
47    |                                                                    |                                            |
48    |    $ python -m torch.distributed.launch --use-env train_script.py  |    $ torchrun train_script.py              |
49    |                                                                    |                                            |
50    +--------------------------------------------------------------------+--------------------------------------------+
51
522.  If your training script reads local rank from a ``--local-rank`` cmd argument.
53    Change your training script to read from the ``LOCAL_RANK`` environment variable as
54    demonstrated by the following code snippet:
55
56    +-------------------------------------------------------+----------------------------------------------------+
57    |         ``torch.distributed.launch``                  |                    ``torchrun``                    |
58    +=======================================================+====================================================+
59    |                                                       |                                                    |
60    | .. code-block:: python                                | .. code-block:: python                             |
61    |                                                       |                                                    |
62    |                                                       |                                                    |
63    |    import argparse                                    |     import os                                      |
64    |    parser = argparse.ArgumentParser()                 |     local_rank = int(os.environ["LOCAL_RANK"])     |
65    |    parser.add_argument("--local-rank", type=int)      |                                                    |
66    |    args = parser.parse_args()                         |                                                    |
67    |                                                       |                                                    |
68    |    local_rank = args.local_rank                       |                                                    |
69    |                                                       |                                                    |
70    +-------------------------------------------------------+----------------------------------------------------+
71
72.. versionchanged:: 2.0.0
73
74    The launcher will pass the ``--local-rank=<rank>`` argument to your script.
75    From PyTorch 2.0.0 onwards, the dashed ``--local-rank`` is preferred over the
76    previously used underscored ``--local_rank``.
77
78    For backward compatibility, it may be necessary for users to handle both
79    cases in their argument parsing code. This means including both ``"--local-rank"``
80    and ``"--local_rank"`` in the argument parser. If only ``"--local_rank"`` is
81    provided, the launcher will trigger an error: "error: unrecognized arguments:
82    --local-rank=<rank>". For training code that only supports PyTorch 2.0.0+,
83    including ``"--local-rank"`` should be sufficient.
84
85    ::
86
87        >>> # xdoctest: +SKIP
88        >>> import argparse
89        >>> parser = argparse.ArgumentParser()
90        >>> parser.add_argument("--local-rank", "--local_rank", type=int)
91        >>> args = parser.parse_args()
92
93The aformentioned changes suffice to migrate from ``torch.distributed.launch`` to ``torchrun``.
94To take advantage of new features such as elasticity, fault-tolerance, and error reporting of ``torchrun``
95please refer to:
96
97* :ref:`elastic_train_script` for more information on authoring training scripts that are ``torchrun`` compliant.
98* the rest of this page for more information on the features of ``torchrun``.
99
100
101Usage
102--------
103
104Single-node multi-worker
105++++++++++++++++++++++++++++++
106
107::
108
109    torchrun
110        --standalone
111        --nnodes=1
112        --nproc-per-node=$NUM_TRAINERS
113        YOUR_TRAINING_SCRIPT.py (--arg1 ... train script args...)
114
115Stacked single-node multi-worker
116+++++++++++++++++++++++++++++++++++
117
118To run multiple instances (separate jobs) of single-node, multi-worker on the
119same host, we need to make sure that each instance (job) is
120setup on different ports to avoid port conflicts (or worse, two jobs being merged
121as a single job). To do this you have to run with ``--rdzv-backend=c10d``
122and specify a different port by setting ``--rdzv-endpoint=localhost:$PORT_k``.
123For ``--nodes=1``, its often convenient to let ``torchrun`` pick a free random
124port automatically instead of manually assigning different ports for each run.
125
126::
127
128    torchrun
129        --rdzv-backend=c10d
130        --rdzv-endpoint=localhost:0
131        --nnodes=1
132        --nproc-per-node=$NUM_TRAINERS
133        YOUR_TRAINING_SCRIPT.py (--arg1 ... train script args...)
134
135
136Fault tolerant (fixed sized number of workers, no elasticity, tolerates 3 failures)
137++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
138
139::
140
141    torchrun
142        --nnodes=$NUM_NODES
143        --nproc-per-node=$NUM_TRAINERS
144        --max-restarts=3
145        --rdzv-id=$JOB_ID
146        --rdzv-backend=c10d
147        --rdzv-endpoint=$HOST_NODE_ADDR
148        YOUR_TRAINING_SCRIPT.py (--arg1 ... train script args...)
149
150``HOST_NODE_ADDR``, in form <host>[:<port>] (e.g. node1.example.com:29400), specifies the node and
151the port on which the C10d rendezvous backend should be instantiated and hosted. It can be any
152node in your training cluster, but ideally you should pick a node that has a high bandwidth.
153
154.. note::
155   If no port number is specified ``HOST_NODE_ADDR`` defaults to 29400.
156
157Elastic (``min=1``, ``max=4``, tolerates up to 3 membership changes or failures)
158+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
159
160::
161
162    torchrun
163        --nnodes=1:4
164        --nproc-per-node=$NUM_TRAINERS
165        --max-restarts=3
166        --rdzv-id=$JOB_ID
167        --rdzv-backend=c10d
168        --rdzv-endpoint=$HOST_NODE_ADDR
169        YOUR_TRAINING_SCRIPT.py (--arg1 ... train script args...)
170
171``HOST_NODE_ADDR``, in form <host>[:<port>] (e.g. node1.example.com:29400), specifies the node and
172the port on which the C10d rendezvous backend should be instantiated and hosted. It can be any
173node in your training cluster, but ideally you should pick a node that has a high bandwidth.
174
175.. note::
176   If no port number is specified ``HOST_NODE_ADDR`` defaults to 29400.
177
178Note on rendezvous backend
179------------------------------
180
181For multi-node training you need to specify:
182
1831. ``--rdzv-id``: A unique job id (shared by all nodes participating in the job)
1842. ``--rdzv-backend``: An implementation of
185   :py:class:`torch.distributed.elastic.rendezvous.RendezvousHandler`
1863. ``--rdzv-endpoint``: The endpoint where the rendezvous backend is running; usually in form
187   ``host:port``.
188
189Currently ``c10d`` (recommended), ``etcd-v2``, and ``etcd`` (legacy)  rendezvous backends are
190supported out of the box. To use ``etcd-v2`` or ``etcd``, setup an etcd server with the ``v2`` api
191enabled (e.g. ``--enable-v2``).
192
193.. warning::
194   ``etcd-v2`` and ``etcd`` rendezvous use etcd API v2. You MUST enable the v2 API on the etcd
195   server. Our tests use etcd v3.4.3.
196
197.. warning::
198   For etcd-based rendezvous we recommend using ``etcd-v2`` over ``etcd`` which is functionally
199   equivalent, but uses a revised implementation. ``etcd`` is in maintenance mode and will be
200   removed in a future version.
201
202Definitions
203--------------
204
2051. ``Node`` - A physical instance or a container; maps to the unit that the job manager works with.
206
2072. ``Worker`` - A worker in the context of distributed training.
208
2093. ``WorkerGroup`` - The set of workers that execute the same function (e.g. trainers).
210
2114. ``LocalWorkerGroup`` - A subset of the workers in the worker group running on the same node.
212
2135. ``RANK`` - The rank of the worker within a worker group.
214
2156. ``WORLD_SIZE`` - The total number of workers in a worker group.
216
2177. ``LOCAL_RANK`` - The rank of the worker within a local worker group.
218
2198. ``LOCAL_WORLD_SIZE`` - The size of the local worker group.
220
2219. ``rdzv_id`` - A user-defined id that uniquely identifies the worker group for a job. This id is
222   used by each node to join as a member of a particular worker group.
223
2249. ``rdzv_backend`` - The backend of the rendezvous (e.g. ``c10d``). This is typically a strongly
225   consistent key-value store.
226
22710. ``rdzv_endpoint`` - The rendezvous backend endpoint; usually in form ``<host>:<port>``.
228
229A ``Node`` runs ``LOCAL_WORLD_SIZE`` workers which comprise a ``LocalWorkerGroup``. The union of
230all ``LocalWorkerGroups`` in the nodes in the job comprise the ``WorkerGroup``.
231
232Environment Variables
233----------------------
234
235The following environment variables are made available to you in your script:
236
2371. ``LOCAL_RANK`` -  The local rank.
238
2392. ``RANK`` -  The global rank.
240
2413. ``GROUP_RANK`` - The rank of the worker group. A number between 0 and ``max_nnodes``. When
242   running a single worker group per node, this is the rank of the node.
243
2444. ``ROLE_RANK`` -  The rank of the worker across all the workers that have the same role. The role
245   of the worker is specified in the ``WorkerSpec``.
246
2475. ``LOCAL_WORLD_SIZE`` - The local world size (e.g. number of workers running locally); equals to
248   ``--nproc-per-node`` specified on ``torchrun``.
249
2506. ``WORLD_SIZE`` - The world size (total number of workers in the job).
251
2527. ``ROLE_WORLD_SIZE`` - The total number of workers that was launched with the same role specified
253   in ``WorkerSpec``.
254
2558. ``MASTER_ADDR`` - The FQDN of the host that is running worker with rank 0; used to initialize
256   the Torch Distributed backend.
257
2589. ``MASTER_PORT`` - The port on the ``MASTER_ADDR`` that can be used to host the C10d TCP store.
259
26010. ``TORCHELASTIC_RESTART_COUNT`` - The number of worker group restarts so far.
261
26211. ``TORCHELASTIC_MAX_RESTARTS`` - The configured maximum number of restarts.
263
26412. ``TORCHELASTIC_RUN_ID`` - Equal to the rendezvous ``run_id`` (e.g. unique job id).
265
26613. ``PYTHON_EXEC`` - System executable override. If provided, the python user script will
267    use the value of ``PYTHON_EXEC`` as executable. The `sys.executable` is used by default.
268
269Deployment
270------------
271
2721. (Not needed for the C10d backend) Start the rendezvous backend server and get the endpoint (to be
273   passed as ``--rdzv-endpoint`` to the launcher script)
274
2752. Single-node multi-worker: Start the launcher on the host to start the agent process which
276   creates and monitors a local worker group.
277
2783. Multi-node multi-worker: Start the launcher with the same arguments on all the nodes
279   participating in training.
280
281When using a job/cluster manager the entry point command to the multi-node job should be this
282launcher.
283
284Failure Modes
285---------------
286
2871. Worker failure: For a training job with ``n`` workers, if ``k<=n`` workers fail all workers
288   are stopped and restarted up to ``max_restarts``.
289
2902. Agent failure: An agent failure results in a local worker group failure. It is up to the job
291   manager to fail the entire job (gang semantics) or attempt to replace the node. Both behaviors
292   are supported by the agent.
293
2943. Node failure: Same as agent failure.
295
296Membership Changes
297--------------------
298
2991. Node departure (scale-down): The agent is notified of the departure, all existing workers are
300   stopped, a new ``WorkerGroup`` is formed, and all workers are started with a new ``RANK`` and
301   ``WORLD_SIZE``.
302
3032. Node arrival (scale-up): The new node is admitted to the job, all existing workers are stopped,
304   a new ``WorkerGroup`` is formed, and all workers are started with a new ``RANK`` and
305   ``WORLD_SIZE``.
306
307Important Notices
308--------------------
309
3101. This utility and multi-process distributed (single-node or
311   multi-node) GPU training currently only achieves the best performance using
312   the NCCL distributed backend. Thus NCCL backend is the recommended backend to
313   use for GPU training.
314
3152. The environment variables necessary to initialize a Torch process group are provided to you by
316   this module, no need for you to pass ``RANK`` manually.  To initialize a process group in your
317   training script, simply run:
318
319::
320
321 >>> # xdoctest: +SKIP("stub")
322 >>> import torch.distributed as dist
323 >>> dist.init_process_group(backend="gloo|nccl")
324
3253. In your training program, you can either use regular distributed functions
326   or use :func:`torch.nn.parallel.DistributedDataParallel` module. If your
327   training program uses GPUs for training and you would like to use
328   :func:`torch.nn.parallel.DistributedDataParallel` module,
329   here is how to configure it.
330
331::
332
333    local_rank = int(os.environ["LOCAL_RANK"])
334    model = torch.nn.parallel.DistributedDataParallel(model,
335                                                      device_ids=[local_rank],
336                                                      output_device=local_rank)
337
338Please ensure that ``device_ids`` argument is set to be the only GPU device id
339that your code will be operating on. This is generally the local rank of the
340process. In other words, the ``device_ids`` needs to be ``[int(os.environ("LOCAL_RANK"))]``,
341and ``output_device`` needs to be ``int(os.environ("LOCAL_RANK"))`` in order to use this
342utility
343
344
3454. On failures or membership changes ALL surviving workers are killed immediately. Make sure to
346   checkpoint your progress. The frequency of checkpoints should depend on your job's tolerance
347   for lost work.
348
3495. This module only supports homogeneous ``LOCAL_WORLD_SIZE``. That is, it is assumed that all
350   nodes run the same number of local workers (per role).
351
3526. ``RANK`` is NOT stable. Between restarts, the local workers on a node can be assigned a
353   different range of ranks than before. NEVER hard code any assumptions about the stable-ness of
354   ranks or some correlation between ``RANK`` and ``LOCAL_RANK``.
355
3567. When using elasticity (``min_size!=max_size``) DO NOT hard code assumptions about
357   ``WORLD_SIZE`` as the world size can change as nodes are allowed to leave and join.
358
3598. It is recommended for your script to have the following structure:
360
361::
362
363  def main():
364    load_checkpoint(checkpoint_path)
365    initialize()
366    train()
367
368  def train():
369    for batch in iter(dataset):
370      train_step(batch)
371
372      if should_checkpoint:
373        save_checkpoint(checkpoint_path)
374
3759. (Recommended) On worker errors, this tool will summarize the details of the error
376   (e.g. time, rank, host, pid, traceback, etc). On each node, the first error (by timestamp)
377   is heuristically reported as the "Root Cause" error. To get tracebacks as part of this
378   error summary print out, you must decorate your main entrypoint function in your
379   training script as shown in the example below. If not decorated, then the summary
380   will not include the traceback of the exception and will only contain the exitcode.
381   For details on torchelastic error handling see: https://pytorch.org/docs/stable/elastic/errors.html
382
383::
384
385  from torch.distributed.elastic.multiprocessing.errors import record
386
387  @record
388  def main():
389      # do train
390      pass
391
392  if __name__ == "__main__":
393      main()
394
395"""
396import logging
397import os
398import sys
399import uuid
400from argparse import ArgumentParser, REMAINDER
401from importlib import metadata
402from typing import Callable, List, Optional, Set, Tuple, Type, Union
403
404import torch
405from torch.distributed.argparse_util import check_env, env
406from torch.distributed.elastic.multiprocessing import DefaultLogsSpecs, LogsSpecs, Std
407from torch.distributed.elastic.multiprocessing.errors import record
408from torch.distributed.elastic.rendezvous.utils import _parse_rendezvous_config
409from torch.distributed.elastic.utils import macros
410from torch.distributed.elastic.utils.logging import get_logger
411from torch.distributed.launcher.api import elastic_launch, LaunchConfig
412from torch.utils.backend_registration import _get_custom_mod_func
413
414
415logger = get_logger(__name__)
416
417
418def get_args_parser() -> ArgumentParser:
419    """Parse the command line options."""
420    parser = ArgumentParser(description="Torch Distributed Elastic Training Launcher")
421
422    #
423    # Worker/node size related arguments.
424    #
425
426    parser.add_argument(
427        "--nnodes",
428        action=env,
429        type=str,
430        default="1:1",
431        help="Number of nodes, or the range of nodes in form <minimum_nodes>:<maximum_nodes>.",
432    )
433    parser.add_argument(
434        "--nproc-per-node",
435        "--nproc_per_node",
436        action=env,
437        type=str,
438        default="1",
439        help="Number of workers per node; supported values: [auto, cpu, gpu, int].",
440    )
441
442    #
443    # Rendezvous related arguments
444    #
445
446    parser.add_argument(
447        "--rdzv-backend",
448        "--rdzv_backend",
449        action=env,
450        type=str,
451        default="static",
452        help="Rendezvous backend.",
453    )
454    parser.add_argument(
455        "--rdzv-endpoint",
456        "--rdzv_endpoint",
457        action=env,
458        type=str,
459        default="",
460        help="Rendezvous backend endpoint; usually in form <host>:<port>.",
461    )
462    parser.add_argument(
463        "--rdzv-id",
464        "--rdzv_id",
465        action=env,
466        type=str,
467        default="none",
468        help="User-defined group id.",
469    )
470    parser.add_argument(
471        "--rdzv-conf",
472        "--rdzv_conf",
473        action=env,
474        type=str,
475        default="",
476        help="Additional rendezvous configuration (<key1>=<value1>,<key2>=<value2>,...).",
477    )
478    parser.add_argument(
479        "--standalone",
480        action=check_env,
481        help="Start a local standalone rendezvous backend that is represented by a C10d TCP store "
482        "on a free port. Useful when launching single-node, multi-worker job. If specified "
483        "--rdzv-backend, --rdzv-endpoint, --rdzv-id are auto-assigned and any explicitly set values "
484        "are ignored.",
485    )
486
487    #
488    # User-code launch related arguments.
489    #
490
491    parser.add_argument(
492        "--max-restarts",
493        "--max_restarts",
494        action=env,
495        type=int,
496        default=0,
497        help="Maximum number of worker group restarts before failing.",
498    )
499    parser.add_argument(
500        "--monitor-interval",
501        "--monitor_interval",
502        action=env,
503        type=float,
504        default=0.1,
505        help="Interval, in seconds, to monitor the state of workers.",
506    )
507    parser.add_argument(
508        "--start-method",
509        "--start_method",
510        action=env,
511        type=str,
512        default="spawn",
513        choices=["spawn", "fork", "forkserver"],
514        help="Multiprocessing start method to use when creating workers.",
515    )
516    parser.add_argument(
517        "--role",
518        action=env,
519        type=str,
520        default="default",
521        help="User-defined role for the workers.",
522    )
523    parser.add_argument(
524        "-m",
525        "--module",
526        action=check_env,
527        help="Change each process to interpret the launch script as a Python module, executing "
528        "with the same behavior as 'python -m'.",
529    )
530    parser.add_argument(
531        "--no-python",
532        "--no_python",
533        action=check_env,
534        help="Skip prepending the training script with 'python' - just execute it directly. Useful "
535        "when the script is not a Python script.",
536    )
537
538    parser.add_argument(
539        "--run-path",
540        "--run_path",
541        action=check_env,
542        help="Run the training script with runpy.run_path in the same interpreter."
543        " Script must be provided as an abs path (e.g. /abs/path/script.py)."
544        " Takes precedence over --no-python.",
545    )
546    parser.add_argument(
547        "--log-dir",
548        "--log_dir",
549        action=env,
550        type=str,
551        default=None,
552        help="Base directory to use for log files (e.g. /var/log/torch/elastic). The same "
553        "directory is re-used for multiple runs (a unique job-level sub-directory is created with "
554        "rdzv_id as the prefix).",
555    )
556    parser.add_argument(
557        "-r",
558        "--redirects",
559        action=env,
560        type=str,
561        default="0",
562        help="Redirect std streams into a log file in the log directory (e.g. [-r 3] redirects "
563        "both stdout+stderr for all workers, [-r 0:1,1:2] redirects stdout for local rank 0 and "
564        "stderr for local rank 1).",
565    )
566    parser.add_argument(
567        "-t",
568        "--tee",
569        action=env,
570        type=str,
571        default="0",
572        help="Tee std streams into a log file and also to console (see --redirects for format).",
573    )
574
575    parser.add_argument(
576        "--local-ranks-filter",
577        "--local_ranks_filter",
578        action=env,
579        type=str,
580        default="",
581        help="Only show logs from specified ranks in console (e.g. [--local_ranks_filter=0,1,2] will "
582        "only show logs from rank 0, 1 and 2). This will only apply to stdout and stderr, not to"
583        "log files saved via --redirect or --tee",
584    )
585
586    #
587    # Backwards compatible parameters with caffe2.distributed.launch.
588    #
589
590    parser.add_argument(
591        "--node-rank",
592        "--node_rank",
593        type=int,
594        action=env,
595        default=0,
596        help="Rank of the node for multi-node distributed training.",
597    )
598    parser.add_argument(
599        "--master-addr",
600        "--master_addr",
601        default="127.0.0.1",
602        type=str,
603        action=env,
604        help="Address of the master node (rank 0) that only used for static rendezvous. It should "
605        "be either the IP address or the hostname of rank 0. For single node multi-proc training "
606        "the --master-addr can simply be 127.0.0.1; IPv6 should have the pattern "
607        "`[0:0:0:0:0:0:0:1]`.",
608    )
609    parser.add_argument(
610        "--master-port",
611        "--master_port",
612        default=29500,
613        type=int,
614        action=env,
615        help="Port on the master node (rank 0) to be used for communication during distributed "
616        "training. It is only used for static rendezvous.",
617    )
618    parser.add_argument(
619        "--local-addr",
620        "--local_addr",
621        default=None,
622        type=str,
623        action=env,
624        help="Address of the local node. If specified, will use the given address for connection. "
625        "Else, will look up the local node address instead. Else, it will be default to local "
626        "machine's FQDN.",
627    )
628
629    parser.add_argument(
630        "--logs-specs",
631        "--logs_specs",
632        default=None,
633        type=str,
634        help="torchrun.logs_specs group entrypoint name, value must be type of LogsSpecs. "
635        "Can be used to override custom logging behavior.",
636    )
637
638    #
639    # Positional arguments.
640    #
641
642    parser.add_argument(
643        "training_script",
644        type=str,
645        help="Full path to the (single GPU) training program/script to be launched in parallel, "
646        "followed by all the arguments for the training script.",
647    )
648
649    # Rest from the training program.
650    parser.add_argument("training_script_args", nargs=REMAINDER)
651
652    return parser
653
654
655def parse_args(args):
656    parser = get_args_parser()
657    return parser.parse_args(args)
658
659
660def parse_min_max_nnodes(nnodes: str):
661    arr = nnodes.split(":")
662
663    if len(arr) == 1:
664        min_nodes = max_nodes = int(arr[0])
665    elif len(arr) == 2:
666        min_nodes = int(arr[0])
667        max_nodes = int(arr[1])
668    else:
669        raise RuntimeError(f'nnodes={nnodes} is not in "MIN:MAX" format')  # noqa: E231
670
671    return min_nodes, max_nodes
672
673
674def determine_local_world_size(nproc_per_node: str):
675    try:
676        logging.info("Using nproc_per_node=%s.", nproc_per_node)
677        return int(nproc_per_node)
678    except ValueError as e:
679        if nproc_per_node == "cpu":
680            num_proc = os.cpu_count()
681            device_type = "cpu"
682        elif nproc_per_node == "gpu":
683            if not torch.cuda.is_available():
684                raise ValueError("Cuda is not available.") from e
685            device_type = "gpu"
686            num_proc = torch.cuda.device_count()
687        elif nproc_per_node == torch._C._get_privateuse1_backend_name():
688            if not _get_custom_mod_func("is_available")():
689                raise ValueError(f"{nproc_per_node} is not available.") from e
690            device_type = nproc_per_node
691            num_proc = _get_custom_mod_func("device_count")()
692        elif nproc_per_node == "auto":
693            if torch.cuda.is_available():
694                num_proc = torch.cuda.device_count()
695                device_type = "gpu"
696            elif (
697                hasattr(torch, torch._C._get_privateuse1_backend_name())
698                and _get_custom_mod_func("is_available")()
699            ):
700                num_proc = _get_custom_mod_func("device_count")()
701                device_type = torch._C._get_privateuse1_backend_name()
702            else:
703                num_proc = os.cpu_count()
704                device_type = "cpu"
705        else:
706            raise ValueError(
707                f"Unsupported nproc_per_node value: {nproc_per_node}"
708            ) from e
709
710        logger.info(
711            "Using nproc_per_node=%s, setting nproc_per_node to %s since the instance has %s %s",
712            nproc_per_node,
713            num_proc,
714            num_proc,
715            device_type,
716        )
717        return num_proc
718
719
720def get_rdzv_endpoint(args):
721    if args.rdzv_backend == "static" and not args.rdzv_endpoint:
722        return f"{args.master_addr}:{args.master_port}"  # noqa: E231
723    return args.rdzv_endpoint
724
725
726def get_use_env(args) -> bool:
727    """
728    Retrieve ``use_env`` from the args.
729
730    ``use_env`` is a legacy argument, if ``use_env`` is False, the
731    ``--node-rank`` argument will be transferred to all worker processes.
732    ``use_env`` is only used by the ``torch.distributed.launch`` and will
733    be deprecated in future releases.
734    """
735    if not hasattr(args, "use_env"):
736        return True
737    return args.use_env
738
739
740def _get_logs_specs_class(logs_specs_name: Optional[str]) -> Type[LogsSpecs]:
741    """
742    Attemps to load `torchrun.logs_spec` entrypoint with key of `logs_specs_name` param.
743    Provides plugin mechanism to provide custom implementation of LogsSpecs.
744
745    Returns `DefaultLogsSpecs` when logs_spec_name is None.
746    Raises ValueError when entrypoint for `logs_spec_name` can't be found in entrypoints.
747    """
748    logs_specs_cls = None
749    if logs_specs_name is not None:
750        eps = metadata.entry_points()
751        if hasattr(eps, "select"):  # >= 3.10
752            group = eps.select(group="torchrun.logs_specs")
753            if group.select(name=logs_specs_name):
754                logs_specs_cls = group[logs_specs_name].load()
755
756        elif specs := eps.get("torchrun.logs_specs"):  # < 3.10
757            if entrypoint_list := [ep for ep in specs if ep.name == logs_specs_name]:
758                logs_specs_cls = entrypoint_list[0].load()
759
760        if logs_specs_cls is None:
761            raise ValueError(
762                f"Could not find entrypoint under 'torchrun.logs_specs[{logs_specs_name}]' key"
763            )
764
765        logging.info(
766            "Using logs_spec '%s' mapped to %s", logs_specs_name, str(logs_specs_cls)
767        )
768    else:
769        logs_specs_cls = DefaultLogsSpecs
770
771    return logs_specs_cls
772
773
774def config_from_args(args) -> Tuple[LaunchConfig, Union[Callable, str], List[str]]:
775    # If ``args`` not passed, defaults to ``sys.argv[:1]``
776    min_nodes, max_nodes = parse_min_max_nnodes(args.nnodes)
777    assert 0 < min_nodes <= max_nodes
778    assert args.max_restarts >= 0
779
780    if (
781        hasattr(args, "master_addr")
782        and args.rdzv_backend != "static"
783        and not args.rdzv_endpoint
784    ):
785        logger.warning(
786            "master_addr is only used for static rdzv_backend and when rdzv_endpoint "
787            "is not specified."
788        )
789
790    nproc_per_node = determine_local_world_size(args.nproc_per_node)
791    if "OMP_NUM_THREADS" not in os.environ and nproc_per_node > 1:
792        omp_num_threads = 1
793        logger.warning(
794            "\n*****************************************\n"
795            "Setting OMP_NUM_THREADS environment variable for each process to be "
796            "%s in default, to avoid your system being overloaded, "
797            "please further tune the variable for optimal performance in "
798            "your application as needed. \n"
799            "*****************************************",
800            omp_num_threads,
801        )
802        # This env variable will be passed down to the subprocesses
803        os.environ["OMP_NUM_THREADS"] = str(omp_num_threads)
804
805    log_line_prefix_template = os.getenv("TORCHELASTIC_LOG_LINE_PREFIX_TEMPLATE")
806
807    rdzv_configs = _parse_rendezvous_config(args.rdzv_conf)
808
809    if args.rdzv_backend == "static":
810        rdzv_configs["rank"] = args.node_rank
811
812    rdzv_endpoint = get_rdzv_endpoint(args)
813
814    ranks: Optional[Set[int]] = None
815    if args.local_ranks_filter:
816        try:
817            ranks = set(map(int, args.local_ranks_filter.split(",")))
818            assert ranks
819        except Exception as e:
820            raise ValueError(
821                "--local_ranks_filter must be a comma-separated list of integers e.g. --local_ranks_filter=0,1,2"
822            ) from e
823
824    logs_specs_cls: Type[LogsSpecs] = _get_logs_specs_class(args.logs_specs)
825    logs_specs = logs_specs_cls(
826        log_dir=args.log_dir,
827        redirects=Std.from_str(args.redirects),
828        tee=Std.from_str(args.tee),
829        local_ranks_filter=ranks,
830    )
831
832    config = LaunchConfig(
833        min_nodes=min_nodes,
834        max_nodes=max_nodes,
835        nproc_per_node=nproc_per_node,
836        run_id=args.rdzv_id,
837        role=args.role,
838        rdzv_endpoint=rdzv_endpoint,
839        rdzv_backend=args.rdzv_backend,
840        rdzv_configs=rdzv_configs,
841        max_restarts=args.max_restarts,
842        monitor_interval=args.monitor_interval,
843        start_method=args.start_method,
844        log_line_prefix_template=log_line_prefix_template,
845        local_addr=args.local_addr,
846        logs_specs=logs_specs,
847    )
848
849    with_python = not args.no_python
850    cmd: Union[Callable, str]
851    cmd_args = []
852    use_env = get_use_env(args)
853    if args.run_path:
854        cmd = run_script_path
855        cmd_args.append(args.training_script)
856    else:
857        if with_python:
858            cmd = os.getenv("PYTHON_EXEC", sys.executable)
859            cmd_args.append("-u")
860            if args.module:
861                cmd_args.append("-m")
862            cmd_args.append(args.training_script)
863        else:
864            if args.module:
865                raise ValueError(
866                    "Don't use both the '--no-python' flag"
867                    " and the '--module' flag at the same time."
868                )
869            cmd = args.training_script
870    if not use_env:
871        cmd_args.append(f"--local-rank={macros.local_rank}")
872    cmd_args.extend(args.training_script_args)
873
874    return config, cmd, cmd_args
875
876
877def run_script_path(training_script: str, *training_script_args: str):
878    """
879    Run the provided `training_script` from within this interpreter.
880
881    Usage: `script_as_function("/abs/path/to/script.py", "--arg1", "val1")`
882    """
883    import runpy
884    import sys
885
886    sys.argv = [training_script] + [*training_script_args]
887    runpy.run_path(sys.argv[0], run_name="__main__")
888
889
890def run(args):
891    torch.multiprocessing._set_thread_name("pt_elastic")
892
893    if args.standalone:
894        args.rdzv_backend = "c10d"
895        args.rdzv_endpoint = "localhost:0"
896        args.rdzv_id = str(uuid.uuid4())
897        logger.info(
898            "\n**************************************\n"
899            "Rendezvous info:\n"
900            "--rdzv-backend=%s "
901            "--rdzv-endpoint=%s "
902            "--rdzv-id=%s\n"
903            "**************************************\n",
904            args.rdzv_backend,
905            args.rdzv_endpoint,
906            args.rdzv_id,
907        )
908
909    config, cmd, cmd_args = config_from_args(args)
910    elastic_launch(
911        config=config,
912        entrypoint=cmd,
913    )(*cmd_args)
914
915
916@record
917def main(args=None):
918    args = parse_args(args)
919    run(args)
920
921
922if __name__ == "__main__":
923    main()
924