xref: /aosp_15_r20/external/pytorch/torch/_dynamo/types.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1import dataclasses
2import sys
3import types
4from typing import Any, Callable, Dict, List, NamedTuple, Optional, Protocol, Union
5
6# CacheEntry has a `check_fn` field for the guard, and a `code` field for the code object.
7from torch._C._dynamo.eval_frame import (
8    _CacheEntry as CacheEntry,
9    _ExtraState as ExtraState,
10)
11from torch._guards import CompileId
12
13
14if sys.version_info >= (3, 11):
15    from torch._C._dynamo.eval_frame import _PyInterpreterFrame as DynamoFrameType
16else:
17    from types import FrameType as DynamoFrameType
18
19
20# We use a dict to store additional data per frame.
21FrameState = Dict[Any, Any]
22
23
24class GuardFail(NamedTuple):
25    # A string repr of the piece of failed guard code we eval-ed
26    reason: str
27    # A code object where we failed a guard
28    orig_code: types.CodeType
29
30
31class GuardFn(Protocol):
32    closure_vars: Dict[str, object]
33    args: List[str]
34    code_parts: List[str]
35    verbose_code_parts: List[str]
36    global_scope: Dict[str, object]
37    guard_fail_fn: Optional[Callable[[GuardFail], None]]
38    cache_entry: Optional[CacheEntry]
39    extra_state: Optional[ExtraState]
40
41    # maps locals of user function to bool
42    def __call__(self, f_locals: Dict[str, object]) -> bool:
43        ...
44
45
46@dataclasses.dataclass
47class GuardedCode:
48    code: types.CodeType
49    check_fn: GuardFn
50    compile_id: CompileId
51
52
53class DynamoCallbackFn(Protocol):
54    def __call__(
55        self,
56        frame: DynamoFrameType,
57        cache_entry: Optional[CacheEntry],
58        frame_state: FrameState,
59    ) -> Optional[GuardedCode]:
60        ...
61
62
63DynamoCallback = Union[DynamoCallbackFn, None, bool]
64
65
66class DynamoGuardHook(Protocol):
67    def __call__(
68        self,
69        guard_fn: GuardFn,
70        code: types.CodeType,
71        f_locals: Dict[str, object],
72        index: int,
73        last: bool,
74    ) -> None:
75        ...
76
77
78class ProfilerStartHook(Protocol):
79    def __call__(
80        self,
81        name: str,
82        # TODO(whc) how do I annotate a _RecordFunction here?
83    ) -> Any:
84        ...
85
86
87class ProfilerEndHook(Protocol):
88    def __call__(self, record: Any) -> None:
89        ...
90
91
92class BytecodeHook(Protocol):
93    def __call__(
94        self, code: types.CodeType, new_code: types.CodeType
95    ) -> Optional[types.CodeType]:
96        ...
97