xref: /aosp_15_r20/external/pytorch/torch/_inductor/hooks.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# mypy: allow-untyped-defs
2import contextlib
3from typing import Callable, List, TYPE_CHECKING
4
5
6if TYPE_CHECKING:
7    import torch
8
9# Executed in the order they're registered
10INTERMEDIATE_HOOKS: List[Callable[[str, "torch.Tensor"], None]] = []
11
12
13@contextlib.contextmanager
14def intermediate_hook(fn):
15    INTERMEDIATE_HOOKS.append(fn)
16    try:
17        yield
18    finally:
19        INTERMEDIATE_HOOKS.pop()
20
21
22def run_intermediate_hooks(name, val):
23    global INTERMEDIATE_HOOKS
24    hooks = INTERMEDIATE_HOOKS
25    INTERMEDIATE_HOOKS = []
26    try:
27        for hook in hooks:
28            hook(name, val)
29    finally:
30        INTERMEDIATE_HOOKS = hooks
31