xref: /aosp_15_r20/external/pytorch/torch/distributed/_composable/fsdp/_fsdp_common.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# mypy: allow-untyped-defs
2import math
3import traceback
4from dataclasses import dataclass
5from enum import auto, Enum
6from typing import Any, cast, List, Optional
7
8import torch
9import torch._dynamo.compiled_autograd as ca
10import torch.distributed as dist
11import torch.nn as nn
12from torch.distributed._composable.contract import _get_registry
13from torch.distributed.tensor import DeviceMesh, DTensor
14from torch.distributed.tensor._dtensor_spec import DTensorSpec
15
16
17@dataclass
18class DataParallelMeshInfo:
19    mesh: DeviceMesh
20    shard_mesh_dim: Optional[int] = None
21    replicate_mesh_dim: Optional[int] = None
22
23    def __post_init__(self):
24        if self.shard_mesh_dim is None and self.replicate_mesh_dim is None:
25            raise AssertionError(
26                "At least one of shard_mesh_dim and replicate_mesh_dim must not be None"
27            )
28
29
30@dataclass
31class FSDPMeshInfo(DataParallelMeshInfo):
32    def __post_init__(self):
33        super().__post_init__()
34        if self.shard_mesh_dim is None:
35            raise AssertionError("Expects non-None shard_mesh_dim")
36        self.shard_mesh_size: int = self.mesh.size(self.shard_mesh_dim)
37        self.shard_process_group = self.mesh.get_group(self.shard_mesh_dim)
38        self.shard_mesh_rank: int = self.shard_process_group.rank()
39
40
41@dataclass
42class DDPMeshInfo(DataParallelMeshInfo):
43    def __post_init__(self):
44        super().__post_init__()
45        if self.replicate_mesh_dim is None:
46            raise AssertionError("Expects non-None replicate_mesh_dim")
47        self.replicate_mesh_size: int = self.mesh.size(self.replicate_mesh_dim)
48        self.replicate_process_group = self.mesh.get_group(self.replicate_mesh_dim)
49        self.replicate_mesh_rank: int = self.replicate_process_group.rank()
50
51
52@dataclass
53class HSDPMeshInfo(FSDPMeshInfo, DDPMeshInfo):
54    def __post_init__(self):
55        # Calls `FSDPMeshInfo` -> `DDPMeshInfo` -> `DataParallelMeshInfo`
56        super().__post_init__()
57
58
59class TrainingState(Enum):
60    """Describes the training state of one FSDP state / parameter group."""
61
62    # Transition to forward starting pre-forward until post-forward
63    FORWARD = auto()
64    # Transition to pre-backward when unsharding in backward
65    PRE_BACKWARD = auto()
66    # Transition to post-backward when resharding and reducing gradients
67    POST_BACKWARD = auto()
68    # Idle before/after forward or before pre-backward/after post-backward
69    IDLE = auto()
70
71
72def _raise_assert_with_print(*args: Any, **kwargs: Any):
73    print(f"[Rank {dist.get_rank()}] ", end="")
74    print(*args, **kwargs)
75    traceback.print_stack()
76    raise AssertionError(*args, **kwargs)
77
78
79def _is_composable_with_fsdp(module: nn.Module) -> bool:
80    registry = _get_registry(module)
81    if registry is None:
82        return True
83    # Registry keys by function name
84    return "replicate" not in registry
85
86
87def _get_dim0_padded_size(tensor_size: torch.Size, dim0_factor: int) -> torch.Size:
88    padded_dim0 = math.ceil(tensor_size[0] / dim0_factor) * dim0_factor
89    return cast(torch.Size, torch.Size([padded_dim0]) + tensor_size[1:])
90
91
92def _chunk_with_empty(
93    tensor: torch.Tensor, num_chunks: int, dim: int
94) -> List[torch.Tensor]:
95    chunks = list(torch.chunk(tensor, num_chunks, dim=dim))
96    while len(chunks) < num_chunks:
97        chunks.append(chunks[0].new_empty(0))
98    return chunks
99
100
101def _get_dim0_chunked_size(
102    chunk: torch.Tensor, unchunked_size: torch.Size
103) -> torch.Size:
104    if chunk.numel() > 0:
105        return chunk.size()
106    # For 0 numel, we need to preserve trailing dims for DTensor APIs
107    return cast(torch.Size, torch.Size([0]) + unchunked_size[1:])
108
109
110def _from_local_no_grad(
111    local_tensor: torch.Tensor,
112    sharding_spec: DTensorSpec,
113) -> DTensor:
114    """
115    This method is similar to ``DTensor.from_local()`` except that in eager mode
116    it avoids some CPU overhead by avoiding default args and not being differentiable.
117    """
118
119    if not ca.compiled_autograd_enabled:
120        return DTensor(
121            # Use the local tensor directly instead of constructing a new tensor
122            # variable, e.g. with `view_as()`, since this is not differentiable
123            local_tensor,
124            sharding_spec,
125            requires_grad=local_tensor.requires_grad,
126        )
127    else:
128        return DTensor.from_local(
129            local_tensor,
130            sharding_spec.mesh,
131            sharding_spec.placements,
132            shape=sharding_spec.shape,
133            stride=sharding_spec.stride,
134        )
135
136
137def _to_dtype_if_needed(
138    tensor: torch.Tensor, dtype: Optional[torch.dtype]
139) -> torch.Tensor:
140    if dtype is not None and tensor.dtype != dtype:
141        return tensor.to(dtype)
142    return tensor
143
144
145def _cast_fp_tensor(dtype: torch.dtype, x: torch.Tensor) -> torch.Tensor:
146    if (
147        not isinstance(x, torch.Tensor)
148        or not torch.is_floating_point(x)
149        or x.dtype == dtype
150    ):
151        return x
152    return x.to(dtype)
153