1*da0073e9SAndroid Build Coastguard Worker# mypy: allow-untyped-defs 2*da0073e9SAndroid Build Coastguard Workerimport collections 3*da0073e9SAndroid Build Coastguard Workerimport warnings 4*da0073e9SAndroid Build Coastguard Workerfrom typing import Optional, Sequence, Union 5*da0073e9SAndroid Build Coastguard Worker 6*da0073e9SAndroid Build Coastguard Workerimport torch.cuda 7*da0073e9SAndroid Build Coastguard Worker 8*da0073e9SAndroid Build Coastguard Worker 9*da0073e9SAndroid Build Coastguard Worker__all__ = ["all_reduce", "reduce", "broadcast", "all_gather", "reduce_scatter"] 10*da0073e9SAndroid Build Coastguard Worker 11*da0073e9SAndroid Build Coastguard WorkerSUM = 0 # ncclRedOp_t 12*da0073e9SAndroid Build Coastguard Worker 13*da0073e9SAndroid Build Coastguard Worker 14*da0073e9SAndroid Build Coastguard Workerdef is_available(tensors): 15*da0073e9SAndroid Build Coastguard Worker if not hasattr(torch._C, "_nccl_all_reduce"): 16*da0073e9SAndroid Build Coastguard Worker warnings.warn("PyTorch is not compiled with NCCL support") 17*da0073e9SAndroid Build Coastguard Worker return False 18*da0073e9SAndroid Build Coastguard Worker 19*da0073e9SAndroid Build Coastguard Worker devices = set() 20*da0073e9SAndroid Build Coastguard Worker for tensor in tensors: 21*da0073e9SAndroid Build Coastguard Worker if tensor.is_sparse: 22*da0073e9SAndroid Build Coastguard Worker return False 23*da0073e9SAndroid Build Coastguard Worker if not tensor.is_contiguous(): 24*da0073e9SAndroid Build Coastguard Worker return False 25*da0073e9SAndroid Build Coastguard Worker if not tensor.is_cuda: 26*da0073e9SAndroid Build Coastguard Worker return False 27*da0073e9SAndroid Build Coastguard Worker device = tensor.get_device() 28*da0073e9SAndroid Build Coastguard Worker if device in devices: 29*da0073e9SAndroid Build Coastguard Worker return False 30*da0073e9SAndroid Build Coastguard Worker devices.add(device) 31*da0073e9SAndroid Build Coastguard Worker 32*da0073e9SAndroid Build Coastguard Worker return True 33*da0073e9SAndroid Build Coastguard Worker 34*da0073e9SAndroid Build Coastguard Worker 35*da0073e9SAndroid Build Coastguard Workerdef version(): 36*da0073e9SAndroid Build Coastguard Worker """ 37*da0073e9SAndroid Build Coastguard Worker Returns the version of the NCCL. 38*da0073e9SAndroid Build Coastguard Worker 39*da0073e9SAndroid Build Coastguard Worker 40*da0073e9SAndroid Build Coastguard Worker This function returns a tuple containing the major, minor, and patch version numbers of the NCCL. 41*da0073e9SAndroid Build Coastguard Worker The suffix is also included in the tuple if a version suffix exists. 42*da0073e9SAndroid Build Coastguard Worker Returns: 43*da0073e9SAndroid Build Coastguard Worker tuple: The version information of the NCCL. 44*da0073e9SAndroid Build Coastguard Worker """ 45*da0073e9SAndroid Build Coastguard Worker ver = torch._C._nccl_version() 46*da0073e9SAndroid Build Coastguard Worker major = ver >> 32 47*da0073e9SAndroid Build Coastguard Worker minor = (ver >> 16) & 65535 48*da0073e9SAndroid Build Coastguard Worker patch = ver & 65535 49*da0073e9SAndroid Build Coastguard Worker suffix = torch._C._nccl_version_suffix().decode("utf-8") 50*da0073e9SAndroid Build Coastguard Worker if suffix == "": 51*da0073e9SAndroid Build Coastguard Worker return (major, minor, patch) 52*da0073e9SAndroid Build Coastguard Worker else: 53*da0073e9SAndroid Build Coastguard Worker return (major, minor, patch, suffix) 54*da0073e9SAndroid Build Coastguard Worker 55*da0073e9SAndroid Build Coastguard Worker 56*da0073e9SAndroid Build Coastguard Workerdef unique_id(): 57*da0073e9SAndroid Build Coastguard Worker return torch._C._nccl_unique_id() 58*da0073e9SAndroid Build Coastguard Worker 59*da0073e9SAndroid Build Coastguard Worker 60*da0073e9SAndroid Build Coastguard Workerdef init_rank(num_ranks, uid, rank): 61*da0073e9SAndroid Build Coastguard Worker return torch._C._nccl_init_rank(num_ranks, uid, rank) 62*da0073e9SAndroid Build Coastguard Worker 63*da0073e9SAndroid Build Coastguard Worker 64*da0073e9SAndroid Build Coastguard Workerdef _check_sequence_type(inputs: Union[torch.Tensor, Sequence[torch.Tensor]]) -> None: 65*da0073e9SAndroid Build Coastguard Worker if not isinstance(inputs, collections.abc.Container) or isinstance( 66*da0073e9SAndroid Build Coastguard Worker inputs, torch.Tensor 67*da0073e9SAndroid Build Coastguard Worker ): 68*da0073e9SAndroid Build Coastguard Worker raise TypeError("Inputs should be a collection of tensors") 69*da0073e9SAndroid Build Coastguard Worker 70*da0073e9SAndroid Build Coastguard Worker 71*da0073e9SAndroid Build Coastguard Workerdef all_reduce(inputs, outputs=None, op=SUM, streams=None, comms=None): 72*da0073e9SAndroid Build Coastguard Worker _check_sequence_type(inputs) 73*da0073e9SAndroid Build Coastguard Worker if outputs is None: 74*da0073e9SAndroid Build Coastguard Worker outputs = inputs 75*da0073e9SAndroid Build Coastguard Worker _check_sequence_type(outputs) 76*da0073e9SAndroid Build Coastguard Worker torch._C._nccl_all_reduce(inputs, outputs, op, streams, comms) 77*da0073e9SAndroid Build Coastguard Worker 78*da0073e9SAndroid Build Coastguard Worker 79*da0073e9SAndroid Build Coastguard Worker# `output` used to be `outputs`, taking in a list of tensors. So we have two 80*da0073e9SAndroid Build Coastguard Worker# arguments for BC reasons. 81*da0073e9SAndroid Build Coastguard Workerdef reduce( 82*da0073e9SAndroid Build Coastguard Worker inputs: Sequence[torch.Tensor], 83*da0073e9SAndroid Build Coastguard Worker output: Optional[Union[torch.Tensor, Sequence[torch.Tensor]]] = None, 84*da0073e9SAndroid Build Coastguard Worker root: int = 0, 85*da0073e9SAndroid Build Coastguard Worker op: int = SUM, 86*da0073e9SAndroid Build Coastguard Worker streams: Optional[Sequence[torch.cuda.Stream]] = None, 87*da0073e9SAndroid Build Coastguard Worker comms=None, 88*da0073e9SAndroid Build Coastguard Worker *, 89*da0073e9SAndroid Build Coastguard Worker outputs: Optional[Sequence[torch.Tensor]] = None, 90*da0073e9SAndroid Build Coastguard Worker) -> None: 91*da0073e9SAndroid Build Coastguard Worker _check_sequence_type(inputs) 92*da0073e9SAndroid Build Coastguard Worker _output: torch.Tensor 93*da0073e9SAndroid Build Coastguard Worker if outputs is not None: 94*da0073e9SAndroid Build Coastguard Worker if output is not None: 95*da0073e9SAndroid Build Coastguard Worker raise ValueError( 96*da0073e9SAndroid Build Coastguard Worker "'output' and 'outputs' can not be both specified. 'outputs' is deprecated in " 97*da0073e9SAndroid Build Coastguard Worker "favor of 'output', taking in a single output tensor. The signature of reduce is: " 98*da0073e9SAndroid Build Coastguard Worker "reduce(inputs, output=None, root=0, op=SUM, streams=None, comms=None)." 99*da0073e9SAndroid Build Coastguard Worker ) 100*da0073e9SAndroid Build Coastguard Worker else: 101*da0073e9SAndroid Build Coastguard Worker warnings.warn( 102*da0073e9SAndroid Build Coastguard Worker "`nccl.reduce` with an output tensor list is deprecated. " 103*da0073e9SAndroid Build Coastguard Worker "Please specify a single output tensor with argument 'output' instead instead.", 104*da0073e9SAndroid Build Coastguard Worker FutureWarning, 105*da0073e9SAndroid Build Coastguard Worker stacklevel=2, 106*da0073e9SAndroid Build Coastguard Worker ) 107*da0073e9SAndroid Build Coastguard Worker _output = outputs[root] 108*da0073e9SAndroid Build Coastguard Worker elif not isinstance(output, torch.Tensor) and isinstance( 109*da0073e9SAndroid Build Coastguard Worker output, collections.abc.Sequence 110*da0073e9SAndroid Build Coastguard Worker ): 111*da0073e9SAndroid Build Coastguard Worker # User called old API with positional arguments of list of output tensors. 112*da0073e9SAndroid Build Coastguard Worker warnings.warn( 113*da0073e9SAndroid Build Coastguard Worker "nccl.reduce with an output tensor list is deprecated. " 114*da0073e9SAndroid Build Coastguard Worker "Please specify a single output tensor.", 115*da0073e9SAndroid Build Coastguard Worker FutureWarning, 116*da0073e9SAndroid Build Coastguard Worker stacklevel=2, 117*da0073e9SAndroid Build Coastguard Worker ) 118*da0073e9SAndroid Build Coastguard Worker _output = output[root] 119*da0073e9SAndroid Build Coastguard Worker else: 120*da0073e9SAndroid Build Coastguard Worker _output = inputs[root] if output is None else output 121*da0073e9SAndroid Build Coastguard Worker torch._C._nccl_reduce(inputs, _output, root, op, streams, comms) 122*da0073e9SAndroid Build Coastguard Worker 123*da0073e9SAndroid Build Coastguard Worker 124*da0073e9SAndroid Build Coastguard Workerdef broadcast( 125*da0073e9SAndroid Build Coastguard Worker inputs: Sequence[torch.Tensor], root: int = 0, streams=None, comms=None 126*da0073e9SAndroid Build Coastguard Worker) -> None: 127*da0073e9SAndroid Build Coastguard Worker _check_sequence_type(inputs) 128*da0073e9SAndroid Build Coastguard Worker torch._C._nccl_broadcast(inputs, root, streams, comms) 129*da0073e9SAndroid Build Coastguard Worker 130*da0073e9SAndroid Build Coastguard Worker 131*da0073e9SAndroid Build Coastguard Workerdef all_gather( 132*da0073e9SAndroid Build Coastguard Worker inputs: Sequence[torch.Tensor], 133*da0073e9SAndroid Build Coastguard Worker outputs: Sequence[torch.Tensor], 134*da0073e9SAndroid Build Coastguard Worker streams=None, 135*da0073e9SAndroid Build Coastguard Worker comms=None, 136*da0073e9SAndroid Build Coastguard Worker) -> None: 137*da0073e9SAndroid Build Coastguard Worker _check_sequence_type(inputs) 138*da0073e9SAndroid Build Coastguard Worker _check_sequence_type(outputs) 139*da0073e9SAndroid Build Coastguard Worker torch._C._nccl_all_gather(inputs, outputs, streams, comms) 140*da0073e9SAndroid Build Coastguard Worker 141*da0073e9SAndroid Build Coastguard Worker 142*da0073e9SAndroid Build Coastguard Workerdef reduce_scatter( 143*da0073e9SAndroid Build Coastguard Worker inputs: Sequence[torch.Tensor], 144*da0073e9SAndroid Build Coastguard Worker outputs: Sequence[torch.Tensor], 145*da0073e9SAndroid Build Coastguard Worker op: int = SUM, 146*da0073e9SAndroid Build Coastguard Worker streams=None, 147*da0073e9SAndroid Build Coastguard Worker comms=None, 148*da0073e9SAndroid Build Coastguard Worker) -> None: 149*da0073e9SAndroid Build Coastguard Worker _check_sequence_type(inputs) 150*da0073e9SAndroid Build Coastguard Worker _check_sequence_type(outputs) 151*da0073e9SAndroid Build Coastguard Worker torch._C._nccl_reduce_scatter(inputs, outputs, op, streams, comms) 152