xref: /aosp_15_r20/external/pytorch/torch/distributed/_shard/metadata.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# mypy: allow-untyped-defs
2from dataclasses import dataclass
3from functools import reduce
4from typing import List, Optional, Union
5
6from torch.distributed.remote_device import _remote_device
7
8
9@dataclass
10class ShardMetadata:
11    """
12    Represents a shard of the overall Tensor including its
13    offsets, lengths and device placement.
14
15    Args:
16        shard_offsets(List[int]): Offsets in the original tensor indicating
17            the start offsets for this shard. Should have the same rank as
18            the original tensor.
19        shard_sizes(List[int]): Integers indicating the size of each
20            dimension for this shard. Should have the same rank as the
21            original tensor.
22        placement(:class:`torch.distributed._remote_device`):
23            Specifies the placement of this shard.
24    """
25
26    __slots__ = ["shard_offsets", "shard_sizes", "placement"]
27
28    shard_offsets: List[int]
29    shard_sizes: List[int]
30    placement: Optional[_remote_device]
31
32    def __init__(
33        self,
34        shard_offsets: List[int],
35        shard_sizes: List[int],
36        placement: Optional[Union[str, _remote_device]] = None,
37    ):
38        self.shard_offsets = shard_offsets
39        self.shard_sizes = shard_sizes
40        if isinstance(placement, str):
41            self.placement = _remote_device(placement)
42        else:
43            self.placement = placement
44        if len(self.shard_offsets) != len(self.shard_sizes):
45            raise ValueError(
46                f"shard_offsets and shard_sizes should have "
47                f"the same number of elements, found {len(self.shard_offsets)} "
48                f"and {self.shard_sizes} respectively"
49            )
50
51        for i in range(len(self.shard_offsets)):
52            if self.shard_offsets[i] < 0:
53                raise ValueError("shard_offsets should be >=0")
54            if self.shard_sizes[i] < 0:
55                raise ValueError("shard_sizes should be >= 0")
56
57    def __hash__(self):
58        def _hash_reduce(a, b):
59            return (a << 8) + hash(b)
60
61        res = reduce(_hash_reduce, self.shard_offsets, 37)
62        res = reduce(_hash_reduce, self.shard_sizes, res)
63        res = _hash_reduce(res, self.placement)
64        return res
65