1 #include <c10/cuda/CUDAFunctions.h>
2 #include <c10/macros/Macros.h>
3
4 #include <limits>
5
6 namespace c10::cuda {
7
8 namespace {
9 // returns -1 on failure
driver_version()10 int32_t driver_version() {
11 int driver_version = -1;
12 C10_CUDA_IGNORE_ERROR(cudaDriverGetVersion(&driver_version));
13 return driver_version;
14 }
15
device_count_impl(bool fail_if_no_driver)16 int device_count_impl(bool fail_if_no_driver) {
17 int count = 0;
18 auto err = C10_CUDA_ERROR_HANDLED(c10::cuda::GetDeviceCount(&count));
19 if (err == cudaSuccess) {
20 return count;
21 }
22 // Clear out the error state, so we don't spuriously trigger someone else.
23 // (This shouldn't really matter, since we won't be running very much CUDA
24 // code in this regime.)
25 cudaError_t last_err C10_UNUSED = cudaGetLastError();
26 switch (err) {
27 case cudaErrorNoDevice:
28 // Zero devices is ok here
29 count = 0;
30 break;
31 case cudaErrorInsufficientDriver: {
32 auto version = driver_version();
33 if (version <= 0) {
34 if (!fail_if_no_driver) {
35 // No CUDA driver means no devices
36 count = 0;
37 break;
38 }
39 TORCH_CHECK(
40 false,
41 "Found no NVIDIA driver on your system. Please check that you "
42 "have an NVIDIA GPU and installed a driver from "
43 "http://www.nvidia.com/Download/index.aspx");
44 } else {
45 TORCH_CHECK(
46 false,
47 "The NVIDIA driver on your system is too old (found version ",
48 version,
49 "). Please update your GPU driver by downloading and installing "
50 "a new version from the URL: "
51 "http://www.nvidia.com/Download/index.aspx Alternatively, go to: "
52 "https://pytorch.org to install a PyTorch version that has been "
53 "compiled with your version of the CUDA driver.");
54 }
55 } break;
56 case cudaErrorInitializationError:
57 TORCH_CHECK(
58 false,
59 "CUDA driver initialization failed, you might not "
60 "have a CUDA gpu.");
61 break;
62 case cudaErrorUnknown:
63 TORCH_CHECK(
64 false,
65 "CUDA unknown error - this may be due to an "
66 "incorrectly set up environment, e.g. changing env "
67 "variable CUDA_VISIBLE_DEVICES after program start. "
68 "Setting the available devices to be zero.");
69 break;
70 #if C10_ASAN_ENABLED
71 case cudaErrorMemoryAllocation:
72 // In ASAN mode, we know that a cudaErrorMemoryAllocation error will
73 // pop up if compiled with NVCC (clang-cuda is fine)
74 TORCH_CHECK(
75 false,
76 "Got 'out of memory' error while trying to initialize CUDA. "
77 "CUDA with nvcc does not work well with ASAN and it's probably "
78 "the reason. We will simply shut down CUDA support. If you "
79 "would like to use GPUs, turn off ASAN.");
80 break;
81 #endif // C10_ASAN_ENABLED
82 default:
83 TORCH_CHECK(
84 false,
85 "Unexpected error from cudaGetDeviceCount(). Did you run "
86 "some cuda functions before calling NumCudaDevices() "
87 "that might have already set an error? Error ",
88 err,
89 ": ",
90 cudaGetErrorString(err));
91 }
92 return count;
93 }
94 } // namespace
95
device_count()96 DeviceIndex device_count() noexcept {
97 // initialize number of devices only once
98 static int count = []() {
99 try {
100 auto result = device_count_impl(/*fail_if_no_driver=*/false);
101 TORCH_INTERNAL_ASSERT(
102 result <= std::numeric_limits<DeviceIndex>::max(),
103 "Too many CUDA devices, DeviceIndex overflowed");
104 return result;
105 } catch (const c10::Error& ex) {
106 // We don't want to fail, but still log the warning
107 // msg() returns the message without the stack trace
108 TORCH_WARN("CUDA initialization: ", ex.msg());
109 return 0;
110 }
111 }();
112 return static_cast<DeviceIndex>(count);
113 }
114
device_count_ensure_non_zero()115 DeviceIndex device_count_ensure_non_zero() {
116 // Call the implementation every time to throw the exception
117 int count = device_count_impl(/*fail_if_no_driver=*/true);
118 // Zero gpus doesn't produce a warning in `device_count` but we fail here
119 TORCH_CHECK(count, "No CUDA GPUs are available");
120 TORCH_INTERNAL_ASSERT(
121 count <= std::numeric_limits<DeviceIndex>::max(),
122 "Too many CUDA devices, DeviceIndex overflowed");
123 return static_cast<DeviceIndex>(count);
124 }
125
current_device()126 DeviceIndex current_device() {
127 DeviceIndex cur_device = -1;
128 C10_CUDA_CHECK(c10::cuda::GetDevice(&cur_device));
129 return cur_device;
130 }
131
set_device(DeviceIndex device)132 void set_device(DeviceIndex device) {
133 C10_CUDA_CHECK(c10::cuda::SetDevice(device));
134 }
135
device_synchronize()136 void device_synchronize() {
137 const c10::impl::PyInterpreter* interp = c10::impl::GPUTrace::get_trace();
138 if (C10_UNLIKELY(interp)) {
139 (*interp)->trace_gpu_device_synchronization(c10::kCUDA);
140 }
141 C10_CUDA_CHECK(cudaDeviceSynchronize());
142 }
143
144 // this function has to be called from callers performing cuda synchronizing
145 // operations, to raise proper error or warning
warn_or_error_on_sync()146 void warn_or_error_on_sync() {
147 if (warning_state().get_sync_debug_mode() == SyncDebugMode::L_ERROR) {
148 TORCH_CHECK(false, "called a synchronizing CUDA operation");
149 } else if (warning_state().get_sync_debug_mode() == SyncDebugMode::L_WARN) {
150 TORCH_WARN("called a synchronizing CUDA operation");
151 }
152 }
153
getDeviceIndexWithPrimaryContext()154 std::optional<DeviceIndex> getDeviceIndexWithPrimaryContext() {
155 // check current device first
156 auto current_device_index = current_device();
157 if (current_device_index >= 0) {
158 if (hasPrimaryContext(current_device_index)) {
159 return current_device_index;
160 }
161 }
162 for (const auto device_index : c10::irange(at::cuda::device_count())) {
163 if (device_index == current_device_index)
164 continue;
165 if (hasPrimaryContext(device_index)) {
166 return device_index;
167 }
168 }
169 return std::nullopt;
170 }
171
172 namespace _internal {
dummyHasPrimaryContext(C10_UNUSED DeviceIndex device_index)173 bool dummyHasPrimaryContext(C10_UNUSED DeviceIndex device_index) {
174 TORCH_CHECK(false, "Should never been called");
175 }
176 bool (*hasPrimaryContext)(DeviceIndex) = dummyHasPrimaryContext;
177
178 // Private api to be called from CUDAHooks.cpp
setHasPrimaryContext(bool (* func)(DeviceIndex))179 C10_CUDA_API void setHasPrimaryContext(bool (*func)(DeviceIndex)) {
180 hasPrimaryContext = func ? func : dummyHasPrimaryContext;
181 }
182 } // namespace _internal
183
hasPrimaryContext(DeviceIndex device_index)184 bool hasPrimaryContext(DeviceIndex device_index) {
185 return _internal::hasPrimaryContext(device_index);
186 }
187
188 // Wrappers for raw CUDA device management functions
GetDeviceCount(int * dev_count)189 cudaError_t GetDeviceCount(int* dev_count) {
190 return cudaGetDeviceCount(dev_count);
191 }
192
193 // This is a codepath for CUDA 12 that comes with a critical change in behavior
194 // of `cudaSetDevice`. Unlike to previous CUDA versions that allocate context
195 // lazily CUDA 12.x eagerly allocates primary context the moment `cudaSetDevice`
196 // is called. This can lead to dramatic consequences and pollute the device
197 // memory in distributed runs. To avoid unnecessary context creation a new
198 // function called `MaybeSetDevice` was introduced. This function is to be
199 // called in device guard destructor and at the exit of torch.cuda.device
200 // context manager. The behavior of `MaybeSetDevice` is quite simple, it calls
201 // to `cudaSetDevice` if context already exist or if context was not allocated
202 // on targeted device it simply saves the device index. This way we can keep
203 // PyTorch backward compatible for applications like this:
204 //
205 // ```
206 // import torch
207 // x = torch.empty(1, device=“cuda:1”) # no CUDA context on cuda:0 after this
208 // call y = torch.empty(1, device=“cuda”) # CUDA context is created on cuda:0
209 // ```
210 #if CUDA_VERSION >= 12000
211 thread_local DeviceIndex targetDeviceIndex = -1;
212
GetDevice(DeviceIndex * device)213 cudaError_t GetDevice(DeviceIndex* device) {
214 if (targetDeviceIndex >= 0) {
215 *device = targetDeviceIndex;
216 return cudaSuccess;
217 }
218 int tmp_device = -1;
219 auto err = cudaGetDevice(&tmp_device);
220 if (err == cudaSuccess) {
221 TORCH_INTERNAL_ASSERT(
222 tmp_device >= 0 &&
223 tmp_device <= std::numeric_limits<DeviceIndex>::max(),
224 "cudaGetDevice returns invalid device ",
225 tmp_device);
226 *device = static_cast<DeviceIndex>(tmp_device);
227 }
228 return err;
229 }
230
SetDevice(DeviceIndex device)231 cudaError_t SetDevice(DeviceIndex device) {
232 TORCH_CHECK(device >= 0, "device id must be positive!", device);
233 targetDeviceIndex = -1;
234 int cur_device = -1;
235 C10_CUDA_CHECK(cudaGetDevice(&cur_device));
236 if (device == cur_device) {
237 return cudaSuccess;
238 }
239 return cudaSetDevice(device);
240 }
241
MaybeSetDevice(DeviceIndex device)242 cudaError_t MaybeSetDevice(DeviceIndex device) {
243 if (hasPrimaryContext(device)) {
244 return c10::cuda::SetDevice(device);
245 }
246 targetDeviceIndex = device;
247 return cudaSuccess;
248 }
249
250 // This function always initializes the CUDA context
251 // on to_device
ExchangeDevice(DeviceIndex to_device)252 DeviceIndex ExchangeDevice(DeviceIndex to_device) {
253 auto cur_device = targetDeviceIndex;
254 targetDeviceIndex = -1;
255 if (cur_device < 0) {
256 int tmp_device = -1;
257 C10_CUDA_CHECK(cudaGetDevice(&tmp_device));
258 cur_device = static_cast<DeviceIndex>(tmp_device);
259 if (to_device == cur_device) {
260 return cur_device;
261 }
262 }
263 C10_CUDA_CHECK(cudaSetDevice(to_device));
264 return cur_device;
265 }
266
267 // This function does not initialize the CUDA context
268 // on to_device if it does not already exist
MaybeExchangeDevice(DeviceIndex to_device)269 DeviceIndex MaybeExchangeDevice(DeviceIndex to_device) {
270 int tmp_cur_device = -1;
271 C10_CUDA_CHECK(cudaGetDevice(&tmp_cur_device));
272 TORCH_INTERNAL_ASSERT(
273 tmp_cur_device >= 0 &&
274 tmp_cur_device <= std::numeric_limits<DeviceIndex>::max(),
275 "cudaGetDevice returns invalid device ",
276 tmp_cur_device);
277 auto cur_device = static_cast<DeviceIndex>(tmp_cur_device);
278 if (to_device == tmp_cur_device) {
279 return cur_device;
280 }
281 if (hasPrimaryContext(to_device)) {
282 C10_CUDA_CHECK(cudaSetDevice(to_device));
283 } else {
284 targetDeviceIndex = to_device;
285 }
286 return cur_device;
287 }
288
SetTargetDevice()289 void SetTargetDevice() {
290 if (targetDeviceIndex >= 0) {
291 C10_CUDA_CHECK(c10::cuda::SetDevice(targetDeviceIndex));
292 }
293 }
294 #else
GetDevice(DeviceIndex * device)295 cudaError_t GetDevice(DeviceIndex* device) {
296 int tmp_device = -1;
297 auto err = cudaGetDevice(&tmp_device);
298 if (err == cudaSuccess) {
299 TORCH_INTERNAL_ASSERT(
300 tmp_device >= 0 &&
301 tmp_device <= std::numeric_limits<DeviceIndex>::max(),
302 "cudaGetDevice returns invalid device ",
303 tmp_device);
304 *device = static_cast<DeviceIndex>(tmp_device);
305 }
306 return err;
307 }
308
SetDevice(DeviceIndex device)309 cudaError_t SetDevice(DeviceIndex device) {
310 TORCH_CHECK(device >= 0, "device id must be positive!", device);
311 int cur_device = -1;
312 C10_CUDA_CHECK(cudaGetDevice(&cur_device));
313 if (device == cur_device) {
314 return cudaSuccess;
315 }
316 return cudaSetDevice(device);
317 }
318
MaybeSetDevice(DeviceIndex device)319 cudaError_t MaybeSetDevice(DeviceIndex device) {
320 return c10::cuda::SetDevice(device);
321 }
322
ExchangeDevice(DeviceIndex to_device)323 DeviceIndex ExchangeDevice(DeviceIndex to_device) {
324 DeviceIndex cur_device = -1;
325 C10_CUDA_CHECK(c10::cuda::GetDevice(&cur_device));
326 if (to_device == cur_device) {
327 return cur_device;
328 }
329 C10_CUDA_CHECK(cudaSetDevice(to_device));
330 return cur_device;
331 }
332
MaybeExchangeDevice(DeviceIndex to_device)333 DeviceIndex MaybeExchangeDevice(DeviceIndex to_device) {
334 return c10::cuda::ExchangeDevice(to_device);
335 }
336
SetTargetDevice()337 void SetTargetDevice() {
338 // no-op on CUDA version < 12.x
339 }
340 #endif
341
342 } // namespace c10::cuda
343