1 #include <torch/cuda.h> 2 3 #include <ATen/Context.h> 4 #include <c10/core/DeviceGuard.h> 5 #include <c10/util/irange.h> 6 7 #include <cstddef> 8 9 namespace torch { 10 namespace cuda { 11 device_count()12size_t device_count() { 13 return at::detail::getCUDAHooks().getNumGPUs(); 14 } 15 is_available()16bool is_available() { 17 // NB: the semantics of this are different from at::globalContext().hasCUDA(); 18 // ATen's function tells you if you have a working driver and CUDA build, 19 // whereas this function also tells you if you actually have any GPUs. 20 // This function matches the semantics of at::cuda::is_available() 21 return cuda::device_count() > 0; 22 } 23 cudnn_is_available()24bool cudnn_is_available() { 25 return is_available() && at::detail::getCUDAHooks().hasCuDNN(); 26 } 27 28 /// Sets the seed for the current GPU. manual_seed(uint64_t seed)29void manual_seed(uint64_t seed) { 30 if (is_available()) { 31 auto index = at::detail::getCUDAHooks().current_device(); 32 auto gen = at::detail::getCUDAHooks().getDefaultCUDAGenerator(index); 33 { 34 // See Note [Acquire lock when using random generators] 35 std::lock_guard<std::mutex> lock(gen.mutex()); 36 gen.set_current_seed(seed); 37 } 38 } 39 } 40 41 /// Sets the seed for all available GPUs. manual_seed_all(uint64_t seed)42void manual_seed_all(uint64_t seed) { 43 auto num_gpu = device_count(); 44 for (const auto i : c10::irange(num_gpu)) { 45 auto gen = at::detail::getCUDAHooks().getDefaultCUDAGenerator(i); 46 { 47 // See Note [Acquire lock when using random generators] 48 std::lock_guard<std::mutex> lock(gen.mutex()); 49 gen.set_current_seed(seed); 50 } 51 } 52 } 53 synchronize(int64_t device_index)54void synchronize(int64_t device_index) { 55 TORCH_CHECK(is_available(), "No CUDA GPUs are available"); 56 int64_t num_gpus = cuda::device_count(); 57 TORCH_CHECK( 58 device_index == -1 || device_index < num_gpus, 59 "Device index out of range: ", 60 device_index); 61 at::detail::getCUDAHooks().deviceSynchronize(device_index); 62 } 63 64 } // namespace cuda 65 } // namespace torch 66