xref: /aosp_15_r20/external/pytorch/torch/distributed/checkpoint/planner_helpers.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# mypy: allow-untyped-defs
2from typing import Any, cast, List
3
4import torch
5import torch.distributed as dist
6from torch._utils import _get_device_module
7
8from torch.distributed._shard.metadata import ShardMetadata
9from torch.distributed._shard.sharded_tensor import ShardedTensor
10from torch.distributed._tensor import DTensor
11from torch.distributed._tensor._utils import compute_local_shape_and_global_offset
12from torch.distributed.checkpoint.planner import _Checkpointable
13
14from torch.utils._pytree import tree_map_only_
15
16from .metadata import (
17    BytesStorageMetadata,
18    ChunkStorageMetadata,
19    MetadataIndex,
20    STATE_DICT_TYPE,
21    STORAGE_TYPES,
22    TensorProperties,
23    TensorStorageMetadata,
24)
25from .planner import (
26    LoadItemType,
27    ReadItem,
28    SavePlan,
29    TensorWriteData,
30    WriteItem,
31    WriteItemType,
32)
33from .resharding import (
34    _check_shard_metadata_pair_overlap,
35    _shards_get_overlap_region_wrt_saved_tensor,
36)
37
38__all__: List[str] = ["create_read_items_for_chunk_list"]
39
40
41def _create_chunk_from_tensor(tensor: torch.Tensor) -> ChunkStorageMetadata:
42    return ChunkStorageMetadata(
43        offsets=torch.Size([0] * len(tensor.size())), sizes=tensor.size()
44    )
45
46
47def _chunk_for_shard(shard_md: ShardMetadata) -> ChunkStorageMetadata:
48    return ChunkStorageMetadata(
49        offsets=torch.Size(shard_md.shard_offsets),
50        sizes=torch.Size(shard_md.shard_sizes),
51    )
52
53
54def _sharded_tensor_metadata(
55    sharded_tensor: ShardedTensor, shard_md: ShardMetadata
56) -> TensorWriteData:
57    shard_properties = sharded_tensor.metadata().tensor_properties
58
59    properties = TensorProperties(
60        dtype=shard_properties.dtype,
61        layout=shard_properties.layout,
62        requires_grad=shard_properties.requires_grad,
63        memory_format=shard_properties.memory_format,
64        pin_memory=shard_properties.pin_memory,
65    )
66
67    return TensorWriteData(
68        chunk=_chunk_for_shard(shard_md),
69        properties=properties,
70        size=sharded_tensor.metadata().size,
71    )
72
73
74def _create_write_items_for_dtensor(fqn: str, tensor: DTensor) -> WriteItem:
75    sizes, offsets = compute_local_shape_and_global_offset(
76        tensor.shape, tensor.device_mesh, tensor.placements
77    )
78    sizes, offsets = torch.Size(sizes), torch.Size(offsets)
79
80    return WriteItem(
81        index=MetadataIndex(fqn, offsets),
82        type=WriteItemType.SHARD,
83        tensor_data=TensorWriteData(
84            chunk=ChunkStorageMetadata(
85                offsets=offsets,
86                sizes=sizes,
87            ),
88            properties=TensorProperties.create_from_tensor(tensor.to_local()),
89            size=tensor.size(),
90        ),
91    )
92
93
94def _create_write_item_for_shard(
95    fqn: str, sharded_tensor: ShardedTensor, shard_md: ShardMetadata
96) -> WriteItem:
97    offsets = torch.Size(shard_md.shard_offsets)
98    return WriteItem(
99        index=MetadataIndex(fqn, offsets),
100        type=WriteItemType.SHARD,
101        tensor_data=_sharded_tensor_metadata(sharded_tensor, shard_md),
102    )
103
104
105def _create_write_item_for_tensor(fqn: str, tensor: torch.Tensor) -> WriteItem:
106    offsets = torch.Size([0] * len(tensor.size()))
107    return WriteItem(
108        index=MetadataIndex(fqn, offsets),
109        type=WriteItemType.TENSOR,
110        tensor_data=TensorWriteData(
111            chunk=ChunkStorageMetadata(offsets=offsets, sizes=tensor.size()),
112            properties=TensorProperties.create_from_tensor(tensor),
113            size=tensor.size(),
114        ),
115    )
116
117
118def _create_write_item_for_bytesio(fqn: str, bytes: Any):
119    return WriteItem(
120        index=MetadataIndex(fqn),
121        type=WriteItemType.BYTE_IO,
122    )
123
124
125def _create_read_item_for_byteio(
126    dest_index, dest_offset, storage_index, storage_offset, length
127):
128    return ReadItem(
129        type=LoadItemType.BYTE_IO,
130        dest_index=dest_index,
131        dest_offsets=torch.Size((dest_offset,)),
132        storage_index=storage_index,
133        storage_offsets=torch.Size((storage_offset,)),
134        lengths=torch.Size((length,)),
135    )
136
137
138def _create_read_item_for_tensor(
139    dest_index, dest_offsets, storage_index, storage_offsets, lengths
140):
141    return ReadItem(
142        type=LoadItemType.TENSOR,
143        dest_index=dest_index,
144        dest_offsets=torch.Size(dest_offsets),
145        storage_index=storage_index,
146        storage_offsets=torch.Size(storage_offsets),
147        lengths=torch.Size(lengths),
148    )
149
150
151def create_read_items_for_chunk_list(
152    fqn: str,
153    checkpoint_md: TensorStorageMetadata,
154    local_chunks: List[ChunkStorageMetadata],
155) -> List[ReadItem]:
156    """
157    Create a list of ``ReadItem`` based on the checkpoint and local chunks.
158
159    This applies the resharding algorithm and computes the reads needed
160    to satisfy ``local_chunks`` with a checkpoint described by ``checkpoint_md``.
161
162    Args:
163        fqn (str) : The state_dict FQN to pass to ``ReadItem``.
164        checkpoint_md (TensorStorageMetadata): metadata for a given tensor
165            from a checkpoint.
166        local_chunks (List[ChunkStorageMetadata]): Local chunks that needs to be
167            loaded.
168
169    Returns:
170        A list of ``ReadItem`` that will satisfy all input chunks.
171    """
172    read_items = []
173    # this is a naive quadratic algo that can be optimized later
174    for idx, shard in enumerate(local_chunks):
175        for storage_idx, storage_md in enumerate(checkpoint_md.chunks):
176            if not _check_shard_metadata_pair_overlap(shard, storage_md):
177                continue
178
179            storage_offsets = []
180            dest_offsets = []
181            lengths = []
182            for (
183                dim,
184                offset_for_saved_tensor,
185                offset_for_current_tensor,
186                length,
187            ) in _shards_get_overlap_region_wrt_saved_tensor(
188                saved_shard=storage_md, current_shard=shard
189            ):
190                storage_offsets.append(offset_for_saved_tensor)
191                dest_offsets.append(offset_for_current_tensor)
192                lengths.append(length)
193
194            read_items.append(
195                _create_read_item_for_tensor(
196                    dest_index=MetadataIndex(fqn, shard.offsets, idx),
197                    dest_offsets=dest_offsets,
198                    storage_index=MetadataIndex(fqn, storage_md.offsets, storage_idx),
199                    storage_offsets=storage_offsets,
200                    lengths=lengths,
201                )
202            )
203    return read_items
204
205
206def _create_default_metadata_only_plan(state_dict: STATE_DICT_TYPE) -> SavePlan:
207    requests = []
208    for fqn, obj in state_dict.items():
209        if isinstance(obj, DTensor):
210            requests.append(_create_write_items_for_dtensor(fqn, obj))
211        elif isinstance(obj, ShardedTensor):
212            for shard_md in obj.metadata().shards_metadata:
213                requests.append(_create_write_item_for_shard(fqn, obj, shard_md))
214        elif isinstance(obj, torch.Tensor):
215            requests.append(_create_write_item_for_tensor(fqn, obj))
216        else:
217            requests.append(_create_write_item_for_bytesio(fqn, obj))
218    return SavePlan(requests)
219
220
221def _create_write_items(fqn: str, object: Any) -> List[WriteItem]:
222    if isinstance(object, _Checkpointable):
223        return object._create_write_items(fqn, object)
224    elif isinstance(object, DTensor):
225        # DTensor can contain a local tensor that is a tensor subclass
226        if isinstance(object.to_local(), _Checkpointable):
227            return object.to_local()._create_write_items(fqn, object)  # type: ignore[arg-type]
228        return [_create_write_items_for_dtensor(fqn, object)]
229    elif isinstance(object, ShardedTensor):
230        return [
231            _create_write_item_for_shard(fqn, object, shard.metadata)
232            for shard in object.local_shards()
233        ]
234    elif isinstance(object, torch.Tensor):
235        return [_create_write_item_for_tensor(fqn, object)]
236    else:
237        return [_create_write_item_for_bytesio(fqn, object)]
238
239
240def _create_chunk_from_dtensor(tensor: DTensor) -> ChunkStorageMetadata:
241    sizes, offsets = compute_local_shape_and_global_offset(
242        tensor.shape, tensor.device_mesh, tensor.placements
243    )
244    sizes, offsets = torch.Size(sizes), torch.Size(offsets)
245    return ChunkStorageMetadata(
246        offsets=offsets,
247        sizes=sizes,
248    )
249
250
251def _create_chunk_list(tensor: torch.Tensor) -> List[ChunkStorageMetadata]:
252    if isinstance(tensor, _Checkpointable):
253        local_chunks = tensor._create_chunk_list(tensor)
254    elif isinstance(tensor, DTensor):
255        # DTensor can contain a local tensor that is a tensor subclass
256        if isinstance(tensor.to_local(), _Checkpointable):
257            return tensor.to_local()._create_chunk_list(tensor)  # type: ignore[arg-type]
258        local_chunks = [_create_chunk_from_dtensor(tensor)]
259    elif isinstance(tensor, ShardedTensor):
260        local_chunks = [
261            _chunk_for_shard(shard.metadata) for shard in tensor.local_shards()
262        ]
263    elif isinstance(tensor, torch.Tensor):
264        local_chunks = [_create_chunk_from_tensor(tensor)]
265    else:
266        raise ValueError(
267            "Unsupported Type, expecting one of [Tensor, DTensor, ShardedTensor] "
268            f",but got {type(tensor)}"
269        )
270
271    return local_chunks
272
273
274def _create_read_items(fqn: str, md: STORAGE_TYPES, obj: Any) -> List[ReadItem]:
275    if not isinstance(md, BytesStorageMetadata):
276        try:
277            local_chunks = _create_chunk_list(obj)
278        except ValueError as ex:
279            raise ValueError(
280                f"Invalid checkpoint metadata for {fqn}, "
281                + f"expected BytesStorageMetadata but found {type(md)}",
282            ) from ex
283
284        return create_read_items_for_chunk_list(fqn, md, local_chunks)
285    else:
286        return [
287            _create_read_item_for_byteio(
288                dest_index=MetadataIndex(fqn),
289                dest_offset=0,
290                storage_index=MetadataIndex(fqn),
291                storage_offset=0,
292                length=0,
293            )
294        ]
295
296
297def _init_state_dict(state_dict: STATE_DICT_TYPE) -> None:
298    tree_map_only_(torch.Tensor, _init_meta_tensor, state_dict)
299
300
301def _init_meta_tensor(value: Any) -> Any:
302    """
303    Initializes tensor, moves it to device for torch.Tensor/DTensor on meta device.
304    """
305
306    device = getattr(value, "device", None)
307    # DCP does the initialization if it's meta tensor/DTensor.
308    if device == torch.device("meta"):
309        device_type = dist.distributed_c10d._get_pg_default_device().type
310        device = cast(torch.device, _get_device_module(device_type).current_device())
311        if isinstance(value, DTensor):
312            new_local_tensor = torch.empty_like(value.to_local(), device=device)
313            # We need to pass shape and stride explicitly, since DTensor might be
314            # sharded unevenly.
315            dtensor = DTensor.from_local(
316                new_local_tensor,
317                device_mesh=value.device_mesh,
318                placements=value.placements,
319                shape=value.size(),
320                stride=value.stride(),
321            )
322            return dtensor
323        elif isinstance(value, torch.Tensor):
324            tensor = torch.empty_like(value, device=device)
325            return tensor
326        else:
327            raise RuntimeError(
328                f"Found unsupported type {type(value)} for meta device loading."
329            )
330    else:
331        return value
332