1# mypy: allow-untyped-defs 2# Copyright (c) Facebook, Inc. and its affiliates. 3# All rights reserved. 4# 5# This source code is licensed under the BSD-style license found in the 6# LICENSE file in the root directory of this source tree. 7 8import ipaddress 9import random 10import re 11import socket 12import time 13import weakref 14from datetime import timedelta 15from threading import Event, Thread 16from typing import Any, Callable, Dict, Optional, Tuple, Union 17 18 19__all__ = ["parse_rendezvous_endpoint"] 20 21 22def _parse_rendezvous_config(config_str: str) -> Dict[str, str]: 23 """Extract key-value pairs from a rendezvous configuration string. 24 25 Args: 26 config_str: 27 A string in format <key1>=<value1>,...,<keyN>=<valueN>. 28 """ 29 config: Dict[str, str] = {} 30 31 config_str = config_str.strip() 32 if not config_str: 33 return config 34 35 key_values = config_str.split(",") 36 for kv in key_values: 37 key, *values = kv.split("=", 1) 38 39 key = key.strip() 40 if not key: 41 raise ValueError( 42 "The rendezvous configuration string must be in format " 43 "<key1>=<value1>,...,<keyN>=<valueN>." 44 ) 45 46 value: Optional[str] 47 if values: 48 value = values[0].strip() 49 else: 50 value = None 51 if not value: 52 raise ValueError( 53 f"The rendezvous configuration option '{key}' must have a value specified." 54 ) 55 56 config[key] = value 57 return config 58 59 60def _try_parse_port(port_str: str) -> Optional[int]: 61 """Try to extract the port number from ``port_str``.""" 62 if port_str and re.match(r"^[0-9]{1,5}$", port_str): 63 return int(port_str) 64 return None 65 66 67def parse_rendezvous_endpoint( 68 endpoint: Optional[str], default_port: int 69) -> Tuple[str, int]: 70 """Extract the hostname and the port number from a rendezvous endpoint. 71 72 Args: 73 endpoint: 74 A string in format <hostname>[:<port>]. 75 default_port: 76 The port number to use if the endpoint does not include one. 77 78 Returns: 79 A tuple of hostname and port number. 80 """ 81 if endpoint is not None: 82 endpoint = endpoint.strip() 83 84 if not endpoint: 85 return ("localhost", default_port) 86 87 # An endpoint that starts and ends with brackets represents an IPv6 address. 88 if endpoint[0] == "[" and endpoint[-1] == "]": 89 host, *rest = endpoint, *[] 90 else: 91 host, *rest = endpoint.rsplit(":", 1) 92 93 # Sanitize the IPv6 address. 94 if len(host) > 1 and host[0] == "[" and host[-1] == "]": 95 host = host[1:-1] 96 97 if len(rest) == 1: 98 port = _try_parse_port(rest[0]) 99 if port is None or port >= 2**16: 100 raise ValueError( 101 f"The port number of the rendezvous endpoint '{endpoint}' must be an integer " 102 "between 0 and 65536." 103 ) 104 else: 105 port = default_port 106 107 if not re.match(r"^[\w\.:-]+$", host): 108 raise ValueError( 109 f"The hostname of the rendezvous endpoint '{endpoint}' must be a dot-separated list of " 110 "labels, an IPv4 address, or an IPv6 address." 111 ) 112 113 return host, port 114 115 116def _matches_machine_hostname(host: str) -> bool: 117 """Indicate whether ``host`` matches the hostname of this machine. 118 119 This function compares ``host`` to the hostname as well as to the IP 120 addresses of this machine. Note that it may return a false negative if this 121 machine has CNAME records beyond its FQDN or IP addresses assigned to 122 secondary NICs. 123 """ 124 if host == "localhost": 125 return True 126 127 try: 128 addr = ipaddress.ip_address(host) 129 except ValueError: 130 addr = None 131 132 if addr and addr.is_loopback: 133 return True 134 135 try: 136 host_addr_list = socket.getaddrinfo( 137 host, None, proto=socket.IPPROTO_TCP, flags=socket.AI_CANONNAME 138 ) 139 except (ValueError, socket.gaierror) as _: 140 host_addr_list = [] 141 142 host_ip_list = [host_addr_info[4][0] for host_addr_info in host_addr_list] 143 144 this_host = socket.gethostname() 145 if host == this_host: 146 return True 147 148 addr_list = socket.getaddrinfo( 149 this_host, None, proto=socket.IPPROTO_TCP, flags=socket.AI_CANONNAME 150 ) 151 for addr_info in addr_list: 152 # If we have an FQDN in the addr_info, compare it to `host`. 153 if addr_info[3] and addr_info[3] == host: 154 return True 155 156 # Otherwise if `host` represents an IP address, compare it to our IP 157 # address. 158 if addr and addr_info[4][0] == str(addr): 159 return True 160 161 # If the IP address matches one of the provided host's IP addresses 162 if addr_info[4][0] in host_ip_list: 163 return True 164 165 return False 166 167 168def _delay(seconds: Union[float, Tuple[float, float]]) -> None: 169 """Suspend the current thread for ``seconds``. 170 171 Args: 172 seconds: 173 Either the delay, in seconds, or a tuple of a lower and an upper 174 bound within which a random delay will be picked. 175 """ 176 if isinstance(seconds, tuple): 177 seconds = random.uniform(*seconds) 178 # Ignore delay requests that are less than 10 milliseconds. 179 if seconds >= 0.01: 180 time.sleep(seconds) 181 182 183class _PeriodicTimer: 184 """Represent a timer that periodically runs a specified function. 185 186 Args: 187 interval: 188 The interval, in seconds, between each run. 189 function: 190 The function to run. 191 """ 192 193 # The state of the timer is hold in a separate context object to avoid a 194 # reference cycle between the timer and the background thread. 195 class _Context: 196 interval: float 197 function: Callable[..., None] 198 args: Tuple[Any, ...] 199 kwargs: Dict[str, Any] 200 stop_event: Event 201 202 _name: Optional[str] 203 _thread: Optional[Thread] 204 _finalizer: Optional[weakref.finalize] 205 206 # The context that is shared between the timer and the background thread. 207 _ctx: _Context 208 209 def __init__( 210 self, 211 interval: timedelta, 212 function: Callable[..., None], 213 *args: Any, 214 **kwargs: Any, 215 ) -> None: 216 self._name = None 217 218 self._ctx = self._Context() 219 self._ctx.interval = interval.total_seconds() 220 self._ctx.function = function # type: ignore[assignment] 221 self._ctx.args = args or () 222 self._ctx.kwargs = kwargs or {} 223 self._ctx.stop_event = Event() 224 225 self._thread = None 226 self._finalizer = None 227 228 @property 229 def name(self) -> Optional[str]: 230 """Get the name of the timer.""" 231 return self._name 232 233 def set_name(self, name: str) -> None: 234 """Set the name of the timer. 235 236 The specified name will be assigned to the background thread and serves 237 for debugging and troubleshooting purposes. 238 """ 239 if self._thread: 240 raise RuntimeError("The timer has already started.") 241 242 self._name = name 243 244 def start(self) -> None: 245 """Start the timer.""" 246 if self._thread: 247 raise RuntimeError("The timer has already started.") 248 249 self._thread = Thread( 250 target=self._run, 251 name=self._name or "PeriodicTimer", 252 args=(self._ctx,), 253 daemon=True, 254 ) 255 256 # We avoid using a regular finalizer (a.k.a. __del__) for stopping the 257 # timer as joining a daemon thread during the interpreter shutdown can 258 # cause deadlocks. The weakref.finalize is a superior alternative that 259 # provides a consistent behavior regardless of the GC implementation. 260 self._finalizer = weakref.finalize( 261 self, self._stop_thread, self._thread, self._ctx.stop_event 262 ) 263 264 # We do not attempt to stop our background thread during the interpreter 265 # shutdown. At that point we do not even know whether it still exists. 266 self._finalizer.atexit = False 267 268 self._thread.start() 269 270 def cancel(self) -> None: 271 """Stop the timer at the next opportunity.""" 272 if self._finalizer: 273 self._finalizer() 274 275 @staticmethod 276 def _run(ctx) -> None: 277 while not ctx.stop_event.wait(ctx.interval): 278 ctx.function(*ctx.args, **ctx.kwargs) 279 280 @staticmethod 281 def _stop_thread(thread, stop_event): 282 stop_event.set() 283 284 thread.join() 285