1# mypy: allow-untyped-defs 2# Copyright (c) Facebook, Inc. and its affiliates. 3# All rights reserved. 4# 5# This source code is licensed under the BSD-style license found in the 6# LICENSE file in the root directory of this source tree. 7 8import datetime 9import random 10import time 11from base64 import b64decode, b64encode 12from typing import Optional 13 14import etcd # type: ignore[import] 15 16# pyre-ignore[21]: Could not find name `Store` in `torch.distributed`. 17from torch.distributed import Store 18 19 20# Delay (sleep) for a small random amount to reduce CAS failures. 21# This does not affect correctness, but will reduce requests to etcd server. 22def cas_delay(): 23 time.sleep(random.uniform(0, 0.1)) 24 25 26# pyre-fixme[11]: Annotation `Store` is not defined as a type. 27class EtcdStore(Store): 28 """ 29 Implement a c10 Store interface by piggybacking on the rendezvous etcd instance. 30 31 This is the store object returned by ``EtcdRendezvous``. 32 """ 33 34 def __init__( 35 self, 36 etcd_client, 37 etcd_store_prefix, 38 # Default timeout same as in c10d/Store.hpp 39 timeout: Optional[datetime.timedelta] = None, 40 ): 41 super().__init__() # required for pybind trampoline. 42 43 self.client = etcd_client 44 self.prefix = etcd_store_prefix 45 46 if timeout is not None: 47 self.set_timeout(timeout) 48 49 if not self.prefix.endswith("/"): 50 self.prefix += "/" 51 52 def set(self, key, value): 53 """ 54 Write a key/value pair into ``EtcdStore``. 55 56 Both key and value may be either Python ``str`` or ``bytes``. 57 """ 58 self.client.set(key=self.prefix + self._encode(key), value=self._encode(value)) 59 60 def get(self, key) -> bytes: 61 """ 62 Get a value by key, possibly doing a blocking wait. 63 64 If key is not immediately present, will do a blocking wait 65 for at most ``timeout`` duration or until the key is published. 66 67 68 Returns: 69 value ``(bytes)`` 70 71 Raises: 72 LookupError - If key still not published after timeout 73 """ 74 b64_key = self.prefix + self._encode(key) 75 kvs = self._try_wait_get([b64_key]) 76 77 if kvs is None: 78 raise LookupError(f"Key {key} not found in EtcdStore") 79 80 return self._decode(kvs[b64_key]) 81 82 def add(self, key, num: int) -> int: 83 """ 84 Atomically increment a value by an integer amount. 85 86 The integer is represented as a string using base 10. If key is not present, 87 a default value of ``0`` will be assumed. 88 89 Returns: 90 the new (incremented) value 91 92 93 """ 94 b64_key = self._encode(key) 95 # c10d Store assumes value is an integer represented as a decimal string 96 try: 97 # Assume default value "0", if this key didn't yet: 98 node = self.client.write( 99 key=self.prefix + b64_key, 100 value=self._encode(str(num)), # i.e. 0 + num 101 prevExist=False, 102 ) 103 return int(self._decode(node.value)) 104 except etcd.EtcdAlreadyExist: 105 pass 106 107 while True: 108 # Note: c10d Store does not have a method to delete keys, so we 109 # can be sure it's still there. 110 node = self.client.get(key=self.prefix + b64_key) 111 new_value = self._encode(str(int(self._decode(node.value)) + num)) 112 try: 113 node = self.client.test_and_set( 114 key=node.key, value=new_value, prev_value=node.value 115 ) 116 return int(self._decode(node.value)) 117 except etcd.EtcdCompareFailed: 118 cas_delay() 119 120 def wait(self, keys, override_timeout: Optional[datetime.timedelta] = None): 121 """ 122 Wait until all of the keys are published, or until timeout. 123 124 Raises: 125 LookupError - if timeout occurs 126 """ 127 b64_keys = [self.prefix + self._encode(key) for key in keys] 128 kvs = self._try_wait_get(b64_keys, override_timeout) 129 if kvs is None: 130 raise LookupError("Timeout while waiting for keys in EtcdStore") 131 # No return value on success 132 133 def check(self, keys) -> bool: 134 """Check if all of the keys are immediately present (without waiting).""" 135 b64_keys = [self.prefix + self._encode(key) for key in keys] 136 kvs = self._try_wait_get( 137 b64_keys, 138 override_timeout=datetime.timedelta(microseconds=1), # as if no wait 139 ) 140 return kvs is not None 141 142 # 143 # Encode key/value data in base64, so we can store arbitrary binary data 144 # in EtcdStore. Input can be `str` or `bytes`. 145 # In case of `str`, utf-8 encoding is assumed. 146 # 147 def _encode(self, value) -> str: 148 if type(value) == bytes: 149 return b64encode(value).decode() 150 elif type(value) == str: 151 return b64encode(value.encode()).decode() 152 raise ValueError("Value must be of type str or bytes") 153 154 # 155 # Decode a base64 string (of type `str` or `bytes`). 156 # Return type is `bytes`, which is more convenient with the Store interface. 157 # 158 def _decode(self, value) -> bytes: 159 if type(value) == bytes: 160 return b64decode(value) 161 elif type(value) == str: 162 return b64decode(value.encode()) 163 raise ValueError("Value must be of type str or bytes") 164 165 # 166 # Get all of the (base64-encoded) etcd keys at once, or wait until all the keys 167 # are published or timeout occurs. 168 # This is a helper method for the public interface methods. 169 # 170 # On success, a dictionary of {etcd key -> etcd value} is returned. 171 # On timeout, None is returned. 172 # 173 def _try_wait_get(self, b64_keys, override_timeout=None): 174 timeout = self.timeout if override_timeout is None else override_timeout # type: ignore[attr-defined] 175 deadline = time.time() + timeout.total_seconds() 176 177 while True: 178 # Read whole directory (of keys), filter only the ones waited for 179 all_nodes = self.client.get(key=self.prefix) 180 req_nodes = { 181 node.key: node.value 182 for node in all_nodes.children 183 if node.key in b64_keys 184 } 185 186 if len(req_nodes) == len(b64_keys): 187 # All keys are available 188 return req_nodes 189 190 watch_timeout = deadline - time.time() 191 if watch_timeout <= 0: 192 return None 193 194 try: 195 self.client.watch( 196 key=self.prefix, 197 recursive=True, 198 timeout=watch_timeout, 199 index=all_nodes.etcd_index + 1, 200 ) 201 except etcd.EtcdWatchTimedOut: 202 if time.time() >= deadline: 203 return None 204 else: 205 continue 206 except etcd.EtcdEventIndexCleared: 207 continue 208