1# mypy: allow-untyped-defs 2from typing import Any, cast, List 3 4import torch 5import torch.distributed as dist 6from torch._utils import _get_device_module 7 8from torch.distributed._shard.metadata import ShardMetadata 9from torch.distributed._shard.sharded_tensor import ShardedTensor 10from torch.distributed._tensor import DTensor 11from torch.distributed._tensor._utils import compute_local_shape_and_global_offset 12from torch.distributed.checkpoint.planner import _Checkpointable 13 14from torch.utils._pytree import tree_map_only_ 15 16from .metadata import ( 17 BytesStorageMetadata, 18 ChunkStorageMetadata, 19 MetadataIndex, 20 STATE_DICT_TYPE, 21 STORAGE_TYPES, 22 TensorProperties, 23 TensorStorageMetadata, 24) 25from .planner import ( 26 LoadItemType, 27 ReadItem, 28 SavePlan, 29 TensorWriteData, 30 WriteItem, 31 WriteItemType, 32) 33from .resharding import ( 34 _check_shard_metadata_pair_overlap, 35 _shards_get_overlap_region_wrt_saved_tensor, 36) 37 38__all__: List[str] = ["create_read_items_for_chunk_list"] 39 40 41def _create_chunk_from_tensor(tensor: torch.Tensor) -> ChunkStorageMetadata: 42 return ChunkStorageMetadata( 43 offsets=torch.Size([0] * len(tensor.size())), sizes=tensor.size() 44 ) 45 46 47def _chunk_for_shard(shard_md: ShardMetadata) -> ChunkStorageMetadata: 48 return ChunkStorageMetadata( 49 offsets=torch.Size(shard_md.shard_offsets), 50 sizes=torch.Size(shard_md.shard_sizes), 51 ) 52 53 54def _sharded_tensor_metadata( 55 sharded_tensor: ShardedTensor, shard_md: ShardMetadata 56) -> TensorWriteData: 57 shard_properties = sharded_tensor.metadata().tensor_properties 58 59 properties = TensorProperties( 60 dtype=shard_properties.dtype, 61 layout=shard_properties.layout, 62 requires_grad=shard_properties.requires_grad, 63 memory_format=shard_properties.memory_format, 64 pin_memory=shard_properties.pin_memory, 65 ) 66 67 return TensorWriteData( 68 chunk=_chunk_for_shard(shard_md), 69 properties=properties, 70 size=sharded_tensor.metadata().size, 71 ) 72 73 74def _create_write_items_for_dtensor(fqn: str, tensor: DTensor) -> WriteItem: 75 sizes, offsets = compute_local_shape_and_global_offset( 76 tensor.shape, tensor.device_mesh, tensor.placements 77 ) 78 sizes, offsets = torch.Size(sizes), torch.Size(offsets) 79 80 return WriteItem( 81 index=MetadataIndex(fqn, offsets), 82 type=WriteItemType.SHARD, 83 tensor_data=TensorWriteData( 84 chunk=ChunkStorageMetadata( 85 offsets=offsets, 86 sizes=sizes, 87 ), 88 properties=TensorProperties.create_from_tensor(tensor.to_local()), 89 size=tensor.size(), 90 ), 91 ) 92 93 94def _create_write_item_for_shard( 95 fqn: str, sharded_tensor: ShardedTensor, shard_md: ShardMetadata 96) -> WriteItem: 97 offsets = torch.Size(shard_md.shard_offsets) 98 return WriteItem( 99 index=MetadataIndex(fqn, offsets), 100 type=WriteItemType.SHARD, 101 tensor_data=_sharded_tensor_metadata(sharded_tensor, shard_md), 102 ) 103 104 105def _create_write_item_for_tensor(fqn: str, tensor: torch.Tensor) -> WriteItem: 106 offsets = torch.Size([0] * len(tensor.size())) 107 return WriteItem( 108 index=MetadataIndex(fqn, offsets), 109 type=WriteItemType.TENSOR, 110 tensor_data=TensorWriteData( 111 chunk=ChunkStorageMetadata(offsets=offsets, sizes=tensor.size()), 112 properties=TensorProperties.create_from_tensor(tensor), 113 size=tensor.size(), 114 ), 115 ) 116 117 118def _create_write_item_for_bytesio(fqn: str, bytes: Any): 119 return WriteItem( 120 index=MetadataIndex(fqn), 121 type=WriteItemType.BYTE_IO, 122 ) 123 124 125def _create_read_item_for_byteio( 126 dest_index, dest_offset, storage_index, storage_offset, length 127): 128 return ReadItem( 129 type=LoadItemType.BYTE_IO, 130 dest_index=dest_index, 131 dest_offsets=torch.Size((dest_offset,)), 132 storage_index=storage_index, 133 storage_offsets=torch.Size((storage_offset,)), 134 lengths=torch.Size((length,)), 135 ) 136 137 138def _create_read_item_for_tensor( 139 dest_index, dest_offsets, storage_index, storage_offsets, lengths 140): 141 return ReadItem( 142 type=LoadItemType.TENSOR, 143 dest_index=dest_index, 144 dest_offsets=torch.Size(dest_offsets), 145 storage_index=storage_index, 146 storage_offsets=torch.Size(storage_offsets), 147 lengths=torch.Size(lengths), 148 ) 149 150 151def create_read_items_for_chunk_list( 152 fqn: str, 153 checkpoint_md: TensorStorageMetadata, 154 local_chunks: List[ChunkStorageMetadata], 155) -> List[ReadItem]: 156 """ 157 Create a list of ``ReadItem`` based on the checkpoint and local chunks. 158 159 This applies the resharding algorithm and computes the reads needed 160 to satisfy ``local_chunks`` with a checkpoint described by ``checkpoint_md``. 161 162 Args: 163 fqn (str) : The state_dict FQN to pass to ``ReadItem``. 164 checkpoint_md (TensorStorageMetadata): metadata for a given tensor 165 from a checkpoint. 166 local_chunks (List[ChunkStorageMetadata]): Local chunks that needs to be 167 loaded. 168 169 Returns: 170 A list of ``ReadItem`` that will satisfy all input chunks. 171 """ 172 read_items = [] 173 # this is a naive quadratic algo that can be optimized later 174 for idx, shard in enumerate(local_chunks): 175 for storage_idx, storage_md in enumerate(checkpoint_md.chunks): 176 if not _check_shard_metadata_pair_overlap(shard, storage_md): 177 continue 178 179 storage_offsets = [] 180 dest_offsets = [] 181 lengths = [] 182 for ( 183 dim, 184 offset_for_saved_tensor, 185 offset_for_current_tensor, 186 length, 187 ) in _shards_get_overlap_region_wrt_saved_tensor( 188 saved_shard=storage_md, current_shard=shard 189 ): 190 storage_offsets.append(offset_for_saved_tensor) 191 dest_offsets.append(offset_for_current_tensor) 192 lengths.append(length) 193 194 read_items.append( 195 _create_read_item_for_tensor( 196 dest_index=MetadataIndex(fqn, shard.offsets, idx), 197 dest_offsets=dest_offsets, 198 storage_index=MetadataIndex(fqn, storage_md.offsets, storage_idx), 199 storage_offsets=storage_offsets, 200 lengths=lengths, 201 ) 202 ) 203 return read_items 204 205 206def _create_default_metadata_only_plan(state_dict: STATE_DICT_TYPE) -> SavePlan: 207 requests = [] 208 for fqn, obj in state_dict.items(): 209 if isinstance(obj, DTensor): 210 requests.append(_create_write_items_for_dtensor(fqn, obj)) 211 elif isinstance(obj, ShardedTensor): 212 for shard_md in obj.metadata().shards_metadata: 213 requests.append(_create_write_item_for_shard(fqn, obj, shard_md)) 214 elif isinstance(obj, torch.Tensor): 215 requests.append(_create_write_item_for_tensor(fqn, obj)) 216 else: 217 requests.append(_create_write_item_for_bytesio(fqn, obj)) 218 return SavePlan(requests) 219 220 221def _create_write_items(fqn: str, object: Any) -> List[WriteItem]: 222 if isinstance(object, _Checkpointable): 223 return object._create_write_items(fqn, object) 224 elif isinstance(object, DTensor): 225 # DTensor can contain a local tensor that is a tensor subclass 226 if isinstance(object.to_local(), _Checkpointable): 227 return object.to_local()._create_write_items(fqn, object) # type: ignore[arg-type] 228 return [_create_write_items_for_dtensor(fqn, object)] 229 elif isinstance(object, ShardedTensor): 230 return [ 231 _create_write_item_for_shard(fqn, object, shard.metadata) 232 for shard in object.local_shards() 233 ] 234 elif isinstance(object, torch.Tensor): 235 return [_create_write_item_for_tensor(fqn, object)] 236 else: 237 return [_create_write_item_for_bytesio(fqn, object)] 238 239 240def _create_chunk_from_dtensor(tensor: DTensor) -> ChunkStorageMetadata: 241 sizes, offsets = compute_local_shape_and_global_offset( 242 tensor.shape, tensor.device_mesh, tensor.placements 243 ) 244 sizes, offsets = torch.Size(sizes), torch.Size(offsets) 245 return ChunkStorageMetadata( 246 offsets=offsets, 247 sizes=sizes, 248 ) 249 250 251def _create_chunk_list(tensor: torch.Tensor) -> List[ChunkStorageMetadata]: 252 if isinstance(tensor, _Checkpointable): 253 local_chunks = tensor._create_chunk_list(tensor) 254 elif isinstance(tensor, DTensor): 255 # DTensor can contain a local tensor that is a tensor subclass 256 if isinstance(tensor.to_local(), _Checkpointable): 257 return tensor.to_local()._create_chunk_list(tensor) # type: ignore[arg-type] 258 local_chunks = [_create_chunk_from_dtensor(tensor)] 259 elif isinstance(tensor, ShardedTensor): 260 local_chunks = [ 261 _chunk_for_shard(shard.metadata) for shard in tensor.local_shards() 262 ] 263 elif isinstance(tensor, torch.Tensor): 264 local_chunks = [_create_chunk_from_tensor(tensor)] 265 else: 266 raise ValueError( 267 "Unsupported Type, expecting one of [Tensor, DTensor, ShardedTensor] " 268 f",but got {type(tensor)}" 269 ) 270 271 return local_chunks 272 273 274def _create_read_items(fqn: str, md: STORAGE_TYPES, obj: Any) -> List[ReadItem]: 275 if not isinstance(md, BytesStorageMetadata): 276 try: 277 local_chunks = _create_chunk_list(obj) 278 except ValueError as ex: 279 raise ValueError( 280 f"Invalid checkpoint metadata for {fqn}, " 281 + f"expected BytesStorageMetadata but found {type(md)}", 282 ) from ex 283 284 return create_read_items_for_chunk_list(fqn, md, local_chunks) 285 else: 286 return [ 287 _create_read_item_for_byteio( 288 dest_index=MetadataIndex(fqn), 289 dest_offset=0, 290 storage_index=MetadataIndex(fqn), 291 storage_offset=0, 292 length=0, 293 ) 294 ] 295 296 297def _init_state_dict(state_dict: STATE_DICT_TYPE) -> None: 298 tree_map_only_(torch.Tensor, _init_meta_tensor, state_dict) 299 300 301def _init_meta_tensor(value: Any) -> Any: 302 """ 303 Initializes tensor, moves it to device for torch.Tensor/DTensor on meta device. 304 """ 305 306 device = getattr(value, "device", None) 307 # DCP does the initialization if it's meta tensor/DTensor. 308 if device == torch.device("meta"): 309 device_type = dist.distributed_c10d._get_pg_default_device().type 310 device = cast(torch.device, _get_device_module(device_type).current_device()) 311 if isinstance(value, DTensor): 312 new_local_tensor = torch.empty_like(value.to_local(), device=device) 313 # We need to pass shape and stride explicitly, since DTensor might be 314 # sharded unevenly. 315 dtensor = DTensor.from_local( 316 new_local_tensor, 317 device_mesh=value.device_mesh, 318 placements=value.placements, 319 shape=value.size(), 320 stride=value.stride(), 321 ) 322 return dtensor 323 elif isinstance(value, torch.Tensor): 324 tensor = torch.empty_like(value, device=device) 325 return tensor 326 else: 327 raise RuntimeError( 328 f"Found unsupported type {type(value)} for meta device loading." 329 ) 330 else: 331 return value 332