xref: /aosp_15_r20/external/pytorch/torch/distributed/tensor/_tp_conv.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# mypy: allow-untyped-defs
2# Copyright (c) Meta Platforms, Inc. and affiliates
3# implement matrix related ops for distributed tensor
4from typing import cast, Dict, List, Tuple
5
6import torch
7import torch.distributed as dist
8import torch.distributed.tensor._api as dtensor
9
10
11aten = torch.ops.aten
12
13
14def _requires_data_exchange(padding):
15    # TODO: whether there requires data exchange is currently determined by padding
16    return padding[1] != 0
17
18
19def _is_supported(input_size, kernel_size, stride, padding, dilation):
20    if dilation[1] != 1:
21        raise RuntimeError("Dilation must be 1 for tensor parallel convolution.")
22    if padding[1] != 0:
23        if stride[1] != 1:
24            raise RuntimeError(
25                "Stride must be 1 when there is padding for tensor parallel convolution."
26            )
27        if kernel_size[3] // 2 > input_size[3]:
28            raise RuntimeError(
29                "kernel_size[3] // 2 should be less than or equal to input_size[3] for tensor parallel convolution."
30            )
31    else:
32        if not (input_size[3] % stride[1] == 0 and stride[1] == kernel_size[3]):
33            raise RuntimeError(
34                "It requires that input_size[3] is divisible by stride[1] and stride[1] equals kernel_size[3] "
35                "when there is padding for tensor parallel convolution."
36            )
37    return True
38
39
40def _ring_send_recv_construct(in_tensor, d1, d2, left, right, rank, size):
41    # dist comms and reconstruct local input tensor
42    send_to_right = in_tensor[:, :, :, -d1:].contiguous()
43    send_to_left = in_tensor[:, :, :, :d2].contiguous()
44    recv_from_right = torch.zeros_like(send_to_left)
45    recv_from_left = torch.zeros_like(send_to_right)
46
47    send_op_right = dist.P2POp(dist.isend, send_to_right, right)
48    send_op_left = dist.P2POp(dist.isend, send_to_left, left)
49    recv_op_right = dist.P2POp(dist.irecv, recv_from_right, right)
50    recv_op_left = dist.P2POp(dist.irecv, recv_from_left, left)
51
52    reqs = dist.batch_isend_irecv(
53        [send_op_right, send_op_left, recv_op_left, recv_op_right]
54    )
55    for req in reqs:
56        req.wait()
57
58    if rank == 0:
59        in_tensor = torch.cat([in_tensor, recv_from_right], dim=-1)
60    elif rank == size - 1:
61        in_tensor = torch.cat([recv_from_left, in_tensor], dim=-1)
62    else:
63        in_tensor = torch.cat([recv_from_left, in_tensor, recv_from_right], dim=-1)
64
65    return in_tensor
66
67
68def _ring_send_recv_aggregate(grad_in_tensor, d1, d2, left, right, rank, size):
69    # dist comms and aggregate gradients for edge pixels
70    send_to_right = grad_in_tensor[:, :, :, -d2:].contiguous()
71    send_to_left = grad_in_tensor[:, :, :, :d1].contiguous()
72    recv_from_right = torch.zeros_like(send_to_left)
73    recv_from_left = torch.zeros_like(send_to_right)
74
75    send_op_right = dist.P2POp(dist.isend, send_to_right, right)
76    send_op_left = dist.P2POp(dist.isend, send_to_left, left)
77    recv_op_right = dist.P2POp(dist.irecv, recv_from_right, right)
78    recv_op_left = dist.P2POp(dist.irecv, recv_from_left, left)
79
80    reqs = dist.batch_isend_irecv(
81        [send_op_right, send_op_left, recv_op_left, recv_op_right]
82    )
83    for req in reqs:
84        req.wait()
85
86    if rank == 0:
87        grad_in_tensor = grad_in_tensor[:, :, :, :-d2]
88        grad_in_tensor[:, :, :, -d1:] = torch.add(
89            grad_in_tensor[:, :, :, -d1:], recv_from_right
90        )
91    elif rank == size - 1:
92        grad_in_tensor = grad_in_tensor[:, :, :, d1:]
93        grad_in_tensor[:, :, :, :d2] = torch.add(
94            grad_in_tensor[:, :, :, :d2], recv_from_left
95        )
96    else:
97        grad_in_tensor = grad_in_tensor[:, :, :, d1:-d2]
98        grad_in_tensor[:, :, :, -d1:] = torch.add(
99            grad_in_tensor[:, :, :, -d1:], recv_from_right
100        )
101        grad_in_tensor[:, :, :, :d2] = torch.add(
102            grad_in_tensor[:, :, :, :d2], recv_from_left
103        )
104
105
106def tp_convolution(
107    op_call: torch._ops.OpOverload,
108    local_tensor_args: Tuple[object, ...],
109    local_tensor_kwargs: Dict[str, object],
110) -> object:
111    assert op_call == aten.convolution.default
112    assert len(local_tensor_args) == 9
113
114    rank = dist.get_rank()
115    size = dist.get_world_size()
116    in_tensor = cast(torch.Tensor, local_tensor_args[0])
117    weight = cast(torch.Tensor, local_tensor_args[1])
118    stride, padding, dilation = local_tensor_args[3:6]
119
120    assert _is_supported(in_tensor.shape, weight.shape, stride, padding, dilation)
121    assert isinstance(padding, List)
122
123    if not _requires_data_exchange(padding):
124        local_results = op_call(*local_tensor_args, **local_tensor_kwargs)
125        return local_results
126    else:
127        # step 0 compute the overlap pixels of the input tensor
128        d = weight.shape[3] - 1
129        d1 = d // 2
130        d2 = d - d1
131        assert d1 + d2 == d
132        right = (rank + 1) % size
133        left = (rank - 1 + size) % size
134
135        # step1 reconstruct local input tensor
136        in_tensor = _ring_send_recv_construct(
137            in_tensor, d1, d2, left, right, rank, size
138        )
139
140        # step2 feed local input tensor to op_call
141        local_tensor_args_list = list(local_tensor_args)
142        local_tensor_args_list[0] = in_tensor
143        local_tensor_args = cast(Tuple[object, ...], local_tensor_args_list)
144        local_results = op_call(*local_tensor_args, **local_tensor_kwargs)
145
146        # step3 remove extra outputs from the results
147        padding_w = padding[1]
148        w = local_results.size(3)
149        if rank == 0:
150            local_results = local_results[:, :, :, : w - padding_w]
151        elif rank == size - 1:
152            local_results = local_results[:, :, :, padding_w:]
153        else:
154            local_results = local_results[:, :, :, padding_w : w - padding_w]
155
156        return local_results
157
158
159def tp_convolution_backward(
160    op_call: torch._ops.OpOverload,
161    local_tensor_args: Tuple[object, ...],
162    local_tensor_kwargs: Dict[str, object],
163) -> object:
164    assert op_call == aten.convolution_backward.default
165    assert len(local_tensor_args) == 11
166
167    rank = dist.get_rank()
168    size = dist.get_world_size()
169    grad_out_tensor = cast(torch.Tensor, local_tensor_args[0])
170    in_tensor = cast(torch.Tensor, local_tensor_args[1])
171    weight = cast(torch.Tensor, local_tensor_args[2])
172    stride, padding, dilation = local_tensor_args[4:7]
173
174    assert _is_supported(in_tensor.shape, weight.shape, stride, padding, dilation)
175    assert isinstance(padding, List)
176
177    if not _requires_data_exchange(padding):
178        local_results = op_call(*local_tensor_args, **local_tensor_kwargs)
179        return local_results
180    else:
181        # step 0 compute the overlap pixels of the input tensor
182        d = weight.shape[3] - 1
183        d1 = d // 2
184        d2 = d - d1
185        assert d1 + d2 == d
186        right = (rank + 1) % size
187        left = (rank - 1 + size) % size
188
189        # step1 reconstruct local input tensor
190        in_tensor = _ring_send_recv_construct(
191            in_tensor, d1, d2, left, right, rank, size
192        )
193
194        # step2 reconstruct local gradient output tensor
195        N, C_out, H_out, _ = grad_out_tensor.shape
196        padding_w = padding[1]
197        if rank == 0:
198            grad_out_tensor = torch.nn.functional.pad(
199                grad_out_tensor, (0, padding_w), "constant", 0
200            )
201        elif rank == size - 1:
202            grad_out_tensor = torch.nn.functional.pad(
203                grad_out_tensor, (padding_w, 0), "constant", 0
204            )
205        else:
206            grad_out_tensor = torch.nn.functional.pad(
207                grad_out_tensor, (padding_w, padding_w), "constant", 0
208            )
209
210        # step3 feed local input tensor to op_call
211        local_tensor_args_list = list(local_tensor_args)
212        local_tensor_args_list[0] = grad_out_tensor
213        local_tensor_args_list[1] = in_tensor
214        local_tensor_args = cast(Tuple[object, ...], local_tensor_args_list)
215        local_results = op_call(*local_tensor_args, **local_tensor_kwargs)
216
217        # step4 aggregate gradients for edge pixels
218        grad_in_tensor = local_results[0]
219        grad_in_tensor = _ring_send_recv_aggregate(
220            grad_in_tensor, d1, d2, left, right, rank, size
221        )
222
223        local_results = list(local_results)
224        local_results[0] = grad_in_tensor
225        local_results = cast(Tuple[object, ...], local_results)
226
227        return local_results
228
229
230def convolution_handler(
231    op_call: torch._ops.OpOverload,
232    args: Tuple[object, ...],
233    kwargs: Dict[str, object],
234) -> object:
235    # extract local tensor and sharding infos to a OpInfo
236    op_info = dtensor.DTensor._op_dispatcher.unwrap_to_op_info(op_call, args, kwargs)
237
238    # sharding propagation
239    dtensor.DTensor._op_dispatcher.sharding_propagator.propagate(op_info)
240    output_sharding = op_info.output_sharding
241    assert output_sharding is not None, "output sharding should not be None"
242
243    # local propagation
244    local_results = tp_convolution(
245        op_call, tuple(op_info.local_args), op_info.local_kwargs
246    )
247
248    return dtensor.DTensor._op_dispatcher.wrap(
249        local_results, output_sharding.output_spec
250    )
251
252
253def convolution_backward_handler(
254    op_call: torch._ops.OpOverload,
255    args: Tuple[object, ...],
256    kwargs: Dict[str, object],
257) -> object:
258    # Redistribute grad_output tensor to the same placement as input tensor
259    args = list(args)
260    assert isinstance(args[0], dtensor.DTensor) and isinstance(args[1], dtensor.DTensor)
261    args[0] = args[0].redistribute(args[1].device_mesh, args[1].placements)
262    args = tuple(args)
263
264    # extract local tensor and sharding infos to a OpInfo
265    op_info = dtensor.DTensor._op_dispatcher.unwrap_to_op_info(op_call, args, kwargs)
266
267    # sharding propagation
268    dtensor.DTensor._op_dispatcher.sharding_propagator.propagate(op_info)
269    output_sharding = op_info.output_sharding
270    assert output_sharding is not None, "output sharding should not be None"
271
272    # local propagation
273    local_results = tp_convolution_backward(
274        op_call, tuple(op_info.local_args), op_info.local_kwargs
275    )
276
277    return dtensor.DTensor._op_dispatcher.wrap(
278        local_results, output_sharding.output_spec
279    )
280