1# mypy: allow-untyped-decorators 2# mypy: allow-untyped-defs 3import functools 4from typing import Any, cast, Dict, Iterable, List, NoReturn, Optional, Type, Union 5 6import torch 7import torch.nn as nn 8from torch.distributed._composable import contract 9from torch.distributed.tensor import DeviceMesh 10from torch.distributed.utils import _get_root_modules 11 12from ._fsdp_api import MixedPrecisionPolicy, OffloadPolicy 13from ._fsdp_common import FSDPMeshInfo, HSDPMeshInfo 14from ._fsdp_init import ( 15 _get_device_from_mesh, 16 _get_managed_modules, 17 _get_managed_states, 18 _get_post_forward_mesh_info, 19 _init_default_fully_shard_mesh, 20 _move_states_to_device, 21) 22from ._fsdp_param_group import FSDPParamGroup 23from ._fsdp_state import _get_module_fsdp_state, FSDPState 24 25 26cls_to_fsdp_cls: Dict[Type, Type] = {} 27 28 29# The decorator adds a state object to `module` that can be accessed via 30# `fully_shard.state(module)`. The state object and module are 1:1. 31@contract(state_cls=FSDPState) # type: ignore[operator] 32def fully_shard( 33 module: Union[nn.Module, List[nn.Module]], 34 *, 35 mesh: Optional[DeviceMesh] = None, 36 reshard_after_forward: Union[bool, int] = True, 37 mp_policy: MixedPrecisionPolicy = MixedPrecisionPolicy(), 38 offload_policy: OffloadPolicy = OffloadPolicy(), 39): 40 """ 41 Shard module parameters across data parallel workers. 42 43 This function applies fully sharded data parallelism (FSDP) or a variant to 44 ``module``, a technique for memory savings at the cost of communication. 45 Parameters are sharded across ``mesh``, and in turn, so are their gradients 46 and optimizer states. 47 48 The sharded parameters are all-gathered to construct the unsharded 49 parameters for forward or backward computation. The unsharded parameters 50 are freed after computation to save memory. The gradients are reduced 51 across the mesh and divided by the mesh size for data parallelism. The 52 optimizer step runs on the sharded parameters. 53 54 Each call to ``fully_shard`` constructs one communication group that 55 includes the parameters in ``module.parameters()`` except those already 56 assigned to a group from a nested call. Each group's parameters and its 57 gradients are communicated together in one collective, respectively. 58 Constructing multiple groups across the model (e.g. "layer by layer") 59 allows for peak memory savings and communication/computation overlap. 60 61 Implementation-wise, the sharded parameters are represented as 62 :class:`DTensor` s, sharded on dim-0, and the unsharded parameters are 63 represented as :class:`Tensor` s. A module forward pre-hook all-gathers the 64 parameters, and a module forward hook frees them. Similar backward hooks 65 gather parameters and later free parameters/reduce gradients. 66 67 Args: 68 module (Union[nn.Module, List[nn.Module]): The module or modules to 69 shard with FSDP and group together for communication. 70 mesh (Optional[DeviceMesh]): This data parallel mesh defines the 71 sharding and device. If 1D, then parameters are fully sharded 72 across the 1D mesh (FSDP). If 2D, then parameters are sharded 73 across the 0th dim and replicated across the 1st dim (HSDP). The 74 mesh's device type gives the device type used for communication; 75 if a CUDA or CUDA-like device type, then we use the current device. 76 reshard_after_forward (Union[bool, int]): This controls the parameter 77 behavior after forward and can trade off memory and communication: 78 - If ``True``, then this reshards parameters after forward and 79 all-gathers in backward. 80 - If ``False``, then this keeps the unsharded parameters in memory 81 after forward and avoids the all-gather in backward. 82 - If an ``int``, then this represents the world size to reshard to 83 after forward. It should be a non-trivial divisor of the ``mesh`` 84 shard dim size (i.e. excluding 1 and the dim size itself). A choice 85 may be the intra-node size (e.g. ``torch.cuda.device_count()``). 86 This allows the all-gather in backward to be over a smaller world 87 size at the cost of higher memory usage than setting to ``True``. 88 - The root FSDP state has its value specially set to ``False`` as a 89 heuristic since its parameters would typically be immediately 90 all-gathered for backward. 91 - After forward, the parameters registered to the module depend on 92 to this: The registered parameters are the sharded parameters if 93 ``True``; unsharded parameters if ``False``; and the paramters 94 resharded to the smaller mesh otherwise. To modify the parameters 95 between forward and backward, the registered parameters must be the 96 sharded parameters. For ``False`` or an ``int``, this can be done 97 by manually resharding via :meth:`reshard`. 98 mp_policy (MixedPrecisionPolicy): This controls the mixed precision 99 policy, which offers parameter/reduction mixed precision for this 100 module. See :class:`MixedPrecisionPolicy` for details. 101 offload_policy (OffloadPolicy): This controls the offloading policy, 102 which offers parameter/gradient/optimizer state offloading. See 103 :class:`OffloadPolicy` and its subclasses for details. 104 """ 105 if isinstance(module, (nn.ModuleList, nn.ModuleDict)): 106 raise ValueError( 107 f"fully_shard does not support containers that do not implement forward: {module}" 108 ) 109 mesh = mesh or _init_default_fully_shard_mesh() 110 if mesh.ndim not in (1, 2): 111 raise ValueError(f"fully_shard expects a 1D or 2D DeviceMesh but got {mesh}") 112 elif mesh.ndim == 1: 113 mesh_info = FSDPMeshInfo(mesh, shard_mesh_dim=0) 114 else: 115 mesh_info = HSDPMeshInfo(mesh, shard_mesh_dim=1, replicate_mesh_dim=0) 116 device = _get_device_from_mesh(mesh) 117 post_forward_mesh_info = _get_post_forward_mesh_info( 118 reshard_after_forward, mesh_info 119 ) 120 121 arg_module = module 122 modules = ( 123 (module,) if isinstance(module, nn.Module) else tuple(_get_root_modules(module)) 124 ) 125 state = fully_shard.state(modules[0]) 126 state.init(modules, device, mp_policy) 127 128 managed_modules = _get_managed_modules(modules) 129 params, buffers = _get_managed_states(managed_modules) 130 _move_states_to_device(params, buffers, device) 131 if params: 132 state._fsdp_param_group = FSDPParamGroup( 133 params, 134 modules, 135 mesh_info, 136 post_forward_mesh_info, 137 device, 138 mp_policy, 139 offload_policy, 140 ) 141 142 # For Dynamo 143 for managed_module in managed_modules: 144 managed_module._is_fsdp_managed_module = True # type: ignore[assignment] 145 managed_module._fsdp_use_orig_params = True # type: ignore[assignment] 146 147 # Place FSDP leftmost for highest priority in the method resolution order 148 for module in modules: 149 cls = module.__class__ 150 new_cls = cls_to_fsdp_cls.get(cls, None) 151 if not new_cls: 152 dct = {"__deepcopy__": unimplemented_deepcopy} 153 new_cls = type(f"FSDP{cls.__name__}", (FSDPModule, cls), dct) 154 cls_to_fsdp_cls[cls] = new_cls 155 module.__class__ = new_cls 156 return arg_module 157 158 159def unimplemented_deepcopy(*args: Any, **kwargs: Any) -> NoReturn: 160 raise AssertionError( 161 "FSDP does not support deepcopy. Please use state dict for serialization." 162 ) 163 164 165class FSDPModule: 166 def __new__(cls, *args, **kwargs): 167 """ 168 Override ``__new__`` to remove the FSDP class and directly construct 169 the original class for cases like indexing into a container module. 170 """ 171 # Use index 2 since 0 is the dynamically constructed `FSDP<...>` class 172 # and index 1 is the `FSDPModule` class itself 173 orig_cls = cls.__mro__[2] 174 self = orig_cls.__new__(orig_cls, *args, **kwargs) 175 self.__init__(*args, **kwargs) 176 return self 177 178 def reshard(self) -> None: 179 """ 180 Reshards the module's parameters, registering the sharded parameters 181 to the module and freeing the unsharded parameters if needed. This 182 method is *not* recursive. 183 """ 184 state = self._get_fsdp_state() 185 if fsdp_param_group := state._fsdp_param_group: 186 fsdp_param_group.reshard() 187 188 def unshard(self, async_op: bool = False) -> Optional["UnshardHandle"]: 189 """ 190 Unshards the module's parameters by allocating memory and all-gathering 191 the parameters. This method is *not* recursive. 192 193 Args: 194 async_op (bool): If ``True``, then returns a :class:`UnshardHandle` 195 that has a :meth:`wait` method to wait on the unshard op. If 196 ``False``, then returns ``None`` and waits on the handle inside 197 this function. 198 199 .. warning:: This method is experimental and subject to change. 200 201 .. note:: If ``async_op=True``, then the user does not have to call 202 :meth:`wait` on the returned handle if waiting on the unshard op 203 in the module's pre-forward is tolerable. FSDP will wait on the 204 pending unshard op in the pre-forward automatically. 205 """ 206 state = self._get_fsdp_state() 207 fsdp_param_group = state._fsdp_param_group 208 if fsdp_param_group is not None: 209 fsdp_param_group.lazy_init() 210 fsdp_param_group.unshard(async_op=async_op) 211 handle = UnshardHandle(fsdp_param_group) 212 if async_op: 213 return handle 214 handle.wait() 215 return None 216 217 def set_is_last_backward(self, is_last_backward: bool) -> None: 218 """ 219 Sets whether the next backward is the last one, meaning that FSDP 220 should wait for gradient reduction to finish and clear internal data 221 structures used for explicit prefetching. 222 """ 223 state = self._get_fsdp_state() 224 state._state_ctx.is_last_backward = is_last_backward 225 226 def set_requires_gradient_sync( 227 self, requires_gradient_sync: bool, *, recurse: bool = True 228 ) -> None: 229 """ 230 Sets if the module should sync gradients. This can be used to implement 231 gradient accumulation without communication. For HSDP, this controls 232 both reduce-scatter and all-reduce together. 233 234 Args: 235 requires_gradient_sync (bool): Whether to reduce gradients for the 236 module's parameters. 237 recurse (bool): Whether to set for all submodules or just the 238 passed-in module. 239 """ 240 self_module = cast(nn.Module, self) 241 modules = list(self_module.modules()) if recurse else [self_module] 242 for module in modules: 243 if isinstance(module, FSDPModule): 244 state = module._get_fsdp_state() 245 if fsdp_param_group := state._fsdp_param_group: 246 fsdp_param_group.reduce_grads = requires_gradient_sync 247 fsdp_param_group.all_reduce_grads = requires_gradient_sync 248 249 def set_requires_all_reduce( 250 self, requires_all_reduce: bool, *, recurse: bool = True 251 ) -> None: 252 """ 253 Sets if the module should all-reduce gradients. This can be used to 254 implement gradient accumulation with only reduce-scatter but not 255 all-reduce for HSDP. 256 """ 257 self_module = cast(nn.Module, self) 258 modules = list(self_module.modules()) if recurse else [self_module] 259 for module in modules: 260 if isinstance(module, FSDPModule): 261 state = module._get_fsdp_state() 262 if fsdp_param_group := state._fsdp_param_group: 263 fsdp_param_group.all_reduce_grads = requires_all_reduce 264 265 def set_reshard_after_backward( 266 self, reshard_after_backward: bool, *, recurse: bool = True 267 ) -> None: 268 """ 269 Sets if the module should reshard parameters after backward. This can 270 be used during gradient accumulation to trade off higher memory for 271 reduced communication. 272 273 Args: 274 reshard_after_backward (bool): Whether to reshard parameters after 275 backward. 276 recurse (bool): Whether to set for all submodules or just the 277 passed-in module. 278 """ 279 self_module = cast(nn.Module, self) 280 modules = list(self_module.modules()) if recurse else [self_module] 281 for module in modules: 282 if isinstance(module, FSDPModule): 283 state = module._get_fsdp_state() 284 if fsdp_param_group := state._fsdp_param_group: 285 fsdp_param_group.reshard_after_backward = reshard_after_backward 286 287 def set_modules_to_forward_prefetch(self, modules: List["FSDPModule"]) -> None: 288 """ 289 Sets the FSDP modules for which this FSDP module should explicitly 290 prefetch all-gathers in forward. The prefetching runs after this 291 module's all-gather copy-out. 292 293 Passing a singleton list containing the next FSDP module gives the same 294 all-gather overlap behavior as the default overlap behavior, except the 295 prefetched all-gather is issued earlier from the CPU. Passing a list 296 with at least length two is required for more aggressive overlap and 297 will use more reserved memory. 298 299 Args: 300 modules (List[FSDPModule]): FSDP modules to prefetch. 301 """ 302 _assert_all_fsdp_modules(modules) 303 self._get_fsdp_state()._states_to_forward_prefetch = [ 304 module._get_fsdp_state() for module in modules 305 ] 306 307 def set_modules_to_backward_prefetch(self, modules: List["FSDPModule"]) -> None: 308 """ 309 Sets the FSDP modules for which this FSDP module should explicitly 310 prefetch all-gathers in backward. This overrides the default backward 311 pretching implementation that prefetches the next FSDP module based on 312 the reverse post-forward order. 313 314 Passing a singleton list containing the previous FSDP module gives the 315 same all-gather overlap behavior as the default overlap behavior. 316 Passing a list with at least length two is required for more aggressive 317 overlap and will use more reserved memory. 318 319 Args: 320 modules (List[FSDPModule]): FSDP modules to prefetch. 321 """ 322 _assert_all_fsdp_modules(modules) 323 self._get_fsdp_state()._states_to_backward_prefetch = [ 324 module._get_fsdp_state() for module in modules 325 ] 326 327 def set_post_optim_event(self, event: torch.cuda.Event) -> None: 328 """ 329 Sets a post-optimizer-step event for the root FSDP module to wait the 330 all-gather streams on. 331 332 By default, the root FSDP module waits the all-gather streams on the 333 current stream to ensure that the optimizer step has finished before 334 all-gathering. However, this may introduce false dependencies if 335 there is unrelated computation after the optimizer step. This API 336 allows the user to provide their own event to wait on. After the root 337 waits on the event, the event is discarded, so this API should be 338 called with a new event each iteration. 339 340 Args: 341 event (torch.cuda.Event): Event recorded after the optimizer step 342 to wait all-gather streams on. 343 """ 344 self._get_fsdp_state()._state_ctx.post_optim_event = event 345 346 def set_reduce_scatter_divide_factor(self, factor: float) -> None: 347 """ 348 Sets a custom divide factor for the reduce-scatter. This becomes a 349 custom reduce op using NCCL's PreMulSum, which allows multiplying by 350 the factor before reduction. 351 352 Args: 353 factor (float): Custom divide factor. 354 """ 355 state = self._get_fsdp_state() 356 if (fsdp_param_group := state._fsdp_param_group) is not None: 357 mul_factor = 1.0 / float(factor) 358 reduce_op = torch.distributed._make_nccl_premul_sum(mul_factor) 359 fsdp_param_group.reduce_scatter_reduce_op = reduce_op 360 361 def _get_fsdp_state(self) -> FSDPState: 362 if (state := _get_module_fsdp_state(cast(nn.Module, self))) is None: 363 raise AssertionError(f"No FSDP state found on {self}") 364 return state 365 366 def _apply(self, *args: Any, **kwargs: Any) -> Any: 367 # Reshard to ensure that sharded parameters are registered 368 self.reshard() 369 ret = super()._apply(*args, **kwargs) # type: ignore[misc] 370 state = self._get_fsdp_state() 371 if not (fsdp_param_group := state._fsdp_param_group): 372 return ret 373 # TODO: Remove this padding logic once DTensor pads the local tensor: 374 # https://github.com/pytorch/pytorch/issues/113045 375 with torch.no_grad(): 376 for fsdp_param in fsdp_param_group.fsdp_params: 377 fsdp_param.reset_sharded_param() 378 return ret 379 380 381class UnshardHandle: 382 """ 383 A handle to wait on the unshard op. 384 385 Args: 386 fsdp_param_group (FSDPParamGroup, optional): FSDP parameter group to 387 unshard. This should be ``None`` iff the FSDP module does not 388 manage any parameters, meaning the unshard is a no-op. 389 """ 390 391 def __init__(self, fsdp_param_group: Optional[FSDPParamGroup]): 392 self._fsdp_param_group = fsdp_param_group 393 394 def wait(self): 395 """ 396 Waits on the unshard op. 397 398 This ensures that the current stream can use the unsharded parameters, 399 which are now registered to the module. 400 """ 401 if self._fsdp_param_group is not None: 402 self._fsdp_param_group.wait_for_unshard() 403 # Avoid keeping a reference 404 self._fsdp_param_group = None 405 406 407def register_fsdp_forward_method(module: nn.Module, method_name: str) -> None: 408 """ 409 Registers a method on ``module`` to be a forward method for FSDP. 410 411 FSDP only knows to run its pre-forward and post-forward hooks on the 412 default :meth:`nn.Module.forward` method. This function patches a user 413 specified method to run the pre/post-forward hooks before/after the method, 414 respectively. If ``module`` is not an :class:`FSDPModule`, then this is a 415 no-op. 416 417 Args: 418 module (nn.Module): Module to register the forward method on. 419 method_name (str): Name of the forward method. 420 """ 421 if not isinstance(module, FSDPModule): 422 # Make no-op to allow including both when using/not using FSDP 423 return 424 if not hasattr(module, method_name): 425 raise ValueError(f"{type(module)} does not have a method {method_name}") 426 orig_method = getattr(module, method_name) 427 428 @functools.wraps(orig_method) 429 def wrapped_method(self, *args, **kwargs): 430 fsdp_state = self._get_fsdp_state() 431 args, kwargs = fsdp_state._pre_forward(self, args, kwargs) 432 out = orig_method(*args, **kwargs) 433 return fsdp_state._post_forward(self, args, out) 434 435 # Use `__get__` to make `wrapped_method` an instance method 436 setattr( 437 module, 438 method_name, 439 wrapped_method.__get__(module, type(module)), # type:ignore[attr-defined] 440 ) 441 442 443def _assert_all_fsdp_modules(modules: Iterable[Any]) -> None: 444 for module in modules: 445 if not isinstance(module, FSDPModule): 446 raise ValueError(f"Expects FSDPModule but got {type(module)}: {module}") 447