1# mypy: allow-untyped-defs 2from typing import Any, List, Optional, Set, Tuple 3 4import torch.nn as nn 5from torch.distributed.tensor.parallel._data_parallel_utils import ( 6 _flatten_tensor, 7 _unflatten_tensor, 8) 9 10 11__all__ = [] # type: ignore[var-annotated] 12 13 14def _get_submodule_n_params(module: nn.Module, path: str): 15 """ 16 Get submodule and the direct path of parameter from the module 17 """ 18 if "." in path: 19 path_list = path.split(".") 20 parent_module_path = ".".join(path_list[:-1]) 21 module = module.get_submodule(parent_module_path) 22 path = path_list[-1] 23 return module, path 24 25 26def _update_module_param(param_list: List[Tuple[nn.Module, str, nn.Parameter]]): 27 """ 28 Update parameters within the module 29 """ 30 for item in param_list: 31 parent_module, module_path, t = item 32 assert hasattr(parent_module, module_path) 33 delattr(parent_module, module_path) 34 setattr(parent_module, module_path, t) 35 36 37def _reconstruct_dtensor(module: nn.Module, _input: Any): 38 """ 39 Recontruct DTensor parameters from local tensors 40 """ 41 param_list = [] 42 # TODO: To add perf optimizations to this iterations 43 for name, t in module.named_parameters(): 44 if hasattr(t, "_st_info"): 45 dtensor = _unflatten_tensor(t, t._st_info) 46 param_list.append((*_get_submodule_n_params(module, name), dtensor)) 47 _update_module_param(param_list) # type: ignore[arg-type] 48 49 50def _localize_dtensor( 51 module: nn.Module, *_: Any, ignored_params: Optional[Set[nn.Parameter]] = None 52): 53 """ 54 Convert DTensor parameters to local tensors 55 """ 56 if ignored_params is None: 57 ignored_params = set() 58 param_list = [] 59 for name, param in module.named_parameters(): 60 if param in ignored_params: 61 continue 62 t, sharding_info = _flatten_tensor(param) 63 if sharding_info is not None: 64 t = nn.Parameter(t) 65 t._st_info = sharding_info # type: ignore[attr-defined] 66 param_list.append((*_get_submodule_n_params(module, name), t)) 67 _update_module_param(param_list) # type: ignore[arg-type] 68 69 70def _pre_dp_module_transform(module: nn.Module): 71 """ 72 Enable the composability between Tensor Parallelism (TP) and Data 73 Parallelism(DP) in PyTorch when using DDP. We need to convert Parameters which 74 are DTensors to local tensors before wrapping with data parallelism API. 75 We then register two hooks, one for converting local tensors back to DTensor 76 preforward and one to convert DTensors back to tensors after Forward. By 77 integrating this way, we avoid any special handling of DTensor parameters by DDP 78 and get DTensor's gradients propagated back to DP, e.g. gradient buckets of DDP. 79 80 For now, this API only works with ``DistributedDataParallel``. It will later support 81 other DP methods such as FSDP. 82 83 Args: 84 module (:class:`nn.Module`): 85 Module which has been applied TP on. 86 87 Example:: 88 >>> # xdoctest: +SKIP("distributed") 89 >>> from torch.distributed.tensor.parallel import parallelize_module, PairwiseParallel 90 >>> from torch.nn.parallel import DistributedDataParallel as DDP 91 >>> from torch.distributed.tensor.parallel.ddp import pre_dp_module_transform 92 >>> 93 >>> # Define the module. 94 >>> m = module(...) 95 >>> parallelize_module(m, PairwiseParallel()) 96 >>> m = pre_dp_module_transform(m) 97 >>> m = DDP(m) 98 >>> 99 """ 100 101 _localize_dtensor(module, None, None) 102 # TODO: To add test cases and ensure that it works for nested modules 103 module.register_forward_pre_hook(_reconstruct_dtensor) 104 module.register_forward_hook(_localize_dtensor) 105