xref: /aosp_15_r20/external/pytorch/torch/utils/_thunk.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1from typing import Callable, Generic, Optional, TypeVar
2
3
4R = TypeVar("R")
5
6
7class Thunk(Generic[R]):
8    """
9    A simple lazy evaluation implementation that lets you delay
10    execution of a function.  It properly handles releasing the
11    function once it is forced.
12    """
13
14    f: Optional[Callable[[], R]]
15    r: Optional[R]
16
17    __slots__ = ["f", "r"]
18
19    def __init__(self, f: Callable[[], R]):
20        self.f = f
21        self.r = None
22
23    def force(self) -> R:
24        if self.f is None:
25            return self.r  # type: ignore[return-value]
26        self.r = self.f()
27        self.f = None
28        return self.r
29