1# mypy: allow-untyped-defs 2r""" 3This package enables an interface for accessing MTIA backend in python 4""" 5 6import threading 7import warnings 8from typing import Any, Callable, Dict, List, Optional, Tuple, Union 9 10import torch 11from torch import device as _device, Tensor 12from torch._utils import _dummy_type, _LazySeedTracker, classproperty 13from torch.types import Device 14 15from ._utils import _get_device_index 16 17 18_device_t = Union[_device, str, int, None] 19 20# torch.mtia.Event/Stream is alias of torch.Event/Stream 21Event = torch.Event 22Stream = torch.Stream 23 24_initialized = False 25_queued_calls: List[ 26 Tuple[Callable[[], None], List[str]] 27] = [] # don't invoke these until initialization occurs 28_tls = threading.local() 29_initialization_lock = threading.Lock() 30_lazy_seed_tracker = _LazySeedTracker() 31 32 33def init(): 34 _lazy_init() 35 36 37def is_initialized(): 38 r"""Return whether PyTorch's MTIA state has been initialized.""" 39 return _initialized and not _is_in_bad_fork() 40 41 42def _is_in_bad_fork() -> bool: 43 return torch._C._mtia_isInBadFork() 44 45 46def _lazy_init() -> None: 47 global _initialized, _queued_calls 48 if is_initialized() or hasattr(_tls, "is_initializing"): 49 return 50 with _initialization_lock: 51 # We be double-checking locking, boys! This is OK because 52 # the above test was GIL protected anyway. The inner test 53 # is for when a thread blocked on some other thread which was 54 # doing the initialization; when they get the lock, they will 55 # find there is nothing left to do. 56 if is_initialized(): 57 return 58 # It is important to prevent other threads from entering _lazy_init 59 # immediately, while we are still guaranteed to have the GIL, because some 60 # of the C calls we make below will release the GIL 61 if _is_in_bad_fork(): 62 raise RuntimeError( 63 "Cannot re-initialize MTIA in forked subprocess. To use MTIA with " 64 "multiprocessing, you must use the 'spawn' start method" 65 ) 66 if not _is_compiled(): 67 raise AssertionError( 68 "Torch not compiled with MTIA enabled. " 69 "Ensure you have `import mtia.host_runtime.torch_mtia` in your python " 70 "src file and include `//mtia/host_runtime/torch_mtia:torch_mtia` as " 71 "your target dependency!" 72 ) 73 74 torch._C._mtia_init() 75 # Some of the queued calls may reentrantly call _lazy_init(); 76 # we need to just return without initializing in that case. 77 # However, we must not let any *other* threads in! 78 _tls.is_initializing = True 79 80 for calls in _lazy_seed_tracker.get_calls(): 81 if calls: 82 _queued_calls.append(calls) 83 84 try: 85 for queued_call, orig_traceback in _queued_calls: 86 try: 87 queued_call() 88 except Exception as e: 89 msg = ( 90 f"MTIA call failed lazily at initialization with error: {str(e)}\n\n" 91 f"MTIA call was originally invoked at:\n\n{''.join(orig_traceback)}" 92 ) 93 raise DeferredMtiaCallError(msg) from e 94 finally: 95 delattr(_tls, "is_initializing") 96 _initialized = True 97 98 99class DeferredMtiaCallError(Exception): 100 pass 101 102 103def _is_compiled() -> bool: 104 r"""Return true if compiled with MTIA support.""" 105 return torch._C._mtia_isBuilt() 106 107 108def is_available() -> bool: 109 r"""Return true if MTIA device is available""" 110 if not _is_compiled(): 111 return False 112 # MTIA has to init devices first to know if there is any devices available. 113 return device_count() > 0 114 115 116def synchronize(device: Optional[_device_t] = None) -> None: 117 r"""Waits for all jobs in all streams on a MTIA device to complete.""" 118 with torch.mtia.device(device): 119 return torch._C._mtia_deviceSynchronize() 120 121 122def device_count() -> int: 123 r"""Return the number of MTIA devices available.""" 124 return torch._C._accelerator_hooks_device_count() 125 126 127def current_device() -> int: 128 r"""Return the index of a currently selected device.""" 129 return torch._C._accelerator_hooks_get_current_device() 130 131 132def current_stream(device: Optional[_device_t] = None) -> Stream: 133 r"""Return the currently selected :class:`Stream` for a given device. 134 135 Args: 136 device (torch.device or int, optional): selected device. Returns 137 the currently selected :class:`Stream` for the current device, given 138 by :func:`~torch.mtia.current_device`, if :attr:`device` is ``None`` 139 (default). 140 """ 141 return torch._C._mtia_getCurrentStream(_get_device_index(device, optional=True)) 142 143 144def default_stream(device: Optional[_device_t] = None) -> Stream: 145 r"""Return the default :class:`Stream` for a given device. 146 147 Args: 148 device (torch.device or int, optional): selected device. Returns 149 the default :class:`Stream` for the current device, given by 150 :func:`~torch.mtia.current_device`, if :attr:`device` is ``None`` 151 (default). 152 """ 153 return torch._C._mtia_getDefaultStream(_get_device_index(device, optional=True)) 154 155 156def memory_stats(device: Optional[_device_t] = None) -> Dict[str, Any]: 157 r"""Return a dictionary of MTIA memory allocator statistics for a given device. 158 159 Args: 160 device (torch.device or int, optional) selected device. Returns 161 statistics for the current device, given by current_device(), 162 if device is None (default). 163 """ 164 if not is_initialized(): 165 return {} 166 return torch._C._mtia_memoryStats(_get_device_index(device, optional=True)) 167 168 169def set_stream(stream: Stream): 170 r"""Set the current stream.This is a wrapper API to set the stream. 171 Usage of this function is discouraged in favor of the ``stream`` 172 context manager. 173 174 Args: 175 stream (Stream): selected stream. This function is a no-op 176 if this argument is ``None``. 177 """ 178 if stream is None: 179 return 180 torch._C._mtia_setCurrentStream(stream) 181 182 183def set_device(device: _device_t) -> None: 184 r"""Set the current device. 185 186 Args: 187 device (torch.device or int): selected device. This function is a no-op 188 if this argument is negative. 189 """ 190 device = _get_device_index(device) 191 if device >= 0: 192 torch._C._accelerator_hooks_set_current_device(device) 193 194 195class device: 196 r"""Context-manager that changes the selected device. 197 198 Args: 199 device (torch.device or int): device index to select. It's a no-op if 200 this argument is a negative integer or ``None``. 201 """ 202 203 def __init__(self, device: Any): 204 self.idx = _get_device_index(device, optional=True) 205 self.prev_idx = -1 206 207 def __enter__(self): 208 self.prev_idx = torch._C._accelerator_hooks_maybe_exchange_device(self.idx) 209 210 def __exit__(self, type: Any, value: Any, traceback: Any): 211 self.idx = torch._C._accelerator_hooks_maybe_exchange_device(self.prev_idx) 212 return False 213 214 215class StreamContext: 216 r"""Context-manager that selects a given stream. 217 218 All MTIA kernels queued within its context will be enqueued on a selected 219 stream. 220 221 Args: 222 Stream (Stream): selected stream. This manager is a no-op if it's 223 ``None``. 224 .. note:: Streams are per-device. 225 """ 226 227 cur_stream: Optional["torch.mtia.Stream"] 228 229 def __init__(self, stream: Optional["torch.mtia.Stream"]): 230 self.cur_stream = None 231 self.stream = stream 232 self.idx = _get_device_index(None, True) 233 if not torch.jit.is_scripting(): 234 if self.idx is None: 235 self.idx = -1 236 237 self.src_prev_stream = ( 238 None if not torch.jit.is_scripting() else torch.mtia.default_stream(None) 239 ) 240 self.dst_prev_stream = ( 241 None if not torch.jit.is_scripting() else torch.mtia.default_stream(None) 242 ) 243 244 def __enter__(self): 245 # Local cur_stream variable for type refinement 246 cur_stream = self.stream 247 # Return if stream is None or MTIA device not available 248 if cur_stream is None or self.idx == -1: 249 return 250 self.src_prev_stream = torch.mtia.current_stream(None) 251 252 # If the stream is not on the current device, then 253 # set the current stream on the device 254 if self.src_prev_stream.device != cur_stream.device: 255 with device(cur_stream.device): 256 self.dst_prev_stream = torch.mtia.current_stream(cur_stream.device) 257 torch.mtia.set_stream(cur_stream) 258 259 def __exit__(self, type: Any, value: Any, traceback: Any): 260 # Local cur_stream variable for type refinement 261 cur_stream = self.stream 262 # If stream is None or no MTIA device available, return 263 if cur_stream is None or self.idx == -1: 264 return 265 266 # Reset the stream on the original device 267 # and destination device 268 if self.src_prev_stream.device != cur_stream.device: # type: ignore[union-attr] 269 torch.mtia.set_stream(self.dst_prev_stream) # type: ignore[arg-type] 270 torch.mtia.set_stream(self.src_prev_stream) # type: ignore[arg-type] 271 272 273def stream(stream: Optional["torch.mtia.Stream"]) -> StreamContext: 274 r"""Wrap around the Context-manager StreamContext that selects a given stream. 275 276 Arguments: 277 stream (Stream): selected stream. This manager is a no-op if it's 278 ``None``. 279 ..Note:: In eager mode stream is of type Stream class while in JIT it doesn't support torch.mtia.stream 280 """ 281 return StreamContext(stream) 282 283 284def get_rng_state(device: Union[int, str, torch.device] = "mtia") -> Tensor: 285 r"""Returns the random number generator state as a ByteTensor. 286 287 Args: 288 device (torch.device or int, optional): The device to return the RNG state of. 289 Default: ``'mtia'`` (i.e., ``torch.device('mtia')``, the current mtia device). 290 """ 291 warnings.warn( 292 "get_rng_state is not implemented in torch.mtia", 293 UserWarning, 294 stacklevel=2, 295 ) 296 return torch.zeros([1], dtype=torch.uint8, device=device) 297 298 299def set_rng_state( 300 new_state: Tensor, device: Union[int, str, torch.device] = "mtia" 301) -> None: 302 r"""Sets the random number generator state. 303 304 Args: 305 new_state (torch.ByteTensor): The desired state 306 device (torch.device or int, optional): The device to set the RNG state. 307 Default: ``'mtia'`` (i.e., ``torch.device('mtia')``, the current mtia device). 308 """ 309 warnings.warn( 310 "set_rng_state is not implemented in torch.mtia", 311 UserWarning, 312 stacklevel=2, 313 ) 314 315 316__all__ = [ 317 "init", 318 "is_available", 319 "is_initialized", 320 "synchronize", 321 "device_count", 322 "current_device", 323 "current_stream", 324 "default_stream", 325 "memory_stats", 326 "set_device", 327 "set_stream", 328 "stream", 329 "device", 330 "set_rng_state", 331 "get_rng_state", 332] 333