xref: /aosp_15_r20/external/pytorch/torch/distributed/_composable/fsdp/_fsdp_init.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1import itertools
2from typing import List, Optional, Set, Tuple, Union
3
4import torch
5import torch.distributed as dist
6import torch.nn as nn
7from torch.distributed.device_mesh import _get_device_handle
8from torch.distributed.tensor import DeviceMesh, DTensor, init_device_mesh
9from torch.utils._python_dispatch import is_traceable_wrapper_subclass
10
11from ._fsdp_common import _is_composable_with_fsdp, FSDPMeshInfo, HSDPMeshInfo
12from ._fsdp_state import _get_module_fsdp_state
13
14
15def _get_post_forward_mesh_info(
16    reshard_after_forward: Union[bool, int], mesh_info: FSDPMeshInfo
17) -> Optional[FSDPMeshInfo]:
18    shard_mesh_size = mesh_info.shard_mesh_size
19    if not isinstance(reshard_after_forward, (bool, int)):
20        raise ValueError(
21            "reshard_after_forward should be a bool or an int representing the "
22            f"group size to reshard to, not {reshard_after_forward}"
23        )
24    # NOTE: `isinstance(False, int)` returns `True`.
25    if not isinstance(reshard_after_forward, bool) and isinstance(
26        reshard_after_forward, int
27    ):
28        if (
29            reshard_after_forward < 1
30            or reshard_after_forward > shard_mesh_size
31            or shard_mesh_size % reshard_after_forward != 0
32        ):
33            raise ValueError(
34                "If passing reshard_after_forward as an int, it should be a "
35                f"factor of {shard_mesh_size}, not {reshard_after_forward}"
36            )
37        elif reshard_after_forward == 1:
38            reshard_after_forward = False
39        elif reshard_after_forward == shard_mesh_size:
40            reshard_after_forward = True
41    post_forward_mesh_info = None
42    if reshard_after_forward is True:
43        post_forward_mesh_info = mesh_info
44    elif reshard_after_forward is not False:  # int case
45        # For HSDP, we can flatten the two replicate dims into the 0th dim
46        post_forward_mesh_tensor = mesh_info.mesh.mesh.view(-1, reshard_after_forward)
47        post_forward_mesh = DeviceMesh(
48            mesh_info.mesh.device_type, post_forward_mesh_tensor
49        )
50        post_forward_mesh_info = HSDPMeshInfo(
51            post_forward_mesh, shard_mesh_dim=1, replicate_mesh_dim=0
52        )
53    return post_forward_mesh_info
54
55
56def _init_default_fully_shard_mesh() -> DeviceMesh:
57    """Default to global CUDA mesh if possible else global CPU mesh."""
58    if not dist.distributed_c10d.is_initialized():
59        dist.distributed_c10d.init_process_group()
60    default_pg = dist.distributed_c10d._get_default_group()
61    device_type = "cuda" if torch.cuda.is_available() else "cpu"
62    mesh = init_device_mesh(device_type, mesh_shape=(default_pg.size(),))
63    return mesh
64
65
66def _get_device_from_mesh(mesh: DeviceMesh) -> torch.device:
67    if mesh.device_type == "cpu":
68        return torch.device("cpu")
69    device_handle = _get_device_handle(mesh.device_type)
70    return torch.device(mesh.device_type, device_handle.current_device())
71
72
73def _get_managed_modules(root_modules: Tuple[nn.Module, ...]) -> List[nn.Module]:
74    modules: List[nn.Module] = []
75    root_modules_set = set(root_modules)
76    # Track visisted modules to avoid visiting shared modules multiple times
77    visited_modules: Set[nn.Module] = set()
78
79    def dfs(module: nn.Module) -> None:
80        """
81        Runs a DFS to collect managed modules, not recursing into modules with
82        a non-composable API or ``fully_shard`` already applied.
83        """
84        if not _is_composable_with_fsdp(module):
85            return
86        elif (
87            module not in root_modules_set
88            and _get_module_fsdp_state(module) is not None
89        ):
90            return  # nested `fully_shard` module
91        visited_modules.add(module)
92        for submodule in module.children():
93            if submodule not in visited_modules:
94                dfs(submodule)
95        modules.append(module)
96
97    for root_module in root_modules:
98        dfs(root_module)
99    return modules
100
101
102def _verify_managed_param(name: str, param: nn.Parameter) -> None:
103    """
104    Verify if the parameter is accepted by fully_shard. The only restriction now
105    is that the parameter cannot be a scalar tensor (param.numel == 0) since we
106    need at least one dim to shard.
107    """
108    if len(param.shape) == 0:
109        raise ValueError(
110            "fully_shard doesn't support salar parameters. "
111            f"Change {name} to a 1D tensor with numel equal to 1."
112        )
113
114
115def _get_managed_states(
116    modules: List[nn.Module],
117) -> Tuple[List[nn.Parameter], List[torch.Tensor]]:
118    params: List[nn.Parameter] = []
119    buffers: List[torch.Tensor] = []
120    # Track visited parameters/buffers to avoid visiting shared parameters and
121    # buffers multiple times
122    visited_params: Set[nn.Parameter] = set()
123    visited_buffers: Set[torch.Tensor] = set()
124    for module in modules:
125        for name, param in module.named_parameters(recurse=False):
126            if param not in visited_params:
127                _verify_managed_param(name, param)
128                params.append(param)
129                visited_params.add(param)
130        for buffer in module.buffers(recurse=False):
131            if buffer not in visited_buffers:
132                buffers.append(buffer)
133                visited_buffers.add(buffer)
134    return params, buffers
135
136
137def _move_states_to_device(
138    params: List[nn.Parameter],
139    buffers: List[torch.Tensor],
140    device: torch.device,
141) -> None:
142    """
143    We have FSDP move states to device for simpler and faster initialization
144    since FSDP almost always uses CUDA for training. We move parameters/buffers
145    rather than modules since modules to support ignoring parameters/buffers in
146    the future.
147    """
148    # Follow the logic in `nn.Module._apply`
149    for tensor in itertools.chain(params, buffers):
150        if tensor.device == device or tensor.device.type == "meta":
151            # Keep meta-device tensors on meta device for deferred init
152            continue
153        if isinstance(tensor, DTensor):
154            if (dtensor_mesh_type := tensor.device_mesh.device_type) != device.type:
155                raise ValueError(
156                    "Requires DTensor to have mesh of the same type as the FSDP mesh "
157                    f"but got {dtensor_mesh_type} for DTensor and {device.type} for FSDP"
158                )
159            raise AssertionError(
160                f"Expects DTensor to be moved to {dtensor_mesh_type} but got {tensor.device}"
161            )
162        tensor_ = tensor
163        if is_traceable_wrapper_subclass(tensor_):
164            with torch.no_grad():  # avoid autograd increasing C++ refcount by 1
165                tensor_on_device = nn.Parameter(tensor.to(device))
166            torch.utils.swap_tensors(tensor, tensor_on_device)
167        else:
168            tensor.data = tensor.to(device)
169