xref: /aosp_15_r20/external/pytorch/c10/xpu/XPUFunctions.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1*da0073e9SAndroid Build Coastguard Worker #include <c10/util/CallOnce.h>
2*da0073e9SAndroid Build Coastguard Worker #include <c10/util/Exception.h>
3*da0073e9SAndroid Build Coastguard Worker #include <c10/xpu/XPUFunctions.h>
4*da0073e9SAndroid Build Coastguard Worker 
5*da0073e9SAndroid Build Coastguard Worker #include <vector>
6*da0073e9SAndroid Build Coastguard Worker 
7*da0073e9SAndroid Build Coastguard Worker namespace c10::xpu {
8*da0073e9SAndroid Build Coastguard Worker namespace {
9*da0073e9SAndroid Build Coastguard Worker 
10*da0073e9SAndroid Build Coastguard Worker /*
11*da0073e9SAndroid Build Coastguard Worker  * Note [Device Management]
12*da0073e9SAndroid Build Coastguard Worker  *
13*da0073e9SAndroid Build Coastguard Worker  * An Intel GPU device qualifies as a type of SYCL device. This classification
14*da0073e9SAndroid Build Coastguard Worker  * allows for the runtime querying of Intel GPU device information through the
15*da0073e9SAndroid Build Coastguard Worker  * SYCL runtime library.
16*da0073e9SAndroid Build Coastguard Worker  *
17*da0073e9SAndroid Build Coastguard Worker  * Device status is managed through a SYCL device pool, with SYCL devices
18*da0073e9SAndroid Build Coastguard Worker  * determined at runtime. There's currently a SYCL device pool that is lazily
19*da0073e9SAndroid Build Coastguard Worker  * created and only initialized once, ensuring thread-local safety. Each device
20*da0073e9SAndroid Build Coastguard Worker  * within the device pool shares the same default context.
21*da0073e9SAndroid Build Coastguard Worker  */
22*da0073e9SAndroid Build Coastguard Worker c10::once_flag init_flag;
23*da0073e9SAndroid Build Coastguard Worker thread_local DeviceIndex curDeviceIndex = 0;
24*da0073e9SAndroid Build Coastguard Worker 
25*da0073e9SAndroid Build Coastguard Worker struct DevicePool {
26*da0073e9SAndroid Build Coastguard Worker   std::vector<std::unique_ptr<sycl::device>> devices;
27*da0073e9SAndroid Build Coastguard Worker   std::unique_ptr<sycl::context> context;
28*da0073e9SAndroid Build Coastguard Worker } gDevicePool;
29*da0073e9SAndroid Build Coastguard Worker 
enumDevices(std::vector<std::unique_ptr<sycl::device>> & devices)30*da0073e9SAndroid Build Coastguard Worker void enumDevices(std::vector<std::unique_ptr<sycl::device>>& devices) {
31*da0073e9SAndroid Build Coastguard Worker   auto platform_list = sycl::platform::get_platforms();
32*da0073e9SAndroid Build Coastguard Worker   // Enumerated GPU devices from the specific platform.
33*da0073e9SAndroid Build Coastguard Worker   for (const auto& platform : platform_list) {
34*da0073e9SAndroid Build Coastguard Worker     if (platform.get_backend() != sycl::backend::ext_oneapi_level_zero) {
35*da0073e9SAndroid Build Coastguard Worker       continue;
36*da0073e9SAndroid Build Coastguard Worker     }
37*da0073e9SAndroid Build Coastguard Worker     auto device_list = platform.get_devices();
38*da0073e9SAndroid Build Coastguard Worker     for (const auto& device : device_list) {
39*da0073e9SAndroid Build Coastguard Worker       if (device.is_gpu()) {
40*da0073e9SAndroid Build Coastguard Worker         devices.push_back(std::make_unique<sycl::device>(device));
41*da0073e9SAndroid Build Coastguard Worker       }
42*da0073e9SAndroid Build Coastguard Worker     }
43*da0073e9SAndroid Build Coastguard Worker   }
44*da0073e9SAndroid Build Coastguard Worker }
45*da0073e9SAndroid Build Coastguard Worker 
initGlobalDevicePoolState()46*da0073e9SAndroid Build Coastguard Worker inline void initGlobalDevicePoolState() {
47*da0073e9SAndroid Build Coastguard Worker   // Enumerate all GPU devices and record them.
48*da0073e9SAndroid Build Coastguard Worker   enumDevices(gDevicePool.devices);
49*da0073e9SAndroid Build Coastguard Worker   if (gDevicePool.devices.empty()) {
50*da0073e9SAndroid Build Coastguard Worker     TORCH_WARN("XPU device count is zero!");
51*da0073e9SAndroid Build Coastguard Worker     return;
52*da0073e9SAndroid Build Coastguard Worker   }
53*da0073e9SAndroid Build Coastguard Worker 
54*da0073e9SAndroid Build Coastguard Worker #ifdef _WIN32
55*da0073e9SAndroid Build Coastguard Worker   // default context feature is disabled by default on Windows.
56*da0073e9SAndroid Build Coastguard Worker   std::vector<sycl::device> deviceList;
57*da0073e9SAndroid Build Coastguard Worker   for (auto it = gDevicePool.devices.begin(); it != gDevicePool.devices.end();
58*da0073e9SAndroid Build Coastguard Worker        ++it) {
59*da0073e9SAndroid Build Coastguard Worker     deviceList.push_back(*(*it));
60*da0073e9SAndroid Build Coastguard Worker   }
61*da0073e9SAndroid Build Coastguard Worker   gDevicePool.context = std::make_unique<sycl::context>(deviceList);
62*da0073e9SAndroid Build Coastguard Worker #else
63*da0073e9SAndroid Build Coastguard Worker   // The default context is utilized for each Intel GPU device, allowing the
64*da0073e9SAndroid Build Coastguard Worker   // retrieval of the context from any GPU device.
65*da0073e9SAndroid Build Coastguard Worker   gDevicePool.context = std::make_unique<sycl::context>(
66*da0073e9SAndroid Build Coastguard Worker       gDevicePool.devices[0]->get_platform().ext_oneapi_get_default_context());
67*da0073e9SAndroid Build Coastguard Worker #endif
68*da0073e9SAndroid Build Coastguard Worker }
69*da0073e9SAndroid Build Coastguard Worker 
initDevicePoolCallOnce()70*da0073e9SAndroid Build Coastguard Worker inline void initDevicePoolCallOnce() {
71*da0073e9SAndroid Build Coastguard Worker   c10::call_once(init_flag, initGlobalDevicePoolState);
72*da0073e9SAndroid Build Coastguard Worker }
73*da0073e9SAndroid Build Coastguard Worker 
initDeviceProperties(DeviceProp * device_prop,int device)74*da0073e9SAndroid Build Coastguard Worker void initDeviceProperties(DeviceProp* device_prop, int device) {
75*da0073e9SAndroid Build Coastguard Worker   using namespace sycl::info;
76*da0073e9SAndroid Build Coastguard Worker   using namespace sycl::ext;
77*da0073e9SAndroid Build Coastguard Worker   // Get raw sycl device associated with device index.
78*da0073e9SAndroid Build Coastguard Worker   auto& raw_device = *gDevicePool.devices[device];
79*da0073e9SAndroid Build Coastguard Worker 
80*da0073e9SAndroid Build Coastguard Worker   // Initialize the device properties associated with the specific device.
81*da0073e9SAndroid Build Coastguard Worker #define ASSIGN_DEVICE_PROP(property) \
82*da0073e9SAndroid Build Coastguard Worker   device_prop->property = raw_device.get_info<device::property>();
83*da0073e9SAndroid Build Coastguard Worker 
84*da0073e9SAndroid Build Coastguard Worker #define ASSIGN_EXT_DEVICE_PROP(property, default_value)                      \
85*da0073e9SAndroid Build Coastguard Worker   device_prop->property = raw_device.has(sycl::aspect::ext_intel_##property) \
86*da0073e9SAndroid Build Coastguard Worker       ? raw_device.get_info<intel::info::device::property>()                 \
87*da0073e9SAndroid Build Coastguard Worker       : default_value;
88*da0073e9SAndroid Build Coastguard Worker 
89*da0073e9SAndroid Build Coastguard Worker #define ASSIGN_DEVICE_ASPECT(member) \
90*da0073e9SAndroid Build Coastguard Worker   device_prop->has_##member = raw_device.has(sycl::aspect::member);
91*da0073e9SAndroid Build Coastguard Worker 
92*da0073e9SAndroid Build Coastguard Worker #define ASSIGN_EXP_CL_ASPECT(member)                                       \
93*da0073e9SAndroid Build Coastguard Worker   device_prop->has_##member = raw_device.ext_oneapi_supports_cl_extension( \
94*da0073e9SAndroid Build Coastguard Worker       "cl_intel_" #member, &cl_version);
95*da0073e9SAndroid Build Coastguard Worker 
96*da0073e9SAndroid Build Coastguard Worker   AT_FORALL_XPU_DEVICE_PROPERTIES(ASSIGN_DEVICE_PROP);
97*da0073e9SAndroid Build Coastguard Worker 
98*da0073e9SAndroid Build Coastguard Worker   device_prop->platform_name =
99*da0073e9SAndroid Build Coastguard Worker       raw_device.get_info<device::platform>().get_info<platform::name>();
100*da0073e9SAndroid Build Coastguard Worker 
101*da0073e9SAndroid Build Coastguard Worker   AT_FORALL_XPU_EXT_DEVICE_PROPERTIES(ASSIGN_EXT_DEVICE_PROP);
102*da0073e9SAndroid Build Coastguard Worker 
103*da0073e9SAndroid Build Coastguard Worker   AT_FORALL_XPU_DEVICE_ASPECT(ASSIGN_DEVICE_ASPECT);
104*da0073e9SAndroid Build Coastguard Worker 
105*da0073e9SAndroid Build Coastguard Worker   // TODO: Remove cl_version since it is unnecessary.
106*da0073e9SAndroid Build Coastguard Worker   sycl::ext::oneapi::experimental::cl_version cl_version;
107*da0073e9SAndroid Build Coastguard Worker   AT_FORALL_XPU_EXP_CL_ASPECT(ASSIGN_EXP_CL_ASPECT);
108*da0073e9SAndroid Build Coastguard Worker   return;
109*da0073e9SAndroid Build Coastguard Worker }
110*da0073e9SAndroid Build Coastguard Worker 
check_device(DeviceIndex device)111*da0073e9SAndroid Build Coastguard Worker inline void check_device(DeviceIndex device) {
112*da0073e9SAndroid Build Coastguard Worker   // TODO: Use c10::Device::MAX_NUM_DEVICES directly. DeviceIndex is a int8_t
113*da0073e9SAndroid Build Coastguard Worker   // value, and the maximum number of GPUs that PyTorch recognizes is 64. So, we
114*da0073e9SAndroid Build Coastguard Worker   // have to check if there is an overflow happen. When DeviceIndex changes to
115*da0073e9SAndroid Build Coastguard Worker   // int16_t and c10::Device::MAX_NUM_DEVICES is provided, we should use it
116*da0073e9SAndroid Build Coastguard Worker   // directly to check if too many XPU devices are detected.
117*da0073e9SAndroid Build Coastguard Worker   TORCH_CHECK(
118*da0073e9SAndroid Build Coastguard Worker       gDevicePool.devices.size() <= std::numeric_limits<DeviceIndex>::max(),
119*da0073e9SAndroid Build Coastguard Worker       "Too many XPU devices, DeviceIndex overflowed");
120*da0073e9SAndroid Build Coastguard Worker   auto total = static_cast<DeviceIndex>(gDevicePool.devices.size());
121*da0073e9SAndroid Build Coastguard Worker   TORCH_CHECK(
122*da0073e9SAndroid Build Coastguard Worker       device >= 0 && device < total,
123*da0073e9SAndroid Build Coastguard Worker       "device is out of range, device is ",
124*da0073e9SAndroid Build Coastguard Worker       device,
125*da0073e9SAndroid Build Coastguard Worker       ", total number of device is ",
126*da0073e9SAndroid Build Coastguard Worker       total,
127*da0073e9SAndroid Build Coastguard Worker       ".");
128*da0073e9SAndroid Build Coastguard Worker }
129*da0073e9SAndroid Build Coastguard Worker 
130*da0073e9SAndroid Build Coastguard Worker } // anonymous namespace
131*da0073e9SAndroid Build Coastguard Worker 
get_raw_device(DeviceIndex device)132*da0073e9SAndroid Build Coastguard Worker sycl::device& get_raw_device(DeviceIndex device) {
133*da0073e9SAndroid Build Coastguard Worker   initDevicePoolCallOnce();
134*da0073e9SAndroid Build Coastguard Worker   check_device(device);
135*da0073e9SAndroid Build Coastguard Worker   return *gDevicePool.devices[device];
136*da0073e9SAndroid Build Coastguard Worker }
137*da0073e9SAndroid Build Coastguard Worker 
get_device_context()138*da0073e9SAndroid Build Coastguard Worker sycl::context& get_device_context() {
139*da0073e9SAndroid Build Coastguard Worker   initDevicePoolCallOnce();
140*da0073e9SAndroid Build Coastguard Worker   TORCH_CHECK(
141*da0073e9SAndroid Build Coastguard Worker       gDevicePool.context,
142*da0073e9SAndroid Build Coastguard Worker       "Device pool initialization failed, you might not have an XPU device.")
143*da0073e9SAndroid Build Coastguard Worker   return *gDevicePool.context;
144*da0073e9SAndroid Build Coastguard Worker }
145*da0073e9SAndroid Build Coastguard Worker 
get_device_properties(DeviceProp * device_prop,DeviceIndex device)146*da0073e9SAndroid Build Coastguard Worker void get_device_properties(DeviceProp* device_prop, DeviceIndex device) {
147*da0073e9SAndroid Build Coastguard Worker   initDevicePoolCallOnce();
148*da0073e9SAndroid Build Coastguard Worker   TORCH_CHECK(device_prop, "device_prop is an invalid pointer.");
149*da0073e9SAndroid Build Coastguard Worker   check_device(device);
150*da0073e9SAndroid Build Coastguard Worker   initDeviceProperties(device_prop, device);
151*da0073e9SAndroid Build Coastguard Worker }
152*da0073e9SAndroid Build Coastguard Worker 
get_device_idx_from_pointer(void * ptr)153*da0073e9SAndroid Build Coastguard Worker DeviceIndex get_device_idx_from_pointer(void* ptr) {
154*da0073e9SAndroid Build Coastguard Worker   initDevicePoolCallOnce();
155*da0073e9SAndroid Build Coastguard Worker   TORCH_CHECK(ptr, "ptr is an invalid pointer.");
156*da0073e9SAndroid Build Coastguard Worker   auto type = sycl::get_pointer_type(ptr, get_device_context());
157*da0073e9SAndroid Build Coastguard Worker   TORCH_CHECK(
158*da0073e9SAndroid Build Coastguard Worker       type == sycl::usm::alloc::device, "ptr is not a device type pointer.");
159*da0073e9SAndroid Build Coastguard Worker 
160*da0073e9SAndroid Build Coastguard Worker   sycl::device raw_device = sycl::get_pointer_device(ptr, get_device_context());
161*da0073e9SAndroid Build Coastguard Worker   auto match_device = [raw_device](const auto& device) -> bool {
162*da0073e9SAndroid Build Coastguard Worker     return raw_device == *device;
163*da0073e9SAndroid Build Coastguard Worker   };
164*da0073e9SAndroid Build Coastguard Worker   auto it = std::find_if(
165*da0073e9SAndroid Build Coastguard Worker       gDevicePool.devices.begin(), gDevicePool.devices.end(), match_device);
166*da0073e9SAndroid Build Coastguard Worker   TORCH_CHECK(
167*da0073e9SAndroid Build Coastguard Worker       it != gDevicePool.devices.end(),
168*da0073e9SAndroid Build Coastguard Worker       "Can't find the pointer from XPU devices.");
169*da0073e9SAndroid Build Coastguard Worker   return static_cast<DeviceIndex>(
170*da0073e9SAndroid Build Coastguard Worker       std::distance(gDevicePool.devices.begin(), it));
171*da0073e9SAndroid Build Coastguard Worker }
172*da0073e9SAndroid Build Coastguard Worker 
device_count()173*da0073e9SAndroid Build Coastguard Worker DeviceIndex device_count() {
174*da0073e9SAndroid Build Coastguard Worker   initDevicePoolCallOnce();
175*da0073e9SAndroid Build Coastguard Worker   return static_cast<DeviceIndex>(gDevicePool.devices.size());
176*da0073e9SAndroid Build Coastguard Worker }
177*da0073e9SAndroid Build Coastguard Worker 
device_count_ensure_non_zero()178*da0073e9SAndroid Build Coastguard Worker DeviceIndex device_count_ensure_non_zero() {
179*da0073e9SAndroid Build Coastguard Worker   auto count = device_count();
180*da0073e9SAndroid Build Coastguard Worker   // Zero gpus could produce a warning in `device_count` but we fail here.
181*da0073e9SAndroid Build Coastguard Worker   TORCH_CHECK(count, "No XPU devices are available.");
182*da0073e9SAndroid Build Coastguard Worker   return count;
183*da0073e9SAndroid Build Coastguard Worker }
184*da0073e9SAndroid Build Coastguard Worker 
current_device()185*da0073e9SAndroid Build Coastguard Worker DeviceIndex current_device() {
186*da0073e9SAndroid Build Coastguard Worker   initDevicePoolCallOnce();
187*da0073e9SAndroid Build Coastguard Worker   return curDeviceIndex;
188*da0073e9SAndroid Build Coastguard Worker }
189*da0073e9SAndroid Build Coastguard Worker 
set_device(DeviceIndex device)190*da0073e9SAndroid Build Coastguard Worker void set_device(DeviceIndex device) {
191*da0073e9SAndroid Build Coastguard Worker   initDevicePoolCallOnce();
192*da0073e9SAndroid Build Coastguard Worker   check_device(device);
193*da0073e9SAndroid Build Coastguard Worker   curDeviceIndex = device;
194*da0073e9SAndroid Build Coastguard Worker }
195*da0073e9SAndroid Build Coastguard Worker 
exchange_device(c10::DeviceIndex to_device)196*da0073e9SAndroid Build Coastguard Worker c10::DeviceIndex exchange_device(c10::DeviceIndex to_device) {
197*da0073e9SAndroid Build Coastguard Worker   auto cur_device = current_device();
198*da0073e9SAndroid Build Coastguard Worker   if (to_device == cur_device) {
199*da0073e9SAndroid Build Coastguard Worker     return cur_device;
200*da0073e9SAndroid Build Coastguard Worker   }
201*da0073e9SAndroid Build Coastguard Worker   set_device(to_device);
202*da0073e9SAndroid Build Coastguard Worker   return cur_device;
203*da0073e9SAndroid Build Coastguard Worker }
204*da0073e9SAndroid Build Coastguard Worker 
maybe_exchange_device(c10::DeviceIndex to_device)205*da0073e9SAndroid Build Coastguard Worker c10::DeviceIndex maybe_exchange_device(c10::DeviceIndex to_device) {
206*da0073e9SAndroid Build Coastguard Worker   return exchange_device(to_device);
207*da0073e9SAndroid Build Coastguard Worker }
208*da0073e9SAndroid Build Coastguard Worker 
209*da0073e9SAndroid Build Coastguard Worker } // namespace c10::xpu
210