xref: /aosp_15_r20/external/pytorch/torch/_C/__init__.pyi.in (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# ${generated_comment}
2# mypy: disable-error-code="type-arg"
3# mypy: allow-untyped-defs
4
5import builtins
6from enum import Enum, IntEnum
7from pathlib import Path
8from typing import (
9    Any,
10    AnyStr,
11    BinaryIO,
12    Callable,
13    ContextManager,
14    Dict,
15    Generic,
16    Iterable,
17    Iterator,
18    List,
19    Literal,
20    NamedTuple,
21    Optional,
22    Protocol,
23    Sequence,
24    Set,
25    SupportsIndex,
26    Tuple,
27    Type,
28    TypeVar,
29    Union,
30    overload,
31    runtime_checkable,
32)
33from typing_extensions import ParamSpec, Self
34
35import numpy
36
37import torch
38from torch import inf, SymInt, Tensor
39from torch.autograd.graph import Node as _Node
40from torch.package import PackageExporter
41from torch.storage import UntypedStorage, TypedStorage
42from torch.types import (
43    _bool,
44    _complex,
45    _device,
46    _dispatchkey,
47    _dtype,
48    _float,
49    _int,
50    _layout,
51    _qscheme,
52    _size,
53    Device,
54    Number,
55    Storage,
56)
57
58from torch._prims_common import DeviceLikeType
59from torch.utils._python_dispatch import TorchDispatchMode
60
61# This module is defined in torch/csrc/Module.cpp
62
63from . import _functorch, _lazy, _lazy_ts_backend, _nn, _onnx, _VariableFunctions, _cpu, _aoti, _verbose
64
65K = TypeVar("K")
66T = TypeVar("T")
67S = TypeVar("S", bound="torch.Tensor")
68P = ParamSpec("P")
69ReturnVal = TypeVar("ReturnVal", covariant=True)  # return value (always covariant)
70_T_co = TypeVar("_T_co", covariant=True)
71
72
73@runtime_checkable
74class _NestedSequence(Protocol[_T_co]):
75    """A protocol for representing nested sequences.
76
77    References::
78        `numpy._typing._NestedSequence`
79        <https://github.com/numpy/numpy/blob/main/numpy/_typing/_nested_sequence.py>
80    """
81
82    def __len__(self, /) -> builtins.int: ...
83    def __getitem__(self, index: builtins.int, /) -> _T_co | _NestedSequence[_T_co]: ...
84    def __contains__(self, x: builtins.object, /) -> builtins.bool: ...
85    def __iter__(self, /) -> Iterator[_T_co | _NestedSequence[_T_co]]: ...
86    def __reversed__(self, /) -> Iterator[_T_co | _NestedSequence[_T_co]]: ...
87    def count(self, value: Any, /) -> builtins.int: ...
88    def index(self, value: Any, /) -> builtins.int: ...
89
90
91# Defined in torch/csrc/Device.cpp
92class device:
93    type: str  # THPDevice_type
94    index: _int  # THPDevice_index
95
96    def __get__(self, instance, owner=None) -> device: ...
97
98    # THPDevice_pynew
99    @overload
100    def __init__(self, device: DeviceLikeType) -> None: ...
101    @overload
102    def __init__(self, type: str, index: _int) -> None: ...
103
104    # Uncomment if we ever make torch.device a decorator
105    # def __call__(self, func: T) -> T: ...
106
107    def __enter__(self) -> device: ...
108    def __exit__(self, exc_type, exc_val, exc_tb) -> None: ...
109    def __reduce__(self) -> Tuple[Any, ...]: ...  # THPDevice_reduce
110
111# Defined in torch/csrc/Stream.cpp
112class Stream:
113    stream_id: _int  # Stream id
114    device_index: _int
115    device_type: _int
116
117    device: _device  # The device of the stream
118
119    @overload
120    def __new__(self, device: Optional[DeviceLikeType] = None, *, priority: _int = 0) -> Stream: ...
121    @overload
122    def __new__(self, stream_id: _int, device_index: _int, device_type: _int, *, priority: _int = 0) -> Stream: ...
123    def query(self) -> _bool: ...
124    def synchronize(self) -> None: ...
125    def wait_event(self, event: Event) -> None: ...
126    def wait_stream(self, other: Stream) -> None: ...
127    def record_event(self, event: Optional[Event] = None) -> Event: ...
128    def __hash__(self) -> _int: ...
129    def __repr__(self) -> str: ...
130    def __eq__(self, other: object) -> _bool: ...
131
132
133# Defined in torch/csrc/Event.cpp
134class Event:
135
136    device: _device  # The device of the Event
137    event_id: _int # The raw event created by device backend
138
139    def __new__(self,
140        device: Optional[DeviceLikeType] = None,
141        *,
142        enable_timing: _bool = False,
143        blocking: _bool = False,
144        interprocess: _bool = False) -> Event: ...
145    @classmethod
146    def from_ipc_handle(self, device: _device, ipc_handle: bytes) -> Event: ...
147    def record(self, stream: Optional[Stream] = None) -> None: ...
148    def wait(self, stream: Optional[Stream] = None) -> None: ...
149    def query(self) -> _bool: ...
150    def elapsed_time(self, other: Event) -> _float: ...
151    def synchronize(self) -> None: ...
152    def ipc_handle(self) -> bytes: ...
153    def __repr__(self) -> str: ...
154
155
156# Defined in torch/csrc/Size.cpp
157class Size(Tuple[_int, ...]):
158    # TODO: __reduce__
159
160    @overload  # type: ignore[override]
161    def __getitem__(self: Size, key: _int) -> _int: ...
162    @overload
163    def __getitem__(self: Size, key: slice) -> Size: ...
164    def numel(self: Size) -> _int: ...
165
166# Defined in torch/csrc/Dtype.cpp
167class dtype:
168    # TODO: __reduce__
169    is_floating_point: _bool
170    is_complex: _bool
171    is_signed: _bool
172    itemsize: _int
173    def to_real(self) -> dtype: ...
174    def to_complex(self) -> dtype: ...
175
176# Defined in torch/csrc/TypeInfo.cpp
177class iinfo:
178    bits: _int
179    min: _int
180    max: _int
181    dtype: str
182
183    def __init__(self, dtype: _dtype) -> None: ...
184
185class finfo:
186    bits: _int
187    min: _float
188    max: _float
189    eps: _float
190    tiny: _float
191    smallest_normal: _float
192    resolution: _float
193    dtype: str
194
195    @overload
196    def __init__(self, dtype: _dtype) -> None: ...
197    @overload
198    def __init__(self) -> None: ...
199
200${dtype_class_hints}
201
202# Defined in torch/csrc/Layout.cpp
203class layout: ...
204
205# Defined in torch/csrc/utils/disable_torch_function.cpp
206def DisableTorchFunction(): ...
207def DisableTorchFunctionSubclass(): ...
208
209# Defined in torch/csrc/utils/tensor_layouts.cpp
210strided: layout = ...
211sparse_coo: layout = ...
212sparse_csr: layout = ...
213sparse_csc: layout = ...
214sparse_bsr: layout = ...
215sparse_bsc: layout = ...
216_mkldnn: layout = ...
217jagged: layout = ...
218
219# Defined in torch/csrc/MemoryFormat.cpp
220class memory_format: ...
221
222# Defined in torch/csrc/utils/tensor_memoryformats.cpp
223contiguous_format: memory_format = ...
224channels_last: memory_format = ...
225channels_last_3d: memory_format = ...
226preserve_format: memory_format = ...
227
228# Defined in torch/csrc/QScheme.cpp
229class qscheme: ...
230
231# Defined in torch/csrc/utils/tensor_qschemes.h
232per_tensor_affine: qscheme = ...
233per_channel_affine: qscheme = ...
234per_tensor_symmetric: qscheme = ...
235per_channel_symmetric: qscheme = ...
236per_channel_affine_float_qparams: qscheme = ...
237
238# Defined in torch/csrc/autograd/python_function.cpp
239class _FunctionBase:
240    saved_tensors: Tuple[Tensor]
241    _raw_saved_tensors: Tuple[Any]
242    next_functions: Tuple[Tuple[Any, _int], ...]
243    needs_input_grad: Tuple[_bool]
244    metadata: dict
245    _materialize_non_diff_grads: _bool
246    # skip adding type hints for the fields that have wrappers defined
247    # in torch/autograd/function.py
248
249# Defined in torch/csrc/autograd/python_legacy_variable.cpp
250class _LegacyVariableBase(Tensor):  # inherits from Tensor to appease mypy
251    def __init__(
252        self,
253        data: Optional[Tensor] = ...,
254        requires_grad: Optional[_bool] = ...,
255        volatile: Optional[_bool] = ...,
256        _grad_fn: Optional[_FunctionBase] = ...,
257    ) -> None: ...
258
259# Defined in torch/csrc/jit/python/init.cpp
260class IODescriptor: ...
261class JITException: ...
262
263class Future(Generic[T]):
264    def __init__(self, devices: List[device]) -> None: ...
265    def done(self) -> _bool: ...
266    def value(self) -> T: ...
267    def wait(self) -> T: ...
268    def add_done_callback(self, callback: Callable) -> None: ...
269    def then(self, callback: Callable) -> Future[T]: ...
270    def set_result(self, result: T) -> None: ...
271    def _set_unwrap_func(self, callback: Callable) -> None: ...
272
273class _Await:
274    def __init__(self) -> None: ...
275    def fn(self) -> Callable: ...
276    def args(self) -> Tuple[Any, ...]: ...
277    def is_nowait(self) -> _bool: ...
278
279def _jit_set_num_profiled_runs(num: _size) -> _size: ...
280
281# Defined in torch/csrc/jit/passes/mobile_optimizer_type.h
282class _MobileOptimizerType: ...
283
284CONV_BN_FUSION: _MobileOptimizerType
285INSERT_FOLD_PREPACK_OPS: _MobileOptimizerType
286REMOVE_DROPOUT: _MobileOptimizerType
287FUSE_ADD_RELU: _MobileOptimizerType
288HOIST_CONV_PACKED_PARAMS: _MobileOptimizerType
289VULKAN_AUTOMATIC_GPU_TRANSFER: _MobileOptimizerType
290
291def fork(*args: Any, **kwargs: Any) -> Future: ...
292def wait(fut: Future) -> Any: ...
293def _awaitable(*args: Any, **kwargs: Any) -> _Await: ...
294def _awaitable_wait(aw: _Await) -> Any: ...
295def _awaitable_nowait(x: Any) -> _Await: ...
296def _collect_all(futures: List[Future]) -> Future: ...
297def _set_print_stack_traces_on_fatal_signal(print: _bool) -> None: ...
298def unify_type_list(types: List[JitType]) -> JitType: ...
299def _freeze_module(
300    module: ScriptModule,
301    preserved_attrs: List[str] = [],
302    freeze_interfaces: _bool = True,
303    preserveParameters: _bool = True,
304) -> ScriptModule: ...
305def _jit_pass_optimize_frozen_graph(Graph, optimize_numerics: _bool = True) -> None: ...
306def _jit_pass_optimize_for_inference(
307    module: torch.jit.ScriptModule,
308    other_methods: List[str] = [],
309) -> None: ...
310def _jit_pass_fold_frozen_conv_bn(graph: Graph): ...
311def _jit_pass_fold_frozen_conv_add_or_sub(graph: Graph): ...
312def _jit_pass_fold_frozen_conv_mul_or_div(graph: Graph): ...
313def _jit_pass_fuse_frozen_conv_add_relu(graph: Graph): ...
314def _jit_pass_concat_frozen_linear(graph: Graph): ...
315def _jit_pass_convert_frozen_ops_to_mkldnn(graph: Graph): ...
316def _jit_pass_transpose_frozen_linear(graph: Graph): ...
317def _jit_pass_remove_dropout(module: torch.jit.ScriptModule): ...
318def _is_tracing() -> _bool: ...
319def _jit_init() -> _bool: ...
320def _jit_flatten(arg: Any) -> Tuple[List[Tensor], IODescriptor]: ...
321def _jit_unflatten(vars: List[Tensor], desc: IODescriptor) -> Any: ...
322def _jit_get_operation(op_name: str) -> Tuple[Callable, List[str]]: ...
323def _get_operation_overload(
324    op_name: str,
325    op_overload_name: str,
326) -> Tuple[Callable, Callable, List[Any]]: ...
327def _get_schema(op_name: str, overload_name: str) -> FunctionSchema: ...
328def _jit_pass_optimize_for_mobile(
329    module: torch.jit.ScriptModule,
330    optimization_blocklist: Set[_MobileOptimizerType],
331    preserved_methods: List[AnyStr],
332) -> torch.jit.ScriptModule: ...
333def _clone_module_with_class(
334    module: torch.jit.ScriptModule,
335    ignored_methods: List[AnyStr],
336    ignored_attributes: List[AnyStr],
337) -> torch.jit.ScriptModule: ...
338def _jit_pass_vulkan_optimize_for_mobile(
339    module: torch.jit.ScriptModule,
340    optimization_blocklist: Set[_MobileOptimizerType],
341    preserved_methods: List[AnyStr],
342) -> torch.jit.ScriptModule: ...
343def _jit_pass_metal_optimize_for_mobile(
344    module: torch.jit.ScriptModule,
345    preserved_methods: List[AnyStr],
346) -> torch.jit.ScriptModule: ...
347def _jit_pass_inline(Graph) -> None: ...
348def _jit_pass_constant_propagation(Graph) -> None: ...
349def _jit_pass_propagate_shapes_on_graph(Graph) -> None: ...
350def _jit_register_decomposition_for_schema(schema: FunctionSchema, Graph) -> None: ...
351def _jit_erase_non_input_shape_information(Graph) -> None: ...
352def _jit_get_schemas_for_operator(name: str) -> List[FunctionSchema]: ...
353def _jit_get_all_schemas() -> List[FunctionSchema]: ...
354def _jit_check_alias_annotation(
355    g: Graph,
356    args: Tuple[Any, ...],
357    unqualified_op_name: str,
358): ...
359def _jit_can_fuse_on_cpu() -> _bool: ...
360def _jit_can_fuse_on_gpu() -> _bool: ...
361def _jit_can_fuse_on_cpu_legacy() -> _bool: ...
362def _debug_get_fusion_group_inlining() -> _bool: ...
363def _debug_set_fusion_group_inlining(enable: _bool): ...
364def _jit_texpr_fuser_enabled() -> _bool: ...
365def _jit_nvfuser_enabled() -> _bool: ...
366def _jit_llga_enabled() -> _bool: ...
367def _jit_set_llga_enabled(enable: _bool): ...
368def _llvm_enabled() -> _bool: ...
369def _jit_override_can_fuse_on_cpu(override: _bool): ...
370def _jit_override_can_fuse_on_gpu(override: _bool): ...
371def _jit_override_can_fuse_on_cpu_legacy(override: _bool): ...
372def _jit_set_symbolic_shapes_test_mode(override: _bool): ...
373def _jit_symbolic_shapes_test_mode_enabled() -> _bool: ...
374def _jit_set_texpr_fuser_enabled(enable: _bool): ...
375def _jit_set_te_must_use_llvm_cpu(use_llvm: _bool): ...
376def _jit_set_nvfuser_enabled(enable: _bool) -> _bool: ...
377def _jit_cat_wo_conditionals(optimize_cat: _bool): ...
378def _jit_opt_conditionals(opt_conds: _bool): ...
379def _jit_pass_canonicalize(graph: Graph, keep_unique_names: _bool = True): ...
380def _jit_pass_erase_shape_information(graph: Graph): ...
381def _jit_pass_fold_convbn(module: torch.jit.ScriptModule): ...
382def _jit_pass_insert_observers(
383    module: torch.jit.ScriptModule,
384    method_name: str,
385    qconfig_dict: Dict[str, Any],
386    inplace: _bool,
387    quant_type: _int,
388): ...
389def _jit_pass_insert_quant_dequant(
390    module: torch.jit.ScriptModule,
391    method_name: str,
392    inplace: _bool,
393    debug: _bool,
394    quant_type: _int,
395): ...
396def _jit_pass_insert_quant_dequant_for_ondevice_ptq(
397    module: torch.jit.ScriptModule,
398    method_name: str,
399    inplace: _bool,
400    debug: _bool,
401    quant_type: _int,
402): ...
403def _jit_pass_quant_finalize(
404    module: torch.jit.ScriptModule,
405    quant_type: _int,
406    preserved_attrs: Sequence[str],
407): ...
408def _jit_pass_quant_finalize_for_ondevice_ptq(
409    module: torch.jit.ScriptModule,
410    quant_type: _int,
411    method_name: str,
412): ...
413def _jit_pass_insert_observer_method_for_ondevice_ptq(
414    module: torch.jit.ScriptModule,
415    method_name: str,
416    qconfig_dict: Dict[str, Any],
417    inplace: _bool,
418    quant_type: _int,
419): ...
420def _jit_set_profiling_executor(profiling_flag: _bool) -> _bool: ...
421def _jit_set_profiling_mode(profiling_flag: _bool) -> _bool: ...
422def _jit_set_fusion_strategy(
423    strategy: List[Tuple[str, _int]],
424) -> List[Tuple[str, _int]]: ...
425def _jit_try_infer_type(obj: Any) -> InferredType: ...
426def _jit_get_trigger_value(trigger_name: str) -> _int: ...
427
428# Defined in torch/csrc/jit/python/script_init.cpp
429ResolutionCallback = Callable[[str], Callable[..., Any]]
430
431# Defined in torch/csrc/jit/python/script_init.cpp
432#        and torch/csrc/jit/python/init.cpp
433def _maybe_call_torch_function_for_op_packet(
434    op_overload_packet: Any,
435    args: Any,
436    kwargs: Any,
437) -> Any: ...
438def _check_schema_allow_fake_script_object(
439    schema: FunctionSchema,
440    args: Any,
441    kwargs: Any,
442) -> _bool: ...
443def _create_function_from_graph(qualname: str, graph: Graph) -> ScriptFunction: ...
444def _debug_set_autodiff_subgraph_inlining(disabled: _bool) -> None: ...
445def _ivalue_tags_match(lhs: ScriptModule, rhs: ScriptModule) -> _bool: ...
446def _jit_assert_is_instance(obj: Any, type: JitType): ...
447def _jit_clear_class_registry() -> None: ...
448def _jit_set_emit_hooks(
449    ModuleHook: Optional[Callable],
450    FunctionHook: Optional[Callable],
451) -> None: ...
452def _jit_get_emit_hooks() -> Tuple[Callable, Callable]: ...
453def _load_for_lite_interpreter(
454    filename: Union[str, Path],
455    map_location: Optional[DeviceLikeType],
456): ...
457def _load_for_lite_interpreter_from_buffer(
458    buffer: BinaryIO,
459    map_location: Optional[DeviceLikeType],
460): ...
461def _export_operator_list(module: LiteScriptModule): ...
462def _quantize_ondevice_ptq_dynamic(module: LiteScriptModule, method_name: str): ...
463def _get_model_bytecode_version(filename: Union[str, Path]) -> _int: ...
464def _get_model_bytecode_version_from_buffer(buffer: BinaryIO) -> _int: ...
465def _backport_for_mobile(
466    filename_input: Union[str, Path],
467    filename_output: Union[str, Path],
468    to_version: _int,
469) -> None: ...
470def _backport_for_mobile_from_buffer(
471    buffer: BinaryIO,
472    filename_output: Union[str, Path],
473    to_version: _int,
474) -> None: ...
475def _backport_for_mobile_to_buffer(
476    filename_input: Union[str, Path],
477    to_version: _int,
478) -> bytes: ...
479def _backport_for_mobile_from_buffer_to_buffer(
480    buffer: BinaryIO,
481    to_version: _int,
482) -> bytes: ...
483def _get_model_ops_and_info(filename: Union[str, Path]): ...
484def _get_model_ops_and_info_from_buffer(buffer: BinaryIO): ...
485def _get_mobile_model_contained_types(filename: Union[str, Path]): ...
486def _get_mobile_model_contained_types_from_buffer(buffer: BinaryIO): ...
487def _logging_set_logger(logger: LoggerBase) -> LoggerBase: ...
488def _get_graph_executor_optimize(optimize: Optional[_bool] = None) -> _bool: ...
489def _set_graph_executor_optimize(optimize: _bool): ...
490def _export_opnames(module: ScriptModule) -> List[str]: ...
491def _create_function_from_trace(
492    qualname: str,
493    func: Callable[..., Any],
494    input_tuple: Tuple[Any, ...],
495    var_lookup_fn: Callable[[Tensor], str],
496    strict: _bool,
497    force_outplace: _bool,
498    argument_names: List[str],
499) -> Tuple[Graph, Stack]: ...
500def _create_function_from_trace_with_dict(
501    qualname: str,
502    func: Callable[..., Any],
503    input_dict: Dict[str, Any],
504    var_lookup_fn: Callable[[Tensor], str],
505    strict: _bool,
506    force_outplace: _bool,
507    argument_names: List[str],
508) -> Tuple[Graph, Stack]: ...
509def _jit_is_script_object(obj: Any) -> _bool: ...
510def _last_executed_optimized_graph() -> Graph: ...
511def parse_type_comment(comment: str) -> Decl: ...
512def _get_upgraders_map_size() -> _int: ...
513def _get_upgraders_entry_map() -> Dict[str, str]: ...
514def _dump_upgraders_map() -> Dict[str, str]: ...
515def _test_only_populate_upgraders(content: Dict[str, str]) -> None: ...
516def _test_only_remove_upgraders(content: Dict[str, str]) -> None: ...
517def merge_type_from_type_comment(
518    decl: Decl,
519    type_annotation_decl: Decl,
520    is_method: _bool,
521) -> Decl: ...
522def parse_ir(input: str, parse_tensor_constants: _bool = False) -> Graph: ...
523def parse_schema(schema: str) -> FunctionSchema: ...
524def get_device(input: Tensor) -> _int: ...
525def _resolve_type_from_object(
526    obj: Any,
527    range: SourceRange,
528    rcb: ResolutionCallback,
529) -> JitType: ...
530def _create_module_with_type(ty: JitType) -> ScriptModule: ...
531def _create_object_with_type(ty: ClassType) -> ScriptObject: ...
532def _run_emit_module_hook(m: ScriptModule): ...
533def _replace_overloaded_method_decl(
534    overload_decl: Decl,
535    implementation_def: Def,
536    new_name: str,
537) -> Def: ...
538def _jit_pass_lower_all_tuples(graph: Graph) -> None: ...
539def _jit_pass_onnx_set_dynamic_input_shape(
540    graph: Graph,
541    dynamic_axes: Dict[str, Dict[_int, str]],
542    input_names: List[str],
543) -> None: ...
544def _jit_pass_onnx_graph_shape_type_inference(
545    graph: Graph,
546    params_dict: Dict[str, IValue],
547    opset_version: _int,
548) -> None: ...
549def _jit_pass_onnx_assign_output_shape(
550    graph: Graph,
551    tensors: List[Tensor],
552    desc: IODescriptor,
553    onnx_shape_inference: _bool,
554    is_script: _bool,
555    opset_version: _int,
556) -> None: ...
557def _jit_pass_onnx_remove_inplace_ops_for_onnx(
558    graph: Graph,
559    module: Optional[ScriptModule] = None,
560) -> None: ...
561def _jit_pass_remove_inplace_ops(graph: Graph) -> None: ...
562def _jit_pass_canonicalize_graph_fuser_ops(graph: Graph) -> None: ...
563def _jit_pass_peephole(
564    graph: Graph,
565    disable_shape_peepholes: _bool = False,
566) -> None: ...
567def _jit_pass_onnx_autograd_function_process(graph: Graph) -> None: ...
568def _jit_pass_fuse_addmm(graph: Graph) -> None: ...
569def _jit_pass_onnx_preprocess(graph: Graph) -> None: ...
570def _jit_pass_prepare_division_for_onnx(graph: Graph) -> None: ...
571def _jit_pass_onnx_remove_print(graph: Graph) -> None: ...
572def _jit_pass_onnx_preprocess_caffe2(graph: Graph) -> None: ...
573def _jit_pass_onnx_unpack_quantized_weights(
574    graph: Graph,
575    paramsDict: Dict[str, IValue],
576    caffe2: _bool,
577) -> Dict[str, IValue]: ...
578def _jit_pass_onnx_quantization_insert_permutes(
579    graph: Graph,
580    paramsDict: Dict[str, IValue],
581) -> Dict[str, IValue]: ...
582def _jit_pass_custom_pattern_based_rewrite_graph(
583    pattern: str,
584    fused_node_name: str,
585    graph: Graph,
586) -> None: ...
587def _jit_onnx_list_model_parameters(
588    module: ScriptModule,
589) -> Tuple[ScriptModule, List[IValue]]: ...
590def _jit_pass_erase_number_types(graph: Graph) -> None: ...
591def _jit_pass_onnx_lint(graph: Graph) -> None: ...
592def _jit_pass_onnx(
593    graph: Graph,
594    _jit_pass_onnx: _onnx.OperatorExportTypes,
595) -> Graph: ...
596def _jit_pass_onnx_scalar_type_analysis(
597    graph: Graph,
598    lowprecision_cast: _bool,
599    opset_version: _int,
600) -> None: ...
601def _jit_pass_onnx_peephole(
602    graph: Graph,
603    opset_version: _int,
604    fixed_batch_size: _bool,
605) -> None: ...
606def _jit_pass_dce_allow_deleting_nodes_with_side_effects(graph: Graph) -> None: ...
607def _jit_pass_onnx_function_substitution(graph: Graph) -> None: ...
608def _jit_pass_onnx_function_extraction(
609    graph: Graph,
610    module_names: Set[str],
611    param_names: List[str],
612) -> Dict[Node, Dict[str, str]]: ...
613def _jit_pass_onnx_clear_scope_records() -> None: ...
614def _jit_pass_onnx_track_scope_attributes(
615    graph: Graph,
616    onnx_attrs: Dict[str, Any],
617) -> None: ...
618def _jit_is_onnx_log_enabled() -> _bool: ...
619def _jit_set_onnx_log_enabled(enabled: _bool) -> None: ...
620def _jit_set_onnx_log_output_stream(stream_name: str) -> None: ...
621def _jit_onnx_log(*args: Any) -> None: ...
622def _jit_pass_lower_graph(graph: Graph, m: Module) -> Tuple[Graph, List[IValue]]: ...
623def _jit_pass_inline_fork_wait(graph: Graph) -> None: ...
624def _jit_pass_onnx_deduplicate_initializers(
625    graph: Graph,
626    params_dict: Dict[str, IValue],
627    is_train: _bool,
628) -> Dict[str, IValue]: ...
629def _jit_pass_onnx_eval_peephole(
630    graph: Graph,
631    paramsDict: Dict[str, IValue],
632) -> Dict[str, IValue]: ...
633def _jit_pass_onnx_constant_fold(
634    graph: Graph,
635    paramsDict: Dict[str, IValue],
636    opset_version: _int,
637) -> Dict[str, IValue]: ...
638def _jit_pass_onnx_eliminate_unused_items(
639    graph: Graph,
640    paramsDict: Dict[str, IValue],
641) -> Dict[str, IValue]: ...
642def _jit_pass_onnx_cast_all_constant_to_floating(graph: Graph) -> None: ...
643def _jit_pass_filter_non_tensor_arguments(
644    params: Dict[str, IValue],
645) -> Dict[str, Tensor]: ...
646def _jit_decay_packed_param_input_types(graph: Graph) -> None: ...
647def _jit_pass_onnx_node_shape_type_inference(
648    n: Node,
649    paramsDict: Dict[str, IValue],
650    opset_version: _int,
651) -> None: ...
652def _jit_onnx_convert_pattern_from_subblock(
653    block: Block,
654    n: Node,
655    env: Dict[Value, Value],
656    values_in_env: Set[Value],
657) -> List[Value]: ...
658def _jit_pass_onnx_block(
659    old_block: Block,
660    new_block: Block,
661    operator_export_type: _onnx.OperatorExportTypes,
662    env: Dict[Value, Value],
663    values_in_env: Set[Value],
664    is_sub_block: _bool,
665) -> Dict[Value, Value]: ...
666def _jit_pass_onnx_assign_scoped_names_for_node_and_value(graph: Graph) -> None: ...
667def _jit_pass_fixup_onnx_controlflow_node(
668    n: Node,
669    opset_version: _int,
670) -> List[Value]: ...
671def _jit_onnx_create_full_scope_name(class_name: str, variable_name: str) -> str: ...
672def _compile_graph_to_code_table(name: str, graph: Graph) -> IValue: ...
673def _generate_upgraders_graph() -> Dict[str, Graph]: ...
674def _calculate_package_version_based_on_upgraders(val: _bool): ...
675def _get_version_calculator_flag() -> _bool: ...
676def _jit_script_interface_compile(
677    name: str,
678    class_def: ClassDef,
679    rcb: ResolutionCallback,
680    is_module: _bool,
681): ...
682def _jit_script_compile_overload(
683    qualname: str,
684    overload_decl: Decl,
685    implementation_def: Def,
686    rcb: ResolutionCallback,
687    implementation_defaults: Dict[str, Any],
688    signature: Any,
689): ...
690def _jit_script_compile(
691    qual_name: str,
692    definition: Def,
693    rcb: ResolutionCallback,
694    defaults: Dict[str, Any],
695): ...
696def _jit_script_class_compile(
697    qual_name: str,
698    definition: ClassDef,
699    defaults: Dict[str, Dict[str, Any]],
700    rcb: ResolutionCallback,
701): ...
702def _parse_source_def(src: str) -> Def: ...
703def import_ir_module(
704    cu: CompilationUnit,
705    filename: Union[str, Path],
706    map_location: Optional[DeviceLikeType],
707    extra_files: Dict[str, Any],
708) -> ScriptModule: ...
709def import_ir_module_from_buffer(
710    cu: CompilationUnit,
711    buffer: BinaryIO,
712    map_location: Optional[DeviceLikeType],
713    extra_files: Dict[str, Any],
714) -> ScriptModule: ...
715def _import_ir_module_from_package(
716    cu: CompilationUnit,
717    reader: PyTorchFileReader,
718    storage_context: DeserializationStorageContext,
719    map_location: Optional[DeviceLikeType],
720    ts_id: str,
721) -> ScriptModule: ...
722def _assign_output_shapes(graph: Graph, inputs: List[Tensor]) -> Graph: ...
723def _check_onnx_proto(proto: str) -> None: ...
724def _propagate_and_assign_input_shapes(
725    graph: Graph,
726    inputs: Tuple[Tensor, ...],
727    param_count_list: List[_int],
728    with_grad: _bool,
729    propagate: _bool,
730) -> Graph: ...
731
732# Defined in torch/csrc/jit/runtime/graph_executor.h
733class GraphExecutorState: ...
734
735# Defined in torch/torch/csrc/jit/ir/alias_analysis.h
736class AliasDb:
737    def __str__(self) -> str: ...
738
739class _InsertPoint:
740    def __enter__(self) -> None: ...
741    def __exit__(self, *args) -> None: ...
742
743# Defined in torch/csrc/jit/ir/ir.h
744class Use:
745    @property
746    def user(self) -> Node: ...
747    @property
748    def offset(self) -> _int: ...
749    def isAfter(self, other: Use) -> _bool: ...
750
751# Defined in torch/csrc/jit/ir/ir.h
752class Value:
753    def type(self) -> JitType: ...
754    def setType(self, t: JitType) -> Value: ...
755    def setTypeAs(self, other: Value) -> Value: ...
756    def inferTypeFrom(self, t: Tensor) -> None: ...
757    def debugName(self) -> str: ...
758    def setDebugName(self, name: str) -> None: ...
759    def unique(self) -> _int: ...
760    def offset(self) -> _int: ...
761    def node(self) -> Node: ...
762    def uses(self) -> List[Use]: ...
763    def replaceAllUsesWith(self, val: Value) -> None: ...
764    def replaceAllUsesAfterNodeWith(self, node: Node, val: Value) -> None: ...
765    def requires_grad(self) -> _bool: ...
766    def requiresGrad(self) -> _bool: ...
767    def copyMetadata(self, other: Value) -> Value: ...
768    def isCompleteTensor(self) -> _bool: ...
769    def toIValue(self) -> IValue: ...
770
771# Defined in torch/csrc/jit/ir/ir.h
772class Block:
773    def inputs(self) -> Iterator[Value]: ...
774    def outputs(self) -> Iterator[Value]: ...
775    def nodes(self) -> Iterator[Node]: ...
776    def paramNode(self) -> Node: ...
777    def returnNode(self) -> Node: ...
778    def owningNode(self) -> Node: ...
779    def registerOutput(self, n: Value) -> _int: ...
780    def addNode(self, name: str, inputs: Sequence[Value]) -> Node: ...
781
782# Defined in torch/csrc/jit/ir/ir.h
783class Node:
784    def __getitem__(self, key: str) -> Any: ...
785    def schema(self) -> str: ...
786    def input(self) -> Value: ...
787    def inputs(self) -> Iterator[Value]: ...
788    def inputsAt(self, idx: _int) -> Value: ...
789    def inputsSize(self) -> _int: ...
790    def output(self) -> Value: ...
791    def outputs(self) -> Iterator[Value]: ...
792    def outputsAt(self, idx: _int) -> Value: ...
793    def outputsSize(self) -> _int: ...
794    def hasMultipleOutputs(self) -> _bool: ...
795    def blocks(self) -> List[Block]: ...
796    def addBlock(self) -> Block: ...
797    def mustBeNone(self) -> _bool: ...
798    def matches(self, pattern: str) -> _bool: ...
799    def kind(self) -> str: ...
800    def kindOf(self, name: str) -> str: ...
801    def addInput(self, name: str) -> Value: ...
802    def replaceInput(self, i: _int, newValue: Value) -> Value: ...
803    def replaceInputWith(self, from_: Value, to: Value) -> None: ...
804    def replaceAllUsesWith(self, n: Node) -> None: ...
805    def insertBefore(self, n: Node) -> Node: ...
806    def insertAfter(self, n: Node) -> Node: ...
807    def isBefore(self, n: Node) -> _bool: ...
808    def isAfter(self, n: Node) -> _bool: ...
809    def moveBefore(self, n: Node) -> None: ...
810    def moveAfter(self, n: Node) -> None: ...
811    def removeInput(self, i: _int) -> None: ...
812    def removeAllInputs(self, i: _int) -> None: ...
813    def hasUses(self) -> _bool: ...
814    def eraseOutput(self, i: _int) -> None: ...
815    def addOutput(self) -> Value: ...
816    def scopeName(self) -> str: ...
817    def isNondeterministic(self) -> _bool: ...
818    def copyAttributes(self, rhs: Node) -> Node: ...
819    def copyMetadata(self, rhs: Node) -> Node: ...
820    def hasAttributes(self) -> _bool: ...
821    def hasAttribute(self, name: str) -> _bool: ...
822    def removeAttribute(self, attr: str) -> Node: ...
823    def namedInput(self, name: str) -> Value: ...
824    def sourceRange(self) -> SourceRange: ...
825    def owningBlock(self) -> Block: ...
826    def findNode(self, kind: str, recurse: _bool = True) -> Node: ...
827    def findAllNodes(self, kind: str, recurse: _bool = True) -> List[Node]: ...
828    def getModuleHierarchy(self) -> str: ...
829    def prev(self) -> Node: ...
830    def destroy(self) -> None: ...
831    def attributeNames(self) -> List[str]: ...
832
833    # Accessors for attributes as types.
834    def f(self, name: str) -> _float: ...
835    def f_(self, name: str, val: _float) -> Node: ...
836    def fs(self, name: str) -> List[_float]: ...
837    def fs_(self, name: str, val: List[_float]) -> Node: ...
838    def c(self, name: str) -> complex: ...
839    def c_(self, name: str, val: complex) -> Node: ...
840    def s(self, name: str) -> str: ...
841    def s_(self, name: str, val: str) -> Node: ...
842    def ss(self, name: str) -> List[str]: ...
843    def ss_(self, name: str, val: List[str]) -> Node: ...
844    def i(self, name: str) -> _int: ...
845    def i_(self, name: str, val: _int) -> Node: ...
846    # Cannot define "is" like this because it's a reserved keyword in python.
847    # def is(self, name: str) -> List[_int]: ...
848    # def is_(self, name: str, val: List[_int]) -> Node: ...
849    def g(self, name: str) -> Graph: ...
850    def g_(self, name: str, val: Graph) -> Node: ...
851    def gs(self, name: str) -> List[Graph]: ...
852    def gs_(self, name: str, val: List[Graph]) -> Node: ...
853    def ival(self, name: str) -> IValue: ...
854    def ival_(self, name: str, val: IValue) -> Node: ...
855    def t(self, name: str) -> Tensor: ...
856    def t_(self, name: str, val: Tensor) -> Node: ...
857    def ts(self, name: str) -> List[Tensor]: ...
858    def ts_(self, name: str, val: List[Tensor]) -> Node: ...
859    def ty(self, name: str) -> JitType: ...
860    def ty_(self, name: str, val: JitType) -> Node: ...
861    def tys(self, name: str) -> List[JitType]: ...
862    def tys_(self, name: str, val: List[JitType]) -> Node: ...
863
864# Defined in torch/torch/csrc/jit/ir/ir.h
865class Graph:
866    def inputs(self) -> Iterator[Value]: ...
867    def outputs(self) -> Iterator[Value]: ...
868    def nodes(self) -> Iterator[Node]: ...
869    def param_node(self) -> Node: ...
870    def return_node(self) -> Node: ...
871    def addInput(self, name: str = "") -> Value: ...
872    def eraseInput(self, i: _int) -> None: ...
873    def registerOutput(self, n: Value) -> _int: ...
874    def eraseOutput(self, i: _int) -> None: ...
875    def create(self, name: str, args, num_outputs: _int) -> Node: ...
876    def appendNode(self, n: Node) -> Node: ...
877    def prependNode(self, n: Node) -> Node: ...
878    def insertNode(self, n: Node) -> Node: ...
879    def block(self) -> Block: ...
880    def lint(self) -> None: ...
881    def alias_db(self) -> AliasDb: ...
882    def setInsertPoint(self, n: Union[Block, Node]) -> None: ...
883    def insert_point_guard(self, n: Union[Block, Node]) -> _InsertPoint: ...
884    def insertPoint(self) -> Node: ...
885    def insertGraph(self, callee: Graph, inputs: List[Value]) -> List[Value]: ...
886    def makeMultiOutputIntoTuple(self) -> None: ...
887    def copy(self) -> Graph: ...
888
889# Defined in torch/aten/src/ATen/core/alias_info.h
890class AliasInfo:
891    is_write: _bool
892    before_set: Set[str]
893    after_set: Set[str]
894
895# Defined in torch/aten/src/ATen/core/function_schema.h
896class Argument:
897    name: str
898    type: JitType
899    default_value: Optional[Any]
900    def has_default_value(self) -> _bool: ...
901    kwarg_only: _bool
902    is_out: _bool
903    alias_info: Optional[AliasInfo]
904
905class FunctionSchema:
906    arguments: List[Argument]
907    returns: List[Argument]
908    name: str
909    overload_name: str
910    is_mutable: _bool
911
912class _UpgraderEntry:
913    bumped_at_version: _int
914    upgrader_name: str
915    old_schema: str
916    def __init__(
917        self,
918        bumped_at_version: _int,
919        upgrader_name: str,
920        old_schema: str,
921    ) -> None: ...
922
923class _UpgraderRange:
924    min_version: _int
925    max_version: _int
926
927def _get_max_operator_version() -> _int: ...
928def _get_operator_version_map() -> Dict[str, List[_UpgraderEntry]]: ...
929def _get_upgrader_ranges(name: str) -> List[_UpgraderRange]: ...
930def _test_only_add_entry_to_op_version(op_name: str, entry: _UpgraderEntry) -> None: ...
931def _test_only_remove_entry_to_op_version(op_name: str) -> None: ...
932
933# Defined in torch/csrc/jit/python/script_init.cpp
934class ScriptModuleSerializer:
935    def __init__(self, export_writer: PyTorchFileWriter) -> None: ...
936    def serialize(self, model: ScriptModule, script_module_id: _int) -> None: ...
937    def write_files(self) -> None: ...
938    def storage_context(self) -> SerializationStorageContext: ...
939
940# Defined in torch/csrc/jit/python/script_init.cpp
941class SerializationStorageContext:
942    def __init__(self) -> None: ...
943    def has_storage(self, storage: Storage) -> _bool: ...
944    def get_or_add_storage(self, storage: Storage) -> _int: ...
945
946# Defined in torch/csrc/jit/python/script_init.cpp
947class DeserializationStorageContext:
948    def __init__(self) -> None: ...
949    def get_storage(self, name: str, dtype: _dtype) -> Tensor: ...
950    def has_storage(self, name: str) -> _bool: ...
951    def add_storage(self, name: str, tensor: Tensor) -> _int: ...
952
953# Defined in torch/csrc/jit/python/script_init.cpp
954class ConcreteModuleTypeBuilder:
955    def __init__(self, obj: Any) -> None: ...
956    def set_module_dict(self): ...
957    def set_module_list(self): ...
958    def set_parameter_list(self): ...
959    def set_parameter_dict(self): ...
960    def add_attribute(
961        self,
962        name: str,
963        ty: JitType,
964        is_param: _bool,
965        is_buffer: _bool,
966    ): ...
967    def add_module(self, name: str, meta: ConcreteModuleType): ...
968    def add_constant(self, name: str, value: Any): ...
969    def add_overload(self, method_name: str, overloaded_method_names: List[str]): ...
970    def add_builtin_function(self, name: str, symbol_name: str): ...
971    def add_failed_attribute(self, name: str, failure_reason: str): ...
972    def add_function_attribute(
973        self,
974        name: str,
975        ty: JitType,
976        func: Callable[..., Any],
977    ): ...
978    def add_ignored_attribute(self, name: str): ...
979    def add_ignored_attributes(self, names: List[str]): ...
980    def add_forward_hook(self, hook: Callable[..., Any]): ...
981    def add_forward_pre_hook(self, pre_hook: Callable[..., Any]): ...
982
983class ConcreteModuleType:
984    def get_constants(self) -> Dict[str, Any]: ...
985    def equals(self, other: ConcreteModuleType) -> _bool: ...
986    @staticmethod
987    def from_jit_type(ty: JitType) -> ConcreteModuleType: ...
988
989class CallStack:
990    def __init__(self, name: str, range: SourceRange): ...
991
992class ErrorReport:
993    def __init__(self, range: SourceRange) -> None: ...
994    def what(self) -> str: ...
995    @staticmethod
996    def call_stack() -> str: ...
997
998class CompilationUnit:
999    def __init__(self, lang: str = ..., _frames_up: _int = ...) -> None: ...
1000    def find_function(self, name: str) -> ScriptFunction: ...
1001    def __getattr__(self, name: str) -> ScriptFunction: ...
1002    def define(
1003        self,
1004        script: str,
1005        rcb: ResolutionCallback = ...,
1006        _frames_up: _int = ...,
1007    ): ...
1008    def get_interface(self, name: str) -> InterfaceType: ...
1009    def get_functions(self) -> List[ScriptFunction]: ...
1010    def create_function(
1011        self,
1012        name: str,
1013        graph: Graph,
1014        shouldMangle: _bool = ...,
1015    ) -> ScriptFunction: ...
1016    def get_class(self, name: str) -> ClassType: ...
1017
1018class ScriptObject:
1019    def setattr(self, name: str, value: Any): ...
1020
1021class ScriptModule(ScriptObject):
1022    def _method_names(self) -> List[str]: ...
1023    def _get_method(self, name: str) -> ScriptMethod: ...
1024
1025class LiteScriptModule:
1026    def __call__(self, *input): ...
1027    def find_method(self, method_name: str): ...
1028    def forward(self, *input) -> List[str]: ...
1029    def run_method(self, method_name: str, *input): ...
1030
1031# NOTE: switch to collections.abc.Callable in python 3.9
1032class ScriptFunction(Generic[P, ReturnVal]):
1033    def __call__(self, *args: P.args, **kwargs: P.kwargs) -> ReturnVal: ...
1034    def save(self, filename: str, _extra_files: Dict[str, bytes]) -> None: ...
1035    def save_to_buffer(self, _extra_files: Dict[str, bytes]) -> bytes: ...
1036    @property
1037    def graph(self) -> Graph: ...
1038    def inlined_graph(self) -> Graph: ...
1039    def schema(self) -> FunctionSchema: ...
1040    def code(self) -> str: ...
1041    def name(self) -> str: ...
1042    @property
1043    def qualified_name(self) -> str: ...
1044
1045# NOTE: switch to collections.abc.Callable in python 3.9
1046class ScriptMethod(Generic[P, ReturnVal]):
1047    graph: Graph
1048    def __call__(self, *args: P.args, **kwargs: P.kwargs) -> ReturnVal: ...
1049    @property
1050    def owner(self) -> ScriptModule: ...
1051    @property
1052    def name(self) -> str: ...
1053
1054class ScriptDict(Generic[K, T]):
1055    def __init__(self, dict: Dict[K, T]) -> None: ...
1056    def __len__(self) -> _int: ...
1057    def __contains__(self, key: K) -> _bool: ...
1058    def __getitem__(self, key: K) -> T: ...
1059    def __setitem__(self, key: K, value: T) -> None: ...
1060    def __delitem__(self, key: K) -> None: ...
1061    def __iter__(self) -> Iterator[K]: ...
1062    def items(self) -> Iterator[tuple[K, T]]: ...
1063    def keys(self) -> Iterator[K]: ...
1064
1065class ScriptList(Generic[T]):
1066    def __init__(self, list: List[T]) -> None: ...
1067    def __len__(self) -> _int: ...
1068    def __contains__(self, item: T) -> _bool: ...
1069    @overload
1070    def __getitem__(self, idx: _int) -> T: ...
1071    @overload
1072    def __getitem__(self, idx: slice) -> ScriptList[T]: ...
1073    @overload
1074    def __setitem__(self, idx: _int, value: T) -> None: ...
1075    @overload
1076    def __setitem__(self, idx: slice, value: List[T]) -> None: ...
1077    def __delitem__(self, idx: _int) -> None: ...
1078    def __iter__(self) -> Iterator[T]: ...
1079    def count(self, value: T) -> _int: ...
1080    def remove(self, value: T) -> None: ...
1081    def append(self, value: T) -> None: ...
1082    def clear(self) -> None: ...
1083    @overload
1084    def extend(self, values: List[T]) -> None: ...
1085    @overload
1086    def extend(self, values: Iterable[T]) -> None: ...
1087    @overload
1088    def pop(self) -> T: ...
1089    @overload
1090    def pop(self, idx: _int) -> T: ...
1091
1092class ModuleDict:
1093    def __init__(self, mod: ScriptModule) -> None: ...
1094    def items(self) -> List[Tuple[str, Any]]: ...
1095
1096class ParameterDict:
1097    def __init__(self, mod: ScriptModule) -> None: ...
1098
1099class BufferDict:
1100    def __init__(self, mod: ScriptModule) -> None: ...
1101
1102# Defined in torch/csrc/jit/api/module.h
1103class Module: ...
1104
1105# Defined in torch/csrc/Module.cpp
1106def _initExtension(shm_manager_path: str) -> None: ...  # THPModule_initExtension
1107def _autograd_init() -> _bool: ...  # THPAutograd_initExtension
1108def _add_docstr(obj: T, doc_obj: str) -> T: ...  # THPModule_addDocStr
1109def _init_names(arg: Sequence[Type]) -> None: ...  # THPModule_initNames
1110def _has_distributed() -> _bool: ...  # THPModule_hasDistributed
1111def _set_default_tensor_type(type) -> None: ...  # THPModule_setDefaultTensorType
1112def _set_default_dtype(d: _dtype) -> None: ...  # THPModule_setDefaultDtype
1113def _infer_size(arg1: Size, arg2: Size) -> Size: ...  # THPModule_inferSize
1114def _crash_if_csrc_asan() -> _int: ...  # THPModule_crashIfCsrcASAN
1115def _crash_if_csrc_ubsan() -> _int: ...  # THPModule_crashIfCsrcUBSAN
1116def _crash_if_aten_asan() -> _int: ...  # THPModule_crashIfATenASAN
1117def _show_config() -> str: ...  # THPModule_showConfig
1118def _cxx_flags() -> str: ...  # THPModule_cxxFlags
1119def _parallel_info() -> str: ...  # THPModule_parallelInfo
1120def _get_cpu_capability() -> str: ...  # THPModule_getCpuCapability
1121def _set_backcompat_broadcast_warn(
1122    arg: _bool,
1123) -> None: ...  # THPModule_setBackcompatBroadcastWarn
1124def _get_backcompat_broadcast_warn() -> _bool: ...  # THPModule_getBackcompatBroadcastWarn
1125def _set_backcompat_keepdim_warn(
1126    arg: _bool,
1127) -> None: ...  # THPModule_setBackcompatKeepdimWarn
1128def _get_backcompat_keepdim_warn() -> _bool: ...  # THPModule_getBackcompatKeepdimWarn
1129def get_num_thread() -> _int: ...  # THPModule_getNumThreads
1130def set_num_threads(nthreads: _int) -> None: ...  # THPModule_setNumThreads
1131def get_num_interop_threads() -> _int: ...  # THPModule_getNumInteropThreads
1132def set_num_interop_threads(
1133    nthreads: _int,
1134) -> None: ...  # THPModule_setNumInteropThreads
1135def _get_cudnn_enabled() -> _bool: ...  # THPModule_userEnabledCuDNN
1136def _set_cudnn_enabled(arg: _bool) -> None: ...  # THPModule_setUserEnabledCuDNN
1137def _get_flash_sdp_enabled() -> _bool: ...  # THPModule_userEnabledFusedSDP
1138def _set_sdp_use_flash(arg: _bool) -> None: ...  # THPModule_setSDPUseFlash
1139def _get_mem_efficient_sdp_enabled() -> _bool: ...  # THPModule_userEnabledMathSDP
1140def _set_sdp_use_mem_efficient(
1141    arg: _bool,
1142) -> None: ...  # THPModule_setSDPUseMemEfficient
1143def _get_math_sdp_enabled() -> _bool: ...  # THPModule_userEnabledMathSDP
1144def _set_sdp_use_math(arg: _bool) -> None: ...  # THPModule_setSDPUseMath
1145def _get_cudnn_sdp_enabled() -> _bool: ...  # THPModule_userEnabledMathSDP
1146def _set_sdp_use_cudnn(arg: _bool) -> None: ...  # THPModule_setSDPUseMath
1147def _get_mkldnn_enabled() -> _bool: ...  # THPModule_userEnabledMkldnn
1148def _set_mkldnn_enabled(arg: _bool) -> None: ...  # THPModule_setUserEnabledMkldnn
1149def _get_cudnn_benchmark() -> _bool: ...  # THPModule_benchmarkCuDNN
1150def _set_cudnn_benchmark(arg: _bool) -> None: ...  # THPModule_setBenchmarkCuDNN
1151def _get_cudnn_deterministic() -> _bool: ...  # THPModule_deterministicCuDNN
1152def _set_cudnn_deterministic(arg: _bool) -> None: ...  # THPModule_setDeterministicCuDNN
1153def _get_deterministic_algorithms() -> _bool: ...  # THPModule_deterministicAlgorithms
1154def _get_deterministic_algorithms_warn_only() -> _bool: ...  # THPModule_deterministicAlgorithmsWarnOnly
1155def _set_deterministic_algorithms(
1156    mode: _bool,
1157    *,
1158    warn_only: _bool = ...,
1159) -> None: ...  # THPModule_setDeterministicAlgorithms
1160def _get_deterministic_fill_uninitialized_memory() -> _bool: ...  # THPModule_deterministicFillUninitializedMemory
1161def _set_deterministic_fill_uninitialized_memory(arg: _bool) -> None: ...  # THPModule_setDeterministicFillUninitializedMemory
1162def _get_nnpack_enabled() -> _bool: ...  # THPModule_userEnabledNNPACK
1163def _set_nnpack_enabled(arg: _bool) -> None: ...  # THPModule_setUserEnabledNNPACK
1164def _get_warnAlways() -> _bool: ...  # THPModule_warnAlways
1165def _set_warnAlways(arg: _bool) -> None: ...  # THPModule_setWarnAlways
1166def _get_cudnn_allow_tf32() -> _bool: ...  # THPModule_allowTF32CuDNN
1167def _set_cudnn_allow_tf32(arg: _bool) -> None: ...  # THPModule_setAllowTF32CuDNN
1168def _get_cublas_allow_tf32() -> _bool: ...  # THPModule_allowTF32CuBLAS
1169def _set_cublas_allow_tf32(arg: _bool) -> None: ...  # THPModule_setAllowTF32CuBLAS
1170def _get_float32_matmul_precision() -> str: ...  # THPModule_float32MatmulPrecision
1171def _set_float32_matmul_precision(
1172    arg: str,
1173) -> None: ...  # THPModule_setFloat32MatmulPrecision
1174def _get_cublas_allow_fp16_reduced_precision_reduction() -> _bool: ...  # THPModule_allowFP16ReductionCuBLAS
1175def _set_cublas_allow_fp16_reduced_precision_reduction(
1176    arg: _bool,
1177) -> None: ...  # THPModule_setAllowFP16ReductionCuBLAS
1178def _get_cublas_allow_bf16_reduced_precision_reduction() -> _bool: ...  # THPModule_allowBF16ReductionCuBLAS
1179def _set_cublas_allow_bf16_reduced_precision_reduction(
1180    arg: _bool,
1181) -> None: ...  # THPModule_setAllowBF16ReductionCuBLAS
1182def _set_conj(x: Tensor, conj: _bool) -> None: ...
1183def _set_neg(x: Tensor, neg: _bool) -> None: ...
1184def _set_meta_in_tls_dispatch_include(meta_in_tls: _bool) -> None: ...
1185def _meta_in_tls_dispatch_include() -> _bool: ...
1186def _stash_obj_in_tls(key: str, arg: Any) -> None: ...
1187def _get_obj_in_tls(key: str) -> Any: ...
1188def _is_key_in_tls(key: str) -> _bool: ...
1189def _select_batch_norm_backend(*args, **kwargs) -> BatchNormBackend: ...
1190def _select_conv_backend(*args, **kwargs) -> ConvBackend: ...
1191def _conv_determine_backend_memory_format(
1192    input: Tensor,
1193    weight: Tensor,
1194    backend: ConvBackend,
1195) -> memory_format: ...
1196def _has_storage(x: Tensor) -> _bool: ...
1197def _construct_storage_from_data_pointer(data_ptr: _int, device: torch.device, size: _int) -> Storage: ...
1198def _should_allow_numbers_as_tensors(func_name: str) -> _bool: ...
1199def _group_tensors_by_device_and_dtype(nested_tensorlists: List[List[Optional[Tensor]]], with_indices: _bool = False) -> Dict[Tuple[torch.device, torch.dtype], Tuple[List[List[Optional[Tensor]]], List[_int]]]: ...
1200
1201# NB: There is no Capsule type in typing, see
1202# https://code.activestate.com/lists/python-dev/139675/
1203def _to_dlpack(data: Tensor) -> Any: ...  # THPModule_toDLPack
1204def _from_dlpack(data: Any) -> Tensor: ...  # THPModule_fromDLPack
1205def _get_cpp_backtrace(
1206    frames_to_skip: _int,
1207    maximum_number_of_frames: _int,
1208) -> str: ...  # THPModule_getCppBacktrace
1209def set_flush_denormal(arg: _bool) -> _bool: ...  # THPModule_setFlushDenormal
1210def get_default_dtype() -> _dtype: ...  # THPModule_getDefaultDtype
1211def _get_default_device() -> str: ...  # THPModule_getDefaultDevice
1212def _get_qengine() -> _int: ...  # THPModule_qEngine
1213def _set_qengine(qengine: _int) -> None: ...  # THPModule_setQEngine
1214def _supported_qengines() -> List[_int]: ...  # THPModule_supportedQEngines
1215def _is_xnnpack_enabled() -> _bool: ...  # THPModule_isEnabledXNNPACK
1216def _check_sparse_tensor_invariants() -> _bool: ...  # THPModule_checkSparseTensorInvariants
1217def _set_check_sparse_tensor_invariants(
1218    arg: _bool,
1219) -> None: ...  # THPModule_setCheckSparseTensorInvariants
1220def _set_default_mobile_cpu_allocator() -> None: ...  # THPModule_setDefaultMobileCPUAllocator
1221def _unset_default_mobile_cpu_allocator() -> None: ...  # THPModule_unsetDefaultMobileCPUAllocator
1222def _is_torch_function_enabled() -> _bool: ...  # THPModule_isEnabledTorchFunction
1223def _has_torch_function(
1224    args: Iterable[Any],
1225) -> _bool: ...  # THPModule_has_torch_function
1226def _has_torch_function_unary(Any) -> _bool: ...  # THPModule_has_torch_function_unary
1227def _has_torch_function_variadic(
1228    *args: Any,
1229) -> _bool: ...  # THPModule_has_torch_function_variadic
1230def _vmapmode_increment_nesting() -> _int: ...  # THPModule_vmapmode_increment_nesting
1231def _vmapmode_decrement_nesting() -> _int: ...  # THPModule_vmapmode_decrement_nesting
1232def _log_api_usage_once(str) -> None: ...  # LogAPIUsageOnceFromPython
1233def _log_api_usage_metadata(event: str, metadata_map: Dict[str, str]) -> None: ...  # LogAPIUsageMetadataFromPython
1234def _demangle(str) -> str: ...  # c10::demangle
1235def _disabled_torch_function_impl(
1236    func: Callable,
1237    types: Iterable[Type],
1238    args: Tuple,
1239    kwargs: Dict,
1240) -> Any: ...  # THPModule_disable_torch_function
1241def _disabled_torch_dispatch_impl(
1242    func: Callable,
1243    types: Iterable[Type],
1244    args: Tuple,
1245    kwargs: Dict,
1246) -> Any: ...  # THPModule_disable_dispatch_function
1247def _get_linalg_preferred_backend() -> torch._C._LinalgBackend: ...
1248def _set_linalg_preferred_backend(arg: torch._C._LinalgBackend): ...
1249
1250class _LinalgBackend:
1251    Default: _LinalgBackend
1252    Cusolver: _LinalgBackend
1253    Magma: _LinalgBackend
1254
1255class BatchNormBackend(Enum): ...
1256
1257def _get_blas_preferred_backend() -> torch._C._BlasBackend: ...
1258def _set_blas_preferred_backend(arg: torch._C._BlasBackend): ...
1259
1260class _BlasBackend:
1261    Cublas: _BlasBackend
1262    Cublaslt: _BlasBackend
1263
1264class ConvBackend(Enum): ...
1265
1266class Tag(Enum):
1267    ${tag_attributes}
1268
1269# Defined in `valgrind.h` and `callgrind.h` respectively.
1270def _valgrind_supported_platform() -> _bool: ...  # NVALGRIND
1271def _valgrind_toggle() -> None: ...  # CALLGRIND_TOGGLE_COLLECT
1272def _valgrind_toggle_and_dump_stats() -> None: ...  # CALLGRIND_TOGGLE_COLLECT and CALLGRIND_DUMP_STATS
1273
1274has_openmp: _bool
1275has_mkl: _bool
1276_has_mps: _bool
1277has_lapack: _bool
1278_has_cuda: _bool
1279_has_magma: _bool
1280_has_xpu: _bool
1281_has_mkldnn: _bool
1282_has_cudnn: _bool
1283has_spectral: _bool
1284_GLIBCXX_USE_CXX11_ABI: _bool
1285default_generator: Generator
1286
1287# Defined in torch/csrc/autograd/init.cpp
1288def _set_grad_enabled(enabled: _bool) -> None: ...
1289def is_grad_enabled() -> _bool: ...
1290def _set_fwd_grad_enabled(enabled: _bool) -> None: ...
1291def _is_fwd_grad_enabled() -> _bool: ...
1292def is_inference_mode_enabled() -> _bool: ...
1293@overload
1294def set_autocast_enabled(device_type: str, enabled: _bool) -> None: ...
1295@overload
1296def set_autocast_enabled(enabled: _bool) -> None: ...
1297@overload
1298def is_autocast_enabled(device_type: str) -> _bool: ...
1299@overload
1300def is_autocast_enabled() -> _bool: ...
1301def set_autocast_dtype(device_type: str, dtype: _dtype) -> None: ...
1302def get_autocast_dtype(device_type: str) -> _dtype: ...
1303def clear_autocast_cache() -> None: ...
1304def set_autocast_cpu_enabled(enabled: _bool) -> None: ...
1305def is_autocast_cpu_enabled() -> _bool: ...
1306def _is_any_autocast_enabled() -> _bool: ...
1307def _is_autocast_available(device_type: str) -> _bool: ...
1308def set_autocast_cpu_dtype(dtype: _dtype) -> None: ...
1309def set_autocast_gpu_dtype(dtype: _dtype) -> None: ...
1310def get_autocast_cpu_dtype() -> _dtype: ...
1311def get_autocast_gpu_dtype() -> _dtype: ...
1312def autocast_increment_nesting() -> _int: ...
1313def autocast_decrement_nesting() -> _int: ...
1314def is_autocast_cache_enabled() -> _bool: ...
1315def set_autocast_cache_enabled(enabled: _bool) -> None: ...
1316def _increment_version(tensor: Tensor) -> None: ...
1317def set_anomaly_enabled(enabled: _bool, check_nan: _bool = True) -> None: ...
1318def is_anomaly_enabled() -> _bool: ...
1319def is_anomaly_check_nan_enabled() -> _bool: ...
1320def _is_multithreading_enabled() -> _bool: ...
1321def _set_multithreading_enabled(enabled: _bool) -> None: ...
1322def _set_view_replay_enabled(enabled: _bool) -> None: ...
1323def _is_view_replay_enabled() -> _bool: ...
1324def _enter_dual_level() -> _int: ...
1325def _exit_dual_level(level: _int) -> None: ...
1326def _make_dual(tensor: Tensor, tangent: Tensor, level: _int) -> Tensor: ...
1327def _unpack_dual(tensor: Tensor, level: _int) -> Tensor: ...
1328def __set_forward_AD_enabled(enabled: _bool) -> None: ...
1329def __is_forward_AD_enabled() -> _bool: ...
1330def _register_default_hooks(pack_hook: Callable, unpack_hook: Callable) -> None: ...
1331def _reset_default_hooks() -> None: ...
1332def _is_torch_function_mode_enabled() -> _bool: ...
1333def _set_torch_function_mode(cls: Any) -> None: ...
1334def _push_on_torch_function_stack(cls: Any) -> None: ...
1335def _pop_torch_function_stack() -> Any: ...
1336def _get_function_stack_at(idx: _int) -> Any: ...
1337def _len_torch_function_stack() -> _int: ...
1338def _set_torch_dispatch_mode(cls: Any) -> None: ...
1339def _push_on_torch_dispatch_stack(cls: TorchDispatchMode) -> None: ...
1340def _pop_torch_dispatch_stack(mode_key: Optional[torch._C._TorchDispatchModeKey] = None) -> Any: ...
1341def _get_dispatch_mode(mode_key: Optional[torch._C._TorchDispatchModeKey]) -> Any: ...
1342def _unset_dispatch_mode(mode: torch._C._TorchDispatchModeKey) -> Optional[TorchDispatchMode]: ...
1343def _set_dispatch_mode(mode: TorchDispatchMode) -> None: ...
1344def _get_dispatch_stack_at(idx: _int) -> Any: ...
1345def _len_torch_dispatch_stack() -> _int: ...
1346def _activate_gpu_trace() -> None: ...
1347
1348class _DisableTorchDispatch:
1349    def __init__(self): ...
1350    def __enter__(self): ...
1351    def __exit__(self, exc_type, exc_value, traceback): ...
1352
1353class _EnableTorchFunction:
1354    def __init__(self): ...
1355    def __enter__(self): ...
1356    def __exit__(self, exc_type, exc_value, traceback): ...
1357
1358class _EnablePythonDispatcher:
1359    def __init__(self): ...
1360    def __enter__(self): ...
1361    def __exit__(self, exc_type, exc_value, traceback): ...
1362
1363class _DisablePythonDispatcher:
1364    def __init__(self): ...
1365    def __enter__(self): ...
1366    def __exit__(self, exc_type, exc_value, traceback): ...
1367
1368class _EnablePreDispatch:
1369    def __init__(self): ...
1370    def __enter__(self): ...
1371    def __exit__(self, exc_type, exc_value, traceback): ...
1372
1373class _DisableFuncTorch:
1374    def __init__(self): ...
1375    def __enter__(self): ...
1376    def __exit__(self, exc_type, exc_value, traceback): ...
1377
1378class _DisableAutocast:
1379    def __init__(self): ...
1380    def __enter__(self): ...
1381    def __exit__(self, exc_type, exc_value, traceback): ...
1382
1383class _InferenceMode:
1384    def __init__(self, enabled: _bool): ...
1385    def __enter__(self): ...
1386    def __exit__(self, exc_type, exc_value, traceback): ...
1387
1388def _set_autograd_fallback_mode(mode: str) -> None: ...
1389def _get_autograd_fallback_mode() -> str: ...
1390
1391# Defined in torch/csrc/jit/python/script_init.cpp
1392class LoggerBase: ...
1393class NoopLogger(LoggerBase): ...
1394class LockingLogger(LoggerBase): ...
1395
1396class AggregationType(Enum):
1397    SUM = 0
1398    AVG = 1
1399
1400class FileCheck:
1401    def run(self, test_string: str) -> None: ...
1402    def check(self, test_string: str) -> FileCheck: ...
1403    def check_not(self, test_string: str) -> FileCheck: ...
1404    def check_same(self, test_string: str) -> FileCheck: ...
1405    def check_next(self, test_string: str) -> FileCheck: ...
1406    def check_count(
1407        self,
1408        test_string: str,
1409        count: _int,
1410        exactly: _bool = False,
1411    ) -> FileCheck: ...
1412    def check_dag(self, test_string: str) -> FileCheck: ...
1413    def check_source_highlighted(self, test_string: str) -> FileCheck: ...
1414    def check_regex(self, test_string: str) -> FileCheck: ...
1415
1416# Defined in torch/csrc/jit/python/init.cpp
1417class PyTorchFileReader:
1418    @overload
1419    def __init__(self, name: str) -> None: ...
1420    @overload
1421    def __init__(self, buffer: BinaryIO) -> None: ...
1422    def get_record(self, name: str) -> bytes: ...
1423    def serialization_id(self) -> str: ...
1424
1425class PyTorchFileWriter:
1426    @overload
1427    def __init__(self, name: str) -> None: ...
1428    @overload
1429    def __init__(self, buffer: BinaryIO) -> None: ...
1430    def write_record(self, name: str, data: Union[Storage, bytes, _int], size: _int) -> None: ...
1431    def write_end_of_file(self) -> None: ...
1432    def set_min_version(self, version: _int) -> None: ...
1433    def get_all_written_records(self) -> List[str]: ...
1434    def archive_name(self) -> str: ...
1435    def serialization_id(self) -> str: ...
1436
1437def _jit_get_inline_everything_mode() -> _bool: ...
1438def _jit_set_inline_everything_mode(enabled: _bool) -> None: ...
1439def _jit_get_logging_option() -> str: ...
1440def _jit_set_logging_option(option: str) -> None: ...
1441def _jit_set_logging_stream(stream_name: str) -> None: ...
1442def _jit_pass_cse(Graph) -> _bool: ...
1443def _jit_pass_dce(Graph) -> None: ...
1444def _jit_pass_lint(Graph) -> None: ...
1445
1446# Defined in torch/csrc/jit/python/python_custom_class.cpp
1447def _get_custom_class_python_wrapper(name: str, attr: str) -> Any: ...
1448
1449# Defined in torch/csrc/Module.cpp
1450def _rename_privateuse1_backend(backend: str) -> None: ...
1451def _get_privateuse1_backend_name() -> str: ...
1452
1453# Defined in torch/csrc/Generator.cpp
1454class Generator:
1455    device: _device
1456    def __init__(self, device: Optional[DeviceLikeType] = None) -> None: ...
1457    def __reduce__(self) -> Tuple[Type[Generator], Tuple[_device], Tuple[_int, Optional[_int], Tensor]]: ...
1458    def __setstate__(self, state: Tuple[_int, Optional[_int], Tensor]) -> None: ...
1459    def get_state(self) -> Tensor: ...
1460    def set_state(self, _new_state: Tensor) -> Generator: ...
1461    def clone_state(self) -> Generator: ...
1462    def graphsafe_get_state(self) -> Generator: ...
1463    def graphsafe_set_state(self, _new_state: Generator) -> Generator: ...
1464    def set_offset(self, offset: _int) -> Generator: ...
1465    def get_offset(self) -> _int: ...
1466    def manual_seed(self, seed: _int) -> Generator: ...
1467    def seed(self) -> _int: ...
1468    def initial_seed(self) -> _int: ...
1469
1470# Defined in torch/csrc/utils/python_dispatch.cpp
1471
1472class _DispatchOperatorHandle:
1473    def schema(self) -> FunctionSchema: ...
1474    def debug(self) -> str: ...
1475
1476class _DispatchModule:
1477    def def_(self, schema: str, alias: str = "") -> _DispatchModule: ...
1478    def def_legacy(self, schema: str) -> _DispatchModule: ...
1479    def def_name_t_t(
1480        self,
1481        name: str,
1482        dispatch: str,
1483        debug: str = "default_def_name_t_t",
1484    ) -> _DispatchModule: ...
1485    def def_schema_t_t(
1486        self,
1487        schema: str,
1488        dispatch: str,
1489        alias: str,
1490        debug: str = "default_def_schema_t_t",
1491    ) -> _DispatchModule: ...
1492    def impl_t_t(
1493        self,
1494        name: str,
1495        dispatch: str,
1496        debug: str = "impl_t_t",
1497    ) -> _DispatchModule: ...
1498    def impl(self, name: str, dispatch: str, func: Callable) -> _DispatchModule: ...
1499    def define(self, schema: str, alias: str = "") -> _DispatchModule: ...
1500    def fallback_fallthrough(self, dispatch: str = "") -> _DispatchModule: ...
1501
1502_after_ADInplaceOrView_keyset: DispatchKeySet
1503_after_autograd_keyset: DispatchKeySet
1504
1505def _dispatch_library(
1506    kind: str,
1507    name: str,
1508    dispatch: str,
1509    file: str = "",
1510    linenum: Any = 0,
1511) -> _DispatchModule: ...
1512def _dispatch_dump(name: str) -> str: ...
1513def _dispatch_dump_table(name: str) -> str: ...
1514def _dispatch_check_invariants(name: str) -> None: ...
1515def _dispatch_check_all_invariants() -> None: ...
1516def _dispatch_call_boxed(handle: _DispatchOperatorHandle, *args, **kwargs) -> Any: ...
1517def _dispatch_find_schema_or_throw(name: str, overload_name: str) -> _DispatchOperatorHandle: ...
1518def _dispatch_set_report_error_callback(handle: _DispatchOperatorHandle, callback: Callable) -> None: ...
1519def _dispatch_has_kernel(name: str) -> _bool: ...
1520def _dispatch_has_kernel_for_dispatch_key(
1521    name: str,
1522    dispatch: _dispatchkey,
1523) -> _bool: ...
1524def _dispatch_has_kernel_for_any_dispatch_key(
1525    name: str,
1526    dispatch_key_set: DispatchKeySet,
1527) -> _bool: ...
1528def _dispatch_kernel_for_dispatch_key_is_fallthrough(
1529    name: str,
1530    dispatch: _dispatchkey,
1531) -> _bool: ...
1532def _dispatch_has_computed_kernel_for_dispatch_key(
1533    name: str,
1534    dispatch: _dispatchkey,
1535) -> _bool: ...
1536def _dispatch_find_dangling_impls() -> List[str]: ...
1537def _dispatch_get_all_op_names() -> List[str]: ...
1538def _dispatch_tls_set_dispatch_key_excluded(
1539    dispatch: _dispatchkey,
1540    val: _bool,
1541) -> None: ...
1542def _dispatch_tls_is_dispatch_key_excluded(dispatch: _dispatchkey) -> _bool: ...
1543def _dispatch_tls_set_dispatch_key_included(
1544    dispatch: _dispatchkey,
1545    val: _bool,
1546) -> None: ...
1547def _dispatch_tls_is_dispatch_key_included(dispatch: _dispatchkey) -> _bool: ...
1548def _dispatch_isTensorSubclassLike(tensor: Tensor) -> _bool: ...
1549def _dispatch_key_name(dispatch: _dispatchkey) -> str: ...
1550def _dispatch_key_for_device(device_type: str) -> str: ...
1551def _parse_dispatch_key(key: str) -> Optional[DispatchKey]: ...
1552def _dispatch_key_parse(dispatch: _dispatchkey) -> DispatchKey: ...
1553def _dispatch_num_backends() -> _int: ...
1554def _dispatch_pystub(name: str, overload: str) -> Optional[Tuple[str, str]]: ...
1555def _dispatch_is_alias_key(dispatch: _dispatchkey) -> _bool: ...
1556def _functionality_to_backend_keys(dispatch: _dispatchkey) -> List[DispatchKey]: ...
1557def _functionalization_reapply_views_tls() -> _bool: ...
1558def _only_lift_cpu_tensors() -> _bool: ...
1559def _set_only_lift_cpu_tensors(value: _bool) -> None: ...
1560def _set_throw_on_mutable_data_ptr(tensor: Tensor) -> None: ...
1561def _set_warn_deprecated_on_mutable_data_ptr(tensor: Tensor) -> None: ...
1562
1563class DispatchKey(Enum):
1564    ${dispatch_key_hints}
1565
1566class DispatchKeySet:
1567    def __init__(self, key: DispatchKey) -> None: ...
1568    def __or__(self, other: DispatchKeySet) -> DispatchKeySet: ...
1569    def __sub__(self, other: DispatchKeySet) -> DispatchKeySet: ...
1570    def __and__(self, other: DispatchKeySet) -> DispatchKeySet: ...
1571    def highestPriorityTypeId(self) -> DispatchKey: ...
1572    def has(self, k: _dispatchkey) -> _bool: ...
1573    def add(self, k: _dispatchkey) -> DispatchKeySet: ...
1574    def remove(self, k: _dispatchkey) -> DispatchKeySet: ...
1575    def __repr__(self) -> str: ...
1576
1577_dispatch_autogradother_backends: DispatchKeySet
1578_additional_keys_to_prop_for_wrapper_tensors: DispatchKeySet
1579
1580def _dispatch_has_backend_fallback(dispatch: _dispatchkey) -> _bool: ...
1581def _dispatch_keyset_full_after(t: _dispatchkey) -> DispatchKeySet: ...
1582def _dispatch_keyset_full() -> DispatchKeySet: ...
1583def _dispatch_keyset_to_string(keyset: DispatchKeySet) -> str: ...
1584def _dispatch_get_backend_keyset_from_autograd(
1585    dispatch: _dispatchkey,
1586) -> DispatchKeySet: ...
1587def _dispatch_keys(tensor: Tensor) -> DispatchKeySet: ...
1588def _dispatch_tls_local_exclude_set() -> DispatchKeySet: ...
1589def _dispatch_tls_local_include_set() -> DispatchKeySet: ...
1590def _dispatch_is_included_in_alias(
1591    dispatch_a: _dispatchkey,
1592    dispatch_b: _dispatchkey,
1593) -> _bool: ...
1594def _propagate_xla_data(a: Tensor, b: Tensor) -> None: ...
1595def _replace_(a: Tensor, b: Tensor) -> None: ...
1596def _commit_update(a: Tensor) -> None: ...
1597
1598class _ExcludeDispatchKeyGuard:
1599    def __init__(self, keyset: DispatchKeySet): ...
1600    def __enter__(self): ...
1601    def __exit__(self, exc_type, exc_value, traceback): ...
1602
1603class _IncludeDispatchKeyGuard:
1604    def __init__(self, k: DispatchKey): ...
1605    def __enter__(self): ...
1606    def __exit__(self, exc_type, exc_value, traceback): ...
1607
1608class _ForceDispatchKeyGuard:
1609    def __init__(self, include: DispatchKeySet, exclude: DispatchKeySet): ...
1610    def __enter__(self): ...
1611    def __exit__(self, exc_type, exc_value, traceback): ...
1612
1613class _PreserveDispatchKeyGuard:
1614    def __init__(self): ...
1615    def __enter__(self): ...
1616    def __exit__(self, exc_type, exc_value, traceback): ...
1617
1618class _AutoDispatchBelowAutograd:
1619    def __init__(self): ...
1620    def __enter__(self): ...
1621    def __exit__(self, exc_type, exc_value, traceback): ...
1622
1623class _AutoDispatchBelowADInplaceOrView:
1624    def __init__(self): ...
1625    def __enter__(self): ...
1626    def __exit__(self, exc_type, exc_value, traceback): ...
1627
1628def _dispatch_print_registrations_for_dispatch_key(dispatch_key: str = "") -> None: ...
1629def _dispatch_get_registrations_for_dispatch_key(
1630    dispatch_key: str = "",
1631) -> List[str]: ...
1632def _are_functorch_transforms_active() -> _bool: ...
1633
1634# Define in torch/csrc/autograd/init.cpp
1635def _set_python_dispatcher(dispatcher: object) -> None: ...
1636
1637def _get_nested_int(id: _int, coeff: _int) -> SymInt: ...
1638
1639def _get_constant_bool_symnode(val: _bool) -> Any: ...
1640
1641class _TorchDispatchModeKey(Enum):
1642    ${torch_dispatch_mode_key_hints}
1643
1644class _SetExcludeDispatchKeyGuard:
1645    def __init__(self, k: DispatchKey, enabled: _bool): ...
1646    def __enter__(self): ...
1647    def __exit__(self, exc_type, exc_value, traceback): ...
1648
1649# Defined in torch/csrc/utils/init.cpp
1650class BenchmarkConfig:
1651    num_calling_threads: _int
1652    num_worker_threads: _int
1653    num_warmup_iters: _int
1654    num_iters: _int
1655    profiler_output_path: str
1656
1657class BenchmarkExecutionStats:
1658    latency_avg_ms: _float
1659    num_iters: _int
1660
1661class ThroughputBenchmark:
1662    def __init__(self, module: Any) -> None: ...
1663    def add_input(self, *args: Any, **kwargs: Any) -> None: ...
1664    def run_once(self, *args: Any, **kwargs: Any) -> Any: ...
1665    def benchmark(self, config: BenchmarkConfig) -> BenchmarkExecutionStats: ...
1666
1667# Defined in torch/csrc/Storage.cpp
1668${legacy_storage_base_hints}
1669
1670# TODO: where
1671${legacy_class_hints}
1672
1673# Defined in torch/csrc/autograd/python_engine.cpp
1674class _ImperativeEngine:
1675    def queue_callback(self, callback: Callable[[], None]) -> None: ...
1676    def run_backward(self, *args: Any, **kwargs: Any) -> Tuple[Tensor, ...]: ...
1677    def is_checkpoint_valid(self) -> _bool: ...
1678
1679# Defined in torch/csrc/autograd/python_variable.cpp
1680class _TensorMeta(type): ...
1681
1682# Defined in torch/csrc/autograd/python_variable.cpp
1683class TensorBase(metaclass=_TensorMeta):
1684    requires_grad: _bool
1685    retains_grad: _bool
1686    shape: Size
1687    data: Tensor
1688    names: List[str]
1689    device: _device
1690    dtype: _dtype
1691    layout: _layout
1692    real: Tensor
1693    imag: Tensor
1694    T: Tensor
1695    H: Tensor
1696    mT: Tensor
1697    mH: Tensor
1698    ndim: _int
1699    output_nr: _int
1700    _version: _int
1701    _base: Optional[Tensor]
1702    _cdata: _int
1703    grad_fn: Optional[_Node]
1704    _grad_fn: Any
1705    _grad: Optional[Tensor]
1706    grad: Optional[Tensor]
1707    _backward_hooks: Optional[Dict[_int, Callable[[Tensor], Optional[Tensor]]]]
1708    nbytes: _int
1709    itemsize: _int
1710    _has_symbolic_sizes_strides: _bool
1711
1712    def _view_func_unsafe(
1713        self,
1714        new_base: Tensor,
1715        symint_visitor_fn: Optional[Callable[[_int], _int]] = None,
1716        tensor_visitor_fn: Optional[Callable[[Tensor], Tensor]] = None
1717    ):
1718        ...
1719
1720    ${tensor_method_hints}
1721
1722_TensorBase = TensorBase
1723
1724# Defined in torch/csrc/multiprocessing/init.cpp
1725def _multiprocessing_init() -> None: ...
1726
1727# Defined in torch/csrc/Module.cpp
1728def _accelerator_hooks_device_count() -> _int: ...
1729def _accelerator_hooks_set_current_device(device_index: _int) -> None: ...
1730def _accelerator_hooks_get_current_device() -> _int: ...
1731def _accelerator_hooks_exchange_device(device_index: _int) -> _int: ...
1732def _accelerator_hooks_maybe_exchange_device(device_index: _int) -> _int: ...
1733def _get_accelerator(check: _bool = False) -> _device: ...
1734
1735# Defined in torch/csrc/mtia/Module.cpp
1736def _mtia_init() -> None: ...
1737def _mtia_isBuilt() -> _bool: ...
1738def _mtia_isInBadFork() -> _bool: ...
1739def _mtia_deviceSynchronize() -> None: ...
1740def _mtia_getCurrentStream(device: _int) -> Stream: ...
1741def _mtia_setCurrentStream(stream: Stream) -> None: ...
1742def _mtia_getDefaultStream(device: _int) -> Stream: ...
1743
1744
1745# Defined in torch/csrc/mps/Module.cpp
1746def _mps_deviceSynchronize() -> None: ...
1747def _mps_get_default_generator() -> Generator: ...
1748def _mps_emptyCache() -> None: ...
1749def _mps_setMemoryFraction(fraction: _float) -> None: ...
1750def _mps_currentAllocatedMemory() -> _int: ...
1751def _mps_driverAllocatedMemory() -> _int: ...
1752def _mps_is_available() -> _bool: ...
1753def _mps_is_on_macos_or_newer(major: _int, minor: _int) -> _bool: ...
1754def _mps_profilerStartTrace(mode: str, wait_until_completed: _bool) -> None: ...
1755def _mps_profilerStopTrace() -> None: ...
1756def _mps_acquireEvent(enable_timing: _bool) -> _int: ...
1757def _mps_releaseEvent(event_id: _int) -> None: ...
1758def _mps_recordEvent(event_id: _int) -> None: ...
1759def _mps_waitForEvent(event_id: _int) -> None: ...
1760def _mps_synchronizeEvent(event_id: _int) -> None: ...
1761def _mps_queryEvent(event_id: _int) -> _bool: ...
1762def _mps_elapsedTimeOfEvents(start_event_id: _int, end_event_id: _int) -> _float: ...
1763
1764
1765# Defined in torch/csrc/cuda/Module.cpp
1766def _cuda_getCurrentStream(device: _int) -> Tuple: ...
1767def _cuda_getCurrentRawStream(device: _int) -> _int: ...
1768def _cuda_getDefaultStream(device: _int) -> Tuple: ...
1769def _cuda_getCurrentBlasHandle() -> _int: ...
1770def _cuda_clearCublasWorkspaces() -> None: ...
1771def _cuda_setDevice(device: _int) -> None: ...
1772def _cuda_exchangeDevice(device: _int) -> _int: ...
1773def _cuda_maybeExchangeDevice(device: _int) -> _int: ...
1774def _cuda_getDevice() -> _int: ...
1775def _cuda_getDeviceCount() -> _int: ...
1776def _cuda_set_sync_debug_mode(warn_level: Union[_int, str]) -> None: ...
1777def _cuda_get_sync_debug_mode() -> _int: ...
1778def _cuda_sleep(cycles: _int) -> None: ...
1779def _cuda_synchronize() -> None: ...
1780def _cuda_ipc_collect() -> None: ...
1781def _cuda_getArchFlags() -> Optional[str]: ...
1782def _cuda_init() -> None: ...
1783def _cuda_setStream(stream_id: _int, device_index: _int, device_type: _int) -> None: ...
1784def _cuda_getCompiledVersion() -> _int: ...
1785def _cuda_cudaHostAllocator() -> _int: ...
1786def _cuda_cudaCachingAllocator_raw_alloc(size: _int, cuda_stream: _int) -> _int: ...
1787def _cuda_cudaCachingAllocator_raw_delete(ptr: _int) -> None: ...
1788def _cuda_cudaCachingAllocator_set_allocator_settings(env: str) -> None: ...
1789def _cuda_beginAllocateCurrentStreamToPool(device: _int, mempool_id: Tuple[_int, _int]) -> None: ...
1790def _cuda_endAllocateCurrentStreamToPool(device: _int, mempool_id: Tuple[_int, _int]) -> None: ...
1791def _cuda_releasePool(device: _int, mempool_id: Tuple[_int, _int]) -> None: ...
1792def _cuda_checkPoolLiveAllocations(device: _int, mempool_id: Tuple[_int, _int], expected_live_allocations: Set) -> _bool: ...
1793def _cuda_setCheckpointPoolState(device: _int, state: _cuda_CUDAAllocator_AllocatorState,  stale_storages: List[_int], storages_to_add_deleters_to: List[_int]) -> None: ...
1794def _cuda_setMemoryFraction(fraction: _float, device: _int) -> None: ...
1795def _cuda_emptyCache() -> None: ...
1796def _cuda_memoryStats(device: _int) -> Dict[str, Any]: ...
1797def _cuda_resetAccumulatedMemoryStats(device: _int) -> None: ...
1798def _cuda_resetPeakMemoryStats(device: _int) -> None: ...
1799def _cuda_memorySnapshot() -> Dict[str, Any]: ...
1800def _cuda_record_memory_history_legacy(
1801    enabled: _bool,
1802    record_context: _bool,
1803    record_context_cpp: _bool,
1804    alloc_trace_max_entries: _int,
1805    alloc_trace_record_context: _bool,
1806) -> None: ...
1807def _cuda_record_memory_history(
1808    enabled: Optional[str],
1809    context: Optional[str],
1810    stacks: str,
1811    max_entries
1812) -> None: ...
1813def _cuda_isHistoryEnabled() -> _bool: ...
1814
1815def _cuda_getAllocatorBackend() -> str: ...
1816class _cuda_CUDAAllocator_AllocatorState:
1817    pass
1818def _cuda_getCheckpointState(device: _int, mempool: Tuple[_int, _int]) -> _cuda_CUDAAllocator_AllocatorState: ...
1819def _set_cached_tensors_enabled(enabled: _bool) -> None: ...
1820def _add_cached_tensor(t: Tensor) -> None: ...
1821def _remove_cached_tensor(t: Tensor) -> None: ...
1822def _tensors_data_ptrs_at_indices_equal(tensors: List[Tensor], ptrs: List[Optional[_int]], indices: List[_int]) -> _bool: ...
1823def _construct_CUDA_Tensor_From_Storage_And_Metadata(metadata: dict, storage: Storage) -> Tensor: ...
1824def _storage_Use_Count(storage_ptr: _int) -> _int: ...
1825def _set_storage_access_error_msg(t: Tensor, s: str) -> None: ...
1826def _free_And_Remove_DeleterFn(storage_ptr: _int) -> None: ...
1827def _has_Standard_Deleter(storage_ptr: _int) -> _bool: ...
1828
1829class _cuda_CUDAAllocator: ...
1830
1831def _cuda_customAllocator(alloc_fn: _int, free_fn: _int) -> _cuda_CUDAAllocator: ...
1832def _cuda_changeCurrentAllocator(allocator: _cuda_CUDAAllocator) -> None: ...
1833def _cuda_getAllocator() -> _cuda_CUDAAllocator: ...
1834def _cuda_lock_mutex() -> None: ...
1835def _cuda_unlock_mutex() -> None: ...
1836def _cuda_canDeviceAccessPeer(device: _int, peer_device: _int) -> _bool: ...
1837def _cuda_jiterator_compile_and_launch_kernel(
1838    code_string: str,
1839    kernel_name: str,
1840    return_by_ref: _bool,
1841    num_outputs: _int,
1842    tensors: Tuple,
1843    kwargs: Dict[str, Union[_int, _float, _bool]],
1844) -> Tensor: ...
1845def _cuda_get_cudnn_benchmark_limit() -> _int: ...
1846def _cuda_set_cudnn_benchmark_limit(arg: _int) -> None: ...
1847def _cuda_get_conv_benchmark_empty_cache() -> _bool: ...
1848def _cudnn_set_conv_benchmark_empty_cache(enable: _bool) -> None: ...
1849def _nccl_version() -> _int: ...
1850def _nccl_version_suffix() -> bytes : ...
1851def _nccl_unique_id() -> bytes: ...
1852def _nccl_init_rank(nranks: _int, comm_id: bytes, rank: _int) -> object: ...
1853def _nccl_reduce(
1854    input: Sequence[Tensor],
1855    output: Tensor,
1856    root: _int,
1857    op: _int,
1858    streams: Optional[Sequence[_CudaStreamBase]],
1859    comms: Optional[Sequence[object]],
1860) -> None: ...
1861def _nccl_all_reduce(
1862    input: Sequence[Tensor],
1863    output: Sequence[Tensor],
1864    op: _int,
1865    streams: Optional[Sequence[_CudaStreamBase]],
1866    comms: Optional[Sequence[object]],
1867) -> None: ...
1868def _nccl_broadcast(
1869    input: Sequence[Tensor],
1870    root: _int,
1871    streams: Optional[Sequence[_CudaStreamBase]],
1872    comms: Optional[Sequence[object]],
1873) -> None: ...
1874def _nccl_all_gather(
1875    input: Sequence[Tensor],
1876    output: Sequence[Tensor],
1877    streams: Optional[Sequence[_CudaStreamBase]],
1878    comms: Optional[Sequence[object]],
1879) -> None: ...
1880def _nccl_reduce_scatter(
1881    input: Sequence[Tensor],
1882    output: Sequence[Tensor],
1883    op: _int,
1884    streams: Optional[Sequence[_CudaStreamBase]],
1885    comms: Optional[Sequence[object]],
1886) -> None: ...
1887def _rocm_is_backward_pass() -> _bool: ...
1888def _cuda_tunableop_enable(val: _bool) -> None: ...
1889def _cuda_tunableop_is_enabled() -> _bool: ...
1890def _cuda_tunableop_tuning_enable(val: _bool) -> None: ...
1891def _cuda_tunableop_tuning_is_enabled() -> _bool: ...
1892def _cuda_tunableop_set_max_tuning_duration(duration: _int) -> None: ...
1893def _cuda_tunableop_get_max_tuning_duration() -> _int: ...
1894def _cuda_tunableop_set_max_tuning_iterations(iterations: _int) -> None: ...
1895def _cuda_tunableop_get_max_tuning_iterations() -> _int: ...
1896def _cuda_tunableop_set_filename(filename: str, insert_device_ordinal: Optional[_bool]) -> None: ...
1897def _cuda_tunableop_get_filename() -> str: ...
1898def _cuda_tunableop_write_file(filename: Optional[str]) -> _bool: ...
1899def _cuda_tunableop_read_file(filename: Optional[str]) -> _bool: ...
1900def _cuda_tunableop_write_file_on_exit(val: _bool) -> None: ...
1901def _cuda_tunableop_get_results() -> Tuple[str, str, str, _float]: ...
1902def _cuda_tunableop_get_validators() -> Tuple[str, str]: ...
1903
1904class _CudaDeviceProperties:
1905    name: str
1906    major: _int
1907    minor: _int
1908    multi_processor_count: _int
1909    total_memory: _int
1910    is_integrated: _int
1911    is_multi_gpu_board: _int
1912    max_threads_per_multi_processor: _int
1913    gcnArchName: str
1914
1915# Functions related to SDPA
1916class _SDPAParams:
1917    query: Tensor
1918    key: Tensor
1919    value: Tensor
1920    attn_mask: Optional[Tensor]
1921    dropout: _float
1922    is_causal: _bool
1923    def __init__(
1924        self,
1925        query: Tensor,
1926        key: Tensor,
1927        value: Tensor,
1928        attn_mask: Optional[Tensor],
1929        dropout: _float,
1930        is_causal: _bool) -> None: ...
1931
1932class _SDPBackend(Enum):
1933    ERROR = -1
1934    MATH = 0
1935    FLASH_ATTENTION = 1
1936    EFFICIENT_ATTENTION = 2
1937    CUDNN_ATTENTION = 3
1938
1939def _can_use_flash_attention(params: _SDPAParams, debug: _bool) -> _bool: ...
1940def _can_use_mem_efficient_attention(params: _SDPAParams, debug: _bool) -> _bool: ...
1941
1942# Defined in torch/csrc/cuda/python_comm.cpp
1943def _broadcast(tensor: Tensor, devices: List[_int]) -> List[Tensor]: ...
1944def _broadcast_out(tensor: Tensor, out_tensors: List[Tensor]) -> List[Tensor]: ...
1945def _broadcast_coalesced(
1946    tensors: List[Tensor],
1947    devices: List[_int],
1948    buffer_size: _int,
1949) -> List[List[Tensor]]: ...
1950def _scatter(
1951    tensor: Tensor,
1952    devices: List[_int],
1953    chunk_sizes: Optional[List[_int]],
1954    dim: _int,
1955    streams: Optional[List[Stream]],
1956) -> List[Tensor]: ...
1957def _scatter_out(
1958    tensor: Tensor,
1959    out_tensors: List[Tensor],
1960    dim: _int,
1961    streams: Optional[List[Stream]],
1962) -> List[Tensor]: ...
1963def _gather(
1964    tensors: List[Tensor],
1965    dim: _int,
1966    destination_index: Optional[_int],
1967) -> Tensor: ...
1968def _gather_out(tensors: List[Tensor], out_tensor: Tensor, dim: _int) -> Tensor: ...
1969
1970# Defined in torch/csrc/cuda/Stream.cpp
1971class _CudaStreamBase(Stream):
1972    stream_id: _int
1973    device_index: _int
1974    device_type: _int
1975
1976    device: _device
1977    cuda_stream: _int
1978    priority: _int
1979
1980    def __new__(
1981        self,
1982        priority: _int = 0,
1983        stream_id: _int = 0,
1984        device_index: _int = 0,
1985        stream_ptr: _int = 0,
1986    ) -> _CudaStreamBase: ...
1987    def query(self) -> _bool: ...
1988    def synchronize(self) -> None: ...
1989    def priority_range(self) -> Tuple[_int, _int]: ...
1990
1991# Defined in torch/csrc/cuda/Event.cpp
1992class _CudaEventBase:
1993    device: _device
1994    cuda_event: _int
1995
1996    def __new__(
1997        cls,
1998        enable_timing: _bool = False,
1999        blocking: _bool = False,
2000        interprocess: _bool = False,
2001    ) -> _CudaEventBase: ...
2002    @classmethod
2003    def from_ipc_handle(cls, device: _device, ipc_handle: bytes) -> _CudaEventBase: ...
2004    def record(self, stream: _CudaStreamBase) -> None: ...
2005    def wait(self, stream: _CudaStreamBase) -> None: ...
2006    def query(self) -> _bool: ...
2007    def elapsed_time(self, other: _CudaEventBase) -> _float: ...
2008    def synchronize(self) -> None: ...
2009    def ipc_handle(self) -> bytes: ...
2010
2011# Defined in torch/csrc/cuda/Graph.cpp
2012class _CUDAGraph:
2013    def capture_begin(self, pool: Optional[Tuple[_int, _int]] = ..., capture_error_mode: str = "global") -> None: ...
2014    def capture_end(self) -> None: ...
2015    def register_generator_state(self, Generator) -> None: ...
2016    def replay(self) -> None: ...
2017    def reset(self) -> None: ...
2018    def pool(self) -> Tuple[_int, _int]: ...
2019    def enable_debug_mode(self) -> None: ...
2020    def debug_dump(self, debug_path: str) -> None: ...
2021
2022def _cuda_isCurrentStreamCapturing() -> _bool: ...
2023def _graph_pool_handle() -> Tuple[_int, _int]: ...
2024
2025# Defined in torch/csrc/xpu/Module.cpp
2026def _xpu_setDevice(device: _int) -> None: ...
2027def _xpu_exchangeDevice(device: _int) -> _int: ...
2028def _xpu_maybeExchangeDevice(device: _int) -> _int: ...
2029def _xpu_getDevice() -> _int: ...
2030def _xpu_getDeviceCount() -> _int: ...
2031def _xpu_init() -> None: ...
2032def _xpu_setStream(stream_id: _int, device_index: _int, device_type: _int) -> None: ...
2033def _xpu_getCurrentStream(device: _int) -> Tuple: ...
2034def _xpu_getCurrentRawStream(device: _int) -> _int: ...
2035def _xpu_synchronize(device: _int) -> None: ...
2036def _xpu_emptyCache() -> None: ...
2037
2038class _XpuDeviceProperties:
2039    name: str
2040    platform_name: str
2041    vendor: str
2042    driver_version: str
2043    version: str
2044    total_memory: _int
2045    max_compute_units: _int
2046    gpu_eu_count: _int
2047    gpu_subslice_count: _int
2048    max_work_group_size: _int
2049    max_num_sub_groups: _int
2050    sub_group_sizes: List[_int]
2051    has_fp16: _bool
2052    has_fp64: _bool
2053    has_atomic64: _bool
2054    type: str
2055
2056# Defined in torch/csrc/xpu/Stream.cpp
2057class _XpuStreamBase(Stream):
2058    stream_id: _int
2059    device_index: _int
2060    device_type: _int
2061
2062    device: _device
2063    sycl_queue: _int
2064    priority: _int
2065
2066    def __new__(
2067        cls,
2068        priority: _int = 0,
2069        stream_id: _int = 0,
2070        device_index: _int = 0,
2071        device_type: _int = 0,
2072    ) -> _XpuStreamBase: ...
2073    def query(self) -> _bool: ...
2074    def synchronize(self) -> None: ...
2075    @staticmethod
2076    def priority_range() -> Tuple: ...
2077
2078# Defined in torch/csrc/xpu/Event.cpp
2079class _XpuEventBase:
2080    device: _device
2081    sycl_event: _int
2082
2083    def __new__(cls, enable_timing: _bool = False) -> _XpuEventBase: ...
2084    def record(self, stream: _XpuEventBase) -> None: ...
2085    def wait(self, stream: _XpuStreamBase) -> None: ...
2086    def query(self) -> _bool: ...
2087    def elapsed_time(self, other: _XpuEventBase) -> _float: ...
2088    def synchronize(self) -> None: ...
2089
2090# Defined in torch/csrc/DataLoader.cpp
2091def _set_worker_signal_handlers(
2092    *arg: Any,
2093) -> None: ...  # THPModule_setWorkerSignalHandlers
2094def _set_worker_pids(
2095    key: _int,
2096    child_pids: Tuple[_int, ...],
2097) -> None: ...  # THPModule_setWorkerPIDs
2098def _remove_worker_pids(loader_id: _int) -> None: ...  # THPModule_removeWorkerPIDs
2099def _error_if_any_worker_fails() -> None: ...  # THPModule_errorIfAnyWorkerFails
2100
2101# Defined in torch/csrc/jit/python/python_tracer.cpp
2102class TracingState:
2103    def push_scope(self, scope_name: str) -> None: ...
2104    def pop_scope(self) -> None: ...
2105    def current_scope(self) -> str: ...
2106    def set_graph(self, graph: Graph) -> None: ...
2107    def graph(self) -> Graph: ...
2108
2109def _create_graph_by_tracing(
2110    func: Callable[..., Any],
2111    inputs: Any,
2112    var_name_lookup_fn: Callable[[Tensor], str],
2113    strict: Any,
2114    force_outplace: Any,
2115    self: Any = None,
2116    argument_names: List[str] = [],
2117) -> Tuple[Graph, Stack]: ...
2118def _tracer_warn_use_python(): ...
2119def _get_tracing_state() -> TracingState: ...
2120
2121# Defined in torch/csrc/jit/python/python_ir.cpp
2122# Not actually defined in python_ir.cpp, not sure where they are.
2123class IValue: ...
2124
2125Stack = List[IValue]
2126
2127class JitType:
2128    annotation_str: str
2129    def isSubtypeOf(self, other: JitType) -> _bool: ...
2130    def with_dtype(self, dtype: _dtype) -> JitType: ...
2131    def with_sizes(self, sizes: List[Optional[_int]]) -> JitType: ...
2132    def kind(self) -> str: ...
2133    def scalarType(self) -> Optional[str]: ...
2134    def getElementType(self) -> JitType: ...
2135    def dtype(self) -> Optional[_dtype]: ...
2136
2137class InferredType:
2138    def __init__(self, arg: Union[JitType, str]): ...
2139    def type(self) -> JitType: ...
2140    def success(self) -> _bool: ...
2141    def reason(self) -> str: ...
2142
2143R = TypeVar("R", bound=JitType)
2144
2145class AnyType(JitType):
2146    @staticmethod
2147    def get() -> AnyType: ...
2148
2149class NoneType(JitType):
2150    @staticmethod
2151    def get() -> NoneType: ...
2152
2153class BoolType(JitType):
2154    @staticmethod
2155    def get() -> BoolType: ...
2156
2157class FloatType(JitType):
2158    @staticmethod
2159    def get() -> FloatType: ...
2160
2161class ComplexType(JitType):
2162    @staticmethod
2163    def get() -> ComplexType: ...
2164
2165class IntType(JitType):
2166    @staticmethod
2167    def get() -> IntType: ...
2168
2169class SymIntType(JitType):
2170    @staticmethod
2171    def get() -> SymIntType: ...
2172
2173class SymBoolType(JitType):
2174    @staticmethod
2175    def get() -> SymBoolType: ...
2176
2177class NumberType(JitType):
2178    @staticmethod
2179    def get() -> NumberType: ...
2180
2181class StringType(JitType):
2182    @staticmethod
2183    def get() -> StringType: ...
2184
2185class DeviceObjType(JitType):
2186    @staticmethod
2187    def get() -> DeviceObjType: ...
2188
2189class _GeneratorType(JitType):
2190    @staticmethod
2191    def get() -> _GeneratorType: ...
2192
2193class StreamObjType(JitType):
2194    @staticmethod
2195    def get() -> StreamObjType: ...
2196
2197class ListType(JitType):
2198    def __init__(self, a: JitType) -> None: ...
2199    def getElementType(self) -> JitType: ...
2200    @staticmethod
2201    def ofInts() -> ListType: ...
2202    @staticmethod
2203    def ofTensors() -> ListType: ...
2204    @staticmethod
2205    def ofFloats() -> ListType: ...
2206    @staticmethod
2207    def ofComplexDoubles() -> ListType: ...
2208    @staticmethod
2209    def ofBools() -> ListType: ...
2210    @staticmethod
2211    def ofStrings() -> ListType: ...
2212
2213class DictType(JitType):
2214    def __init__(self, key: JitType, value: JitType) -> None: ...
2215    def getKeyType(self) -> JitType: ...
2216    def getValueType(self) -> JitType: ...
2217
2218class TupleType(JitType):
2219    def __init__(self, a: List[Optional[JitType]]) -> None: ...
2220    def elements(self) -> List[JitType]: ...
2221
2222class UnionType(JitType):
2223    def __init__(self, a: List[JitType]) -> None: ...
2224
2225class ClassType(JitType):
2226    def __init__(self, qualified_name: str) -> None: ...
2227
2228class InterfaceType(JitType):
2229    def __init__(self, qualified_name: str) -> None: ...
2230    def getMethod(self, name: str) -> Optional[FunctionSchema]: ...
2231    def getMethodNames(self) -> List[str]: ...
2232
2233class OptionalType(JitType, Generic[R]):
2234    def __init__(self, a: JitType) -> None: ...
2235    def getElementType(self) -> JitType: ...
2236    @staticmethod
2237    def ofTensor() -> OptionalType: ...
2238
2239class FutureType(JitType):
2240    def __init__(self, a: JitType) -> None: ...
2241    def getElementType(self) -> JitType: ...
2242
2243class AwaitType(JitType):
2244    def __init__(self, a: JitType) -> None: ...
2245    def getElementType(self) -> JitType: ...
2246
2247class RRefType(JitType):
2248    def __init__(self, a: JitType) -> None: ...
2249
2250class EnumType(JitType):
2251    def __init__(
2252        self,
2253        qualified_name: str,
2254        value_type: JitType,
2255        enum_names_values: List[Any],
2256    ) -> None: ...
2257
2258class TensorType(JitType):
2259    @classmethod
2260    def get(cls) -> TensorType: ...
2261    @classmethod
2262    def getInferred(cls) -> TensorType: ...
2263    def with_sizes(self, other: Optional[List[Optional[_int]]]) -> TensorType: ...
2264    def sizes(self) -> Optional[List[_int]]: ...
2265    def varyingSizes(self) -> Optional[List[Optional[_int]]]: ...
2266    def strides(self) -> Optional[List[_int]]: ...
2267    def device(self) -> Optional[_device]: ...
2268    def dim(self) -> _int: ...
2269    def dtype(self) -> Optional[_dtype]: ...
2270    @staticmethod
2271    def create_from_tensor(t: Tensor) -> TensorType: ...
2272
2273# Defined in torch/csrc/jit/python/python_tree_views.cpp
2274class SourceRange: ...
2275class TreeView: ...
2276
2277class Ident(TreeView):
2278    @property
2279    def name(self) -> str: ...
2280
2281class ClassDef(TreeView): ...
2282
2283class Def(TreeView):
2284    def name(self) -> Ident: ...
2285
2286class Decl(TreeView): ...
2287
2288# Defined in torch/csrc/distributed/rpc/init.cpp
2289def _rpc_init() -> _bool: ...
2290
2291# Defined in torch/csrc/distributed/autograd/init.cpp
2292def _dist_autograd_init() -> _bool: ...
2293
2294# Defined in torch/csrc/distributed/c10d/init.cpp
2295def _c10d_init() -> _bool: ...
2296
2297# Defined in torch/csrc/distributed/rpc/testing/init.cpp
2298def _faulty_agent_init() -> _bool: ...
2299def _register_py_class_for_device(device: str, cls: Any) -> None: ...
2300
2301# Defined in torch/csrc/Module.cpp
2302def _current_graph_task_id() -> _int: ...
2303def _current_autograd_node() -> _Node: ...
2304def _dispatch_key_set(Tensor) -> str: ...
2305
2306# Defined in torch/csrc/Exceptions.cpp
2307class OutOfMemoryError(RuntimeError): ...
2308class _DistError(RuntimeError): ...
2309class _DistBackendError(RuntimeError): ...
2310class _DistStoreError(RuntimeError): ...
2311class _DistNetworkError(RuntimeError): ...
2312
2313# Defined in torch/csrc/profiler/init.cpp
2314class CapturedTraceback:
2315    pass
2316def gather_traceback(python: _bool, script: _bool, cpp: _bool) -> CapturedTraceback: ...
2317def symbolize_tracebacks(tracebacks: List[CapturedTraceback]) -> List[Dict[str, Any]]: ...
2318
2319def _load_mobile_module_from_file(filename: str): ...
2320def _load_mobile_module_from_bytes(bytes_: bytes): ...
2321def _load_jit_module_from_file(filename: str): ...
2322def _load_jit_module_from_bytes(bytes_: bytes): ...
2323def _save_mobile_module(m: LiteScriptModule, filename: str): ...
2324def _save_jit_module(m: ScriptModule, filename: str, extra_files: Dict[str, Any]): ...
2325def _save_mobile_module_to_bytes(m: LiteScriptModule) -> bytes: ...
2326def _save_jit_module_to_bytes(m: ScriptModule,  extra_files: Dict[str, Any]) -> bytes: ...
2327def _get_module_info_from_flatbuffer(data: bytes): ...
2328def _jit_resolve_packet(op_name: str, *args, **kwargs) -> str: ...
2329def _swap_tensor_impl(t1: Tensor, t2: Tensor): ...
2330def _save_pickle(obj: Any) -> bytes: ...
2331
2332# Defined in torch/csrc/jit/runtime/static/init.cpp
2333def _jit_to_static_module(graph_or_module: Union[Graph,ScriptModule]) -> Any: ...
2334def _fuse_to_static_module(graph_or_module: Union[Graph,ScriptModule], min_size: _int) -> Any: ...
2335
2336# Defined in torch/csrc/fx/node.cpp
2337class _NodeBase:
2338    _erased: _bool
2339    _prev: "_NodeBase"
2340    _next: "_NodeBase"
2341
2342class _NodeIter(Iterator):
2343    def __init__(self, root: _NodeBase, reversed: _bool) -> None: ...
2344    def __iter__(self) -> Iterator[_NodeBase]: ...
2345    def __next__(self) -> _NodeBase: ...
2346