1import dataclasses 2import sys 3import types 4from typing import Any, Callable, Dict, List, NamedTuple, Optional, Protocol, Union 5 6# CacheEntry has a `check_fn` field for the guard, and a `code` field for the code object. 7from torch._C._dynamo.eval_frame import ( 8 _CacheEntry as CacheEntry, 9 _ExtraState as ExtraState, 10) 11from torch._guards import CompileId 12 13 14if sys.version_info >= (3, 11): 15 from torch._C._dynamo.eval_frame import _PyInterpreterFrame as DynamoFrameType 16else: 17 from types import FrameType as DynamoFrameType 18 19 20# We use a dict to store additional data per frame. 21FrameState = Dict[Any, Any] 22 23 24class GuardFail(NamedTuple): 25 # A string repr of the piece of failed guard code we eval-ed 26 reason: str 27 # A code object where we failed a guard 28 orig_code: types.CodeType 29 30 31class GuardFn(Protocol): 32 closure_vars: Dict[str, object] 33 args: List[str] 34 code_parts: List[str] 35 verbose_code_parts: List[str] 36 global_scope: Dict[str, object] 37 guard_fail_fn: Optional[Callable[[GuardFail], None]] 38 cache_entry: Optional[CacheEntry] 39 extra_state: Optional[ExtraState] 40 41 # maps locals of user function to bool 42 def __call__(self, f_locals: Dict[str, object]) -> bool: 43 ... 44 45 46@dataclasses.dataclass 47class GuardedCode: 48 code: types.CodeType 49 check_fn: GuardFn 50 compile_id: CompileId 51 52 53class DynamoCallbackFn(Protocol): 54 def __call__( 55 self, 56 frame: DynamoFrameType, 57 cache_entry: Optional[CacheEntry], 58 frame_state: FrameState, 59 ) -> Optional[GuardedCode]: 60 ... 61 62 63DynamoCallback = Union[DynamoCallbackFn, None, bool] 64 65 66class DynamoGuardHook(Protocol): 67 def __call__( 68 self, 69 guard_fn: GuardFn, 70 code: types.CodeType, 71 f_locals: Dict[str, object], 72 index: int, 73 last: bool, 74 ) -> None: 75 ... 76 77 78class ProfilerStartHook(Protocol): 79 def __call__( 80 self, 81 name: str, 82 # TODO(whc) how do I annotate a _RecordFunction here? 83 ) -> Any: 84 ... 85 86 87class ProfilerEndHook(Protocol): 88 def __call__(self, record: Any) -> None: 89 ... 90 91 92class BytecodeHook(Protocol): 93 def __call__( 94 self, code: types.CodeType, new_code: types.CodeType 95 ) -> Optional[types.CodeType]: 96 ... 97