1# mypy: allow-untyped-defs 2"""JIT-related state. 3 4This module stores various pieces of Python-global state relating to the JIT. 5 6This is not intended to be imported directly; please the exposed 7functionalities in `torch.jit`. 8""" 9import os 10import weakref 11from typing import Any, Dict, Type 12 13import torch 14 15 16class EnabledProxy: 17 """Stores whether the JIT is enabled or not. 18 19 This is just a wrapper for a bool, so that we get reference semantics 20 """ 21 22 def __init__(self) -> None: 23 self.enabled = self.parse_env( 24 "PYTORCH_JIT", True, "> Using PyTorch JIT", "> PyTorch JIT DISABLED" 25 ) 26 27 def parse_env(self, name, default, true_message, false_message): 28 value = os.environ.get(name) 29 if value is None: 30 return default 31 if value.lower() in {"1", "true", "yes"}: 32 return True 33 elif value.lower() in {"0", "false", "no"}: 34 return False 35 if value == "1v": 36 print(true_message) 37 return True 38 elif value == "0v": 39 print(false_message) 40 return False 41 raise ValueError(f"Unknown setting of {name}. Try using 0 or 1.") 42 43 def __bool__(self): 44 return self.enabled 45 46 47_enabled = EnabledProxy() 48 49 50def disable(): 51 _enabled.enabled = False 52 53 54def enable(): 55 _enabled.enabled = True 56 57 58# The Python CompilationUnit. All functions and modules defined in Python will 59# live in here. It's defined in Python because doing in cpp creates static 60# destruction order issues. 61_python_cu = torch._C.CompilationUnit() 62 63 64# python class => ScriptClass mapping 65_script_classes: Dict[Type[Any], Type[Any]] = {} 66_name_to_pyclass: Dict[str, Type[Any]] = {} 67 68 69def _add_script_class(python_class, script_class): 70 _script_classes[python_class] = script_class 71 _name_to_pyclass[script_class.qualified_name()] = python_class 72 73 74def _get_script_class(python_class): 75 override = getattr(python_class, "_jit_override_qualname", None) 76 if override is not None: 77 python_class = _get_python_class(override) 78 return _script_classes.get(python_class, None) 79 80 81def _get_python_class(qualified_name): 82 return _name_to_pyclass.get(qualified_name, None) 83 84 85def _clear_class_state(): 86 _script_classes.clear() 87 _name_to_pyclass.clear() 88 89 90# Caching: we currently cache compilation of free functions and overloaded functions. 91# To cache free functions we hold a weak ref to the function object and 92# map to the compiled fn's qualified name. 93# To cache overloaded functions we hold a weak ref to the function obj and 94# map to all of its overloaded compiled fns. 95# In the future we could consider caching more types of objects so that 96# aliasing is preserved across separate compilations of the same object. 97 98_jit_caching_layer: weakref.WeakKeyDictionary = weakref.WeakKeyDictionary() 99_jit_function_overload_caching: weakref.WeakKeyDictionary = weakref.WeakKeyDictionary() 100 101 102def _try_get_jit_cached_overloads(key): 103 qual_names = _jit_function_overload_caching.get(key, None) 104 if qual_names: 105 return [_python_cu.find_function(qual_name) for qual_name in qual_names] 106 else: 107 return None 108 109 110def _set_jit_overload_cache(key, compiled_fns): 111 _jit_function_overload_caching[key] = [fn.qualified_name for fn in compiled_fns] 112 113 114def _try_get_jit_cached_function(key): 115 if getattr(key, "__disable_jit_function_caching__", False) is True: 116 return None 117 qual_name = _jit_caching_layer.get(key, None) 118 if qual_name: 119 return _python_cu.find_function(qual_name) 120 else: 121 return None 122 123 124def _set_jit_function_cache(key, value): 125 # only free functions currently supported 126 assert isinstance(value, torch.jit.ScriptFunction) 127 _jit_caching_layer[key] = value.qualified_name 128