1*da0073e9SAndroid Build Coastguard Worker#!/usr/bin/env python3 2*da0073e9SAndroid Build Coastguard Worker# mypy: allow-untyped-defs 3*da0073e9SAndroid Build Coastguard Worker 4*da0073e9SAndroid Build Coastguard Worker# Copyright (c) Facebook, Inc. and its affiliates. 5*da0073e9SAndroid Build Coastguard Worker# All rights reserved. 6*da0073e9SAndroid Build Coastguard Worker# 7*da0073e9SAndroid Build Coastguard Worker# This source code is licensed under the BSD-style license found in the 8*da0073e9SAndroid Build Coastguard Worker# LICENSE file in the root directory of this source tree. 9*da0073e9SAndroid Build Coastguard Worker 10*da0073e9SAndroid Build Coastguard Worker""" 11*da0073e9SAndroid Build Coastguard WorkerSuperset of ``torch.distributed.launch``. 12*da0073e9SAndroid Build Coastguard Worker 13*da0073e9SAndroid Build Coastguard Worker``torchrun`` provides a superset of the functionality as ``torch.distributed.launch`` 14*da0073e9SAndroid Build Coastguard Workerwith the following additional functionalities: 15*da0073e9SAndroid Build Coastguard Worker 16*da0073e9SAndroid Build Coastguard Worker1. Worker failures are handled gracefully by restarting all workers. 17*da0073e9SAndroid Build Coastguard Worker 18*da0073e9SAndroid Build Coastguard Worker2. Worker ``RANK`` and ``WORLD_SIZE`` are assigned automatically. 19*da0073e9SAndroid Build Coastguard Worker 20*da0073e9SAndroid Build Coastguard Worker3. Number of nodes is allowed to change between minimum and maximum sizes (elasticity). 21*da0073e9SAndroid Build Coastguard Worker 22*da0073e9SAndroid Build Coastguard Worker.. note:: ``torchrun`` is a python 23*da0073e9SAndroid Build Coastguard Worker `console script <https://packaging.python.org/en/latest/specifications/entry-points/#use-for-scripts>`_ 24*da0073e9SAndroid Build Coastguard Worker to the main module 25*da0073e9SAndroid Build Coastguard Worker `torch.distributed.run <https://github.com/pytorch/pytorch/blob/master/torch/distributed/run.py>`_ 26*da0073e9SAndroid Build Coastguard Worker declared in the ``entry_points`` configuration in 27*da0073e9SAndroid Build Coastguard Worker `setup.py <https://github.com/pytorch/pytorch/blob/master/setup.py>`_. 28*da0073e9SAndroid Build Coastguard Worker It is equivalent to invoking ``python -m torch.distributed.run``. 29*da0073e9SAndroid Build Coastguard Worker 30*da0073e9SAndroid Build Coastguard Worker 31*da0073e9SAndroid Build Coastguard WorkerTransitioning from torch.distributed.launch to torchrun 32*da0073e9SAndroid Build Coastguard Worker~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ 33*da0073e9SAndroid Build Coastguard Worker 34*da0073e9SAndroid Build Coastguard Worker 35*da0073e9SAndroid Build Coastguard Worker``torchrun`` supports the same arguments as ``torch.distributed.launch`` **except** 36*da0073e9SAndroid Build Coastguard Workerfor ``--use-env`` which is now deprecated. To migrate from ``torch.distributed.launch`` 37*da0073e9SAndroid Build Coastguard Workerto ``torchrun`` follow these steps: 38*da0073e9SAndroid Build Coastguard Worker 39*da0073e9SAndroid Build Coastguard Worker1. If your training script is already reading ``local_rank`` from the ``LOCAL_RANK`` environment variable. 40*da0073e9SAndroid Build Coastguard Worker Then you need simply omit the ``--use-env`` flag, e.g.: 41*da0073e9SAndroid Build Coastguard Worker 42*da0073e9SAndroid Build Coastguard Worker +--------------------------------------------------------------------+--------------------------------------------+ 43*da0073e9SAndroid Build Coastguard Worker | ``torch.distributed.launch`` | ``torchrun`` | 44*da0073e9SAndroid Build Coastguard Worker +====================================================================+============================================+ 45*da0073e9SAndroid Build Coastguard Worker | | | 46*da0073e9SAndroid Build Coastguard Worker | .. code-block:: shell-session | .. code-block:: shell-session | 47*da0073e9SAndroid Build Coastguard Worker | | | 48*da0073e9SAndroid Build Coastguard Worker | $ python -m torch.distributed.launch --use-env train_script.py | $ torchrun train_script.py | 49*da0073e9SAndroid Build Coastguard Worker | | | 50*da0073e9SAndroid Build Coastguard Worker +--------------------------------------------------------------------+--------------------------------------------+ 51*da0073e9SAndroid Build Coastguard Worker 52*da0073e9SAndroid Build Coastguard Worker2. If your training script reads local rank from a ``--local-rank`` cmd argument. 53*da0073e9SAndroid Build Coastguard Worker Change your training script to read from the ``LOCAL_RANK`` environment variable as 54*da0073e9SAndroid Build Coastguard Worker demonstrated by the following code snippet: 55*da0073e9SAndroid Build Coastguard Worker 56*da0073e9SAndroid Build Coastguard Worker +-------------------------------------------------------+----------------------------------------------------+ 57*da0073e9SAndroid Build Coastguard Worker | ``torch.distributed.launch`` | ``torchrun`` | 58*da0073e9SAndroid Build Coastguard Worker +=======================================================+====================================================+ 59*da0073e9SAndroid Build Coastguard Worker | | | 60*da0073e9SAndroid Build Coastguard Worker | .. code-block:: python | .. code-block:: python | 61*da0073e9SAndroid Build Coastguard Worker | | | 62*da0073e9SAndroid Build Coastguard Worker | | | 63*da0073e9SAndroid Build Coastguard Worker | import argparse | import os | 64*da0073e9SAndroid Build Coastguard Worker | parser = argparse.ArgumentParser() | local_rank = int(os.environ["LOCAL_RANK"]) | 65*da0073e9SAndroid Build Coastguard Worker | parser.add_argument("--local-rank", type=int) | | 66*da0073e9SAndroid Build Coastguard Worker | args = parser.parse_args() | | 67*da0073e9SAndroid Build Coastguard Worker | | | 68*da0073e9SAndroid Build Coastguard Worker | local_rank = args.local_rank | | 69*da0073e9SAndroid Build Coastguard Worker | | | 70*da0073e9SAndroid Build Coastguard Worker +-------------------------------------------------------+----------------------------------------------------+ 71*da0073e9SAndroid Build Coastguard Worker 72*da0073e9SAndroid Build Coastguard Worker.. versionchanged:: 2.0.0 73*da0073e9SAndroid Build Coastguard Worker 74*da0073e9SAndroid Build Coastguard Worker The launcher will pass the ``--local-rank=<rank>`` argument to your script. 75*da0073e9SAndroid Build Coastguard Worker From PyTorch 2.0.0 onwards, the dashed ``--local-rank`` is preferred over the 76*da0073e9SAndroid Build Coastguard Worker previously used underscored ``--local_rank``. 77*da0073e9SAndroid Build Coastguard Worker 78*da0073e9SAndroid Build Coastguard Worker For backward compatibility, it may be necessary for users to handle both 79*da0073e9SAndroid Build Coastguard Worker cases in their argument parsing code. This means including both ``"--local-rank"`` 80*da0073e9SAndroid Build Coastguard Worker and ``"--local_rank"`` in the argument parser. If only ``"--local_rank"`` is 81*da0073e9SAndroid Build Coastguard Worker provided, the launcher will trigger an error: "error: unrecognized arguments: 82*da0073e9SAndroid Build Coastguard Worker --local-rank=<rank>". For training code that only supports PyTorch 2.0.0+, 83*da0073e9SAndroid Build Coastguard Worker including ``"--local-rank"`` should be sufficient. 84*da0073e9SAndroid Build Coastguard Worker 85*da0073e9SAndroid Build Coastguard Worker :: 86*da0073e9SAndroid Build Coastguard Worker 87*da0073e9SAndroid Build Coastguard Worker >>> # xdoctest: +SKIP 88*da0073e9SAndroid Build Coastguard Worker >>> import argparse 89*da0073e9SAndroid Build Coastguard Worker >>> parser = argparse.ArgumentParser() 90*da0073e9SAndroid Build Coastguard Worker >>> parser.add_argument("--local-rank", "--local_rank", type=int) 91*da0073e9SAndroid Build Coastguard Worker >>> args = parser.parse_args() 92*da0073e9SAndroid Build Coastguard Worker 93*da0073e9SAndroid Build Coastguard WorkerThe aformentioned changes suffice to migrate from ``torch.distributed.launch`` to ``torchrun``. 94*da0073e9SAndroid Build Coastguard WorkerTo take advantage of new features such as elasticity, fault-tolerance, and error reporting of ``torchrun`` 95*da0073e9SAndroid Build Coastguard Workerplease refer to: 96*da0073e9SAndroid Build Coastguard Worker 97*da0073e9SAndroid Build Coastguard Worker* :ref:`elastic_train_script` for more information on authoring training scripts that are ``torchrun`` compliant. 98*da0073e9SAndroid Build Coastguard Worker* the rest of this page for more information on the features of ``torchrun``. 99*da0073e9SAndroid Build Coastguard Worker 100*da0073e9SAndroid Build Coastguard Worker 101*da0073e9SAndroid Build Coastguard WorkerUsage 102*da0073e9SAndroid Build Coastguard Worker-------- 103*da0073e9SAndroid Build Coastguard Worker 104*da0073e9SAndroid Build Coastguard WorkerSingle-node multi-worker 105*da0073e9SAndroid Build Coastguard Worker++++++++++++++++++++++++++++++ 106*da0073e9SAndroid Build Coastguard Worker 107*da0073e9SAndroid Build Coastguard Worker:: 108*da0073e9SAndroid Build Coastguard Worker 109*da0073e9SAndroid Build Coastguard Worker torchrun 110*da0073e9SAndroid Build Coastguard Worker --standalone 111*da0073e9SAndroid Build Coastguard Worker --nnodes=1 112*da0073e9SAndroid Build Coastguard Worker --nproc-per-node=$NUM_TRAINERS 113*da0073e9SAndroid Build Coastguard Worker YOUR_TRAINING_SCRIPT.py (--arg1 ... train script args...) 114*da0073e9SAndroid Build Coastguard Worker 115*da0073e9SAndroid Build Coastguard WorkerStacked single-node multi-worker 116*da0073e9SAndroid Build Coastguard Worker+++++++++++++++++++++++++++++++++++ 117*da0073e9SAndroid Build Coastguard Worker 118*da0073e9SAndroid Build Coastguard WorkerTo run multiple instances (separate jobs) of single-node, multi-worker on the 119*da0073e9SAndroid Build Coastguard Workersame host, we need to make sure that each instance (job) is 120*da0073e9SAndroid Build Coastguard Workersetup on different ports to avoid port conflicts (or worse, two jobs being merged 121*da0073e9SAndroid Build Coastguard Workeras a single job). To do this you have to run with ``--rdzv-backend=c10d`` 122*da0073e9SAndroid Build Coastguard Workerand specify a different port by setting ``--rdzv-endpoint=localhost:$PORT_k``. 123*da0073e9SAndroid Build Coastguard WorkerFor ``--nodes=1``, its often convenient to let ``torchrun`` pick a free random 124*da0073e9SAndroid Build Coastguard Workerport automatically instead of manually assigning different ports for each run. 125*da0073e9SAndroid Build Coastguard Worker 126*da0073e9SAndroid Build Coastguard Worker:: 127*da0073e9SAndroid Build Coastguard Worker 128*da0073e9SAndroid Build Coastguard Worker torchrun 129*da0073e9SAndroid Build Coastguard Worker --rdzv-backend=c10d 130*da0073e9SAndroid Build Coastguard Worker --rdzv-endpoint=localhost:0 131*da0073e9SAndroid Build Coastguard Worker --nnodes=1 132*da0073e9SAndroid Build Coastguard Worker --nproc-per-node=$NUM_TRAINERS 133*da0073e9SAndroid Build Coastguard Worker YOUR_TRAINING_SCRIPT.py (--arg1 ... train script args...) 134*da0073e9SAndroid Build Coastguard Worker 135*da0073e9SAndroid Build Coastguard Worker 136*da0073e9SAndroid Build Coastguard WorkerFault tolerant (fixed sized number of workers, no elasticity, tolerates 3 failures) 137*da0073e9SAndroid Build Coastguard Worker++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ 138*da0073e9SAndroid Build Coastguard Worker 139*da0073e9SAndroid Build Coastguard Worker:: 140*da0073e9SAndroid Build Coastguard Worker 141*da0073e9SAndroid Build Coastguard Worker torchrun 142*da0073e9SAndroid Build Coastguard Worker --nnodes=$NUM_NODES 143*da0073e9SAndroid Build Coastguard Worker --nproc-per-node=$NUM_TRAINERS 144*da0073e9SAndroid Build Coastguard Worker --max-restarts=3 145*da0073e9SAndroid Build Coastguard Worker --rdzv-id=$JOB_ID 146*da0073e9SAndroid Build Coastguard Worker --rdzv-backend=c10d 147*da0073e9SAndroid Build Coastguard Worker --rdzv-endpoint=$HOST_NODE_ADDR 148*da0073e9SAndroid Build Coastguard Worker YOUR_TRAINING_SCRIPT.py (--arg1 ... train script args...) 149*da0073e9SAndroid Build Coastguard Worker 150*da0073e9SAndroid Build Coastguard Worker``HOST_NODE_ADDR``, in form <host>[:<port>] (e.g. node1.example.com:29400), specifies the node and 151*da0073e9SAndroid Build Coastguard Workerthe port on which the C10d rendezvous backend should be instantiated and hosted. It can be any 152*da0073e9SAndroid Build Coastguard Workernode in your training cluster, but ideally you should pick a node that has a high bandwidth. 153*da0073e9SAndroid Build Coastguard Worker 154*da0073e9SAndroid Build Coastguard Worker.. note:: 155*da0073e9SAndroid Build Coastguard Worker If no port number is specified ``HOST_NODE_ADDR`` defaults to 29400. 156*da0073e9SAndroid Build Coastguard Worker 157*da0073e9SAndroid Build Coastguard WorkerElastic (``min=1``, ``max=4``, tolerates up to 3 membership changes or failures) 158*da0073e9SAndroid Build Coastguard Worker+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ 159*da0073e9SAndroid Build Coastguard Worker 160*da0073e9SAndroid Build Coastguard Worker:: 161*da0073e9SAndroid Build Coastguard Worker 162*da0073e9SAndroid Build Coastguard Worker torchrun 163*da0073e9SAndroid Build Coastguard Worker --nnodes=1:4 164*da0073e9SAndroid Build Coastguard Worker --nproc-per-node=$NUM_TRAINERS 165*da0073e9SAndroid Build Coastguard Worker --max-restarts=3 166*da0073e9SAndroid Build Coastguard Worker --rdzv-id=$JOB_ID 167*da0073e9SAndroid Build Coastguard Worker --rdzv-backend=c10d 168*da0073e9SAndroid Build Coastguard Worker --rdzv-endpoint=$HOST_NODE_ADDR 169*da0073e9SAndroid Build Coastguard Worker YOUR_TRAINING_SCRIPT.py (--arg1 ... train script args...) 170*da0073e9SAndroid Build Coastguard Worker 171*da0073e9SAndroid Build Coastguard Worker``HOST_NODE_ADDR``, in form <host>[:<port>] (e.g. node1.example.com:29400), specifies the node and 172*da0073e9SAndroid Build Coastguard Workerthe port on which the C10d rendezvous backend should be instantiated and hosted. It can be any 173*da0073e9SAndroid Build Coastguard Workernode in your training cluster, but ideally you should pick a node that has a high bandwidth. 174*da0073e9SAndroid Build Coastguard Worker 175*da0073e9SAndroid Build Coastguard Worker.. note:: 176*da0073e9SAndroid Build Coastguard Worker If no port number is specified ``HOST_NODE_ADDR`` defaults to 29400. 177*da0073e9SAndroid Build Coastguard Worker 178*da0073e9SAndroid Build Coastguard WorkerNote on rendezvous backend 179*da0073e9SAndroid Build Coastguard Worker------------------------------ 180*da0073e9SAndroid Build Coastguard Worker 181*da0073e9SAndroid Build Coastguard WorkerFor multi-node training you need to specify: 182*da0073e9SAndroid Build Coastguard Worker 183*da0073e9SAndroid Build Coastguard Worker1. ``--rdzv-id``: A unique job id (shared by all nodes participating in the job) 184*da0073e9SAndroid Build Coastguard Worker2. ``--rdzv-backend``: An implementation of 185*da0073e9SAndroid Build Coastguard Worker :py:class:`torch.distributed.elastic.rendezvous.RendezvousHandler` 186*da0073e9SAndroid Build Coastguard Worker3. ``--rdzv-endpoint``: The endpoint where the rendezvous backend is running; usually in form 187*da0073e9SAndroid Build Coastguard Worker ``host:port``. 188*da0073e9SAndroid Build Coastguard Worker 189*da0073e9SAndroid Build Coastguard WorkerCurrently ``c10d`` (recommended), ``etcd-v2``, and ``etcd`` (legacy) rendezvous backends are 190*da0073e9SAndroid Build Coastguard Workersupported out of the box. To use ``etcd-v2`` or ``etcd``, setup an etcd server with the ``v2`` api 191*da0073e9SAndroid Build Coastguard Workerenabled (e.g. ``--enable-v2``). 192*da0073e9SAndroid Build Coastguard Worker 193*da0073e9SAndroid Build Coastguard Worker.. warning:: 194*da0073e9SAndroid Build Coastguard Worker ``etcd-v2`` and ``etcd`` rendezvous use etcd API v2. You MUST enable the v2 API on the etcd 195*da0073e9SAndroid Build Coastguard Worker server. Our tests use etcd v3.4.3. 196*da0073e9SAndroid Build Coastguard Worker 197*da0073e9SAndroid Build Coastguard Worker.. warning:: 198*da0073e9SAndroid Build Coastguard Worker For etcd-based rendezvous we recommend using ``etcd-v2`` over ``etcd`` which is functionally 199*da0073e9SAndroid Build Coastguard Worker equivalent, but uses a revised implementation. ``etcd`` is in maintenance mode and will be 200*da0073e9SAndroid Build Coastguard Worker removed in a future version. 201*da0073e9SAndroid Build Coastguard Worker 202*da0073e9SAndroid Build Coastguard WorkerDefinitions 203*da0073e9SAndroid Build Coastguard Worker-------------- 204*da0073e9SAndroid Build Coastguard Worker 205*da0073e9SAndroid Build Coastguard Worker1. ``Node`` - A physical instance or a container; maps to the unit that the job manager works with. 206*da0073e9SAndroid Build Coastguard Worker 207*da0073e9SAndroid Build Coastguard Worker2. ``Worker`` - A worker in the context of distributed training. 208*da0073e9SAndroid Build Coastguard Worker 209*da0073e9SAndroid Build Coastguard Worker3. ``WorkerGroup`` - The set of workers that execute the same function (e.g. trainers). 210*da0073e9SAndroid Build Coastguard Worker 211*da0073e9SAndroid Build Coastguard Worker4. ``LocalWorkerGroup`` - A subset of the workers in the worker group running on the same node. 212*da0073e9SAndroid Build Coastguard Worker 213*da0073e9SAndroid Build Coastguard Worker5. ``RANK`` - The rank of the worker within a worker group. 214*da0073e9SAndroid Build Coastguard Worker 215*da0073e9SAndroid Build Coastguard Worker6. ``WORLD_SIZE`` - The total number of workers in a worker group. 216*da0073e9SAndroid Build Coastguard Worker 217*da0073e9SAndroid Build Coastguard Worker7. ``LOCAL_RANK`` - The rank of the worker within a local worker group. 218*da0073e9SAndroid Build Coastguard Worker 219*da0073e9SAndroid Build Coastguard Worker8. ``LOCAL_WORLD_SIZE`` - The size of the local worker group. 220*da0073e9SAndroid Build Coastguard Worker 221*da0073e9SAndroid Build Coastguard Worker9. ``rdzv_id`` - A user-defined id that uniquely identifies the worker group for a job. This id is 222*da0073e9SAndroid Build Coastguard Worker used by each node to join as a member of a particular worker group. 223*da0073e9SAndroid Build Coastguard Worker 224*da0073e9SAndroid Build Coastguard Worker9. ``rdzv_backend`` - The backend of the rendezvous (e.g. ``c10d``). This is typically a strongly 225*da0073e9SAndroid Build Coastguard Worker consistent key-value store. 226*da0073e9SAndroid Build Coastguard Worker 227*da0073e9SAndroid Build Coastguard Worker10. ``rdzv_endpoint`` - The rendezvous backend endpoint; usually in form ``<host>:<port>``. 228*da0073e9SAndroid Build Coastguard Worker 229*da0073e9SAndroid Build Coastguard WorkerA ``Node`` runs ``LOCAL_WORLD_SIZE`` workers which comprise a ``LocalWorkerGroup``. The union of 230*da0073e9SAndroid Build Coastguard Workerall ``LocalWorkerGroups`` in the nodes in the job comprise the ``WorkerGroup``. 231*da0073e9SAndroid Build Coastguard Worker 232*da0073e9SAndroid Build Coastguard WorkerEnvironment Variables 233*da0073e9SAndroid Build Coastguard Worker---------------------- 234*da0073e9SAndroid Build Coastguard Worker 235*da0073e9SAndroid Build Coastguard WorkerThe following environment variables are made available to you in your script: 236*da0073e9SAndroid Build Coastguard Worker 237*da0073e9SAndroid Build Coastguard Worker1. ``LOCAL_RANK`` - The local rank. 238*da0073e9SAndroid Build Coastguard Worker 239*da0073e9SAndroid Build Coastguard Worker2. ``RANK`` - The global rank. 240*da0073e9SAndroid Build Coastguard Worker 241*da0073e9SAndroid Build Coastguard Worker3. ``GROUP_RANK`` - The rank of the worker group. A number between 0 and ``max_nnodes``. When 242*da0073e9SAndroid Build Coastguard Worker running a single worker group per node, this is the rank of the node. 243*da0073e9SAndroid Build Coastguard Worker 244*da0073e9SAndroid Build Coastguard Worker4. ``ROLE_RANK`` - The rank of the worker across all the workers that have the same role. The role 245*da0073e9SAndroid Build Coastguard Worker of the worker is specified in the ``WorkerSpec``. 246*da0073e9SAndroid Build Coastguard Worker 247*da0073e9SAndroid Build Coastguard Worker5. ``LOCAL_WORLD_SIZE`` - The local world size (e.g. number of workers running locally); equals to 248*da0073e9SAndroid Build Coastguard Worker ``--nproc-per-node`` specified on ``torchrun``. 249*da0073e9SAndroid Build Coastguard Worker 250*da0073e9SAndroid Build Coastguard Worker6. ``WORLD_SIZE`` - The world size (total number of workers in the job). 251*da0073e9SAndroid Build Coastguard Worker 252*da0073e9SAndroid Build Coastguard Worker7. ``ROLE_WORLD_SIZE`` - The total number of workers that was launched with the same role specified 253*da0073e9SAndroid Build Coastguard Worker in ``WorkerSpec``. 254*da0073e9SAndroid Build Coastguard Worker 255*da0073e9SAndroid Build Coastguard Worker8. ``MASTER_ADDR`` - The FQDN of the host that is running worker with rank 0; used to initialize 256*da0073e9SAndroid Build Coastguard Worker the Torch Distributed backend. 257*da0073e9SAndroid Build Coastguard Worker 258*da0073e9SAndroid Build Coastguard Worker9. ``MASTER_PORT`` - The port on the ``MASTER_ADDR`` that can be used to host the C10d TCP store. 259*da0073e9SAndroid Build Coastguard Worker 260*da0073e9SAndroid Build Coastguard Worker10. ``TORCHELASTIC_RESTART_COUNT`` - The number of worker group restarts so far. 261*da0073e9SAndroid Build Coastguard Worker 262*da0073e9SAndroid Build Coastguard Worker11. ``TORCHELASTIC_MAX_RESTARTS`` - The configured maximum number of restarts. 263*da0073e9SAndroid Build Coastguard Worker 264*da0073e9SAndroid Build Coastguard Worker12. ``TORCHELASTIC_RUN_ID`` - Equal to the rendezvous ``run_id`` (e.g. unique job id). 265*da0073e9SAndroid Build Coastguard Worker 266*da0073e9SAndroid Build Coastguard Worker13. ``PYTHON_EXEC`` - System executable override. If provided, the python user script will 267*da0073e9SAndroid Build Coastguard Worker use the value of ``PYTHON_EXEC`` as executable. The `sys.executable` is used by default. 268*da0073e9SAndroid Build Coastguard Worker 269*da0073e9SAndroid Build Coastguard WorkerDeployment 270*da0073e9SAndroid Build Coastguard Worker------------ 271*da0073e9SAndroid Build Coastguard Worker 272*da0073e9SAndroid Build Coastguard Worker1. (Not needed for the C10d backend) Start the rendezvous backend server and get the endpoint (to be 273*da0073e9SAndroid Build Coastguard Worker passed as ``--rdzv-endpoint`` to the launcher script) 274*da0073e9SAndroid Build Coastguard Worker 275*da0073e9SAndroid Build Coastguard Worker2. Single-node multi-worker: Start the launcher on the host to start the agent process which 276*da0073e9SAndroid Build Coastguard Worker creates and monitors a local worker group. 277*da0073e9SAndroid Build Coastguard Worker 278*da0073e9SAndroid Build Coastguard Worker3. Multi-node multi-worker: Start the launcher with the same arguments on all the nodes 279*da0073e9SAndroid Build Coastguard Worker participating in training. 280*da0073e9SAndroid Build Coastguard Worker 281*da0073e9SAndroid Build Coastguard WorkerWhen using a job/cluster manager the entry point command to the multi-node job should be this 282*da0073e9SAndroid Build Coastguard Workerlauncher. 283*da0073e9SAndroid Build Coastguard Worker 284*da0073e9SAndroid Build Coastguard WorkerFailure Modes 285*da0073e9SAndroid Build Coastguard Worker--------------- 286*da0073e9SAndroid Build Coastguard Worker 287*da0073e9SAndroid Build Coastguard Worker1. Worker failure: For a training job with ``n`` workers, if ``k<=n`` workers fail all workers 288*da0073e9SAndroid Build Coastguard Worker are stopped and restarted up to ``max_restarts``. 289*da0073e9SAndroid Build Coastguard Worker 290*da0073e9SAndroid Build Coastguard Worker2. Agent failure: An agent failure results in a local worker group failure. It is up to the job 291*da0073e9SAndroid Build Coastguard Worker manager to fail the entire job (gang semantics) or attempt to replace the node. Both behaviors 292*da0073e9SAndroid Build Coastguard Worker are supported by the agent. 293*da0073e9SAndroid Build Coastguard Worker 294*da0073e9SAndroid Build Coastguard Worker3. Node failure: Same as agent failure. 295*da0073e9SAndroid Build Coastguard Worker 296*da0073e9SAndroid Build Coastguard WorkerMembership Changes 297*da0073e9SAndroid Build Coastguard Worker-------------------- 298*da0073e9SAndroid Build Coastguard Worker 299*da0073e9SAndroid Build Coastguard Worker1. Node departure (scale-down): The agent is notified of the departure, all existing workers are 300*da0073e9SAndroid Build Coastguard Worker stopped, a new ``WorkerGroup`` is formed, and all workers are started with a new ``RANK`` and 301*da0073e9SAndroid Build Coastguard Worker ``WORLD_SIZE``. 302*da0073e9SAndroid Build Coastguard Worker 303*da0073e9SAndroid Build Coastguard Worker2. Node arrival (scale-up): The new node is admitted to the job, all existing workers are stopped, 304*da0073e9SAndroid Build Coastguard Worker a new ``WorkerGroup`` is formed, and all workers are started with a new ``RANK`` and 305*da0073e9SAndroid Build Coastguard Worker ``WORLD_SIZE``. 306*da0073e9SAndroid Build Coastguard Worker 307*da0073e9SAndroid Build Coastguard WorkerImportant Notices 308*da0073e9SAndroid Build Coastguard Worker-------------------- 309*da0073e9SAndroid Build Coastguard Worker 310*da0073e9SAndroid Build Coastguard Worker1. This utility and multi-process distributed (single-node or 311*da0073e9SAndroid Build Coastguard Worker multi-node) GPU training currently only achieves the best performance using 312*da0073e9SAndroid Build Coastguard Worker the NCCL distributed backend. Thus NCCL backend is the recommended backend to 313*da0073e9SAndroid Build Coastguard Worker use for GPU training. 314*da0073e9SAndroid Build Coastguard Worker 315*da0073e9SAndroid Build Coastguard Worker2. The environment variables necessary to initialize a Torch process group are provided to you by 316*da0073e9SAndroid Build Coastguard Worker this module, no need for you to pass ``RANK`` manually. To initialize a process group in your 317*da0073e9SAndroid Build Coastguard Worker training script, simply run: 318*da0073e9SAndroid Build Coastguard Worker 319*da0073e9SAndroid Build Coastguard Worker:: 320*da0073e9SAndroid Build Coastguard Worker 321*da0073e9SAndroid Build Coastguard Worker >>> # xdoctest: +SKIP("stub") 322*da0073e9SAndroid Build Coastguard Worker >>> import torch.distributed as dist 323*da0073e9SAndroid Build Coastguard Worker >>> dist.init_process_group(backend="gloo|nccl") 324*da0073e9SAndroid Build Coastguard Worker 325*da0073e9SAndroid Build Coastguard Worker3. In your training program, you can either use regular distributed functions 326*da0073e9SAndroid Build Coastguard Worker or use :func:`torch.nn.parallel.DistributedDataParallel` module. If your 327*da0073e9SAndroid Build Coastguard Worker training program uses GPUs for training and you would like to use 328*da0073e9SAndroid Build Coastguard Worker :func:`torch.nn.parallel.DistributedDataParallel` module, 329*da0073e9SAndroid Build Coastguard Worker here is how to configure it. 330*da0073e9SAndroid Build Coastguard Worker 331*da0073e9SAndroid Build Coastguard Worker:: 332*da0073e9SAndroid Build Coastguard Worker 333*da0073e9SAndroid Build Coastguard Worker local_rank = int(os.environ["LOCAL_RANK"]) 334*da0073e9SAndroid Build Coastguard Worker model = torch.nn.parallel.DistributedDataParallel(model, 335*da0073e9SAndroid Build Coastguard Worker device_ids=[local_rank], 336*da0073e9SAndroid Build Coastguard Worker output_device=local_rank) 337*da0073e9SAndroid Build Coastguard Worker 338*da0073e9SAndroid Build Coastguard WorkerPlease ensure that ``device_ids`` argument is set to be the only GPU device id 339*da0073e9SAndroid Build Coastguard Workerthat your code will be operating on. This is generally the local rank of the 340*da0073e9SAndroid Build Coastguard Workerprocess. In other words, the ``device_ids`` needs to be ``[int(os.environ("LOCAL_RANK"))]``, 341*da0073e9SAndroid Build Coastguard Workerand ``output_device`` needs to be ``int(os.environ("LOCAL_RANK"))`` in order to use this 342*da0073e9SAndroid Build Coastguard Workerutility 343*da0073e9SAndroid Build Coastguard Worker 344*da0073e9SAndroid Build Coastguard Worker 345*da0073e9SAndroid Build Coastguard Worker4. On failures or membership changes ALL surviving workers are killed immediately. Make sure to 346*da0073e9SAndroid Build Coastguard Worker checkpoint your progress. The frequency of checkpoints should depend on your job's tolerance 347*da0073e9SAndroid Build Coastguard Worker for lost work. 348*da0073e9SAndroid Build Coastguard Worker 349*da0073e9SAndroid Build Coastguard Worker5. This module only supports homogeneous ``LOCAL_WORLD_SIZE``. That is, it is assumed that all 350*da0073e9SAndroid Build Coastguard Worker nodes run the same number of local workers (per role). 351*da0073e9SAndroid Build Coastguard Worker 352*da0073e9SAndroid Build Coastguard Worker6. ``RANK`` is NOT stable. Between restarts, the local workers on a node can be assigned a 353*da0073e9SAndroid Build Coastguard Worker different range of ranks than before. NEVER hard code any assumptions about the stable-ness of 354*da0073e9SAndroid Build Coastguard Worker ranks or some correlation between ``RANK`` and ``LOCAL_RANK``. 355*da0073e9SAndroid Build Coastguard Worker 356*da0073e9SAndroid Build Coastguard Worker7. When using elasticity (``min_size!=max_size``) DO NOT hard code assumptions about 357*da0073e9SAndroid Build Coastguard Worker ``WORLD_SIZE`` as the world size can change as nodes are allowed to leave and join. 358*da0073e9SAndroid Build Coastguard Worker 359*da0073e9SAndroid Build Coastguard Worker8. It is recommended for your script to have the following structure: 360*da0073e9SAndroid Build Coastguard Worker 361*da0073e9SAndroid Build Coastguard Worker:: 362*da0073e9SAndroid Build Coastguard Worker 363*da0073e9SAndroid Build Coastguard Worker def main(): 364*da0073e9SAndroid Build Coastguard Worker load_checkpoint(checkpoint_path) 365*da0073e9SAndroid Build Coastguard Worker initialize() 366*da0073e9SAndroid Build Coastguard Worker train() 367*da0073e9SAndroid Build Coastguard Worker 368*da0073e9SAndroid Build Coastguard Worker def train(): 369*da0073e9SAndroid Build Coastguard Worker for batch in iter(dataset): 370*da0073e9SAndroid Build Coastguard Worker train_step(batch) 371*da0073e9SAndroid Build Coastguard Worker 372*da0073e9SAndroid Build Coastguard Worker if should_checkpoint: 373*da0073e9SAndroid Build Coastguard Worker save_checkpoint(checkpoint_path) 374*da0073e9SAndroid Build Coastguard Worker 375*da0073e9SAndroid Build Coastguard Worker9. (Recommended) On worker errors, this tool will summarize the details of the error 376*da0073e9SAndroid Build Coastguard Worker (e.g. time, rank, host, pid, traceback, etc). On each node, the first error (by timestamp) 377*da0073e9SAndroid Build Coastguard Worker is heuristically reported as the "Root Cause" error. To get tracebacks as part of this 378*da0073e9SAndroid Build Coastguard Worker error summary print out, you must decorate your main entrypoint function in your 379*da0073e9SAndroid Build Coastguard Worker training script as shown in the example below. If not decorated, then the summary 380*da0073e9SAndroid Build Coastguard Worker will not include the traceback of the exception and will only contain the exitcode. 381*da0073e9SAndroid Build Coastguard Worker For details on torchelastic error handling see: https://pytorch.org/docs/stable/elastic/errors.html 382*da0073e9SAndroid Build Coastguard Worker 383*da0073e9SAndroid Build Coastguard Worker:: 384*da0073e9SAndroid Build Coastguard Worker 385*da0073e9SAndroid Build Coastguard Worker from torch.distributed.elastic.multiprocessing.errors import record 386*da0073e9SAndroid Build Coastguard Worker 387*da0073e9SAndroid Build Coastguard Worker @record 388*da0073e9SAndroid Build Coastguard Worker def main(): 389*da0073e9SAndroid Build Coastguard Worker # do train 390*da0073e9SAndroid Build Coastguard Worker pass 391*da0073e9SAndroid Build Coastguard Worker 392*da0073e9SAndroid Build Coastguard Worker if __name__ == "__main__": 393*da0073e9SAndroid Build Coastguard Worker main() 394*da0073e9SAndroid Build Coastguard Worker 395*da0073e9SAndroid Build Coastguard Worker""" 396*da0073e9SAndroid Build Coastguard Workerimport logging 397*da0073e9SAndroid Build Coastguard Workerimport os 398*da0073e9SAndroid Build Coastguard Workerimport sys 399*da0073e9SAndroid Build Coastguard Workerimport uuid 400*da0073e9SAndroid Build Coastguard Workerfrom argparse import ArgumentParser, REMAINDER 401*da0073e9SAndroid Build Coastguard Workerfrom importlib import metadata 402*da0073e9SAndroid Build Coastguard Workerfrom typing import Callable, List, Optional, Set, Tuple, Type, Union 403*da0073e9SAndroid Build Coastguard Worker 404*da0073e9SAndroid Build Coastguard Workerimport torch 405*da0073e9SAndroid Build Coastguard Workerfrom torch.distributed.argparse_util import check_env, env 406*da0073e9SAndroid Build Coastguard Workerfrom torch.distributed.elastic.multiprocessing import DefaultLogsSpecs, LogsSpecs, Std 407*da0073e9SAndroid Build Coastguard Workerfrom torch.distributed.elastic.multiprocessing.errors import record 408*da0073e9SAndroid Build Coastguard Workerfrom torch.distributed.elastic.rendezvous.utils import _parse_rendezvous_config 409*da0073e9SAndroid Build Coastguard Workerfrom torch.distributed.elastic.utils import macros 410*da0073e9SAndroid Build Coastguard Workerfrom torch.distributed.elastic.utils.logging import get_logger 411*da0073e9SAndroid Build Coastguard Workerfrom torch.distributed.launcher.api import elastic_launch, LaunchConfig 412*da0073e9SAndroid Build Coastguard Workerfrom torch.utils.backend_registration import _get_custom_mod_func 413*da0073e9SAndroid Build Coastguard Worker 414*da0073e9SAndroid Build Coastguard Worker 415*da0073e9SAndroid Build Coastguard Workerlogger = get_logger(__name__) 416*da0073e9SAndroid Build Coastguard Worker 417*da0073e9SAndroid Build Coastguard Worker 418*da0073e9SAndroid Build Coastguard Workerdef get_args_parser() -> ArgumentParser: 419*da0073e9SAndroid Build Coastguard Worker """Parse the command line options.""" 420*da0073e9SAndroid Build Coastguard Worker parser = ArgumentParser(description="Torch Distributed Elastic Training Launcher") 421*da0073e9SAndroid Build Coastguard Worker 422*da0073e9SAndroid Build Coastguard Worker # 423*da0073e9SAndroid Build Coastguard Worker # Worker/node size related arguments. 424*da0073e9SAndroid Build Coastguard Worker # 425*da0073e9SAndroid Build Coastguard Worker 426*da0073e9SAndroid Build Coastguard Worker parser.add_argument( 427*da0073e9SAndroid Build Coastguard Worker "--nnodes", 428*da0073e9SAndroid Build Coastguard Worker action=env, 429*da0073e9SAndroid Build Coastguard Worker type=str, 430*da0073e9SAndroid Build Coastguard Worker default="1:1", 431*da0073e9SAndroid Build Coastguard Worker help="Number of nodes, or the range of nodes in form <minimum_nodes>:<maximum_nodes>.", 432*da0073e9SAndroid Build Coastguard Worker ) 433*da0073e9SAndroid Build Coastguard Worker parser.add_argument( 434*da0073e9SAndroid Build Coastguard Worker "--nproc-per-node", 435*da0073e9SAndroid Build Coastguard Worker "--nproc_per_node", 436*da0073e9SAndroid Build Coastguard Worker action=env, 437*da0073e9SAndroid Build Coastguard Worker type=str, 438*da0073e9SAndroid Build Coastguard Worker default="1", 439*da0073e9SAndroid Build Coastguard Worker help="Number of workers per node; supported values: [auto, cpu, gpu, int].", 440*da0073e9SAndroid Build Coastguard Worker ) 441*da0073e9SAndroid Build Coastguard Worker 442*da0073e9SAndroid Build Coastguard Worker # 443*da0073e9SAndroid Build Coastguard Worker # Rendezvous related arguments 444*da0073e9SAndroid Build Coastguard Worker # 445*da0073e9SAndroid Build Coastguard Worker 446*da0073e9SAndroid Build Coastguard Worker parser.add_argument( 447*da0073e9SAndroid Build Coastguard Worker "--rdzv-backend", 448*da0073e9SAndroid Build Coastguard Worker "--rdzv_backend", 449*da0073e9SAndroid Build Coastguard Worker action=env, 450*da0073e9SAndroid Build Coastguard Worker type=str, 451*da0073e9SAndroid Build Coastguard Worker default="static", 452*da0073e9SAndroid Build Coastguard Worker help="Rendezvous backend.", 453*da0073e9SAndroid Build Coastguard Worker ) 454*da0073e9SAndroid Build Coastguard Worker parser.add_argument( 455*da0073e9SAndroid Build Coastguard Worker "--rdzv-endpoint", 456*da0073e9SAndroid Build Coastguard Worker "--rdzv_endpoint", 457*da0073e9SAndroid Build Coastguard Worker action=env, 458*da0073e9SAndroid Build Coastguard Worker type=str, 459*da0073e9SAndroid Build Coastguard Worker default="", 460*da0073e9SAndroid Build Coastguard Worker help="Rendezvous backend endpoint; usually in form <host>:<port>.", 461*da0073e9SAndroid Build Coastguard Worker ) 462*da0073e9SAndroid Build Coastguard Worker parser.add_argument( 463*da0073e9SAndroid Build Coastguard Worker "--rdzv-id", 464*da0073e9SAndroid Build Coastguard Worker "--rdzv_id", 465*da0073e9SAndroid Build Coastguard Worker action=env, 466*da0073e9SAndroid Build Coastguard Worker type=str, 467*da0073e9SAndroid Build Coastguard Worker default="none", 468*da0073e9SAndroid Build Coastguard Worker help="User-defined group id.", 469*da0073e9SAndroid Build Coastguard Worker ) 470*da0073e9SAndroid Build Coastguard Worker parser.add_argument( 471*da0073e9SAndroid Build Coastguard Worker "--rdzv-conf", 472*da0073e9SAndroid Build Coastguard Worker "--rdzv_conf", 473*da0073e9SAndroid Build Coastguard Worker action=env, 474*da0073e9SAndroid Build Coastguard Worker type=str, 475*da0073e9SAndroid Build Coastguard Worker default="", 476*da0073e9SAndroid Build Coastguard Worker help="Additional rendezvous configuration (<key1>=<value1>,<key2>=<value2>,...).", 477*da0073e9SAndroid Build Coastguard Worker ) 478*da0073e9SAndroid Build Coastguard Worker parser.add_argument( 479*da0073e9SAndroid Build Coastguard Worker "--standalone", 480*da0073e9SAndroid Build Coastguard Worker action=check_env, 481*da0073e9SAndroid Build Coastguard Worker help="Start a local standalone rendezvous backend that is represented by a C10d TCP store " 482*da0073e9SAndroid Build Coastguard Worker "on a free port. Useful when launching single-node, multi-worker job. If specified " 483*da0073e9SAndroid Build Coastguard Worker "--rdzv-backend, --rdzv-endpoint, --rdzv-id are auto-assigned and any explicitly set values " 484*da0073e9SAndroid Build Coastguard Worker "are ignored.", 485*da0073e9SAndroid Build Coastguard Worker ) 486*da0073e9SAndroid Build Coastguard Worker 487*da0073e9SAndroid Build Coastguard Worker # 488*da0073e9SAndroid Build Coastguard Worker # User-code launch related arguments. 489*da0073e9SAndroid Build Coastguard Worker # 490*da0073e9SAndroid Build Coastguard Worker 491*da0073e9SAndroid Build Coastguard Worker parser.add_argument( 492*da0073e9SAndroid Build Coastguard Worker "--max-restarts", 493*da0073e9SAndroid Build Coastguard Worker "--max_restarts", 494*da0073e9SAndroid Build Coastguard Worker action=env, 495*da0073e9SAndroid Build Coastguard Worker type=int, 496*da0073e9SAndroid Build Coastguard Worker default=0, 497*da0073e9SAndroid Build Coastguard Worker help="Maximum number of worker group restarts before failing.", 498*da0073e9SAndroid Build Coastguard Worker ) 499*da0073e9SAndroid Build Coastguard Worker parser.add_argument( 500*da0073e9SAndroid Build Coastguard Worker "--monitor-interval", 501*da0073e9SAndroid Build Coastguard Worker "--monitor_interval", 502*da0073e9SAndroid Build Coastguard Worker action=env, 503*da0073e9SAndroid Build Coastguard Worker type=float, 504*da0073e9SAndroid Build Coastguard Worker default=0.1, 505*da0073e9SAndroid Build Coastguard Worker help="Interval, in seconds, to monitor the state of workers.", 506*da0073e9SAndroid Build Coastguard Worker ) 507*da0073e9SAndroid Build Coastguard Worker parser.add_argument( 508*da0073e9SAndroid Build Coastguard Worker "--start-method", 509*da0073e9SAndroid Build Coastguard Worker "--start_method", 510*da0073e9SAndroid Build Coastguard Worker action=env, 511*da0073e9SAndroid Build Coastguard Worker type=str, 512*da0073e9SAndroid Build Coastguard Worker default="spawn", 513*da0073e9SAndroid Build Coastguard Worker choices=["spawn", "fork", "forkserver"], 514*da0073e9SAndroid Build Coastguard Worker help="Multiprocessing start method to use when creating workers.", 515*da0073e9SAndroid Build Coastguard Worker ) 516*da0073e9SAndroid Build Coastguard Worker parser.add_argument( 517*da0073e9SAndroid Build Coastguard Worker "--role", 518*da0073e9SAndroid Build Coastguard Worker action=env, 519*da0073e9SAndroid Build Coastguard Worker type=str, 520*da0073e9SAndroid Build Coastguard Worker default="default", 521*da0073e9SAndroid Build Coastguard Worker help="User-defined role for the workers.", 522*da0073e9SAndroid Build Coastguard Worker ) 523*da0073e9SAndroid Build Coastguard Worker parser.add_argument( 524*da0073e9SAndroid Build Coastguard Worker "-m", 525*da0073e9SAndroid Build Coastguard Worker "--module", 526*da0073e9SAndroid Build Coastguard Worker action=check_env, 527*da0073e9SAndroid Build Coastguard Worker help="Change each process to interpret the launch script as a Python module, executing " 528*da0073e9SAndroid Build Coastguard Worker "with the same behavior as 'python -m'.", 529*da0073e9SAndroid Build Coastguard Worker ) 530*da0073e9SAndroid Build Coastguard Worker parser.add_argument( 531*da0073e9SAndroid Build Coastguard Worker "--no-python", 532*da0073e9SAndroid Build Coastguard Worker "--no_python", 533*da0073e9SAndroid Build Coastguard Worker action=check_env, 534*da0073e9SAndroid Build Coastguard Worker help="Skip prepending the training script with 'python' - just execute it directly. Useful " 535*da0073e9SAndroid Build Coastguard Worker "when the script is not a Python script.", 536*da0073e9SAndroid Build Coastguard Worker ) 537*da0073e9SAndroid Build Coastguard Worker 538*da0073e9SAndroid Build Coastguard Worker parser.add_argument( 539*da0073e9SAndroid Build Coastguard Worker "--run-path", 540*da0073e9SAndroid Build Coastguard Worker "--run_path", 541*da0073e9SAndroid Build Coastguard Worker action=check_env, 542*da0073e9SAndroid Build Coastguard Worker help="Run the training script with runpy.run_path in the same interpreter." 543*da0073e9SAndroid Build Coastguard Worker " Script must be provided as an abs path (e.g. /abs/path/script.py)." 544*da0073e9SAndroid Build Coastguard Worker " Takes precedence over --no-python.", 545*da0073e9SAndroid Build Coastguard Worker ) 546*da0073e9SAndroid Build Coastguard Worker parser.add_argument( 547*da0073e9SAndroid Build Coastguard Worker "--log-dir", 548*da0073e9SAndroid Build Coastguard Worker "--log_dir", 549*da0073e9SAndroid Build Coastguard Worker action=env, 550*da0073e9SAndroid Build Coastguard Worker type=str, 551*da0073e9SAndroid Build Coastguard Worker default=None, 552*da0073e9SAndroid Build Coastguard Worker help="Base directory to use for log files (e.g. /var/log/torch/elastic). The same " 553*da0073e9SAndroid Build Coastguard Worker "directory is re-used for multiple runs (a unique job-level sub-directory is created with " 554*da0073e9SAndroid Build Coastguard Worker "rdzv_id as the prefix).", 555*da0073e9SAndroid Build Coastguard Worker ) 556*da0073e9SAndroid Build Coastguard Worker parser.add_argument( 557*da0073e9SAndroid Build Coastguard Worker "-r", 558*da0073e9SAndroid Build Coastguard Worker "--redirects", 559*da0073e9SAndroid Build Coastguard Worker action=env, 560*da0073e9SAndroid Build Coastguard Worker type=str, 561*da0073e9SAndroid Build Coastguard Worker default="0", 562*da0073e9SAndroid Build Coastguard Worker help="Redirect std streams into a log file in the log directory (e.g. [-r 3] redirects " 563*da0073e9SAndroid Build Coastguard Worker "both stdout+stderr for all workers, [-r 0:1,1:2] redirects stdout for local rank 0 and " 564*da0073e9SAndroid Build Coastguard Worker "stderr for local rank 1).", 565*da0073e9SAndroid Build Coastguard Worker ) 566*da0073e9SAndroid Build Coastguard Worker parser.add_argument( 567*da0073e9SAndroid Build Coastguard Worker "-t", 568*da0073e9SAndroid Build Coastguard Worker "--tee", 569*da0073e9SAndroid Build Coastguard Worker action=env, 570*da0073e9SAndroid Build Coastguard Worker type=str, 571*da0073e9SAndroid Build Coastguard Worker default="0", 572*da0073e9SAndroid Build Coastguard Worker help="Tee std streams into a log file and also to console (see --redirects for format).", 573*da0073e9SAndroid Build Coastguard Worker ) 574*da0073e9SAndroid Build Coastguard Worker 575*da0073e9SAndroid Build Coastguard Worker parser.add_argument( 576*da0073e9SAndroid Build Coastguard Worker "--local-ranks-filter", 577*da0073e9SAndroid Build Coastguard Worker "--local_ranks_filter", 578*da0073e9SAndroid Build Coastguard Worker action=env, 579*da0073e9SAndroid Build Coastguard Worker type=str, 580*da0073e9SAndroid Build Coastguard Worker default="", 581*da0073e9SAndroid Build Coastguard Worker help="Only show logs from specified ranks in console (e.g. [--local_ranks_filter=0,1,2] will " 582*da0073e9SAndroid Build Coastguard Worker "only show logs from rank 0, 1 and 2). This will only apply to stdout and stderr, not to" 583*da0073e9SAndroid Build Coastguard Worker "log files saved via --redirect or --tee", 584*da0073e9SAndroid Build Coastguard Worker ) 585*da0073e9SAndroid Build Coastguard Worker 586*da0073e9SAndroid Build Coastguard Worker # 587*da0073e9SAndroid Build Coastguard Worker # Backwards compatible parameters with caffe2.distributed.launch. 588*da0073e9SAndroid Build Coastguard Worker # 589*da0073e9SAndroid Build Coastguard Worker 590*da0073e9SAndroid Build Coastguard Worker parser.add_argument( 591*da0073e9SAndroid Build Coastguard Worker "--node-rank", 592*da0073e9SAndroid Build Coastguard Worker "--node_rank", 593*da0073e9SAndroid Build Coastguard Worker type=int, 594*da0073e9SAndroid Build Coastguard Worker action=env, 595*da0073e9SAndroid Build Coastguard Worker default=0, 596*da0073e9SAndroid Build Coastguard Worker help="Rank of the node for multi-node distributed training.", 597*da0073e9SAndroid Build Coastguard Worker ) 598*da0073e9SAndroid Build Coastguard Worker parser.add_argument( 599*da0073e9SAndroid Build Coastguard Worker "--master-addr", 600*da0073e9SAndroid Build Coastguard Worker "--master_addr", 601*da0073e9SAndroid Build Coastguard Worker default="127.0.0.1", 602*da0073e9SAndroid Build Coastguard Worker type=str, 603*da0073e9SAndroid Build Coastguard Worker action=env, 604*da0073e9SAndroid Build Coastguard Worker help="Address of the master node (rank 0) that only used for static rendezvous. It should " 605*da0073e9SAndroid Build Coastguard Worker "be either the IP address or the hostname of rank 0. For single node multi-proc training " 606*da0073e9SAndroid Build Coastguard Worker "the --master-addr can simply be 127.0.0.1; IPv6 should have the pattern " 607*da0073e9SAndroid Build Coastguard Worker "`[0:0:0:0:0:0:0:1]`.", 608*da0073e9SAndroid Build Coastguard Worker ) 609*da0073e9SAndroid Build Coastguard Worker parser.add_argument( 610*da0073e9SAndroid Build Coastguard Worker "--master-port", 611*da0073e9SAndroid Build Coastguard Worker "--master_port", 612*da0073e9SAndroid Build Coastguard Worker default=29500, 613*da0073e9SAndroid Build Coastguard Worker type=int, 614*da0073e9SAndroid Build Coastguard Worker action=env, 615*da0073e9SAndroid Build Coastguard Worker help="Port on the master node (rank 0) to be used for communication during distributed " 616*da0073e9SAndroid Build Coastguard Worker "training. It is only used for static rendezvous.", 617*da0073e9SAndroid Build Coastguard Worker ) 618*da0073e9SAndroid Build Coastguard Worker parser.add_argument( 619*da0073e9SAndroid Build Coastguard Worker "--local-addr", 620*da0073e9SAndroid Build Coastguard Worker "--local_addr", 621*da0073e9SAndroid Build Coastguard Worker default=None, 622*da0073e9SAndroid Build Coastguard Worker type=str, 623*da0073e9SAndroid Build Coastguard Worker action=env, 624*da0073e9SAndroid Build Coastguard Worker help="Address of the local node. If specified, will use the given address for connection. " 625*da0073e9SAndroid Build Coastguard Worker "Else, will look up the local node address instead. Else, it will be default to local " 626*da0073e9SAndroid Build Coastguard Worker "machine's FQDN.", 627*da0073e9SAndroid Build Coastguard Worker ) 628*da0073e9SAndroid Build Coastguard Worker 629*da0073e9SAndroid Build Coastguard Worker parser.add_argument( 630*da0073e9SAndroid Build Coastguard Worker "--logs-specs", 631*da0073e9SAndroid Build Coastguard Worker "--logs_specs", 632*da0073e9SAndroid Build Coastguard Worker default=None, 633*da0073e9SAndroid Build Coastguard Worker type=str, 634*da0073e9SAndroid Build Coastguard Worker help="torchrun.logs_specs group entrypoint name, value must be type of LogsSpecs. " 635*da0073e9SAndroid Build Coastguard Worker "Can be used to override custom logging behavior.", 636*da0073e9SAndroid Build Coastguard Worker ) 637*da0073e9SAndroid Build Coastguard Worker 638*da0073e9SAndroid Build Coastguard Worker # 639*da0073e9SAndroid Build Coastguard Worker # Positional arguments. 640*da0073e9SAndroid Build Coastguard Worker # 641*da0073e9SAndroid Build Coastguard Worker 642*da0073e9SAndroid Build Coastguard Worker parser.add_argument( 643*da0073e9SAndroid Build Coastguard Worker "training_script", 644*da0073e9SAndroid Build Coastguard Worker type=str, 645*da0073e9SAndroid Build Coastguard Worker help="Full path to the (single GPU) training program/script to be launched in parallel, " 646*da0073e9SAndroid Build Coastguard Worker "followed by all the arguments for the training script.", 647*da0073e9SAndroid Build Coastguard Worker ) 648*da0073e9SAndroid Build Coastguard Worker 649*da0073e9SAndroid Build Coastguard Worker # Rest from the training program. 650*da0073e9SAndroid Build Coastguard Worker parser.add_argument("training_script_args", nargs=REMAINDER) 651*da0073e9SAndroid Build Coastguard Worker 652*da0073e9SAndroid Build Coastguard Worker return parser 653*da0073e9SAndroid Build Coastguard Worker 654*da0073e9SAndroid Build Coastguard Worker 655*da0073e9SAndroid Build Coastguard Workerdef parse_args(args): 656*da0073e9SAndroid Build Coastguard Worker parser = get_args_parser() 657*da0073e9SAndroid Build Coastguard Worker return parser.parse_args(args) 658*da0073e9SAndroid Build Coastguard Worker 659*da0073e9SAndroid Build Coastguard Worker 660*da0073e9SAndroid Build Coastguard Workerdef parse_min_max_nnodes(nnodes: str): 661*da0073e9SAndroid Build Coastguard Worker arr = nnodes.split(":") 662*da0073e9SAndroid Build Coastguard Worker 663*da0073e9SAndroid Build Coastguard Worker if len(arr) == 1: 664*da0073e9SAndroid Build Coastguard Worker min_nodes = max_nodes = int(arr[0]) 665*da0073e9SAndroid Build Coastguard Worker elif len(arr) == 2: 666*da0073e9SAndroid Build Coastguard Worker min_nodes = int(arr[0]) 667*da0073e9SAndroid Build Coastguard Worker max_nodes = int(arr[1]) 668*da0073e9SAndroid Build Coastguard Worker else: 669*da0073e9SAndroid Build Coastguard Worker raise RuntimeError(f'nnodes={nnodes} is not in "MIN:MAX" format') # noqa: E231 670*da0073e9SAndroid Build Coastguard Worker 671*da0073e9SAndroid Build Coastguard Worker return min_nodes, max_nodes 672*da0073e9SAndroid Build Coastguard Worker 673*da0073e9SAndroid Build Coastguard Worker 674*da0073e9SAndroid Build Coastguard Workerdef determine_local_world_size(nproc_per_node: str): 675*da0073e9SAndroid Build Coastguard Worker try: 676*da0073e9SAndroid Build Coastguard Worker logging.info("Using nproc_per_node=%s.", nproc_per_node) 677*da0073e9SAndroid Build Coastguard Worker return int(nproc_per_node) 678*da0073e9SAndroid Build Coastguard Worker except ValueError as e: 679*da0073e9SAndroid Build Coastguard Worker if nproc_per_node == "cpu": 680*da0073e9SAndroid Build Coastguard Worker num_proc = os.cpu_count() 681*da0073e9SAndroid Build Coastguard Worker device_type = "cpu" 682*da0073e9SAndroid Build Coastguard Worker elif nproc_per_node == "gpu": 683*da0073e9SAndroid Build Coastguard Worker if not torch.cuda.is_available(): 684*da0073e9SAndroid Build Coastguard Worker raise ValueError("Cuda is not available.") from e 685*da0073e9SAndroid Build Coastguard Worker device_type = "gpu" 686*da0073e9SAndroid Build Coastguard Worker num_proc = torch.cuda.device_count() 687*da0073e9SAndroid Build Coastguard Worker elif nproc_per_node == torch._C._get_privateuse1_backend_name(): 688*da0073e9SAndroid Build Coastguard Worker if not _get_custom_mod_func("is_available")(): 689*da0073e9SAndroid Build Coastguard Worker raise ValueError(f"{nproc_per_node} is not available.") from e 690*da0073e9SAndroid Build Coastguard Worker device_type = nproc_per_node 691*da0073e9SAndroid Build Coastguard Worker num_proc = _get_custom_mod_func("device_count")() 692*da0073e9SAndroid Build Coastguard Worker elif nproc_per_node == "auto": 693*da0073e9SAndroid Build Coastguard Worker if torch.cuda.is_available(): 694*da0073e9SAndroid Build Coastguard Worker num_proc = torch.cuda.device_count() 695*da0073e9SAndroid Build Coastguard Worker device_type = "gpu" 696*da0073e9SAndroid Build Coastguard Worker elif ( 697*da0073e9SAndroid Build Coastguard Worker hasattr(torch, torch._C._get_privateuse1_backend_name()) 698*da0073e9SAndroid Build Coastguard Worker and _get_custom_mod_func("is_available")() 699*da0073e9SAndroid Build Coastguard Worker ): 700*da0073e9SAndroid Build Coastguard Worker num_proc = _get_custom_mod_func("device_count")() 701*da0073e9SAndroid Build Coastguard Worker device_type = torch._C._get_privateuse1_backend_name() 702*da0073e9SAndroid Build Coastguard Worker else: 703*da0073e9SAndroid Build Coastguard Worker num_proc = os.cpu_count() 704*da0073e9SAndroid Build Coastguard Worker device_type = "cpu" 705*da0073e9SAndroid Build Coastguard Worker else: 706*da0073e9SAndroid Build Coastguard Worker raise ValueError( 707*da0073e9SAndroid Build Coastguard Worker f"Unsupported nproc_per_node value: {nproc_per_node}" 708*da0073e9SAndroid Build Coastguard Worker ) from e 709*da0073e9SAndroid Build Coastguard Worker 710*da0073e9SAndroid Build Coastguard Worker logger.info( 711*da0073e9SAndroid Build Coastguard Worker "Using nproc_per_node=%s, setting nproc_per_node to %s since the instance has %s %s", 712*da0073e9SAndroid Build Coastguard Worker nproc_per_node, 713*da0073e9SAndroid Build Coastguard Worker num_proc, 714*da0073e9SAndroid Build Coastguard Worker num_proc, 715*da0073e9SAndroid Build Coastguard Worker device_type, 716*da0073e9SAndroid Build Coastguard Worker ) 717*da0073e9SAndroid Build Coastguard Worker return num_proc 718*da0073e9SAndroid Build Coastguard Worker 719*da0073e9SAndroid Build Coastguard Worker 720*da0073e9SAndroid Build Coastguard Workerdef get_rdzv_endpoint(args): 721*da0073e9SAndroid Build Coastguard Worker if args.rdzv_backend == "static" and not args.rdzv_endpoint: 722*da0073e9SAndroid Build Coastguard Worker return f"{args.master_addr}:{args.master_port}" # noqa: E231 723*da0073e9SAndroid Build Coastguard Worker return args.rdzv_endpoint 724*da0073e9SAndroid Build Coastguard Worker 725*da0073e9SAndroid Build Coastguard Worker 726*da0073e9SAndroid Build Coastguard Workerdef get_use_env(args) -> bool: 727*da0073e9SAndroid Build Coastguard Worker """ 728*da0073e9SAndroid Build Coastguard Worker Retrieve ``use_env`` from the args. 729*da0073e9SAndroid Build Coastguard Worker 730*da0073e9SAndroid Build Coastguard Worker ``use_env`` is a legacy argument, if ``use_env`` is False, the 731*da0073e9SAndroid Build Coastguard Worker ``--node-rank`` argument will be transferred to all worker processes. 732*da0073e9SAndroid Build Coastguard Worker ``use_env`` is only used by the ``torch.distributed.launch`` and will 733*da0073e9SAndroid Build Coastguard Worker be deprecated in future releases. 734*da0073e9SAndroid Build Coastguard Worker """ 735*da0073e9SAndroid Build Coastguard Worker if not hasattr(args, "use_env"): 736*da0073e9SAndroid Build Coastguard Worker return True 737*da0073e9SAndroid Build Coastguard Worker return args.use_env 738*da0073e9SAndroid Build Coastguard Worker 739*da0073e9SAndroid Build Coastguard Worker 740*da0073e9SAndroid Build Coastguard Workerdef _get_logs_specs_class(logs_specs_name: Optional[str]) -> Type[LogsSpecs]: 741*da0073e9SAndroid Build Coastguard Worker """ 742*da0073e9SAndroid Build Coastguard Worker Attemps to load `torchrun.logs_spec` entrypoint with key of `logs_specs_name` param. 743*da0073e9SAndroid Build Coastguard Worker Provides plugin mechanism to provide custom implementation of LogsSpecs. 744*da0073e9SAndroid Build Coastguard Worker 745*da0073e9SAndroid Build Coastguard Worker Returns `DefaultLogsSpecs` when logs_spec_name is None. 746*da0073e9SAndroid Build Coastguard Worker Raises ValueError when entrypoint for `logs_spec_name` can't be found in entrypoints. 747*da0073e9SAndroid Build Coastguard Worker """ 748*da0073e9SAndroid Build Coastguard Worker logs_specs_cls = None 749*da0073e9SAndroid Build Coastguard Worker if logs_specs_name is not None: 750*da0073e9SAndroid Build Coastguard Worker eps = metadata.entry_points() 751*da0073e9SAndroid Build Coastguard Worker if hasattr(eps, "select"): # >= 3.10 752*da0073e9SAndroid Build Coastguard Worker group = eps.select(group="torchrun.logs_specs") 753*da0073e9SAndroid Build Coastguard Worker if group.select(name=logs_specs_name): 754*da0073e9SAndroid Build Coastguard Worker logs_specs_cls = group[logs_specs_name].load() 755*da0073e9SAndroid Build Coastguard Worker 756*da0073e9SAndroid Build Coastguard Worker elif specs := eps.get("torchrun.logs_specs"): # < 3.10 757*da0073e9SAndroid Build Coastguard Worker if entrypoint_list := [ep for ep in specs if ep.name == logs_specs_name]: 758*da0073e9SAndroid Build Coastguard Worker logs_specs_cls = entrypoint_list[0].load() 759*da0073e9SAndroid Build Coastguard Worker 760*da0073e9SAndroid Build Coastguard Worker if logs_specs_cls is None: 761*da0073e9SAndroid Build Coastguard Worker raise ValueError( 762*da0073e9SAndroid Build Coastguard Worker f"Could not find entrypoint under 'torchrun.logs_specs[{logs_specs_name}]' key" 763*da0073e9SAndroid Build Coastguard Worker ) 764*da0073e9SAndroid Build Coastguard Worker 765*da0073e9SAndroid Build Coastguard Worker logging.info( 766*da0073e9SAndroid Build Coastguard Worker "Using logs_spec '%s' mapped to %s", logs_specs_name, str(logs_specs_cls) 767*da0073e9SAndroid Build Coastguard Worker ) 768*da0073e9SAndroid Build Coastguard Worker else: 769*da0073e9SAndroid Build Coastguard Worker logs_specs_cls = DefaultLogsSpecs 770*da0073e9SAndroid Build Coastguard Worker 771*da0073e9SAndroid Build Coastguard Worker return logs_specs_cls 772*da0073e9SAndroid Build Coastguard Worker 773*da0073e9SAndroid Build Coastguard Worker 774*da0073e9SAndroid Build Coastguard Workerdef config_from_args(args) -> Tuple[LaunchConfig, Union[Callable, str], List[str]]: 775*da0073e9SAndroid Build Coastguard Worker # If ``args`` not passed, defaults to ``sys.argv[:1]`` 776*da0073e9SAndroid Build Coastguard Worker min_nodes, max_nodes = parse_min_max_nnodes(args.nnodes) 777*da0073e9SAndroid Build Coastguard Worker assert 0 < min_nodes <= max_nodes 778*da0073e9SAndroid Build Coastguard Worker assert args.max_restarts >= 0 779*da0073e9SAndroid Build Coastguard Worker 780*da0073e9SAndroid Build Coastguard Worker if ( 781*da0073e9SAndroid Build Coastguard Worker hasattr(args, "master_addr") 782*da0073e9SAndroid Build Coastguard Worker and args.rdzv_backend != "static" 783*da0073e9SAndroid Build Coastguard Worker and not args.rdzv_endpoint 784*da0073e9SAndroid Build Coastguard Worker ): 785*da0073e9SAndroid Build Coastguard Worker logger.warning( 786*da0073e9SAndroid Build Coastguard Worker "master_addr is only used for static rdzv_backend and when rdzv_endpoint " 787*da0073e9SAndroid Build Coastguard Worker "is not specified." 788*da0073e9SAndroid Build Coastguard Worker ) 789*da0073e9SAndroid Build Coastguard Worker 790*da0073e9SAndroid Build Coastguard Worker nproc_per_node = determine_local_world_size(args.nproc_per_node) 791*da0073e9SAndroid Build Coastguard Worker if "OMP_NUM_THREADS" not in os.environ and nproc_per_node > 1: 792*da0073e9SAndroid Build Coastguard Worker omp_num_threads = 1 793*da0073e9SAndroid Build Coastguard Worker logger.warning( 794*da0073e9SAndroid Build Coastguard Worker "\n*****************************************\n" 795*da0073e9SAndroid Build Coastguard Worker "Setting OMP_NUM_THREADS environment variable for each process to be " 796*da0073e9SAndroid Build Coastguard Worker "%s in default, to avoid your system being overloaded, " 797*da0073e9SAndroid Build Coastguard Worker "please further tune the variable for optimal performance in " 798*da0073e9SAndroid Build Coastguard Worker "your application as needed. \n" 799*da0073e9SAndroid Build Coastguard Worker "*****************************************", 800*da0073e9SAndroid Build Coastguard Worker omp_num_threads, 801*da0073e9SAndroid Build Coastguard Worker ) 802*da0073e9SAndroid Build Coastguard Worker # This env variable will be passed down to the subprocesses 803*da0073e9SAndroid Build Coastguard Worker os.environ["OMP_NUM_THREADS"] = str(omp_num_threads) 804*da0073e9SAndroid Build Coastguard Worker 805*da0073e9SAndroid Build Coastguard Worker log_line_prefix_template = os.getenv("TORCHELASTIC_LOG_LINE_PREFIX_TEMPLATE") 806*da0073e9SAndroid Build Coastguard Worker 807*da0073e9SAndroid Build Coastguard Worker rdzv_configs = _parse_rendezvous_config(args.rdzv_conf) 808*da0073e9SAndroid Build Coastguard Worker 809*da0073e9SAndroid Build Coastguard Worker if args.rdzv_backend == "static": 810*da0073e9SAndroid Build Coastguard Worker rdzv_configs["rank"] = args.node_rank 811*da0073e9SAndroid Build Coastguard Worker 812*da0073e9SAndroid Build Coastguard Worker rdzv_endpoint = get_rdzv_endpoint(args) 813*da0073e9SAndroid Build Coastguard Worker 814*da0073e9SAndroid Build Coastguard Worker ranks: Optional[Set[int]] = None 815*da0073e9SAndroid Build Coastguard Worker if args.local_ranks_filter: 816*da0073e9SAndroid Build Coastguard Worker try: 817*da0073e9SAndroid Build Coastguard Worker ranks = set(map(int, args.local_ranks_filter.split(","))) 818*da0073e9SAndroid Build Coastguard Worker assert ranks 819*da0073e9SAndroid Build Coastguard Worker except Exception as e: 820*da0073e9SAndroid Build Coastguard Worker raise ValueError( 821*da0073e9SAndroid Build Coastguard Worker "--local_ranks_filter must be a comma-separated list of integers e.g. --local_ranks_filter=0,1,2" 822*da0073e9SAndroid Build Coastguard Worker ) from e 823*da0073e9SAndroid Build Coastguard Worker 824*da0073e9SAndroid Build Coastguard Worker logs_specs_cls: Type[LogsSpecs] = _get_logs_specs_class(args.logs_specs) 825*da0073e9SAndroid Build Coastguard Worker logs_specs = logs_specs_cls( 826*da0073e9SAndroid Build Coastguard Worker log_dir=args.log_dir, 827*da0073e9SAndroid Build Coastguard Worker redirects=Std.from_str(args.redirects), 828*da0073e9SAndroid Build Coastguard Worker tee=Std.from_str(args.tee), 829*da0073e9SAndroid Build Coastguard Worker local_ranks_filter=ranks, 830*da0073e9SAndroid Build Coastguard Worker ) 831*da0073e9SAndroid Build Coastguard Worker 832*da0073e9SAndroid Build Coastguard Worker config = LaunchConfig( 833*da0073e9SAndroid Build Coastguard Worker min_nodes=min_nodes, 834*da0073e9SAndroid Build Coastguard Worker max_nodes=max_nodes, 835*da0073e9SAndroid Build Coastguard Worker nproc_per_node=nproc_per_node, 836*da0073e9SAndroid Build Coastguard Worker run_id=args.rdzv_id, 837*da0073e9SAndroid Build Coastguard Worker role=args.role, 838*da0073e9SAndroid Build Coastguard Worker rdzv_endpoint=rdzv_endpoint, 839*da0073e9SAndroid Build Coastguard Worker rdzv_backend=args.rdzv_backend, 840*da0073e9SAndroid Build Coastguard Worker rdzv_configs=rdzv_configs, 841*da0073e9SAndroid Build Coastguard Worker max_restarts=args.max_restarts, 842*da0073e9SAndroid Build Coastguard Worker monitor_interval=args.monitor_interval, 843*da0073e9SAndroid Build Coastguard Worker start_method=args.start_method, 844*da0073e9SAndroid Build Coastguard Worker log_line_prefix_template=log_line_prefix_template, 845*da0073e9SAndroid Build Coastguard Worker local_addr=args.local_addr, 846*da0073e9SAndroid Build Coastguard Worker logs_specs=logs_specs, 847*da0073e9SAndroid Build Coastguard Worker ) 848*da0073e9SAndroid Build Coastguard Worker 849*da0073e9SAndroid Build Coastguard Worker with_python = not args.no_python 850*da0073e9SAndroid Build Coastguard Worker cmd: Union[Callable, str] 851*da0073e9SAndroid Build Coastguard Worker cmd_args = [] 852*da0073e9SAndroid Build Coastguard Worker use_env = get_use_env(args) 853*da0073e9SAndroid Build Coastguard Worker if args.run_path: 854*da0073e9SAndroid Build Coastguard Worker cmd = run_script_path 855*da0073e9SAndroid Build Coastguard Worker cmd_args.append(args.training_script) 856*da0073e9SAndroid Build Coastguard Worker else: 857*da0073e9SAndroid Build Coastguard Worker if with_python: 858*da0073e9SAndroid Build Coastguard Worker cmd = os.getenv("PYTHON_EXEC", sys.executable) 859*da0073e9SAndroid Build Coastguard Worker cmd_args.append("-u") 860*da0073e9SAndroid Build Coastguard Worker if args.module: 861*da0073e9SAndroid Build Coastguard Worker cmd_args.append("-m") 862*da0073e9SAndroid Build Coastguard Worker cmd_args.append(args.training_script) 863*da0073e9SAndroid Build Coastguard Worker else: 864*da0073e9SAndroid Build Coastguard Worker if args.module: 865*da0073e9SAndroid Build Coastguard Worker raise ValueError( 866*da0073e9SAndroid Build Coastguard Worker "Don't use both the '--no-python' flag" 867*da0073e9SAndroid Build Coastguard Worker " and the '--module' flag at the same time." 868*da0073e9SAndroid Build Coastguard Worker ) 869*da0073e9SAndroid Build Coastguard Worker cmd = args.training_script 870*da0073e9SAndroid Build Coastguard Worker if not use_env: 871*da0073e9SAndroid Build Coastguard Worker cmd_args.append(f"--local-rank={macros.local_rank}") 872*da0073e9SAndroid Build Coastguard Worker cmd_args.extend(args.training_script_args) 873*da0073e9SAndroid Build Coastguard Worker 874*da0073e9SAndroid Build Coastguard Worker return config, cmd, cmd_args 875*da0073e9SAndroid Build Coastguard Worker 876*da0073e9SAndroid Build Coastguard Worker 877*da0073e9SAndroid Build Coastguard Workerdef run_script_path(training_script: str, *training_script_args: str): 878*da0073e9SAndroid Build Coastguard Worker """ 879*da0073e9SAndroid Build Coastguard Worker Run the provided `training_script` from within this interpreter. 880*da0073e9SAndroid Build Coastguard Worker 881*da0073e9SAndroid Build Coastguard Worker Usage: `script_as_function("/abs/path/to/script.py", "--arg1", "val1")` 882*da0073e9SAndroid Build Coastguard Worker """ 883*da0073e9SAndroid Build Coastguard Worker import runpy 884*da0073e9SAndroid Build Coastguard Worker import sys 885*da0073e9SAndroid Build Coastguard Worker 886*da0073e9SAndroid Build Coastguard Worker sys.argv = [training_script] + [*training_script_args] 887*da0073e9SAndroid Build Coastguard Worker runpy.run_path(sys.argv[0], run_name="__main__") 888*da0073e9SAndroid Build Coastguard Worker 889*da0073e9SAndroid Build Coastguard Worker 890*da0073e9SAndroid Build Coastguard Workerdef run(args): 891*da0073e9SAndroid Build Coastguard Worker torch.multiprocessing._set_thread_name("pt_elastic") 892*da0073e9SAndroid Build Coastguard Worker 893*da0073e9SAndroid Build Coastguard Worker if args.standalone: 894*da0073e9SAndroid Build Coastguard Worker args.rdzv_backend = "c10d" 895*da0073e9SAndroid Build Coastguard Worker args.rdzv_endpoint = "localhost:0" 896*da0073e9SAndroid Build Coastguard Worker args.rdzv_id = str(uuid.uuid4()) 897*da0073e9SAndroid Build Coastguard Worker logger.info( 898*da0073e9SAndroid Build Coastguard Worker "\n**************************************\n" 899*da0073e9SAndroid Build Coastguard Worker "Rendezvous info:\n" 900*da0073e9SAndroid Build Coastguard Worker "--rdzv-backend=%s " 901*da0073e9SAndroid Build Coastguard Worker "--rdzv-endpoint=%s " 902*da0073e9SAndroid Build Coastguard Worker "--rdzv-id=%s\n" 903*da0073e9SAndroid Build Coastguard Worker "**************************************\n", 904*da0073e9SAndroid Build Coastguard Worker args.rdzv_backend, 905*da0073e9SAndroid Build Coastguard Worker args.rdzv_endpoint, 906*da0073e9SAndroid Build Coastguard Worker args.rdzv_id, 907*da0073e9SAndroid Build Coastguard Worker ) 908*da0073e9SAndroid Build Coastguard Worker 909*da0073e9SAndroid Build Coastguard Worker config, cmd, cmd_args = config_from_args(args) 910*da0073e9SAndroid Build Coastguard Worker elastic_launch( 911*da0073e9SAndroid Build Coastguard Worker config=config, 912*da0073e9SAndroid Build Coastguard Worker entrypoint=cmd, 913*da0073e9SAndroid Build Coastguard Worker )(*cmd_args) 914*da0073e9SAndroid Build Coastguard Worker 915*da0073e9SAndroid Build Coastguard Worker 916*da0073e9SAndroid Build Coastguard Worker@record 917*da0073e9SAndroid Build Coastguard Workerdef main(args=None): 918*da0073e9SAndroid Build Coastguard Worker args = parse_args(args) 919*da0073e9SAndroid Build Coastguard Worker run(args) 920*da0073e9SAndroid Build Coastguard Worker 921*da0073e9SAndroid Build Coastguard Worker 922*da0073e9SAndroid Build Coastguard Workerif __name__ == "__main__": 923*da0073e9SAndroid Build Coastguard Worker main() 924