xref: /aosp_15_r20/external/pytorch/torch/distributed/rpc/options.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# mypy: allow-untyped-defs
2from typing import Dict, List, Optional, Union
3
4import torch
5from torch._C._distributed_rpc import _TensorPipeRpcBackendOptionsBase
6
7from . import constants as rpc_contants
8
9
10DeviceType = Union[int, str, torch.device]
11
12__all__ = ["TensorPipeRpcBackendOptions"]
13
14
15def _to_device(device: DeviceType) -> torch.device:
16    device = torch.device(device)
17    if device.type != "cuda":
18        raise ValueError(
19            "`set_devices` expect a list of CUDA devices, but got "
20            f"device type {device.type}."
21        )
22    return device
23
24
25def _to_device_map(
26    device_map: Dict[DeviceType, DeviceType]
27) -> Dict[torch.device, torch.device]:
28    full_device_map: Dict[torch.device, torch.device] = {}
29    reverse_map: Dict[torch.device, torch.device] = {}
30    for k, v in device_map.items():
31        k, v = torch.device(k), torch.device(v)
32        if v in reverse_map:
33            raise ValueError(
34                "`device_map` only supports 1-to-1 mapping, "
35                f"trying to map {k} and {reverse_map[v]} to {v}"
36            )
37        full_device_map[k] = v
38        reverse_map[v] = k
39    return full_device_map
40
41
42def _to_device_list(devices: List[DeviceType]) -> List[torch.device]:
43    return list(map(_to_device, devices))
44
45
46class TensorPipeRpcBackendOptions(_TensorPipeRpcBackendOptionsBase):
47    r"""
48    The backend options for
49    :class:`~torch.distributed.rpc.TensorPipeAgent`, derived from
50    :class:`~torch.distributed.rpc.RpcBackendOptions`.
51
52    Args:
53        num_worker_threads (int, optional): The number of threads in the
54            thread-pool used by
55            :class:`~torch.distributed.rpc.TensorPipeAgent` to execute
56            requests (default: 16).
57        rpc_timeout (float, optional): The default timeout, in seconds,
58            for RPC requests (default: 60 seconds). If the RPC has not
59            completed in this timeframe, an exception indicating so will
60            be raised. Callers can override this timeout for individual
61            RPCs in :meth:`~torch.distributed.rpc.rpc_sync` and
62            :meth:`~torch.distributed.rpc.rpc_async` if necessary.
63        init_method (str, optional): The URL to initialize the distributed
64            store used for rendezvous. It takes any value accepted for the
65            same argument of :meth:`~torch.distributed.init_process_group`
66            (default: ``env://``).
67        device_maps (Dict[str, Dict], optional): Device placement mappings from
68            this worker to the callee. Key is the callee worker name and value
69            the dictionary (``Dict`` of ``int``, ``str``, or ``torch.device``)
70            that maps this worker's devices to the callee worker's devices.
71            (default: ``None``)
72        devices (List[int, str, or ``torch.device``], optional): all local
73            CUDA devices used by RPC agent. By Default, it will be initialized
74            to all local devices from its own ``device_maps`` and corresponding
75            devices from its peers' ``device_maps``. When processing CUDA RPC
76            requests, the agent will properly synchronize CUDA streams for
77            all devices in this ``List``.
78    """
79
80    def __init__(
81        self,
82        *,
83        num_worker_threads: int = rpc_contants.DEFAULT_NUM_WORKER_THREADS,
84        rpc_timeout: float = rpc_contants.DEFAULT_RPC_TIMEOUT_SEC,
85        init_method: str = rpc_contants.DEFAULT_INIT_METHOD,
86        device_maps: Optional[Dict[str, Dict[DeviceType, DeviceType]]] = None,
87        devices: Optional[List[DeviceType]] = None,
88        _transports: Optional[List] = None,
89        _channels: Optional[List] = None,
90    ):
91        full_device_maps = (
92            {}
93            if device_maps is None
94            else {k: _to_device_map(v) for k, v in device_maps.items()}
95        )
96        full_device_list = [] if devices is None else _to_device_list(devices)
97        super().__init__(
98            num_worker_threads,
99            _transports,
100            _channels,
101            rpc_timeout,
102            init_method,
103            full_device_maps,
104            full_device_list,
105        )
106
107    def set_device_map(self, to: str, device_map: Dict[DeviceType, DeviceType]):
108        r"""
109        Set device mapping between each RPC caller and callee pair. This
110        function can be called multiple times to incrementally add
111        device placement configurations.
112
113        Args:
114            to (str): Callee name.
115            device_map (Dict of int, str, or torch.device): Device placement
116                mappings from this worker to the callee. This map must be
117                invertible.
118
119        Example:
120            >>> # xdoctest: +SKIP("distributed")
121            >>> # both workers
122            >>> def add(x, y):
123            >>>     print(x)  # tensor([1., 1.], device='cuda:1')
124            >>>     return x + y, (x + y).to(2)
125            >>>
126            >>> # on worker 0
127            >>> options = TensorPipeRpcBackendOptions(
128            >>>     num_worker_threads=8,
129            >>>     device_maps={"worker1": {0: 1}}
130            >>>     # maps worker0's cuda:0 to worker1's cuda:1
131            >>> )
132            >>> options.set_device_map("worker1", {1: 2})
133            >>> # maps worker0's cuda:1 to worker1's cuda:2
134            >>>
135            >>> rpc.init_rpc(
136            >>>     "worker0",
137            >>>     rank=0,
138            >>>     world_size=2,
139            >>>     backend=rpc.BackendType.TENSORPIPE,
140            >>>     rpc_backend_options=options
141            >>> )
142            >>>
143            >>> x = torch.ones(2)
144            >>> rets = rpc.rpc_sync("worker1", add, args=(x.to(0), 1))
145            >>> # The first argument will be moved to cuda:1 on worker1. When
146            >>> # sending the return value back, it will follow the invert of
147            >>> # the device map, and hence will be moved back to cuda:0 and
148            >>> # cuda:1 on worker0
149            >>> print(rets[0])  # tensor([2., 2.], device='cuda:0')
150            >>> print(rets[1])  # tensor([2., 2.], device='cuda:1')
151        """
152        full_device_map = _to_device_map(device_map)
153        curr_device_maps = super().device_maps
154
155        if to in curr_device_maps:
156            for k, v in full_device_map.items():
157                if k in curr_device_maps[to] and v != curr_device_maps[to][k]:
158                    raise ValueError(
159                        "`set_device_map` only supports 1-to-1 mapping, trying"
160                        f" to map {k} to {v} and {curr_device_maps[to][k]}"
161                    )
162
163        super()._set_device_map(to, full_device_map)
164
165    def set_devices(self, devices: List[DeviceType]):
166        r"""
167        Set local devices used by the TensorPipe RPC agent. When processing
168        CUDA RPC requests, the TensorPipe RPC agent will properly synchronize
169        CUDA streams for all devices in this ``List``.
170
171        Args:
172            devices (List of int, str, or torch.device): local devices used by
173                the TensorPipe RPC agent.
174        """
175        self.devices = _to_device_list(devices)
176