1import itertools 2from typing import List, Optional, Set, Tuple, Union 3 4import torch 5import torch.distributed as dist 6import torch.nn as nn 7from torch.distributed.device_mesh import _get_device_handle 8from torch.distributed.tensor import DeviceMesh, DTensor, init_device_mesh 9from torch.utils._python_dispatch import is_traceable_wrapper_subclass 10 11from ._fsdp_common import _is_composable_with_fsdp, FSDPMeshInfo, HSDPMeshInfo 12from ._fsdp_state import _get_module_fsdp_state 13 14 15def _get_post_forward_mesh_info( 16 reshard_after_forward: Union[bool, int], mesh_info: FSDPMeshInfo 17) -> Optional[FSDPMeshInfo]: 18 shard_mesh_size = mesh_info.shard_mesh_size 19 if not isinstance(reshard_after_forward, (bool, int)): 20 raise ValueError( 21 "reshard_after_forward should be a bool or an int representing the " 22 f"group size to reshard to, not {reshard_after_forward}" 23 ) 24 # NOTE: `isinstance(False, int)` returns `True`. 25 if not isinstance(reshard_after_forward, bool) and isinstance( 26 reshard_after_forward, int 27 ): 28 if ( 29 reshard_after_forward < 1 30 or reshard_after_forward > shard_mesh_size 31 or shard_mesh_size % reshard_after_forward != 0 32 ): 33 raise ValueError( 34 "If passing reshard_after_forward as an int, it should be a " 35 f"factor of {shard_mesh_size}, not {reshard_after_forward}" 36 ) 37 elif reshard_after_forward == 1: 38 reshard_after_forward = False 39 elif reshard_after_forward == shard_mesh_size: 40 reshard_after_forward = True 41 post_forward_mesh_info = None 42 if reshard_after_forward is True: 43 post_forward_mesh_info = mesh_info 44 elif reshard_after_forward is not False: # int case 45 # For HSDP, we can flatten the two replicate dims into the 0th dim 46 post_forward_mesh_tensor = mesh_info.mesh.mesh.view(-1, reshard_after_forward) 47 post_forward_mesh = DeviceMesh( 48 mesh_info.mesh.device_type, post_forward_mesh_tensor 49 ) 50 post_forward_mesh_info = HSDPMeshInfo( 51 post_forward_mesh, shard_mesh_dim=1, replicate_mesh_dim=0 52 ) 53 return post_forward_mesh_info 54 55 56def _init_default_fully_shard_mesh() -> DeviceMesh: 57 """Default to global CUDA mesh if possible else global CPU mesh.""" 58 if not dist.distributed_c10d.is_initialized(): 59 dist.distributed_c10d.init_process_group() 60 default_pg = dist.distributed_c10d._get_default_group() 61 device_type = "cuda" if torch.cuda.is_available() else "cpu" 62 mesh = init_device_mesh(device_type, mesh_shape=(default_pg.size(),)) 63 return mesh 64 65 66def _get_device_from_mesh(mesh: DeviceMesh) -> torch.device: 67 if mesh.device_type == "cpu": 68 return torch.device("cpu") 69 device_handle = _get_device_handle(mesh.device_type) 70 return torch.device(mesh.device_type, device_handle.current_device()) 71 72 73def _get_managed_modules(root_modules: Tuple[nn.Module, ...]) -> List[nn.Module]: 74 modules: List[nn.Module] = [] 75 root_modules_set = set(root_modules) 76 # Track visisted modules to avoid visiting shared modules multiple times 77 visited_modules: Set[nn.Module] = set() 78 79 def dfs(module: nn.Module) -> None: 80 """ 81 Runs a DFS to collect managed modules, not recursing into modules with 82 a non-composable API or ``fully_shard`` already applied. 83 """ 84 if not _is_composable_with_fsdp(module): 85 return 86 elif ( 87 module not in root_modules_set 88 and _get_module_fsdp_state(module) is not None 89 ): 90 return # nested `fully_shard` module 91 visited_modules.add(module) 92 for submodule in module.children(): 93 if submodule not in visited_modules: 94 dfs(submodule) 95 modules.append(module) 96 97 for root_module in root_modules: 98 dfs(root_module) 99 return modules 100 101 102def _verify_managed_param(name: str, param: nn.Parameter) -> None: 103 """ 104 Verify if the parameter is accepted by fully_shard. The only restriction now 105 is that the parameter cannot be a scalar tensor (param.numel == 0) since we 106 need at least one dim to shard. 107 """ 108 if len(param.shape) == 0: 109 raise ValueError( 110 "fully_shard doesn't support salar parameters. " 111 f"Change {name} to a 1D tensor with numel equal to 1." 112 ) 113 114 115def _get_managed_states( 116 modules: List[nn.Module], 117) -> Tuple[List[nn.Parameter], List[torch.Tensor]]: 118 params: List[nn.Parameter] = [] 119 buffers: List[torch.Tensor] = [] 120 # Track visited parameters/buffers to avoid visiting shared parameters and 121 # buffers multiple times 122 visited_params: Set[nn.Parameter] = set() 123 visited_buffers: Set[torch.Tensor] = set() 124 for module in modules: 125 for name, param in module.named_parameters(recurse=False): 126 if param not in visited_params: 127 _verify_managed_param(name, param) 128 params.append(param) 129 visited_params.add(param) 130 for buffer in module.buffers(recurse=False): 131 if buffer not in visited_buffers: 132 buffers.append(buffer) 133 visited_buffers.add(buffer) 134 return params, buffers 135 136 137def _move_states_to_device( 138 params: List[nn.Parameter], 139 buffers: List[torch.Tensor], 140 device: torch.device, 141) -> None: 142 """ 143 We have FSDP move states to device for simpler and faster initialization 144 since FSDP almost always uses CUDA for training. We move parameters/buffers 145 rather than modules since modules to support ignoring parameters/buffers in 146 the future. 147 """ 148 # Follow the logic in `nn.Module._apply` 149 for tensor in itertools.chain(params, buffers): 150 if tensor.device == device or tensor.device.type == "meta": 151 # Keep meta-device tensors on meta device for deferred init 152 continue 153 if isinstance(tensor, DTensor): 154 if (dtensor_mesh_type := tensor.device_mesh.device_type) != device.type: 155 raise ValueError( 156 "Requires DTensor to have mesh of the same type as the FSDP mesh " 157 f"but got {dtensor_mesh_type} for DTensor and {device.type} for FSDP" 158 ) 159 raise AssertionError( 160 f"Expects DTensor to be moved to {dtensor_mesh_type} but got {tensor.device}" 161 ) 162 tensor_ = tensor 163 if is_traceable_wrapper_subclass(tensor_): 164 with torch.no_grad(): # avoid autograd increasing C++ refcount by 1 165 tensor_on_device = nn.Parameter(tensor.to(device)) 166 torch.utils.swap_tensors(tensor, tensor_on_device) 167 else: 168 tensor.data = tensor.to(device) 169