1# mypy: allow-untyped-defs 2"""Distributed Collective Communication (c10d).""" 3 4import collections.abc 5import contextlib 6import hashlib 7import io 8import itertools 9import logging 10import os 11import pickle 12import sys 13import time 14import warnings 15from collections import namedtuple 16from datetime import timedelta 17from typing import Any, Callable, Dict, List, Optional, Tuple, TYPE_CHECKING, Union 18from typing_extensions import deprecated 19 20import torch 21from torch._C import _DistStoreError as DistStoreError 22from torch._C._distributed_c10d import ( 23 _DistributedBackendOptions, 24 _register_process_group, 25 _resolve_process_group, 26 _unregister_all_process_groups, 27 _unregister_process_group, 28 AllgatherOptions, 29 AllreduceCoalescedOptions, 30 AllreduceOptions, 31 AllToAllOptions, 32 BarrierOptions, 33 BroadcastOptions, 34 DebugLevel, 35 GatherOptions, 36 get_debug_level, 37 PrefixStore, 38 ProcessGroup, 39 ReduceOp, 40 ReduceOptions, 41 ReduceScatterOptions, 42 ScatterOptions, 43 Store, 44 Work, 45) 46from torch._utils_internal import set_pytorch_distributed_envs_from_justknobs 47from torch.utils._typing_utils import not_none 48 49from .c10d_logger import _exception_logger, _time_logger 50from .constants import default_pg_nccl_timeout, default_pg_timeout 51from .rendezvous import register_rendezvous_handler, rendezvous # noqa: F401 52 53 54__all__ = [ 55 "Backend", 56 "BackendConfig", 57 "GroupMember", 58 "P2POp", 59 "all_gather", 60 "all_gather_coalesced", 61 "all_gather_object", 62 "all_reduce", 63 "all_reduce_coalesced", 64 "all_to_all", 65 "all_to_all_single", 66 "barrier", 67 "batch_isend_irecv", 68 "broadcast", 69 "send_object_list", 70 "recv_object_list", 71 "broadcast_object_list", 72 "destroy_process_group", 73 "gather", 74 "gather_object", 75 "get_backend_config", 76 "get_backend", 77 "get_rank", 78 "get_world_size", 79 "get_pg_count", 80 "group", 81 "init_process_group", 82 "irecv", 83 "is_gloo_available", 84 "is_initialized", 85 "is_mpi_available", 86 "is_backend_available", 87 "is_nccl_available", 88 "is_torchelastic_launched", 89 "is_ucc_available", 90 "isend", 91 "monitored_barrier", 92 "new_group", 93 "new_subgroups", 94 "new_subgroups_by_enumeration", 95 "recv", 96 "reduce", 97 "reduce_scatter", 98 "scatter", 99 "scatter_object_list", 100 "send", 101 "supports_complex", 102 "AllreduceCoalescedOptions", 103 "AllreduceOptions", 104 "AllToAllOptions", 105 "BarrierOptions", 106 "BroadcastOptions", 107 "GatherOptions", 108 "PrefixStore", 109 "ProcessGroup", 110 "ReduceOp", 111 "ReduceOptions", 112 "ReduceScatterOptions", 113 "ScatterOptions", 114 "Store", 115 "DebugLevel", 116 "get_debug_level", 117 "Work", 118 "default_pg_timeout", 119 "get_group_rank", 120 "get_global_rank", 121 "get_process_group_ranks", 122 "reduce_op", 123 "all_gather_into_tensor", 124 "reduce_scatter_tensor", 125 "get_node_local_rank", 126 "split_group", 127] 128 129_MPI_AVAILABLE = True 130_NCCL_AVAILABLE = True 131_GLOO_AVAILABLE = True 132_UCC_AVAILABLE = True 133 134_pickler = pickle.Pickler 135_unpickler = pickle.Unpickler 136 137 138# Change __module__ of all imported types from torch._C._distributed_c10d that are public 139def _export_c_types() -> None: 140 _public_types_to_change_module = [ 141 AllreduceCoalescedOptions, 142 AllreduceOptions, 143 AllToAllOptions, 144 BarrierOptions, 145 BroadcastOptions, 146 GatherOptions, 147 PrefixStore, 148 ProcessGroup, 149 ReduceOp, 150 ReduceOptions, 151 ReduceScatterOptions, 152 ScatterOptions, 153 Store, 154 DebugLevel, 155 get_debug_level, 156 Work, 157 ] 158 for type in _public_types_to_change_module: 159 type.__module__ = "torch.distributed.distributed_c10d" 160 161 162_export_c_types() 163 164try: 165 from torch._C._distributed_c10d import ProcessGroupMPI 166 167 ProcessGroupMPI.__module__ = "torch.distributed.distributed_c10d" 168 __all__ += ["ProcessGroupMPI"] 169except ImportError: 170 _MPI_AVAILABLE = False 171 172try: 173 from torch._C._distributed_c10d import ProcessGroupNCCL 174 175 ProcessGroupNCCL.__module__ = "torch.distributed.distributed_c10d" 176 __all__ += ["ProcessGroupNCCL"] 177except ImportError: 178 _NCCL_AVAILABLE = False 179 180try: 181 from torch._C._distributed_c10d import _ProcessGroupWrapper, ProcessGroupGloo 182 183 ProcessGroupGloo.__module__ = "torch.distributed.distributed_c10d" 184 __all__ += ["ProcessGroupGloo"] 185except ImportError: 186 _GLOO_AVAILABLE = False 187 188try: 189 from torch._C._distributed_c10d import ProcessGroupUCC 190 191 ProcessGroupUCC.__module__ = "torch.distributed.distributed_c10d" 192 __all__ += ["ProcessGroupUCC"] 193except ImportError: 194 _UCC_AVAILABLE = False 195 196logger = logging.getLogger(__name__) 197 198PG_WRAPPER_STORE_PREFIX = "pg_wrapper" 199 200 201# Some reduce ops are not supported by complex numbers and will result in an error. 202# We currently provide complex support to the distributed API by viewing 203# complex tensors as real (torch.view_as_real), meaning that calling 204# these unsupported ops will return garbage values rather than error out. 205# (e.g. max(2+3i, 3+2i) = 3+3i) 206# We'd like calls to unsupported ops to error out accordingly, 207# rather than returning garbage values. 208def supports_complex(reduceOp: ReduceOp) -> bool: 209 """Return true if reduce ops is supported. False otherwise.""" 210 denyList = [ 211 ReduceOp.MAX, 212 ReduceOp.MIN, 213 ReduceOp.PRODUCT, 214 ReduceOp.BAND, 215 ReduceOp.BOR, 216 ReduceOp.BXOR, 217 ] 218 return reduceOp not in denyList 219 220 221class Backend(str): 222 """ 223 An enum-like class for backends. 224 225 Available backends: GLOO, NCCL, UCC, MPI, and other registered backends. 226 227 The values of this class are lowercase strings, e.g., ``"gloo"``. They can 228 be accessed as attributes, e.g., ``Backend.NCCL``. 229 230 This class can be directly called to parse the string, e.g., 231 ``Backend(backend_str)`` will check if ``backend_str`` is valid, and 232 return the parsed lowercase string if so. It also accepts uppercase strings, 233 e.g., ``Backend("GLOO")`` returns ``"gloo"``. 234 235 .. note:: The entry ``Backend.UNDEFINED`` is present but only used as 236 initial value of some fields. Users should neither use it directly 237 nor assume its existence. 238 """ 239 240 UNDEFINED = "undefined" 241 GLOO = "gloo" 242 NCCL = "nccl" 243 UCC = "ucc" 244 MPI = "mpi" 245 246 _BackendPlugin = namedtuple("_BackendPlugin", ["creator_fn", "extended_api"]) 247 248 _plugins: Dict[str, _BackendPlugin] = {} 249 250 backend_list = [UNDEFINED, GLOO, NCCL, UCC, MPI] 251 252 default_device_backend_map: Dict[str, str] = { 253 "cpu": GLOO, 254 "cuda": NCCL, 255 } 256 257 backend_capability: Dict[str, List[str]] = { 258 GLOO: ["cpu", "cuda"], 259 NCCL: ["cuda"], 260 UCC: ["cpu", "cuda"], 261 MPI: ["cpu", "cuda"], 262 } 263 264 backend_type_map: Dict[str, ProcessGroup.BackendType] = { 265 UNDEFINED: ProcessGroup.BackendType.UNDEFINED, 266 GLOO: ProcessGroup.BackendType.GLOO, 267 NCCL: ProcessGroup.BackendType.NCCL, 268 UCC: ProcessGroup.BackendType.UCC, 269 } 270 271 def __new__(cls, name: str): 272 """Create and return a new instance of the class.""" 273 if not isinstance(name, str): 274 raise ValueError("Backend constructor parameter must be string-ish") 275 value = getattr(Backend, name.upper(), Backend.UNDEFINED) 276 277 if value == Backend.UNDEFINED: 278 value = name.lower() 279 return value 280 281 @classmethod 282 def register_backend( 283 cls, 284 name, 285 func, 286 extended_api=False, 287 devices: Optional[Union[str, List[str]]] = None, 288 ) -> None: 289 """ 290 Register a new backend with the given name and instantiating function. 291 292 This class method is used by 3rd party ``ProcessGroup`` extension to 293 register new backends. 294 295 Args: 296 name (str): Backend name of the ``ProcessGroup`` extension. It 297 should match the one in ``init_process_group()``. 298 func (function): Function handler that instantiates the backend. 299 The function should be implemented in the backend 300 extension and takes four arguments, including 301 ``store``, ``rank``, ``world_size``, and ``timeout``. 302 extended_api (bool, optional): Whether the backend supports extended argument structure. 303 Default: ``False``. If set to ``True``, the backend 304 will get an instance of ``c10d::DistributedBackendOptions``, and 305 a process group options object as defined by the backend implementation. 306 device (str or list of str, optional): device type this backend 307 supports, e.g. "cpu", "cuda", etc. If `None`, 308 assuming both "cpu" and "cuda" 309 310 .. note:: This support of 3rd party backend is experimental and subject to change. 311 312 """ 313 # Allow UCC plugin if Pytorch is not built with native support. 314 # TODO: remove this exception once UCC plugin is fully deprecated. 315 if name != Backend.UCC or (name == Backend.UCC and is_ucc_available()): 316 assert not hasattr( 317 Backend, name.upper() 318 ), f"{name.upper()} c10d backend already exist" 319 assert ( 320 name.upper() not in Backend._plugins 321 ), f"{name.upper()} c10d backend creator function already exist" 322 323 setattr(Backend, name.upper(), name.lower()) 324 Backend.backend_list.append(name.lower()) 325 if devices is not None: 326 for device in devices: 327 if device != "cpu" and device != "cuda": 328 Backend.default_device_backend_map[device] = name.lower() 329 Backend.backend_type_map[name.lower()] = ProcessGroup.BackendType.CUSTOM 330 331 # Update device capability matrix in Backend class 332 if devices is None: 333 # This is more of a backward support for groups like `threaded`: 334 # assume default devices "cpu" and "cuda", but warn 335 warnings.warn( 336 f"Device capability of {name} unspecified, assuming `cpu` and " 337 "`cuda`. Please specify it via the `devices` argument of " 338 "`register_backend`." 339 ) 340 Backend.backend_capability[name.lower()] = ["cpu", "cuda"] 341 elif isinstance(devices, str): 342 # Single device string specified. Simply convert to list. 343 Backend.backend_capability[name.lower()] = [devices] 344 else: 345 Backend.backend_capability[name.lower()] = devices 346 347 Backend._plugins[name.upper()] = Backend._BackendPlugin(func, extended_api) 348 349 350class BackendConfig: 351 """Backend configuration class.""" 352 353 def __init__(self, backend: Backend): 354 """Init.""" 355 self.device_backend_map: Dict[str, Backend] = {} 356 backend = str(backend) 357 358 if backend == Backend.UNDEFINED: 359 # default config when backend is not specified 360 # supported since PyTorch 2.0 361 for device, default_backend in Backend.default_device_backend_map.items(): 362 if is_backend_available(default_backend): 363 if ( 364 default_backend == Backend.NCCL 365 and not torch.cuda.is_available() 366 ): 367 continue 368 self.device_backend_map[device] = Backend(default_backend) 369 elif backend.lower() in Backend.backend_list: 370 # Cases for when backend is a single string (without device types) 371 # e.g. "nccl", "gloo", "ucc", "mpi" 372 supported_devices = Backend.backend_capability[backend.lower()] 373 backend_val = Backend(backend) 374 self.device_backend_map = dict.fromkeys(supported_devices, backend_val) 375 elif ":" in backend.lower(): 376 # Backend specified in "device:backend" format 377 # make sure the backend string is in the correct format 378 # "{device_type1}:{backend1},{device_type2}:{backend2}" 379 # e.g. "cpu:gloo,cuda:nccl" 380 backend_str_error_message = f"""The custom backend string argument is invalid: {backend}. 381 Custom backend string is an experimental feature where the backend string must be in the format: 382 "<device_type1>:<backend1>,<device_type2>:<backend2>...". e.g. 'cpu:gloo,cuda:nccl'""" 383 384 # parse the backend string and populate the device_backend_map 385 for device_backend_pair_str in backend.lower().split(","): 386 device_backend_pair = device_backend_pair_str.split(":") 387 if len(device_backend_pair) != 2: 388 raise ValueError( 389 f"Invalid device:backend pairing: \ 390 {device_backend_pair_str}. {backend_str_error_message}" 391 ) 392 device, backend = device_backend_pair 393 if device in self.device_backend_map: 394 raise ValueError( 395 f"Duplicate device type {device} \ 396 in backend string: {backend}. {backend_str_error_message}" 397 ) 398 self.device_backend_map[device] = Backend(backend) 399 else: 400 # User specified a single backend name whose device capability is 401 # unknown, assuming it can support the default devices of PyTorch 402 # (cpu and cuda) 403 warnings.warn( 404 f"Device capability of {backend} unknown, assuming `cpu` and " 405 "`cuda`. You can specify it in `device:backend` format in " 406 "`init_process_group` call." 407 ) 408 backend_val = Backend(backend) 409 self.device_backend_map = { 410 "cpu": backend_val, 411 "cuda": backend_val, 412 "xpu": backend_val, 413 } 414 415 logger.info("Using backend config: %s", self.device_backend_map) 416 417 def __repr__(self): 418 """Return all the device:backend pairs separated by commas.""" 419 return ",".join( 420 f"{device}:{backend}" for device, backend in self.device_backend_map.items() 421 ) 422 423 def get_device_backend_map(self) -> Dict[str, Backend]: 424 """Return backend map of the device.""" 425 return self.device_backend_map 426 427 428class _reduce_op: 429 r""" 430 Deprecated enum-like class. 431 432 For reduction operations: ``SUM``, ``PRODUCT``, ``MIN``, and ``MAX``. 433 434 :class:`~torch.distributed.ReduceOp` is recommended to use instead. 435 """ 436 437 def __init__(self) -> None: 438 # __members__ is a dict storing key-value pairs for enum classes 439 for k, v in ReduceOp.RedOpType.__members__.items(): 440 setattr(self, k, v) 441 self.__members__ = ReduceOp.RedOpType.__members__ 442 443 @deprecated( 444 "`torch.distributed.reduce_op` is deprecated, " 445 "please use `torch.distributed.ReduceOp` instead", 446 category=FutureWarning, 447 ) 448 def __getattribute__(self, key): 449 return object.__getattribute__(self, key) 450 451 452reduce_op = _reduce_op() 453 454 455class P2POp: 456 """ 457 A class to build point-to-point operations for ``batch_isend_irecv``. 458 459 This class builds the type of P2P operation, communication buffer, peer rank, 460 Process Group, and tag. Instances of this class will be passed to 461 ``batch_isend_irecv`` for point-to-point communications. 462 463 Args: 464 op (Callable): A function to send data to or receive data from a peer process. 465 The type of ``op`` is either ``torch.distributed.isend`` or 466 ``torch.distributed.irecv``. 467 tensor (Tensor): Tensor to send or receive. 468 peer (int): Destination or source rank. 469 group (ProcessGroup, optional): The process group to work on. If None, 470 the default process group will be used. 471 tag (int, optional): Tag to match send with recv. 472 """ 473 474 def __init__( 475 self, 476 op: Callable, 477 tensor: torch.Tensor, 478 peer: int, 479 group: Optional[ProcessGroup] = None, 480 tag: int = 0, 481 ): 482 """Init.""" 483 self.op = op 484 self.tensor = tensor 485 self.peer = peer 486 self.group = group 487 self.tag = tag 488 489 def __new__( 490 cls, 491 op: Callable, 492 tensor: torch.Tensor, 493 peer: int, 494 group: Optional[ProcessGroup] = None, 495 tag: int = 0, 496 ): 497 """Create and return a new instance of the class.""" 498 _check_op(op) 499 _check_single_tensor(tensor, "tensor") 500 return object.__new__(cls) 501 502 def __repr__(self): 503 my_group_rank = get_rank(self.group) 504 peer_group_rank = ( 505 get_group_rank(self.group, self.peer) if self.group else self.peer 506 ) 507 op_name = self.op.__name__ 508 group_name = self.group.group_name if self.group else "default_pg" 509 if "send" in op_name: 510 s = my_group_rank 511 d = peer_group_rank 512 elif "recv" in op_name: 513 s = peer_group_rank 514 d = my_group_rank 515 else: 516 return super().__repr__() 517 518 return f"P2POp({op_name} pg={group_name}, s={s}, d={d}, {self.tensor.shape}, {self.tensor.dtype})" 519 520 521class _CollOp: 522 """ 523 A class to capture collective operations. 524 525 Args: 526 op (Callable): A collective function, e.g. ``torch.distributed.all_reduce``. 527 tensor (Tensor): Tensor to operate on. 528 dst_tensor (Tensor, optional): Provided when source and destinaton tensors are not the same. 529 redop (ReduceOp, optional): reduce operation. 530 root (int, optional): root of broadcast or reduce. 531 """ 532 533 def __init__( 534 self, 535 op: Callable, 536 tensor: torch.Tensor, 537 dst_tensor: Optional[torch.Tensor] = None, 538 redop: Optional[ReduceOp] = None, 539 root: Optional[int] = None, 540 ): 541 self.op = op 542 self.tensor = tensor 543 self.dst_tensor = dst_tensor 544 self.redop = redop 545 self.root = root 546 547 548# DO NOT USE THESE FIELDS DIRECTLY. 549# Use them through the _world object to make sure the _world override mechanism 550_pg_map: Dict[ProcessGroup, Tuple[str, Store]] = {} 551_pg_names: Dict[ProcessGroup, str] = {} 552_pg_group_ranks: Dict[ProcessGroup, Dict[int, int]] = {} 553# For a pg, it is a map from ProcessGroup to BackendConfig 554_pg_backend_config: Dict[ProcessGroup, str] = {} 555_group_count = 0 556_tags_to_pg: Dict[str, List[ProcessGroup]] = {} 557_pg_to_tag: Dict[ProcessGroup, str] = {} 558_backend: Optional[str] = None 559 560 561class _World: 562 """ 563 Container class for c10d process group state. 564 565 This is used during registration and lookup of PG state. 566 567 .. warning:: This is an experimental API intended to expose the inner workings 568 of c10d and is subject to change.. 569 """ 570 571 def __init__(self) -> None: 572 self._default_pg = None 573 self._pg_coalesce_state: Dict[ProcessGroup, List[_CollOp]] = {} 574 self._pg_default_device: Dict[ProcessGroup, torch.device] = {} 575 576 @property 577 def default_pg(self) -> Optional[ProcessGroup]: 578 """ 579 Process group that includes all ranks of the cluster. 580 581 This default ProcessGroup is used by c10d APIs when a ProcessGroup is needed 582 but None is provided. 583 """ 584 return self._default_pg 585 586 @default_pg.setter 587 def default_pg(self, value) -> None: 588 self._default_pg = value 589 590 @property 591 def pg_map(self) -> Dict[ProcessGroup, Tuple[str, Store]]: 592 """ 593 Provide Mapping from ProcessGroup to backend name and store. 594 595 For NCCL and GLOO pg, it is a map from ProcessGroup to (Backend, Store) 596 For MPI pg, it is a map from ProcessGroup to (Backend, None) 597 598 TODO don't expose the map, expose fine grained ops 599 """ 600 global _pg_map 601 return _pg_map 602 603 @property 604 def pg_names(self) -> Dict[ProcessGroup, str]: 605 """ 606 Process group's names, map from ProcessGroup to str. 607 608 TODO don't expose the map, expose fine grained ops 609 """ 610 global _pg_names 611 return _pg_names 612 613 @property 614 def pg_group_ranks(self) -> Dict[ProcessGroup, Dict[int, int]]: 615 """ 616 Process group's global rank to local rank mapping. 617 618 TODO don't expose the map, expose fine grained ops 619 """ 620 global _pg_group_ranks 621 return _pg_group_ranks 622 623 @property 624 def pg_backend_config(self) -> Dict[ProcessGroup, str]: 625 """ 626 Process group's backend config. 627 628 TODO don't expose the map, expose fine grained ops 629 """ 630 global _pg_backend_config 631 return _pg_backend_config 632 633 @property 634 def group_count(self) -> int: 635 """ 636 Process group count for default naming. 637 638 TODO don't expose group_count, use something else instead 639 """ 640 global _group_count 641 return _group_count 642 643 @group_count.setter 644 def group_count(self, value: int) -> None: 645 """Use to compute the name of ProcessGroups when using global synchronization.""" 646 global _group_count 647 _group_count = value 648 649 @property 650 def tags_to_pg(self) -> Dict[str, List[ProcessGroup]]: 651 global _tags_to_pg 652 return _tags_to_pg 653 654 @property 655 def pg_to_tag(self) -> Dict[ProcessGroup, str]: 656 global _pg_to_tag 657 return _pg_to_tag 658 659 @property 660 def pg_coalesce_state(self) -> Dict[ProcessGroup, List[_CollOp]]: 661 return self._pg_coalesce_state 662 663 @property 664 def pg_default_device(self) -> Dict[ProcessGroup, torch.device]: 665 return self._pg_default_device 666 667 @property 668 def pg_config_info(self) -> List[Dict[str, Any]]: 669 """ 670 Return a list of dict with process groups and backends. 671 672 Along with their unique IDs and configurations (types and ranks). 673 """ 674 config_info: List[Dict[str, Any]] = [] 675 default_pg_size = _get_group_size(None) 676 for pg in self.pg_map.keys(): 677 ranks = self.pg_group_ranks[pg] 678 config_info.append( 679 { 680 "pg_name": self.pg_names[pg], 681 "pg_desc": pg.group_desc, 682 "backend_config": self.pg_backend_config[pg], 683 "ranks": list(ranks.keys()) 684 if len(ranks) != default_pg_size 685 else [], # 'ranks' is an empty list when all ranks are involved in a pg 686 "group_size": len(ranks), 687 "group_count": self.group_count, 688 } 689 ) 690 return config_info 691 692 693_world = _World() 694"""Holds the singleton instance of ``_World`` used by c10. Experimental extension point to override it""" 695 696 697class _WorldMeta(type): 698 """ 699 Meta class of ``group`` and ``GroupMember``. 700 701 Allows them to have the class property ``WORLD``. 702 """ 703 704 # Points to the default PG once initialized. 705 @property 706 def WORLD(cls) -> Optional[ProcessGroup]: 707 return _world.default_pg 708 709 @WORLD.setter 710 def WORLD(cls, pg: Optional[ProcessGroup]): 711 _world.default_pg = pg 712 713 714class group(metaclass=_WorldMeta): 715 """Group class. Placeholder.""" 716 717 718class GroupMember(metaclass=_WorldMeta): 719 """Group member class.""" 720 721 NON_GROUP_MEMBER = -100 722 723 724def _get_default_timeout(backend: Backend) -> timedelta: 725 # see note on nccl vs other backend timeout (constants.py) 726 if backend == Backend.NCCL: 727 if not isinstance(default_pg_nccl_timeout, timedelta): 728 # TODO moco benchmark on CPU initializes pgnccl backend today, triggered this assert in CI before it was 729 # changed to be a warning. We should fix the moco model. 730 warnings.warn( 731 "Attempted to get default timeout for nccl backend, but NCCL support is not compiled" 732 ) 733 return default_pg_timeout 734 return default_pg_nccl_timeout 735 else: 736 return default_pg_timeout 737 738 739def _check_valid_timeout(timeout: Any) -> None: 740 if not isinstance(timeout, timedelta): 741 raise TypeError( 742 f"Expected timeout argument to be of type datetime.timedelta, got {timeout}" 743 ) 744 745 746# Default process group state 747_default_pg_init_method: Optional[str] = None 748 749STORE_BASED_BARRIER_PREFIX = "store_based_barrier_key" 750 751 752def _get_pg_default_device(group: Optional[ProcessGroup] = None) -> torch.device: 753 """ 754 Return the device to use with ``group`` for control flow usage (object collectives, barrier). 755 756 There are selection rules: 757 1. If user specifies exactly one backend in ``init_process_group`` call: 758 use that backend 759 2. Else if user specifies multiple "device:backend" pairs in init_process_group: 760 If "cpu" is among those pairs, use "cpu" (because the object is in cpu memory); 761 Otherwise, use the first backend (sort of a random pick). 762 763 Args: 764 group (ProcessGroup, optional): The process group to work on. If None, 765 the default process group will be used. 766 767 Returns: 768 torch.device: The device to use with ``group``. 769 770 """ 771 group = group or _get_default_group() 772 if group in _world.pg_default_device: 773 # Previously searched and cached; just return 774 return _world.pg_default_device[group] 775 776 if not isinstance(group, ProcessGroup): 777 # Provide backward compatibility to cases where `group` passed in is 778 # actually a Backend (like `ProcessGroupGloo`) rather than a 779 # `ProcessGroup` in PT 2.0 sense 780 warnings.warn( 781 f"You are using a Backend {type(group)} as a ProcessGroup. " 782 "This usage is deprecated since PyTorch 2.0. Please use a public API " 783 "of PyTorch Distributed instead.", 784 FutureWarning, 785 stacklevel=3, 786 ) 787 # Most users create Gloo with private API for object collectives 788 _world.pg_default_device[group] = torch.device("cpu") 789 return _world.pg_default_device[group] 790 791 """ 792 ``group._device_types`` is a property pybind that returns the devices 793 ("cpu", "cuda", etc) supported by ``group``. Can be multiple if the 794 ``group`` supports multiple devices. 795 """ 796 devices = group._device_types 797 798 if len(devices) == 1: 799 # User fixed exactly one backend in `init_process_group` 800 _world.pg_default_device[group] = devices[0] 801 elif len(devices) == 0: 802 # No backend has been registered with this PG (maybe because no 803 # collective has been run?) We pick cpu as the default and hopefully 804 # this would lazily init Gloo or other available cpu backend. 805 _world.pg_default_device[group] = torch.device("cpu") 806 elif torch.device("cpu") in devices: 807 # There are multiple backends in this PG and cpu is among them. 808 # cpu is preferred as the object is in cpu memory. No need for device 809 # copy. 810 _world.pg_default_device[group] = torch.device("cpu") 811 else: 812 # No cpu in the backend list. Randomly pick the first backend 813 _world.pg_default_device[group] = devices[0] 814 815 logger.info( 816 "Using device %s for object " "collectives.", _world.pg_default_device[group] 817 ) 818 return _world.pg_default_device[group] 819 820 821@_time_logger 822def _store_based_barrier( 823 rank, 824 store, 825 group_name, 826 rendezvous_count, 827 timeout, 828 logging_interval=timedelta(seconds=10), 829) -> None: 830 """ 831 Store based barrier for synchronizing processes. 832 833 Barrier based on store which is used for synchronizing processes after 834 ``init_process_group`` or ``new_group``. Intended to be used only with 835 those two methods and is not a generic alternative to ``barrier()``. 836 """ 837 store_key = f"{STORE_BASED_BARRIER_PREFIX}:{group_name}" 838 store.add(store_key, 1) 839 logger.debug("Added key: %s to store for rank: %s", store_key, rank) 840 841 # Now wait for all workers to check in with the store. 842 world_size = rendezvous_count 843 worker_count = store.add(store_key, 0) 844 845 last_worker_key = f"{store_key}:last_worker" 846 if worker_count == world_size: 847 store.set(last_worker_key, "1") 848 849 # adjust the timeout to be at least 10secs + 1sec per thousand ranks to reduce the odds of timeout 850 # this value was empirically found while scale testing. 851 logging_interval = max(logging_interval, timedelta(seconds=10 + world_size / 1000)) 852 853 start = time.time() 854 while True: 855 try: 856 # This will throw an exception after the logging_interval in which we print out 857 # the status of the group or time out officially, throwing runtime error 858 store.wait([last_worker_key], logging_interval) 859 break 860 except RuntimeError as e: 861 worker_count = store.add(store_key, 0) 862 # Print status periodically to keep track. 863 logger.debug( 864 "Waiting in store based barrier to initialize process group for " 865 "rank: %s, key: %s (world_size=%s, num_workers_joined=%s, timeout=%s error=%s)", 866 rank, 867 store_key, 868 world_size, 869 worker_count, 870 timeout, 871 e, 872 ) 873 874 if timedelta(seconds=(time.time() - start)) > timeout: 875 raise DistStoreError( # noqa: B904 876 "Timed out initializing process group in store based barrier on " 877 f"rank {rank}, for key: {store_key} (world_size={world_size}, " 878 f"num_workers_joined={worker_count}, timeout={timeout} error={e})" 879 ) 880 881 logger.info( 882 "Rank %s: Completed store-based barrier for key:%s with %s nodes.", 883 rank, 884 store_key, 885 world_size, 886 ) 887 888 889def _rank_not_in_group(group: Optional[ProcessGroup]) -> bool: 890 """Check if the current process's rank is not in a given group.""" 891 if group is None: 892 return False 893 return group == GroupMember.NON_GROUP_MEMBER 894 895 896def _warn_not_in_group(op_name) -> None: 897 global_rank = -1 if GroupMember.WORLD is None else GroupMember.WORLD.rank() 898 warnings.warn( 899 f"Running {op_name} on global rank {global_rank} which does not " 900 "belong to the given group." 901 ) 902 903 904def get_group_rank(group: ProcessGroup, global_rank: int) -> int: 905 """ 906 Translate a global rank into a group rank. 907 908 ``global_rank`` must be part of ``group`` otherwise this raises RuntimeError. 909 910 Args: 911 group (ProcessGroup): ProcessGroup to find the relative rank. 912 global_rank (int): Global rank to query. 913 914 Returns: 915 Group rank of ``global_rank`` relative to ``group`` 916 917 N.B. calling this function on the default process group returns identity 918 """ 919 if group is GroupMember.WORLD: 920 return global_rank 921 if group not in _world.pg_group_ranks: 922 raise ValueError( 923 f"Group {group} is not registered, please create group with torch.distributed.new_group API" 924 ) 925 group_ranks = _world.pg_group_ranks[group] 926 if global_rank not in group_ranks: 927 raise ValueError(f"Global rank {global_rank} is not part of group {group}") 928 929 return group_ranks[global_rank] 930 931 932def get_global_rank(group: ProcessGroup, group_rank: int) -> int: 933 """ 934 Translate a group rank into a global rank. 935 936 ``group_rank`` must be part of `group` otherwise this raises RuntimeError. 937 938 Args: 939 group (ProcessGroup): ProcessGroup to find the global rank from. 940 group_rank (int): Group rank to query. 941 942 Returns: 943 Global rank of ``group_rank`` relative to ``group`` 944 945 N.B. calling this function on the default process group returns identity 946 """ 947 if group is GroupMember.WORLD: 948 return group_rank 949 if group not in _world.pg_group_ranks: 950 raise ValueError( 951 f"Group {group} is not registered, please create group with torch.distributed.new_group API" 952 ) 953 for rank, grp_rank in _world.pg_group_ranks[group].items(): 954 if grp_rank == group_rank: 955 return rank 956 raise ValueError(f"Group rank {group_rank} is not part of group {group}") 957 958 959# TODO: remove this once the ecosystem moves away from it. 960@deprecated( 961 "`torch.distributed.distributed_c10d._get_global_rank` is deprecated, " 962 "please use `torch.distributed.distributed_c10d.get_global_rank` instead", 963 category=FutureWarning, 964) 965def _get_global_rank(group, rank) -> int: 966 """Use get_global_rank as this method is deprecated.""" 967 return get_global_rank(group, rank) 968 969 970def get_process_group_ranks(group: ProcessGroup) -> List[int]: 971 """ 972 Get all ranks associated with ``group``. 973 974 Args: 975 group (ProcessGroup): ProcessGroup to get all ranks from. 976 977 Returns: 978 List of global ranks ordered by group rank. 979 """ 980 return list(_world.pg_group_ranks[group].keys()) 981 982 983def _get_group_size(group) -> int: 984 """Get a given group's world size.""" 985 if group is GroupMember.WORLD or group is None: 986 default_pg = _get_default_group() 987 return default_pg.size() 988 return group.size() 989 990 991def _get_group_size_by_name(group_name: str) -> int: 992 group = _resolve_process_group(group_name) 993 return group.size() 994 995 996def _resolve_group_name_by_ranks_and_tag(ranks: List[int], tag: str) -> str: 997 # TODO(yifu): remove this function once ranks + tag is not a supported 998 # identifier for process group for functional collectives. 999 group = _find_pg_by_ranks_and_tag(tag, ranks) 1000 if group is None: 1001 raise ValueError("") 1002 return group.group_name 1003 1004 1005def _check_single_tensor(param, param_name) -> None: 1006 """Check that the parameter ``param_name`` is a single tensor.""" 1007 if not isinstance(param, torch.Tensor): 1008 raise TypeError( 1009 f"""Invalid function argument. Expected parameter `{param_name}` of type torch.Tensor 1010 but got {type(param)} instead.""" 1011 ) 1012 1013 1014def _check_tensor_list(param, param_name) -> None: 1015 """Check that the parameter ``param_name`` is a list of tensors.""" 1016 if not isinstance(param, list): 1017 raise TypeError( 1018 f"""Invalid function argument. Expected parameter `{param_name}` of type List[torch.Tensor] 1019 but got {type(param)} instead.""" 1020 ) 1021 elif not all(isinstance(p, torch.Tensor) for p in param): 1022 raise TypeError( 1023 f"""Invalid function argument. Expected parameter `{param_name}` of type List[torch.Tensor] 1024 but got {type(param)} with elements of type {[type(p) for p in param]}.""" 1025 ) 1026 1027 1028def _as_iterable(obj) -> collections.abc.Iterable: 1029 return obj if isinstance(obj, list) else (obj,) 1030 1031 1032def _ensure_all_tensors_same_dtype(*tensors) -> None: 1033 last_dtype = None 1034 for tensor in itertools.chain.from_iterable(map(_as_iterable, tensors)): 1035 tensor_dtype = tensor.dtype 1036 # Mixing complex and its element type is allowed 1037 if tensor_dtype.is_complex: 1038 tensor_dtype = ( 1039 torch.float32 if tensor_dtype == torch.complex64 else torch.complex128 1040 ) 1041 1042 if last_dtype is None: 1043 last_dtype = tensor_dtype 1044 else: 1045 if last_dtype != tensor_dtype: 1046 raise ValueError( 1047 "Invalid usage of tensors with different dtypes" 1048 f"Found {last_dtype} and {tensor.dtype}" 1049 ) 1050 1051 1052def _check_op(op) -> None: 1053 """Check that the ``op`` is either isend or irecv.""" 1054 if op not in [isend, irecv]: 1055 raise ValueError( 1056 "Invalid ``op``. Expected ``op`` " 1057 "to be of type ``torch.distributed.isend`` or " 1058 "``torch.distributed.irecv``." 1059 ) 1060 1061 1062def _check_p2p_op_list(p2p_op_list) -> None: 1063 """ 1064 Check that the ``p2p_op_list`` is a list of P2POp instances. 1065 1066 Also, check that all ops use the same group. 1067 """ 1068 if not isinstance(p2p_op_list, list) or not all( 1069 isinstance(p2p_op, P2POp) for p2p_op in p2p_op_list 1070 ): 1071 raise ValueError( 1072 "Invalid ``p2p_op_list``. Each op is expected to " 1073 "to be of type ``torch.distributed.P2POp``." 1074 ) 1075 1076 group = p2p_op_list[0].group 1077 if not all(group == p2p_op.group for p2p_op in p2p_op_list): 1078 raise ValueError("All ops need to use the same group.") 1079 1080 1081def is_mpi_available() -> bool: 1082 """Check if the MPI backend is available.""" 1083 return _MPI_AVAILABLE 1084 1085 1086def is_nccl_available() -> bool: 1087 """Check if the NCCL backend is available.""" 1088 return _NCCL_AVAILABLE 1089 1090 1091def is_gloo_available() -> bool: 1092 """Check if the Gloo backend is available.""" 1093 return _GLOO_AVAILABLE 1094 1095 1096def is_ucc_available() -> bool: 1097 """Check if the UCC backend is available.""" 1098 return _UCC_AVAILABLE 1099 1100 1101def is_backend_available(backend: str) -> bool: 1102 """ 1103 Check backend availability. 1104 1105 Checks if the given backend is available and supports the built-in backends or 1106 third-party backends through function ``Backend.register_backend``. 1107 1108 Args: 1109 backend (str): Backend name. 1110 Returns: 1111 bool: Returns true if the backend is available otherwise false. 1112 """ 1113 # If the backend has an ``is_backend_available`` function, return the result of that function directly 1114 available_func = getattr(torch.distributed, f"is_{backend.lower()}_available", None) 1115 if available_func: 1116 return available_func() 1117 1118 return backend.lower() in Backend.backend_list 1119 1120 1121def is_initialized() -> bool: 1122 """Check if the default process group has been initialized.""" 1123 return GroupMember.WORLD is not None 1124 1125 1126def is_torchelastic_launched() -> bool: 1127 """ 1128 Check whether this process was launched with ``torch.distributed.elastic`` (aka torchelastic). 1129 1130 The existence of ``TORCHELASTIC_RUN_ID`` environment 1131 variable is used as a proxy to determine whether the current process 1132 was launched with torchelastic. This is a reasonable proxy since 1133 ``TORCHELASTIC_RUN_ID`` maps to the rendezvous id which is always a 1134 non-null value indicating the job id for peer discovery purposes.. 1135 """ 1136 return os.getenv("TORCHELASTIC_RUN_ID") is not None 1137 1138 1139def _is_barrier_after_init() -> int: 1140 # Environment variable to control whether process group should perform a 1141 # barrier after its init. Default value is 0, i.e. no barrier. If you 1142 # experience issue with this setting, you may set 1143 # `TORCH_DIST_INIT_BARRIER=1` to add the barrier. 1144 return int(os.getenv("TORCH_DIST_INIT_BARRIER", "0")) 1145 1146 1147def _get_default_group() -> ProcessGroup: 1148 """Get the default process group created by init_process_group.""" 1149 if not is_initialized(): 1150 raise ValueError( 1151 "Default process group has not been initialized, " 1152 "please make sure to call init_process_group." 1153 ) 1154 if TYPE_CHECKING: 1155 return not_none(GroupMember.WORLD) 1156 else: 1157 return GroupMember.WORLD 1158 1159 1160def _get_default_store() -> Store: 1161 """Get the default store created by init_process_group.""" 1162 if not is_initialized(): 1163 raise ValueError( 1164 "Default process group has not been initialized, " 1165 "please make sure to call init_process_group." 1166 ) 1167 default_pg = _get_default_group() 1168 _, default_store = _world.pg_map[default_pg] 1169 return default_store 1170 1171 1172def _update_default_pg(pg) -> None: 1173 _world.default_pg = pg 1174 rank = pg.rank() if pg is not None and pg != GroupMember.NON_GROUP_MEMBER else -1 1175 torch._C._distributed_c10d._set_global_rank(rank) 1176 1177 1178def get_backend_config(group: Optional[ProcessGroup] = None) -> str: 1179 """ 1180 Return the backend configuration of the given process group. 1181 1182 Args: 1183 group (ProcessGroup, optional): The process group to work on. The 1184 default is the general main process group. If another specific group 1185 is specified, the calling process must be part of :attr:`group`. 1186 1187 Returns: 1188 The backend configuration of the given process group as a lower case string. 1189 1190 """ 1191 pg = group or _get_default_group() 1192 if _rank_not_in_group(pg): 1193 raise ValueError("Invalid process group specified") 1194 backend_config = _world.pg_backend_config.get(pg) 1195 return str(not_none(backend_config)) 1196 1197 1198def get_backend(group: Optional[ProcessGroup] = None) -> Backend: 1199 """ 1200 Return the backend of the given process group. 1201 1202 Args: 1203 group (ProcessGroup, optional): The process group to work on. The 1204 default is the general main process group. If another specific group 1205 is specified, the calling process must be part of :attr:`group`. 1206 1207 Returns: 1208 The backend of the given process group as a lower case string. 1209 1210 """ 1211 pg = group or _get_default_group() 1212 if _rank_not_in_group(pg): 1213 raise ValueError("Invalid process group specified") 1214 pg_store = _world.pg_map[pg] if pg in _world.pg_map else None 1215 return Backend(not_none(pg_store)[0]) 1216 1217 1218def _get_process_group_uid(pg: ProcessGroup) -> int: 1219 backend = None 1220 try: 1221 backend = pg._get_backend(torch.device("cuda")) 1222 except RuntimeError: 1223 pass 1224 if is_nccl_available() and isinstance(backend, ProcessGroupNCCL): 1225 return backend.uid 1226 return -1 1227 1228 1229def _get_pg_config(group: Optional[ProcessGroup] = None) -> Dict[str, Any]: 1230 """ 1231 Return the pg configuration of the given process group. 1232 1233 """ 1234 pg = group or _get_default_group() 1235 return { 1236 "pg_name": _get_process_group_name(pg), 1237 "pg_desc": pg.group_desc, 1238 "backend_config": get_backend_config(pg), 1239 "pg_size": _get_group_size(pg), 1240 "ranks": get_process_group_ranks(pg), 1241 } 1242 1243 1244def _get_all_pg_configs() -> List[Dict[str, Any]]: 1245 """ 1246 Return the pg configuration of all the process groups. 1247 1248 """ 1249 config_info: List[Dict[str, Any]] = [] 1250 for pg in _world.pg_map.keys(): 1251 config_info.append(_get_pg_config(pg)) 1252 return config_info 1253 1254 1255def get_pg_count() -> int: 1256 """ 1257 Return the number of process groups. 1258 1259 """ 1260 return _world.group_count 1261 1262 1263def get_node_local_rank(fallback_rank: Optional[int] = None) -> int: 1264 """ 1265 Return the local rank of the current process relative to the node. 1266 1267 Semantically, this is a useful concept for mapping processes to devices. 1268 For example, on a node with 8 accelerator you could use the node local rank to decide 1269 which accelerator device to bind the process to. 1270 1271 In practice, the actual assignment of node local ranks is handled by the process launcher outside of pytorch, 1272 and communicated via the `LOCAL_RANK` environment variable. 1273 1274 Torchrun will automatically populate `LOCAL_RANK`, but other launchers may not. If `LOCAL_RANK` is unspecified, 1275 this API will fall back to the provided kwarg 'fallback_rank' if specified, otherwise it will raise an error. The 1276 intent is to allow writing an application that runs either in single or multi device contexts without error. 1277 1278 """ 1279 if "LOCAL_RANK" in os.environ: 1280 return int(os.environ["LOCAL_RANK"]) 1281 elif fallback_rank is not None: 1282 return int(fallback_rank) 1283 raise RuntimeError( 1284 "LOCAL_RANK is not in the environment. Consider passing fallback_rank to allow `get_node_local_rank` to work, " 1285 "assuming you are not running in a multi-device context and want the code to run locally instead." 1286 ) 1287 1288 1289def _add_ephemeral_timeout_for_all_pgs(timeout: timedelta) -> None: 1290 """ 1291 This API adds an ephemeral timeout extension for all PGs locally 1292 on one rank. The timeout gets reset when the first collective issued 1293 after API called finished. 1294 NOTE: We only support to set timeout for cuda backends for now. 1295 NOTE: While this feature 1296 provides flexibility in specific scenarios, it introduces statefulness 1297 to timeout setting. Therefore, it is advisable to use this API sparingly 1298 and consider alternative approaches, such as directly setting the timeout 1299 or utilizing a barrier collective (one can set any timeout to the barrier), 1300 whenever feasible. 1301 1302 Args: 1303 timeout (timedelta): The delta of timeout to extend. 1304 1305 Returns: 1306 None. 1307 """ 1308 for pg in _world.pg_map.keys(): 1309 devices = pg._device_types 1310 if torch.device("cuda") in devices: 1311 backend = pg._get_backend(torch.device("cuda")) 1312 if is_nccl_available() and isinstance(backend, ProcessGroupNCCL): 1313 backend._add_ephemeral_timeout(timeout) 1314 1315 1316def _set_pg_timeout(timeout: timedelta, group: Optional[ProcessGroup] = None) -> None: 1317 """ 1318 Set the timeout for the given process group when users want to use a different timeout instead of 1319 default values. 1320 1321 Args: 1322 timeout (timedelta): Timeout for operations executed against the process group which 1323 users want to set. Default value is 10 minutes for NCCL and 30 minutes for other backends. 1324 This is the duration after which collectives will be aborted asynchronously and the process will crash. 1325 This is done since CUDA execution is async and it is no longer safe to continue executing user code since 1326 failed async NCCL operations might result in subsequent CUDA operations running on corrupted data. 1327 When TORCH_NCCL_BLOCKING_WAIT is set, the process will block and wait for this timeout. 1328 1329 group (ProcessGroup, optional): The process group to work on. The 1330 default is the general main process group. If another specific group 1331 is specified, the calling process must be part of :attr:`group`. 1332 1333 Returns: 1334 None 1335 """ 1336 if group is None: 1337 group = _get_default_group() 1338 if _rank_not_in_group(group): 1339 raise ValueError("Invalid process group specified") 1340 assert isinstance(group, ProcessGroup) 1341 devices = group._device_types 1342 backends = set() 1343 if torch.device("cpu") in devices and is_gloo_available(): 1344 backend = group._get_backend(torch.device("cpu")) 1345 if isinstance(backend, ProcessGroupGloo): 1346 backends.add(backend) 1347 if torch.device("cuda") in devices: 1348 backend = group._get_backend(torch.device("cuda")) 1349 if is_nccl_available() and isinstance(backend, ProcessGroupNCCL): 1350 backends.add(backend) # type: ignore[arg-type] 1351 elif is_gloo_available() and isinstance(backend, ProcessGroupGloo): 1352 backends.add(backend) # type: ignore[arg-type] 1353 if len(backends) == 0: 1354 warnings.warn("Set timeout is now only supported for either nccl or gloo.") 1355 for backend in backends: 1356 backend._set_default_timeout(timeout) 1357 1358 1359@_exception_logger 1360@_time_logger 1361def init_process_group( 1362 backend: Optional[str] = None, 1363 init_method: Optional[str] = None, 1364 timeout: Optional[timedelta] = None, 1365 world_size: int = -1, 1366 rank: int = -1, 1367 store: Optional[Store] = None, 1368 group_name: str = "", 1369 pg_options: Optional[Any] = None, 1370 device_id: Optional[torch.device] = None, 1371) -> None: 1372 """ 1373 Initialize the default distributed process group. 1374 1375 This will also initialize the distributed package. 1376 1377 There are 2 main ways to initialize a process group: 1378 1. Specify ``store``, ``rank``, and ``world_size`` explicitly. 1379 2. Specify ``init_method`` (a URL string) which indicates where/how 1380 to discover peers. Optionally specify ``rank`` and ``world_size``, 1381 or encode all required parameters in the URL and omit them. 1382 1383 If neither is specified, ``init_method`` is assumed to be "env://". 1384 1385 1386 Args: 1387 backend (str or Backend, optional): The backend to use. Depending on 1388 build-time configurations, valid values include ``mpi``, ``gloo``, 1389 ``nccl``, and ``ucc``. If the backend is not provided, then both a ``gloo`` 1390 and ``nccl`` backend will be created, see notes below for how multiple 1391 backends are managed. This field can be given as a lowercase string 1392 (e.g., ``"gloo"``), which can also be accessed via 1393 :class:`Backend` attributes (e.g., ``Backend.GLOO``). If using 1394 multiple processes per machine with ``nccl`` backend, each process 1395 must have exclusive access to every GPU it uses, as sharing GPUs 1396 between processes can result in deadlocks. ``ucc`` backend is 1397 experimental. 1398 init_method (str, optional): URL specifying how to initialize the 1399 process group. Default is "env://" if no 1400 ``init_method`` or ``store`` is specified. 1401 Mutually exclusive with ``store``. 1402 world_size (int, optional): Number of processes participating in 1403 the job. Required if ``store`` is specified. 1404 rank (int, optional): Rank of the current process (it should be a 1405 number between 0 and ``world_size``-1). 1406 Required if ``store`` is specified. 1407 store(Store, optional): Key/value store accessible to all workers, used 1408 to exchange connection/address information. 1409 Mutually exclusive with ``init_method``. 1410 timeout (timedelta, optional): Timeout for operations executed against 1411 the process group. Default value is 10 minutes for NCCL and 30 minutes for other backends. 1412 This is the duration after which collectives will be aborted asynchronously and the process will crash. 1413 This is done since CUDA execution is async and it is no longer safe to continue executing user code since 1414 failed async NCCL operations might result in subsequent CUDA operations running on corrupted data. 1415 When TORCH_NCCL_BLOCKING_WAIT is set, the process will block and wait for this timeout. 1416 1417 group_name (str, optional, deprecated): Group name. This argument is ignored 1418 pg_options (ProcessGroupOptions, optional): process group options 1419 specifying what additional options need to be passed in during 1420 the construction of specific process groups. As of now, the only 1421 options we support is ``ProcessGroupNCCL.Options`` for the ``nccl`` 1422 backend, ``is_high_priority_stream`` can be specified so that 1423 the nccl backend can pick up high priority cuda streams when 1424 there're compute kernels waiting. For other availble options to config nccl, 1425 See https://docs.nvidia.com/deeplearning/nccl/user-guide/docs/api/types.html#ncclconfig-t 1426 device_id (torch.device, optional): a single, specific device 1427 to "bind" this process to, allowing for backend-specific 1428 optimizations. Currently this has two effects, only under 1429 NCCL: the communicator is immediately formed (calling 1430 ``ncclCommInit*`` immediately rather than the normal lazy 1431 call) and sub-groups will use ``ncclCommSplit`` when 1432 possible to avoid unnecessary overhead of group creation. If you 1433 want to know NCCL initialization error early, you can also use this 1434 field. 1435 1436 .. note:: To enable ``backend == Backend.MPI``, PyTorch needs to be built from source 1437 on a system that supports MPI. 1438 1439 .. note:: Support for multiple backends is experimental. Currently when no backend is 1440 specified, both ``gloo`` and ``nccl`` backends will be created. The ``gloo`` backend 1441 will be used for collectives with CPU tensors and the ``nccl`` backend will be used 1442 for collectives with CUDA tensors. A custom backend can be specified by passing in 1443 a string with format "<device_type>:<backend_name>,<device_type>:<backend_name>", e.g. 1444 "cpu:gloo,cuda:custom_backend". 1445 1446 """ 1447 1448 global _world 1449 1450 global _backend 1451 global _default_pg_init_method 1452 1453 if GroupMember.WORLD is not None: 1454 raise ValueError("trying to initialize the default process group twice!") 1455 1456 set_pytorch_distributed_envs_from_justknobs() 1457 1458 # Depending on the import order, some trace_rules functions may be evaluated 1459 # during the import phase. In such a case, these functions may not correctly 1460 # add the distributed related rules due to import circular dependency. 1461 # We need to clear the lru_cache during the runtime to ensure the correctness 1462 # of these trace_rules. 1463 # 1464 # Since this API must be called before all distributed code being compiled, 1465 # clearing the cache here should be safe. 1466 if "torch._dynamo" in sys.modules: 1467 torch._dynamo.trace_rules.clear_lru_cache() 1468 1469 assert (store is None) or ( 1470 init_method is None 1471 ), "Cannot specify both init_method and store." 1472 1473 if store is not None: 1474 assert world_size > 0, "world_size must be positive if using store" 1475 assert rank >= 0, "rank must be non-negative if using store" 1476 elif init_method is None: 1477 init_method = "env://" 1478 1479 if backend: 1480 backend = Backend(backend) 1481 else: 1482 backend = Backend("undefined") 1483 1484 if timeout is None: 1485 timeout = _get_default_timeout(backend) 1486 1487 _check_valid_timeout(timeout) 1488 1489 """ 1490 Group name is not visible to users unless they access 1491 internals of c10d. This means we can ignore the value 1492 they provide as it not exposed in a public way. 1493 """ 1494 group_name = _process_group_name([], use_hashed_name=False) 1495 if backend == Backend.MPI: 1496 if world_size != -1 or rank != -1: 1497 warnings.warn( 1498 f"For MPI backend, world_size ({world_size}) and rank ({rank}) " 1499 "are ignored since they are assigned by the " 1500 "MPI runtime." 1501 ) 1502 1503 default_pg, _ = _new_process_group_helper( 1504 -1, 1505 -1, 1506 [], 1507 backend, 1508 None, 1509 group_name, 1510 timeout=timeout, 1511 group_desc="default_pg", 1512 ) 1513 _update_default_pg(default_pg) 1514 else: 1515 # backward compatible API 1516 if store is None: 1517 rendezvous_iterator = rendezvous( 1518 not_none(init_method), rank, world_size, timeout=timeout 1519 ) 1520 store, rank, world_size = next(rendezvous_iterator) 1521 store.set_timeout(timeout) 1522 1523 # Use a PrefixStore to avoid accidental overrides of keys used by 1524 # different systems (e.g. RPC) in case the store is multi-tenant. 1525 store = PrefixStore("default_pg", store) 1526 1527 default_pg, _ = _new_process_group_helper( 1528 world_size, 1529 rank, 1530 [], 1531 backend, 1532 store, 1533 group_name, 1534 pg_options=pg_options, 1535 timeout=timeout, 1536 device_id=device_id, 1537 group_desc="default_pg", 1538 ) 1539 _update_default_pg(default_pg) 1540 1541 _world.pg_group_ranks[GroupMember.WORLD] = {i: i for i in range(GroupMember.WORLD.size())} # type: ignore[attr-defined, index] 1542 _backend = _world.pg_map[not_none(GroupMember.WORLD)][0] 1543 _default_pg_init_method = init_method 1544 1545 old_hook = sys.excepthook 1546 excepthook_prefix = f"[rank{get_rank()}]" 1547 1548 def _distributed_excepthook(*args): 1549 old_stderr = sys.stderr 1550 sys.stderr = buf = io.StringIO() 1551 try: 1552 old_hook(*args) 1553 finally: 1554 sys.stderr = old_stderr 1555 msg = buf.getvalue() 1556 msg = "\n".join( 1557 f"{excepthook_prefix}: {s}" if s != "" else "" for s in msg.split("\n") 1558 ) 1559 sys.stderr.write(msg) 1560 sys.stderr.flush() 1561 1562 sys.excepthook = _distributed_excepthook 1563 1564 if _is_barrier_after_init() == 1: 1565 # barrier at the end to ensure that once we return from this method, all 1566 # process groups including global variables (if any) are updated 1567 # correctly on all ranks. 1568 # Update 04/2023: for large-scale runs, this barrier (esp. store-based 1569 # barrier) may be costly and/or unscalable. Also, in a lot of cases, 1570 # these barriers may be unnecessary, as proven by a green CI after 1571 # removal. An environment variable `TORCH_DIST_INIT_BARRIER` has been 1572 # added which enables this barrier only when set to 1. 1573 logger.debug( 1574 "Performing barrier after ProcessGroup initialization since " 1575 "TORCH_DIST_INIT_BARRIER = 1" 1576 ) 1577 if backend == Backend.MPI: 1578 # MPI backend doesn't use store. 1579 barrier() 1580 else: 1581 # Use store based barrier here since barrier() used a bunch of 1582 # default devices and messes up NCCL internal state. 1583 _store_based_barrier(rank, store, group_name, world_size, timeout) 1584 1585 1586def _get_split_source(pg): 1587 split_from = None 1588 if pg.bound_device_id: 1589 split_from = pg._get_backend(pg.bound_device_id) 1590 elif pg is _world.default_pg: 1591 try: 1592 split_from = pg._get_backend(torch.device("cuda")) 1593 except RuntimeError: 1594 # no cuda device associated with this backend 1595 pass 1596 1597 if not split_from or not split_from.supports_splitting: 1598 return None 1599 1600 # If necessary, find a backend to split from by peeling process 1601 # group wrappers from our potentially wrapped process group. 1602 while _GLOO_AVAILABLE and isinstance(split_from, _ProcessGroupWrapper): 1603 split_from = split_from.wrapped_pg 1604 1605 return split_from 1606 1607 1608def _shutdown_backend(pg): 1609 """ 1610 Try to shut down the backend of a process group. 1611 Currently, only ProcessGroupNCCL backend is supported. 1612 No op for other backends. 1613 """ 1614 backend = None 1615 try: 1616 backend = pg._get_backend(torch.device("cuda")) 1617 except RuntimeError: 1618 pass 1619 if is_nccl_available() and isinstance(backend, ProcessGroupNCCL): 1620 # explictly call shutdown to ensure that NCCL resources are released 1621 backend._shutdown() 1622 1623 1624def _new_process_group_helper( 1625 group_size, 1626 group_rank, 1627 global_ranks_in_group, 1628 backend, 1629 store, 1630 group_name, 1631 pg_options=None, 1632 timeout=None, 1633 pg_tag=None, 1634 device_id=None, 1635 group_desc=None, 1636): 1637 """ 1638 Create a new distributed process group. 1639 1640 This function must be called by ALL processes in the global group, even if 1641 the calling process is not part of the newly created group. In that case, 1642 this function returns GroupMember.NON_GROUP_MEMBER. 1643 1644 This function is called with ``global_ranks_in_group == []`` for the default group. 1645 """ 1646 global _world 1647 1648 if group_name in _world.pg_names.values(): 1649 raise ValueError( 1650 "The specified group name has already been " 1651 "created, please use a different group name" 1652 ) 1653 1654 if device_id is not None and (device_id.index is None or device_id.type != "cuda"): 1655 raise ValueError( 1656 "init_process_group device_id parameter must be a cuda device with an " 1657 "id, e.g. cuda:0, not just cuda or cpu" 1658 ) 1659 1660 # Note: _new_process_group_helper is only called from init_process_group, which always provides a timeout value 1661 _check_valid_timeout(timeout) 1662 1663 if pg_tag not in [None, ""]: 1664 # creating with the same tag and rank set results in the same underlying PG 1665 existing_group = _find_pg_by_ranks_and_tag(pg_tag, global_ranks_in_group) 1666 if existing_group: 1667 _, prefix_store = _world.pg_map[existing_group] 1668 return existing_group, prefix_store 1669 1670 group_desc = "undefined" if group_desc is None else group_desc 1671 1672 # The list of group ranks is empty if we're creating the default group. 1673 is_default_group = len(global_ranks_in_group) == 0 1674 1675 # nccl and potentially other backends allow creation of 1676 # communicators based on pre-existing ones, which can save 1677 # initialization time. Due to lazy initialization of 1678 # communicators in some backends, we have to be careful and only 1679 # split when we *know* the backends already are connected _on all 1680 # ranks_. We can only know this if the group we are making is the 1681 # entire world or if we have bound a device id to the world (which 1682 # causes early connection initialization). 1683 if is_initialized() and ( 1684 len(global_ranks_in_group) == _get_default_group().size() 1685 or _get_default_group().bound_device_id 1686 ): 1687 split_from = _get_split_source(_get_default_group()) 1688 else: 1689 split_from = None 1690 1691 # If this is a subgroup (which means group_ranks is specified), 1692 # we check if the current process is a member of the new group. 1693 if not is_default_group: 1694 global_rank = _get_default_group().rank() 1695 if global_rank not in global_ranks_in_group: 1696 # If we are using `ncclCommSplit` (or similar split from 1697 # other APIs) to create the communicator, we will need to 1698 # call `ncclCommSplit` on *all* ranks in this new group's 1699 # parent group, even those not in the new group. This is 1700 # a requirement of the NCCL API as otherwise we would get 1701 # out of sync. 1702 if split_from: 1703 split_from.perform_nocolor_split(_get_default_group().bound_device_id) 1704 return GroupMember.NON_GROUP_MEMBER, None 1705 1706 prefix_store = PrefixStore(f"{group_name}/", store) 1707 base_pg_options = ProcessGroup.Options(backend=str(backend)) 1708 base_pg_options._timeout = timeout 1709 pg: ProcessGroup = ProcessGroup( 1710 prefix_store, group_rank, group_size, base_pg_options 1711 ) 1712 if device_id: 1713 pg.bound_device_id = device_id 1714 backend_config = BackendConfig(backend) 1715 backend_class: torch._C._distributed_c10d.Backend 1716 for device, backend_str in backend_config.get_device_backend_map().items(): 1717 # Use the group name as prefix in the default store, such that 1718 # a single store can be reused by multiple groups. 1719 backend_prefix_store = PrefixStore(f"{device}/", prefix_store) 1720 1721 if backend_str == Backend.MPI: 1722 if not is_mpi_available(): 1723 raise RuntimeError( 1724 "Distributed package doesn't have MPI built in." 1725 " MPI is only included if you build PyTorch from" 1726 " source on a host that has MPI installed." 1727 ) 1728 backend_class = ProcessGroupMPI.create(global_ranks_in_group) 1729 backend_type = ProcessGroup.BackendType.MPI 1730 if not backend_class: 1731 return GroupMember.NON_GROUP_MEMBER, None 1732 # create new process group with accurate rank and size 1733 if pg.rank() == -1 and pg.size() == -1: 1734 pg = ProcessGroup( 1735 backend_prefix_store, 1736 backend_class.rank(), 1737 backend_class.size(), 1738 base_pg_options, 1739 ) 1740 elif backend_str == Backend.GLOO: 1741 # TODO: remove this check after lazy initialization is supported 1742 # if pg_options is not None: 1743 # raise RuntimeError("GLOO options not supported") 1744 backend_class = ProcessGroupGloo( 1745 backend_prefix_store, group_rank, group_size, timeout=timeout 1746 ) 1747 backend_type = ProcessGroup.BackendType.GLOO 1748 elif backend_str == Backend.NCCL: 1749 if not is_nccl_available(): 1750 raise RuntimeError("Distributed package doesn't have NCCL built in") 1751 if pg_options is not None: 1752 assert isinstance( 1753 pg_options, ProcessGroupNCCL.Options 1754 ), "Expected pg_options argument to be of type ProcessGroupNCCL.Options" 1755 if pg_options._timeout != timeout: 1756 warnings.warn( 1757 "pg_options._timeout was specified, " 1758 "but timeout kwarg has a default value that will always override it. " 1759 ) 1760 else: 1761 # default pg_options for NCCL 1762 pg_options = ProcessGroupNCCL.Options() 1763 pg_options.is_high_priority_stream = False 1764 pg_options._timeout = timeout 1765 1766 if split_from: 1767 pg_options.split_from = split_from 1768 pg_options.split_color = _process_group_color(global_ranks_in_group) 1769 pg_options.global_ranks_in_group = global_ranks_in_group 1770 pg_options.group_name = group_name 1771 backend_class = ProcessGroupNCCL( 1772 backend_prefix_store, group_rank, group_size, pg_options 1773 ) 1774 backend_type = ProcessGroup.BackendType.NCCL 1775 elif backend_str == Backend.UCC and is_ucc_available(): 1776 # TODO: once UCC plugin is fully deprecated, remove 1777 # is_ucc_available() from above elif-condition and raise 1778 # RuntimeError if is_ucc_available() returns false. 1779 1780 backend_class = ProcessGroupUCC( 1781 backend_prefix_store, group_rank, group_size, timeout=timeout 1782 ) 1783 backend_type = ProcessGroup.BackendType.UCC 1784 else: 1785 assert ( 1786 backend_str.upper() in Backend._plugins 1787 ), f"Unknown c10d backend type {backend_str.upper()}" 1788 1789 backend_plugin = Backend._plugins[backend_str.upper()] 1790 creator_fn = backend_plugin.creator_fn 1791 extended_api = backend_plugin.extended_api 1792 backend_type = ProcessGroup.BackendType.CUSTOM 1793 1794 if not extended_api: 1795 backend_class = creator_fn( 1796 backend_prefix_store, group_rank, group_size, timeout 1797 ) 1798 else: 1799 dist_backend_opts = _DistributedBackendOptions() 1800 dist_backend_opts.store = backend_prefix_store 1801 dist_backend_opts.group_rank = group_rank 1802 dist_backend_opts.group_size = group_size 1803 dist_backend_opts.timeout = timeout 1804 dist_backend_opts.group_id = group_name 1805 dist_backend_opts.global_ranks_in_group = global_ranks_in_group 1806 1807 backend_class = creator_fn(dist_backend_opts, pg_options) 1808 1809 # Set sequence numbers for gloo and nccl backends. 1810 if backend_str == Backend.GLOO: 1811 assert isinstance(backend_class, ProcessGroupGloo) 1812 backend_class._set_sequence_number_for_group() 1813 elif backend_str == Backend.NCCL: 1814 assert isinstance(backend_class, ProcessGroupNCCL) 1815 backend_class._set_sequence_number_for_group() 1816 1817 # If the type is a subclass of ProcessGroup then return this process group immediately 1818 # TODO: This defaults to the old behavior for PythonProcessGroups which overwrites the 1819 # ProcessGroup instance 1820 if issubclass(type(backend_class), ProcessGroup): 1821 pg = backend_class # type: ignore[assignment] 1822 break 1823 1824 # Process group wrapper initialization for supported PGs when TORCH_DISTRIBUTED_DEBUG is set 1825 if ( 1826 backend_str in [Backend.GLOO, Backend.NCCL, Backend.UCC] 1827 or backend_str.upper() in Backend._plugins 1828 ): 1829 # In debug mode and if GLOO is available, wrap in a wrapper PG that 1830 # enables enhanced collective checking for debuggability. 1831 if get_debug_level() == DebugLevel.DETAIL: 1832 if not _GLOO_AVAILABLE: 1833 logger.info( 1834 """TORCH_DISTRIBUTED_DEBUG was set to DETAIL, but 1835 GLOO is not available. Build with Gloo to 1836 create a wrapper process group in debug mode 1837 to aid collective desynchronization debugging.""" 1838 ) 1839 else: 1840 backend_class = _create_process_group_wrapper( 1841 wrapped_pg=backend_class, 1842 store_prefix=group_name, 1843 store=backend_prefix_store, 1844 rank=group_rank, 1845 world_size=group_size, 1846 timeout=timeout, 1847 ) 1848 1849 # register only a single backend when all get_device_backend_map values are the same 1850 if len(set(backend_config.get_device_backend_map().values())) == 1: 1851 for device in backend_config.get_device_backend_map().keys(): 1852 pg._register_backend(torch.device(device), backend_type, backend_class) 1853 1854 # break out of outer loop to not create any more backends 1855 break 1856 1857 pg._register_backend(torch.device(device), backend_type, backend_class) 1858 1859 # set group_name and group_dsec to backend 1860 assert group_name is not None 1861 assert group_desc is not None 1862 pg._set_group_name(group_name) 1863 pg._set_group_desc(group_desc) 1864 1865 if device_id and pg._get_backend(device_id).supports_splitting: 1866 eager_backend = pg._get_backend(device_id) 1867 eager_backend.eager_connect_single_device(device_id) 1868 1869 # update global state 1870 _world.pg_map[pg] = (backend, prefix_store) 1871 _world.pg_names[pg] = group_name 1872 _register_process_group(group_name, pg) 1873 1874 _world.pg_backend_config[pg] = str(backend_config) 1875 # "" is the default tag for user PGs 1876 if pg_tag in [None, ""]: 1877 pg_tag = f"ptd:{group_name}" 1878 _world.tags_to_pg.setdefault("", []).append(pg) 1879 else: 1880 pg_tag = f"user:{pg_tag}" 1881 1882 _world.tags_to_pg.setdefault(pg_tag, []).append(pg) 1883 _world.pg_to_tag[pg] = pg_tag 1884 return pg, prefix_store 1885 1886 1887def destroy_process_group(group: Optional[ProcessGroup] = None): 1888 """ 1889 Destroy a given process group, and deinitialize the distributed package. 1890 1891 Args: 1892 group (ProcessGroup, optional): The process group to be destroyed, if 1893 group.WORLD is given, all process 1894 groups including the default one will 1895 be destroyed. 1896 """ 1897 global _world 1898 1899 if group == GroupMember.NON_GROUP_MEMBER: 1900 return 1901 1902 if group is None: 1903 pg = GroupMember.WORLD 1904 else: 1905 pg = group 1906 1907 assert pg is not None 1908 if _world.pg_map.get(pg, None) is None: 1909 raise ValueError("Invalid process group specified") 1910 1911 # When users register Python onCompletion hooks, those hooks will run on a 1912 # different thread than the main thread. Today, the ProcessGroup dtor does 1913 # wait for that thread. However, the dtor might finish after the Python 1914 # Interpreter exits. After that grabbing the GIL for the Python hook will crash. 1915 # We can either revive the interpreter when running hooks or keep the main one 1916 # alive until all works and hooks are done. The current implementation does the 1917 # latter. Therefore, we explicitly call _wait_for_pending_works() here to wait 1918 # for the pending hooks to finish. 1919 if pg.name().lower() == "nccl" and pg._has_hooks(): 1920 pg._wait_for_pending_works() 1921 1922 if group is None or group == GroupMember.WORLD: 1923 # shutdown all backends in the order of pg names. shutting down in order because 1924 # ncclCommAbort() was a 'collective' call in some versions of NCCL. 1925 for pg_to_shutdown in sorted( 1926 _world.pg_names, key=lambda x: _world.pg_names[x], reverse=True 1927 ): 1928 _shutdown_backend(pg_to_shutdown) 1929 1930 _update_default_pg(None) 1931 _world.pg_map.clear() 1932 _world.pg_names.clear() 1933 _world.pg_group_ranks.clear() 1934 _world.pg_backend_config.clear() 1935 _world.pg_to_tag.clear() 1936 _world.tags_to_pg.clear() 1937 _world.pg_coalesce_state.clear() 1938 _world.pg_default_device.clear() 1939 _unregister_all_process_groups() 1940 1941 # when process group doesn't have an explicit name (only WORLD (default) 1942 # process group can have an explicit name), we use global _world.group_count 1943 # to generate the name. We need to reset the counter on destruction to 1944 # allow consistent value to be generated when we re-create process 1945 # groups after some trainers recover from failure 1946 # 1947 # We only reset this when WORLD is being destroyed because if this 1948 # process group is in good state, we aren't dealing with failures. 1949 _world.group_count = 0 1950 else: 1951 _shutdown_backend(pg) 1952 del _world.pg_map[pg] 1953 del _world.pg_names[pg] 1954 del _world.pg_group_ranks[pg] 1955 del _world.pg_backend_config[pg] 1956 if pg in _world.pg_default_device: 1957 del _world.pg_default_device[pg] 1958 if pg in _world.pg_coalesce_state.keys(): 1959 warnings.warn( 1960 "Some coalesced collectives haven't been launched when " 1961 "ProcessGroup is destroyed. They will be cleaned." 1962 ) 1963 del _world.pg_coalesce_state[pg] 1964 1965 tag = _world.pg_to_tag.get(pg) 1966 del _world.pg_to_tag[pg] 1967 if tag is not None: 1968 try: 1969 _world.tags_to_pg[tag].remove(pg) 1970 if tag.startswith("ptd:"): 1971 _world.tags_to_pg[""].remove(pg) 1972 except Exception: 1973 pass 1974 _unregister_process_group(pg.group_name) 1975 1976 1977def get_rank(group: Optional[ProcessGroup] = None) -> int: 1978 """ 1979 Return the rank of the current process in the provided ``group``, default otherwise. 1980 1981 Rank is a unique identifier assigned to each process within a distributed 1982 process group. They are always consecutive integers ranging from 0 to 1983 ``world_size``. 1984 1985 Args: 1986 group (ProcessGroup, optional): The process group to work on. If None, 1987 the default process group will be used. 1988 1989 Returns: 1990 The rank of the process group 1991 -1, if not part of the group 1992 1993 """ 1994 if _rank_not_in_group(group): 1995 return -1 1996 1997 default_pg = _get_default_group() 1998 if group is None or group is GroupMember.WORLD: 1999 return default_pg.rank() 2000 2001 return get_group_rank(group, default_pg.rank()) 2002 2003 2004def get_world_size(group: Optional[ProcessGroup] = None) -> int: 2005 """ 2006 Return the number of processes in the current process group. 2007 2008 Args: 2009 group (ProcessGroup, optional): The process group to work on. If None, 2010 the default process group will be used. 2011 2012 Returns: 2013 The world size of the process group 2014 -1, if not part of the group 2015 2016 """ 2017 if _rank_not_in_group(group): 2018 return -1 2019 2020 return _get_group_size(group) 2021 2022 2023def isend( 2024 tensor: torch.Tensor, dst: int, group: Optional[ProcessGroup] = None, tag: int = 0 2025) -> Optional[Work]: 2026 """ 2027 Send a tensor asynchronously. 2028 2029 .. warning:: 2030 Modifying ``tensor`` before the request completes causes undefined 2031 behavior. 2032 2033 .. warning:: 2034 ``tag`` is not supported with the NCCL backend. 2035 2036 Args: 2037 tensor (Tensor): Tensor to send. 2038 dst (int): Destination rank on global process group (regardless of ``group`` argument) 2039 group (ProcessGroup, optional): The process group to work on. If None, 2040 the default process group will be used. 2041 tag (int, optional): Tag to match send with remote recv 2042 2043 Returns: 2044 A distributed request object. 2045 None, if not part of the group 2046 2047 """ 2048 _check_single_tensor(tensor, "tensor") 2049 if _rank_not_in_group(group): 2050 _warn_not_in_group("isend") 2051 return None 2052 2053 if tensor.is_complex(): 2054 tensor = torch.view_as_real(tensor) 2055 2056 if group is None or group is GroupMember.WORLD: 2057 pg = _get_default_group() 2058 else: 2059 pg = group 2060 dst = get_group_rank(pg, dst) 2061 2062 return pg.send([tensor], dst, tag) 2063 2064 2065def irecv( 2066 tensor: torch.Tensor, 2067 src: Optional[int] = None, 2068 group: Optional[ProcessGroup] = None, 2069 tag: int = 0, 2070) -> Optional[Work]: 2071 """ 2072 Receives a tensor asynchronously. 2073 2074 .. warning:: 2075 ``tag`` is not supported with the NCCL backend. 2076 2077 Args: 2078 tensor (Tensor): Tensor to fill with received data. 2079 src (int, optional): Source rank on global process group (regardless of ``group`` argument). 2080 Will receive from any process if unspecified. 2081 group (ProcessGroup, optional): The process group to work on. If None, 2082 the default process group will be used. 2083 tag (int, optional): Tag to match recv with remote send 2084 2085 Returns: 2086 A distributed request object. 2087 None, if not part of the group 2088 2089 """ 2090 _check_single_tensor(tensor, "tensor") 2091 if _rank_not_in_group(group): 2092 _warn_not_in_group("irecv") 2093 return None 2094 2095 if tensor.is_complex(): 2096 tensor = torch.view_as_real(tensor) 2097 2098 if group is None or group is GroupMember.WORLD: 2099 pg = _get_default_group() 2100 else: 2101 pg = group 2102 2103 if src is None: 2104 return pg.recv_anysource([tensor], tag) 2105 else: 2106 if pg is GroupMember.WORLD: 2107 return pg.recv([tensor], src, tag) 2108 else: 2109 group_src_rank = get_group_rank(pg, src) 2110 return pg.recv([tensor], group_src_rank, tag) 2111 2112 2113@_exception_logger 2114def send( 2115 tensor: torch.Tensor, dst: int, group: Optional[ProcessGroup] = None, tag: int = 0 2116) -> None: 2117 """ 2118 Send a tensor synchronously. 2119 2120 .. warning:: 2121 ``tag`` is not supported with the NCCL backend. 2122 2123 Args: 2124 tensor (Tensor): Tensor to send. 2125 dst (int): Destination rank on global process group (regardless of ``group`` argument). 2126 Destination rank should not be the same as the rank of the current process. 2127 group (ProcessGroup, optional): The process group to work on. If None, 2128 the default process group will be used. 2129 tag (int, optional): Tag to match send with remote recv 2130 2131 """ 2132 if get_rank() == dst: 2133 raise ValueError( 2134 "Invalid destination rank: destination rank should not be the same as " 2135 "the rank of the current process." 2136 ) 2137 2138 _check_single_tensor(tensor, "tensor") 2139 if _rank_not_in_group(group): 2140 _warn_not_in_group("send") 2141 return None 2142 2143 if tensor.is_complex(): 2144 tensor = torch.view_as_real(tensor) 2145 2146 if group is None or group is GroupMember.WORLD: 2147 default_pg = _get_default_group() 2148 default_pg.send([tensor], dst, tag).wait() 2149 else: 2150 group_dst_rank = get_group_rank(group, dst) 2151 group.send([tensor], group_dst_rank, tag).wait() 2152 2153 2154@_exception_logger 2155def recv( 2156 tensor: torch.Tensor, 2157 src: Optional[int] = None, 2158 group: Optional[ProcessGroup] = None, 2159 tag: int = 0, 2160) -> int: 2161 """ 2162 Receives a tensor synchronously. 2163 2164 .. warning:: 2165 ``tag`` is not supported with the NCCL backend. 2166 2167 Args: 2168 tensor (Tensor): Tensor to fill with received data. 2169 src (int, optional): Source rank on global process group (regardless of ``group`` argument). 2170 Will receive from any process if unspecified. 2171 group (ProcessGroup, optional): The process group to work on. If None, 2172 the default process group will be used. 2173 tag (int, optional): Tag to match recv with remote send 2174 2175 Returns: 2176 Sender rank 2177 -1, if not part of the group 2178 2179 """ 2180 _check_single_tensor(tensor, "tensor") 2181 if _rank_not_in_group(group): 2182 _warn_not_in_group("recv") 2183 return -1 2184 2185 if tensor.is_complex(): 2186 tensor = torch.view_as_real(tensor) 2187 2188 pg = group or _get_default_group() 2189 2190 if src is None: 2191 work = pg.recv_anysource([tensor], tag) 2192 work.wait() 2193 src_rank = work._source_rank() 2194 if group is None or group is GroupMember.WORLD: 2195 return src_rank 2196 else: 2197 return get_global_rank(pg, src_rank) 2198 else: 2199 if group is None or group is GroupMember.WORLD: 2200 pg.recv([tensor], src, tag).wait() 2201 else: 2202 group_src_rank = get_group_rank(pg, src) 2203 pg.recv([tensor], group_src_rank, tag).wait() 2204 return src 2205 2206 2207class _IllegalWork(Work): 2208 def __getattribute__(self, name): 2209 if name in [ 2210 "is_success", 2211 "exception", 2212 "wait", 2213 "source_rank", 2214 "_source_rank", 2215 "result", 2216 "synchronize", 2217 ]: 2218 raise ValueError(f"Illegal to call {name} on IllegalWork object") 2219 2220 2221class _CoalescingManager: 2222 def __init__(self) -> None: 2223 self.works: List[Work] = [] 2224 2225 def append(self, work: Work): 2226 if work: 2227 self.works.append(work) 2228 2229 def wait(self): 2230 for work in self.works: 2231 work.wait() 2232 2233 2234@contextlib.contextmanager 2235def _coalescing_manager( 2236 group: Optional[ProcessGroup] = None, 2237 device: Optional[torch.device] = None, 2238 async_ops: Optional[bool] = False, 2239): 2240 """ 2241 Context manager used to coalesce collectives or P2P operations when possible. 2242 2243 Args: 2244 group (`ProcessGroup`, optional): The process group to work on. If None, 2245 the default process group will be used. 2246 device (`torch.device`, optional): Default is None, set to a device if 2247 there isn't a `**_coalesced` implementation by the backend. 2248 async_ops (`bool`, optional): whether the coalesced ops are async ops. 2249 2250 Examples: 2251 >>> # xdoctest: +SKIP("no rank") 2252 >>> # Synchronous ops 2253 >>> with _coalescing_manager(): 2254 >>> for i in range(num_colls): 2255 >>> dist.all_reduce(tensors[i]) 2256 >>> # Asynchronous ops 2257 >>> with _coalescing_manager(async_ops=True) as cm: 2258 >>> for i in range(num_colls): 2259 >>> dist.all_reduce(tensors[i]) 2260 >>> cm.wait() 2261 2262 .. warning:: 2263 :func:`_coalescing_manager` currently do not support coalescing 2264 all-reduces with different reduce operators, e.g. `ReduceOp.SUM` mixed 2265 with `ReduceOp.PRODUCT`. 2266 """ 2267 group = group or _get_default_group() 2268 op_list = _world.pg_coalesce_state.setdefault(group, []) 2269 if op_list: 2270 raise ValueError( 2271 "ProcessGroup has non-empty op list at the start of coalescing" 2272 ) 2273 if device: 2274 group._start_coalescing(device) 2275 cm = _CoalescingManager() 2276 yield cm 2277 op_list = _world.pg_coalesce_state.pop(group) 2278 if op_list: 2279 # Collectives supporting "Fast Path" coalescing are captured. 2280 # See implementation in corresponding collective APIs. 2281 # Currently supported: 2282 # - coalesced `all_reduce` 2283 # - coalesced `all_gather_into_tensor` 2284 # - coalesced `reduce_scatter_tensor` 2285 op0 = op_list[0].op 2286 if op0 == all_reduce: 2287 tensors = [] 2288 for op in op_list: 2289 tensors.append(op.tensor) 2290 all_reduce_opts = AllreduceCoalescedOptions() 2291 all_reduce_opts.reduceOp = not_none(op_list[0].redop) 2292 work = group.allreduce_coalesced(tensors, all_reduce_opts) 2293 elif op0 == all_gather_into_tensor: 2294 inputs = [] 2295 outputs = [] 2296 for op in op_list: 2297 inputs.append(op.tensor) 2298 outputs.append(not_none(op.dst_tensor)) 2299 work = group.allgather_into_tensor_coalesced(outputs, inputs) 2300 elif op0 == reduce_scatter_tensor: 2301 inputs = [] 2302 outputs = [] 2303 for op in op_list: 2304 inputs.append(op.tensor) 2305 outputs.append(not_none(op.dst_tensor)) 2306 reduce_opts = ReduceScatterOptions() 2307 reduce_opts.reduceOp = not_none(op_list[0].redop) 2308 work = group.reduce_scatter_tensor_coalesced(outputs, inputs, reduce_opts) 2309 else: 2310 raise AssertionError( 2311 f"Coalescing manager does not support fast-path coalescing of {op0}, " 2312 f"yet {op0} is still recorded in op list. This is an internal error of c10d." 2313 ) 2314 2315 if device: 2316 # Old style of letting each coll inside the context manager to call into C++ counterpart via python binding 2317 work = group._end_coalescing(device) 2318 2319 if async_ops: 2320 cm.append(work) # type: ignore[possibly-undefined] 2321 else: 2322 work.wait() # type: ignore[possibly-undefined] 2323 2324 2325def batch_isend_irecv(p2p_op_list): 2326 """ 2327 Send or Receive a batch of tensors asynchronously and return a list of requests. 2328 2329 Process each of the operations in ``p2p_op_list`` and return the corresponding 2330 requests. NCCL, Gloo, and UCC backend are currently supported. 2331 2332 Args: 2333 p2p_op_list: A list of point-to-point operations(type of each operator is 2334 ``torch.distributed.P2POp``). The order of the isend/irecv in the list 2335 matters and it needs to match with corresponding isend/irecv on the 2336 remote end. 2337 2338 Returns: 2339 A list of distributed request objects returned by calling the corresponding 2340 op in the op_list. 2341 2342 Examples: 2343 >>> # xdoctest: +SKIP("no rank") 2344 >>> send_tensor = torch.arange(2, dtype=torch.float32) + 2 * rank 2345 >>> recv_tensor = torch.randn(2, dtype=torch.float32) 2346 >>> send_op = dist.P2POp(dist.isend, send_tensor, (rank + 1)%world_size) 2347 >>> recv_op = dist.P2POp(dist.irecv, recv_tensor, (rank - 1 + world_size)%world_size) 2348 >>> reqs = batch_isend_irecv([send_op, recv_op]) 2349 >>> for req in reqs: 2350 >>> req.wait() 2351 >>> recv_tensor 2352 tensor([2, 3]) # Rank 0 2353 tensor([0, 1]) # Rank 1 2354 2355 .. note:: Note that when this API is used with the NCCL PG backend, users must set 2356 the current GPU device with `torch.cuda.set_device`, otherwise it will 2357 lead to unexpected hang issues. 2358 2359 In addition, if this API is the first collective call in the ``group`` 2360 passed to ``dist.P2POp``, all ranks of the ``group`` must participate in 2361 this API call; otherwise, the behavior is undefined. If this API call is 2362 not the first collective call in the ``group``, batched P2P operations 2363 involving only a subset of ranks of the ``group`` are allowed. 2364 """ 2365 _check_p2p_op_list(p2p_op_list) 2366 group = p2p_op_list[0].group 2367 device = p2p_op_list[0].tensor.device 2368 if device.type == "cuda": 2369 # NCCL style coalescing 2370 with _coalescing_manager(group, device, async_ops=True) as cm: 2371 for p2p_op in p2p_op_list: 2372 p2p_op.op(p2p_op.tensor, p2p_op.peer, p2p_op.group, p2p_op.tag) 2373 return cm.works 2374 else: 2375 # Backward support for Gloo 2376 reqs = [] 2377 for p2p_op in p2p_op_list: 2378 work = p2p_op.op(p2p_op.tensor, p2p_op.peer, p2p_op.group, p2p_op.tag) 2379 if work: 2380 reqs.append(work) 2381 return reqs 2382 2383 2384@_exception_logger 2385def broadcast(tensor, src, group=None, async_op=False): 2386 """ 2387 Broadcasts the tensor to the whole group. 2388 2389 ``tensor`` must have the same number of elements in all processes 2390 participating in the collective. 2391 2392 Args: 2393 tensor (Tensor): Data to be sent if ``src`` is the rank of current 2394 process, and tensor to be used to save received data otherwise. 2395 src (int): Source rank on global process group (regardless of ``group`` argument). 2396 group (ProcessGroup, optional): The process group to work on. If None, 2397 the default process group will be used. 2398 async_op (bool, optional): Whether this op should be an async op 2399 2400 Returns: 2401 Async work handle, if async_op is set to True. 2402 None, if not async_op or if not part of the group 2403 2404 """ 2405 _check_single_tensor(tensor, "tensor") 2406 if _rank_not_in_group(group): 2407 _warn_not_in_group("broadcast") 2408 return 2409 2410 opts = BroadcastOptions() 2411 opts.rootRank = src 2412 opts.rootTensor = 0 2413 opts.asyncOp = async_op 2414 2415 if group is None or group is GroupMember.WORLD: 2416 default_pg = _get_default_group() 2417 work = default_pg.broadcast([tensor], opts) 2418 else: 2419 group_src_rank = get_group_rank(group, src) 2420 opts.rootRank = group_src_rank 2421 work = group.broadcast([tensor], opts) 2422 if async_op: 2423 return work 2424 else: 2425 work.wait() 2426 2427 2428@_exception_logger 2429def all_reduce(tensor, op=ReduceOp.SUM, group=None, async_op=False): 2430 """ 2431 Reduces the tensor data across all machines in a way that all get the final result. 2432 2433 After the call ``tensor`` is going to be bitwise identical in all processes. 2434 2435 Complex tensors are supported. 2436 2437 Args: 2438 tensor (Tensor): Input and output of the collective. The function 2439 operates in-place. 2440 op (optional): One of the values from 2441 ``torch.distributed.ReduceOp`` 2442 enum. Specifies an operation used for element-wise reductions. 2443 group (ProcessGroup, optional): The process group to work on. If None, 2444 the default process group will be used. 2445 async_op (bool, optional): Whether this op should be an async op 2446 2447 Returns: 2448 Async work handle, if async_op is set to True. 2449 None, if not async_op or if not part of the group 2450 2451 Examples: 2452 >>> # xdoctest: +SKIP("no rank") 2453 >>> # All tensors below are of torch.int64 type. 2454 >>> # We have 2 process groups, 2 ranks. 2455 >>> device = torch.device(f'cuda:{rank}') 2456 >>> tensor = torch.arange(2, dtype=torch.int64, device=device) + 1 + 2 * rank 2457 >>> tensor 2458 tensor([1, 2], device='cuda:0') # Rank 0 2459 tensor([3, 4], device='cuda:1') # Rank 1 2460 >>> dist.all_reduce(tensor, op=ReduceOp.SUM) 2461 >>> tensor 2462 tensor([4, 6], device='cuda:0') # Rank 0 2463 tensor([4, 6], device='cuda:1') # Rank 1 2464 2465 >>> # All tensors below are of torch.cfloat type. 2466 >>> # We have 2 process groups, 2 ranks. 2467 >>> tensor = torch.tensor([1+1j, 2+2j], dtype=torch.cfloat, device=device) + 2 * rank * (1+1j) 2468 >>> tensor 2469 tensor([1.+1.j, 2.+2.j], device='cuda:0') # Rank 0 2470 tensor([3.+3.j, 4.+4.j], device='cuda:1') # Rank 1 2471 >>> dist.all_reduce(tensor, op=ReduceOp.SUM) 2472 >>> tensor 2473 tensor([4.+4.j, 6.+6.j], device='cuda:0') # Rank 0 2474 tensor([4.+4.j, 6.+6.j], device='cuda:1') # Rank 1 2475 2476 """ 2477 _check_single_tensor(tensor, "tensor") 2478 if _rank_not_in_group(group): 2479 _warn_not_in_group("all_reduce") 2480 return 2481 2482 if tensor.is_complex(): 2483 if not supports_complex(op): 2484 raise ValueError(f"all_reduce does not support {op} on complex tensors") 2485 tensor = torch.view_as_real(tensor) 2486 2487 opts = AllreduceOptions() 2488 opts.reduceOp = op 2489 if group is None: 2490 group = _get_default_group() 2491 2492 if group in _world.pg_coalesce_state.keys(): 2493 # We are in coalescing context, do not issue single operation, just append a collective representation 2494 coll = _CollOp(all_reduce, tensor, None, op, None) 2495 _world.pg_coalesce_state[group].append(coll) 2496 if async_op: 2497 return _IllegalWork() 2498 else: 2499 return None 2500 2501 work = group.allreduce([tensor], opts) 2502 2503 if async_op: 2504 return work 2505 else: 2506 work.wait() 2507 2508 2509@_exception_logger 2510@deprecated( 2511 "`torch.distributed.all_reduce_coalesced` will be deprecated. If you must " 2512 "use it, please revisit our documentation later at " 2513 "https://pytorch.org/docs/main/distributed.html#collective-functions", 2514 category=FutureWarning, 2515) 2516def all_reduce_coalesced(tensors, op=ReduceOp.SUM, group=None, async_op=False): 2517 """ 2518 WARNING: at this time individual shape checking is not implemented across nodes. 2519 2520 For example, if the rank 0 node passes [torch.rand(4), torch.rand(2)] and the 2521 rank 1 node passes [torch.rand(2), torch.rand(2), torch.rand(2)], the allreduce 2522 operation will proceed without complaint and return erroneous outputs. This lack 2523 of shape checking results in significant performance improvements but users of this 2524 function should take extra care to ensure that each node passes in tensors whose 2525 shapes match across nodes. 2526 2527 Reduces each tensor in tensors (residing on the same device) across all machines 2528 in such a way that all get the final result. 2529 2530 After the call each tensor in tensors is going to bitwise identical 2531 in all processes. 2532 2533 Complex tensors are supported. 2534 2535 Args: 2536 tensors (Union[List[Tensor], Tensor]): Input and output of the collective. 2537 The function operates in-place. 2538 op (Optional[ReduceOp]): One of the values from 2539 ``torch.distributed.ReduceOp`` enum. Specifies an operation used for 2540 element-wise reductions. 2541 group (ProcessGroup, optional): The process group to work on. If None, 2542 the default process group will be used. 2543 async_op (Optional[bool]): Whether this op should be an async op. 2544 2545 Returns: 2546 Async work handle, if async_op is set to True. 2547 None, if not async_op or if not part of the group. 2548 2549 """ 2550 if isinstance(tensors, torch.Tensor): 2551 tensors = [tensors] 2552 _check_tensor_list(tensors, "tensor") 2553 _ensure_all_tensors_same_dtype(tensors) 2554 if _rank_not_in_group(group): 2555 _warn_not_in_group("all_reduce_coalesced") 2556 return 2557 2558 if any(t.is_complex() for t in tensors) and not supports_complex(op): 2559 raise ValueError(f"all_reduce does not support {op} on complex tensors") 2560 2561 tensors = [t if not t.is_complex() else torch.view_as_real(t) for t in tensors] 2562 2563 opts = AllreduceCoalescedOptions() 2564 opts.reduceOp = op 2565 group = group or _get_default_group() 2566 work = group.allreduce_coalesced(tensors, opts) 2567 2568 if async_op: 2569 return work.get_future() 2570 else: 2571 work.wait() 2572 2573 2574@_exception_logger 2575def reduce(tensor, dst, op=ReduceOp.SUM, group=None, async_op=False): 2576 """ 2577 Reduces the tensor data across all machines. 2578 2579 Only the process with rank ``dst`` is going to receive the final result. 2580 2581 Args: 2582 tensor (Tensor): Input and output of the collective. The function 2583 operates in-place. 2584 dst (int): Destination rank on global process group (regardless of ``group`` argument) 2585 op (optional): One of the values from 2586 ``torch.distributed.ReduceOp`` 2587 enum. Specifies an operation used for element-wise reductions. 2588 group (ProcessGroup, optional): The process group to work on. If None, 2589 the default process group will be used. 2590 async_op (bool, optional): Whether this op should be an async op 2591 2592 Returns: 2593 Async work handle, if async_op is set to True. 2594 None, if not async_op or if not part of the group 2595 2596 """ 2597 _check_single_tensor(tensor, "tensor") 2598 if _rank_not_in_group(group): 2599 _warn_not_in_group("reduce") 2600 return 2601 2602 opts = ReduceOptions() 2603 opts.reduceOp = op 2604 opts.rootRank = dst 2605 2606 if group is None or group is GroupMember.WORLD: 2607 default_pg = _get_default_group() 2608 work = default_pg.reduce([tensor], opts) 2609 else: 2610 group_dst_rank = get_group_rank(group, dst) 2611 opts.rootRank = group_dst_rank 2612 work = group.reduce([tensor], opts) 2613 2614 if async_op: 2615 return work 2616 else: 2617 work.wait() 2618 2619 2620def _object_to_tensor(obj, device, group): 2621 f = io.BytesIO() 2622 _pickler(f).dump(obj) 2623 byte_storage = torch.ByteStorage._from_buffer(f.getvalue()) # type: ignore[attr-defined] 2624 # Do not replace `torch.ByteTensor` or `torch.LongTensor` with torch.tensor and specifying dtype. 2625 # Otherwise, it will casue 100X slowdown. 2626 # See: https://github.com/pytorch/pytorch/issues/65696 2627 byte_tensor = torch.ByteTensor(byte_storage).to(device) 2628 if get_debug_level() == DebugLevel.DETAIL and is_nccl_available(): 2629 backend = get_backend(group) 2630 if backend == Backend.NCCL: 2631 hash = torch._C._distributed_c10d._hash_tensors([byte_tensor]) 2632 logger.warning( 2633 "_object_to_tensor size: %s hash value: %s", byte_tensor.numel(), hash 2634 ) 2635 local_size = torch.LongTensor([byte_tensor.numel()]).to(device) 2636 return byte_tensor, local_size 2637 2638 2639def _tensor_to_object(tensor, tensor_size, group): 2640 if get_debug_level() == DebugLevel.DETAIL and is_nccl_available(): 2641 backend = get_backend(group) 2642 if backend == Backend.NCCL: 2643 hash = torch._C._distributed_c10d._hash_tensors([tensor]) 2644 logger.warning( 2645 "_tensor_to_object size: %s hash value: %s", tensor.numel(), hash 2646 ) 2647 tensor = tensor.cpu() 2648 buf = tensor.numpy().tobytes()[:tensor_size] 2649 return _unpickler(io.BytesIO(buf)).load() 2650 2651 2652@_exception_logger 2653def all_gather_object(object_list, obj, group=None): 2654 """ 2655 Gathers picklable objects from the whole group into a list. 2656 2657 Similar to :func:`all_gather`, but Python objects can be passed in. 2658 Note that the object must be picklable in order to be gathered. 2659 2660 Args: 2661 object_list (list[Any]): Output list. It should be correctly sized as the 2662 size of the group for this collective and will contain the output. 2663 obj (Any): Pickable Python object to be broadcast from current process. 2664 group (ProcessGroup, optional): The process group to work on. If None, 2665 the default process group will be used. Default is ``None``. 2666 2667 Returns: 2668 None. If the calling rank is part of this group, the output of the 2669 collective will be populated into the input ``object_list``. If the 2670 calling rank is not part of the group, the passed in ``object_list`` will 2671 be unmodified. 2672 2673 .. note:: Note that this API differs slightly from the :func:`all_gather` 2674 collective since it does not provide an ``async_op`` handle and thus 2675 will be a blocking call. 2676 2677 .. note:: For NCCL-based processed groups, internal tensor representations 2678 of objects must be moved to the GPU device before communication takes 2679 place. In this case, the device used is given by 2680 ``torch.cuda.current_device()`` and it is the user's responsiblity to 2681 ensure that this is set so that each rank has an individual GPU, via 2682 ``torch.cuda.set_device()``. 2683 2684 .. warning:: 2685 :func:`all_gather_object` uses ``pickle`` module implicitly, which is 2686 known to be insecure. It is possible to construct malicious pickle data 2687 which will execute arbitrary code during unpickling. Only call this 2688 function with data you trust. 2689 2690 .. warning:: 2691 Calling :func:`all_gather_object` with GPU tensors is not well supported 2692 and inefficient as it incurs GPU -> CPU transfer since tensors would be 2693 pickled. Please consider using :func:`all_gather` instead. 2694 2695 Example:: 2696 >>> # xdoctest: +SKIP("need process group init") 2697 >>> # Note: Process group initialization omitted on each rank. 2698 >>> import torch.distributed as dist 2699 >>> # Assumes world_size of 3. 2700 >>> gather_objects = ["foo", 12, {1: 2}] # any picklable object 2701 >>> output = [None for _ in gather_objects] 2702 >>> dist.all_gather_object(output, gather_objects[dist.get_rank()]) 2703 >>> output 2704 ['foo', 12, {1: 2}] 2705 """ 2706 if _rank_not_in_group(group): 2707 _warn_not_in_group("all_gather_object") 2708 return 2709 2710 current_device = _get_pg_default_device(group) 2711 input_tensor, local_size = _object_to_tensor(obj, current_device, group) 2712 2713 # Gather all local sizes. This is so that we can find the max size, and index 2714 # until the correct size when deserializing the tensors. 2715 group_size = get_world_size(group=group) 2716 object_sizes_tensor = torch.zeros( 2717 group_size, dtype=torch.long, device=current_device 2718 ) 2719 object_size_list = [ 2720 object_sizes_tensor[i].unsqueeze(dim=0) for i in range(group_size) 2721 ] 2722 # Allgather tensor sizes 2723 all_gather(object_size_list, local_size, group=group) 2724 max_object_size = int(max(object_size_list).item()) # type: ignore[type-var] 2725 # Resize tensor to max size across all ranks. 2726 input_tensor.resize_(max_object_size) 2727 coalesced_output_tensor = torch.empty( 2728 max_object_size * group_size, dtype=torch.uint8, device=current_device 2729 ) 2730 # Output tensors are nonoverlapping views of coalesced_output_tensor 2731 output_tensors = [ 2732 coalesced_output_tensor[max_object_size * i : max_object_size * (i + 1)] 2733 for i in range(group_size) 2734 ] 2735 all_gather(output_tensors, input_tensor, group=group) 2736 # Deserialize outputs back to object. 2737 for i, tensor in enumerate(output_tensors): 2738 tensor = tensor.type(torch.uint8) 2739 tensor_size = object_size_list[i] 2740 object_list[i] = _tensor_to_object(tensor, tensor_size, group) 2741 2742 2743@_exception_logger 2744def gather_object(obj, object_gather_list=None, dst=0, group=None): 2745 """ 2746 Gathers picklable objects from the whole group in a single process. 2747 2748 Similar to :func:`gather`, but Python objects can be passed in. Note that the 2749 object must be picklable in order to be gathered. 2750 2751 Args: 2752 obj (Any): Input object. Must be picklable. 2753 object_gather_list (list[Any]): Output list. On the ``dst`` rank, it 2754 should be correctly sized as the size of the group for this 2755 collective and will contain the output. Must be ``None`` on non-dst 2756 ranks. (default is ``None``) 2757 dst (int, optional): Destination rank on global process group (regardless of ``group`` argument). (default is 0) 2758 group: (ProcessGroup, optional): The process group to work on. If None, 2759 the default process group will be used. Default is ``None``. 2760 2761 Returns: 2762 None. On the ``dst`` rank, ``object_gather_list`` will contain the 2763 output of the collective. 2764 2765 .. note:: Note that this API differs slightly from the gather collective 2766 since it does not provide an async_op handle and thus will be a blocking 2767 call. 2768 2769 .. note:: For NCCL-based processed groups, internal tensor representations 2770 of objects must be moved to the GPU device before communication takes 2771 place. In this case, the device used is given by 2772 ``torch.cuda.current_device()`` and it is the user's responsiblity to 2773 ensure that this is set so that each rank has an individual GPU, via 2774 ``torch.cuda.set_device()``. 2775 2776 .. warning:: 2777 :func:`gather_object` uses ``pickle`` module implicitly, which is 2778 known to be insecure. It is possible to construct malicious pickle data 2779 which will execute arbitrary code during unpickling. Only call this 2780 function with data you trust. 2781 2782 .. warning:: 2783 Calling :func:`gather_object` with GPU tensors is not well supported 2784 and inefficient as it incurs GPU -> CPU transfer since tensors would be 2785 pickled. Please consider using :func:`gather` instead. 2786 2787 Example:: 2788 >>> # xdoctest: +SKIP("need process group init") 2789 >>> # Note: Process group initialization omitted on each rank. 2790 >>> import torch.distributed as dist 2791 >>> # Assumes world_size of 3. 2792 >>> gather_objects = ["foo", 12, {1: 2}] # any picklable object 2793 >>> output = [None for _ in gather_objects] 2794 >>> dist.gather_object( 2795 ... gather_objects[dist.get_rank()], 2796 ... output if dist.get_rank() == 0 else None, 2797 ... dst=0 2798 ... ) 2799 >>> # On rank 0 2800 >>> output 2801 ['foo', 12, {1: 2}] 2802 """ 2803 if _rank_not_in_group(group): 2804 _warn_not_in_group("gather_object") 2805 return 2806 2807 # Ensure object_gather_list is specified appropriately. 2808 my_rank = get_rank() 2809 _validate_output_list_for_rank(my_rank, dst, object_gather_list) 2810 current_device = _get_pg_default_device(group) 2811 input_tensor, local_size = _object_to_tensor(obj, current_device, group) 2812 2813 # Gather all local sizes. This is so that we can find the max size, and index 2814 # until the correct size when deserializing the tensors. 2815 group_size = get_world_size(group=group) 2816 object_sizes_tensor = torch.zeros( 2817 group_size, dtype=torch.long, device=current_device 2818 ) 2819 object_size_list = [ 2820 object_sizes_tensor[i].unsqueeze(dim=0) for i in range(group_size) 2821 ] 2822 # Allgather tensor sizes. An all-gather is needed here despite this being a 2823 # gather, since each rank needs to broadcast a tensor of the same (maximal) 2824 # size. 2825 all_gather(object_size_list, local_size, group=group) 2826 max_object_size = int(max(object_size_list).item()) # type: ignore[type-var] 2827 # Resize tensor to max size across all ranks. 2828 input_tensor.resize_(max_object_size) 2829 # Avoid populating output tensors if the result won't be gathered on this rank. 2830 if my_rank == dst: 2831 coalesced_output_tensor = torch.empty( 2832 max_object_size * group_size, dtype=torch.uint8, device=current_device 2833 ) 2834 # Output tensors are nonoverlapping views of coalesced_output_tensor 2835 output_tensors = [ 2836 coalesced_output_tensor[max_object_size * i : max_object_size * (i + 1)] 2837 for i in range(group_size) 2838 ] 2839 # All ranks call gather with equal-sized tensors. 2840 gather( 2841 input_tensor, 2842 gather_list=output_tensors if my_rank == dst else None, # type: ignore[possibly-undefined] 2843 dst=dst, 2844 group=group, 2845 ) 2846 if my_rank != dst: 2847 return 2848 for i, tensor in enumerate(output_tensors): 2849 tensor = tensor.type(torch.uint8) 2850 tensor_size = object_size_list[i] 2851 object_gather_list[i] = _tensor_to_object(tensor, tensor_size, group) 2852 2853 2854@_exception_logger 2855def send_object_list(object_list, dst, group=None, device=None): 2856 """ 2857 Sends picklable objects in ``object_list`` synchronously. 2858 2859 Similar to :func:`send`, but Python objects can be passed in. 2860 Note that all objects in ``object_list`` must be picklable in order to be 2861 sent. 2862 2863 Args: 2864 object_list (List[Any]): List of input objects to sent. 2865 Each object must be picklable. Receiver must provide lists of equal sizes. 2866 dst (int): Destination rank to send ``object_list`` to. 2867 Destination rank is based on global process group (regardless of ``group`` argument) 2868 group: (ProcessGroup, optional): The process group to work on. If None, 2869 the default process group will be used. Default is ``None``. 2870 device (``torch.device``, optional): If not None, the objects are 2871 serialized and converted to tensors which are moved to the 2872 ``device`` before sending. Default is ``None``. 2873 2874 Returns: 2875 ``None``. 2876 2877 .. note:: For NCCL-based process groups, internal tensor representations 2878 of objects must be moved to the GPU device before communication takes 2879 place. In this case, the device used is given by 2880 ``torch.cuda.current_device()`` and it is the user's responsibility to 2881 ensure that this is set so that each rank has an individual GPU, via 2882 ``torch.cuda.set_device()``. 2883 2884 .. warning:: 2885 :func:`send_object_list` uses ``pickle`` module implicitly, which 2886 is known to be insecure. It is possible to construct malicious pickle 2887 data which will execute arbitrary code during unpickling. Only call this 2888 function with data you trust. 2889 2890 .. warning:: 2891 Calling :func:`send_object_list` with GPU tensors is not well supported 2892 and inefficient as it incurs GPU -> CPU transfer since tensors would be 2893 pickled. Please consider using :func:`send` instead. 2894 2895 Example:: 2896 >>> # xdoctest: +SKIP("need process group init") 2897 >>> # Note: Process group initialization omitted on each rank. 2898 >>> import torch.distributed as dist 2899 >>> # Assumes backend is not NCCL 2900 >>> device = torch.device("cpu") 2901 >>> if dist.get_rank() == 0: 2902 >>> # Assumes world_size of 2. 2903 >>> objects = ["foo", 12, {1: 2}] # any picklable object 2904 >>> dist.send_object_list(objects, dst=1, device=device) 2905 >>> else: 2906 >>> objects = [None, None, None] 2907 >>> dist.recv_object_list(objects, src=0, device=device) 2908 >>> objects 2909 ['foo', 12, {1: 2}] 2910 """ 2911 if get_rank() == dst: 2912 raise ValueError( 2913 "Invalid destination rank: destination rank should not be the same as " 2914 "the rank of the current process." 2915 ) 2916 2917 if _rank_not_in_group(group): 2918 _warn_not_in_group("send_object_list") 2919 return 2920 2921 # Current device selection. 2922 # To preserve backwards compatibility, ``device`` is default to ``None`` 2923 # in which case we run current logic of device selection, i.e. 2924 # ``current_device`` is CUDA if backend is NCCL otherwise CPU device. In the 2925 # case it is not ``None`` we move the size and object tensors to be 2926 # sent to this device. 2927 current_device = device or _get_pg_default_device(group) 2928 # Serialize object_list elements to tensors on src rank. 2929 tensor_list, size_list = zip( 2930 *[_object_to_tensor(obj, current_device, group) for obj in object_list] 2931 ) 2932 object_sizes_tensor = torch.cat(size_list) 2933 2934 # Send object sizes 2935 send(object_sizes_tensor, dst=dst, group=group) 2936 2937 # Concatenate and send serialized object tensors 2938 # Note: torch.cat will do an extra memory copy to the current device, if the tensor_list 2939 # has only one element, we can skip the copy. 2940 if len(tensor_list) == 1: # type: ignore[possibly-undefined] 2941 object_tensor = tensor_list[0] 2942 else: 2943 object_tensor = torch.cat(tensor_list) 2944 2945 send(object_tensor, dst=dst, group=group) 2946 2947 2948@_exception_logger 2949def recv_object_list(object_list, src=None, group=None, device=None): 2950 """ 2951 Receives picklable objects in ``object_list`` synchronously. 2952 2953 Similar to :func:`recv`, but can receive Python objects. 2954 2955 Args: 2956 object_list (List[Any]): List of objects to receive into. 2957 Must provide a list of sizes equal to the size of the list being sent. 2958 src (int, optional): Source rank from which to recv ``object_list``. 2959 Source rank is based on global process group (regardless of ``group`` argument) 2960 Will receive from any rank if set to None. Default is ``None``. 2961 group: (ProcessGroup, optional): The process group to work on. If None, 2962 the default process group will be used. Default is ``None``. 2963 device (``torch.device``, optional): If not None, receives on this device. 2964 Default is ``None``. 2965 2966 Returns: 2967 Sender rank. -1 if rank is not part of the group. If rank is part of the group, 2968 ``object_list`` will contain the sent objects from ``src`` rank. 2969 2970 .. note:: For NCCL-based process groups, internal tensor representations 2971 of objects must be moved to the GPU device before communication takes 2972 place. In this case, the device used is given by 2973 ``torch.cuda.current_device()`` and it is the user's responsibility to 2974 ensure that this is set so that each rank has an individual GPU, via 2975 ``torch.cuda.set_device()``. 2976 2977 .. warning:: 2978 :func:`recv_object_list` uses ``pickle`` module implicitly, which 2979 is known to be insecure. It is possible to construct malicious pickle 2980 data which will execute arbitrary code during unpickling. Only call this 2981 function with data you trust. 2982 2983 .. warning:: 2984 Calling :func:`recv_object_list` with GPU tensors is not well supported 2985 and inefficient as it incurs GPU -> CPU transfer since tensors would be 2986 pickled. Please consider using :func:`recv` instead. 2987 2988 Example:: 2989 >>> # xdoctest: +SKIP("need process group init") 2990 >>> # Note: Process group initialization omitted on each rank. 2991 >>> import torch.distributed as dist 2992 >>> # Assumes backend is not NCCL 2993 >>> device = torch.device("cpu") 2994 >>> if dist.get_rank() == 0: 2995 >>> # Assumes world_size of 2. 2996 >>> objects = ["foo", 12, {1: 2}] # any picklable object 2997 >>> dist.send_object_list(objects, dst=1, device=device) 2998 >>> else: 2999 >>> objects = [None, None, None] 3000 >>> dist.recv_object_list(objects, src=0, device=device) 3001 >>> objects 3002 ['foo', 12, {1: 2}] 3003 """ 3004 if _rank_not_in_group(group): 3005 _warn_not_in_group("recv_object_list") 3006 return -1 3007 3008 # Current device selection. 3009 # To preserve backwards compatibility, ``device`` is default to ``None`` 3010 # in which case we run current logic of device selection, i.e. 3011 # ``current_device`` is CUDA if backend is NCCL otherwise CPU device. In the 3012 # case it is not ``None`` we move the size and object tensors to be 3013 # received to this device. 3014 current_device = device or _get_pg_default_device(group) 3015 object_sizes_tensor = torch.empty( 3016 len(object_list), dtype=torch.long, device=current_device 3017 ) 3018 3019 # Receive object sizes 3020 rank_sizes = recv(object_sizes_tensor, src=src, group=group) 3021 3022 # Tensor to receive serialized objects into. 3023 object_tensor = torch.empty( # type: ignore[call-overload] 3024 torch.sum(object_sizes_tensor).item(), # type: ignore[arg-type] 3025 dtype=torch.uint8, 3026 device=current_device, 3027 ) 3028 3029 rank_objects = recv(object_tensor, src=src, group=group) 3030 assert ( 3031 rank_sizes == rank_objects 3032 ), "Mismatch in return ranks for object sizes and objects." 3033 # Deserialize objects using their stored sizes. 3034 offset = 0 3035 for i, obj_size in enumerate(object_sizes_tensor): 3036 obj_view = object_tensor[offset : offset + obj_size] 3037 obj_view = obj_view.type(torch.uint8) 3038 offset += obj_size 3039 object_list[i] = _tensor_to_object(obj_view, obj_size, group) 3040 return rank_objects 3041 3042 3043@_exception_logger 3044def broadcast_object_list(object_list, src=0, group=None, device=None): 3045 """ 3046 Broadcasts picklable objects in ``object_list`` to the whole group. 3047 3048 Similar to :func:`broadcast`, but Python objects can be passed in. 3049 Note that all objects in ``object_list`` must be picklable in order to be 3050 broadcasted. 3051 3052 Args: 3053 object_list (List[Any]): List of input objects to broadcast. 3054 Each object must be picklable. Only objects on the ``src`` rank will 3055 be broadcast, but each rank must provide lists of equal sizes. 3056 src (int): Source rank from which to broadcast ``object_list``. 3057 Source rank is based on global process group (regardless of ``group`` argument) 3058 group: (ProcessGroup, optional): The process group to work on. If None, 3059 the default process group will be used. Default is ``None``. 3060 device (``torch.device``, optional): If not None, the objects are 3061 serialized and converted to tensors which are moved to the 3062 ``device`` before broadcasting. Default is ``None``. 3063 3064 Returns: 3065 ``None``. If rank is part of the group, ``object_list`` will contain the 3066 broadcasted objects from ``src`` rank. 3067 3068 .. note:: For NCCL-based process groups, internal tensor representations 3069 of objects must be moved to the GPU device before communication takes 3070 place. In this case, the device used is given by 3071 ``torch.cuda.current_device()`` and it is the user's responsibility to 3072 ensure that this is set so that each rank has an individual GPU, via 3073 ``torch.cuda.set_device()``. 3074 3075 .. note:: Note that this API differs slightly from the :func:`broadcast` 3076 collective since it does not provide an ``async_op`` handle and thus 3077 will be a blocking call. 3078 3079 .. warning:: 3080 :func:`broadcast_object_list` uses ``pickle`` module implicitly, which 3081 is known to be insecure. It is possible to construct malicious pickle 3082 data which will execute arbitrary code during unpickling. Only call this 3083 function with data you trust. 3084 3085 .. warning:: 3086 Calling :func:`broadcast_object_list` with GPU tensors is not well supported 3087 and inefficient as it incurs GPU -> CPU transfer since tensors would be 3088 pickled. Please consider using :func:`broadcast` instead. 3089 3090 Example:: 3091 >>> # xdoctest: +SKIP("need process group init") 3092 >>> # Note: Process group initialization omitted on each rank. 3093 >>> import torch.distributed as dist 3094 >>> if dist.get_rank() == 0: 3095 >>> # Assumes world_size of 3. 3096 >>> objects = ["foo", 12, {1: 2}] # any picklable object 3097 >>> else: 3098 >>> objects = [None, None, None] 3099 >>> # Assumes backend is not NCCL 3100 >>> device = torch.device("cpu") 3101 >>> dist.broadcast_object_list(objects, src=0, device=device) 3102 >>> objects 3103 ['foo', 12, {1: 2}] 3104 """ 3105 if _rank_not_in_group(group): 3106 _warn_not_in_group("broadcast_object_list") 3107 return 3108 3109 # Current device selection. 3110 # To preserve backwards compatibility, ``device`` is default to ``None`` 3111 # in which case we run current logic of device selection, i.e. 3112 # ``current_device`` is CUDA if backend is NCCL otherwise CPU device. In the 3113 # case it is not ``None`` we move the size and object tensors to be 3114 # broadcasted to this device. 3115 current_device = device or _get_pg_default_device(group) 3116 my_rank = get_rank() 3117 # Serialize object_list elements to tensors on src rank. 3118 if my_rank == src: 3119 tensor_list, size_list = zip( 3120 *[_object_to_tensor(obj, current_device, group) for obj in object_list] 3121 ) 3122 object_sizes_tensor = torch.cat(size_list) 3123 else: 3124 object_sizes_tensor = torch.empty( 3125 len(object_list), dtype=torch.long, device=current_device 3126 ) 3127 3128 # Broadcast object sizes 3129 broadcast(object_sizes_tensor, src=src, group=group) 3130 3131 # Concatenate and broadcast serialized object tensors 3132 # Note: torch.cat will do an extra memory copy to the current device, if the tensor_list 3133 # has only one element, we can skip the copy. 3134 if my_rank == src: 3135 if len(tensor_list) == 1: # type: ignore[possibly-undefined] 3136 object_tensor = tensor_list[0] 3137 else: 3138 object_tensor = torch.cat(tensor_list) 3139 else: 3140 object_tensor = torch.empty( # type: ignore[call-overload] 3141 torch.sum(object_sizes_tensor).item(), # type: ignore[arg-type] 3142 dtype=torch.uint8, 3143 device=current_device, 3144 ) 3145 3146 broadcast(object_tensor, src=src, group=group) 3147 # Deserialize objects using their stored sizes. 3148 offset = 0 3149 if my_rank != src: 3150 for i, obj_size in enumerate(object_sizes_tensor): 3151 obj_view = object_tensor[offset : offset + obj_size] 3152 obj_view = obj_view.type(torch.uint8) 3153 offset += obj_size 3154 object_list[i] = _tensor_to_object(obj_view, obj_size, group) 3155 3156 3157@_exception_logger 3158def scatter_object_list( 3159 scatter_object_output_list, scatter_object_input_list, src=0, group=None 3160): 3161 """ 3162 Scatters picklable objects in ``scatter_object_input_list`` to the whole group. 3163 3164 Similar to :func:`scatter`, but Python objects can be passed in. On 3165 each rank, the scattered object will be stored as the first element of 3166 ``scatter_object_output_list``. Note that all objects in 3167 ``scatter_object_input_list`` must be picklable in order to be scattered. 3168 3169 Args: 3170 scatter_object_output_list (List[Any]): Non-empty list whose first 3171 element will store the object scattered to this rank. 3172 scatter_object_input_list (List[Any]): List of input objects to scatter. 3173 Each object must be picklable. Only objects on the ``src`` rank will 3174 be scattered, and the argument can be ``None`` for non-src ranks. 3175 src (int): Source rank from which to scatter ``scatter_object_input_list``. 3176 Source rank is based on global process group (regardless of ``group`` argument). 3177 group: (ProcessGroup, optional): The process group to work on. If None, 3178 the default process group will be used. Default is ``None``. 3179 3180 Returns: 3181 ``None``. If rank is part of the group, ``scatter_object_output_list`` 3182 will have its first element set to the scattered object for this rank. 3183 3184 .. note:: Note that this API differs slightly from the scatter collective 3185 since it does not provide an ``async_op`` handle and thus will be a 3186 blocking call. 3187 3188 .. warning:: 3189 :func:`scatter_object_list` uses ``pickle`` module implicitly, which 3190 is known to be insecure. It is possible to construct malicious pickle 3191 data which will execute arbitrary code during unpickling. Only call this 3192 function with data you trust. 3193 3194 .. warning:: 3195 Calling :func:`scatter_object_list` with GPU tensors is not well supported 3196 and inefficient as it incurs GPU -> CPU transfer since tensors would be 3197 pickled. Please consider using :func:`scatter` instead. 3198 3199 Example:: 3200 >>> # xdoctest: +SKIP("need process group init") 3201 >>> # Note: Process group initialization omitted on each rank. 3202 >>> import torch.distributed as dist 3203 >>> if dist.get_rank() == 0: 3204 >>> # Assumes world_size of 3. 3205 >>> objects = ["foo", 12, {1: 2}] # any picklable object 3206 >>> else: 3207 >>> # Can be any list on non-src ranks, elements are not used. 3208 >>> objects = [None, None, None] 3209 >>> output_list = [None] 3210 >>> dist.scatter_object_list(output_list, objects, src=0) 3211 >>> # Rank i gets objects[i]. For example, on rank 2: 3212 >>> output_list 3213 [{1: 2}] 3214 """ 3215 if _rank_not_in_group(group): 3216 _warn_not_in_group("scatter_object_list") 3217 return 3218 3219 if ( 3220 not isinstance(scatter_object_output_list, list) 3221 or len(scatter_object_output_list) < 1 3222 ): 3223 raise ValueError( 3224 "Expected argument scatter_object_output_list to be a list of size at least 1." 3225 ) 3226 3227 my_rank = get_rank() 3228 pg_device = _get_pg_default_device(group) 3229 if my_rank == src: 3230 tensor_list, tensor_sizes = zip( 3231 *[ 3232 _object_to_tensor(obj, pg_device, group) 3233 for obj in scatter_object_input_list 3234 ] 3235 ) 3236 tensor_list, tensor_sizes = list(tensor_list), list(tensor_sizes) 3237 3238 # Src rank broadcasts the maximum tensor size. This is because all ranks are 3239 # expected to call into scatter() with equal-sized tensors. 3240 if my_rank == src: 3241 max_tensor_size = max(tensor_sizes) # type: ignore[possibly-undefined] 3242 for tensor in tensor_list: # type: ignore[possibly-undefined] 3243 tensor.resize_(max_tensor_size) 3244 else: 3245 max_tensor_size = torch.tensor([0], dtype=torch.long, device=pg_device) 3246 broadcast(max_tensor_size, src=src, group=group) 3247 3248 # Scatter actual serialized objects 3249 output_tensor = torch.empty( 3250 max_tensor_size.item(), dtype=torch.uint8, device=pg_device 3251 ) 3252 scatter( 3253 output_tensor, 3254 scatter_list=None if my_rank != src else tensor_list, # type: ignore[possibly-undefined] 3255 src=src, 3256 group=group, 3257 ) 3258 3259 # Scatter per-object sizes to trim tensors when deserializing back to object 3260 obj_tensor_size = torch.tensor([0], dtype=torch.long, device=pg_device) 3261 scatter( 3262 obj_tensor_size, 3263 scatter_list=None if my_rank != src else tensor_sizes, # type: ignore[possibly-undefined] 3264 src=src, 3265 group=group, 3266 ) 3267 3268 # Deserialize back to object 3269 scatter_object_output_list[0] = _tensor_to_object( 3270 output_tensor, obj_tensor_size, group 3271 ) 3272 3273 3274@_exception_logger 3275def all_gather(tensor_list, tensor, group=None, async_op=False): 3276 """ 3277 Gathers tensors from the whole group in a list. 3278 3279 Complex and uneven sized tensors are supported. 3280 3281 Args: 3282 tensor_list (list[Tensor]): Output list. It should contain 3283 correctly-sized tensors to be used for output of the collective. 3284 Uneven sized tensors are supported. 3285 tensor (Tensor): Tensor to be broadcast from current process. 3286 group (ProcessGroup, optional): The process group to work on. If None, 3287 the default process group will be used. 3288 async_op (bool, optional): Whether this op should be an async op 3289 3290 Returns: 3291 Async work handle, if async_op is set to True. 3292 None, if not async_op or if not part of the group 3293 3294 Examples: 3295 >>> # xdoctest: +SKIP("need process group init") 3296 >>> # All tensors below are of torch.int64 dtype. 3297 >>> # We have 2 process groups, 2 ranks. 3298 >>> device = torch.device(f'cuda:{rank}') 3299 >>> tensor_list = [torch.zeros(2, dtype=torch.int64, device=device) for _ in range(2)] 3300 >>> tensor_list 3301 [tensor([0, 0], device='cuda:0'), tensor([0, 0], device='cuda:0')] # Rank 0 3302 [tensor([0, 0], device='cuda:1'), tensor([0, 0], device='cuda:1')] # Rank 1 3303 >>> tensor = torch.arange(2, dtype=torch.int64, device=device) + 1 + 2 * rank 3304 >>> tensor 3305 tensor([1, 2], device='cuda:0') # Rank 0 3306 tensor([3, 4], device='cuda:1') # Rank 1 3307 >>> dist.all_gather(tensor_list, tensor) 3308 >>> tensor_list 3309 [tensor([1, 2], device='cuda:0'), tensor([3, 4], device='cuda:0')] # Rank 0 3310 [tensor([1, 2], device='cuda:1'), tensor([3, 4], device='cuda:1')] # Rank 1 3311 3312 >>> # All tensors below are of torch.cfloat dtype. 3313 >>> # We have 2 process groups, 2 ranks. 3314 >>> tensor_list = [torch.zeros(2, dtype=torch.cfloat, device=device) for _ in range(2)] 3315 >>> tensor_list 3316 [tensor([0.+0.j, 0.+0.j], device='cuda:0'), tensor([0.+0.j, 0.+0.j], device='cuda:0')] # Rank 0 3317 [tensor([0.+0.j, 0.+0.j], device='cuda:1'), tensor([0.+0.j, 0.+0.j], device='cuda:1')] # Rank 1 3318 >>> tensor = torch.tensor([1+1j, 2+2j], dtype=torch.cfloat, device=device) + 2 * rank * (1+1j) 3319 >>> tensor 3320 tensor([1.+1.j, 2.+2.j], device='cuda:0') # Rank 0 3321 tensor([3.+3.j, 4.+4.j], device='cuda:1') # Rank 1 3322 >>> dist.all_gather(tensor_list, tensor) 3323 >>> tensor_list 3324 [tensor([1.+1.j, 2.+2.j], device='cuda:0'), tensor([3.+3.j, 4.+4.j], device='cuda:0')] # Rank 0 3325 [tensor([1.+1.j, 2.+2.j], device='cuda:1'), tensor([3.+3.j, 4.+4.j], device='cuda:1')] # Rank 1 3326 3327 """ 3328 _check_tensor_list(tensor_list, "tensor_list") 3329 _check_single_tensor(tensor, "tensor") 3330 _ensure_all_tensors_same_dtype(tensor_list, tensor) 3331 if _rank_not_in_group(group): 3332 _warn_not_in_group("all_gather") 3333 return 3334 3335 tensor_list = [ 3336 t if not t.is_complex() else torch.view_as_real(t) for t in tensor_list 3337 ] 3338 tensor = tensor if not tensor.is_complex() else torch.view_as_real(tensor) 3339 3340 group = group or _get_default_group() 3341 work = group.allgather([tensor_list], [tensor]) 3342 3343 if async_op: 3344 return work 3345 else: 3346 work.wait() 3347 3348 3349@_exception_logger 3350def all_gather_into_tensor(output_tensor, input_tensor, group=None, async_op=False): 3351 """ 3352 Gather tensors from all ranks and put them in a single output tensor. 3353 3354 This function requires all tensors to be the same size on each process. 3355 3356 Args: 3357 output_tensor (Tensor): Output tensor to accommodate tensor elements 3358 from all ranks. It must be correctly sized to have one of the 3359 following forms: 3360 (i) a concatenation of all the input tensors along the primary 3361 dimension; for definition of "concatenation", see ``torch.cat()``; 3362 (ii) a stack of all the input tensors along the primary dimension; 3363 for definition of "stack", see ``torch.stack()``. 3364 Examples below may better explain the supported output forms. 3365 input_tensor (Tensor): Tensor to be gathered from current rank. 3366 Different from the ``all_gather`` API, the input tensors in this 3367 API must have the same size across all ranks. 3368 group (ProcessGroup, optional): The process group to work on. If None, 3369 the default process group will be used. 3370 async_op (bool, optional): Whether this op should be an async op 3371 3372 Returns: 3373 Async work handle, if async_op is set to True. 3374 None, if not async_op or if not part of the group 3375 3376 Examples: 3377 >>> # xdoctest: +SKIP("need process group init") 3378 >>> # All tensors below are of torch.int64 dtype and on CUDA devices. 3379 >>> # We have two ranks. 3380 >>> device = torch.device(f'cuda:{rank}') 3381 >>> tensor_in = torch.arange(2, dtype=torch.int64, device=device) + 1 + 2 * rank 3382 >>> tensor_in 3383 tensor([1, 2], device='cuda:0') # Rank 0 3384 tensor([3, 4], device='cuda:1') # Rank 1 3385 >>> # Output in concatenation form 3386 >>> tensor_out = torch.zeros(world_size * 2, dtype=torch.int64, device=device) 3387 >>> dist.all_gather_into_tensor(tensor_out, tensor_in) 3388 >>> tensor_out 3389 tensor([1, 2, 3, 4], device='cuda:0') # Rank 0 3390 tensor([1, 2, 3, 4], device='cuda:1') # Rank 1 3391 >>> # Output in stack form 3392 >>> tensor_out2 = torch.zeros(world_size, 2, dtype=torch.int64, device=device) 3393 >>> dist.all_gather_into_tensor(tensor_out2, tensor_in) 3394 >>> tensor_out2 3395 tensor([[1, 2], 3396 [3, 4]], device='cuda:0') # Rank 0 3397 tensor([[1, 2], 3398 [3, 4]], device='cuda:1') # Rank 1 3399 3400 .. warning:: 3401 The Gloo backend does not support this API. 3402 3403 """ 3404 _check_single_tensor(input_tensor, "input_tensor") 3405 _check_single_tensor(output_tensor, "output_tensor") 3406 if _rank_not_in_group(group): 3407 _warn_not_in_group("all_gather_into_tensor") 3408 return 3409 3410 output_tensor = ( 3411 output_tensor 3412 if not output_tensor.is_complex() 3413 else torch.view_as_real(output_tensor) 3414 ) 3415 input_tensor = ( 3416 input_tensor 3417 if not input_tensor.is_complex() 3418 else torch.view_as_real(input_tensor) 3419 ) 3420 3421 opts = AllgatherOptions() 3422 opts.asyncOp = async_op 3423 3424 group = group or _get_default_group() 3425 3426 if group in _world.pg_coalesce_state.keys(): 3427 # We are in coalescing context, do not issue single operation, just append a collective representation 3428 coll = _CollOp(all_gather_into_tensor, input_tensor, output_tensor) 3429 _world.pg_coalesce_state[group].append(coll) 3430 if async_op: 3431 return _IllegalWork() 3432 else: 3433 return None 3434 3435 work = group._allgather_base(output_tensor, input_tensor, opts) 3436 3437 if async_op: 3438 return work 3439 else: 3440 work.wait() 3441 3442 3443@_exception_logger 3444@deprecated( 3445 "`torch.distributed._all_gather_base` is a private function and will be deprecated. " 3446 "Please use `torch.distributed.all_gather_into_tensor` instead.", 3447 category=FutureWarning, 3448) 3449def _all_gather_base(output_tensor, input_tensor, group=None, async_op=False): 3450 """ 3451 Single tensor all gather. Gathers a single tensor from all ranks, and puts them in a single output tensor. 3452 3453 Args: 3454 output_tensor (Tensor): Output tensor. It should contain 3455 correctly-sized tensors to be used for output of the collective. 3456 input_tensor (Tensor): Tensor to be broadcast from current process. 3457 group (ProcessGroup, optional): The process group to work on. If None, 3458 the default process group will be used. 3459 async_op (bool, optional): Whether this op should be an async op 3460 3461 Returns: 3462 Async work handle, if async_op is set to True. 3463 None, if not async_op or if not part of the group 3464 3465 .. warning:: 3466 `_all_gather_base` is a private function. Users should use 3467 `all_gather_into_tensor` instead. 3468 3469 """ 3470 return all_gather_into_tensor(output_tensor, input_tensor, group, async_op) 3471 3472 3473@_exception_logger 3474@deprecated( 3475 "`torch.distributed.all_gather_coalesced` will be deprecated. If you must use it, " 3476 "please revisit our documentation later at " 3477 "https://pytorch.org/docs/main/distributed.html#collective-functions", 3478 category=FutureWarning, 3479) 3480def all_gather_coalesced( 3481 output_tensor_lists, input_tensor_list, group=None, async_op=False 3482): 3483 """ 3484 Gathers input tensors from the whole group in a list in a coalesced manner. 3485 3486 Complex tensors are supported. 3487 3488 Args: 3489 output_tensor_lists (list[list[Tensor]]): Output list. It should contain 3490 correctly-sized tensors to be used for output of the collective. 3491 input_tensor_list (list[Tensor]): Tensors to be broadcast from 3492 current process. At least one tensor has to be non empty. 3493 group (ProcessGroup, optional): The process group to work on. If None, 3494 the default process group will be used. 3495 async_op (bool, optional): Whether this op should be an async op. 3496 3497 Returns: 3498 Async work handle, if async_op is set to True. 3499 None, if not async_op or if not part of the group 3500 3501 Example: 3502 we have 2 process groups, 2 ranks. 3503 rank 0 passes: 3504 input_tensor_list = [[[1, 1], [1, 1]], [2], [3, 3]] 3505 output_tensor_lists = 3506 [[[[-1, -1], [-1, -1]], [-1], [-1, -1]], 3507 [[[-1, -1], [-1, -1]], [-1], [-1, -1]]] 3508 rank 1 passes: 3509 input_tensor_list = [[[3, 3], [3, 3]], [5], [1, 1]] 3510 output_tensor_lists = 3511 [[[[-1, -1], [-1, -1]], [-1], [-1, -1]], 3512 [[[-1, -1], [-1, -1]], [-1], [-1, -1]]] 3513 both rank 0 and 1 get: 3514 output_tensor_lists = 3515 [[[1, 1], [1, 1]], [2], [3, 3]], 3516 [[3, 3], [3, 3]], [5], [1, 1]]]. 3517 3518 WARNING: at this time individual shape checking is not implemented across nodes. 3519 For example, if the rank 0 node passes [torch.rand(4), torch.rand(2)] and the 3520 rank 1 node passes [torch.rand(2), torch.rand(2), torch.rand(2)], the 3521 all_gather_coalesced operation will proceed without complaint and return 3522 erroneous outputs. This lack of shape checking results in significant 3523 performance improvements but users of this function should take extra care 3524 to ensure that each node passes in tensors whose shapes match across nodes. 3525 """ 3526 # We only check basic compatibility with C++ params here, C++ code will 3527 # do shape and type checking. 3528 if _rank_not_in_group(group): 3529 _warn_not_in_group("all_gather_coalesced") 3530 return 3531 _check_tensor_list(input_tensor_list, "input_tensor_list") 3532 _ensure_all_tensors_same_dtype(input_tensor_list) 3533 if not isinstance(output_tensor_lists, list): 3534 raise TypeError( 3535 "Invalid function argument: output_tensor_lists should be a list" 3536 ) 3537 for output_tensor_list in output_tensor_lists: 3538 _check_tensor_list(output_tensor_list, "output_tensor_lists") 3539 _ensure_all_tensors_same_dtype(output_tensor_list) 3540 3541 output_tensor_lists = [ 3542 [t if not t.is_complex() else torch.view_as_real(t) for t in l] 3543 for l in output_tensor_lists 3544 ] 3545 input_tensor_list = [ 3546 t if not t.is_complex() else torch.view_as_real(t) for t in input_tensor_list 3547 ] 3548 3549 group = group or _get_default_group() 3550 work = group.allgather_coalesced(output_tensor_lists, input_tensor_list) 3551 3552 if async_op: 3553 return work.get_future() 3554 else: 3555 work.wait() 3556 3557 3558def _validate_output_list_for_rank(my_rank, dst, gather_list): 3559 if dst == my_rank: 3560 if not gather_list: 3561 raise ValueError( 3562 "Argument ``gather_list`` must be specified on destination rank." 3563 ) 3564 elif gather_list: 3565 raise ValueError( 3566 "Argument ``gather_list`` must NOT be specified " 3567 "on non-destination ranks." 3568 ) 3569 3570 3571@_exception_logger 3572def gather(tensor, gather_list=None, dst=0, group=None, async_op=False): 3573 """ 3574 Gathers a list of tensors in a single process. 3575 3576 This function requires all tensors to be the same size on each process. 3577 3578 Args: 3579 tensor (Tensor): Input tensor. 3580 gather_list (list[Tensor], optional): List of appropriately, 3581 same-sized tensors to use for gathered data 3582 (default is None, must be specified on the destination rank) 3583 dst (int, optional): Destination rank on global process group (regardless of ``group`` argument). (default is 0) 3584 group (ProcessGroup, optional): The process group to work on. If None, 3585 the default process group will be used. 3586 async_op (bool, optional): Whether this op should be an async op 3587 3588 Returns: 3589 Async work handle, if async_op is set to True. 3590 None, if not async_op or if not part of the group 3591 3592 """ 3593 _check_single_tensor(tensor, "tensor") 3594 3595 # Parameter ``gather_list`` may be left unspecified on non-dst ranks. 3596 if gather_list: 3597 _check_tensor_list(gather_list, "gather_list") 3598 else: 3599 gather_list = [] 3600 _ensure_all_tensors_same_dtype(tensor, gather_list) 3601 3602 if _rank_not_in_group(group): 3603 _warn_not_in_group("gather") 3604 return 3605 3606 my_rank = get_rank() 3607 _validate_output_list_for_rank(my_rank, dst, gather_list) 3608 output_tensors = [gather_list] if dst == my_rank else [] 3609 input_tensors = [tensor] 3610 3611 opts = GatherOptions() 3612 opts.rootRank = dst 3613 3614 if group is None or group is GroupMember.WORLD: 3615 default_pg = _get_default_group() 3616 work = default_pg.gather(output_tensors, input_tensors, opts) 3617 else: 3618 group_dst_rank = get_group_rank(group, dst) 3619 opts.rootRank = group_dst_rank 3620 work = group.gather(output_tensors, input_tensors, opts) 3621 3622 if async_op: 3623 return work 3624 else: 3625 work.wait() 3626 3627 3628@_exception_logger 3629def scatter(tensor, scatter_list=None, src=0, group=None, async_op=False): 3630 """ 3631 Scatters a list of tensors to all processes in a group. 3632 3633 Each process will receive exactly one tensor and store its data in the 3634 ``tensor`` argument. 3635 3636 Complex tensors are supported. 3637 3638 Args: 3639 tensor (Tensor): Output tensor. 3640 scatter_list (list[Tensor]): List of tensors to scatter (default is 3641 None, must be specified on the source rank) 3642 src (int): Source rank on global process group (regardless of ``group`` argument). 3643 Default is 0 3644 group (ProcessGroup, optional): The process group to work on. If None, 3645 the default process group will be used. 3646 async_op (bool, optional): Whether this op should be an async op 3647 3648 Returns: 3649 Async work handle, if async_op is set to True. 3650 None, if not async_op or if not part of the group 3651 3652 .. note:: Note that all Tensors in scatter_list must have the same size. 3653 3654 Example:: 3655 >>> # xdoctest: +SKIP("need process group init") 3656 >>> # Note: Process group initialization omitted on each rank. 3657 >>> import torch.distributed as dist 3658 >>> tensor_size = 2 3659 >>> t_ones = torch.ones(tensor_size) 3660 >>> t_fives = torch.ones(tensor_size) * 5 3661 >>> output_tensor = torch.zeros(tensor_size) 3662 >>> if dist.get_rank() == 0: 3663 >>> # Assumes world_size of 2. 3664 >>> # Only tensors, all of which must be the same size. 3665 >>> scatter_list = [t_ones, t_fives] 3666 >>> else: 3667 >>> scatter_list = None 3668 >>> dist.scatter(output_tensor, scatter_list, src=0) 3669 >>> # Rank i gets scatter_list[i]. For example, on rank 1: 3670 >>> output_tensor 3671 tensor([5., 5.]) 3672 3673 """ 3674 _check_single_tensor(tensor, "tensor") 3675 3676 # Parameter ``scatter_list`` may be left unspecified on non-src ranks. 3677 if scatter_list: 3678 _check_tensor_list(scatter_list, "scatter_list") 3679 else: 3680 scatter_list = [] 3681 _ensure_all_tensors_same_dtype(tensor, scatter_list) 3682 3683 if _rank_not_in_group(group): 3684 _warn_not_in_group("scatter") 3685 return 3686 scatter_list = [ 3687 t if not t.is_complex() else torch.view_as_real(t) for t in scatter_list 3688 ] 3689 tensor = tensor if not tensor.is_complex() else torch.view_as_real(tensor) 3690 3691 my_rank = get_rank() 3692 if src == my_rank: 3693 if not scatter_list: 3694 raise ValueError( 3695 "Argument ``scatter_list`` must be specified on source rank." 3696 ) 3697 input_tensors = [scatter_list] 3698 output_tensors = [tensor] 3699 else: 3700 if scatter_list: 3701 raise ValueError( 3702 "Argument ``scatter_list`` must NOT be specified " 3703 "on non-source ranks." 3704 ) 3705 input_tensors = [] 3706 output_tensors = [tensor] 3707 3708 opts = ScatterOptions() 3709 opts.rootRank = src 3710 opts.asyncOp = async_op 3711 3712 if group is None or group is GroupMember.WORLD: 3713 default_pg = _get_default_group() 3714 work = default_pg.scatter(output_tensors, input_tensors, opts) 3715 else: 3716 group_src_rank = get_group_rank(group, src) 3717 opts.rootRank = group_src_rank 3718 work = group.scatter(output_tensors, input_tensors, opts) 3719 3720 if async_op: 3721 return work 3722 else: 3723 work.wait() 3724 3725 3726@_exception_logger 3727def reduce_scatter(output, input_list, op=ReduceOp.SUM, group=None, async_op=False): 3728 """ 3729 Reduces, then scatters a list of tensors to all processes in a group. 3730 3731 Args: 3732 output (Tensor): Output tensor. 3733 input_list (list[Tensor]): List of tensors to reduce and scatter. 3734 op (optional): One of the values from 3735 ``torch.distributed.ReduceOp`` 3736 enum. Specifies an operation used for element-wise reductions. 3737 group (ProcessGroup, optional): The process group to work on. If None, 3738 the default process group will be used. 3739 async_op (bool, optional): Whether this op should be an async op. 3740 3741 Returns: 3742 Async work handle, if async_op is set to True. 3743 None, if not async_op or if not part of the group. 3744 3745 """ 3746 _check_single_tensor(output, "output") 3747 _check_tensor_list(input_list, "input_list") 3748 _ensure_all_tensors_same_dtype(output, input_list) 3749 if _rank_not_in_group(group): 3750 _warn_not_in_group("reduce_scatter") 3751 return 3752 3753 opts = ReduceScatterOptions() 3754 opts.reduceOp = op 3755 3756 group = group or _get_default_group() 3757 work = group.reduce_scatter([output], [input_list], opts) 3758 3759 if async_op: 3760 return work 3761 else: 3762 work.wait() 3763 3764 3765@_exception_logger 3766def reduce_scatter_tensor(output, input, op=ReduceOp.SUM, group=None, async_op=False): 3767 """ 3768 Reduces, then scatters a tensor to all ranks in a group. 3769 3770 Args: 3771 output (Tensor): Output tensor. It should have the same size across all 3772 ranks. 3773 input (Tensor): Input tensor to be reduced and scattered. Its size 3774 should be output tensor size times the world size. The input tensor 3775 can have one of the following shapes: 3776 (i) a concatenation of the output tensors along the primary 3777 dimension, or 3778 (ii) a stack of the output tensors along the primary dimension. 3779 For definition of "concatenation", see ``torch.cat()``. 3780 For definition of "stack", see ``torch.stack()``. 3781 group (ProcessGroup, optional): The process group to work on. If None, 3782 the default process group will be used. 3783 async_op (bool, optional): Whether this op should be an async op. 3784 3785 Returns: 3786 Async work handle, if async_op is set to True. 3787 None, if not async_op or if not part of the group. 3788 3789 Examples: 3790 >>> # xdoctest: +SKIP("need process group init") 3791 >>> # All tensors below are of torch.int64 dtype and on CUDA devices. 3792 >>> # We have two ranks. 3793 >>> device = torch.device(f'cuda:{rank}') 3794 >>> tensor_out = torch.zeros(2, dtype=torch.int64, device=device) 3795 >>> # Input in concatenation form 3796 >>> tensor_in = torch.arange(world_size * 2, dtype=torch.int64, device=device) 3797 >>> tensor_in 3798 tensor([0, 1, 2, 3], device='cuda:0') # Rank 0 3799 tensor([0, 1, 2, 3], device='cuda:1') # Rank 1 3800 >>> dist.reduce_scatter_tensor(tensor_out, tensor_in) 3801 >>> tensor_out 3802 tensor([0, 2], device='cuda:0') # Rank 0 3803 tensor([4, 6], device='cuda:1') # Rank 1 3804 >>> # Input in stack form 3805 >>> tensor_in = torch.reshape(tensor_in, (world_size, 2)) 3806 >>> tensor_in 3807 tensor([[0, 1], 3808 [2, 3]], device='cuda:0') # Rank 0 3809 tensor([[0, 1], 3810 [2, 3]], device='cuda:1') # Rank 1 3811 >>> dist.reduce_scatter_tensor(tensor_out, tensor_in) 3812 >>> tensor_out 3813 tensor([0, 2], device='cuda:0') # Rank 0 3814 tensor([4, 6], device='cuda:1') # Rank 1 3815 3816 .. warning:: 3817 The Gloo backend does not support this API. 3818 3819 """ 3820 _check_single_tensor(output, "output") 3821 _check_single_tensor(input, "input") 3822 3823 if _rank_not_in_group(group): 3824 _warn_not_in_group("reduce_scatter_tensor") 3825 return 3826 3827 opts = ReduceScatterOptions() 3828 opts.reduceOp = op 3829 opts.asyncOp = async_op 3830 3831 group = group or _get_default_group() 3832 3833 # Check if we are in coalescing context 3834 # If we are, do not issue single operation, just append a collective representation 3835 if group in _world.pg_coalesce_state.keys(): 3836 coll = _CollOp(reduce_scatter_tensor, input, output, op, None) 3837 _world.pg_coalesce_state[group].append(coll) 3838 if async_op: 3839 return _IllegalWork() 3840 else: 3841 return None 3842 3843 work = group._reduce_scatter_base(output, input, opts) 3844 3845 if async_op: 3846 return work 3847 else: 3848 work.wait() 3849 3850 3851@deprecated( 3852 "`torch.distributed._reduce_scatter_base` is a private function and will be deprecated. " 3853 "Please use `torch.distributed.reduce_scatter_tensor` instead.", 3854 category=FutureWarning, 3855) 3856def _reduce_scatter_base(output, input, op=ReduceOp.SUM, group=None, async_op=False): 3857 """ 3858 Reduces, then scatters a flattened tensor to all processes in a group. 3859 3860 Args: 3861 output (Tensor): Output tensor. 3862 input (Tensor): Input tensor that is of size output tensor size times world size 3863 group (ProcessGroup, optional): The process group to work on. If None, 3864 the default process group will be used. 3865 async_op (bool, optional): Whether this op should be an async op. 3866 3867 Returns: 3868 Async work handle, if async_op is set to True. 3869 None, if not async_op or if not part of the group. 3870 3871 .. warning:: 3872 `_reduce_scatter_base` is a private function. Users should use 3873 `reduce_scatter_tensor` instead. 3874 3875 """ 3876 return reduce_scatter_tensor(output, input, op, group, async_op) 3877 3878 3879@_exception_logger 3880def all_to_all_single( 3881 output, 3882 input, 3883 output_split_sizes=None, 3884 input_split_sizes=None, 3885 group=None, 3886 async_op=False, 3887): 3888 """ 3889 Split input tensor and then scatter the split list to all processes in a group. 3890 3891 Later the received tensors are concatenated from all the processes in the group 3892 and returned as a single output tensor. 3893 3894 Complex tensors are supported. 3895 3896 Args: 3897 output (Tensor): Gathered concatenated output tensor. 3898 input (Tensor): Input tensor to scatter. 3899 output_split_sizes: (list[Int], optional): Output split sizes for dim 0 3900 if specified None or empty, dim 0 of ``output`` tensor must divide 3901 equally by ``world_size``. 3902 input_split_sizes: (list[Int], optional): Input split sizes for dim 0 3903 if specified None or empty, dim 0 of ``input`` tensor must divide 3904 equally by ``world_size``. 3905 group (ProcessGroup, optional): The process group to work on. If None, 3906 the default process group will be used. 3907 async_op (bool, optional): Whether this op should be an async op. 3908 3909 Returns: 3910 Async work handle, if async_op is set to True. 3911 None, if not async_op or if not part of the group. 3912 3913 .. warning:: 3914 `all_to_all_single` is experimental and subject to change. 3915 3916 Examples: 3917 >>> # xdoctest: +SKIP("Undefined rank") 3918 >>> input = torch.arange(4) + rank * 4 3919 >>> input 3920 tensor([0, 1, 2, 3]) # Rank 0 3921 tensor([4, 5, 6, 7]) # Rank 1 3922 tensor([8, 9, 10, 11]) # Rank 2 3923 tensor([12, 13, 14, 15]) # Rank 3 3924 >>> output = torch.empty([4], dtype=torch.int64) 3925 >>> dist.all_to_all_single(output, input) 3926 >>> output 3927 tensor([0, 4, 8, 12]) # Rank 0 3928 tensor([1, 5, 9, 13]) # Rank 1 3929 tensor([2, 6, 10, 14]) # Rank 2 3930 tensor([3, 7, 11, 15]) # Rank 3 3931 3932 >>> # Essentially, it is similar to following operation: 3933 >>> scatter_list = list(input.chunk(world_size)) 3934 >>> gather_list = list(output.chunk(world_size)) 3935 >>> for i in range(world_size): 3936 >>> dist.scatter(gather_list[i], scatter_list if i == rank else [], src = i) 3937 3938 >>> # Another example with uneven split 3939 >>> input 3940 tensor([0, 1, 2, 3, 4, 5]) # Rank 0 3941 tensor([10, 11, 12, 13, 14, 15, 16, 17, 18]) # Rank 1 3942 tensor([20, 21, 22, 23, 24]) # Rank 2 3943 tensor([30, 31, 32, 33, 34, 35, 36]) # Rank 3 3944 >>> input_splits 3945 [2, 2, 1, 1] # Rank 0 3946 [3, 2, 2, 2] # Rank 1 3947 [2, 1, 1, 1] # Rank 2 3948 [2, 2, 2, 1] # Rank 3 3949 >>> output_splits 3950 [2, 3, 2, 2] # Rank 0 3951 [2, 2, 1, 2] # Rank 1 3952 [1, 2, 1, 2] # Rank 2 3953 [1, 2, 1, 1] # Rank 3 3954 >>> output = ... 3955 >>> dist.all_to_all_single(output, input, output_splits, input_splits) 3956 >>> output 3957 tensor([ 0, 1, 10, 11, 12, 20, 21, 30, 31]) # Rank 0 3958 tensor([ 2, 3, 13, 14, 22, 32, 33]) # Rank 1 3959 tensor([ 4, 15, 16, 23, 34, 35]) # Rank 2 3960 tensor([ 5, 17, 18, 24, 36]) # Rank 3 3961 3962 3963 >>> # Another example with tensors of torch.cfloat type. 3964 >>> input = torch.tensor([1+1j, 2+2j, 3+3j, 4+4j], dtype=torch.cfloat) + 4 * rank * (1+1j) 3965 >>> input 3966 tensor([1+1j, 2+2j, 3+3j, 4+4j]) # Rank 0 3967 tensor([5+5j, 6+6j, 7+7j, 8+8j]) # Rank 1 3968 tensor([9+9j, 10+10j, 11+11j, 12+12j]) # Rank 2 3969 tensor([13+13j, 14+14j, 15+15j, 16+16j]) # Rank 3 3970 >>> output = torch.empty([4], dtype=torch.int64) 3971 >>> dist.all_to_all_single(output, input) 3972 >>> output 3973 tensor([1+1j, 5+5j, 9+9j, 13+13j]) # Rank 0 3974 tensor([2+2j, 6+6j, 10+10j, 14+14j]) # Rank 1 3975 tensor([3+3j, 7+7j, 11+11j, 15+15j]) # Rank 2 3976 tensor([4+4j, 8+8j, 12+12j, 16+16j]) # Rank 3 3977 """ 3978 if _rank_not_in_group(group): 3979 _warn_not_in_group("all_to_all_single") 3980 return 3981 3982 opts = AllToAllOptions() 3983 _check_single_tensor(output, "output") 3984 _check_single_tensor(input, "input") 3985 _ensure_all_tensors_same_dtype(output, input) 3986 3987 if input.is_complex(): 3988 input = torch.view_as_real(input) 3989 if output.is_complex(): 3990 output = torch.view_as_real(output) 3991 3992 output_split_sizes = [] if output_split_sizes is None else output_split_sizes 3993 input_split_sizes = [] if input_split_sizes is None else input_split_sizes 3994 3995 group = group or _get_default_group() 3996 work = group.alltoall_base( 3997 output, input, output_split_sizes, input_split_sizes, opts 3998 ) 3999 4000 if async_op: 4001 return work 4002 else: 4003 work.wait() 4004 4005 4006@_exception_logger 4007def all_to_all(output_tensor_list, input_tensor_list, group=None, async_op=False): 4008 """ 4009 Scatters list of input tensors to all processes in a group and return gathered list of tensors in output list. 4010 4011 Complex tensors are supported. 4012 4013 Args: 4014 output_tensor_list (list[Tensor]): List of tensors to be gathered one 4015 per rank. 4016 input_tensor_list (list[Tensor]): List of tensors to scatter one per rank. 4017 group (ProcessGroup, optional): The process group to work on. If None, 4018 the default process group will be used. 4019 async_op (bool, optional): Whether this op should be an async op. 4020 4021 Returns: 4022 Async work handle, if async_op is set to True. 4023 None, if not async_op or if not part of the group. 4024 4025 .. warning:: 4026 `all_to_all` is experimental and subject to change. 4027 4028 Examples: 4029 >>> # xdoctest: +SKIP("Undefined rank") 4030 >>> input = torch.arange(4) + rank * 4 4031 >>> input = list(input.chunk(4)) 4032 >>> input 4033 [tensor([0]), tensor([1]), tensor([2]), tensor([3])] # Rank 0 4034 [tensor([4]), tensor([5]), tensor([6]), tensor([7])] # Rank 1 4035 [tensor([8]), tensor([9]), tensor([10]), tensor([11])] # Rank 2 4036 [tensor([12]), tensor([13]), tensor([14]), tensor([15])] # Rank 3 4037 >>> output = list(torch.empty([4], dtype=torch.int64).chunk(4)) 4038 >>> dist.all_to_all(output, input) 4039 >>> output 4040 [tensor([0]), tensor([4]), tensor([8]), tensor([12])] # Rank 0 4041 [tensor([1]), tensor([5]), tensor([9]), tensor([13])] # Rank 1 4042 [tensor([2]), tensor([6]), tensor([10]), tensor([14])] # Rank 2 4043 [tensor([3]), tensor([7]), tensor([11]), tensor([15])] # Rank 3 4044 4045 >>> # Essentially, it is similar to following operation: 4046 >>> scatter_list = input 4047 >>> gather_list = output 4048 >>> for i in range(world_size): 4049 >>> dist.scatter(gather_list[i], scatter_list if i == rank else [], src=i) 4050 4051 >>> input 4052 tensor([0, 1, 2, 3, 4, 5]) # Rank 0 4053 tensor([10, 11, 12, 13, 14, 15, 16, 17, 18]) # Rank 1 4054 tensor([20, 21, 22, 23, 24]) # Rank 2 4055 tensor([30, 31, 32, 33, 34, 35, 36]) # Rank 3 4056 >>> input_splits 4057 [2, 2, 1, 1] # Rank 0 4058 [3, 2, 2, 2] # Rank 1 4059 [2, 1, 1, 1] # Rank 2 4060 [2, 2, 2, 1] # Rank 3 4061 >>> output_splits 4062 [2, 3, 2, 2] # Rank 0 4063 [2, 2, 1, 2] # Rank 1 4064 [1, 2, 1, 2] # Rank 2 4065 [1, 2, 1, 1] # Rank 3 4066 >>> input = list(input.split(input_splits)) 4067 >>> input 4068 [tensor([0, 1]), tensor([2, 3]), tensor([4]), tensor([5])] # Rank 0 4069 [tensor([10, 11, 12]), tensor([13, 14]), tensor([15, 16]), tensor([17, 18])] # Rank 1 4070 [tensor([20, 21]), tensor([22]), tensor([23]), tensor([24])] # Rank 2 4071 [tensor([30, 31]), tensor([32, 33]), tensor([34, 35]), tensor([36])] # Rank 3 4072 >>> output = ... 4073 >>> dist.all_to_all(output, input) 4074 >>> output 4075 [tensor([0, 1]), tensor([10, 11, 12]), tensor([20, 21]), tensor([30, 31])] # Rank 0 4076 [tensor([2, 3]), tensor([13, 14]), tensor([22]), tensor([32, 33])] # Rank 1 4077 [tensor([4]), tensor([15, 16]), tensor([23]), tensor([34, 35])] # Rank 2 4078 [tensor([5]), tensor([17, 18]), tensor([24]), tensor([36])] # Rank 3 4079 4080 >>> # Another example with tensors of torch.cfloat type. 4081 >>> input = torch.tensor([1+1j, 2+2j, 3+3j, 4+4j], dtype=torch.cfloat) + 4 * rank * (1+1j) 4082 >>> input = list(input.chunk(4)) 4083 >>> input 4084 [tensor([1+1j]), tensor([2+2j]), tensor([3+3j]), tensor([4+4j])] # Rank 0 4085 [tensor([5+5j]), tensor([6+6j]), tensor([7+7j]), tensor([8+8j])] # Rank 1 4086 [tensor([9+9j]), tensor([10+10j]), tensor([11+11j]), tensor([12+12j])] # Rank 2 4087 [tensor([13+13j]), tensor([14+14j]), tensor([15+15j]), tensor([16+16j])] # Rank 3 4088 >>> output = list(torch.empty([4], dtype=torch.int64).chunk(4)) 4089 >>> dist.all_to_all(output, input) 4090 >>> output 4091 [tensor([1+1j]), tensor([5+5j]), tensor([9+9j]), tensor([13+13j])] # Rank 0 4092 [tensor([2+2j]), tensor([6+6j]), tensor([10+10j]), tensor([14+14j])] # Rank 1 4093 [tensor([3+3j]), tensor([7+7j]), tensor([11+11j]), tensor([15+15j])] # Rank 2 4094 [tensor([4+4j]), tensor([8+8j]), tensor([12+12j]), tensor([16+16j])] # Rank 3 4095 4096 """ 4097 if _rank_not_in_group(group): 4098 _warn_not_in_group("all_to_all") 4099 return 4100 4101 opts = AllToAllOptions() 4102 _check_tensor_list(output_tensor_list, "output_tensor_list") 4103 _check_tensor_list(input_tensor_list, "input_tensor_list") 4104 _ensure_all_tensors_same_dtype(output_tensor_list, input_tensor_list) 4105 4106 input_tensor_list = [ 4107 t if not t.is_complex() else torch.view_as_real(t) for t in input_tensor_list 4108 ] 4109 output_tensor_list = [ 4110 t if not t.is_complex() else torch.view_as_real(t) for t in output_tensor_list 4111 ] 4112 4113 group = group or _get_default_group() 4114 work = group.alltoall(output_tensor_list, input_tensor_list, opts) 4115 4116 if async_op: 4117 return work 4118 else: 4119 work.wait() 4120 4121 4122@_exception_logger 4123def barrier(group=GroupMember.WORLD, async_op=False, device_ids=None): 4124 """ 4125 Synchronize all processes. 4126 4127 This collective blocks processes until the whole group enters this function, 4128 if async_op is False, or if async work handle is called on wait(). 4129 4130 Args: 4131 group (ProcessGroup, optional): The process group to work on. If None, 4132 the default process group will be used. 4133 async_op (bool, optional): Whether this op should be an async op 4134 device_ids ([int], optional): List of device/GPU ids. 4135 4136 Returns: 4137 Async work handle, if async_op is set to True. 4138 None, if not async_op or if not part of the group 4139 4140 .. note:: `ProcessGroupNCCL` now relies on stream synchronization instead of 4141 device synchronization to block the CPU. Thus, please do not assume that 4142 `barrier()` would perform a device synchronization. 4143 """ 4144 if _rank_not_in_group(group): 4145 _warn_not_in_group("barrier") 4146 return 4147 4148 opts = BarrierOptions() 4149 opts.device = _get_pg_default_device(group) 4150 if device_ids is not None: 4151 if isinstance(device_ids, list): 4152 opts.device_ids = device_ids 4153 else: 4154 raise TypeError( 4155 "Invalid function argument: device_ids type should be List[int]" 4156 ) 4157 4158 group = group or _get_default_group() 4159 work = group.barrier(opts=opts) 4160 4161 if async_op: 4162 return work 4163 else: 4164 work.wait() 4165 4166 4167def monitored_barrier(group=GroupMember.WORLD, timeout=None, wait_all_ranks=False): 4168 """ 4169 Synchronize processes similar to ``torch.distributed.barrier``, but consider a configurable timeout. 4170 4171 It is able to report ranks that did not pass this barrier within the provided timeout. 4172 Specifically, for non-zero ranks, will block until a send/recv is processed from rank 0. 4173 Rank 0 will block until all send /recv from other ranks are processed, and will report 4174 failures for ranks that failed to respond in time. Note that if one rank does not reach the 4175 monitored_barrier (for example due to a hang), all other ranks would fail in monitored_barrier. 4176 4177 This collective will block all processes/ranks in the group, until the 4178 whole group exits the function successfully, making it useful for debugging 4179 and synchronizing. However, it can have a performance impact and should only 4180 be used for debugging or scenarios that require full synchronization points 4181 on the host-side. For debugging purposes, this barrier can be inserted 4182 before the application's collective calls to check if any ranks are 4183 desynchronized. 4184 4185 .. note:: Note that this collective is only supported with the GLOO backend. 4186 4187 Args: 4188 group (ProcessGroup, optional): The process group to work on. If 4189 ``None``, the default process group will be used. 4190 timeout (datetime.timedelta, optional): Timeout for monitored_barrier. 4191 If ``None``, the default process group timeout will be used. 4192 wait_all_ranks (bool, optional): Whether to collect all failed ranks or 4193 not. By default, this is ``False`` and ``monitored_barrier`` on rank 0 4194 will throw on the first failed rank it encounters in order to fail 4195 fast. By setting ``wait_all_ranks=True`` ``monitored_barrier`` will 4196 collect all failed ranks and throw an error containing information 4197 about all failed ranks. 4198 4199 Returns: 4200 ``None``. 4201 4202 Example:: 4203 >>> # xdoctest: +SKIP("need process group init") 4204 >>> # Note: Process group initialization omitted on each rank. 4205 >>> import torch.distributed as dist 4206 >>> if dist.get_rank() != 1: 4207 >>> dist.monitored_barrier() # Raises exception indicating that 4208 >>> # rank 1 did not call into monitored_barrier. 4209 >>> # Example with wait_all_ranks=True 4210 >>> if dist.get_rank() == 0: 4211 >>> dist.monitored_barrier(wait_all_ranks=True) # Raises exception 4212 >>> # indicating that ranks 1, 2, ... world_size - 1 did not call into 4213 >>> # monitored_barrier. 4214 """ 4215 # Need to call rank not in group before using the group, otherwise 4216 # "Invalid process group" error is raised. 4217 if _rank_not_in_group(group): 4218 _warn_not_in_group("monitored_barrier") 4219 return 4220 4221 if get_backend(group) != Backend.GLOO: 4222 raise ValueError("monitored_barrier is only implemented for GLOO backend.") 4223 4224 if timeout is None: 4225 timeout = _get_default_timeout(get_backend(group)) 4226 elif isinstance(timeout, float): 4227 # TODO(whc) aparently some existing test case for monitored_barrier passes in a timeout in float format? 4228 warnings.warn( 4229 "Please specify timeout arg as a timedelta. " 4230 f"Converting current value of {timeout} assuming it represents seconds", 4231 ) 4232 timeout = timedelta(seconds=timeout) 4233 4234 _check_valid_timeout(timeout) 4235 4236 group_to_use = _get_default_group() if group is None else group 4237 return group_to_use.monitored_barrier(timeout, wait_all_ranks=wait_all_ranks) 4238 4239 4240def _create_process_group_wrapper( 4241 wrapped_pg: torch._C._distributed_c10d.Backend, 4242 store_prefix: str, 4243 store: Store, 4244 rank: int, 4245 world_size: int, 4246 timeout: timedelta = default_pg_timeout, 4247): 4248 assert _GLOO_AVAILABLE, "ProcessGroupWrapper unsupported without GLOO backend." 4249 4250 # (whc) this appears to be just for the gloo backend? if so, `default_pg_timeout` is appropriate... 4251 4252 # Create a separate prefix store for the helper process group. 4253 prefix = f"{PG_WRAPPER_STORE_PREFIX}:{store_prefix}" 4254 store = PrefixStore(prefix, store) 4255 helper_pg = ProcessGroupGloo(store, rank, world_size, timeout=timeout) 4256 # Wrap the underlying pg with ProcessGroupWrapper. 4257 wrapped_pg = _ProcessGroupWrapper(wrapped_pg, helper_pg) 4258 return wrapped_pg 4259 4260 4261# helper function for deterministically hashing a list of ranks 4262def _hash_ranks(ranks: List[int]): 4263 return hashlib.sha1(bytes("_".join(map(str, ranks)), "utf-8")).hexdigest() 4264 4265 4266# Takes a list of ranks and computes an integer color 4267def _process_group_color(ranks: List[int]) -> int: 4268 # Convert our hash to an int, but avoid negative numbers by shifting a bit. 4269 return int(_hash_ranks(ranks), 16) % (sys.maxsize >> 1) 4270 4271 4272def _process_group_name(ranks, use_hashed_name): 4273 global _world 4274 if use_hashed_name: 4275 pg_name = _hash_ranks(ranks) 4276 while pg_name in _world.pg_names.values(): 4277 pg_name = hashlib.sha1(bytes(pg_name + "_", "utf-8")).hexdigest() 4278 else: 4279 pg_name = str(_world.group_count) 4280 _world.group_count += 1 4281 return pg_name 4282 4283 4284def _get_backend_from_str(backend: Optional[str] = None) -> Backend: 4285 # Default to the same backend as the global process group 4286 # if backend is not specified. 4287 if not backend: 4288 backend = get_backend(_get_default_group()) 4289 return Backend(backend) 4290 4291 4292def _is_safe_to_split() -> bool: 4293 """ 4294 Checks if it is safe to split the any process group in the world. 4295 This is only safe if the default pg has a bound device id, otherwise 4296 users must be aware that a pg is only splittable after the first collective is 4297 issued. 4298 """ 4299 return False if _get_default_group().bound_device_id is None else True 4300 4301 4302@_time_logger 4303def split_group( 4304 parent_pg: Optional[ProcessGroup] = None, 4305 split_ranks: Optional[list] = None, 4306 timeout: Optional[timedelta] = None, 4307 pg_options: Optional[Any] = None, 4308 group_desc: Optional[str] = None, 4309) -> Optional[ProcessGroup]: 4310 """ 4311 Create a new process group splitted from the given parent process group. 4312 4313 warning:: This is an experimental API and only the ``NCCL`` backend supports this API. 4314 Other backends will raise an error. 4315 Users of this API must gurantee that all ranks in the parent group enter this API call, 4316 and the split of the sub groups is the same accross all ranks in the parent group. 4317 4318 Args: 4319 parent_pg (ProcessGroup, optional): The parent process group. If None, 4320 the default process group will be used. Users need to gurantee that 4321 the parent group is fully initialized (e.g, communicators are initialized) 4322 split_ranks (list[list[int]]): the split ranks, which is a list of list of ranks. 4323 Users need to make sure the validity of the split ranks such that one 4324 split (represented by one inner list of ints) does not overlap with any other split. 4325 Note that the ranks in each split is the group rank (instead of global rank) 4326 in the parent pg. For example, if the parent group has 4 ranks, and split_ranks can be 4327 [[0, 1], [2, 3]]. Note [[0,1]] is also a valid split, in which case ranks 2, 3 would 4328 return a non-group member. 4329 timeout (timedelta, optional): see `init_process_group` for details and default value. 4330 pg_options (ProcessGroupOptions, optional): only ProcessGroupNCCLOptions is supported now. 4331 specifying what additional options need to be passed in during 4332 the construction of specific process groups. i.e.``is_high_priority_stream`` 4333 can be specified so that process group can pick up high priority cuda streams. 4334 For other availble options to config nccl, 4335 See https://docs.nvidia.com/deeplearning/nccl/user-guide/docs/api/types.html#ncclconfig-t 4336 group_desc (str, optional): a string to describe the process group. 4337 4338 Returns: 4339 ProcessGroup if the current rank is within one split/subgroup given by split_ranks, 4340 or None if the current rank is not part of any split_ranks`. 4341 4342 """ 4343 # check inputs 4344 if split_ranks is None: 4345 raise ValueError("split_ranks cannot be None") 4346 4347 global _world 4348 default_pg = _get_default_group() 4349 device_id = default_pg.bound_device_id 4350 if not device_id: 4351 raise RuntimeError( 4352 "No device associated with the default pg, not safe to split any process groups" 4353 ) 4354 default_backend, default_store = _world.pg_map[default_pg] 4355 global_rank = default_pg.rank() 4356 global_world_size = default_pg.size() 4357 4358 if not parent_pg: 4359 parent_pg = default_pg 4360 if parent_pg not in _world.pg_group_ranks: 4361 raise ValueError(f"Group {parent_pg} is not registered") 4362 4363 parent_global_to_group_ranks = _world.pg_group_ranks[parent_pg] 4364 parent_group_to_global_ranks = { 4365 group_rank: global_rank 4366 for global_rank, group_rank in parent_global_to_group_ranks.items() 4367 } 4368 4369 if global_rank not in parent_global_to_group_ranks: 4370 raise ValueError( 4371 f"Global rank {global_rank} is not part of the parent group {parent_pg}" 4372 ) 4373 4374 parent_group_rank = parent_global_to_group_ranks[global_rank] 4375 parent_backend = parent_pg._get_backend(torch.device("cuda")) 4376 4377 # if the parent backend does not support splitting, raise error 4378 # currently this API only support NCCL backend 4379 if ( 4380 not parent_backend 4381 or not parent_backend.supports_splitting 4382 or not isinstance(parent_backend, ProcessGroupNCCL) 4383 ): 4384 raise RuntimeError( 4385 "No backend for the parent process group or its backend does not support splitting" 4386 ) 4387 4388 # set the group_desc before the color or no_cloor split 4389 group_desc = ( 4390 f"{parent_pg.group_desc}:split:{parent_backend.comm_split_count()}" 4391 if group_desc is None 4392 else group_desc 4393 ) 4394 4395 parent_backend_str, _ = _world.pg_map[parent_pg] 4396 # same type of backend as the parent process group 4397 backend = Backend(parent_backend_str) 4398 backend_config = BackendConfig(backend) 4399 4400 if pg_options is not None: 4401 assert isinstance( 4402 pg_options, ProcessGroupNCCL.Options 4403 ), "Expected pg_options argument to be of type ProcessGroupNCCL.Options" 4404 else: 4405 # default pg_options same as the parent process group 4406 pg_options = parent_backend.options 4407 4408 # this timeout defaulting/validation is used for all the new_groups/new_subgroups variants, 4409 # which may just pass their timeout value (or None) 4410 if timeout is None: 4411 timeout = _get_default_timeout(backend) 4412 _check_valid_timeout(timeout) 4413 4414 # find my group of ranks and my group local rank in split_ranks 4415 my_group = None 4416 group_rank = -1 4417 4418 for split_group in split_ranks: 4419 if len(split_group) == 0: 4420 raise ValueError("the split group cannot be empty") 4421 if len(split_group) > global_world_size: 4422 raise ValueError( 4423 "the split group's size should be less or equal to the world_size set by init_process_group" 4424 ) 4425 if len(split_group) != len(set(split_group)): 4426 raise ValueError("the split group cannot have duplicate ranks") 4427 split_group = sorted(split_group) 4428 if parent_group_rank in split_group: 4429 my_group = split_group 4430 group_rank = split_group.index(parent_group_rank) 4431 break 4432 # if my rank does not belong to any sub group, 4433 # no_color split should be called 4434 if my_group is None or group_rank == -1: 4435 parent_backend.perform_nocolor_split(device_id) 4436 return None 4437 4438 group_name = _process_group_name(my_group, use_hashed_name=False) 4439 global_ranks_in_my_group = [parent_group_to_global_ranks[rank] for rank in my_group] 4440 4441 prefix_store = PrefixStore(f"{group_name}/", default_store) 4442 base_pg_options = ProcessGroup.Options(backend=str(backend)) 4443 base_pg_options._timeout = timeout 4444 pg: ProcessGroup = ProcessGroup( 4445 prefix_store, group_rank, len(my_group), base_pg_options 4446 ) 4447 pg.bound_device_id = device_id 4448 4449 pg_options._timeout = timeout 4450 pg_options.split_from = parent_backend 4451 pg_options.split_color = _process_group_color(my_group) 4452 pg_options.global_ranks_in_group = global_ranks_in_my_group 4453 pg_options.group_name = group_name 4454 backend_class = ProcessGroupNCCL( 4455 prefix_store, group_rank, len(my_group), pg_options 4456 ) 4457 backend_type = ProcessGroup.BackendType.NCCL 4458 backend_class._set_sequence_number_for_group() 4459 4460 pg._register_backend(torch.device("cuda"), backend_type, backend_class) 4461 4462 # set group_name and group_desc to backend 4463 assert group_name is not None 4464 assert group_desc is not None 4465 pg._set_group_name(group_name) 4466 pg._set_group_desc(group_desc) 4467 4468 # always eagerly initialize the backend in split_group 4469 eager_backend = pg._get_backend(device_id) 4470 eager_backend.eager_connect_single_device(device_id) 4471 4472 # update global state 4473 _world.pg_map[pg] = (backend, prefix_store) 4474 _world.pg_names[pg] = group_name 4475 _register_process_group(group_name, pg) 4476 _world.pg_backend_config[pg] = str(backend_config) 4477 pg_tag = f"ptd:{group_name}" 4478 _world.tags_to_pg.setdefault(pg_tag, []).append(pg) 4479 _world.pg_to_tag[pg] = pg_tag 4480 4481 # Create the global rank to group rank mapping 4482 _world.pg_group_ranks[pg] = { 4483 global_rank: group_rank 4484 for group_rank, global_rank in enumerate(global_ranks_in_my_group) 4485 } 4486 4487 return pg 4488 4489 4490@_time_logger 4491def new_group( 4492 ranks=None, 4493 timeout=None, 4494 backend=None, 4495 pg_options=None, 4496 use_local_synchronization=False, 4497 group_desc=None, 4498): 4499 """ 4500 Create a new distributed group. 4501 4502 This function requires that all processes in the main group (i.e. all 4503 processes that are part of the distributed job) enter this function, even 4504 if they are not going to be members of the group. Additionally, groups 4505 should be created in the same order in all processes. 4506 4507 .. warning:: 4508 Safe concurrent usage: 4509 When using multiple process groups with the ``NCCL`` backend, the user 4510 must ensure a globally consistent execution order of collectives across 4511 ranks. 4512 4513 If multiple threads within a process issue collectives, explicit 4514 synchronization is necessary to ensure consistent ordering. 4515 4516 When using async variants of torch.distributed communication APIs, 4517 a work object is returned and the communication kernel is 4518 enqueued on a separate CUDA stream, allowing overlap of communication 4519 and computation. Once one or more async ops have been issued on one process 4520 group, they must be synchronized with other cuda streams by calling `work.wait()` 4521 before using another process group. 4522 4523 See `Using multiple NCCL communicators concurrently <https://docs.nvid 4524 ia.com/deeplearning/nccl/user-guide/docs/usage/communicators.html#using 4525 -multiple-nccl-communicators-concurrently>`_ for more details. 4526 4527 Args: 4528 ranks (list[int]): List of ranks of group members. If ``None``, will be 4529 set to all ranks. Default is ``None``. 4530 timeout (timedelta, optional): see `init_process_group` for details and default value. 4531 backend (str or Backend, optional): The backend to use. Depending on 4532 build-time configurations, valid values are ``gloo`` and ``nccl``. 4533 By default uses the same backend as the global group. This field 4534 should be given as a lowercase string (e.g., ``"gloo"``), which can 4535 also be accessed via :class:`Backend` attributes (e.g., 4536 ``Backend.GLOO``). If ``None`` is passed in, the backend 4537 corresponding to the default process group will be used. Default is 4538 ``None``. 4539 pg_options (ProcessGroupOptions, optional): process group options 4540 specifying what additional options need to be passed in during 4541 the construction of specific process groups. i.e. for the ``nccl`` 4542 backend, ``is_high_priority_stream`` can be specified so that 4543 process group can pick up high priority cuda streams. For other availble options to config nccl, 4544 See https://docs.nvidia.com/deeplearning/nccl/user-guide/docs/api/types.html#ncclconfig-t 4545 use_local_synchronization (bool, optional): perform a group-local 4546 barrier at the end of the process group creation. This is different 4547 in that non-member ranks don't need to call into API and don't 4548 join the barrier. 4549 group_desc (str, optional): a string to describe the process group. 4550 4551 Returns: 4552 A handle of distributed group that can be given to collective calls or 4553 GroupMember.NON_GROUP_MEMBER if the rank is not part of ``ranks``. 4554 4555 N.B. use_local_synchronization doesn't work with MPI. 4556 4557 N.B. While use_local_synchronization=True can be significantly faster with larger 4558 clusters and small process groups, care must be taken since it changes cluster behavior 4559 as non-member ranks don't join the group barrier(). 4560 4561 N.B. use_local_synchronization=True can lead to deadlocks when each rank creates 4562 multiple overlaping process groups. To avoid that, make sure all ranks follow the 4563 same global creation order. 4564 """ 4565 return _new_group_with_tag( 4566 ranks, 4567 timeout, 4568 backend, 4569 pg_options, 4570 None, 4571 use_local_synchronization=use_local_synchronization, 4572 group_desc=group_desc, 4573 ) 4574 4575 4576def _new_group_with_tag( 4577 ranks=None, 4578 timeout=None, 4579 backend=None, 4580 pg_options=None, 4581 pg_tag=None, 4582 use_local_synchronization=False, 4583 group_desc=None, 4584): 4585 """ 4586 Variant of ``new_group`` that exposes tag creation. 4587 4588 :: N.B. The mechanism is experimental and tied to the functional collectives effort, see 4589 ``torch.distributed._functional_collectives`` for reference on how to use it. 4590 """ 4591 global _world 4592 4593 default_pg = _get_default_group() 4594 device_id = default_pg.bound_device_id 4595 default_backend, default_store = _world.pg_map[default_pg] 4596 global_rank = default_pg.rank() 4597 global_world_size = default_pg.size() 4598 4599 # Default to the same backend as the global process group 4600 # if the backend is not specified. 4601 if not backend: 4602 backend = default_backend 4603 backend = Backend(backend) 4604 4605 # this timeout defaulting/validation is used for all the new_groups/new_subgroups variants, 4606 # which may just pass their timeout value (or None) 4607 if timeout is None: 4608 timeout = _get_default_timeout(backend) 4609 _check_valid_timeout(timeout) 4610 4611 if use_local_synchronization: 4612 # MPI backend doesn't have have a way for us to perform a partial sync 4613 if backend == Backend.MPI: 4614 raise ValueError( 4615 "MPI backend doesn't support use_local_synchronization=True" 4616 ) 4617 if ranks is not None and get_rank() not in ranks: 4618 return None 4619 4620 # checks the input ranks 4621 if ranks is not None: 4622 ranks = sorted(ranks) 4623 group_world_size = len(ranks) 4624 if group_world_size > global_world_size: 4625 raise ValueError( 4626 "the new group's world size should be less or " 4627 "equal to the world size set by " 4628 "init_process_group" 4629 ) 4630 # check ranks' sanity 4631 for rank in ranks: 4632 if rank < 0 or rank >= global_world_size: 4633 raise ValueError( 4634 "The new group's rank should be within " 4635 "the world_size set by init_process_group" 4636 ) 4637 if global_rank in ranks: 4638 group_rank = ranks.index(global_rank) 4639 else: 4640 group_rank = None 4641 else: 4642 ranks = list(range(global_world_size)) 4643 group_world_size = global_world_size 4644 group_rank = global_rank 4645 4646 group_name = _process_group_name(ranks, use_hashed_name=use_local_synchronization) 4647 4648 pg, pg_store = _new_process_group_helper( 4649 group_world_size, 4650 group_rank, 4651 ranks, 4652 backend, 4653 default_store, 4654 group_name, 4655 pg_options=pg_options, 4656 timeout=timeout, 4657 pg_tag=pg_tag, 4658 device_id=device_id, 4659 group_desc=group_desc, 4660 ) 4661 4662 # Create the global rank to group rank mapping 4663 _world.pg_group_ranks[pg] = { 4664 global_rank: group_rank for group_rank, global_rank in enumerate(ranks) 4665 } 4666 4667 if _is_barrier_after_init() == 1: 4668 # barrier at the end to ensure that once we return from this method, all 4669 # process groups including global variables (if any) are updated 4670 # correctly on all ranks. 4671 # Update 04/2023: for large-scale runs, this barrier (esp. store-based 4672 # barrier) may be costly and/or unscalable. Also, in a lot of cases, 4673 # these barriers may be unnecessary, as proven by a green CI after 4674 # removal. An environment variable `TORCH_DIST_INIT_BARRIER` has been 4675 # added which enables this barrier only when set to 1. 4676 logger.info( 4677 "Performing barrier after ProcessGroup initialization since " 4678 "TORCH_DIST_INIT_BARRIER = 1" 4679 ) 4680 if backend == Backend.MPI: 4681 # MPI doesn't have store. 4682 barrier() 4683 else: 4684 barrier_store = pg_store if use_local_synchronization else default_store 4685 world_size = len(ranks) if use_local_synchronization else get_world_size() 4686 # Use store based barrier here since barrier() used a bunch of 4687 # default devices and messes up NCCL internal state. 4688 _store_based_barrier( 4689 global_rank, barrier_store, group_name, world_size, timeout 4690 ) 4691 4692 return pg 4693 4694 4695def new_subgroups( 4696 group_size=None, 4697 group=None, 4698 timeout=None, 4699 backend=None, 4700 pg_options=None, 4701 group_desc=None, 4702): 4703 """ 4704 Create subgroups of equal size. 4705 4706 By default, it creates intra-machine subgroups, 4707 where each of which contains all the ranks of a machine, based on the assumption 4708 that each machine has the same number of devices. 4709 4710 This is a convenience API that calls ``new_group`` to generate multiple subgroups. 4711 It requires that all processes in the main group (i.e. all 4712 processes that are part of the distributed job) enter this function, even 4713 if they are not going to be members of the group. 4714 4715 .. warning:: 4716 If ``group_size`` is passed in, the world size must be divisible by ``group_size``. 4717 If no ``group_size`` is passed in, it believe that you are creating a group based 4718 on CUDA and determining the group size by number of CUDA devices, and if not all 4719 the machines have the same number of devices, the subgroup division will be 4720 different across nodes and can cause unexpected behaviors. Therefore, if you are 4721 creating a subgroup that does not depend on CUDA (such as Gloo on CPU), please 4722 pass in ``group_size`` correctly. 4723 4724 .. warning:: 4725 See warning `Safe concurrent usage` for `new_group` API for important details about 4726 using multiple process groups concurrently in a safe manner. 4727 4728 Args: 4729 group_size (int, optional): The size of each subgroup. If ``None``, 4730 the default subgroup size is equal to the number of devices on each machine, 4731 based on the assumption that each machine has exactly the same 4732 number of devices. Default is ``None``. 4733 timeout (timedelta, optional): see `init_process_group` for details and default value. 4734 backend (str or Backend, optional): The backend to use. Depending on 4735 build-time configurations, valid values are ``gloo`` and ``nccl``. 4736 By default uses the same backend as the global group. This field 4737 should be given as a lowercase string (e.g., ``"gloo"``), which can 4738 also be accessed via :class:`Backend` attributes (e.g., 4739 ``Backend.GLOO``). If ``None`` is passed in, the backend 4740 corresponding to the default process group will be used. Default is 4741 ``None``. 4742 pg_options (ProcessGroupOptions, optional): process group options 4743 specifying what additional options need to be passed in during 4744 the construction of specific process groups. i.e. for the ``nccl`` 4745 backend, ``is_high_priority_stream`` can be specified so that 4746 process group can pick up high priority cuda streams. 4747 group_desc (str, optional): A string describing the group. Each subgroup will 4748 inherit its group_desc 4749 4750 Returns: 4751 The subgroup containing the current rank, and all the subgroups used for cleanup. 4752 4753 Examples: 4754 >>> # Create intra-machine subgroups. 4755 >>> # xdoctest: +SKIP("need process group init") 4756 >>> cur_subgroup, subgroups = dist.new_subgroups() 4757 >>> # Allreduce within the machine. 4758 >>> rank = dist.get_rank() 4759 >>> tensor = torch.ones(1, device=rank) * rank 4760 >>> dist.all_reduce(tensor, group=cur_subgroup) 4761 >>> tensor 4762 tensor([28]) # Assume 8 CUDA devices per machine. 28 is sum(range(8)). 4763 >>> # Cleanup. 4764 >>> for subgroup in subgroups: 4765 >>> dist.destroy_process_group(subgroup) 4766 """ 4767 if group_size is None: 4768 if not torch.cuda.is_available(): 4769 raise ValueError( 4770 "Default group size only takes effect when CUDA is available." 4771 "If your subgroup using a backend that does not depend on CUDA," 4772 "please pass in 'group_size' correctly." 4773 ) 4774 group_size = torch.cuda.device_count() 4775 if group_size <= 0: 4776 raise ValueError(f"The arg 'group_size' ({group_size}) must be positive") 4777 4778 world_size = get_world_size() 4779 if world_size < group_size: 4780 raise ValueError( 4781 f"The arg 'group_size' ({group_size}) must not exceed the world size ({world_size})" 4782 ) 4783 if world_size % group_size != 0: 4784 raise ValueError("The world size must be divisible by 'group_size'") 4785 4786 subgroups = [] 4787 cur_subgroup = None 4788 4789 for subgroup_id in range(world_size // group_size): 4790 start_rank = subgroup_id * group_size 4791 end_rank = start_rank + group_size 4792 ranks_in_subgroup = list(range(start_rank, end_rank)) 4793 subgroup = new_group( 4794 ranks=ranks_in_subgroup, 4795 timeout=timeout, 4796 backend=backend, 4797 pg_options=pg_options, 4798 group_desc=group_desc, 4799 ) 4800 subgroups.append(subgroup) 4801 4802 rank = get_rank() 4803 if rank in ranks_in_subgroup: 4804 cur_subgroup = subgroup 4805 logger.info("Rank %s is assigned to subgroup %s", rank, ranks_in_subgroup) 4806 4807 return cur_subgroup, subgroups 4808 4809 4810def new_subgroups_by_enumeration( 4811 ranks_per_subgroup_list, 4812 timeout=None, 4813 backend=None, 4814 pg_options=None, 4815 group_desc=None, 4816): 4817 """ 4818 Create subgroups by dividing the global world. 4819 4820 The division is specified by a nested list of ranks. The subgroups cannot have 4821 overlap, and some ranks may not have to be in any subgroup. 4822 4823 This is a convenience API that calls ``new_group`` to generate multiple subgroups. 4824 It requires that all processes in the main group (i.e. all 4825 processes that are part of the distributed job) enter this function, even 4826 if they are not going to be members of the group. 4827 4828 .. warning:: 4829 See warning `Safe concurrent usage` for `new_group` API for important details about 4830 using multiple process groups concurrently in a safe manner. 4831 4832 Args: 4833 ranks_per_subgroup_list (list[list[int]]): A nested list of ranks of 4834 group members. 4835 timeout (timedelta, optional): see `init_process_group` for details and default value. 4836 backend (str or Backend, optional): The backend to use. Depending on 4837 build-time configurations, valid values are ``gloo`` and ``nccl``. 4838 By default uses the same backend as the global group. This field 4839 should be given as a lowercase string (e.g., ``"gloo"``), which can 4840 also be accessed via :class:`Backend` attributes (e.g., 4841 ``Backend.GLOO``). If ``None`` is passed in, the backend 4842 corresponding to the default process group will be used. Default is 4843 ``None``. 4844 pg_options (ProcessGroupOptions, optional): process group options 4845 specifying what additional options need to be passed in during 4846 the construction of specific process groups. i.e. for the ``nccl`` 4847 backend, ``is_high_priority_stream`` can be specified so that 4848 process group can pick up high priority cuda streams. 4849 group_desc (str, optional): A string describing the group. Each subgroup will 4850 inherit its group_desc. 4851 4852 Returns: 4853 The subgroup containing the current rank, and all the subgroups used for cleanup. 4854 4855 Examples: 4856 >>> # Create two subgroups, where each has 2 processes. 4857 >>> # xdoctest: +SKIP("need process group init") 4858 >>> cur_subgroup, subgroups = dist.new_subgroups(ranks=[[0, 2], [1, 3]]) 4859 >>> rank = dist.get_rank() 4860 >>> tensor = torch.ones(1, device=rank) * rank 4861 >>> dist.all_reduce(tensor, group=cur_subgroup) 4862 >>> tensor 4863 tensor([2]) # Subgroup 0: ranks 0 and 2 4864 tensor([4]) # Subgroup 1: ranks 1 and 3 4865 """ 4866 if ranks_per_subgroup_list is None or len(ranks_per_subgroup_list) == 0: 4867 raise ValueError("The arg 'ranks_per_subgroup_list' cannot be empty") 4868 4869 subgroups = [] 4870 cur_subgroup = None 4871 # Create a mapping from rank to subgroup to check if there is any subgroup overlap. 4872 rank_to_ranks_dict = {} # type: ignore[var-annotated] 4873 for ranks in ranks_per_subgroup_list: 4874 subgroup = new_group( 4875 ranks=ranks, 4876 timeout=timeout, 4877 backend=backend, 4878 pg_options=pg_options, 4879 group_desc=group_desc, 4880 ) 4881 subgroups.append(subgroup) 4882 my_rank = get_rank() 4883 for rank in ranks: 4884 if rank in rank_to_ranks_dict: 4885 raise ValueError( 4886 f"Rank {rank} has appeared in both subgroup {rank_to_ranks_dict[rank]} and {ranks}" 4887 ) 4888 rank_to_ranks_dict[rank] = ranks 4889 if my_rank == rank: 4890 cur_subgroup = subgroup 4891 logger.info("Rank %s is assigned to subgroup %s", rank, ranks) 4892 4893 return cur_subgroup, subgroups 4894 4895 4896def _find_pg_by_ranks_and_tag(tag: str, ranks: List[int]) -> Optional[ProcessGroup]: 4897 if len(tag) > 0 and not tag.startswith("ptd:") and not tag.startswith("user:"): 4898 tag = f"user:{tag}" 4899 4900 for group in _world.tags_to_pg.get(tag, []): 4901 if group.size() != len(ranks): 4902 continue 4903 4904 group_ranks = get_process_group_ranks(group) 4905 good = all(r in group_ranks for r in ranks) 4906 if good: 4907 return group 4908 return None 4909 4910 4911def _find_or_create_pg_by_ranks_and_tag( 4912 tag: str, ranks: List[int], stride: int 4913) -> ProcessGroup: 4914 assert ( 4915 len(ranks) % stride == 0 4916 ), f"Ranks length ({len(ranks)}) must be divisible by stride ({stride})" 4917 4918 my_rank = get_rank() 4919 my_ranks = None 4920 4921 if stride == len(ranks): 4922 my_ranks = ranks.copy() 4923 assert my_rank in my_ranks, "rankset doesn't include the current node" 4924 else: 4925 for i in range(0, len(ranks), stride): 4926 rank_set = ranks[i : i + stride] 4927 if my_rank in rank_set: 4928 my_ranks = rank_set 4929 assert my_ranks is not None, "rankset doesn't include the current node" 4930 4931 my_ranks = sorted(my_ranks) 4932 4933 pg = _find_pg_by_ranks_and_tag(tag, my_ranks) 4934 if pg is not None: 4935 return pg 4936 if tag == "": 4937 raise ValueError("Cannot automatically create PG with empty tag") 4938 # TODO copy settings and timeout from default PG 4939 return _new_group_with_tag(my_ranks, pg_tag=tag) 4940 4941 4942def _get_group_tag(pg: ProcessGroup) -> str: 4943 """Return the tag associated with ``pg``.""" 4944 tag = _world.pg_to_tag[pg] 4945 if tag.startswith("user:"): 4946 tag = tag[5:] 4947 return tag 4948 4949 4950def _get_process_group_name(pg: ProcessGroup) -> str: 4951 return _world.pg_names.get(pg, "None") 4952 4953 4954def _get_process_group_store(pg: ProcessGroup) -> Store: 4955 return _world.pg_map[pg][1] 4956 4957 4958# This ops are not friendly to TorchDynamo. So, we decide to disallow these ops 4959# in FX graph, allowing them to run them on eager, with torch.compile. 4960dynamo_unsupported_distributed_c10d_ops = [ 4961 recv, 4962 all_gather_object, 4963 all_gather_coalesced, 4964 all_to_all_single, 4965 all_reduce, 4966 gather_object, 4967 all_to_all, 4968 all_reduce_coalesced, 4969 gather, 4970 send_object_list, 4971 recv_object_list, 4972 broadcast_object_list, 4973 barrier, 4974 scatter, 4975 scatter_object_list, 4976 reduce, 4977 all_gather, 4978 reduce_scatter, 4979 all_gather_into_tensor, 4980 broadcast, 4981 reduce_scatter_tensor, 4982 send, 4983] 4984