1# mypy: allow-untyped-defs 2from typing import Optional, Union 3 4import torch 5 6 7class _remote_device: 8 """ 9 Represents a device on a remote worker. 10 11 Args: 12 remote_device (str or torch.device): Represents a device on a remote worker. 13 The string format should be one of the following: 14 15 1. "<workername>/<device>", where the device field can be parsed as torch.device type. 16 E.g., "trainer0/cpu", "trainer0", "ps0/cuda:0". 17 In addition, the device field can be optional and the default value is "cpu". 18 2. "rank:<rank>/<device>", where <rank> is the rank of the 19 process and device can be parsed as torch.device type. 20 E.g., "rank:0/cpu", "rank:0", "rank:0/cuda:0" 21 3. <workername> and <rank> are optional and formats like "cpu" 22 and "cuda:1", just represent local devices. 23 """ 24 25 def __init__(self, remote_device: Union[str, torch.device]): 26 PARSE_ERROR = ( 27 f"Could not parse remote_device: {remote_device}. The valid format is " 28 "'<workername>/<device>' or 'rank:<rank>/<device>' or '<device>'" 29 ) 30 self._worker_name = None 31 self._rank = None 32 self._device: Optional[Union[str, int, torch.device]] = None 33 34 if isinstance(remote_device, torch.device): 35 self._device = remote_device 36 elif isinstance(remote_device, str): 37 fields = remote_device.split("/") 38 if len(fields) == 2: 39 self._worker_name, self._device = fields 40 elif len(fields) == 1: 41 # Check if this is a valid device. 42 if _remote_device._is_valid_local_device(fields[0]): 43 self._device = fields[0] 44 else: 45 self._worker_name = fields[0] 46 self._device = "cpu" 47 else: 48 raise ValueError(PARSE_ERROR) 49 else: 50 raise TypeError(f"Invalid type for remote_device: {type(remote_device)}") 51 52 # Do some basic sanity check (no empty string) 53 if self._worker_name is not None and not self._worker_name: 54 raise ValueError(PARSE_ERROR) 55 56 # Validate the device. 57 self._device = torch.device(self._device) 58 59 # Check for rank based format. 60 if self._worker_name is not None: 61 fields = self._worker_name.split(":") 62 if len(fields) == 2: 63 # rank:<rank>/device format, extract rank 64 if fields[0] == "rank" and fields[1].isdigit(): 65 self._rank = int(fields[1]) # type: ignore[assignment] 66 self._worker_name = None 67 else: 68 raise ValueError(PARSE_ERROR) 69 elif len(fields) > 2: 70 raise ValueError(PARSE_ERROR) 71 72 @staticmethod 73 def _is_valid_local_device(device): 74 # Check for torch.device 75 try: 76 torch.device(device) 77 return True 78 except Exception: 79 return False 80 81 def worker_name(self) -> Optional[str]: 82 """Return the name of remote worker representing the remote device and ``None`` if no worker name is available.""" 83 return self._worker_name 84 85 def rank(self) -> Optional[int]: 86 """ 87 Returns the rank of remote worker representing the remote device. 88 Returns ``None`` if no rank is available. 89 """ 90 return self._rank 91 92 def device(self) -> torch.device: 93 """Return the local device on the remote worker.""" 94 return self._device # type: ignore[return-value] 95 96 def __repr__(self): 97 if self._device is not None: 98 if self._worker_name is not None: 99 return f"{self._worker_name}/{self._device}" 100 elif self._rank is not None: 101 return f"rank:{self._rank}/{self._device}" 102 else: 103 return str(self._device) 104 else: 105 if self._worker_name is not None: 106 return f"{self._worker_name}" 107 elif self._rank is not None: 108 return f"{self._rank}" 109 else: 110 raise RuntimeError("Invalid state!") 111 112 def __eq__(self, other): 113 return isinstance(other, _remote_device) and ( 114 self._worker_name == other._worker_name 115 and self._device == other._device 116 and self._rank == other._rank 117 ) 118 119 def __hash__(self): 120 return hash(self._worker_name) ^ hash(self._device) ^ hash(self._rank) 121