xref: /aosp_15_r20/external/pytorch/torch/distributed/fsdp/_traversal_utils.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1"""
2NOTE: This file must be imported like
3``import torch.distributed.fsdp._traversal_utils`` and not like
4``from torch.distirbuted.fsdp._traversal_utils import ...`` to avoid circular
5imports. For brevity, we may import the file as ``traversal_utils``.
6"""
7
8import collections
9from typing import Deque, List, Set, Tuple
10
11import torch.nn as nn
12from torch.distributed._composable.contract import _get_registry
13from torch.distributed.fsdp._common_utils import _FSDPState, _get_module_fsdp_state
14
15
16"""
17[Note: FSDP State Traversal]
18For the wrapper code path, ``_FSDPState`` is the ``FullyShardedDataParallel``
19module wrapping a fully sharded module, and for the non-wrapper code path,
20``_FSDPState`` is an object that gets embedded on a fully sharded module.
21See [Note: Fully Sharded Module] for the definition.
22
23There are three common traversal idioms: Given a root module,
24- ``_get_fsdp_states()`` returns all ``_FSDPState`` s in the tree.
25- ``get_fsdp_root_states()`` returns all local root ``_FSDPState`` s in the
26tree (i.e. those with ``_is_root == True``).
27- ``_get_fsdp_handles()``returns all ``FlatParamHandle`` s in the tree.
28
29All of these methods must take in the root module (i.e. an ``nn.Module``) and
30not a general ``_FSDPState`` because ``_FSDPState`` does not support a graph
31traversal, whereas ``nn.Module`` has ``nn.Module.modules()`` for traversal.
32"""
33
34
35def _composable(module: nn.Module) -> bool:
36    """
37    Returns if ``module`` can compose with ``fully_shard``.
38    """
39    # TODO: Add any other composable APIs that are mutually exclusive.
40    registry = _get_registry(module)
41    if registry is None:
42        return True
43    return "replicate" not in registry
44
45
46# TODO (awgu): We may be able to remove this function if we retired the
47# `use_orig_params=False` code path since so far we only need the module for
48# `FlatParameter` registration, which is not needed for `use_orig_params=True`.
49def _get_fsdp_states_with_modules(
50    module: nn.Module,
51) -> Tuple[List[_FSDPState], List[nn.Module]]:
52    """
53    Returns a tuple containing:
54    1. A list of the ``_FSDPState`` instances in the module tree rooted at
55    ``module`` without any duplicates and following the ``module.modules()``
56    traversal order (which is assumed to be depth-first).
57    2. A corresponding list of the modules owning the states in the first list.
58
59    For the wrapper code path, both returned lists are the same, each
60    containing all ``FullyShardedDataParallel`` instances. For the composable
61    code path, this returns a list of all composable state instances and a list
62    of the corresponding fully sharded modules. See [Note: Fully Sharded
63    Module].
64
65    NOTE: The traversal does not proceed into any module annotated by an
66    incompatible API (e.g. ``replicate``).
67    """
68    fsdp_states: List[_FSDPState] = []
69    fsdp_modules: List[nn.Module] = []
70    # Track the visited FSDP states since multiple modules may share the same
71    # one and we want to return a de-duplicated list
72    visited_fsdp_states: Set[_FSDPState] = set()
73    # Track the visited modules in case of shared modules, which implies the
74    # module graph is no longer a tree
75    visited_modules: Set[nn.Module] = set()
76
77    # Perform depth-first search from `module` to ensure that we do not
78    # traverse into an incompatible API's subtree (use DFS instead of BFS to
79    # match `.modules()` order)
80    deque: Deque[nn.Module] = collections.deque([module])
81    while deque:
82        submodule = deque.popleft()
83        visited_modules.add(submodule)
84        if not _composable(submodule):
85            continue
86        for child_module in reversed(list(submodule.children())):
87            if child_module not in visited_modules:
88                deque.appendleft(child_module)
89        optional_state = _get_module_fsdp_state(submodule)
90        if optional_state is not None and optional_state not in visited_fsdp_states:
91            visited_fsdp_states.add(optional_state)
92            fsdp_states.append(optional_state)
93            fsdp_modules.append(submodule)
94    return fsdp_states, fsdp_modules
95
96
97def _get_fsdp_states(module: nn.Module) -> List[_FSDPState]:
98    """See :func:`_get_fsdp_states_with_modules`."""
99    fsdp_states, _ = _get_fsdp_states_with_modules(module)
100    return fsdp_states
101
102
103def _get_fsdp_handles(module: nn.Module) -> List:
104    """
105    Returns all ``FlatParamHandle`` s in the module tree rooted at ``module``
106    following the rules in :func:`_get_fsdp_state`.
107    """
108    handles = [
109        fsdp_state._handle
110        for fsdp_state in _get_fsdp_states(module)
111        if fsdp_state._handle is not None
112    ]
113    return handles
114