xref: /aosp_15_r20/external/pytorch/torch/distributed/_composable/replicate.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# mypy: allow-untyped-decorators
2# mypy: allow-untyped-defs
3import weakref
4from typing import Any, cast, Dict, Iterable, List, NoReturn, Optional, Set, Tuple
5
6import torch
7import torch.nn as nn
8from torch.distributed._composable_state import _State
9from torch.nn.parallel import DistributedDataParallel
10
11from .contract import _get_registry, contract
12
13
14_ROOT_MODULE_PREFIX = ""
15
16
17class _ReplicateState(_State):
18    def __init__(self) -> None:
19        super().__init__()
20        self.module: nn.Module = nn.ParameterList()
21        self.has_initialized: bool = False
22        self._param_list: nn.ParameterList = nn.ParameterList()
23        # TODO(@fegin): this variable is originally create for testing, we
24        # should remove this if possible.
25        self._orig_module = self.module
26        self._param_names: List[str] = []
27        self._no_sync: bool = False
28        self._init_args: Optional[Tuple[Any, ...]] = None
29        self._init_kwargs: Dict[str, Any] = {}
30        self._comm_hook_args: List[Any] = []
31
32    def _collect_params(
33        self,
34        module: nn.Module,
35        ignored_modules: Set[nn.Module],
36        ignored_params: Set[nn.Parameter],
37        prefix: str = _ROOT_MODULE_PREFIX,
38    ) -> None:
39        # skip if managed by fully_sharded API
40        if _is_fully_sharded(module):
41            return
42
43        # if a module is ignored, all descendants of the module are ignored.
44        if module in ignored_modules:
45            return
46
47        recurse_prefix = (
48            f"{prefix}." if prefix != _ROOT_MODULE_PREFIX else _ROOT_MODULE_PREFIX
49        )
50
51        for n, p in module.named_parameters(recurse=False):
52            if p not in ignored_params:
53                self._param_list.append(p)
54                self._param_names.append(f"{recurse_prefix}{n}")
55
56        for name, child_module in module.named_children():
57            self._collect_params(
58                child_module,
59                ignored_modules,
60                ignored_params,
61                prefix=f"{recurse_prefix}{name}",
62            )
63
64    def lazy_init(self) -> None:
65        @torch._disable_dynamo(recursive=True)
66        def _lazy_init():
67            assert self._init_args is not None
68            self.init(*self._init_args, **self._init_kwargs)
69            self.register_comm_hook()
70            self._init_args = ()
71            self._init_kwargs = {}
72
73        _lazy_init()
74
75    def init(
76        self,
77        module: nn.Module,
78        ignored_modules: Set[nn.Module],
79        **kwargs,
80    ) -> None:
81        if self.has_initialized:
82            return
83
84        self.has_initialized = True
85
86        device_mesh = kwargs.get("device_mesh", None)
87        self.module = module
88        ignored_params = {p for m in ignored_modules for p in m.parameters()}
89        for submodule in module.modules():
90            if _is_fully_sharded(submodule):
91                ignored_params.update(submodule.parameters())
92        from torch.distributed.tensor.parallel.ddp import _localize_dtensor
93
94        _localize_dtensor(module, ignored_params=ignored_params)
95        self._collect_params(module, ignored_modules, ignored_params)
96
97        if "device_id" in kwargs:
98            # replicate() supports a small usability enhancement where
99            # user can pass in device_id as a Union[int, torch.device] even for
100            # CPU devices so users don't have to change code for CPU/GPU runs.
101            # We derive the right device_ids to feed into DDP to support this.
102            if kwargs["device_id"] is not None:
103                device_id = kwargs["device_id"]
104                # Convert to device_ids that DDP expects.
105                if isinstance(device_id, torch.device) and device_id.type == "cpu":
106                    # CPU modules receive device_ids None
107                    kwargs["device_ids"] = None
108                else:
109                    # GPU modules expect device_ids=[cuda_device]
110                    kwargs["device_ids"] = [device_id]
111            else:
112                kwargs["device_ids"] = None
113            kwargs.pop("device_id")
114
115        self._ddp = DistributedDataParallel(self._param_list, **kwargs)
116        # Weakref to the DDP instance is currently only used for testing.
117        replicate.state(self.module)._ddp_weakref = weakref.ref(self._ddp)
118
119    def register_comm_hook(self) -> None:
120        for comm_args, comm_kwargs in self._comm_hook_args:
121            self._ddp.register_comm_hook(*comm_args, **comm_kwargs)
122        self._comm_hook_args.clear()
123
124    def record_init_args(self, *args, **kwargs) -> None:
125        self._init_args = args
126        self._init_kwargs = kwargs
127
128    def forward_pre_hook(
129        self, module: nn.Module, args: Tuple[Any, ...], kwargs: Dict[str, Any]
130    ) -> Any:
131        if self._init_args or self._init_kwargs:
132            self.lazy_init()
133        self._ddp.require_backward_grad_sync = not self._no_sync
134        return self._ddp._pre_forward(*args, **kwargs)
135
136    def forward_post_hook(
137        self,
138        module: nn.Module,
139        input: Tuple[torch.Tensor],
140        output: torch.Tensor,
141    ) -> torch.Tensor:
142        return self._ddp._post_forward(output)
143
144
145def unimplemented_deepcopy(*args: Any, **kwargs: Any) -> NoReturn:
146    raise AssertionError(
147        "DDP does not support deepcopy. Please use state dict for serialization."
148    )
149
150
151# Follow the same pattern as FSDP/fully_shard
152class DDP:
153    def __new__(cls, *args, **kwargs):
154        """
155        Override ``__new__`` to remove the DDP class and directly construct
156        the original class for cases like indexing into a container module.
157        """
158        # Use index 2 since 0 is the dynamically constructed `DDP<...>` class
159        # and index 1 is the `DDP` class itself
160        orig_cls = cls.__mro__[2]
161        return orig_cls.__new__(orig_cls, *args, **kwargs)
162
163    def set_requires_gradient_sync(self, requires_gradient_sync: bool) -> None:
164        """
165        Sets if the module should sync gradients. This can be used to implement
166        gradient accumulation without communication.
167
168        Args:
169            requires_gradient_sync (bool): Whether to reduce gradients for the
170                module's parameters.
171        """
172        replicate.state(self)._no_sync = not requires_gradient_sync
173
174    def register_comm_hook(self, *args, **kwargs) -> None:
175        replicate.state(self)._comm_hook_args.append((args, kwargs))
176
177
178@contract(state_cls=_ReplicateState)
179def replicate(
180    module: nn.Module,
181    ignored_modules: Optional[Iterable[torch.nn.Module]] = None,
182    **kwargs,
183) -> nn.Module:
184    r"""Replicates a module
185
186    Args:
187        module (torch.nn.Module): module to replicate
188
189    Example::
190        >>> # xdoctest: +REQUIRES(module:torch._C._distributed_c10d)
191        >>> module = nn.Linear(3, 3)
192        >>> replicate(module)
193    """
194    torch._C._log_api_usage_once("torch.distributed.replicate")
195
196    # TODO(fegin): using kwargs is not a good idea if we would like to make
197    # replicate a formal API to replace DDP.
198    if "device_id" in kwargs:
199        if not isinstance(kwargs["device_id"], (int, torch.device)):
200            raise RuntimeError(
201                "Expected device_id to be int or torch.device, "
202                f"but got {type(kwargs['device_id'])}"
203            )
204
205    if _is_fully_sharded(module):
206        raise RuntimeError(
207            "Cannot apply `replicate()` on a Module already managed by `fully_shard`"
208        )
209
210    if ignored_modules is None:
211        ignored_modules = {}
212    else:
213        ignored_modules = set(ignored_modules)
214
215    state = cast(_ReplicateState, replicate.state(module))
216    module.register_forward_pre_hook(state.forward_pre_hook, with_kwargs=True)
217    device_mesh = kwargs.get("device_mesh", None)
218    if device_mesh is not None:
219        from torch.distributed.device_mesh import _mesh_resources
220
221        root_mesh = _mesh_resources.get_root_mesh(device_mesh)
222        # if a root mesh is not the same as device_mesh,
223        # meaning the device_mesh is sliced out from the root mesh.
224        if root_mesh != device_mesh:
225            # TODO: This is a temporary work around to enable DDP + TP.
226            # We should do the logic in DDP so that the 2D implementation is
227            # sound and the state_dict works out of the box.
228            #
229            # This won't conflict with what is done in DDP class as the module
230            # replicate is going to pass is NOT the original module.
231            from torch.distributed.tensor.parallel.ddp import (
232                _localize_dtensor,
233                _reconstruct_dtensor,
234            )
235
236            module.register_forward_pre_hook(_reconstruct_dtensor)
237            module.register_forward_hook(_localize_dtensor)
238
239    module.register_forward_hook(state.forward_post_hook)  # type: ignore[arg-type]
240
241    state.record_init_args(module, ignored_modules, **kwargs)
242
243    # Place DDP leftmost for highest priority in the method resolution order
244    cls = module.__class__
245    dct = {"__deepcopy__": unimplemented_deepcopy}
246    new_cls = type(f"DDP{cls.__name__}", (DDP, cls), dct)
247    module.__class__ = new_cls
248    return module
249
250
251def _is_fully_sharded(module: nn.Module) -> bool:
252    r"""Check if module is marked with fully_shard."""
253    registry = _get_registry(module)
254    if registry is None:
255        return False
256    return "fully_shard" in registry
257