1*da0073e9SAndroid Build Coastguard Worker #pragma once
2*da0073e9SAndroid Build Coastguard Worker
3*da0073e9SAndroid Build Coastguard Worker // This header provides C++ wrappers around commonly used CUDA API functions.
4*da0073e9SAndroid Build Coastguard Worker // The benefit of using C++ here is that we can raise an exception in the
5*da0073e9SAndroid Build Coastguard Worker // event of an error, rather than explicitly pass around error codes. This
6*da0073e9SAndroid Build Coastguard Worker // leads to more natural APIs.
7*da0073e9SAndroid Build Coastguard Worker //
8*da0073e9SAndroid Build Coastguard Worker // The naming convention used here matches the naming convention of torch.cuda
9*da0073e9SAndroid Build Coastguard Worker
10*da0073e9SAndroid Build Coastguard Worker #include <c10/core/Device.h>
11*da0073e9SAndroid Build Coastguard Worker #include <c10/core/impl/GPUTrace.h>
12*da0073e9SAndroid Build Coastguard Worker #include <c10/cuda/CUDAException.h>
13*da0073e9SAndroid Build Coastguard Worker #include <c10/cuda/CUDAMacros.h>
14*da0073e9SAndroid Build Coastguard Worker #include <cuda_runtime_api.h>
15*da0073e9SAndroid Build Coastguard Worker namespace c10::cuda {
16*da0073e9SAndroid Build Coastguard Worker
17*da0073e9SAndroid Build Coastguard Worker // NB: In the past, we were inconsistent about whether or not this reported
18*da0073e9SAndroid Build Coastguard Worker // an error if there were driver problems are not. Based on experience
19*da0073e9SAndroid Build Coastguard Worker // interacting with users, it seems that people basically ~never want this
20*da0073e9SAndroid Build Coastguard Worker // function to fail; it should just return zero if things are not working.
21*da0073e9SAndroid Build Coastguard Worker // Oblige them.
22*da0073e9SAndroid Build Coastguard Worker // It still might log a warning for user first time it's invoked
23*da0073e9SAndroid Build Coastguard Worker C10_CUDA_API DeviceIndex device_count() noexcept;
24*da0073e9SAndroid Build Coastguard Worker
25*da0073e9SAndroid Build Coastguard Worker // Version of device_count that throws is no devices are detected
26*da0073e9SAndroid Build Coastguard Worker C10_CUDA_API DeviceIndex device_count_ensure_non_zero();
27*da0073e9SAndroid Build Coastguard Worker
28*da0073e9SAndroid Build Coastguard Worker C10_CUDA_API DeviceIndex current_device();
29*da0073e9SAndroid Build Coastguard Worker
30*da0073e9SAndroid Build Coastguard Worker C10_CUDA_API void set_device(DeviceIndex device);
31*da0073e9SAndroid Build Coastguard Worker
32*da0073e9SAndroid Build Coastguard Worker C10_CUDA_API void device_synchronize();
33*da0073e9SAndroid Build Coastguard Worker
34*da0073e9SAndroid Build Coastguard Worker C10_CUDA_API void warn_or_error_on_sync();
35*da0073e9SAndroid Build Coastguard Worker
36*da0073e9SAndroid Build Coastguard Worker // Raw CUDA device management functions
37*da0073e9SAndroid Build Coastguard Worker C10_CUDA_API cudaError_t GetDeviceCount(int* dev_count);
38*da0073e9SAndroid Build Coastguard Worker
39*da0073e9SAndroid Build Coastguard Worker C10_CUDA_API cudaError_t GetDevice(DeviceIndex* device);
40*da0073e9SAndroid Build Coastguard Worker
41*da0073e9SAndroid Build Coastguard Worker C10_CUDA_API cudaError_t SetDevice(DeviceIndex device);
42*da0073e9SAndroid Build Coastguard Worker
43*da0073e9SAndroid Build Coastguard Worker C10_CUDA_API cudaError_t MaybeSetDevice(DeviceIndex device);
44*da0073e9SAndroid Build Coastguard Worker
45*da0073e9SAndroid Build Coastguard Worker C10_CUDA_API DeviceIndex ExchangeDevice(DeviceIndex device);
46*da0073e9SAndroid Build Coastguard Worker
47*da0073e9SAndroid Build Coastguard Worker C10_CUDA_API DeviceIndex MaybeExchangeDevice(DeviceIndex device);
48*da0073e9SAndroid Build Coastguard Worker
49*da0073e9SAndroid Build Coastguard Worker C10_CUDA_API void SetTargetDevice();
50*da0073e9SAndroid Build Coastguard Worker
51*da0073e9SAndroid Build Coastguard Worker enum class SyncDebugMode { L_DISABLED = 0, L_WARN, L_ERROR };
52*da0073e9SAndroid Build Coastguard Worker
53*da0073e9SAndroid Build Coastguard Worker // this is a holder for c10 global state (similar to at GlobalContext)
54*da0073e9SAndroid Build Coastguard Worker // currently it's used to store cuda synchronization warning state,
55*da0073e9SAndroid Build Coastguard Worker // but can be expanded to hold other related global state, e.g. to
56*da0073e9SAndroid Build Coastguard Worker // record stream usage
57*da0073e9SAndroid Build Coastguard Worker class WarningState {
58*da0073e9SAndroid Build Coastguard Worker public:
set_sync_debug_mode(SyncDebugMode l)59*da0073e9SAndroid Build Coastguard Worker void set_sync_debug_mode(SyncDebugMode l) {
60*da0073e9SAndroid Build Coastguard Worker sync_debug_mode = l;
61*da0073e9SAndroid Build Coastguard Worker }
62*da0073e9SAndroid Build Coastguard Worker
get_sync_debug_mode()63*da0073e9SAndroid Build Coastguard Worker SyncDebugMode get_sync_debug_mode() {
64*da0073e9SAndroid Build Coastguard Worker return sync_debug_mode;
65*da0073e9SAndroid Build Coastguard Worker }
66*da0073e9SAndroid Build Coastguard Worker
67*da0073e9SAndroid Build Coastguard Worker private:
68*da0073e9SAndroid Build Coastguard Worker SyncDebugMode sync_debug_mode = SyncDebugMode::L_DISABLED;
69*da0073e9SAndroid Build Coastguard Worker };
70*da0073e9SAndroid Build Coastguard Worker
warning_state()71*da0073e9SAndroid Build Coastguard Worker C10_CUDA_API __inline__ WarningState& warning_state() {
72*da0073e9SAndroid Build Coastguard Worker static WarningState warning_state_;
73*da0073e9SAndroid Build Coastguard Worker return warning_state_;
74*da0073e9SAndroid Build Coastguard Worker }
75*da0073e9SAndroid Build Coastguard Worker // the subsequent functions are defined in the header because for performance
76*da0073e9SAndroid Build Coastguard Worker // reasons we want them to be inline
memcpy_and_sync(void * dst,const void * src,int64_t nbytes,cudaMemcpyKind kind,cudaStream_t stream)77*da0073e9SAndroid Build Coastguard Worker C10_CUDA_API void __inline__ memcpy_and_sync(
78*da0073e9SAndroid Build Coastguard Worker void* dst,
79*da0073e9SAndroid Build Coastguard Worker const void* src,
80*da0073e9SAndroid Build Coastguard Worker int64_t nbytes,
81*da0073e9SAndroid Build Coastguard Worker cudaMemcpyKind kind,
82*da0073e9SAndroid Build Coastguard Worker cudaStream_t stream) {
83*da0073e9SAndroid Build Coastguard Worker if (C10_UNLIKELY(
84*da0073e9SAndroid Build Coastguard Worker warning_state().get_sync_debug_mode() != SyncDebugMode::L_DISABLED)) {
85*da0073e9SAndroid Build Coastguard Worker warn_or_error_on_sync();
86*da0073e9SAndroid Build Coastguard Worker }
87*da0073e9SAndroid Build Coastguard Worker const c10::impl::PyInterpreter* interp = c10::impl::GPUTrace::get_trace();
88*da0073e9SAndroid Build Coastguard Worker if (C10_UNLIKELY(interp)) {
89*da0073e9SAndroid Build Coastguard Worker (*interp)->trace_gpu_stream_synchronization(
90*da0073e9SAndroid Build Coastguard Worker c10::kCUDA, reinterpret_cast<uintptr_t>(stream));
91*da0073e9SAndroid Build Coastguard Worker }
92*da0073e9SAndroid Build Coastguard Worker #if defined(TORCH_HIP_VERSION) && (TORCH_HIP_VERSION >= 301)
93*da0073e9SAndroid Build Coastguard Worker C10_CUDA_CHECK(hipMemcpyWithStream(dst, src, nbytes, kind, stream));
94*da0073e9SAndroid Build Coastguard Worker #else
95*da0073e9SAndroid Build Coastguard Worker C10_CUDA_CHECK(cudaMemcpyAsync(dst, src, nbytes, kind, stream));
96*da0073e9SAndroid Build Coastguard Worker C10_CUDA_CHECK(cudaStreamSynchronize(stream));
97*da0073e9SAndroid Build Coastguard Worker #endif
98*da0073e9SAndroid Build Coastguard Worker }
99*da0073e9SAndroid Build Coastguard Worker
stream_synchronize(cudaStream_t stream)100*da0073e9SAndroid Build Coastguard Worker C10_CUDA_API void __inline__ stream_synchronize(cudaStream_t stream) {
101*da0073e9SAndroid Build Coastguard Worker if (C10_UNLIKELY(
102*da0073e9SAndroid Build Coastguard Worker warning_state().get_sync_debug_mode() != SyncDebugMode::L_DISABLED)) {
103*da0073e9SAndroid Build Coastguard Worker warn_or_error_on_sync();
104*da0073e9SAndroid Build Coastguard Worker }
105*da0073e9SAndroid Build Coastguard Worker const c10::impl::PyInterpreter* interp = c10::impl::GPUTrace::get_trace();
106*da0073e9SAndroid Build Coastguard Worker if (C10_UNLIKELY(interp)) {
107*da0073e9SAndroid Build Coastguard Worker (*interp)->trace_gpu_stream_synchronization(
108*da0073e9SAndroid Build Coastguard Worker c10::kCUDA, reinterpret_cast<uintptr_t>(stream));
109*da0073e9SAndroid Build Coastguard Worker }
110*da0073e9SAndroid Build Coastguard Worker C10_CUDA_CHECK(cudaStreamSynchronize(stream));
111*da0073e9SAndroid Build Coastguard Worker }
112*da0073e9SAndroid Build Coastguard Worker
113*da0073e9SAndroid Build Coastguard Worker C10_CUDA_API bool hasPrimaryContext(DeviceIndex device_index);
114*da0073e9SAndroid Build Coastguard Worker C10_CUDA_API std::optional<DeviceIndex> getDeviceIndexWithPrimaryContext();
115*da0073e9SAndroid Build Coastguard Worker
116*da0073e9SAndroid Build Coastguard Worker } // namespace c10::cuda
117