1# mypy: allow-untyped-defs 2import warnings 3from contextlib import contextmanager 4from typing import Any, Iterator 5 6import torch._C 7 8# These are imported so users can access them from the `torch.jit` module 9from torch._jit_internal import ( 10 _Await, 11 _drop, 12 _IgnoreContextManager, 13 _isinstance, 14 _overload, 15 _overload_method, 16 export, 17 Final, 18 Future, 19 ignore, 20 is_scripting, 21 unused, 22) 23from torch.jit._async import fork, wait 24from torch.jit._await import _awaitable, _awaitable_nowait, _awaitable_wait 25from torch.jit._decomposition_utils import _register_decomposition 26from torch.jit._freeze import freeze, optimize_for_inference, run_frozen_optimizations 27from torch.jit._fuser import ( 28 fuser, 29 last_executed_optimized_graph, 30 optimized_execution, 31 set_fusion_strategy, 32) 33from torch.jit._ir_utils import _InsertPoint 34from torch.jit._script import ( 35 _ScriptProfile, 36 _unwrap_optional, 37 Attribute, 38 CompilationUnit, 39 interface, 40 RecursiveScriptClass, 41 RecursiveScriptModule, 42 script, 43 script_method, 44 ScriptFunction, 45 ScriptModule, 46 ScriptWarning, 47) 48from torch.jit._serialization import ( 49 jit_module_from_flatbuffer, 50 load, 51 save, 52 save_jit_module_to_flatbuffer, 53) 54from torch.jit._trace import ( 55 _flatten, 56 _get_trace_graph, 57 _script_if_tracing, 58 _unique_state_dict, 59 is_tracing, 60 ONNXTracedModule, 61 TopLevelTracedModule, 62 trace, 63 trace_module, 64 TracedModule, 65 TracerWarning, 66 TracingCheckError, 67) 68from torch.utils import set_module 69 70 71__all__ = [ 72 "Attribute", 73 "CompilationUnit", 74 "Error", 75 "Future", 76 "ScriptFunction", 77 "ScriptModule", 78 "annotate", 79 "enable_onednn_fusion", 80 "export", 81 "export_opnames", 82 "fork", 83 "freeze", 84 "interface", 85 "ignore", 86 "isinstance", 87 "load", 88 "onednn_fusion_enabled", 89 "optimize_for_inference", 90 "save", 91 "script", 92 "script_if_tracing", 93 "set_fusion_strategy", 94 "strict_fusion", 95 "trace", 96 "trace_module", 97 "unused", 98 "wait", 99] 100 101# For backwards compatibility 102_fork = fork 103_wait = wait 104_set_fusion_strategy = set_fusion_strategy 105 106 107def export_opnames(m): 108 r""" 109 Generate new bytecode for a Script module. 110 111 Returns what the op list would be for a Script Module based off the current code base. 112 113 If you have a LiteScriptModule and want to get the currently present 114 list of ops call _export_operator_list instead. 115 """ 116 return torch._C._export_opnames(m._c) 117 118 119# torch.jit.Error 120Error = torch._C.JITException 121set_module(Error, "torch.jit") 122# This is not perfect but works in common cases 123Error.__name__ = "Error" 124Error.__qualname__ = "Error" 125 126 127# for use in python if using annotate 128def annotate(the_type, the_value): 129 """Use to give type of `the_value` in TorchScript compiler. 130 131 This method is a pass-through function that returns `the_value`, used to hint TorchScript 132 compiler the type of `the_value`. It is a no-op when running outside of TorchScript. 133 134 Though TorchScript can infer correct type for most Python expressions, there are some cases where 135 type inference can be wrong, including: 136 137 - Empty containers like `[]` and `{}`, which TorchScript assumes to be container of `Tensor` 138 - Optional types like `Optional[T]` but assigned a valid value of type `T`, TorchScript would assume 139 it is type `T` rather than `Optional[T]` 140 141 Note that `annotate()` does not help in `__init__` method of `torch.nn.Module` subclasses because it 142 is executed in eager mode. To annotate types of `torch.nn.Module` attributes, 143 use :meth:`~torch.jit.Attribute` instead. 144 145 Example: 146 147 .. testcode:: 148 149 import torch 150 from typing import Dict 151 152 @torch.jit.script 153 def fn(): 154 # Telling TorchScript that this empty dictionary is a (str -> int) dictionary 155 # instead of default dictionary type of (str -> Tensor). 156 d = torch.jit.annotate(Dict[str, int], {}) 157 158 # Without `torch.jit.annotate` above, following statement would fail because of 159 # type mismatch. 160 d["name"] = 20 161 162 .. testcleanup:: 163 164 del fn 165 166 Args: 167 the_type: Python type that should be passed to TorchScript compiler as type hint for `the_value` 168 the_value: Value or expression to hint type for. 169 170 Returns: 171 `the_value` is passed back as return value. 172 """ 173 return the_value 174 175 176def script_if_tracing(fn): 177 """ 178 Compiles ``fn`` when it is first called during tracing. 179 180 ``torch.jit.script`` has a non-negligible start up time when it is first called due to 181 lazy-initializations of many compiler builtins. Therefore you should not use 182 it in library code. However, you may want to have parts of your library work 183 in tracing even if they use control flow. In these cases, you should use 184 ``@torch.jit.script_if_tracing`` to substitute for 185 ``torch.jit.script``. 186 187 Args: 188 fn: A function to compile. 189 190 Returns: 191 If called during tracing, a :class:`ScriptFunction` created by `torch.jit.script` is returned. 192 Otherwise, the original function `fn` is returned. 193 """ 194 return _script_if_tracing(fn) 195 196 197# for torch.jit.isinstance 198def isinstance(obj, target_type): 199 """ 200 Provide container type refinement in TorchScript. 201 202 It can refine parameterized containers of the List, Dict, Tuple, and Optional types. E.g. ``List[str]``, 203 ``Dict[str, List[torch.Tensor]]``, ``Optional[Tuple[int,str,int]]``. It can also 204 refine basic types such as bools and ints that are available in TorchScript. 205 206 Args: 207 obj: object to refine the type of 208 target_type: type to try to refine obj to 209 Returns: 210 ``bool``: True if obj was successfully refined to the type of target_type, 211 False otherwise with no new type refinement 212 213 214 Example (using ``torch.jit.isinstance`` for type refinement): 215 .. testcode:: 216 217 import torch 218 from typing import Any, Dict, List 219 220 class MyModule(torch.nn.Module): 221 def __init__(self) -> None: 222 super().__init__() 223 224 def forward(self, input: Any): # note the Any type 225 if torch.jit.isinstance(input, List[torch.Tensor]): 226 for t in input: 227 y = t.clamp(0, 0.5) 228 elif torch.jit.isinstance(input, Dict[str, str]): 229 for val in input.values(): 230 print(val) 231 232 m = torch.jit.script(MyModule()) 233 x = [torch.rand(3,3), torch.rand(4,3)] 234 m(x) 235 y = {"key1":"val1","key2":"val2"} 236 m(y) 237 """ 238 return _isinstance(obj, target_type) 239 240 241class strict_fusion: 242 """ 243 Give errors if not all nodes have been fused in inference, or symbolically differentiated in training. 244 245 Example: 246 Forcing fusion of additions. 247 248 .. code-block:: python 249 250 @torch.jit.script 251 def foo(x): 252 with torch.jit.strict_fusion(): 253 return x + x + x 254 255 """ 256 257 def __init__(self) -> None: 258 if not torch._jit_internal.is_scripting(): 259 warnings.warn("Only works in script mode") 260 261 def __enter__(self): 262 pass 263 264 def __exit__(self, type: Any, value: Any, tb: Any) -> None: 265 pass 266 267 268# Context manager for globally hiding source ranges when printing graphs. 269# Note that these functions are exposed to Python as static members of the 270# Graph class, so mypy checks need to be skipped. 271@contextmanager 272def _hide_source_ranges() -> Iterator[None]: 273 old_enable_source_ranges = torch._C.Graph.global_print_source_ranges # type: ignore[attr-defined] 274 try: 275 torch._C.Graph.set_global_print_source_ranges(False) # type: ignore[attr-defined] 276 yield 277 finally: 278 torch._C.Graph.set_global_print_source_ranges(old_enable_source_ranges) # type: ignore[attr-defined] 279 280 281def enable_onednn_fusion(enabled: bool): 282 """Enable or disables onednn JIT fusion based on the parameter `enabled`.""" 283 torch._C._jit_set_llga_enabled(enabled) 284 285 286def onednn_fusion_enabled(): 287 """Return whether onednn JIT fusion is enabled.""" 288 return torch._C._jit_llga_enabled() 289 290 291del Any 292 293if not torch._C._jit_init(): 294 raise RuntimeError("JIT initialization failed") 295