1# ${generated_comment} 2# mypy: disable-error-code="type-arg" 3# mypy: allow-untyped-defs 4 5import builtins 6from enum import Enum, IntEnum 7from pathlib import Path 8from typing import ( 9 Any, 10 AnyStr, 11 BinaryIO, 12 Callable, 13 ContextManager, 14 Dict, 15 Generic, 16 Iterable, 17 Iterator, 18 List, 19 Literal, 20 NamedTuple, 21 Optional, 22 Protocol, 23 Sequence, 24 Set, 25 SupportsIndex, 26 Tuple, 27 Type, 28 TypeVar, 29 Union, 30 overload, 31 runtime_checkable, 32) 33from typing_extensions import ParamSpec, Self 34 35import numpy 36 37import torch 38from torch import inf, SymInt, Tensor 39from torch.autograd.graph import Node as _Node 40from torch.package import PackageExporter 41from torch.storage import UntypedStorage, TypedStorage 42from torch.types import ( 43 _bool, 44 _complex, 45 _device, 46 _dispatchkey, 47 _dtype, 48 _float, 49 _int, 50 _layout, 51 _qscheme, 52 _size, 53 Device, 54 Number, 55 Storage, 56) 57 58from torch._prims_common import DeviceLikeType 59from torch.utils._python_dispatch import TorchDispatchMode 60 61# This module is defined in torch/csrc/Module.cpp 62 63from . import _functorch, _lazy, _lazy_ts_backend, _nn, _onnx, _VariableFunctions, _cpu, _aoti, _verbose 64 65K = TypeVar("K") 66T = TypeVar("T") 67S = TypeVar("S", bound="torch.Tensor") 68P = ParamSpec("P") 69ReturnVal = TypeVar("ReturnVal", covariant=True) # return value (always covariant) 70_T_co = TypeVar("_T_co", covariant=True) 71 72 73@runtime_checkable 74class _NestedSequence(Protocol[_T_co]): 75 """A protocol for representing nested sequences. 76 77 References:: 78 `numpy._typing._NestedSequence` 79 <https://github.com/numpy/numpy/blob/main/numpy/_typing/_nested_sequence.py> 80 """ 81 82 def __len__(self, /) -> builtins.int: ... 83 def __getitem__(self, index: builtins.int, /) -> _T_co | _NestedSequence[_T_co]: ... 84 def __contains__(self, x: builtins.object, /) -> builtins.bool: ... 85 def __iter__(self, /) -> Iterator[_T_co | _NestedSequence[_T_co]]: ... 86 def __reversed__(self, /) -> Iterator[_T_co | _NestedSequence[_T_co]]: ... 87 def count(self, value: Any, /) -> builtins.int: ... 88 def index(self, value: Any, /) -> builtins.int: ... 89 90 91# Defined in torch/csrc/Device.cpp 92class device: 93 type: str # THPDevice_type 94 index: _int # THPDevice_index 95 96 def __get__(self, instance, owner=None) -> device: ... 97 98 # THPDevice_pynew 99 @overload 100 def __init__(self, device: DeviceLikeType) -> None: ... 101 @overload 102 def __init__(self, type: str, index: _int) -> None: ... 103 104 # Uncomment if we ever make torch.device a decorator 105 # def __call__(self, func: T) -> T: ... 106 107 def __enter__(self) -> device: ... 108 def __exit__(self, exc_type, exc_val, exc_tb) -> None: ... 109 def __reduce__(self) -> Tuple[Any, ...]: ... # THPDevice_reduce 110 111# Defined in torch/csrc/Stream.cpp 112class Stream: 113 stream_id: _int # Stream id 114 device_index: _int 115 device_type: _int 116 117 device: _device # The device of the stream 118 119 @overload 120 def __new__(self, device: Optional[DeviceLikeType] = None, *, priority: _int = 0) -> Stream: ... 121 @overload 122 def __new__(self, stream_id: _int, device_index: _int, device_type: _int, *, priority: _int = 0) -> Stream: ... 123 def query(self) -> _bool: ... 124 def synchronize(self) -> None: ... 125 def wait_event(self, event: Event) -> None: ... 126 def wait_stream(self, other: Stream) -> None: ... 127 def record_event(self, event: Optional[Event] = None) -> Event: ... 128 def __hash__(self) -> _int: ... 129 def __repr__(self) -> str: ... 130 def __eq__(self, other: object) -> _bool: ... 131 132 133# Defined in torch/csrc/Event.cpp 134class Event: 135 136 device: _device # The device of the Event 137 event_id: _int # The raw event created by device backend 138 139 def __new__(self, 140 device: Optional[DeviceLikeType] = None, 141 *, 142 enable_timing: _bool = False, 143 blocking: _bool = False, 144 interprocess: _bool = False) -> Event: ... 145 @classmethod 146 def from_ipc_handle(self, device: _device, ipc_handle: bytes) -> Event: ... 147 def record(self, stream: Optional[Stream] = None) -> None: ... 148 def wait(self, stream: Optional[Stream] = None) -> None: ... 149 def query(self) -> _bool: ... 150 def elapsed_time(self, other: Event) -> _float: ... 151 def synchronize(self) -> None: ... 152 def ipc_handle(self) -> bytes: ... 153 def __repr__(self) -> str: ... 154 155 156# Defined in torch/csrc/Size.cpp 157class Size(Tuple[_int, ...]): 158 # TODO: __reduce__ 159 160 @overload # type: ignore[override] 161 def __getitem__(self: Size, key: _int) -> _int: ... 162 @overload 163 def __getitem__(self: Size, key: slice) -> Size: ... 164 def numel(self: Size) -> _int: ... 165 166# Defined in torch/csrc/Dtype.cpp 167class dtype: 168 # TODO: __reduce__ 169 is_floating_point: _bool 170 is_complex: _bool 171 is_signed: _bool 172 itemsize: _int 173 def to_real(self) -> dtype: ... 174 def to_complex(self) -> dtype: ... 175 176# Defined in torch/csrc/TypeInfo.cpp 177class iinfo: 178 bits: _int 179 min: _int 180 max: _int 181 dtype: str 182 183 def __init__(self, dtype: _dtype) -> None: ... 184 185class finfo: 186 bits: _int 187 min: _float 188 max: _float 189 eps: _float 190 tiny: _float 191 smallest_normal: _float 192 resolution: _float 193 dtype: str 194 195 @overload 196 def __init__(self, dtype: _dtype) -> None: ... 197 @overload 198 def __init__(self) -> None: ... 199 200${dtype_class_hints} 201 202# Defined in torch/csrc/Layout.cpp 203class layout: ... 204 205# Defined in torch/csrc/utils/disable_torch_function.cpp 206def DisableTorchFunction(): ... 207def DisableTorchFunctionSubclass(): ... 208 209# Defined in torch/csrc/utils/tensor_layouts.cpp 210strided: layout = ... 211sparse_coo: layout = ... 212sparse_csr: layout = ... 213sparse_csc: layout = ... 214sparse_bsr: layout = ... 215sparse_bsc: layout = ... 216_mkldnn: layout = ... 217jagged: layout = ... 218 219# Defined in torch/csrc/MemoryFormat.cpp 220class memory_format: ... 221 222# Defined in torch/csrc/utils/tensor_memoryformats.cpp 223contiguous_format: memory_format = ... 224channels_last: memory_format = ... 225channels_last_3d: memory_format = ... 226preserve_format: memory_format = ... 227 228# Defined in torch/csrc/QScheme.cpp 229class qscheme: ... 230 231# Defined in torch/csrc/utils/tensor_qschemes.h 232per_tensor_affine: qscheme = ... 233per_channel_affine: qscheme = ... 234per_tensor_symmetric: qscheme = ... 235per_channel_symmetric: qscheme = ... 236per_channel_affine_float_qparams: qscheme = ... 237 238# Defined in torch/csrc/autograd/python_function.cpp 239class _FunctionBase: 240 saved_tensors: Tuple[Tensor] 241 _raw_saved_tensors: Tuple[Any] 242 next_functions: Tuple[Tuple[Any, _int], ...] 243 needs_input_grad: Tuple[_bool] 244 metadata: dict 245 _materialize_non_diff_grads: _bool 246 # skip adding type hints for the fields that have wrappers defined 247 # in torch/autograd/function.py 248 249# Defined in torch/csrc/autograd/python_legacy_variable.cpp 250class _LegacyVariableBase(Tensor): # inherits from Tensor to appease mypy 251 def __init__( 252 self, 253 data: Optional[Tensor] = ..., 254 requires_grad: Optional[_bool] = ..., 255 volatile: Optional[_bool] = ..., 256 _grad_fn: Optional[_FunctionBase] = ..., 257 ) -> None: ... 258 259# Defined in torch/csrc/jit/python/init.cpp 260class IODescriptor: ... 261class JITException: ... 262 263class Future(Generic[T]): 264 def __init__(self, devices: List[device]) -> None: ... 265 def done(self) -> _bool: ... 266 def value(self) -> T: ... 267 def wait(self) -> T: ... 268 def add_done_callback(self, callback: Callable) -> None: ... 269 def then(self, callback: Callable) -> Future[T]: ... 270 def set_result(self, result: T) -> None: ... 271 def _set_unwrap_func(self, callback: Callable) -> None: ... 272 273class _Await: 274 def __init__(self) -> None: ... 275 def fn(self) -> Callable: ... 276 def args(self) -> Tuple[Any, ...]: ... 277 def is_nowait(self) -> _bool: ... 278 279def _jit_set_num_profiled_runs(num: _size) -> _size: ... 280 281# Defined in torch/csrc/jit/passes/mobile_optimizer_type.h 282class _MobileOptimizerType: ... 283 284CONV_BN_FUSION: _MobileOptimizerType 285INSERT_FOLD_PREPACK_OPS: _MobileOptimizerType 286REMOVE_DROPOUT: _MobileOptimizerType 287FUSE_ADD_RELU: _MobileOptimizerType 288HOIST_CONV_PACKED_PARAMS: _MobileOptimizerType 289VULKAN_AUTOMATIC_GPU_TRANSFER: _MobileOptimizerType 290 291def fork(*args: Any, **kwargs: Any) -> Future: ... 292def wait(fut: Future) -> Any: ... 293def _awaitable(*args: Any, **kwargs: Any) -> _Await: ... 294def _awaitable_wait(aw: _Await) -> Any: ... 295def _awaitable_nowait(x: Any) -> _Await: ... 296def _collect_all(futures: List[Future]) -> Future: ... 297def _set_print_stack_traces_on_fatal_signal(print: _bool) -> None: ... 298def unify_type_list(types: List[JitType]) -> JitType: ... 299def _freeze_module( 300 module: ScriptModule, 301 preserved_attrs: List[str] = [], 302 freeze_interfaces: _bool = True, 303 preserveParameters: _bool = True, 304) -> ScriptModule: ... 305def _jit_pass_optimize_frozen_graph(Graph, optimize_numerics: _bool = True) -> None: ... 306def _jit_pass_optimize_for_inference( 307 module: torch.jit.ScriptModule, 308 other_methods: List[str] = [], 309) -> None: ... 310def _jit_pass_fold_frozen_conv_bn(graph: Graph): ... 311def _jit_pass_fold_frozen_conv_add_or_sub(graph: Graph): ... 312def _jit_pass_fold_frozen_conv_mul_or_div(graph: Graph): ... 313def _jit_pass_fuse_frozen_conv_add_relu(graph: Graph): ... 314def _jit_pass_concat_frozen_linear(graph: Graph): ... 315def _jit_pass_convert_frozen_ops_to_mkldnn(graph: Graph): ... 316def _jit_pass_transpose_frozen_linear(graph: Graph): ... 317def _jit_pass_remove_dropout(module: torch.jit.ScriptModule): ... 318def _is_tracing() -> _bool: ... 319def _jit_init() -> _bool: ... 320def _jit_flatten(arg: Any) -> Tuple[List[Tensor], IODescriptor]: ... 321def _jit_unflatten(vars: List[Tensor], desc: IODescriptor) -> Any: ... 322def _jit_get_operation(op_name: str) -> Tuple[Callable, List[str]]: ... 323def _get_operation_overload( 324 op_name: str, 325 op_overload_name: str, 326) -> Tuple[Callable, Callable, List[Any]]: ... 327def _get_schema(op_name: str, overload_name: str) -> FunctionSchema: ... 328def _jit_pass_optimize_for_mobile( 329 module: torch.jit.ScriptModule, 330 optimization_blocklist: Set[_MobileOptimizerType], 331 preserved_methods: List[AnyStr], 332) -> torch.jit.ScriptModule: ... 333def _clone_module_with_class( 334 module: torch.jit.ScriptModule, 335 ignored_methods: List[AnyStr], 336 ignored_attributes: List[AnyStr], 337) -> torch.jit.ScriptModule: ... 338def _jit_pass_vulkan_optimize_for_mobile( 339 module: torch.jit.ScriptModule, 340 optimization_blocklist: Set[_MobileOptimizerType], 341 preserved_methods: List[AnyStr], 342) -> torch.jit.ScriptModule: ... 343def _jit_pass_metal_optimize_for_mobile( 344 module: torch.jit.ScriptModule, 345 preserved_methods: List[AnyStr], 346) -> torch.jit.ScriptModule: ... 347def _jit_pass_inline(Graph) -> None: ... 348def _jit_pass_constant_propagation(Graph) -> None: ... 349def _jit_pass_propagate_shapes_on_graph(Graph) -> None: ... 350def _jit_register_decomposition_for_schema(schema: FunctionSchema, Graph) -> None: ... 351def _jit_erase_non_input_shape_information(Graph) -> None: ... 352def _jit_get_schemas_for_operator(name: str) -> List[FunctionSchema]: ... 353def _jit_get_all_schemas() -> List[FunctionSchema]: ... 354def _jit_check_alias_annotation( 355 g: Graph, 356 args: Tuple[Any, ...], 357 unqualified_op_name: str, 358): ... 359def _jit_can_fuse_on_cpu() -> _bool: ... 360def _jit_can_fuse_on_gpu() -> _bool: ... 361def _jit_can_fuse_on_cpu_legacy() -> _bool: ... 362def _debug_get_fusion_group_inlining() -> _bool: ... 363def _debug_set_fusion_group_inlining(enable: _bool): ... 364def _jit_texpr_fuser_enabled() -> _bool: ... 365def _jit_nvfuser_enabled() -> _bool: ... 366def _jit_llga_enabled() -> _bool: ... 367def _jit_set_llga_enabled(enable: _bool): ... 368def _llvm_enabled() -> _bool: ... 369def _jit_override_can_fuse_on_cpu(override: _bool): ... 370def _jit_override_can_fuse_on_gpu(override: _bool): ... 371def _jit_override_can_fuse_on_cpu_legacy(override: _bool): ... 372def _jit_set_symbolic_shapes_test_mode(override: _bool): ... 373def _jit_symbolic_shapes_test_mode_enabled() -> _bool: ... 374def _jit_set_texpr_fuser_enabled(enable: _bool): ... 375def _jit_set_te_must_use_llvm_cpu(use_llvm: _bool): ... 376def _jit_set_nvfuser_enabled(enable: _bool) -> _bool: ... 377def _jit_cat_wo_conditionals(optimize_cat: _bool): ... 378def _jit_opt_conditionals(opt_conds: _bool): ... 379def _jit_pass_canonicalize(graph: Graph, keep_unique_names: _bool = True): ... 380def _jit_pass_erase_shape_information(graph: Graph): ... 381def _jit_pass_fold_convbn(module: torch.jit.ScriptModule): ... 382def _jit_pass_insert_observers( 383 module: torch.jit.ScriptModule, 384 method_name: str, 385 qconfig_dict: Dict[str, Any], 386 inplace: _bool, 387 quant_type: _int, 388): ... 389def _jit_pass_insert_quant_dequant( 390 module: torch.jit.ScriptModule, 391 method_name: str, 392 inplace: _bool, 393 debug: _bool, 394 quant_type: _int, 395): ... 396def _jit_pass_insert_quant_dequant_for_ondevice_ptq( 397 module: torch.jit.ScriptModule, 398 method_name: str, 399 inplace: _bool, 400 debug: _bool, 401 quant_type: _int, 402): ... 403def _jit_pass_quant_finalize( 404 module: torch.jit.ScriptModule, 405 quant_type: _int, 406 preserved_attrs: Sequence[str], 407): ... 408def _jit_pass_quant_finalize_for_ondevice_ptq( 409 module: torch.jit.ScriptModule, 410 quant_type: _int, 411 method_name: str, 412): ... 413def _jit_pass_insert_observer_method_for_ondevice_ptq( 414 module: torch.jit.ScriptModule, 415 method_name: str, 416 qconfig_dict: Dict[str, Any], 417 inplace: _bool, 418 quant_type: _int, 419): ... 420def _jit_set_profiling_executor(profiling_flag: _bool) -> _bool: ... 421def _jit_set_profiling_mode(profiling_flag: _bool) -> _bool: ... 422def _jit_set_fusion_strategy( 423 strategy: List[Tuple[str, _int]], 424) -> List[Tuple[str, _int]]: ... 425def _jit_try_infer_type(obj: Any) -> InferredType: ... 426def _jit_get_trigger_value(trigger_name: str) -> _int: ... 427 428# Defined in torch/csrc/jit/python/script_init.cpp 429ResolutionCallback = Callable[[str], Callable[..., Any]] 430 431# Defined in torch/csrc/jit/python/script_init.cpp 432# and torch/csrc/jit/python/init.cpp 433def _maybe_call_torch_function_for_op_packet( 434 op_overload_packet: Any, 435 args: Any, 436 kwargs: Any, 437) -> Any: ... 438def _check_schema_allow_fake_script_object( 439 schema: FunctionSchema, 440 args: Any, 441 kwargs: Any, 442) -> _bool: ... 443def _create_function_from_graph(qualname: str, graph: Graph) -> ScriptFunction: ... 444def _debug_set_autodiff_subgraph_inlining(disabled: _bool) -> None: ... 445def _ivalue_tags_match(lhs: ScriptModule, rhs: ScriptModule) -> _bool: ... 446def _jit_assert_is_instance(obj: Any, type: JitType): ... 447def _jit_clear_class_registry() -> None: ... 448def _jit_set_emit_hooks( 449 ModuleHook: Optional[Callable], 450 FunctionHook: Optional[Callable], 451) -> None: ... 452def _jit_get_emit_hooks() -> Tuple[Callable, Callable]: ... 453def _load_for_lite_interpreter( 454 filename: Union[str, Path], 455 map_location: Optional[DeviceLikeType], 456): ... 457def _load_for_lite_interpreter_from_buffer( 458 buffer: BinaryIO, 459 map_location: Optional[DeviceLikeType], 460): ... 461def _export_operator_list(module: LiteScriptModule): ... 462def _quantize_ondevice_ptq_dynamic(module: LiteScriptModule, method_name: str): ... 463def _get_model_bytecode_version(filename: Union[str, Path]) -> _int: ... 464def _get_model_bytecode_version_from_buffer(buffer: BinaryIO) -> _int: ... 465def _backport_for_mobile( 466 filename_input: Union[str, Path], 467 filename_output: Union[str, Path], 468 to_version: _int, 469) -> None: ... 470def _backport_for_mobile_from_buffer( 471 buffer: BinaryIO, 472 filename_output: Union[str, Path], 473 to_version: _int, 474) -> None: ... 475def _backport_for_mobile_to_buffer( 476 filename_input: Union[str, Path], 477 to_version: _int, 478) -> bytes: ... 479def _backport_for_mobile_from_buffer_to_buffer( 480 buffer: BinaryIO, 481 to_version: _int, 482) -> bytes: ... 483def _get_model_ops_and_info(filename: Union[str, Path]): ... 484def _get_model_ops_and_info_from_buffer(buffer: BinaryIO): ... 485def _get_mobile_model_contained_types(filename: Union[str, Path]): ... 486def _get_mobile_model_contained_types_from_buffer(buffer: BinaryIO): ... 487def _logging_set_logger(logger: LoggerBase) -> LoggerBase: ... 488def _get_graph_executor_optimize(optimize: Optional[_bool] = None) -> _bool: ... 489def _set_graph_executor_optimize(optimize: _bool): ... 490def _export_opnames(module: ScriptModule) -> List[str]: ... 491def _create_function_from_trace( 492 qualname: str, 493 func: Callable[..., Any], 494 input_tuple: Tuple[Any, ...], 495 var_lookup_fn: Callable[[Tensor], str], 496 strict: _bool, 497 force_outplace: _bool, 498 argument_names: List[str], 499) -> Tuple[Graph, Stack]: ... 500def _create_function_from_trace_with_dict( 501 qualname: str, 502 func: Callable[..., Any], 503 input_dict: Dict[str, Any], 504 var_lookup_fn: Callable[[Tensor], str], 505 strict: _bool, 506 force_outplace: _bool, 507 argument_names: List[str], 508) -> Tuple[Graph, Stack]: ... 509def _jit_is_script_object(obj: Any) -> _bool: ... 510def _last_executed_optimized_graph() -> Graph: ... 511def parse_type_comment(comment: str) -> Decl: ... 512def _get_upgraders_map_size() -> _int: ... 513def _get_upgraders_entry_map() -> Dict[str, str]: ... 514def _dump_upgraders_map() -> Dict[str, str]: ... 515def _test_only_populate_upgraders(content: Dict[str, str]) -> None: ... 516def _test_only_remove_upgraders(content: Dict[str, str]) -> None: ... 517def merge_type_from_type_comment( 518 decl: Decl, 519 type_annotation_decl: Decl, 520 is_method: _bool, 521) -> Decl: ... 522def parse_ir(input: str, parse_tensor_constants: _bool = False) -> Graph: ... 523def parse_schema(schema: str) -> FunctionSchema: ... 524def get_device(input: Tensor) -> _int: ... 525def _resolve_type_from_object( 526 obj: Any, 527 range: SourceRange, 528 rcb: ResolutionCallback, 529) -> JitType: ... 530def _create_module_with_type(ty: JitType) -> ScriptModule: ... 531def _create_object_with_type(ty: ClassType) -> ScriptObject: ... 532def _run_emit_module_hook(m: ScriptModule): ... 533def _replace_overloaded_method_decl( 534 overload_decl: Decl, 535 implementation_def: Def, 536 new_name: str, 537) -> Def: ... 538def _jit_pass_lower_all_tuples(graph: Graph) -> None: ... 539def _jit_pass_onnx_set_dynamic_input_shape( 540 graph: Graph, 541 dynamic_axes: Dict[str, Dict[_int, str]], 542 input_names: List[str], 543) -> None: ... 544def _jit_pass_onnx_graph_shape_type_inference( 545 graph: Graph, 546 params_dict: Dict[str, IValue], 547 opset_version: _int, 548) -> None: ... 549def _jit_pass_onnx_assign_output_shape( 550 graph: Graph, 551 tensors: List[Tensor], 552 desc: IODescriptor, 553 onnx_shape_inference: _bool, 554 is_script: _bool, 555 opset_version: _int, 556) -> None: ... 557def _jit_pass_onnx_remove_inplace_ops_for_onnx( 558 graph: Graph, 559 module: Optional[ScriptModule] = None, 560) -> None: ... 561def _jit_pass_remove_inplace_ops(graph: Graph) -> None: ... 562def _jit_pass_canonicalize_graph_fuser_ops(graph: Graph) -> None: ... 563def _jit_pass_peephole( 564 graph: Graph, 565 disable_shape_peepholes: _bool = False, 566) -> None: ... 567def _jit_pass_onnx_autograd_function_process(graph: Graph) -> None: ... 568def _jit_pass_fuse_addmm(graph: Graph) -> None: ... 569def _jit_pass_onnx_preprocess(graph: Graph) -> None: ... 570def _jit_pass_prepare_division_for_onnx(graph: Graph) -> None: ... 571def _jit_pass_onnx_remove_print(graph: Graph) -> None: ... 572def _jit_pass_onnx_preprocess_caffe2(graph: Graph) -> None: ... 573def _jit_pass_onnx_unpack_quantized_weights( 574 graph: Graph, 575 paramsDict: Dict[str, IValue], 576 caffe2: _bool, 577) -> Dict[str, IValue]: ... 578def _jit_pass_onnx_quantization_insert_permutes( 579 graph: Graph, 580 paramsDict: Dict[str, IValue], 581) -> Dict[str, IValue]: ... 582def _jit_pass_custom_pattern_based_rewrite_graph( 583 pattern: str, 584 fused_node_name: str, 585 graph: Graph, 586) -> None: ... 587def _jit_onnx_list_model_parameters( 588 module: ScriptModule, 589) -> Tuple[ScriptModule, List[IValue]]: ... 590def _jit_pass_erase_number_types(graph: Graph) -> None: ... 591def _jit_pass_onnx_lint(graph: Graph) -> None: ... 592def _jit_pass_onnx( 593 graph: Graph, 594 _jit_pass_onnx: _onnx.OperatorExportTypes, 595) -> Graph: ... 596def _jit_pass_onnx_scalar_type_analysis( 597 graph: Graph, 598 lowprecision_cast: _bool, 599 opset_version: _int, 600) -> None: ... 601def _jit_pass_onnx_peephole( 602 graph: Graph, 603 opset_version: _int, 604 fixed_batch_size: _bool, 605) -> None: ... 606def _jit_pass_dce_allow_deleting_nodes_with_side_effects(graph: Graph) -> None: ... 607def _jit_pass_onnx_function_substitution(graph: Graph) -> None: ... 608def _jit_pass_onnx_function_extraction( 609 graph: Graph, 610 module_names: Set[str], 611 param_names: List[str], 612) -> Dict[Node, Dict[str, str]]: ... 613def _jit_pass_onnx_clear_scope_records() -> None: ... 614def _jit_pass_onnx_track_scope_attributes( 615 graph: Graph, 616 onnx_attrs: Dict[str, Any], 617) -> None: ... 618def _jit_is_onnx_log_enabled() -> _bool: ... 619def _jit_set_onnx_log_enabled(enabled: _bool) -> None: ... 620def _jit_set_onnx_log_output_stream(stream_name: str) -> None: ... 621def _jit_onnx_log(*args: Any) -> None: ... 622def _jit_pass_lower_graph(graph: Graph, m: Module) -> Tuple[Graph, List[IValue]]: ... 623def _jit_pass_inline_fork_wait(graph: Graph) -> None: ... 624def _jit_pass_onnx_deduplicate_initializers( 625 graph: Graph, 626 params_dict: Dict[str, IValue], 627 is_train: _bool, 628) -> Dict[str, IValue]: ... 629def _jit_pass_onnx_eval_peephole( 630 graph: Graph, 631 paramsDict: Dict[str, IValue], 632) -> Dict[str, IValue]: ... 633def _jit_pass_onnx_constant_fold( 634 graph: Graph, 635 paramsDict: Dict[str, IValue], 636 opset_version: _int, 637) -> Dict[str, IValue]: ... 638def _jit_pass_onnx_eliminate_unused_items( 639 graph: Graph, 640 paramsDict: Dict[str, IValue], 641) -> Dict[str, IValue]: ... 642def _jit_pass_onnx_cast_all_constant_to_floating(graph: Graph) -> None: ... 643def _jit_pass_filter_non_tensor_arguments( 644 params: Dict[str, IValue], 645) -> Dict[str, Tensor]: ... 646def _jit_decay_packed_param_input_types(graph: Graph) -> None: ... 647def _jit_pass_onnx_node_shape_type_inference( 648 n: Node, 649 paramsDict: Dict[str, IValue], 650 opset_version: _int, 651) -> None: ... 652def _jit_onnx_convert_pattern_from_subblock( 653 block: Block, 654 n: Node, 655 env: Dict[Value, Value], 656 values_in_env: Set[Value], 657) -> List[Value]: ... 658def _jit_pass_onnx_block( 659 old_block: Block, 660 new_block: Block, 661 operator_export_type: _onnx.OperatorExportTypes, 662 env: Dict[Value, Value], 663 values_in_env: Set[Value], 664 is_sub_block: _bool, 665) -> Dict[Value, Value]: ... 666def _jit_pass_onnx_assign_scoped_names_for_node_and_value(graph: Graph) -> None: ... 667def _jit_pass_fixup_onnx_controlflow_node( 668 n: Node, 669 opset_version: _int, 670) -> List[Value]: ... 671def _jit_onnx_create_full_scope_name(class_name: str, variable_name: str) -> str: ... 672def _compile_graph_to_code_table(name: str, graph: Graph) -> IValue: ... 673def _generate_upgraders_graph() -> Dict[str, Graph]: ... 674def _calculate_package_version_based_on_upgraders(val: _bool): ... 675def _get_version_calculator_flag() -> _bool: ... 676def _jit_script_interface_compile( 677 name: str, 678 class_def: ClassDef, 679 rcb: ResolutionCallback, 680 is_module: _bool, 681): ... 682def _jit_script_compile_overload( 683 qualname: str, 684 overload_decl: Decl, 685 implementation_def: Def, 686 rcb: ResolutionCallback, 687 implementation_defaults: Dict[str, Any], 688 signature: Any, 689): ... 690def _jit_script_compile( 691 qual_name: str, 692 definition: Def, 693 rcb: ResolutionCallback, 694 defaults: Dict[str, Any], 695): ... 696def _jit_script_class_compile( 697 qual_name: str, 698 definition: ClassDef, 699 defaults: Dict[str, Dict[str, Any]], 700 rcb: ResolutionCallback, 701): ... 702def _parse_source_def(src: str) -> Def: ... 703def import_ir_module( 704 cu: CompilationUnit, 705 filename: Union[str, Path], 706 map_location: Optional[DeviceLikeType], 707 extra_files: Dict[str, Any], 708) -> ScriptModule: ... 709def import_ir_module_from_buffer( 710 cu: CompilationUnit, 711 buffer: BinaryIO, 712 map_location: Optional[DeviceLikeType], 713 extra_files: Dict[str, Any], 714) -> ScriptModule: ... 715def _import_ir_module_from_package( 716 cu: CompilationUnit, 717 reader: PyTorchFileReader, 718 storage_context: DeserializationStorageContext, 719 map_location: Optional[DeviceLikeType], 720 ts_id: str, 721) -> ScriptModule: ... 722def _assign_output_shapes(graph: Graph, inputs: List[Tensor]) -> Graph: ... 723def _check_onnx_proto(proto: str) -> None: ... 724def _propagate_and_assign_input_shapes( 725 graph: Graph, 726 inputs: Tuple[Tensor, ...], 727 param_count_list: List[_int], 728 with_grad: _bool, 729 propagate: _bool, 730) -> Graph: ... 731 732# Defined in torch/csrc/jit/runtime/graph_executor.h 733class GraphExecutorState: ... 734 735# Defined in torch/torch/csrc/jit/ir/alias_analysis.h 736class AliasDb: 737 def __str__(self) -> str: ... 738 739class _InsertPoint: 740 def __enter__(self) -> None: ... 741 def __exit__(self, *args) -> None: ... 742 743# Defined in torch/csrc/jit/ir/ir.h 744class Use: 745 @property 746 def user(self) -> Node: ... 747 @property 748 def offset(self) -> _int: ... 749 def isAfter(self, other: Use) -> _bool: ... 750 751# Defined in torch/csrc/jit/ir/ir.h 752class Value: 753 def type(self) -> JitType: ... 754 def setType(self, t: JitType) -> Value: ... 755 def setTypeAs(self, other: Value) -> Value: ... 756 def inferTypeFrom(self, t: Tensor) -> None: ... 757 def debugName(self) -> str: ... 758 def setDebugName(self, name: str) -> None: ... 759 def unique(self) -> _int: ... 760 def offset(self) -> _int: ... 761 def node(self) -> Node: ... 762 def uses(self) -> List[Use]: ... 763 def replaceAllUsesWith(self, val: Value) -> None: ... 764 def replaceAllUsesAfterNodeWith(self, node: Node, val: Value) -> None: ... 765 def requires_grad(self) -> _bool: ... 766 def requiresGrad(self) -> _bool: ... 767 def copyMetadata(self, other: Value) -> Value: ... 768 def isCompleteTensor(self) -> _bool: ... 769 def toIValue(self) -> IValue: ... 770 771# Defined in torch/csrc/jit/ir/ir.h 772class Block: 773 def inputs(self) -> Iterator[Value]: ... 774 def outputs(self) -> Iterator[Value]: ... 775 def nodes(self) -> Iterator[Node]: ... 776 def paramNode(self) -> Node: ... 777 def returnNode(self) -> Node: ... 778 def owningNode(self) -> Node: ... 779 def registerOutput(self, n: Value) -> _int: ... 780 def addNode(self, name: str, inputs: Sequence[Value]) -> Node: ... 781 782# Defined in torch/csrc/jit/ir/ir.h 783class Node: 784 def __getitem__(self, key: str) -> Any: ... 785 def schema(self) -> str: ... 786 def input(self) -> Value: ... 787 def inputs(self) -> Iterator[Value]: ... 788 def inputsAt(self, idx: _int) -> Value: ... 789 def inputsSize(self) -> _int: ... 790 def output(self) -> Value: ... 791 def outputs(self) -> Iterator[Value]: ... 792 def outputsAt(self, idx: _int) -> Value: ... 793 def outputsSize(self) -> _int: ... 794 def hasMultipleOutputs(self) -> _bool: ... 795 def blocks(self) -> List[Block]: ... 796 def addBlock(self) -> Block: ... 797 def mustBeNone(self) -> _bool: ... 798 def matches(self, pattern: str) -> _bool: ... 799 def kind(self) -> str: ... 800 def kindOf(self, name: str) -> str: ... 801 def addInput(self, name: str) -> Value: ... 802 def replaceInput(self, i: _int, newValue: Value) -> Value: ... 803 def replaceInputWith(self, from_: Value, to: Value) -> None: ... 804 def replaceAllUsesWith(self, n: Node) -> None: ... 805 def insertBefore(self, n: Node) -> Node: ... 806 def insertAfter(self, n: Node) -> Node: ... 807 def isBefore(self, n: Node) -> _bool: ... 808 def isAfter(self, n: Node) -> _bool: ... 809 def moveBefore(self, n: Node) -> None: ... 810 def moveAfter(self, n: Node) -> None: ... 811 def removeInput(self, i: _int) -> None: ... 812 def removeAllInputs(self, i: _int) -> None: ... 813 def hasUses(self) -> _bool: ... 814 def eraseOutput(self, i: _int) -> None: ... 815 def addOutput(self) -> Value: ... 816 def scopeName(self) -> str: ... 817 def isNondeterministic(self) -> _bool: ... 818 def copyAttributes(self, rhs: Node) -> Node: ... 819 def copyMetadata(self, rhs: Node) -> Node: ... 820 def hasAttributes(self) -> _bool: ... 821 def hasAttribute(self, name: str) -> _bool: ... 822 def removeAttribute(self, attr: str) -> Node: ... 823 def namedInput(self, name: str) -> Value: ... 824 def sourceRange(self) -> SourceRange: ... 825 def owningBlock(self) -> Block: ... 826 def findNode(self, kind: str, recurse: _bool = True) -> Node: ... 827 def findAllNodes(self, kind: str, recurse: _bool = True) -> List[Node]: ... 828 def getModuleHierarchy(self) -> str: ... 829 def prev(self) -> Node: ... 830 def destroy(self) -> None: ... 831 def attributeNames(self) -> List[str]: ... 832 833 # Accessors for attributes as types. 834 def f(self, name: str) -> _float: ... 835 def f_(self, name: str, val: _float) -> Node: ... 836 def fs(self, name: str) -> List[_float]: ... 837 def fs_(self, name: str, val: List[_float]) -> Node: ... 838 def c(self, name: str) -> complex: ... 839 def c_(self, name: str, val: complex) -> Node: ... 840 def s(self, name: str) -> str: ... 841 def s_(self, name: str, val: str) -> Node: ... 842 def ss(self, name: str) -> List[str]: ... 843 def ss_(self, name: str, val: List[str]) -> Node: ... 844 def i(self, name: str) -> _int: ... 845 def i_(self, name: str, val: _int) -> Node: ... 846 # Cannot define "is" like this because it's a reserved keyword in python. 847 # def is(self, name: str) -> List[_int]: ... 848 # def is_(self, name: str, val: List[_int]) -> Node: ... 849 def g(self, name: str) -> Graph: ... 850 def g_(self, name: str, val: Graph) -> Node: ... 851 def gs(self, name: str) -> List[Graph]: ... 852 def gs_(self, name: str, val: List[Graph]) -> Node: ... 853 def ival(self, name: str) -> IValue: ... 854 def ival_(self, name: str, val: IValue) -> Node: ... 855 def t(self, name: str) -> Tensor: ... 856 def t_(self, name: str, val: Tensor) -> Node: ... 857 def ts(self, name: str) -> List[Tensor]: ... 858 def ts_(self, name: str, val: List[Tensor]) -> Node: ... 859 def ty(self, name: str) -> JitType: ... 860 def ty_(self, name: str, val: JitType) -> Node: ... 861 def tys(self, name: str) -> List[JitType]: ... 862 def tys_(self, name: str, val: List[JitType]) -> Node: ... 863 864# Defined in torch/torch/csrc/jit/ir/ir.h 865class Graph: 866 def inputs(self) -> Iterator[Value]: ... 867 def outputs(self) -> Iterator[Value]: ... 868 def nodes(self) -> Iterator[Node]: ... 869 def param_node(self) -> Node: ... 870 def return_node(self) -> Node: ... 871 def addInput(self, name: str = "") -> Value: ... 872 def eraseInput(self, i: _int) -> None: ... 873 def registerOutput(self, n: Value) -> _int: ... 874 def eraseOutput(self, i: _int) -> None: ... 875 def create(self, name: str, args, num_outputs: _int) -> Node: ... 876 def appendNode(self, n: Node) -> Node: ... 877 def prependNode(self, n: Node) -> Node: ... 878 def insertNode(self, n: Node) -> Node: ... 879 def block(self) -> Block: ... 880 def lint(self) -> None: ... 881 def alias_db(self) -> AliasDb: ... 882 def setInsertPoint(self, n: Union[Block, Node]) -> None: ... 883 def insert_point_guard(self, n: Union[Block, Node]) -> _InsertPoint: ... 884 def insertPoint(self) -> Node: ... 885 def insertGraph(self, callee: Graph, inputs: List[Value]) -> List[Value]: ... 886 def makeMultiOutputIntoTuple(self) -> None: ... 887 def copy(self) -> Graph: ... 888 889# Defined in torch/aten/src/ATen/core/alias_info.h 890class AliasInfo: 891 is_write: _bool 892 before_set: Set[str] 893 after_set: Set[str] 894 895# Defined in torch/aten/src/ATen/core/function_schema.h 896class Argument: 897 name: str 898 type: JitType 899 default_value: Optional[Any] 900 def has_default_value(self) -> _bool: ... 901 kwarg_only: _bool 902 is_out: _bool 903 alias_info: Optional[AliasInfo] 904 905class FunctionSchema: 906 arguments: List[Argument] 907 returns: List[Argument] 908 name: str 909 overload_name: str 910 is_mutable: _bool 911 912class _UpgraderEntry: 913 bumped_at_version: _int 914 upgrader_name: str 915 old_schema: str 916 def __init__( 917 self, 918 bumped_at_version: _int, 919 upgrader_name: str, 920 old_schema: str, 921 ) -> None: ... 922 923class _UpgraderRange: 924 min_version: _int 925 max_version: _int 926 927def _get_max_operator_version() -> _int: ... 928def _get_operator_version_map() -> Dict[str, List[_UpgraderEntry]]: ... 929def _get_upgrader_ranges(name: str) -> List[_UpgraderRange]: ... 930def _test_only_add_entry_to_op_version(op_name: str, entry: _UpgraderEntry) -> None: ... 931def _test_only_remove_entry_to_op_version(op_name: str) -> None: ... 932 933# Defined in torch/csrc/jit/python/script_init.cpp 934class ScriptModuleSerializer: 935 def __init__(self, export_writer: PyTorchFileWriter) -> None: ... 936 def serialize(self, model: ScriptModule, script_module_id: _int) -> None: ... 937 def write_files(self) -> None: ... 938 def storage_context(self) -> SerializationStorageContext: ... 939 940# Defined in torch/csrc/jit/python/script_init.cpp 941class SerializationStorageContext: 942 def __init__(self) -> None: ... 943 def has_storage(self, storage: Storage) -> _bool: ... 944 def get_or_add_storage(self, storage: Storage) -> _int: ... 945 946# Defined in torch/csrc/jit/python/script_init.cpp 947class DeserializationStorageContext: 948 def __init__(self) -> None: ... 949 def get_storage(self, name: str, dtype: _dtype) -> Tensor: ... 950 def has_storage(self, name: str) -> _bool: ... 951 def add_storage(self, name: str, tensor: Tensor) -> _int: ... 952 953# Defined in torch/csrc/jit/python/script_init.cpp 954class ConcreteModuleTypeBuilder: 955 def __init__(self, obj: Any) -> None: ... 956 def set_module_dict(self): ... 957 def set_module_list(self): ... 958 def set_parameter_list(self): ... 959 def set_parameter_dict(self): ... 960 def add_attribute( 961 self, 962 name: str, 963 ty: JitType, 964 is_param: _bool, 965 is_buffer: _bool, 966 ): ... 967 def add_module(self, name: str, meta: ConcreteModuleType): ... 968 def add_constant(self, name: str, value: Any): ... 969 def add_overload(self, method_name: str, overloaded_method_names: List[str]): ... 970 def add_builtin_function(self, name: str, symbol_name: str): ... 971 def add_failed_attribute(self, name: str, failure_reason: str): ... 972 def add_function_attribute( 973 self, 974 name: str, 975 ty: JitType, 976 func: Callable[..., Any], 977 ): ... 978 def add_ignored_attribute(self, name: str): ... 979 def add_ignored_attributes(self, names: List[str]): ... 980 def add_forward_hook(self, hook: Callable[..., Any]): ... 981 def add_forward_pre_hook(self, pre_hook: Callable[..., Any]): ... 982 983class ConcreteModuleType: 984 def get_constants(self) -> Dict[str, Any]: ... 985 def equals(self, other: ConcreteModuleType) -> _bool: ... 986 @staticmethod 987 def from_jit_type(ty: JitType) -> ConcreteModuleType: ... 988 989class CallStack: 990 def __init__(self, name: str, range: SourceRange): ... 991 992class ErrorReport: 993 def __init__(self, range: SourceRange) -> None: ... 994 def what(self) -> str: ... 995 @staticmethod 996 def call_stack() -> str: ... 997 998class CompilationUnit: 999 def __init__(self, lang: str = ..., _frames_up: _int = ...) -> None: ... 1000 def find_function(self, name: str) -> ScriptFunction: ... 1001 def __getattr__(self, name: str) -> ScriptFunction: ... 1002 def define( 1003 self, 1004 script: str, 1005 rcb: ResolutionCallback = ..., 1006 _frames_up: _int = ..., 1007 ): ... 1008 def get_interface(self, name: str) -> InterfaceType: ... 1009 def get_functions(self) -> List[ScriptFunction]: ... 1010 def create_function( 1011 self, 1012 name: str, 1013 graph: Graph, 1014 shouldMangle: _bool = ..., 1015 ) -> ScriptFunction: ... 1016 def get_class(self, name: str) -> ClassType: ... 1017 1018class ScriptObject: 1019 def setattr(self, name: str, value: Any): ... 1020 1021class ScriptModule(ScriptObject): 1022 def _method_names(self) -> List[str]: ... 1023 def _get_method(self, name: str) -> ScriptMethod: ... 1024 1025class LiteScriptModule: 1026 def __call__(self, *input): ... 1027 def find_method(self, method_name: str): ... 1028 def forward(self, *input) -> List[str]: ... 1029 def run_method(self, method_name: str, *input): ... 1030 1031# NOTE: switch to collections.abc.Callable in python 3.9 1032class ScriptFunction(Generic[P, ReturnVal]): 1033 def __call__(self, *args: P.args, **kwargs: P.kwargs) -> ReturnVal: ... 1034 def save(self, filename: str, _extra_files: Dict[str, bytes]) -> None: ... 1035 def save_to_buffer(self, _extra_files: Dict[str, bytes]) -> bytes: ... 1036 @property 1037 def graph(self) -> Graph: ... 1038 def inlined_graph(self) -> Graph: ... 1039 def schema(self) -> FunctionSchema: ... 1040 def code(self) -> str: ... 1041 def name(self) -> str: ... 1042 @property 1043 def qualified_name(self) -> str: ... 1044 1045# NOTE: switch to collections.abc.Callable in python 3.9 1046class ScriptMethod(Generic[P, ReturnVal]): 1047 graph: Graph 1048 def __call__(self, *args: P.args, **kwargs: P.kwargs) -> ReturnVal: ... 1049 @property 1050 def owner(self) -> ScriptModule: ... 1051 @property 1052 def name(self) -> str: ... 1053 1054class ScriptDict(Generic[K, T]): 1055 def __init__(self, dict: Dict[K, T]) -> None: ... 1056 def __len__(self) -> _int: ... 1057 def __contains__(self, key: K) -> _bool: ... 1058 def __getitem__(self, key: K) -> T: ... 1059 def __setitem__(self, key: K, value: T) -> None: ... 1060 def __delitem__(self, key: K) -> None: ... 1061 def __iter__(self) -> Iterator[K]: ... 1062 def items(self) -> Iterator[tuple[K, T]]: ... 1063 def keys(self) -> Iterator[K]: ... 1064 1065class ScriptList(Generic[T]): 1066 def __init__(self, list: List[T]) -> None: ... 1067 def __len__(self) -> _int: ... 1068 def __contains__(self, item: T) -> _bool: ... 1069 @overload 1070 def __getitem__(self, idx: _int) -> T: ... 1071 @overload 1072 def __getitem__(self, idx: slice) -> ScriptList[T]: ... 1073 @overload 1074 def __setitem__(self, idx: _int, value: T) -> None: ... 1075 @overload 1076 def __setitem__(self, idx: slice, value: List[T]) -> None: ... 1077 def __delitem__(self, idx: _int) -> None: ... 1078 def __iter__(self) -> Iterator[T]: ... 1079 def count(self, value: T) -> _int: ... 1080 def remove(self, value: T) -> None: ... 1081 def append(self, value: T) -> None: ... 1082 def clear(self) -> None: ... 1083 @overload 1084 def extend(self, values: List[T]) -> None: ... 1085 @overload 1086 def extend(self, values: Iterable[T]) -> None: ... 1087 @overload 1088 def pop(self) -> T: ... 1089 @overload 1090 def pop(self, idx: _int) -> T: ... 1091 1092class ModuleDict: 1093 def __init__(self, mod: ScriptModule) -> None: ... 1094 def items(self) -> List[Tuple[str, Any]]: ... 1095 1096class ParameterDict: 1097 def __init__(self, mod: ScriptModule) -> None: ... 1098 1099class BufferDict: 1100 def __init__(self, mod: ScriptModule) -> None: ... 1101 1102# Defined in torch/csrc/jit/api/module.h 1103class Module: ... 1104 1105# Defined in torch/csrc/Module.cpp 1106def _initExtension(shm_manager_path: str) -> None: ... # THPModule_initExtension 1107def _autograd_init() -> _bool: ... # THPAutograd_initExtension 1108def _add_docstr(obj: T, doc_obj: str) -> T: ... # THPModule_addDocStr 1109def _init_names(arg: Sequence[Type]) -> None: ... # THPModule_initNames 1110def _has_distributed() -> _bool: ... # THPModule_hasDistributed 1111def _set_default_tensor_type(type) -> None: ... # THPModule_setDefaultTensorType 1112def _set_default_dtype(d: _dtype) -> None: ... # THPModule_setDefaultDtype 1113def _infer_size(arg1: Size, arg2: Size) -> Size: ... # THPModule_inferSize 1114def _crash_if_csrc_asan() -> _int: ... # THPModule_crashIfCsrcASAN 1115def _crash_if_csrc_ubsan() -> _int: ... # THPModule_crashIfCsrcUBSAN 1116def _crash_if_aten_asan() -> _int: ... # THPModule_crashIfATenASAN 1117def _show_config() -> str: ... # THPModule_showConfig 1118def _cxx_flags() -> str: ... # THPModule_cxxFlags 1119def _parallel_info() -> str: ... # THPModule_parallelInfo 1120def _get_cpu_capability() -> str: ... # THPModule_getCpuCapability 1121def _set_backcompat_broadcast_warn( 1122 arg: _bool, 1123) -> None: ... # THPModule_setBackcompatBroadcastWarn 1124def _get_backcompat_broadcast_warn() -> _bool: ... # THPModule_getBackcompatBroadcastWarn 1125def _set_backcompat_keepdim_warn( 1126 arg: _bool, 1127) -> None: ... # THPModule_setBackcompatKeepdimWarn 1128def _get_backcompat_keepdim_warn() -> _bool: ... # THPModule_getBackcompatKeepdimWarn 1129def get_num_thread() -> _int: ... # THPModule_getNumThreads 1130def set_num_threads(nthreads: _int) -> None: ... # THPModule_setNumThreads 1131def get_num_interop_threads() -> _int: ... # THPModule_getNumInteropThreads 1132def set_num_interop_threads( 1133 nthreads: _int, 1134) -> None: ... # THPModule_setNumInteropThreads 1135def _get_cudnn_enabled() -> _bool: ... # THPModule_userEnabledCuDNN 1136def _set_cudnn_enabled(arg: _bool) -> None: ... # THPModule_setUserEnabledCuDNN 1137def _get_flash_sdp_enabled() -> _bool: ... # THPModule_userEnabledFusedSDP 1138def _set_sdp_use_flash(arg: _bool) -> None: ... # THPModule_setSDPUseFlash 1139def _get_mem_efficient_sdp_enabled() -> _bool: ... # THPModule_userEnabledMathSDP 1140def _set_sdp_use_mem_efficient( 1141 arg: _bool, 1142) -> None: ... # THPModule_setSDPUseMemEfficient 1143def _get_math_sdp_enabled() -> _bool: ... # THPModule_userEnabledMathSDP 1144def _set_sdp_use_math(arg: _bool) -> None: ... # THPModule_setSDPUseMath 1145def _get_cudnn_sdp_enabled() -> _bool: ... # THPModule_userEnabledMathSDP 1146def _set_sdp_use_cudnn(arg: _bool) -> None: ... # THPModule_setSDPUseMath 1147def _get_mkldnn_enabled() -> _bool: ... # THPModule_userEnabledMkldnn 1148def _set_mkldnn_enabled(arg: _bool) -> None: ... # THPModule_setUserEnabledMkldnn 1149def _get_cudnn_benchmark() -> _bool: ... # THPModule_benchmarkCuDNN 1150def _set_cudnn_benchmark(arg: _bool) -> None: ... # THPModule_setBenchmarkCuDNN 1151def _get_cudnn_deterministic() -> _bool: ... # THPModule_deterministicCuDNN 1152def _set_cudnn_deterministic(arg: _bool) -> None: ... # THPModule_setDeterministicCuDNN 1153def _get_deterministic_algorithms() -> _bool: ... # THPModule_deterministicAlgorithms 1154def _get_deterministic_algorithms_warn_only() -> _bool: ... # THPModule_deterministicAlgorithmsWarnOnly 1155def _set_deterministic_algorithms( 1156 mode: _bool, 1157 *, 1158 warn_only: _bool = ..., 1159) -> None: ... # THPModule_setDeterministicAlgorithms 1160def _get_deterministic_fill_uninitialized_memory() -> _bool: ... # THPModule_deterministicFillUninitializedMemory 1161def _set_deterministic_fill_uninitialized_memory(arg: _bool) -> None: ... # THPModule_setDeterministicFillUninitializedMemory 1162def _get_nnpack_enabled() -> _bool: ... # THPModule_userEnabledNNPACK 1163def _set_nnpack_enabled(arg: _bool) -> None: ... # THPModule_setUserEnabledNNPACK 1164def _get_warnAlways() -> _bool: ... # THPModule_warnAlways 1165def _set_warnAlways(arg: _bool) -> None: ... # THPModule_setWarnAlways 1166def _get_cudnn_allow_tf32() -> _bool: ... # THPModule_allowTF32CuDNN 1167def _set_cudnn_allow_tf32(arg: _bool) -> None: ... # THPModule_setAllowTF32CuDNN 1168def _get_cublas_allow_tf32() -> _bool: ... # THPModule_allowTF32CuBLAS 1169def _set_cublas_allow_tf32(arg: _bool) -> None: ... # THPModule_setAllowTF32CuBLAS 1170def _get_float32_matmul_precision() -> str: ... # THPModule_float32MatmulPrecision 1171def _set_float32_matmul_precision( 1172 arg: str, 1173) -> None: ... # THPModule_setFloat32MatmulPrecision 1174def _get_cublas_allow_fp16_reduced_precision_reduction() -> _bool: ... # THPModule_allowFP16ReductionCuBLAS 1175def _set_cublas_allow_fp16_reduced_precision_reduction( 1176 arg: _bool, 1177) -> None: ... # THPModule_setAllowFP16ReductionCuBLAS 1178def _get_cublas_allow_bf16_reduced_precision_reduction() -> _bool: ... # THPModule_allowBF16ReductionCuBLAS 1179def _set_cublas_allow_bf16_reduced_precision_reduction( 1180 arg: _bool, 1181) -> None: ... # THPModule_setAllowBF16ReductionCuBLAS 1182def _set_conj(x: Tensor, conj: _bool) -> None: ... 1183def _set_neg(x: Tensor, neg: _bool) -> None: ... 1184def _set_meta_in_tls_dispatch_include(meta_in_tls: _bool) -> None: ... 1185def _meta_in_tls_dispatch_include() -> _bool: ... 1186def _stash_obj_in_tls(key: str, arg: Any) -> None: ... 1187def _get_obj_in_tls(key: str) -> Any: ... 1188def _is_key_in_tls(key: str) -> _bool: ... 1189def _select_batch_norm_backend(*args, **kwargs) -> BatchNormBackend: ... 1190def _select_conv_backend(*args, **kwargs) -> ConvBackend: ... 1191def _conv_determine_backend_memory_format( 1192 input: Tensor, 1193 weight: Tensor, 1194 backend: ConvBackend, 1195) -> memory_format: ... 1196def _has_storage(x: Tensor) -> _bool: ... 1197def _construct_storage_from_data_pointer(data_ptr: _int, device: torch.device, size: _int) -> Storage: ... 1198def _should_allow_numbers_as_tensors(func_name: str) -> _bool: ... 1199def _group_tensors_by_device_and_dtype(nested_tensorlists: List[List[Optional[Tensor]]], with_indices: _bool = False) -> Dict[Tuple[torch.device, torch.dtype], Tuple[List[List[Optional[Tensor]]], List[_int]]]: ... 1200 1201# NB: There is no Capsule type in typing, see 1202# https://code.activestate.com/lists/python-dev/139675/ 1203def _to_dlpack(data: Tensor) -> Any: ... # THPModule_toDLPack 1204def _from_dlpack(data: Any) -> Tensor: ... # THPModule_fromDLPack 1205def _get_cpp_backtrace( 1206 frames_to_skip: _int, 1207 maximum_number_of_frames: _int, 1208) -> str: ... # THPModule_getCppBacktrace 1209def set_flush_denormal(arg: _bool) -> _bool: ... # THPModule_setFlushDenormal 1210def get_default_dtype() -> _dtype: ... # THPModule_getDefaultDtype 1211def _get_default_device() -> str: ... # THPModule_getDefaultDevice 1212def _get_qengine() -> _int: ... # THPModule_qEngine 1213def _set_qengine(qengine: _int) -> None: ... # THPModule_setQEngine 1214def _supported_qengines() -> List[_int]: ... # THPModule_supportedQEngines 1215def _is_xnnpack_enabled() -> _bool: ... # THPModule_isEnabledXNNPACK 1216def _check_sparse_tensor_invariants() -> _bool: ... # THPModule_checkSparseTensorInvariants 1217def _set_check_sparse_tensor_invariants( 1218 arg: _bool, 1219) -> None: ... # THPModule_setCheckSparseTensorInvariants 1220def _set_default_mobile_cpu_allocator() -> None: ... # THPModule_setDefaultMobileCPUAllocator 1221def _unset_default_mobile_cpu_allocator() -> None: ... # THPModule_unsetDefaultMobileCPUAllocator 1222def _is_torch_function_enabled() -> _bool: ... # THPModule_isEnabledTorchFunction 1223def _has_torch_function( 1224 args: Iterable[Any], 1225) -> _bool: ... # THPModule_has_torch_function 1226def _has_torch_function_unary(Any) -> _bool: ... # THPModule_has_torch_function_unary 1227def _has_torch_function_variadic( 1228 *args: Any, 1229) -> _bool: ... # THPModule_has_torch_function_variadic 1230def _vmapmode_increment_nesting() -> _int: ... # THPModule_vmapmode_increment_nesting 1231def _vmapmode_decrement_nesting() -> _int: ... # THPModule_vmapmode_decrement_nesting 1232def _log_api_usage_once(str) -> None: ... # LogAPIUsageOnceFromPython 1233def _log_api_usage_metadata(event: str, metadata_map: Dict[str, str]) -> None: ... # LogAPIUsageMetadataFromPython 1234def _demangle(str) -> str: ... # c10::demangle 1235def _disabled_torch_function_impl( 1236 func: Callable, 1237 types: Iterable[Type], 1238 args: Tuple, 1239 kwargs: Dict, 1240) -> Any: ... # THPModule_disable_torch_function 1241def _disabled_torch_dispatch_impl( 1242 func: Callable, 1243 types: Iterable[Type], 1244 args: Tuple, 1245 kwargs: Dict, 1246) -> Any: ... # THPModule_disable_dispatch_function 1247def _get_linalg_preferred_backend() -> torch._C._LinalgBackend: ... 1248def _set_linalg_preferred_backend(arg: torch._C._LinalgBackend): ... 1249 1250class _LinalgBackend: 1251 Default: _LinalgBackend 1252 Cusolver: _LinalgBackend 1253 Magma: _LinalgBackend 1254 1255class BatchNormBackend(Enum): ... 1256 1257def _get_blas_preferred_backend() -> torch._C._BlasBackend: ... 1258def _set_blas_preferred_backend(arg: torch._C._BlasBackend): ... 1259 1260class _BlasBackend: 1261 Cublas: _BlasBackend 1262 Cublaslt: _BlasBackend 1263 1264class ConvBackend(Enum): ... 1265 1266class Tag(Enum): 1267 ${tag_attributes} 1268 1269# Defined in `valgrind.h` and `callgrind.h` respectively. 1270def _valgrind_supported_platform() -> _bool: ... # NVALGRIND 1271def _valgrind_toggle() -> None: ... # CALLGRIND_TOGGLE_COLLECT 1272def _valgrind_toggle_and_dump_stats() -> None: ... # CALLGRIND_TOGGLE_COLLECT and CALLGRIND_DUMP_STATS 1273 1274has_openmp: _bool 1275has_mkl: _bool 1276_has_mps: _bool 1277has_lapack: _bool 1278_has_cuda: _bool 1279_has_magma: _bool 1280_has_xpu: _bool 1281_has_mkldnn: _bool 1282_has_cudnn: _bool 1283has_spectral: _bool 1284_GLIBCXX_USE_CXX11_ABI: _bool 1285default_generator: Generator 1286 1287# Defined in torch/csrc/autograd/init.cpp 1288def _set_grad_enabled(enabled: _bool) -> None: ... 1289def is_grad_enabled() -> _bool: ... 1290def _set_fwd_grad_enabled(enabled: _bool) -> None: ... 1291def _is_fwd_grad_enabled() -> _bool: ... 1292def is_inference_mode_enabled() -> _bool: ... 1293@overload 1294def set_autocast_enabled(device_type: str, enabled: _bool) -> None: ... 1295@overload 1296def set_autocast_enabled(enabled: _bool) -> None: ... 1297@overload 1298def is_autocast_enabled(device_type: str) -> _bool: ... 1299@overload 1300def is_autocast_enabled() -> _bool: ... 1301def set_autocast_dtype(device_type: str, dtype: _dtype) -> None: ... 1302def get_autocast_dtype(device_type: str) -> _dtype: ... 1303def clear_autocast_cache() -> None: ... 1304def set_autocast_cpu_enabled(enabled: _bool) -> None: ... 1305def is_autocast_cpu_enabled() -> _bool: ... 1306def _is_any_autocast_enabled() -> _bool: ... 1307def _is_autocast_available(device_type: str) -> _bool: ... 1308def set_autocast_cpu_dtype(dtype: _dtype) -> None: ... 1309def set_autocast_gpu_dtype(dtype: _dtype) -> None: ... 1310def get_autocast_cpu_dtype() -> _dtype: ... 1311def get_autocast_gpu_dtype() -> _dtype: ... 1312def autocast_increment_nesting() -> _int: ... 1313def autocast_decrement_nesting() -> _int: ... 1314def is_autocast_cache_enabled() -> _bool: ... 1315def set_autocast_cache_enabled(enabled: _bool) -> None: ... 1316def _increment_version(tensor: Tensor) -> None: ... 1317def set_anomaly_enabled(enabled: _bool, check_nan: _bool = True) -> None: ... 1318def is_anomaly_enabled() -> _bool: ... 1319def is_anomaly_check_nan_enabled() -> _bool: ... 1320def _is_multithreading_enabled() -> _bool: ... 1321def _set_multithreading_enabled(enabled: _bool) -> None: ... 1322def _set_view_replay_enabled(enabled: _bool) -> None: ... 1323def _is_view_replay_enabled() -> _bool: ... 1324def _enter_dual_level() -> _int: ... 1325def _exit_dual_level(level: _int) -> None: ... 1326def _make_dual(tensor: Tensor, tangent: Tensor, level: _int) -> Tensor: ... 1327def _unpack_dual(tensor: Tensor, level: _int) -> Tensor: ... 1328def __set_forward_AD_enabled(enabled: _bool) -> None: ... 1329def __is_forward_AD_enabled() -> _bool: ... 1330def _register_default_hooks(pack_hook: Callable, unpack_hook: Callable) -> None: ... 1331def _reset_default_hooks() -> None: ... 1332def _is_torch_function_mode_enabled() -> _bool: ... 1333def _set_torch_function_mode(cls: Any) -> None: ... 1334def _push_on_torch_function_stack(cls: Any) -> None: ... 1335def _pop_torch_function_stack() -> Any: ... 1336def _get_function_stack_at(idx: _int) -> Any: ... 1337def _len_torch_function_stack() -> _int: ... 1338def _set_torch_dispatch_mode(cls: Any) -> None: ... 1339def _push_on_torch_dispatch_stack(cls: TorchDispatchMode) -> None: ... 1340def _pop_torch_dispatch_stack(mode_key: Optional[torch._C._TorchDispatchModeKey] = None) -> Any: ... 1341def _get_dispatch_mode(mode_key: Optional[torch._C._TorchDispatchModeKey]) -> Any: ... 1342def _unset_dispatch_mode(mode: torch._C._TorchDispatchModeKey) -> Optional[TorchDispatchMode]: ... 1343def _set_dispatch_mode(mode: TorchDispatchMode) -> None: ... 1344def _get_dispatch_stack_at(idx: _int) -> Any: ... 1345def _len_torch_dispatch_stack() -> _int: ... 1346def _activate_gpu_trace() -> None: ... 1347 1348class _DisableTorchDispatch: 1349 def __init__(self): ... 1350 def __enter__(self): ... 1351 def __exit__(self, exc_type, exc_value, traceback): ... 1352 1353class _EnableTorchFunction: 1354 def __init__(self): ... 1355 def __enter__(self): ... 1356 def __exit__(self, exc_type, exc_value, traceback): ... 1357 1358class _EnablePythonDispatcher: 1359 def __init__(self): ... 1360 def __enter__(self): ... 1361 def __exit__(self, exc_type, exc_value, traceback): ... 1362 1363class _DisablePythonDispatcher: 1364 def __init__(self): ... 1365 def __enter__(self): ... 1366 def __exit__(self, exc_type, exc_value, traceback): ... 1367 1368class _EnablePreDispatch: 1369 def __init__(self): ... 1370 def __enter__(self): ... 1371 def __exit__(self, exc_type, exc_value, traceback): ... 1372 1373class _DisableFuncTorch: 1374 def __init__(self): ... 1375 def __enter__(self): ... 1376 def __exit__(self, exc_type, exc_value, traceback): ... 1377 1378class _DisableAutocast: 1379 def __init__(self): ... 1380 def __enter__(self): ... 1381 def __exit__(self, exc_type, exc_value, traceback): ... 1382 1383class _InferenceMode: 1384 def __init__(self, enabled: _bool): ... 1385 def __enter__(self): ... 1386 def __exit__(self, exc_type, exc_value, traceback): ... 1387 1388def _set_autograd_fallback_mode(mode: str) -> None: ... 1389def _get_autograd_fallback_mode() -> str: ... 1390 1391# Defined in torch/csrc/jit/python/script_init.cpp 1392class LoggerBase: ... 1393class NoopLogger(LoggerBase): ... 1394class LockingLogger(LoggerBase): ... 1395 1396class AggregationType(Enum): 1397 SUM = 0 1398 AVG = 1 1399 1400class FileCheck: 1401 def run(self, test_string: str) -> None: ... 1402 def check(self, test_string: str) -> FileCheck: ... 1403 def check_not(self, test_string: str) -> FileCheck: ... 1404 def check_same(self, test_string: str) -> FileCheck: ... 1405 def check_next(self, test_string: str) -> FileCheck: ... 1406 def check_count( 1407 self, 1408 test_string: str, 1409 count: _int, 1410 exactly: _bool = False, 1411 ) -> FileCheck: ... 1412 def check_dag(self, test_string: str) -> FileCheck: ... 1413 def check_source_highlighted(self, test_string: str) -> FileCheck: ... 1414 def check_regex(self, test_string: str) -> FileCheck: ... 1415 1416# Defined in torch/csrc/jit/python/init.cpp 1417class PyTorchFileReader: 1418 @overload 1419 def __init__(self, name: str) -> None: ... 1420 @overload 1421 def __init__(self, buffer: BinaryIO) -> None: ... 1422 def get_record(self, name: str) -> bytes: ... 1423 def serialization_id(self) -> str: ... 1424 1425class PyTorchFileWriter: 1426 @overload 1427 def __init__(self, name: str) -> None: ... 1428 @overload 1429 def __init__(self, buffer: BinaryIO) -> None: ... 1430 def write_record(self, name: str, data: Union[Storage, bytes, _int], size: _int) -> None: ... 1431 def write_end_of_file(self) -> None: ... 1432 def set_min_version(self, version: _int) -> None: ... 1433 def get_all_written_records(self) -> List[str]: ... 1434 def archive_name(self) -> str: ... 1435 def serialization_id(self) -> str: ... 1436 1437def _jit_get_inline_everything_mode() -> _bool: ... 1438def _jit_set_inline_everything_mode(enabled: _bool) -> None: ... 1439def _jit_get_logging_option() -> str: ... 1440def _jit_set_logging_option(option: str) -> None: ... 1441def _jit_set_logging_stream(stream_name: str) -> None: ... 1442def _jit_pass_cse(Graph) -> _bool: ... 1443def _jit_pass_dce(Graph) -> None: ... 1444def _jit_pass_lint(Graph) -> None: ... 1445 1446# Defined in torch/csrc/jit/python/python_custom_class.cpp 1447def _get_custom_class_python_wrapper(name: str, attr: str) -> Any: ... 1448 1449# Defined in torch/csrc/Module.cpp 1450def _rename_privateuse1_backend(backend: str) -> None: ... 1451def _get_privateuse1_backend_name() -> str: ... 1452 1453# Defined in torch/csrc/Generator.cpp 1454class Generator: 1455 device: _device 1456 def __init__(self, device: Optional[DeviceLikeType] = None) -> None: ... 1457 def __reduce__(self) -> Tuple[Type[Generator], Tuple[_device], Tuple[_int, Optional[_int], Tensor]]: ... 1458 def __setstate__(self, state: Tuple[_int, Optional[_int], Tensor]) -> None: ... 1459 def get_state(self) -> Tensor: ... 1460 def set_state(self, _new_state: Tensor) -> Generator: ... 1461 def clone_state(self) -> Generator: ... 1462 def graphsafe_get_state(self) -> Generator: ... 1463 def graphsafe_set_state(self, _new_state: Generator) -> Generator: ... 1464 def set_offset(self, offset: _int) -> Generator: ... 1465 def get_offset(self) -> _int: ... 1466 def manual_seed(self, seed: _int) -> Generator: ... 1467 def seed(self) -> _int: ... 1468 def initial_seed(self) -> _int: ... 1469 1470# Defined in torch/csrc/utils/python_dispatch.cpp 1471 1472class _DispatchOperatorHandle: 1473 def schema(self) -> FunctionSchema: ... 1474 def debug(self) -> str: ... 1475 1476class _DispatchModule: 1477 def def_(self, schema: str, alias: str = "") -> _DispatchModule: ... 1478 def def_legacy(self, schema: str) -> _DispatchModule: ... 1479 def def_name_t_t( 1480 self, 1481 name: str, 1482 dispatch: str, 1483 debug: str = "default_def_name_t_t", 1484 ) -> _DispatchModule: ... 1485 def def_schema_t_t( 1486 self, 1487 schema: str, 1488 dispatch: str, 1489 alias: str, 1490 debug: str = "default_def_schema_t_t", 1491 ) -> _DispatchModule: ... 1492 def impl_t_t( 1493 self, 1494 name: str, 1495 dispatch: str, 1496 debug: str = "impl_t_t", 1497 ) -> _DispatchModule: ... 1498 def impl(self, name: str, dispatch: str, func: Callable) -> _DispatchModule: ... 1499 def define(self, schema: str, alias: str = "") -> _DispatchModule: ... 1500 def fallback_fallthrough(self, dispatch: str = "") -> _DispatchModule: ... 1501 1502_after_ADInplaceOrView_keyset: DispatchKeySet 1503_after_autograd_keyset: DispatchKeySet 1504 1505def _dispatch_library( 1506 kind: str, 1507 name: str, 1508 dispatch: str, 1509 file: str = "", 1510 linenum: Any = 0, 1511) -> _DispatchModule: ... 1512def _dispatch_dump(name: str) -> str: ... 1513def _dispatch_dump_table(name: str) -> str: ... 1514def _dispatch_check_invariants(name: str) -> None: ... 1515def _dispatch_check_all_invariants() -> None: ... 1516def _dispatch_call_boxed(handle: _DispatchOperatorHandle, *args, **kwargs) -> Any: ... 1517def _dispatch_find_schema_or_throw(name: str, overload_name: str) -> _DispatchOperatorHandle: ... 1518def _dispatch_set_report_error_callback(handle: _DispatchOperatorHandle, callback: Callable) -> None: ... 1519def _dispatch_has_kernel(name: str) -> _bool: ... 1520def _dispatch_has_kernel_for_dispatch_key( 1521 name: str, 1522 dispatch: _dispatchkey, 1523) -> _bool: ... 1524def _dispatch_has_kernel_for_any_dispatch_key( 1525 name: str, 1526 dispatch_key_set: DispatchKeySet, 1527) -> _bool: ... 1528def _dispatch_kernel_for_dispatch_key_is_fallthrough( 1529 name: str, 1530 dispatch: _dispatchkey, 1531) -> _bool: ... 1532def _dispatch_has_computed_kernel_for_dispatch_key( 1533 name: str, 1534 dispatch: _dispatchkey, 1535) -> _bool: ... 1536def _dispatch_find_dangling_impls() -> List[str]: ... 1537def _dispatch_get_all_op_names() -> List[str]: ... 1538def _dispatch_tls_set_dispatch_key_excluded( 1539 dispatch: _dispatchkey, 1540 val: _bool, 1541) -> None: ... 1542def _dispatch_tls_is_dispatch_key_excluded(dispatch: _dispatchkey) -> _bool: ... 1543def _dispatch_tls_set_dispatch_key_included( 1544 dispatch: _dispatchkey, 1545 val: _bool, 1546) -> None: ... 1547def _dispatch_tls_is_dispatch_key_included(dispatch: _dispatchkey) -> _bool: ... 1548def _dispatch_isTensorSubclassLike(tensor: Tensor) -> _bool: ... 1549def _dispatch_key_name(dispatch: _dispatchkey) -> str: ... 1550def _dispatch_key_for_device(device_type: str) -> str: ... 1551def _parse_dispatch_key(key: str) -> Optional[DispatchKey]: ... 1552def _dispatch_key_parse(dispatch: _dispatchkey) -> DispatchKey: ... 1553def _dispatch_num_backends() -> _int: ... 1554def _dispatch_pystub(name: str, overload: str) -> Optional[Tuple[str, str]]: ... 1555def _dispatch_is_alias_key(dispatch: _dispatchkey) -> _bool: ... 1556def _functionality_to_backend_keys(dispatch: _dispatchkey) -> List[DispatchKey]: ... 1557def _functionalization_reapply_views_tls() -> _bool: ... 1558def _only_lift_cpu_tensors() -> _bool: ... 1559def _set_only_lift_cpu_tensors(value: _bool) -> None: ... 1560def _set_throw_on_mutable_data_ptr(tensor: Tensor) -> None: ... 1561def _set_warn_deprecated_on_mutable_data_ptr(tensor: Tensor) -> None: ... 1562 1563class DispatchKey(Enum): 1564 ${dispatch_key_hints} 1565 1566class DispatchKeySet: 1567 def __init__(self, key: DispatchKey) -> None: ... 1568 def __or__(self, other: DispatchKeySet) -> DispatchKeySet: ... 1569 def __sub__(self, other: DispatchKeySet) -> DispatchKeySet: ... 1570 def __and__(self, other: DispatchKeySet) -> DispatchKeySet: ... 1571 def highestPriorityTypeId(self) -> DispatchKey: ... 1572 def has(self, k: _dispatchkey) -> _bool: ... 1573 def add(self, k: _dispatchkey) -> DispatchKeySet: ... 1574 def remove(self, k: _dispatchkey) -> DispatchKeySet: ... 1575 def __repr__(self) -> str: ... 1576 1577_dispatch_autogradother_backends: DispatchKeySet 1578_additional_keys_to_prop_for_wrapper_tensors: DispatchKeySet 1579 1580def _dispatch_has_backend_fallback(dispatch: _dispatchkey) -> _bool: ... 1581def _dispatch_keyset_full_after(t: _dispatchkey) -> DispatchKeySet: ... 1582def _dispatch_keyset_full() -> DispatchKeySet: ... 1583def _dispatch_keyset_to_string(keyset: DispatchKeySet) -> str: ... 1584def _dispatch_get_backend_keyset_from_autograd( 1585 dispatch: _dispatchkey, 1586) -> DispatchKeySet: ... 1587def _dispatch_keys(tensor: Tensor) -> DispatchKeySet: ... 1588def _dispatch_tls_local_exclude_set() -> DispatchKeySet: ... 1589def _dispatch_tls_local_include_set() -> DispatchKeySet: ... 1590def _dispatch_is_included_in_alias( 1591 dispatch_a: _dispatchkey, 1592 dispatch_b: _dispatchkey, 1593) -> _bool: ... 1594def _propagate_xla_data(a: Tensor, b: Tensor) -> None: ... 1595def _replace_(a: Tensor, b: Tensor) -> None: ... 1596def _commit_update(a: Tensor) -> None: ... 1597 1598class _ExcludeDispatchKeyGuard: 1599 def __init__(self, keyset: DispatchKeySet): ... 1600 def __enter__(self): ... 1601 def __exit__(self, exc_type, exc_value, traceback): ... 1602 1603class _IncludeDispatchKeyGuard: 1604 def __init__(self, k: DispatchKey): ... 1605 def __enter__(self): ... 1606 def __exit__(self, exc_type, exc_value, traceback): ... 1607 1608class _ForceDispatchKeyGuard: 1609 def __init__(self, include: DispatchKeySet, exclude: DispatchKeySet): ... 1610 def __enter__(self): ... 1611 def __exit__(self, exc_type, exc_value, traceback): ... 1612 1613class _PreserveDispatchKeyGuard: 1614 def __init__(self): ... 1615 def __enter__(self): ... 1616 def __exit__(self, exc_type, exc_value, traceback): ... 1617 1618class _AutoDispatchBelowAutograd: 1619 def __init__(self): ... 1620 def __enter__(self): ... 1621 def __exit__(self, exc_type, exc_value, traceback): ... 1622 1623class _AutoDispatchBelowADInplaceOrView: 1624 def __init__(self): ... 1625 def __enter__(self): ... 1626 def __exit__(self, exc_type, exc_value, traceback): ... 1627 1628def _dispatch_print_registrations_for_dispatch_key(dispatch_key: str = "") -> None: ... 1629def _dispatch_get_registrations_for_dispatch_key( 1630 dispatch_key: str = "", 1631) -> List[str]: ... 1632def _are_functorch_transforms_active() -> _bool: ... 1633 1634# Define in torch/csrc/autograd/init.cpp 1635def _set_python_dispatcher(dispatcher: object) -> None: ... 1636 1637def _get_nested_int(id: _int, coeff: _int) -> SymInt: ... 1638 1639def _get_constant_bool_symnode(val: _bool) -> Any: ... 1640 1641class _TorchDispatchModeKey(Enum): 1642 ${torch_dispatch_mode_key_hints} 1643 1644class _SetExcludeDispatchKeyGuard: 1645 def __init__(self, k: DispatchKey, enabled: _bool): ... 1646 def __enter__(self): ... 1647 def __exit__(self, exc_type, exc_value, traceback): ... 1648 1649# Defined in torch/csrc/utils/init.cpp 1650class BenchmarkConfig: 1651 num_calling_threads: _int 1652 num_worker_threads: _int 1653 num_warmup_iters: _int 1654 num_iters: _int 1655 profiler_output_path: str 1656 1657class BenchmarkExecutionStats: 1658 latency_avg_ms: _float 1659 num_iters: _int 1660 1661class ThroughputBenchmark: 1662 def __init__(self, module: Any) -> None: ... 1663 def add_input(self, *args: Any, **kwargs: Any) -> None: ... 1664 def run_once(self, *args: Any, **kwargs: Any) -> Any: ... 1665 def benchmark(self, config: BenchmarkConfig) -> BenchmarkExecutionStats: ... 1666 1667# Defined in torch/csrc/Storage.cpp 1668${legacy_storage_base_hints} 1669 1670# TODO: where 1671${legacy_class_hints} 1672 1673# Defined in torch/csrc/autograd/python_engine.cpp 1674class _ImperativeEngine: 1675 def queue_callback(self, callback: Callable[[], None]) -> None: ... 1676 def run_backward(self, *args: Any, **kwargs: Any) -> Tuple[Tensor, ...]: ... 1677 def is_checkpoint_valid(self) -> _bool: ... 1678 1679# Defined in torch/csrc/autograd/python_variable.cpp 1680class _TensorMeta(type): ... 1681 1682# Defined in torch/csrc/autograd/python_variable.cpp 1683class TensorBase(metaclass=_TensorMeta): 1684 requires_grad: _bool 1685 retains_grad: _bool 1686 shape: Size 1687 data: Tensor 1688 names: List[str] 1689 device: _device 1690 dtype: _dtype 1691 layout: _layout 1692 real: Tensor 1693 imag: Tensor 1694 T: Tensor 1695 H: Tensor 1696 mT: Tensor 1697 mH: Tensor 1698 ndim: _int 1699 output_nr: _int 1700 _version: _int 1701 _base: Optional[Tensor] 1702 _cdata: _int 1703 grad_fn: Optional[_Node] 1704 _grad_fn: Any 1705 _grad: Optional[Tensor] 1706 grad: Optional[Tensor] 1707 _backward_hooks: Optional[Dict[_int, Callable[[Tensor], Optional[Tensor]]]] 1708 nbytes: _int 1709 itemsize: _int 1710 _has_symbolic_sizes_strides: _bool 1711 1712 def _view_func_unsafe( 1713 self, 1714 new_base: Tensor, 1715 symint_visitor_fn: Optional[Callable[[_int], _int]] = None, 1716 tensor_visitor_fn: Optional[Callable[[Tensor], Tensor]] = None 1717 ): 1718 ... 1719 1720 ${tensor_method_hints} 1721 1722_TensorBase = TensorBase 1723 1724# Defined in torch/csrc/multiprocessing/init.cpp 1725def _multiprocessing_init() -> None: ... 1726 1727# Defined in torch/csrc/Module.cpp 1728def _accelerator_hooks_device_count() -> _int: ... 1729def _accelerator_hooks_set_current_device(device_index: _int) -> None: ... 1730def _accelerator_hooks_get_current_device() -> _int: ... 1731def _accelerator_hooks_exchange_device(device_index: _int) -> _int: ... 1732def _accelerator_hooks_maybe_exchange_device(device_index: _int) -> _int: ... 1733def _get_accelerator(check: _bool = False) -> _device: ... 1734 1735# Defined in torch/csrc/mtia/Module.cpp 1736def _mtia_init() -> None: ... 1737def _mtia_isBuilt() -> _bool: ... 1738def _mtia_isInBadFork() -> _bool: ... 1739def _mtia_deviceSynchronize() -> None: ... 1740def _mtia_getCurrentStream(device: _int) -> Stream: ... 1741def _mtia_setCurrentStream(stream: Stream) -> None: ... 1742def _mtia_getDefaultStream(device: _int) -> Stream: ... 1743 1744 1745# Defined in torch/csrc/mps/Module.cpp 1746def _mps_deviceSynchronize() -> None: ... 1747def _mps_get_default_generator() -> Generator: ... 1748def _mps_emptyCache() -> None: ... 1749def _mps_setMemoryFraction(fraction: _float) -> None: ... 1750def _mps_currentAllocatedMemory() -> _int: ... 1751def _mps_driverAllocatedMemory() -> _int: ... 1752def _mps_is_available() -> _bool: ... 1753def _mps_is_on_macos_or_newer(major: _int, minor: _int) -> _bool: ... 1754def _mps_profilerStartTrace(mode: str, wait_until_completed: _bool) -> None: ... 1755def _mps_profilerStopTrace() -> None: ... 1756def _mps_acquireEvent(enable_timing: _bool) -> _int: ... 1757def _mps_releaseEvent(event_id: _int) -> None: ... 1758def _mps_recordEvent(event_id: _int) -> None: ... 1759def _mps_waitForEvent(event_id: _int) -> None: ... 1760def _mps_synchronizeEvent(event_id: _int) -> None: ... 1761def _mps_queryEvent(event_id: _int) -> _bool: ... 1762def _mps_elapsedTimeOfEvents(start_event_id: _int, end_event_id: _int) -> _float: ... 1763 1764 1765# Defined in torch/csrc/cuda/Module.cpp 1766def _cuda_getCurrentStream(device: _int) -> Tuple: ... 1767def _cuda_getCurrentRawStream(device: _int) -> _int: ... 1768def _cuda_getDefaultStream(device: _int) -> Tuple: ... 1769def _cuda_getCurrentBlasHandle() -> _int: ... 1770def _cuda_clearCublasWorkspaces() -> None: ... 1771def _cuda_setDevice(device: _int) -> None: ... 1772def _cuda_exchangeDevice(device: _int) -> _int: ... 1773def _cuda_maybeExchangeDevice(device: _int) -> _int: ... 1774def _cuda_getDevice() -> _int: ... 1775def _cuda_getDeviceCount() -> _int: ... 1776def _cuda_set_sync_debug_mode(warn_level: Union[_int, str]) -> None: ... 1777def _cuda_get_sync_debug_mode() -> _int: ... 1778def _cuda_sleep(cycles: _int) -> None: ... 1779def _cuda_synchronize() -> None: ... 1780def _cuda_ipc_collect() -> None: ... 1781def _cuda_getArchFlags() -> Optional[str]: ... 1782def _cuda_init() -> None: ... 1783def _cuda_setStream(stream_id: _int, device_index: _int, device_type: _int) -> None: ... 1784def _cuda_getCompiledVersion() -> _int: ... 1785def _cuda_cudaHostAllocator() -> _int: ... 1786def _cuda_cudaCachingAllocator_raw_alloc(size: _int, cuda_stream: _int) -> _int: ... 1787def _cuda_cudaCachingAllocator_raw_delete(ptr: _int) -> None: ... 1788def _cuda_cudaCachingAllocator_set_allocator_settings(env: str) -> None: ... 1789def _cuda_beginAllocateCurrentStreamToPool(device: _int, mempool_id: Tuple[_int, _int]) -> None: ... 1790def _cuda_endAllocateCurrentStreamToPool(device: _int, mempool_id: Tuple[_int, _int]) -> None: ... 1791def _cuda_releasePool(device: _int, mempool_id: Tuple[_int, _int]) -> None: ... 1792def _cuda_checkPoolLiveAllocations(device: _int, mempool_id: Tuple[_int, _int], expected_live_allocations: Set) -> _bool: ... 1793def _cuda_setCheckpointPoolState(device: _int, state: _cuda_CUDAAllocator_AllocatorState, stale_storages: List[_int], storages_to_add_deleters_to: List[_int]) -> None: ... 1794def _cuda_setMemoryFraction(fraction: _float, device: _int) -> None: ... 1795def _cuda_emptyCache() -> None: ... 1796def _cuda_memoryStats(device: _int) -> Dict[str, Any]: ... 1797def _cuda_resetAccumulatedMemoryStats(device: _int) -> None: ... 1798def _cuda_resetPeakMemoryStats(device: _int) -> None: ... 1799def _cuda_memorySnapshot() -> Dict[str, Any]: ... 1800def _cuda_record_memory_history_legacy( 1801 enabled: _bool, 1802 record_context: _bool, 1803 record_context_cpp: _bool, 1804 alloc_trace_max_entries: _int, 1805 alloc_trace_record_context: _bool, 1806) -> None: ... 1807def _cuda_record_memory_history( 1808 enabled: Optional[str], 1809 context: Optional[str], 1810 stacks: str, 1811 max_entries 1812) -> None: ... 1813def _cuda_isHistoryEnabled() -> _bool: ... 1814 1815def _cuda_getAllocatorBackend() -> str: ... 1816class _cuda_CUDAAllocator_AllocatorState: 1817 pass 1818def _cuda_getCheckpointState(device: _int, mempool: Tuple[_int, _int]) -> _cuda_CUDAAllocator_AllocatorState: ... 1819def _set_cached_tensors_enabled(enabled: _bool) -> None: ... 1820def _add_cached_tensor(t: Tensor) -> None: ... 1821def _remove_cached_tensor(t: Tensor) -> None: ... 1822def _tensors_data_ptrs_at_indices_equal(tensors: List[Tensor], ptrs: List[Optional[_int]], indices: List[_int]) -> _bool: ... 1823def _construct_CUDA_Tensor_From_Storage_And_Metadata(metadata: dict, storage: Storage) -> Tensor: ... 1824def _storage_Use_Count(storage_ptr: _int) -> _int: ... 1825def _set_storage_access_error_msg(t: Tensor, s: str) -> None: ... 1826def _free_And_Remove_DeleterFn(storage_ptr: _int) -> None: ... 1827def _has_Standard_Deleter(storage_ptr: _int) -> _bool: ... 1828 1829class _cuda_CUDAAllocator: ... 1830 1831def _cuda_customAllocator(alloc_fn: _int, free_fn: _int) -> _cuda_CUDAAllocator: ... 1832def _cuda_changeCurrentAllocator(allocator: _cuda_CUDAAllocator) -> None: ... 1833def _cuda_getAllocator() -> _cuda_CUDAAllocator: ... 1834def _cuda_lock_mutex() -> None: ... 1835def _cuda_unlock_mutex() -> None: ... 1836def _cuda_canDeviceAccessPeer(device: _int, peer_device: _int) -> _bool: ... 1837def _cuda_jiterator_compile_and_launch_kernel( 1838 code_string: str, 1839 kernel_name: str, 1840 return_by_ref: _bool, 1841 num_outputs: _int, 1842 tensors: Tuple, 1843 kwargs: Dict[str, Union[_int, _float, _bool]], 1844) -> Tensor: ... 1845def _cuda_get_cudnn_benchmark_limit() -> _int: ... 1846def _cuda_set_cudnn_benchmark_limit(arg: _int) -> None: ... 1847def _cuda_get_conv_benchmark_empty_cache() -> _bool: ... 1848def _cudnn_set_conv_benchmark_empty_cache(enable: _bool) -> None: ... 1849def _nccl_version() -> _int: ... 1850def _nccl_version_suffix() -> bytes : ... 1851def _nccl_unique_id() -> bytes: ... 1852def _nccl_init_rank(nranks: _int, comm_id: bytes, rank: _int) -> object: ... 1853def _nccl_reduce( 1854 input: Sequence[Tensor], 1855 output: Tensor, 1856 root: _int, 1857 op: _int, 1858 streams: Optional[Sequence[_CudaStreamBase]], 1859 comms: Optional[Sequence[object]], 1860) -> None: ... 1861def _nccl_all_reduce( 1862 input: Sequence[Tensor], 1863 output: Sequence[Tensor], 1864 op: _int, 1865 streams: Optional[Sequence[_CudaStreamBase]], 1866 comms: Optional[Sequence[object]], 1867) -> None: ... 1868def _nccl_broadcast( 1869 input: Sequence[Tensor], 1870 root: _int, 1871 streams: Optional[Sequence[_CudaStreamBase]], 1872 comms: Optional[Sequence[object]], 1873) -> None: ... 1874def _nccl_all_gather( 1875 input: Sequence[Tensor], 1876 output: Sequence[Tensor], 1877 streams: Optional[Sequence[_CudaStreamBase]], 1878 comms: Optional[Sequence[object]], 1879) -> None: ... 1880def _nccl_reduce_scatter( 1881 input: Sequence[Tensor], 1882 output: Sequence[Tensor], 1883 op: _int, 1884 streams: Optional[Sequence[_CudaStreamBase]], 1885 comms: Optional[Sequence[object]], 1886) -> None: ... 1887def _rocm_is_backward_pass() -> _bool: ... 1888def _cuda_tunableop_enable(val: _bool) -> None: ... 1889def _cuda_tunableop_is_enabled() -> _bool: ... 1890def _cuda_tunableop_tuning_enable(val: _bool) -> None: ... 1891def _cuda_tunableop_tuning_is_enabled() -> _bool: ... 1892def _cuda_tunableop_set_max_tuning_duration(duration: _int) -> None: ... 1893def _cuda_tunableop_get_max_tuning_duration() -> _int: ... 1894def _cuda_tunableop_set_max_tuning_iterations(iterations: _int) -> None: ... 1895def _cuda_tunableop_get_max_tuning_iterations() -> _int: ... 1896def _cuda_tunableop_set_filename(filename: str, insert_device_ordinal: Optional[_bool]) -> None: ... 1897def _cuda_tunableop_get_filename() -> str: ... 1898def _cuda_tunableop_write_file(filename: Optional[str]) -> _bool: ... 1899def _cuda_tunableop_read_file(filename: Optional[str]) -> _bool: ... 1900def _cuda_tunableop_write_file_on_exit(val: _bool) -> None: ... 1901def _cuda_tunableop_get_results() -> Tuple[str, str, str, _float]: ... 1902def _cuda_tunableop_get_validators() -> Tuple[str, str]: ... 1903 1904class _CudaDeviceProperties: 1905 name: str 1906 major: _int 1907 minor: _int 1908 multi_processor_count: _int 1909 total_memory: _int 1910 is_integrated: _int 1911 is_multi_gpu_board: _int 1912 max_threads_per_multi_processor: _int 1913 gcnArchName: str 1914 1915# Functions related to SDPA 1916class _SDPAParams: 1917 query: Tensor 1918 key: Tensor 1919 value: Tensor 1920 attn_mask: Optional[Tensor] 1921 dropout: _float 1922 is_causal: _bool 1923 def __init__( 1924 self, 1925 query: Tensor, 1926 key: Tensor, 1927 value: Tensor, 1928 attn_mask: Optional[Tensor], 1929 dropout: _float, 1930 is_causal: _bool) -> None: ... 1931 1932class _SDPBackend(Enum): 1933 ERROR = -1 1934 MATH = 0 1935 FLASH_ATTENTION = 1 1936 EFFICIENT_ATTENTION = 2 1937 CUDNN_ATTENTION = 3 1938 1939def _can_use_flash_attention(params: _SDPAParams, debug: _bool) -> _bool: ... 1940def _can_use_mem_efficient_attention(params: _SDPAParams, debug: _bool) -> _bool: ... 1941 1942# Defined in torch/csrc/cuda/python_comm.cpp 1943def _broadcast(tensor: Tensor, devices: List[_int]) -> List[Tensor]: ... 1944def _broadcast_out(tensor: Tensor, out_tensors: List[Tensor]) -> List[Tensor]: ... 1945def _broadcast_coalesced( 1946 tensors: List[Tensor], 1947 devices: List[_int], 1948 buffer_size: _int, 1949) -> List[List[Tensor]]: ... 1950def _scatter( 1951 tensor: Tensor, 1952 devices: List[_int], 1953 chunk_sizes: Optional[List[_int]], 1954 dim: _int, 1955 streams: Optional[List[Stream]], 1956) -> List[Tensor]: ... 1957def _scatter_out( 1958 tensor: Tensor, 1959 out_tensors: List[Tensor], 1960 dim: _int, 1961 streams: Optional[List[Stream]], 1962) -> List[Tensor]: ... 1963def _gather( 1964 tensors: List[Tensor], 1965 dim: _int, 1966 destination_index: Optional[_int], 1967) -> Tensor: ... 1968def _gather_out(tensors: List[Tensor], out_tensor: Tensor, dim: _int) -> Tensor: ... 1969 1970# Defined in torch/csrc/cuda/Stream.cpp 1971class _CudaStreamBase(Stream): 1972 stream_id: _int 1973 device_index: _int 1974 device_type: _int 1975 1976 device: _device 1977 cuda_stream: _int 1978 priority: _int 1979 1980 def __new__( 1981 self, 1982 priority: _int = 0, 1983 stream_id: _int = 0, 1984 device_index: _int = 0, 1985 stream_ptr: _int = 0, 1986 ) -> _CudaStreamBase: ... 1987 def query(self) -> _bool: ... 1988 def synchronize(self) -> None: ... 1989 def priority_range(self) -> Tuple[_int, _int]: ... 1990 1991# Defined in torch/csrc/cuda/Event.cpp 1992class _CudaEventBase: 1993 device: _device 1994 cuda_event: _int 1995 1996 def __new__( 1997 cls, 1998 enable_timing: _bool = False, 1999 blocking: _bool = False, 2000 interprocess: _bool = False, 2001 ) -> _CudaEventBase: ... 2002 @classmethod 2003 def from_ipc_handle(cls, device: _device, ipc_handle: bytes) -> _CudaEventBase: ... 2004 def record(self, stream: _CudaStreamBase) -> None: ... 2005 def wait(self, stream: _CudaStreamBase) -> None: ... 2006 def query(self) -> _bool: ... 2007 def elapsed_time(self, other: _CudaEventBase) -> _float: ... 2008 def synchronize(self) -> None: ... 2009 def ipc_handle(self) -> bytes: ... 2010 2011# Defined in torch/csrc/cuda/Graph.cpp 2012class _CUDAGraph: 2013 def capture_begin(self, pool: Optional[Tuple[_int, _int]] = ..., capture_error_mode: str = "global") -> None: ... 2014 def capture_end(self) -> None: ... 2015 def register_generator_state(self, Generator) -> None: ... 2016 def replay(self) -> None: ... 2017 def reset(self) -> None: ... 2018 def pool(self) -> Tuple[_int, _int]: ... 2019 def enable_debug_mode(self) -> None: ... 2020 def debug_dump(self, debug_path: str) -> None: ... 2021 2022def _cuda_isCurrentStreamCapturing() -> _bool: ... 2023def _graph_pool_handle() -> Tuple[_int, _int]: ... 2024 2025# Defined in torch/csrc/xpu/Module.cpp 2026def _xpu_setDevice(device: _int) -> None: ... 2027def _xpu_exchangeDevice(device: _int) -> _int: ... 2028def _xpu_maybeExchangeDevice(device: _int) -> _int: ... 2029def _xpu_getDevice() -> _int: ... 2030def _xpu_getDeviceCount() -> _int: ... 2031def _xpu_init() -> None: ... 2032def _xpu_setStream(stream_id: _int, device_index: _int, device_type: _int) -> None: ... 2033def _xpu_getCurrentStream(device: _int) -> Tuple: ... 2034def _xpu_getCurrentRawStream(device: _int) -> _int: ... 2035def _xpu_synchronize(device: _int) -> None: ... 2036def _xpu_emptyCache() -> None: ... 2037 2038class _XpuDeviceProperties: 2039 name: str 2040 platform_name: str 2041 vendor: str 2042 driver_version: str 2043 version: str 2044 total_memory: _int 2045 max_compute_units: _int 2046 gpu_eu_count: _int 2047 gpu_subslice_count: _int 2048 max_work_group_size: _int 2049 max_num_sub_groups: _int 2050 sub_group_sizes: List[_int] 2051 has_fp16: _bool 2052 has_fp64: _bool 2053 has_atomic64: _bool 2054 type: str 2055 2056# Defined in torch/csrc/xpu/Stream.cpp 2057class _XpuStreamBase(Stream): 2058 stream_id: _int 2059 device_index: _int 2060 device_type: _int 2061 2062 device: _device 2063 sycl_queue: _int 2064 priority: _int 2065 2066 def __new__( 2067 cls, 2068 priority: _int = 0, 2069 stream_id: _int = 0, 2070 device_index: _int = 0, 2071 device_type: _int = 0, 2072 ) -> _XpuStreamBase: ... 2073 def query(self) -> _bool: ... 2074 def synchronize(self) -> None: ... 2075 @staticmethod 2076 def priority_range() -> Tuple: ... 2077 2078# Defined in torch/csrc/xpu/Event.cpp 2079class _XpuEventBase: 2080 device: _device 2081 sycl_event: _int 2082 2083 def __new__(cls, enable_timing: _bool = False) -> _XpuEventBase: ... 2084 def record(self, stream: _XpuEventBase) -> None: ... 2085 def wait(self, stream: _XpuStreamBase) -> None: ... 2086 def query(self) -> _bool: ... 2087 def elapsed_time(self, other: _XpuEventBase) -> _float: ... 2088 def synchronize(self) -> None: ... 2089 2090# Defined in torch/csrc/DataLoader.cpp 2091def _set_worker_signal_handlers( 2092 *arg: Any, 2093) -> None: ... # THPModule_setWorkerSignalHandlers 2094def _set_worker_pids( 2095 key: _int, 2096 child_pids: Tuple[_int, ...], 2097) -> None: ... # THPModule_setWorkerPIDs 2098def _remove_worker_pids(loader_id: _int) -> None: ... # THPModule_removeWorkerPIDs 2099def _error_if_any_worker_fails() -> None: ... # THPModule_errorIfAnyWorkerFails 2100 2101# Defined in torch/csrc/jit/python/python_tracer.cpp 2102class TracingState: 2103 def push_scope(self, scope_name: str) -> None: ... 2104 def pop_scope(self) -> None: ... 2105 def current_scope(self) -> str: ... 2106 def set_graph(self, graph: Graph) -> None: ... 2107 def graph(self) -> Graph: ... 2108 2109def _create_graph_by_tracing( 2110 func: Callable[..., Any], 2111 inputs: Any, 2112 var_name_lookup_fn: Callable[[Tensor], str], 2113 strict: Any, 2114 force_outplace: Any, 2115 self: Any = None, 2116 argument_names: List[str] = [], 2117) -> Tuple[Graph, Stack]: ... 2118def _tracer_warn_use_python(): ... 2119def _get_tracing_state() -> TracingState: ... 2120 2121# Defined in torch/csrc/jit/python/python_ir.cpp 2122# Not actually defined in python_ir.cpp, not sure where they are. 2123class IValue: ... 2124 2125Stack = List[IValue] 2126 2127class JitType: 2128 annotation_str: str 2129 def isSubtypeOf(self, other: JitType) -> _bool: ... 2130 def with_dtype(self, dtype: _dtype) -> JitType: ... 2131 def with_sizes(self, sizes: List[Optional[_int]]) -> JitType: ... 2132 def kind(self) -> str: ... 2133 def scalarType(self) -> Optional[str]: ... 2134 def getElementType(self) -> JitType: ... 2135 def dtype(self) -> Optional[_dtype]: ... 2136 2137class InferredType: 2138 def __init__(self, arg: Union[JitType, str]): ... 2139 def type(self) -> JitType: ... 2140 def success(self) -> _bool: ... 2141 def reason(self) -> str: ... 2142 2143R = TypeVar("R", bound=JitType) 2144 2145class AnyType(JitType): 2146 @staticmethod 2147 def get() -> AnyType: ... 2148 2149class NoneType(JitType): 2150 @staticmethod 2151 def get() -> NoneType: ... 2152 2153class BoolType(JitType): 2154 @staticmethod 2155 def get() -> BoolType: ... 2156 2157class FloatType(JitType): 2158 @staticmethod 2159 def get() -> FloatType: ... 2160 2161class ComplexType(JitType): 2162 @staticmethod 2163 def get() -> ComplexType: ... 2164 2165class IntType(JitType): 2166 @staticmethod 2167 def get() -> IntType: ... 2168 2169class SymIntType(JitType): 2170 @staticmethod 2171 def get() -> SymIntType: ... 2172 2173class SymBoolType(JitType): 2174 @staticmethod 2175 def get() -> SymBoolType: ... 2176 2177class NumberType(JitType): 2178 @staticmethod 2179 def get() -> NumberType: ... 2180 2181class StringType(JitType): 2182 @staticmethod 2183 def get() -> StringType: ... 2184 2185class DeviceObjType(JitType): 2186 @staticmethod 2187 def get() -> DeviceObjType: ... 2188 2189class _GeneratorType(JitType): 2190 @staticmethod 2191 def get() -> _GeneratorType: ... 2192 2193class StreamObjType(JitType): 2194 @staticmethod 2195 def get() -> StreamObjType: ... 2196 2197class ListType(JitType): 2198 def __init__(self, a: JitType) -> None: ... 2199 def getElementType(self) -> JitType: ... 2200 @staticmethod 2201 def ofInts() -> ListType: ... 2202 @staticmethod 2203 def ofTensors() -> ListType: ... 2204 @staticmethod 2205 def ofFloats() -> ListType: ... 2206 @staticmethod 2207 def ofComplexDoubles() -> ListType: ... 2208 @staticmethod 2209 def ofBools() -> ListType: ... 2210 @staticmethod 2211 def ofStrings() -> ListType: ... 2212 2213class DictType(JitType): 2214 def __init__(self, key: JitType, value: JitType) -> None: ... 2215 def getKeyType(self) -> JitType: ... 2216 def getValueType(self) -> JitType: ... 2217 2218class TupleType(JitType): 2219 def __init__(self, a: List[Optional[JitType]]) -> None: ... 2220 def elements(self) -> List[JitType]: ... 2221 2222class UnionType(JitType): 2223 def __init__(self, a: List[JitType]) -> None: ... 2224 2225class ClassType(JitType): 2226 def __init__(self, qualified_name: str) -> None: ... 2227 2228class InterfaceType(JitType): 2229 def __init__(self, qualified_name: str) -> None: ... 2230 def getMethod(self, name: str) -> Optional[FunctionSchema]: ... 2231 def getMethodNames(self) -> List[str]: ... 2232 2233class OptionalType(JitType, Generic[R]): 2234 def __init__(self, a: JitType) -> None: ... 2235 def getElementType(self) -> JitType: ... 2236 @staticmethod 2237 def ofTensor() -> OptionalType: ... 2238 2239class FutureType(JitType): 2240 def __init__(self, a: JitType) -> None: ... 2241 def getElementType(self) -> JitType: ... 2242 2243class AwaitType(JitType): 2244 def __init__(self, a: JitType) -> None: ... 2245 def getElementType(self) -> JitType: ... 2246 2247class RRefType(JitType): 2248 def __init__(self, a: JitType) -> None: ... 2249 2250class EnumType(JitType): 2251 def __init__( 2252 self, 2253 qualified_name: str, 2254 value_type: JitType, 2255 enum_names_values: List[Any], 2256 ) -> None: ... 2257 2258class TensorType(JitType): 2259 @classmethod 2260 def get(cls) -> TensorType: ... 2261 @classmethod 2262 def getInferred(cls) -> TensorType: ... 2263 def with_sizes(self, other: Optional[List[Optional[_int]]]) -> TensorType: ... 2264 def sizes(self) -> Optional[List[_int]]: ... 2265 def varyingSizes(self) -> Optional[List[Optional[_int]]]: ... 2266 def strides(self) -> Optional[List[_int]]: ... 2267 def device(self) -> Optional[_device]: ... 2268 def dim(self) -> _int: ... 2269 def dtype(self) -> Optional[_dtype]: ... 2270 @staticmethod 2271 def create_from_tensor(t: Tensor) -> TensorType: ... 2272 2273# Defined in torch/csrc/jit/python/python_tree_views.cpp 2274class SourceRange: ... 2275class TreeView: ... 2276 2277class Ident(TreeView): 2278 @property 2279 def name(self) -> str: ... 2280 2281class ClassDef(TreeView): ... 2282 2283class Def(TreeView): 2284 def name(self) -> Ident: ... 2285 2286class Decl(TreeView): ... 2287 2288# Defined in torch/csrc/distributed/rpc/init.cpp 2289def _rpc_init() -> _bool: ... 2290 2291# Defined in torch/csrc/distributed/autograd/init.cpp 2292def _dist_autograd_init() -> _bool: ... 2293 2294# Defined in torch/csrc/distributed/c10d/init.cpp 2295def _c10d_init() -> _bool: ... 2296 2297# Defined in torch/csrc/distributed/rpc/testing/init.cpp 2298def _faulty_agent_init() -> _bool: ... 2299def _register_py_class_for_device(device: str, cls: Any) -> None: ... 2300 2301# Defined in torch/csrc/Module.cpp 2302def _current_graph_task_id() -> _int: ... 2303def _current_autograd_node() -> _Node: ... 2304def _dispatch_key_set(Tensor) -> str: ... 2305 2306# Defined in torch/csrc/Exceptions.cpp 2307class OutOfMemoryError(RuntimeError): ... 2308class _DistError(RuntimeError): ... 2309class _DistBackendError(RuntimeError): ... 2310class _DistStoreError(RuntimeError): ... 2311class _DistNetworkError(RuntimeError): ... 2312 2313# Defined in torch/csrc/profiler/init.cpp 2314class CapturedTraceback: 2315 pass 2316def gather_traceback(python: _bool, script: _bool, cpp: _bool) -> CapturedTraceback: ... 2317def symbolize_tracebacks(tracebacks: List[CapturedTraceback]) -> List[Dict[str, Any]]: ... 2318 2319def _load_mobile_module_from_file(filename: str): ... 2320def _load_mobile_module_from_bytes(bytes_: bytes): ... 2321def _load_jit_module_from_file(filename: str): ... 2322def _load_jit_module_from_bytes(bytes_: bytes): ... 2323def _save_mobile_module(m: LiteScriptModule, filename: str): ... 2324def _save_jit_module(m: ScriptModule, filename: str, extra_files: Dict[str, Any]): ... 2325def _save_mobile_module_to_bytes(m: LiteScriptModule) -> bytes: ... 2326def _save_jit_module_to_bytes(m: ScriptModule, extra_files: Dict[str, Any]) -> bytes: ... 2327def _get_module_info_from_flatbuffer(data: bytes): ... 2328def _jit_resolve_packet(op_name: str, *args, **kwargs) -> str: ... 2329def _swap_tensor_impl(t1: Tensor, t2: Tensor): ... 2330def _save_pickle(obj: Any) -> bytes: ... 2331 2332# Defined in torch/csrc/jit/runtime/static/init.cpp 2333def _jit_to_static_module(graph_or_module: Union[Graph,ScriptModule]) -> Any: ... 2334def _fuse_to_static_module(graph_or_module: Union[Graph,ScriptModule], min_size: _int) -> Any: ... 2335 2336# Defined in torch/csrc/fx/node.cpp 2337class _NodeBase: 2338 _erased: _bool 2339 _prev: "_NodeBase" 2340 _next: "_NodeBase" 2341 2342class _NodeIter(Iterator): 2343 def __init__(self, root: _NodeBase, reversed: _bool) -> None: ... 2344 def __iter__(self) -> Iterator[_NodeBase]: ... 2345 def __next__(self) -> _NodeBase: ... 2346