xref: /aosp_15_r20/external/pytorch/c10/cuda/impl/CUDAGuardImpl.h (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #pragma once
2 
3 #include <c10/core/impl/DeviceGuardImplInterface.h>
4 #include <c10/core/impl/GPUTrace.h>
5 #include <c10/macros/Macros.h>
6 #include <c10/util/Exception.h>
7 
8 #include <c10/cuda/CUDACachingAllocator.h>
9 #include <c10/cuda/CUDAException.h>
10 #include <c10/cuda/CUDAFunctions.h>
11 #include <c10/cuda/CUDAStream.h>
12 
13 #include <c10/core/Device.h>
14 #include <c10/core/DeviceType.h>
15 #include <c10/core/Stream.h>
16 #include <c10/core/impl/PyInterpreter.h>
17 #include <cuda_runtime_api.h>
18 #include <cstdint>
19 #include <optional>
20 
21 namespace c10::cuda::impl {
22 
23 struct CUDAGuardImpl final : public c10::impl::DeviceGuardImplInterface {
24   static constexpr DeviceType static_type = DeviceType::CUDA;
25 
26   CUDAGuardImpl() = default;
CUDAGuardImplfinal27   explicit CUDAGuardImpl(DeviceType t) {
28     TORCH_INTERNAL_ASSERT(t == DeviceType::CUDA);
29   }
typefinal30   DeviceType type() const override {
31     return DeviceType::CUDA;
32   }
exchangeDevicefinal33   Device exchangeDevice(Device d) const override {
34     TORCH_INTERNAL_ASSERT(d.is_cuda());
35     auto old_device_index = c10::cuda::ExchangeDevice(d.index());
36     return Device(DeviceType::CUDA, old_device_index);
37   }
getDevicefinal38   Device getDevice() const override {
39     DeviceIndex device = 0;
40     C10_CUDA_CHECK(c10::cuda::GetDevice(&device));
41     return Device(DeviceType::CUDA, device);
42   }
uncheckedGetDevicefinal43   std::optional<Device> uncheckedGetDevice() const noexcept {
44     DeviceIndex device{-1};
45     const auto err = C10_CUDA_ERROR_HANDLED(c10::cuda::GetDevice(&device));
46     C10_CUDA_CHECK_WARN(err);
47     if (err != cudaSuccess) {
48       return std::nullopt;
49     }
50     return Device(DeviceType::CUDA, device);
51   }
setDevicefinal52   void setDevice(Device d) const override {
53     TORCH_INTERNAL_ASSERT(d.is_cuda());
54     C10_CUDA_CHECK(c10::cuda::SetDevice(d.index()));
55   }
uncheckedSetDevicefinal56   void uncheckedSetDevice(Device d) const noexcept override {
57     C10_CUDA_CHECK_WARN(c10::cuda::MaybeSetDevice(d.index()));
58   }
getStreamfinal59   Stream getStream(Device d) const noexcept override {
60     return getCurrentCUDAStream(d.index()).unwrap();
61   }
getDefaultStreamfinal62   Stream getDefaultStream(Device d) const override {
63     return getDefaultCUDAStream(d.index());
64   }
65   Stream getNewStream(Device d, int priority = 0) const override {
66     return getStreamFromPool(priority, d.index());
67   }
68   Stream getStreamFromGlobalPool(Device d, bool isHighPriority = false)
69       const override {
70     return getStreamFromPool(isHighPriority, d.index());
71   }
72   // NB: These do NOT set the current device
exchangeStreamfinal73   Stream exchangeStream(Stream s) const noexcept override {
74     CUDAStream cs(s);
75     auto old_stream = getCurrentCUDAStream(s.device().index());
76     setCurrentCUDAStream(cs);
77     return old_stream.unwrap();
78   }
deviceCountfinal79   DeviceIndex deviceCount() const noexcept override {
80     return device_count();
81   }
82 
83   // Event-related functions
createEventfinal84   void createEvent(cudaEvent_t* cuda_event, const EventFlag flag) const {
85     // Maps PyTorch's Event::Flag to CUDA flag
86     auto cuda_flag = cudaEventDefault;
87     switch (flag) {
88       case EventFlag::PYTORCH_DEFAULT:
89         cuda_flag = cudaEventDisableTiming;
90         break;
91       case EventFlag::BACKEND_DEFAULT:
92         cuda_flag = cudaEventDefault;
93         break;
94       default:
95         TORCH_CHECK(false, "CUDA event received unknown flag");
96     }
97 
98     C10_CUDA_CHECK(cudaEventCreateWithFlags(cuda_event, cuda_flag));
99     const c10::impl::PyInterpreter* interp = c10::impl::GPUTrace::get_trace();
100     if (C10_UNLIKELY(interp)) {
101       (*interp)->trace_gpu_event_creation(
102           c10::kCUDA, reinterpret_cast<uintptr_t>(cuda_event));
103     }
104   }
105 
destroyEventfinal106   void destroyEvent(void* event, const DeviceIndex device_index)
107       const noexcept override {
108     if (!event)
109       return;
110     auto cuda_event = static_cast<cudaEvent_t>(event);
111     DeviceIndex orig_device{-1};
112     C10_CUDA_CHECK_WARN(c10::cuda::GetDevice(&orig_device));
113     C10_CUDA_CHECK_WARN(c10::cuda::SetDevice(device_index));
114     const c10::impl::PyInterpreter* interp = c10::impl::GPUTrace::get_trace();
115     if (C10_UNLIKELY(interp)) {
116       (*interp)->trace_gpu_event_deletion(
117           c10::kCUDA, reinterpret_cast<uintptr_t>(cuda_event));
118     }
119     C10_CUDA_CHECK_WARN(cudaEventDestroy(cuda_event));
120     C10_CUDA_CHECK_WARN(c10::cuda::SetDevice(orig_device));
121   }
122 
recordfinal123   void record(
124       void** event,
125       const Stream& stream,
126       const DeviceIndex device_index,
127       const EventFlag flag) const override {
128     TORCH_CHECK(
129         device_index == -1 || device_index == stream.device_index(),
130         "Event device index ",
131         device_index,
132         " does not match recording stream's device index ",
133         stream.device_index(),
134         ".");
135 
136     cudaEvent_t cuda_event = static_cast<cudaEvent_t>(*event);
137     CUDAStream cuda_stream{stream};
138 
139     // Moves to stream's device to record
140     const auto orig_device = getDevice();
141     setDevice(stream.device());
142 
143     // Creates the event (lazily)
144     if (!cuda_event)
145       createEvent(&cuda_event, flag);
146     C10_CUDA_CHECK(cudaEventRecord(cuda_event, cuda_stream));
147     // Makes the void* point to the (possibly just allocated) CUDA event
148     *event = cuda_event;
149     const c10::impl::PyInterpreter* interp = c10::impl::GPUTrace::get_trace();
150     if (C10_UNLIKELY(interp)) {
151       (*interp)->trace_gpu_event_record(
152           c10::kCUDA,
153           reinterpret_cast<uintptr_t>(cuda_event),
154           reinterpret_cast<uintptr_t>(cuda_stream.stream()));
155     }
156 
157     // Resets device
158     setDevice(orig_device);
159   }
160 
blockfinal161   void block(void* event, const Stream& stream) const override {
162     if (!event)
163       return;
164     cudaEvent_t cuda_event = static_cast<cudaEvent_t>(event);
165     CUDAStream cuda_stream{stream};
166     const auto orig_device = getDevice();
167     setDevice(stream.device());
168     C10_CUDA_CHECK(cudaStreamWaitEvent(
169         cuda_stream,
170         cuda_event,
171         /*flags (must be zero)=*/0));
172     const c10::impl::PyInterpreter* interp = c10::impl::GPUTrace::get_trace();
173     if (C10_UNLIKELY(interp)) {
174       (*interp)->trace_gpu_event_wait(
175           c10::kCUDA,
176           reinterpret_cast<uintptr_t>(cuda_event),
177           reinterpret_cast<uintptr_t>(cuda_stream.stream()));
178     }
179     setDevice(orig_device);
180   }
181 
182   // May be called from any device
queryEventfinal183   bool queryEvent(void* event) const override {
184     if (!event)
185       return true;
186     cudaEvent_t cuda_event = static_cast<cudaEvent_t>(event);
187     // Note: cudaEventQuery can be safely called from any device
188     const cudaError_t err = C10_CUDA_ERROR_HANDLED(cudaEventQuery(cuda_event));
189     if (err != cudaErrorNotReady) {
190       C10_CUDA_CHECK(err);
191     } else {
192       // ignore and clear the error if not ready
193       (void)cudaGetLastError();
194     }
195     return (err == cudaSuccess);
196   }
197 
198   // Stream-related functions
queryStreamfinal199   bool queryStream(const Stream& stream) const override {
200     CUDAStream cuda_stream{stream};
201     return cuda_stream.query();
202   }
203 
synchronizeStreamfinal204   void synchronizeStream(const Stream& stream) const override {
205     CUDAStream cuda_stream{stream};
206     cuda_stream.synchronize();
207   }
208 
synchronizeEventfinal209   void synchronizeEvent(void* event) const override {
210     if (!event)
211       return;
212     cudaEvent_t cuda_event = static_cast<cudaEvent_t>(event);
213     const c10::impl::PyInterpreter* interp = c10::impl::GPUTrace::get_trace();
214     if (C10_UNLIKELY(interp)) {
215       (*interp)->trace_gpu_event_synchronization(
216           c10::kCUDA, reinterpret_cast<uintptr_t>(cuda_event));
217     }
218     // Note: cudaEventSynchronize can be safely called from any device
219     C10_CUDA_CHECK(cudaEventSynchronize(cuda_event));
220   }
221 
recordDataPtrOnStreamfinal222   void recordDataPtrOnStream(const c10::DataPtr& data_ptr, const Stream& stream)
223       const override {
224     CUDAStream cuda_stream{stream};
225     CUDACachingAllocator::recordStream(data_ptr, cuda_stream);
226   }
227 
elapsedTimefinal228   double elapsedTime(void* event1, void* event2, const DeviceIndex device_index)
229       const override {
230     TORCH_CHECK(
231         event1 && event2,
232         "Both events must be recorded before calculating elapsed time.");
233     // Even though cudaEventElapsedTime can be safely called from any device, if
234     // the current device is not initialized, it will create a new cuda context,
235     // which will consume a lot of memory.
236     DeviceIndex orig_device{-1};
237     C10_CUDA_CHECK(c10::cuda::GetDevice(&orig_device));
238     C10_CUDA_CHECK(c10::cuda::SetDevice(device_index));
239     cudaEvent_t cuda_event1 = static_cast<cudaEvent_t>(event1);
240     cudaEvent_t cuda_event2 = static_cast<cudaEvent_t>(event2);
241     float time_ms = 0;
242     // raise cudaErrorNotReady if either event is recorded but not yet completed
243     C10_CUDA_CHECK(cudaEventElapsedTime(&time_ms, cuda_event1, cuda_event2));
244     C10_CUDA_CHECK(c10::cuda::SetDevice(orig_device));
245     return static_cast<double>(time_ms);
246   }
247 };
248 
249 } // namespace c10::cuda::impl
250