1*da0073e9SAndroid Build Coastguard Worker# mypy: allow-untyped-defs 2*da0073e9SAndroid Build Coastguard Workerimport ctypes 3*da0073e9SAndroid Build Coastguard Worker 4*da0073e9SAndroid Build Coastguard Workerimport torch 5*da0073e9SAndroid Build Coastguard Workerfrom torch._streambase import _EventBase, _StreamBase 6*da0073e9SAndroid Build Coastguard Workerfrom torch._utils import _dummy_type 7*da0073e9SAndroid Build Coastguard Worker 8*da0073e9SAndroid Build Coastguard Worker 9*da0073e9SAndroid Build Coastguard Workerif not hasattr(torch._C, "_CudaStreamBase"): 10*da0073e9SAndroid Build Coastguard Worker # Define dummy base classes 11*da0073e9SAndroid Build Coastguard Worker torch._C.__dict__["_CudaStreamBase"] = _dummy_type("_CudaStreamBase") 12*da0073e9SAndroid Build Coastguard Worker torch._C.__dict__["_CudaEventBase"] = _dummy_type("_CudaEventBase") 13*da0073e9SAndroid Build Coastguard Worker 14*da0073e9SAndroid Build Coastguard Worker 15*da0073e9SAndroid Build Coastguard Workerclass Stream(torch._C._CudaStreamBase, _StreamBase): 16*da0073e9SAndroid Build Coastguard Worker r"""Wrapper around a CUDA stream. 17*da0073e9SAndroid Build Coastguard Worker 18*da0073e9SAndroid Build Coastguard Worker A CUDA stream is a linear sequence of execution that belongs to a specific 19*da0073e9SAndroid Build Coastguard Worker device, independent from other streams. See :ref:`cuda-semantics` for 20*da0073e9SAndroid Build Coastguard Worker details. 21*da0073e9SAndroid Build Coastguard Worker 22*da0073e9SAndroid Build Coastguard Worker Args: 23*da0073e9SAndroid Build Coastguard Worker device(torch.device or int, optional): a device on which to allocate 24*da0073e9SAndroid Build Coastguard Worker the stream. If :attr:`device` is ``None`` (default) or a negative 25*da0073e9SAndroid Build Coastguard Worker integer, this will use the current device. 26*da0073e9SAndroid Build Coastguard Worker priority(int, optional): priority of the stream, should be 0 or 27*da0073e9SAndroid Build Coastguard Worker negative, where negative numbers indicate higher priority. By default, 28*da0073e9SAndroid Build Coastguard Worker streams have priority 0. 29*da0073e9SAndroid Build Coastguard Worker 30*da0073e9SAndroid Build Coastguard Worker """ 31*da0073e9SAndroid Build Coastguard Worker 32*da0073e9SAndroid Build Coastguard Worker def __new__(cls, device=None, priority=0, **kwargs): 33*da0073e9SAndroid Build Coastguard Worker # setting device manager is expensive, so we avoid it unless necessary 34*da0073e9SAndroid Build Coastguard Worker if device is None or ("stream_id" in kwargs and "device_index" in kwargs): 35*da0073e9SAndroid Build Coastguard Worker return super().__new__(cls, priority=priority, **kwargs) 36*da0073e9SAndroid Build Coastguard Worker else: 37*da0073e9SAndroid Build Coastguard Worker with torch.cuda.device(device): 38*da0073e9SAndroid Build Coastguard Worker return super().__new__(cls, priority=priority, **kwargs) 39*da0073e9SAndroid Build Coastguard Worker 40*da0073e9SAndroid Build Coastguard Worker def wait_event(self, event) -> None: 41*da0073e9SAndroid Build Coastguard Worker r"""Make all future work submitted to the stream wait for an event. 42*da0073e9SAndroid Build Coastguard Worker 43*da0073e9SAndroid Build Coastguard Worker Args: 44*da0073e9SAndroid Build Coastguard Worker event (torch.cuda.Event): an event to wait for. 45*da0073e9SAndroid Build Coastguard Worker 46*da0073e9SAndroid Build Coastguard Worker .. note:: This is a wrapper around ``cudaStreamWaitEvent()``: see 47*da0073e9SAndroid Build Coastguard Worker `CUDA Stream documentation`_ for more info. 48*da0073e9SAndroid Build Coastguard Worker 49*da0073e9SAndroid Build Coastguard Worker This function returns without waiting for :attr:`event`: only future 50*da0073e9SAndroid Build Coastguard Worker operations are affected. 51*da0073e9SAndroid Build Coastguard Worker 52*da0073e9SAndroid Build Coastguard Worker .. _CUDA Stream documentation: 53*da0073e9SAndroid Build Coastguard Worker https://docs.nvidia.com/cuda/cuda-runtime-api/group__CUDART__STREAM.html 54*da0073e9SAndroid Build Coastguard Worker """ 55*da0073e9SAndroid Build Coastguard Worker event.wait(self) 56*da0073e9SAndroid Build Coastguard Worker 57*da0073e9SAndroid Build Coastguard Worker def wait_stream(self, stream) -> None: 58*da0073e9SAndroid Build Coastguard Worker r"""Synchronize with another stream. 59*da0073e9SAndroid Build Coastguard Worker 60*da0073e9SAndroid Build Coastguard Worker All future work submitted to this stream will wait until all kernels 61*da0073e9SAndroid Build Coastguard Worker submitted to a given stream at the time of call complete. 62*da0073e9SAndroid Build Coastguard Worker 63*da0073e9SAndroid Build Coastguard Worker Args: 64*da0073e9SAndroid Build Coastguard Worker stream (Stream): a stream to synchronize. 65*da0073e9SAndroid Build Coastguard Worker 66*da0073e9SAndroid Build Coastguard Worker .. note:: This function returns without waiting for currently enqueued 67*da0073e9SAndroid Build Coastguard Worker kernels in :attr:`stream`: only future operations are affected. 68*da0073e9SAndroid Build Coastguard Worker """ 69*da0073e9SAndroid Build Coastguard Worker self.wait_event(stream.record_event()) 70*da0073e9SAndroid Build Coastguard Worker 71*da0073e9SAndroid Build Coastguard Worker def record_event(self, event=None): 72*da0073e9SAndroid Build Coastguard Worker r"""Record an event. 73*da0073e9SAndroid Build Coastguard Worker 74*da0073e9SAndroid Build Coastguard Worker Args: 75*da0073e9SAndroid Build Coastguard Worker event (torch.cuda.Event, optional): event to record. If not given, a new one 76*da0073e9SAndroid Build Coastguard Worker will be allocated. 77*da0073e9SAndroid Build Coastguard Worker 78*da0073e9SAndroid Build Coastguard Worker Returns: 79*da0073e9SAndroid Build Coastguard Worker Recorded event. 80*da0073e9SAndroid Build Coastguard Worker """ 81*da0073e9SAndroid Build Coastguard Worker if event is None: 82*da0073e9SAndroid Build Coastguard Worker event = Event() 83*da0073e9SAndroid Build Coastguard Worker event.record(self) 84*da0073e9SAndroid Build Coastguard Worker return event 85*da0073e9SAndroid Build Coastguard Worker 86*da0073e9SAndroid Build Coastguard Worker def query(self) -> bool: 87*da0073e9SAndroid Build Coastguard Worker r"""Check if all the work submitted has been completed. 88*da0073e9SAndroid Build Coastguard Worker 89*da0073e9SAndroid Build Coastguard Worker Returns: 90*da0073e9SAndroid Build Coastguard Worker A boolean indicating if all kernels in this stream are completed. 91*da0073e9SAndroid Build Coastguard Worker """ 92*da0073e9SAndroid Build Coastguard Worker return super().query() 93*da0073e9SAndroid Build Coastguard Worker 94*da0073e9SAndroid Build Coastguard Worker def synchronize(self) -> None: 95*da0073e9SAndroid Build Coastguard Worker r"""Wait for all the kernels in this stream to complete. 96*da0073e9SAndroid Build Coastguard Worker 97*da0073e9SAndroid Build Coastguard Worker .. note:: This is a wrapper around ``cudaStreamSynchronize()``: see 98*da0073e9SAndroid Build Coastguard Worker `CUDA Stream documentation`_ for more info. 99*da0073e9SAndroid Build Coastguard Worker """ 100*da0073e9SAndroid Build Coastguard Worker super().synchronize() 101*da0073e9SAndroid Build Coastguard Worker 102*da0073e9SAndroid Build Coastguard Worker @property 103*da0073e9SAndroid Build Coastguard Worker def _as_parameter_(self): 104*da0073e9SAndroid Build Coastguard Worker return ctypes.c_void_p(self.cuda_stream) 105*da0073e9SAndroid Build Coastguard Worker 106*da0073e9SAndroid Build Coastguard Worker def __eq__(self, o) -> bool: 107*da0073e9SAndroid Build Coastguard Worker if isinstance(o, Stream): 108*da0073e9SAndroid Build Coastguard Worker return super().__eq__(o) 109*da0073e9SAndroid Build Coastguard Worker return False 110*da0073e9SAndroid Build Coastguard Worker 111*da0073e9SAndroid Build Coastguard Worker def __hash__(self): 112*da0073e9SAndroid Build Coastguard Worker return hash((self.cuda_stream, self.device)) 113*da0073e9SAndroid Build Coastguard Worker 114*da0073e9SAndroid Build Coastguard Worker def __repr__(self): 115*da0073e9SAndroid Build Coastguard Worker return f"<torch.cuda.Stream device={self.device} cuda_stream={self.cuda_stream:#x}>" 116*da0073e9SAndroid Build Coastguard Worker 117*da0073e9SAndroid Build Coastguard Worker 118*da0073e9SAndroid Build Coastguard Workerclass ExternalStream(Stream): 119*da0073e9SAndroid Build Coastguard Worker r"""Wrapper around an externally allocated CUDA stream. 120*da0073e9SAndroid Build Coastguard Worker 121*da0073e9SAndroid Build Coastguard Worker This class is used to wrap streams allocated in other libraries in order 122*da0073e9SAndroid Build Coastguard Worker to facilitate data exchange and multi-library interactions. 123*da0073e9SAndroid Build Coastguard Worker 124*da0073e9SAndroid Build Coastguard Worker .. note:: This class doesn't manage the stream life-cycle, it is the user 125*da0073e9SAndroid Build Coastguard Worker responsibility to keep the referenced stream alive while this class is 126*da0073e9SAndroid Build Coastguard Worker being used. 127*da0073e9SAndroid Build Coastguard Worker 128*da0073e9SAndroid Build Coastguard Worker Args: 129*da0073e9SAndroid Build Coastguard Worker stream_ptr(int): Integer representation of the `cudaStream_t` value. 130*da0073e9SAndroid Build Coastguard Worker allocated externally. 131*da0073e9SAndroid Build Coastguard Worker device(torch.device or int, optional): the device where the stream 132*da0073e9SAndroid Build Coastguard Worker was originally allocated. If device is specified incorrectly, 133*da0073e9SAndroid Build Coastguard Worker subsequent launches using this stream may fail. 134*da0073e9SAndroid Build Coastguard Worker """ 135*da0073e9SAndroid Build Coastguard Worker 136*da0073e9SAndroid Build Coastguard Worker def __new__(cls, stream_ptr, device=None, **kwargs): 137*da0073e9SAndroid Build Coastguard Worker with torch.cuda.device(device): 138*da0073e9SAndroid Build Coastguard Worker return super().__new__(cls, stream_ptr=stream_ptr, **kwargs) 139*da0073e9SAndroid Build Coastguard Worker 140*da0073e9SAndroid Build Coastguard Worker 141*da0073e9SAndroid Build Coastguard Workerclass Event(torch._C._CudaEventBase, _EventBase): 142*da0073e9SAndroid Build Coastguard Worker r"""Wrapper around a CUDA event. 143*da0073e9SAndroid Build Coastguard Worker 144*da0073e9SAndroid Build Coastguard Worker CUDA events are synchronization markers that can be used to monitor the 145*da0073e9SAndroid Build Coastguard Worker device's progress, to accurately measure timing, and to synchronize CUDA 146*da0073e9SAndroid Build Coastguard Worker streams. 147*da0073e9SAndroid Build Coastguard Worker 148*da0073e9SAndroid Build Coastguard Worker The underlying CUDA events are lazily initialized when the event is first 149*da0073e9SAndroid Build Coastguard Worker recorded or exported to another process. After creation, only streams on the 150*da0073e9SAndroid Build Coastguard Worker same device may record the event. However, streams on any device can wait on 151*da0073e9SAndroid Build Coastguard Worker the event. 152*da0073e9SAndroid Build Coastguard Worker 153*da0073e9SAndroid Build Coastguard Worker Args: 154*da0073e9SAndroid Build Coastguard Worker enable_timing (bool, optional): indicates if the event should measure time 155*da0073e9SAndroid Build Coastguard Worker (default: ``False``) 156*da0073e9SAndroid Build Coastguard Worker blocking (bool, optional): if ``True``, :meth:`wait` will be blocking (default: ``False``) 157*da0073e9SAndroid Build Coastguard Worker interprocess (bool): if ``True``, the event can be shared between processes 158*da0073e9SAndroid Build Coastguard Worker (default: ``False``) 159*da0073e9SAndroid Build Coastguard Worker 160*da0073e9SAndroid Build Coastguard Worker .. _CUDA Event Documentation: 161*da0073e9SAndroid Build Coastguard Worker https://docs.nvidia.com/cuda/cuda-runtime-api/group__CUDART__EVENT.html 162*da0073e9SAndroid Build Coastguard Worker """ 163*da0073e9SAndroid Build Coastguard Worker 164*da0073e9SAndroid Build Coastguard Worker def __new__(cls, enable_timing=False, blocking=False, interprocess=False): 165*da0073e9SAndroid Build Coastguard Worker return super().__new__( 166*da0073e9SAndroid Build Coastguard Worker cls, 167*da0073e9SAndroid Build Coastguard Worker enable_timing=enable_timing, 168*da0073e9SAndroid Build Coastguard Worker blocking=blocking, 169*da0073e9SAndroid Build Coastguard Worker interprocess=interprocess, 170*da0073e9SAndroid Build Coastguard Worker ) 171*da0073e9SAndroid Build Coastguard Worker 172*da0073e9SAndroid Build Coastguard Worker @classmethod 173*da0073e9SAndroid Build Coastguard Worker def from_ipc_handle(cls, device, handle): 174*da0073e9SAndroid Build Coastguard Worker r"""Reconstruct an event from an IPC handle on the given device.""" 175*da0073e9SAndroid Build Coastguard Worker return super().from_ipc_handle(device, handle) 176*da0073e9SAndroid Build Coastguard Worker 177*da0073e9SAndroid Build Coastguard Worker def record(self, stream=None): 178*da0073e9SAndroid Build Coastguard Worker r"""Record the event in a given stream. 179*da0073e9SAndroid Build Coastguard Worker 180*da0073e9SAndroid Build Coastguard Worker Uses ``torch.cuda.current_stream()`` if no stream is specified. The 181*da0073e9SAndroid Build Coastguard Worker stream's device must match the event's device. 182*da0073e9SAndroid Build Coastguard Worker """ 183*da0073e9SAndroid Build Coastguard Worker if stream is None: 184*da0073e9SAndroid Build Coastguard Worker stream = torch.cuda.current_stream() 185*da0073e9SAndroid Build Coastguard Worker super().record(stream) 186*da0073e9SAndroid Build Coastguard Worker 187*da0073e9SAndroid Build Coastguard Worker def wait(self, stream=None) -> None: 188*da0073e9SAndroid Build Coastguard Worker r"""Make all future work submitted to the given stream wait for this event. 189*da0073e9SAndroid Build Coastguard Worker 190*da0073e9SAndroid Build Coastguard Worker Use ``torch.cuda.current_stream()`` if no stream is specified. 191*da0073e9SAndroid Build Coastguard Worker 192*da0073e9SAndroid Build Coastguard Worker .. note:: This is a wrapper around ``cudaStreamWaitEvent()``: see 193*da0073e9SAndroid Build Coastguard Worker `CUDA Event documentation`_ for more info. 194*da0073e9SAndroid Build Coastguard Worker """ 195*da0073e9SAndroid Build Coastguard Worker if stream is None: 196*da0073e9SAndroid Build Coastguard Worker stream = torch.cuda.current_stream() 197*da0073e9SAndroid Build Coastguard Worker super().wait(stream) 198*da0073e9SAndroid Build Coastguard Worker 199*da0073e9SAndroid Build Coastguard Worker def query(self): 200*da0073e9SAndroid Build Coastguard Worker r"""Check if all work currently captured by event has completed. 201*da0073e9SAndroid Build Coastguard Worker 202*da0073e9SAndroid Build Coastguard Worker Returns: 203*da0073e9SAndroid Build Coastguard Worker A boolean indicating if all work currently captured by event has 204*da0073e9SAndroid Build Coastguard Worker completed. 205*da0073e9SAndroid Build Coastguard Worker """ 206*da0073e9SAndroid Build Coastguard Worker return super().query() 207*da0073e9SAndroid Build Coastguard Worker 208*da0073e9SAndroid Build Coastguard Worker def elapsed_time(self, end_event): 209*da0073e9SAndroid Build Coastguard Worker r"""Return the time elapsed. 210*da0073e9SAndroid Build Coastguard Worker 211*da0073e9SAndroid Build Coastguard Worker Time reported in milliseconds after the event was recorded and 212*da0073e9SAndroid Build Coastguard Worker before the end_event was recorded. 213*da0073e9SAndroid Build Coastguard Worker """ 214*da0073e9SAndroid Build Coastguard Worker return super().elapsed_time(end_event) 215*da0073e9SAndroid Build Coastguard Worker 216*da0073e9SAndroid Build Coastguard Worker def synchronize(self) -> None: 217*da0073e9SAndroid Build Coastguard Worker r"""Wait for the event to complete. 218*da0073e9SAndroid Build Coastguard Worker 219*da0073e9SAndroid Build Coastguard Worker Waits until the completion of all work currently captured in this event. 220*da0073e9SAndroid Build Coastguard Worker This prevents the CPU thread from proceeding until the event completes. 221*da0073e9SAndroid Build Coastguard Worker 222*da0073e9SAndroid Build Coastguard Worker .. note:: This is a wrapper around ``cudaEventSynchronize()``: see 223*da0073e9SAndroid Build Coastguard Worker `CUDA Event documentation`_ for more info. 224*da0073e9SAndroid Build Coastguard Worker """ 225*da0073e9SAndroid Build Coastguard Worker super().synchronize() 226*da0073e9SAndroid Build Coastguard Worker 227*da0073e9SAndroid Build Coastguard Worker def ipc_handle(self): 228*da0073e9SAndroid Build Coastguard Worker r"""Return an IPC handle of this event. 229*da0073e9SAndroid Build Coastguard Worker 230*da0073e9SAndroid Build Coastguard Worker If not recorded yet, the event will use the current device. 231*da0073e9SAndroid Build Coastguard Worker """ 232*da0073e9SAndroid Build Coastguard Worker return super().ipc_handle() 233*da0073e9SAndroid Build Coastguard Worker 234*da0073e9SAndroid Build Coastguard Worker @property 235*da0073e9SAndroid Build Coastguard Worker def _as_parameter_(self): 236*da0073e9SAndroid Build Coastguard Worker return ctypes.c_void_p(self.cuda_event) 237*da0073e9SAndroid Build Coastguard Worker 238*da0073e9SAndroid Build Coastguard Worker def __repr__(self) -> str: 239*da0073e9SAndroid Build Coastguard Worker if self.cuda_event: 240*da0073e9SAndroid Build Coastguard Worker return f"<torch.cuda.Event {self._as_parameter_.value:#x}>" 241*da0073e9SAndroid Build Coastguard Worker else: 242*da0073e9SAndroid Build Coastguard Worker return "<torch.cuda.Event uninitialized>" 243