xref: /aosp_15_r20/external/pytorch/torch/jit/_state.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# mypy: allow-untyped-defs
2"""JIT-related state.
3
4This module stores various pieces of Python-global state relating to the JIT.
5
6This is not intended to be imported directly; please the exposed
7functionalities in `torch.jit`.
8"""
9import os
10import weakref
11from typing import Any, Dict, Type
12
13import torch
14
15
16class EnabledProxy:
17    """Stores whether the JIT is enabled or not.
18
19    This is just a wrapper for a bool, so that we get reference semantics
20    """
21
22    def __init__(self) -> None:
23        self.enabled = self.parse_env(
24            "PYTORCH_JIT", True, "> Using PyTorch JIT", "> PyTorch JIT DISABLED"
25        )
26
27    def parse_env(self, name, default, true_message, false_message):
28        value = os.environ.get(name)
29        if value is None:
30            return default
31        if value.lower() in {"1", "true", "yes"}:
32            return True
33        elif value.lower() in {"0", "false", "no"}:
34            return False
35        if value == "1v":
36            print(true_message)
37            return True
38        elif value == "0v":
39            print(false_message)
40            return False
41        raise ValueError(f"Unknown setting of {name}. Try using 0 or 1.")
42
43    def __bool__(self):
44        return self.enabled
45
46
47_enabled = EnabledProxy()
48
49
50def disable():
51    _enabled.enabled = False
52
53
54def enable():
55    _enabled.enabled = True
56
57
58# The Python CompilationUnit. All functions and modules defined in Python will
59# live in here. It's defined in Python because doing in cpp creates static
60# destruction order issues.
61_python_cu = torch._C.CompilationUnit()
62
63
64# python class => ScriptClass mapping
65_script_classes: Dict[Type[Any], Type[Any]] = {}
66_name_to_pyclass: Dict[str, Type[Any]] = {}
67
68
69def _add_script_class(python_class, script_class):
70    _script_classes[python_class] = script_class
71    _name_to_pyclass[script_class.qualified_name()] = python_class
72
73
74def _get_script_class(python_class):
75    override = getattr(python_class, "_jit_override_qualname", None)
76    if override is not None:
77        python_class = _get_python_class(override)
78    return _script_classes.get(python_class, None)
79
80
81def _get_python_class(qualified_name):
82    return _name_to_pyclass.get(qualified_name, None)
83
84
85def _clear_class_state():
86    _script_classes.clear()
87    _name_to_pyclass.clear()
88
89
90# Caching: we currently cache compilation of free functions and overloaded functions.
91# To cache free functions we hold a weak ref to the function object and
92# map to the compiled fn's qualified name.
93# To cache overloaded functions we hold a weak ref to the function obj and
94# map to all of its overloaded compiled fns.
95# In the future we could consider caching more types of objects so that
96# aliasing is preserved across separate compilations of the same object.
97
98_jit_caching_layer: weakref.WeakKeyDictionary = weakref.WeakKeyDictionary()
99_jit_function_overload_caching: weakref.WeakKeyDictionary = weakref.WeakKeyDictionary()
100
101
102def _try_get_jit_cached_overloads(key):
103    qual_names = _jit_function_overload_caching.get(key, None)
104    if qual_names:
105        return [_python_cu.find_function(qual_name) for qual_name in qual_names]
106    else:
107        return None
108
109
110def _set_jit_overload_cache(key, compiled_fns):
111    _jit_function_overload_caching[key] = [fn.qualified_name for fn in compiled_fns]
112
113
114def _try_get_jit_cached_function(key):
115    if getattr(key, "__disable_jit_function_caching__", False) is True:
116        return None
117    qual_name = _jit_caching_layer.get(key, None)
118    if qual_name:
119        return _python_cu.find_function(qual_name)
120    else:
121        return None
122
123
124def _set_jit_function_cache(key, value):
125    # only free functions currently supported
126    assert isinstance(value, torch.jit.ScriptFunction)
127    _jit_caching_layer[key] = value.qualified_name
128