1# mypy: allow-untyped-defs 2import os 3from dataclasses import dataclass, field 4from enum import Enum 5from typing import Any, Dict, List, Optional, Sequence, Union 6 7import torch 8from torch.distributed.checkpoint.stateful import StatefulT 9 10 11__all__ = [ 12 "ChunkStorageMetadata", 13 "TensorStorageMetadata", 14 "BytesStorageMetadata", 15 "Metadata", 16 "MetadataIndex", 17 "TensorProperties", 18 "StorageMeta", 19] 20 21 22@dataclass 23class ChunkStorageMetadata: 24 """ 25 Each chunk is expected to have the same properties of the TensorStorageMetadata 26 that includes it. 27 """ 28 29 offsets: torch.Size 30 sizes: torch.Size 31 32 33class _MEM_FORMAT_ENCODING(Enum): 34 """Describe the memory format of a tensor.""" 35 36 TORCH_CONTIGUOUS_FORMAT = 0 37 TORCH_CHANNELS_LAST = 1 38 TORCH_PRESERVE_FORMAT = 2 39 40 41@dataclass 42class TensorProperties: 43 """Properties used to create :class:`Tensor`""" 44 45 # Regular tensor fields 46 dtype: torch.dtype = field(default_factory=torch.get_default_dtype) 47 # This field is deprecated. 48 layout: torch.layout = field(default=torch.strided) 49 # This field is deprecated. 50 requires_grad: bool = False 51 # This field is deprecated. 52 memory_format: torch.memory_format = field(default=torch.contiguous_format) 53 # This field is deprecated. 54 pin_memory: bool = False 55 56 def __getstate__(self): 57 # Since torch.memory_format cannot be pickled! 58 memory_format = self.memory_format 59 if memory_format == torch.contiguous_format: 60 mem_format_encoding = _MEM_FORMAT_ENCODING.TORCH_CONTIGUOUS_FORMAT 61 elif memory_format == torch.channels_last: 62 mem_format_encoding = _MEM_FORMAT_ENCODING.TORCH_CHANNELS_LAST 63 elif memory_format == torch.preserve_format: 64 mem_format_encoding = _MEM_FORMAT_ENCODING.TORCH_PRESERVE_FORMAT 65 else: 66 raise RuntimeError(f"Invalid torch.memory_format: {memory_format}") 67 68 return ( 69 self.dtype, 70 self.layout, 71 self.requires_grad, 72 mem_format_encoding, 73 self.pin_memory, 74 ) 75 76 def __setstate__( 77 self, 78 state, 79 ): 80 ( 81 self.dtype, 82 self.layout, 83 self.requires_grad, 84 mem_format_encoding, 85 self.pin_memory, 86 ) = state 87 88 if mem_format_encoding == _MEM_FORMAT_ENCODING.TORCH_CONTIGUOUS_FORMAT: 89 memory_format = torch.contiguous_format 90 elif mem_format_encoding == _MEM_FORMAT_ENCODING.TORCH_CHANNELS_LAST: 91 memory_format = torch.channels_last 92 elif mem_format_encoding == _MEM_FORMAT_ENCODING.TORCH_PRESERVE_FORMAT: 93 memory_format = torch.preserve_format 94 else: 95 raise RuntimeError( 96 f"Invalid torch.memory_format encoding: {mem_format_encoding}" 97 ) 98 99 self.memory_format = memory_format 100 101 @staticmethod 102 def create_from_tensor(tensor: torch.Tensor) -> "TensorProperties": 103 return TensorProperties( 104 dtype=tensor.dtype, 105 layout=tensor.layout, 106 requires_grad=tensor.requires_grad, 107 memory_format=torch.contiguous_format, 108 pin_memory=tensor.is_pinned(), 109 ) 110 111 112@dataclass 113class TensorStorageMetadata: 114 properties: TensorProperties 115 size: torch.Size 116 chunks: List[ChunkStorageMetadata] 117 118 119@dataclass 120class BytesStorageMetadata: 121 pass 122 123 124STORAGE_TYPES = Union[TensorStorageMetadata, BytesStorageMetadata] 125STATE_DICT_TYPE = Dict[str, Union[StatefulT, Any]] 126 127 128@dataclass 129class StorageMeta: 130 checkpoint_id: Union[str, os.PathLike, None] = None 131 save_id: Optional[str] = None 132 load_id: Optional[str] = None 133 134 135@dataclass 136class Metadata: 137 """This class represents the metadata of the checkpoint.""" 138 139 # Keys are the same from the `state_dict` used. 140 state_dict_metadata: Dict[str, STORAGE_TYPES] 141 # It is the responsibility of the planner and storage plugins to ensure 142 # backward compatibility of the planner_data and storage_data. DCP will 143 # also ensure the backward compatibility of the metadata in this file and 144 # the metadata of the built-in planner and storage plugins. 145 planner_data: Any = None 146 storage_data: Any = None 147 storage_meta: Optional[StorageMeta] = None 148 149 150@dataclass(frozen=True) 151class MetadataIndex: 152 """This class represents a lookup key for items in a state dict or Metadata.""" 153 154 fqn: str 155 """Fully Qualified Name of the object""" 156 157 offset: Optional[torch.Size] = None 158 """If the object is a tensor, offset into the tensor we're looking for""" 159 160 index: Optional[int] = field(hash=False, compare=False, default=None) 161 """ 162 Index hint when searching for tensor chunk to speedup lookups (optional) 163 164 A common representation of a sharded tensor is as a list of chunks so to 165 find the index in such a list you need to linear search it. 166 167 When constructing an instance of MetadataIndex that points to that list, 168 one can provide the index as a hint and it will be probed first before 169 the linear search and thus making it significantly faster. 170 """ 171 172 def __init__( 173 self, 174 fqn: str, 175 offset: Optional[Sequence[int]] = None, 176 index: Optional[int] = None, 177 ): 178 # We must use object.__setattr__ due to frozen=True 179 object.__setattr__(self, "fqn", fqn) 180 object.__setattr__(self, "index", index) 181 if offset is not None: 182 object.__setattr__(self, "offset", torch.Size(offset)) 183