1 #pragma once 2 3 #include <c10/core/DeviceGuard.h> 4 #include <c10/core/impl/DeviceGuardImplInterface.h> 5 #include <c10/core/impl/GPUTrace.h> 6 #include <c10/xpu/XPUCachingAllocator.h> 7 #include <c10/xpu/XPUFunctions.h> 8 #include <c10/xpu/XPUStream.h> 9 10 #include <vector> 11 12 namespace c10::xpu::impl { 13 14 struct XPUGuardImpl final : public c10::impl::DeviceGuardImplInterface { 15 static constexpr DeviceType static_type = kXPU; 16 17 XPUGuardImpl() = default; 18 XPUGuardImplfinal19 explicit XPUGuardImpl(DeviceType t) { 20 TORCH_INTERNAL_ASSERT(t == kXPU); 21 } 22 typefinal23 DeviceType type() const override { 24 return kXPU; 25 } 26 exchangeDevicefinal27 Device exchangeDevice(Device d) const override { 28 TORCH_INTERNAL_ASSERT(d.is_xpu()); 29 const auto old_device_index = c10::xpu::exchange_device(d.index()); 30 return Device(kXPU, old_device_index); 31 } 32 getDevicefinal33 Device getDevice() const override { 34 const auto device = c10::xpu::current_device(); 35 return Device(kXPU, device); 36 } 37 setDevicefinal38 void setDevice(Device d) const override { 39 TORCH_INTERNAL_ASSERT(d.is_xpu()); 40 c10::xpu::set_device(d.index()); 41 } 42 uncheckedSetDevicefinal43 void uncheckedSetDevice(Device d) const noexcept override { 44 c10::xpu::set_device(d.index()); 45 } 46 getStreamfinal47 Stream getStream(Device d) const noexcept override { 48 return getCurrentXPUStream(d.index()).unwrap(); 49 } 50 51 Stream getNewStream(Device d, int priority = 0) const override { 52 return getStreamFromPool(priority, d.index()); 53 } 54 55 Stream getStreamFromGlobalPool(Device d, bool isHighPriority = false) 56 const override { 57 return getStreamFromPool(isHighPriority, d.index()); 58 } 59 60 // NB: These do NOT set the current device exchangeStreamfinal61 Stream exchangeStream(Stream s) const noexcept override { 62 const XPUStream stream(s); 63 const auto old_stream = getCurrentXPUStream(s.device().index()); 64 setCurrentXPUStream(stream); 65 return old_stream.unwrap(); 66 } 67 deviceCountfinal68 DeviceIndex deviceCount() const noexcept override { 69 return c10::xpu::device_count(); 70 } 71 72 // Event-related functions destroyEventfinal73 void destroyEvent(void* event, const DeviceIndex device_index) 74 const noexcept override { 75 if (!event) 76 return; 77 78 const c10::impl::PyInterpreter* interp = c10::impl::GPUTrace::get_trace(); 79 if (C10_UNLIKELY(interp)) { 80 (*interp)->trace_gpu_event_deletion( 81 c10::kXPU, reinterpret_cast<uintptr_t>(event)); 82 } 83 84 delete reinterpret_cast<sycl::event*>(event); 85 } 86 recordfinal87 void record( 88 void** event, 89 const Stream& stream, 90 const DeviceIndex device_index, 91 const EventFlag flag) const override { 92 TORCH_CHECK( 93 device_index == -1 || device_index == stream.device_index(), 94 "Event device index ", 95 device_index, 96 " does not match recording stream's device index ", 97 stream.device_index(), 98 "."); 99 100 auto* xpu_event = reinterpret_cast<sycl::event*>(*event); 101 const XPUStream xpu_stream{stream}; 102 103 // Delete the event previously recorded. 104 if (xpu_event) 105 delete xpu_event; 106 xpu_event = new sycl::event(xpu_stream.queue().ext_oneapi_submit_barrier()); 107 *event = reinterpret_cast<void*>(xpu_event); 108 109 const c10::impl::PyInterpreter* interp = c10::impl::GPUTrace::get_trace(); 110 if (C10_UNLIKELY(interp)) { 111 (*interp)->trace_gpu_event_record( 112 c10::kXPU, 113 reinterpret_cast<uintptr_t>(xpu_event), 114 reinterpret_cast<uintptr_t>(&xpu_stream.queue())); 115 } 116 } 117 blockfinal118 void block(void* event, const Stream& stream) const override { 119 if (!event) 120 return; 121 auto* xpu_event = reinterpret_cast<sycl::event*>(event); 122 std::vector<sycl::event> event_list{*xpu_event}; 123 const XPUStream xpu_stream(stream); 124 xpu_stream.queue().ext_oneapi_submit_barrier(event_list); 125 const c10::impl::PyInterpreter* interp = c10::impl::GPUTrace::get_trace(); 126 if (C10_UNLIKELY(interp)) { 127 (*interp)->trace_gpu_event_wait( 128 c10::kXPU, 129 reinterpret_cast<uintptr_t>(xpu_event), 130 reinterpret_cast<uintptr_t>(&xpu_stream.queue())); 131 } 132 } 133 queryEventfinal134 bool queryEvent(void* event) const override { 135 using namespace sycl::info; 136 if (!event) 137 return true; 138 auto* xpu_event = reinterpret_cast<sycl::event*>(event); 139 return xpu_event->get_info<event::command_execution_status>() == 140 event_command_status::complete; 141 } 142 143 // Stream-related functions queryStreamfinal144 bool queryStream(const Stream& stream) const override { 145 const XPUStream xpu_stream{stream}; 146 return xpu_stream.query(); 147 } 148 synchronizeStreamfinal149 void synchronizeStream(const Stream& stream) const override { 150 const XPUStream xpu_stream{stream}; 151 xpu_stream.synchronize(); 152 } 153 synchronizeEventfinal154 void synchronizeEvent(void* event) const override { 155 if (!event) 156 return; 157 auto* xpu_event = reinterpret_cast<sycl::event*>(event); 158 const c10::impl::PyInterpreter* interp = c10::impl::GPUTrace::get_trace(); 159 if (C10_UNLIKELY(interp)) { 160 (*interp)->trace_gpu_event_synchronization( 161 c10::kXPU, reinterpret_cast<uintptr_t>(xpu_event)); 162 } 163 xpu_event->wait_and_throw(); 164 } 165 recordDataPtrOnStreamfinal166 void recordDataPtrOnStream(const c10::DataPtr& data_ptr, const Stream& stream) 167 const override { 168 const XPUStream xpu_stream{stream}; 169 XPUCachingAllocator::recordStream(data_ptr, xpu_stream); 170 } 171 elapsedTimefinal172 double elapsedTime(void* event1, void* event2, const DeviceIndex device_index) 173 const override { 174 TORCH_CHECK_NOT_IMPLEMENTED( 175 false, "elapsedTime is not supported by XPU backend."); 176 } 177 }; 178 179 } // namespace c10::xpu::impl 180