1# mypy: allow-untyped-defs 2from typing import List, Optional, Tuple 3 4from torch.distributed._shard.metadata import ShardMetadata 5 6 7def _check_shard_metadata_pair_overlap(shard1: ShardMetadata, shard2: ShardMetadata): 8 """ 9 Checks if two shards overlap. 10 """ 11 12 # For each dim of each shard, check if one shard resides on the other 13 # end of second shard with respect to that dim. As an example for a 2D 14 # shard, we would check if one shard is above or on the left of the 15 # other shard. 16 ndims = len(shard1.shard_offsets) 17 for i in range(ndims): 18 if shard1.shard_offsets[i] >= shard2.shard_offsets[i] + shard2.shard_sizes[i]: 19 return False 20 if shard2.shard_offsets[i] >= shard1.shard_offsets[i] + shard1.shard_sizes[i]: 21 return False 22 23 return True 24 25 26def _find_nd_overlapping_shards( 27 shards: List[ShardMetadata], sharded_dims: List[int] 28) -> Optional[Tuple[int, int]]: 29 # Each rank has len(sharded_dims) tuples. Each tuple represent the 30 # [begin, end] (inclusive) pair of that dimension. 31 shard_intervals = [ 32 [ 33 (s.shard_offsets[dim], s.shard_offsets[dim] + s.shard_sizes[dim] - 1) 34 for dim in sharded_dims 35 ] 36 for s in shards 37 ] 38 39 for i in range(len(shards)): 40 shard_i = shard_intervals[i] 41 for j in range(i + 1, len(shards)): 42 shard_j = shard_intervals[j] 43 # For each dim of each shard, check if one shard resides on the other 44 # end of second shard with respect to that dim. As an example for a 2D 45 # shard, we would check if one shard is above or on the left of the 46 # other shard. 47 overlap = True 48 for interval_i, interval_j in zip(shard_i, shard_j): 49 if interval_i[0] > interval_j[1] or interval_j[0] > interval_i[1]: 50 overlap = False 51 break 52 if overlap: 53 return (i, j) 54 return None 55 56 57def _find_1d_overlapping_shards( 58 shards: List[ShardMetadata], dim: int 59) -> Optional[Tuple[int, int]]: 60 # (begin, end, index_in_shards). Begin and end are inclusive. 61 intervals = [ 62 (s.shard_offsets[dim], s.shard_offsets[dim] + s.shard_sizes[dim] - 1, i) 63 for i, s in enumerate(shards) 64 ] 65 intervals.sort() 66 for i in range(len(shards) - 1): 67 if intervals[i][1] >= intervals[i + 1][0]: 68 return (intervals[i][2], intervals[i + 1][2]) 69 return None 70 71 72def validate_non_overlapping_shards_metadata(shards: List[ShardMetadata]): 73 """ 74 Ensures none of the shards overlap with each other. 75 76 Args: 77 shards(List[ShardMetadata]): List of :class:`ShardMetadata` objects representing 78 each shard. 79 Raises: 80 ``ValueError`` if there's overlap in any two shards. 81 """ 82 if not shards or len(shards) == 1: 83 return 84 85 sharded_dims: List[int] = [] 86 for dim in range(len(shards[0].shard_offsets)): 87 for i in range(1, len(shards)): 88 if ( 89 shards[i].shard_offsets[dim] != shards[0].shard_offsets[dim] 90 or shards[i].shard_sizes[dim] != shards[0].shard_sizes[dim] 91 ): 92 sharded_dims.append(dim) 93 break 94 95 pair: Optional[Tuple[int, int]] = None 96 if len(sharded_dims) == 0: 97 # All shards are the same, all dims are not partitioned. Choose any 2. 98 pair = (0, 1) 99 elif len(sharded_dims) == 1: 100 # Shards are partitioned over only one dimension. Overlap can be found 101 # using a O(nlogn) overlapping interval algorithm. 102 pair = _find_1d_overlapping_shards(shards, sharded_dims[0]) 103 else: 104 # Shards are partitioned over more than one dimension. Fall back to 105 # pair-wise check. Even though O(nlogn) algorithms (line sweep) exist 106 # for 2D overlap, the implementation is not trivial and may not justify 107 # the time saving in most cases. 108 pair = _find_nd_overlapping_shards(shards, sharded_dims) 109 110 if pair: 111 raise ValueError(f"Shards {shards[pair[0]]} and {shards[pair[1]]} overlap") 112 113 114def check_tensor(shards_metadata, tensor_dims) -> None: 115 """ 116 Checks if the shards_metadata is compatible with the provided tensor dims. 117 118 Args: 119 shards_metadata(List[ShardMetadata]): List of :class:`ShardMetadata` 120 objects representing each shard of the tensor. 121 tensor_dims(Sequence of int): Dimensions of tensor to verify 122 Raises: 123 ``ValueError`` if not compatible. 124 """ 125 126 # If the tensor's volume matches the total volume of all shards and 127 # all shard boundaries are within tensor dims, we have a compatible 128 # sharding spec for this tensor. Note that we have already verified 129 # we don't have overlapping shards. 130 tensor_rank = len(tensor_dims) 131 shards_rank = len(shards_metadata[0].shard_offsets) 132 if tensor_rank != shards_rank: 133 raise ValueError( 134 f"Rank of tensor is {tensor_rank}, but shards rank is {shards_rank}" 135 ) 136 137 total_shard_volume = 0 138 for shard in shards_metadata: 139 shard_volume = 1 140 for i, shard_length in enumerate(shard.shard_sizes): 141 shard_volume *= shard_length 142 if shard.shard_offsets[i] + shard.shard_sizes[i] > tensor_dims[i]: 143 raise ValueError( 144 f"Shard offset {shard.shard_offsets[i]} and length " 145 f"{shard.shard_sizes[i]} exceeds tensor dim: {tensor_dims[i]} for shard {shard}" 146 ) 147 total_shard_volume += shard_volume 148 149 tensor_volume = 1 150 for size in tensor_dims: 151 tensor_volume *= size 152 153 if total_shard_volume != tensor_volume: 154 # TODO: Can we improve this error message to point out the gaps? 155 raise ValueError( 156 f"Total volume of shards: {total_shard_volume} " 157 f"does not match tensor volume: {tensor_volume}, in other words " 158 f"all the individual shards do not cover the entire tensor" 159 ) 160 161 162def get_split_size(dim_size, chunks): 163 """ 164 Computes the split size inline with ``torch.chunk`` 165 166 Args: 167 dim_size(int): Size of the dimension being chunked. 168 chunks(int): Number of chunks to create for ``dim_size``. 169 170 Returns: 171 An int indicating the split size to use. 172 """ 173 return (dim_size + chunks - 1) // chunks 174 175 176def get_chunked_dim_size(dim_size, split_size, idx): 177 """ 178 Computes the dim size of the chunk for provided ``idx`` given ``dim_size`` 179 and ``split_size``. 180 181 Args: 182 dim_size(int): Size of the dimension being chunked. 183 split_size(int): The chunk size for each chunk of ``dim_size``. 184 idx(int): The index of chunk whose dim size is being requested. 185 186 Returns: 187 An int indicating the dim size of the chunk. 188 """ 189 return max(min(dim_size, split_size * (idx + 1)) - split_size * idx, 0) 190 191 192def get_chunk_sharding_params(sharding_dim_size, world_size, spec, rank): 193 """ 194 Generate the start pos and offset length for the current rank for 195 chunk sharding. 196 197 Args: 198 sharding_dim_size(int): The dimension length which we shard on. 199 world_size(int): number of ranks. 200 spec (:class:`torch.distributed._shard.sharding_spec.ChunkShardingSpec`): 201 sharding spec. 202 rank(int): # of cuda process. 203 204 Returns: 205 start_pos(int): start position of sharded tensor on the given rank. 206 chunk_size(int): chunk size of sharded tensor on the given rank. 207 """ 208 split_size = get_split_size(sharding_dim_size, world_size) 209 current_offsets = 0 210 start_pos = current_offsets 211 for idx, placement in enumerate(spec.placements): 212 chunk_size = get_chunked_dim_size(sharding_dim_size, split_size, idx) 213 if rank == placement.rank(): 214 start_pos = current_offsets 215 break 216 current_offsets += chunk_size 217 return start_pos, chunk_size # type: ignore[possibly-undefined] 218