1# mypy: allow-untyped-defs 2from enum import Enum 3from typing import Any, Callable 4 5import torch 6from torch._C._profiler import ( 7 _ProfilerEvent, 8 ActiveProfilerType, 9 ProfilerActivity, 10 ProfilerConfig, 11) 12 13# Defined in torch/csrc/autograd/init.cpp 14 15class DeviceType(Enum): 16 CPU = ... 17 CUDA = ... 18 XPU = ... 19 MKLDNN = ... 20 OPENGL = ... 21 OPENCL = ... 22 IDEEP = ... 23 HIP = ... 24 FPGA = ... 25 MAIA = ... 26 XLA = ... 27 MTIA = ... 28 MPS = ... 29 HPU = ... 30 Meta = ... 31 Vulkan = ... 32 Metal = ... 33 PrivateUse1 = ... 34 35class ProfilerEvent: 36 def cpu_elapsed_us(self, other: ProfilerEvent) -> float: ... 37 def cpu_memory_usage(self) -> int: ... 38 def cuda_elapsed_us(self, other: ProfilerEvent) -> float: ... 39 def privateuse1_elapsed_us(self, other: ProfilerEvent) -> float: ... 40 def cuda_memory_usage(self) -> int: ... 41 def device(self) -> int: ... 42 def handle(self) -> int: ... 43 def has_cuda(self) -> bool: ... 44 def is_remote(self) -> bool: ... 45 def kind(self) -> int: ... 46 def name(self) -> str: ... 47 def node_id(self) -> int: ... 48 def sequence_nr(self) -> int: ... 49 def shapes(self) -> list[list[int]]: ... 50 def thread_id(self) -> int: ... 51 def flops(self) -> float: ... 52 def is_async(self) -> bool: ... 53 54class _KinetoEvent: 55 def name(self) -> str: ... 56 def device_index(self) -> int: ... 57 def device_resource_id(self) -> int: ... 58 def start_ns(self) -> int: ... 59 def end_ns(self) -> int: ... 60 def duration_ns(self) -> int: ... 61 def is_async(self) -> bool: ... 62 def linked_correlation_id(self) -> int: ... 63 def shapes(self) -> list[list[int]]: ... 64 def dtypes(self) -> list[str]: ... 65 def concrete_inputs(self) -> list[Any]: ... 66 def kwinputs(self) -> dict[str, Any]: ... 67 def device_type(self) -> DeviceType: ... 68 def start_thread_id(self) -> int: ... 69 def end_thread_id(self) -> int: ... 70 def correlation_id(self) -> int: ... 71 def fwd_thread_id(self) -> int: ... 72 def stack(self) -> list[str]: ... 73 def scope(self) -> int: ... 74 def sequence_nr(self) -> int: ... 75 def flops(self) -> int: ... 76 def cuda_elapsed_us(self) -> int: ... 77 def privateuse1_elapsed_us(self) -> int: ... 78 def is_user_annotation(self) -> bool: ... 79 80class _ProfilerResult: 81 def events(self) -> list[_KinetoEvent]: ... 82 def legacy_events(self) -> list[list[ProfilerEvent]]: ... 83 def save(self, path: str) -> None: ... 84 def experimental_event_tree(self) -> list[_ProfilerEvent]: ... 85 def trace_start_ns(self) -> int: ... 86 87class SavedTensor: ... 88 89def _enable_profiler( 90 config: ProfilerConfig, 91 activities: set[ProfilerActivity], 92) -> None: ... 93def _prepare_profiler( 94 config: ProfilerConfig, 95 activities: set[ProfilerActivity], 96) -> None: ... 97def _toggle_collection_dynamic( 98 enable: bool, 99 activities: set[ProfilerActivity], 100) -> None: ... 101def _disable_profiler() -> _ProfilerResult: ... 102def _profiler_enabled() -> bool: ... 103def _add_metadata_json(key: str, value: str) -> None: ... 104def _kineto_step() -> None: ... 105def _get_current_graph_task_keep_graph() -> bool: ... 106def _get_sequence_nr() -> int: ... 107def kineto_available() -> bool: ... 108def _record_function_with_args_enter(name: str, *args) -> torch.Tensor: ... 109def _record_function_with_args_exit(handle: torch.Tensor) -> None: ... 110def _supported_activities() -> set[ProfilerActivity]: ... 111def _enable_record_function(enable: bool) -> None: ... 112def _set_empty_test_observer(is_global: bool, sampling_prob: float) -> None: ... 113def _push_saved_tensors_default_hooks( 114 pack_hook: Callable[[torch.Tensor], Any], 115 unpack_hook: Callable[[Any], torch.Tensor], 116) -> None: ... 117def _pop_saved_tensors_default_hooks() -> None: ... 118def _unsafe_set_version_counter(t: torch.Tensor, prev_version: int) -> None: ... 119def _enable_profiler_legacy(config: ProfilerConfig) -> None: ... 120def _disable_profiler_legacy() -> list[list[ProfilerEvent]]: ... 121def _profiler_type() -> ActiveProfilerType: ... 122def _saved_tensors_hooks_enable() -> None: ... 123def _saved_tensors_hooks_disable(message: str) -> None: ... 124def _saved_tensors_hooks_get_disabled_error_message() -> str | None: ... 125def _saved_tensors_hooks_set_tracing(is_tracing: bool) -> bool: ... 126 127class CreationMeta(Enum): 128 DEFAULT = ... 129 IN_CUSTOM_FUNCTION = ... 130 MULTI_OUTPUT_NODE = ... 131 NO_GRAD_MODE = ... 132 INFERENCE_MODE = ... 133 134def _set_creation_meta(t: torch.Tensor, creation_meta: CreationMeta) -> None: ... 135def _get_creation_meta(t: torch.Tensor) -> CreationMeta: ... 136