xref: /aosp_15_r20/external/pytorch/aten/src/ATen/DeviceAccelerator.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #include <ATen/Context.h>
2 #include <ATen/DeviceAccelerator.h>
3 namespace at {
4 
getAccelerator(bool checked)5 std::optional<c10::DeviceType> getAccelerator(bool checked) {
6 #define DETECT_AND_ASSIGN_ACCELERATOR(device_name) \
7   if (at::has##device_name()) {                    \
8     device_type = k##device_name;                  \
9     TORCH_CHECK(                                   \
10         !is_accelerator_detected,                  \
11         "Cannot have ",                            \
12         device_type.value(),                       \
13         " with other accelerators.");              \
14     is_accelerator_detected = true;                \
15   }
16 
17   if (is_privateuse1_backend_registered()) {
18     // We explicitly allow PrivateUse1 and another device at the same time as we
19     // use this for testing. Whenever a PrivateUse1 device is registered, use it
20     // first.
21     return kPrivateUse1;
22   }
23   std::optional<c10::DeviceType> device_type = std::nullopt;
24   bool is_accelerator_detected = false;
25   DETECT_AND_ASSIGN_ACCELERATOR(CUDA)
26   DETECT_AND_ASSIGN_ACCELERATOR(MTIA)
27   DETECT_AND_ASSIGN_ACCELERATOR(XPU)
28   DETECT_AND_ASSIGN_ACCELERATOR(HIP)
29   DETECT_AND_ASSIGN_ACCELERATOR(MPS)
30   if (checked) {
31     TORCH_CHECK(
32         device_type, "Cannot access accelerator device when none is available.")
33   }
34   return device_type;
35 
36 #undef DETECT_AND_ASSIGN_ACCELERATOR
37 }
38 
isAccelerator(c10::DeviceType d)39 bool isAccelerator(c10::DeviceType d) {
40   switch (d) {
41     case at::kCUDA:
42     case at::kMTIA:
43     case at::kXPU:
44     case at::kHIP:
45     case at::kMPS:
46     case at::kPrivateUse1:
47       return true;
48     default:
49       return false;
50   }
51 }
52 
53 } // namespace at
54