xref: /aosp_15_r20/external/pytorch/torch/distributed/_composable/fsdp/_fsdp_param.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# mypy: allow-untyped-defs
2import itertools
3from dataclasses import dataclass, field
4from enum import auto, Enum
5from typing import Any, cast, List, Optional, Sequence, Tuple
6
7import torch
8import torch._dynamo.compiled_autograd as ca
9import torch.nn as nn
10from torch._prims_common import make_contiguous_strides_for
11from torch.distributed._functional_collectives import AsyncCollectiveTensor
12from torch.distributed.tensor import DTensor, Replicate, Shard
13from torch.distributed.tensor._dtensor_spec import DTensorSpec, TensorMeta
14from torch.distributed.tensor.device_mesh import _mesh_resources
15from torch.distributed.tensor.placement_types import _StridedShard, Placement
16
17from ._fsdp_api import CPUOffloadPolicy, MixedPrecisionPolicy, OffloadPolicy
18from ._fsdp_common import (
19    _chunk_with_empty,
20    _from_local_no_grad,
21    _get_dim0_chunked_size,
22    _raise_assert_with_print,
23    _to_dtype_if_needed,
24    FSDPMeshInfo,
25    HSDPMeshInfo,
26)
27
28
29"""
30[Note: FSDP tensors]
31FSDP considers the following tensors:
32- Original parameter: parameter passed to :class:`FSDPParam`, i.e. the one
33  on the module when applying FSDP
34- Sharded parameter: sharding the original parameter on dim-0 as a DTensor
35  over the main mesh
36- All-gather inputs: the ``torch.Tensor`` or ``Tensor`` s passed to all-gather,
37  derived from the sharded parameter
38- All-gather output: the ``torch.Tensor`` or ``Tensor`` s resulting from
39  all-gathering the all-gather inputs
40- Unsharded parameter: parameter used for forward/backward computation, derived
41  from the all-gather output; autograd leaf
42
43We define these tensors to describe the general framework that can accomodate
44extensions, where:
45- all-gather-inputs = pre-all-gather-transform(sharded-parameter)
46- unsharded-parameter = post-all-gather-transform(all-gather-outputs)
47
48For the default ``torch.Tensor`` case, there is only one all-gather input, and
49it shares the same underlying tensor data as the sharded parameter, meaning
50that they can be thought of as the same tensors. The same applies for the
51all-gather output and unsharded parameter. For non-``torch.Tensor`` extensions,
52these equivalences may no longer hold due to the pre/post-all-gather
53transforms, and some may have multiple all-gather inputs/outputs (e.g.
54quantized data and scales).
55
56[Note: FSDP and autograd]
57FSDP dynamically frees and allocates the unsharded parameter. Since autograd
58can pack a reference to it or a view to save for backward, we use storage
59resizing to implement the freeing/allocation since that preserves the aliasing.
60This implies that we construct the unsharded parameter object once and write to
61it in-place thereafter. For the default ``torch.Tensor` original parameter
62case, the all-gather output and unsharded parameter share the same
63data, so we use storage resizing on the all-gather output.
64"""
65
66lib = torch.library.Library("fsdp", "FRAGMENT")  # noqa: TOR901
67
68lib.define("set_(Tensor(a!) tensor, Tensor data) -> ()")
69
70
71@torch.library.impl(lib, "set_", "Meta")
72@torch.library.impl(lib, "set_", "CUDA")
73@torch.library.impl(lib, "set_", "CPU")
74def set_(tensor, data):
75    tensor.set_(data)
76
77
78"""
79[Note: Avoiding functionalization for fsdp.set_ and inductor.resize_storage_bytes_(0)]
80
81Currently we don't functionalize `fsdp.set_` op or `inductor.resize_storage_bytes_(0)` op
82(i.e. they show up as a mutation op in the middle of the AOT joint graph).
83
84Reason:
85Traceable FSDP2 compiled autograd BWD graph have the following traits:
86(1) Two inputs of the graph were aliased to each other (one from hook closed-over tensors, one from FWD saved tensors).
87(2) One of them is mutated (set_ and resize_(0) to handle the all-gathered param).
88(3) They are both subclasses.
89The combination of these traits is not supported by AOTAutograd (it's difficult to reason about subclass aliasing).
90So this doesn't work at all for Traceable FSDP2.
91
92The compromise we use is to avoid functionalization for the FSDP2 set_ and resize_(0) ops.
93This avoids the problem above, because from AOTAutograd point-of-view there are no mutations
94that functionalization needs to handle. (Although we need to be careful not to DCE those mutable ops.)
95
96We can avoid this functionalization because:
97(1) The nn.Parameter is never used before its .set_() is called in eager code (i.e. no alias of it is created),
98so it's safe to call .set_() in the middle of the graph to swap out its storage and start using the nn.Parameter downstream.
99(2) We always re-allocate the buffer for nn.Parameter to store the AllGather output and to be used in downstream user ops.
100So calling resize-to-0 in the middle of the graph to free nn.Parameter memory after use should always be okay
101(since we always allocate anew next time we need it, we strictly don't need to keep the old tensor storage around anymore).
102
103Q: But doesn't the torch.compile stack have the "functional graph" assumption in many places?
104A: Yes - this is WIP but we will try to get back to functional graph as early as possible in the lowering process.
105Specifically, we believe we can move both .set_ and .resize_(0) ops to end of graph in AOT joint graph before partitioner
106(i.e. effectively "re-functionalizing" those ops). Put it in another way, we avoid functionalization for those two ops just to
107make AOTAutograd alias analysis happy, and as soon as we are past that point, we "re-functionalize" the graph.
108This requires a custom FX pass but we believe it's not hard to write and maintain.
109
110Q: What's the importance of partitioner not saving views of nn.Parameter as FWD saved tensors?
111A: This is critical: we do want to save FWD nn.Parameter graph input (instead of its view) for BWD use,
112so that downstream ops in BWD graph uses the post-`.set_` nn.Parameter instead of any of its saved views as input.
113This is because .set_ will not update any of the nn.Parameter's views, so BWD downstream ops must use the original
114nn.Parameter in order to see the result of .set_.
115"""
116
117
118@torch.library.impl(lib, "set_", "Functionalize")
119def set__functionalize(tensor, data):
120    torch._sync(tensor)
121    torch._sync(data)
122    # AOTDispatcher needs to know if any inputs had their storages mutated.
123    # (Why? It sometimes detaches inputs before sending them into the graph,
124    #  when it sees that they do not need to have any gradients computed)
125    torch._functionalize_set_storage_changed(tensor)
126    tensor_inner = torch._from_functional_tensor(tensor)
127    data_inner = torch._from_functional_tensor(data)
128    with torch._C._ExcludeDispatchKeyGuard(
129        torch._C.DispatchKeySet(torch._C.DispatchKey.Functionalize)
130    ):
131        torch.ops.fsdp.set_.default(tensor_inner, data_inner)
132
133
134torch.fx.node.has_side_effect(torch.ops.fsdp.set_.default)
135
136
137class ShardedState(Enum):
138    """
139    - ``SHARDED``: The sharded parameter is registered to the module. It is the
140      only contributor to parameter memory.
141    - ``SHARDED_POST_FORWARD``: The unsharded parameter is resharded to a
142      smaller world size. Since this data should not be used for computation,
143      we do not register it to the module. Users should reshard the module
144      before any in-place modifications. Both it and the sharded parameter
145      contribute to parameter memory.
146    - ``UNSHARDED``: The unsharded parameter is registered to the module. Both
147      it and the sharded parameter contribute to parameter memory.
148    """
149
150    SHARDED = auto()
151    SHARDED_POST_FORWARD = auto()
152    UNSHARDED = auto()
153
154
155@dataclass
156class ParamModuleInfo:
157    """
158    For a parameter, this stores the module and the parameter name to be able
159    to do a parameter swap via ``setattr(module, param_name, ...)`` or to get
160    the parameter via ``getattr(module, param_name)``. We additionally save
161    shared modules and shared parameter names to update them accordingly.
162    """
163
164    # Parameter names are unprefixed, e.g. "weight", not "lin.weight"
165    module: nn.Module
166    param_name: str
167    shared_modules: List[nn.Module] = field(default_factory=list)
168    shared_param_names: List[str] = field(default_factory=list)
169
170
171@dataclass
172class ExtensionsData:
173    # User-defined metadata passed from pre to post-all-gather
174    all_gather_metadata: Optional[Any] = None
175    # Save the all-gather input sizes to unflatten the all-gather outputs to ND
176    all_gather_input_sizes: Sequence[torch.Size] = ()  # ND
177
178    def clear(self):
179        self.all_gather_metadata = None
180        self.all_gather_input_sizes = ()
181
182
183class FSDPParam:
184    """
185    This class manages a parameter with FSDP or FSDP variants applied,
186    implementing dim-0 per-parameter sharding.
187    """
188
189    orig_dtype: torch.dtype
190    param_dtype: Optional[torch.dtype]
191    reduce_dtype: Optional[torch.dtype]
192    _orig_size: torch.Size  # ND
193    sharded_size: torch.Size  # ND
194    contiguous_sharded_stride: Tuple[int, ...]
195    padded_sharded_param_size: torch.Size  # ND
196    sharded_post_forward_size: torch.Size  # ND
197    contiguous_sharded_post_forward_stride: Tuple[int, ...]
198    _sharded_param_data: torch.Tensor  # 1D
199    sharded_param: nn.Parameter  # ND
200    _sharded_post_forward_param_data: Optional[torch.Tensor]  # 1D
201    _sharded_post_forward_param: Optional[nn.Parameter]  # ND
202    _unsharded_param: nn.Parameter  # ND
203    unsharded_accumulated_grad: Optional[torch.Tensor]  # ND
204    _sharding_spec: DTensorSpec
205    # DTensor attributes (only defined for DTensor `param`):
206    _tp_spec: DTensorSpec
207    all_gather_outputs: List[torch.Tensor]  # 1D
208    # All-gather extension attributes
209    _extensions_data: ExtensionsData
210    _unsharded_inner_tensors: List[torch.Tensor]
211
212    def __init__(
213        self,
214        param: nn.Parameter,
215        module_info: ParamModuleInfo,
216        mesh_info: FSDPMeshInfo,
217        post_forward_mesh_info: Optional[FSDPMeshInfo],
218        device: torch.device,
219        mp_policy: MixedPrecisionPolicy,
220        offload_policy: OffloadPolicy,
221    ):
222        self._module_info: ParamModuleInfo = module_info
223        self.mesh_info = mesh_info
224        self.post_forward_mesh_info = post_forward_mesh_info
225        self.device = device
226        self.offload_to_cpu: bool = isinstance(offload_policy, CPUOffloadPolicy)
227        self.pin_memory = (
228            self.offload_to_cpu and cast(CPUOffloadPolicy, offload_policy).pin_memory
229        )
230        self.grad_offload_event: Optional[torch.cuda.Event] = None
231        self._init_sharded_param(param, device)
232        if self.post_forward_mesh_info:
233            self._init_sharded_post_forward_param_metadata(param)
234        self._init_extensions()
235        self.all_gather_outputs: List[torch.Tensor] = []
236        self.unsharded_accumulated_grad = None
237        self._param_fqn: Optional[str] = None  # prefixed from root module
238        # TODO: Remove this padding logic once DTensor pads the local tensor:
239        # https://github.com/pytorch/pytorch/issues/113045
240        self._post_load_hook_handle = (
241            module_info.module.register_load_state_dict_post_hook(
242                lambda *args, **kwargs: self.reset_sharded_param()
243            )
244        )
245
246    @torch.no_grad()
247    def _init_sharded_param(self, param: nn.Parameter, device: torch.device):
248        if param.device != device and param.device.type != "meta":
249            raise AssertionError(
250                f"Expects the parameter to already be moved to device {device} but got {param.device}"
251            )
252        # TODO: Replace the sharded DTensor parameter construction logic with
253        # `distribute_tensor` after https://github.com/pytorch/pytorch/issues/116101
254        # TODO: Simplify the following sharded parameter padding logic after
255        # https://github.com/pytorch/pytorch/issues/113045
256        self.is_dtensor = isinstance(param, DTensor)
257        if self.is_dtensor:
258            self._tp_spec = cast(DTensor, param)._spec
259            dp_mesh, tp_mesh = (self.mesh_info.mesh, self._tp_spec.mesh)
260            dp_global_mesh = _mesh_resources.get_root_mesh(dp_mesh)
261            tp_global_mesh = _mesh_resources.get_root_mesh(tp_mesh)
262            if dp_global_mesh != tp_global_mesh or (
263                dp_global_mesh is None or tp_global_mesh is None
264            ):
265                raise AssertionError(
266                    "FSDP requires the DP and TP mesh to have the same parent mesh but got: \n"
267                    f"DP's global mesh: {dp_global_mesh}\nTP's global mesh: {tp_global_mesh}"
268                )
269
270            name_dims_error = "FSDP requires named DeviceMesh dims for ND parallelism"
271            assert dp_mesh.mesh_dim_names is not None, name_dims_error
272            assert tp_mesh.mesh_dim_names is not None, name_dims_error
273            submesh_names = dp_mesh.mesh_dim_names + tp_mesh.mesh_dim_names
274            self._spmd_mesh = dp_global_mesh[submesh_names]
275            if len(self._tp_spec.placements) != 1:
276                raise NotImplementedError(
277                    f"FSDP only supports 1D TP, not {self._tp_spec.placements}"
278                )
279            split_factor = self._tp_spec.num_shards_map[0]
280            assert (
281                2 <= self._spmd_mesh.ndim <= 3
282            ), f"_spmd_mesh.ndim can only be 2 or 3 but got {self._spmd_mesh.ndim}."
283            self._spmd_placements: Tuple[Placement, ...]
284            dp_shard_tp_placement = (
285                (
286                    _StridedShard(0, split_factor=split_factor)
287                    if split_factor > 1
288                    else Shard(0)
289                ),
290                self._tp_spec.placements[0],
291            )
292            if self._spmd_mesh.ndim == 2:
293                self._spmd_placements = dp_shard_tp_placement
294            else:
295                assert self.mesh_info.replicate_mesh_dim == 0
296                self._spmd_placements = (Replicate(),) + dp_shard_tp_placement
297            self._sharding_spec = DTensorSpec(
298                self._spmd_mesh,
299                self._spmd_placements,
300                tensor_meta=self._tp_spec.tensor_meta,
301            )
302            # NOTE: FSDP+TP does not support uneven sharding for now
303            # TODO: enable uneven sharding for FSDP+TP
304            if split_factor > 1:  # FSDP has strided sharding on tensor dim 0
305                num_shards = self._sharding_spec.num_shards_map[0]
306                tensor_size_dim_0 = self._sharding_spec.shape[0]
307                if tensor_size_dim_0 % num_shards != 0:
308                    raise NotImplementedError(
309                        "FSDP+TP sharding does not support uneven sharding for now: "
310                        f"tensor dim 0 has size {tensor_size_dim_0} which cannot be "
311                        f"evenly sharded into {num_shards} shards."
312                    )
313
314            param_data = cast(DTensor, param)._local_tensor
315        else:
316            self._spmd_mesh = self.mesh_info.mesh
317            if isinstance(self.mesh_info, HSDPMeshInfo):
318                self._spmd_placements = (Replicate(), Shard(0))
319            else:
320                self._spmd_placements = (Shard(0),)
321            self._sharding_spec = DTensorSpec(
322                self._spmd_mesh,
323                self._spmd_placements,
324                tensor_meta=TensorMeta(
325                    param.size(),
326                    param.stride(),
327                    param.dtype,
328                ),
329            )
330            param_data = param
331        self._orig_size = param_data.size()
332        self._contiguous_orig_stride = make_contiguous_strides_for(self._orig_size)
333        shard_rank = self.mesh_info.shard_mesh_rank
334        shard_world_size = self.mesh_info.shard_mesh_size
335        chunks = _chunk_with_empty(param_data, shard_world_size, dim=0)
336        sharded_param = chunks[shard_rank]
337        self.sharded_size = _get_dim0_chunked_size(sharded_param, param_data.size())
338        self.contiguous_sharded_stride = make_contiguous_strides_for(self.sharded_size)
339        padded_sharded_size = chunks[0].size()  # 0th always padded
340        padded_sharded_param = param_data.new_zeros(padded_sharded_size)
341        self.padded_sharded_param_size = padded_sharded_param.size()
342        if sharded_param.numel() > 0:
343            padded_sharded_param[: sharded_param.size(0)].copy_(sharded_param)
344        if self.offload_to_cpu and not padded_sharded_param.is_meta:
345            padded_sharded_param = padded_sharded_param.cpu()
346            if self.pin_memory:
347                padded_sharded_param = padded_sharded_param.pin_memory()
348        self._sharded_param_data = padded_sharded_param.view(-1)
349        self.sharded_param = nn.Parameter(
350            self.to_sharded_dtensor(padded_sharded_param[: sharded_param.size(0)])
351        )
352        self.sharded_param.requires_grad_(param.requires_grad)
353        # Let `param_data` be freed normally when its ref count reaches 0 when
354        # the `fully_shard` call returns to allow provided parameters to alias
355        self._setattr_on_modules(self.sharded_param)
356        self.sharded_state = ShardedState.SHARDED
357
358    def _init_sharded_post_forward_param_metadata(self, param: torch.Tensor) -> None:
359        mesh_info = self.post_forward_mesh_info
360        assert mesh_info is not None  # mypy
361        param_data = param._local_tensor if isinstance(param, DTensor) else param
362        chunks = _chunk_with_empty(param_data, mesh_info.shard_mesh_size, dim=0)
363        self.sharded_post_forward_size = _get_dim0_chunked_size(
364            chunks[mesh_info.shard_mesh_rank], param_data.size()
365        )
366        self.contiguous_sharded_post_forward_stride = make_contiguous_strides_for(
367            self.sharded_post_forward_size
368        )
369
370    def init_dtype_attrs(self, mp_policy: MixedPrecisionPolicy):
371        param_dtype, reduce_dtype = (mp_policy.param_dtype, mp_policy.reduce_dtype)
372        self.orig_dtype = self.sharded_param.dtype
373        # Clamp `param_dtype` to `None` if no casting is required
374        if param_dtype == self.orig_dtype:
375            param_dtype = None
376        self.param_dtype = param_dtype
377        self.reduce_dtype = reduce_dtype
378        # None indicates that the mixed precision is not enabled
379
380    def _init_extensions(self) -> None:
381        inner_tensor = self._sharded_local_tensor
382        has_fsdp_pre_all_gather = hasattr(inner_tensor, "fsdp_pre_all_gather")
383        has_fsdp_post_all_gather = hasattr(inner_tensor, "fsdp_post_all_gather")
384        if has_fsdp_pre_all_gather != has_fsdp_post_all_gather:
385            raise AssertionError(
386                "Both fsdp_pre_all_gather and fsdp_post_all_gather should be defined "
387                f"if using all-gather extensions: {inner_tensor}"
388            )
389        if has_fsdp_pre_all_gather:
390            if self.padded_sharded_param_size != self._sharded_local_tensor.size():
391                raise NotImplementedError(
392                    "FSDP all-gather extensions require even sharding on dim-0.\n"
393                    f"{self._orig_size} is not divisible by FSDP world size {self.mesh_info.mesh.size()}."
394                )
395            self._extensions_data = ExtensionsData()
396        self._unsharded_inner_tensors: List[torch.Tensor] = []
397
398    def init_all_gather_outputs(
399        self,
400        all_gather_input_numels: List[int],
401        all_gather_input_dtypes: List[torch.dtype],
402        world_size: int,
403        device: torch.device,
404        force_recreate: bool = False,
405    ):
406        if not force_recreate and len(self.all_gather_outputs) > 0:
407            return  # already initialized
408        self.all_gather_outputs = [
409            torch.empty(torch.Size([numel * world_size]), dtype=dtype, device=device)
410            for numel, dtype in zip(all_gather_input_numels, all_gather_input_dtypes)
411        ]
412
413    def init_unsharded_param(self):
414        """
415        [Note: Invariants for torch.compile Traceable FSDP2]
416        1. Under compile, we always re-populate the content of `self._unsharded_param`
417           per AllGather using the slow path.
418        2. Under compile, we always recreate `self.all_gather_outputs` per AllGather.
419           This is to ensure the buffer creation is internal to the graph and
420           avoid `self.all_gather_outputs` being captured as a graph input.
421        3. Under compile, at the end of `free_unsharded_param()`, we always clean up
422           `self.all_gather_outputs` and `self._unsharded_inner_tensors`,
423           to avoid them being captured as graph output.
424
425        With these invariants, only these tensors will be inputs to the graph:
426        - Sharded parameters
427        - Placeholders for the `self._unsharded_param` nn.Parameter
428        """
429        if not ca.compiled_autograd_enabled and hasattr(
430            self, "_unsharded_param"
431        ):  # after the 1st all-gather
432            inner_tensor = self._sharded_local_tensor
433            if not hasattr(inner_tensor, "fsdp_post_all_gather"):
434                return  # already initialized
435            for tensor in self._unsharded_inner_tensors:
436                alloc_storage(tensor)
437            all_gather_outputs = self._unflatten_all_gather_outputs()
438            inner_tensor.fsdp_post_all_gather(
439                all_gather_outputs,
440                self._extensions_data.all_gather_metadata,
441                self.param_dtype or self.orig_dtype,
442                out=self._unsharded_param,
443            )
444            self._extensions_data.clear()
445            return
446        inner_tensor = self._sharded_local_tensor
447        if not ca.compiled_autograd_enabled and hasattr(
448            inner_tensor, "fsdp_post_all_gather"
449        ):
450            all_gather_outputs = self._unflatten_all_gather_outputs()
451            (
452                unsharded_tensor,
453                self._unsharded_inner_tensors,
454            ) = inner_tensor.fsdp_post_all_gather(
455                all_gather_outputs,
456                self._extensions_data.all_gather_metadata,
457                self.param_dtype or self.orig_dtype,
458            )
459            self._extensions_data.clear()
460        else:
461            # For the default path (no post-all-gather), the all-gather output
462            # gives the unsharded parameter data directly
463            assert len(self.all_gather_outputs) == 1, f"{len(self.all_gather_outputs)}"
464            unsharded_tensor = self.all_gather_outputs[0]
465        unsharded_param = torch.as_strided(
466            unsharded_tensor,
467            self._orig_size,
468            self._contiguous_orig_stride,
469            storage_offset=0,
470        )
471        if self.is_dtensor:
472            unsharded_param = _from_local_no_grad(unsharded_param, self._tp_spec)
473        if hasattr(self, "_unsharded_param"):
474            assert ca.compiled_autograd_enabled
475            with torch.no_grad(), torch.autograd._unsafe_preserve_version_counter(
476                self._unsharded_param
477            ):
478                torch.ops.fsdp.set_.default(self._unsharded_param, unsharded_param)
479        else:
480            self._unsharded_param = nn.Parameter(
481                unsharded_param, requires_grad=self.sharded_param.requires_grad
482            )
483
484    def _unflatten_all_gather_outputs(self) -> Tuple[torch.Tensor, ...]:
485        return tuple(
486            t.view(-1, *s[1:])
487            for t, s in zip(
488                self.all_gather_outputs, self._extensions_data.all_gather_input_sizes
489            )
490        )
491
492    def to_sharded(self) -> None:
493        self._setattr_on_modules(self.sharded_param)
494        self.free_unsharded_param()
495        self.sharded_state = ShardedState.SHARDED
496
497    def to_sharded_post_forward(self) -> None:
498        if self.is_dtensor:
499            raise NotImplementedError(
500                "Resharding to smaller mesh with TP is not supported yet"
501            )
502        self._assert_in_states(ShardedState.UNSHARDED)
503        assert self.post_forward_mesh_info is not None  # mypy
504        assert len(self.all_gather_outputs) == 1
505        shard_world_size = self.post_forward_mesh_info.shard_mesh_size
506        if (numel := self.all_gather_outputs[0].numel()) % shard_world_size != 0:
507            _raise_assert_with_print(
508                f"All-gather output size ({numel}) must be divisible by the shard "
509                f"world size ({shard_world_size})"
510            )
511        shard_rank = self.post_forward_mesh_info.shard_mesh_rank
512        sharded_numel = numel // shard_world_size
513        self._sharded_post_forward_param_data = (
514            self.all_gather_outputs[0].narrow(
515                0, sharded_numel * shard_rank, sharded_numel
516            )
517        ).clone()  # clone to be able to free all-gather output
518        sharded_post_forward_tensor = torch.as_strided(
519            self._sharded_post_forward_param_data,
520            size=self.sharded_post_forward_size,
521            stride=self.contiguous_sharded_post_forward_stride,
522            storage_offset=0,
523        )
524        self._sharded_post_forward_param = nn.Parameter(
525            self.to_sharded_post_forward_dtensor(sharded_post_forward_tensor)
526        )
527        self._setattr_on_modules(self._sharded_post_forward_param)
528        self.free_unsharded_param()
529        self.sharded_state = ShardedState.SHARDED_POST_FORWARD
530
531    def to_unsharded(self) -> None:
532        # Assume that the data has been allocated and all-gathered
533        set_requires_grad_if_needed(self.sharded_param, self._unsharded_param)
534        self._setattr_on_modules(self._unsharded_param)
535        if self.sharded_state == ShardedState.SHARDED_POST_FORWARD:
536            # The data is allocated in the default stream via the post-forward
537            # reshard and must be kept alive for the next all-gather copy-in.
538            # Since we call this method after the copy-out, the data's lifetime
539            # is ensured without further synchronization.
540            self._sharded_post_forward_param = None
541            self._sharded_post_forward_param_data = None  # free
542        self.sharded_state = ShardedState.UNSHARDED
543
544    def _setattr_on_modules(self, param: nn.Parameter) -> None:
545        unsafe_setattr_param(
546            self._module_info.module, self._module_info.param_name, param
547        )
548        for shared_module, shared_param_name in zip(
549            self._module_info.shared_modules, self._module_info.shared_param_names
550        ):
551            unsafe_setattr_param(shared_module, shared_param_name, param)
552
553    def to_sharded_dtensor(self, tensor: torch.Tensor) -> DTensor:
554        """
555        Converts a local tensor representing either the sharded parameter or
556        sharded gradient to DTensor.
557        """
558        if tensor.shape != self.sharded_size:
559            _raise_assert_with_print(
560                f"Expects size {self.sharded_size} but got {tensor.shape}"
561            )
562        return _from_local_no_grad(
563            tensor,
564            self._sharding_spec,
565        )
566
567    def to_sharded_post_forward_dtensor(self, tensor: torch.Tensor) -> DTensor:
568        if tensor.shape != self.sharded_post_forward_size:
569            _raise_assert_with_print(
570                f"Expects size {self.sharded_post_forward_size} but got {tensor.shape}"
571            )
572        assert isinstance(self.post_forward_mesh_info, HSDPMeshInfo)
573        # TODO: Prefer this DTensor to be read-only and generalize the
574        # placement once we support TP.
575        post_forward_sharding_spec = DTensorSpec(
576            self.post_forward_mesh_info.mesh,
577            (Replicate(), Shard(0)),
578            tensor_meta=self._sharding_spec.tensor_meta,
579        )
580        return _from_local_no_grad(tensor, post_forward_sharding_spec)
581
582    def to_accumulated_grad_if_needed(self) -> None:
583        # Access `_unsharded_param` to bypass the sharded state check since we
584        # prefer to reshard before upcasting the gradient to save memory
585        if (
586            self.reduce_dtype is None
587            or self._unsharded_param.grad is None
588            or self._unsharded_param.grad.dtype == self.reduce_dtype
589        ):
590            return
591        unsharded_grad = self._unsharded_param.grad
592        self._unsharded_param.grad = None
593        self.unsharded_accumulated_grad = unsharded_grad.to(self.reduce_dtype)
594
595    def accumulate_unsharded_grad_if_needed(self) -> None:
596        if (
597            self.unsharded_accumulated_grad is not None
598            and self.unsharded_param.grad is not None
599        ):
600            self.unsharded_accumulated_grad += self.unsharded_param.grad
601            self.unsharded_param.grad = None
602
603    def alloc_all_gather_outputs(self) -> None:
604        for tensor in self.all_gather_outputs:
605            alloc_storage(tensor)
606
607    def free_unsharded_param(self) -> None:
608        for tensor in itertools.chain(
609            self.all_gather_outputs, self._unsharded_inner_tensors
610        ):
611            free_storage(tensor)
612        if ca.compiled_autograd_enabled:
613            self.all_gather_outputs = []
614            self._unsharded_inner_tensors = []
615
616    @property
617    def all_gather_inputs(self) -> List[torch.Tensor]:  # 1D
618        self._assert_in_states(ShardedState.SHARDED, ShardedState.SHARDED_POST_FORWARD)
619        if self.sharded_state == ShardedState.SHARDED:
620            if not ca.compiled_autograd_enabled and hasattr(
621                self._sharded_local_tensor, "fsdp_pre_all_gather"
622            ):
623                sharded_local_tensor = self._sharded_local_tensor
624                if self.offload_to_cpu:
625                    sharded_local_tensor = sharded_local_tensor.to(
626                        self.device, non_blocking=True
627                    )
628                (
629                    all_gather_inputs,
630                    self._extensions_data.all_gather_metadata,
631                ) = sharded_local_tensor.fsdp_pre_all_gather(self.mesh_info.mesh)
632                self._extensions_data.all_gather_input_sizes = [
633                    t.size() for t in all_gather_inputs
634                ]
635                return [t.view(-1) for t in all_gather_inputs]
636            sharded_param_data = self._sharded_param_data
637            if self.offload_to_cpu:
638                sharded_param_data = sharded_param_data.to(
639                    self.device, non_blocking=True
640                )
641            return [_to_dtype_if_needed(sharded_param_data, self.param_dtype)]
642        elif self.sharded_state == ShardedState.SHARDED_POST_FORWARD:
643            if not ca.compiled_autograd_enabled and hasattr(
644                self._sharded_local_tensor, "fsdp_pre_all_gather"
645            ):
646                raise NotImplementedError
647            all_gather_input = _to_dtype_if_needed(
648                cast(torch.Tensor, self._sharded_post_forward_param_data),
649                self.param_dtype,
650            )
651            return [all_gather_input]
652        return [torch.empty(0)]  # mypy
653
654    @property
655    def unsharded_param(self) -> nn.Parameter:  # ND
656        self._assert_in_states(ShardedState.UNSHARDED)
657        return self._unsharded_param
658
659    @property
660    def unsharded_grad_data(self) -> torch.Tensor:
661        grad = self.unsharded_param.grad
662        assert grad is not None, "Expects unsharded_param.grad to not be None"
663        return self._get_grad_inner_tensor(grad)
664
665    @property
666    def unsharded_accumulated_grad_data(self) -> torch.Tensor:
667        grad = self.unsharded_accumulated_grad
668        assert grad is not None, "Expects unsharded_accumulated_grad to not be None"
669        return self._get_grad_inner_tensor(grad)
670
671    def _get_grad_inner_tensor(self, grad: torch.Tensor) -> torch.Tensor:
672        if self.is_dtensor:
673            if isinstance(grad, AsyncCollectiveTensor):
674                grad = grad.wait()
675            assert isinstance(grad, DTensor), f"{type(grad)}"
676            if any(pl.is_partial() for pl in grad.placements):
677                placements = [
678                    Replicate() if pl.is_partial() else pl for pl in grad.placements
679                ]
680                grad = grad.redistribute(placements=placements)
681            grad = grad._local_tensor
682        return grad
683
684    @property
685    def _sharded_local_tensor(self) -> torch.Tensor:
686        return cast(DTensor, self.sharded_param)._local_tensor
687
688    def _assert_in_states(self, *states: ShardedState) -> None:
689        if self.sharded_state not in states:
690            _raise_assert_with_print(
691                f"Expects to be in one of {states}, not {self.sharded_state}"
692            )
693
694    def reset_sharded_param(self):
695        # For ops like `nn.Module._apply` or `load_state_dict(assign=True)`
696        # that change the sharded parameter tensor, we may need to re-pad the
697        # sharded local tensor and re-save the reference.
698        module_info = self._module_info
699        new_param = getattr(module_info.module, module_info.param_name)
700        if new_param is not self.sharded_param:
701            if torch.__future__.get_swap_module_params_on_conversion():
702                raise AssertionError(
703                    f"Expects swap_tensors to preserve object but got {new_param} "
704                    f"instead of {self.sharded_param}"
705                )
706            self.sharded_param = new_param
707        local_tensor = new_param._local_tensor
708        if local_tensor.is_meta:
709            return
710        padded_sharded_size = self.padded_sharded_param_size
711        if local_tensor.size() != padded_sharded_size:
712            padded_local_tensor = local_tensor.new_zeros(padded_sharded_size)
713            padded_local_tensor[: local_tensor.size(0)].copy_(local_tensor)
714            local_tensor = padded_local_tensor
715        if self.pin_memory and not local_tensor.is_pinned():
716            local_tensor = local_tensor.cpu().pin_memory()
717        self._sharded_param_data = local_tensor.view(-1)
718        assert isinstance(self.sharded_param, DTensor)  # mypy
719        self.sharded_param._local_tensor = local_tensor[: self.sharded_size[0]]
720
721    def __repr__(self):
722        return f"FSDPParam(fqn={self._param_fqn}, orig_size={self._orig_size})"
723
724
725def alloc_storage(tensor: torch.Tensor) -> None:
726    size = tensor.numel() * tensor.itemsize
727    if (storage := tensor.untyped_storage()).size() != size:
728        storage.resize_(size)
729
730
731def free_storage(tensor: torch.Tensor) -> None:
732    if (storage := tensor.untyped_storage()).size() != 0:
733        storage.resize_(0)
734
735
736# NOTE: These bypass `nn.Module.__setattr__` checks, which incur non-trivial
737# CPU overhead, if the module did not override it. For FSDP, we know we do not
738# need those checks when transitioning between sharded/unsharded parameters.
739def unsafe_setattr_param(
740    module: nn.Module, param_name: str, param: nn.Parameter
741) -> None:
742    if getattr(module.__setattr__, "__func__", None) is nn.Module.__setattr__:
743        module._parameters[param_name] = param
744    else:  # slow path
745        setattr(module, param_name, param)
746
747
748def set_requires_grad_if_needed(
749    src_tensor: torch.Tensor, dst_tensor: torch.Tensor
750) -> None:
751    # Only call `requires_grad_` if needed to avoid the Python <> C++ context
752    # switch overhead
753    if src_tensor.requires_grad != dst_tensor.requires_grad:
754        dst_tensor.requires_grad_(src_tensor.requires_grad)
755