xref: /aosp_15_r20/external/pytorch/c10/xpu/impl/XPUGuardImpl.h (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #pragma once
2 
3 #include <c10/core/DeviceGuard.h>
4 #include <c10/core/impl/DeviceGuardImplInterface.h>
5 #include <c10/core/impl/GPUTrace.h>
6 #include <c10/xpu/XPUCachingAllocator.h>
7 #include <c10/xpu/XPUFunctions.h>
8 #include <c10/xpu/XPUStream.h>
9 
10 #include <vector>
11 
12 namespace c10::xpu::impl {
13 
14 struct XPUGuardImpl final : public c10::impl::DeviceGuardImplInterface {
15   static constexpr DeviceType static_type = kXPU;
16 
17   XPUGuardImpl() = default;
18 
XPUGuardImplfinal19   explicit XPUGuardImpl(DeviceType t) {
20     TORCH_INTERNAL_ASSERT(t == kXPU);
21   }
22 
typefinal23   DeviceType type() const override {
24     return kXPU;
25   }
26 
exchangeDevicefinal27   Device exchangeDevice(Device d) const override {
28     TORCH_INTERNAL_ASSERT(d.is_xpu());
29     const auto old_device_index = c10::xpu::exchange_device(d.index());
30     return Device(kXPU, old_device_index);
31   }
32 
getDevicefinal33   Device getDevice() const override {
34     const auto device = c10::xpu::current_device();
35     return Device(kXPU, device);
36   }
37 
setDevicefinal38   void setDevice(Device d) const override {
39     TORCH_INTERNAL_ASSERT(d.is_xpu());
40     c10::xpu::set_device(d.index());
41   }
42 
uncheckedSetDevicefinal43   void uncheckedSetDevice(Device d) const noexcept override {
44     c10::xpu::set_device(d.index());
45   }
46 
getStreamfinal47   Stream getStream(Device d) const noexcept override {
48     return getCurrentXPUStream(d.index()).unwrap();
49   }
50 
51   Stream getNewStream(Device d, int priority = 0) const override {
52     return getStreamFromPool(priority, d.index());
53   }
54 
55   Stream getStreamFromGlobalPool(Device d, bool isHighPriority = false)
56       const override {
57     return getStreamFromPool(isHighPriority, d.index());
58   }
59 
60   // NB: These do NOT set the current device
exchangeStreamfinal61   Stream exchangeStream(Stream s) const noexcept override {
62     const XPUStream stream(s);
63     const auto old_stream = getCurrentXPUStream(s.device().index());
64     setCurrentXPUStream(stream);
65     return old_stream.unwrap();
66   }
67 
deviceCountfinal68   DeviceIndex deviceCount() const noexcept override {
69     return c10::xpu::device_count();
70   }
71 
72   // Event-related functions
destroyEventfinal73   void destroyEvent(void* event, const DeviceIndex device_index)
74       const noexcept override {
75     if (!event)
76       return;
77 
78     const c10::impl::PyInterpreter* interp = c10::impl::GPUTrace::get_trace();
79     if (C10_UNLIKELY(interp)) {
80       (*interp)->trace_gpu_event_deletion(
81           c10::kXPU, reinterpret_cast<uintptr_t>(event));
82     }
83 
84     delete reinterpret_cast<sycl::event*>(event);
85   }
86 
recordfinal87   void record(
88       void** event,
89       const Stream& stream,
90       const DeviceIndex device_index,
91       const EventFlag flag) const override {
92     TORCH_CHECK(
93         device_index == -1 || device_index == stream.device_index(),
94         "Event device index ",
95         device_index,
96         " does not match recording stream's device index ",
97         stream.device_index(),
98         ".");
99 
100     auto* xpu_event = reinterpret_cast<sycl::event*>(*event);
101     const XPUStream xpu_stream{stream};
102 
103     // Delete the event previously recorded.
104     if (xpu_event)
105       delete xpu_event;
106     xpu_event = new sycl::event(xpu_stream.queue().ext_oneapi_submit_barrier());
107     *event = reinterpret_cast<void*>(xpu_event);
108 
109     const c10::impl::PyInterpreter* interp = c10::impl::GPUTrace::get_trace();
110     if (C10_UNLIKELY(interp)) {
111       (*interp)->trace_gpu_event_record(
112           c10::kXPU,
113           reinterpret_cast<uintptr_t>(xpu_event),
114           reinterpret_cast<uintptr_t>(&xpu_stream.queue()));
115     }
116   }
117 
blockfinal118   void block(void* event, const Stream& stream) const override {
119     if (!event)
120       return;
121     auto* xpu_event = reinterpret_cast<sycl::event*>(event);
122     std::vector<sycl::event> event_list{*xpu_event};
123     const XPUStream xpu_stream(stream);
124     xpu_stream.queue().ext_oneapi_submit_barrier(event_list);
125     const c10::impl::PyInterpreter* interp = c10::impl::GPUTrace::get_trace();
126     if (C10_UNLIKELY(interp)) {
127       (*interp)->trace_gpu_event_wait(
128           c10::kXPU,
129           reinterpret_cast<uintptr_t>(xpu_event),
130           reinterpret_cast<uintptr_t>(&xpu_stream.queue()));
131     }
132   }
133 
queryEventfinal134   bool queryEvent(void* event) const override {
135     using namespace sycl::info;
136     if (!event)
137       return true;
138     auto* xpu_event = reinterpret_cast<sycl::event*>(event);
139     return xpu_event->get_info<event::command_execution_status>() ==
140         event_command_status::complete;
141   }
142 
143   // Stream-related functions
queryStreamfinal144   bool queryStream(const Stream& stream) const override {
145     const XPUStream xpu_stream{stream};
146     return xpu_stream.query();
147   }
148 
synchronizeStreamfinal149   void synchronizeStream(const Stream& stream) const override {
150     const XPUStream xpu_stream{stream};
151     xpu_stream.synchronize();
152   }
153 
synchronizeEventfinal154   void synchronizeEvent(void* event) const override {
155     if (!event)
156       return;
157     auto* xpu_event = reinterpret_cast<sycl::event*>(event);
158     const c10::impl::PyInterpreter* interp = c10::impl::GPUTrace::get_trace();
159     if (C10_UNLIKELY(interp)) {
160       (*interp)->trace_gpu_event_synchronization(
161           c10::kXPU, reinterpret_cast<uintptr_t>(xpu_event));
162     }
163     xpu_event->wait_and_throw();
164   }
165 
recordDataPtrOnStreamfinal166   void recordDataPtrOnStream(const c10::DataPtr& data_ptr, const Stream& stream)
167       const override {
168     const XPUStream xpu_stream{stream};
169     XPUCachingAllocator::recordStream(data_ptr, xpu_stream);
170   }
171 
elapsedTimefinal172   double elapsedTime(void* event1, void* event2, const DeviceIndex device_index)
173       const override {
174     TORCH_CHECK_NOT_IMPLEMENTED(
175         false, "elapsedTime is not supported by XPU backend.");
176   }
177 };
178 
179 } // namespace c10::xpu::impl
180