1*da0073e9SAndroid Build Coastguard Worker #pragma once 2*da0073e9SAndroid Build Coastguard Worker 3*da0073e9SAndroid Build Coastguard Worker #include <c10/core/impl/DeviceGuardImplInterface.h> 4*da0073e9SAndroid Build Coastguard Worker #include <c10/core/impl/GPUTrace.h> 5*da0073e9SAndroid Build Coastguard Worker #include <c10/macros/Macros.h> 6*da0073e9SAndroid Build Coastguard Worker #include <c10/util/Exception.h> 7*da0073e9SAndroid Build Coastguard Worker 8*da0073e9SAndroid Build Coastguard Worker #include <c10/cuda/CUDACachingAllocator.h> 9*da0073e9SAndroid Build Coastguard Worker #include <c10/cuda/CUDAException.h> 10*da0073e9SAndroid Build Coastguard Worker #include <c10/cuda/CUDAFunctions.h> 11*da0073e9SAndroid Build Coastguard Worker #include <c10/cuda/CUDAStream.h> 12*da0073e9SAndroid Build Coastguard Worker 13*da0073e9SAndroid Build Coastguard Worker #include <c10/core/Device.h> 14*da0073e9SAndroid Build Coastguard Worker #include <c10/core/DeviceType.h> 15*da0073e9SAndroid Build Coastguard Worker #include <c10/core/Stream.h> 16*da0073e9SAndroid Build Coastguard Worker #include <c10/core/impl/PyInterpreter.h> 17*da0073e9SAndroid Build Coastguard Worker #include <cuda_runtime_api.h> 18*da0073e9SAndroid Build Coastguard Worker #include <cstdint> 19*da0073e9SAndroid Build Coastguard Worker #include <optional> 20*da0073e9SAndroid Build Coastguard Worker 21*da0073e9SAndroid Build Coastguard Worker namespace c10::cuda::impl { 22*da0073e9SAndroid Build Coastguard Worker 23*da0073e9SAndroid Build Coastguard Worker struct CUDAGuardImpl final : public c10::impl::DeviceGuardImplInterface { 24*da0073e9SAndroid Build Coastguard Worker static constexpr DeviceType static_type = DeviceType::CUDA; 25*da0073e9SAndroid Build Coastguard Worker 26*da0073e9SAndroid Build Coastguard Worker CUDAGuardImpl() = default; CUDAGuardImplfinal27*da0073e9SAndroid Build Coastguard Worker explicit CUDAGuardImpl(DeviceType t) { 28*da0073e9SAndroid Build Coastguard Worker TORCH_INTERNAL_ASSERT(t == DeviceType::CUDA); 29*da0073e9SAndroid Build Coastguard Worker } typefinal30*da0073e9SAndroid Build Coastguard Worker DeviceType type() const override { 31*da0073e9SAndroid Build Coastguard Worker return DeviceType::CUDA; 32*da0073e9SAndroid Build Coastguard Worker } exchangeDevicefinal33*da0073e9SAndroid Build Coastguard Worker Device exchangeDevice(Device d) const override { 34*da0073e9SAndroid Build Coastguard Worker TORCH_INTERNAL_ASSERT(d.is_cuda()); 35*da0073e9SAndroid Build Coastguard Worker auto old_device_index = c10::cuda::ExchangeDevice(d.index()); 36*da0073e9SAndroid Build Coastguard Worker return Device(DeviceType::CUDA, old_device_index); 37*da0073e9SAndroid Build Coastguard Worker } getDevicefinal38*da0073e9SAndroid Build Coastguard Worker Device getDevice() const override { 39*da0073e9SAndroid Build Coastguard Worker DeviceIndex device = 0; 40*da0073e9SAndroid Build Coastguard Worker C10_CUDA_CHECK(c10::cuda::GetDevice(&device)); 41*da0073e9SAndroid Build Coastguard Worker return Device(DeviceType::CUDA, device); 42*da0073e9SAndroid Build Coastguard Worker } uncheckedGetDevicefinal43*da0073e9SAndroid Build Coastguard Worker std::optional<Device> uncheckedGetDevice() const noexcept { 44*da0073e9SAndroid Build Coastguard Worker DeviceIndex device{-1}; 45*da0073e9SAndroid Build Coastguard Worker const auto err = C10_CUDA_ERROR_HANDLED(c10::cuda::GetDevice(&device)); 46*da0073e9SAndroid Build Coastguard Worker C10_CUDA_CHECK_WARN(err); 47*da0073e9SAndroid Build Coastguard Worker if (err != cudaSuccess) { 48*da0073e9SAndroid Build Coastguard Worker return std::nullopt; 49*da0073e9SAndroid Build Coastguard Worker } 50*da0073e9SAndroid Build Coastguard Worker return Device(DeviceType::CUDA, device); 51*da0073e9SAndroid Build Coastguard Worker } setDevicefinal52*da0073e9SAndroid Build Coastguard Worker void setDevice(Device d) const override { 53*da0073e9SAndroid Build Coastguard Worker TORCH_INTERNAL_ASSERT(d.is_cuda()); 54*da0073e9SAndroid Build Coastguard Worker C10_CUDA_CHECK(c10::cuda::SetDevice(d.index())); 55*da0073e9SAndroid Build Coastguard Worker } uncheckedSetDevicefinal56*da0073e9SAndroid Build Coastguard Worker void uncheckedSetDevice(Device d) const noexcept override { 57*da0073e9SAndroid Build Coastguard Worker C10_CUDA_CHECK_WARN(c10::cuda::MaybeSetDevice(d.index())); 58*da0073e9SAndroid Build Coastguard Worker } getStreamfinal59*da0073e9SAndroid Build Coastguard Worker Stream getStream(Device d) const noexcept override { 60*da0073e9SAndroid Build Coastguard Worker return getCurrentCUDAStream(d.index()).unwrap(); 61*da0073e9SAndroid Build Coastguard Worker } getDefaultStreamfinal62*da0073e9SAndroid Build Coastguard Worker Stream getDefaultStream(Device d) const override { 63*da0073e9SAndroid Build Coastguard Worker return getDefaultCUDAStream(d.index()); 64*da0073e9SAndroid Build Coastguard Worker } 65*da0073e9SAndroid Build Coastguard Worker Stream getNewStream(Device d, int priority = 0) const override { 66*da0073e9SAndroid Build Coastguard Worker return getStreamFromPool(priority, d.index()); 67*da0073e9SAndroid Build Coastguard Worker } 68*da0073e9SAndroid Build Coastguard Worker Stream getStreamFromGlobalPool(Device d, bool isHighPriority = false) 69*da0073e9SAndroid Build Coastguard Worker const override { 70*da0073e9SAndroid Build Coastguard Worker return getStreamFromPool(isHighPriority, d.index()); 71*da0073e9SAndroid Build Coastguard Worker } 72*da0073e9SAndroid Build Coastguard Worker // NB: These do NOT set the current device exchangeStreamfinal73*da0073e9SAndroid Build Coastguard Worker Stream exchangeStream(Stream s) const noexcept override { 74*da0073e9SAndroid Build Coastguard Worker CUDAStream cs(s); 75*da0073e9SAndroid Build Coastguard Worker auto old_stream = getCurrentCUDAStream(s.device().index()); 76*da0073e9SAndroid Build Coastguard Worker setCurrentCUDAStream(cs); 77*da0073e9SAndroid Build Coastguard Worker return old_stream.unwrap(); 78*da0073e9SAndroid Build Coastguard Worker } deviceCountfinal79*da0073e9SAndroid Build Coastguard Worker DeviceIndex deviceCount() const noexcept override { 80*da0073e9SAndroid Build Coastguard Worker return device_count(); 81*da0073e9SAndroid Build Coastguard Worker } 82*da0073e9SAndroid Build Coastguard Worker 83*da0073e9SAndroid Build Coastguard Worker // Event-related functions createEventfinal84*da0073e9SAndroid Build Coastguard Worker void createEvent(cudaEvent_t* cuda_event, const EventFlag flag) const { 85*da0073e9SAndroid Build Coastguard Worker // Maps PyTorch's Event::Flag to CUDA flag 86*da0073e9SAndroid Build Coastguard Worker auto cuda_flag = cudaEventDefault; 87*da0073e9SAndroid Build Coastguard Worker switch (flag) { 88*da0073e9SAndroid Build Coastguard Worker case EventFlag::PYTORCH_DEFAULT: 89*da0073e9SAndroid Build Coastguard Worker cuda_flag = cudaEventDisableTiming; 90*da0073e9SAndroid Build Coastguard Worker break; 91*da0073e9SAndroid Build Coastguard Worker case EventFlag::BACKEND_DEFAULT: 92*da0073e9SAndroid Build Coastguard Worker cuda_flag = cudaEventDefault; 93*da0073e9SAndroid Build Coastguard Worker break; 94*da0073e9SAndroid Build Coastguard Worker default: 95*da0073e9SAndroid Build Coastguard Worker TORCH_CHECK(false, "CUDA event received unknown flag"); 96*da0073e9SAndroid Build Coastguard Worker } 97*da0073e9SAndroid Build Coastguard Worker 98*da0073e9SAndroid Build Coastguard Worker C10_CUDA_CHECK(cudaEventCreateWithFlags(cuda_event, cuda_flag)); 99*da0073e9SAndroid Build Coastguard Worker const c10::impl::PyInterpreter* interp = c10::impl::GPUTrace::get_trace(); 100*da0073e9SAndroid Build Coastguard Worker if (C10_UNLIKELY(interp)) { 101*da0073e9SAndroid Build Coastguard Worker (*interp)->trace_gpu_event_creation( 102*da0073e9SAndroid Build Coastguard Worker c10::kCUDA, reinterpret_cast<uintptr_t>(cuda_event)); 103*da0073e9SAndroid Build Coastguard Worker } 104*da0073e9SAndroid Build Coastguard Worker } 105*da0073e9SAndroid Build Coastguard Worker destroyEventfinal106*da0073e9SAndroid Build Coastguard Worker void destroyEvent(void* event, const DeviceIndex device_index) 107*da0073e9SAndroid Build Coastguard Worker const noexcept override { 108*da0073e9SAndroid Build Coastguard Worker if (!event) 109*da0073e9SAndroid Build Coastguard Worker return; 110*da0073e9SAndroid Build Coastguard Worker auto cuda_event = static_cast<cudaEvent_t>(event); 111*da0073e9SAndroid Build Coastguard Worker DeviceIndex orig_device{-1}; 112*da0073e9SAndroid Build Coastguard Worker C10_CUDA_CHECK_WARN(c10::cuda::GetDevice(&orig_device)); 113*da0073e9SAndroid Build Coastguard Worker C10_CUDA_CHECK_WARN(c10::cuda::SetDevice(device_index)); 114*da0073e9SAndroid Build Coastguard Worker const c10::impl::PyInterpreter* interp = c10::impl::GPUTrace::get_trace(); 115*da0073e9SAndroid Build Coastguard Worker if (C10_UNLIKELY(interp)) { 116*da0073e9SAndroid Build Coastguard Worker (*interp)->trace_gpu_event_deletion( 117*da0073e9SAndroid Build Coastguard Worker c10::kCUDA, reinterpret_cast<uintptr_t>(cuda_event)); 118*da0073e9SAndroid Build Coastguard Worker } 119*da0073e9SAndroid Build Coastguard Worker C10_CUDA_CHECK_WARN(cudaEventDestroy(cuda_event)); 120*da0073e9SAndroid Build Coastguard Worker C10_CUDA_CHECK_WARN(c10::cuda::SetDevice(orig_device)); 121*da0073e9SAndroid Build Coastguard Worker } 122*da0073e9SAndroid Build Coastguard Worker recordfinal123*da0073e9SAndroid Build Coastguard Worker void record( 124*da0073e9SAndroid Build Coastguard Worker void** event, 125*da0073e9SAndroid Build Coastguard Worker const Stream& stream, 126*da0073e9SAndroid Build Coastguard Worker const DeviceIndex device_index, 127*da0073e9SAndroid Build Coastguard Worker const EventFlag flag) const override { 128*da0073e9SAndroid Build Coastguard Worker TORCH_CHECK( 129*da0073e9SAndroid Build Coastguard Worker device_index == -1 || device_index == stream.device_index(), 130*da0073e9SAndroid Build Coastguard Worker "Event device index ", 131*da0073e9SAndroid Build Coastguard Worker device_index, 132*da0073e9SAndroid Build Coastguard Worker " does not match recording stream's device index ", 133*da0073e9SAndroid Build Coastguard Worker stream.device_index(), 134*da0073e9SAndroid Build Coastguard Worker "."); 135*da0073e9SAndroid Build Coastguard Worker 136*da0073e9SAndroid Build Coastguard Worker cudaEvent_t cuda_event = static_cast<cudaEvent_t>(*event); 137*da0073e9SAndroid Build Coastguard Worker CUDAStream cuda_stream{stream}; 138*da0073e9SAndroid Build Coastguard Worker 139*da0073e9SAndroid Build Coastguard Worker // Moves to stream's device to record 140*da0073e9SAndroid Build Coastguard Worker const auto orig_device = getDevice(); 141*da0073e9SAndroid Build Coastguard Worker setDevice(stream.device()); 142*da0073e9SAndroid Build Coastguard Worker 143*da0073e9SAndroid Build Coastguard Worker // Creates the event (lazily) 144*da0073e9SAndroid Build Coastguard Worker if (!cuda_event) 145*da0073e9SAndroid Build Coastguard Worker createEvent(&cuda_event, flag); 146*da0073e9SAndroid Build Coastguard Worker C10_CUDA_CHECK(cudaEventRecord(cuda_event, cuda_stream)); 147*da0073e9SAndroid Build Coastguard Worker // Makes the void* point to the (possibly just allocated) CUDA event 148*da0073e9SAndroid Build Coastguard Worker *event = cuda_event; 149*da0073e9SAndroid Build Coastguard Worker const c10::impl::PyInterpreter* interp = c10::impl::GPUTrace::get_trace(); 150*da0073e9SAndroid Build Coastguard Worker if (C10_UNLIKELY(interp)) { 151*da0073e9SAndroid Build Coastguard Worker (*interp)->trace_gpu_event_record( 152*da0073e9SAndroid Build Coastguard Worker c10::kCUDA, 153*da0073e9SAndroid Build Coastguard Worker reinterpret_cast<uintptr_t>(cuda_event), 154*da0073e9SAndroid Build Coastguard Worker reinterpret_cast<uintptr_t>(cuda_stream.stream())); 155*da0073e9SAndroid Build Coastguard Worker } 156*da0073e9SAndroid Build Coastguard Worker 157*da0073e9SAndroid Build Coastguard Worker // Resets device 158*da0073e9SAndroid Build Coastguard Worker setDevice(orig_device); 159*da0073e9SAndroid Build Coastguard Worker } 160*da0073e9SAndroid Build Coastguard Worker blockfinal161*da0073e9SAndroid Build Coastguard Worker void block(void* event, const Stream& stream) const override { 162*da0073e9SAndroid Build Coastguard Worker if (!event) 163*da0073e9SAndroid Build Coastguard Worker return; 164*da0073e9SAndroid Build Coastguard Worker cudaEvent_t cuda_event = static_cast<cudaEvent_t>(event); 165*da0073e9SAndroid Build Coastguard Worker CUDAStream cuda_stream{stream}; 166*da0073e9SAndroid Build Coastguard Worker const auto orig_device = getDevice(); 167*da0073e9SAndroid Build Coastguard Worker setDevice(stream.device()); 168*da0073e9SAndroid Build Coastguard Worker C10_CUDA_CHECK(cudaStreamWaitEvent( 169*da0073e9SAndroid Build Coastguard Worker cuda_stream, 170*da0073e9SAndroid Build Coastguard Worker cuda_event, 171*da0073e9SAndroid Build Coastguard Worker /*flags (must be zero)=*/0)); 172*da0073e9SAndroid Build Coastguard Worker const c10::impl::PyInterpreter* interp = c10::impl::GPUTrace::get_trace(); 173*da0073e9SAndroid Build Coastguard Worker if (C10_UNLIKELY(interp)) { 174*da0073e9SAndroid Build Coastguard Worker (*interp)->trace_gpu_event_wait( 175*da0073e9SAndroid Build Coastguard Worker c10::kCUDA, 176*da0073e9SAndroid Build Coastguard Worker reinterpret_cast<uintptr_t>(cuda_event), 177*da0073e9SAndroid Build Coastguard Worker reinterpret_cast<uintptr_t>(cuda_stream.stream())); 178*da0073e9SAndroid Build Coastguard Worker } 179*da0073e9SAndroid Build Coastguard Worker setDevice(orig_device); 180*da0073e9SAndroid Build Coastguard Worker } 181*da0073e9SAndroid Build Coastguard Worker 182*da0073e9SAndroid Build Coastguard Worker // May be called from any device queryEventfinal183*da0073e9SAndroid Build Coastguard Worker bool queryEvent(void* event) const override { 184*da0073e9SAndroid Build Coastguard Worker if (!event) 185*da0073e9SAndroid Build Coastguard Worker return true; 186*da0073e9SAndroid Build Coastguard Worker cudaEvent_t cuda_event = static_cast<cudaEvent_t>(event); 187*da0073e9SAndroid Build Coastguard Worker // Note: cudaEventQuery can be safely called from any device 188*da0073e9SAndroid Build Coastguard Worker const cudaError_t err = C10_CUDA_ERROR_HANDLED(cudaEventQuery(cuda_event)); 189*da0073e9SAndroid Build Coastguard Worker if (err != cudaErrorNotReady) { 190*da0073e9SAndroid Build Coastguard Worker C10_CUDA_CHECK(err); 191*da0073e9SAndroid Build Coastguard Worker } else { 192*da0073e9SAndroid Build Coastguard Worker // ignore and clear the error if not ready 193*da0073e9SAndroid Build Coastguard Worker (void)cudaGetLastError(); 194*da0073e9SAndroid Build Coastguard Worker } 195*da0073e9SAndroid Build Coastguard Worker return (err == cudaSuccess); 196*da0073e9SAndroid Build Coastguard Worker } 197*da0073e9SAndroid Build Coastguard Worker 198*da0073e9SAndroid Build Coastguard Worker // Stream-related functions queryStreamfinal199*da0073e9SAndroid Build Coastguard Worker bool queryStream(const Stream& stream) const override { 200*da0073e9SAndroid Build Coastguard Worker CUDAStream cuda_stream{stream}; 201*da0073e9SAndroid Build Coastguard Worker return cuda_stream.query(); 202*da0073e9SAndroid Build Coastguard Worker } 203*da0073e9SAndroid Build Coastguard Worker synchronizeStreamfinal204*da0073e9SAndroid Build Coastguard Worker void synchronizeStream(const Stream& stream) const override { 205*da0073e9SAndroid Build Coastguard Worker CUDAStream cuda_stream{stream}; 206*da0073e9SAndroid Build Coastguard Worker cuda_stream.synchronize(); 207*da0073e9SAndroid Build Coastguard Worker } 208*da0073e9SAndroid Build Coastguard Worker synchronizeEventfinal209*da0073e9SAndroid Build Coastguard Worker void synchronizeEvent(void* event) const override { 210*da0073e9SAndroid Build Coastguard Worker if (!event) 211*da0073e9SAndroid Build Coastguard Worker return; 212*da0073e9SAndroid Build Coastguard Worker cudaEvent_t cuda_event = static_cast<cudaEvent_t>(event); 213*da0073e9SAndroid Build Coastguard Worker const c10::impl::PyInterpreter* interp = c10::impl::GPUTrace::get_trace(); 214*da0073e9SAndroid Build Coastguard Worker if (C10_UNLIKELY(interp)) { 215*da0073e9SAndroid Build Coastguard Worker (*interp)->trace_gpu_event_synchronization( 216*da0073e9SAndroid Build Coastguard Worker c10::kCUDA, reinterpret_cast<uintptr_t>(cuda_event)); 217*da0073e9SAndroid Build Coastguard Worker } 218*da0073e9SAndroid Build Coastguard Worker // Note: cudaEventSynchronize can be safely called from any device 219*da0073e9SAndroid Build Coastguard Worker C10_CUDA_CHECK(cudaEventSynchronize(cuda_event)); 220*da0073e9SAndroid Build Coastguard Worker } 221*da0073e9SAndroid Build Coastguard Worker recordDataPtrOnStreamfinal222*da0073e9SAndroid Build Coastguard Worker void recordDataPtrOnStream(const c10::DataPtr& data_ptr, const Stream& stream) 223*da0073e9SAndroid Build Coastguard Worker const override { 224*da0073e9SAndroid Build Coastguard Worker CUDAStream cuda_stream{stream}; 225*da0073e9SAndroid Build Coastguard Worker CUDACachingAllocator::recordStream(data_ptr, cuda_stream); 226*da0073e9SAndroid Build Coastguard Worker } 227*da0073e9SAndroid Build Coastguard Worker elapsedTimefinal228*da0073e9SAndroid Build Coastguard Worker double elapsedTime(void* event1, void* event2, const DeviceIndex device_index) 229*da0073e9SAndroid Build Coastguard Worker const override { 230*da0073e9SAndroid Build Coastguard Worker TORCH_CHECK( 231*da0073e9SAndroid Build Coastguard Worker event1 && event2, 232*da0073e9SAndroid Build Coastguard Worker "Both events must be recorded before calculating elapsed time."); 233*da0073e9SAndroid Build Coastguard Worker // Even though cudaEventElapsedTime can be safely called from any device, if 234*da0073e9SAndroid Build Coastguard Worker // the current device is not initialized, it will create a new cuda context, 235*da0073e9SAndroid Build Coastguard Worker // which will consume a lot of memory. 236*da0073e9SAndroid Build Coastguard Worker DeviceIndex orig_device{-1}; 237*da0073e9SAndroid Build Coastguard Worker C10_CUDA_CHECK(c10::cuda::GetDevice(&orig_device)); 238*da0073e9SAndroid Build Coastguard Worker C10_CUDA_CHECK(c10::cuda::SetDevice(device_index)); 239*da0073e9SAndroid Build Coastguard Worker cudaEvent_t cuda_event1 = static_cast<cudaEvent_t>(event1); 240*da0073e9SAndroid Build Coastguard Worker cudaEvent_t cuda_event2 = static_cast<cudaEvent_t>(event2); 241*da0073e9SAndroid Build Coastguard Worker float time_ms = 0; 242*da0073e9SAndroid Build Coastguard Worker // raise cudaErrorNotReady if either event is recorded but not yet completed 243*da0073e9SAndroid Build Coastguard Worker C10_CUDA_CHECK(cudaEventElapsedTime(&time_ms, cuda_event1, cuda_event2)); 244*da0073e9SAndroid Build Coastguard Worker C10_CUDA_CHECK(c10::cuda::SetDevice(orig_device)); 245*da0073e9SAndroid Build Coastguard Worker return static_cast<double>(time_ms); 246*da0073e9SAndroid Build Coastguard Worker } 247*da0073e9SAndroid Build Coastguard Worker }; 248*da0073e9SAndroid Build Coastguard Worker 249*da0073e9SAndroid Build Coastguard Worker } // namespace c10::cuda::impl 250