xref: /aosp_15_r20/external/pytorch/torch/distributed/checkpoint/stateful.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1from typing import Any, Dict, runtime_checkable, TypeVar
2from typing_extensions import Protocol
3
4
5__all__ = ["Stateful", "StatefulT"]
6
7
8@runtime_checkable
9class Stateful(Protocol):
10    """
11    Stateful protocol for objects that can be checkpointed and restored.
12    """
13
14    def state_dict(self) -> Dict[str, Any]:
15        """
16        Objects should return their state_dict representation as a dictionary.
17        The output of this function will be checkpointed, and later restored in
18        `load_state_dict()`.
19
20        .. warning::
21            Because of the inplace nature of restoring a checkpoint, this function
22            is also called during `torch.distributed.checkpoint.load`.
23
24
25        Returns:
26            Dict: The objects state dict
27        """
28
29        ...
30
31    def load_state_dict(self, state_dict: Dict[str, Any]) -> None:
32        """
33        Restore the object's state from the provided state_dict.
34
35        Args:
36            state_dict: The state dict to restore from
37        """
38
39        ...
40
41
42StatefulT = TypeVar("StatefulT", bound=Stateful)
43