1# mypy: allow-untyped-defs 2r""" 3This package implements abstractions found in ``torch.cuda`` 4to facilitate writing device-agnostic code. 5""" 6 7from contextlib import AbstractContextManager 8from typing import Any, Optional, Union 9 10import torch 11 12from .. import device as _device 13from . import amp 14 15 16__all__ = [ 17 "is_available", 18 "synchronize", 19 "current_device", 20 "current_stream", 21 "stream", 22 "set_device", 23 "device_count", 24 "Stream", 25 "StreamContext", 26 "Event", 27] 28 29_device_t = Union[_device, str, int, None] 30 31 32def _is_avx2_supported() -> bool: 33 r"""Returns a bool indicating if CPU supports AVX2.""" 34 return torch._C._cpu._is_avx2_supported() 35 36 37def _is_avx512_supported() -> bool: 38 r"""Returns a bool indicating if CPU supports AVX512.""" 39 return torch._C._cpu._is_avx512_supported() 40 41 42def _is_avx512_bf16_supported() -> bool: 43 r"""Returns a bool indicating if CPU supports AVX512_BF16.""" 44 return torch._C._cpu._is_avx512_bf16_supported() 45 46 47def _is_vnni_supported() -> bool: 48 r"""Returns a bool indicating if CPU supports VNNI.""" 49 # Note: Currently, it only checks avx512_vnni, will add the support of avx2_vnni later. 50 return torch._C._cpu._is_avx512_vnni_supported() 51 52 53def _is_amx_tile_supported() -> bool: 54 r"""Returns a bool indicating if CPU supports AMX_TILE.""" 55 return torch._C._cpu._is_amx_tile_supported() 56 57 58def _init_amx() -> bool: 59 r"""Initializes AMX instructions.""" 60 return torch._C._cpu._init_amx() 61 62 63def is_available() -> bool: 64 r"""Returns a bool indicating if CPU is currently available. 65 66 N.B. This function only exists to facilitate device-agnostic code 67 68 """ 69 return True 70 71 72def synchronize(device: _device_t = None) -> None: 73 r"""Waits for all kernels in all streams on the CPU device to complete. 74 75 Args: 76 device (torch.device or int, optional): ignored, there's only one CPU device. 77 78 N.B. This function only exists to facilitate device-agnostic code. 79 """ 80 81 82class Stream: 83 """ 84 N.B. This class only exists to facilitate device-agnostic code 85 """ 86 87 def __init__(self, priority: int = -1) -> None: 88 pass 89 90 def wait_stream(self, stream) -> None: 91 pass 92 93 94class Event: 95 def query(self) -> bool: 96 return True 97 98 def record(self, stream=None) -> None: 99 pass 100 101 def synchronize(self) -> None: 102 pass 103 104 def wait(self, stream=None) -> None: 105 pass 106 107 108_default_cpu_stream = Stream() 109_current_stream = _default_cpu_stream 110 111 112def current_stream(device: _device_t = None) -> Stream: 113 r"""Returns the currently selected :class:`Stream` for a given device. 114 115 Args: 116 device (torch.device or int, optional): Ignored. 117 118 N.B. This function only exists to facilitate device-agnostic code 119 120 """ 121 return _current_stream 122 123 124class StreamContext(AbstractContextManager): 125 r"""Context-manager that selects a given stream. 126 127 N.B. This class only exists to facilitate device-agnostic code 128 129 """ 130 131 cur_stream: Optional[Stream] 132 133 def __init__(self, stream): 134 self.stream = stream 135 self.prev_stream = _default_cpu_stream 136 137 def __enter__(self): 138 cur_stream = self.stream 139 if cur_stream is None: 140 return 141 142 global _current_stream 143 self.prev_stream = _current_stream 144 _current_stream = cur_stream 145 146 def __exit__(self, type: Any, value: Any, traceback: Any) -> None: 147 cur_stream = self.stream 148 if cur_stream is None: 149 return 150 151 global _current_stream 152 _current_stream = self.prev_stream 153 154 155def stream(stream: Stream) -> AbstractContextManager: 156 r"""Wrapper around the Context-manager StreamContext that 157 selects a given stream. 158 159 N.B. This function only exists to facilitate device-agnostic code 160 """ 161 return StreamContext(stream) 162 163 164def device_count() -> int: 165 r"""Returns number of CPU devices (not cores). Always 1. 166 167 N.B. This function only exists to facilitate device-agnostic code 168 """ 169 return 1 170 171 172def set_device(device: _device_t) -> None: 173 r"""Sets the current device, in CPU we do nothing. 174 175 N.B. This function only exists to facilitate device-agnostic code 176 """ 177 178 179def current_device() -> str: 180 r"""Returns current device for cpu. Always 'cpu'. 181 182 N.B. This function only exists to facilitate device-agnostic code 183 """ 184 return "cpu" 185