1# mypy: allow-untyped-defs 2import contextlib 3import logging 4from typing import Any, cast, Dict, List, NamedTuple, Optional, Set, Tuple 5 6import torch 7import torch._dynamo.compiled_autograd as ca 8import torch.distributed as dist 9import torch.nn as nn 10from torch.distributed.fsdp._common_utils import _named_parameters_with_duplicates 11from torch.profiler import record_function 12from torch.utils._pytree import tree_flatten, tree_unflatten 13from torch.utils.hooks import RemovableHandle 14 15from ._fsdp_api import MixedPrecisionPolicy, OffloadPolicy 16from ._fsdp_collectives import ( 17 AllGatherResult, 18 foreach_all_gather, 19 foreach_all_gather_copy_out, 20 foreach_reduce, 21) 22from ._fsdp_common import FSDPMeshInfo, HSDPMeshInfo, TrainingState 23from ._fsdp_param import FSDPParam, ParamModuleInfo, ShardedState 24 25 26logger = logging.getLogger("torch.distributed._composable.fsdp") 27 28_ModuleToHandleDict = Dict[nn.Module, RemovableHandle] # for state dict 29 30 31""" 32[Note: Overlapping all-gather copy-in and all-gather] 33For implicit forward prefetching, we want to overlap the next copy-in with the 34current all-gather. We do so using a separate copy-in stream. However, since 35we have the all-gather input as a view into the output, we must make sure to 36copy into different memory from the current all-gather's output. Thus, we keep 37a reference to the current all-gather's output and have the next FSDP parameter 38group free it after its copy-in. Finally, we have the last FSDP state flush the 39reference to avoid holding onto memory after forward. 40""" 41 42 43class FSDPCommContext: 44 """This has the communication state shared across FSDP states/parameter groups.""" 45 46 def lazy_init(self): 47 if not torch.cuda.is_available(): 48 raise RuntimeError("FSDP requires CUDA for streams") 49 # Setting the all-gather/reduce-scatter streams to be higher priority 50 # can help avoid some issues where their copies in/out are delayed and 51 # block computation (this is different from high-pri NCCL streams) 52 high_priority = -1 53 # All-gather state and copy-in stream allow overlapping the next 54 # copy-in with the current all-gather in forward; copy-in overlaps with 55 # reduce-scatter in backward without the separate copy-in stream 56 self.all_gather_copy_in_stream = torch.cuda.Stream(priority=high_priority) 57 # All-gather stream allows overlapping next all-gather with current 58 # forward compute 59 self.all_gather_stream = torch.cuda.Stream(priority=high_priority) 60 # Reduce-scatter stream gives separate execution "thread" for post- 61 # backward logic like pre/post-gradient division and reduce-scatter 62 self.reduce_scatter_stream = torch.cuda.Stream(priority=high_priority) 63 # Run the HSDP all-reduces concurrently with all-gather/reduce-scatter 64 # since collectives use different network resources and can overlap 65 # in the typical intra-node sharding / inter-node replication case 66 self.all_reduce_stream = torch.cuda.Stream() 67 # All-gather/reduce-scatter states keep references to collective 68 # tensors produced in one stream and used in another and accompanying 69 # CUDA events for synchronization 70 self.all_gather_state: Optional[AllGatherState] = None 71 self.reduce_scatter_state: Optional[ReduceScatterState] = None 72 # Post-forward order for explicit backward prefetching 73 self.post_forward_order: List[FSDPParamGroup] = [] # will cause ref cycles 74 75 def get_all_gather_streams( 76 self, training_state: TrainingState 77 ) -> Tuple[torch.cuda.Stream, torch.cuda.Stream]: 78 if training_state in (TrainingState.FORWARD, TrainingState.PRE_BACKWARD): 79 # Use separate streams for implicit prefetching 80 return self.all_gather_copy_in_stream, self.all_gather_stream 81 current_stream = torch.cuda.current_stream() 82 return current_stream, current_stream 83 84 85# See [Note: Overlapping all-gather copy-in and all-gather] 86class AllGatherState(NamedTuple): 87 all_gather_result: AllGatherResult 88 event: torch.cuda.Event # all-gather copy-out 89 90 91class ReduceScatterState(NamedTuple): 92 reduce_scatter_input: torch.Tensor 93 event: torch.cuda.Event # reduce-scatter event 94 95 96class FSDPParamGroup: 97 """This class represents a parameter group to communicate together.""" 98 99 _orig_dtype: torch.dtype 100 _reduce_dtype: Optional[torch.dtype] 101 102 def __init__( 103 self, 104 params: List[nn.Parameter], 105 modules: Tuple[nn.Module, ...], 106 mesh_info: FSDPMeshInfo, 107 post_forward_mesh_info: Optional[FSDPMeshInfo], 108 device: torch.device, 109 mp_policy: MixedPrecisionPolicy, 110 offload_policy: OffloadPolicy, 111 ): 112 self.modules = modules # permit ref cycle because 1:1 lifetime 113 param_module_infos = _get_param_module_infos(params, modules) 114 self.fsdp_params = [ 115 FSDPParam( 116 param, 117 module_info, 118 mesh_info, 119 post_forward_mesh_info, 120 device, 121 mp_policy, 122 offload_policy, 123 ) 124 for param, module_info in zip(params, param_module_infos) 125 ] 126 self.mesh_info = mesh_info 127 self.post_forward_mesh_info = post_forward_mesh_info 128 self.device = device 129 self.mp_policy = mp_policy 130 self._training_state = TrainingState.IDLE 131 # Group's sharded state always matches its parameters' sharded states 132 self._sharded_state = ShardedState.SHARDED 133 self._module_fqn: Optional[str] = None # prefixed from root module 134 # Only consider resetting sharded parameters once in lazy init since it 135 # can incur nontrivial overhead to reset them 136 self._reset_sharded_params: bool = False 137 138 # - Hook state 139 self._module_to_pre_save_state_dict_hook_handle: _ModuleToHandleDict = {} 140 self._module_to_pre_load_state_dict_hook_handle: _ModuleToHandleDict = {} 141 142 # - Communication and communication/computation overlap 143 self.comm_ctx = FSDPCommContext() 144 # Group's indices in the shared post-forward order 145 self._post_forward_indices: List[int] = [] 146 # Whether to reduce gradients at all (whether for FSDP or HSDP) 147 self.reduce_grads: bool = True 148 # Whether to all-reduce gradients for HSDP; only used if 149 # `self.reduce_grads` is true, in which case setting this to false 150 # means reduce-scatter but no all-reduce 151 self.all_reduce_grads: bool = True 152 # Whether to reshard parameters after backward (only useful for 153 # gradient accumulation) 154 self.reshard_after_backward: bool = True 155 # Optional custom reduce-scatter reduce op (e.g. to divide by a 156 # factor other than the shard world size) 157 self.reduce_scatter_reduce_op: Optional[dist.ReduceOp] = None 158 159 # - CUDA events for stream synchronization 160 # Holds the all-gather output buffer, sync objects, and metadata 161 self._all_gather_result: Optional[AllGatherResult] = None 162 # Holds the reduce-scatter/all-reduce view-out CUDA event that marks the end of 163 # the group's post-backward (e.g. reduce-scatter, all-reduce and div), which 164 # should be waited on at the end of backward 165 self._post_reduce_event: Optional[torch.cuda.Event] = None 166 # Holds the reshard-after-forward CUDA event when resharding to a 167 # different world size, which should be waited on in the next unshard 168 self._reshard_after_forward_event: Optional[torch.cuda.Event] = None 169 170 # Only for HSDP, if accumulating gradients without all-reduce, save the 171 # partial reduce output (only reduce-scattered but not all-reduced) 172 self._partial_reduce_output: Optional[torch.Tensor] = None 173 174 # TODO: remove this hook and hook register once 2D state dict is supported. 175 def _raise_not_implemented_if_2d(*args: Any, **kwargs: Any) -> None: 176 raise NotImplementedError( 177 "2D state_dict is under development. Please check " 178 "https://github.com/pytorch/pytorch/issues/129627 for more details." 179 ) 180 181 modules_with_2d_params: Set[nn.Module] = set() 182 for fsdp_param in self.fsdp_params: 183 module = fsdp_param._module_info.module 184 if len(fsdp_param._spmd_placements) > 1: 185 modules_with_2d_params.add(module) 186 for module in modules_with_2d_params: 187 module.register_state_dict_pre_hook(_raise_not_implemented_if_2d) 188 module._register_load_state_dict_pre_hook(_raise_not_implemented_if_2d) 189 190 # Initialization # 191 def _init_mp_dtypes(self) -> None: 192 for fsdp_param in self.fsdp_params: 193 fsdp_param.init_dtype_attrs(self.mp_policy) 194 orig_dtypes = {fsdp_param.orig_dtype for fsdp_param in self.fsdp_params} 195 if len(orig_dtypes) != 1: 196 # This can be relaxed if we copy-out for the reduce-scatter 197 raise AssertionError( 198 f"FSDP expects uniform original parameter dtype but got {orig_dtypes}" 199 ) 200 self._orig_dtype = next(iter(orig_dtypes)) 201 reduce_dtypes = {fsdp_param.reduce_dtype for fsdp_param in self.fsdp_params} 202 if len(reduce_dtypes) != 1: 203 # This can be relaxed if we issue one reduce-scatter per reduce 204 # dtype (but we would need a way for users to specify multiple 205 # reduce dtypes) 206 raise AssertionError( 207 f"FSDP expects uniform reduce dtype but got {reduce_dtypes}" 208 ) 209 self._reduce_dtype = next(iter(reduce_dtypes)) 210 211 def lazy_init(self): 212 # Lazy init should be idempotent 213 # Users may change or register parameters after construction time. 214 # For example, DoRA (https://arxiv.org/abs/2402.09353) initializes linear magnitudes based on 215 # other parameters (e.g. loaded from the state dict). 216 if self.is_sharded and not self._reset_sharded_params: 217 for fsdp_param in self.fsdp_params: 218 fsdp_param.reset_sharded_param() 219 self._reset_sharded_params = True 220 param_names_on_meta = [ 221 fsdp_param._param_fqn 222 for fsdp_param in self.fsdp_params 223 if fsdp_param.sharded_param.device.type == "meta" 224 ] 225 if param_names_on_meta: 226 raise RuntimeError( 227 "FSDP parameters should be materialized from meta device before training, " 228 f"but the following were still on meta device: {param_names_on_meta}\n" 229 "For example, call module.to_empty(device) to materialize to device and " 230 "call module.reset_parameters() on each module to initialize values." 231 ) 232 # Initialize mixed precision attributes lazily in case the user changes 233 # the parameter dtypes after construction time but before forward 234 self._init_mp_dtypes() 235 self._register_state_dict_hooks() 236 237 # Runtime # 238 def unshard(self, async_op: bool = False): 239 if self._all_gather_result is not None: # already called, pending wait 240 return 241 if self.is_unsharded: 242 return # no-op 243 if self._reshard_after_forward_event is not None: 244 # Resharded parameter data is allocated in the default stream and 245 # used in the all-gather streams 246 self._wait_all_gather_streams_on_event(self._reshard_after_forward_event) 247 self._reshard_after_forward_event = None 248 with record_function(self._with_fqn("FSDP::all_gather")): 249 self._all_gather_result = foreach_all_gather( 250 self.fsdp_params, 251 self._all_gather_process_group, 252 async_op, 253 *self.comm_ctx.get_all_gather_streams(self._training_state), 254 self.device, 255 ) 256 257 def wait_for_unshard(self): 258 """ 259 1. In forward with implict prefetching, to overlap the current copy-out 260 with the next all-gather, we save a reference to the current all-gather 261 result to free after the next copy-out. 262 2. Otherwise (explicit prefetching or in backward), we free the 263 all-gather result immediately after the current copy-out since we can 264 already overlap the current copy-out with the previous reduce-scatter. 265 """ 266 if not self._all_gather_result: 267 return # no preceding unshard 268 if self._training_state == TrainingState.FORWARD: # implicit prefetch 269 if prev_all_gather_state := self.comm_ctx.all_gather_state: 270 self._wait_all_gather_streams_on_event(prev_all_gather_state.event) 271 self.comm_ctx.all_gather_state = None # free the all-gather result 272 with record_function(self._with_fqn("FSDP::all_gather_copy_out")): 273 foreach_all_gather_copy_out( 274 self._all_gather_result, 275 self.fsdp_params, 276 self._all_gather_process_group, 277 ) 278 for fsdp_param in self.fsdp_params: 279 fsdp_param.init_unsharded_param() 280 self._to_unsharded() 281 all_gather_copy_out_event = torch.cuda.Event() 282 all_gather_copy_out_event.record() 283 if self._training_state == TrainingState.FORWARD: 284 self.comm_ctx.all_gather_state = AllGatherState( 285 self._all_gather_result, all_gather_copy_out_event 286 ) 287 else: 288 self._wait_all_gather_streams_on_event(all_gather_copy_out_event) 289 self._all_gather_result = None # free unless saved in `all_gather_state` 290 291 def _wait_all_gather_streams_on_event(self, event: torch.cuda.Event): 292 # Calling `unshard` before lazy init means streams are not initialized 293 if hasattr(self.comm_ctx, "all_gather_copy_in_stream"): 294 self.comm_ctx.all_gather_copy_in_stream.wait_event(event) 295 if hasattr(self.comm_ctx, "all_gather_stream"): 296 self.comm_ctx.all_gather_stream.wait_event(event) 297 298 def reshard(self): 299 if self._training_state == TrainingState.FORWARD: 300 if not self._reshard_after_forward: 301 return 302 if self._use_post_forward_mesh: 303 self._to_sharded_post_forward() 304 self._reshard_after_forward_event = torch.cuda.Event() 305 self._reshard_after_forward_event.record() 306 return 307 self._to_sharded() 308 309 def pre_forward( 310 self, module: nn.Module, args: Tuple[Any, ...], kwargs: Dict[str, Any] 311 ) -> Tuple[Tuple[Any, ...], Dict[str, Any]]: 312 if not ca.compiled_autograd_enabled: 313 logger.debug("%s", self._with_fqn("FSDP::pre_forward")) 314 with record_function(self._with_fqn("FSDP::pre_forward")): 315 self._training_state = TrainingState.FORWARD 316 self.unshard() 317 self.wait_for_unshard() 318 args, kwargs = self._register_post_backward_hook(args, kwargs) 319 return args, kwargs 320 321 def post_forward(self, module: nn.Module, input: Any, output: Any): 322 if not ca.compiled_autograd_enabled: 323 logger.debug("%s", self._with_fqn("FSDP::post_forward")) 324 with record_function(self._with_fqn("FSDP::post_forward")): 325 self.reshard() 326 self._record_post_forward() 327 self._training_state = TrainingState.IDLE 328 return output 329 330 def _record_post_forward(self) -> None: 331 # Since a group has one pre-backward unshard for each forward call 332 # before the backward, we record each usage (with multiplicity) 333 post_forward_index = len(self.comm_ctx.post_forward_order) 334 self.comm_ctx.post_forward_order.append(self) 335 self._post_forward_indices.append(post_forward_index) 336 337 def pre_backward(self, default_prefetch: bool, *unused: Any): 338 if self._training_state == TrainingState.PRE_BACKWARD: 339 return 340 if not ca.compiled_autograd_enabled: 341 logger.debug("%s", self._with_fqn("FSDP::pre_backward")) 342 with record_function(self._with_fqn("FSDP::pre_backward")): 343 self._training_state = TrainingState.PRE_BACKWARD 344 self.unshard() # no-op if prefetched 345 self.wait_for_unshard() 346 if default_prefetch and not ca.compiled_autograd_enabled: 347 self._backward_prefetch() 348 349 def post_backward(self, *unused: Any): 350 if not ca.compiled_autograd_enabled: 351 logger.debug("%s", self._with_fqn("FSDP::post_backward")) 352 self._training_state = TrainingState.POST_BACKWARD 353 with record_function(self._with_fqn("FSDP::post_backward_accumulate")): 354 for fsdp_param in self.fsdp_params: 355 fsdp_param.accumulate_unsharded_grad_if_needed() 356 with record_function(self._with_fqn("FSDP::post_backward_reshard")): 357 if not self.reduce_grads: 358 if self.reshard_after_backward: 359 self.reshard() 360 for fsdp_param in self.fsdp_params: 361 fsdp_param.to_accumulated_grad_if_needed() 362 return 363 # Save the autograd-computed gradients before resharding to only 364 # access the unsharded parameters when their data is present 365 fsdp_params_with_grad: List[FSDPParam] = [] 366 unsharded_grads: List[torch.Tensor] = [] 367 for fsdp_param in self.fsdp_params: 368 # May have an accumulated gradient of the reduce dtype if the 369 # previous backward did not reduce-scatter 370 if fsdp_param.unsharded_accumulated_grad is not None: 371 fsdp_params_with_grad.append(fsdp_param) 372 unsharded_grads.append(fsdp_param.unsharded_accumulated_grad_data) 373 fsdp_param.unsharded_accumulated_grad = None 374 elif fsdp_param.unsharded_param.grad is not None: 375 fsdp_params_with_grad.append(fsdp_param) 376 unsharded_grads.append(fsdp_param.unsharded_grad_data) 377 fsdp_param.unsharded_param.grad = None 378 if self.reshard_after_backward: 379 self.reshard() 380 if len(fsdp_params_with_grad) == 0: 381 return 382 with record_function(self._with_fqn("FSDP::post_backward_reduce")): 383 if self.comm_ctx.reduce_scatter_state is not None: 384 torch.cuda.current_stream().wait_event( 385 self.comm_ctx.reduce_scatter_state.event 386 ) 387 self.comm_ctx.reduce_scatter_state = None 388 ( 389 reduce_scatter_input, 390 reduce_scatter_event, 391 self._post_reduce_event, 392 self._partial_reduce_output, 393 ) = foreach_reduce( 394 fsdp_params_with_grad, 395 unsharded_grads, 396 self._reduce_scatter_process_group, 397 self.comm_ctx.reduce_scatter_stream, 398 self._orig_dtype, 399 self._reduce_dtype, 400 self.device, 401 self.reduce_scatter_reduce_op, 402 self._all_reduce_process_group if self._is_hsdp else None, 403 self.comm_ctx.all_reduce_stream, 404 self.all_reduce_grads, 405 self._partial_reduce_output, 406 ) 407 self.comm_ctx.reduce_scatter_state = ReduceScatterState( 408 reduce_scatter_input, reduce_scatter_event 409 ) 410 411 def finalize_backward(self): 412 if self._post_reduce_event is not None: 413 torch.cuda.current_stream().wait_event(self._post_reduce_event) 414 self._post_reduce_event = None 415 for fsdp_param in self.fsdp_params: 416 if fsdp_param.grad_offload_event is not None: 417 fsdp_param.grad_offload_event.synchronize() 418 fsdp_param.grad_offload_event = None 419 self._post_forward_indices.clear() 420 421 def _backward_prefetch(self) -> None: 422 if self._training_state == TrainingState.PRE_BACKWARD: 423 if not self._post_forward_indices: 424 # Can be cleared if running multiple `backward`s 425 return 426 curr_index = self._post_forward_indices.pop() 427 if (target_index := curr_index - 1) < 0: 428 return 429 # Prefetch naively using the reverse post-forward order, which may 430 # have mistargeted prefetches if not all modules used in forward 431 # are used in this backward 432 target_fsdp_param_group = self.comm_ctx.post_forward_order[target_index] 433 self._prefetch_unshard(target_fsdp_param_group, "backward") 434 435 @staticmethod 436 def _prefetch_unshard( 437 target_fsdp_param_group: "FSDPParamGroup", pass_type: str 438 ) -> None: 439 if pass_type == "backward": 440 training_state = TrainingState.PRE_BACKWARD 441 elif pass_type == "forward": 442 training_state = TrainingState.FORWARD 443 else: 444 raise ValueError(f"Unknown pass type: {pass_type}") 445 target_fqn = target_fsdp_param_group._module_fqn 446 with record_function( 447 f"FSDP::{pass_type}_prefetch for {target_fqn}" 448 ), target_fsdp_param_group.use_training_state(training_state): 449 target_fsdp_param_group.unshard() 450 451 # Utilities # 452 def _to_sharded(self): 453 if not self.is_sharded: 454 for fsdp_param in self.fsdp_params: 455 fsdp_param.to_sharded() 456 self._sharded_state = ShardedState.SHARDED 457 458 def _to_sharded_post_forward(self): 459 if not self.is_sharded_post_forward: 460 for fsdp_param in self.fsdp_params: 461 fsdp_param.to_sharded_post_forward() 462 self._sharded_state = ShardedState.SHARDED_POST_FORWARD 463 464 def _to_unsharded(self): 465 if not self.is_unsharded: 466 for fsdp_param in self.fsdp_params: 467 fsdp_param.to_unsharded() 468 self._sharded_state = ShardedState.UNSHARDED 469 470 @property 471 def is_sharded(self) -> bool: 472 return self._sharded_state == ShardedState.SHARDED 473 474 @property 475 def is_sharded_post_forward(self) -> bool: 476 return self._sharded_state == ShardedState.SHARDED_POST_FORWARD 477 478 @property 479 def is_unsharded(self) -> bool: 480 return self._sharded_state == ShardedState.UNSHARDED 481 482 @contextlib.contextmanager 483 def use_training_state(self, training_state: TrainingState): 484 old_training_state = self._training_state 485 self._training_state = training_state 486 try: 487 yield 488 finally: 489 self._training_state = old_training_state 490 491 # Hook Registration # 492 def _register_post_backward_hook( 493 self, args: Tuple[Any, ...], kwargs: Dict[str, Any] 494 ) -> Tuple[Tuple[Any, ...], Dict[str, Any]]: 495 # Compile relies on `root_post_backward_callback` to call each 496 # `FSDPParamGroup.post_backward` 497 if ca.compiled_autograd_enabled: 498 return args, kwargs 499 if not torch.is_grad_enabled(): 500 return args, kwargs 501 args_list, args_spec = tree_flatten(args) 502 kwargs_list, kwargs_spec = tree_flatten(kwargs) 503 args_kwargs_list = list(args_list) + list(kwargs_list) 504 inp_tensor_indices: List[int] = [] 505 inp_tensors: List[torch.Tensor] = [] 506 for i, obj in enumerate(args_kwargs_list): 507 if torch.is_tensor(obj) and obj.requires_grad: 508 inp_tensor_indices.append(i) 509 inp_tensors.append(obj) 510 if len(inp_tensors) == 0: 511 return args, kwargs # no tensors that require gradients 512 inp_tensors = RegisterPostBackwardFunction.apply(self, *inp_tensors) 513 for inp_tensor_idx, inp_tensor in zip(inp_tensor_indices, inp_tensors): 514 args_kwargs_list[inp_tensor_idx] = inp_tensor 515 args_list = args_kwargs_list[: len(args_list)] 516 kwargs_list = args_kwargs_list[len(args_list) :] 517 args = tree_unflatten(args_list, args_spec) 518 kwargs = tree_unflatten(kwargs_list, kwargs_spec) 519 return args, kwargs 520 521 def _register_state_dict_hooks(self) -> None: 522 num_pre_save_hooks = len(self._module_to_pre_save_state_dict_hook_handle) 523 num_pre_load_hooks = len(self._module_to_pre_load_state_dict_hook_handle) 524 assert ( 525 num_pre_save_hooks == num_pre_load_hooks 526 ), f"Pre-save: {num_pre_save_hooks} pre-load: {num_pre_load_hooks}" 527 if num_pre_save_hooks > 0: 528 return # already registered 529 modules_with_fsdp_params: Set[nn.Module] = { 530 fsdp_param._module_info.module for fsdp_param in self.fsdp_params 531 } 532 533 def to_sharded_hook(*args: Any, **kwargs: Any) -> None: 534 self._to_sharded() 535 536 for module in modules_with_fsdp_params: 537 self._module_to_pre_save_state_dict_hook_handle[ 538 module 539 ] = module.register_state_dict_pre_hook(to_sharded_hook) 540 self._module_to_pre_load_state_dict_hook_handle[ 541 module 542 ] = module._register_load_state_dict_pre_hook(to_sharded_hook) 543 544 # Properties # 545 @property 546 def _reshard_after_forward(self) -> bool: 547 return self.post_forward_mesh_info is not None 548 549 @property 550 def _use_post_forward_mesh(self) -> bool: 551 return ( 552 self._reshard_after_forward 553 and self.mesh_info != self.post_forward_mesh_info 554 ) 555 556 @property 557 def _is_hsdp(self) -> bool: 558 return isinstance(self.mesh_info, HSDPMeshInfo) 559 560 @property 561 def _all_gather_process_group(self) -> dist.ProcessGroup: 562 mesh_info = ( 563 cast(FSDPMeshInfo, self.post_forward_mesh_info) 564 if self.is_sharded_post_forward 565 else self.mesh_info 566 ) 567 assert isinstance(mesh_info, FSDPMeshInfo) 568 return mesh_info.shard_process_group 569 570 @property 571 def _reduce_scatter_process_group(self) -> dist.ProcessGroup: 572 assert isinstance(self.mesh_info, FSDPMeshInfo) 573 return self.mesh_info.shard_process_group 574 575 @property 576 def _all_reduce_process_group(self) -> dist.ProcessGroup: 577 assert isinstance(self.mesh_info, HSDPMeshInfo) 578 return self.mesh_info.replicate_process_group 579 580 def _with_fqn(self, label: str) -> str: 581 if self._module_fqn: 582 return f"{label} ({self._module_fqn})" 583 return label 584 585 def __repr__(self): 586 return f"FSDPParamGroup(fqn={self._module_fqn})" 587 588 589def _get_param_module_infos( 590 params: List[nn.Parameter], modules: Tuple[nn.Module, ...] 591) -> List[ParamModuleInfo]: 592 """ 593 Shared parameter: lin1.weight = lin2.weight 594 Shared module: mlp.lin1 = mlp.lin2 595 We do not remove duplicates when traversing both modules and parameters to 596 find shared modules' parameters and shared parameters within a module. 597 """ 598 params_set = set(params) 599 param_to_module_info: Dict[nn.Parameter, ParamModuleInfo] = {} 600 for module in modules: 601 for _, submodule in module.named_modules(remove_duplicate=False): 602 for param_name, param in _named_parameters_with_duplicates( 603 submodule, recurse=False 604 ): 605 if param in params_set: 606 if param not in param_to_module_info: 607 param_to_module_info[param] = ParamModuleInfo( 608 submodule, param_name 609 ) 610 else: 611 param_to_module_info[param].shared_modules.append(submodule) 612 param_to_module_info[param].shared_param_names.append( 613 param_name 614 ) 615 if len(param_to_module_info) != len(params): 616 raise AssertionError(f"Some parameters are not in the module tree of {module}") 617 return [param_to_module_info[param] for param in params] 618 619 620class RegisterPostBackwardFunction(torch.autograd.Function): 621 @staticmethod 622 def forward(ctx, param_group: FSDPParamGroup, *inputs: torch.Tensor): 623 # All tensors in `inputs` should require gradient 624 ctx.param_group = param_group 625 return inputs 626 627 @staticmethod 628 def backward(ctx, *grads: torch.Tensor): 629 ctx.param_group.post_backward() 630 return (None,) + grads 631