xref: /aosp_15_r20/external/pytorch/torch/distributed/_shard/api.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# mypy: allow-untyped-defs
2from contextlib import contextmanager
3from typing import Optional
4
5import torch
6import torch.distributed as dist
7import torch.nn as nn
8from torch.distributed import distributed_c10d
9from torch.distributed._shard.sharded_tensor import ShardedTensor
10
11from .sharder import Sharder
12from .sharding_plan import ShardingPlan
13from .sharding_spec import ChunkShardingSpec, ShardingSpec
14
15
16def _shard_tensor(
17    tensor: torch.Tensor, sharding_spec: ShardingSpec, src_rank=0, process_group=None
18) -> ShardedTensor:
19    """
20    Given a :class:`torch.Tensor`, it shards that tensor according to the provided
21    ``sharding_spec``. ``src_rank`` denotes the source rank which would be
22    used as the ground truth of the data which would be scattered as shards
23    across the rest of the ranks.
24
25    Args:
26        tensor (:class:`torch.Tensor`): Tensor needs to be sharded.
27        sharding_spec (:class:`torch.distributed._shard.sharding_spec.ShardingSpec`): The specification
28            describing how to shard the Tensor.
29
30    Keyword args:
31        src_rank (int, optional): The source rank which is used as the ground truth of
32            the data for the parameter that would be sharded and scattered
33            across the rest of the ranks.
34            Default: 0.
35        process_group (ProcessGroup, optional): The process group to work on. If None,
36            the default process group will be used.
37
38    Returns:
39        A :class:`ShardedTensor` sharded from the given tensor.
40
41    .. warning::
42        Only :class:`torch.distributed._shard.sharding_spec.ChunkShardingSpec` is
43        currently supported as the ``sharding_spec``.
44    """
45    if not tensor.is_contiguous():
46        raise ValueError("input tensor is not a contiguous Tensor")
47
48    pg = (
49        process_group
50        if process_group is not None
51        else distributed_c10d._get_default_group()
52    )
53    world_size = dist.get_world_size(pg)
54    current_rank = dist.get_rank(pg)
55
56    # Validate src_rank and sharding_spec are same across all ranks.
57    gathered_list = [None] * world_size
58    dist.all_gather_object(gathered_list, (src_rank, sharding_spec), group=pg)
59
60    for idx, entry in enumerate(gathered_list):
61        if src_rank != entry[0]:  # type: ignore[index]
62            raise ValueError(
63                f"src_rank={src_rank} on rank: {current_rank} does not "  # type: ignore[index]
64                f"match with src_rank={entry[0]} on rank: {idx}"  # type: ignore[index]
65            )
66        if sharding_spec != entry[1]:  # type: ignore[index]
67            raise ValueError(
68                f"sharding_spec={sharding_spec} on rank: {current_rank} does not "  # type: ignore[index]
69                f"match with sharding_spec={entry[1]} on rank: {idx}"  # type: ignore[index]
70            )
71
72    st = sharding_spec.shard(tensor, src_rank=src_rank, process_group=pg)
73
74    return st
75
76
77def shard_parameter(
78    module: torch.nn.Module,
79    param_name: str,
80    sharding_spec: ShardingSpec,
81    src_rank=0,
82    process_group=None,
83):
84    """
85    Given a :class:`torch.nn.Module`, a ``param_name`` for a parameter in that
86    module, it shards that parameter according to the provided
87    ``sharding_spec``. ``src_rank`` denotes the source rank which would be
88    used as the ground truth of the data which would be scattered as shards
89    across the rest of the ranks.
90
91    This method replaces ``module.param_name`` with a
92    :class:`torch.distributed._sharded_tensor.ShardedTensor`
93
94    Args:
95        module (:class:`torch.nn.Module`): Module whose parameter needs to be sharded.
96        param_name (str): Name of the parameter of ``module`` that needs to be sharded.
97        sharding_spec (:class:`torch.distributed._shard.sharding_spec.ShardingSpec`): The specification
98            describing how to shard the Tensor.
99
100    Keyword args:
101        src_rank (int, optional): The source rank which is used as the ground truth of
102            the data for the parameter that would be sharded and scattered
103            across the rest of the ranks.
104            Default: 0.
105        process_group (ProcessGroup, optional): The process group to work on. If None,
106            the default process group will be used.
107
108    .. warning::
109        Only :class:`torch.distributed._shard.sharding_spec.ChunkShardingSpec` is
110        currently supported as the ``sharding_spec``.
111    """
112    # Perform some validation first.
113    if not hasattr(module, param_name):
114        raise AttributeError(f"{module._get_name()} has no attribute `{param_name}`")
115
116    tensor = getattr(module, param_name)
117    if not isinstance(tensor, torch.Tensor):
118        raise ValueError(
119            f"Expected {type(module).__name__}.{param_name} to be a Tensor, but found {type(tensor).__name__}"
120        )
121
122    if not tensor.is_contiguous():
123        raise ValueError(f"param: {param_name} is not a contiguous Tensor")
124
125    st = _shard_tensor(tensor, sharding_spec, src_rank, process_group)
126
127    # Replace param with ShardedTensor.
128    module.register_parameter(param_name, nn.Parameter(st))
129
130
131# Tracks the current process group in the load context manager.
132_CURRENT_PROCESS_GROUP: Optional[dist.ProcessGroup] = None
133
134
135@contextmanager
136def load_with_process_group(process_group):
137    """
138    Context manager to set the process group with which to load a ShardedTensor.
139    """
140    global _CURRENT_PROCESS_GROUP
141    if _CURRENT_PROCESS_GROUP is not None:
142        raise RuntimeError(
143            'ProcessGroup already set by previous "load_with_process_group" '
144            "context manager"
145        )
146    _CURRENT_PROCESS_GROUP = process_group
147    try:
148        yield process_group
149    finally:
150        _CURRENT_PROCESS_GROUP = None
151
152
153def _get_current_process_group():
154    """
155    Retrieves the current process group set by ``load_with_process_group``.
156    If not set, it just returns the default group.
157    """
158    global _CURRENT_PROCESS_GROUP
159    if _CURRENT_PROCESS_GROUP is None:
160        return distributed_c10d._get_default_group()
161    else:
162        return _CURRENT_PROCESS_GROUP
163
164
165def _reshard_output(
166    module: torch.nn.Module, resharding_spec: ShardingSpec
167) -> torch.nn.Module:
168    """
169    Hook a module with output resharding in the forward pass according
170    to the given ``resharding_spec``.
171
172    Args:
173        module (:class:`torch.nn.Module`): Module whose output needs to be resharded.
174        resharding_spec (:class:`torch.distributed._shard.sharding_spec.ShardingSpec`):
175            The specification describing how the output of the module will be resharded.
176
177    Returns:
178        A :class:`torch.nn.Module` object with reshard API hooked.
179    """
180
181    def hook_func(_module, _input, output):
182        if isinstance(output, ShardedTensor):
183            return output.reshard(resharding_spec)
184        return output
185
186    module.register_forward_hook(hook_func)
187    return module
188
189
190def _collect_local_shard(module: torch.nn.Module) -> torch.nn.Module:
191    """
192    Hook a module with local shards collection in the forward pass.
193
194    This API is typically used to convert a sharded representation back to data parallel
195    representation. In particular, it returns the local tensor for this Shard. If the
196    size along the sharding dimension for the local tensor is 1, this dimension is removed
197    from the final result. For example a [4, 16] ShardedTensor across 4 ranks is typically
198    a local Tensor of size [16] across each rank and not [1, 16] across each rank.
199
200    Args:
201        module (:class:`torch.nn.Module`): Module whose output is ShardedTensor and the
202            local tensor value needs to be returned.
203
204    Returns:
205        A :class:`torch.nn.Module` object with collection API hooked.
206    """
207
208    def hook_func(_module, _input, output):
209        if isinstance(output, ShardedTensor):
210            local_tensor = output.local_tensor()
211            # Squeeze the # of dimensions manually, only applicable to ChunkShardingSpec
212            sharding_spec = output._sharding_spec
213            if (
214                isinstance(sharding_spec, ChunkShardingSpec)
215                and local_tensor.size(sharding_spec.dim) == 1  # type: ignore[attr-defined, arg-type]
216            ):
217                local_tensor = local_tensor.squeeze(
218                    output._sharding_spec.dim  # type: ignore[attr-defined]
219                )
220            return local_tensor
221
222    module.register_forward_hook(hook_func)
223    return module
224
225
226def shard_module(module: nn.Module, plan: ShardingPlan, src_rank=0, process_group=None):
227    """
228    Shards a given module according to the provided sharding `plan`. This method
229    first shards all the parameters according to the given sharding `plan`. Then if
230    `output_plan` and `return_local_tensor` are specified in the sharding `plan`, it
231    will tag the output of modules according `output_plan`, convert the module's
232    output back to data parallel according to `return_local_tensor`.
233
234    Needs to be called on all ranks in an SPMD fashion.
235
236    Args:
237        module (:class:`torch.nn.Module`): The module to apply sharding to
238        plan (:class:`torch.distributed._shard.sharding_plan.ShardingPlan`):
239            The ShardingPlan which specified param name to ShardingSpec to apply to
240            each parameter.
241
242    Keyword args:
243         src_rank (int, optional): The source rank which is used as the ground truth of
244            the data for the module that would be sharded and scattered across the rest
245            of the ranks.
246            Default: 0.
247        process_group (ProcessGroup, optional): The process group to work on. If None,
248            the default process group will be used.
249    """
250    # record Sharder paths for sanity check on the plan to ensure items in the plan
251    # does not conflict with the submodule tree that the Sharder is working with
252    sharder_paths = []
253    for name, spec in plan.plan.items():
254        if isinstance(spec, Sharder):
255            sharder_paths.append(name)
256
257    # shard the parameter according to the ShardingPlan
258    for name, spec in plan.plan.items():
259        if isinstance(spec, ShardingSpec):
260            # if found a sharding spec, try to shard the parameter
261            module_path, _, param_name = name.rpartition(".")
262
263            for sharder_path in sharder_paths:
264                if module_path.startswith(sharder_path):
265                    raise RuntimeError(
266                        f"ShardingPlan is in-valid, trying to shard a parameter: {name},"
267                        f" but there's already a Sharder entry for module {sharder_path},"
268                        f" parameter sharding should not conflict with the submodule tree"
269                        f" that a Sharder is working with!"
270                    )
271
272            mod = module.get_submodule(module_path)
273            shard_parameter(
274                mod, param_name, spec, src_rank=src_rank, process_group=process_group
275            )
276        elif isinstance(spec, Sharder):
277            parent_mod_path, _, mod_name = name.rpartition(".")
278            if name == "":
279                raise KeyError("Module path must not be empty for custom sharder!")
280            mod = module.get_submodule(name)
281            parent_mod = module.get_submodule(parent_mod_path)
282            sharded_mod = spec.shard(mod)
283            # swap this submodule with the sharded module
284            parent_mod.mod_name = sharded_mod
285        else:
286            raise TypeError(
287                f"Only `ShardingSpec` and `Sharder` are supported to shard '{name}'"
288            )
289
290    # reshard output if there's an entry in `reshard_output` for this module
291    if plan.output_plan is not None:
292        for module_path, output_spec in plan.output_plan.items():
293            if isinstance(output_spec, ShardingSpec):
294                mod = module.get_submodule(module_path)
295                _reshard_output(mod, output_spec)
296            else:
297                raise TypeError(
298                    f"Only `ShardingSpec` is supported as output_plan for '{module_path}'"
299                )
300    # convert the output back to data parallel for the modules appears in
301    # `return_local_tensor` of the plan, we will call `_collect_local_shard`
302    # to collect the local tensor for output of modules
303    if plan.return_local_tensor is not None:
304        for module_path in plan.return_local_tensor:
305            mod = module.get_submodule(module_path)
306            _collect_local_shard(mod)
307