1# mypy: allow-untyped-defs 2# Copyright (c) Meta Platforms, Inc. and affiliates 3 4import dataclasses 5import io 6import logging 7import operator 8from collections import ChainMap 9from functools import reduce 10from typing import Any, cast, Dict, List, Optional, Tuple, Union 11 12import torch 13from torch.distributed._shard._utils import narrow_tensor_by_index 14from torch.distributed.checkpoint._dedup_save_plans import dedup_save_plans 15from torch.distributed.checkpoint._nested_dict import ( 16 FLATTEN_MAPPING, 17 flatten_state_dict, 18) 19from torch.distributed.checkpoint._sharded_tensor_utils import _flatten_sharded_tensors 20from torch.distributed.checkpoint._traverse import set_element 21from torch.distributed.checkpoint.metadata import ( 22 BytesStorageMetadata, 23 ChunkStorageMetadata, 24 Metadata, 25 MetadataIndex, 26 STATE_DICT_TYPE, 27 STORAGE_TYPES, 28 StorageMeta, 29 TensorStorageMetadata, 30) 31from torch.distributed.checkpoint.planner import ( 32 LoadPlan, 33 LoadPlanner, 34 ReadItem, 35 SavePlan, 36 SavePlanner, 37 WriteItem, 38 WriteItemType, 39) 40from torch.distributed.checkpoint.planner_helpers import ( 41 _create_default_metadata_only_plan, 42 _create_read_items, 43 _create_write_items, 44 _init_state_dict, 45) 46from torch.distributed.checkpoint.utils import find_state_dict_object 47from torch.distributed.tensor import DTensor 48 49from . import _version 50 51 52logger: logging.Logger = logging.getLogger(__name__) 53 54 55__all__ = [ 56 "DefaultSavePlanner", 57 "DefaultLoadPlanner", 58 "create_default_local_load_plan", 59 "create_default_global_load_plan", 60 "create_default_local_save_plan", 61 "create_default_global_save_plan", 62] 63 64 65# TODO: Update docstrings for default_planner.py 66class DefaultSavePlanner(SavePlanner): 67 mappings: FLATTEN_MAPPING 68 69 def __init__( 70 self, 71 flatten_state_dict: bool = True, 72 flatten_sharded_tensors: bool = True, 73 dedup_replicated_tensors: Optional[bool] = None, 74 dedup_save_to_lowest_rank: bool = False, 75 ) -> None: 76 self.flatten_state_dict = flatten_state_dict 77 self.flatten_sharded_tensors = flatten_sharded_tensors 78 self.mappings = {} 79 self.dedup_save_to_lowest_rank = dedup_save_to_lowest_rank 80 if dedup_replicated_tensors is not None: 81 logger.warning( 82 "DefaultSavePlanner's `dedup_replicated_tensors` argument is being " 83 "deprecated, and no longer has any effect. Please remove this argument " 84 "from your call." 85 ) 86 87 def set_up_planner( 88 self, 89 state_dict: STATE_DICT_TYPE, 90 storage_meta: Optional[StorageMeta] = None, 91 is_coordinator: bool = False, 92 ) -> None: 93 if self.flatten_state_dict: 94 state_dict, self.mappings = flatten_state_dict(state_dict) 95 if self.flatten_sharded_tensors: 96 state_dict = _flatten_sharded_tensors(state_dict) 97 self.state_dict = state_dict 98 self.is_coordinator = is_coordinator 99 100 def create_local_plan(self) -> SavePlan: 101 plan = create_default_local_save_plan(self.state_dict, self.is_coordinator) 102 if self.flatten_state_dict: 103 plan = dataclasses.replace(plan, planner_data=self.mappings) 104 self.plan = plan 105 106 return self.plan 107 108 def create_global_plan( 109 self, all_plans: List[SavePlan] 110 ) -> Tuple[List[SavePlan], Metadata]: 111 all_plans = dedup_save_plans(all_plans, self.dedup_save_to_lowest_rank) 112 113 global_plan, metadata = create_default_global_save_plan(all_plans) 114 115 if self.flatten_state_dict: 116 # | does not work for Python 3.8 or older version. 117 # merged_mappings = reduce( 118 # lambda x, y: x | y, (p.planner_data for p in global_plan) 119 # ) 120 planner_data_dict = [p.planner_data for p in global_plan] 121 merged_mappings = dict(ChainMap(*planner_data_dict)) 122 metadata = dataclasses.replace(metadata, planner_data=merged_mappings) 123 124 if not _validate_global_plan(global_plan, metadata): 125 raise ValueError("Failed to validate global plan") 126 127 self.global_plan = global_plan 128 self.metadata = metadata 129 130 return self.global_plan, self.metadata 131 132 def finish_plan(self, new_plan: SavePlan) -> SavePlan: 133 self.plan = new_plan 134 return new_plan 135 136 def resolve_data(self, write_item: WriteItem) -> Union[torch.Tensor, io.BytesIO]: 137 object = self.lookup_object(write_item.index) 138 return self.transform_object(write_item, object) 139 140 def lookup_object(self, index: MetadataIndex) -> Any: 141 """Extension from the planner interface to make it easy to extend the default planner.""" 142 return find_state_dict_object(self.state_dict, index) 143 144 def transform_object(self, write_item: WriteItem, object: Any): 145 """Extension from the planner interface to make it easy to extend the default planner.""" 146 if write_item.type == WriteItemType.BYTE_IO: 147 bytes = io.BytesIO() 148 torch.save(object, bytes) 149 object = bytes 150 return object 151 152 153class DefaultLoadPlanner(LoadPlanner): 154 """ 155 DefaultLoadPlanner that adds multiple features on top of LoadPlanner. 156 157 In particular it adds the following: 158 159 flatten_state_dict: Handle state_dict with nested dicts 160 flatten_sharded_tensors: For FSDP in 2D parallel mode 161 allow_partial_load: If False, will raise a runtime error if a key is present in state_dict, but not in the checkpoint. 162 """ 163 164 original_state_dict: STATE_DICT_TYPE 165 mappings: FLATTEN_MAPPING 166 167 def __init__( 168 self, 169 flatten_state_dict: bool = True, 170 flatten_sharded_tensors: bool = True, 171 allow_partial_load: bool = False, 172 ) -> None: 173 self.flatten_state_dict = flatten_state_dict 174 self.flatten_sharded_tensors = flatten_sharded_tensors 175 self.original_state_dict = {} 176 self.mappings = {} 177 self.allow_partial_load = allow_partial_load 178 179 def set_up_planner( 180 self, 181 state_dict: STATE_DICT_TYPE, 182 metadata: Optional[Metadata] = None, 183 is_coordinator: bool = False, 184 ) -> None: 185 _init_state_dict(state_dict) 186 self.original_state_dict = state_dict 187 188 if self.flatten_sharded_tensors: 189 state_dict = _flatten_sharded_tensors(state_dict) 190 191 if self.flatten_state_dict: 192 state_dict, self.mappings = flatten_state_dict(state_dict) 193 194 self.state_dict = state_dict 195 self.metadata = metadata 196 self.is_coordinator = is_coordinator 197 198 def create_local_plan(self) -> LoadPlan: 199 assert self.metadata is not None 200 if self.flatten_state_dict: 201 # To support checkpoints that are saved before v2.4, we have to 202 # differentiate if the missing keys are due to old checkpoints. 203 # The contracts are: 204 # 1. There are 3 cases when we found a missing key. 205 # 1.1 Actual missing key, but allow_partial_load is False 206 # 1.2 Actual missing key, but allow_partial load is True 207 # 1.3 Old checkpoint, but allow_partial_load is False 208 # 1.4 Old checkpoint, but allow_partial_load is True 209 # 2. If we found a missing key, we first convert the keys back to 210 # the key format of v2.3 211 # 3. If the previous missing keys are in the v2.3 keys, we assume 212 # this is a old checkpoint. 213 # 4. Pass the state_dict to `create_default_local_load_plan()`, 214 # which has the logic to check missing for allow_partial_load. 215 # So for 1.2 and 1.4 cases, we delegate allow_partial_load check to 216 # `create_default_local_load_plan()`. The logic here is to determine 217 # whether the checkpoint belong to 2.3 (or before) or 2.4 (or after). 218 current_keys = set(self.state_dict.keys()) 219 load_keys = set(self.metadata.state_dict_metadata.keys()) 220 missing_keys = load_keys - current_keys 221 if missing_keys: 222 _version._derived_version = "2_3" 223 old_state_dict, old_mappings = flatten_state_dict( 224 self.original_state_dict 225 ) 226 old_keys = set(old_state_dict.keys()) 227 if old_keys & missing_keys: 228 self.state_dict, self.mappings = old_state_dict, old_mappings 229 # _derived_version is only used by flatten_state_dict now. 230 # Set it back to None so that later we can save to a new version. 231 _version._derived_version = None 232 233 return create_default_local_load_plan( 234 self.state_dict, self.metadata, not self.allow_partial_load 235 ) 236 237 def create_global_plan(self, global_plan: List[LoadPlan]) -> List[LoadPlan]: 238 return create_default_global_load_plan(global_plan) 239 240 def finish_plan(self, new_plan: LoadPlan) -> LoadPlan: 241 return new_plan 242 243 def load_bytes(self, read_item: ReadItem, value: io.BytesIO) -> None: 244 if self.flatten_state_dict: 245 set_element( 246 self.original_state_dict, 247 self.mappings[read_item.dest_index.fqn], 248 torch.load(value, weights_only=False), 249 ) 250 else: 251 self.state_dict[read_item.dest_index.fqn] = torch.load( 252 value, weights_only=False 253 ) 254 255 def resolve_tensor(self, read_item: ReadItem): 256 tensor = self.lookup_tensor(read_item.dest_index) 257 return self.transform_tensor(read_item, tensor) 258 259 def commit_tensor(self, read_item: ReadItem, tensor: torch.Tensor) -> None: 260 pass 261 262 def lookup_tensor(self, index: MetadataIndex) -> torch.Tensor: 263 """Extension from the planner interface to make it easy to extend the default planner.""" 264 return find_state_dict_object(self.state_dict, index) 265 266 def transform_tensor(self, read_item: ReadItem, tensor: torch.Tensor): 267 """Extension from the planner interface to make it easy to extend the default planner.""" 268 return narrow_tensor_by_index(tensor, read_item.dest_offsets, read_item.lengths) 269 270 271class _EmptyStateDictLoadPlanner(DefaultLoadPlanner): 272 """ 273 Extension of DefaultLoadPlanner, which rebuilds state_dict from the saved metadata. 274 Useful for loading in state_dict without first initializing a model, such as 275 when converting a DCP checkpoint into a Torch save file. 276 277 . N.B. `state_dict` must be an empty dictionary when used with this LoadPlanner 278 279 .. warning:: 280 Because the entire state dict is initialized, It's recommended to only utilize 281 this LoadPlanner on a single rank or process to avoid OOM. 282 283 """ 284 285 def __init__(self, keys=None, *args, **kwargs): 286 self.keys = keys 287 super().__init__(*args, **kwargs) 288 289 def _should_include_key(self, key: str, metadata: Metadata) -> bool: 290 if self.keys is None: 291 return True 292 293 if key in self.keys: 294 True 295 296 unflattened_keys: List[str] = [] 297 planner_data = metadata.planner_data.get(key) 298 for unflattened_key in planner_data: 299 if unflattened_keys: 300 unflattened_keys.append( 301 ".".join([unflattened_keys[-1], str(unflattened_key)]) 302 ) 303 304 else: 305 unflattened_keys.append(unflattened_key) 306 307 if any(unflattened_key in self.keys for unflattened_key in unflattened_keys): 308 return True 309 310 return False 311 312 def set_up_planner( 313 self, 314 state_dict: STATE_DICT_TYPE, 315 metadata: Optional[Metadata] = None, 316 is_coordinator: bool = False, 317 ) -> None: 318 assert not state_dict 319 assert metadata is not None 320 321 # rebuild the state dict from the metadata 322 for k, v in metadata.state_dict_metadata.items(): 323 if not self._should_include_key(k, metadata): 324 continue 325 326 if isinstance(v, TensorStorageMetadata): 327 v = torch.empty(v.size, dtype=v.properties.dtype) # type: ignore[assignment] 328 if k in metadata.planner_data: 329 set_element(state_dict, metadata.planner_data[k], v) 330 else: 331 state_dict[k] = v 332 333 super().set_up_planner(state_dict, metadata, is_coordinator) 334 335 336def create_default_local_load_plan( 337 state_dict: Dict[str, Any], metadata: Metadata, strict: bool = True 338) -> LoadPlan: 339 requests = [] 340 """ 341 Create the ``LoadPlan`` used by DefaultLoadPlanner. 342 343 It produces one read item per value in ``state_dict`` using the metadata in ``metadata``. 344 345 The default behavior is to match key exactly between state_dict and metadata. 346 It handles resharding by issuing multiple read requests against storage in order to match 347 load requirements. 348 """ 349 350 for fqn, obj in state_dict.items(): 351 # ignore state_dict keys which do not exist in `state_dict` if strict=False 352 if fqn not in metadata.state_dict_metadata: 353 if strict: 354 raise RuntimeError(f"Missing key in checkpoint state_dict: {fqn}.") 355 else: 356 continue 357 358 md = metadata.state_dict_metadata[fqn] 359 # Since DTensor supports submesh, adding extra check to ensure _create_read_items() 360 # gets called only when the current rank is part of the mesh for the corresponding DTensor. 361 if isinstance(obj, DTensor): 362 if obj.device_mesh.get_coordinate() is not None: 363 requests += _create_read_items(fqn, md, obj) 364 else: 365 requests += _create_read_items(fqn, md, obj) 366 367 return LoadPlan(requests) 368 369 370def create_default_global_load_plan( 371 all_plans: List[LoadPlan], 372) -> List[LoadPlan]: 373 """ 374 Create global load plan used by DefaultLoadPlanner. 375 376 The default load behavior involved no global coordination and this function 377 currently doesn't change the local plans. 378 """ 379 return all_plans 380 381 382def create_default_local_save_plan( 383 state_dict: Dict[str, Any], is_coordinator: bool 384) -> SavePlan: 385 """ 386 Create the ``SavePlan`` used by DefaultSavePlanner. 387 388 On non-coordinator ranks, this function ignores tensors and non-tensor objects, 389 only producing writes for ShardedTensor objects. 390 391 On the coordinator rank, produce writes for all values. 392 """ 393 requests = [] 394 for fqn, obj in state_dict.items(): 395 # Since DTensor supports submesh, adding extra check to ensure _create_write_items() 396 # gets called only when the current rank is part of the mesh for the corresponding DTensor. 397 if isinstance(obj, DTensor): 398 if obj.device_mesh.get_coordinate() is not None: 399 requests += _create_write_items(fqn, obj) 400 else: 401 # For the plain tensor and non-tensor values, add the request for all 402 # the ranks. Coordinator will decides whether to deduplicate the 403 # values based on the keys. 404 requests += _create_write_items(fqn, obj) 405 406 return SavePlan(requests) 407 408 409def create_default_global_save_plan( 410 all_plans: List[SavePlan], 411 rewrite_index_hints: bool = True, 412) -> Tuple[List[SavePlan], Metadata]: 413 """ 414 Create the global plan and metadata used by DefaultSavePlanner. 415 416 Metadata is produced by concatenating the metadata of all ``WriteItem`` from the supplied plans. 417 418 The only global planning change is to update index hints in all ``MetadataIndex`` objects if 419 ``rewrite_index_hints`` is True. 420 """ 421 md: Dict[str, STORAGE_TYPES] = {} 422 new_plans = [] 423 for plan in all_plans: 424 new_items = [] 425 for item in plan.items: 426 if not item.type == WriteItemType.SHARD: 427 assert item.index.fqn not in md 428 429 if item.type == WriteItemType.BYTE_IO: 430 md[item.index.fqn] = BytesStorageMetadata() 431 new_items.append(item) 432 else: 433 assert item.tensor_data is not None 434 tensor_md = cast( 435 TensorStorageMetadata, 436 md.setdefault( 437 item.index.fqn, 438 TensorStorageMetadata( 439 properties=item.tensor_data.properties, 440 size=item.tensor_data.size, 441 chunks=[], 442 ), 443 ), 444 ) 445 new_item = item 446 if rewrite_index_hints: 447 new_index = dataclasses.replace( 448 item.index, index=len(tensor_md.chunks) 449 ) 450 new_item = dataclasses.replace(item, index=new_index) 451 new_items.append(new_item) 452 453 assert ( 454 item.tensor_data.chunk is not None 455 ), f""" 456 Cannot create MD for tensor without bounds. 457 FQN: {item.index.fqn} 458 """ 459 tensor_md.chunks.append(item.tensor_data.chunk) 460 new_plans.append(dataclasses.replace(plan, items=new_items)) 461 return (new_plans, Metadata(md)) 462 463 464def _create_default_local_metadata(state_dict: STATE_DICT_TYPE) -> Metadata: 465 """Return the ``Metadata`` if DefaultSavePlanner was used to checkpoint ``state_dict``.""" 466 plan = _create_default_metadata_only_plan(state_dict) 467 _, md = create_default_global_save_plan([plan]) 468 return md 469 470 471def _check_box_overlap(box0: ChunkStorageMetadata, box1: ChunkStorageMetadata) -> bool: 472 """Check if two boxes overlap. Tuples are (offset, lengths).""" 473 # For each dim of each shard, check if one shard resides on the other 474 # end of second shard with respect to that dim. As an example for a 2D 475 # shard, we would check if one shard is above or on the left of the 476 # other shard. 477 ndims = len(box0.offsets) 478 for i in range(ndims): 479 if box0.offsets[i] >= box1.offsets[i] + box1.sizes[i]: 480 return False 481 if box1.offsets[i] >= box0.offsets[i] + box0.sizes[i]: 482 return False 483 484 return True 485 486 487def _check_box_bounds( 488 outer_box_size: torch.Size, inner_box: ChunkStorageMetadata 489) -> bool: 490 for i in range(len(outer_box_size)): 491 if inner_box.offsets[i] < 0: 492 return False 493 if inner_box.sizes[i] < 0: 494 return False 495 if inner_box.offsets[i] + inner_box.sizes[i] > outer_box_size[i]: 496 return False 497 498 return True 499 500 501def _validate_global_plan(global_plan: List[SavePlan], metadata: Metadata) -> bool: 502 all_good = True 503 for key, value in metadata.state_dict_metadata.items(): 504 if isinstance(value, BytesStorageMetadata): 505 continue 506 if len(value.size) == 0: 507 continue 508 chunks_volume = 0 509 for chunk_idx, chunk0 in enumerate(value.chunks): 510 # Compute the volume 511 if not _check_box_bounds(value.size, chunk0): 512 logger.warning( 513 """ 514 key:%s has out of bounds chunk: 515 tensor-size:%s chunk: %s 516 """, 517 key, 518 value.size, 519 chunk0, 520 ) 521 all_good = False 522 chunks_volume += reduce(operator.mul, chunk0.sizes, 1) 523 524 # Check for overlap 525 for chunk1 in value.chunks[chunk_idx + 1 :]: 526 if _check_box_overlap(chunk0, chunk1): 527 logger.warning( 528 "key:%s has overlapping chunks: %s %s", key, chunk0, chunk1 529 ) 530 all_good = False 531 532 # Check whether combined chunk cover the whole tensor 533 tensor_volume = reduce(operator.mul, value.size, 1) 534 if chunks_volume != tensor_volume: 535 logger.warning( 536 """ 537 key:%s invalid fill tensor-volume: 538 %s chunks-volume: %s 539 """, 540 key, 541 tensor_volume, 542 chunks_volume, 543 ) 544 all_good = False 545 546 return all_good 547