xref: /aosp_15_r20/external/pytorch/torch/cpu/__init__.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# mypy: allow-untyped-defs
2r"""
3This package implements abstractions found in ``torch.cuda``
4to facilitate writing device-agnostic code.
5"""
6
7from contextlib import AbstractContextManager
8from typing import Any, Optional, Union
9
10import torch
11
12from .. import device as _device
13from . import amp
14
15
16__all__ = [
17    "is_available",
18    "synchronize",
19    "current_device",
20    "current_stream",
21    "stream",
22    "set_device",
23    "device_count",
24    "Stream",
25    "StreamContext",
26    "Event",
27]
28
29_device_t = Union[_device, str, int, None]
30
31
32def _is_avx2_supported() -> bool:
33    r"""Returns a bool indicating if CPU supports AVX2."""
34    return torch._C._cpu._is_avx2_supported()
35
36
37def _is_avx512_supported() -> bool:
38    r"""Returns a bool indicating if CPU supports AVX512."""
39    return torch._C._cpu._is_avx512_supported()
40
41
42def _is_avx512_bf16_supported() -> bool:
43    r"""Returns a bool indicating if CPU supports AVX512_BF16."""
44    return torch._C._cpu._is_avx512_bf16_supported()
45
46
47def _is_vnni_supported() -> bool:
48    r"""Returns a bool indicating if CPU supports VNNI."""
49    # Note: Currently, it only checks avx512_vnni, will add the support of avx2_vnni later.
50    return torch._C._cpu._is_avx512_vnni_supported()
51
52
53def _is_amx_tile_supported() -> bool:
54    r"""Returns a bool indicating if CPU supports AMX_TILE."""
55    return torch._C._cpu._is_amx_tile_supported()
56
57
58def _init_amx() -> bool:
59    r"""Initializes AMX instructions."""
60    return torch._C._cpu._init_amx()
61
62
63def is_available() -> bool:
64    r"""Returns a bool indicating if CPU is currently available.
65
66    N.B. This function only exists to facilitate device-agnostic code
67
68    """
69    return True
70
71
72def synchronize(device: _device_t = None) -> None:
73    r"""Waits for all kernels in all streams on the CPU device to complete.
74
75    Args:
76        device (torch.device or int, optional): ignored, there's only one CPU device.
77
78    N.B. This function only exists to facilitate device-agnostic code.
79    """
80
81
82class Stream:
83    """
84    N.B. This class only exists to facilitate device-agnostic code
85    """
86
87    def __init__(self, priority: int = -1) -> None:
88        pass
89
90    def wait_stream(self, stream) -> None:
91        pass
92
93
94class Event:
95    def query(self) -> bool:
96        return True
97
98    def record(self, stream=None) -> None:
99        pass
100
101    def synchronize(self) -> None:
102        pass
103
104    def wait(self, stream=None) -> None:
105        pass
106
107
108_default_cpu_stream = Stream()
109_current_stream = _default_cpu_stream
110
111
112def current_stream(device: _device_t = None) -> Stream:
113    r"""Returns the currently selected :class:`Stream` for a given device.
114
115    Args:
116        device (torch.device or int, optional): Ignored.
117
118    N.B. This function only exists to facilitate device-agnostic code
119
120    """
121    return _current_stream
122
123
124class StreamContext(AbstractContextManager):
125    r"""Context-manager that selects a given stream.
126
127    N.B. This class only exists to facilitate device-agnostic code
128
129    """
130
131    cur_stream: Optional[Stream]
132
133    def __init__(self, stream):
134        self.stream = stream
135        self.prev_stream = _default_cpu_stream
136
137    def __enter__(self):
138        cur_stream = self.stream
139        if cur_stream is None:
140            return
141
142        global _current_stream
143        self.prev_stream = _current_stream
144        _current_stream = cur_stream
145
146    def __exit__(self, type: Any, value: Any, traceback: Any) -> None:
147        cur_stream = self.stream
148        if cur_stream is None:
149            return
150
151        global _current_stream
152        _current_stream = self.prev_stream
153
154
155def stream(stream: Stream) -> AbstractContextManager:
156    r"""Wrapper around the Context-manager StreamContext that
157    selects a given stream.
158
159    N.B. This function only exists to facilitate device-agnostic code
160    """
161    return StreamContext(stream)
162
163
164def device_count() -> int:
165    r"""Returns number of CPU devices (not cores). Always 1.
166
167    N.B. This function only exists to facilitate device-agnostic code
168    """
169    return 1
170
171
172def set_device(device: _device_t) -> None:
173    r"""Sets the current device, in CPU we do nothing.
174
175    N.B. This function only exists to facilitate device-agnostic code
176    """
177
178
179def current_device() -> str:
180    r"""Returns current device for cpu. Always 'cpu'.
181
182    N.B. This function only exists to facilitate device-agnostic code
183    """
184    return "cpu"
185