1# mypy: allow-untyped-defs 2import ctypes 3 4import torch 5from torch._streambase import _EventBase, _StreamBase 6from torch._utils import _dummy_type 7 8 9if not hasattr(torch._C, "_CudaStreamBase"): 10 # Define dummy base classes 11 torch._C.__dict__["_CudaStreamBase"] = _dummy_type("_CudaStreamBase") 12 torch._C.__dict__["_CudaEventBase"] = _dummy_type("_CudaEventBase") 13 14 15class Stream(torch._C._CudaStreamBase, _StreamBase): 16 r"""Wrapper around a CUDA stream. 17 18 A CUDA stream is a linear sequence of execution that belongs to a specific 19 device, independent from other streams. See :ref:`cuda-semantics` for 20 details. 21 22 Args: 23 device(torch.device or int, optional): a device on which to allocate 24 the stream. If :attr:`device` is ``None`` (default) or a negative 25 integer, this will use the current device. 26 priority(int, optional): priority of the stream, should be 0 or 27 negative, where negative numbers indicate higher priority. By default, 28 streams have priority 0. 29 30 """ 31 32 def __new__(cls, device=None, priority=0, **kwargs): 33 # setting device manager is expensive, so we avoid it unless necessary 34 if device is None or ("stream_id" in kwargs and "device_index" in kwargs): 35 return super().__new__(cls, priority=priority, **kwargs) 36 else: 37 with torch.cuda.device(device): 38 return super().__new__(cls, priority=priority, **kwargs) 39 40 def wait_event(self, event) -> None: 41 r"""Make all future work submitted to the stream wait for an event. 42 43 Args: 44 event (torch.cuda.Event): an event to wait for. 45 46 .. note:: This is a wrapper around ``cudaStreamWaitEvent()``: see 47 `CUDA Stream documentation`_ for more info. 48 49 This function returns without waiting for :attr:`event`: only future 50 operations are affected. 51 52 .. _CUDA Stream documentation: 53 https://docs.nvidia.com/cuda/cuda-runtime-api/group__CUDART__STREAM.html 54 """ 55 event.wait(self) 56 57 def wait_stream(self, stream) -> None: 58 r"""Synchronize with another stream. 59 60 All future work submitted to this stream will wait until all kernels 61 submitted to a given stream at the time of call complete. 62 63 Args: 64 stream (Stream): a stream to synchronize. 65 66 .. note:: This function returns without waiting for currently enqueued 67 kernels in :attr:`stream`: only future operations are affected. 68 """ 69 self.wait_event(stream.record_event()) 70 71 def record_event(self, event=None): 72 r"""Record an event. 73 74 Args: 75 event (torch.cuda.Event, optional): event to record. If not given, a new one 76 will be allocated. 77 78 Returns: 79 Recorded event. 80 """ 81 if event is None: 82 event = Event() 83 event.record(self) 84 return event 85 86 def query(self) -> bool: 87 r"""Check if all the work submitted has been completed. 88 89 Returns: 90 A boolean indicating if all kernels in this stream are completed. 91 """ 92 return super().query() 93 94 def synchronize(self) -> None: 95 r"""Wait for all the kernels in this stream to complete. 96 97 .. note:: This is a wrapper around ``cudaStreamSynchronize()``: see 98 `CUDA Stream documentation`_ for more info. 99 """ 100 super().synchronize() 101 102 @property 103 def _as_parameter_(self): 104 return ctypes.c_void_p(self.cuda_stream) 105 106 def __eq__(self, o) -> bool: 107 if isinstance(o, Stream): 108 return super().__eq__(o) 109 return False 110 111 def __hash__(self): 112 return hash((self.cuda_stream, self.device)) 113 114 def __repr__(self): 115 return f"<torch.cuda.Stream device={self.device} cuda_stream={self.cuda_stream:#x}>" 116 117 118class ExternalStream(Stream): 119 r"""Wrapper around an externally allocated CUDA stream. 120 121 This class is used to wrap streams allocated in other libraries in order 122 to facilitate data exchange and multi-library interactions. 123 124 .. note:: This class doesn't manage the stream life-cycle, it is the user 125 responsibility to keep the referenced stream alive while this class is 126 being used. 127 128 Args: 129 stream_ptr(int): Integer representation of the `cudaStream_t` value. 130 allocated externally. 131 device(torch.device or int, optional): the device where the stream 132 was originally allocated. If device is specified incorrectly, 133 subsequent launches using this stream may fail. 134 """ 135 136 def __new__(cls, stream_ptr, device=None, **kwargs): 137 with torch.cuda.device(device): 138 return super().__new__(cls, stream_ptr=stream_ptr, **kwargs) 139 140 141class Event(torch._C._CudaEventBase, _EventBase): 142 r"""Wrapper around a CUDA event. 143 144 CUDA events are synchronization markers that can be used to monitor the 145 device's progress, to accurately measure timing, and to synchronize CUDA 146 streams. 147 148 The underlying CUDA events are lazily initialized when the event is first 149 recorded or exported to another process. After creation, only streams on the 150 same device may record the event. However, streams on any device can wait on 151 the event. 152 153 Args: 154 enable_timing (bool, optional): indicates if the event should measure time 155 (default: ``False``) 156 blocking (bool, optional): if ``True``, :meth:`wait` will be blocking (default: ``False``) 157 interprocess (bool): if ``True``, the event can be shared between processes 158 (default: ``False``) 159 160 .. _CUDA Event Documentation: 161 https://docs.nvidia.com/cuda/cuda-runtime-api/group__CUDART__EVENT.html 162 """ 163 164 def __new__(cls, enable_timing=False, blocking=False, interprocess=False): 165 return super().__new__( 166 cls, 167 enable_timing=enable_timing, 168 blocking=blocking, 169 interprocess=interprocess, 170 ) 171 172 @classmethod 173 def from_ipc_handle(cls, device, handle): 174 r"""Reconstruct an event from an IPC handle on the given device.""" 175 return super().from_ipc_handle(device, handle) 176 177 def record(self, stream=None): 178 r"""Record the event in a given stream. 179 180 Uses ``torch.cuda.current_stream()`` if no stream is specified. The 181 stream's device must match the event's device. 182 """ 183 if stream is None: 184 stream = torch.cuda.current_stream() 185 super().record(stream) 186 187 def wait(self, stream=None) -> None: 188 r"""Make all future work submitted to the given stream wait for this event. 189 190 Use ``torch.cuda.current_stream()`` if no stream is specified. 191 192 .. note:: This is a wrapper around ``cudaStreamWaitEvent()``: see 193 `CUDA Event documentation`_ for more info. 194 """ 195 if stream is None: 196 stream = torch.cuda.current_stream() 197 super().wait(stream) 198 199 def query(self): 200 r"""Check if all work currently captured by event has completed. 201 202 Returns: 203 A boolean indicating if all work currently captured by event has 204 completed. 205 """ 206 return super().query() 207 208 def elapsed_time(self, end_event): 209 r"""Return the time elapsed. 210 211 Time reported in milliseconds after the event was recorded and 212 before the end_event was recorded. 213 """ 214 return super().elapsed_time(end_event) 215 216 def synchronize(self) -> None: 217 r"""Wait for the event to complete. 218 219 Waits until the completion of all work currently captured in this event. 220 This prevents the CPU thread from proceeding until the event completes. 221 222 .. note:: This is a wrapper around ``cudaEventSynchronize()``: see 223 `CUDA Event documentation`_ for more info. 224 """ 225 super().synchronize() 226 227 def ipc_handle(self): 228 r"""Return an IPC handle of this event. 229 230 If not recorded yet, the event will use the current device. 231 """ 232 return super().ipc_handle() 233 234 @property 235 def _as_parameter_(self): 236 return ctypes.c_void_p(self.cuda_event) 237 238 def __repr__(self) -> str: 239 if self.cuda_event: 240 return f"<torch.cuda.Event {self._as_parameter_.value:#x}>" 241 else: 242 return "<torch.cuda.Event uninitialized>" 243