xref: /aosp_15_r20/external/pytorch/torch/csrc/api/src/cuda.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #include <torch/cuda.h>
2 
3 #include <ATen/Context.h>
4 #include <c10/core/DeviceGuard.h>
5 #include <c10/util/irange.h>
6 
7 #include <cstddef>
8 
9 namespace torch {
10 namespace cuda {
11 
device_count()12 size_t device_count() {
13   return at::detail::getCUDAHooks().getNumGPUs();
14 }
15 
is_available()16 bool is_available() {
17   // NB: the semantics of this are different from at::globalContext().hasCUDA();
18   // ATen's function tells you if you have a working driver and CUDA build,
19   // whereas this function also tells you if you actually have any GPUs.
20   // This function matches the semantics of at::cuda::is_available()
21   return cuda::device_count() > 0;
22 }
23 
cudnn_is_available()24 bool cudnn_is_available() {
25   return is_available() && at::detail::getCUDAHooks().hasCuDNN();
26 }
27 
28 /// Sets the seed for the current GPU.
manual_seed(uint64_t seed)29 void manual_seed(uint64_t seed) {
30   if (is_available()) {
31     auto index = at::detail::getCUDAHooks().current_device();
32     auto gen = at::detail::getCUDAHooks().getDefaultCUDAGenerator(index);
33     {
34       // See Note [Acquire lock when using random generators]
35       std::lock_guard<std::mutex> lock(gen.mutex());
36       gen.set_current_seed(seed);
37     }
38   }
39 }
40 
41 /// Sets the seed for all available GPUs.
manual_seed_all(uint64_t seed)42 void manual_seed_all(uint64_t seed) {
43   auto num_gpu = device_count();
44   for (const auto i : c10::irange(num_gpu)) {
45     auto gen = at::detail::getCUDAHooks().getDefaultCUDAGenerator(i);
46     {
47       // See Note [Acquire lock when using random generators]
48       std::lock_guard<std::mutex> lock(gen.mutex());
49       gen.set_current_seed(seed);
50     }
51   }
52 }
53 
synchronize(int64_t device_index)54 void synchronize(int64_t device_index) {
55   TORCH_CHECK(is_available(), "No CUDA GPUs are available");
56   int64_t num_gpus = cuda::device_count();
57   TORCH_CHECK(
58       device_index == -1 || device_index < num_gpus,
59       "Device index out of range: ",
60       device_index);
61   at::detail::getCUDAHooks().deviceSynchronize(device_index);
62 }
63 
64 } // namespace cuda
65 } // namespace torch
66