1# mypy: allow-untyped-decorators 2# mypy: allow-untyped-defs 3import weakref 4from typing import Any, cast, Dict, Iterable, List, NoReturn, Optional, Set, Tuple 5 6import torch 7import torch.nn as nn 8from torch.distributed._composable_state import _State 9from torch.nn.parallel import DistributedDataParallel 10 11from .contract import _get_registry, contract 12 13 14_ROOT_MODULE_PREFIX = "" 15 16 17class _ReplicateState(_State): 18 def __init__(self) -> None: 19 super().__init__() 20 self.module: nn.Module = nn.ParameterList() 21 self.has_initialized: bool = False 22 self._param_list: nn.ParameterList = nn.ParameterList() 23 # TODO(@fegin): this variable is originally create for testing, we 24 # should remove this if possible. 25 self._orig_module = self.module 26 self._param_names: List[str] = [] 27 self._no_sync: bool = False 28 self._init_args: Optional[Tuple[Any, ...]] = None 29 self._init_kwargs: Dict[str, Any] = {} 30 self._comm_hook_args: List[Any] = [] 31 32 def _collect_params( 33 self, 34 module: nn.Module, 35 ignored_modules: Set[nn.Module], 36 ignored_params: Set[nn.Parameter], 37 prefix: str = _ROOT_MODULE_PREFIX, 38 ) -> None: 39 # skip if managed by fully_sharded API 40 if _is_fully_sharded(module): 41 return 42 43 # if a module is ignored, all descendants of the module are ignored. 44 if module in ignored_modules: 45 return 46 47 recurse_prefix = ( 48 f"{prefix}." if prefix != _ROOT_MODULE_PREFIX else _ROOT_MODULE_PREFIX 49 ) 50 51 for n, p in module.named_parameters(recurse=False): 52 if p not in ignored_params: 53 self._param_list.append(p) 54 self._param_names.append(f"{recurse_prefix}{n}") 55 56 for name, child_module in module.named_children(): 57 self._collect_params( 58 child_module, 59 ignored_modules, 60 ignored_params, 61 prefix=f"{recurse_prefix}{name}", 62 ) 63 64 def lazy_init(self) -> None: 65 @torch._disable_dynamo(recursive=True) 66 def _lazy_init(): 67 assert self._init_args is not None 68 self.init(*self._init_args, **self._init_kwargs) 69 self.register_comm_hook() 70 self._init_args = () 71 self._init_kwargs = {} 72 73 _lazy_init() 74 75 def init( 76 self, 77 module: nn.Module, 78 ignored_modules: Set[nn.Module], 79 **kwargs, 80 ) -> None: 81 if self.has_initialized: 82 return 83 84 self.has_initialized = True 85 86 device_mesh = kwargs.get("device_mesh", None) 87 self.module = module 88 ignored_params = {p for m in ignored_modules for p in m.parameters()} 89 for submodule in module.modules(): 90 if _is_fully_sharded(submodule): 91 ignored_params.update(submodule.parameters()) 92 from torch.distributed.tensor.parallel.ddp import _localize_dtensor 93 94 _localize_dtensor(module, ignored_params=ignored_params) 95 self._collect_params(module, ignored_modules, ignored_params) 96 97 if "device_id" in kwargs: 98 # replicate() supports a small usability enhancement where 99 # user can pass in device_id as a Union[int, torch.device] even for 100 # CPU devices so users don't have to change code for CPU/GPU runs. 101 # We derive the right device_ids to feed into DDP to support this. 102 if kwargs["device_id"] is not None: 103 device_id = kwargs["device_id"] 104 # Convert to device_ids that DDP expects. 105 if isinstance(device_id, torch.device) and device_id.type == "cpu": 106 # CPU modules receive device_ids None 107 kwargs["device_ids"] = None 108 else: 109 # GPU modules expect device_ids=[cuda_device] 110 kwargs["device_ids"] = [device_id] 111 else: 112 kwargs["device_ids"] = None 113 kwargs.pop("device_id") 114 115 self._ddp = DistributedDataParallel(self._param_list, **kwargs) 116 # Weakref to the DDP instance is currently only used for testing. 117 replicate.state(self.module)._ddp_weakref = weakref.ref(self._ddp) 118 119 def register_comm_hook(self) -> None: 120 for comm_args, comm_kwargs in self._comm_hook_args: 121 self._ddp.register_comm_hook(*comm_args, **comm_kwargs) 122 self._comm_hook_args.clear() 123 124 def record_init_args(self, *args, **kwargs) -> None: 125 self._init_args = args 126 self._init_kwargs = kwargs 127 128 def forward_pre_hook( 129 self, module: nn.Module, args: Tuple[Any, ...], kwargs: Dict[str, Any] 130 ) -> Any: 131 if self._init_args or self._init_kwargs: 132 self.lazy_init() 133 self._ddp.require_backward_grad_sync = not self._no_sync 134 return self._ddp._pre_forward(*args, **kwargs) 135 136 def forward_post_hook( 137 self, 138 module: nn.Module, 139 input: Tuple[torch.Tensor], 140 output: torch.Tensor, 141 ) -> torch.Tensor: 142 return self._ddp._post_forward(output) 143 144 145def unimplemented_deepcopy(*args: Any, **kwargs: Any) -> NoReturn: 146 raise AssertionError( 147 "DDP does not support deepcopy. Please use state dict for serialization." 148 ) 149 150 151# Follow the same pattern as FSDP/fully_shard 152class DDP: 153 def __new__(cls, *args, **kwargs): 154 """ 155 Override ``__new__`` to remove the DDP class and directly construct 156 the original class for cases like indexing into a container module. 157 """ 158 # Use index 2 since 0 is the dynamically constructed `DDP<...>` class 159 # and index 1 is the `DDP` class itself 160 orig_cls = cls.__mro__[2] 161 return orig_cls.__new__(orig_cls, *args, **kwargs) 162 163 def set_requires_gradient_sync(self, requires_gradient_sync: bool) -> None: 164 """ 165 Sets if the module should sync gradients. This can be used to implement 166 gradient accumulation without communication. 167 168 Args: 169 requires_gradient_sync (bool): Whether to reduce gradients for the 170 module's parameters. 171 """ 172 replicate.state(self)._no_sync = not requires_gradient_sync 173 174 def register_comm_hook(self, *args, **kwargs) -> None: 175 replicate.state(self)._comm_hook_args.append((args, kwargs)) 176 177 178@contract(state_cls=_ReplicateState) 179def replicate( 180 module: nn.Module, 181 ignored_modules: Optional[Iterable[torch.nn.Module]] = None, 182 **kwargs, 183) -> nn.Module: 184 r"""Replicates a module 185 186 Args: 187 module (torch.nn.Module): module to replicate 188 189 Example:: 190 >>> # xdoctest: +REQUIRES(module:torch._C._distributed_c10d) 191 >>> module = nn.Linear(3, 3) 192 >>> replicate(module) 193 """ 194 torch._C._log_api_usage_once("torch.distributed.replicate") 195 196 # TODO(fegin): using kwargs is not a good idea if we would like to make 197 # replicate a formal API to replace DDP. 198 if "device_id" in kwargs: 199 if not isinstance(kwargs["device_id"], (int, torch.device)): 200 raise RuntimeError( 201 "Expected device_id to be int or torch.device, " 202 f"but got {type(kwargs['device_id'])}" 203 ) 204 205 if _is_fully_sharded(module): 206 raise RuntimeError( 207 "Cannot apply `replicate()` on a Module already managed by `fully_shard`" 208 ) 209 210 if ignored_modules is None: 211 ignored_modules = {} 212 else: 213 ignored_modules = set(ignored_modules) 214 215 state = cast(_ReplicateState, replicate.state(module)) 216 module.register_forward_pre_hook(state.forward_pre_hook, with_kwargs=True) 217 device_mesh = kwargs.get("device_mesh", None) 218 if device_mesh is not None: 219 from torch.distributed.device_mesh import _mesh_resources 220 221 root_mesh = _mesh_resources.get_root_mesh(device_mesh) 222 # if a root mesh is not the same as device_mesh, 223 # meaning the device_mesh is sliced out from the root mesh. 224 if root_mesh != device_mesh: 225 # TODO: This is a temporary work around to enable DDP + TP. 226 # We should do the logic in DDP so that the 2D implementation is 227 # sound and the state_dict works out of the box. 228 # 229 # This won't conflict with what is done in DDP class as the module 230 # replicate is going to pass is NOT the original module. 231 from torch.distributed.tensor.parallel.ddp import ( 232 _localize_dtensor, 233 _reconstruct_dtensor, 234 ) 235 236 module.register_forward_pre_hook(_reconstruct_dtensor) 237 module.register_forward_hook(_localize_dtensor) 238 239 module.register_forward_hook(state.forward_post_hook) # type: ignore[arg-type] 240 241 state.record_init_args(module, ignored_modules, **kwargs) 242 243 # Place DDP leftmost for highest priority in the method resolution order 244 cls = module.__class__ 245 dct = {"__deepcopy__": unimplemented_deepcopy} 246 new_cls = type(f"DDP{cls.__name__}", (DDP, cls), dct) 247 module.__class__ = new_cls 248 return module 249 250 251def _is_fully_sharded(module: nn.Module) -> bool: 252 r"""Check if module is marked with fully_shard.""" 253 registry = _get_registry(module) 254 if registry is None: 255 return False 256 return "fully_shard" in registry 257