1# mypy: allow-untyped-defs 2from contextlib import contextmanager 3from typing import Optional 4 5import torch 6import torch.distributed as dist 7import torch.nn as nn 8from torch.distributed import distributed_c10d 9from torch.distributed._shard.sharded_tensor import ShardedTensor 10 11from .sharder import Sharder 12from .sharding_plan import ShardingPlan 13from .sharding_spec import ChunkShardingSpec, ShardingSpec 14 15 16def _shard_tensor( 17 tensor: torch.Tensor, sharding_spec: ShardingSpec, src_rank=0, process_group=None 18) -> ShardedTensor: 19 """ 20 Given a :class:`torch.Tensor`, it shards that tensor according to the provided 21 ``sharding_spec``. ``src_rank`` denotes the source rank which would be 22 used as the ground truth of the data which would be scattered as shards 23 across the rest of the ranks. 24 25 Args: 26 tensor (:class:`torch.Tensor`): Tensor needs to be sharded. 27 sharding_spec (:class:`torch.distributed._shard.sharding_spec.ShardingSpec`): The specification 28 describing how to shard the Tensor. 29 30 Keyword args: 31 src_rank (int, optional): The source rank which is used as the ground truth of 32 the data for the parameter that would be sharded and scattered 33 across the rest of the ranks. 34 Default: 0. 35 process_group (ProcessGroup, optional): The process group to work on. If None, 36 the default process group will be used. 37 38 Returns: 39 A :class:`ShardedTensor` sharded from the given tensor. 40 41 .. warning:: 42 Only :class:`torch.distributed._shard.sharding_spec.ChunkShardingSpec` is 43 currently supported as the ``sharding_spec``. 44 """ 45 if not tensor.is_contiguous(): 46 raise ValueError("input tensor is not a contiguous Tensor") 47 48 pg = ( 49 process_group 50 if process_group is not None 51 else distributed_c10d._get_default_group() 52 ) 53 world_size = dist.get_world_size(pg) 54 current_rank = dist.get_rank(pg) 55 56 # Validate src_rank and sharding_spec are same across all ranks. 57 gathered_list = [None] * world_size 58 dist.all_gather_object(gathered_list, (src_rank, sharding_spec), group=pg) 59 60 for idx, entry in enumerate(gathered_list): 61 if src_rank != entry[0]: # type: ignore[index] 62 raise ValueError( 63 f"src_rank={src_rank} on rank: {current_rank} does not " # type: ignore[index] 64 f"match with src_rank={entry[0]} on rank: {idx}" # type: ignore[index] 65 ) 66 if sharding_spec != entry[1]: # type: ignore[index] 67 raise ValueError( 68 f"sharding_spec={sharding_spec} on rank: {current_rank} does not " # type: ignore[index] 69 f"match with sharding_spec={entry[1]} on rank: {idx}" # type: ignore[index] 70 ) 71 72 st = sharding_spec.shard(tensor, src_rank=src_rank, process_group=pg) 73 74 return st 75 76 77def shard_parameter( 78 module: torch.nn.Module, 79 param_name: str, 80 sharding_spec: ShardingSpec, 81 src_rank=0, 82 process_group=None, 83): 84 """ 85 Given a :class:`torch.nn.Module`, a ``param_name`` for a parameter in that 86 module, it shards that parameter according to the provided 87 ``sharding_spec``. ``src_rank`` denotes the source rank which would be 88 used as the ground truth of the data which would be scattered as shards 89 across the rest of the ranks. 90 91 This method replaces ``module.param_name`` with a 92 :class:`torch.distributed._sharded_tensor.ShardedTensor` 93 94 Args: 95 module (:class:`torch.nn.Module`): Module whose parameter needs to be sharded. 96 param_name (str): Name of the parameter of ``module`` that needs to be sharded. 97 sharding_spec (:class:`torch.distributed._shard.sharding_spec.ShardingSpec`): The specification 98 describing how to shard the Tensor. 99 100 Keyword args: 101 src_rank (int, optional): The source rank which is used as the ground truth of 102 the data for the parameter that would be sharded and scattered 103 across the rest of the ranks. 104 Default: 0. 105 process_group (ProcessGroup, optional): The process group to work on. If None, 106 the default process group will be used. 107 108 .. warning:: 109 Only :class:`torch.distributed._shard.sharding_spec.ChunkShardingSpec` is 110 currently supported as the ``sharding_spec``. 111 """ 112 # Perform some validation first. 113 if not hasattr(module, param_name): 114 raise AttributeError(f"{module._get_name()} has no attribute `{param_name}`") 115 116 tensor = getattr(module, param_name) 117 if not isinstance(tensor, torch.Tensor): 118 raise ValueError( 119 f"Expected {type(module).__name__}.{param_name} to be a Tensor, but found {type(tensor).__name__}" 120 ) 121 122 if not tensor.is_contiguous(): 123 raise ValueError(f"param: {param_name} is not a contiguous Tensor") 124 125 st = _shard_tensor(tensor, sharding_spec, src_rank, process_group) 126 127 # Replace param with ShardedTensor. 128 module.register_parameter(param_name, nn.Parameter(st)) 129 130 131# Tracks the current process group in the load context manager. 132_CURRENT_PROCESS_GROUP: Optional[dist.ProcessGroup] = None 133 134 135@contextmanager 136def load_with_process_group(process_group): 137 """ 138 Context manager to set the process group with which to load a ShardedTensor. 139 """ 140 global _CURRENT_PROCESS_GROUP 141 if _CURRENT_PROCESS_GROUP is not None: 142 raise RuntimeError( 143 'ProcessGroup already set by previous "load_with_process_group" ' 144 "context manager" 145 ) 146 _CURRENT_PROCESS_GROUP = process_group 147 try: 148 yield process_group 149 finally: 150 _CURRENT_PROCESS_GROUP = None 151 152 153def _get_current_process_group(): 154 """ 155 Retrieves the current process group set by ``load_with_process_group``. 156 If not set, it just returns the default group. 157 """ 158 global _CURRENT_PROCESS_GROUP 159 if _CURRENT_PROCESS_GROUP is None: 160 return distributed_c10d._get_default_group() 161 else: 162 return _CURRENT_PROCESS_GROUP 163 164 165def _reshard_output( 166 module: torch.nn.Module, resharding_spec: ShardingSpec 167) -> torch.nn.Module: 168 """ 169 Hook a module with output resharding in the forward pass according 170 to the given ``resharding_spec``. 171 172 Args: 173 module (:class:`torch.nn.Module`): Module whose output needs to be resharded. 174 resharding_spec (:class:`torch.distributed._shard.sharding_spec.ShardingSpec`): 175 The specification describing how the output of the module will be resharded. 176 177 Returns: 178 A :class:`torch.nn.Module` object with reshard API hooked. 179 """ 180 181 def hook_func(_module, _input, output): 182 if isinstance(output, ShardedTensor): 183 return output.reshard(resharding_spec) 184 return output 185 186 module.register_forward_hook(hook_func) 187 return module 188 189 190def _collect_local_shard(module: torch.nn.Module) -> torch.nn.Module: 191 """ 192 Hook a module with local shards collection in the forward pass. 193 194 This API is typically used to convert a sharded representation back to data parallel 195 representation. In particular, it returns the local tensor for this Shard. If the 196 size along the sharding dimension for the local tensor is 1, this dimension is removed 197 from the final result. For example a [4, 16] ShardedTensor across 4 ranks is typically 198 a local Tensor of size [16] across each rank and not [1, 16] across each rank. 199 200 Args: 201 module (:class:`torch.nn.Module`): Module whose output is ShardedTensor and the 202 local tensor value needs to be returned. 203 204 Returns: 205 A :class:`torch.nn.Module` object with collection API hooked. 206 """ 207 208 def hook_func(_module, _input, output): 209 if isinstance(output, ShardedTensor): 210 local_tensor = output.local_tensor() 211 # Squeeze the # of dimensions manually, only applicable to ChunkShardingSpec 212 sharding_spec = output._sharding_spec 213 if ( 214 isinstance(sharding_spec, ChunkShardingSpec) 215 and local_tensor.size(sharding_spec.dim) == 1 # type: ignore[attr-defined, arg-type] 216 ): 217 local_tensor = local_tensor.squeeze( 218 output._sharding_spec.dim # type: ignore[attr-defined] 219 ) 220 return local_tensor 221 222 module.register_forward_hook(hook_func) 223 return module 224 225 226def shard_module(module: nn.Module, plan: ShardingPlan, src_rank=0, process_group=None): 227 """ 228 Shards a given module according to the provided sharding `plan`. This method 229 first shards all the parameters according to the given sharding `plan`. Then if 230 `output_plan` and `return_local_tensor` are specified in the sharding `plan`, it 231 will tag the output of modules according `output_plan`, convert the module's 232 output back to data parallel according to `return_local_tensor`. 233 234 Needs to be called on all ranks in an SPMD fashion. 235 236 Args: 237 module (:class:`torch.nn.Module`): The module to apply sharding to 238 plan (:class:`torch.distributed._shard.sharding_plan.ShardingPlan`): 239 The ShardingPlan which specified param name to ShardingSpec to apply to 240 each parameter. 241 242 Keyword args: 243 src_rank (int, optional): The source rank which is used as the ground truth of 244 the data for the module that would be sharded and scattered across the rest 245 of the ranks. 246 Default: 0. 247 process_group (ProcessGroup, optional): The process group to work on. If None, 248 the default process group will be used. 249 """ 250 # record Sharder paths for sanity check on the plan to ensure items in the plan 251 # does not conflict with the submodule tree that the Sharder is working with 252 sharder_paths = [] 253 for name, spec in plan.plan.items(): 254 if isinstance(spec, Sharder): 255 sharder_paths.append(name) 256 257 # shard the parameter according to the ShardingPlan 258 for name, spec in plan.plan.items(): 259 if isinstance(spec, ShardingSpec): 260 # if found a sharding spec, try to shard the parameter 261 module_path, _, param_name = name.rpartition(".") 262 263 for sharder_path in sharder_paths: 264 if module_path.startswith(sharder_path): 265 raise RuntimeError( 266 f"ShardingPlan is in-valid, trying to shard a parameter: {name}," 267 f" but there's already a Sharder entry for module {sharder_path}," 268 f" parameter sharding should not conflict with the submodule tree" 269 f" that a Sharder is working with!" 270 ) 271 272 mod = module.get_submodule(module_path) 273 shard_parameter( 274 mod, param_name, spec, src_rank=src_rank, process_group=process_group 275 ) 276 elif isinstance(spec, Sharder): 277 parent_mod_path, _, mod_name = name.rpartition(".") 278 if name == "": 279 raise KeyError("Module path must not be empty for custom sharder!") 280 mod = module.get_submodule(name) 281 parent_mod = module.get_submodule(parent_mod_path) 282 sharded_mod = spec.shard(mod) 283 # swap this submodule with the sharded module 284 parent_mod.mod_name = sharded_mod 285 else: 286 raise TypeError( 287 f"Only `ShardingSpec` and `Sharder` are supported to shard '{name}'" 288 ) 289 290 # reshard output if there's an entry in `reshard_output` for this module 291 if plan.output_plan is not None: 292 for module_path, output_spec in plan.output_plan.items(): 293 if isinstance(output_spec, ShardingSpec): 294 mod = module.get_submodule(module_path) 295 _reshard_output(mod, output_spec) 296 else: 297 raise TypeError( 298 f"Only `ShardingSpec` is supported as output_plan for '{module_path}'" 299 ) 300 # convert the output back to data parallel for the modules appears in 301 # `return_local_tensor` of the plan, we will call `_collect_local_shard` 302 # to collect the local tensor for output of modules 303 if plan.return_local_tensor is not None: 304 for module_path in plan.return_local_tensor: 305 mod = module.get_submodule(module_path) 306 _collect_local_shard(mod) 307