xref: /aosp_15_r20/external/pytorch/torch/distributed/tensor/parallel/ddp.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# mypy: allow-untyped-defs
2from typing import Any, List, Optional, Set, Tuple
3
4import torch.nn as nn
5from torch.distributed.tensor.parallel._data_parallel_utils import (
6    _flatten_tensor,
7    _unflatten_tensor,
8)
9
10
11__all__ = []  # type: ignore[var-annotated]
12
13
14def _get_submodule_n_params(module: nn.Module, path: str):
15    """
16    Get submodule and the direct path of parameter from the module
17    """
18    if "." in path:
19        path_list = path.split(".")
20        parent_module_path = ".".join(path_list[:-1])
21        module = module.get_submodule(parent_module_path)
22        path = path_list[-1]
23    return module, path
24
25
26def _update_module_param(param_list: List[Tuple[nn.Module, str, nn.Parameter]]):
27    """
28    Update parameters within the module
29    """
30    for item in param_list:
31        parent_module, module_path, t = item
32        assert hasattr(parent_module, module_path)
33        delattr(parent_module, module_path)
34        setattr(parent_module, module_path, t)
35
36
37def _reconstruct_dtensor(module: nn.Module, _input: Any):
38    """
39    Recontruct DTensor parameters from local tensors
40    """
41    param_list = []
42    # TODO: To add perf optimizations to this iterations
43    for name, t in module.named_parameters():
44        if hasattr(t, "_st_info"):
45            dtensor = _unflatten_tensor(t, t._st_info)
46            param_list.append((*_get_submodule_n_params(module, name), dtensor))
47    _update_module_param(param_list)  # type: ignore[arg-type]
48
49
50def _localize_dtensor(
51    module: nn.Module, *_: Any, ignored_params: Optional[Set[nn.Parameter]] = None
52):
53    """
54    Convert DTensor parameters to local tensors
55    """
56    if ignored_params is None:
57        ignored_params = set()
58    param_list = []
59    for name, param in module.named_parameters():
60        if param in ignored_params:
61            continue
62        t, sharding_info = _flatten_tensor(param)
63        if sharding_info is not None:
64            t = nn.Parameter(t)
65            t._st_info = sharding_info  # type: ignore[attr-defined]
66            param_list.append((*_get_submodule_n_params(module, name), t))
67    _update_module_param(param_list)  # type: ignore[arg-type]
68
69
70def _pre_dp_module_transform(module: nn.Module):
71    """
72    Enable the composability between Tensor Parallelism (TP) and Data
73    Parallelism(DP) in PyTorch when using DDP. We need to convert Parameters which
74    are DTensors to local tensors before wrapping with data parallelism API.
75    We then register two hooks, one for converting local tensors back to DTensor
76    preforward and one to convert DTensors back to tensors after Forward. By
77    integrating this way, we avoid any special handling of DTensor parameters by DDP
78    and get DTensor's gradients propagated back to DP, e.g. gradient buckets of DDP.
79
80    For now, this API only works with ``DistributedDataParallel``. It will later support
81    other DP methods such as FSDP.
82
83    Args:
84        module (:class:`nn.Module`):
85            Module which has been applied TP on.
86
87    Example::
88        >>> # xdoctest: +SKIP("distributed")
89        >>> from torch.distributed.tensor.parallel import parallelize_module, PairwiseParallel
90        >>> from torch.nn.parallel import DistributedDataParallel as DDP
91        >>> from torch.distributed.tensor.parallel.ddp import pre_dp_module_transform
92        >>>
93        >>> # Define the module.
94        >>> m = module(...)
95        >>> parallelize_module(m, PairwiseParallel())
96        >>> m = pre_dp_module_transform(m)
97        >>> m = DDP(m)
98        >>>
99    """
100
101    _localize_dtensor(module, None, None)
102    # TODO: To add test cases and ensure that it works for nested modules
103    module.register_forward_pre_hook(_reconstruct_dtensor)
104    module.register_forward_hook(_localize_dtensor)
105