1# mypy: allow-untyped-defs 2import copy 3from typing import Any, cast, List, Optional, Tuple 4 5import torch 6import torch.distributed as dist 7import torch.distributed._shard.sharding_spec as shard_spec 8import torch.distributed.distributed_c10d as c10d 9from torch.distributed._shard.sharded_tensor import ( 10 Shard, 11 ShardedTensor, 12 ShardedTensorMetadata, 13 TensorProperties, 14) 15from torch.distributed._shard.sharding_spec import ShardMetadata 16from torch.distributed._shard.sharding_spec.chunk_sharding_spec import ChunkShardingSpec 17from torch.distributed.device_mesh import _mesh_resources 18from torch.distributed.fsdp._common_utils import _set_fsdp_flattened 19from torch.distributed.fsdp._fsdp_extensions import FSDPExtensions 20from torch.distributed.fsdp._shard_utils import _create_chunk_sharded_tensor 21from torch.distributed.remote_device import _remote_device 22from torch.distributed.tensor import DeviceMesh, DTensor, Replicate, Shard as DShard 23from torch.distributed.tensor.parallel._data_parallel_utils import ( 24 _flatten_tensor, 25 _unflatten_tensor, 26) 27 28 29__all__ = ["DTensorExtensions"] 30 31 32def _get_box(tensor: DTensor) -> Tuple[torch.Size, torch.Size]: 33 device_mesh = tensor.device_mesh 34 assert device_mesh.ndim == 1, "Only 1D DeviceMeshes currently handled" 35 36 placement = tensor.placements[0] 37 offsets = [0] * len(tensor.size()) 38 num_chunks = device_mesh.size(mesh_dim=0) 39 40 if tensor.placements[0].is_shard(): 41 shard_dim = cast(DShard, placement).dim 42 chunk_size = tensor.size(shard_dim) // num_chunks 43 offsets[shard_dim] = chunk_size 44 45 return (torch.Size(offsets), tensor._local_tensor.size()) 46 47 48def _get_box_for(tensor: DTensor, idx: int) -> Tuple[torch.Size, torch.Size]: 49 offsets, size = _get_box(tensor) 50 return (torch.Size([val * idx for val in offsets]), size) 51 52 53def _get_local_box(tensor: DTensor) -> Tuple[torch.Size, torch.Size]: 54 device_mesh = tensor.device_mesh 55 coord = device_mesh.get_coordinate() 56 assert coord is not None 57 return _get_box_for(tensor, coord[0]) 58 59 60def _create_shard_md_from_dt(dt: DTensor, current_rank: int) -> ShardMetadata: 61 mesh = dt.device_mesh 62 assert mesh.ndim == 1, "Only 1D DeviceMeshes currently handled" 63 64 offsets, sizes = _get_local_box(dt) 65 return ShardMetadata( 66 shard_offsets=list(offsets), 67 shard_sizes=list(sizes), 68 placement=f"rank:{current_rank}/{dt._local_tensor.device}", 69 ) 70 71 72def _create_sharded_tensor_md_from_dt( 73 dt: DTensor, dt_pg: c10d.ProcessGroup 74) -> ShardedTensorMetadata: 75 # This is where it gets tricky, we have to produce a ShardedTensor that has full coverage 76 # and yet has only one valid shard for the current rank. 77 78 shards_md = [] 79 my_rank = dist.get_rank(dt_pg) 80 scapegoat_rank = 0 if my_rank > 0 else 1 81 82 if dt.placements[0].is_shard(): 83 shard_count = dt_pg.size() 84 else: 85 shard_count = 1 86 87 for i in range(shard_count): 88 offsets, sizes = _get_box_for(dt, i) 89 shards_md.append( 90 ShardMetadata( 91 shard_offsets=list(offsets), 92 shard_sizes=list(sizes), 93 placement=( 94 f"rank:{scapegoat_rank if i > 0 else my_rank}/{dt._local_tensor.device}" 95 ), 96 ) 97 ) 98 99 return ShardedTensorMetadata( 100 shards_metadata=shards_md, 101 size=dt.size(), 102 tensor_properties=TensorProperties( 103 dtype=dt.dtype, 104 layout=dt.layout, 105 requires_grad=dt.requires_grad, 106 # ignore memory_format and pin_memory as those are not supported by DT 107 ), 108 ) 109 110 111def _get_dt_pg(dt: DTensor) -> c10d.ProcessGroup: 112 mesh = dt.device_mesh 113 assert mesh.ndim == 1, "Only 1D DeviceMeshes currently handled" 114 return mesh.get_group() 115 116 117def _rewrite_spec_if_needed( 118 spec: shard_spec.ShardingSpec, tensor: torch.Tensor, rank: int 119) -> shard_spec.ShardingSpec: 120 """ 121 Rewrite ``spec`` to match the device of ``tensor``. 122 123 FSDP.sharded_optim_state_dict sneakly ships optimizer state to CPU so if the original ShardingSpec 124 produces CUDA metadata, ST construction bombs. 125 """ 126 if not isinstance(spec, ChunkShardingSpec): 127 return spec 128 129 # let's see if we need 130 rewrite = False 131 for p in spec.placements: 132 p = cast(_remote_device, p) 133 if p.rank() == rank and p.device() != tensor.device: 134 rewrite = True 135 break 136 if rewrite: 137 spec = copy.deepcopy(spec) 138 for i, placement in enumerate(spec.placements): 139 placement = cast(_remote_device, placement) 140 if placement.rank() == rank and placement.device() != tensor.device: 141 spec.placements[i] = _remote_device(f"rank:{rank}/{tensor.device}") 142 143 return spec 144 145 146def _chunk_tensor( 147 tensor: torch.Tensor, 148 rank: int, 149 world_size: int, 150 num_devices_per_node: int, 151 pg: dist.ProcessGroup, 152) -> torch.Tensor: 153 if type(tensor) is ShardedTensor: 154 assert len(tensor.local_shards()) == 1 155 156 inner_param = tensor.local_tensor() 157 inner_st = _create_chunk_sharded_tensor( 158 inner_param, 159 rank, 160 world_size, 161 num_devices_per_node, 162 pg, 163 ) 164 165 outer_local_shard = tensor.local_shards()[0] 166 shards: List[Shard] = [ 167 Shard(inner_st, copy.deepcopy(outer_local_shard.metadata)) 168 ] 169 st_meta = copy.deepcopy(tensor.metadata()) 170 st_meta.tensor_properties.requires_grad = False 171 172 st_outer = ShardedTensor._init_from_local_shards_and_global_metadata( 173 shards, 174 sharded_tensor_metadata=st_meta, 175 process_group=tensor._process_group, 176 init_rrefs=False, 177 ) 178 return st_outer 179 elif type(tensor) is DTensor: 180 device_mesh = tensor.device_mesh 181 assert device_mesh.ndim == 1, "Only 1D DeviceMeshes currently handled" 182 183 inner_param = tensor._local_tensor 184 185 inner_st = _create_chunk_sharded_tensor( 186 inner_param, 187 rank, 188 world_size, 189 torch.cuda.device_count(), 190 pg, 191 ) 192 193 dt_pg = _get_dt_pg(tensor) 194 # We do this differently here, we create a ST with no local shards then patch it 195 shards = [ 196 Shard(inner_st, _create_shard_md_from_dt(tensor, dist.get_rank(dt_pg))) 197 ] 198 199 st_meta = _create_sharded_tensor_md_from_dt(tensor, dt_pg) 200 st_meta.tensor_properties.requires_grad = False 201 202 st_outer = ShardedTensor._init_from_local_shards_and_global_metadata( 203 shards, 204 sharded_tensor_metadata=st_meta, 205 process_group=dt_pg, 206 init_rrefs=False, 207 ) 208 209 return st_outer 210 else: 211 return _create_chunk_sharded_tensor( 212 tensor, 213 rank, 214 world_size, 215 num_devices_per_node, 216 pg, 217 ) 218 219 220def _chunk_dtensor( 221 tensor: torch.Tensor, 222 rank: int, 223 device_mesh: DeviceMesh, 224) -> DTensor: 225 """ 226 Shard a tensor to chunks along the first dimension. 227 228 The local rank will gets its corresponding chunk as the local tensor to create a DTensor. 229 """ 230 root_mesh = _mesh_resources.get_root_mesh(device_mesh) 231 if root_mesh is None: 232 raise RuntimeError("No parent device_mesh is found for FSDP device_mesh.") 233 if root_mesh.ndim < 2: 234 raise RuntimeError( 235 f"Found parent device_mesh of ndim={root_mesh.ndim},", 236 "but meshes must be at least 2D.", 237 ) 238 239 # We need to explicitly call .detach() to return a new tensor detached from the current graph. 240 tensor = tensor.clone().detach() 241 242 # When a layer is not involved in TP, then the tensor will not be a DTensor. 243 # e.g. When a layer is not sppecified in the parallelize_plan, TP will have no effect on the layer. 244 # e.g. When you do PairwiseParallel on a 3 layer model, TP will have no effect on the third layer. 245 if isinstance(tensor, torch.Tensor) and not isinstance(tensor, DTensor): 246 # For tensors, it is replicated across tp dimension and sharded across FSDP dimension. 247 # TP is the inner dimension and FSDP is the outer dimension. 248 # Therefore, shard placements for tensor is (Shard(0), Replicate()). 249 replicate_placements = [Replicate() for _ in range(root_mesh.ndim)] 250 shard_placements = [Replicate() for _ in range(root_mesh.ndim)] 251 shard_placements[0] = DShard(0) # type: ignore[call-overload] 252 253 return DTensor.from_local( 254 tensor, root_mesh, replicate_placements, run_check=False 255 ).redistribute( 256 device_mesh=root_mesh, 257 placements=shard_placements, 258 ) 259 260 else: 261 tp_placements = tensor.placements 262 tp_placement = tp_placements[0] 263 264 tensor = tensor.to_local() 265 266 # For DTensors, it is sharded across tp dimension first and then sharded across FSDP dimension. 267 # TP is the inner dimension and FSDP is the outer dimension. 268 # Therefore, shard placements for tensor is (Shard(0), tp_placement). 269 # For higher dimensional meshes, it is replicated across other dimensions. For example, with 270 # HSDP the shard placements for tensor is (Replicate, Shard(0), tp_placement). 271 replicate_placements = [Replicate() for _ in range(root_mesh.ndim)] 272 replicate_placements[-1] = tp_placement # type: ignore[call-overload] 273 shard_placements = [Replicate() for i in range(root_mesh.ndim)] # type: ignore[misc] 274 shard_placements[-2] = DShard(0) # type: ignore[call-overload] 275 shard_placements[-1] = tp_placement # type: ignore[call-overload] 276 277 return DTensor.from_local( 278 tensor, root_mesh, replicate_placements, run_check=False 279 ).redistribute( 280 device_mesh=root_mesh, 281 placements=shard_placements, 282 ) 283 284 285def _pre_load_state_dict( 286 tensor: torch.Tensor, 287) -> Tuple[torch.Tensor, List[Shard]]: 288 shards = cast(ShardedTensor, tensor).local_shards() 289 if len(shards) == 1 and type(shards[0].tensor) is ShardedTensor: 290 inner_tensor = shards[0].tensor 291 shards = inner_tensor.local_shards() # pyre-ignore[16] 292 tensor = inner_tensor 293 294 return (tensor, shards if len(shards) > 0 else []) 295 296 297def _all_gather_dtensor( 298 tensor: DTensor, 299 parent_mesh: Optional[DeviceMesh], 300) -> torch.Tensor: 301 """All gather a DTensor in its FSDP dimension and return the local tensor.""" 302 assert parent_mesh == tensor.device_mesh 303 304 placements = list(copy.deepcopy(tensor.placements)) 305 # FSDP + TP: [Shard(0), tp_placement] -> [Replicate(), tp_placement] 306 # HSDP + TP: [Replicate(), Shard(0), tp_placement] -> [Replicate(), Replicate(), tp_placement] 307 for i in range(0, len(placements) - 1): 308 placements[i] = Replicate() 309 tensor = tensor.redistribute( 310 device_mesh=tensor.device_mesh, 311 placements=placements, 312 ) 313 314 return tensor.to_local() 315 316 317class DTensorExtensions(FSDPExtensions): 318 """ 319 DTensorExtension is the TensorFlattener extension needed for 2D FSDP + TP. 320 321 This is the implementation for FSDPExtensions defined in 322 https://github.com/pytorch/pytorch/blob/main/torch/distributed/fsdp/_fsdp_extensions.py 323 """ 324 325 def __init__(self, device_handle) -> None: 326 super().__init__() 327 self.compute_stream = None 328 self.device_handle = device_handle 329 # we have to use the dynamo disable this way to disable dynamo as the decorater way would 330 # trigger build failure with torch deploy... 331 self.post_unflatten_transform = torch._dynamo.disable(self.post_unflatten_transform) # type: ignore[method-assign] 332 333 def pre_flatten_transform( 334 self, 335 tensor: torch.Tensor, 336 ) -> Tuple[torch.Tensor, Optional[Any]]: 337 return _flatten_tensor(tensor) 338 339 def post_unflatten_transform( 340 self, tensor: torch.Tensor, param_extension: Any 341 ) -> torch.Tensor: 342 stream = self.compute_stream or self.device_handle.current_stream() 343 with self.device_handle.stream(stream): 344 # runtime we put the unflattened tensor call on the compute stream since 345 # the unflattened tensor might contain computations in fwd/bwd where we 346 # need to sync properly. 347 # TODO: this is a short term fix and we should make the get_unflat_views 348 # directly happen in the compute stream. 349 result = _unflatten_tensor( 350 tensor, 351 param_extension, 352 device_handle=self.device_handle, 353 compute_stream=self.compute_stream, 354 ) 355 _set_fsdp_flattened(result) 356 return result 357 358 def chunk_tensor( 359 self, 360 tensor: torch.Tensor, 361 rank: int, 362 world_size: int, 363 num_devices_per_node: int, 364 pg: dist.ProcessGroup, 365 device: Optional[torch.device] = None, 366 ) -> torch.Tensor: 367 return _chunk_tensor(tensor, rank, world_size, num_devices_per_node, pg) 368 369 def chunk_dtensor( 370 self, 371 tensor: torch.Tensor, 372 rank: int, 373 device_mesh: DeviceMesh, 374 ) -> torch.Tensor: 375 return _chunk_dtensor(tensor, rank, device_mesh) 376 377 def pre_load_state_dict_transform( 378 self, 379 tensor: torch.Tensor, 380 ) -> Tuple[torch.Tensor, List[Shard]]: 381 return _pre_load_state_dict(tensor) 382 383 def all_gather_dtensor( 384 self, 385 tensor: DTensor, 386 parent_mesh: Optional[DeviceMesh], 387 ) -> torch.Tensor: 388 return _all_gather_dtensor(tensor, parent_mesh) 389