xref: /aosp_15_r20/external/pytorch/torch/nn/parallel/_functions.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1import warnings
2from typing import List, Optional
3
4import torch
5from torch._utils import _get_device_index
6from torch.autograd import Function
7from torch.nn.parallel import comm
8
9
10class Broadcast(Function):
11    @staticmethod
12    def forward(ctx, target_gpus, *inputs):
13        assert all(
14            i.device.type != "cpu" for i in inputs
15        ), "Broadcast function not implemented for CPU tensors"
16        target_gpus = [_get_device_index(x, True) for x in target_gpus]
17        ctx.target_gpus = target_gpus
18        if len(inputs) == 0:
19            return ()
20        ctx.num_inputs = len(inputs)
21        ctx.input_device = inputs[0].get_device()
22        outputs = comm.broadcast_coalesced(inputs, ctx.target_gpus)
23        non_differentiables = []
24        for idx, input_requires_grad in enumerate(ctx.needs_input_grad[1:]):
25            if not input_requires_grad:
26                for output in outputs:
27                    non_differentiables.append(output[idx])
28        ctx.mark_non_differentiable(*non_differentiables)
29        return tuple([t for tensors in outputs for t in tensors])
30
31    @staticmethod
32    def backward(ctx, *grad_outputs):
33        return (None,) + ReduceAddCoalesced.apply(
34            ctx.input_device, ctx.num_inputs, *grad_outputs
35        )
36
37
38class ReduceAddCoalesced(Function):
39    @staticmethod
40    def forward(ctx, destination, num_inputs, *grads):
41        ctx.target_gpus = [
42            grads[i].get_device() for i in range(0, len(grads), num_inputs)
43        ]
44
45        grads_ = [grads[i : i + num_inputs] for i in range(0, len(grads), num_inputs)]
46        return comm.reduce_add_coalesced(grads_, destination)
47
48    @staticmethod
49    def backward(ctx, *grad_outputs):
50        return (
51            None,
52            None,
53        ) + Broadcast.apply(ctx.target_gpus, *grad_outputs)
54
55
56class Gather(Function):
57    @staticmethod
58    def forward(ctx, target_device, dim, *inputs):
59        assert all(
60            i.device.type != "cpu" for i in inputs
61        ), "Gather function not implemented for CPU tensors"
62        if target_device == "cpu":
63            ctx.target_device = "cpu"
64        else:
65            target_device = _get_device_index(target_device, True)
66            ctx.target_device = target_device
67        ctx.dim = dim
68        ctx.input_gpus = tuple(i.get_device() for i in inputs)
69        if all(t.dim() == 0 for t in inputs) and dim == 0:
70            inputs = tuple(t.view(1) for t in inputs)
71            warnings.warn(
72                "Was asked to gather along dimension 0, but all "
73                "input tensors were scalars; will instead unsqueeze "
74                "and return a vector."
75            )
76            ctx.unsqueezed_scalar = True
77        else:
78            ctx.unsqueezed_scalar = False
79        ctx.input_sizes = tuple(i.size(ctx.dim) for i in inputs)
80        return comm.gather(inputs, ctx.dim, ctx.target_device)
81
82    @staticmethod
83    def backward(ctx, grad_output):
84        scattered_grads = Scatter.apply(
85            ctx.input_gpus, ctx.input_sizes, ctx.dim, grad_output
86        )
87        if ctx.unsqueezed_scalar:
88            scattered_grads = tuple(g[0] for g in scattered_grads)
89        return (None, None) + scattered_grads
90
91
92class Scatter(Function):
93    @staticmethod
94    def forward(ctx, target_gpus, chunk_sizes, dim, input):
95        target_gpus = [_get_device_index(x, True) for x in target_gpus]
96        ctx.dim = dim
97        ctx.input_device = input.get_device() if input.device.type != "cpu" else -1
98        streams = None
99        if torch.cuda.is_available() and ctx.input_device == -1:
100            # Perform CPU to GPU copies in a background stream
101            streams = [
102                _get_stream(torch.device("cuda", device)) for device in target_gpus
103            ]
104        outputs = comm.scatter(input, target_gpus, chunk_sizes, ctx.dim, streams)
105        # Synchronize with the copy stream
106        if streams is not None:
107            for i, output in enumerate(outputs):
108                with torch.cuda.device(target_gpus[i]):
109                    main_stream = torch.cuda.current_stream()
110                    main_stream.wait_stream(streams[i])
111                    output.record_stream(main_stream)
112        return outputs
113
114    @staticmethod
115    def backward(ctx, *grad_output):
116        return None, None, None, Gather.apply(ctx.input_device, ctx.dim, *grad_output)
117
118
119# background streams used for copying
120_streams: Optional[List[Optional[torch.Stream]]] = None
121
122
123def _get_stream(device: torch.device):
124    """Get a background stream for copying between CPU and target device."""
125    global _streams
126    if device.type == "cpu":
127        return None
128    device_mod = getattr(torch, device.type, None)
129    if device_mod is None:
130        return None
131    if _streams is None:
132        _streams = [None] * device_mod.device_count()
133    if _streams[device.index] is None:
134        _streams[device.index] = device_mod.Stream(device.index)
135    return _streams[device.index]
136