xref: /aosp_15_r20/external/pytorch/aten/src/ATen/xpu/XPUEvent.h (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #pragma once
2 #include <ATen/xpu/XPUContext.h>
3 
4 #include <optional>
5 
6 namespace at::xpu {
7 
8 /*
9  * XPUEvent are movable not copyable wrappers around SYCL event. XPUEvent are
10  * constructed lazily when first recorded. It has a device, and this device is
11  * acquired from the first recording stream. Later streams that record the event
12  * must match the same device.
13  *
14  * Currently, XPUEvent does NOT support to export an inter-process event from
15  * another process via inter-process comunication(IPC). So it means that
16  * inter-process communication for event handles between different processes is
17  * not available. This could impact some applications that rely on cross-process
18  * synchronization and communication.
19  */
20 struct TORCH_XPU_API XPUEvent {
21   // Constructors
22   XPUEvent(bool enable_timing = false) noexcept
23       : enable_timing_{enable_timing} {}
24 
~XPUEventXPUEvent25   ~XPUEvent() {
26     if (isCreated()) {
27       const c10::impl::PyInterpreter* interp = c10::impl::GPUTrace::get_trace();
28       if (C10_UNLIKELY(interp)) {
29         (*interp)->trace_gpu_event_deletion(
30             at::kXPU, reinterpret_cast<uintptr_t>(event_.get()));
31       }
32     }
33   }
34 
35   XPUEvent(const XPUEvent&) = delete;
36   XPUEvent& operator=(const XPUEvent&) = delete;
37 
38   XPUEvent(XPUEvent&& other) = default;
39   XPUEvent& operator=(XPUEvent&& other) = default;
40 
41   operator sycl::event&() const {
42     return event();
43   }
44 
deviceXPUEvent45   std::optional<at::Device> device() const {
46     if (isCreated()) {
47       return at::Device(at::kXPU, device_index_);
48     } else {
49       return std::nullopt;
50     }
51   }
52 
isCreatedXPUEvent53   inline bool isCreated() const {
54     return (event_.get() != nullptr);
55   }
56 
device_indexXPUEvent57   DeviceIndex device_index() const {
58     return device_index_;
59   }
60 
eventXPUEvent61   sycl::event& event() const {
62     return *event_;
63   }
64 
queryXPUEvent65   bool query() const {
66     using namespace sycl::info;
67     if (!isCreated()) {
68       return true;
69     }
70 
71     return event().get_info<event::command_execution_status>() ==
72         event_command_status::complete;
73   }
74 
recordXPUEvent75   void record() {
76     record(getCurrentXPUStream());
77   }
78 
recordOnceXPUEvent79   void recordOnce(const XPUStream& stream) {
80     if (!isCreated()) {
81       record(stream);
82     }
83   }
84 
recordXPUEvent85   void record(const XPUStream& stream) {
86     if (!isCreated()) {
87       device_index_ = stream.device_index();
88       event_ = std::make_unique<sycl::event>(
89           stream.queue().ext_oneapi_submit_barrier());
90       const c10::impl::PyInterpreter* interp = c10::impl::GPUTrace::get_trace();
91       if (C10_UNLIKELY(interp)) {
92         (*interp)->trace_gpu_event_creation(
93             at::kXPU, reinterpret_cast<uintptr_t>(event_.get()));
94       }
95     } else {
96       TORCH_CHECK(
97           device_index_ == stream.device_index(),
98           "Event device ",
99           device_index_,
100           " does not match recording stream's device ",
101           stream.device_index(),
102           ".");
103       event_.reset();
104       event_ = std::make_unique<sycl::event>(
105           stream.queue().ext_oneapi_submit_barrier());
106     }
107     const c10::impl::PyInterpreter* interp = c10::impl::GPUTrace::get_trace();
108     if (C10_UNLIKELY(interp)) {
109       (*interp)->trace_gpu_event_record(
110           at::kXPU,
111           reinterpret_cast<uintptr_t>(event_.get()),
112           reinterpret_cast<uintptr_t>(&stream.queue()));
113     }
114   }
115 
blockXPUEvent116   void block(const XPUStream& stream) {
117     if (isCreated()) {
118       std::vector<sycl::event> event_list{event()};
119       // Make this stream wait until event_ is completed.
120       stream.queue().ext_oneapi_submit_barrier(event_list);
121       const c10::impl::PyInterpreter* interp = c10::impl::GPUTrace::get_trace();
122       if (C10_UNLIKELY(interp)) {
123         (*interp)->trace_gpu_event_wait(
124             at::kXPU,
125             reinterpret_cast<uintptr_t>(event_.get()),
126             reinterpret_cast<uintptr_t>(&stream.queue()));
127       }
128     }
129   }
130 
elapsed_timeXPUEvent131   float elapsed_time(const XPUEvent& other) const {
132     TORCH_CHECK(
133         isCreated() && other.isCreated(),
134         "Both events must be recorded before calculating elapsed time.");
135     TORCH_CHECK(
136         query() && other.query(),
137         "Both events must be completed before calculating elapsed time.");
138     TORCH_CHECK(
139         enable_timing_ && other.enable_timing_,
140         "Both events must be created with argument 'enable_timing=True'.");
141     // TODO: provides the ability to time the execution of commands in a SYCL
142     // queue without enabling profiling on the entire queue
143     TORCH_CHECK_NOT_IMPLEMENTED(
144         false, "elapsed_time is not supported by XPUEvent.");
145   }
146 
synchronizeXPUEvent147   void synchronize() const {
148     if (isCreated()) {
149       const c10::impl::PyInterpreter* interp = c10::impl::GPUTrace::get_trace();
150       if (C10_UNLIKELY(interp)) {
151         (*interp)->trace_gpu_event_synchronization(
152             at::kXPU, reinterpret_cast<uintptr_t>(event_.get()));
153       }
154       event().wait_and_throw();
155     }
156   }
157 
158  private:
159   bool enable_timing_ = false;
160   DeviceIndex device_index_ = -1;
161   // Only need to track the last event, as events in an in-order queue are
162   // executed sequentially.
163   std::unique_ptr<sycl::event> event_;
164 };
165 
166 } // namespace at::xpu
167