1# mypy: allow-untyped-defs 2import itertools 3from dataclasses import dataclass, field 4from enum import auto, Enum 5from typing import Any, cast, List, Optional, Sequence, Tuple 6 7import torch 8import torch._dynamo.compiled_autograd as ca 9import torch.nn as nn 10from torch._prims_common import make_contiguous_strides_for 11from torch.distributed._functional_collectives import AsyncCollectiveTensor 12from torch.distributed.tensor import DTensor, Replicate, Shard 13from torch.distributed.tensor._dtensor_spec import DTensorSpec, TensorMeta 14from torch.distributed.tensor.device_mesh import _mesh_resources 15from torch.distributed.tensor.placement_types import _StridedShard, Placement 16 17from ._fsdp_api import CPUOffloadPolicy, MixedPrecisionPolicy, OffloadPolicy 18from ._fsdp_common import ( 19 _chunk_with_empty, 20 _from_local_no_grad, 21 _get_dim0_chunked_size, 22 _raise_assert_with_print, 23 _to_dtype_if_needed, 24 FSDPMeshInfo, 25 HSDPMeshInfo, 26) 27 28 29""" 30[Note: FSDP tensors] 31FSDP considers the following tensors: 32- Original parameter: parameter passed to :class:`FSDPParam`, i.e. the one 33 on the module when applying FSDP 34- Sharded parameter: sharding the original parameter on dim-0 as a DTensor 35 over the main mesh 36- All-gather inputs: the ``torch.Tensor`` or ``Tensor`` s passed to all-gather, 37 derived from the sharded parameter 38- All-gather output: the ``torch.Tensor`` or ``Tensor`` s resulting from 39 all-gathering the all-gather inputs 40- Unsharded parameter: parameter used for forward/backward computation, derived 41 from the all-gather output; autograd leaf 42 43We define these tensors to describe the general framework that can accomodate 44extensions, where: 45- all-gather-inputs = pre-all-gather-transform(sharded-parameter) 46- unsharded-parameter = post-all-gather-transform(all-gather-outputs) 47 48For the default ``torch.Tensor`` case, there is only one all-gather input, and 49it shares the same underlying tensor data as the sharded parameter, meaning 50that they can be thought of as the same tensors. The same applies for the 51all-gather output and unsharded parameter. For non-``torch.Tensor`` extensions, 52these equivalences may no longer hold due to the pre/post-all-gather 53transforms, and some may have multiple all-gather inputs/outputs (e.g. 54quantized data and scales). 55 56[Note: FSDP and autograd] 57FSDP dynamically frees and allocates the unsharded parameter. Since autograd 58can pack a reference to it or a view to save for backward, we use storage 59resizing to implement the freeing/allocation since that preserves the aliasing. 60This implies that we construct the unsharded parameter object once and write to 61it in-place thereafter. For the default ``torch.Tensor` original parameter 62case, the all-gather output and unsharded parameter share the same 63data, so we use storage resizing on the all-gather output. 64""" 65 66lib = torch.library.Library("fsdp", "FRAGMENT") # noqa: TOR901 67 68lib.define("set_(Tensor(a!) tensor, Tensor data) -> ()") 69 70 71@torch.library.impl(lib, "set_", "Meta") 72@torch.library.impl(lib, "set_", "CUDA") 73@torch.library.impl(lib, "set_", "CPU") 74def set_(tensor, data): 75 tensor.set_(data) 76 77 78""" 79[Note: Avoiding functionalization for fsdp.set_ and inductor.resize_storage_bytes_(0)] 80 81Currently we don't functionalize `fsdp.set_` op or `inductor.resize_storage_bytes_(0)` op 82(i.e. they show up as a mutation op in the middle of the AOT joint graph). 83 84Reason: 85Traceable FSDP2 compiled autograd BWD graph have the following traits: 86(1) Two inputs of the graph were aliased to each other (one from hook closed-over tensors, one from FWD saved tensors). 87(2) One of them is mutated (set_ and resize_(0) to handle the all-gathered param). 88(3) They are both subclasses. 89The combination of these traits is not supported by AOTAutograd (it's difficult to reason about subclass aliasing). 90So this doesn't work at all for Traceable FSDP2. 91 92The compromise we use is to avoid functionalization for the FSDP2 set_ and resize_(0) ops. 93This avoids the problem above, because from AOTAutograd point-of-view there are no mutations 94that functionalization needs to handle. (Although we need to be careful not to DCE those mutable ops.) 95 96We can avoid this functionalization because: 97(1) The nn.Parameter is never used before its .set_() is called in eager code (i.e. no alias of it is created), 98so it's safe to call .set_() in the middle of the graph to swap out its storage and start using the nn.Parameter downstream. 99(2) We always re-allocate the buffer for nn.Parameter to store the AllGather output and to be used in downstream user ops. 100So calling resize-to-0 in the middle of the graph to free nn.Parameter memory after use should always be okay 101(since we always allocate anew next time we need it, we strictly don't need to keep the old tensor storage around anymore). 102 103Q: But doesn't the torch.compile stack have the "functional graph" assumption in many places? 104A: Yes - this is WIP but we will try to get back to functional graph as early as possible in the lowering process. 105Specifically, we believe we can move both .set_ and .resize_(0) ops to end of graph in AOT joint graph before partitioner 106(i.e. effectively "re-functionalizing" those ops). Put it in another way, we avoid functionalization for those two ops just to 107make AOTAutograd alias analysis happy, and as soon as we are past that point, we "re-functionalize" the graph. 108This requires a custom FX pass but we believe it's not hard to write and maintain. 109 110Q: What's the importance of partitioner not saving views of nn.Parameter as FWD saved tensors? 111A: This is critical: we do want to save FWD nn.Parameter graph input (instead of its view) for BWD use, 112so that downstream ops in BWD graph uses the post-`.set_` nn.Parameter instead of any of its saved views as input. 113This is because .set_ will not update any of the nn.Parameter's views, so BWD downstream ops must use the original 114nn.Parameter in order to see the result of .set_. 115""" 116 117 118@torch.library.impl(lib, "set_", "Functionalize") 119def set__functionalize(tensor, data): 120 torch._sync(tensor) 121 torch._sync(data) 122 # AOTDispatcher needs to know if any inputs had their storages mutated. 123 # (Why? It sometimes detaches inputs before sending them into the graph, 124 # when it sees that they do not need to have any gradients computed) 125 torch._functionalize_set_storage_changed(tensor) 126 tensor_inner = torch._from_functional_tensor(tensor) 127 data_inner = torch._from_functional_tensor(data) 128 with torch._C._ExcludeDispatchKeyGuard( 129 torch._C.DispatchKeySet(torch._C.DispatchKey.Functionalize) 130 ): 131 torch.ops.fsdp.set_.default(tensor_inner, data_inner) 132 133 134torch.fx.node.has_side_effect(torch.ops.fsdp.set_.default) 135 136 137class ShardedState(Enum): 138 """ 139 - ``SHARDED``: The sharded parameter is registered to the module. It is the 140 only contributor to parameter memory. 141 - ``SHARDED_POST_FORWARD``: The unsharded parameter is resharded to a 142 smaller world size. Since this data should not be used for computation, 143 we do not register it to the module. Users should reshard the module 144 before any in-place modifications. Both it and the sharded parameter 145 contribute to parameter memory. 146 - ``UNSHARDED``: The unsharded parameter is registered to the module. Both 147 it and the sharded parameter contribute to parameter memory. 148 """ 149 150 SHARDED = auto() 151 SHARDED_POST_FORWARD = auto() 152 UNSHARDED = auto() 153 154 155@dataclass 156class ParamModuleInfo: 157 """ 158 For a parameter, this stores the module and the parameter name to be able 159 to do a parameter swap via ``setattr(module, param_name, ...)`` or to get 160 the parameter via ``getattr(module, param_name)``. We additionally save 161 shared modules and shared parameter names to update them accordingly. 162 """ 163 164 # Parameter names are unprefixed, e.g. "weight", not "lin.weight" 165 module: nn.Module 166 param_name: str 167 shared_modules: List[nn.Module] = field(default_factory=list) 168 shared_param_names: List[str] = field(default_factory=list) 169 170 171@dataclass 172class ExtensionsData: 173 # User-defined metadata passed from pre to post-all-gather 174 all_gather_metadata: Optional[Any] = None 175 # Save the all-gather input sizes to unflatten the all-gather outputs to ND 176 all_gather_input_sizes: Sequence[torch.Size] = () # ND 177 178 def clear(self): 179 self.all_gather_metadata = None 180 self.all_gather_input_sizes = () 181 182 183class FSDPParam: 184 """ 185 This class manages a parameter with FSDP or FSDP variants applied, 186 implementing dim-0 per-parameter sharding. 187 """ 188 189 orig_dtype: torch.dtype 190 param_dtype: Optional[torch.dtype] 191 reduce_dtype: Optional[torch.dtype] 192 _orig_size: torch.Size # ND 193 sharded_size: torch.Size # ND 194 contiguous_sharded_stride: Tuple[int, ...] 195 padded_sharded_param_size: torch.Size # ND 196 sharded_post_forward_size: torch.Size # ND 197 contiguous_sharded_post_forward_stride: Tuple[int, ...] 198 _sharded_param_data: torch.Tensor # 1D 199 sharded_param: nn.Parameter # ND 200 _sharded_post_forward_param_data: Optional[torch.Tensor] # 1D 201 _sharded_post_forward_param: Optional[nn.Parameter] # ND 202 _unsharded_param: nn.Parameter # ND 203 unsharded_accumulated_grad: Optional[torch.Tensor] # ND 204 _sharding_spec: DTensorSpec 205 # DTensor attributes (only defined for DTensor `param`): 206 _tp_spec: DTensorSpec 207 all_gather_outputs: List[torch.Tensor] # 1D 208 # All-gather extension attributes 209 _extensions_data: ExtensionsData 210 _unsharded_inner_tensors: List[torch.Tensor] 211 212 def __init__( 213 self, 214 param: nn.Parameter, 215 module_info: ParamModuleInfo, 216 mesh_info: FSDPMeshInfo, 217 post_forward_mesh_info: Optional[FSDPMeshInfo], 218 device: torch.device, 219 mp_policy: MixedPrecisionPolicy, 220 offload_policy: OffloadPolicy, 221 ): 222 self._module_info: ParamModuleInfo = module_info 223 self.mesh_info = mesh_info 224 self.post_forward_mesh_info = post_forward_mesh_info 225 self.device = device 226 self.offload_to_cpu: bool = isinstance(offload_policy, CPUOffloadPolicy) 227 self.pin_memory = ( 228 self.offload_to_cpu and cast(CPUOffloadPolicy, offload_policy).pin_memory 229 ) 230 self.grad_offload_event: Optional[torch.cuda.Event] = None 231 self._init_sharded_param(param, device) 232 if self.post_forward_mesh_info: 233 self._init_sharded_post_forward_param_metadata(param) 234 self._init_extensions() 235 self.all_gather_outputs: List[torch.Tensor] = [] 236 self.unsharded_accumulated_grad = None 237 self._param_fqn: Optional[str] = None # prefixed from root module 238 # TODO: Remove this padding logic once DTensor pads the local tensor: 239 # https://github.com/pytorch/pytorch/issues/113045 240 self._post_load_hook_handle = ( 241 module_info.module.register_load_state_dict_post_hook( 242 lambda *args, **kwargs: self.reset_sharded_param() 243 ) 244 ) 245 246 @torch.no_grad() 247 def _init_sharded_param(self, param: nn.Parameter, device: torch.device): 248 if param.device != device and param.device.type != "meta": 249 raise AssertionError( 250 f"Expects the parameter to already be moved to device {device} but got {param.device}" 251 ) 252 # TODO: Replace the sharded DTensor parameter construction logic with 253 # `distribute_tensor` after https://github.com/pytorch/pytorch/issues/116101 254 # TODO: Simplify the following sharded parameter padding logic after 255 # https://github.com/pytorch/pytorch/issues/113045 256 self.is_dtensor = isinstance(param, DTensor) 257 if self.is_dtensor: 258 self._tp_spec = cast(DTensor, param)._spec 259 dp_mesh, tp_mesh = (self.mesh_info.mesh, self._tp_spec.mesh) 260 dp_global_mesh = _mesh_resources.get_root_mesh(dp_mesh) 261 tp_global_mesh = _mesh_resources.get_root_mesh(tp_mesh) 262 if dp_global_mesh != tp_global_mesh or ( 263 dp_global_mesh is None or tp_global_mesh is None 264 ): 265 raise AssertionError( 266 "FSDP requires the DP and TP mesh to have the same parent mesh but got: \n" 267 f"DP's global mesh: {dp_global_mesh}\nTP's global mesh: {tp_global_mesh}" 268 ) 269 270 name_dims_error = "FSDP requires named DeviceMesh dims for ND parallelism" 271 assert dp_mesh.mesh_dim_names is not None, name_dims_error 272 assert tp_mesh.mesh_dim_names is not None, name_dims_error 273 submesh_names = dp_mesh.mesh_dim_names + tp_mesh.mesh_dim_names 274 self._spmd_mesh = dp_global_mesh[submesh_names] 275 if len(self._tp_spec.placements) != 1: 276 raise NotImplementedError( 277 f"FSDP only supports 1D TP, not {self._tp_spec.placements}" 278 ) 279 split_factor = self._tp_spec.num_shards_map[0] 280 assert ( 281 2 <= self._spmd_mesh.ndim <= 3 282 ), f"_spmd_mesh.ndim can only be 2 or 3 but got {self._spmd_mesh.ndim}." 283 self._spmd_placements: Tuple[Placement, ...] 284 dp_shard_tp_placement = ( 285 ( 286 _StridedShard(0, split_factor=split_factor) 287 if split_factor > 1 288 else Shard(0) 289 ), 290 self._tp_spec.placements[0], 291 ) 292 if self._spmd_mesh.ndim == 2: 293 self._spmd_placements = dp_shard_tp_placement 294 else: 295 assert self.mesh_info.replicate_mesh_dim == 0 296 self._spmd_placements = (Replicate(),) + dp_shard_tp_placement 297 self._sharding_spec = DTensorSpec( 298 self._spmd_mesh, 299 self._spmd_placements, 300 tensor_meta=self._tp_spec.tensor_meta, 301 ) 302 # NOTE: FSDP+TP does not support uneven sharding for now 303 # TODO: enable uneven sharding for FSDP+TP 304 if split_factor > 1: # FSDP has strided sharding on tensor dim 0 305 num_shards = self._sharding_spec.num_shards_map[0] 306 tensor_size_dim_0 = self._sharding_spec.shape[0] 307 if tensor_size_dim_0 % num_shards != 0: 308 raise NotImplementedError( 309 "FSDP+TP sharding does not support uneven sharding for now: " 310 f"tensor dim 0 has size {tensor_size_dim_0} which cannot be " 311 f"evenly sharded into {num_shards} shards." 312 ) 313 314 param_data = cast(DTensor, param)._local_tensor 315 else: 316 self._spmd_mesh = self.mesh_info.mesh 317 if isinstance(self.mesh_info, HSDPMeshInfo): 318 self._spmd_placements = (Replicate(), Shard(0)) 319 else: 320 self._spmd_placements = (Shard(0),) 321 self._sharding_spec = DTensorSpec( 322 self._spmd_mesh, 323 self._spmd_placements, 324 tensor_meta=TensorMeta( 325 param.size(), 326 param.stride(), 327 param.dtype, 328 ), 329 ) 330 param_data = param 331 self._orig_size = param_data.size() 332 self._contiguous_orig_stride = make_contiguous_strides_for(self._orig_size) 333 shard_rank = self.mesh_info.shard_mesh_rank 334 shard_world_size = self.mesh_info.shard_mesh_size 335 chunks = _chunk_with_empty(param_data, shard_world_size, dim=0) 336 sharded_param = chunks[shard_rank] 337 self.sharded_size = _get_dim0_chunked_size(sharded_param, param_data.size()) 338 self.contiguous_sharded_stride = make_contiguous_strides_for(self.sharded_size) 339 padded_sharded_size = chunks[0].size() # 0th always padded 340 padded_sharded_param = param_data.new_zeros(padded_sharded_size) 341 self.padded_sharded_param_size = padded_sharded_param.size() 342 if sharded_param.numel() > 0: 343 padded_sharded_param[: sharded_param.size(0)].copy_(sharded_param) 344 if self.offload_to_cpu and not padded_sharded_param.is_meta: 345 padded_sharded_param = padded_sharded_param.cpu() 346 if self.pin_memory: 347 padded_sharded_param = padded_sharded_param.pin_memory() 348 self._sharded_param_data = padded_sharded_param.view(-1) 349 self.sharded_param = nn.Parameter( 350 self.to_sharded_dtensor(padded_sharded_param[: sharded_param.size(0)]) 351 ) 352 self.sharded_param.requires_grad_(param.requires_grad) 353 # Let `param_data` be freed normally when its ref count reaches 0 when 354 # the `fully_shard` call returns to allow provided parameters to alias 355 self._setattr_on_modules(self.sharded_param) 356 self.sharded_state = ShardedState.SHARDED 357 358 def _init_sharded_post_forward_param_metadata(self, param: torch.Tensor) -> None: 359 mesh_info = self.post_forward_mesh_info 360 assert mesh_info is not None # mypy 361 param_data = param._local_tensor if isinstance(param, DTensor) else param 362 chunks = _chunk_with_empty(param_data, mesh_info.shard_mesh_size, dim=0) 363 self.sharded_post_forward_size = _get_dim0_chunked_size( 364 chunks[mesh_info.shard_mesh_rank], param_data.size() 365 ) 366 self.contiguous_sharded_post_forward_stride = make_contiguous_strides_for( 367 self.sharded_post_forward_size 368 ) 369 370 def init_dtype_attrs(self, mp_policy: MixedPrecisionPolicy): 371 param_dtype, reduce_dtype = (mp_policy.param_dtype, mp_policy.reduce_dtype) 372 self.orig_dtype = self.sharded_param.dtype 373 # Clamp `param_dtype` to `None` if no casting is required 374 if param_dtype == self.orig_dtype: 375 param_dtype = None 376 self.param_dtype = param_dtype 377 self.reduce_dtype = reduce_dtype 378 # None indicates that the mixed precision is not enabled 379 380 def _init_extensions(self) -> None: 381 inner_tensor = self._sharded_local_tensor 382 has_fsdp_pre_all_gather = hasattr(inner_tensor, "fsdp_pre_all_gather") 383 has_fsdp_post_all_gather = hasattr(inner_tensor, "fsdp_post_all_gather") 384 if has_fsdp_pre_all_gather != has_fsdp_post_all_gather: 385 raise AssertionError( 386 "Both fsdp_pre_all_gather and fsdp_post_all_gather should be defined " 387 f"if using all-gather extensions: {inner_tensor}" 388 ) 389 if has_fsdp_pre_all_gather: 390 if self.padded_sharded_param_size != self._sharded_local_tensor.size(): 391 raise NotImplementedError( 392 "FSDP all-gather extensions require even sharding on dim-0.\n" 393 f"{self._orig_size} is not divisible by FSDP world size {self.mesh_info.mesh.size()}." 394 ) 395 self._extensions_data = ExtensionsData() 396 self._unsharded_inner_tensors: List[torch.Tensor] = [] 397 398 def init_all_gather_outputs( 399 self, 400 all_gather_input_numels: List[int], 401 all_gather_input_dtypes: List[torch.dtype], 402 world_size: int, 403 device: torch.device, 404 force_recreate: bool = False, 405 ): 406 if not force_recreate and len(self.all_gather_outputs) > 0: 407 return # already initialized 408 self.all_gather_outputs = [ 409 torch.empty(torch.Size([numel * world_size]), dtype=dtype, device=device) 410 for numel, dtype in zip(all_gather_input_numels, all_gather_input_dtypes) 411 ] 412 413 def init_unsharded_param(self): 414 """ 415 [Note: Invariants for torch.compile Traceable FSDP2] 416 1. Under compile, we always re-populate the content of `self._unsharded_param` 417 per AllGather using the slow path. 418 2. Under compile, we always recreate `self.all_gather_outputs` per AllGather. 419 This is to ensure the buffer creation is internal to the graph and 420 avoid `self.all_gather_outputs` being captured as a graph input. 421 3. Under compile, at the end of `free_unsharded_param()`, we always clean up 422 `self.all_gather_outputs` and `self._unsharded_inner_tensors`, 423 to avoid them being captured as graph output. 424 425 With these invariants, only these tensors will be inputs to the graph: 426 - Sharded parameters 427 - Placeholders for the `self._unsharded_param` nn.Parameter 428 """ 429 if not ca.compiled_autograd_enabled and hasattr( 430 self, "_unsharded_param" 431 ): # after the 1st all-gather 432 inner_tensor = self._sharded_local_tensor 433 if not hasattr(inner_tensor, "fsdp_post_all_gather"): 434 return # already initialized 435 for tensor in self._unsharded_inner_tensors: 436 alloc_storage(tensor) 437 all_gather_outputs = self._unflatten_all_gather_outputs() 438 inner_tensor.fsdp_post_all_gather( 439 all_gather_outputs, 440 self._extensions_data.all_gather_metadata, 441 self.param_dtype or self.orig_dtype, 442 out=self._unsharded_param, 443 ) 444 self._extensions_data.clear() 445 return 446 inner_tensor = self._sharded_local_tensor 447 if not ca.compiled_autograd_enabled and hasattr( 448 inner_tensor, "fsdp_post_all_gather" 449 ): 450 all_gather_outputs = self._unflatten_all_gather_outputs() 451 ( 452 unsharded_tensor, 453 self._unsharded_inner_tensors, 454 ) = inner_tensor.fsdp_post_all_gather( 455 all_gather_outputs, 456 self._extensions_data.all_gather_metadata, 457 self.param_dtype or self.orig_dtype, 458 ) 459 self._extensions_data.clear() 460 else: 461 # For the default path (no post-all-gather), the all-gather output 462 # gives the unsharded parameter data directly 463 assert len(self.all_gather_outputs) == 1, f"{len(self.all_gather_outputs)}" 464 unsharded_tensor = self.all_gather_outputs[0] 465 unsharded_param = torch.as_strided( 466 unsharded_tensor, 467 self._orig_size, 468 self._contiguous_orig_stride, 469 storage_offset=0, 470 ) 471 if self.is_dtensor: 472 unsharded_param = _from_local_no_grad(unsharded_param, self._tp_spec) 473 if hasattr(self, "_unsharded_param"): 474 assert ca.compiled_autograd_enabled 475 with torch.no_grad(), torch.autograd._unsafe_preserve_version_counter( 476 self._unsharded_param 477 ): 478 torch.ops.fsdp.set_.default(self._unsharded_param, unsharded_param) 479 else: 480 self._unsharded_param = nn.Parameter( 481 unsharded_param, requires_grad=self.sharded_param.requires_grad 482 ) 483 484 def _unflatten_all_gather_outputs(self) -> Tuple[torch.Tensor, ...]: 485 return tuple( 486 t.view(-1, *s[1:]) 487 for t, s in zip( 488 self.all_gather_outputs, self._extensions_data.all_gather_input_sizes 489 ) 490 ) 491 492 def to_sharded(self) -> None: 493 self._setattr_on_modules(self.sharded_param) 494 self.free_unsharded_param() 495 self.sharded_state = ShardedState.SHARDED 496 497 def to_sharded_post_forward(self) -> None: 498 if self.is_dtensor: 499 raise NotImplementedError( 500 "Resharding to smaller mesh with TP is not supported yet" 501 ) 502 self._assert_in_states(ShardedState.UNSHARDED) 503 assert self.post_forward_mesh_info is not None # mypy 504 assert len(self.all_gather_outputs) == 1 505 shard_world_size = self.post_forward_mesh_info.shard_mesh_size 506 if (numel := self.all_gather_outputs[0].numel()) % shard_world_size != 0: 507 _raise_assert_with_print( 508 f"All-gather output size ({numel}) must be divisible by the shard " 509 f"world size ({shard_world_size})" 510 ) 511 shard_rank = self.post_forward_mesh_info.shard_mesh_rank 512 sharded_numel = numel // shard_world_size 513 self._sharded_post_forward_param_data = ( 514 self.all_gather_outputs[0].narrow( 515 0, sharded_numel * shard_rank, sharded_numel 516 ) 517 ).clone() # clone to be able to free all-gather output 518 sharded_post_forward_tensor = torch.as_strided( 519 self._sharded_post_forward_param_data, 520 size=self.sharded_post_forward_size, 521 stride=self.contiguous_sharded_post_forward_stride, 522 storage_offset=0, 523 ) 524 self._sharded_post_forward_param = nn.Parameter( 525 self.to_sharded_post_forward_dtensor(sharded_post_forward_tensor) 526 ) 527 self._setattr_on_modules(self._sharded_post_forward_param) 528 self.free_unsharded_param() 529 self.sharded_state = ShardedState.SHARDED_POST_FORWARD 530 531 def to_unsharded(self) -> None: 532 # Assume that the data has been allocated and all-gathered 533 set_requires_grad_if_needed(self.sharded_param, self._unsharded_param) 534 self._setattr_on_modules(self._unsharded_param) 535 if self.sharded_state == ShardedState.SHARDED_POST_FORWARD: 536 # The data is allocated in the default stream via the post-forward 537 # reshard and must be kept alive for the next all-gather copy-in. 538 # Since we call this method after the copy-out, the data's lifetime 539 # is ensured without further synchronization. 540 self._sharded_post_forward_param = None 541 self._sharded_post_forward_param_data = None # free 542 self.sharded_state = ShardedState.UNSHARDED 543 544 def _setattr_on_modules(self, param: nn.Parameter) -> None: 545 unsafe_setattr_param( 546 self._module_info.module, self._module_info.param_name, param 547 ) 548 for shared_module, shared_param_name in zip( 549 self._module_info.shared_modules, self._module_info.shared_param_names 550 ): 551 unsafe_setattr_param(shared_module, shared_param_name, param) 552 553 def to_sharded_dtensor(self, tensor: torch.Tensor) -> DTensor: 554 """ 555 Converts a local tensor representing either the sharded parameter or 556 sharded gradient to DTensor. 557 """ 558 if tensor.shape != self.sharded_size: 559 _raise_assert_with_print( 560 f"Expects size {self.sharded_size} but got {tensor.shape}" 561 ) 562 return _from_local_no_grad( 563 tensor, 564 self._sharding_spec, 565 ) 566 567 def to_sharded_post_forward_dtensor(self, tensor: torch.Tensor) -> DTensor: 568 if tensor.shape != self.sharded_post_forward_size: 569 _raise_assert_with_print( 570 f"Expects size {self.sharded_post_forward_size} but got {tensor.shape}" 571 ) 572 assert isinstance(self.post_forward_mesh_info, HSDPMeshInfo) 573 # TODO: Prefer this DTensor to be read-only and generalize the 574 # placement once we support TP. 575 post_forward_sharding_spec = DTensorSpec( 576 self.post_forward_mesh_info.mesh, 577 (Replicate(), Shard(0)), 578 tensor_meta=self._sharding_spec.tensor_meta, 579 ) 580 return _from_local_no_grad(tensor, post_forward_sharding_spec) 581 582 def to_accumulated_grad_if_needed(self) -> None: 583 # Access `_unsharded_param` to bypass the sharded state check since we 584 # prefer to reshard before upcasting the gradient to save memory 585 if ( 586 self.reduce_dtype is None 587 or self._unsharded_param.grad is None 588 or self._unsharded_param.grad.dtype == self.reduce_dtype 589 ): 590 return 591 unsharded_grad = self._unsharded_param.grad 592 self._unsharded_param.grad = None 593 self.unsharded_accumulated_grad = unsharded_grad.to(self.reduce_dtype) 594 595 def accumulate_unsharded_grad_if_needed(self) -> None: 596 if ( 597 self.unsharded_accumulated_grad is not None 598 and self.unsharded_param.grad is not None 599 ): 600 self.unsharded_accumulated_grad += self.unsharded_param.grad 601 self.unsharded_param.grad = None 602 603 def alloc_all_gather_outputs(self) -> None: 604 for tensor in self.all_gather_outputs: 605 alloc_storage(tensor) 606 607 def free_unsharded_param(self) -> None: 608 for tensor in itertools.chain( 609 self.all_gather_outputs, self._unsharded_inner_tensors 610 ): 611 free_storage(tensor) 612 if ca.compiled_autograd_enabled: 613 self.all_gather_outputs = [] 614 self._unsharded_inner_tensors = [] 615 616 @property 617 def all_gather_inputs(self) -> List[torch.Tensor]: # 1D 618 self._assert_in_states(ShardedState.SHARDED, ShardedState.SHARDED_POST_FORWARD) 619 if self.sharded_state == ShardedState.SHARDED: 620 if not ca.compiled_autograd_enabled and hasattr( 621 self._sharded_local_tensor, "fsdp_pre_all_gather" 622 ): 623 sharded_local_tensor = self._sharded_local_tensor 624 if self.offload_to_cpu: 625 sharded_local_tensor = sharded_local_tensor.to( 626 self.device, non_blocking=True 627 ) 628 ( 629 all_gather_inputs, 630 self._extensions_data.all_gather_metadata, 631 ) = sharded_local_tensor.fsdp_pre_all_gather(self.mesh_info.mesh) 632 self._extensions_data.all_gather_input_sizes = [ 633 t.size() for t in all_gather_inputs 634 ] 635 return [t.view(-1) for t in all_gather_inputs] 636 sharded_param_data = self._sharded_param_data 637 if self.offload_to_cpu: 638 sharded_param_data = sharded_param_data.to( 639 self.device, non_blocking=True 640 ) 641 return [_to_dtype_if_needed(sharded_param_data, self.param_dtype)] 642 elif self.sharded_state == ShardedState.SHARDED_POST_FORWARD: 643 if not ca.compiled_autograd_enabled and hasattr( 644 self._sharded_local_tensor, "fsdp_pre_all_gather" 645 ): 646 raise NotImplementedError 647 all_gather_input = _to_dtype_if_needed( 648 cast(torch.Tensor, self._sharded_post_forward_param_data), 649 self.param_dtype, 650 ) 651 return [all_gather_input] 652 return [torch.empty(0)] # mypy 653 654 @property 655 def unsharded_param(self) -> nn.Parameter: # ND 656 self._assert_in_states(ShardedState.UNSHARDED) 657 return self._unsharded_param 658 659 @property 660 def unsharded_grad_data(self) -> torch.Tensor: 661 grad = self.unsharded_param.grad 662 assert grad is not None, "Expects unsharded_param.grad to not be None" 663 return self._get_grad_inner_tensor(grad) 664 665 @property 666 def unsharded_accumulated_grad_data(self) -> torch.Tensor: 667 grad = self.unsharded_accumulated_grad 668 assert grad is not None, "Expects unsharded_accumulated_grad to not be None" 669 return self._get_grad_inner_tensor(grad) 670 671 def _get_grad_inner_tensor(self, grad: torch.Tensor) -> torch.Tensor: 672 if self.is_dtensor: 673 if isinstance(grad, AsyncCollectiveTensor): 674 grad = grad.wait() 675 assert isinstance(grad, DTensor), f"{type(grad)}" 676 if any(pl.is_partial() for pl in grad.placements): 677 placements = [ 678 Replicate() if pl.is_partial() else pl for pl in grad.placements 679 ] 680 grad = grad.redistribute(placements=placements) 681 grad = grad._local_tensor 682 return grad 683 684 @property 685 def _sharded_local_tensor(self) -> torch.Tensor: 686 return cast(DTensor, self.sharded_param)._local_tensor 687 688 def _assert_in_states(self, *states: ShardedState) -> None: 689 if self.sharded_state not in states: 690 _raise_assert_with_print( 691 f"Expects to be in one of {states}, not {self.sharded_state}" 692 ) 693 694 def reset_sharded_param(self): 695 # For ops like `nn.Module._apply` or `load_state_dict(assign=True)` 696 # that change the sharded parameter tensor, we may need to re-pad the 697 # sharded local tensor and re-save the reference. 698 module_info = self._module_info 699 new_param = getattr(module_info.module, module_info.param_name) 700 if new_param is not self.sharded_param: 701 if torch.__future__.get_swap_module_params_on_conversion(): 702 raise AssertionError( 703 f"Expects swap_tensors to preserve object but got {new_param} " 704 f"instead of {self.sharded_param}" 705 ) 706 self.sharded_param = new_param 707 local_tensor = new_param._local_tensor 708 if local_tensor.is_meta: 709 return 710 padded_sharded_size = self.padded_sharded_param_size 711 if local_tensor.size() != padded_sharded_size: 712 padded_local_tensor = local_tensor.new_zeros(padded_sharded_size) 713 padded_local_tensor[: local_tensor.size(0)].copy_(local_tensor) 714 local_tensor = padded_local_tensor 715 if self.pin_memory and not local_tensor.is_pinned(): 716 local_tensor = local_tensor.cpu().pin_memory() 717 self._sharded_param_data = local_tensor.view(-1) 718 assert isinstance(self.sharded_param, DTensor) # mypy 719 self.sharded_param._local_tensor = local_tensor[: self.sharded_size[0]] 720 721 def __repr__(self): 722 return f"FSDPParam(fqn={self._param_fqn}, orig_size={self._orig_size})" 723 724 725def alloc_storage(tensor: torch.Tensor) -> None: 726 size = tensor.numel() * tensor.itemsize 727 if (storage := tensor.untyped_storage()).size() != size: 728 storage.resize_(size) 729 730 731def free_storage(tensor: torch.Tensor) -> None: 732 if (storage := tensor.untyped_storage()).size() != 0: 733 storage.resize_(0) 734 735 736# NOTE: These bypass `nn.Module.__setattr__` checks, which incur non-trivial 737# CPU overhead, if the module did not override it. For FSDP, we know we do not 738# need those checks when transitioning between sharded/unsharded parameters. 739def unsafe_setattr_param( 740 module: nn.Module, param_name: str, param: nn.Parameter 741) -> None: 742 if getattr(module.__setattr__, "__func__", None) is nn.Module.__setattr__: 743 module._parameters[param_name] = param 744 else: # slow path 745 setattr(module, param_name, param) 746 747 748def set_requires_grad_if_needed( 749 src_tensor: torch.Tensor, dst_tensor: torch.Tensor 750) -> None: 751 # Only call `requires_grad_` if needed to avoid the Python <> C++ context 752 # switch overhead 753 if src_tensor.requires_grad != dst_tensor.requires_grad: 754 dst_tensor.requires_grad_(src_tensor.requires_grad) 755