xref: /aosp_15_r20/external/pytorch/torch/csrc/api/src/xpu.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #include <ATen/Context.h>
2 #include <torch/xpu.h>
3 
4 namespace torch::xpu {
5 
device_count()6 size_t device_count() {
7   return at::detail::getXPUHooks().getNumGPUs();
8 }
9 
is_available()10 bool is_available() {
11   return xpu::device_count() > 0;
12 }
13 
manual_seed(uint64_t seed)14 void manual_seed(uint64_t seed) {
15   if (is_available()) {
16     auto index = at::detail::getXPUHooks().current_device();
17     auto gen = at::detail::getXPUHooks().getDefaultXPUGenerator(index);
18     {
19       // See Note [Acquire lock when using random generators]
20       std::lock_guard<std::mutex> lock(gen.mutex());
21       gen.set_current_seed(seed);
22     }
23   }
24 }
25 
26 /// Sets the seed for all available GPUs.
manual_seed_all(uint64_t seed)27 void manual_seed_all(uint64_t seed) {
28   auto num_gpu = device_count();
29   for (const auto i : c10::irange(num_gpu)) {
30     auto gen = at::detail::getXPUHooks().getDefaultXPUGenerator(i);
31     {
32       // See Note [Acquire lock when using random generators]
33       std::lock_guard<std::mutex> lock(gen.mutex());
34       gen.set_current_seed(seed);
35     }
36   }
37 }
38 
synchronize(int64_t device_index)39 void synchronize(int64_t device_index) {
40   TORCH_CHECK(is_available(), "No XPU are available");
41   at::detail::getXPUHooks().deviceSynchronize(
42       static_cast<c10::DeviceIndex>(device_index));
43 }
44 
45 } // namespace torch::xpu
46