1 #pragma once
2
3 // This header provides C++ wrappers around commonly used CUDA API functions.
4 // The benefit of using C++ here is that we can raise an exception in the
5 // event of an error, rather than explicitly pass around error codes. This
6 // leads to more natural APIs.
7 //
8 // The naming convention used here matches the naming convention of torch.cuda
9
10 #include <c10/core/Device.h>
11 #include <c10/core/impl/GPUTrace.h>
12 #include <c10/cuda/CUDAException.h>
13 #include <c10/cuda/CUDAMacros.h>
14 #include <cuda_runtime_api.h>
15 namespace c10::cuda {
16
17 // NB: In the past, we were inconsistent about whether or not this reported
18 // an error if there were driver problems are not. Based on experience
19 // interacting with users, it seems that people basically ~never want this
20 // function to fail; it should just return zero if things are not working.
21 // Oblige them.
22 // It still might log a warning for user first time it's invoked
23 C10_CUDA_API DeviceIndex device_count() noexcept;
24
25 // Version of device_count that throws is no devices are detected
26 C10_CUDA_API DeviceIndex device_count_ensure_non_zero();
27
28 C10_CUDA_API DeviceIndex current_device();
29
30 C10_CUDA_API void set_device(DeviceIndex device);
31
32 C10_CUDA_API void device_synchronize();
33
34 C10_CUDA_API void warn_or_error_on_sync();
35
36 // Raw CUDA device management functions
37 C10_CUDA_API cudaError_t GetDeviceCount(int* dev_count);
38
39 C10_CUDA_API cudaError_t GetDevice(DeviceIndex* device);
40
41 C10_CUDA_API cudaError_t SetDevice(DeviceIndex device);
42
43 C10_CUDA_API cudaError_t MaybeSetDevice(DeviceIndex device);
44
45 C10_CUDA_API DeviceIndex ExchangeDevice(DeviceIndex device);
46
47 C10_CUDA_API DeviceIndex MaybeExchangeDevice(DeviceIndex device);
48
49 C10_CUDA_API void SetTargetDevice();
50
51 enum class SyncDebugMode { L_DISABLED = 0, L_WARN, L_ERROR };
52
53 // this is a holder for c10 global state (similar to at GlobalContext)
54 // currently it's used to store cuda synchronization warning state,
55 // but can be expanded to hold other related global state, e.g. to
56 // record stream usage
57 class WarningState {
58 public:
set_sync_debug_mode(SyncDebugMode l)59 void set_sync_debug_mode(SyncDebugMode l) {
60 sync_debug_mode = l;
61 }
62
get_sync_debug_mode()63 SyncDebugMode get_sync_debug_mode() {
64 return sync_debug_mode;
65 }
66
67 private:
68 SyncDebugMode sync_debug_mode = SyncDebugMode::L_DISABLED;
69 };
70
warning_state()71 C10_CUDA_API __inline__ WarningState& warning_state() {
72 static WarningState warning_state_;
73 return warning_state_;
74 }
75 // the subsequent functions are defined in the header because for performance
76 // reasons we want them to be inline
memcpy_and_sync(void * dst,const void * src,int64_t nbytes,cudaMemcpyKind kind,cudaStream_t stream)77 C10_CUDA_API void __inline__ memcpy_and_sync(
78 void* dst,
79 const void* src,
80 int64_t nbytes,
81 cudaMemcpyKind kind,
82 cudaStream_t stream) {
83 if (C10_UNLIKELY(
84 warning_state().get_sync_debug_mode() != SyncDebugMode::L_DISABLED)) {
85 warn_or_error_on_sync();
86 }
87 const c10::impl::PyInterpreter* interp = c10::impl::GPUTrace::get_trace();
88 if (C10_UNLIKELY(interp)) {
89 (*interp)->trace_gpu_stream_synchronization(
90 c10::kCUDA, reinterpret_cast<uintptr_t>(stream));
91 }
92 #if defined(TORCH_HIP_VERSION) && (TORCH_HIP_VERSION >= 301)
93 C10_CUDA_CHECK(hipMemcpyWithStream(dst, src, nbytes, kind, stream));
94 #else
95 C10_CUDA_CHECK(cudaMemcpyAsync(dst, src, nbytes, kind, stream));
96 C10_CUDA_CHECK(cudaStreamSynchronize(stream));
97 #endif
98 }
99
stream_synchronize(cudaStream_t stream)100 C10_CUDA_API void __inline__ stream_synchronize(cudaStream_t stream) {
101 if (C10_UNLIKELY(
102 warning_state().get_sync_debug_mode() != SyncDebugMode::L_DISABLED)) {
103 warn_or_error_on_sync();
104 }
105 const c10::impl::PyInterpreter* interp = c10::impl::GPUTrace::get_trace();
106 if (C10_UNLIKELY(interp)) {
107 (*interp)->trace_gpu_stream_synchronization(
108 c10::kCUDA, reinterpret_cast<uintptr_t>(stream));
109 }
110 C10_CUDA_CHECK(cudaStreamSynchronize(stream));
111 }
112
113 C10_CUDA_API bool hasPrimaryContext(DeviceIndex device_index);
114 C10_CUDA_API std::optional<DeviceIndex> getDeviceIndexWithPrimaryContext();
115
116 } // namespace c10::cuda
117