xref: /aosp_15_r20/external/pytorch/torchgen/local.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1from __future__ import annotations
2
3import threading
4from contextlib import contextmanager
5from typing import Iterator
6
7
8# Simple dynamic scoping implementation.  The name "parametrize" comes
9# from Racket.
10#
11# WARNING WARNING: LOOKING TO EDIT THIS FILE?  Think carefully about
12# why you need to add a toggle to the global behavior of code
13# generation.  The parameters here should really only be used
14# for "temporary" situations, where we need to temporarily change
15# the codegen in some cases because we cannot conveniently update
16# all call sites, and are slated to be eliminated once all call
17# sites are eliminated.  If you don't have a plan for how to get there,
18# DON'T add a new entry here.
19
20
21class Locals(threading.local):
22    use_const_ref_for_mutable_tensors: bool | None = None
23    use_ilistref_for_tensor_lists: bool | None = None
24
25
26_locals = Locals()
27
28
29def use_const_ref_for_mutable_tensors() -> bool:
30    assert _locals.use_const_ref_for_mutable_tensors is not None, (
31        "need to initialize local.use_const_ref_for_mutable_tensors with "
32        "local.parametrize"
33    )
34    return _locals.use_const_ref_for_mutable_tensors
35
36
37def use_ilistref_for_tensor_lists() -> bool:
38    assert _locals.use_ilistref_for_tensor_lists is not None, (
39        "need to initialize local.use_ilistref_for_tensor_lists with "
40        "local.parametrize"
41    )
42    return _locals.use_ilistref_for_tensor_lists
43
44
45@contextmanager
46def parametrize(
47    *, use_const_ref_for_mutable_tensors: bool, use_ilistref_for_tensor_lists: bool
48) -> Iterator[None]:
49    old_use_const_ref_for_mutable_tensors = _locals.use_const_ref_for_mutable_tensors
50    old_use_ilistref_for_tensor_lists = _locals.use_ilistref_for_tensor_lists
51    try:
52        _locals.use_const_ref_for_mutable_tensors = use_const_ref_for_mutable_tensors
53        _locals.use_ilistref_for_tensor_lists = use_ilistref_for_tensor_lists
54        yield
55    finally:
56        _locals.use_const_ref_for_mutable_tensors = (
57            old_use_const_ref_for_mutable_tensors
58        )
59        _locals.use_ilistref_for_tensor_lists = old_use_ilistref_for_tensor_lists
60