1from typing import Any, Dict, runtime_checkable, TypeVar 2from typing_extensions import Protocol 3 4 5__all__ = ["Stateful", "StatefulT"] 6 7 8@runtime_checkable 9class Stateful(Protocol): 10 """ 11 Stateful protocol for objects that can be checkpointed and restored. 12 """ 13 14 def state_dict(self) -> Dict[str, Any]: 15 """ 16 Objects should return their state_dict representation as a dictionary. 17 The output of this function will be checkpointed, and later restored in 18 `load_state_dict()`. 19 20 .. warning:: 21 Because of the inplace nature of restoring a checkpoint, this function 22 is also called during `torch.distributed.checkpoint.load`. 23 24 25 Returns: 26 Dict: The objects state dict 27 """ 28 29 ... 30 31 def load_state_dict(self, state_dict: Dict[str, Any]) -> None: 32 """ 33 Restore the object's state from the provided state_dict. 34 35 Args: 36 state_dict: The state dict to restore from 37 """ 38 39 ... 40 41 42StatefulT = TypeVar("StatefulT", bound=Stateful) 43