1 #include <ATen/Context.h> 2 #include <ATen/DeviceAccelerator.h> 3 namespace at { 4 getAccelerator(bool checked)5std::optional<c10::DeviceType> getAccelerator(bool checked) { 6 #define DETECT_AND_ASSIGN_ACCELERATOR(device_name) \ 7 if (at::has##device_name()) { \ 8 device_type = k##device_name; \ 9 TORCH_CHECK( \ 10 !is_accelerator_detected, \ 11 "Cannot have ", \ 12 device_type.value(), \ 13 " with other accelerators."); \ 14 is_accelerator_detected = true; \ 15 } 16 17 if (is_privateuse1_backend_registered()) { 18 // We explicitly allow PrivateUse1 and another device at the same time as we 19 // use this for testing. Whenever a PrivateUse1 device is registered, use it 20 // first. 21 return kPrivateUse1; 22 } 23 std::optional<c10::DeviceType> device_type = std::nullopt; 24 bool is_accelerator_detected = false; 25 DETECT_AND_ASSIGN_ACCELERATOR(CUDA) 26 DETECT_AND_ASSIGN_ACCELERATOR(MTIA) 27 DETECT_AND_ASSIGN_ACCELERATOR(XPU) 28 DETECT_AND_ASSIGN_ACCELERATOR(HIP) 29 DETECT_AND_ASSIGN_ACCELERATOR(MPS) 30 if (checked) { 31 TORCH_CHECK( 32 device_type, "Cannot access accelerator device when none is available.") 33 } 34 return device_type; 35 36 #undef DETECT_AND_ASSIGN_ACCELERATOR 37 } 38 isAccelerator(c10::DeviceType d)39bool isAccelerator(c10::DeviceType d) { 40 switch (d) { 41 case at::kCUDA: 42 case at::kMTIA: 43 case at::kXPU: 44 case at::kHIP: 45 case at::kMPS: 46 case at::kPrivateUse1: 47 return true; 48 default: 49 return false; 50 } 51 } 52 53 } // namespace at 54