xref: /aosp_15_r20/external/pytorch/torch/cuda/streams.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1*da0073e9SAndroid Build Coastguard Worker# mypy: allow-untyped-defs
2*da0073e9SAndroid Build Coastguard Workerimport ctypes
3*da0073e9SAndroid Build Coastguard Worker
4*da0073e9SAndroid Build Coastguard Workerimport torch
5*da0073e9SAndroid Build Coastguard Workerfrom torch._streambase import _EventBase, _StreamBase
6*da0073e9SAndroid Build Coastguard Workerfrom torch._utils import _dummy_type
7*da0073e9SAndroid Build Coastguard Worker
8*da0073e9SAndroid Build Coastguard Worker
9*da0073e9SAndroid Build Coastguard Workerif not hasattr(torch._C, "_CudaStreamBase"):
10*da0073e9SAndroid Build Coastguard Worker    # Define dummy base classes
11*da0073e9SAndroid Build Coastguard Worker    torch._C.__dict__["_CudaStreamBase"] = _dummy_type("_CudaStreamBase")
12*da0073e9SAndroid Build Coastguard Worker    torch._C.__dict__["_CudaEventBase"] = _dummy_type("_CudaEventBase")
13*da0073e9SAndroid Build Coastguard Worker
14*da0073e9SAndroid Build Coastguard Worker
15*da0073e9SAndroid Build Coastguard Workerclass Stream(torch._C._CudaStreamBase, _StreamBase):
16*da0073e9SAndroid Build Coastguard Worker    r"""Wrapper around a CUDA stream.
17*da0073e9SAndroid Build Coastguard Worker
18*da0073e9SAndroid Build Coastguard Worker    A CUDA stream is a linear sequence of execution that belongs to a specific
19*da0073e9SAndroid Build Coastguard Worker    device, independent from other streams.  See :ref:`cuda-semantics` for
20*da0073e9SAndroid Build Coastguard Worker    details.
21*da0073e9SAndroid Build Coastguard Worker
22*da0073e9SAndroid Build Coastguard Worker    Args:
23*da0073e9SAndroid Build Coastguard Worker        device(torch.device or int, optional): a device on which to allocate
24*da0073e9SAndroid Build Coastguard Worker            the stream. If :attr:`device` is ``None`` (default) or a negative
25*da0073e9SAndroid Build Coastguard Worker            integer, this will use the current device.
26*da0073e9SAndroid Build Coastguard Worker        priority(int, optional): priority of the stream, should be 0 or
27*da0073e9SAndroid Build Coastguard Worker            negative, where negative numbers indicate higher priority. By default,
28*da0073e9SAndroid Build Coastguard Worker            streams have priority 0.
29*da0073e9SAndroid Build Coastguard Worker
30*da0073e9SAndroid Build Coastguard Worker    """
31*da0073e9SAndroid Build Coastguard Worker
32*da0073e9SAndroid Build Coastguard Worker    def __new__(cls, device=None, priority=0, **kwargs):
33*da0073e9SAndroid Build Coastguard Worker        # setting device manager is expensive, so we avoid it unless necessary
34*da0073e9SAndroid Build Coastguard Worker        if device is None or ("stream_id" in kwargs and "device_index" in kwargs):
35*da0073e9SAndroid Build Coastguard Worker            return super().__new__(cls, priority=priority, **kwargs)
36*da0073e9SAndroid Build Coastguard Worker        else:
37*da0073e9SAndroid Build Coastguard Worker            with torch.cuda.device(device):
38*da0073e9SAndroid Build Coastguard Worker                return super().__new__(cls, priority=priority, **kwargs)
39*da0073e9SAndroid Build Coastguard Worker
40*da0073e9SAndroid Build Coastguard Worker    def wait_event(self, event) -> None:
41*da0073e9SAndroid Build Coastguard Worker        r"""Make all future work submitted to the stream wait for an event.
42*da0073e9SAndroid Build Coastguard Worker
43*da0073e9SAndroid Build Coastguard Worker        Args:
44*da0073e9SAndroid Build Coastguard Worker            event (torch.cuda.Event): an event to wait for.
45*da0073e9SAndroid Build Coastguard Worker
46*da0073e9SAndroid Build Coastguard Worker        .. note:: This is a wrapper around ``cudaStreamWaitEvent()``: see
47*da0073e9SAndroid Build Coastguard Worker           `CUDA Stream documentation`_ for more info.
48*da0073e9SAndroid Build Coastguard Worker
49*da0073e9SAndroid Build Coastguard Worker           This function returns without waiting for :attr:`event`: only future
50*da0073e9SAndroid Build Coastguard Worker           operations are affected.
51*da0073e9SAndroid Build Coastguard Worker
52*da0073e9SAndroid Build Coastguard Worker        .. _CUDA Stream documentation:
53*da0073e9SAndroid Build Coastguard Worker           https://docs.nvidia.com/cuda/cuda-runtime-api/group__CUDART__STREAM.html
54*da0073e9SAndroid Build Coastguard Worker        """
55*da0073e9SAndroid Build Coastguard Worker        event.wait(self)
56*da0073e9SAndroid Build Coastguard Worker
57*da0073e9SAndroid Build Coastguard Worker    def wait_stream(self, stream) -> None:
58*da0073e9SAndroid Build Coastguard Worker        r"""Synchronize with another stream.
59*da0073e9SAndroid Build Coastguard Worker
60*da0073e9SAndroid Build Coastguard Worker        All future work submitted to this stream will wait until all kernels
61*da0073e9SAndroid Build Coastguard Worker        submitted to a given stream at the time of call complete.
62*da0073e9SAndroid Build Coastguard Worker
63*da0073e9SAndroid Build Coastguard Worker        Args:
64*da0073e9SAndroid Build Coastguard Worker            stream (Stream): a stream to synchronize.
65*da0073e9SAndroid Build Coastguard Worker
66*da0073e9SAndroid Build Coastguard Worker        .. note:: This function returns without waiting for currently enqueued
67*da0073e9SAndroid Build Coastguard Worker           kernels in :attr:`stream`: only future operations are affected.
68*da0073e9SAndroid Build Coastguard Worker        """
69*da0073e9SAndroid Build Coastguard Worker        self.wait_event(stream.record_event())
70*da0073e9SAndroid Build Coastguard Worker
71*da0073e9SAndroid Build Coastguard Worker    def record_event(self, event=None):
72*da0073e9SAndroid Build Coastguard Worker        r"""Record an event.
73*da0073e9SAndroid Build Coastguard Worker
74*da0073e9SAndroid Build Coastguard Worker        Args:
75*da0073e9SAndroid Build Coastguard Worker            event (torch.cuda.Event, optional): event to record. If not given, a new one
76*da0073e9SAndroid Build Coastguard Worker                will be allocated.
77*da0073e9SAndroid Build Coastguard Worker
78*da0073e9SAndroid Build Coastguard Worker        Returns:
79*da0073e9SAndroid Build Coastguard Worker            Recorded event.
80*da0073e9SAndroid Build Coastguard Worker        """
81*da0073e9SAndroid Build Coastguard Worker        if event is None:
82*da0073e9SAndroid Build Coastguard Worker            event = Event()
83*da0073e9SAndroid Build Coastguard Worker        event.record(self)
84*da0073e9SAndroid Build Coastguard Worker        return event
85*da0073e9SAndroid Build Coastguard Worker
86*da0073e9SAndroid Build Coastguard Worker    def query(self) -> bool:
87*da0073e9SAndroid Build Coastguard Worker        r"""Check if all the work submitted has been completed.
88*da0073e9SAndroid Build Coastguard Worker
89*da0073e9SAndroid Build Coastguard Worker        Returns:
90*da0073e9SAndroid Build Coastguard Worker            A boolean indicating if all kernels in this stream are completed.
91*da0073e9SAndroid Build Coastguard Worker        """
92*da0073e9SAndroid Build Coastguard Worker        return super().query()
93*da0073e9SAndroid Build Coastguard Worker
94*da0073e9SAndroid Build Coastguard Worker    def synchronize(self) -> None:
95*da0073e9SAndroid Build Coastguard Worker        r"""Wait for all the kernels in this stream to complete.
96*da0073e9SAndroid Build Coastguard Worker
97*da0073e9SAndroid Build Coastguard Worker        .. note:: This is a wrapper around ``cudaStreamSynchronize()``: see
98*da0073e9SAndroid Build Coastguard Worker           `CUDA Stream documentation`_ for more info.
99*da0073e9SAndroid Build Coastguard Worker        """
100*da0073e9SAndroid Build Coastguard Worker        super().synchronize()
101*da0073e9SAndroid Build Coastguard Worker
102*da0073e9SAndroid Build Coastguard Worker    @property
103*da0073e9SAndroid Build Coastguard Worker    def _as_parameter_(self):
104*da0073e9SAndroid Build Coastguard Worker        return ctypes.c_void_p(self.cuda_stream)
105*da0073e9SAndroid Build Coastguard Worker
106*da0073e9SAndroid Build Coastguard Worker    def __eq__(self, o) -> bool:
107*da0073e9SAndroid Build Coastguard Worker        if isinstance(o, Stream):
108*da0073e9SAndroid Build Coastguard Worker            return super().__eq__(o)
109*da0073e9SAndroid Build Coastguard Worker        return False
110*da0073e9SAndroid Build Coastguard Worker
111*da0073e9SAndroid Build Coastguard Worker    def __hash__(self):
112*da0073e9SAndroid Build Coastguard Worker        return hash((self.cuda_stream, self.device))
113*da0073e9SAndroid Build Coastguard Worker
114*da0073e9SAndroid Build Coastguard Worker    def __repr__(self):
115*da0073e9SAndroid Build Coastguard Worker        return f"<torch.cuda.Stream device={self.device} cuda_stream={self.cuda_stream:#x}>"
116*da0073e9SAndroid Build Coastguard Worker
117*da0073e9SAndroid Build Coastguard Worker
118*da0073e9SAndroid Build Coastguard Workerclass ExternalStream(Stream):
119*da0073e9SAndroid Build Coastguard Worker    r"""Wrapper around an externally allocated CUDA stream.
120*da0073e9SAndroid Build Coastguard Worker
121*da0073e9SAndroid Build Coastguard Worker    This class is used to wrap streams allocated in other libraries in order
122*da0073e9SAndroid Build Coastguard Worker    to facilitate data exchange and multi-library interactions.
123*da0073e9SAndroid Build Coastguard Worker
124*da0073e9SAndroid Build Coastguard Worker    .. note:: This class doesn't manage the stream life-cycle, it is the user
125*da0073e9SAndroid Build Coastguard Worker       responsibility to keep the referenced stream alive while this class is
126*da0073e9SAndroid Build Coastguard Worker       being used.
127*da0073e9SAndroid Build Coastguard Worker
128*da0073e9SAndroid Build Coastguard Worker    Args:
129*da0073e9SAndroid Build Coastguard Worker        stream_ptr(int): Integer representation of the `cudaStream_t` value.
130*da0073e9SAndroid Build Coastguard Worker            allocated externally.
131*da0073e9SAndroid Build Coastguard Worker        device(torch.device or int, optional): the device where the stream
132*da0073e9SAndroid Build Coastguard Worker            was originally allocated. If device is specified incorrectly,
133*da0073e9SAndroid Build Coastguard Worker            subsequent launches using this stream may fail.
134*da0073e9SAndroid Build Coastguard Worker    """
135*da0073e9SAndroid Build Coastguard Worker
136*da0073e9SAndroid Build Coastguard Worker    def __new__(cls, stream_ptr, device=None, **kwargs):
137*da0073e9SAndroid Build Coastguard Worker        with torch.cuda.device(device):
138*da0073e9SAndroid Build Coastguard Worker            return super().__new__(cls, stream_ptr=stream_ptr, **kwargs)
139*da0073e9SAndroid Build Coastguard Worker
140*da0073e9SAndroid Build Coastguard Worker
141*da0073e9SAndroid Build Coastguard Workerclass Event(torch._C._CudaEventBase, _EventBase):
142*da0073e9SAndroid Build Coastguard Worker    r"""Wrapper around a CUDA event.
143*da0073e9SAndroid Build Coastguard Worker
144*da0073e9SAndroid Build Coastguard Worker    CUDA events are synchronization markers that can be used to monitor the
145*da0073e9SAndroid Build Coastguard Worker    device's progress, to accurately measure timing, and to synchronize CUDA
146*da0073e9SAndroid Build Coastguard Worker    streams.
147*da0073e9SAndroid Build Coastguard Worker
148*da0073e9SAndroid Build Coastguard Worker    The underlying CUDA events are lazily initialized when the event is first
149*da0073e9SAndroid Build Coastguard Worker    recorded or exported to another process. After creation, only streams on the
150*da0073e9SAndroid Build Coastguard Worker    same device may record the event. However, streams on any device can wait on
151*da0073e9SAndroid Build Coastguard Worker    the event.
152*da0073e9SAndroid Build Coastguard Worker
153*da0073e9SAndroid Build Coastguard Worker    Args:
154*da0073e9SAndroid Build Coastguard Worker        enable_timing (bool, optional): indicates if the event should measure time
155*da0073e9SAndroid Build Coastguard Worker            (default: ``False``)
156*da0073e9SAndroid Build Coastguard Worker        blocking (bool, optional): if ``True``, :meth:`wait` will be blocking (default: ``False``)
157*da0073e9SAndroid Build Coastguard Worker        interprocess (bool): if ``True``, the event can be shared between processes
158*da0073e9SAndroid Build Coastguard Worker            (default: ``False``)
159*da0073e9SAndroid Build Coastguard Worker
160*da0073e9SAndroid Build Coastguard Worker    .. _CUDA Event Documentation:
161*da0073e9SAndroid Build Coastguard Worker       https://docs.nvidia.com/cuda/cuda-runtime-api/group__CUDART__EVENT.html
162*da0073e9SAndroid Build Coastguard Worker    """
163*da0073e9SAndroid Build Coastguard Worker
164*da0073e9SAndroid Build Coastguard Worker    def __new__(cls, enable_timing=False, blocking=False, interprocess=False):
165*da0073e9SAndroid Build Coastguard Worker        return super().__new__(
166*da0073e9SAndroid Build Coastguard Worker            cls,
167*da0073e9SAndroid Build Coastguard Worker            enable_timing=enable_timing,
168*da0073e9SAndroid Build Coastguard Worker            blocking=blocking,
169*da0073e9SAndroid Build Coastguard Worker            interprocess=interprocess,
170*da0073e9SAndroid Build Coastguard Worker        )
171*da0073e9SAndroid Build Coastguard Worker
172*da0073e9SAndroid Build Coastguard Worker    @classmethod
173*da0073e9SAndroid Build Coastguard Worker    def from_ipc_handle(cls, device, handle):
174*da0073e9SAndroid Build Coastguard Worker        r"""Reconstruct an event from an IPC handle on the given device."""
175*da0073e9SAndroid Build Coastguard Worker        return super().from_ipc_handle(device, handle)
176*da0073e9SAndroid Build Coastguard Worker
177*da0073e9SAndroid Build Coastguard Worker    def record(self, stream=None):
178*da0073e9SAndroid Build Coastguard Worker        r"""Record the event in a given stream.
179*da0073e9SAndroid Build Coastguard Worker
180*da0073e9SAndroid Build Coastguard Worker        Uses ``torch.cuda.current_stream()`` if no stream is specified. The
181*da0073e9SAndroid Build Coastguard Worker        stream's device must match the event's device.
182*da0073e9SAndroid Build Coastguard Worker        """
183*da0073e9SAndroid Build Coastguard Worker        if stream is None:
184*da0073e9SAndroid Build Coastguard Worker            stream = torch.cuda.current_stream()
185*da0073e9SAndroid Build Coastguard Worker        super().record(stream)
186*da0073e9SAndroid Build Coastguard Worker
187*da0073e9SAndroid Build Coastguard Worker    def wait(self, stream=None) -> None:
188*da0073e9SAndroid Build Coastguard Worker        r"""Make all future work submitted to the given stream wait for this event.
189*da0073e9SAndroid Build Coastguard Worker
190*da0073e9SAndroid Build Coastguard Worker        Use ``torch.cuda.current_stream()`` if no stream is specified.
191*da0073e9SAndroid Build Coastguard Worker
192*da0073e9SAndroid Build Coastguard Worker        .. note:: This is a wrapper around ``cudaStreamWaitEvent()``: see
193*da0073e9SAndroid Build Coastguard Worker            `CUDA Event documentation`_ for more info.
194*da0073e9SAndroid Build Coastguard Worker        """
195*da0073e9SAndroid Build Coastguard Worker        if stream is None:
196*da0073e9SAndroid Build Coastguard Worker            stream = torch.cuda.current_stream()
197*da0073e9SAndroid Build Coastguard Worker        super().wait(stream)
198*da0073e9SAndroid Build Coastguard Worker
199*da0073e9SAndroid Build Coastguard Worker    def query(self):
200*da0073e9SAndroid Build Coastguard Worker        r"""Check if all work currently captured by event has completed.
201*da0073e9SAndroid Build Coastguard Worker
202*da0073e9SAndroid Build Coastguard Worker        Returns:
203*da0073e9SAndroid Build Coastguard Worker            A boolean indicating if all work currently captured by event has
204*da0073e9SAndroid Build Coastguard Worker            completed.
205*da0073e9SAndroid Build Coastguard Worker        """
206*da0073e9SAndroid Build Coastguard Worker        return super().query()
207*da0073e9SAndroid Build Coastguard Worker
208*da0073e9SAndroid Build Coastguard Worker    def elapsed_time(self, end_event):
209*da0073e9SAndroid Build Coastguard Worker        r"""Return the time elapsed.
210*da0073e9SAndroid Build Coastguard Worker
211*da0073e9SAndroid Build Coastguard Worker        Time reported in milliseconds after the event was recorded and
212*da0073e9SAndroid Build Coastguard Worker        before the end_event was recorded.
213*da0073e9SAndroid Build Coastguard Worker        """
214*da0073e9SAndroid Build Coastguard Worker        return super().elapsed_time(end_event)
215*da0073e9SAndroid Build Coastguard Worker
216*da0073e9SAndroid Build Coastguard Worker    def synchronize(self) -> None:
217*da0073e9SAndroid Build Coastguard Worker        r"""Wait for the event to complete.
218*da0073e9SAndroid Build Coastguard Worker
219*da0073e9SAndroid Build Coastguard Worker        Waits until the completion of all work currently captured in this event.
220*da0073e9SAndroid Build Coastguard Worker        This prevents the CPU thread from proceeding until the event completes.
221*da0073e9SAndroid Build Coastguard Worker
222*da0073e9SAndroid Build Coastguard Worker         .. note:: This is a wrapper around ``cudaEventSynchronize()``: see
223*da0073e9SAndroid Build Coastguard Worker            `CUDA Event documentation`_ for more info.
224*da0073e9SAndroid Build Coastguard Worker        """
225*da0073e9SAndroid Build Coastguard Worker        super().synchronize()
226*da0073e9SAndroid Build Coastguard Worker
227*da0073e9SAndroid Build Coastguard Worker    def ipc_handle(self):
228*da0073e9SAndroid Build Coastguard Worker        r"""Return an IPC handle of this event.
229*da0073e9SAndroid Build Coastguard Worker
230*da0073e9SAndroid Build Coastguard Worker        If not recorded yet, the event will use the current device.
231*da0073e9SAndroid Build Coastguard Worker        """
232*da0073e9SAndroid Build Coastguard Worker        return super().ipc_handle()
233*da0073e9SAndroid Build Coastguard Worker
234*da0073e9SAndroid Build Coastguard Worker    @property
235*da0073e9SAndroid Build Coastguard Worker    def _as_parameter_(self):
236*da0073e9SAndroid Build Coastguard Worker        return ctypes.c_void_p(self.cuda_event)
237*da0073e9SAndroid Build Coastguard Worker
238*da0073e9SAndroid Build Coastguard Worker    def __repr__(self) -> str:
239*da0073e9SAndroid Build Coastguard Worker        if self.cuda_event:
240*da0073e9SAndroid Build Coastguard Worker            return f"<torch.cuda.Event {self._as_parameter_.value:#x}>"
241*da0073e9SAndroid Build Coastguard Worker        else:
242*da0073e9SAndroid Build Coastguard Worker            return "<torch.cuda.Event uninitialized>"
243