xref: /aosp_15_r20/external/pytorch/torch/distributed/elastic/rendezvous/etcd_store.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# mypy: allow-untyped-defs
2# Copyright (c) Facebook, Inc. and its affiliates.
3# All rights reserved.
4#
5# This source code is licensed under the BSD-style license found in the
6# LICENSE file in the root directory of this source tree.
7
8import datetime
9import random
10import time
11from base64 import b64decode, b64encode
12from typing import Optional
13
14import etcd  # type: ignore[import]
15
16# pyre-ignore[21]: Could not find name `Store` in `torch.distributed`.
17from torch.distributed import Store
18
19
20# Delay (sleep) for a small random amount to reduce CAS failures.
21# This does not affect correctness, but will reduce requests to etcd server.
22def cas_delay():
23    time.sleep(random.uniform(0, 0.1))
24
25
26# pyre-fixme[11]: Annotation `Store` is not defined as a type.
27class EtcdStore(Store):
28    """
29    Implement a c10 Store interface by piggybacking on the rendezvous etcd instance.
30
31    This is the store object returned by ``EtcdRendezvous``.
32    """
33
34    def __init__(
35        self,
36        etcd_client,
37        etcd_store_prefix,
38        # Default timeout same as in c10d/Store.hpp
39        timeout: Optional[datetime.timedelta] = None,
40    ):
41        super().__init__()  # required for pybind trampoline.
42
43        self.client = etcd_client
44        self.prefix = etcd_store_prefix
45
46        if timeout is not None:
47            self.set_timeout(timeout)
48
49        if not self.prefix.endswith("/"):
50            self.prefix += "/"
51
52    def set(self, key, value):
53        """
54        Write a key/value pair into ``EtcdStore``.
55
56        Both key and value may be either Python ``str`` or ``bytes``.
57        """
58        self.client.set(key=self.prefix + self._encode(key), value=self._encode(value))
59
60    def get(self, key) -> bytes:
61        """
62        Get a value by key, possibly doing a blocking wait.
63
64        If key is not immediately present, will do a blocking wait
65        for at most ``timeout`` duration or until the key is published.
66
67
68        Returns:
69            value ``(bytes)``
70
71        Raises:
72            LookupError - If key still not published after timeout
73        """
74        b64_key = self.prefix + self._encode(key)
75        kvs = self._try_wait_get([b64_key])
76
77        if kvs is None:
78            raise LookupError(f"Key {key} not found in EtcdStore")
79
80        return self._decode(kvs[b64_key])
81
82    def add(self, key, num: int) -> int:
83        """
84        Atomically increment a value by an integer amount.
85
86        The integer is represented as a string using base 10. If key is not present,
87        a default value of ``0`` will be assumed.
88
89        Returns:
90             the new (incremented) value
91
92
93        """
94        b64_key = self._encode(key)
95        # c10d Store assumes value is an integer represented as a decimal string
96        try:
97            # Assume default value "0", if this key didn't yet:
98            node = self.client.write(
99                key=self.prefix + b64_key,
100                value=self._encode(str(num)),  # i.e. 0 + num
101                prevExist=False,
102            )
103            return int(self._decode(node.value))
104        except etcd.EtcdAlreadyExist:
105            pass
106
107        while True:
108            # Note: c10d Store does not have a method to delete keys, so we
109            # can be sure it's still there.
110            node = self.client.get(key=self.prefix + b64_key)
111            new_value = self._encode(str(int(self._decode(node.value)) + num))
112            try:
113                node = self.client.test_and_set(
114                    key=node.key, value=new_value, prev_value=node.value
115                )
116                return int(self._decode(node.value))
117            except etcd.EtcdCompareFailed:
118                cas_delay()
119
120    def wait(self, keys, override_timeout: Optional[datetime.timedelta] = None):
121        """
122        Wait until all of the keys are published, or until timeout.
123
124        Raises:
125            LookupError - if timeout occurs
126        """
127        b64_keys = [self.prefix + self._encode(key) for key in keys]
128        kvs = self._try_wait_get(b64_keys, override_timeout)
129        if kvs is None:
130            raise LookupError("Timeout while waiting for keys in EtcdStore")
131        # No return value on success
132
133    def check(self, keys) -> bool:
134        """Check if all of the keys are immediately present (without waiting)."""
135        b64_keys = [self.prefix + self._encode(key) for key in keys]
136        kvs = self._try_wait_get(
137            b64_keys,
138            override_timeout=datetime.timedelta(microseconds=1),  # as if no wait
139        )
140        return kvs is not None
141
142    #
143    # Encode key/value data in base64, so we can store arbitrary binary data
144    # in EtcdStore. Input can be `str` or `bytes`.
145    # In case of `str`, utf-8 encoding is assumed.
146    #
147    def _encode(self, value) -> str:
148        if type(value) == bytes:
149            return b64encode(value).decode()
150        elif type(value) == str:
151            return b64encode(value.encode()).decode()
152        raise ValueError("Value must be of type str or bytes")
153
154    #
155    # Decode a base64 string (of type `str` or `bytes`).
156    # Return type is `bytes`, which is more convenient with the Store interface.
157    #
158    def _decode(self, value) -> bytes:
159        if type(value) == bytes:
160            return b64decode(value)
161        elif type(value) == str:
162            return b64decode(value.encode())
163        raise ValueError("Value must be of type str or bytes")
164
165    #
166    # Get all of the (base64-encoded) etcd keys at once, or wait until all the keys
167    # are published or timeout occurs.
168    # This is a helper method for the public interface methods.
169    #
170    # On success, a dictionary of {etcd key -> etcd value} is returned.
171    # On timeout, None is returned.
172    #
173    def _try_wait_get(self, b64_keys, override_timeout=None):
174        timeout = self.timeout if override_timeout is None else override_timeout  # type: ignore[attr-defined]
175        deadline = time.time() + timeout.total_seconds()
176
177        while True:
178            # Read whole directory (of keys), filter only the ones waited for
179            all_nodes = self.client.get(key=self.prefix)
180            req_nodes = {
181                node.key: node.value
182                for node in all_nodes.children
183                if node.key in b64_keys
184            }
185
186            if len(req_nodes) == len(b64_keys):
187                # All keys are available
188                return req_nodes
189
190            watch_timeout = deadline - time.time()
191            if watch_timeout <= 0:
192                return None
193
194            try:
195                self.client.watch(
196                    key=self.prefix,
197                    recursive=True,
198                    timeout=watch_timeout,
199                    index=all_nodes.etcd_index + 1,
200                )
201            except etcd.EtcdWatchTimedOut:
202                if time.time() >= deadline:
203                    return None
204                else:
205                    continue
206            except etcd.EtcdEventIndexCleared:
207                continue
208