1#!/usr/bin/env python3 2# mypy: allow-untyped-defs 3 4# Copyright (c) Facebook, Inc. and its affiliates. 5# All rights reserved. 6# 7# This source code is licensed under the BSD-style license found in the 8# LICENSE file in the root directory of this source tree. 9 10""" 11Superset of ``torch.distributed.launch``. 12 13``torchrun`` provides a superset of the functionality as ``torch.distributed.launch`` 14with the following additional functionalities: 15 161. Worker failures are handled gracefully by restarting all workers. 17 182. Worker ``RANK`` and ``WORLD_SIZE`` are assigned automatically. 19 203. Number of nodes is allowed to change between minimum and maximum sizes (elasticity). 21 22.. note:: ``torchrun`` is a python 23 `console script <https://packaging.python.org/en/latest/specifications/entry-points/#use-for-scripts>`_ 24 to the main module 25 `torch.distributed.run <https://github.com/pytorch/pytorch/blob/master/torch/distributed/run.py>`_ 26 declared in the ``entry_points`` configuration in 27 `setup.py <https://github.com/pytorch/pytorch/blob/master/setup.py>`_. 28 It is equivalent to invoking ``python -m torch.distributed.run``. 29 30 31Transitioning from torch.distributed.launch to torchrun 32~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ 33 34 35``torchrun`` supports the same arguments as ``torch.distributed.launch`` **except** 36for ``--use-env`` which is now deprecated. To migrate from ``torch.distributed.launch`` 37to ``torchrun`` follow these steps: 38 391. If your training script is already reading ``local_rank`` from the ``LOCAL_RANK`` environment variable. 40 Then you need simply omit the ``--use-env`` flag, e.g.: 41 42 +--------------------------------------------------------------------+--------------------------------------------+ 43 | ``torch.distributed.launch`` | ``torchrun`` | 44 +====================================================================+============================================+ 45 | | | 46 | .. code-block:: shell-session | .. code-block:: shell-session | 47 | | | 48 | $ python -m torch.distributed.launch --use-env train_script.py | $ torchrun train_script.py | 49 | | | 50 +--------------------------------------------------------------------+--------------------------------------------+ 51 522. If your training script reads local rank from a ``--local-rank`` cmd argument. 53 Change your training script to read from the ``LOCAL_RANK`` environment variable as 54 demonstrated by the following code snippet: 55 56 +-------------------------------------------------------+----------------------------------------------------+ 57 | ``torch.distributed.launch`` | ``torchrun`` | 58 +=======================================================+====================================================+ 59 | | | 60 | .. code-block:: python | .. code-block:: python | 61 | | | 62 | | | 63 | import argparse | import os | 64 | parser = argparse.ArgumentParser() | local_rank = int(os.environ["LOCAL_RANK"]) | 65 | parser.add_argument("--local-rank", type=int) | | 66 | args = parser.parse_args() | | 67 | | | 68 | local_rank = args.local_rank | | 69 | | | 70 +-------------------------------------------------------+----------------------------------------------------+ 71 72.. versionchanged:: 2.0.0 73 74 The launcher will pass the ``--local-rank=<rank>`` argument to your script. 75 From PyTorch 2.0.0 onwards, the dashed ``--local-rank`` is preferred over the 76 previously used underscored ``--local_rank``. 77 78 For backward compatibility, it may be necessary for users to handle both 79 cases in their argument parsing code. This means including both ``"--local-rank"`` 80 and ``"--local_rank"`` in the argument parser. If only ``"--local_rank"`` is 81 provided, the launcher will trigger an error: "error: unrecognized arguments: 82 --local-rank=<rank>". For training code that only supports PyTorch 2.0.0+, 83 including ``"--local-rank"`` should be sufficient. 84 85 :: 86 87 >>> # xdoctest: +SKIP 88 >>> import argparse 89 >>> parser = argparse.ArgumentParser() 90 >>> parser.add_argument("--local-rank", "--local_rank", type=int) 91 >>> args = parser.parse_args() 92 93The aformentioned changes suffice to migrate from ``torch.distributed.launch`` to ``torchrun``. 94To take advantage of new features such as elasticity, fault-tolerance, and error reporting of ``torchrun`` 95please refer to: 96 97* :ref:`elastic_train_script` for more information on authoring training scripts that are ``torchrun`` compliant. 98* the rest of this page for more information on the features of ``torchrun``. 99 100 101Usage 102-------- 103 104Single-node multi-worker 105++++++++++++++++++++++++++++++ 106 107:: 108 109 torchrun 110 --standalone 111 --nnodes=1 112 --nproc-per-node=$NUM_TRAINERS 113 YOUR_TRAINING_SCRIPT.py (--arg1 ... train script args...) 114 115Stacked single-node multi-worker 116+++++++++++++++++++++++++++++++++++ 117 118To run multiple instances (separate jobs) of single-node, multi-worker on the 119same host, we need to make sure that each instance (job) is 120setup on different ports to avoid port conflicts (or worse, two jobs being merged 121as a single job). To do this you have to run with ``--rdzv-backend=c10d`` 122and specify a different port by setting ``--rdzv-endpoint=localhost:$PORT_k``. 123For ``--nodes=1``, its often convenient to let ``torchrun`` pick a free random 124port automatically instead of manually assigning different ports for each run. 125 126:: 127 128 torchrun 129 --rdzv-backend=c10d 130 --rdzv-endpoint=localhost:0 131 --nnodes=1 132 --nproc-per-node=$NUM_TRAINERS 133 YOUR_TRAINING_SCRIPT.py (--arg1 ... train script args...) 134 135 136Fault tolerant (fixed sized number of workers, no elasticity, tolerates 3 failures) 137++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ 138 139:: 140 141 torchrun 142 --nnodes=$NUM_NODES 143 --nproc-per-node=$NUM_TRAINERS 144 --max-restarts=3 145 --rdzv-id=$JOB_ID 146 --rdzv-backend=c10d 147 --rdzv-endpoint=$HOST_NODE_ADDR 148 YOUR_TRAINING_SCRIPT.py (--arg1 ... train script args...) 149 150``HOST_NODE_ADDR``, in form <host>[:<port>] (e.g. node1.example.com:29400), specifies the node and 151the port on which the C10d rendezvous backend should be instantiated and hosted. It can be any 152node in your training cluster, but ideally you should pick a node that has a high bandwidth. 153 154.. note:: 155 If no port number is specified ``HOST_NODE_ADDR`` defaults to 29400. 156 157Elastic (``min=1``, ``max=4``, tolerates up to 3 membership changes or failures) 158+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ 159 160:: 161 162 torchrun 163 --nnodes=1:4 164 --nproc-per-node=$NUM_TRAINERS 165 --max-restarts=3 166 --rdzv-id=$JOB_ID 167 --rdzv-backend=c10d 168 --rdzv-endpoint=$HOST_NODE_ADDR 169 YOUR_TRAINING_SCRIPT.py (--arg1 ... train script args...) 170 171``HOST_NODE_ADDR``, in form <host>[:<port>] (e.g. node1.example.com:29400), specifies the node and 172the port on which the C10d rendezvous backend should be instantiated and hosted. It can be any 173node in your training cluster, but ideally you should pick a node that has a high bandwidth. 174 175.. note:: 176 If no port number is specified ``HOST_NODE_ADDR`` defaults to 29400. 177 178Note on rendezvous backend 179------------------------------ 180 181For multi-node training you need to specify: 182 1831. ``--rdzv-id``: A unique job id (shared by all nodes participating in the job) 1842. ``--rdzv-backend``: An implementation of 185 :py:class:`torch.distributed.elastic.rendezvous.RendezvousHandler` 1863. ``--rdzv-endpoint``: The endpoint where the rendezvous backend is running; usually in form 187 ``host:port``. 188 189Currently ``c10d`` (recommended), ``etcd-v2``, and ``etcd`` (legacy) rendezvous backends are 190supported out of the box. To use ``etcd-v2`` or ``etcd``, setup an etcd server with the ``v2`` api 191enabled (e.g. ``--enable-v2``). 192 193.. warning:: 194 ``etcd-v2`` and ``etcd`` rendezvous use etcd API v2. You MUST enable the v2 API on the etcd 195 server. Our tests use etcd v3.4.3. 196 197.. warning:: 198 For etcd-based rendezvous we recommend using ``etcd-v2`` over ``etcd`` which is functionally 199 equivalent, but uses a revised implementation. ``etcd`` is in maintenance mode and will be 200 removed in a future version. 201 202Definitions 203-------------- 204 2051. ``Node`` - A physical instance or a container; maps to the unit that the job manager works with. 206 2072. ``Worker`` - A worker in the context of distributed training. 208 2093. ``WorkerGroup`` - The set of workers that execute the same function (e.g. trainers). 210 2114. ``LocalWorkerGroup`` - A subset of the workers in the worker group running on the same node. 212 2135. ``RANK`` - The rank of the worker within a worker group. 214 2156. ``WORLD_SIZE`` - The total number of workers in a worker group. 216 2177. ``LOCAL_RANK`` - The rank of the worker within a local worker group. 218 2198. ``LOCAL_WORLD_SIZE`` - The size of the local worker group. 220 2219. ``rdzv_id`` - A user-defined id that uniquely identifies the worker group for a job. This id is 222 used by each node to join as a member of a particular worker group. 223 2249. ``rdzv_backend`` - The backend of the rendezvous (e.g. ``c10d``). This is typically a strongly 225 consistent key-value store. 226 22710. ``rdzv_endpoint`` - The rendezvous backend endpoint; usually in form ``<host>:<port>``. 228 229A ``Node`` runs ``LOCAL_WORLD_SIZE`` workers which comprise a ``LocalWorkerGroup``. The union of 230all ``LocalWorkerGroups`` in the nodes in the job comprise the ``WorkerGroup``. 231 232Environment Variables 233---------------------- 234 235The following environment variables are made available to you in your script: 236 2371. ``LOCAL_RANK`` - The local rank. 238 2392. ``RANK`` - The global rank. 240 2413. ``GROUP_RANK`` - The rank of the worker group. A number between 0 and ``max_nnodes``. When 242 running a single worker group per node, this is the rank of the node. 243 2444. ``ROLE_RANK`` - The rank of the worker across all the workers that have the same role. The role 245 of the worker is specified in the ``WorkerSpec``. 246 2475. ``LOCAL_WORLD_SIZE`` - The local world size (e.g. number of workers running locally); equals to 248 ``--nproc-per-node`` specified on ``torchrun``. 249 2506. ``WORLD_SIZE`` - The world size (total number of workers in the job). 251 2527. ``ROLE_WORLD_SIZE`` - The total number of workers that was launched with the same role specified 253 in ``WorkerSpec``. 254 2558. ``MASTER_ADDR`` - The FQDN of the host that is running worker with rank 0; used to initialize 256 the Torch Distributed backend. 257 2589. ``MASTER_PORT`` - The port on the ``MASTER_ADDR`` that can be used to host the C10d TCP store. 259 26010. ``TORCHELASTIC_RESTART_COUNT`` - The number of worker group restarts so far. 261 26211. ``TORCHELASTIC_MAX_RESTARTS`` - The configured maximum number of restarts. 263 26412. ``TORCHELASTIC_RUN_ID`` - Equal to the rendezvous ``run_id`` (e.g. unique job id). 265 26613. ``PYTHON_EXEC`` - System executable override. If provided, the python user script will 267 use the value of ``PYTHON_EXEC`` as executable. The `sys.executable` is used by default. 268 269Deployment 270------------ 271 2721. (Not needed for the C10d backend) Start the rendezvous backend server and get the endpoint (to be 273 passed as ``--rdzv-endpoint`` to the launcher script) 274 2752. Single-node multi-worker: Start the launcher on the host to start the agent process which 276 creates and monitors a local worker group. 277 2783. Multi-node multi-worker: Start the launcher with the same arguments on all the nodes 279 participating in training. 280 281When using a job/cluster manager the entry point command to the multi-node job should be this 282launcher. 283 284Failure Modes 285--------------- 286 2871. Worker failure: For a training job with ``n`` workers, if ``k<=n`` workers fail all workers 288 are stopped and restarted up to ``max_restarts``. 289 2902. Agent failure: An agent failure results in a local worker group failure. It is up to the job 291 manager to fail the entire job (gang semantics) or attempt to replace the node. Both behaviors 292 are supported by the agent. 293 2943. Node failure: Same as agent failure. 295 296Membership Changes 297-------------------- 298 2991. Node departure (scale-down): The agent is notified of the departure, all existing workers are 300 stopped, a new ``WorkerGroup`` is formed, and all workers are started with a new ``RANK`` and 301 ``WORLD_SIZE``. 302 3032. Node arrival (scale-up): The new node is admitted to the job, all existing workers are stopped, 304 a new ``WorkerGroup`` is formed, and all workers are started with a new ``RANK`` and 305 ``WORLD_SIZE``. 306 307Important Notices 308-------------------- 309 3101. This utility and multi-process distributed (single-node or 311 multi-node) GPU training currently only achieves the best performance using 312 the NCCL distributed backend. Thus NCCL backend is the recommended backend to 313 use for GPU training. 314 3152. The environment variables necessary to initialize a Torch process group are provided to you by 316 this module, no need for you to pass ``RANK`` manually. To initialize a process group in your 317 training script, simply run: 318 319:: 320 321 >>> # xdoctest: +SKIP("stub") 322 >>> import torch.distributed as dist 323 >>> dist.init_process_group(backend="gloo|nccl") 324 3253. In your training program, you can either use regular distributed functions 326 or use :func:`torch.nn.parallel.DistributedDataParallel` module. If your 327 training program uses GPUs for training and you would like to use 328 :func:`torch.nn.parallel.DistributedDataParallel` module, 329 here is how to configure it. 330 331:: 332 333 local_rank = int(os.environ["LOCAL_RANK"]) 334 model = torch.nn.parallel.DistributedDataParallel(model, 335 device_ids=[local_rank], 336 output_device=local_rank) 337 338Please ensure that ``device_ids`` argument is set to be the only GPU device id 339that your code will be operating on. This is generally the local rank of the 340process. In other words, the ``device_ids`` needs to be ``[int(os.environ("LOCAL_RANK"))]``, 341and ``output_device`` needs to be ``int(os.environ("LOCAL_RANK"))`` in order to use this 342utility 343 344 3454. On failures or membership changes ALL surviving workers are killed immediately. Make sure to 346 checkpoint your progress. The frequency of checkpoints should depend on your job's tolerance 347 for lost work. 348 3495. This module only supports homogeneous ``LOCAL_WORLD_SIZE``. That is, it is assumed that all 350 nodes run the same number of local workers (per role). 351 3526. ``RANK`` is NOT stable. Between restarts, the local workers on a node can be assigned a 353 different range of ranks than before. NEVER hard code any assumptions about the stable-ness of 354 ranks or some correlation between ``RANK`` and ``LOCAL_RANK``. 355 3567. When using elasticity (``min_size!=max_size``) DO NOT hard code assumptions about 357 ``WORLD_SIZE`` as the world size can change as nodes are allowed to leave and join. 358 3598. It is recommended for your script to have the following structure: 360 361:: 362 363 def main(): 364 load_checkpoint(checkpoint_path) 365 initialize() 366 train() 367 368 def train(): 369 for batch in iter(dataset): 370 train_step(batch) 371 372 if should_checkpoint: 373 save_checkpoint(checkpoint_path) 374 3759. (Recommended) On worker errors, this tool will summarize the details of the error 376 (e.g. time, rank, host, pid, traceback, etc). On each node, the first error (by timestamp) 377 is heuristically reported as the "Root Cause" error. To get tracebacks as part of this 378 error summary print out, you must decorate your main entrypoint function in your 379 training script as shown in the example below. If not decorated, then the summary 380 will not include the traceback of the exception and will only contain the exitcode. 381 For details on torchelastic error handling see: https://pytorch.org/docs/stable/elastic/errors.html 382 383:: 384 385 from torch.distributed.elastic.multiprocessing.errors import record 386 387 @record 388 def main(): 389 # do train 390 pass 391 392 if __name__ == "__main__": 393 main() 394 395""" 396import logging 397import os 398import sys 399import uuid 400from argparse import ArgumentParser, REMAINDER 401from importlib import metadata 402from typing import Callable, List, Optional, Set, Tuple, Type, Union 403 404import torch 405from torch.distributed.argparse_util import check_env, env 406from torch.distributed.elastic.multiprocessing import DefaultLogsSpecs, LogsSpecs, Std 407from torch.distributed.elastic.multiprocessing.errors import record 408from torch.distributed.elastic.rendezvous.utils import _parse_rendezvous_config 409from torch.distributed.elastic.utils import macros 410from torch.distributed.elastic.utils.logging import get_logger 411from torch.distributed.launcher.api import elastic_launch, LaunchConfig 412from torch.utils.backend_registration import _get_custom_mod_func 413 414 415logger = get_logger(__name__) 416 417 418def get_args_parser() -> ArgumentParser: 419 """Parse the command line options.""" 420 parser = ArgumentParser(description="Torch Distributed Elastic Training Launcher") 421 422 # 423 # Worker/node size related arguments. 424 # 425 426 parser.add_argument( 427 "--nnodes", 428 action=env, 429 type=str, 430 default="1:1", 431 help="Number of nodes, or the range of nodes in form <minimum_nodes>:<maximum_nodes>.", 432 ) 433 parser.add_argument( 434 "--nproc-per-node", 435 "--nproc_per_node", 436 action=env, 437 type=str, 438 default="1", 439 help="Number of workers per node; supported values: [auto, cpu, gpu, int].", 440 ) 441 442 # 443 # Rendezvous related arguments 444 # 445 446 parser.add_argument( 447 "--rdzv-backend", 448 "--rdzv_backend", 449 action=env, 450 type=str, 451 default="static", 452 help="Rendezvous backend.", 453 ) 454 parser.add_argument( 455 "--rdzv-endpoint", 456 "--rdzv_endpoint", 457 action=env, 458 type=str, 459 default="", 460 help="Rendezvous backend endpoint; usually in form <host>:<port>.", 461 ) 462 parser.add_argument( 463 "--rdzv-id", 464 "--rdzv_id", 465 action=env, 466 type=str, 467 default="none", 468 help="User-defined group id.", 469 ) 470 parser.add_argument( 471 "--rdzv-conf", 472 "--rdzv_conf", 473 action=env, 474 type=str, 475 default="", 476 help="Additional rendezvous configuration (<key1>=<value1>,<key2>=<value2>,...).", 477 ) 478 parser.add_argument( 479 "--standalone", 480 action=check_env, 481 help="Start a local standalone rendezvous backend that is represented by a C10d TCP store " 482 "on a free port. Useful when launching single-node, multi-worker job. If specified " 483 "--rdzv-backend, --rdzv-endpoint, --rdzv-id are auto-assigned and any explicitly set values " 484 "are ignored.", 485 ) 486 487 # 488 # User-code launch related arguments. 489 # 490 491 parser.add_argument( 492 "--max-restarts", 493 "--max_restarts", 494 action=env, 495 type=int, 496 default=0, 497 help="Maximum number of worker group restarts before failing.", 498 ) 499 parser.add_argument( 500 "--monitor-interval", 501 "--monitor_interval", 502 action=env, 503 type=float, 504 default=0.1, 505 help="Interval, in seconds, to monitor the state of workers.", 506 ) 507 parser.add_argument( 508 "--start-method", 509 "--start_method", 510 action=env, 511 type=str, 512 default="spawn", 513 choices=["spawn", "fork", "forkserver"], 514 help="Multiprocessing start method to use when creating workers.", 515 ) 516 parser.add_argument( 517 "--role", 518 action=env, 519 type=str, 520 default="default", 521 help="User-defined role for the workers.", 522 ) 523 parser.add_argument( 524 "-m", 525 "--module", 526 action=check_env, 527 help="Change each process to interpret the launch script as a Python module, executing " 528 "with the same behavior as 'python -m'.", 529 ) 530 parser.add_argument( 531 "--no-python", 532 "--no_python", 533 action=check_env, 534 help="Skip prepending the training script with 'python' - just execute it directly. Useful " 535 "when the script is not a Python script.", 536 ) 537 538 parser.add_argument( 539 "--run-path", 540 "--run_path", 541 action=check_env, 542 help="Run the training script with runpy.run_path in the same interpreter." 543 " Script must be provided as an abs path (e.g. /abs/path/script.py)." 544 " Takes precedence over --no-python.", 545 ) 546 parser.add_argument( 547 "--log-dir", 548 "--log_dir", 549 action=env, 550 type=str, 551 default=None, 552 help="Base directory to use for log files (e.g. /var/log/torch/elastic). The same " 553 "directory is re-used for multiple runs (a unique job-level sub-directory is created with " 554 "rdzv_id as the prefix).", 555 ) 556 parser.add_argument( 557 "-r", 558 "--redirects", 559 action=env, 560 type=str, 561 default="0", 562 help="Redirect std streams into a log file in the log directory (e.g. [-r 3] redirects " 563 "both stdout+stderr for all workers, [-r 0:1,1:2] redirects stdout for local rank 0 and " 564 "stderr for local rank 1).", 565 ) 566 parser.add_argument( 567 "-t", 568 "--tee", 569 action=env, 570 type=str, 571 default="0", 572 help="Tee std streams into a log file and also to console (see --redirects for format).", 573 ) 574 575 parser.add_argument( 576 "--local-ranks-filter", 577 "--local_ranks_filter", 578 action=env, 579 type=str, 580 default="", 581 help="Only show logs from specified ranks in console (e.g. [--local_ranks_filter=0,1,2] will " 582 "only show logs from rank 0, 1 and 2). This will only apply to stdout and stderr, not to" 583 "log files saved via --redirect or --tee", 584 ) 585 586 # 587 # Backwards compatible parameters with caffe2.distributed.launch. 588 # 589 590 parser.add_argument( 591 "--node-rank", 592 "--node_rank", 593 type=int, 594 action=env, 595 default=0, 596 help="Rank of the node for multi-node distributed training.", 597 ) 598 parser.add_argument( 599 "--master-addr", 600 "--master_addr", 601 default="127.0.0.1", 602 type=str, 603 action=env, 604 help="Address of the master node (rank 0) that only used for static rendezvous. It should " 605 "be either the IP address or the hostname of rank 0. For single node multi-proc training " 606 "the --master-addr can simply be 127.0.0.1; IPv6 should have the pattern " 607 "`[0:0:0:0:0:0:0:1]`.", 608 ) 609 parser.add_argument( 610 "--master-port", 611 "--master_port", 612 default=29500, 613 type=int, 614 action=env, 615 help="Port on the master node (rank 0) to be used for communication during distributed " 616 "training. It is only used for static rendezvous.", 617 ) 618 parser.add_argument( 619 "--local-addr", 620 "--local_addr", 621 default=None, 622 type=str, 623 action=env, 624 help="Address of the local node. If specified, will use the given address for connection. " 625 "Else, will look up the local node address instead. Else, it will be default to local " 626 "machine's FQDN.", 627 ) 628 629 parser.add_argument( 630 "--logs-specs", 631 "--logs_specs", 632 default=None, 633 type=str, 634 help="torchrun.logs_specs group entrypoint name, value must be type of LogsSpecs. " 635 "Can be used to override custom logging behavior.", 636 ) 637 638 # 639 # Positional arguments. 640 # 641 642 parser.add_argument( 643 "training_script", 644 type=str, 645 help="Full path to the (single GPU) training program/script to be launched in parallel, " 646 "followed by all the arguments for the training script.", 647 ) 648 649 # Rest from the training program. 650 parser.add_argument("training_script_args", nargs=REMAINDER) 651 652 return parser 653 654 655def parse_args(args): 656 parser = get_args_parser() 657 return parser.parse_args(args) 658 659 660def parse_min_max_nnodes(nnodes: str): 661 arr = nnodes.split(":") 662 663 if len(arr) == 1: 664 min_nodes = max_nodes = int(arr[0]) 665 elif len(arr) == 2: 666 min_nodes = int(arr[0]) 667 max_nodes = int(arr[1]) 668 else: 669 raise RuntimeError(f'nnodes={nnodes} is not in "MIN:MAX" format') # noqa: E231 670 671 return min_nodes, max_nodes 672 673 674def determine_local_world_size(nproc_per_node: str): 675 try: 676 logging.info("Using nproc_per_node=%s.", nproc_per_node) 677 return int(nproc_per_node) 678 except ValueError as e: 679 if nproc_per_node == "cpu": 680 num_proc = os.cpu_count() 681 device_type = "cpu" 682 elif nproc_per_node == "gpu": 683 if not torch.cuda.is_available(): 684 raise ValueError("Cuda is not available.") from e 685 device_type = "gpu" 686 num_proc = torch.cuda.device_count() 687 elif nproc_per_node == torch._C._get_privateuse1_backend_name(): 688 if not _get_custom_mod_func("is_available")(): 689 raise ValueError(f"{nproc_per_node} is not available.") from e 690 device_type = nproc_per_node 691 num_proc = _get_custom_mod_func("device_count")() 692 elif nproc_per_node == "auto": 693 if torch.cuda.is_available(): 694 num_proc = torch.cuda.device_count() 695 device_type = "gpu" 696 elif ( 697 hasattr(torch, torch._C._get_privateuse1_backend_name()) 698 and _get_custom_mod_func("is_available")() 699 ): 700 num_proc = _get_custom_mod_func("device_count")() 701 device_type = torch._C._get_privateuse1_backend_name() 702 else: 703 num_proc = os.cpu_count() 704 device_type = "cpu" 705 else: 706 raise ValueError( 707 f"Unsupported nproc_per_node value: {nproc_per_node}" 708 ) from e 709 710 logger.info( 711 "Using nproc_per_node=%s, setting nproc_per_node to %s since the instance has %s %s", 712 nproc_per_node, 713 num_proc, 714 num_proc, 715 device_type, 716 ) 717 return num_proc 718 719 720def get_rdzv_endpoint(args): 721 if args.rdzv_backend == "static" and not args.rdzv_endpoint: 722 return f"{args.master_addr}:{args.master_port}" # noqa: E231 723 return args.rdzv_endpoint 724 725 726def get_use_env(args) -> bool: 727 """ 728 Retrieve ``use_env`` from the args. 729 730 ``use_env`` is a legacy argument, if ``use_env`` is False, the 731 ``--node-rank`` argument will be transferred to all worker processes. 732 ``use_env`` is only used by the ``torch.distributed.launch`` and will 733 be deprecated in future releases. 734 """ 735 if not hasattr(args, "use_env"): 736 return True 737 return args.use_env 738 739 740def _get_logs_specs_class(logs_specs_name: Optional[str]) -> Type[LogsSpecs]: 741 """ 742 Attemps to load `torchrun.logs_spec` entrypoint with key of `logs_specs_name` param. 743 Provides plugin mechanism to provide custom implementation of LogsSpecs. 744 745 Returns `DefaultLogsSpecs` when logs_spec_name is None. 746 Raises ValueError when entrypoint for `logs_spec_name` can't be found in entrypoints. 747 """ 748 logs_specs_cls = None 749 if logs_specs_name is not None: 750 eps = metadata.entry_points() 751 if hasattr(eps, "select"): # >= 3.10 752 group = eps.select(group="torchrun.logs_specs") 753 if group.select(name=logs_specs_name): 754 logs_specs_cls = group[logs_specs_name].load() 755 756 elif specs := eps.get("torchrun.logs_specs"): # < 3.10 757 if entrypoint_list := [ep for ep in specs if ep.name == logs_specs_name]: 758 logs_specs_cls = entrypoint_list[0].load() 759 760 if logs_specs_cls is None: 761 raise ValueError( 762 f"Could not find entrypoint under 'torchrun.logs_specs[{logs_specs_name}]' key" 763 ) 764 765 logging.info( 766 "Using logs_spec '%s' mapped to %s", logs_specs_name, str(logs_specs_cls) 767 ) 768 else: 769 logs_specs_cls = DefaultLogsSpecs 770 771 return logs_specs_cls 772 773 774def config_from_args(args) -> Tuple[LaunchConfig, Union[Callable, str], List[str]]: 775 # If ``args`` not passed, defaults to ``sys.argv[:1]`` 776 min_nodes, max_nodes = parse_min_max_nnodes(args.nnodes) 777 assert 0 < min_nodes <= max_nodes 778 assert args.max_restarts >= 0 779 780 if ( 781 hasattr(args, "master_addr") 782 and args.rdzv_backend != "static" 783 and not args.rdzv_endpoint 784 ): 785 logger.warning( 786 "master_addr is only used for static rdzv_backend and when rdzv_endpoint " 787 "is not specified." 788 ) 789 790 nproc_per_node = determine_local_world_size(args.nproc_per_node) 791 if "OMP_NUM_THREADS" not in os.environ and nproc_per_node > 1: 792 omp_num_threads = 1 793 logger.warning( 794 "\n*****************************************\n" 795 "Setting OMP_NUM_THREADS environment variable for each process to be " 796 "%s in default, to avoid your system being overloaded, " 797 "please further tune the variable for optimal performance in " 798 "your application as needed. \n" 799 "*****************************************", 800 omp_num_threads, 801 ) 802 # This env variable will be passed down to the subprocesses 803 os.environ["OMP_NUM_THREADS"] = str(omp_num_threads) 804 805 log_line_prefix_template = os.getenv("TORCHELASTIC_LOG_LINE_PREFIX_TEMPLATE") 806 807 rdzv_configs = _parse_rendezvous_config(args.rdzv_conf) 808 809 if args.rdzv_backend == "static": 810 rdzv_configs["rank"] = args.node_rank 811 812 rdzv_endpoint = get_rdzv_endpoint(args) 813 814 ranks: Optional[Set[int]] = None 815 if args.local_ranks_filter: 816 try: 817 ranks = set(map(int, args.local_ranks_filter.split(","))) 818 assert ranks 819 except Exception as e: 820 raise ValueError( 821 "--local_ranks_filter must be a comma-separated list of integers e.g. --local_ranks_filter=0,1,2" 822 ) from e 823 824 logs_specs_cls: Type[LogsSpecs] = _get_logs_specs_class(args.logs_specs) 825 logs_specs = logs_specs_cls( 826 log_dir=args.log_dir, 827 redirects=Std.from_str(args.redirects), 828 tee=Std.from_str(args.tee), 829 local_ranks_filter=ranks, 830 ) 831 832 config = LaunchConfig( 833 min_nodes=min_nodes, 834 max_nodes=max_nodes, 835 nproc_per_node=nproc_per_node, 836 run_id=args.rdzv_id, 837 role=args.role, 838 rdzv_endpoint=rdzv_endpoint, 839 rdzv_backend=args.rdzv_backend, 840 rdzv_configs=rdzv_configs, 841 max_restarts=args.max_restarts, 842 monitor_interval=args.monitor_interval, 843 start_method=args.start_method, 844 log_line_prefix_template=log_line_prefix_template, 845 local_addr=args.local_addr, 846 logs_specs=logs_specs, 847 ) 848 849 with_python = not args.no_python 850 cmd: Union[Callable, str] 851 cmd_args = [] 852 use_env = get_use_env(args) 853 if args.run_path: 854 cmd = run_script_path 855 cmd_args.append(args.training_script) 856 else: 857 if with_python: 858 cmd = os.getenv("PYTHON_EXEC", sys.executable) 859 cmd_args.append("-u") 860 if args.module: 861 cmd_args.append("-m") 862 cmd_args.append(args.training_script) 863 else: 864 if args.module: 865 raise ValueError( 866 "Don't use both the '--no-python' flag" 867 " and the '--module' flag at the same time." 868 ) 869 cmd = args.training_script 870 if not use_env: 871 cmd_args.append(f"--local-rank={macros.local_rank}") 872 cmd_args.extend(args.training_script_args) 873 874 return config, cmd, cmd_args 875 876 877def run_script_path(training_script: str, *training_script_args: str): 878 """ 879 Run the provided `training_script` from within this interpreter. 880 881 Usage: `script_as_function("/abs/path/to/script.py", "--arg1", "val1")` 882 """ 883 import runpy 884 import sys 885 886 sys.argv = [training_script] + [*training_script_args] 887 runpy.run_path(sys.argv[0], run_name="__main__") 888 889 890def run(args): 891 torch.multiprocessing._set_thread_name("pt_elastic") 892 893 if args.standalone: 894 args.rdzv_backend = "c10d" 895 args.rdzv_endpoint = "localhost:0" 896 args.rdzv_id = str(uuid.uuid4()) 897 logger.info( 898 "\n**************************************\n" 899 "Rendezvous info:\n" 900 "--rdzv-backend=%s " 901 "--rdzv-endpoint=%s " 902 "--rdzv-id=%s\n" 903 "**************************************\n", 904 args.rdzv_backend, 905 args.rdzv_endpoint, 906 args.rdzv_id, 907 ) 908 909 config, cmd, cmd_args = config_from_args(args) 910 elastic_launch( 911 config=config, 912 entrypoint=cmd, 913 )(*cmd_args) 914 915 916@record 917def main(args=None): 918 args = parse_args(args) 919 run(args) 920 921 922if __name__ == "__main__": 923 main() 924