1*da0073e9SAndroid Build Coastguard Worker #pragma once 2*da0073e9SAndroid Build Coastguard Worker 3*da0073e9SAndroid Build Coastguard Worker #include <c10/core/DeviceGuard.h> 4*da0073e9SAndroid Build Coastguard Worker #include <c10/core/impl/DeviceGuardImplInterface.h> 5*da0073e9SAndroid Build Coastguard Worker #include <c10/core/impl/GPUTrace.h> 6*da0073e9SAndroid Build Coastguard Worker #include <c10/xpu/XPUCachingAllocator.h> 7*da0073e9SAndroid Build Coastguard Worker #include <c10/xpu/XPUFunctions.h> 8*da0073e9SAndroid Build Coastguard Worker #include <c10/xpu/XPUStream.h> 9*da0073e9SAndroid Build Coastguard Worker 10*da0073e9SAndroid Build Coastguard Worker #include <vector> 11*da0073e9SAndroid Build Coastguard Worker 12*da0073e9SAndroid Build Coastguard Worker namespace c10::xpu::impl { 13*da0073e9SAndroid Build Coastguard Worker 14*da0073e9SAndroid Build Coastguard Worker struct XPUGuardImpl final : public c10::impl::DeviceGuardImplInterface { 15*da0073e9SAndroid Build Coastguard Worker static constexpr DeviceType static_type = kXPU; 16*da0073e9SAndroid Build Coastguard Worker 17*da0073e9SAndroid Build Coastguard Worker XPUGuardImpl() = default; 18*da0073e9SAndroid Build Coastguard Worker XPUGuardImplfinal19*da0073e9SAndroid Build Coastguard Worker explicit XPUGuardImpl(DeviceType t) { 20*da0073e9SAndroid Build Coastguard Worker TORCH_INTERNAL_ASSERT(t == kXPU); 21*da0073e9SAndroid Build Coastguard Worker } 22*da0073e9SAndroid Build Coastguard Worker typefinal23*da0073e9SAndroid Build Coastguard Worker DeviceType type() const override { 24*da0073e9SAndroid Build Coastguard Worker return kXPU; 25*da0073e9SAndroid Build Coastguard Worker } 26*da0073e9SAndroid Build Coastguard Worker exchangeDevicefinal27*da0073e9SAndroid Build Coastguard Worker Device exchangeDevice(Device d) const override { 28*da0073e9SAndroid Build Coastguard Worker TORCH_INTERNAL_ASSERT(d.is_xpu()); 29*da0073e9SAndroid Build Coastguard Worker const auto old_device_index = c10::xpu::exchange_device(d.index()); 30*da0073e9SAndroid Build Coastguard Worker return Device(kXPU, old_device_index); 31*da0073e9SAndroid Build Coastguard Worker } 32*da0073e9SAndroid Build Coastguard Worker getDevicefinal33*da0073e9SAndroid Build Coastguard Worker Device getDevice() const override { 34*da0073e9SAndroid Build Coastguard Worker const auto device = c10::xpu::current_device(); 35*da0073e9SAndroid Build Coastguard Worker return Device(kXPU, device); 36*da0073e9SAndroid Build Coastguard Worker } 37*da0073e9SAndroid Build Coastguard Worker setDevicefinal38*da0073e9SAndroid Build Coastguard Worker void setDevice(Device d) const override { 39*da0073e9SAndroid Build Coastguard Worker TORCH_INTERNAL_ASSERT(d.is_xpu()); 40*da0073e9SAndroid Build Coastguard Worker c10::xpu::set_device(d.index()); 41*da0073e9SAndroid Build Coastguard Worker } 42*da0073e9SAndroid Build Coastguard Worker uncheckedSetDevicefinal43*da0073e9SAndroid Build Coastguard Worker void uncheckedSetDevice(Device d) const noexcept override { 44*da0073e9SAndroid Build Coastguard Worker c10::xpu::set_device(d.index()); 45*da0073e9SAndroid Build Coastguard Worker } 46*da0073e9SAndroid Build Coastguard Worker getStreamfinal47*da0073e9SAndroid Build Coastguard Worker Stream getStream(Device d) const noexcept override { 48*da0073e9SAndroid Build Coastguard Worker return getCurrentXPUStream(d.index()).unwrap(); 49*da0073e9SAndroid Build Coastguard Worker } 50*da0073e9SAndroid Build Coastguard Worker 51*da0073e9SAndroid Build Coastguard Worker Stream getNewStream(Device d, int priority = 0) const override { 52*da0073e9SAndroid Build Coastguard Worker return getStreamFromPool(priority, d.index()); 53*da0073e9SAndroid Build Coastguard Worker } 54*da0073e9SAndroid Build Coastguard Worker 55*da0073e9SAndroid Build Coastguard Worker Stream getStreamFromGlobalPool(Device d, bool isHighPriority = false) 56*da0073e9SAndroid Build Coastguard Worker const override { 57*da0073e9SAndroid Build Coastguard Worker return getStreamFromPool(isHighPriority, d.index()); 58*da0073e9SAndroid Build Coastguard Worker } 59*da0073e9SAndroid Build Coastguard Worker 60*da0073e9SAndroid Build Coastguard Worker // NB: These do NOT set the current device exchangeStreamfinal61*da0073e9SAndroid Build Coastguard Worker Stream exchangeStream(Stream s) const noexcept override { 62*da0073e9SAndroid Build Coastguard Worker const XPUStream stream(s); 63*da0073e9SAndroid Build Coastguard Worker const auto old_stream = getCurrentXPUStream(s.device().index()); 64*da0073e9SAndroid Build Coastguard Worker setCurrentXPUStream(stream); 65*da0073e9SAndroid Build Coastguard Worker return old_stream.unwrap(); 66*da0073e9SAndroid Build Coastguard Worker } 67*da0073e9SAndroid Build Coastguard Worker deviceCountfinal68*da0073e9SAndroid Build Coastguard Worker DeviceIndex deviceCount() const noexcept override { 69*da0073e9SAndroid Build Coastguard Worker return c10::xpu::device_count(); 70*da0073e9SAndroid Build Coastguard Worker } 71*da0073e9SAndroid Build Coastguard Worker 72*da0073e9SAndroid Build Coastguard Worker // Event-related functions destroyEventfinal73*da0073e9SAndroid Build Coastguard Worker void destroyEvent(void* event, const DeviceIndex device_index) 74*da0073e9SAndroid Build Coastguard Worker const noexcept override { 75*da0073e9SAndroid Build Coastguard Worker if (!event) 76*da0073e9SAndroid Build Coastguard Worker return; 77*da0073e9SAndroid Build Coastguard Worker 78*da0073e9SAndroid Build Coastguard Worker const c10::impl::PyInterpreter* interp = c10::impl::GPUTrace::get_trace(); 79*da0073e9SAndroid Build Coastguard Worker if (C10_UNLIKELY(interp)) { 80*da0073e9SAndroid Build Coastguard Worker (*interp)->trace_gpu_event_deletion( 81*da0073e9SAndroid Build Coastguard Worker c10::kXPU, reinterpret_cast<uintptr_t>(event)); 82*da0073e9SAndroid Build Coastguard Worker } 83*da0073e9SAndroid Build Coastguard Worker 84*da0073e9SAndroid Build Coastguard Worker delete reinterpret_cast<sycl::event*>(event); 85*da0073e9SAndroid Build Coastguard Worker } 86*da0073e9SAndroid Build Coastguard Worker recordfinal87*da0073e9SAndroid Build Coastguard Worker void record( 88*da0073e9SAndroid Build Coastguard Worker void** event, 89*da0073e9SAndroid Build Coastguard Worker const Stream& stream, 90*da0073e9SAndroid Build Coastguard Worker const DeviceIndex device_index, 91*da0073e9SAndroid Build Coastguard Worker const EventFlag flag) const override { 92*da0073e9SAndroid Build Coastguard Worker TORCH_CHECK( 93*da0073e9SAndroid Build Coastguard Worker device_index == -1 || device_index == stream.device_index(), 94*da0073e9SAndroid Build Coastguard Worker "Event device index ", 95*da0073e9SAndroid Build Coastguard Worker device_index, 96*da0073e9SAndroid Build Coastguard Worker " does not match recording stream's device index ", 97*da0073e9SAndroid Build Coastguard Worker stream.device_index(), 98*da0073e9SAndroid Build Coastguard Worker "."); 99*da0073e9SAndroid Build Coastguard Worker 100*da0073e9SAndroid Build Coastguard Worker auto* xpu_event = reinterpret_cast<sycl::event*>(*event); 101*da0073e9SAndroid Build Coastguard Worker const XPUStream xpu_stream{stream}; 102*da0073e9SAndroid Build Coastguard Worker 103*da0073e9SAndroid Build Coastguard Worker // Delete the event previously recorded. 104*da0073e9SAndroid Build Coastguard Worker if (xpu_event) 105*da0073e9SAndroid Build Coastguard Worker delete xpu_event; 106*da0073e9SAndroid Build Coastguard Worker xpu_event = new sycl::event(xpu_stream.queue().ext_oneapi_submit_barrier()); 107*da0073e9SAndroid Build Coastguard Worker *event = reinterpret_cast<void*>(xpu_event); 108*da0073e9SAndroid Build Coastguard Worker 109*da0073e9SAndroid Build Coastguard Worker const c10::impl::PyInterpreter* interp = c10::impl::GPUTrace::get_trace(); 110*da0073e9SAndroid Build Coastguard Worker if (C10_UNLIKELY(interp)) { 111*da0073e9SAndroid Build Coastguard Worker (*interp)->trace_gpu_event_record( 112*da0073e9SAndroid Build Coastguard Worker c10::kXPU, 113*da0073e9SAndroid Build Coastguard Worker reinterpret_cast<uintptr_t>(xpu_event), 114*da0073e9SAndroid Build Coastguard Worker reinterpret_cast<uintptr_t>(&xpu_stream.queue())); 115*da0073e9SAndroid Build Coastguard Worker } 116*da0073e9SAndroid Build Coastguard Worker } 117*da0073e9SAndroid Build Coastguard Worker blockfinal118*da0073e9SAndroid Build Coastguard Worker void block(void* event, const Stream& stream) const override { 119*da0073e9SAndroid Build Coastguard Worker if (!event) 120*da0073e9SAndroid Build Coastguard Worker return; 121*da0073e9SAndroid Build Coastguard Worker auto* xpu_event = reinterpret_cast<sycl::event*>(event); 122*da0073e9SAndroid Build Coastguard Worker std::vector<sycl::event> event_list{*xpu_event}; 123*da0073e9SAndroid Build Coastguard Worker const XPUStream xpu_stream(stream); 124*da0073e9SAndroid Build Coastguard Worker xpu_stream.queue().ext_oneapi_submit_barrier(event_list); 125*da0073e9SAndroid Build Coastguard Worker const c10::impl::PyInterpreter* interp = c10::impl::GPUTrace::get_trace(); 126*da0073e9SAndroid Build Coastguard Worker if (C10_UNLIKELY(interp)) { 127*da0073e9SAndroid Build Coastguard Worker (*interp)->trace_gpu_event_wait( 128*da0073e9SAndroid Build Coastguard Worker c10::kXPU, 129*da0073e9SAndroid Build Coastguard Worker reinterpret_cast<uintptr_t>(xpu_event), 130*da0073e9SAndroid Build Coastguard Worker reinterpret_cast<uintptr_t>(&xpu_stream.queue())); 131*da0073e9SAndroid Build Coastguard Worker } 132*da0073e9SAndroid Build Coastguard Worker } 133*da0073e9SAndroid Build Coastguard Worker queryEventfinal134*da0073e9SAndroid Build Coastguard Worker bool queryEvent(void* event) const override { 135*da0073e9SAndroid Build Coastguard Worker using namespace sycl::info; 136*da0073e9SAndroid Build Coastguard Worker if (!event) 137*da0073e9SAndroid Build Coastguard Worker return true; 138*da0073e9SAndroid Build Coastguard Worker auto* xpu_event = reinterpret_cast<sycl::event*>(event); 139*da0073e9SAndroid Build Coastguard Worker return xpu_event->get_info<event::command_execution_status>() == 140*da0073e9SAndroid Build Coastguard Worker event_command_status::complete; 141*da0073e9SAndroid Build Coastguard Worker } 142*da0073e9SAndroid Build Coastguard Worker 143*da0073e9SAndroid Build Coastguard Worker // Stream-related functions queryStreamfinal144*da0073e9SAndroid Build Coastguard Worker bool queryStream(const Stream& stream) const override { 145*da0073e9SAndroid Build Coastguard Worker const XPUStream xpu_stream{stream}; 146*da0073e9SAndroid Build Coastguard Worker return xpu_stream.query(); 147*da0073e9SAndroid Build Coastguard Worker } 148*da0073e9SAndroid Build Coastguard Worker synchronizeStreamfinal149*da0073e9SAndroid Build Coastguard Worker void synchronizeStream(const Stream& stream) const override { 150*da0073e9SAndroid Build Coastguard Worker const XPUStream xpu_stream{stream}; 151*da0073e9SAndroid Build Coastguard Worker xpu_stream.synchronize(); 152*da0073e9SAndroid Build Coastguard Worker } 153*da0073e9SAndroid Build Coastguard Worker synchronizeEventfinal154*da0073e9SAndroid Build Coastguard Worker void synchronizeEvent(void* event) const override { 155*da0073e9SAndroid Build Coastguard Worker if (!event) 156*da0073e9SAndroid Build Coastguard Worker return; 157*da0073e9SAndroid Build Coastguard Worker auto* xpu_event = reinterpret_cast<sycl::event*>(event); 158*da0073e9SAndroid Build Coastguard Worker const c10::impl::PyInterpreter* interp = c10::impl::GPUTrace::get_trace(); 159*da0073e9SAndroid Build Coastguard Worker if (C10_UNLIKELY(interp)) { 160*da0073e9SAndroid Build Coastguard Worker (*interp)->trace_gpu_event_synchronization( 161*da0073e9SAndroid Build Coastguard Worker c10::kXPU, reinterpret_cast<uintptr_t>(xpu_event)); 162*da0073e9SAndroid Build Coastguard Worker } 163*da0073e9SAndroid Build Coastguard Worker xpu_event->wait_and_throw(); 164*da0073e9SAndroid Build Coastguard Worker } 165*da0073e9SAndroid Build Coastguard Worker recordDataPtrOnStreamfinal166*da0073e9SAndroid Build Coastguard Worker void recordDataPtrOnStream(const c10::DataPtr& data_ptr, const Stream& stream) 167*da0073e9SAndroid Build Coastguard Worker const override { 168*da0073e9SAndroid Build Coastguard Worker const XPUStream xpu_stream{stream}; 169*da0073e9SAndroid Build Coastguard Worker XPUCachingAllocator::recordStream(data_ptr, xpu_stream); 170*da0073e9SAndroid Build Coastguard Worker } 171*da0073e9SAndroid Build Coastguard Worker elapsedTimefinal172*da0073e9SAndroid Build Coastguard Worker double elapsedTime(void* event1, void* event2, const DeviceIndex device_index) 173*da0073e9SAndroid Build Coastguard Worker const override { 174*da0073e9SAndroid Build Coastguard Worker TORCH_CHECK_NOT_IMPLEMENTED( 175*da0073e9SAndroid Build Coastguard Worker false, "elapsedTime is not supported by XPU backend."); 176*da0073e9SAndroid Build Coastguard Worker } 177*da0073e9SAndroid Build Coastguard Worker }; 178*da0073e9SAndroid Build Coastguard Worker 179*da0073e9SAndroid Build Coastguard Worker } // namespace c10::xpu::impl 180