xref: /aosp_15_r20/external/pytorch/torch/distributed/rpc/backend_registry.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# mypy: allow-untyped-defs
2
3
4import collections
5import enum
6from typing import cast, Dict, List, Set, Tuple
7
8import torch
9import torch.distributed as dist
10
11from . import api, constants as rpc_constants
12from ._utils import _group_membership_management, _update_group_membership
13
14
15__all__ = [
16    "backend_registered",
17    "register_backend",
18    "construct_rpc_backend_options",
19    "init_backend",
20    "BackendValue",
21    "BackendType",
22]
23
24BackendValue = collections.namedtuple(
25    "BackendValue", ["construct_rpc_backend_options_handler", "init_backend_handler"]
26)
27
28
29def _backend_type_repr(self):
30    return "BackendType." + self.name
31
32
33_backend_type_doc = """
34    An enum class of available backends.
35
36    PyTorch ships with a builtin ``BackendType.TENSORPIPE`` backend.
37    Additional ones can be registered using the
38    :func:`~torch.distributed.rpc.backend_registry.register_backend` function.
39"""
40
41# Create an enum type, `BackendType`, with empty members.
42# Can't handle Function Enum API (mypy bug #9079)
43BackendType = enum.Enum(value="BackendType", names={})  # type: ignore[misc]
44# Unable to assign a function a method (mypy bug #2427)
45BackendType.__repr__ = _backend_type_repr  # type: ignore[assignment]
46
47if BackendType.__doc__:
48    BackendType.__doc__ = _backend_type_doc
49
50
51def backend_registered(backend_name):
52    """
53    Checks if backend_name is registered as an RPC backend.
54
55    Args:
56        backend_name (str): string to identify the RPC backend.
57    Returns:
58        True if the backend has been registered with ``register_backend``, else
59        False.
60    """
61    return backend_name in BackendType.__members__.keys()
62
63
64def register_backend(
65    backend_name, construct_rpc_backend_options_handler, init_backend_handler
66):
67    """Registers a new RPC backend.
68
69    Args:
70        backend_name (str): backend string to identify the handler.
71        construct_rpc_backend_options_handler (function):
72            Handler that is invoked when
73            rpc_backend.construct_rpc_backend_options(**dict) is called.
74        init_backend_handler (function): Handler that is invoked when the
75            `_init_rpc_backend()` function is called with a backend.
76             This returns the agent.
77    """
78    global BackendType
79    if backend_registered(backend_name):
80        raise RuntimeError(f"RPC backend {backend_name}: already registered")
81    # Create a new enum type, `BackendType`, with extended members.
82    existing_enum_dict = {member.name: member.value for member in BackendType}
83    extended_enum_dict = dict(
84        {
85            backend_name: BackendValue(
86                construct_rpc_backend_options_handler=construct_rpc_backend_options_handler,
87                init_backend_handler=init_backend_handler,
88            )
89        },
90        **existing_enum_dict,
91    )
92    # Can't handle Function Enum API (mypy bug #9079)
93    BackendType = enum.Enum(value="BackendType", names=extended_enum_dict)  # type: ignore[misc]
94    # Unable to assign a function a method (mypy bug #2427)
95    BackendType.__repr__ = _backend_type_repr  # type: ignore[assignment]
96    if BackendType.__doc__:
97        BackendType.__doc__ = _backend_type_doc
98    return BackendType[backend_name]
99
100
101def construct_rpc_backend_options(
102    backend,
103    rpc_timeout=rpc_constants.DEFAULT_RPC_TIMEOUT_SEC,
104    init_method=rpc_constants.DEFAULT_INIT_METHOD,
105    **kwargs,
106):
107    return backend.value.construct_rpc_backend_options_handler(
108        rpc_timeout, init_method, **kwargs
109    )
110
111
112def init_backend(backend, *args, **kwargs):
113    return backend.value.init_backend_handler(*args, **kwargs)
114
115
116def _init_process_group(store, rank, world_size):
117    # Initialize ProcessGroup.
118    process_group_timeout = rpc_constants.DEFAULT_PROCESS_GROUP_TIMEOUT
119
120    # We're using a bunch of private APIs here since `new_group` requires the
121    # default group to be initialized.
122    group = dist.ProcessGroupGloo(store, rank, world_size, process_group_timeout)
123
124    assert group is not None, "Failed to initialize default ProcessGroup."
125
126    if (rank != -1) and (rank != group.rank()):
127        raise RuntimeError(f"rank argument {rank} doesn't match pg rank {group.rank()}")
128    if (world_size != -1) and (world_size != group.size()):
129        raise RuntimeError(
130            f"world_size argument {world_size} doesn't match pg size {group.size()}"
131        )
132    return group
133
134
135def _tensorpipe_construct_rpc_backend_options_handler(
136    rpc_timeout,
137    init_method,
138    num_worker_threads=rpc_constants.DEFAULT_NUM_WORKER_THREADS,
139    _transports=None,
140    _channels=None,
141    **kwargs,
142):
143    from . import TensorPipeRpcBackendOptions
144
145    return TensorPipeRpcBackendOptions(
146        rpc_timeout=rpc_timeout,
147        init_method=init_method,
148        num_worker_threads=num_worker_threads,
149        _transports=_transports,
150        _channels=_channels,
151    )
152
153
154def _tensorpipe_validate_devices(devices, device_count):
155    return all(
156        d.type == "cpu" or (d.type == "cuda" and 0 <= d.index < device_count)
157        for d in devices
158    )
159
160
161# detect if any worker has invalid device_map configurations, and return
162# reverse device maps
163def _tensorpipe_exchange_and_check_all_device_maps(
164    my_name, my_device_count, my_device_maps, my_devices, group
165):
166    gathered: List[
167        Tuple[str, int, Dict[str, Dict[torch.device, torch.device]], List[torch.device]]
168    ] = [("", 0, {}, []) for _ in range(group.size())]
169    dist.all_gather_object(
170        gathered, (my_name, my_device_count, my_device_maps, my_devices), group
171    )
172    all_names = [name for name, _, _, _ in gathered]
173    all_device_counts = {name: count for name, count, _, _ in gathered}
174    all_device_maps = {name: map_ for name, _, map_, _ in gathered}
175    all_devices = {name: devices for name, _, _, devices in gathered}
176
177    _validate_device_maps(all_names, all_device_counts, all_device_maps, all_devices)
178
179    # passed all checked, construct reverse mapping and get list of devices handled by this agent
180    reverse_device_maps = _create_reverse_mapping(my_name, all_names, all_device_maps)
181    my_devices = _create_device_list(my_devices, my_device_maps, reverse_device_maps)
182    return reverse_device_maps, my_devices
183
184
185def _validate_device_maps(
186    all_names, all_device_counts, all_device_maps, all_devices, is_static_group=True
187):
188    for node in all_names:
189        devices = all_devices[node]
190        if len(set(devices)) != len(devices):
191            raise ValueError(
192                f"Node {node} has duplicated devices\n" f"devices = {devices}"
193            )
194        if not _tensorpipe_validate_devices(devices, all_device_counts[node]):
195            raise ValueError(
196                f"Node {node} has devices with invalid indices\n"
197                f"devices = {devices}\n"
198                f"device count = {all_device_counts[node]}"
199            )
200
201    for source_node in all_names:
202        # For dynamic group (non-static) do not check the target node name since it may not have joined yet
203        if is_static_group and not set(all_device_maps[source_node].keys()).issubset(
204            all_names
205        ):
206            raise ValueError(
207                f"Node {source_node} has invalid target node names in its device maps\n"
208                f"device maps = {all_device_maps[source_node].keys()}\n"
209                f"node names = {all_names}"
210            )
211        for target_node, map_ in all_device_maps[source_node].items():
212            if len(set(map_.values())) != len(map_):
213                raise ValueError(
214                    f"Node {source_node} has duplicated target devices "
215                    f"in its device map for {target_node}\n"
216                    f"device map = {map_}"
217                )
218            if all_devices[source_node]:
219                if not set(map_.keys()).issubset(all_devices[source_node]):
220                    raise ValueError(
221                        f"Node {source_node} has unexpected source devices "
222                        f"in its device map for {target_node}\n"
223                        f"device map = {map_}\n"
224                        f"devices = {all_devices[source_node]}"
225                    )
226            elif not _tensorpipe_validate_devices(
227                map_.keys(), all_device_counts[source_node]
228            ):
229                raise ValueError(
230                    f"Node {source_node} has source devices with invalid indices "
231                    f"in its device map for {target_node}\n"
232                    f"device map = {map_}\n"
233                    f"device count = {all_device_counts[source_node]}"
234                )
235            if all_devices.get(target_node, []):
236                if not set(map_.values()).issubset(all_devices[target_node]):
237                    raise ValueError(
238                        f"Node {source_node} has unexpected target devices "
239                        f"in its device map for {target_node}\n"
240                        f"device map = {map_}\n"
241                        f"devices = {all_devices[target_node]}"
242                    )
243            elif target_node in all_device_counts and not _tensorpipe_validate_devices(
244                map_.values(), all_device_counts[target_node]
245            ):
246                raise ValueError(
247                    f"Node {source_node} has target devices with invalid indices "
248                    f"in its device map for {target_node}\n"
249                    f"device map = {map_}\n"
250                    f"device count = {all_device_counts[target_node]}"
251                )
252
253
254def _create_device_list(my_devices, my_device_maps, reverse_device_maps):
255    if not my_devices:
256        devices_set: Set[torch.device] = set()
257        for map_ in my_device_maps.values():
258            devices_set.update(map_.keys())
259        for map_ in reverse_device_maps.values():
260            devices_set.update(map_.keys())
261        devices_set.discard(torch.device("cpu"))
262        my_devices = list(devices_set)
263    my_devices = sorted(my_devices, key=lambda d: d.index)
264    return my_devices
265
266
267def _create_reverse_mapping(my_name, all_names, all_device_maps):
268    reverse_device_maps: Dict[str, Dict[torch.device, torch.device]] = {}
269    for node in all_names:
270        if my_name in all_device_maps[node]:
271            reverse_device_maps[node] = {
272                v: k for k, v in all_device_maps[node][my_name].items()
273            }
274    return reverse_device_maps
275
276
277def _get_device_infos():
278    from . import TensorPipeAgent
279
280    agent = cast(TensorPipeAgent, api._get_current_rpc_agent())
281    opts = agent._get_backend_options()
282    device_count = torch.cuda.device_count()
283    if torch.cuda.is_available() and opts.devices:
284        torch.cuda.init()
285    return device_count, opts.device_maps, opts.devices
286
287
288def _set_devices_and_reverse_device_map(agent):
289    from . import TensorPipeAgent
290
291    agent = cast(TensorPipeAgent, agent)
292    # Group state is retrieved from local agent
293    # On initialization, tensorpipe agent retrieves information from all existing workers, so group state is valid
294    my_worker_info = agent.get_worker_info()
295    my_name = my_worker_info.name
296    all_worker_infos = agent.get_worker_infos()
297    # One round to get device_maps of all workers and construct reverse device maps
298    all_device_counts, all_device_maps, all_devices, all_names = {}, {}, {}, []
299    for worker_info in all_worker_infos:
300        worker_name = worker_info.name
301        if worker_name != my_name:
302            # TODO: make async?
303            device_count, device_map, devices = api.rpc_sync(
304                worker_name, _get_device_infos
305            )
306        else:
307            opts = agent._get_backend_options()
308            device_count, device_map, devices = (
309                torch.cuda.device_count(),
310                opts.device_maps,
311                opts.devices,
312            )
313        all_device_counts[worker_name] = device_count
314        all_device_maps[worker_name] = device_map
315        all_devices[worker_name] = devices
316        all_names.append(worker_name)
317
318    _validate_device_maps(
319        all_names,
320        all_device_counts,
321        all_device_maps,
322        all_devices,
323        is_static_group=False,
324    )
325    reverse_device_maps = _create_reverse_mapping(my_name, all_names, all_device_maps)
326
327    # Perform RPC call to all workers, including itself, to include newly joined worker information and device maps
328    for worker_name in all_names:
329        # Set device list for each worker
330        all_devices[worker_name] = _create_device_list(
331            all_devices[worker_name], all_device_maps[worker_name], reverse_device_maps
332        )
333        api.rpc_sync(
334            worker_name,
335            _update_group_membership,
336            args=(my_worker_info, all_devices[worker_name], reverse_device_maps, True),
337        )
338
339
340def _tensorpipe_init_backend_handler(
341    store, name, rank, world_size, rpc_backend_options
342):
343    from . import TensorPipeAgent, TensorPipeRpcBackendOptions
344
345    if not isinstance(store, dist.Store):
346        raise TypeError(f"`store` must be a c10d::Store. {store}")
347
348    if not isinstance(rpc_backend_options, TensorPipeRpcBackendOptions):
349        raise TypeError(
350            f"`rpc_backend_options` must be a `TensorPipeRpcBackendOptions`. {rpc_backend_options}"
351        )
352
353    device_count = torch.cuda.device_count()
354
355    is_static_group = True if world_size else False
356    # world_size is specified so this is a static group (ranks cannot join and leave)
357    if is_static_group:
358        # The agent's join method is required to behave like a barrier and perform
359        # collective operations, for which it relies on a process group, instead of
360        # re-implementing this on top of RPCs.
361        group = _init_process_group(store, rank, world_size)
362
363        reverse_device_maps, devices = _tensorpipe_exchange_and_check_all_device_maps(
364            name,
365            device_count,
366            rpc_backend_options.device_maps,
367            rpc_backend_options.devices,
368            group,
369        )
370
371        if torch.cuda.is_available() and devices:
372            # It's necessary to initialize PyTorch CUDA states here (e.g.,
373            # CUDACachingAllocator). If this is missing, we could hit errors like
374            # "allocator not initialized", because other processes might send
375            # CUDA-related RPC request to this process before user code in this
376            # process initializes its PyTorch CUDA states.
377            torch.cuda.init()
378
379        # TODO: add try-except and destroy _agent in all processes if any fails.
380        agent = TensorPipeAgent(
381            store,
382            name,
383            rank,
384            world_size,
385            rpc_backend_options,
386            reverse_device_maps,
387            devices,
388        )
389
390        api._init_rpc_states(agent)
391
392        # Run one dummy round of RPC to initialize channels/transports. Without
393        # this, it's easy to hit timeout in rpc.shutdown() if there is no other RPC
394        # on that process before rpc.shutdown(), as the agent initialization can
395        # take longer than 5s.
396        api._all_gather(None, timeout=rpc_backend_options.rpc_timeout)
397        # Need a barrier here to make sure no peers leave before the rank0 finishes
398        # _all_gather
399        group.barrier().wait()
400
401        return agent
402    # initialization for dynamic rpc (ranks can join and leave)
403    else:
404        with _group_membership_management(store, name, True):
405            # Construct TPAgent with empty reverse_device_map and devices
406            # these properties will be updated after initialization
407            agent = TensorPipeAgent(
408                store,
409                name,
410                rank,
411                world_size,
412                rpc_backend_options,
413                {},
414                [],
415            )
416            api._init_rpc_states(agent)
417
418            try:
419                # Notify all workers in group this rank has joined and set devices and reverse_device_map
420                # This is a synchronous operation that completes once all existing ranks are updated
421                _set_devices_and_reverse_device_map(agent)
422            except Exception:
423                api.shutdown()
424                raise
425            return agent
426
427
428register_backend(
429    "TENSORPIPE",
430    _tensorpipe_construct_rpc_backend_options_handler,
431    _tensorpipe_init_backend_handler,
432)
433