xref: /aosp_15_r20/external/pytorch/torch/jit/__init__.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# mypy: allow-untyped-defs
2import warnings
3from contextlib import contextmanager
4from typing import Any, Iterator
5
6import torch._C
7
8# These are imported so users can access them from the `torch.jit` module
9from torch._jit_internal import (
10    _Await,
11    _drop,
12    _IgnoreContextManager,
13    _isinstance,
14    _overload,
15    _overload_method,
16    export,
17    Final,
18    Future,
19    ignore,
20    is_scripting,
21    unused,
22)
23from torch.jit._async import fork, wait
24from torch.jit._await import _awaitable, _awaitable_nowait, _awaitable_wait
25from torch.jit._decomposition_utils import _register_decomposition
26from torch.jit._freeze import freeze, optimize_for_inference, run_frozen_optimizations
27from torch.jit._fuser import (
28    fuser,
29    last_executed_optimized_graph,
30    optimized_execution,
31    set_fusion_strategy,
32)
33from torch.jit._ir_utils import _InsertPoint
34from torch.jit._script import (
35    _ScriptProfile,
36    _unwrap_optional,
37    Attribute,
38    CompilationUnit,
39    interface,
40    RecursiveScriptClass,
41    RecursiveScriptModule,
42    script,
43    script_method,
44    ScriptFunction,
45    ScriptModule,
46    ScriptWarning,
47)
48from torch.jit._serialization import (
49    jit_module_from_flatbuffer,
50    load,
51    save,
52    save_jit_module_to_flatbuffer,
53)
54from torch.jit._trace import (
55    _flatten,
56    _get_trace_graph,
57    _script_if_tracing,
58    _unique_state_dict,
59    is_tracing,
60    ONNXTracedModule,
61    TopLevelTracedModule,
62    trace,
63    trace_module,
64    TracedModule,
65    TracerWarning,
66    TracingCheckError,
67)
68from torch.utils import set_module
69
70
71__all__ = [
72    "Attribute",
73    "CompilationUnit",
74    "Error",
75    "Future",
76    "ScriptFunction",
77    "ScriptModule",
78    "annotate",
79    "enable_onednn_fusion",
80    "export",
81    "export_opnames",
82    "fork",
83    "freeze",
84    "interface",
85    "ignore",
86    "isinstance",
87    "load",
88    "onednn_fusion_enabled",
89    "optimize_for_inference",
90    "save",
91    "script",
92    "script_if_tracing",
93    "set_fusion_strategy",
94    "strict_fusion",
95    "trace",
96    "trace_module",
97    "unused",
98    "wait",
99]
100
101# For backwards compatibility
102_fork = fork
103_wait = wait
104_set_fusion_strategy = set_fusion_strategy
105
106
107def export_opnames(m):
108    r"""
109    Generate new bytecode for a Script module.
110
111    Returns what the op list would be for a Script Module based off the current code base.
112
113    If you have a LiteScriptModule and want to get the currently present
114    list of ops call _export_operator_list instead.
115    """
116    return torch._C._export_opnames(m._c)
117
118
119# torch.jit.Error
120Error = torch._C.JITException
121set_module(Error, "torch.jit")
122# This is not perfect but works in common cases
123Error.__name__ = "Error"
124Error.__qualname__ = "Error"
125
126
127# for use in python if using annotate
128def annotate(the_type, the_value):
129    """Use to give type of `the_value` in TorchScript compiler.
130
131    This method is a pass-through function that returns `the_value`, used to hint TorchScript
132    compiler the type of `the_value`. It is a no-op when running outside of TorchScript.
133
134    Though TorchScript can infer correct type for most Python expressions, there are some cases where
135    type inference can be wrong, including:
136
137    - Empty containers like `[]` and `{}`, which TorchScript assumes to be container of `Tensor`
138    - Optional types like `Optional[T]` but assigned a valid value of type `T`, TorchScript would assume
139      it is type `T` rather than `Optional[T]`
140
141    Note that `annotate()` does not help in `__init__` method of `torch.nn.Module` subclasses because it
142    is executed in eager mode. To annotate types of `torch.nn.Module` attributes,
143    use :meth:`~torch.jit.Attribute` instead.
144
145    Example:
146
147    .. testcode::
148
149        import torch
150        from typing import Dict
151
152        @torch.jit.script
153        def fn():
154            # Telling TorchScript that this empty dictionary is a (str -> int) dictionary
155            # instead of default dictionary type of (str -> Tensor).
156            d = torch.jit.annotate(Dict[str, int], {})
157
158            # Without `torch.jit.annotate` above, following statement would fail because of
159            # type mismatch.
160            d["name"] = 20
161
162    .. testcleanup::
163
164        del fn
165
166    Args:
167        the_type: Python type that should be passed to TorchScript compiler as type hint for `the_value`
168        the_value: Value or expression to hint type for.
169
170    Returns:
171        `the_value` is passed back as return value.
172    """
173    return the_value
174
175
176def script_if_tracing(fn):
177    """
178    Compiles ``fn`` when it is first called during tracing.
179
180    ``torch.jit.script`` has a non-negligible start up time when it is first called due to
181    lazy-initializations of many compiler builtins. Therefore you should not use
182    it in library code. However, you may want to have parts of your library work
183    in tracing even if they use control flow. In these cases, you should use
184    ``@torch.jit.script_if_tracing`` to substitute for
185    ``torch.jit.script``.
186
187    Args:
188        fn: A function to compile.
189
190    Returns:
191        If called during tracing, a :class:`ScriptFunction` created by `torch.jit.script` is returned.
192        Otherwise, the original function `fn` is returned.
193    """
194    return _script_if_tracing(fn)
195
196
197# for torch.jit.isinstance
198def isinstance(obj, target_type):
199    """
200    Provide container type refinement in TorchScript.
201
202    It can refine parameterized containers of the List, Dict, Tuple, and Optional types. E.g. ``List[str]``,
203    ``Dict[str, List[torch.Tensor]]``, ``Optional[Tuple[int,str,int]]``. It can also
204    refine basic types such as bools and ints that are available in TorchScript.
205
206    Args:
207        obj: object to refine the type of
208        target_type: type to try to refine obj to
209    Returns:
210        ``bool``: True if obj was successfully refined to the type of target_type,
211            False otherwise with no new type refinement
212
213
214    Example (using ``torch.jit.isinstance`` for type refinement):
215    .. testcode::
216
217        import torch
218        from typing import Any, Dict, List
219
220        class MyModule(torch.nn.Module):
221            def __init__(self) -> None:
222                super().__init__()
223
224            def forward(self, input: Any): # note the Any type
225                if torch.jit.isinstance(input, List[torch.Tensor]):
226                    for t in input:
227                        y = t.clamp(0, 0.5)
228                elif torch.jit.isinstance(input, Dict[str, str]):
229                    for val in input.values():
230                        print(val)
231
232        m = torch.jit.script(MyModule())
233        x = [torch.rand(3,3), torch.rand(4,3)]
234        m(x)
235        y = {"key1":"val1","key2":"val2"}
236        m(y)
237    """
238    return _isinstance(obj, target_type)
239
240
241class strict_fusion:
242    """
243    Give errors if not all nodes have been fused in inference, or symbolically differentiated in training.
244
245    Example:
246    Forcing fusion of additions.
247
248    .. code-block:: python
249
250        @torch.jit.script
251        def foo(x):
252            with torch.jit.strict_fusion():
253                return x + x + x
254
255    """
256
257    def __init__(self) -> None:
258        if not torch._jit_internal.is_scripting():
259            warnings.warn("Only works in script mode")
260
261    def __enter__(self):
262        pass
263
264    def __exit__(self, type: Any, value: Any, tb: Any) -> None:
265        pass
266
267
268# Context manager for globally hiding source ranges when printing graphs.
269# Note that these functions are exposed to Python as static members of the
270# Graph class, so mypy checks need to be skipped.
271@contextmanager
272def _hide_source_ranges() -> Iterator[None]:
273    old_enable_source_ranges = torch._C.Graph.global_print_source_ranges  # type: ignore[attr-defined]
274    try:
275        torch._C.Graph.set_global_print_source_ranges(False)  # type: ignore[attr-defined]
276        yield
277    finally:
278        torch._C.Graph.set_global_print_source_ranges(old_enable_source_ranges)  # type: ignore[attr-defined]
279
280
281def enable_onednn_fusion(enabled: bool):
282    """Enable or disables onednn JIT fusion based on the parameter `enabled`."""
283    torch._C._jit_set_llga_enabled(enabled)
284
285
286def onednn_fusion_enabled():
287    """Return whether onednn JIT fusion is enabled."""
288    return torch._C._jit_llga_enabled()
289
290
291del Any
292
293if not torch._C._jit_init():
294    raise RuntimeError("JIT initialization failed")
295