1""" 2NOTE: This file must be imported like 3``import torch.distributed.fsdp._traversal_utils`` and not like 4``from torch.distirbuted.fsdp._traversal_utils import ...`` to avoid circular 5imports. For brevity, we may import the file as ``traversal_utils``. 6""" 7 8import collections 9from typing import Deque, List, Set, Tuple 10 11import torch.nn as nn 12from torch.distributed._composable.contract import _get_registry 13from torch.distributed.fsdp._common_utils import _FSDPState, _get_module_fsdp_state 14 15 16""" 17[Note: FSDP State Traversal] 18For the wrapper code path, ``_FSDPState`` is the ``FullyShardedDataParallel`` 19module wrapping a fully sharded module, and for the non-wrapper code path, 20``_FSDPState`` is an object that gets embedded on a fully sharded module. 21See [Note: Fully Sharded Module] for the definition. 22 23There are three common traversal idioms: Given a root module, 24- ``_get_fsdp_states()`` returns all ``_FSDPState`` s in the tree. 25- ``get_fsdp_root_states()`` returns all local root ``_FSDPState`` s in the 26tree (i.e. those with ``_is_root == True``). 27- ``_get_fsdp_handles()``returns all ``FlatParamHandle`` s in the tree. 28 29All of these methods must take in the root module (i.e. an ``nn.Module``) and 30not a general ``_FSDPState`` because ``_FSDPState`` does not support a graph 31traversal, whereas ``nn.Module`` has ``nn.Module.modules()`` for traversal. 32""" 33 34 35def _composable(module: nn.Module) -> bool: 36 """ 37 Returns if ``module`` can compose with ``fully_shard``. 38 """ 39 # TODO: Add any other composable APIs that are mutually exclusive. 40 registry = _get_registry(module) 41 if registry is None: 42 return True 43 return "replicate" not in registry 44 45 46# TODO (awgu): We may be able to remove this function if we retired the 47# `use_orig_params=False` code path since so far we only need the module for 48# `FlatParameter` registration, which is not needed for `use_orig_params=True`. 49def _get_fsdp_states_with_modules( 50 module: nn.Module, 51) -> Tuple[List[_FSDPState], List[nn.Module]]: 52 """ 53 Returns a tuple containing: 54 1. A list of the ``_FSDPState`` instances in the module tree rooted at 55 ``module`` without any duplicates and following the ``module.modules()`` 56 traversal order (which is assumed to be depth-first). 57 2. A corresponding list of the modules owning the states in the first list. 58 59 For the wrapper code path, both returned lists are the same, each 60 containing all ``FullyShardedDataParallel`` instances. For the composable 61 code path, this returns a list of all composable state instances and a list 62 of the corresponding fully sharded modules. See [Note: Fully Sharded 63 Module]. 64 65 NOTE: The traversal does not proceed into any module annotated by an 66 incompatible API (e.g. ``replicate``). 67 """ 68 fsdp_states: List[_FSDPState] = [] 69 fsdp_modules: List[nn.Module] = [] 70 # Track the visited FSDP states since multiple modules may share the same 71 # one and we want to return a de-duplicated list 72 visited_fsdp_states: Set[_FSDPState] = set() 73 # Track the visited modules in case of shared modules, which implies the 74 # module graph is no longer a tree 75 visited_modules: Set[nn.Module] = set() 76 77 # Perform depth-first search from `module` to ensure that we do not 78 # traverse into an incompatible API's subtree (use DFS instead of BFS to 79 # match `.modules()` order) 80 deque: Deque[nn.Module] = collections.deque([module]) 81 while deque: 82 submodule = deque.popleft() 83 visited_modules.add(submodule) 84 if not _composable(submodule): 85 continue 86 for child_module in reversed(list(submodule.children())): 87 if child_module not in visited_modules: 88 deque.appendleft(child_module) 89 optional_state = _get_module_fsdp_state(submodule) 90 if optional_state is not None and optional_state not in visited_fsdp_states: 91 visited_fsdp_states.add(optional_state) 92 fsdp_states.append(optional_state) 93 fsdp_modules.append(submodule) 94 return fsdp_states, fsdp_modules 95 96 97def _get_fsdp_states(module: nn.Module) -> List[_FSDPState]: 98 """See :func:`_get_fsdp_states_with_modules`.""" 99 fsdp_states, _ = _get_fsdp_states_with_modules(module) 100 return fsdp_states 101 102 103def _get_fsdp_handles(module: nn.Module) -> List: 104 """ 105 Returns all ``FlatParamHandle`` s in the module tree rooted at ``module`` 106 following the rules in :func:`_get_fsdp_state`. 107 """ 108 handles = [ 109 fsdp_state._handle 110 for fsdp_state in _get_fsdp_states(module) 111 if fsdp_state._handle is not None 112 ] 113 return handles 114