xref: /aosp_15_r20/external/pytorch/torch/distributed/_composable/fsdp/fully_shard.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# mypy: allow-untyped-decorators
2# mypy: allow-untyped-defs
3import functools
4from typing import Any, cast, Dict, Iterable, List, NoReturn, Optional, Type, Union
5
6import torch
7import torch.nn as nn
8from torch.distributed._composable import contract
9from torch.distributed.tensor import DeviceMesh
10from torch.distributed.utils import _get_root_modules
11
12from ._fsdp_api import MixedPrecisionPolicy, OffloadPolicy
13from ._fsdp_common import FSDPMeshInfo, HSDPMeshInfo
14from ._fsdp_init import (
15    _get_device_from_mesh,
16    _get_managed_modules,
17    _get_managed_states,
18    _get_post_forward_mesh_info,
19    _init_default_fully_shard_mesh,
20    _move_states_to_device,
21)
22from ._fsdp_param_group import FSDPParamGroup
23from ._fsdp_state import _get_module_fsdp_state, FSDPState
24
25
26cls_to_fsdp_cls: Dict[Type, Type] = {}
27
28
29# The decorator adds a state object to `module` that can be accessed via
30# `fully_shard.state(module)`. The state object and module are 1:1.
31@contract(state_cls=FSDPState)  # type: ignore[operator]
32def fully_shard(
33    module: Union[nn.Module, List[nn.Module]],
34    *,
35    mesh: Optional[DeviceMesh] = None,
36    reshard_after_forward: Union[bool, int] = True,
37    mp_policy: MixedPrecisionPolicy = MixedPrecisionPolicy(),
38    offload_policy: OffloadPolicy = OffloadPolicy(),
39):
40    """
41    Shard module parameters across data parallel workers.
42
43    This function applies fully sharded data parallelism (FSDP) or a variant to
44    ``module``, a technique for memory savings at the cost of communication.
45    Parameters are sharded across ``mesh``, and in turn, so are their gradients
46    and optimizer states.
47
48    The sharded parameters are all-gathered to construct the unsharded
49    parameters for forward or backward computation. The unsharded parameters
50    are freed after computation to save memory. The gradients are reduced
51    across the mesh and divided by the mesh size for data parallelism. The
52    optimizer step runs on the sharded parameters.
53
54    Each call to ``fully_shard`` constructs one communication group that
55    includes the parameters in ``module.parameters()`` except those already
56    assigned to a group from a nested call. Each group's parameters and its
57    gradients are communicated together in one collective, respectively.
58    Constructing multiple groups across the model (e.g. "layer by layer")
59    allows for peak memory savings and communication/computation overlap.
60
61    Implementation-wise, the sharded parameters are represented as
62    :class:`DTensor` s, sharded on dim-0, and the unsharded parameters are
63    represented as :class:`Tensor` s. A module forward pre-hook all-gathers the
64    parameters, and a module forward hook frees them. Similar backward hooks
65    gather parameters and later free parameters/reduce gradients.
66
67    Args:
68        module (Union[nn.Module, List[nn.Module]): The module or modules to
69            shard with FSDP and group together for communication.
70        mesh (Optional[DeviceMesh]): This data parallel mesh defines the
71            sharding and device. If 1D, then parameters are fully sharded
72            across the 1D mesh (FSDP). If 2D, then parameters are sharded
73            across the 0th dim and replicated across the 1st dim (HSDP). The
74            mesh's device type gives the device type used for communication;
75            if a CUDA or CUDA-like device type, then we use the current device.
76        reshard_after_forward (Union[bool, int]): This controls the parameter
77            behavior after forward and can trade off memory and communication:
78            - If ``True``, then this reshards parameters after forward and
79            all-gathers in backward.
80            - If ``False``, then this keeps the unsharded parameters in memory
81            after forward and avoids the all-gather in backward.
82            - If an ``int``, then this represents the world size to reshard to
83            after forward. It should be a non-trivial divisor of the ``mesh``
84            shard dim size (i.e. excluding 1 and the dim size itself). A choice
85            may be the intra-node size (e.g. ``torch.cuda.device_count()``).
86            This allows the all-gather in backward to be over a smaller world
87            size at the cost of higher memory usage than setting to ``True``.
88            - The root FSDP state has its value specially set to ``False`` as a
89            heuristic since its parameters would typically be immediately
90            all-gathered for backward.
91            - After forward, the parameters registered to the module depend on
92            to this: The registered parameters are the sharded parameters if
93            ``True``; unsharded parameters if ``False``; and the paramters
94            resharded to the smaller mesh otherwise. To modify the parameters
95            between forward and backward, the registered parameters must be the
96            sharded parameters. For ``False`` or an ``int``, this can be done
97            by manually resharding via :meth:`reshard`.
98        mp_policy (MixedPrecisionPolicy): This controls the mixed precision
99            policy, which offers parameter/reduction mixed precision for this
100            module. See :class:`MixedPrecisionPolicy` for details.
101        offload_policy (OffloadPolicy): This controls the offloading policy,
102            which offers parameter/gradient/optimizer state offloading. See
103            :class:`OffloadPolicy` and its subclasses for details.
104    """
105    if isinstance(module, (nn.ModuleList, nn.ModuleDict)):
106        raise ValueError(
107            f"fully_shard does not support containers that do not implement forward: {module}"
108        )
109    mesh = mesh or _init_default_fully_shard_mesh()
110    if mesh.ndim not in (1, 2):
111        raise ValueError(f"fully_shard expects a 1D or 2D DeviceMesh but got {mesh}")
112    elif mesh.ndim == 1:
113        mesh_info = FSDPMeshInfo(mesh, shard_mesh_dim=0)
114    else:
115        mesh_info = HSDPMeshInfo(mesh, shard_mesh_dim=1, replicate_mesh_dim=0)
116    device = _get_device_from_mesh(mesh)
117    post_forward_mesh_info = _get_post_forward_mesh_info(
118        reshard_after_forward, mesh_info
119    )
120
121    arg_module = module
122    modules = (
123        (module,) if isinstance(module, nn.Module) else tuple(_get_root_modules(module))
124    )
125    state = fully_shard.state(modules[0])
126    state.init(modules, device, mp_policy)
127
128    managed_modules = _get_managed_modules(modules)
129    params, buffers = _get_managed_states(managed_modules)
130    _move_states_to_device(params, buffers, device)
131    if params:
132        state._fsdp_param_group = FSDPParamGroup(
133            params,
134            modules,
135            mesh_info,
136            post_forward_mesh_info,
137            device,
138            mp_policy,
139            offload_policy,
140        )
141
142    # For Dynamo
143    for managed_module in managed_modules:
144        managed_module._is_fsdp_managed_module = True  # type: ignore[assignment]
145        managed_module._fsdp_use_orig_params = True  # type: ignore[assignment]
146
147    # Place FSDP leftmost for highest priority in the method resolution order
148    for module in modules:
149        cls = module.__class__
150        new_cls = cls_to_fsdp_cls.get(cls, None)
151        if not new_cls:
152            dct = {"__deepcopy__": unimplemented_deepcopy}
153            new_cls = type(f"FSDP{cls.__name__}", (FSDPModule, cls), dct)
154            cls_to_fsdp_cls[cls] = new_cls
155        module.__class__ = new_cls
156    return arg_module
157
158
159def unimplemented_deepcopy(*args: Any, **kwargs: Any) -> NoReturn:
160    raise AssertionError(
161        "FSDP does not support deepcopy. Please use state dict for serialization."
162    )
163
164
165class FSDPModule:
166    def __new__(cls, *args, **kwargs):
167        """
168        Override ``__new__`` to remove the FSDP class and directly construct
169        the original class for cases like indexing into a container module.
170        """
171        # Use index 2 since 0 is the dynamically constructed `FSDP<...>` class
172        # and index 1 is the `FSDPModule` class itself
173        orig_cls = cls.__mro__[2]
174        self = orig_cls.__new__(orig_cls, *args, **kwargs)
175        self.__init__(*args, **kwargs)
176        return self
177
178    def reshard(self) -> None:
179        """
180        Reshards the module's parameters, registering the sharded parameters
181        to the module and freeing the unsharded parameters if needed. This
182        method is *not* recursive.
183        """
184        state = self._get_fsdp_state()
185        if fsdp_param_group := state._fsdp_param_group:
186            fsdp_param_group.reshard()
187
188    def unshard(self, async_op: bool = False) -> Optional["UnshardHandle"]:
189        """
190        Unshards the module's parameters by allocating memory and all-gathering
191        the parameters. This method is *not* recursive.
192
193        Args:
194            async_op (bool): If ``True``, then returns a :class:`UnshardHandle`
195                that has a :meth:`wait` method to wait on the unshard op. If
196                ``False``, then returns ``None`` and waits on the handle inside
197                this function.
198
199        .. warning:: This method is experimental and subject to change.
200
201        .. note:: If ``async_op=True``, then the user does not have to call
202            :meth:`wait` on the returned handle if waiting on the unshard op
203            in the module's pre-forward is tolerable. FSDP will wait on the
204            pending unshard op in the pre-forward automatically.
205        """
206        state = self._get_fsdp_state()
207        fsdp_param_group = state._fsdp_param_group
208        if fsdp_param_group is not None:
209            fsdp_param_group.lazy_init()
210            fsdp_param_group.unshard(async_op=async_op)
211        handle = UnshardHandle(fsdp_param_group)
212        if async_op:
213            return handle
214        handle.wait()
215        return None
216
217    def set_is_last_backward(self, is_last_backward: bool) -> None:
218        """
219        Sets whether the next backward is the last one, meaning that FSDP
220        should wait for gradient reduction to finish and clear internal data
221        structures used for explicit prefetching.
222        """
223        state = self._get_fsdp_state()
224        state._state_ctx.is_last_backward = is_last_backward
225
226    def set_requires_gradient_sync(
227        self, requires_gradient_sync: bool, *, recurse: bool = True
228    ) -> None:
229        """
230        Sets if the module should sync gradients. This can be used to implement
231        gradient accumulation without communication. For HSDP, this controls
232        both reduce-scatter and all-reduce together.
233
234        Args:
235            requires_gradient_sync (bool): Whether to reduce gradients for the
236                module's parameters.
237            recurse (bool): Whether to set for all submodules or just the
238                passed-in module.
239        """
240        self_module = cast(nn.Module, self)
241        modules = list(self_module.modules()) if recurse else [self_module]
242        for module in modules:
243            if isinstance(module, FSDPModule):
244                state = module._get_fsdp_state()
245                if fsdp_param_group := state._fsdp_param_group:
246                    fsdp_param_group.reduce_grads = requires_gradient_sync
247                    fsdp_param_group.all_reduce_grads = requires_gradient_sync
248
249    def set_requires_all_reduce(
250        self, requires_all_reduce: bool, *, recurse: bool = True
251    ) -> None:
252        """
253        Sets if the module should all-reduce gradients. This can be used to
254        implement gradient accumulation with only reduce-scatter but not
255        all-reduce for HSDP.
256        """
257        self_module = cast(nn.Module, self)
258        modules = list(self_module.modules()) if recurse else [self_module]
259        for module in modules:
260            if isinstance(module, FSDPModule):
261                state = module._get_fsdp_state()
262                if fsdp_param_group := state._fsdp_param_group:
263                    fsdp_param_group.all_reduce_grads = requires_all_reduce
264
265    def set_reshard_after_backward(
266        self, reshard_after_backward: bool, *, recurse: bool = True
267    ) -> None:
268        """
269        Sets if the module should reshard parameters after backward. This can
270        be used during gradient accumulation to trade off higher memory for
271        reduced communication.
272
273        Args:
274            reshard_after_backward (bool): Whether to reshard parameters after
275                backward.
276            recurse (bool): Whether to set for all submodules or just the
277                passed-in module.
278        """
279        self_module = cast(nn.Module, self)
280        modules = list(self_module.modules()) if recurse else [self_module]
281        for module in modules:
282            if isinstance(module, FSDPModule):
283                state = module._get_fsdp_state()
284                if fsdp_param_group := state._fsdp_param_group:
285                    fsdp_param_group.reshard_after_backward = reshard_after_backward
286
287    def set_modules_to_forward_prefetch(self, modules: List["FSDPModule"]) -> None:
288        """
289        Sets the FSDP modules for which this FSDP module should explicitly
290        prefetch all-gathers in forward. The prefetching runs after this
291        module's all-gather copy-out.
292
293        Passing a singleton list containing the next FSDP module gives the same
294        all-gather overlap behavior as the default overlap behavior, except the
295        prefetched all-gather is issued earlier from the CPU. Passing a list
296        with at least length two is required for more aggressive overlap and
297        will use more reserved memory.
298
299        Args:
300            modules (List[FSDPModule]): FSDP modules to prefetch.
301        """
302        _assert_all_fsdp_modules(modules)
303        self._get_fsdp_state()._states_to_forward_prefetch = [
304            module._get_fsdp_state() for module in modules
305        ]
306
307    def set_modules_to_backward_prefetch(self, modules: List["FSDPModule"]) -> None:
308        """
309        Sets the FSDP modules for which this FSDP module should explicitly
310        prefetch all-gathers in backward. This overrides the default backward
311        pretching implementation that prefetches the next FSDP module based on
312        the reverse post-forward order.
313
314        Passing a singleton list containing the previous FSDP module gives the
315        same all-gather overlap behavior as the default overlap behavior.
316        Passing a list with at least length two is required for more aggressive
317        overlap and will use more reserved memory.
318
319        Args:
320            modules (List[FSDPModule]): FSDP modules to prefetch.
321        """
322        _assert_all_fsdp_modules(modules)
323        self._get_fsdp_state()._states_to_backward_prefetch = [
324            module._get_fsdp_state() for module in modules
325        ]
326
327    def set_post_optim_event(self, event: torch.cuda.Event) -> None:
328        """
329        Sets a post-optimizer-step event for the root FSDP module to wait the
330        all-gather streams on.
331
332        By default, the root FSDP module waits the all-gather streams on the
333        current stream to ensure that the optimizer step has finished before
334        all-gathering. However, this may introduce false dependencies if
335        there is unrelated computation after the optimizer step. This API
336        allows the user to provide their own event to wait on. After the root
337        waits on the event, the event is discarded, so this API should be
338        called with a new event each iteration.
339
340        Args:
341            event (torch.cuda.Event): Event recorded after the optimizer step
342                to wait all-gather streams on.
343        """
344        self._get_fsdp_state()._state_ctx.post_optim_event = event
345
346    def set_reduce_scatter_divide_factor(self, factor: float) -> None:
347        """
348        Sets a custom divide factor for the reduce-scatter. This becomes a
349        custom reduce op using NCCL's PreMulSum, which allows multiplying by
350        the factor before reduction.
351
352        Args:
353            factor (float): Custom divide factor.
354        """
355        state = self._get_fsdp_state()
356        if (fsdp_param_group := state._fsdp_param_group) is not None:
357            mul_factor = 1.0 / float(factor)
358            reduce_op = torch.distributed._make_nccl_premul_sum(mul_factor)
359            fsdp_param_group.reduce_scatter_reduce_op = reduce_op
360
361    def _get_fsdp_state(self) -> FSDPState:
362        if (state := _get_module_fsdp_state(cast(nn.Module, self))) is None:
363            raise AssertionError(f"No FSDP state found on {self}")
364        return state
365
366    def _apply(self, *args: Any, **kwargs: Any) -> Any:
367        # Reshard to ensure that sharded parameters are registered
368        self.reshard()
369        ret = super()._apply(*args, **kwargs)  # type: ignore[misc]
370        state = self._get_fsdp_state()
371        if not (fsdp_param_group := state._fsdp_param_group):
372            return ret
373        # TODO: Remove this padding logic once DTensor pads the local tensor:
374        # https://github.com/pytorch/pytorch/issues/113045
375        with torch.no_grad():
376            for fsdp_param in fsdp_param_group.fsdp_params:
377                fsdp_param.reset_sharded_param()
378        return ret
379
380
381class UnshardHandle:
382    """
383    A handle to wait on the unshard op.
384
385    Args:
386        fsdp_param_group (FSDPParamGroup, optional): FSDP parameter group to
387            unshard. This should be ``None`` iff the FSDP module does not
388            manage any parameters, meaning the unshard is a no-op.
389    """
390
391    def __init__(self, fsdp_param_group: Optional[FSDPParamGroup]):
392        self._fsdp_param_group = fsdp_param_group
393
394    def wait(self):
395        """
396        Waits on the unshard op.
397
398        This ensures that the current stream can use the unsharded parameters,
399        which are now registered to the module.
400        """
401        if self._fsdp_param_group is not None:
402            self._fsdp_param_group.wait_for_unshard()
403            # Avoid keeping a reference
404            self._fsdp_param_group = None
405
406
407def register_fsdp_forward_method(module: nn.Module, method_name: str) -> None:
408    """
409    Registers a method on ``module`` to be a forward method for FSDP.
410
411    FSDP only knows to run its pre-forward and post-forward hooks on the
412    default :meth:`nn.Module.forward` method. This function patches a user
413    specified method to run the pre/post-forward hooks before/after the method,
414    respectively. If ``module`` is not an :class:`FSDPModule`, then this is a
415    no-op.
416
417    Args:
418        module (nn.Module): Module to register the forward method on.
419        method_name (str): Name of the forward method.
420    """
421    if not isinstance(module, FSDPModule):
422        # Make no-op to allow including both when using/not using FSDP
423        return
424    if not hasattr(module, method_name):
425        raise ValueError(f"{type(module)} does not have a method {method_name}")
426    orig_method = getattr(module, method_name)
427
428    @functools.wraps(orig_method)
429    def wrapped_method(self, *args, **kwargs):
430        fsdp_state = self._get_fsdp_state()
431        args, kwargs = fsdp_state._pre_forward(self, args, kwargs)
432        out = orig_method(*args, **kwargs)
433        return fsdp_state._post_forward(self, args, out)
434
435    # Use `__get__` to make `wrapped_method` an instance method
436    setattr(
437        module,
438        method_name,
439        wrapped_method.__get__(module, type(module)),  # type:ignore[attr-defined]
440    )
441
442
443def _assert_all_fsdp_modules(modules: Iterable[Any]) -> None:
444    for module in modules:
445        if not isinstance(module, FSDPModule):
446            raise ValueError(f"Expects FSDPModule but got {type(module)}: {module}")
447