1# mypy: allow-untyped-defs 2# mypy: disable-error-code="type-arg" 3from typing import Any, Callable, NamedTuple, overload, TypeVar 4from typing_extensions import Never, TypeAlias 5 6from _typeshed import Incomplete 7 8import torch 9from torch._classes import classes as classes 10from torch._jit_internal import _qualified_name as _qualified_name 11from torch.jit._builtins import _register_builtin as _register_builtin 12from torch.jit._fuser import ( 13 _graph_for as _graph_for, 14 _script_method_graph_for as _script_method_graph_for, 15) 16from torch.jit._monkeytype_config import ( 17 JitTypeTraceConfig as JitTypeTraceConfig, 18 JitTypeTraceStore as JitTypeTraceStore, 19 monkeytype_trace as monkeytype_trace, 20) 21from torch.jit._recursive import ( 22 _compile_and_register_class as _compile_and_register_class, 23 infer_methods_to_compile as infer_methods_to_compile, 24 ScriptMethodStub as ScriptMethodStub, 25 wrap_cpp_module as wrap_cpp_module, 26) 27from torch.jit._serialization import validate_map_location as validate_map_location 28from torch.jit._state import ( 29 _enabled as _enabled, 30 _set_jit_function_cache as _set_jit_function_cache, 31 _set_jit_overload_cache as _set_jit_overload_cache, 32 _try_get_jit_cached_function as _try_get_jit_cached_function, 33 _try_get_jit_cached_overloads as _try_get_jit_cached_overloads, 34) 35from torch.jit.frontend import ( 36 get_default_args as get_default_args, 37 get_jit_class_def as get_jit_class_def, 38 get_jit_def as get_jit_def, 39) 40from torch.nn import Module as Module 41from torch.overrides import ( 42 has_torch_function as has_torch_function, 43 has_torch_function_unary as has_torch_function_unary, 44 has_torch_function_variadic as has_torch_function_variadic, 45) 46from torch.package import ( 47 PackageExporter as PackageExporter, 48 PackageImporter as PackageImporter, 49) 50from torch.utils import set_module as set_module 51 52ScriptFunction = torch._C.ScriptFunction 53 54type_trace_db: JitTypeTraceStore 55 56# Defined in torch/csrc/jit/python/script_init.cpp 57ResolutionCallback: TypeAlias = Callable[[str], Callable[..., Any]] 58_ClassVar = TypeVar("_ClassVar", bound=type) 59 60def _reduce(cls) -> None: ... 61 62class Attribute(NamedTuple): 63 value: Incomplete 64 type: Incomplete 65 66def _get_type_trace_db(): ... 67def _get_function_from_type(cls, name): ... 68def _is_new_style_class(cls): ... 69 70class OrderedDictWrapper: 71 _c: Incomplete 72 def __init__(self, _c) -> None: ... 73 def keys(self): ... 74 def values(self): ... 75 def __len__(self) -> int: ... 76 def __delitem__(self, k) -> None: ... 77 def items(self): ... 78 def __setitem__(self, k, v) -> None: ... 79 def __contains__(self, k) -> bool: ... 80 def __getitem__(self, k): ... 81 82class OrderedModuleDict(OrderedDictWrapper): 83 _python_modules: Incomplete 84 def __init__(self, module, python_dict) -> None: ... 85 def items(self): ... 86 def __contains__(self, k) -> bool: ... 87 def __setitem__(self, k, v) -> None: ... 88 def __getitem__(self, k): ... 89 90class ScriptMeta(type): 91 def __init__(cls, name, bases, attrs) -> None: ... 92 93class _CachedForward: 94 def __get__(self, obj, cls): ... 95 96class ScriptWarning(Warning): ... 97 98def script_method(fn): ... 99 100class ConstMap: 101 const_mapping: Incomplete 102 def __init__(self, const_mapping) -> None: ... 103 def __getattr__(self, attr): ... 104 105def unpackage_script_module( 106 importer: PackageImporter, 107 script_module_id: str, 108) -> torch.nn.Module: ... 109 110_magic_methods: Incomplete 111 112class RecursiveScriptClass: 113 _c: Incomplete 114 _props: Incomplete 115 def __init__(self, cpp_class) -> None: ... 116 def __getattr__(self, attr): ... 117 def __setattr__(self, attr, value) -> None: ... 118 def forward_magic_method(self, method_name, *args, **kwargs): ... 119 def __getstate__(self) -> None: ... 120 def __iadd__(self, other): ... 121 122def method_template(self, *args, **kwargs): ... 123 124class ScriptModule(Module, metaclass=ScriptMeta): 125 __jit_unused_properties__: Incomplete 126 def __init__(self) -> None: ... 127 forward: Callable[..., Any] 128 def __getattr__(self, attr): ... 129 def __setattr__(self, attr, value) -> None: ... 130 def define(self, src): ... 131 def _replicate_for_data_parallel(self): ... 132 def __reduce_package__(self, exporter: PackageExporter): ... 133 # add __jit_unused_properties__ 134 @property 135 def code(self) -> str: ... 136 @property 137 def code_with_constants(self) -> tuple[str, ConstMap]: ... 138 @property 139 def graph(self) -> torch.Graph: ... 140 @property 141 def inlined_graph(self) -> torch.Graph: ... 142 @property 143 def original_name(self) -> str: ... 144 145class RecursiveScriptModule(ScriptModule): 146 _disable_script_meta: bool 147 _c: Incomplete 148 def __init__(self, cpp_module) -> None: ... 149 @staticmethod 150 def _construct(cpp_module, init_fn): ... 151 @staticmethod 152 def _finalize_scriptmodule(script_module) -> None: ... 153 _concrete_type: Incomplete 154 _modules: Incomplete 155 _parameters: Incomplete 156 _buffers: Incomplete 157 __dict__: Incomplete 158 def _reconstruct(self, cpp_module) -> None: ... 159 def save(self, f, **kwargs): ... 160 def _save_for_lite_interpreter(self, *args, **kwargs): ... 161 def _save_to_buffer_for_lite_interpreter(self, *args, **kwargs): ... 162 def save_to_buffer(self, *args, **kwargs): ... 163 def get_debug_state(self, *args, **kwargs): ... 164 def extra_repr(self): ... 165 def graph_for(self, *args, **kwargs): ... 166 def define(self, src) -> None: ... 167 def __getattr__(self, attr): ... 168 def __setattr__(self, attr, value) -> None: ... 169 def __copy__(self): ... 170 def __deepcopy__(self, memo): ... 171 def forward_magic_method(self, method_name, *args, **kwargs): ... 172 def __iter__(self): ... 173 def __getitem__(self, idx): ... 174 def __len__(self) -> int: ... 175 def __contains__(self, key) -> bool: ... 176 def __dir__(self): ... 177 def __bool__(self) -> bool: ... 178 def _replicate_for_data_parallel(self): ... 179 180def _get_methods(cls): ... 181 182_compiled_methods_allowlist: Incomplete 183 184def _make_fail(name): ... 185def call_prepare_scriptable_func_impl(obj, memo): ... 186def call_prepare_scriptable_func(obj): ... 187def create_script_dict(obj): ... 188def create_script_list(obj, type_hint: Incomplete | None = ...): ... 189@overload 190def script( 191 obj: type[Module], 192 optimize: bool | None = None, 193 _frames_up: int = 0, 194 _rcb: ResolutionCallback | None = None, 195 example_inputs: list[tuple] | dict[Callable, list[tuple]] | None = None, 196) -> Never: ... 197@overload 198def script( # type: ignore[misc] 199 obj: dict, 200 optimize: bool | None = None, 201 _frames_up: int = 0, 202 _rcb: ResolutionCallback | None = None, 203 example_inputs: list[tuple] | dict[Callable, list[tuple]] | None = None, 204) -> torch.ScriptDict: ... 205@overload 206def script( # type: ignore[misc] 207 obj: list, 208 optimize: bool | None = None, 209 _frames_up: int = 0, 210 _rcb: ResolutionCallback | None = None, 211 example_inputs: list[tuple] | dict[Callable, list[tuple]] | None = None, 212) -> torch.ScriptList: ... 213@overload 214def script( # type: ignore[misc] 215 obj: Module, 216 optimize: bool | None = None, 217 _frames_up: int = 0, 218 _rcb: ResolutionCallback | None = None, 219 example_inputs: list[tuple] | dict[Callable, list[tuple]] | None = None, 220) -> RecursiveScriptModule: ... 221@overload 222def script( # type: ignore[misc] 223 obj: _ClassVar, 224 optimize: bool | None = None, 225 _frames_up: int = 0, 226 _rcb: ResolutionCallback | None = None, 227 example_inputs: list[tuple] | dict[Callable, list[tuple]] | None = None, 228) -> _ClassVar: ... 229@overload 230def script( # type: ignore[misc] 231 obj: Callable, 232 optimize: bool | None = None, 233 _frames_up: int = 0, 234 _rcb: ResolutionCallback | None = None, 235 example_inputs: list[tuple] | dict[Callable, list[tuple]] | None = None, 236) -> ScriptFunction: ... 237@overload 238def script( 239 obj: Any, 240 optimize: bool | None = None, 241 _frames_up: int = 0, 242 _rcb: ResolutionCallback | None = None, 243 example_inputs: list[tuple] | dict[Callable, list[tuple]] | None = None, 244) -> RecursiveScriptClass: ... 245@overload 246def script( 247 obj, 248 optimize: Incomplete | None = ..., 249 _frames_up: int = ..., 250 _rcb: Incomplete | None = ..., 251 example_inputs: list[tuple] | dict[Callable, list[tuple]] | None = ..., 252): ... 253def _check_overload_defaults(impl_defaults, overload_defaults, loc) -> None: ... 254def _compile_function_with_overload(overload_fn, qual_name, impl_fn): ... 255def _get_overloads(obj): ... 256def _check_directly_compile_overloaded(obj) -> None: ... 257def interface(obj): ... 258def _recursive_compile_class(obj, loc): ... 259 260CompilationUnit: Incomplete 261 262def pad(s: str, padding: int, offset: int = ..., char: str = ...): ... 263 264class _ScriptProfileColumn: 265 header: Incomplete 266 alignment: Incomplete 267 offset: Incomplete 268 rows: Incomplete 269 def __init__( 270 self, 271 header: str, 272 alignment: int = ..., 273 offset: int = ..., 274 ) -> None: ... 275 def add_row(self, lineno: int, value: Any): ... 276 def materialize(self): ... 277 278class _ScriptProfileTable: 279 cols: Incomplete 280 source_range: Incomplete 281 def __init__( 282 self, 283 cols: list[_ScriptProfileColumn], 284 source_range: list[int], 285 ) -> None: ... 286 def dump_string(self): ... 287 288class _ScriptProfile: 289 profile: Incomplete 290 def __init__(self) -> None: ... 291 def enable(self) -> None: ... 292 def disable(self) -> None: ... 293 def dump_string(self) -> str: ... 294 def dump(self) -> None: ... 295 296def _unwrap_optional(x): ... 297