xref: /aosp_15_r20/external/pytorch/c10/xpu/impl/XPUGuardImpl.h (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1*da0073e9SAndroid Build Coastguard Worker #pragma once
2*da0073e9SAndroid Build Coastguard Worker 
3*da0073e9SAndroid Build Coastguard Worker #include <c10/core/DeviceGuard.h>
4*da0073e9SAndroid Build Coastguard Worker #include <c10/core/impl/DeviceGuardImplInterface.h>
5*da0073e9SAndroid Build Coastguard Worker #include <c10/core/impl/GPUTrace.h>
6*da0073e9SAndroid Build Coastguard Worker #include <c10/xpu/XPUCachingAllocator.h>
7*da0073e9SAndroid Build Coastguard Worker #include <c10/xpu/XPUFunctions.h>
8*da0073e9SAndroid Build Coastguard Worker #include <c10/xpu/XPUStream.h>
9*da0073e9SAndroid Build Coastguard Worker 
10*da0073e9SAndroid Build Coastguard Worker #include <vector>
11*da0073e9SAndroid Build Coastguard Worker 
12*da0073e9SAndroid Build Coastguard Worker namespace c10::xpu::impl {
13*da0073e9SAndroid Build Coastguard Worker 
14*da0073e9SAndroid Build Coastguard Worker struct XPUGuardImpl final : public c10::impl::DeviceGuardImplInterface {
15*da0073e9SAndroid Build Coastguard Worker   static constexpr DeviceType static_type = kXPU;
16*da0073e9SAndroid Build Coastguard Worker 
17*da0073e9SAndroid Build Coastguard Worker   XPUGuardImpl() = default;
18*da0073e9SAndroid Build Coastguard Worker 
XPUGuardImplfinal19*da0073e9SAndroid Build Coastguard Worker   explicit XPUGuardImpl(DeviceType t) {
20*da0073e9SAndroid Build Coastguard Worker     TORCH_INTERNAL_ASSERT(t == kXPU);
21*da0073e9SAndroid Build Coastguard Worker   }
22*da0073e9SAndroid Build Coastguard Worker 
typefinal23*da0073e9SAndroid Build Coastguard Worker   DeviceType type() const override {
24*da0073e9SAndroid Build Coastguard Worker     return kXPU;
25*da0073e9SAndroid Build Coastguard Worker   }
26*da0073e9SAndroid Build Coastguard Worker 
exchangeDevicefinal27*da0073e9SAndroid Build Coastguard Worker   Device exchangeDevice(Device d) const override {
28*da0073e9SAndroid Build Coastguard Worker     TORCH_INTERNAL_ASSERT(d.is_xpu());
29*da0073e9SAndroid Build Coastguard Worker     const auto old_device_index = c10::xpu::exchange_device(d.index());
30*da0073e9SAndroid Build Coastguard Worker     return Device(kXPU, old_device_index);
31*da0073e9SAndroid Build Coastguard Worker   }
32*da0073e9SAndroid Build Coastguard Worker 
getDevicefinal33*da0073e9SAndroid Build Coastguard Worker   Device getDevice() const override {
34*da0073e9SAndroid Build Coastguard Worker     const auto device = c10::xpu::current_device();
35*da0073e9SAndroid Build Coastguard Worker     return Device(kXPU, device);
36*da0073e9SAndroid Build Coastguard Worker   }
37*da0073e9SAndroid Build Coastguard Worker 
setDevicefinal38*da0073e9SAndroid Build Coastguard Worker   void setDevice(Device d) const override {
39*da0073e9SAndroid Build Coastguard Worker     TORCH_INTERNAL_ASSERT(d.is_xpu());
40*da0073e9SAndroid Build Coastguard Worker     c10::xpu::set_device(d.index());
41*da0073e9SAndroid Build Coastguard Worker   }
42*da0073e9SAndroid Build Coastguard Worker 
uncheckedSetDevicefinal43*da0073e9SAndroid Build Coastguard Worker   void uncheckedSetDevice(Device d) const noexcept override {
44*da0073e9SAndroid Build Coastguard Worker     c10::xpu::set_device(d.index());
45*da0073e9SAndroid Build Coastguard Worker   }
46*da0073e9SAndroid Build Coastguard Worker 
getStreamfinal47*da0073e9SAndroid Build Coastguard Worker   Stream getStream(Device d) const noexcept override {
48*da0073e9SAndroid Build Coastguard Worker     return getCurrentXPUStream(d.index()).unwrap();
49*da0073e9SAndroid Build Coastguard Worker   }
50*da0073e9SAndroid Build Coastguard Worker 
51*da0073e9SAndroid Build Coastguard Worker   Stream getNewStream(Device d, int priority = 0) const override {
52*da0073e9SAndroid Build Coastguard Worker     return getStreamFromPool(priority, d.index());
53*da0073e9SAndroid Build Coastguard Worker   }
54*da0073e9SAndroid Build Coastguard Worker 
55*da0073e9SAndroid Build Coastguard Worker   Stream getStreamFromGlobalPool(Device d, bool isHighPriority = false)
56*da0073e9SAndroid Build Coastguard Worker       const override {
57*da0073e9SAndroid Build Coastguard Worker     return getStreamFromPool(isHighPriority, d.index());
58*da0073e9SAndroid Build Coastguard Worker   }
59*da0073e9SAndroid Build Coastguard Worker 
60*da0073e9SAndroid Build Coastguard Worker   // NB: These do NOT set the current device
exchangeStreamfinal61*da0073e9SAndroid Build Coastguard Worker   Stream exchangeStream(Stream s) const noexcept override {
62*da0073e9SAndroid Build Coastguard Worker     const XPUStream stream(s);
63*da0073e9SAndroid Build Coastguard Worker     const auto old_stream = getCurrentXPUStream(s.device().index());
64*da0073e9SAndroid Build Coastguard Worker     setCurrentXPUStream(stream);
65*da0073e9SAndroid Build Coastguard Worker     return old_stream.unwrap();
66*da0073e9SAndroid Build Coastguard Worker   }
67*da0073e9SAndroid Build Coastguard Worker 
deviceCountfinal68*da0073e9SAndroid Build Coastguard Worker   DeviceIndex deviceCount() const noexcept override {
69*da0073e9SAndroid Build Coastguard Worker     return c10::xpu::device_count();
70*da0073e9SAndroid Build Coastguard Worker   }
71*da0073e9SAndroid Build Coastguard Worker 
72*da0073e9SAndroid Build Coastguard Worker   // Event-related functions
destroyEventfinal73*da0073e9SAndroid Build Coastguard Worker   void destroyEvent(void* event, const DeviceIndex device_index)
74*da0073e9SAndroid Build Coastguard Worker       const noexcept override {
75*da0073e9SAndroid Build Coastguard Worker     if (!event)
76*da0073e9SAndroid Build Coastguard Worker       return;
77*da0073e9SAndroid Build Coastguard Worker 
78*da0073e9SAndroid Build Coastguard Worker     const c10::impl::PyInterpreter* interp = c10::impl::GPUTrace::get_trace();
79*da0073e9SAndroid Build Coastguard Worker     if (C10_UNLIKELY(interp)) {
80*da0073e9SAndroid Build Coastguard Worker       (*interp)->trace_gpu_event_deletion(
81*da0073e9SAndroid Build Coastguard Worker           c10::kXPU, reinterpret_cast<uintptr_t>(event));
82*da0073e9SAndroid Build Coastguard Worker     }
83*da0073e9SAndroid Build Coastguard Worker 
84*da0073e9SAndroid Build Coastguard Worker     delete reinterpret_cast<sycl::event*>(event);
85*da0073e9SAndroid Build Coastguard Worker   }
86*da0073e9SAndroid Build Coastguard Worker 
recordfinal87*da0073e9SAndroid Build Coastguard Worker   void record(
88*da0073e9SAndroid Build Coastguard Worker       void** event,
89*da0073e9SAndroid Build Coastguard Worker       const Stream& stream,
90*da0073e9SAndroid Build Coastguard Worker       const DeviceIndex device_index,
91*da0073e9SAndroid Build Coastguard Worker       const EventFlag flag) const override {
92*da0073e9SAndroid Build Coastguard Worker     TORCH_CHECK(
93*da0073e9SAndroid Build Coastguard Worker         device_index == -1 || device_index == stream.device_index(),
94*da0073e9SAndroid Build Coastguard Worker         "Event device index ",
95*da0073e9SAndroid Build Coastguard Worker         device_index,
96*da0073e9SAndroid Build Coastguard Worker         " does not match recording stream's device index ",
97*da0073e9SAndroid Build Coastguard Worker         stream.device_index(),
98*da0073e9SAndroid Build Coastguard Worker         ".");
99*da0073e9SAndroid Build Coastguard Worker 
100*da0073e9SAndroid Build Coastguard Worker     auto* xpu_event = reinterpret_cast<sycl::event*>(*event);
101*da0073e9SAndroid Build Coastguard Worker     const XPUStream xpu_stream{stream};
102*da0073e9SAndroid Build Coastguard Worker 
103*da0073e9SAndroid Build Coastguard Worker     // Delete the event previously recorded.
104*da0073e9SAndroid Build Coastguard Worker     if (xpu_event)
105*da0073e9SAndroid Build Coastguard Worker       delete xpu_event;
106*da0073e9SAndroid Build Coastguard Worker     xpu_event = new sycl::event(xpu_stream.queue().ext_oneapi_submit_barrier());
107*da0073e9SAndroid Build Coastguard Worker     *event = reinterpret_cast<void*>(xpu_event);
108*da0073e9SAndroid Build Coastguard Worker 
109*da0073e9SAndroid Build Coastguard Worker     const c10::impl::PyInterpreter* interp = c10::impl::GPUTrace::get_trace();
110*da0073e9SAndroid Build Coastguard Worker     if (C10_UNLIKELY(interp)) {
111*da0073e9SAndroid Build Coastguard Worker       (*interp)->trace_gpu_event_record(
112*da0073e9SAndroid Build Coastguard Worker           c10::kXPU,
113*da0073e9SAndroid Build Coastguard Worker           reinterpret_cast<uintptr_t>(xpu_event),
114*da0073e9SAndroid Build Coastguard Worker           reinterpret_cast<uintptr_t>(&xpu_stream.queue()));
115*da0073e9SAndroid Build Coastguard Worker     }
116*da0073e9SAndroid Build Coastguard Worker   }
117*da0073e9SAndroid Build Coastguard Worker 
blockfinal118*da0073e9SAndroid Build Coastguard Worker   void block(void* event, const Stream& stream) const override {
119*da0073e9SAndroid Build Coastguard Worker     if (!event)
120*da0073e9SAndroid Build Coastguard Worker       return;
121*da0073e9SAndroid Build Coastguard Worker     auto* xpu_event = reinterpret_cast<sycl::event*>(event);
122*da0073e9SAndroid Build Coastguard Worker     std::vector<sycl::event> event_list{*xpu_event};
123*da0073e9SAndroid Build Coastguard Worker     const XPUStream xpu_stream(stream);
124*da0073e9SAndroid Build Coastguard Worker     xpu_stream.queue().ext_oneapi_submit_barrier(event_list);
125*da0073e9SAndroid Build Coastguard Worker     const c10::impl::PyInterpreter* interp = c10::impl::GPUTrace::get_trace();
126*da0073e9SAndroid Build Coastguard Worker     if (C10_UNLIKELY(interp)) {
127*da0073e9SAndroid Build Coastguard Worker       (*interp)->trace_gpu_event_wait(
128*da0073e9SAndroid Build Coastguard Worker           c10::kXPU,
129*da0073e9SAndroid Build Coastguard Worker           reinterpret_cast<uintptr_t>(xpu_event),
130*da0073e9SAndroid Build Coastguard Worker           reinterpret_cast<uintptr_t>(&xpu_stream.queue()));
131*da0073e9SAndroid Build Coastguard Worker     }
132*da0073e9SAndroid Build Coastguard Worker   }
133*da0073e9SAndroid Build Coastguard Worker 
queryEventfinal134*da0073e9SAndroid Build Coastguard Worker   bool queryEvent(void* event) const override {
135*da0073e9SAndroid Build Coastguard Worker     using namespace sycl::info;
136*da0073e9SAndroid Build Coastguard Worker     if (!event)
137*da0073e9SAndroid Build Coastguard Worker       return true;
138*da0073e9SAndroid Build Coastguard Worker     auto* xpu_event = reinterpret_cast<sycl::event*>(event);
139*da0073e9SAndroid Build Coastguard Worker     return xpu_event->get_info<event::command_execution_status>() ==
140*da0073e9SAndroid Build Coastguard Worker         event_command_status::complete;
141*da0073e9SAndroid Build Coastguard Worker   }
142*da0073e9SAndroid Build Coastguard Worker 
143*da0073e9SAndroid Build Coastguard Worker   // Stream-related functions
queryStreamfinal144*da0073e9SAndroid Build Coastguard Worker   bool queryStream(const Stream& stream) const override {
145*da0073e9SAndroid Build Coastguard Worker     const XPUStream xpu_stream{stream};
146*da0073e9SAndroid Build Coastguard Worker     return xpu_stream.query();
147*da0073e9SAndroid Build Coastguard Worker   }
148*da0073e9SAndroid Build Coastguard Worker 
synchronizeStreamfinal149*da0073e9SAndroid Build Coastguard Worker   void synchronizeStream(const Stream& stream) const override {
150*da0073e9SAndroid Build Coastguard Worker     const XPUStream xpu_stream{stream};
151*da0073e9SAndroid Build Coastguard Worker     xpu_stream.synchronize();
152*da0073e9SAndroid Build Coastguard Worker   }
153*da0073e9SAndroid Build Coastguard Worker 
synchronizeEventfinal154*da0073e9SAndroid Build Coastguard Worker   void synchronizeEvent(void* event) const override {
155*da0073e9SAndroid Build Coastguard Worker     if (!event)
156*da0073e9SAndroid Build Coastguard Worker       return;
157*da0073e9SAndroid Build Coastguard Worker     auto* xpu_event = reinterpret_cast<sycl::event*>(event);
158*da0073e9SAndroid Build Coastguard Worker     const c10::impl::PyInterpreter* interp = c10::impl::GPUTrace::get_trace();
159*da0073e9SAndroid Build Coastguard Worker     if (C10_UNLIKELY(interp)) {
160*da0073e9SAndroid Build Coastguard Worker       (*interp)->trace_gpu_event_synchronization(
161*da0073e9SAndroid Build Coastguard Worker           c10::kXPU, reinterpret_cast<uintptr_t>(xpu_event));
162*da0073e9SAndroid Build Coastguard Worker     }
163*da0073e9SAndroid Build Coastguard Worker     xpu_event->wait_and_throw();
164*da0073e9SAndroid Build Coastguard Worker   }
165*da0073e9SAndroid Build Coastguard Worker 
recordDataPtrOnStreamfinal166*da0073e9SAndroid Build Coastguard Worker   void recordDataPtrOnStream(const c10::DataPtr& data_ptr, const Stream& stream)
167*da0073e9SAndroid Build Coastguard Worker       const override {
168*da0073e9SAndroid Build Coastguard Worker     const XPUStream xpu_stream{stream};
169*da0073e9SAndroid Build Coastguard Worker     XPUCachingAllocator::recordStream(data_ptr, xpu_stream);
170*da0073e9SAndroid Build Coastguard Worker   }
171*da0073e9SAndroid Build Coastguard Worker 
elapsedTimefinal172*da0073e9SAndroid Build Coastguard Worker   double elapsedTime(void* event1, void* event2, const DeviceIndex device_index)
173*da0073e9SAndroid Build Coastguard Worker       const override {
174*da0073e9SAndroid Build Coastguard Worker     TORCH_CHECK_NOT_IMPLEMENTED(
175*da0073e9SAndroid Build Coastguard Worker         false, "elapsedTime is not supported by XPU backend.");
176*da0073e9SAndroid Build Coastguard Worker   }
177*da0073e9SAndroid Build Coastguard Worker };
178*da0073e9SAndroid Build Coastguard Worker 
179*da0073e9SAndroid Build Coastguard Worker } // namespace c10::xpu::impl
180