1# mypy: allow-untyped-defs 2from dataclasses import dataclass 3from functools import reduce 4from typing import List, Optional, Union 5 6from torch.distributed.remote_device import _remote_device 7 8 9@dataclass 10class ShardMetadata: 11 """ 12 Represents a shard of the overall Tensor including its 13 offsets, lengths and device placement. 14 15 Args: 16 shard_offsets(List[int]): Offsets in the original tensor indicating 17 the start offsets for this shard. Should have the same rank as 18 the original tensor. 19 shard_sizes(List[int]): Integers indicating the size of each 20 dimension for this shard. Should have the same rank as the 21 original tensor. 22 placement(:class:`torch.distributed._remote_device`): 23 Specifies the placement of this shard. 24 """ 25 26 __slots__ = ["shard_offsets", "shard_sizes", "placement"] 27 28 shard_offsets: List[int] 29 shard_sizes: List[int] 30 placement: Optional[_remote_device] 31 32 def __init__( 33 self, 34 shard_offsets: List[int], 35 shard_sizes: List[int], 36 placement: Optional[Union[str, _remote_device]] = None, 37 ): 38 self.shard_offsets = shard_offsets 39 self.shard_sizes = shard_sizes 40 if isinstance(placement, str): 41 self.placement = _remote_device(placement) 42 else: 43 self.placement = placement 44 if len(self.shard_offsets) != len(self.shard_sizes): 45 raise ValueError( 46 f"shard_offsets and shard_sizes should have " 47 f"the same number of elements, found {len(self.shard_offsets)} " 48 f"and {self.shard_sizes} respectively" 49 ) 50 51 for i in range(len(self.shard_offsets)): 52 if self.shard_offsets[i] < 0: 53 raise ValueError("shard_offsets should be >=0") 54 if self.shard_sizes[i] < 0: 55 raise ValueError("shard_sizes should be >= 0") 56 57 def __hash__(self): 58 def _hash_reduce(a, b): 59 return (a << 8) + hash(b) 60 61 res = reduce(_hash_reduce, self.shard_offsets, 37) 62 res = reduce(_hash_reduce, self.shard_sizes, res) 63 res = _hash_reduce(res, self.placement) 64 return res 65