xref: /aosp_15_r20/external/pytorch/torch/distributed/_shard/sharding_spec/_internals.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# mypy: allow-untyped-defs
2from typing import List, Optional, Tuple
3
4from torch.distributed._shard.metadata import ShardMetadata
5
6
7def _check_shard_metadata_pair_overlap(shard1: ShardMetadata, shard2: ShardMetadata):
8    """
9    Checks if two shards overlap.
10    """
11
12    # For each dim of each shard, check if one shard resides on the other
13    # end of second shard with respect to that dim. As an example for a 2D
14    # shard, we would check if one shard is above or on the left of the
15    # other shard.
16    ndims = len(shard1.shard_offsets)
17    for i in range(ndims):
18        if shard1.shard_offsets[i] >= shard2.shard_offsets[i] + shard2.shard_sizes[i]:
19            return False
20        if shard2.shard_offsets[i] >= shard1.shard_offsets[i] + shard1.shard_sizes[i]:
21            return False
22
23    return True
24
25
26def _find_nd_overlapping_shards(
27    shards: List[ShardMetadata], sharded_dims: List[int]
28) -> Optional[Tuple[int, int]]:
29    # Each rank has len(sharded_dims) tuples. Each tuple represent the
30    # [begin, end] (inclusive) pair of that dimension.
31    shard_intervals = [
32        [
33            (s.shard_offsets[dim], s.shard_offsets[dim] + s.shard_sizes[dim] - 1)
34            for dim in sharded_dims
35        ]
36        for s in shards
37    ]
38
39    for i in range(len(shards)):
40        shard_i = shard_intervals[i]
41        for j in range(i + 1, len(shards)):
42            shard_j = shard_intervals[j]
43            # For each dim of each shard, check if one shard resides on the other
44            # end of second shard with respect to that dim. As an example for a 2D
45            # shard, we would check if one shard is above or on the left of the
46            # other shard.
47            overlap = True
48            for interval_i, interval_j in zip(shard_i, shard_j):
49                if interval_i[0] > interval_j[1] or interval_j[0] > interval_i[1]:
50                    overlap = False
51                    break
52            if overlap:
53                return (i, j)
54    return None
55
56
57def _find_1d_overlapping_shards(
58    shards: List[ShardMetadata], dim: int
59) -> Optional[Tuple[int, int]]:
60    # (begin, end, index_in_shards). Begin and end are inclusive.
61    intervals = [
62        (s.shard_offsets[dim], s.shard_offsets[dim] + s.shard_sizes[dim] - 1, i)
63        for i, s in enumerate(shards)
64    ]
65    intervals.sort()
66    for i in range(len(shards) - 1):
67        if intervals[i][1] >= intervals[i + 1][0]:
68            return (intervals[i][2], intervals[i + 1][2])
69    return None
70
71
72def validate_non_overlapping_shards_metadata(shards: List[ShardMetadata]):
73    """
74    Ensures none of the shards overlap with each other.
75
76    Args:
77        shards(List[ShardMetadata]): List of :class:`ShardMetadata` objects representing
78            each shard.
79    Raises:
80        ``ValueError`` if there's overlap in any two shards.
81    """
82    if not shards or len(shards) == 1:
83        return
84
85    sharded_dims: List[int] = []
86    for dim in range(len(shards[0].shard_offsets)):
87        for i in range(1, len(shards)):
88            if (
89                shards[i].shard_offsets[dim] != shards[0].shard_offsets[dim]
90                or shards[i].shard_sizes[dim] != shards[0].shard_sizes[dim]
91            ):
92                sharded_dims.append(dim)
93                break
94
95    pair: Optional[Tuple[int, int]] = None
96    if len(sharded_dims) == 0:
97        # All shards are the same, all dims are not partitioned. Choose any 2.
98        pair = (0, 1)
99    elif len(sharded_dims) == 1:
100        # Shards are partitioned over only one dimension. Overlap can be found
101        # using a O(nlogn) overlapping interval algorithm.
102        pair = _find_1d_overlapping_shards(shards, sharded_dims[0])
103    else:
104        # Shards are partitioned over more than one dimension. Fall back to
105        # pair-wise check. Even though O(nlogn) algorithms (line sweep) exist
106        # for 2D overlap, the implementation is not trivial and may not justify
107        # the time saving in most cases.
108        pair = _find_nd_overlapping_shards(shards, sharded_dims)
109
110    if pair:
111        raise ValueError(f"Shards {shards[pair[0]]} and {shards[pair[1]]} overlap")
112
113
114def check_tensor(shards_metadata, tensor_dims) -> None:
115    """
116    Checks if the shards_metadata is compatible with the provided tensor dims.
117
118    Args:
119        shards_metadata(List[ShardMetadata]): List of :class:`ShardMetadata`
120            objects representing each shard of the tensor.
121        tensor_dims(Sequence of int): Dimensions of tensor to verify
122    Raises:
123        ``ValueError`` if not compatible.
124    """
125
126    # If the tensor's volume matches the total volume of all shards and
127    # all shard boundaries are within tensor dims, we have a compatible
128    # sharding spec for this tensor. Note that we have already verified
129    # we don't have overlapping shards.
130    tensor_rank = len(tensor_dims)
131    shards_rank = len(shards_metadata[0].shard_offsets)
132    if tensor_rank != shards_rank:
133        raise ValueError(
134            f"Rank of tensor is {tensor_rank}, but shards rank is {shards_rank}"
135        )
136
137    total_shard_volume = 0
138    for shard in shards_metadata:
139        shard_volume = 1
140        for i, shard_length in enumerate(shard.shard_sizes):
141            shard_volume *= shard_length
142            if shard.shard_offsets[i] + shard.shard_sizes[i] > tensor_dims[i]:
143                raise ValueError(
144                    f"Shard offset {shard.shard_offsets[i]} and length "
145                    f"{shard.shard_sizes[i]} exceeds tensor dim: {tensor_dims[i]} for shard {shard}"
146                )
147        total_shard_volume += shard_volume
148
149    tensor_volume = 1
150    for size in tensor_dims:
151        tensor_volume *= size
152
153    if total_shard_volume != tensor_volume:
154        # TODO: Can we improve this error message to point out the gaps?
155        raise ValueError(
156            f"Total volume of shards: {total_shard_volume} "
157            f"does not match tensor volume: {tensor_volume}, in other words "
158            f"all the individual shards do not cover the entire tensor"
159        )
160
161
162def get_split_size(dim_size, chunks):
163    """
164    Computes the split size inline with ``torch.chunk``
165
166    Args:
167        dim_size(int): Size of the dimension being chunked.
168        chunks(int): Number of chunks to create for ``dim_size``.
169
170    Returns:
171        An int indicating the split size to use.
172    """
173    return (dim_size + chunks - 1) // chunks
174
175
176def get_chunked_dim_size(dim_size, split_size, idx):
177    """
178    Computes the dim size of the chunk for provided ``idx`` given ``dim_size``
179    and ``split_size``.
180
181    Args:
182        dim_size(int): Size of the dimension being chunked.
183        split_size(int): The chunk size for each chunk of ``dim_size``.
184        idx(int): The index of chunk whose dim size is being requested.
185
186    Returns:
187        An int indicating the dim size of the chunk.
188    """
189    return max(min(dim_size, split_size * (idx + 1)) - split_size * idx, 0)
190
191
192def get_chunk_sharding_params(sharding_dim_size, world_size, spec, rank):
193    """
194    Generate the start pos and offset length for the current rank for
195    chunk sharding.
196
197    Args:
198        sharding_dim_size(int): The dimension length which we shard on.
199        world_size(int): number of ranks.
200        spec (:class:`torch.distributed._shard.sharding_spec.ChunkShardingSpec`):
201            sharding spec.
202        rank(int): # of cuda process.
203
204    Returns:
205        start_pos(int): start position of sharded tensor on the given rank.
206        chunk_size(int): chunk size of sharded tensor on the given rank.
207    """
208    split_size = get_split_size(sharding_dim_size, world_size)
209    current_offsets = 0
210    start_pos = current_offsets
211    for idx, placement in enumerate(spec.placements):
212        chunk_size = get_chunked_dim_size(sharding_dim_size, split_size, idx)
213        if rank == placement.rank():
214            start_pos = current_offsets
215            break
216        current_offsets += chunk_size
217    return start_pos, chunk_size  # type: ignore[possibly-undefined]
218