xref: /aosp_15_r20/external/pytorch/c10/xpu/XPUFunctions.h (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1*da0073e9SAndroid Build Coastguard Worker #pragma once
2*da0073e9SAndroid Build Coastguard Worker 
3*da0073e9SAndroid Build Coastguard Worker #include <c10/core/Device.h>
4*da0073e9SAndroid Build Coastguard Worker #include <c10/xpu/XPUDeviceProp.h>
5*da0073e9SAndroid Build Coastguard Worker #include <c10/xpu/XPUMacros.h>
6*da0073e9SAndroid Build Coastguard Worker 
7*da0073e9SAndroid Build Coastguard Worker // The naming convention used here matches the naming convention of torch.xpu
8*da0073e9SAndroid Build Coastguard Worker 
9*da0073e9SAndroid Build Coastguard Worker namespace c10::xpu {
10*da0073e9SAndroid Build Coastguard Worker 
11*da0073e9SAndroid Build Coastguard Worker // Log a warning only once if no devices are detected.
12*da0073e9SAndroid Build Coastguard Worker C10_XPU_API DeviceIndex device_count();
13*da0073e9SAndroid Build Coastguard Worker 
14*da0073e9SAndroid Build Coastguard Worker // Throws an error if no devices are detected.
15*da0073e9SAndroid Build Coastguard Worker C10_XPU_API DeviceIndex device_count_ensure_non_zero();
16*da0073e9SAndroid Build Coastguard Worker 
17*da0073e9SAndroid Build Coastguard Worker C10_XPU_API DeviceIndex current_device();
18*da0073e9SAndroid Build Coastguard Worker 
19*da0073e9SAndroid Build Coastguard Worker C10_XPU_API void set_device(DeviceIndex device);
20*da0073e9SAndroid Build Coastguard Worker 
21*da0073e9SAndroid Build Coastguard Worker C10_XPU_API DeviceIndex exchange_device(DeviceIndex device);
22*da0073e9SAndroid Build Coastguard Worker 
23*da0073e9SAndroid Build Coastguard Worker C10_XPU_API DeviceIndex maybe_exchange_device(DeviceIndex to_device);
24*da0073e9SAndroid Build Coastguard Worker 
25*da0073e9SAndroid Build Coastguard Worker C10_XPU_API sycl::device& get_raw_device(DeviceIndex device);
26*da0073e9SAndroid Build Coastguard Worker 
27*da0073e9SAndroid Build Coastguard Worker C10_XPU_API sycl::context& get_device_context();
28*da0073e9SAndroid Build Coastguard Worker 
29*da0073e9SAndroid Build Coastguard Worker C10_XPU_API void get_device_properties(
30*da0073e9SAndroid Build Coastguard Worker     DeviceProp* device_prop,
31*da0073e9SAndroid Build Coastguard Worker     DeviceIndex device);
32*da0073e9SAndroid Build Coastguard Worker 
33*da0073e9SAndroid Build Coastguard Worker C10_XPU_API DeviceIndex get_device_idx_from_pointer(void* ptr);
34*da0073e9SAndroid Build Coastguard Worker 
35*da0073e9SAndroid Build Coastguard Worker } // namespace c10::xpu
36