1# mypy: allow-untyped-defs 2import collections.abc 3import copy 4from typing import List, Optional, Sequence, TYPE_CHECKING 5 6import torch 7from torch.distributed import distributed_c10d as c10d, rpc 8from torch.distributed._shard.sharding_spec._internals import ( 9 check_tensor, 10 validate_non_overlapping_shards_metadata, 11) 12 13from .metadata import ShardedTensorMetadata, TensorProperties 14from .shard import Shard 15 16 17if TYPE_CHECKING: 18 from torch.distributed._shard.metadata import ShardMetadata 19 20 21def _parse_and_validate_remote_device(pg, remote_device): 22 if remote_device is None: 23 raise ValueError("remote device is None") 24 25 worker_name = remote_device.worker_name() 26 rank = remote_device.rank() 27 device = remote_device.device() 28 29 # Validate rank, skip validation if rank is not part of process group. 30 if rank is not None and not c10d._rank_not_in_group(pg): 31 pg_global_ranks = c10d.get_process_group_ranks(pg) 32 if rank not in pg_global_ranks: 33 raise ValueError( 34 f"Global rank {rank} does not exist in input process group: {pg_global_ranks}" 35 ) 36 37 if worker_name is not None: 38 if not rpc._is_current_rpc_agent_set(): 39 raise RuntimeError( 40 f"RPC framework needs to be initialized for using worker names: {worker_name}" 41 ) 42 43 workers = rpc._get_current_rpc_agent().get_worker_infos() 44 for worker in workers: 45 if worker.name == worker_name: 46 return worker.id, device 47 48 raise ValueError(f"Invalid worker name: {worker_name}") 49 50 return rank, device 51 52 53def _validate_output_tensor_for_gather( 54 my_rank: int, 55 dst_rank: int, 56 size: torch.Size, 57 dst_tensor: Optional[torch.Tensor], 58) -> None: 59 if dst_rank == my_rank: 60 if dst_tensor is None: 61 raise ValueError( 62 f"Argument ``dst_tensor`` must be specified on destination rank {dst_rank}" 63 ) 64 if tuple(size) != (dst_tensor.size()): 65 raise ValueError( 66 f"Argument ``dst_tensor`` have size {tuple(dst_tensor.size())}," 67 f"but should be {tuple(size)}" 68 ) 69 elif dst_tensor: 70 raise ValueError( 71 "Argument ``dst_tensor`` must NOT be specified " "on non-destination ranks." 72 ) 73 74 75def _flatten_tensor_size(size) -> torch.Size: 76 """ 77 Checks if tensor size is valid, then flatten/return a torch.Size object. 78 """ 79 if len(size) == 1 and isinstance(size[0], collections.abc.Sequence): 80 dims = list(*size) 81 else: 82 dims = list(size) 83 84 for dim in dims: 85 if not isinstance(dim, int): 86 raise TypeError(f"size has to be a sequence of ints, found: {dims}") 87 88 return torch.Size(dims) 89 90 91def _raise_if_mismatch(expected, actual, prop_name, ranks, is_local=True): 92 if is_local: 93 assert isinstance(ranks, int) 94 if expected != actual: 95 raise ValueError( 96 f"Local shards' tensor {prop_name} property need to be the same on rank:{ranks}! " 97 f"Found one local shard tensor {prop_name}={expected}, " 98 f"the other local shard tensor {prop_name}={actual}." 99 ) 100 else: 101 # compare failure check across ranks, ranks list should have two rank 102 assert len(ranks) == 2 103 if expected != actual: 104 raise ValueError( 105 f"ShardedTensor {prop_name} property does not match from different ranks! " 106 f"Found {prop_name}={expected} on rank:{ranks[0]}, " 107 f"and {prop_name}={actual} on rank:{ranks[1]}." 108 ) 109 110 111def build_metadata_from_local_shards( 112 local_shards: List[Shard], 113 global_size: torch.Size, 114 current_rank: int, 115 pg: c10d.ProcessGroup, 116) -> ShardedTensorMetadata: 117 assert len(local_shards) > 0, "must have local shards!" 118 local_shard_metadatas: List[ShardMetadata] = [] 119 120 first_shard_dtype = local_shards[0].tensor.dtype 121 first_shard_layout = local_shards[0].tensor.layout 122 first_shard_requires_grad = local_shards[0].tensor.requires_grad 123 first_shard_is_pinned = local_shards[0].tensor.is_pinned() 124 125 # 1). Validate local tensors and associated metadatas 126 for local_shard in local_shards: 127 local_shard_tensor = local_shard.tensor 128 local_shard_meta = local_shard.metadata 129 local_shard_metadatas.append(local_shard_meta) 130 rank, local_device = _parse_and_validate_remote_device( 131 pg, local_shard_meta.placement 132 ) 133 134 if ( 135 local_shard_tensor.layout != torch.strided 136 or local_shard_tensor.layout != first_shard_layout 137 ): 138 raise ValueError( 139 f"Only torch.strided layout is currently supported, but found " 140 f"{local_shard_tensor.layout} on rank:{current_rank}!" 141 ) 142 143 if not local_shard_tensor.is_contiguous(): 144 raise ValueError( 145 "Only torch.contiguous_format memory_format is currently supported!" 146 ) 147 148 if rank != current_rank: 149 raise ValueError( 150 f"Local shard metadata's rank does not match with the rank in its process group! " 151 f"Found current rank in the process group: {current_rank}, " 152 f"local ShardMetadata placement's rank: {rank}" 153 ) 154 if local_shard_tensor.device != local_device: 155 raise ValueError( 156 f"Local shard tensor device does not match with local Shard's placement! " 157 f"Found local shard tensor device: {local_shard_tensor.device}, " 158 f"local shard metadata placement device: {local_device}" 159 ) 160 161 _raise_if_mismatch( 162 local_shard_meta.shard_sizes, 163 list(local_shard_tensor.size()), 164 "size", 165 current_rank, 166 ) 167 _raise_if_mismatch( 168 local_shard_tensor.is_pinned(), 169 first_shard_is_pinned, 170 "pin_memory", 171 current_rank, 172 ) 173 _raise_if_mismatch( 174 local_shard_tensor.dtype, first_shard_dtype, "dtype", current_rank 175 ) 176 _raise_if_mismatch( 177 local_shard_tensor.requires_grad, 178 first_shard_requires_grad, 179 "requires_grad", 180 current_rank, 181 ) 182 183 # 2). Build a "local" ShardedTensorMetadata with all local shards on this rank, then 184 # do all_gather to collect local_sharded_tensor_metadata from all ranks 185 local_tensor_properties = TensorProperties( 186 dtype=first_shard_dtype, 187 layout=first_shard_layout, 188 requires_grad=first_shard_requires_grad, 189 memory_format=torch.contiguous_format, 190 pin_memory=first_shard_is_pinned, 191 ) 192 193 local_sharded_tensor_metadata = ShardedTensorMetadata( 194 shards_metadata=local_shard_metadatas, 195 size=global_size, 196 tensor_properties=local_tensor_properties, 197 ) 198 199 return local_sharded_tensor_metadata 200 201 202def build_global_metadata( 203 gathered_metadatas: Sequence[Optional[ShardedTensorMetadata]], 204): 205 global_sharded_tensor_metadata = None 206 global_metadata_rank = 0 207 208 for rank, rank_metadata in enumerate(gathered_metadatas): 209 if rank_metadata is None: 210 continue 211 212 if global_sharded_tensor_metadata is None: 213 global_sharded_tensor_metadata = copy.deepcopy(rank_metadata) 214 global_metadata_rank = rank 215 else: 216 _raise_if_mismatch( 217 global_sharded_tensor_metadata.size, 218 rank_metadata.size, 219 "global_size", 220 [global_metadata_rank, rank], 221 is_local=False, 222 ) 223 224 # don't need to check layout and memory format as we already checked in local shards validation stage 225 _raise_if_mismatch( 226 global_sharded_tensor_metadata.tensor_properties.dtype, 227 rank_metadata.tensor_properties.dtype, 228 "dtype", 229 [global_metadata_rank, rank], 230 is_local=False, 231 ) 232 233 _raise_if_mismatch( 234 global_sharded_tensor_metadata.tensor_properties.requires_grad, 235 rank_metadata.tensor_properties.requires_grad, 236 "requires_grad", 237 [global_metadata_rank, rank], 238 is_local=False, 239 ) 240 241 _raise_if_mismatch( 242 global_sharded_tensor_metadata.tensor_properties.pin_memory, 243 rank_metadata.tensor_properties.pin_memory, 244 "pin_memory", 245 [global_metadata_rank, rank], 246 is_local=False, 247 ) 248 # pass all validations, extend shards metadata 249 global_sharded_tensor_metadata.shards_metadata.extend( 250 rank_metadata.shards_metadata 251 ) 252 253 if global_sharded_tensor_metadata is not None: 254 # check if shards_metadata have overlap shards 255 validate_non_overlapping_shards_metadata( 256 global_sharded_tensor_metadata.shards_metadata 257 ) 258 259 # check if the shards_metadata is compatible with global size of the sharded tensor. 260 check_tensor( 261 global_sharded_tensor_metadata.shards_metadata, 262 global_sharded_tensor_metadata.size, 263 ) 264 else: 265 raise ValueError("ShardedTensor have no local shards on all ranks!") 266 267 return global_sharded_tensor_metadata 268