xref: /aosp_15_r20/external/pytorch/torch/distributed/elastic/rendezvous/api.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# mypy: allow-untyped-defs
2# Copyright (c) Facebook, Inc. and its affiliates.
3# All rights reserved.
4#
5# This source code is licensed under the BSD-style license found in the
6# LICENSE file in the root directory of this source tree.
7
8import socket
9from abc import ABC, abstractmethod
10from dataclasses import dataclass
11from typing import Any, Callable, ClassVar, Dict, Optional
12
13from torch.distributed import Store
14from torch.distributed.elastic.utils.distributed import get_free_port as _get_free_port
15
16
17__all__ = [
18    "RendezvousClosedError",
19    "RendezvousConnectionError",
20    "RendezvousError",
21    "RendezvousGracefulExitError",
22    "RendezvousHandler",
23    "RendezvousHandlerCreator",
24    "RendezvousHandlerRegistry",
25    "RendezvousInfo",
26    "RendezvousParameters",
27    "RendezvousStateError",
28    "RendezvousStoreInfo",
29    "RendezvousTimeoutError",
30    "rendezvous_handler_registry",
31]
32
33
34class RendezvousError(Exception):
35    """Represents the base type for rendezvous errors."""
36
37
38class RendezvousClosedError(RendezvousError):
39    """Raised when a rendezvous is closed."""
40
41
42class RendezvousTimeoutError(RendezvousError):
43    """Raised when a rendezvous did not complete on time."""
44
45
46class RendezvousConnectionError(RendezvousError):
47    """Raised when the connection to a rendezvous backend has failed."""
48
49
50class RendezvousStateError(RendezvousError):
51    """Raised when the state of a rendezvous is corrupt."""
52
53
54class RendezvousGracefulExitError(RendezvousError):
55    """Raised when node wasn't not included in rendezvous and gracefully exits.
56
57    Exception is a mechanism to exit the stack, however does not mean a failure.
58    """
59
60
61@dataclass
62class RendezvousStoreInfo:
63    """Store address and port that can be used to bootstrap trainer distributed comms"""
64
65    MASTER_ADDR_KEY: ClassVar[str] = "MASTER_ADDR"
66    MASTER_PORT_KEY: ClassVar[str] = "MASTER_PORT"
67    master_addr: str
68    master_port: int
69
70    @staticmethod
71    def build(
72        rank: int, store: Store, local_addr: Optional[str]
73    ) -> "RendezvousStoreInfo":
74        """Factory method, finds unused new port on rank0 host and addr/port info with all ranks.
75
76        If master_addr/master_port is knowns (useful when sharing existing tcp store server) use the constructor.
77
78        Args:
79            rank: rank of the current node
80            store: store to use for rendezvous
81            local_addr: address of the current node, if not provided will be resolved from hostname
82        """
83        # TODO swap to collectives comms API
84        if rank == 0:
85            addr = local_addr or socket.getfqdn()
86            port = _get_free_port()
87            store.set(RendezvousStoreInfo.MASTER_ADDR_KEY, addr.encode(encoding="UTF-8"))  # type: ignore[arg-type]
88            store.set(RendezvousStoreInfo.MASTER_PORT_KEY, str(port).encode(encoding="UTF-8"))  # type: ignore[arg-type]
89
90        addr = store.get(RendezvousStoreInfo.MASTER_ADDR_KEY).decode(encoding="UTF-8")
91        port = int(
92            store.get(RendezvousStoreInfo.MASTER_PORT_KEY).decode(encoding="UTF-8")
93        )
94        return RendezvousStoreInfo(master_addr=addr, master_port=port)
95
96
97class RendezvousInfo:
98    """Holds the information about the rendezvous."""
99
100    def __init__(
101        self,
102        store: Store,
103        rank: int,
104        world_size: int,
105        bootstrap_store_info: RendezvousStoreInfo,
106    ):
107        self._store = store
108        self._rank = rank
109        self._world_size = world_size
110        self._bootstrap_store_info = bootstrap_store_info
111
112    @property
113    def store(self) -> Store:
114        """Store used by torchelastic control plane"""
115        return self._store
116
117    @property
118    def rank(self) -> int:
119        """Rank within a group"""
120        return self._rank
121
122    @property
123    def world_size(self) -> int:
124        """Global group size"""
125        return self._world_size
126
127    @property
128    def bootstrap_store_info(self) -> Optional[RendezvousStoreInfo]:
129        """Store information that can used by trainer code to bootstrap distributed comms."""
130        return self._bootstrap_store_info
131
132
133class RendezvousHandler(ABC):
134    """Main rendezvous interface.
135
136    Note:
137        Distributed Torch users normally **do not** need to implement their own
138        ``RendezvousHandler``. An implementation based on C10d Store is already
139        provided, and is recommended for most users.
140    """
141
142    @abstractmethod
143    def get_backend(self) -> str:
144        """Return the name of the rendezvous backend."""
145
146    @property
147    def use_agent_store(self) -> bool:
148        """Indicates that store reference returned by :py:meth:`next_rendezvous` can be shared with user
149        applications and will be available during application lifecyle.
150
151        Rendezous handler impl will share store details as instance of :py:class:`RendezvousStoreInfo`.
152        Applications as a convention use `MASTER_ADDR`/`MASTER_PORT` env variables to lookup the store.
153        """
154        return False
155
156    @abstractmethod
157    def next_rendezvous(self) -> RendezvousInfo:
158        """Main entry-point into the rendezvous barrier.
159
160        Blocks until the rendezvous is complete and the current process is
161        included in the formed worker group, or a timeout occurs, or the
162        rendezvous was marked closed.
163
164        Returns:
165            Instance of :py:class:`RendezvousInfo`.
166
167        Raises:
168            RendezvousClosedError:
169                The rendezvous is closed.
170            RendezvousConnectionError:
171                The connection to the rendezvous backend has failed.
172            RendezvousStateError:
173                The rendezvous state is corrupt.
174            RendezvousTimeoutError:
175                The rendezvous did not complete on time.
176        """
177
178    @abstractmethod
179    def is_closed(self) -> bool:
180        """Check whether the rendezvous has been closed.
181
182        A closed rendezvous means all future attempts to re-rendezvous within
183        same job will fail.
184
185        ``is_closed()`` and :py:meth:`set_closed` have semantics of eventual
186        propagation and should not be used for synchronization. The intention is
187        that if at least one node decides the job is finished, it will close the
188        rendezvous, and other nodes will soon observe this and stop running as
189        well.
190        """
191
192    @abstractmethod
193    def set_closed(self):
194        """Mark the rendezvous as closed."""
195
196    @abstractmethod
197    def num_nodes_waiting(self) -> int:
198        """Return the number of nodes who arrived late at the rendezvous
199        barrier, hence were not included in the current worker group.
200
201        Callers should periodically call this method to check whether new
202        nodes are waiting to join the job and if so admit them by calling
203        :py:meth:`next_rendezvous()` (re-rendezvous).
204        """
205
206    @abstractmethod
207    def get_run_id(self) -> str:
208        """Return the run id of the rendezvous.
209
210        The run id is a user-defined id that uniquely identifies an instance of
211        a distributed application. It typically maps to a job id and is used to
212        allow nodes to join the correct distributed application.
213        """
214
215    @abstractmethod
216    def shutdown(self) -> bool:
217        """Close all resources that were open for the rendezvous.
218
219        Example::
220
221            rdzv_handler = ...
222            try:
223                store, rank, world_size = rdzv_handler.next_rendezvous()
224            finally:
225                rdzv_handler.shutdown()
226        """
227
228
229class RendezvousParameters:
230    """Hold the parameters to construct a :py:class:`RendezvousHandler`.
231
232    Args:
233        backend:
234            The name of the backend to use to handle the rendezvous.
235        endpoint:
236            The endpoint of the rendezvous, usually in form <hostname>[:<port>].
237        run_id:
238            The id of the rendezvous.
239        min_nodes:
240            The minimum number of nodes to admit to the rendezvous.
241        max_nodes:
242            The maximum number of nodes to admit to the rendezvous.
243        local_addr:
244            The address of the local node.
245        **kwargs:
246            Additional parameters for the specified backend.
247    """
248
249    def __init__(
250        self,
251        backend: str,
252        endpoint: str,
253        run_id: str,
254        min_nodes: int,
255        max_nodes: int,
256        local_addr: Optional[str] = None,
257        **kwargs,
258    ):
259        if not backend:
260            raise ValueError("The rendezvous backend name must be a non-empty string.")
261
262        if min_nodes < 1:
263            raise ValueError(
264                f"The minimum number of rendezvous nodes ({min_nodes}) must be greater than zero."
265            )
266        if max_nodes < min_nodes:
267            raise ValueError(
268                f"The maximum number of rendezvous nodes ({max_nodes}) must be greater than or "
269                f"equal to the minimum number of rendezvous nodes ({min_nodes})."
270            )
271
272        self.backend = backend
273        self.endpoint = endpoint
274        self.run_id = run_id
275        self.min_nodes = min_nodes
276        self.max_nodes = max_nodes
277        self.config = kwargs
278        self.local_addr = local_addr
279
280    def get(self, key: str, default: Any = None) -> Any:
281        """Return the value for ``key`` if ``key`` exists, else ``default``."""
282        return self.config.get(key, default)
283
284    def get_as_bool(self, key: str, default: Optional[bool] = None) -> Optional[bool]:
285        """Return the value for ``key`` as a ``bool``."""
286        value = self.get(key, default)
287        if value is None or isinstance(value, bool):
288            return value
289        if isinstance(value, int):
290            if value == 1:
291                return True
292            if value == 0:
293                return False
294        elif isinstance(value, str):
295            if value.lower() in ["1", "true", "t", "yes", "y"]:
296                return True
297            if value.lower() in ["0", "false", "f", "no", "n"]:
298                return False
299        raise ValueError(
300            f"The rendezvous configuration option '{key}' does not represent a valid boolean value."
301        )
302
303    def get_as_int(self, key: str, default: Optional[int] = None) -> Optional[int]:
304        """Return the value for ``key`` as an ``int``."""
305        value = self.get(key, default)
306        if value is None:
307            return value
308        try:
309            return int(value)
310        except ValueError as e:
311            raise ValueError(
312                f"The rendezvous configuration option '{key}' does not represent a valid integer "
313                "value."
314            ) from e
315
316
317RendezvousHandlerCreator = Callable[[RendezvousParameters], RendezvousHandler]
318
319
320class RendezvousHandlerRegistry:
321    """Represent a registry of :py:class:`RendezvousHandler` backends."""
322
323    _registry: Dict[str, RendezvousHandlerCreator]
324
325    def __init__(self) -> None:
326        self._registry = {}
327
328    def register(self, backend: str, creator: RendezvousHandlerCreator) -> None:
329        """Register a new rendezvous backend.
330
331        Args:
332            backend:
333                The name of the backend.
334            creator:
335                The callback to invoke to construct the
336                :py:class:`RendezvousHandler`.
337        """
338        if not backend:
339            raise ValueError("The rendezvous backend name must be a non-empty string.")
340
341        current_creator: Optional[RendezvousHandlerCreator]
342        try:
343            current_creator = self._registry[backend]
344        except KeyError:
345            current_creator = None
346
347        if current_creator is not None and current_creator != creator:
348            raise ValueError(
349                f"The rendezvous backend '{backend}' cannot be registered with '{creator}' as it "
350                f"is already registered with '{current_creator}'."
351            )
352
353        self._registry[backend] = creator
354
355    def create_handler(self, params: RendezvousParameters) -> RendezvousHandler:
356        """Create a new :py:class:`RendezvousHandler`."""
357        try:
358            creator = self._registry[params.backend]
359        except KeyError as e:
360            raise ValueError(
361                f"The rendezvous backend '{params.backend}' is not registered. Did you forget "
362                f"to call `{self.register.__name__}`?"
363            ) from e
364
365        handler = creator(params)
366
367        # Do some sanity check.
368        if handler.get_backend() != params.backend:
369            raise RuntimeError(
370                f"The rendezvous backend '{handler.get_backend()}' does not match the requested "
371                f"backend '{params.backend}'."
372            )
373
374        return handler
375
376
377# The default global registry instance used by launcher scripts to instantiate
378# rendezvous handlers.
379rendezvous_handler_registry = RendezvousHandlerRegistry()
380