1import abc 2import io 3import operator 4from dataclasses import dataclass 5from enum import auto, Enum 6from functools import reduce 7from typing import Any, List, Optional, Tuple, Union 8 9import torch 10from torch.distributed.checkpoint.metadata import ( 11 ChunkStorageMetadata, 12 Metadata, 13 MetadataIndex, 14 STATE_DICT_TYPE, 15 StorageMeta, 16 TensorProperties, 17) 18 19 20__all__ = [ 21 "WriteItemType", 22 "LoadItemType", 23 "TensorWriteData", 24 "WriteItem", 25 "ReadItem", 26 "SavePlan", 27 "LoadPlan", 28 "SavePlanner", 29 "LoadPlanner", 30] 31 32 33class WriteItemType(Enum): 34 TENSOR = auto() 35 SHARD = auto() 36 BYTE_IO = auto() 37 38 39class LoadItemType(Enum): 40 TENSOR = auto() 41 BYTE_IO = auto() 42 43 44@dataclass(frozen=True) 45class TensorWriteData: 46 chunk: ChunkStorageMetadata 47 properties: TensorProperties 48 size: torch.Size 49 50 51@dataclass(frozen=True) 52class WriteItem: 53 """Dataclass which holds information about what needs to be written to storage.""" 54 55 index: MetadataIndex 56 type: WriteItemType 57 58 # Value present if it's a tensor write 59 tensor_data: Optional[TensorWriteData] = None 60 61 def tensor_storage_size(self) -> Optional[int]: 62 """ 63 Calculates the storage size of the underlying tensor, or None if this is not a tensor write. 64 65 Returns: 66 Optional[int] storage size, in bytes of underlying tensor if any. 67 """ 68 if self.tensor_data is None: 69 return None 70 71 numels = reduce(operator.mul, self.tensor_data.size, 1) 72 dtype_size = torch._utils._element_size(self.tensor_data.properties.dtype) 73 return numels * dtype_size 74 75 76@dataclass(frozen=True) 77class ReadItem: 78 # Read Item 79 type: LoadItemType 80 81 # Index into the state_dict 82 dest_index: MetadataIndex 83 # Offsets into destination tensor 84 dest_offsets: torch.Size 85 86 # Index into the checkpoint 87 storage_index: MetadataIndex 88 # Offset into the checkpoint data 89 storage_offsets: torch.Size 90 91 # Size of the hypercube to copy 92 lengths: torch.Size 93 94 95@dataclass(frozen=True) 96class SavePlan: 97 items: List[WriteItem] 98 storage_data: Any = None 99 planner_data: Any = None 100 101 102@dataclass 103class LoadPlan: 104 items: List[ReadItem] 105 storage_data: Any = None 106 planner_data: Any = None 107 108 109class SavePlanner(abc.ABC): 110 """ 111 Abstract class defining the protocol used by save_state_dict to plan the save process. 112 113 SavePlanners are stateful objects that can be used to customize the whole save process. 114 115 SavePlanner acts as an access proxy to the state_dict, so any transformation done to it 116 will be visible to the whole process. 117 118 A planner subclass can expect the following sequence of calls during save_state_dict: 119 120 1) set_up_planner - called on all ranks. 121 Signals the start of a checkpoint save. 122 123 2) create_local_plan - called on all ranks. 124 Process the state_dict and produces a `SavePlan` that will be sent for global planning. 125 126 3) create_global_plan - called on the coordinator rank only. 127 Takes the SavePlan from all ranks and make any global decision. 128 129 4) finish_plan - called on all ranks. 130 This gives each rank a chance to adjust to global planning decisions. 131 132 5) resolve_data - called multiple times on each rank 133 Lookups a value on the `state_dict` for the storage layer to write. 134 135 Users are recommended to extend DefaultSavePlanner instead of this interface directly as 136 most changes can be expressed by changes in a single method. 137 138 There are 3 usual patterns of extension: 139 140 Rewriting state_dict. This is the simplest way to extend the save process as it 141 doesn't requite understanding the intrincacies of how SavePlan works: 142 143 >>> # xdoctest: +SKIP("undefined vars") 144 >>> class RenamePlanner(DefaultSavePlanner): 145 >>> def set_up_planner( 146 >>> self, 147 >>> state_dict: STATE_DICT_TYPE, 148 >>> storage_meta: Optional[StorageMeta], 149 >>> is_coordinator: bool, 150 >>> ) -> None: 151 >>> # prefix all keys with `foo_`` 152 >>> super().set_up_planner({"foo_" + k: v for k, v in state_dict.items()}, storage_meta, is_coordinator) 153 154 Modifying local plan and lookup in tandem. This is useful when fine control of how data is persisted 155 156 >>> # xdoctest: +SKIP("undefined vars") 157 >>> class FP16Planner(DefaultSavePlanner): 158 >>> def create_local_plan(self): 159 >>> plan = super().create_local_plan() 160 >>> for p in plan: 161 >>> if p.tensor_data is not None: 162 >>> p.tensor_data.properties.dtype = torch.float16 163 >>> return plan 164 >>> 165 >>> def resolve_data(self, write_item): 166 >>> item = super().resolve_data(write_item) 167 >>> return item if write_item.type == WriteItemType.BYTE_IO else item.to(torch.float16) 168 169 Using the global planning step to make central decisions that can't be made individually by each rank 170 171 >>> # xdoctest: +SKIP("undefined vars") 172 >>> from itertools import zip_longest 173 >>> from dataclasses import replace 174 >>> class DDPLoadBalancingPlanner(DefaultSavePlanner): 175 >>> # This uses the default local plan behavior of having all non-sharded writes in rank 0 176 >>> # This sample doesn't handle ShardedTensors 177 >>> def create_global_plan(self, all_plans): 178 >>> iters = [iter(all_plans[0].items)] * len(all_plans) 179 >>> items_per_rank = [ 180 >>> [item for item in items if item is not None] 181 >>> for items in zip(*zip_longest(*iters), strict=True) 182 >>> ] 183 >>> all_plans = [ 184 >>> replace(plan, items=items) 185 >>> for plan, items in zip(all_plans, items_per_rank, strict=True) 186 >>> ] 187 >>> return super().create_global_plan(all_plans) 188 189 Finally, some planners need to save additional metadata in the checkpoint, this is 190 accomplished by having each rank contribute their data items in the local plan and 191 the global planner aggregate them: 192 193 >>> # xdoctest: +SKIP("undefined vars") 194 >>> class SaveExtraDataPlanner(DefaultSavePlanner): 195 >>> def create_local_plan(self) -> SavePlan: 196 >>> plan = super().create_local_plan() 197 >>> return replace(plan, planner_data="per-rank-data") 198 >>> 199 >>> def create_global_plan(self, all_plans: List[SavePlan]) -> Tuple[List[SavePlan], Metadata]: 200 >>> global_plan, metadata = super().create_global_plan(all_plans) 201 >>> merged_data = [p.planner_data for p in global_plan] 202 >>> metadata = replace(metadata, planner_data=merged_data) 203 >>> return global_plan, metadata 204 """ 205 206 @abc.abstractmethod 207 def set_up_planner( 208 self, 209 state_dict: STATE_DICT_TYPE, 210 storage_meta: Optional[StorageMeta] = None, 211 is_coordinator: bool = False, 212 ) -> None: 213 """ 214 Initialize this planner to save ``state_dict``. 215 216 Implementations should save those values as they won't be provided lated in the save process. 217 218 This is called on all ranks. 219 """ 220 221 @abc.abstractmethod 222 def create_local_plan(self) -> SavePlan: 223 """ 224 Compute the save plan for the current rank. 225 226 This will be aggregated and passed to create_global_plan. 227 Planner specific data can be passed through SavePlan::planner_data. 228 229 This is called on all ranks. 230 """ 231 232 @abc.abstractmethod 233 def create_global_plan( 234 self, all_plans: List[SavePlan] 235 ) -> Tuple[List[SavePlan], Metadata]: 236 """ 237 Compute the global checkpoint plan and return the local plan of each rank. 238 239 This is called on the coordinator rank only. 240 """ 241 242 @abc.abstractmethod 243 def finish_plan(self, new_plan: SavePlan) -> SavePlan: 244 """ 245 Merge the plan created by `create_local_plan` and the result of `create_global_plan`. 246 247 This is called on all ranks. 248 """ 249 250 @abc.abstractmethod 251 def resolve_data(self, write_item: WriteItem) -> Union[torch.Tensor, io.BytesIO]: 252 """ 253 Transform and prepare ``write_item`` from ``state_dict`` for storage, ensuring idempotency and thread-safety. 254 255 Lookup the object associated with ``write_item`` in ``state_dict`` and apply any 256 transformation (such as serialization) prior to the storage layer consuming it. 257 258 Called on each rank multiple times, at least once per WriteItem in the final SavePlan. 259 260 This method should be idempotent and thread-save. StorageWriter implementations 261 are free to call it as frequently as they need. 262 263 Any transformation that allocates memory should be lazily done when his method 264 is called in order to reduce peak memory required by checkpointing. 265 266 When returning tensors, they can be on any device or format, they can be views too. 267 It's the storage layer responsibility to figure out how to save them. 268 """ 269 270 271class LoadPlanner: 272 """ 273 Abstract class defining the protocol used by load_state_dict to plan the load process. 274 275 LoadPlanner are stateful objects that can be used to customize the whole load process. 276 277 LoadPlanner acts as an access proxy to the state_dict, so any transformation done to it 278 will be visible to the whole process. 279 280 A planner subclass can expect the following sequence of calls during load_state_dict: 281 282 1) set_up_planner - called on all ranks. 283 Signals the start of loading a checkpoint. 284 285 2) create_local_plan - called on all ranks. 286 Process the state_dict and produces a `LoadPlan` that will be sent for global planning. 287 288 3) create_global_plan - called on the coordinator rank only. 289 Takes the LoadPlan from all ranks and make any global decision. 290 291 4) load_bytes - called multiple times on each rank 292 This is called once per non-tensor value in state_dict. 293 294 5) resolve_tensor and commit_tensor - called multiple times on each rank 295 They are called in pair for each Tensor value in state_dict. 296 297 Users are recommended to extend DefaultLoadPlanner instead of this interface directly as 298 most changes can be expressed by changes in a single method. 299 300 There are two usual patterns of extension: 301 302 Rewriting state_dict. This is the simplest way to extend the load process as it 303 doesn't requite understanding the intrincacies of how LoadPlan works. We need 304 to keep a reference to the original state_dict as load happens in place so 305 we need to be able to perform it in place 306 307 >>> # xdoctest: +SKIP("undefined vars") 308 >>> class RenamePlanner(DefaultLoadPlanner): 309 >>> def set_up_planner( 310 >>> self, 311 >>> state_dict: STATE_DICT_TYPE, 312 >>> metadata: Metadata, 313 >>> is_coordinator: bool, 314 >>> ) -> None: 315 >>> self.original_state_dict = state_dict 316 >>> state_dict = {"foo_" + k: v for k, v in state_dict.items()} 317 >>> 318 >>> if self.flatten_sharded_tensors: 319 >>> state_dict = _flatten_sharded_tensors(state_dict) 320 >>> 321 >>> if self.flatten_state_dict: 322 >>> state_dict, self.mappings = flatten_state_dict(state_dict) 323 >>> 324 >>> self.state_dict = state_dict 325 >>> self.metadata = metadata 326 >>> self.is_coordinator = is_coordinator 327 >>> 328 >>> def load_bytes(self, read_item, value): 329 >>> # Remove the "foo_" prefix 330 >>> self.original_state_dict[read_item.dest_index.fqn[4:]] = torch.load(value, weights_only=False) 331 332 333 Modifying resolve_tensor and commit_tensor to handle load time transformation. 334 335 >>> # xdoctest: +SKIP("undefined vars") 336 >>> class MetaModelMaterialize(DefaultSavePlanner): 337 >>> def resolve_tensor(self, read_item): 338 >>> tensor = super().resolve_tensor(read_item) 339 >>> return torch.empty_like(tensor, device="cpu") 340 >>> 341 >>> def commit_tensor(self, read_item, tensor): 342 >>> self.state_dict[read_item.dest_index.fqn] = tensor 343 """ 344 345 @abc.abstractmethod 346 def set_up_planner( 347 self, 348 state_dict: STATE_DICT_TYPE, 349 metadata: Optional[Metadata] = None, 350 is_coordinator: bool = False, 351 ) -> None: 352 """ 353 Initialize this instance to load data into ``state_dict``. 354 355 . N.B. This is called on every rank. 356 """ 357 358 @abc.abstractmethod 359 def create_local_plan(self) -> LoadPlan: 360 """ 361 Create a LoadPlan based on state_dict and metadata provided by set_up_planner. 362 363 . N.B. This is called on every rank. 364 """ 365 366 @abc.abstractmethod 367 def create_global_plan(self, global_plan: List[LoadPlan]) -> List[LoadPlan]: 368 """ 369 Compute the global load plan and return plans for each rank. 370 371 . N.B. This is called on the coordinator rank only 372 """ 373 374 @abc.abstractmethod 375 def finish_plan(self, central_plan: LoadPlan) -> LoadPlan: 376 """Accept the plan from coordinator and return final LoadPlan.""" 377 378 @abc.abstractmethod 379 def load_bytes(self, read_item: ReadItem, value: io.BytesIO) -> None: 380 """ 381 Load the item described by ``read_item``and ``value``. 382 383 This method is expected to modify in-place the underlying state_dict. 384 385 The contents of ``value`` are defined by the SavePlanner used to produce 386 the checkpoint being loaded. 387 """ 388 389 def resolve_bytes(self, read_item: ReadItem) -> io.BytesIO: 390 """ 391 Return the BytesIO to be used by the StorageReader to load `read_item`. 392 393 The BytesIO should alias with one on the underlying state_dict as StorageReader will replace its contents. 394 """ 395 raise NotImplementedError("LoadPlanner.resolve_bytes is not implemented") 396 397 @abc.abstractmethod 398 def resolve_tensor(self, read_item: ReadItem) -> torch.Tensor: 399 """ 400 Return the tensor described by ``read_item`` to be used by the StorageReader to load `read_item`. 401 402 The tensor should alias with one on the underlying state_dict as StorageReader will replace its contents. 403 If, for any reason, that's not possible, the planner can use the ``commit_tensor`` method to copy the data 404 back to the one in state_dict. 405 """ 406 407 @abc.abstractmethod 408 def commit_tensor(self, read_item: ReadItem, tensor: torch.Tensor) -> None: 409 """ 410 Call once the StorageReader finished loading data into ``tensor``. 411 412 The provided tensor is the same one returned by the call to ``resolve_tensor``. 413 This method is only needed if this LoadPlanner needs to post process ``tensor`` prior to 414 copying it back to the one in the state_dict. 415 416 The contents of tensor will follow its device synchronization model. 417 """ 418