1# mypy: allow-untyped-defs 2import contextlib 3from typing import Callable, List, TYPE_CHECKING 4 5 6if TYPE_CHECKING: 7 import torch 8 9# Executed in the order they're registered 10INTERMEDIATE_HOOKS: List[Callable[[str, "torch.Tensor"], None]] = [] 11 12 13@contextlib.contextmanager 14def intermediate_hook(fn): 15 INTERMEDIATE_HOOKS.append(fn) 16 try: 17 yield 18 finally: 19 INTERMEDIATE_HOOKS.pop() 20 21 22def run_intermediate_hooks(name, val): 23 global INTERMEDIATE_HOOKS 24 hooks = INTERMEDIATE_HOOKS 25 INTERMEDIATE_HOOKS = [] 26 try: 27 for hook in hooks: 28 hook(name, val) 29 finally: 30 INTERMEDIATE_HOOKS = hooks 31