xref: /aosp_15_r20/external/pytorch/torch/distributed/checkpoint/metadata.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# mypy: allow-untyped-defs
2import os
3from dataclasses import dataclass, field
4from enum import Enum
5from typing import Any, Dict, List, Optional, Sequence, Union
6
7import torch
8from torch.distributed.checkpoint.stateful import StatefulT
9
10
11__all__ = [
12    "ChunkStorageMetadata",
13    "TensorStorageMetadata",
14    "BytesStorageMetadata",
15    "Metadata",
16    "MetadataIndex",
17    "TensorProperties",
18    "StorageMeta",
19]
20
21
22@dataclass
23class ChunkStorageMetadata:
24    """
25    Each chunk is expected to have the same properties of the TensorStorageMetadata
26    that includes it.
27    """
28
29    offsets: torch.Size
30    sizes: torch.Size
31
32
33class _MEM_FORMAT_ENCODING(Enum):
34    """Describe the memory format of a tensor."""
35
36    TORCH_CONTIGUOUS_FORMAT = 0
37    TORCH_CHANNELS_LAST = 1
38    TORCH_PRESERVE_FORMAT = 2
39
40
41@dataclass
42class TensorProperties:
43    """Properties used to create :class:`Tensor`"""
44
45    # Regular tensor fields
46    dtype: torch.dtype = field(default_factory=torch.get_default_dtype)
47    # This field is deprecated.
48    layout: torch.layout = field(default=torch.strided)
49    # This field is deprecated.
50    requires_grad: bool = False
51    # This field is deprecated.
52    memory_format: torch.memory_format = field(default=torch.contiguous_format)
53    # This field is deprecated.
54    pin_memory: bool = False
55
56    def __getstate__(self):
57        # Since torch.memory_format cannot be pickled!
58        memory_format = self.memory_format
59        if memory_format == torch.contiguous_format:
60            mem_format_encoding = _MEM_FORMAT_ENCODING.TORCH_CONTIGUOUS_FORMAT
61        elif memory_format == torch.channels_last:
62            mem_format_encoding = _MEM_FORMAT_ENCODING.TORCH_CHANNELS_LAST
63        elif memory_format == torch.preserve_format:
64            mem_format_encoding = _MEM_FORMAT_ENCODING.TORCH_PRESERVE_FORMAT
65        else:
66            raise RuntimeError(f"Invalid torch.memory_format: {memory_format}")
67
68        return (
69            self.dtype,
70            self.layout,
71            self.requires_grad,
72            mem_format_encoding,
73            self.pin_memory,
74        )
75
76    def __setstate__(
77        self,
78        state,
79    ):
80        (
81            self.dtype,
82            self.layout,
83            self.requires_grad,
84            mem_format_encoding,
85            self.pin_memory,
86        ) = state
87
88        if mem_format_encoding == _MEM_FORMAT_ENCODING.TORCH_CONTIGUOUS_FORMAT:
89            memory_format = torch.contiguous_format
90        elif mem_format_encoding == _MEM_FORMAT_ENCODING.TORCH_CHANNELS_LAST:
91            memory_format = torch.channels_last
92        elif mem_format_encoding == _MEM_FORMAT_ENCODING.TORCH_PRESERVE_FORMAT:
93            memory_format = torch.preserve_format
94        else:
95            raise RuntimeError(
96                f"Invalid torch.memory_format encoding: {mem_format_encoding}"
97            )
98
99        self.memory_format = memory_format
100
101    @staticmethod
102    def create_from_tensor(tensor: torch.Tensor) -> "TensorProperties":
103        return TensorProperties(
104            dtype=tensor.dtype,
105            layout=tensor.layout,
106            requires_grad=tensor.requires_grad,
107            memory_format=torch.contiguous_format,
108            pin_memory=tensor.is_pinned(),
109        )
110
111
112@dataclass
113class TensorStorageMetadata:
114    properties: TensorProperties
115    size: torch.Size
116    chunks: List[ChunkStorageMetadata]
117
118
119@dataclass
120class BytesStorageMetadata:
121    pass
122
123
124STORAGE_TYPES = Union[TensorStorageMetadata, BytesStorageMetadata]
125STATE_DICT_TYPE = Dict[str, Union[StatefulT, Any]]
126
127
128@dataclass
129class StorageMeta:
130    checkpoint_id: Union[str, os.PathLike, None] = None
131    save_id: Optional[str] = None
132    load_id: Optional[str] = None
133
134
135@dataclass
136class Metadata:
137    """This class represents the metadata of the checkpoint."""
138
139    # Keys are the same from the `state_dict` used.
140    state_dict_metadata: Dict[str, STORAGE_TYPES]
141    # It is the responsibility of the planner and storage plugins to ensure
142    # backward compatibility of the planner_data and storage_data. DCP will
143    # also ensure the backward compatibility of the metadata in this file and
144    # the metadata of the built-in planner and storage plugins.
145    planner_data: Any = None
146    storage_data: Any = None
147    storage_meta: Optional[StorageMeta] = None
148
149
150@dataclass(frozen=True)
151class MetadataIndex:
152    """This class represents a lookup key for items in a state dict or Metadata."""
153
154    fqn: str
155    """Fully Qualified Name of the object"""
156
157    offset: Optional[torch.Size] = None
158    """If the object is a tensor, offset into the tensor we're looking for"""
159
160    index: Optional[int] = field(hash=False, compare=False, default=None)
161    """
162    Index hint when searching for tensor chunk to speedup lookups (optional)
163
164    A common representation of a sharded tensor is as a list of chunks so to
165    find the index in such a list you need to linear search it.
166
167    When constructing an instance of MetadataIndex that points to that list,
168    one can provide the index as a hint and it will be probed first before
169    the linear search and thus making it significantly faster.
170    """
171
172    def __init__(
173        self,
174        fqn: str,
175        offset: Optional[Sequence[int]] = None,
176        index: Optional[int] = None,
177    ):
178        # We must use object.__setattr__ due to frozen=True
179        object.__setattr__(self, "fqn", fqn)
180        object.__setattr__(self, "index", index)
181        if offset is not None:
182            object.__setattr__(self, "offset", torch.Size(offset))
183