1import warnings 2from typing import List, Optional 3 4import torch 5from torch._utils import _get_device_index 6from torch.autograd import Function 7from torch.nn.parallel import comm 8 9 10class Broadcast(Function): 11 @staticmethod 12 def forward(ctx, target_gpus, *inputs): 13 assert all( 14 i.device.type != "cpu" for i in inputs 15 ), "Broadcast function not implemented for CPU tensors" 16 target_gpus = [_get_device_index(x, True) for x in target_gpus] 17 ctx.target_gpus = target_gpus 18 if len(inputs) == 0: 19 return () 20 ctx.num_inputs = len(inputs) 21 ctx.input_device = inputs[0].get_device() 22 outputs = comm.broadcast_coalesced(inputs, ctx.target_gpus) 23 non_differentiables = [] 24 for idx, input_requires_grad in enumerate(ctx.needs_input_grad[1:]): 25 if not input_requires_grad: 26 for output in outputs: 27 non_differentiables.append(output[idx]) 28 ctx.mark_non_differentiable(*non_differentiables) 29 return tuple([t for tensors in outputs for t in tensors]) 30 31 @staticmethod 32 def backward(ctx, *grad_outputs): 33 return (None,) + ReduceAddCoalesced.apply( 34 ctx.input_device, ctx.num_inputs, *grad_outputs 35 ) 36 37 38class ReduceAddCoalesced(Function): 39 @staticmethod 40 def forward(ctx, destination, num_inputs, *grads): 41 ctx.target_gpus = [ 42 grads[i].get_device() for i in range(0, len(grads), num_inputs) 43 ] 44 45 grads_ = [grads[i : i + num_inputs] for i in range(0, len(grads), num_inputs)] 46 return comm.reduce_add_coalesced(grads_, destination) 47 48 @staticmethod 49 def backward(ctx, *grad_outputs): 50 return ( 51 None, 52 None, 53 ) + Broadcast.apply(ctx.target_gpus, *grad_outputs) 54 55 56class Gather(Function): 57 @staticmethod 58 def forward(ctx, target_device, dim, *inputs): 59 assert all( 60 i.device.type != "cpu" for i in inputs 61 ), "Gather function not implemented for CPU tensors" 62 if target_device == "cpu": 63 ctx.target_device = "cpu" 64 else: 65 target_device = _get_device_index(target_device, True) 66 ctx.target_device = target_device 67 ctx.dim = dim 68 ctx.input_gpus = tuple(i.get_device() for i in inputs) 69 if all(t.dim() == 0 for t in inputs) and dim == 0: 70 inputs = tuple(t.view(1) for t in inputs) 71 warnings.warn( 72 "Was asked to gather along dimension 0, but all " 73 "input tensors were scalars; will instead unsqueeze " 74 "and return a vector." 75 ) 76 ctx.unsqueezed_scalar = True 77 else: 78 ctx.unsqueezed_scalar = False 79 ctx.input_sizes = tuple(i.size(ctx.dim) for i in inputs) 80 return comm.gather(inputs, ctx.dim, ctx.target_device) 81 82 @staticmethod 83 def backward(ctx, grad_output): 84 scattered_grads = Scatter.apply( 85 ctx.input_gpus, ctx.input_sizes, ctx.dim, grad_output 86 ) 87 if ctx.unsqueezed_scalar: 88 scattered_grads = tuple(g[0] for g in scattered_grads) 89 return (None, None) + scattered_grads 90 91 92class Scatter(Function): 93 @staticmethod 94 def forward(ctx, target_gpus, chunk_sizes, dim, input): 95 target_gpus = [_get_device_index(x, True) for x in target_gpus] 96 ctx.dim = dim 97 ctx.input_device = input.get_device() if input.device.type != "cpu" else -1 98 streams = None 99 if torch.cuda.is_available() and ctx.input_device == -1: 100 # Perform CPU to GPU copies in a background stream 101 streams = [ 102 _get_stream(torch.device("cuda", device)) for device in target_gpus 103 ] 104 outputs = comm.scatter(input, target_gpus, chunk_sizes, ctx.dim, streams) 105 # Synchronize with the copy stream 106 if streams is not None: 107 for i, output in enumerate(outputs): 108 with torch.cuda.device(target_gpus[i]): 109 main_stream = torch.cuda.current_stream() 110 main_stream.wait_stream(streams[i]) 111 output.record_stream(main_stream) 112 return outputs 113 114 @staticmethod 115 def backward(ctx, *grad_output): 116 return None, None, None, Gather.apply(ctx.input_device, ctx.dim, *grad_output) 117 118 119# background streams used for copying 120_streams: Optional[List[Optional[torch.Stream]]] = None 121 122 123def _get_stream(device: torch.device): 124 """Get a background stream for copying between CPU and target device.""" 125 global _streams 126 if device.type == "cpu": 127 return None 128 device_mod = getattr(torch, device.type, None) 129 if device_mod is None: 130 return None 131 if _streams is None: 132 _streams = [None] * device_mod.device_count() 133 if _streams[device.index] is None: 134 _streams[device.index] = device_mod.Stream(device.index) 135 return _streams[device.index] 136