xref: /aosp_15_r20/external/pytorch/torch/distributed/_shard/sharding_spec/api.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# mypy: allow-untyped-defs
2import functools
3import operator
4from abc import ABC, abstractmethod
5from dataclasses import dataclass
6from typing import Callable, Dict, List, TYPE_CHECKING
7
8import torch
9import torch.distributed._shard.sharded_tensor.metadata as sharded_tensor_meta
10from torch.distributed._shard.metadata import ShardMetadata
11from torch.distributed._shard.op_registry_utils import _decorator_func
12
13from ._internals import (
14    check_tensor,
15    get_chunked_dim_size,
16    get_split_size,
17    validate_non_overlapping_shards_metadata,
18)
19
20
21if TYPE_CHECKING:
22    # Only include ShardedTensor when do type checking, exclude it
23    # from run-time to resolve circular dependency.
24    from torch.distributed._shard.sharded_tensor import ShardedTensor
25
26
27class PlacementSpec(ABC):  # noqa: B024
28    """
29    Base class representing the placement of an entity. Subclasses of this
30    class can be used to specify customized placements which might not be
31    covered by existing APIs.
32    """
33
34
35@dataclass
36class DevicePlacementSpec(PlacementSpec):
37    """
38    Associates placement of an entity with a single device.
39
40    Args:
41        device(:class:`torch.distributed._remote_device`): The device to place the entity on.
42    """
43
44    device: torch.distributed._remote_device
45
46    def __post_init__(self):
47        if not isinstance(self.device, torch.distributed._remote_device):
48            self.device = torch.distributed._remote_device(self.device)
49
50
51class ShardingSpec(ABC):
52    """
53    Base class representing sharding specifications.
54    """
55
56    @abstractmethod
57    def build_metadata(
58        self,
59        tensor_sizes: torch.Size,
60        tensor_properties: sharded_tensor_meta.TensorProperties,
61    ) -> sharded_tensor_meta.ShardedTensorMetadata:
62        """
63        Given a global tensor size, define how to shard a tensor like this shape
64        across ranks, return ShardedTensorMetadata
65        Args:
66            tensor_sizes (:class:`torch.Size`):
67                The tensor shape to shard on, a `torch.Size` object that represents the
68                tensor shape to be sharded according to the ShardingSpec.
69            tensor_properties(:class:`torch.distributed._shard.sharded_tensor.TensorProperties):
70                Tensor properties used to create a ShardedTensor.
71        Returns:
72            A :class:`ShardedTensorMetadata` object that encodes the information about
73            the layout of the ShardedTensor and its properties.
74        """
75
76    @abstractmethod
77    def shard(
78        self, tensor: torch.Tensor, src_rank: int = 0, process_group=None
79    ) -> "ShardedTensor":
80        """
81        Given a global tensor on src_rank, shard this tensor
82        across ranks within the process group, return a ShardedTensor.
83        Args:
84            tensor (:class:`torch.Tensor`): Tensor needs to be sharded.
85        Keyword args:
86            src_rank (int, optional): The source rank which is used as the ground truth of
87                the data for the parameter that would be sharded and scattered
88                across the rest of the ranks.
89                Default: 0.
90            process_group (ProcessGroup, optional): The process group to work on. If None,
91                the default process group will be used.
92        Returns:
93            A :class:`ShardedTensor` sharded from the given tensor.
94        """
95
96
97# Ops customized for a particular ShardingSpec.
98_CUSTOM_SHARDING_SPEC_OPS: Dict[str, Dict[Callable, Callable]] = {}
99
100
101def _has_custom_op(sharding_spec, op):
102    """
103    Returns whether or not the ShardingSpec has a custom op implementation.
104    """
105    class_name = type(sharding_spec).__qualname__
106    return (
107        class_name in _CUSTOM_SHARDING_SPEC_OPS
108        and op in _CUSTOM_SHARDING_SPEC_OPS[class_name]
109    )
110
111
112def _dispatch_custom_op(
113    sharding_spec, op: Callable, types, args, kwargs, process_group
114):
115    """
116    Calls the custom op for this ShardingSpec if it exists.
117    """
118    class_name = type(sharding_spec).__qualname__
119    if not _has_custom_op(sharding_spec, op):
120        raise RuntimeError(f"Custom op: {op} not registered for {class_name}")
121    func = _CUSTOM_SHARDING_SPEC_OPS[class_name][op]
122    return func(types, args, kwargs, process_group)
123
124
125def custom_sharding_spec_op(sharding_spec_class, func):
126    """
127    Decorator to allow custom registration of ops.
128    Args:
129        sharding_spec_class(type): The ShardingSpec for which we need to add this custom op.
130        func(Callable): The op to override (ex: torch.bmm)
131    """
132    class_name = sharding_spec_class.__qualname__
133    if class_name not in _CUSTOM_SHARDING_SPEC_OPS:
134        _CUSTOM_SHARDING_SPEC_OPS[class_name] = {}
135    return functools.partial(
136        _decorator_func, op=func, op_table=_CUSTOM_SHARDING_SPEC_OPS[class_name]
137    )
138
139
140@dataclass
141class EnumerableShardingSpec(ShardingSpec):
142    """
143    This is a type of PlacementSpec that allows users to specify a generic
144    sharding scheme by enumerating exactly how each shard is laid out.
145
146    Args:
147        shards(List[ShardMetadata]): List of :class:`ShardMetadata` objects representing
148            each shard. Note that none of the shards should overlap.
149    """
150
151    shards: List[ShardMetadata]
152
153    def __post_init__(self):
154        if len(self.shards) == 0:
155            raise ValueError(f"Empty shard list provided: {self.shards}")
156
157        # Validate each shard has same rank.
158        rank = -1
159        for shard in self.shards:
160            if rank != -1 and rank != len(shard.shard_offsets):
161                raise ValueError(
162                    f"Found inconsistent ranks for shards: {rank} and {len(shard.shard_offsets)}"
163                )
164            rank = len(shard.shard_offsets)
165
166        validate_non_overlapping_shards_metadata(self.shards)
167
168    def build_metadata(
169        self,
170        tensor_sizes: torch.Size,
171        tensor_properties: sharded_tensor_meta.TensorProperties,
172    ) -> sharded_tensor_meta.ShardedTensorMetadata:
173        # check if shards form a valid tensor
174        check_tensor(self.shards, tensor_sizes)
175        return sharded_tensor_meta.ShardedTensorMetadata(
176            self.shards, tensor_sizes, tensor_properties
177        )
178
179    def shard(
180        self, tensor: torch.Tensor, src_rank: int = 0, process_group=None
181    ) -> "ShardedTensor":
182        # TODO: figure out a generic and efficient way to scatter the shards for EnumerableShardingSpec
183        raise NotImplementedError("EnumerableShardingSpec.shard not implemented yet!")
184
185
186def _infer_sharding_spec_from_shards_metadata(shards_metadata):
187    """
188    Infer the sharding spec from the metadata of each shard of a ShardedTensor.
189    If the tensor is sharded only on one dimension, we can then verify whether it's
190    a ChunkShardingSpec or not. The way to verify it is to first get the total length
191    and perform a chunk sharding with the given placements to see if we can have the
192    same chunk size as the given shards_metadata. If not, we assume it's enum sharded.
193
194    Args:
195        shards_metadata (List[ShardMetadata]): List of Metadata of local shards.
196
197    Returns:
198        A :class:`torch.distributed._shard.sharding_spec.ShardingSpec` object of sharding
199            spec for one sharded tensor.
200    """
201    placements = []
202    chunk_sharding_dim = None
203    chunk_offset_list = []
204    shard_size_list = []
205    shard_offset_list = []
206    # collect local shard metadatas from the global sharded_tensor_metadata
207    for shard_metadata in shards_metadata:  # type: ignore[attr-defined]
208        placements.append(shard_metadata.placement)
209        local_offsets = shard_metadata.shard_offsets
210        chunk_offset_list.append(sum(local_offsets))
211        shard_size_list.append(shard_metadata.shard_sizes)
212        shard_offset_list.append(shard_metadata.shard_offsets)
213        shard_dims = [idx for idx, e in enumerate(local_offsets) if e != 0]
214        # If the offset is [0, 0, ..., 0] (all zeros),
215        # we cannot decide whether how the tensor is sharded.
216        if len(shard_dims) == 0:
217            continue
218        # If the offset is [0, N, .,0, M, 0, .., 0],
219        # we are sure it's sharded by more than one dimension.
220        if len(shard_dims) != 1:
221            chunk_sharding_dim = None
222            break
223        # If the offset is [0, 0, .,0, M, 0, .., 0], aka, it's sharded by just
224        # one dimension, we need to make sure all ranks share the same dimension.
225        if not chunk_sharding_dim:
226            chunk_sharding_dim = shard_dims[0]
227        elif chunk_sharding_dim != shard_dims[0]:
228            chunk_sharding_dim = None
229            break
230
231    if chunk_sharding_dim is not None:
232        # Ensure we infer the correct placement order from offsets
233        placements = [
234            x
235            for _, x in sorted(
236                zip(chunk_offset_list, placements), key=operator.itemgetter(0)
237            )
238        ]
239
240        from .chunk_sharding_spec import ChunkShardingSpec
241
242        chunk_spec = ChunkShardingSpec(
243            dim=chunk_sharding_dim,
244            placements=placements,
245        )
246
247        shard_sizes = sorted([x[chunk_sharding_dim] for x in shard_size_list])
248        shard_total_length = sum(shard_sizes)
249        shard_offsets = sorted([x[chunk_sharding_dim] for x in shard_offset_list])
250
251        chunks = len(placements)
252        split_size = get_split_size(shard_total_length, chunks)
253        chunk_shard_sizes = sorted(
254            [
255                get_chunked_dim_size(shard_total_length, split_size, idx)
256                for idx in range(chunks)
257            ]
258        )
259        # Should match ChunkShardingSpec offsets calculation
260        chunk_shard_offsets = [split_size * idx for idx in range(chunks)]
261        if shard_sizes == chunk_shard_sizes and shard_offsets == chunk_shard_offsets:
262            return chunk_spec
263    return EnumerableShardingSpec(shards_metadata)
264