xref: /aosp_15_r20/external/pytorch/torch/jit/_script.pyi (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# mypy: allow-untyped-defs
2# mypy: disable-error-code="type-arg"
3from typing import Any, Callable, NamedTuple, overload, TypeVar
4from typing_extensions import Never, TypeAlias
5
6from _typeshed import Incomplete
7
8import torch
9from torch._classes import classes as classes
10from torch._jit_internal import _qualified_name as _qualified_name
11from torch.jit._builtins import _register_builtin as _register_builtin
12from torch.jit._fuser import (
13    _graph_for as _graph_for,
14    _script_method_graph_for as _script_method_graph_for,
15)
16from torch.jit._monkeytype_config import (
17    JitTypeTraceConfig as JitTypeTraceConfig,
18    JitTypeTraceStore as JitTypeTraceStore,
19    monkeytype_trace as monkeytype_trace,
20)
21from torch.jit._recursive import (
22    _compile_and_register_class as _compile_and_register_class,
23    infer_methods_to_compile as infer_methods_to_compile,
24    ScriptMethodStub as ScriptMethodStub,
25    wrap_cpp_module as wrap_cpp_module,
26)
27from torch.jit._serialization import validate_map_location as validate_map_location
28from torch.jit._state import (
29    _enabled as _enabled,
30    _set_jit_function_cache as _set_jit_function_cache,
31    _set_jit_overload_cache as _set_jit_overload_cache,
32    _try_get_jit_cached_function as _try_get_jit_cached_function,
33    _try_get_jit_cached_overloads as _try_get_jit_cached_overloads,
34)
35from torch.jit.frontend import (
36    get_default_args as get_default_args,
37    get_jit_class_def as get_jit_class_def,
38    get_jit_def as get_jit_def,
39)
40from torch.nn import Module as Module
41from torch.overrides import (
42    has_torch_function as has_torch_function,
43    has_torch_function_unary as has_torch_function_unary,
44    has_torch_function_variadic as has_torch_function_variadic,
45)
46from torch.package import (
47    PackageExporter as PackageExporter,
48    PackageImporter as PackageImporter,
49)
50from torch.utils import set_module as set_module
51
52ScriptFunction = torch._C.ScriptFunction
53
54type_trace_db: JitTypeTraceStore
55
56# Defined in torch/csrc/jit/python/script_init.cpp
57ResolutionCallback: TypeAlias = Callable[[str], Callable[..., Any]]
58_ClassVar = TypeVar("_ClassVar", bound=type)
59
60def _reduce(cls) -> None: ...
61
62class Attribute(NamedTuple):
63    value: Incomplete
64    type: Incomplete
65
66def _get_type_trace_db(): ...
67def _get_function_from_type(cls, name): ...
68def _is_new_style_class(cls): ...
69
70class OrderedDictWrapper:
71    _c: Incomplete
72    def __init__(self, _c) -> None: ...
73    def keys(self): ...
74    def values(self): ...
75    def __len__(self) -> int: ...
76    def __delitem__(self, k) -> None: ...
77    def items(self): ...
78    def __setitem__(self, k, v) -> None: ...
79    def __contains__(self, k) -> bool: ...
80    def __getitem__(self, k): ...
81
82class OrderedModuleDict(OrderedDictWrapper):
83    _python_modules: Incomplete
84    def __init__(self, module, python_dict) -> None: ...
85    def items(self): ...
86    def __contains__(self, k) -> bool: ...
87    def __setitem__(self, k, v) -> None: ...
88    def __getitem__(self, k): ...
89
90class ScriptMeta(type):
91    def __init__(cls, name, bases, attrs) -> None: ...
92
93class _CachedForward:
94    def __get__(self, obj, cls): ...
95
96class ScriptWarning(Warning): ...
97
98def script_method(fn): ...
99
100class ConstMap:
101    const_mapping: Incomplete
102    def __init__(self, const_mapping) -> None: ...
103    def __getattr__(self, attr): ...
104
105def unpackage_script_module(
106    importer: PackageImporter,
107    script_module_id: str,
108) -> torch.nn.Module: ...
109
110_magic_methods: Incomplete
111
112class RecursiveScriptClass:
113    _c: Incomplete
114    _props: Incomplete
115    def __init__(self, cpp_class) -> None: ...
116    def __getattr__(self, attr): ...
117    def __setattr__(self, attr, value) -> None: ...
118    def forward_magic_method(self, method_name, *args, **kwargs): ...
119    def __getstate__(self) -> None: ...
120    def __iadd__(self, other): ...
121
122def method_template(self, *args, **kwargs): ...
123
124class ScriptModule(Module, metaclass=ScriptMeta):
125    __jit_unused_properties__: Incomplete
126    def __init__(self) -> None: ...
127    forward: Callable[..., Any]
128    def __getattr__(self, attr): ...
129    def __setattr__(self, attr, value) -> None: ...
130    def define(self, src): ...
131    def _replicate_for_data_parallel(self): ...
132    def __reduce_package__(self, exporter: PackageExporter): ...
133    # add __jit_unused_properties__
134    @property
135    def code(self) -> str: ...
136    @property
137    def code_with_constants(self) -> tuple[str, ConstMap]: ...
138    @property
139    def graph(self) -> torch.Graph: ...
140    @property
141    def inlined_graph(self) -> torch.Graph: ...
142    @property
143    def original_name(self) -> str: ...
144
145class RecursiveScriptModule(ScriptModule):
146    _disable_script_meta: bool
147    _c: Incomplete
148    def __init__(self, cpp_module) -> None: ...
149    @staticmethod
150    def _construct(cpp_module, init_fn): ...
151    @staticmethod
152    def _finalize_scriptmodule(script_module) -> None: ...
153    _concrete_type: Incomplete
154    _modules: Incomplete
155    _parameters: Incomplete
156    _buffers: Incomplete
157    __dict__: Incomplete
158    def _reconstruct(self, cpp_module) -> None: ...
159    def save(self, f, **kwargs): ...
160    def _save_for_lite_interpreter(self, *args, **kwargs): ...
161    def _save_to_buffer_for_lite_interpreter(self, *args, **kwargs): ...
162    def save_to_buffer(self, *args, **kwargs): ...
163    def get_debug_state(self, *args, **kwargs): ...
164    def extra_repr(self): ...
165    def graph_for(self, *args, **kwargs): ...
166    def define(self, src) -> None: ...
167    def __getattr__(self, attr): ...
168    def __setattr__(self, attr, value) -> None: ...
169    def __copy__(self): ...
170    def __deepcopy__(self, memo): ...
171    def forward_magic_method(self, method_name, *args, **kwargs): ...
172    def __iter__(self): ...
173    def __getitem__(self, idx): ...
174    def __len__(self) -> int: ...
175    def __contains__(self, key) -> bool: ...
176    def __dir__(self): ...
177    def __bool__(self) -> bool: ...
178    def _replicate_for_data_parallel(self): ...
179
180def _get_methods(cls): ...
181
182_compiled_methods_allowlist: Incomplete
183
184def _make_fail(name): ...
185def call_prepare_scriptable_func_impl(obj, memo): ...
186def call_prepare_scriptable_func(obj): ...
187def create_script_dict(obj): ...
188def create_script_list(obj, type_hint: Incomplete | None = ...): ...
189@overload
190def script(
191    obj: type[Module],
192    optimize: bool | None = None,
193    _frames_up: int = 0,
194    _rcb: ResolutionCallback | None = None,
195    example_inputs: list[tuple] | dict[Callable, list[tuple]] | None = None,
196) -> Never: ...
197@overload
198def script(  # type: ignore[misc]
199    obj: dict,
200    optimize: bool | None = None,
201    _frames_up: int = 0,
202    _rcb: ResolutionCallback | None = None,
203    example_inputs: list[tuple] | dict[Callable, list[tuple]] | None = None,
204) -> torch.ScriptDict: ...
205@overload
206def script(  # type: ignore[misc]
207    obj: list,
208    optimize: bool | None = None,
209    _frames_up: int = 0,
210    _rcb: ResolutionCallback | None = None,
211    example_inputs: list[tuple] | dict[Callable, list[tuple]] | None = None,
212) -> torch.ScriptList: ...
213@overload
214def script(  # type: ignore[misc]
215    obj: Module,
216    optimize: bool | None = None,
217    _frames_up: int = 0,
218    _rcb: ResolutionCallback | None = None,
219    example_inputs: list[tuple] | dict[Callable, list[tuple]] | None = None,
220) -> RecursiveScriptModule: ...
221@overload
222def script(  # type: ignore[misc]
223    obj: _ClassVar,
224    optimize: bool | None = None,
225    _frames_up: int = 0,
226    _rcb: ResolutionCallback | None = None,
227    example_inputs: list[tuple] | dict[Callable, list[tuple]] | None = None,
228) -> _ClassVar: ...
229@overload
230def script(  # type: ignore[misc]
231    obj: Callable,
232    optimize: bool | None = None,
233    _frames_up: int = 0,
234    _rcb: ResolutionCallback | None = None,
235    example_inputs: list[tuple] | dict[Callable, list[tuple]] | None = None,
236) -> ScriptFunction: ...
237@overload
238def script(
239    obj: Any,
240    optimize: bool | None = None,
241    _frames_up: int = 0,
242    _rcb: ResolutionCallback | None = None,
243    example_inputs: list[tuple] | dict[Callable, list[tuple]] | None = None,
244) -> RecursiveScriptClass: ...
245@overload
246def script(
247    obj,
248    optimize: Incomplete | None = ...,
249    _frames_up: int = ...,
250    _rcb: Incomplete | None = ...,
251    example_inputs: list[tuple] | dict[Callable, list[tuple]] | None = ...,
252): ...
253def _check_overload_defaults(impl_defaults, overload_defaults, loc) -> None: ...
254def _compile_function_with_overload(overload_fn, qual_name, impl_fn): ...
255def _get_overloads(obj): ...
256def _check_directly_compile_overloaded(obj) -> None: ...
257def interface(obj): ...
258def _recursive_compile_class(obj, loc): ...
259
260CompilationUnit: Incomplete
261
262def pad(s: str, padding: int, offset: int = ..., char: str = ...): ...
263
264class _ScriptProfileColumn:
265    header: Incomplete
266    alignment: Incomplete
267    offset: Incomplete
268    rows: Incomplete
269    def __init__(
270        self,
271        header: str,
272        alignment: int = ...,
273        offset: int = ...,
274    ) -> None: ...
275    def add_row(self, lineno: int, value: Any): ...
276    def materialize(self): ...
277
278class _ScriptProfileTable:
279    cols: Incomplete
280    source_range: Incomplete
281    def __init__(
282        self,
283        cols: list[_ScriptProfileColumn],
284        source_range: list[int],
285    ) -> None: ...
286    def dump_string(self): ...
287
288class _ScriptProfile:
289    profile: Incomplete
290    def __init__(self) -> None: ...
291    def enable(self) -> None: ...
292    def disable(self) -> None: ...
293    def dump_string(self) -> str: ...
294    def dump(self) -> None: ...
295
296def _unwrap_optional(x): ...
297