1# mypy: allow-untyped-defs 2from typing import Dict, List, Optional, Union 3 4import torch 5from torch._C._distributed_rpc import _TensorPipeRpcBackendOptionsBase 6 7from . import constants as rpc_contants 8 9 10DeviceType = Union[int, str, torch.device] 11 12__all__ = ["TensorPipeRpcBackendOptions"] 13 14 15def _to_device(device: DeviceType) -> torch.device: 16 device = torch.device(device) 17 if device.type != "cuda": 18 raise ValueError( 19 "`set_devices` expect a list of CUDA devices, but got " 20 f"device type {device.type}." 21 ) 22 return device 23 24 25def _to_device_map( 26 device_map: Dict[DeviceType, DeviceType] 27) -> Dict[torch.device, torch.device]: 28 full_device_map: Dict[torch.device, torch.device] = {} 29 reverse_map: Dict[torch.device, torch.device] = {} 30 for k, v in device_map.items(): 31 k, v = torch.device(k), torch.device(v) 32 if v in reverse_map: 33 raise ValueError( 34 "`device_map` only supports 1-to-1 mapping, " 35 f"trying to map {k} and {reverse_map[v]} to {v}" 36 ) 37 full_device_map[k] = v 38 reverse_map[v] = k 39 return full_device_map 40 41 42def _to_device_list(devices: List[DeviceType]) -> List[torch.device]: 43 return list(map(_to_device, devices)) 44 45 46class TensorPipeRpcBackendOptions(_TensorPipeRpcBackendOptionsBase): 47 r""" 48 The backend options for 49 :class:`~torch.distributed.rpc.TensorPipeAgent`, derived from 50 :class:`~torch.distributed.rpc.RpcBackendOptions`. 51 52 Args: 53 num_worker_threads (int, optional): The number of threads in the 54 thread-pool used by 55 :class:`~torch.distributed.rpc.TensorPipeAgent` to execute 56 requests (default: 16). 57 rpc_timeout (float, optional): The default timeout, in seconds, 58 for RPC requests (default: 60 seconds). If the RPC has not 59 completed in this timeframe, an exception indicating so will 60 be raised. Callers can override this timeout for individual 61 RPCs in :meth:`~torch.distributed.rpc.rpc_sync` and 62 :meth:`~torch.distributed.rpc.rpc_async` if necessary. 63 init_method (str, optional): The URL to initialize the distributed 64 store used for rendezvous. It takes any value accepted for the 65 same argument of :meth:`~torch.distributed.init_process_group` 66 (default: ``env://``). 67 device_maps (Dict[str, Dict], optional): Device placement mappings from 68 this worker to the callee. Key is the callee worker name and value 69 the dictionary (``Dict`` of ``int``, ``str``, or ``torch.device``) 70 that maps this worker's devices to the callee worker's devices. 71 (default: ``None``) 72 devices (List[int, str, or ``torch.device``], optional): all local 73 CUDA devices used by RPC agent. By Default, it will be initialized 74 to all local devices from its own ``device_maps`` and corresponding 75 devices from its peers' ``device_maps``. When processing CUDA RPC 76 requests, the agent will properly synchronize CUDA streams for 77 all devices in this ``List``. 78 """ 79 80 def __init__( 81 self, 82 *, 83 num_worker_threads: int = rpc_contants.DEFAULT_NUM_WORKER_THREADS, 84 rpc_timeout: float = rpc_contants.DEFAULT_RPC_TIMEOUT_SEC, 85 init_method: str = rpc_contants.DEFAULT_INIT_METHOD, 86 device_maps: Optional[Dict[str, Dict[DeviceType, DeviceType]]] = None, 87 devices: Optional[List[DeviceType]] = None, 88 _transports: Optional[List] = None, 89 _channels: Optional[List] = None, 90 ): 91 full_device_maps = ( 92 {} 93 if device_maps is None 94 else {k: _to_device_map(v) for k, v in device_maps.items()} 95 ) 96 full_device_list = [] if devices is None else _to_device_list(devices) 97 super().__init__( 98 num_worker_threads, 99 _transports, 100 _channels, 101 rpc_timeout, 102 init_method, 103 full_device_maps, 104 full_device_list, 105 ) 106 107 def set_device_map(self, to: str, device_map: Dict[DeviceType, DeviceType]): 108 r""" 109 Set device mapping between each RPC caller and callee pair. This 110 function can be called multiple times to incrementally add 111 device placement configurations. 112 113 Args: 114 to (str): Callee name. 115 device_map (Dict of int, str, or torch.device): Device placement 116 mappings from this worker to the callee. This map must be 117 invertible. 118 119 Example: 120 >>> # xdoctest: +SKIP("distributed") 121 >>> # both workers 122 >>> def add(x, y): 123 >>> print(x) # tensor([1., 1.], device='cuda:1') 124 >>> return x + y, (x + y).to(2) 125 >>> 126 >>> # on worker 0 127 >>> options = TensorPipeRpcBackendOptions( 128 >>> num_worker_threads=8, 129 >>> device_maps={"worker1": {0: 1}} 130 >>> # maps worker0's cuda:0 to worker1's cuda:1 131 >>> ) 132 >>> options.set_device_map("worker1", {1: 2}) 133 >>> # maps worker0's cuda:1 to worker1's cuda:2 134 >>> 135 >>> rpc.init_rpc( 136 >>> "worker0", 137 >>> rank=0, 138 >>> world_size=2, 139 >>> backend=rpc.BackendType.TENSORPIPE, 140 >>> rpc_backend_options=options 141 >>> ) 142 >>> 143 >>> x = torch.ones(2) 144 >>> rets = rpc.rpc_sync("worker1", add, args=(x.to(0), 1)) 145 >>> # The first argument will be moved to cuda:1 on worker1. When 146 >>> # sending the return value back, it will follow the invert of 147 >>> # the device map, and hence will be moved back to cuda:0 and 148 >>> # cuda:1 on worker0 149 >>> print(rets[0]) # tensor([2., 2.], device='cuda:0') 150 >>> print(rets[1]) # tensor([2., 2.], device='cuda:1') 151 """ 152 full_device_map = _to_device_map(device_map) 153 curr_device_maps = super().device_maps 154 155 if to in curr_device_maps: 156 for k, v in full_device_map.items(): 157 if k in curr_device_maps[to] and v != curr_device_maps[to][k]: 158 raise ValueError( 159 "`set_device_map` only supports 1-to-1 mapping, trying" 160 f" to map {k} to {v} and {curr_device_maps[to][k]}" 161 ) 162 163 super()._set_device_map(to, full_device_map) 164 165 def set_devices(self, devices: List[DeviceType]): 166 r""" 167 Set local devices used by the TensorPipe RPC agent. When processing 168 CUDA RPC requests, the TensorPipe RPC agent will properly synchronize 169 CUDA streams for all devices in this ``List``. 170 171 Args: 172 devices (List of int, str, or torch.device): local devices used by 173 the TensorPipe RPC agent. 174 """ 175 self.devices = _to_device_list(devices) 176