1# mypy: allow-untyped-defs 2# Copyright (c) Meta Platforms, Inc. and affiliates 3# implement matrix related ops for distributed tensor 4from typing import cast, Dict, List, Tuple 5 6import torch 7import torch.distributed as dist 8import torch.distributed.tensor._api as dtensor 9 10 11aten = torch.ops.aten 12 13 14def _requires_data_exchange(padding): 15 # TODO: whether there requires data exchange is currently determined by padding 16 return padding[1] != 0 17 18 19def _is_supported(input_size, kernel_size, stride, padding, dilation): 20 if dilation[1] != 1: 21 raise RuntimeError("Dilation must be 1 for tensor parallel convolution.") 22 if padding[1] != 0: 23 if stride[1] != 1: 24 raise RuntimeError( 25 "Stride must be 1 when there is padding for tensor parallel convolution." 26 ) 27 if kernel_size[3] // 2 > input_size[3]: 28 raise RuntimeError( 29 "kernel_size[3] // 2 should be less than or equal to input_size[3] for tensor parallel convolution." 30 ) 31 else: 32 if not (input_size[3] % stride[1] == 0 and stride[1] == kernel_size[3]): 33 raise RuntimeError( 34 "It requires that input_size[3] is divisible by stride[1] and stride[1] equals kernel_size[3] " 35 "when there is padding for tensor parallel convolution." 36 ) 37 return True 38 39 40def _ring_send_recv_construct(in_tensor, d1, d2, left, right, rank, size): 41 # dist comms and reconstruct local input tensor 42 send_to_right = in_tensor[:, :, :, -d1:].contiguous() 43 send_to_left = in_tensor[:, :, :, :d2].contiguous() 44 recv_from_right = torch.zeros_like(send_to_left) 45 recv_from_left = torch.zeros_like(send_to_right) 46 47 send_op_right = dist.P2POp(dist.isend, send_to_right, right) 48 send_op_left = dist.P2POp(dist.isend, send_to_left, left) 49 recv_op_right = dist.P2POp(dist.irecv, recv_from_right, right) 50 recv_op_left = dist.P2POp(dist.irecv, recv_from_left, left) 51 52 reqs = dist.batch_isend_irecv( 53 [send_op_right, send_op_left, recv_op_left, recv_op_right] 54 ) 55 for req in reqs: 56 req.wait() 57 58 if rank == 0: 59 in_tensor = torch.cat([in_tensor, recv_from_right], dim=-1) 60 elif rank == size - 1: 61 in_tensor = torch.cat([recv_from_left, in_tensor], dim=-1) 62 else: 63 in_tensor = torch.cat([recv_from_left, in_tensor, recv_from_right], dim=-1) 64 65 return in_tensor 66 67 68def _ring_send_recv_aggregate(grad_in_tensor, d1, d2, left, right, rank, size): 69 # dist comms and aggregate gradients for edge pixels 70 send_to_right = grad_in_tensor[:, :, :, -d2:].contiguous() 71 send_to_left = grad_in_tensor[:, :, :, :d1].contiguous() 72 recv_from_right = torch.zeros_like(send_to_left) 73 recv_from_left = torch.zeros_like(send_to_right) 74 75 send_op_right = dist.P2POp(dist.isend, send_to_right, right) 76 send_op_left = dist.P2POp(dist.isend, send_to_left, left) 77 recv_op_right = dist.P2POp(dist.irecv, recv_from_right, right) 78 recv_op_left = dist.P2POp(dist.irecv, recv_from_left, left) 79 80 reqs = dist.batch_isend_irecv( 81 [send_op_right, send_op_left, recv_op_left, recv_op_right] 82 ) 83 for req in reqs: 84 req.wait() 85 86 if rank == 0: 87 grad_in_tensor = grad_in_tensor[:, :, :, :-d2] 88 grad_in_tensor[:, :, :, -d1:] = torch.add( 89 grad_in_tensor[:, :, :, -d1:], recv_from_right 90 ) 91 elif rank == size - 1: 92 grad_in_tensor = grad_in_tensor[:, :, :, d1:] 93 grad_in_tensor[:, :, :, :d2] = torch.add( 94 grad_in_tensor[:, :, :, :d2], recv_from_left 95 ) 96 else: 97 grad_in_tensor = grad_in_tensor[:, :, :, d1:-d2] 98 grad_in_tensor[:, :, :, -d1:] = torch.add( 99 grad_in_tensor[:, :, :, -d1:], recv_from_right 100 ) 101 grad_in_tensor[:, :, :, :d2] = torch.add( 102 grad_in_tensor[:, :, :, :d2], recv_from_left 103 ) 104 105 106def tp_convolution( 107 op_call: torch._ops.OpOverload, 108 local_tensor_args: Tuple[object, ...], 109 local_tensor_kwargs: Dict[str, object], 110) -> object: 111 assert op_call == aten.convolution.default 112 assert len(local_tensor_args) == 9 113 114 rank = dist.get_rank() 115 size = dist.get_world_size() 116 in_tensor = cast(torch.Tensor, local_tensor_args[0]) 117 weight = cast(torch.Tensor, local_tensor_args[1]) 118 stride, padding, dilation = local_tensor_args[3:6] 119 120 assert _is_supported(in_tensor.shape, weight.shape, stride, padding, dilation) 121 assert isinstance(padding, List) 122 123 if not _requires_data_exchange(padding): 124 local_results = op_call(*local_tensor_args, **local_tensor_kwargs) 125 return local_results 126 else: 127 # step 0 compute the overlap pixels of the input tensor 128 d = weight.shape[3] - 1 129 d1 = d // 2 130 d2 = d - d1 131 assert d1 + d2 == d 132 right = (rank + 1) % size 133 left = (rank - 1 + size) % size 134 135 # step1 reconstruct local input tensor 136 in_tensor = _ring_send_recv_construct( 137 in_tensor, d1, d2, left, right, rank, size 138 ) 139 140 # step2 feed local input tensor to op_call 141 local_tensor_args_list = list(local_tensor_args) 142 local_tensor_args_list[0] = in_tensor 143 local_tensor_args = cast(Tuple[object, ...], local_tensor_args_list) 144 local_results = op_call(*local_tensor_args, **local_tensor_kwargs) 145 146 # step3 remove extra outputs from the results 147 padding_w = padding[1] 148 w = local_results.size(3) 149 if rank == 0: 150 local_results = local_results[:, :, :, : w - padding_w] 151 elif rank == size - 1: 152 local_results = local_results[:, :, :, padding_w:] 153 else: 154 local_results = local_results[:, :, :, padding_w : w - padding_w] 155 156 return local_results 157 158 159def tp_convolution_backward( 160 op_call: torch._ops.OpOverload, 161 local_tensor_args: Tuple[object, ...], 162 local_tensor_kwargs: Dict[str, object], 163) -> object: 164 assert op_call == aten.convolution_backward.default 165 assert len(local_tensor_args) == 11 166 167 rank = dist.get_rank() 168 size = dist.get_world_size() 169 grad_out_tensor = cast(torch.Tensor, local_tensor_args[0]) 170 in_tensor = cast(torch.Tensor, local_tensor_args[1]) 171 weight = cast(torch.Tensor, local_tensor_args[2]) 172 stride, padding, dilation = local_tensor_args[4:7] 173 174 assert _is_supported(in_tensor.shape, weight.shape, stride, padding, dilation) 175 assert isinstance(padding, List) 176 177 if not _requires_data_exchange(padding): 178 local_results = op_call(*local_tensor_args, **local_tensor_kwargs) 179 return local_results 180 else: 181 # step 0 compute the overlap pixels of the input tensor 182 d = weight.shape[3] - 1 183 d1 = d // 2 184 d2 = d - d1 185 assert d1 + d2 == d 186 right = (rank + 1) % size 187 left = (rank - 1 + size) % size 188 189 # step1 reconstruct local input tensor 190 in_tensor = _ring_send_recv_construct( 191 in_tensor, d1, d2, left, right, rank, size 192 ) 193 194 # step2 reconstruct local gradient output tensor 195 N, C_out, H_out, _ = grad_out_tensor.shape 196 padding_w = padding[1] 197 if rank == 0: 198 grad_out_tensor = torch.nn.functional.pad( 199 grad_out_tensor, (0, padding_w), "constant", 0 200 ) 201 elif rank == size - 1: 202 grad_out_tensor = torch.nn.functional.pad( 203 grad_out_tensor, (padding_w, 0), "constant", 0 204 ) 205 else: 206 grad_out_tensor = torch.nn.functional.pad( 207 grad_out_tensor, (padding_w, padding_w), "constant", 0 208 ) 209 210 # step3 feed local input tensor to op_call 211 local_tensor_args_list = list(local_tensor_args) 212 local_tensor_args_list[0] = grad_out_tensor 213 local_tensor_args_list[1] = in_tensor 214 local_tensor_args = cast(Tuple[object, ...], local_tensor_args_list) 215 local_results = op_call(*local_tensor_args, **local_tensor_kwargs) 216 217 # step4 aggregate gradients for edge pixels 218 grad_in_tensor = local_results[0] 219 grad_in_tensor = _ring_send_recv_aggregate( 220 grad_in_tensor, d1, d2, left, right, rank, size 221 ) 222 223 local_results = list(local_results) 224 local_results[0] = grad_in_tensor 225 local_results = cast(Tuple[object, ...], local_results) 226 227 return local_results 228 229 230def convolution_handler( 231 op_call: torch._ops.OpOverload, 232 args: Tuple[object, ...], 233 kwargs: Dict[str, object], 234) -> object: 235 # extract local tensor and sharding infos to a OpInfo 236 op_info = dtensor.DTensor._op_dispatcher.unwrap_to_op_info(op_call, args, kwargs) 237 238 # sharding propagation 239 dtensor.DTensor._op_dispatcher.sharding_propagator.propagate(op_info) 240 output_sharding = op_info.output_sharding 241 assert output_sharding is not None, "output sharding should not be None" 242 243 # local propagation 244 local_results = tp_convolution( 245 op_call, tuple(op_info.local_args), op_info.local_kwargs 246 ) 247 248 return dtensor.DTensor._op_dispatcher.wrap( 249 local_results, output_sharding.output_spec 250 ) 251 252 253def convolution_backward_handler( 254 op_call: torch._ops.OpOverload, 255 args: Tuple[object, ...], 256 kwargs: Dict[str, object], 257) -> object: 258 # Redistribute grad_output tensor to the same placement as input tensor 259 args = list(args) 260 assert isinstance(args[0], dtensor.DTensor) and isinstance(args[1], dtensor.DTensor) 261 args[0] = args[0].redistribute(args[1].device_mesh, args[1].placements) 262 args = tuple(args) 263 264 # extract local tensor and sharding infos to a OpInfo 265 op_info = dtensor.DTensor._op_dispatcher.unwrap_to_op_info(op_call, args, kwargs) 266 267 # sharding propagation 268 dtensor.DTensor._op_dispatcher.sharding_propagator.propagate(op_info) 269 output_sharding = op_info.output_sharding 270 assert output_sharding is not None, "output sharding should not be None" 271 272 # local propagation 273 local_results = tp_convolution_backward( 274 op_call, tuple(op_info.local_args), op_info.local_kwargs 275 ) 276 277 return dtensor.DTensor._op_dispatcher.wrap( 278 local_results, output_sharding.output_spec 279 ) 280