xref: /aosp_15_r20/external/pytorch/torch/distributed/checkpoint/default_planner.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# mypy: allow-untyped-defs
2# Copyright (c) Meta Platforms, Inc. and affiliates
3
4import dataclasses
5import io
6import logging
7import operator
8from collections import ChainMap
9from functools import reduce
10from typing import Any, cast, Dict, List, Optional, Tuple, Union
11
12import torch
13from torch.distributed._shard._utils import narrow_tensor_by_index
14from torch.distributed.checkpoint._dedup_save_plans import dedup_save_plans
15from torch.distributed.checkpoint._nested_dict import (
16    FLATTEN_MAPPING,
17    flatten_state_dict,
18)
19from torch.distributed.checkpoint._sharded_tensor_utils import _flatten_sharded_tensors
20from torch.distributed.checkpoint._traverse import set_element
21from torch.distributed.checkpoint.metadata import (
22    BytesStorageMetadata,
23    ChunkStorageMetadata,
24    Metadata,
25    MetadataIndex,
26    STATE_DICT_TYPE,
27    STORAGE_TYPES,
28    StorageMeta,
29    TensorStorageMetadata,
30)
31from torch.distributed.checkpoint.planner import (
32    LoadPlan,
33    LoadPlanner,
34    ReadItem,
35    SavePlan,
36    SavePlanner,
37    WriteItem,
38    WriteItemType,
39)
40from torch.distributed.checkpoint.planner_helpers import (
41    _create_default_metadata_only_plan,
42    _create_read_items,
43    _create_write_items,
44    _init_state_dict,
45)
46from torch.distributed.checkpoint.utils import find_state_dict_object
47from torch.distributed.tensor import DTensor
48
49from . import _version
50
51
52logger: logging.Logger = logging.getLogger(__name__)
53
54
55__all__ = [
56    "DefaultSavePlanner",
57    "DefaultLoadPlanner",
58    "create_default_local_load_plan",
59    "create_default_global_load_plan",
60    "create_default_local_save_plan",
61    "create_default_global_save_plan",
62]
63
64
65# TODO: Update docstrings for default_planner.py
66class DefaultSavePlanner(SavePlanner):
67    mappings: FLATTEN_MAPPING
68
69    def __init__(
70        self,
71        flatten_state_dict: bool = True,
72        flatten_sharded_tensors: bool = True,
73        dedup_replicated_tensors: Optional[bool] = None,
74        dedup_save_to_lowest_rank: bool = False,
75    ) -> None:
76        self.flatten_state_dict = flatten_state_dict
77        self.flatten_sharded_tensors = flatten_sharded_tensors
78        self.mappings = {}
79        self.dedup_save_to_lowest_rank = dedup_save_to_lowest_rank
80        if dedup_replicated_tensors is not None:
81            logger.warning(
82                "DefaultSavePlanner's `dedup_replicated_tensors` argument is being "
83                "deprecated, and no longer has any effect. Please remove this argument "
84                "from your call."
85            )
86
87    def set_up_planner(
88        self,
89        state_dict: STATE_DICT_TYPE,
90        storage_meta: Optional[StorageMeta] = None,
91        is_coordinator: bool = False,
92    ) -> None:
93        if self.flatten_state_dict:
94            state_dict, self.mappings = flatten_state_dict(state_dict)
95        if self.flatten_sharded_tensors:
96            state_dict = _flatten_sharded_tensors(state_dict)
97        self.state_dict = state_dict
98        self.is_coordinator = is_coordinator
99
100    def create_local_plan(self) -> SavePlan:
101        plan = create_default_local_save_plan(self.state_dict, self.is_coordinator)
102        if self.flatten_state_dict:
103            plan = dataclasses.replace(plan, planner_data=self.mappings)
104        self.plan = plan
105
106        return self.plan
107
108    def create_global_plan(
109        self, all_plans: List[SavePlan]
110    ) -> Tuple[List[SavePlan], Metadata]:
111        all_plans = dedup_save_plans(all_plans, self.dedup_save_to_lowest_rank)
112
113        global_plan, metadata = create_default_global_save_plan(all_plans)
114
115        if self.flatten_state_dict:
116            # | does not work for Python 3.8 or older version.
117            # merged_mappings = reduce(
118            #     lambda x, y: x | y, (p.planner_data for p in global_plan)
119            # )
120            planner_data_dict = [p.planner_data for p in global_plan]
121            merged_mappings = dict(ChainMap(*planner_data_dict))
122            metadata = dataclasses.replace(metadata, planner_data=merged_mappings)
123
124        if not _validate_global_plan(global_plan, metadata):
125            raise ValueError("Failed to validate global plan")
126
127        self.global_plan = global_plan
128        self.metadata = metadata
129
130        return self.global_plan, self.metadata
131
132    def finish_plan(self, new_plan: SavePlan) -> SavePlan:
133        self.plan = new_plan
134        return new_plan
135
136    def resolve_data(self, write_item: WriteItem) -> Union[torch.Tensor, io.BytesIO]:
137        object = self.lookup_object(write_item.index)
138        return self.transform_object(write_item, object)
139
140    def lookup_object(self, index: MetadataIndex) -> Any:
141        """Extension from the planner interface to make it easy to extend the default planner."""
142        return find_state_dict_object(self.state_dict, index)
143
144    def transform_object(self, write_item: WriteItem, object: Any):
145        """Extension from the planner interface to make it easy to extend the default planner."""
146        if write_item.type == WriteItemType.BYTE_IO:
147            bytes = io.BytesIO()
148            torch.save(object, bytes)
149            object = bytes
150        return object
151
152
153class DefaultLoadPlanner(LoadPlanner):
154    """
155    DefaultLoadPlanner that adds multiple features on top of LoadPlanner.
156
157    In particular it adds the following:
158
159    flatten_state_dict: Handle state_dict with nested dicts
160    flatten_sharded_tensors: For FSDP in 2D parallel mode
161    allow_partial_load: If False, will raise a runtime error if a key is present in state_dict, but not in the checkpoint.
162    """
163
164    original_state_dict: STATE_DICT_TYPE
165    mappings: FLATTEN_MAPPING
166
167    def __init__(
168        self,
169        flatten_state_dict: bool = True,
170        flatten_sharded_tensors: bool = True,
171        allow_partial_load: bool = False,
172    ) -> None:
173        self.flatten_state_dict = flatten_state_dict
174        self.flatten_sharded_tensors = flatten_sharded_tensors
175        self.original_state_dict = {}
176        self.mappings = {}
177        self.allow_partial_load = allow_partial_load
178
179    def set_up_planner(
180        self,
181        state_dict: STATE_DICT_TYPE,
182        metadata: Optional[Metadata] = None,
183        is_coordinator: bool = False,
184    ) -> None:
185        _init_state_dict(state_dict)
186        self.original_state_dict = state_dict
187
188        if self.flatten_sharded_tensors:
189            state_dict = _flatten_sharded_tensors(state_dict)
190
191        if self.flatten_state_dict:
192            state_dict, self.mappings = flatten_state_dict(state_dict)
193
194        self.state_dict = state_dict
195        self.metadata = metadata
196        self.is_coordinator = is_coordinator
197
198    def create_local_plan(self) -> LoadPlan:
199        assert self.metadata is not None
200        if self.flatten_state_dict:
201            # To support checkpoints that are saved before v2.4, we have to
202            # differentiate if the missing keys are due to old checkpoints.
203            # The contracts are:
204            # 1. There are 3 cases when we found a missing key.
205            #    1.1 Actual missing key, but allow_partial_load is False
206            #    1.2 Actual missing key, but allow_partial load is True
207            #    1.3 Old checkpoint, but allow_partial_load is False
208            #    1.4 Old checkpoint, but allow_partial_load is True
209            # 2. If we found a missing key, we first convert the keys back to
210            #    the key format of v2.3
211            # 3. If the previous missing keys are in the v2.3 keys, we assume
212            #    this is a old checkpoint.
213            # 4. Pass the state_dict to `create_default_local_load_plan()`,
214            #    which has the logic to check missing for allow_partial_load.
215            # So for 1.2 and 1.4 cases, we delegate allow_partial_load check to
216            # `create_default_local_load_plan()`. The logic here is to determine
217            # whether the checkpoint belong to 2.3 (or before) or 2.4 (or after).
218            current_keys = set(self.state_dict.keys())
219            load_keys = set(self.metadata.state_dict_metadata.keys())
220            missing_keys = load_keys - current_keys
221            if missing_keys:
222                _version._derived_version = "2_3"
223                old_state_dict, old_mappings = flatten_state_dict(
224                    self.original_state_dict
225                )
226                old_keys = set(old_state_dict.keys())
227                if old_keys & missing_keys:
228                    self.state_dict, self.mappings = old_state_dict, old_mappings
229                # _derived_version is only used by flatten_state_dict now.
230                # Set it back to None so that later we can save to a new version.
231                _version._derived_version = None
232
233        return create_default_local_load_plan(
234            self.state_dict, self.metadata, not self.allow_partial_load
235        )
236
237    def create_global_plan(self, global_plan: List[LoadPlan]) -> List[LoadPlan]:
238        return create_default_global_load_plan(global_plan)
239
240    def finish_plan(self, new_plan: LoadPlan) -> LoadPlan:
241        return new_plan
242
243    def load_bytes(self, read_item: ReadItem, value: io.BytesIO) -> None:
244        if self.flatten_state_dict:
245            set_element(
246                self.original_state_dict,
247                self.mappings[read_item.dest_index.fqn],
248                torch.load(value, weights_only=False),
249            )
250        else:
251            self.state_dict[read_item.dest_index.fqn] = torch.load(
252                value, weights_only=False
253            )
254
255    def resolve_tensor(self, read_item: ReadItem):
256        tensor = self.lookup_tensor(read_item.dest_index)
257        return self.transform_tensor(read_item, tensor)
258
259    def commit_tensor(self, read_item: ReadItem, tensor: torch.Tensor) -> None:
260        pass
261
262    def lookup_tensor(self, index: MetadataIndex) -> torch.Tensor:
263        """Extension from the planner interface to make it easy to extend the default planner."""
264        return find_state_dict_object(self.state_dict, index)
265
266    def transform_tensor(self, read_item: ReadItem, tensor: torch.Tensor):
267        """Extension from the planner interface to make it easy to extend the default planner."""
268        return narrow_tensor_by_index(tensor, read_item.dest_offsets, read_item.lengths)
269
270
271class _EmptyStateDictLoadPlanner(DefaultLoadPlanner):
272    """
273    Extension of DefaultLoadPlanner, which rebuilds state_dict from the saved metadata.
274    Useful for loading in state_dict without first initializing a model, such as
275    when converting a DCP checkpoint into a Torch save file.
276
277    . N.B. `state_dict` must be an empty dictionary when used with this LoadPlanner
278
279    .. warning::
280        Because the entire state dict is initialized, It's recommended to only utilize
281        this LoadPlanner on a single rank or process to avoid OOM.
282
283    """
284
285    def __init__(self, keys=None, *args, **kwargs):
286        self.keys = keys
287        super().__init__(*args, **kwargs)
288
289    def _should_include_key(self, key: str, metadata: Metadata) -> bool:
290        if self.keys is None:
291            return True
292
293        if key in self.keys:
294            True
295
296        unflattened_keys: List[str] = []
297        planner_data = metadata.planner_data.get(key)
298        for unflattened_key in planner_data:
299            if unflattened_keys:
300                unflattened_keys.append(
301                    ".".join([unflattened_keys[-1], str(unflattened_key)])
302                )
303
304            else:
305                unflattened_keys.append(unflattened_key)
306
307        if any(unflattened_key in self.keys for unflattened_key in unflattened_keys):
308            return True
309
310        return False
311
312    def set_up_planner(
313        self,
314        state_dict: STATE_DICT_TYPE,
315        metadata: Optional[Metadata] = None,
316        is_coordinator: bool = False,
317    ) -> None:
318        assert not state_dict
319        assert metadata is not None
320
321        # rebuild the state dict from the metadata
322        for k, v in metadata.state_dict_metadata.items():
323            if not self._should_include_key(k, metadata):
324                continue
325
326            if isinstance(v, TensorStorageMetadata):
327                v = torch.empty(v.size, dtype=v.properties.dtype)  # type: ignore[assignment]
328            if k in metadata.planner_data:
329                set_element(state_dict, metadata.planner_data[k], v)
330            else:
331                state_dict[k] = v
332
333        super().set_up_planner(state_dict, metadata, is_coordinator)
334
335
336def create_default_local_load_plan(
337    state_dict: Dict[str, Any], metadata: Metadata, strict: bool = True
338) -> LoadPlan:
339    requests = []
340    """
341    Create the ``LoadPlan`` used by DefaultLoadPlanner.
342
343    It produces one read item per value in ``state_dict`` using the metadata in ``metadata``.
344
345    The default behavior is to match key exactly between state_dict and metadata.
346    It handles resharding by issuing multiple read requests against storage in order to match
347    load requirements.
348    """
349
350    for fqn, obj in state_dict.items():
351        # ignore state_dict keys which do not exist in `state_dict` if strict=False
352        if fqn not in metadata.state_dict_metadata:
353            if strict:
354                raise RuntimeError(f"Missing key in checkpoint state_dict: {fqn}.")
355            else:
356                continue
357
358        md = metadata.state_dict_metadata[fqn]
359        # Since DTensor supports submesh, adding extra check to ensure _create_read_items()
360        # gets called only when the current rank is part of the mesh for the corresponding DTensor.
361        if isinstance(obj, DTensor):
362            if obj.device_mesh.get_coordinate() is not None:
363                requests += _create_read_items(fqn, md, obj)
364        else:
365            requests += _create_read_items(fqn, md, obj)
366
367    return LoadPlan(requests)
368
369
370def create_default_global_load_plan(
371    all_plans: List[LoadPlan],
372) -> List[LoadPlan]:
373    """
374    Create global load plan used by DefaultLoadPlanner.
375
376    The default load behavior involved no global coordination and this function
377    currently doesn't change the local plans.
378    """
379    return all_plans
380
381
382def create_default_local_save_plan(
383    state_dict: Dict[str, Any], is_coordinator: bool
384) -> SavePlan:
385    """
386    Create the ``SavePlan`` used by DefaultSavePlanner.
387
388    On non-coordinator ranks, this function ignores tensors and non-tensor objects,
389    only producing writes for ShardedTensor objects.
390
391    On the coordinator rank, produce writes for all values.
392    """
393    requests = []
394    for fqn, obj in state_dict.items():
395        # Since DTensor supports submesh, adding extra check to ensure _create_write_items()
396        # gets called only when the current rank is part of the mesh for the corresponding DTensor.
397        if isinstance(obj, DTensor):
398            if obj.device_mesh.get_coordinate() is not None:
399                requests += _create_write_items(fqn, obj)
400        else:
401            # For the plain tensor and non-tensor values, add the request for all
402            # the ranks. Coordinator will decides whether to deduplicate the
403            # values based on the keys.
404            requests += _create_write_items(fqn, obj)
405
406    return SavePlan(requests)
407
408
409def create_default_global_save_plan(
410    all_plans: List[SavePlan],
411    rewrite_index_hints: bool = True,
412) -> Tuple[List[SavePlan], Metadata]:
413    """
414    Create the global plan and metadata used by DefaultSavePlanner.
415
416    Metadata is produced by concatenating the metadata of all ``WriteItem`` from the supplied plans.
417
418    The only global planning change is to update index hints in all ``MetadataIndex`` objects if
419    ``rewrite_index_hints`` is True.
420    """
421    md: Dict[str, STORAGE_TYPES] = {}
422    new_plans = []
423    for plan in all_plans:
424        new_items = []
425        for item in plan.items:
426            if not item.type == WriteItemType.SHARD:
427                assert item.index.fqn not in md
428
429            if item.type == WriteItemType.BYTE_IO:
430                md[item.index.fqn] = BytesStorageMetadata()
431                new_items.append(item)
432            else:
433                assert item.tensor_data is not None
434                tensor_md = cast(
435                    TensorStorageMetadata,
436                    md.setdefault(
437                        item.index.fqn,
438                        TensorStorageMetadata(
439                            properties=item.tensor_data.properties,
440                            size=item.tensor_data.size,
441                            chunks=[],
442                        ),
443                    ),
444                )
445                new_item = item
446                if rewrite_index_hints:
447                    new_index = dataclasses.replace(
448                        item.index, index=len(tensor_md.chunks)
449                    )
450                    new_item = dataclasses.replace(item, index=new_index)
451                new_items.append(new_item)
452
453                assert (
454                    item.tensor_data.chunk is not None
455                ), f"""
456                    Cannot create MD for tensor without bounds.
457                    FQN: {item.index.fqn}
458                """
459                tensor_md.chunks.append(item.tensor_data.chunk)
460        new_plans.append(dataclasses.replace(plan, items=new_items))
461    return (new_plans, Metadata(md))
462
463
464def _create_default_local_metadata(state_dict: STATE_DICT_TYPE) -> Metadata:
465    """Return the ``Metadata`` if DefaultSavePlanner was used to checkpoint ``state_dict``."""
466    plan = _create_default_metadata_only_plan(state_dict)
467    _, md = create_default_global_save_plan([plan])
468    return md
469
470
471def _check_box_overlap(box0: ChunkStorageMetadata, box1: ChunkStorageMetadata) -> bool:
472    """Check if two boxes overlap. Tuples are (offset, lengths)."""
473    # For each dim of each shard, check if one shard resides on the other
474    # end of second shard with respect to that dim. As an example for a 2D
475    # shard, we would check if one shard is above or on the left of the
476    # other shard.
477    ndims = len(box0.offsets)
478    for i in range(ndims):
479        if box0.offsets[i] >= box1.offsets[i] + box1.sizes[i]:
480            return False
481        if box1.offsets[i] >= box0.offsets[i] + box0.sizes[i]:
482            return False
483
484    return True
485
486
487def _check_box_bounds(
488    outer_box_size: torch.Size, inner_box: ChunkStorageMetadata
489) -> bool:
490    for i in range(len(outer_box_size)):
491        if inner_box.offsets[i] < 0:
492            return False
493        if inner_box.sizes[i] < 0:
494            return False
495        if inner_box.offsets[i] + inner_box.sizes[i] > outer_box_size[i]:
496            return False
497
498    return True
499
500
501def _validate_global_plan(global_plan: List[SavePlan], metadata: Metadata) -> bool:
502    all_good = True
503    for key, value in metadata.state_dict_metadata.items():
504        if isinstance(value, BytesStorageMetadata):
505            continue
506        if len(value.size) == 0:
507            continue
508        chunks_volume = 0
509        for chunk_idx, chunk0 in enumerate(value.chunks):
510            # Compute the volume
511            if not _check_box_bounds(value.size, chunk0):
512                logger.warning(
513                    """
514                        key:%s has out of bounds chunk:
515                        tensor-size:%s chunk: %s
516                    """,
517                    key,
518                    value.size,
519                    chunk0,
520                )
521                all_good = False
522            chunks_volume += reduce(operator.mul, chunk0.sizes, 1)
523
524            # Check for overlap
525            for chunk1 in value.chunks[chunk_idx + 1 :]:
526                if _check_box_overlap(chunk0, chunk1):
527                    logger.warning(
528                        "key:%s has overlapping chunks: %s %s", key, chunk0, chunk1
529                    )
530                    all_good = False
531
532        # Check whether combined chunk cover the whole tensor
533        tensor_volume = reduce(operator.mul, value.size, 1)
534        if chunks_volume != tensor_volume:
535            logger.warning(
536                """
537                    key:%s invalid fill tensor-volume:
538                    %s chunks-volume: %s
539                """,
540                key,
541                tensor_volume,
542                chunks_volume,
543            )
544            all_good = False
545
546    return all_good
547