xref: /aosp_15_r20/external/pytorch/torch/distributed/checkpoint/planner.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1import abc
2import io
3import operator
4from dataclasses import dataclass
5from enum import auto, Enum
6from functools import reduce
7from typing import Any, List, Optional, Tuple, Union
8
9import torch
10from torch.distributed.checkpoint.metadata import (
11    ChunkStorageMetadata,
12    Metadata,
13    MetadataIndex,
14    STATE_DICT_TYPE,
15    StorageMeta,
16    TensorProperties,
17)
18
19
20__all__ = [
21    "WriteItemType",
22    "LoadItemType",
23    "TensorWriteData",
24    "WriteItem",
25    "ReadItem",
26    "SavePlan",
27    "LoadPlan",
28    "SavePlanner",
29    "LoadPlanner",
30]
31
32
33class WriteItemType(Enum):
34    TENSOR = auto()
35    SHARD = auto()
36    BYTE_IO = auto()
37
38
39class LoadItemType(Enum):
40    TENSOR = auto()
41    BYTE_IO = auto()
42
43
44@dataclass(frozen=True)
45class TensorWriteData:
46    chunk: ChunkStorageMetadata
47    properties: TensorProperties
48    size: torch.Size
49
50
51@dataclass(frozen=True)
52class WriteItem:
53    """Dataclass which holds information about what needs to be written to storage."""
54
55    index: MetadataIndex
56    type: WriteItemType
57
58    # Value present if it's a tensor write
59    tensor_data: Optional[TensorWriteData] = None
60
61    def tensor_storage_size(self) -> Optional[int]:
62        """
63        Calculates the storage size of the underlying tensor, or None if this is not a tensor write.
64
65        Returns:
66            Optional[int] storage size, in bytes of underlying tensor if any.
67        """
68        if self.tensor_data is None:
69            return None
70
71        numels = reduce(operator.mul, self.tensor_data.size, 1)
72        dtype_size = torch._utils._element_size(self.tensor_data.properties.dtype)
73        return numels * dtype_size
74
75
76@dataclass(frozen=True)
77class ReadItem:
78    # Read Item
79    type: LoadItemType
80
81    # Index into the state_dict
82    dest_index: MetadataIndex
83    # Offsets into destination tensor
84    dest_offsets: torch.Size
85
86    # Index into the checkpoint
87    storage_index: MetadataIndex
88    # Offset into the checkpoint data
89    storage_offsets: torch.Size
90
91    # Size of the hypercube to copy
92    lengths: torch.Size
93
94
95@dataclass(frozen=True)
96class SavePlan:
97    items: List[WriteItem]
98    storage_data: Any = None
99    planner_data: Any = None
100
101
102@dataclass
103class LoadPlan:
104    items: List[ReadItem]
105    storage_data: Any = None
106    planner_data: Any = None
107
108
109class SavePlanner(abc.ABC):
110    """
111    Abstract class defining the protocol used by save_state_dict to plan the save process.
112
113    SavePlanners are stateful objects that can be used to customize the whole save process.
114
115    SavePlanner acts as an access proxy to the state_dict, so any transformation done to it
116    will be visible to the whole process.
117
118    A planner subclass can expect the following sequence of calls during save_state_dict:
119
120    1) set_up_planner - called on all ranks.
121        Signals the start of a checkpoint save.
122
123    2) create_local_plan - called on all ranks.
124        Process the state_dict and produces a `SavePlan` that will be sent for global planning.
125
126    3) create_global_plan - called on the coordinator rank only.
127        Takes the SavePlan from all ranks and make any global decision.
128
129    4) finish_plan - called on all ranks.
130        This gives each rank a chance to adjust to global planning decisions.
131
132    5) resolve_data - called multiple times on each rank
133        Lookups a value on the `state_dict` for the storage layer to write.
134
135    Users are recommended to extend DefaultSavePlanner instead of this interface directly as
136    most changes can be expressed by changes in a single method.
137
138    There are 3 usual patterns of extension:
139
140    Rewriting state_dict. This is the simplest way to extend the save process as it
141    doesn't requite understanding the intrincacies of how SavePlan works:
142
143    >>> # xdoctest: +SKIP("undefined vars")
144    >>> class RenamePlanner(DefaultSavePlanner):
145    >>>     def set_up_planner(
146    >>>         self,
147    >>>         state_dict: STATE_DICT_TYPE,
148    >>>         storage_meta: Optional[StorageMeta],
149    >>>         is_coordinator: bool,
150    >>>     ) -> None:
151    >>>         # prefix all keys with `foo_``
152    >>>         super().set_up_planner({"foo_" + k: v for k, v in state_dict.items()}, storage_meta, is_coordinator)
153
154    Modifying local plan and lookup in tandem. This is useful when fine control of how data is persisted
155
156    >>> # xdoctest: +SKIP("undefined vars")
157    >>> class FP16Planner(DefaultSavePlanner):
158    >>>     def create_local_plan(self):
159    >>>         plan = super().create_local_plan()
160    >>>         for p in plan:
161    >>>             if p.tensor_data is not None:
162    >>>                 p.tensor_data.properties.dtype = torch.float16
163    >>>         return plan
164    >>>
165    >>>     def resolve_data(self, write_item):
166    >>>         item = super().resolve_data(write_item)
167    >>>         return item if write_item.type == WriteItemType.BYTE_IO else item.to(torch.float16)
168
169    Using the global planning step to make central decisions that can't be made individually by each rank
170
171    >>> # xdoctest: +SKIP("undefined vars")
172    >>> from itertools import zip_longest
173    >>> from dataclasses import replace
174    >>> class DDPLoadBalancingPlanner(DefaultSavePlanner):
175    >>>     # This uses the default local plan behavior of having all non-sharded writes in rank 0
176    >>>     # This sample doesn't handle ShardedTensors
177    >>>     def create_global_plan(self, all_plans):
178    >>>         iters = [iter(all_plans[0].items)] * len(all_plans)
179    >>>         items_per_rank = [
180    >>>             [item for item in items if item is not None]
181    >>>             for items in zip(*zip_longest(*iters), strict=True)
182    >>>         ]
183    >>>         all_plans = [
184    >>>             replace(plan, items=items)
185    >>>             for plan, items in zip(all_plans, items_per_rank, strict=True)
186    >>>         ]
187    >>>         return super().create_global_plan(all_plans)
188
189    Finally, some planners need to save additional metadata in the checkpoint, this is
190    accomplished by having each rank contribute their data items in the local plan and
191    the global planner aggregate them:
192
193    >>> # xdoctest: +SKIP("undefined vars")
194    >>> class SaveExtraDataPlanner(DefaultSavePlanner):
195    >>>     def create_local_plan(self) -> SavePlan:
196    >>>         plan = super().create_local_plan()
197    >>>         return replace(plan, planner_data="per-rank-data")
198    >>>
199    >>>     def create_global_plan(self, all_plans: List[SavePlan]) -> Tuple[List[SavePlan], Metadata]:
200    >>>         global_plan, metadata = super().create_global_plan(all_plans)
201    >>>         merged_data = [p.planner_data for p in global_plan]
202    >>>         metadata = replace(metadata, planner_data=merged_data)
203    >>>         return global_plan, metadata
204    """
205
206    @abc.abstractmethod
207    def set_up_planner(
208        self,
209        state_dict: STATE_DICT_TYPE,
210        storage_meta: Optional[StorageMeta] = None,
211        is_coordinator: bool = False,
212    ) -> None:
213        """
214        Initialize this planner to save ``state_dict``.
215
216        Implementations should save those values as they won't be provided lated in the save process.
217
218        This is called on all ranks.
219        """
220
221    @abc.abstractmethod
222    def create_local_plan(self) -> SavePlan:
223        """
224        Compute the save plan for the current rank.
225
226        This will be aggregated and passed to create_global_plan.
227        Planner specific data can be passed through SavePlan::planner_data.
228
229        This is called on all ranks.
230        """
231
232    @abc.abstractmethod
233    def create_global_plan(
234        self, all_plans: List[SavePlan]
235    ) -> Tuple[List[SavePlan], Metadata]:
236        """
237        Compute the global checkpoint plan and return the local plan of each rank.
238
239        This is called on the coordinator rank only.
240        """
241
242    @abc.abstractmethod
243    def finish_plan(self, new_plan: SavePlan) -> SavePlan:
244        """
245        Merge the plan created by `create_local_plan` and the result of `create_global_plan`.
246
247        This is called on all ranks.
248        """
249
250    @abc.abstractmethod
251    def resolve_data(self, write_item: WriteItem) -> Union[torch.Tensor, io.BytesIO]:
252        """
253        Transform and prepare ``write_item`` from ``state_dict`` for storage, ensuring idempotency and thread-safety.
254
255        Lookup the object associated with ``write_item`` in ``state_dict`` and apply any
256        transformation (such as serialization) prior to the storage layer consuming it.
257
258        Called on each rank multiple times, at least once per WriteItem in the final SavePlan.
259
260        This method should be idempotent and thread-save. StorageWriter implementations
261        are free to call it as frequently as they need.
262
263        Any transformation that allocates memory should be lazily done when his method
264        is called in order to reduce peak memory required by checkpointing.
265
266        When returning tensors, they can be on any device or format, they can be views too.
267        It's the storage layer responsibility to figure out how to save them.
268        """
269
270
271class LoadPlanner:
272    """
273    Abstract class defining the protocol used by load_state_dict to plan the load process.
274
275    LoadPlanner are stateful objects that can be used to customize the whole load process.
276
277    LoadPlanner acts as an access proxy to the state_dict, so any transformation done to it
278    will be visible to the whole process.
279
280    A planner subclass can expect the following sequence of calls during load_state_dict:
281
282    1) set_up_planner - called on all ranks.
283        Signals the start of loading a checkpoint.
284
285    2) create_local_plan - called on all ranks.
286        Process the state_dict and produces a `LoadPlan` that will be sent for global planning.
287
288    3) create_global_plan - called on the coordinator rank only.
289        Takes the LoadPlan from all ranks and make any global decision.
290
291    4) load_bytes - called multiple times on each rank
292        This is called once per non-tensor value in state_dict.
293
294    5) resolve_tensor and commit_tensor - called multiple times on each rank
295        They are called in pair for each Tensor value in state_dict.
296
297    Users are recommended to extend DefaultLoadPlanner instead of this interface directly as
298    most changes can be expressed by changes in a single method.
299
300    There are two usual patterns of extension:
301
302    Rewriting state_dict. This is the simplest way to extend the load process as it
303    doesn't requite understanding the intrincacies of how LoadPlan works. We need
304    to keep a reference to the original state_dict as load happens in place so
305    we need to be able to perform it in place
306
307    >>> # xdoctest: +SKIP("undefined vars")
308    >>> class RenamePlanner(DefaultLoadPlanner):
309    >>>     def set_up_planner(
310    >>>         self,
311    >>>         state_dict: STATE_DICT_TYPE,
312    >>>         metadata: Metadata,
313    >>>         is_coordinator: bool,
314    >>>     ) -> None:
315    >>>         self.original_state_dict = state_dict
316    >>>         state_dict = {"foo_" + k: v for k, v in state_dict.items()}
317    >>>
318    >>>         if self.flatten_sharded_tensors:
319    >>>             state_dict = _flatten_sharded_tensors(state_dict)
320    >>>
321    >>>         if self.flatten_state_dict:
322    >>>             state_dict, self.mappings = flatten_state_dict(state_dict)
323    >>>
324    >>>         self.state_dict = state_dict
325    >>>         self.metadata = metadata
326    >>>         self.is_coordinator = is_coordinator
327    >>>
328    >>>     def load_bytes(self, read_item, value):
329    >>>         # Remove the "foo_" prefix
330    >>>         self.original_state_dict[read_item.dest_index.fqn[4:]] = torch.load(value, weights_only=False)
331
332
333    Modifying resolve_tensor and commit_tensor to handle load time transformation.
334
335    >>> # xdoctest: +SKIP("undefined vars")
336    >>> class MetaModelMaterialize(DefaultSavePlanner):
337    >>>     def resolve_tensor(self, read_item):
338    >>>         tensor = super().resolve_tensor(read_item)
339    >>>         return torch.empty_like(tensor, device="cpu")
340    >>>
341    >>>     def commit_tensor(self, read_item, tensor):
342    >>>         self.state_dict[read_item.dest_index.fqn] = tensor
343    """
344
345    @abc.abstractmethod
346    def set_up_planner(
347        self,
348        state_dict: STATE_DICT_TYPE,
349        metadata: Optional[Metadata] = None,
350        is_coordinator: bool = False,
351    ) -> None:
352        """
353        Initialize this instance to load data into ``state_dict``.
354
355        . N.B. This is called on every rank.
356        """
357
358    @abc.abstractmethod
359    def create_local_plan(self) -> LoadPlan:
360        """
361        Create a LoadPlan based on state_dict and metadata provided by set_up_planner.
362
363        . N.B. This is called on every rank.
364        """
365
366    @abc.abstractmethod
367    def create_global_plan(self, global_plan: List[LoadPlan]) -> List[LoadPlan]:
368        """
369        Compute the global load plan and return plans for each rank.
370
371        . N.B. This is called on the coordinator rank only
372        """
373
374    @abc.abstractmethod
375    def finish_plan(self, central_plan: LoadPlan) -> LoadPlan:
376        """Accept the plan from coordinator and return final LoadPlan."""
377
378    @abc.abstractmethod
379    def load_bytes(self, read_item: ReadItem, value: io.BytesIO) -> None:
380        """
381        Load the item described by ``read_item``and ``value``.
382
383        This method is expected to modify in-place the underlying state_dict.
384
385        The contents of ``value`` are defined by the SavePlanner used to produce
386        the checkpoint being loaded.
387        """
388
389    def resolve_bytes(self, read_item: ReadItem) -> io.BytesIO:
390        """
391        Return the BytesIO to be used by the StorageReader to load `read_item`.
392
393        The BytesIO should alias with one on the underlying state_dict as StorageReader will replace its contents.
394        """
395        raise NotImplementedError("LoadPlanner.resolve_bytes is not implemented")
396
397    @abc.abstractmethod
398    def resolve_tensor(self, read_item: ReadItem) -> torch.Tensor:
399        """
400        Return the tensor described by ``read_item`` to be used by the StorageReader to load `read_item`.
401
402        The tensor should alias with one on the underlying state_dict as StorageReader will replace its contents.
403        If, for any reason, that's not possible, the planner can use the ``commit_tensor`` method to copy the data
404        back to the one in state_dict.
405        """
406
407    @abc.abstractmethod
408    def commit_tensor(self, read_item: ReadItem, tensor: torch.Tensor) -> None:
409        """
410        Call once the StorageReader finished loading data into ``tensor``.
411
412        The provided tensor is the same one returned by the call to ``resolve_tensor``.
413        This method is only needed if this LoadPlanner needs to post process ``tensor`` prior to
414        copying it back to the one in the state_dict.
415
416        The contents of tensor will follow its device synchronization model.
417        """
418