xref: /aosp_15_r20/external/pytorch/torch/distributed/_composable/fsdp/_fsdp_param_group.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# mypy: allow-untyped-defs
2import contextlib
3import logging
4from typing import Any, cast, Dict, List, NamedTuple, Optional, Set, Tuple
5
6import torch
7import torch._dynamo.compiled_autograd as ca
8import torch.distributed as dist
9import torch.nn as nn
10from torch.distributed.fsdp._common_utils import _named_parameters_with_duplicates
11from torch.profiler import record_function
12from torch.utils._pytree import tree_flatten, tree_unflatten
13from torch.utils.hooks import RemovableHandle
14
15from ._fsdp_api import MixedPrecisionPolicy, OffloadPolicy
16from ._fsdp_collectives import (
17    AllGatherResult,
18    foreach_all_gather,
19    foreach_all_gather_copy_out,
20    foreach_reduce,
21)
22from ._fsdp_common import FSDPMeshInfo, HSDPMeshInfo, TrainingState
23from ._fsdp_param import FSDPParam, ParamModuleInfo, ShardedState
24
25
26logger = logging.getLogger("torch.distributed._composable.fsdp")
27
28_ModuleToHandleDict = Dict[nn.Module, RemovableHandle]  # for state dict
29
30
31"""
32[Note: Overlapping all-gather copy-in and all-gather]
33For implicit forward prefetching, we want to overlap the next copy-in with the
34current all-gather. We do so using a separate copy-in stream. However, since
35we have the all-gather input as a view into the output, we must make sure to
36copy into different memory from the current all-gather's output. Thus, we keep
37a reference to the current all-gather's output and have the next FSDP parameter
38group free it after its copy-in. Finally, we have the last FSDP state flush the
39reference to avoid holding onto memory after forward.
40"""
41
42
43class FSDPCommContext:
44    """This has the communication state shared across FSDP states/parameter groups."""
45
46    def lazy_init(self):
47        if not torch.cuda.is_available():
48            raise RuntimeError("FSDP requires CUDA for streams")
49        # Setting the all-gather/reduce-scatter streams to be higher priority
50        # can help avoid some issues where their copies in/out are delayed and
51        # block computation (this is different from high-pri NCCL streams)
52        high_priority = -1
53        # All-gather state and copy-in stream allow overlapping the next
54        # copy-in with the current all-gather in forward; copy-in overlaps with
55        # reduce-scatter in backward without the separate copy-in stream
56        self.all_gather_copy_in_stream = torch.cuda.Stream(priority=high_priority)
57        # All-gather stream allows overlapping next all-gather with current
58        # forward compute
59        self.all_gather_stream = torch.cuda.Stream(priority=high_priority)
60        # Reduce-scatter stream gives separate execution "thread" for post-
61        # backward logic like pre/post-gradient division and reduce-scatter
62        self.reduce_scatter_stream = torch.cuda.Stream(priority=high_priority)
63        # Run the HSDP all-reduces concurrently with all-gather/reduce-scatter
64        # since collectives use different network resources and can overlap
65        # in the typical intra-node sharding / inter-node replication case
66        self.all_reduce_stream = torch.cuda.Stream()
67        # All-gather/reduce-scatter states keep references to collective
68        # tensors produced in one stream and used in another and accompanying
69        # CUDA events for synchronization
70        self.all_gather_state: Optional[AllGatherState] = None
71        self.reduce_scatter_state: Optional[ReduceScatterState] = None
72        # Post-forward order for explicit backward prefetching
73        self.post_forward_order: List[FSDPParamGroup] = []  # will cause ref cycles
74
75    def get_all_gather_streams(
76        self, training_state: TrainingState
77    ) -> Tuple[torch.cuda.Stream, torch.cuda.Stream]:
78        if training_state in (TrainingState.FORWARD, TrainingState.PRE_BACKWARD):
79            # Use separate streams for implicit prefetching
80            return self.all_gather_copy_in_stream, self.all_gather_stream
81        current_stream = torch.cuda.current_stream()
82        return current_stream, current_stream
83
84
85# See [Note: Overlapping all-gather copy-in and all-gather]
86class AllGatherState(NamedTuple):
87    all_gather_result: AllGatherResult
88    event: torch.cuda.Event  # all-gather copy-out
89
90
91class ReduceScatterState(NamedTuple):
92    reduce_scatter_input: torch.Tensor
93    event: torch.cuda.Event  # reduce-scatter event
94
95
96class FSDPParamGroup:
97    """This class represents a parameter group to communicate together."""
98
99    _orig_dtype: torch.dtype
100    _reduce_dtype: Optional[torch.dtype]
101
102    def __init__(
103        self,
104        params: List[nn.Parameter],
105        modules: Tuple[nn.Module, ...],
106        mesh_info: FSDPMeshInfo,
107        post_forward_mesh_info: Optional[FSDPMeshInfo],
108        device: torch.device,
109        mp_policy: MixedPrecisionPolicy,
110        offload_policy: OffloadPolicy,
111    ):
112        self.modules = modules  # permit ref cycle because 1:1 lifetime
113        param_module_infos = _get_param_module_infos(params, modules)
114        self.fsdp_params = [
115            FSDPParam(
116                param,
117                module_info,
118                mesh_info,
119                post_forward_mesh_info,
120                device,
121                mp_policy,
122                offload_policy,
123            )
124            for param, module_info in zip(params, param_module_infos)
125        ]
126        self.mesh_info = mesh_info
127        self.post_forward_mesh_info = post_forward_mesh_info
128        self.device = device
129        self.mp_policy = mp_policy
130        self._training_state = TrainingState.IDLE
131        # Group's sharded state always matches its parameters' sharded states
132        self._sharded_state = ShardedState.SHARDED
133        self._module_fqn: Optional[str] = None  # prefixed from root module
134        # Only consider resetting sharded parameters once in lazy init since it
135        # can incur nontrivial overhead to reset them
136        self._reset_sharded_params: bool = False
137
138        # - Hook state
139        self._module_to_pre_save_state_dict_hook_handle: _ModuleToHandleDict = {}
140        self._module_to_pre_load_state_dict_hook_handle: _ModuleToHandleDict = {}
141
142        # - Communication and communication/computation overlap
143        self.comm_ctx = FSDPCommContext()
144        # Group's indices in the shared post-forward order
145        self._post_forward_indices: List[int] = []
146        # Whether to reduce gradients at all (whether for FSDP or HSDP)
147        self.reduce_grads: bool = True
148        # Whether to all-reduce gradients for HSDP; only used if
149        # `self.reduce_grads` is true, in which case setting this to false
150        # means reduce-scatter but no all-reduce
151        self.all_reduce_grads: bool = True
152        # Whether to reshard parameters after backward (only useful for
153        # gradient accumulation)
154        self.reshard_after_backward: bool = True
155        # Optional custom reduce-scatter reduce op (e.g. to divide by a
156        # factor other than the shard world size)
157        self.reduce_scatter_reduce_op: Optional[dist.ReduceOp] = None
158
159        # - CUDA events for stream synchronization
160        # Holds the all-gather output buffer, sync objects, and metadata
161        self._all_gather_result: Optional[AllGatherResult] = None
162        # Holds the reduce-scatter/all-reduce view-out CUDA event that marks the end of
163        # the group's post-backward (e.g. reduce-scatter, all-reduce and div), which
164        # should be waited on at the end of backward
165        self._post_reduce_event: Optional[torch.cuda.Event] = None
166        # Holds the reshard-after-forward CUDA event when resharding to a
167        # different world size, which should be waited on in the next unshard
168        self._reshard_after_forward_event: Optional[torch.cuda.Event] = None
169
170        # Only for HSDP, if accumulating gradients without all-reduce, save the
171        # partial reduce output (only reduce-scattered but not all-reduced)
172        self._partial_reduce_output: Optional[torch.Tensor] = None
173
174        # TODO: remove this hook and hook register once 2D state dict is supported.
175        def _raise_not_implemented_if_2d(*args: Any, **kwargs: Any) -> None:
176            raise NotImplementedError(
177                "2D state_dict is under development. Please check "
178                "https://github.com/pytorch/pytorch/issues/129627 for more details."
179            )
180
181        modules_with_2d_params: Set[nn.Module] = set()
182        for fsdp_param in self.fsdp_params:
183            module = fsdp_param._module_info.module
184            if len(fsdp_param._spmd_placements) > 1:
185                modules_with_2d_params.add(module)
186        for module in modules_with_2d_params:
187            module.register_state_dict_pre_hook(_raise_not_implemented_if_2d)
188            module._register_load_state_dict_pre_hook(_raise_not_implemented_if_2d)
189
190    # Initialization #
191    def _init_mp_dtypes(self) -> None:
192        for fsdp_param in self.fsdp_params:
193            fsdp_param.init_dtype_attrs(self.mp_policy)
194        orig_dtypes = {fsdp_param.orig_dtype for fsdp_param in self.fsdp_params}
195        if len(orig_dtypes) != 1:
196            # This can be relaxed if we copy-out for the reduce-scatter
197            raise AssertionError(
198                f"FSDP expects uniform original parameter dtype but got {orig_dtypes}"
199            )
200        self._orig_dtype = next(iter(orig_dtypes))
201        reduce_dtypes = {fsdp_param.reduce_dtype for fsdp_param in self.fsdp_params}
202        if len(reduce_dtypes) != 1:
203            # This can be relaxed if we issue one reduce-scatter per reduce
204            # dtype (but we would need a way for users to specify multiple
205            # reduce dtypes)
206            raise AssertionError(
207                f"FSDP expects uniform reduce dtype but got {reduce_dtypes}"
208            )
209        self._reduce_dtype = next(iter(reduce_dtypes))
210
211    def lazy_init(self):
212        # Lazy init should be idempotent
213        # Users may change or register parameters after construction time.
214        # For example, DoRA (https://arxiv.org/abs/2402.09353) initializes linear magnitudes based on
215        # other parameters (e.g. loaded from the state dict).
216        if self.is_sharded and not self._reset_sharded_params:
217            for fsdp_param in self.fsdp_params:
218                fsdp_param.reset_sharded_param()
219            self._reset_sharded_params = True
220        param_names_on_meta = [
221            fsdp_param._param_fqn
222            for fsdp_param in self.fsdp_params
223            if fsdp_param.sharded_param.device.type == "meta"
224        ]
225        if param_names_on_meta:
226            raise RuntimeError(
227                "FSDP parameters should be materialized from meta device before training, "
228                f"but the following were still on meta device: {param_names_on_meta}\n"
229                "For example, call module.to_empty(device) to materialize to device and "
230                "call module.reset_parameters() on each module to initialize values."
231            )
232        # Initialize mixed precision attributes lazily in case the user changes
233        # the parameter dtypes after construction time but before forward
234        self._init_mp_dtypes()
235        self._register_state_dict_hooks()
236
237    # Runtime #
238    def unshard(self, async_op: bool = False):
239        if self._all_gather_result is not None:  # already called, pending wait
240            return
241        if self.is_unsharded:
242            return  # no-op
243        if self._reshard_after_forward_event is not None:
244            # Resharded parameter data is allocated in the default stream and
245            # used in the all-gather streams
246            self._wait_all_gather_streams_on_event(self._reshard_after_forward_event)
247            self._reshard_after_forward_event = None
248        with record_function(self._with_fqn("FSDP::all_gather")):
249            self._all_gather_result = foreach_all_gather(
250                self.fsdp_params,
251                self._all_gather_process_group,
252                async_op,
253                *self.comm_ctx.get_all_gather_streams(self._training_state),
254                self.device,
255            )
256
257    def wait_for_unshard(self):
258        """
259        1. In forward with implict prefetching, to overlap the current copy-out
260        with the next all-gather, we save a reference to the current all-gather
261        result to free after the next copy-out.
262        2. Otherwise (explicit prefetching or in backward), we free the
263        all-gather result immediately after the current copy-out since we can
264        already overlap the current copy-out with the previous reduce-scatter.
265        """
266        if not self._all_gather_result:
267            return  # no preceding unshard
268        if self._training_state == TrainingState.FORWARD:  # implicit prefetch
269            if prev_all_gather_state := self.comm_ctx.all_gather_state:
270                self._wait_all_gather_streams_on_event(prev_all_gather_state.event)
271                self.comm_ctx.all_gather_state = None  # free the all-gather result
272        with record_function(self._with_fqn("FSDP::all_gather_copy_out")):
273            foreach_all_gather_copy_out(
274                self._all_gather_result,
275                self.fsdp_params,
276                self._all_gather_process_group,
277            )
278        for fsdp_param in self.fsdp_params:
279            fsdp_param.init_unsharded_param()
280        self._to_unsharded()
281        all_gather_copy_out_event = torch.cuda.Event()
282        all_gather_copy_out_event.record()
283        if self._training_state == TrainingState.FORWARD:
284            self.comm_ctx.all_gather_state = AllGatherState(
285                self._all_gather_result, all_gather_copy_out_event
286            )
287        else:
288            self._wait_all_gather_streams_on_event(all_gather_copy_out_event)
289        self._all_gather_result = None  # free unless saved in `all_gather_state`
290
291    def _wait_all_gather_streams_on_event(self, event: torch.cuda.Event):
292        # Calling `unshard` before lazy init means streams are not initialized
293        if hasattr(self.comm_ctx, "all_gather_copy_in_stream"):
294            self.comm_ctx.all_gather_copy_in_stream.wait_event(event)
295        if hasattr(self.comm_ctx, "all_gather_stream"):
296            self.comm_ctx.all_gather_stream.wait_event(event)
297
298    def reshard(self):
299        if self._training_state == TrainingState.FORWARD:
300            if not self._reshard_after_forward:
301                return
302            if self._use_post_forward_mesh:
303                self._to_sharded_post_forward()
304                self._reshard_after_forward_event = torch.cuda.Event()
305                self._reshard_after_forward_event.record()
306                return
307        self._to_sharded()
308
309    def pre_forward(
310        self, module: nn.Module, args: Tuple[Any, ...], kwargs: Dict[str, Any]
311    ) -> Tuple[Tuple[Any, ...], Dict[str, Any]]:
312        if not ca.compiled_autograd_enabled:
313            logger.debug("%s", self._with_fqn("FSDP::pre_forward"))
314        with record_function(self._with_fqn("FSDP::pre_forward")):
315            self._training_state = TrainingState.FORWARD
316            self.unshard()
317            self.wait_for_unshard()
318            args, kwargs = self._register_post_backward_hook(args, kwargs)
319            return args, kwargs
320
321    def post_forward(self, module: nn.Module, input: Any, output: Any):
322        if not ca.compiled_autograd_enabled:
323            logger.debug("%s", self._with_fqn("FSDP::post_forward"))
324        with record_function(self._with_fqn("FSDP::post_forward")):
325            self.reshard()
326            self._record_post_forward()
327            self._training_state = TrainingState.IDLE
328            return output
329
330    def _record_post_forward(self) -> None:
331        # Since a group has one pre-backward unshard for each forward call
332        # before the backward, we record each usage (with multiplicity)
333        post_forward_index = len(self.comm_ctx.post_forward_order)
334        self.comm_ctx.post_forward_order.append(self)
335        self._post_forward_indices.append(post_forward_index)
336
337    def pre_backward(self, default_prefetch: bool, *unused: Any):
338        if self._training_state == TrainingState.PRE_BACKWARD:
339            return
340        if not ca.compiled_autograd_enabled:
341            logger.debug("%s", self._with_fqn("FSDP::pre_backward"))
342        with record_function(self._with_fqn("FSDP::pre_backward")):
343            self._training_state = TrainingState.PRE_BACKWARD
344            self.unshard()  # no-op if prefetched
345            self.wait_for_unshard()
346            if default_prefetch and not ca.compiled_autograd_enabled:
347                self._backward_prefetch()
348
349    def post_backward(self, *unused: Any):
350        if not ca.compiled_autograd_enabled:
351            logger.debug("%s", self._with_fqn("FSDP::post_backward"))
352        self._training_state = TrainingState.POST_BACKWARD
353        with record_function(self._with_fqn("FSDP::post_backward_accumulate")):
354            for fsdp_param in self.fsdp_params:
355                fsdp_param.accumulate_unsharded_grad_if_needed()
356        with record_function(self._with_fqn("FSDP::post_backward_reshard")):
357            if not self.reduce_grads:
358                if self.reshard_after_backward:
359                    self.reshard()
360                for fsdp_param in self.fsdp_params:
361                    fsdp_param.to_accumulated_grad_if_needed()
362                return
363            # Save the autograd-computed gradients before resharding to only
364            # access the unsharded parameters when their data is present
365            fsdp_params_with_grad: List[FSDPParam] = []
366            unsharded_grads: List[torch.Tensor] = []
367            for fsdp_param in self.fsdp_params:
368                # May have an accumulated gradient of the reduce dtype if the
369                # previous backward did not reduce-scatter
370                if fsdp_param.unsharded_accumulated_grad is not None:
371                    fsdp_params_with_grad.append(fsdp_param)
372                    unsharded_grads.append(fsdp_param.unsharded_accumulated_grad_data)
373                    fsdp_param.unsharded_accumulated_grad = None
374                elif fsdp_param.unsharded_param.grad is not None:
375                    fsdp_params_with_grad.append(fsdp_param)
376                    unsharded_grads.append(fsdp_param.unsharded_grad_data)
377                    fsdp_param.unsharded_param.grad = None
378            if self.reshard_after_backward:
379                self.reshard()
380        if len(fsdp_params_with_grad) == 0:
381            return
382        with record_function(self._with_fqn("FSDP::post_backward_reduce")):
383            if self.comm_ctx.reduce_scatter_state is not None:
384                torch.cuda.current_stream().wait_event(
385                    self.comm_ctx.reduce_scatter_state.event
386                )
387                self.comm_ctx.reduce_scatter_state = None
388            (
389                reduce_scatter_input,
390                reduce_scatter_event,
391                self._post_reduce_event,
392                self._partial_reduce_output,
393            ) = foreach_reduce(
394                fsdp_params_with_grad,
395                unsharded_grads,
396                self._reduce_scatter_process_group,
397                self.comm_ctx.reduce_scatter_stream,
398                self._orig_dtype,
399                self._reduce_dtype,
400                self.device,
401                self.reduce_scatter_reduce_op,
402                self._all_reduce_process_group if self._is_hsdp else None,
403                self.comm_ctx.all_reduce_stream,
404                self.all_reduce_grads,
405                self._partial_reduce_output,
406            )
407            self.comm_ctx.reduce_scatter_state = ReduceScatterState(
408                reduce_scatter_input, reduce_scatter_event
409            )
410
411    def finalize_backward(self):
412        if self._post_reduce_event is not None:
413            torch.cuda.current_stream().wait_event(self._post_reduce_event)
414            self._post_reduce_event = None
415        for fsdp_param in self.fsdp_params:
416            if fsdp_param.grad_offload_event is not None:
417                fsdp_param.grad_offload_event.synchronize()
418                fsdp_param.grad_offload_event = None
419        self._post_forward_indices.clear()
420
421    def _backward_prefetch(self) -> None:
422        if self._training_state == TrainingState.PRE_BACKWARD:
423            if not self._post_forward_indices:
424                # Can be cleared if running multiple `backward`s
425                return
426            curr_index = self._post_forward_indices.pop()
427            if (target_index := curr_index - 1) < 0:
428                return
429            # Prefetch naively using the reverse post-forward order, which may
430            # have mistargeted prefetches if not all modules used in forward
431            # are used in this backward
432            target_fsdp_param_group = self.comm_ctx.post_forward_order[target_index]
433            self._prefetch_unshard(target_fsdp_param_group, "backward")
434
435    @staticmethod
436    def _prefetch_unshard(
437        target_fsdp_param_group: "FSDPParamGroup", pass_type: str
438    ) -> None:
439        if pass_type == "backward":
440            training_state = TrainingState.PRE_BACKWARD
441        elif pass_type == "forward":
442            training_state = TrainingState.FORWARD
443        else:
444            raise ValueError(f"Unknown pass type: {pass_type}")
445        target_fqn = target_fsdp_param_group._module_fqn
446        with record_function(
447            f"FSDP::{pass_type}_prefetch for {target_fqn}"
448        ), target_fsdp_param_group.use_training_state(training_state):
449            target_fsdp_param_group.unshard()
450
451    # Utilities #
452    def _to_sharded(self):
453        if not self.is_sharded:
454            for fsdp_param in self.fsdp_params:
455                fsdp_param.to_sharded()
456            self._sharded_state = ShardedState.SHARDED
457
458    def _to_sharded_post_forward(self):
459        if not self.is_sharded_post_forward:
460            for fsdp_param in self.fsdp_params:
461                fsdp_param.to_sharded_post_forward()
462            self._sharded_state = ShardedState.SHARDED_POST_FORWARD
463
464    def _to_unsharded(self):
465        if not self.is_unsharded:
466            for fsdp_param in self.fsdp_params:
467                fsdp_param.to_unsharded()
468            self._sharded_state = ShardedState.UNSHARDED
469
470    @property
471    def is_sharded(self) -> bool:
472        return self._sharded_state == ShardedState.SHARDED
473
474    @property
475    def is_sharded_post_forward(self) -> bool:
476        return self._sharded_state == ShardedState.SHARDED_POST_FORWARD
477
478    @property
479    def is_unsharded(self) -> bool:
480        return self._sharded_state == ShardedState.UNSHARDED
481
482    @contextlib.contextmanager
483    def use_training_state(self, training_state: TrainingState):
484        old_training_state = self._training_state
485        self._training_state = training_state
486        try:
487            yield
488        finally:
489            self._training_state = old_training_state
490
491    # Hook Registration #
492    def _register_post_backward_hook(
493        self, args: Tuple[Any, ...], kwargs: Dict[str, Any]
494    ) -> Tuple[Tuple[Any, ...], Dict[str, Any]]:
495        # Compile relies on `root_post_backward_callback` to call each
496        # `FSDPParamGroup.post_backward`
497        if ca.compiled_autograd_enabled:
498            return args, kwargs
499        if not torch.is_grad_enabled():
500            return args, kwargs
501        args_list, args_spec = tree_flatten(args)
502        kwargs_list, kwargs_spec = tree_flatten(kwargs)
503        args_kwargs_list = list(args_list) + list(kwargs_list)
504        inp_tensor_indices: List[int] = []
505        inp_tensors: List[torch.Tensor] = []
506        for i, obj in enumerate(args_kwargs_list):
507            if torch.is_tensor(obj) and obj.requires_grad:
508                inp_tensor_indices.append(i)
509                inp_tensors.append(obj)
510        if len(inp_tensors) == 0:
511            return args, kwargs  # no tensors that require gradients
512        inp_tensors = RegisterPostBackwardFunction.apply(self, *inp_tensors)
513        for inp_tensor_idx, inp_tensor in zip(inp_tensor_indices, inp_tensors):
514            args_kwargs_list[inp_tensor_idx] = inp_tensor
515        args_list = args_kwargs_list[: len(args_list)]
516        kwargs_list = args_kwargs_list[len(args_list) :]
517        args = tree_unflatten(args_list, args_spec)
518        kwargs = tree_unflatten(kwargs_list, kwargs_spec)
519        return args, kwargs
520
521    def _register_state_dict_hooks(self) -> None:
522        num_pre_save_hooks = len(self._module_to_pre_save_state_dict_hook_handle)
523        num_pre_load_hooks = len(self._module_to_pre_load_state_dict_hook_handle)
524        assert (
525            num_pre_save_hooks == num_pre_load_hooks
526        ), f"Pre-save: {num_pre_save_hooks} pre-load: {num_pre_load_hooks}"
527        if num_pre_save_hooks > 0:
528            return  # already registered
529        modules_with_fsdp_params: Set[nn.Module] = {
530            fsdp_param._module_info.module for fsdp_param in self.fsdp_params
531        }
532
533        def to_sharded_hook(*args: Any, **kwargs: Any) -> None:
534            self._to_sharded()
535
536        for module in modules_with_fsdp_params:
537            self._module_to_pre_save_state_dict_hook_handle[
538                module
539            ] = module.register_state_dict_pre_hook(to_sharded_hook)
540            self._module_to_pre_load_state_dict_hook_handle[
541                module
542            ] = module._register_load_state_dict_pre_hook(to_sharded_hook)
543
544    # Properties #
545    @property
546    def _reshard_after_forward(self) -> bool:
547        return self.post_forward_mesh_info is not None
548
549    @property
550    def _use_post_forward_mesh(self) -> bool:
551        return (
552            self._reshard_after_forward
553            and self.mesh_info != self.post_forward_mesh_info
554        )
555
556    @property
557    def _is_hsdp(self) -> bool:
558        return isinstance(self.mesh_info, HSDPMeshInfo)
559
560    @property
561    def _all_gather_process_group(self) -> dist.ProcessGroup:
562        mesh_info = (
563            cast(FSDPMeshInfo, self.post_forward_mesh_info)
564            if self.is_sharded_post_forward
565            else self.mesh_info
566        )
567        assert isinstance(mesh_info, FSDPMeshInfo)
568        return mesh_info.shard_process_group
569
570    @property
571    def _reduce_scatter_process_group(self) -> dist.ProcessGroup:
572        assert isinstance(self.mesh_info, FSDPMeshInfo)
573        return self.mesh_info.shard_process_group
574
575    @property
576    def _all_reduce_process_group(self) -> dist.ProcessGroup:
577        assert isinstance(self.mesh_info, HSDPMeshInfo)
578        return self.mesh_info.replicate_process_group
579
580    def _with_fqn(self, label: str) -> str:
581        if self._module_fqn:
582            return f"{label} ({self._module_fqn})"
583        return label
584
585    def __repr__(self):
586        return f"FSDPParamGroup(fqn={self._module_fqn})"
587
588
589def _get_param_module_infos(
590    params: List[nn.Parameter], modules: Tuple[nn.Module, ...]
591) -> List[ParamModuleInfo]:
592    """
593    Shared parameter: lin1.weight = lin2.weight
594    Shared module: mlp.lin1 = mlp.lin2
595    We do not remove duplicates when traversing both modules and parameters to
596    find shared modules' parameters and shared parameters within a module.
597    """
598    params_set = set(params)
599    param_to_module_info: Dict[nn.Parameter, ParamModuleInfo] = {}
600    for module in modules:
601        for _, submodule in module.named_modules(remove_duplicate=False):
602            for param_name, param in _named_parameters_with_duplicates(
603                submodule, recurse=False
604            ):
605                if param in params_set:
606                    if param not in param_to_module_info:
607                        param_to_module_info[param] = ParamModuleInfo(
608                            submodule, param_name
609                        )
610                    else:
611                        param_to_module_info[param].shared_modules.append(submodule)
612                        param_to_module_info[param].shared_param_names.append(
613                            param_name
614                        )
615    if len(param_to_module_info) != len(params):
616        raise AssertionError(f"Some parameters are not in the module tree of {module}")
617    return [param_to_module_info[param] for param in params]
618
619
620class RegisterPostBackwardFunction(torch.autograd.Function):
621    @staticmethod
622    def forward(ctx, param_group: FSDPParamGroup, *inputs: torch.Tensor):
623        # All tensors in `inputs` should require gradient
624        ctx.param_group = param_group
625        return inputs
626
627    @staticmethod
628    def backward(ctx, *grads: torch.Tensor):
629        ctx.param_group.post_backward()
630        return (None,) + grads
631