1# mypy: allow-untyped-defs 2import uuid 3from collections import OrderedDict 4from functools import wraps 5from typing import Callable, Dict, List, Optional, Sequence, Type, Union 6 7import torch 8import torch.nn as nn 9from torch.distributed._composable_state import _State 10from torch.distributed.utils import _get_root_modules 11 12 13def generate_state_key(string="__composable_api_state_key"): 14 return f"{string}_{str(uuid.uuid4())}" 15 16 17STATE_KEY = generate_state_key() 18REGISTRY_KEY = generate_state_key() 19 20 21# TODO: we can add additional info to RegistryItem to share across APIs. E.g., 22# we can add args and kwargs here, and then we can detect whether fully_shard 23# is combined with reentrant activation checkpointing and error out with a clear 24# message. 25class RegistryItem: 26 pass 27 28 29def contract(state_cls: Type[_State] = _State): 30 r""" 31 Decorate a function as a composable distributed API, where the first 32 argument of the function must be an :class:`nn.Module` instance or sequence 33 of :class:`nn.Module` instances. 34 35 The decorator verifies that the decorated function does not modify 36 fully-qualified names (FQNs) for parameters, buffers, or modules. The 37 decorated function can return different module instances than the input 38 modules; the FQN invariant will be enforced following the input order. 39 40 When a function ``func`` is decorated by ``@contract()``, a 41 ``.state(module: nn.Module)`` method will be installed to the decorated 42 function. Then you can retrieve and modify the state on a module by calling 43 ``func.state(module)``. 44 45 Example:: 46 >>> # xdoctest: +SKIP 47 >>> import torch.nn as nn 48 >>> 49 >>> class MyModel(nn.Module): 50 >>> def __init__(self) -> None: 51 >>> super().__init__() 52 >>> self.l1 = nn.Linear(10, 10) 53 >>> self.l2 = nn.Linear(10, 10) 54 >>> 55 >>> def forward(self, x): 56 >>> return self.l2(self.l1(x)) 57 >>> 58 >>> @contract() 59 >>> def my_feature(module: nn.Module) -> nn.Module: 60 >>> my_feature.state(module).some_state = "any value" 61 >>> return module 62 >>> 63 >>> model = MyModel() 64 >>> my_feature(model.l1) 65 >>> assert my_feature.state(model.l1).some_state == "any value" 66 >>> my_feature(model.l2) 67 >>> model(torch.randn(2, 10)).sum().backward() 68 """ 69 70 # wraps will make functions decorated with contract() pickleable - needed for integration with torch.package 71 @wraps(state_cls) 72 def inner(func): 73 @wraps(func) 74 def wrapper( 75 module: Union[nn.Module, Sequence[nn.Module]], *args, **kwargs 76 ) -> Optional[nn.Module]: 77 inp_module = module 78 if isinstance(module, nn.Module): 79 modules = [module] 80 else: 81 # If the user passes a sequence of modules, then we assume that 82 # we only need to insert the state object on the root modules 83 # (i.e. those without a parent) among the passed-in modules. 84 modules = _get_root_modules(list(module)) 85 state = state_cls() # shared across all modules 86 registry_item = RegistryItem() # shared across all modules 87 88 # `func` is allowed to return different module instances than the 89 # input modules as long as FQNs are preserved following the input 90 # module order 91 all_orig_named_params: List[Dict[str, nn.Parameter]] = [] 92 all_orig_named_buffers: List[Dict[str, torch.Tensor]] = [] 93 all_orig_named_modules: List[Dict[str, nn.Module]] = [] 94 95 for module in modules: 96 default_all_state: Dict[Callable, _State] = OrderedDict() 97 default_registry: Dict[str, RegistryItem] = OrderedDict() 98 all_state: Dict[Callable, _State] = module.__dict__.setdefault( # type: ignore[call-overload] 99 STATE_KEY, default_all_state 100 ) 101 if not isinstance(all_state, dict): 102 raise AssertionError( 103 f"Distributed composable API states corrupted: {all_state}" 104 ) 105 registry: Dict[str, RegistryItem] = module.__dict__.setdefault( # type: ignore[call-overload] 106 REGISTRY_KEY, default_registry 107 ) 108 if not isinstance(registry, dict): 109 raise AssertionError( 110 f"Distributed composable API registry corrupted: {registry}" 111 ) 112 if func in all_state or func.__name__ in registry: 113 raise AssertionError( 114 "Each distinct composable distributed API can only be applied to a " 115 f"module once. {func.__name__} has already been applied to the " 116 f"following module:\n{module}" 117 ) 118 all_state.setdefault(func, state) 119 registry.setdefault(func.__name__, registry_item) 120 121 all_orig_named_params.append(OrderedDict(module.named_parameters())) 122 all_orig_named_buffers.append(OrderedDict(module.named_buffers())) 123 all_orig_named_modules.append(OrderedDict(module.named_modules())) 124 125 updated = func(inp_module, *args, **kwargs) 126 if updated is None: 127 updated = inp_module 128 if isinstance(updated, nn.Module): 129 updated_modules = [updated] 130 else: 131 updated_modules = _get_root_modules(list(inp_module)) 132 133 all_new_named_params: List[Dict[str, nn.Parameter]] = [] 134 all_new_named_buffers: List[Dict[str, torch.Tensor]] = [] 135 all_new_named_modules: List[Dict[str, nn.Module]] = [] 136 for module in updated_modules: 137 all_new_named_params.append(OrderedDict(module.named_parameters())) 138 all_new_named_buffers.append(OrderedDict(module.named_buffers())) 139 all_new_named_modules.append(OrderedDict(module.named_modules())) 140 141 num_orig_modules = len(all_orig_named_modules) 142 num_new_modules = len(all_new_named_modules) 143 if num_orig_modules != num_new_modules: 144 raise AssertionError( 145 f"{func.__name__} should return the same number of modules as input modules" 146 f"Inputs: {num_orig_modules} modules\n" 147 f"Outputs: {num_new_modules} modules" 148 ) 149 150 def check_fqn(orig_fqns: List[str], new_fqns: List[str], check_key: str): 151 if orig_fqns == new_fqns: 152 return 153 154 orig_fqn_set, new_fqn_set = set(orig_fqns), set(new_fqns) 155 orig_only = orig_fqn_set - new_fqn_set 156 new_only = new_fqn_set - orig_fqn_set 157 if len(orig_only) or len(new_only): 158 raise RuntimeError( 159 f"{check_key}" 160 "Composable distributed API implementations cannot modify FQNs.\n" 161 f"FQNs only in original: {orig_only}\n" 162 f"FQNs only in new: {new_only}" 163 ) 164 else: 165 raise RuntimeError( 166 f"{check_key}" 167 "Composable distributed API implementations cannot modify " 168 "the order of FQNs.\n" 169 f"Original FQNs: {orig_only}\n" 170 f"New FQNs: {new_only}" 171 ) 172 173 for orig_named_params, new_named_params in zip( 174 all_orig_named_params, all_new_named_params 175 ): 176 check_fqn( 177 list(orig_named_params.keys()), 178 list(new_named_params.keys()), 179 "Checking parameters: ", 180 ) 181 for orig_named_buffers, new_named_buffers in zip( 182 all_orig_named_buffers, all_new_named_buffers 183 ): 184 check_fqn( 185 list(orig_named_buffers.keys()), 186 list(new_named_buffers.keys()), 187 "Checking buffers: ", 188 ) 189 for orig_named_modules, new_named_modules in zip( 190 all_orig_named_modules, all_new_named_modules 191 ): 192 check_fqn( 193 list(orig_named_modules.keys()), 194 list(new_named_modules.keys()), 195 "Checking modules: ", 196 ) 197 198 # TODO: verify that installed distributed paradigms are compatible with 199 # each other. 200 201 return updated 202 203 def get_state(module: nn.Module) -> Optional[_State]: 204 return module.__dict__.setdefault( # type: ignore[call-overload] 205 STATE_KEY, 206 {}, # TODO(@yhcharles): this is a temporary fix, need a better way 207 ).get( 208 func 209 ) # type: ignore[call-overload] 210 211 wrapper.state = get_state # type: ignore[attr-defined] 212 213 return wrapper 214 215 return inner 216 217 218def _get_registry(module: nn.Module) -> Optional[Dict[str, RegistryItem]]: 219 r""" 220 Get an ``OrderedDict`` of composable APIs that have been applied to the 221 ``module``, indexed by the API name. If no API has been applied, then this 222 returns ``None``. 223 """ 224 return getattr(module, REGISTRY_KEY, None) 225