1 #include <c10/util/CallOnce.h>
2 #include <c10/util/Exception.h>
3 #include <c10/xpu/XPUFunctions.h>
4
5 #include <vector>
6
7 namespace c10::xpu {
8 namespace {
9
10 /*
11 * Note [Device Management]
12 *
13 * An Intel GPU device qualifies as a type of SYCL device. This classification
14 * allows for the runtime querying of Intel GPU device information through the
15 * SYCL runtime library.
16 *
17 * Device status is managed through a SYCL device pool, with SYCL devices
18 * determined at runtime. There's currently a SYCL device pool that is lazily
19 * created and only initialized once, ensuring thread-local safety. Each device
20 * within the device pool shares the same default context.
21 */
22 c10::once_flag init_flag;
23 thread_local DeviceIndex curDeviceIndex = 0;
24
25 struct DevicePool {
26 std::vector<std::unique_ptr<sycl::device>> devices;
27 std::unique_ptr<sycl::context> context;
28 } gDevicePool;
29
enumDevices(std::vector<std::unique_ptr<sycl::device>> & devices)30 void enumDevices(std::vector<std::unique_ptr<sycl::device>>& devices) {
31 auto platform_list = sycl::platform::get_platforms();
32 // Enumerated GPU devices from the specific platform.
33 for (const auto& platform : platform_list) {
34 if (platform.get_backend() != sycl::backend::ext_oneapi_level_zero) {
35 continue;
36 }
37 auto device_list = platform.get_devices();
38 for (const auto& device : device_list) {
39 if (device.is_gpu()) {
40 devices.push_back(std::make_unique<sycl::device>(device));
41 }
42 }
43 }
44 }
45
initGlobalDevicePoolState()46 inline void initGlobalDevicePoolState() {
47 // Enumerate all GPU devices and record them.
48 enumDevices(gDevicePool.devices);
49 if (gDevicePool.devices.empty()) {
50 TORCH_WARN("XPU device count is zero!");
51 return;
52 }
53
54 #ifdef _WIN32
55 // default context feature is disabled by default on Windows.
56 std::vector<sycl::device> deviceList;
57 for (auto it = gDevicePool.devices.begin(); it != gDevicePool.devices.end();
58 ++it) {
59 deviceList.push_back(*(*it));
60 }
61 gDevicePool.context = std::make_unique<sycl::context>(deviceList);
62 #else
63 // The default context is utilized for each Intel GPU device, allowing the
64 // retrieval of the context from any GPU device.
65 gDevicePool.context = std::make_unique<sycl::context>(
66 gDevicePool.devices[0]->get_platform().ext_oneapi_get_default_context());
67 #endif
68 }
69
initDevicePoolCallOnce()70 inline void initDevicePoolCallOnce() {
71 c10::call_once(init_flag, initGlobalDevicePoolState);
72 }
73
initDeviceProperties(DeviceProp * device_prop,int device)74 void initDeviceProperties(DeviceProp* device_prop, int device) {
75 using namespace sycl::info;
76 using namespace sycl::ext;
77 // Get raw sycl device associated with device index.
78 auto& raw_device = *gDevicePool.devices[device];
79
80 // Initialize the device properties associated with the specific device.
81 #define ASSIGN_DEVICE_PROP(property) \
82 device_prop->property = raw_device.get_info<device::property>();
83
84 #define ASSIGN_EXT_DEVICE_PROP(property, default_value) \
85 device_prop->property = raw_device.has(sycl::aspect::ext_intel_##property) \
86 ? raw_device.get_info<intel::info::device::property>() \
87 : default_value;
88
89 #define ASSIGN_DEVICE_ASPECT(member) \
90 device_prop->has_##member = raw_device.has(sycl::aspect::member);
91
92 #define ASSIGN_EXP_CL_ASPECT(member) \
93 device_prop->has_##member = raw_device.ext_oneapi_supports_cl_extension( \
94 "cl_intel_" #member, &cl_version);
95
96 AT_FORALL_XPU_DEVICE_PROPERTIES(ASSIGN_DEVICE_PROP);
97
98 device_prop->platform_name =
99 raw_device.get_info<device::platform>().get_info<platform::name>();
100
101 AT_FORALL_XPU_EXT_DEVICE_PROPERTIES(ASSIGN_EXT_DEVICE_PROP);
102
103 AT_FORALL_XPU_DEVICE_ASPECT(ASSIGN_DEVICE_ASPECT);
104
105 // TODO: Remove cl_version since it is unnecessary.
106 sycl::ext::oneapi::experimental::cl_version cl_version;
107 AT_FORALL_XPU_EXP_CL_ASPECT(ASSIGN_EXP_CL_ASPECT);
108 return;
109 }
110
check_device(DeviceIndex device)111 inline void check_device(DeviceIndex device) {
112 // TODO: Use c10::Device::MAX_NUM_DEVICES directly. DeviceIndex is a int8_t
113 // value, and the maximum number of GPUs that PyTorch recognizes is 64. So, we
114 // have to check if there is an overflow happen. When DeviceIndex changes to
115 // int16_t and c10::Device::MAX_NUM_DEVICES is provided, we should use it
116 // directly to check if too many XPU devices are detected.
117 TORCH_CHECK(
118 gDevicePool.devices.size() <= std::numeric_limits<DeviceIndex>::max(),
119 "Too many XPU devices, DeviceIndex overflowed");
120 auto total = static_cast<DeviceIndex>(gDevicePool.devices.size());
121 TORCH_CHECK(
122 device >= 0 && device < total,
123 "device is out of range, device is ",
124 device,
125 ", total number of device is ",
126 total,
127 ".");
128 }
129
130 } // anonymous namespace
131
get_raw_device(DeviceIndex device)132 sycl::device& get_raw_device(DeviceIndex device) {
133 initDevicePoolCallOnce();
134 check_device(device);
135 return *gDevicePool.devices[device];
136 }
137
get_device_context()138 sycl::context& get_device_context() {
139 initDevicePoolCallOnce();
140 TORCH_CHECK(
141 gDevicePool.context,
142 "Device pool initialization failed, you might not have an XPU device.")
143 return *gDevicePool.context;
144 }
145
get_device_properties(DeviceProp * device_prop,DeviceIndex device)146 void get_device_properties(DeviceProp* device_prop, DeviceIndex device) {
147 initDevicePoolCallOnce();
148 TORCH_CHECK(device_prop, "device_prop is an invalid pointer.");
149 check_device(device);
150 initDeviceProperties(device_prop, device);
151 }
152
get_device_idx_from_pointer(void * ptr)153 DeviceIndex get_device_idx_from_pointer(void* ptr) {
154 initDevicePoolCallOnce();
155 TORCH_CHECK(ptr, "ptr is an invalid pointer.");
156 auto type = sycl::get_pointer_type(ptr, get_device_context());
157 TORCH_CHECK(
158 type == sycl::usm::alloc::device, "ptr is not a device type pointer.");
159
160 sycl::device raw_device = sycl::get_pointer_device(ptr, get_device_context());
161 auto match_device = [raw_device](const auto& device) -> bool {
162 return raw_device == *device;
163 };
164 auto it = std::find_if(
165 gDevicePool.devices.begin(), gDevicePool.devices.end(), match_device);
166 TORCH_CHECK(
167 it != gDevicePool.devices.end(),
168 "Can't find the pointer from XPU devices.");
169 return static_cast<DeviceIndex>(
170 std::distance(gDevicePool.devices.begin(), it));
171 }
172
device_count()173 DeviceIndex device_count() {
174 initDevicePoolCallOnce();
175 return static_cast<DeviceIndex>(gDevicePool.devices.size());
176 }
177
device_count_ensure_non_zero()178 DeviceIndex device_count_ensure_non_zero() {
179 auto count = device_count();
180 // Zero gpus could produce a warning in `device_count` but we fail here.
181 TORCH_CHECK(count, "No XPU devices are available.");
182 return count;
183 }
184
current_device()185 DeviceIndex current_device() {
186 initDevicePoolCallOnce();
187 return curDeviceIndex;
188 }
189
set_device(DeviceIndex device)190 void set_device(DeviceIndex device) {
191 initDevicePoolCallOnce();
192 check_device(device);
193 curDeviceIndex = device;
194 }
195
exchange_device(c10::DeviceIndex to_device)196 c10::DeviceIndex exchange_device(c10::DeviceIndex to_device) {
197 auto cur_device = current_device();
198 if (to_device == cur_device) {
199 return cur_device;
200 }
201 set_device(to_device);
202 return cur_device;
203 }
204
maybe_exchange_device(c10::DeviceIndex to_device)205 c10::DeviceIndex maybe_exchange_device(c10::DeviceIndex to_device) {
206 return exchange_device(to_device);
207 }
208
209 } // namespace c10::xpu
210