xref: /aosp_15_r20/external/pytorch/torch/distributed/checkpoint/utils.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# mypy: allow-untyped-defs
2import cProfile
3import inspect
4import io
5import itertools
6import os
7import warnings
8from contextlib import contextmanager
9from functools import wraps
10from pstats import Stats
11from typing import Any, Callable, cast, Dict, List, Optional, Sequence, TypeVar, Union
12
13import torch
14import torch.distributed as dist
15from torch.distributed._shard.sharded_tensor import ShardedTensor
16from torch.distributed._shard.sharded_tensor.shard import Shard
17
18from .api import (
19    _is_wrapped_exception,
20    _wrap_exception,
21    CheckpointException,
22    WRAPPED_EXCEPTION,
23)
24from .metadata import MetadataIndex, STATE_DICT_TYPE
25
26
27__all__ = ["find_tensor_shard", "find_state_dict_object"]
28
29T = TypeVar("T")
30R = TypeVar("R")
31
32
33def _get_failure_dict(
34    results: List[Union[T, WRAPPED_EXCEPTION]]
35) -> Dict[int, WRAPPED_EXCEPTION]:
36    return cast(
37        Dict[int, WRAPPED_EXCEPTION],
38        {i: err for i, err in enumerate(results) if _is_wrapped_exception(err)},
39    )
40
41
42def _all_gather_keys(
43    local_dict: Dict[Any, Any], group: Optional[dist.ProcessGroup] = None
44) -> List[Any]:
45    """Gathers all keys, and returns them sorted."""
46    keys = list(local_dict.keys())
47    gathered_keys: List[List[Any]] = [None] * dist.get_world_size(group)  # type: ignore[list-item]
48
49    dist.all_gather_object(gathered_keys, keys, group=group)
50    return sorted(set(itertools.chain.from_iterable(gathered_keys)))
51
52
53class _DistWrapper:
54    """
55    This is a wrapper around PG that provides a series of features around object collectives.
56
57    It works without distributed initialized, where most collectives turns into nops.
58
59    All variants that take functions are exception robust, meaning that if one or more
60    ranks raise errors, all ranks will observe those.
61    """
62
63    def __init__(
64        self,
65        group: Optional[dist.ProcessGroup],
66        use_dist: bool,
67        coordinator_rank: int,
68    ):
69        self.group = group
70        self.use_dist = use_dist
71        self.coordinator_rank = coordinator_rank
72        if self.use_dist:
73            self.rank = dist.get_rank(group)
74            self.is_coordinator = self.rank == coordinator_rank
75        else:
76            self.rank = 0
77            self.is_coordinator = True
78
79    def get_rank(self) -> int:
80        return self.rank
81
82    def get_world_size(self) -> int:
83        if self.use_dist:
84            return dist.get_world_size(self.group)
85        return 1
86
87    def broadcast_object(self, object: Optional[T]) -> T:
88        """Implement functionality similar to c10d::broadcast_object_list but without distributed enabled."""
89        object_list = [object]
90        if self.use_dist:
91            dist.broadcast_object_list(
92                object_list=object_list,
93                group=self.group,
94                src=self.coordinator_rank,
95            )
96        return cast(T, object_list[0])
97
98    def gather_object(self, object: T) -> Optional[List[T]]:
99        """Implement functionality similar to c10d::gather_object but without distributed enabled."""
100        if self.use_dist:
101            gather_objs = (
102                cast(List[T], [None] * dist.get_world_size(self.group))
103                if self.is_coordinator
104                else None
105            )
106
107            dist.gather_object(
108                obj=object,
109                object_gather_list=gather_objs if self.is_coordinator else None,
110                dst=self.coordinator_rank,
111                group=self.group,
112            )
113            result = gather_objs
114        else:
115            result = [object]
116        return result
117
118    def all_gather_object(self, object: T) -> List[T]:
119        """Implement functionality similar to c10d::all_gather_object but without distributed enabled."""
120        if self.use_dist:
121            gather_objs = cast(List[T], [None] * dist.get_world_size(self.group))
122
123            dist.all_gather_object(
124                object_list=gather_objs, obj=object, group=self.group
125            )
126        else:
127            gather_objs = [object]
128        return gather_objs
129
130    def scatter_object(self, object_list: Optional[List[T]]) -> T:
131        """Implement functionality similar to c10d::scatter_object but without distributed enabled."""
132        if self.use_dist:
133            gather_result = cast(List[T], [None])
134            dist.scatter_object_list(
135                scatter_object_output_list=gather_result,
136                scatter_object_input_list=object_list if self.is_coordinator else None,
137                src=self.coordinator_rank,
138                group=self.group,
139            )
140
141            local_reply = gather_result[0]
142        else:
143            assert object_list is not None
144            local_reply = object_list[0]
145        return local_reply
146
147    def reduce_scatter(
148        self,
149        step: str,
150        map_fun: Callable[[], T],
151        reduce_fun: Callable[[List[T]], List[R]],
152    ) -> R:
153        """
154        Compute a value on each rank, then do centralized reduce on a single rank, followed by a scatter.
155
156        This method operates in the following way:
157            Run ``map_fun`` on all ranks
158            Gather results on rank 0
159            Call ``reduce_fun`` on all those values
160            Scatter to each rank part of the result.
161        """
162        local_data: Union[WRAPPED_EXCEPTION, T]
163        try:
164            local_data = map_fun()
165        except BaseException as e:
166            local_data = _wrap_exception(e)
167
168        all_data = self.gather_object(local_data)
169        all_results: Optional[List[Union[R, CheckpointException]]] = None
170        if self.is_coordinator:
171            assert all_data is not None
172            node_failures = _get_failure_dict(all_data)
173
174            if len(node_failures) == 0:
175                try:
176                    # N.B. why can't mypy cast List[R] to List[Union[R, WRAPPED_EXCEPTION]]?
177                    all_results = cast(
178                        List[Union[R, CheckpointException]],
179                        reduce_fun(cast(List[T], all_data)),
180                    )
181                except BaseException as e:
182                    node_failures[self.rank] = _wrap_exception(e)
183
184            if len(node_failures) > 0:
185                all_results = [
186                    CheckpointException(step, node_failures)
187                ] * self.get_world_size()
188
189        result = self.scatter_object(all_results)
190        if isinstance(result, CheckpointException):
191            raise result
192        return result
193
194    def all_reduce(
195        self,
196        step: str,
197        map_fun: Callable[[], T],
198        reduce_fun: Callable[[List[T]], R],
199    ) -> R:
200        """
201        Compute a value on each rank, then do centralized reduce on a single rank, followed by a broadcast.
202
203        This method operates in the following way:
204            Run ``map_fun`` on all ranks
205            Gather results on rank 0
206            Call ``reduce_fun`` on all those values
207            Broadcast the reduced value to all ranks.
208        """
209        local_data: Union[T, WRAPPED_EXCEPTION]
210        try:
211            local_data = map_fun()
212        except BaseException as e:
213            local_data = _wrap_exception(e)
214
215        all_data = self.gather_object(local_data)
216        result: Optional[Union[R, CheckpointException]] = None
217        if self.is_coordinator:
218            assert all_data is not None
219            node_failures = _get_failure_dict(all_data)
220            if len(node_failures) == 0:
221                try:
222                    result = reduce_fun(cast(List[T], all_data))
223                except BaseException as e:
224                    node_failures[self.rank] = _wrap_exception(e)
225
226            if len(node_failures) > 0:
227                result = CheckpointException(step, node_failures)
228
229        final_result = self.broadcast_object(result)
230        if isinstance(final_result, CheckpointException):
231            raise final_result
232        return cast(R, final_result)
233
234    def all_gather(
235        self,
236        step: str,
237        map_fun: Callable[[], T],
238    ) -> List[T]:
239        """
240        Compute a value on each rank, then all_gather them.
241
242        This method operates in the following way:
243            Run ``map_cp`` on all ranks
244            all_gather the values to all ranks
245        """
246        result: Union[T, WRAPPED_EXCEPTION]
247        try:
248            result = map_fun()
249        except BaseException as e:
250            result = _wrap_exception(e)
251
252        all_results = self.all_gather_object(result)
253
254        node_failures = _get_failure_dict(all_results)
255        if len(node_failures) > 0:
256            raise CheckpointException(step, node_failures)
257        return cast(List[T], all_results)
258
259    def broadcast(
260        self,
261        step: str,
262        map_fun: Callable[[], T],
263    ) -> T:
264        """
265        Compute a value on rank 0 and broadcast it.
266
267        This method operates in the following way:
268            Run ``map_cp`` on rank 0
269            broadcast the value
270        """
271        result: Optional[Union[T, CheckpointException]] = None
272        if self.is_coordinator:
273            try:
274                result = map_fun()
275            except BaseException as e:
276                result = CheckpointException(step, {self.rank: _wrap_exception(e)})
277        final_result = self.broadcast_object(result)
278        if isinstance(final_result, CheckpointException):
279            raise final_result
280        return cast(T, final_result)
281
282
283def _find_shard(tensor: ShardedTensor, index: MetadataIndex) -> Shard:
284    if index.offset is None:
285        raise ValueError(
286            f"Cannot lookup {index.fqn} since its a ShardedTensor and no offset was provided"
287        )
288
289    shards = tensor.local_shards()
290    # index fast path
291    if index.index is not None:
292        if (
293            len(shards) > index.index
294            and torch.Size(shards[index.index].metadata.shard_offsets) == index.offset
295        ):
296            return shards[index.index]
297
298    for shard in shards:
299        if torch.Size(shard.metadata.shard_offsets) == index.offset:
300            return shard
301    raise ValueError(f"Could not find shard at '{index.offset}' for FQN: '{index.fqn}'")
302
303
304def find_tensor_shard(tensor: torch.Tensor, index: MetadataIndex) -> torch.Tensor:
305    if hasattr(tensor, "__get_tensor_shard__"):
306        # DTensor implements _Checkpointable
307        return tensor.__get_tensor_shard__(index)  # type: ignore[attr-defined]
308    if isinstance(tensor, ShardedTensor):
309        return _find_shard(tensor, index).tensor
310    if index.offset is not None:
311        # special case looking up a tensor by origin
312        if index.offset == torch.Size([0] * len(tensor.size())):
313            return tensor
314        raise ValueError(
315            f"FQN: '{index.fqn}' is not a ShardedTensor, can't find by offset: '{index.offset}'"
316        )
317    return tensor
318
319
320def find_state_dict_object(state_dict: STATE_DICT_TYPE, index: MetadataIndex) -> Any:
321    if index.fqn not in state_dict:
322        raise ValueError(f"Could not find FQN: '{index.fqn}'")
323    obj = state_dict[index.fqn]
324
325    if isinstance(obj, torch.Tensor):
326        return find_tensor_shard(obj, index)
327    elif index.offset is not None:
328        raise ValueError(
329            f"FQN: '{index.fqn}' is not a ShardedTensor, can't find by offset: '{index.offset}'"
330        )
331    return obj
332
333
334def _element_wise_add(a: Sequence[int], b: Sequence[int]) -> List[int]:
335    return [i_a + i_b for i_a, i_b in zip(a, b)]
336
337
338def _element_wise_sub(a: Sequence[int], b: Sequence[int]) -> List[int]:
339    return [i_a - i_b for i_a, i_b in zip(a, b)]
340
341
342class _ReaderView(io.IOBase):
343    def __init__(self, base_stream: io.IOBase, offset: int, len: int):
344        super().__init__()
345        self.offset = offset
346        self.len = len
347        self.base_stream = base_stream
348        self.seek(0)
349
350    def seek(self, __offset: int, __whence: int = os.SEEK_SET) -> int:
351        if __whence == os.SEEK_SET:
352            __offset = self.offset + __offset
353        elif __whence == os.SEEK_END:
354            __whence = os.SEEK_SET
355            __offset = (self.offset + self.len) - __offset
356        return self.base_stream.seek(__offset, __whence)
357
358    def tell(self) -> int:
359        return self.base_stream.tell() - self.offset
360
361    def readable(self) -> bool:
362        return self.base_stream.readable()
363
364    def seekable(self) -> bool:
365        return self.base_stream.seekable()
366
367    def readinto(self, b):
368        return self.base_stream.readinto(b)  # type: ignore[attr-defined]
369
370    def read(self, size=-1):
371        return self.base_stream.read(size)
372
373
374def _create_file_view(file: io.IOBase, offset: int, length: int) -> io.IOBase:
375    # FIXME (kumpera) torch.load fails if we wrap with io.BufferedReader
376    return _ReaderView(file, offset, length)
377
378
379def _normalize_device_info(device_type: str, device_id: int) -> str:
380    """Device info normalization."""
381    if device_type == "cpu":
382        return "cpu"
383    return f"{device_type}:{device_id}"
384
385
386# TODO: integrate with distributed logging flag
387ENABLE_PROFILE = False
388
389
390@contextmanager
391def _profile():
392    # Only log the profiling when it is enable and is on rank0  or dist is not
393    # avaiable.
394    if ENABLE_PROFILE and (not dist.is_available() or dist.get_rank() == 0):
395        profiler = cProfile.Profile()
396        profiler.enable()
397        try:
398            yield
399        finally:
400            profiler.disable()
401            stats = Stats(profiler)
402            stats.sort_stats("time").print_stats(10)
403    else:
404        yield
405
406
407def _api_bc_check(func):
408    @wraps(func)
409    def inner_func(*args, **kwargs) -> Any:
410        if len(args) == 2:
411            warnings.warn(
412                f"The argument order of {func.__name__} has been changed. "
413                "Please check the document to avoid future breakages."
414            )
415            sig = inspect.signature(func)
416            kwonlyargs = [
417                p.name for p in sig.parameters.values() if p.kind == p.KEYWORD_ONLY
418            ]
419            if "storage_writer" in kwonlyargs:
420                assert "storage_writer" not in kwargs, (args, kwargs)
421                kwargs["storage_writer"] = args[1]
422            elif "storage_reader" in kwonlyargs:
423                assert "storage_reader" not in kwargs, (args, kwargs)
424                kwargs["storage_reader"] = args[1]
425            else:
426                raise RuntimeError(f"Unexpected kwonlyargs = {kwonlyargs}")
427            return func(args[0], **kwargs)
428        else:
429            return func(*args, **kwargs)
430
431    return inner_func
432