1# mypy: allow-untyped-defs 2import logging 3import math 4from dataclasses import dataclass 5from functools import lru_cache 6from typing import List, Optional 7 8import torch 9import torch.distributed._functional_collectives as funcol 10import torch.distributed.tensor._dtensor_spec as dtensor_spec 11from torch._C._distributed_c10d import _resolve_process_group 12from torch.distributed.device_mesh import _mesh_resources, DeviceMesh 13from torch.distributed.distributed_c10d import ( 14 _get_group_size_by_name, 15 broadcast, 16 get_global_rank, 17 get_group_rank, 18 get_rank, 19 GroupMember, 20 ProcessGroup, 21 scatter, 22 Work, 23) 24 25 26logger = logging.getLogger(__name__) 27 28 29if not torch._running_with_deploy(): 30 31 @torch.library.register_fake("_dtensor::shard_dim_alltoall") 32 def _shard_dim_alltoall_meta(input, gather_dim, shard_dim, group_name): 33 group_size = _get_group_size_by_name(group_name) 34 stacked_list = [torch.empty_like(input) for _ in range(group_size)] 35 group = _resolve_process_group(group_name) 36 group_rank = get_group_rank(group, get_rank()) 37 38 return torch.cat(stacked_list, dim=gather_dim).chunk(group_size, dim=shard_dim)[ 39 group_rank 40 ] 41 42else: 43 import warnings 44 45 warnings.warn( 46 "PyTorch Distributed functional collectives do not work with torch::deploy." 47 ) 48 49 50def shard_dim_alltoall(input, gather_dim, shard_dim, mesh, mesh_dim): 51 if mesh.device_type == "cpu": 52 # Gloo does not support alltoall, so falling back to allgather + chunk 53 54 # TODO: This logs way too much 55 logger.warning( 56 "CPU process group does not support alltoall yet, falling back with allgather + chunk!" 57 ) 58 out = funcol.all_gather_tensor(input, gather_dim, (mesh, mesh_dim)) 59 if isinstance(out, funcol.AsyncCollectiveTensor): 60 # stick to the same behavior for the alltoall case, remove this once we enable alltoall async 61 out = out.wait() 62 out = torch.chunk(out, mesh.size(mesh_dim), dim=shard_dim)[ 63 mesh.get_local_rank(mesh_dim) 64 ] 65 return out.contiguous() if not out.is_contiguous() else out 66 67 group_name = funcol._resolve_group_name((mesh, mesh_dim)) 68 # TODO: enable async op for shard_dim_alltoall 69 return torch.ops._dtensor.shard_dim_alltoall( 70 input, gather_dim, shard_dim, group_name 71 ) 72 73 74def mesh_scatter( 75 output: torch.Tensor, 76 scatter_list: List[torch.Tensor], 77 mesh: DeviceMesh, 78 mesh_dim: int = 0, 79 async_op: bool = False, 80) -> Optional[Work]: 81 """ 82 scatter a list of tensors to a device mesh dimension. We by default 83 use the first rank of the mesh dimension as the source of truth, i.e 84 for a 2d mesh [[0, 1], [2, 3]], if we scatter on mesh_dim = 1, we will 85 scatter the tensor list on rank 0 to rank 0/1, and tensor list on rank 86 2 to rank 2/3. 87 88 Args: 89 output (torch.Tensor): the tensor to receive the scattered list. 90 scatter_list (List[torch.Tensor]): the tensor list to be scattered. 91 mesh_dim (int, optional): indicate which mesh dimension we want 92 to scatter on, we by default choose the first rank on the 93 mesh dimension as source of truth. 94 95 Returns: 96 A :class:`Work` object 97 """ 98 # TODO: Ideally we should use the meta tensor way 99 # (to register a meta kernel for the collective op) 100 # so that it would avoid the communication. Need to 101 # remove the check below once that is done. 102 if output.is_meta: 103 return None 104 dim_group = mesh.get_group(mesh_dim) 105 assert isinstance(dim_group, ProcessGroup) 106 # src need to be global rank 107 src_for_dim = 0 108 109 if dim_group is not GroupMember.WORLD: 110 src_for_dim = get_global_rank(dim_group, 0) 111 112 if src_for_dim == get_rank(): 113 fut = scatter( 114 output, 115 scatter_list=scatter_list, 116 src=src_for_dim, 117 group=dim_group, 118 async_op=async_op, 119 ) 120 else: 121 fut = scatter( 122 output, 123 scatter_list=None, 124 src=src_for_dim, 125 group=dim_group, 126 async_op=async_op, 127 ) 128 129 return fut 130 131 132def mesh_broadcast( 133 tensor: torch.Tensor, 134 mesh: DeviceMesh, 135 mesh_dim: int = 0, 136 async_op: bool = False, 137) -> Optional[Work]: 138 """ 139 broadcast the tensor to a device mesh dimension. We by default 140 use the first rank of the mesh dimension as the source of truth, i.e 141 for a 2d mesh [[0, 1], [2, 3]], if we broadcast on mesh_dim = 1, we will 142 broadcast the tensor on rank 0 to rank 0/1, and tensor on rank 2 143 to rank 2/3. 144 145 Args: 146 tensor (torch.Tensor): tensor to broadcast. 147 mesh_dim (int, optional): indicate which mesh dimension we want 148 to scatter on, we by default choose the first rank on the 149 mesh dimension as source of truth. 150 151 Returns: 152 A :class:`Work` object 153 """ 154 # TODO: Ideally we should use the meta tensor way 155 # (to register a meta kernel for the collective op) 156 # so that it would avoid the communication. Need to 157 # remove the check below once that is done. 158 if tensor.is_meta: 159 return None 160 dim_group = mesh.get_group(mesh_dim) 161 assert isinstance(dim_group, ProcessGroup) 162 # src need to be global rank 163 src_for_dim = 0 164 if dim_group is not GroupMember.WORLD: 165 src_for_dim = get_global_rank(dim_group, 0) 166 167 return broadcast(tensor, src=src_for_dim, group=dim_group, async_op=async_op) 168 169 170def pad_tensor(tensor: torch.Tensor, pad_dim: int, pad_size: int) -> torch.Tensor: 171 if pad_size == 0: 172 return tensor 173 pad = [0, 0] * (tensor.ndim - pad_dim) 174 pad[-1] = pad_size 175 return torch.nn.functional.pad(tensor, pad) 176 177 178def unpad_tensor(tensor: torch.Tensor, pad_dim: int, pad_size: int) -> torch.Tensor: 179 if pad_size == 0: 180 return tensor 181 return tensor.narrow( 182 pad_dim, 183 start=0, 184 length=tensor.size(pad_dim) - pad_size, 185 ) 186 187 188def fill_empty_tensor_to_shards( 189 shards: List[torch.Tensor], shard_dim: int, num_empty_tensors: int 190) -> List[torch.Tensor]: 191 if num_empty_tensors == 0: 192 return shards 193 tensor_size = list(shards[0].size()) 194 tensor_size = [ 195 size if idx != shard_dim else 0 for idx, size in enumerate(tensor_size) 196 ] 197 tensor = shards[0].new_zeros(tensor_size) 198 for _ in range(num_empty_tensors): 199 shards.append(tensor) 200 return shards 201 202 203def check_tensor_meta( 204 local_tensor, check_shape_stride=False 205) -> Optional["dtensor_spec.TensorMeta"]: 206 local_metadata = { 207 "dtype": local_tensor.dtype, 208 "requires_grad": local_tensor.requires_grad, 209 } 210 211 if check_shape_stride: 212 local_metadata.update( 213 {"shape": local_tensor.shape, "stride": local_tensor.stride()} 214 ) 215 216 gathered_metadata = [None for _ in range(torch.distributed.get_world_size())] 217 torch.distributed.all_gather_object(gathered_metadata, local_metadata) 218 219 # Check if metadata is consistent across ranks 220 if not all(meta == local_metadata for meta in gathered_metadata): 221 raise ValueError( 222 "Inconsistent tensor metadata (including shape and stride) across ranks." 223 ) 224 return None 225 226 227def spec_to_bytes(spec: "dtensor_spec.DTensorSpec") -> int: 228 assert spec.tensor_meta is not None, "spec should have tensor meta defined!" 229 return spec.tensor_meta.dtype.itemsize * math.prod(spec.shape) 230 231 232@dataclass 233class MeshTopoInfo: 234 """ 235 Mesh information for collective cost estimation 236 """ 237 238 mesh: DeviceMesh 239 mesh_dim_devices: List[int] 240 mesh_dim_bandwidth: List[float] 241 mesh_dim_latency: List[float] 242 243 @staticmethod 244 @lru_cache(None) 245 def build_from_mesh(mesh: DeviceMesh) -> "MeshTopoInfo": 246 # Generate mesh topology info for intra-host/inter-host communication pattern 247 # Note that we made bunch of assumptions for simplicity: 248 # 1. we assume the mesh is homogeneous, and it's gpu/nccl model 249 # 2. we assume gpu arch is Ampere or Hopper 250 # 3. we assume collectives are all ring base algo for now 251 num_devices_per_host = _mesh_resources.num_devices_per_host(mesh.device_type) 252 # the base bw number (intra-node), GB/s 253 base_bw = 87.7 254 mesh_dim_bandwidth = [base_bw] * mesh.ndim 255 # the latency in terms of us (intra-node, nv-link) 256 mesh_dim_latency = [0.6] * mesh.ndim 257 mesh_dim_devices = [1] * mesh.ndim 258 259 total_num_devices = 1 260 for mesh_dim in reversed(range(mesh.ndim)): 261 num_devices = mesh.size(mesh_dim) 262 mesh_dim_devices[mesh_dim] = num_devices 263 total_num_devices *= num_devices 264 if total_num_devices > num_devices_per_host: 265 # magic number for inter-host communication bandwidth/latency factor 266 # This number assumes latest GPU arch, i.e. Ampere or Hopper 267 # TODO: see if we need to tweak this or offer a way for user 268 # to specify the bandwidths/latency 269 mesh_dim_bandwidth[mesh_dim] *= 0.22 270 # set to ethernet latency for inter-host 271 mesh_dim_latency[mesh_dim] = 2.7 272 273 return MeshTopoInfo( 274 mesh, mesh_dim_devices, mesh_dim_bandwidth, mesh_dim_latency 275 ) 276 277 278def allgather_cost(bytes_gb: float, mesh_topo: MeshTopoInfo, mesh_dim: int) -> float: 279 num_devices_on_mesh_dim = mesh_topo.mesh_dim_devices[mesh_dim] 280 mesh_dim_bandwidth = mesh_topo.mesh_dim_bandwidth[mesh_dim] 281 num_hops = num_devices_on_mesh_dim - 1 282 # base latency + comm latency 283 latency = 6.6 + num_hops * mesh_topo.mesh_dim_latency[mesh_dim] # us 284 bw = (bytes_gb * num_hops / num_devices_on_mesh_dim) / mesh_dim_bandwidth # s 285 return latency + bw * 1e6 # rescale to us 286 287 288def allreduce_cost(bytes_gb: float, mesh_topo: MeshTopoInfo, mesh_dim: int) -> float: 289 num_devices_on_mesh_dim = mesh_topo.mesh_dim_devices[mesh_dim] 290 mesh_dim_bandwidth = mesh_topo.mesh_dim_bandwidth[mesh_dim] 291 # allreduce have almost 2x comm bytes compare to allgather/reduce_scatter 292 num_hops = 2 * num_devices_on_mesh_dim - 1 293 294 latency = 6.6 + num_hops * mesh_topo.mesh_dim_latency[mesh_dim] 295 bw = (bytes_gb * num_hops / num_devices_on_mesh_dim) / mesh_dim_bandwidth 296 return latency + bw * 1e6 297 298 299def reduce_scatter_cost( 300 bytes_gb: float, 301 mesh_topo: MeshTopoInfo, 302 mesh_dim: int, 303) -> float: 304 num_devices_on_mesh_dim = mesh_topo.mesh_dim_devices[mesh_dim] 305 mesh_dim_bandwidth = mesh_topo.mesh_dim_bandwidth[mesh_dim] 306 num_hops = num_devices_on_mesh_dim - 1 307 # base latency + comm latency 308 latency = 6.6 + num_hops * mesh_topo.mesh_dim_latency[mesh_dim] 309 bw = (bytes_gb * num_hops / num_devices_on_mesh_dim) / mesh_dim_bandwidth 310 return latency + bw * 1e6 311 312 313def redistribute_cost( 314 current_spec: "dtensor_spec.DTensorSpec", 315 target_spec: "dtensor_spec.DTensorSpec", 316) -> float: 317 """ 318 This function returns the cost of redistribute from current to target DTensorSpec. 319 320 NOTE: 321 1. Only consider communication cost here, since computation costs for redistribute 322 are quite trival (i.e. we only need to narrow or simple division) 323 2. Only consider redistribute cost on same mesh, cross mesh communication cost is 324 not quite needed for operator strategy estimation/selection. 325 """ 326 if current_spec.mesh != target_spec.mesh: 327 # make infinite cost if meshes are not same 328 # TODO: see if we want to support this once there's cross mesh communication 329 return float("inf") 330 331 if current_spec.is_replicated(): 332 # short-cut: 333 # comm cost is 0 if current spec is already full replication 334 return 0.0 335 336 mesh_topo = MeshTopoInfo.build_from_mesh(current_spec.mesh) 337 cost = 0.0 338 comm_bytes_gb = ( 339 spec_to_bytes(current_spec) / current_spec.num_shards / 1024 / 1024 / 1024 340 ) 341 # Transformation that considered for redistribute cost: 342 # 1. allgather 2. alltoall 343 # 3. allreduce 4. reduce_scatter 344 for i, (current, target) in enumerate( 345 zip(current_spec.placements, target_spec.placements) 346 ): 347 if current == target: 348 continue 349 350 num_devices_on_mesh_dim = mesh_topo.mesh_dim_devices[i] 351 if current.is_shard() and target.is_replicate(): 352 # allgather gives larger comm bytes 353 comm_bytes_gb *= num_devices_on_mesh_dim 354 # add up allgather comm cost 355 cost += allgather_cost(comm_bytes_gb, mesh_topo, i) 356 elif current.is_shard() and target.is_shard(): 357 # should be alltoall comm, since we haven't implement it yet, add penalty 358 # to favor allgather instead 359 cost += allgather_cost(comm_bytes_gb, mesh_topo, i) + 1.0 360 elif current.is_partial() and target.is_replicate(): 361 # add up allreduce comm cost 362 cost += allreduce_cost(comm_bytes_gb, mesh_topo, i) 363 elif current.is_partial() and target.is_shard(): 364 # add up reduce_scatter comm cost 365 cost += reduce_scatter_cost(comm_bytes_gb, mesh_topo, i) 366 # after reduce_scatter the comm bytes for further collectives halved. 367 comm_bytes_gb /= num_devices_on_mesh_dim 368 elif current.is_shard() and target.is_partial(): 369 # ban shard -> partial as it does not make sense to perform 370 # this redistribute 371 return float("inf") 372 373 return cost 374