xref: /aosp_15_r20/external/pytorch/torch/cuda/streams.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# mypy: allow-untyped-defs
2import ctypes
3
4import torch
5from torch._streambase import _EventBase, _StreamBase
6from torch._utils import _dummy_type
7
8
9if not hasattr(torch._C, "_CudaStreamBase"):
10    # Define dummy base classes
11    torch._C.__dict__["_CudaStreamBase"] = _dummy_type("_CudaStreamBase")
12    torch._C.__dict__["_CudaEventBase"] = _dummy_type("_CudaEventBase")
13
14
15class Stream(torch._C._CudaStreamBase, _StreamBase):
16    r"""Wrapper around a CUDA stream.
17
18    A CUDA stream is a linear sequence of execution that belongs to a specific
19    device, independent from other streams.  See :ref:`cuda-semantics` for
20    details.
21
22    Args:
23        device(torch.device or int, optional): a device on which to allocate
24            the stream. If :attr:`device` is ``None`` (default) or a negative
25            integer, this will use the current device.
26        priority(int, optional): priority of the stream, should be 0 or
27            negative, where negative numbers indicate higher priority. By default,
28            streams have priority 0.
29
30    """
31
32    def __new__(cls, device=None, priority=0, **kwargs):
33        # setting device manager is expensive, so we avoid it unless necessary
34        if device is None or ("stream_id" in kwargs and "device_index" in kwargs):
35            return super().__new__(cls, priority=priority, **kwargs)
36        else:
37            with torch.cuda.device(device):
38                return super().__new__(cls, priority=priority, **kwargs)
39
40    def wait_event(self, event) -> None:
41        r"""Make all future work submitted to the stream wait for an event.
42
43        Args:
44            event (torch.cuda.Event): an event to wait for.
45
46        .. note:: This is a wrapper around ``cudaStreamWaitEvent()``: see
47           `CUDA Stream documentation`_ for more info.
48
49           This function returns without waiting for :attr:`event`: only future
50           operations are affected.
51
52        .. _CUDA Stream documentation:
53           https://docs.nvidia.com/cuda/cuda-runtime-api/group__CUDART__STREAM.html
54        """
55        event.wait(self)
56
57    def wait_stream(self, stream) -> None:
58        r"""Synchronize with another stream.
59
60        All future work submitted to this stream will wait until all kernels
61        submitted to a given stream at the time of call complete.
62
63        Args:
64            stream (Stream): a stream to synchronize.
65
66        .. note:: This function returns without waiting for currently enqueued
67           kernels in :attr:`stream`: only future operations are affected.
68        """
69        self.wait_event(stream.record_event())
70
71    def record_event(self, event=None):
72        r"""Record an event.
73
74        Args:
75            event (torch.cuda.Event, optional): event to record. If not given, a new one
76                will be allocated.
77
78        Returns:
79            Recorded event.
80        """
81        if event is None:
82            event = Event()
83        event.record(self)
84        return event
85
86    def query(self) -> bool:
87        r"""Check if all the work submitted has been completed.
88
89        Returns:
90            A boolean indicating if all kernels in this stream are completed.
91        """
92        return super().query()
93
94    def synchronize(self) -> None:
95        r"""Wait for all the kernels in this stream to complete.
96
97        .. note:: This is a wrapper around ``cudaStreamSynchronize()``: see
98           `CUDA Stream documentation`_ for more info.
99        """
100        super().synchronize()
101
102    @property
103    def _as_parameter_(self):
104        return ctypes.c_void_p(self.cuda_stream)
105
106    def __eq__(self, o) -> bool:
107        if isinstance(o, Stream):
108            return super().__eq__(o)
109        return False
110
111    def __hash__(self):
112        return hash((self.cuda_stream, self.device))
113
114    def __repr__(self):
115        return f"<torch.cuda.Stream device={self.device} cuda_stream={self.cuda_stream:#x}>"
116
117
118class ExternalStream(Stream):
119    r"""Wrapper around an externally allocated CUDA stream.
120
121    This class is used to wrap streams allocated in other libraries in order
122    to facilitate data exchange and multi-library interactions.
123
124    .. note:: This class doesn't manage the stream life-cycle, it is the user
125       responsibility to keep the referenced stream alive while this class is
126       being used.
127
128    Args:
129        stream_ptr(int): Integer representation of the `cudaStream_t` value.
130            allocated externally.
131        device(torch.device or int, optional): the device where the stream
132            was originally allocated. If device is specified incorrectly,
133            subsequent launches using this stream may fail.
134    """
135
136    def __new__(cls, stream_ptr, device=None, **kwargs):
137        with torch.cuda.device(device):
138            return super().__new__(cls, stream_ptr=stream_ptr, **kwargs)
139
140
141class Event(torch._C._CudaEventBase, _EventBase):
142    r"""Wrapper around a CUDA event.
143
144    CUDA events are synchronization markers that can be used to monitor the
145    device's progress, to accurately measure timing, and to synchronize CUDA
146    streams.
147
148    The underlying CUDA events are lazily initialized when the event is first
149    recorded or exported to another process. After creation, only streams on the
150    same device may record the event. However, streams on any device can wait on
151    the event.
152
153    Args:
154        enable_timing (bool, optional): indicates if the event should measure time
155            (default: ``False``)
156        blocking (bool, optional): if ``True``, :meth:`wait` will be blocking (default: ``False``)
157        interprocess (bool): if ``True``, the event can be shared between processes
158            (default: ``False``)
159
160    .. _CUDA Event Documentation:
161       https://docs.nvidia.com/cuda/cuda-runtime-api/group__CUDART__EVENT.html
162    """
163
164    def __new__(cls, enable_timing=False, blocking=False, interprocess=False):
165        return super().__new__(
166            cls,
167            enable_timing=enable_timing,
168            blocking=blocking,
169            interprocess=interprocess,
170        )
171
172    @classmethod
173    def from_ipc_handle(cls, device, handle):
174        r"""Reconstruct an event from an IPC handle on the given device."""
175        return super().from_ipc_handle(device, handle)
176
177    def record(self, stream=None):
178        r"""Record the event in a given stream.
179
180        Uses ``torch.cuda.current_stream()`` if no stream is specified. The
181        stream's device must match the event's device.
182        """
183        if stream is None:
184            stream = torch.cuda.current_stream()
185        super().record(stream)
186
187    def wait(self, stream=None) -> None:
188        r"""Make all future work submitted to the given stream wait for this event.
189
190        Use ``torch.cuda.current_stream()`` if no stream is specified.
191
192        .. note:: This is a wrapper around ``cudaStreamWaitEvent()``: see
193            `CUDA Event documentation`_ for more info.
194        """
195        if stream is None:
196            stream = torch.cuda.current_stream()
197        super().wait(stream)
198
199    def query(self):
200        r"""Check if all work currently captured by event has completed.
201
202        Returns:
203            A boolean indicating if all work currently captured by event has
204            completed.
205        """
206        return super().query()
207
208    def elapsed_time(self, end_event):
209        r"""Return the time elapsed.
210
211        Time reported in milliseconds after the event was recorded and
212        before the end_event was recorded.
213        """
214        return super().elapsed_time(end_event)
215
216    def synchronize(self) -> None:
217        r"""Wait for the event to complete.
218
219        Waits until the completion of all work currently captured in this event.
220        This prevents the CPU thread from proceeding until the event completes.
221
222         .. note:: This is a wrapper around ``cudaEventSynchronize()``: see
223            `CUDA Event documentation`_ for more info.
224        """
225        super().synchronize()
226
227    def ipc_handle(self):
228        r"""Return an IPC handle of this event.
229
230        If not recorded yet, the event will use the current device.
231        """
232        return super().ipc_handle()
233
234    @property
235    def _as_parameter_(self):
236        return ctypes.c_void_p(self.cuda_event)
237
238    def __repr__(self) -> str:
239        if self.cuda_event:
240            return f"<torch.cuda.Event {self._as_parameter_.value:#x}>"
241        else:
242            return "<torch.cuda.Event uninitialized>"
243