1from typing import cast, List, Sequence, Tuple 2 3import torch 4import torch.distributed.tensor._api as dtensor 5from torch._prims_common import ShapeType 6from torch.distributed.device_mesh import DeviceMesh 7from torch.distributed.tensor._dtensor_spec import DTensorSpec 8from torch.distributed.tensor.placement_types import ( 9 _StridedShard, 10 Partial, 11 Placement, 12 Replicate, 13 Shard, 14) 15 16 17# TODO: audit existing code base to see if we can safely remove this API. 18def compute_local_shape( 19 global_shape: ShapeType, mesh: DeviceMesh, placements: Sequence[Placement] 20) -> Tuple[int, ...]: 21 """ 22 Compute the shape of a local shard of the given DTensor on its current 23 coordinate of the mesh. 24 """ 25 my_coordinate = mesh.get_coordinate() 26 27 if my_coordinate is None: 28 # if rank not in the mesh, return empty shape 29 return (0,) 30 else: 31 local_shape = list(global_shape) # start with global shape 32 ndim = len(global_shape) 33 for idx, placement in enumerate(placements): 34 mesh_dim_size = mesh.size(idx) 35 if isinstance(placement, Shard): 36 shard_dim = placement.dim 37 assert ( 38 shard_dim < ndim 39 ), f"Sharding dim {shard_dim} greater than tensor ndim {ndim}" 40 local_shard_size, _ = placement._local_shard_size_on_dim( 41 local_shape[shard_dim], mesh_dim_size, my_coordinate[idx] 42 ) 43 assert isinstance(local_shard_size, int) 44 local_shape[shard_dim] = local_shard_size 45 46 return tuple(local_shape) 47 48 49def compute_local_shape_and_global_offset( 50 global_shape: ShapeType, mesh: DeviceMesh, placements: Sequence[Placement] 51) -> Tuple[Tuple[int, ...], Tuple[int, ...]]: 52 """ 53 Compute the local tensor shape and the global offsets into the original tensor 54 of a DTensor on its current global rank. This is useful for checkpointing purpose. 55 56 Example (2 host with 4GPUs each): 57 # Below is a DeviceMesh with mesh_shape of (2, 4) 58 mesh = DeviceMesh(device_type="cuda", 59 mesh=[ 60 [0, 1, 2, 3], 61 [4, 5, 6, 7] 62 ], 63 ) 64 65 Let's say we distribute a global_tensor of shape (8,4) over the above DeviceMesh 66 with a placements of [Shard(0), Shard(0)]. 67 The local shape and global offset will be as follows: 68 rank0 -- local_shape:[1, 4], global_offset:[0, 0] 69 rank1 -- local_shape:[1, 4], global_offset:[1, 0] 70 rank2 -- local_shape:[1, 4], global_offset:[2, 0] 71 rank5 -- local_shape:[1, 4], global_offset:[5, 0] 72 rank3 -- local_shape:[1, 4], global_offset:[3, 0] 73 rank4 -- local_shape:[1, 4], global_offset:[4, 0] 74 rank6 -- local_shape:[1, 4], global_offset:[6, 0] 75 rank7 -- local_shape:[1, 4], global_offset:[7, 0] 76 77 Let's say we distribute a global_tensor of shape (2) over the above DeviceMesh with 78 a placements of [Shard(0)]. We will not have non-empty local tensor for all the ranks. 79 The local shape and global offset will be as follows: 80 rank0 -- local_shape:[1,], global_offset:[0,] 81 rank1 -- local_shape:[1,], global_offset:[1,] 82 rank2 -- local_shape:[0,], global_offset:[2,] 83 rank5 -- local_shape:[0,], global_offset:[2,] 84 rank3 -- local_shape:[0,], global_offset:[2,] 85 rank4 -- local_shape:[0,], global_offset:[2,] 86 rank6 -- local_shape:[0,], global_offset:[2,] 87 rank7 -- local_shape:[0,], global_offset:[2,] 88 """ 89 my_coordinate = mesh.get_coordinate() 90 91 if my_coordinate is None: 92 # if rank not in the mesh, return empty offset 93 return ((), ()) 94 else: 95 local_shape = list(global_shape) 96 global_offset = [0] * len(global_shape) 97 shard_idx_stride_by_mesh_dim = [ 98 [0] * mesh.ndim for _ in range(len(global_shape)) 99 ] # index by (shard_dim, mesh_dim) 100 num_shards_by_tensor_dim = [1] * len(global_shape) 101 102 for idx, placement in enumerate(placements): 103 mesh_dim_size = mesh.size(idx) 104 if isinstance(placement, Shard): 105 shard_dim = placement.dim 106 local_offset = [0] * len(global_shape) 107 assert shard_dim < len( 108 local_shape 109 ), f"Sharding dim {shard_dim} greater than tensor ndim {len(local_shape)}" 110 shard_size, shard_offset = placement._local_shard_size_on_dim( 111 local_shape[shard_dim], 112 mesh_dim_size, 113 my_coordinate[idx], 114 return_offset=True, 115 ) 116 117 local_shape[shard_dim] = shard_size 118 local_offset[shard_dim] = shard_offset 119 120 # On a given dimension, if the local_offset[shard_dim] is smaller than global_offset[shard_dim], 121 # it means that this dimension has been already sharded in previous placement. 122 # Therefore, we cannot simply replace the global_offset[shard_dim] with local_offset[shard_dim]. 123 # Instead, for the given shard_dim, we need to add local_offset[shard_dim] to existing global_offset[shard_dim]. 124 if global_offset[shard_dim] <= local_offset[shard_dim]: 125 global_offset[shard_dim] = local_offset[shard_dim] 126 else: 127 global_offset[shard_dim] += local_offset[shard_dim] 128 129 num_shards_by_tensor_dim[shard_dim] *= mesh_dim_size 130 131 # NOTE: the offset compute relies on the local shard index and it has no 132 # problem when strided sharding is not present. To correctly compute, we assume 133 # that the ``_StridedShard.split_factor`` field encodes how many partitions 134 # each local tensor will be further split into when sharding on higher mesh 135 # dimensions. However, this number is only correct if the DTensor is not 136 # sharded after the strided sharding completes. For example, 137 # [Shard(0), _StridedShard(0, split_factor=2), Shard(0)] is the placements 138 # where the DTensor's dim-0 is first sharded on device mesh dim-0, then on 139 # device mesh dim-2, and last on mesh dim-1. We define the 140 # "_StridedShard(0, split_factor=2), Shard(0)" part as the strided sharding 141 # part because strided sharding happens on mesh dim-1 and it was caused by 142 # the fact that sharding on dim-2 occurred ahead. In this case, there's no 143 # further sharding after this strided sharding part and ``split_factor`` 144 # correctly encodes the number. Another example is 145 # [_StridedShard(0, split_factor=2), Shard(0), Shard(0)] where the DTensor's 146 # dim-0 is first sharded on mesh dim-1, then on mesh dim-0, and last on mesh 147 # dim-2. This violates our assumption that no further sharding shall occur 148 # after the strided sharding part and ``split_factor`` won't correctly 149 # encode the number of further split. So far, the only case where _StridedShard 150 # placement would appear is FSDP2 + TP on 2D mesh and the above case could only 151 # happen on mesh of 3 or more dimensions. 152 # TODO: change this function to correctly address this. 153 # TODO: this logic can be applied to contiguous sharding as well 154 strided_sharding = any(isinstance(p, _StridedShard) for p in placements) 155 if strided_sharding: 156 strided_part_seen = [False] * len(global_shape) 157 strided_part_end = [False] * len(global_shape) 158 for idx, placement in enumerate(placements): 159 mesh_dim_size = mesh.size(idx) 160 if isinstance(placement, Shard): 161 shard_dim = placement.dim 162 163 if strided_part_end[shard_dim]: 164 raise NotImplementedError( 165 f"Strided sharding does not allow Shard() to appear after " 166 f"the strided part has ended. {placement} at idx {idx} in " 167 f"{placements} violates this assumption." 168 ) 169 170 if strided_part_seen[shard_dim]: 171 strided_part_end[shard_dim] = True 172 173 if isinstance(placement, _StridedShard): 174 strided_part_seen[shard_dim] = True 175 shard_idx_stride_by_mesh_dim[shard_dim][ 176 idx 177 ] = num_shards_by_tensor_dim[shard_dim] // ( 178 placement.split_factor * mesh_dim_size 179 ) 180 else: 181 num_shards_by_tensor_dim[shard_dim] //= mesh_dim_size 182 shard_idx_stride_by_mesh_dim[shard_dim][ 183 idx 184 ] = num_shards_by_tensor_dim[shard_dim] 185 186 shard_idx = [ 187 sum([x * y for x, y in zip(shard_idx_stride, my_coordinate)]) 188 for shard_dim, shard_idx_stride in enumerate( 189 shard_idx_stride_by_mesh_dim 190 ) 191 ] 192 193 global_offset = [x * y for x, y in zip(local_shape, shard_idx)] 194 195 return tuple(local_shape), tuple(global_offset) 196 197 198def compute_global_tensor_info( 199 tensor: torch.Tensor, mesh: DeviceMesh, placements: Sequence[Placement] 200) -> Tuple[List[int], List[int]]: 201 """ 202 Compute the global size and stride of a DTensor from the given local tensor. 203 The local size is multiplited by `world_size` per Sharding dim. 204 The local stride is multiplited by `world_size` per Sharding dim, as long as the 205 dimension is outside sharding dim. 206 207 For example, if we have a local tensor with size (4, 8, 2) and stride (16, 1, 8). 208 If the DTensor placements are [Shard(2)] and world_size is 2; 209 then the global size is (4, 8, 4) and stride is (16 * 2, 1, 8). 210 211 Args: 212 tensor (:class:`torch.Tensor`): 213 Local tensor which DTensor will be constructed from. 214 mesh (:class:`DeviceMesh`): 215 Object which describes the mesh topology 216 of devices for the DTensor. 217 placements (Sequence[:class:`Placement`]]): 218 The attribute of the DTensor that describes its layout 219 on the mesh topology. 220 221 Return: 222 tensor_shape: A List of int which specifies the size of DTensor which build 223 on top of the local tensor. 224 tensor_stride: A List of int which specifies the stride of DTensor. 225 """ 226 tensor_shape = list(tensor.size()) 227 tensor_stride = list(tensor.stride()) 228 for idx, placement in enumerate(placements): 229 mesh_dim_size = mesh.size(idx) 230 if placement.is_shard(): 231 shard_placement = cast(Shard, placement) 232 if shard_placement.dim < 0: 233 raise AssertionError( 234 "Shard placements should have negative dims normalized in " 235 f"the user-facing APIs: {shard_placement}" 236 ) 237 shard_dim = shard_placement.dim 238 239 assert ( 240 shard_dim < tensor.ndim 241 ), f"Sharding dim {shard_dim} greater than tensor ndim {tensor.ndim} for placement number {idx}." 242 243 local_dim_size = tensor_shape[shard_dim] 244 tensor_shape[shard_dim] = local_dim_size * mesh_dim_size 245 246 # recover tensor stride by modifying the stride that larger than 247 # the current stride on the shard_dim 248 for i in range(len(tensor_stride)): 249 if i != shard_dim and tensor_stride[i] >= tensor_stride[shard_dim]: 250 # rescale the stride by the shard size 251 tensor_stride[i] = tensor_stride[i] * mesh_dim_size 252 elif not isinstance(placement, (Replicate, Partial)): 253 raise RuntimeError(f"placement type {type(placement)} not supported!") 254 return tensor_shape, tensor_stride 255 256 257def try_find_mesh_from_args( 258 op_call: torch._ops.OpOverload, args: Sequence[object] 259) -> DeviceMesh: 260 """ 261 Find the device mesh object from args. 262 It returns None if no mesh is found. 263 NOTE: we can optimize this search if needed 264 """ 265 for arg in args: 266 if isinstance(arg, (dtensor.DTensor, DTensorSpec)): 267 return arg.device_mesh 268 elif ( 269 isinstance(arg, (list, tuple)) 270 and len(arg) > 0 271 and isinstance(arg[0], (dtensor.DTensor, DTensorSpec)) 272 ): 273 return arg[0].device_mesh 274 275 raise ValueError(f"Cannot find device mesh from args for op : {op_call}.") 276 277 278def compute_local_stride( 279 global_stride: ShapeType, mesh: DeviceMesh, placements: Sequence[Placement] 280) -> Tuple[int, ...]: 281 """ 282 Compute the stride of a local tensor shard, given the global stride of the DTensor. 283 NOTE: Currently this function is assuming the DTensor is evenly shardable. 284 """ 285 stride_divisors = [1] * len(global_stride) 286 for mesh_idx, p in enumerate(placements): 287 if p.is_shard(): 288 i = cast(Shard, p).dim 289 # tensor dimension i is sharded on mesh dimension mesh_idx, 290 # so we need to divide all the strides larger than stride[i] 291 # (by the submesh size) 292 for j in range(len(global_stride)): 293 if global_stride[j] > global_stride[i]: 294 stride_divisors[j] *= mesh.size(mesh_idx) 295 return tuple( 296 global_stride[i] // stride_divisors[i] for i in range(len(global_stride)) 297 ) 298 299 300def normalize_to_torch_size(size) -> torch.Size: # type: ignore[no-untyped-def] 301 """ 302 Unify variable types of size argument to torch.Size 303 Acceptable types include: 304 int, Sequence[int], Tuple[int], Tuple[Sequence[int]], 305 or torch.Size 306 """ 307 if isinstance(size, torch.Size): 308 return size 309 310 if isinstance(size, int): 311 torch_size = [size] 312 elif len(size) == 1 and isinstance(size[0], Sequence): 313 torch_size = list(size[0]) 314 else: 315 torch_size = list(size) 316 return torch.Size(torch_size) 317