xref: /aosp_15_r20/external/pytorch/torch/mtia/__init__.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# mypy: allow-untyped-defs
2r"""
3This package enables an interface for accessing MTIA backend in python
4"""
5
6import threading
7import warnings
8from typing import Any, Callable, Dict, List, Optional, Tuple, Union
9
10import torch
11from torch import device as _device, Tensor
12from torch._utils import _dummy_type, _LazySeedTracker, classproperty
13from torch.types import Device
14
15from ._utils import _get_device_index
16
17
18_device_t = Union[_device, str, int, None]
19
20# torch.mtia.Event/Stream is alias of torch.Event/Stream
21Event = torch.Event
22Stream = torch.Stream
23
24_initialized = False
25_queued_calls: List[
26    Tuple[Callable[[], None], List[str]]
27] = []  # don't invoke these until initialization occurs
28_tls = threading.local()
29_initialization_lock = threading.Lock()
30_lazy_seed_tracker = _LazySeedTracker()
31
32
33def init():
34    _lazy_init()
35
36
37def is_initialized():
38    r"""Return whether PyTorch's MTIA state has been initialized."""
39    return _initialized and not _is_in_bad_fork()
40
41
42def _is_in_bad_fork() -> bool:
43    return torch._C._mtia_isInBadFork()
44
45
46def _lazy_init() -> None:
47    global _initialized, _queued_calls
48    if is_initialized() or hasattr(_tls, "is_initializing"):
49        return
50    with _initialization_lock:
51        # We be double-checking locking, boys! This is OK because
52        # the above test was GIL protected anyway. The inner test
53        # is for when a thread blocked on some other thread which was
54        # doing the initialization; when they get the lock, they will
55        # find there is nothing left to do.
56        if is_initialized():
57            return
58        # It is important to prevent other threads from entering _lazy_init
59        # immediately, while we are still guaranteed to have the GIL, because some
60        # of the C calls we make below will release the GIL
61        if _is_in_bad_fork():
62            raise RuntimeError(
63                "Cannot re-initialize MTIA in forked subprocess. To use MTIA with "
64                "multiprocessing, you must use the 'spawn' start method"
65            )
66        if not _is_compiled():
67            raise AssertionError(
68                "Torch not compiled with MTIA enabled. "
69                "Ensure you have `import mtia.host_runtime.torch_mtia` in your python "
70                "src file and include `//mtia/host_runtime/torch_mtia:torch_mtia` as "
71                "your target dependency!"
72            )
73
74        torch._C._mtia_init()
75        # Some of the queued calls may reentrantly call _lazy_init();
76        # we need to just return without initializing in that case.
77        # However, we must not let any *other* threads in!
78        _tls.is_initializing = True
79
80        for calls in _lazy_seed_tracker.get_calls():
81            if calls:
82                _queued_calls.append(calls)
83
84        try:
85            for queued_call, orig_traceback in _queued_calls:
86                try:
87                    queued_call()
88                except Exception as e:
89                    msg = (
90                        f"MTIA call failed lazily at initialization with error: {str(e)}\n\n"
91                        f"MTIA call was originally invoked at:\n\n{''.join(orig_traceback)}"
92                    )
93                    raise DeferredMtiaCallError(msg) from e
94        finally:
95            delattr(_tls, "is_initializing")
96        _initialized = True
97
98
99class DeferredMtiaCallError(Exception):
100    pass
101
102
103def _is_compiled() -> bool:
104    r"""Return true if compiled with MTIA support."""
105    return torch._C._mtia_isBuilt()
106
107
108def is_available() -> bool:
109    r"""Return true if MTIA device is available"""
110    if not _is_compiled():
111        return False
112    # MTIA has to init devices first to know if there is any devices available.
113    return device_count() > 0
114
115
116def synchronize(device: Optional[_device_t] = None) -> None:
117    r"""Waits for all jobs in all streams on a MTIA device to complete."""
118    with torch.mtia.device(device):
119        return torch._C._mtia_deviceSynchronize()
120
121
122def device_count() -> int:
123    r"""Return the number of MTIA devices available."""
124    return torch._C._accelerator_hooks_device_count()
125
126
127def current_device() -> int:
128    r"""Return the index of a currently selected device."""
129    return torch._C._accelerator_hooks_get_current_device()
130
131
132def current_stream(device: Optional[_device_t] = None) -> Stream:
133    r"""Return the currently selected :class:`Stream` for a given device.
134
135    Args:
136        device (torch.device or int, optional): selected device. Returns
137            the currently selected :class:`Stream` for the current device, given
138            by :func:`~torch.mtia.current_device`, if :attr:`device` is ``None``
139            (default).
140    """
141    return torch._C._mtia_getCurrentStream(_get_device_index(device, optional=True))
142
143
144def default_stream(device: Optional[_device_t] = None) -> Stream:
145    r"""Return the default :class:`Stream` for a given device.
146
147    Args:
148        device (torch.device or int, optional): selected device. Returns
149            the default :class:`Stream` for the current device, given by
150            :func:`~torch.mtia.current_device`, if :attr:`device` is ``None``
151            (default).
152    """
153    return torch._C._mtia_getDefaultStream(_get_device_index(device, optional=True))
154
155
156def memory_stats(device: Optional[_device_t] = None) -> Dict[str, Any]:
157    r"""Return a dictionary of MTIA memory allocator statistics for a given device.
158
159    Args:
160        device (torch.device or int, optional) selected device. Returns
161            statistics for the current device, given by current_device(),
162            if device is None (default).
163    """
164    if not is_initialized():
165        return {}
166    return torch._C._mtia_memoryStats(_get_device_index(device, optional=True))
167
168
169def set_stream(stream: Stream):
170    r"""Set the current stream.This is a wrapper API to set the stream.
171        Usage of this function is discouraged in favor of the ``stream``
172        context manager.
173
174    Args:
175        stream (Stream): selected stream. This function is a no-op
176            if this argument is ``None``.
177    """
178    if stream is None:
179        return
180    torch._C._mtia_setCurrentStream(stream)
181
182
183def set_device(device: _device_t) -> None:
184    r"""Set the current device.
185
186    Args:
187        device (torch.device or int): selected device. This function is a no-op
188            if this argument is negative.
189    """
190    device = _get_device_index(device)
191    if device >= 0:
192        torch._C._accelerator_hooks_set_current_device(device)
193
194
195class device:
196    r"""Context-manager that changes the selected device.
197
198    Args:
199        device (torch.device or int): device index to select. It's a no-op if
200            this argument is a negative integer or ``None``.
201    """
202
203    def __init__(self, device: Any):
204        self.idx = _get_device_index(device, optional=True)
205        self.prev_idx = -1
206
207    def __enter__(self):
208        self.prev_idx = torch._C._accelerator_hooks_maybe_exchange_device(self.idx)
209
210    def __exit__(self, type: Any, value: Any, traceback: Any):
211        self.idx = torch._C._accelerator_hooks_maybe_exchange_device(self.prev_idx)
212        return False
213
214
215class StreamContext:
216    r"""Context-manager that selects a given stream.
217
218    All MTIA kernels queued within its context will be enqueued on a selected
219    stream.
220
221    Args:
222        Stream (Stream): selected stream. This manager is a no-op if it's
223            ``None``.
224    .. note:: Streams are per-device.
225    """
226
227    cur_stream: Optional["torch.mtia.Stream"]
228
229    def __init__(self, stream: Optional["torch.mtia.Stream"]):
230        self.cur_stream = None
231        self.stream = stream
232        self.idx = _get_device_index(None, True)
233        if not torch.jit.is_scripting():
234            if self.idx is None:
235                self.idx = -1
236
237        self.src_prev_stream = (
238            None if not torch.jit.is_scripting() else torch.mtia.default_stream(None)
239        )
240        self.dst_prev_stream = (
241            None if not torch.jit.is_scripting() else torch.mtia.default_stream(None)
242        )
243
244    def __enter__(self):
245        # Local cur_stream variable for type refinement
246        cur_stream = self.stream
247        # Return if stream is None or MTIA device not available
248        if cur_stream is None or self.idx == -1:
249            return
250        self.src_prev_stream = torch.mtia.current_stream(None)
251
252        # If the stream is not on the current device, then
253        # set the current stream on the device
254        if self.src_prev_stream.device != cur_stream.device:
255            with device(cur_stream.device):
256                self.dst_prev_stream = torch.mtia.current_stream(cur_stream.device)
257        torch.mtia.set_stream(cur_stream)
258
259    def __exit__(self, type: Any, value: Any, traceback: Any):
260        # Local cur_stream variable for type refinement
261        cur_stream = self.stream
262        # If stream is None or no MTIA device available, return
263        if cur_stream is None or self.idx == -1:
264            return
265
266        # Reset the stream on the original device
267        # and destination device
268        if self.src_prev_stream.device != cur_stream.device:  # type: ignore[union-attr]
269            torch.mtia.set_stream(self.dst_prev_stream)  # type: ignore[arg-type]
270        torch.mtia.set_stream(self.src_prev_stream)  # type: ignore[arg-type]
271
272
273def stream(stream: Optional["torch.mtia.Stream"]) -> StreamContext:
274    r"""Wrap around the Context-manager StreamContext that selects a given stream.
275
276    Arguments:
277        stream (Stream): selected stream. This manager is a no-op if it's
278            ``None``.
279    ..Note:: In eager mode stream is of type Stream class while in JIT it doesn't support torch.mtia.stream
280    """
281    return StreamContext(stream)
282
283
284def get_rng_state(device: Union[int, str, torch.device] = "mtia") -> Tensor:
285    r"""Returns the random number generator state as a ByteTensor.
286
287    Args:
288        device (torch.device or int, optional): The device to return the RNG state of.
289            Default: ``'mtia'`` (i.e., ``torch.device('mtia')``, the current mtia device).
290    """
291    warnings.warn(
292        "get_rng_state is not implemented in torch.mtia",
293        UserWarning,
294        stacklevel=2,
295    )
296    return torch.zeros([1], dtype=torch.uint8, device=device)
297
298
299def set_rng_state(
300    new_state: Tensor, device: Union[int, str, torch.device] = "mtia"
301) -> None:
302    r"""Sets the random number generator state.
303
304    Args:
305        new_state (torch.ByteTensor): The desired state
306        device (torch.device or int, optional): The device to set the RNG state.
307            Default: ``'mtia'`` (i.e., ``torch.device('mtia')``, the current mtia device).
308    """
309    warnings.warn(
310        "set_rng_state is not implemented in torch.mtia",
311        UserWarning,
312        stacklevel=2,
313    )
314
315
316__all__ = [
317    "init",
318    "is_available",
319    "is_initialized",
320    "synchronize",
321    "device_count",
322    "current_device",
323    "current_stream",
324    "default_stream",
325    "memory_stats",
326    "set_device",
327    "set_stream",
328    "stream",
329    "device",
330    "set_rng_state",
331    "get_rng_state",
332]
333