1 #pragma once 2 3 #include <c10/core/impl/DeviceGuardImplInterface.h> 4 #include <c10/core/impl/GPUTrace.h> 5 #include <c10/macros/Macros.h> 6 #include <c10/util/Exception.h> 7 8 #include <c10/cuda/CUDACachingAllocator.h> 9 #include <c10/cuda/CUDAException.h> 10 #include <c10/cuda/CUDAFunctions.h> 11 #include <c10/cuda/CUDAStream.h> 12 13 #include <c10/core/Device.h> 14 #include <c10/core/DeviceType.h> 15 #include <c10/core/Stream.h> 16 #include <c10/core/impl/PyInterpreter.h> 17 #include <cuda_runtime_api.h> 18 #include <cstdint> 19 #include <optional> 20 21 namespace c10::cuda::impl { 22 23 struct CUDAGuardImpl final : public c10::impl::DeviceGuardImplInterface { 24 static constexpr DeviceType static_type = DeviceType::CUDA; 25 26 CUDAGuardImpl() = default; CUDAGuardImplfinal27 explicit CUDAGuardImpl(DeviceType t) { 28 TORCH_INTERNAL_ASSERT(t == DeviceType::CUDA); 29 } typefinal30 DeviceType type() const override { 31 return DeviceType::CUDA; 32 } exchangeDevicefinal33 Device exchangeDevice(Device d) const override { 34 TORCH_INTERNAL_ASSERT(d.is_cuda()); 35 auto old_device_index = c10::cuda::ExchangeDevice(d.index()); 36 return Device(DeviceType::CUDA, old_device_index); 37 } getDevicefinal38 Device getDevice() const override { 39 DeviceIndex device = 0; 40 C10_CUDA_CHECK(c10::cuda::GetDevice(&device)); 41 return Device(DeviceType::CUDA, device); 42 } uncheckedGetDevicefinal43 std::optional<Device> uncheckedGetDevice() const noexcept { 44 DeviceIndex device{-1}; 45 const auto err = C10_CUDA_ERROR_HANDLED(c10::cuda::GetDevice(&device)); 46 C10_CUDA_CHECK_WARN(err); 47 if (err != cudaSuccess) { 48 return std::nullopt; 49 } 50 return Device(DeviceType::CUDA, device); 51 } setDevicefinal52 void setDevice(Device d) const override { 53 TORCH_INTERNAL_ASSERT(d.is_cuda()); 54 C10_CUDA_CHECK(c10::cuda::SetDevice(d.index())); 55 } uncheckedSetDevicefinal56 void uncheckedSetDevice(Device d) const noexcept override { 57 C10_CUDA_CHECK_WARN(c10::cuda::MaybeSetDevice(d.index())); 58 } getStreamfinal59 Stream getStream(Device d) const noexcept override { 60 return getCurrentCUDAStream(d.index()).unwrap(); 61 } getDefaultStreamfinal62 Stream getDefaultStream(Device d) const override { 63 return getDefaultCUDAStream(d.index()); 64 } 65 Stream getNewStream(Device d, int priority = 0) const override { 66 return getStreamFromPool(priority, d.index()); 67 } 68 Stream getStreamFromGlobalPool(Device d, bool isHighPriority = false) 69 const override { 70 return getStreamFromPool(isHighPriority, d.index()); 71 } 72 // NB: These do NOT set the current device exchangeStreamfinal73 Stream exchangeStream(Stream s) const noexcept override { 74 CUDAStream cs(s); 75 auto old_stream = getCurrentCUDAStream(s.device().index()); 76 setCurrentCUDAStream(cs); 77 return old_stream.unwrap(); 78 } deviceCountfinal79 DeviceIndex deviceCount() const noexcept override { 80 return device_count(); 81 } 82 83 // Event-related functions createEventfinal84 void createEvent(cudaEvent_t* cuda_event, const EventFlag flag) const { 85 // Maps PyTorch's Event::Flag to CUDA flag 86 auto cuda_flag = cudaEventDefault; 87 switch (flag) { 88 case EventFlag::PYTORCH_DEFAULT: 89 cuda_flag = cudaEventDisableTiming; 90 break; 91 case EventFlag::BACKEND_DEFAULT: 92 cuda_flag = cudaEventDefault; 93 break; 94 default: 95 TORCH_CHECK(false, "CUDA event received unknown flag"); 96 } 97 98 C10_CUDA_CHECK(cudaEventCreateWithFlags(cuda_event, cuda_flag)); 99 const c10::impl::PyInterpreter* interp = c10::impl::GPUTrace::get_trace(); 100 if (C10_UNLIKELY(interp)) { 101 (*interp)->trace_gpu_event_creation( 102 c10::kCUDA, reinterpret_cast<uintptr_t>(cuda_event)); 103 } 104 } 105 destroyEventfinal106 void destroyEvent(void* event, const DeviceIndex device_index) 107 const noexcept override { 108 if (!event) 109 return; 110 auto cuda_event = static_cast<cudaEvent_t>(event); 111 DeviceIndex orig_device{-1}; 112 C10_CUDA_CHECK_WARN(c10::cuda::GetDevice(&orig_device)); 113 C10_CUDA_CHECK_WARN(c10::cuda::SetDevice(device_index)); 114 const c10::impl::PyInterpreter* interp = c10::impl::GPUTrace::get_trace(); 115 if (C10_UNLIKELY(interp)) { 116 (*interp)->trace_gpu_event_deletion( 117 c10::kCUDA, reinterpret_cast<uintptr_t>(cuda_event)); 118 } 119 C10_CUDA_CHECK_WARN(cudaEventDestroy(cuda_event)); 120 C10_CUDA_CHECK_WARN(c10::cuda::SetDevice(orig_device)); 121 } 122 recordfinal123 void record( 124 void** event, 125 const Stream& stream, 126 const DeviceIndex device_index, 127 const EventFlag flag) const override { 128 TORCH_CHECK( 129 device_index == -1 || device_index == stream.device_index(), 130 "Event device index ", 131 device_index, 132 " does not match recording stream's device index ", 133 stream.device_index(), 134 "."); 135 136 cudaEvent_t cuda_event = static_cast<cudaEvent_t>(*event); 137 CUDAStream cuda_stream{stream}; 138 139 // Moves to stream's device to record 140 const auto orig_device = getDevice(); 141 setDevice(stream.device()); 142 143 // Creates the event (lazily) 144 if (!cuda_event) 145 createEvent(&cuda_event, flag); 146 C10_CUDA_CHECK(cudaEventRecord(cuda_event, cuda_stream)); 147 // Makes the void* point to the (possibly just allocated) CUDA event 148 *event = cuda_event; 149 const c10::impl::PyInterpreter* interp = c10::impl::GPUTrace::get_trace(); 150 if (C10_UNLIKELY(interp)) { 151 (*interp)->trace_gpu_event_record( 152 c10::kCUDA, 153 reinterpret_cast<uintptr_t>(cuda_event), 154 reinterpret_cast<uintptr_t>(cuda_stream.stream())); 155 } 156 157 // Resets device 158 setDevice(orig_device); 159 } 160 blockfinal161 void block(void* event, const Stream& stream) const override { 162 if (!event) 163 return; 164 cudaEvent_t cuda_event = static_cast<cudaEvent_t>(event); 165 CUDAStream cuda_stream{stream}; 166 const auto orig_device = getDevice(); 167 setDevice(stream.device()); 168 C10_CUDA_CHECK(cudaStreamWaitEvent( 169 cuda_stream, 170 cuda_event, 171 /*flags (must be zero)=*/0)); 172 const c10::impl::PyInterpreter* interp = c10::impl::GPUTrace::get_trace(); 173 if (C10_UNLIKELY(interp)) { 174 (*interp)->trace_gpu_event_wait( 175 c10::kCUDA, 176 reinterpret_cast<uintptr_t>(cuda_event), 177 reinterpret_cast<uintptr_t>(cuda_stream.stream())); 178 } 179 setDevice(orig_device); 180 } 181 182 // May be called from any device queryEventfinal183 bool queryEvent(void* event) const override { 184 if (!event) 185 return true; 186 cudaEvent_t cuda_event = static_cast<cudaEvent_t>(event); 187 // Note: cudaEventQuery can be safely called from any device 188 const cudaError_t err = C10_CUDA_ERROR_HANDLED(cudaEventQuery(cuda_event)); 189 if (err != cudaErrorNotReady) { 190 C10_CUDA_CHECK(err); 191 } else { 192 // ignore and clear the error if not ready 193 (void)cudaGetLastError(); 194 } 195 return (err == cudaSuccess); 196 } 197 198 // Stream-related functions queryStreamfinal199 bool queryStream(const Stream& stream) const override { 200 CUDAStream cuda_stream{stream}; 201 return cuda_stream.query(); 202 } 203 synchronizeStreamfinal204 void synchronizeStream(const Stream& stream) const override { 205 CUDAStream cuda_stream{stream}; 206 cuda_stream.synchronize(); 207 } 208 synchronizeEventfinal209 void synchronizeEvent(void* event) const override { 210 if (!event) 211 return; 212 cudaEvent_t cuda_event = static_cast<cudaEvent_t>(event); 213 const c10::impl::PyInterpreter* interp = c10::impl::GPUTrace::get_trace(); 214 if (C10_UNLIKELY(interp)) { 215 (*interp)->trace_gpu_event_synchronization( 216 c10::kCUDA, reinterpret_cast<uintptr_t>(cuda_event)); 217 } 218 // Note: cudaEventSynchronize can be safely called from any device 219 C10_CUDA_CHECK(cudaEventSynchronize(cuda_event)); 220 } 221 recordDataPtrOnStreamfinal222 void recordDataPtrOnStream(const c10::DataPtr& data_ptr, const Stream& stream) 223 const override { 224 CUDAStream cuda_stream{stream}; 225 CUDACachingAllocator::recordStream(data_ptr, cuda_stream); 226 } 227 elapsedTimefinal228 double elapsedTime(void* event1, void* event2, const DeviceIndex device_index) 229 const override { 230 TORCH_CHECK( 231 event1 && event2, 232 "Both events must be recorded before calculating elapsed time."); 233 // Even though cudaEventElapsedTime can be safely called from any device, if 234 // the current device is not initialized, it will create a new cuda context, 235 // which will consume a lot of memory. 236 DeviceIndex orig_device{-1}; 237 C10_CUDA_CHECK(c10::cuda::GetDevice(&orig_device)); 238 C10_CUDA_CHECK(c10::cuda::SetDevice(device_index)); 239 cudaEvent_t cuda_event1 = static_cast<cudaEvent_t>(event1); 240 cudaEvent_t cuda_event2 = static_cast<cudaEvent_t>(event2); 241 float time_ms = 0; 242 // raise cudaErrorNotReady if either event is recorded but not yet completed 243 C10_CUDA_CHECK(cudaEventElapsedTime(&time_ms, cuda_event1, cuda_event2)); 244 C10_CUDA_CHECK(c10::cuda::SetDevice(orig_device)); 245 return static_cast<double>(time_ms); 246 } 247 }; 248 249 } // namespace c10::cuda::impl 250