1# mypy: allow-untyped-defs 2import math 3import traceback 4from dataclasses import dataclass 5from enum import auto, Enum 6from typing import Any, cast, List, Optional 7 8import torch 9import torch._dynamo.compiled_autograd as ca 10import torch.distributed as dist 11import torch.nn as nn 12from torch.distributed._composable.contract import _get_registry 13from torch.distributed.tensor import DeviceMesh, DTensor 14from torch.distributed.tensor._dtensor_spec import DTensorSpec 15 16 17@dataclass 18class DataParallelMeshInfo: 19 mesh: DeviceMesh 20 shard_mesh_dim: Optional[int] = None 21 replicate_mesh_dim: Optional[int] = None 22 23 def __post_init__(self): 24 if self.shard_mesh_dim is None and self.replicate_mesh_dim is None: 25 raise AssertionError( 26 "At least one of shard_mesh_dim and replicate_mesh_dim must not be None" 27 ) 28 29 30@dataclass 31class FSDPMeshInfo(DataParallelMeshInfo): 32 def __post_init__(self): 33 super().__post_init__() 34 if self.shard_mesh_dim is None: 35 raise AssertionError("Expects non-None shard_mesh_dim") 36 self.shard_mesh_size: int = self.mesh.size(self.shard_mesh_dim) 37 self.shard_process_group = self.mesh.get_group(self.shard_mesh_dim) 38 self.shard_mesh_rank: int = self.shard_process_group.rank() 39 40 41@dataclass 42class DDPMeshInfo(DataParallelMeshInfo): 43 def __post_init__(self): 44 super().__post_init__() 45 if self.replicate_mesh_dim is None: 46 raise AssertionError("Expects non-None replicate_mesh_dim") 47 self.replicate_mesh_size: int = self.mesh.size(self.replicate_mesh_dim) 48 self.replicate_process_group = self.mesh.get_group(self.replicate_mesh_dim) 49 self.replicate_mesh_rank: int = self.replicate_process_group.rank() 50 51 52@dataclass 53class HSDPMeshInfo(FSDPMeshInfo, DDPMeshInfo): 54 def __post_init__(self): 55 # Calls `FSDPMeshInfo` -> `DDPMeshInfo` -> `DataParallelMeshInfo` 56 super().__post_init__() 57 58 59class TrainingState(Enum): 60 """Describes the training state of one FSDP state / parameter group.""" 61 62 # Transition to forward starting pre-forward until post-forward 63 FORWARD = auto() 64 # Transition to pre-backward when unsharding in backward 65 PRE_BACKWARD = auto() 66 # Transition to post-backward when resharding and reducing gradients 67 POST_BACKWARD = auto() 68 # Idle before/after forward or before pre-backward/after post-backward 69 IDLE = auto() 70 71 72def _raise_assert_with_print(*args: Any, **kwargs: Any): 73 print(f"[Rank {dist.get_rank()}] ", end="") 74 print(*args, **kwargs) 75 traceback.print_stack() 76 raise AssertionError(*args, **kwargs) 77 78 79def _is_composable_with_fsdp(module: nn.Module) -> bool: 80 registry = _get_registry(module) 81 if registry is None: 82 return True 83 # Registry keys by function name 84 return "replicate" not in registry 85 86 87def _get_dim0_padded_size(tensor_size: torch.Size, dim0_factor: int) -> torch.Size: 88 padded_dim0 = math.ceil(tensor_size[0] / dim0_factor) * dim0_factor 89 return cast(torch.Size, torch.Size([padded_dim0]) + tensor_size[1:]) 90 91 92def _chunk_with_empty( 93 tensor: torch.Tensor, num_chunks: int, dim: int 94) -> List[torch.Tensor]: 95 chunks = list(torch.chunk(tensor, num_chunks, dim=dim)) 96 while len(chunks) < num_chunks: 97 chunks.append(chunks[0].new_empty(0)) 98 return chunks 99 100 101def _get_dim0_chunked_size( 102 chunk: torch.Tensor, unchunked_size: torch.Size 103) -> torch.Size: 104 if chunk.numel() > 0: 105 return chunk.size() 106 # For 0 numel, we need to preserve trailing dims for DTensor APIs 107 return cast(torch.Size, torch.Size([0]) + unchunked_size[1:]) 108 109 110def _from_local_no_grad( 111 local_tensor: torch.Tensor, 112 sharding_spec: DTensorSpec, 113) -> DTensor: 114 """ 115 This method is similar to ``DTensor.from_local()`` except that in eager mode 116 it avoids some CPU overhead by avoiding default args and not being differentiable. 117 """ 118 119 if not ca.compiled_autograd_enabled: 120 return DTensor( 121 # Use the local tensor directly instead of constructing a new tensor 122 # variable, e.g. with `view_as()`, since this is not differentiable 123 local_tensor, 124 sharding_spec, 125 requires_grad=local_tensor.requires_grad, 126 ) 127 else: 128 return DTensor.from_local( 129 local_tensor, 130 sharding_spec.mesh, 131 sharding_spec.placements, 132 shape=sharding_spec.shape, 133 stride=sharding_spec.stride, 134 ) 135 136 137def _to_dtype_if_needed( 138 tensor: torch.Tensor, dtype: Optional[torch.dtype] 139) -> torch.Tensor: 140 if dtype is not None and tensor.dtype != dtype: 141 return tensor.to(dtype) 142 return tensor 143 144 145def _cast_fp_tensor(dtype: torch.dtype, x: torch.Tensor) -> torch.Tensor: 146 if ( 147 not isinstance(x, torch.Tensor) 148 or not torch.is_floating_point(x) 149 or x.dtype == dtype 150 ): 151 return x 152 return x.to(dtype) 153