xref: /aosp_15_r20/external/pytorch/torch/_C/_autograd.pyi (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# mypy: allow-untyped-defs
2from enum import Enum
3from typing import Any, Callable
4
5import torch
6from torch._C._profiler import (
7    _ProfilerEvent,
8    ActiveProfilerType,
9    ProfilerActivity,
10    ProfilerConfig,
11)
12
13# Defined in torch/csrc/autograd/init.cpp
14
15class DeviceType(Enum):
16    CPU = ...
17    CUDA = ...
18    XPU = ...
19    MKLDNN = ...
20    OPENGL = ...
21    OPENCL = ...
22    IDEEP = ...
23    HIP = ...
24    FPGA = ...
25    MAIA = ...
26    XLA = ...
27    MTIA = ...
28    MPS = ...
29    HPU = ...
30    Meta = ...
31    Vulkan = ...
32    Metal = ...
33    PrivateUse1 = ...
34
35class ProfilerEvent:
36    def cpu_elapsed_us(self, other: ProfilerEvent) -> float: ...
37    def cpu_memory_usage(self) -> int: ...
38    def cuda_elapsed_us(self, other: ProfilerEvent) -> float: ...
39    def privateuse1_elapsed_us(self, other: ProfilerEvent) -> float: ...
40    def cuda_memory_usage(self) -> int: ...
41    def device(self) -> int: ...
42    def handle(self) -> int: ...
43    def has_cuda(self) -> bool: ...
44    def is_remote(self) -> bool: ...
45    def kind(self) -> int: ...
46    def name(self) -> str: ...
47    def node_id(self) -> int: ...
48    def sequence_nr(self) -> int: ...
49    def shapes(self) -> list[list[int]]: ...
50    def thread_id(self) -> int: ...
51    def flops(self) -> float: ...
52    def is_async(self) -> bool: ...
53
54class _KinetoEvent:
55    def name(self) -> str: ...
56    def device_index(self) -> int: ...
57    def device_resource_id(self) -> int: ...
58    def start_ns(self) -> int: ...
59    def end_ns(self) -> int: ...
60    def duration_ns(self) -> int: ...
61    def is_async(self) -> bool: ...
62    def linked_correlation_id(self) -> int: ...
63    def shapes(self) -> list[list[int]]: ...
64    def dtypes(self) -> list[str]: ...
65    def concrete_inputs(self) -> list[Any]: ...
66    def kwinputs(self) -> dict[str, Any]: ...
67    def device_type(self) -> DeviceType: ...
68    def start_thread_id(self) -> int: ...
69    def end_thread_id(self) -> int: ...
70    def correlation_id(self) -> int: ...
71    def fwd_thread_id(self) -> int: ...
72    def stack(self) -> list[str]: ...
73    def scope(self) -> int: ...
74    def sequence_nr(self) -> int: ...
75    def flops(self) -> int: ...
76    def cuda_elapsed_us(self) -> int: ...
77    def privateuse1_elapsed_us(self) -> int: ...
78    def is_user_annotation(self) -> bool: ...
79
80class _ProfilerResult:
81    def events(self) -> list[_KinetoEvent]: ...
82    def legacy_events(self) -> list[list[ProfilerEvent]]: ...
83    def save(self, path: str) -> None: ...
84    def experimental_event_tree(self) -> list[_ProfilerEvent]: ...
85    def trace_start_ns(self) -> int: ...
86
87class SavedTensor: ...
88
89def _enable_profiler(
90    config: ProfilerConfig,
91    activities: set[ProfilerActivity],
92) -> None: ...
93def _prepare_profiler(
94    config: ProfilerConfig,
95    activities: set[ProfilerActivity],
96) -> None: ...
97def _toggle_collection_dynamic(
98    enable: bool,
99    activities: set[ProfilerActivity],
100) -> None: ...
101def _disable_profiler() -> _ProfilerResult: ...
102def _profiler_enabled() -> bool: ...
103def _add_metadata_json(key: str, value: str) -> None: ...
104def _kineto_step() -> None: ...
105def _get_current_graph_task_keep_graph() -> bool: ...
106def _get_sequence_nr() -> int: ...
107def kineto_available() -> bool: ...
108def _record_function_with_args_enter(name: str, *args) -> torch.Tensor: ...
109def _record_function_with_args_exit(handle: torch.Tensor) -> None: ...
110def _supported_activities() -> set[ProfilerActivity]: ...
111def _enable_record_function(enable: bool) -> None: ...
112def _set_empty_test_observer(is_global: bool, sampling_prob: float) -> None: ...
113def _push_saved_tensors_default_hooks(
114    pack_hook: Callable[[torch.Tensor], Any],
115    unpack_hook: Callable[[Any], torch.Tensor],
116) -> None: ...
117def _pop_saved_tensors_default_hooks() -> None: ...
118def _unsafe_set_version_counter(t: torch.Tensor, prev_version: int) -> None: ...
119def _enable_profiler_legacy(config: ProfilerConfig) -> None: ...
120def _disable_profiler_legacy() -> list[list[ProfilerEvent]]: ...
121def _profiler_type() -> ActiveProfilerType: ...
122def _saved_tensors_hooks_enable() -> None: ...
123def _saved_tensors_hooks_disable(message: str) -> None: ...
124def _saved_tensors_hooks_get_disabled_error_message() -> str | None: ...
125def _saved_tensors_hooks_set_tracing(is_tracing: bool) -> bool: ...
126
127class CreationMeta(Enum):
128    DEFAULT = ...
129    IN_CUSTOM_FUNCTION = ...
130    MULTI_OUTPUT_NODE = ...
131    NO_GRAD_MODE = ...
132    INFERENCE_MODE = ...
133
134def _set_creation_meta(t: torch.Tensor, creation_meta: CreationMeta) -> None: ...
135def _get_creation_meta(t: torch.Tensor) -> CreationMeta: ...
136