xref: /aosp_15_r20/external/pytorch/aten/src/ATen/cuda/CUDAEvent.h (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #pragma once
2 
3 #include <ATen/cuda/ATenCUDAGeneral.h>
4 #include <ATen/cuda/CUDAContext.h>
5 #include <c10/core/impl/GPUTrace.h>
6 #include <c10/cuda/CUDAStream.h>
7 #include <c10/cuda/CUDAGuard.h>
8 #include <ATen/cuda/Exceptions.h>
9 #include <c10/util/Exception.h>
10 
11 #include <cuda_runtime_api.h>
12 
13 #include <cstdint>
14 #include <utility>
15 
16 namespace at::cuda {
17 
18 /*
19 * CUDAEvents are movable not copyable wrappers around CUDA's events.
20 *
21 * CUDAEvents are constructed lazily when first recorded unless it is
22 * reconstructed from a cudaIpcEventHandle_t. The event has a device, and this
23 * device is acquired from the first recording stream. However, if reconstructed
24 * from a handle, the device should be explicitly specified; or if ipc_handle() is
25 * called before the event is ever recorded, it will use the current device.
26 * Later streams that record the event must match this device.
27 */
28 struct TORCH_CUDA_CPP_API CUDAEvent {
29   // Constructors
30   // Default value for `flags` is specified below - it's cudaEventDisableTiming
31   CUDAEvent() noexcept = default;
CUDAEventCUDAEvent32   CUDAEvent(unsigned int flags) noexcept : flags_{flags} {}
33 
CUDAEventCUDAEvent34   CUDAEvent(
35       DeviceIndex device_index, const cudaIpcEventHandle_t* handle) : device_index_(device_index) {
36       CUDAGuard guard(device_index_);
37 
38       AT_CUDA_CHECK(cudaIpcOpenEventHandle(&event_, *handle));
39       is_created_ = true;
40   }
41 
42   // Note: event destruction done on creating device to avoid creating a
43   // CUDA context on other devices.
~CUDAEventCUDAEvent44   ~CUDAEvent() {
45     try {
46       if (is_created_) {
47         CUDAGuard guard(device_index_);
48         const c10::impl::PyInterpreter* interp = c10::impl::GPUTrace::get_trace();
49         if (C10_UNLIKELY(interp)) {
50           (*interp)->trace_gpu_event_deletion(at::kCUDA, reinterpret_cast<uintptr_t>(event_));
51         }
52         AT_CUDA_CHECK(cudaEventDestroy(event_));
53       }
54     } catch (...) { /* No throw */ }
55   }
56 
57   CUDAEvent(const CUDAEvent&) = delete;
58   CUDAEvent& operator=(const CUDAEvent&) = delete;
59 
CUDAEventCUDAEvent60   CUDAEvent(CUDAEvent&& other) noexcept { moveHelper(std::move(other)); }
61   CUDAEvent& operator=(CUDAEvent&& other) noexcept {
62     if (this != &other) {
63       moveHelper(std::move(other));
64     }
65     return *this;
66   }
67 
cudaEvent_tCUDAEvent68   operator cudaEvent_t() const { return event(); }
69 
70   // Less than operator (to allow use in sets)
71   friend bool operator<(const CUDAEvent& left, const CUDAEvent& right) {
72     return left.event_ < right.event_;
73   }
74 
deviceCUDAEvent75   std::optional<at::Device> device() const {
76     if (is_created_) {
77       return at::Device(at::kCUDA, device_index_);
78     } else {
79       return {};
80     }
81   }
82 
isCreatedCUDAEvent83   bool isCreated() const { return is_created_; }
device_indexCUDAEvent84   DeviceIndex device_index() const {return device_index_;}
eventCUDAEvent85   cudaEvent_t event() const { return event_; }
86 
87   // Note: cudaEventQuery can be safely called from any device
queryCUDAEvent88   bool query() const {
89     if (!is_created_) {
90       return true;
91     }
92 
93     cudaError_t err = cudaEventQuery(event_);
94     if (err == cudaSuccess) {
95       return true;
96     } else if (err != cudaErrorNotReady) {
97       C10_CUDA_CHECK(err);
98     } else {
99       // ignore and clear the error if not ready
100       (void)cudaGetLastError();
101     }
102 
103     return false;
104   }
105 
recordCUDAEvent106   void record() { record(getCurrentCUDAStream()); }
107 
recordOnceCUDAEvent108   void recordOnce(const CUDAStream& stream) {
109     if (!was_recorded_) record(stream);
110   }
111 
112   // Note: cudaEventRecord must be called on the same device as the event.
recordCUDAEvent113   void record(const CUDAStream& stream) {
114     if (!is_created_) {
115       createEvent(stream.device_index());
116     }
117 
118     TORCH_CHECK(device_index_ == stream.device_index(), "Event device ", device_index_,
119       " does not match recording stream's device ", stream.device_index(), ".");
120     CUDAGuard guard(device_index_);
121     AT_CUDA_CHECK(cudaEventRecord(event_, stream));
122     const c10::impl::PyInterpreter* interp = c10::impl::GPUTrace::get_trace();
123     if (C10_UNLIKELY(interp)) {
124       (*interp)->trace_gpu_event_record(at::kCUDA,
125           reinterpret_cast<uintptr_t>(event_),
126           reinterpret_cast<uintptr_t>(stream.stream())
127       );
128     }
129     was_recorded_ = true;
130   }
131 
132   // Note: cudaStreamWaitEvent must be called on the same device as the stream.
133   // The event has no actual GPU resources associated with it.
blockCUDAEvent134   void block(const CUDAStream& stream) {
135     if (is_created_) {
136       CUDAGuard guard(stream.device_index());
137       AT_CUDA_CHECK(cudaStreamWaitEvent(stream, event_, 0));
138       const c10::impl::PyInterpreter* interp = c10::impl::GPUTrace::get_trace();
139       if (C10_UNLIKELY(interp)) {
140         (*interp)->trace_gpu_event_wait(at::kCUDA,
141             reinterpret_cast<uintptr_t>(event_),
142             reinterpret_cast<uintptr_t>(stream.stream())
143         );
144       }
145     }
146   }
147 
148   // Note: cudaEventElapsedTime can be safely called from any device
elapsed_timeCUDAEvent149   float elapsed_time(const CUDAEvent& other) const {
150     TORCH_CHECK(is_created_ && other.isCreated(),
151       "Both events must be recorded before calculating elapsed time.");
152     float time_ms = 0;
153     // We do not strictly have to set the device index to the same as our event,
154     // but if we don't and the current device is not initialized, it will
155     // create a new cuda context, which will consume a lot of memory.
156     CUDAGuard guard(device_index_);
157     // raise cudaErrorNotReady if either event is recorded but not yet completed
158     AT_CUDA_CHECK(cudaEventElapsedTime(&time_ms, event_, other.event_));
159     return time_ms;
160   }
161 
162   // Note: cudaEventSynchronize can be safely called from any device
synchronizeCUDAEvent163   void synchronize() const {
164     if (is_created_) {
165       const c10::impl::PyInterpreter* interp = c10::impl::GPUTrace::get_trace();
166       if (C10_UNLIKELY(interp)) {
167           (*interp)->trace_gpu_event_synchronization(at::kCUDA, reinterpret_cast<uintptr_t>(event_));
168       }
169       AT_CUDA_CHECK(cudaEventSynchronize(event_));
170     }
171   }
172 
173   // Note: cudaIpcGetEventHandle must be called on the same device as the event
ipc_handleCUDAEvent174   void ipc_handle(cudaIpcEventHandle_t * handle) {
175       if (!is_created_) {
176         // this CUDAEvent object was initially constructed from flags but event_
177         // is not created yet.
178         createEvent(getCurrentCUDAStream().device_index());
179       }
180       CUDAGuard guard(device_index_);
181       AT_CUDA_CHECK(cudaIpcGetEventHandle(handle, event_));
182   }
183 
184 private:
185   unsigned int flags_ = cudaEventDisableTiming;
186   bool is_created_ = false;
187   bool was_recorded_ = false;
188   DeviceIndex device_index_ = -1;
189   cudaEvent_t event_{};
190 
createEventCUDAEvent191   void createEvent(DeviceIndex device_index) {
192     device_index_ = device_index;
193     CUDAGuard guard(device_index_);
194     AT_CUDA_CHECK(cudaEventCreateWithFlags(&event_, flags_));
195     const c10::impl::PyInterpreter* interp = c10::impl::GPUTrace::get_trace();
196     if (C10_UNLIKELY(interp)) {
197       (*interp)->trace_gpu_event_creation(at::kCUDA, reinterpret_cast<uintptr_t>(event_));
198     }
199     is_created_ = true;
200   }
201 
moveHelperCUDAEvent202   void moveHelper(CUDAEvent&& other) {
203     std::swap(flags_, other.flags_);
204     std::swap(is_created_, other.is_created_);
205     std::swap(was_recorded_, other.was_recorded_);
206     std::swap(device_index_, other.device_index_);
207     std::swap(event_, other.event_);
208   }
209 };
210 
211 } // namespace at::cuda
212