1*da0073e9SAndroid Build Coastguard Worker #include <c10/util/CallOnce.h>
2*da0073e9SAndroid Build Coastguard Worker #include <c10/util/Exception.h>
3*da0073e9SAndroid Build Coastguard Worker #include <c10/xpu/XPUFunctions.h>
4*da0073e9SAndroid Build Coastguard Worker
5*da0073e9SAndroid Build Coastguard Worker #include <vector>
6*da0073e9SAndroid Build Coastguard Worker
7*da0073e9SAndroid Build Coastguard Worker namespace c10::xpu {
8*da0073e9SAndroid Build Coastguard Worker namespace {
9*da0073e9SAndroid Build Coastguard Worker
10*da0073e9SAndroid Build Coastguard Worker /*
11*da0073e9SAndroid Build Coastguard Worker * Note [Device Management]
12*da0073e9SAndroid Build Coastguard Worker *
13*da0073e9SAndroid Build Coastguard Worker * An Intel GPU device qualifies as a type of SYCL device. This classification
14*da0073e9SAndroid Build Coastguard Worker * allows for the runtime querying of Intel GPU device information through the
15*da0073e9SAndroid Build Coastguard Worker * SYCL runtime library.
16*da0073e9SAndroid Build Coastguard Worker *
17*da0073e9SAndroid Build Coastguard Worker * Device status is managed through a SYCL device pool, with SYCL devices
18*da0073e9SAndroid Build Coastguard Worker * determined at runtime. There's currently a SYCL device pool that is lazily
19*da0073e9SAndroid Build Coastguard Worker * created and only initialized once, ensuring thread-local safety. Each device
20*da0073e9SAndroid Build Coastguard Worker * within the device pool shares the same default context.
21*da0073e9SAndroid Build Coastguard Worker */
22*da0073e9SAndroid Build Coastguard Worker c10::once_flag init_flag;
23*da0073e9SAndroid Build Coastguard Worker thread_local DeviceIndex curDeviceIndex = 0;
24*da0073e9SAndroid Build Coastguard Worker
25*da0073e9SAndroid Build Coastguard Worker struct DevicePool {
26*da0073e9SAndroid Build Coastguard Worker std::vector<std::unique_ptr<sycl::device>> devices;
27*da0073e9SAndroid Build Coastguard Worker std::unique_ptr<sycl::context> context;
28*da0073e9SAndroid Build Coastguard Worker } gDevicePool;
29*da0073e9SAndroid Build Coastguard Worker
enumDevices(std::vector<std::unique_ptr<sycl::device>> & devices)30*da0073e9SAndroid Build Coastguard Worker void enumDevices(std::vector<std::unique_ptr<sycl::device>>& devices) {
31*da0073e9SAndroid Build Coastguard Worker auto platform_list = sycl::platform::get_platforms();
32*da0073e9SAndroid Build Coastguard Worker // Enumerated GPU devices from the specific platform.
33*da0073e9SAndroid Build Coastguard Worker for (const auto& platform : platform_list) {
34*da0073e9SAndroid Build Coastguard Worker if (platform.get_backend() != sycl::backend::ext_oneapi_level_zero) {
35*da0073e9SAndroid Build Coastguard Worker continue;
36*da0073e9SAndroid Build Coastguard Worker }
37*da0073e9SAndroid Build Coastguard Worker auto device_list = platform.get_devices();
38*da0073e9SAndroid Build Coastguard Worker for (const auto& device : device_list) {
39*da0073e9SAndroid Build Coastguard Worker if (device.is_gpu()) {
40*da0073e9SAndroid Build Coastguard Worker devices.push_back(std::make_unique<sycl::device>(device));
41*da0073e9SAndroid Build Coastguard Worker }
42*da0073e9SAndroid Build Coastguard Worker }
43*da0073e9SAndroid Build Coastguard Worker }
44*da0073e9SAndroid Build Coastguard Worker }
45*da0073e9SAndroid Build Coastguard Worker
initGlobalDevicePoolState()46*da0073e9SAndroid Build Coastguard Worker inline void initGlobalDevicePoolState() {
47*da0073e9SAndroid Build Coastguard Worker // Enumerate all GPU devices and record them.
48*da0073e9SAndroid Build Coastguard Worker enumDevices(gDevicePool.devices);
49*da0073e9SAndroid Build Coastguard Worker if (gDevicePool.devices.empty()) {
50*da0073e9SAndroid Build Coastguard Worker TORCH_WARN("XPU device count is zero!");
51*da0073e9SAndroid Build Coastguard Worker return;
52*da0073e9SAndroid Build Coastguard Worker }
53*da0073e9SAndroid Build Coastguard Worker
54*da0073e9SAndroid Build Coastguard Worker #ifdef _WIN32
55*da0073e9SAndroid Build Coastguard Worker // default context feature is disabled by default on Windows.
56*da0073e9SAndroid Build Coastguard Worker std::vector<sycl::device> deviceList;
57*da0073e9SAndroid Build Coastguard Worker for (auto it = gDevicePool.devices.begin(); it != gDevicePool.devices.end();
58*da0073e9SAndroid Build Coastguard Worker ++it) {
59*da0073e9SAndroid Build Coastguard Worker deviceList.push_back(*(*it));
60*da0073e9SAndroid Build Coastguard Worker }
61*da0073e9SAndroid Build Coastguard Worker gDevicePool.context = std::make_unique<sycl::context>(deviceList);
62*da0073e9SAndroid Build Coastguard Worker #else
63*da0073e9SAndroid Build Coastguard Worker // The default context is utilized for each Intel GPU device, allowing the
64*da0073e9SAndroid Build Coastguard Worker // retrieval of the context from any GPU device.
65*da0073e9SAndroid Build Coastguard Worker gDevicePool.context = std::make_unique<sycl::context>(
66*da0073e9SAndroid Build Coastguard Worker gDevicePool.devices[0]->get_platform().ext_oneapi_get_default_context());
67*da0073e9SAndroid Build Coastguard Worker #endif
68*da0073e9SAndroid Build Coastguard Worker }
69*da0073e9SAndroid Build Coastguard Worker
initDevicePoolCallOnce()70*da0073e9SAndroid Build Coastguard Worker inline void initDevicePoolCallOnce() {
71*da0073e9SAndroid Build Coastguard Worker c10::call_once(init_flag, initGlobalDevicePoolState);
72*da0073e9SAndroid Build Coastguard Worker }
73*da0073e9SAndroid Build Coastguard Worker
initDeviceProperties(DeviceProp * device_prop,int device)74*da0073e9SAndroid Build Coastguard Worker void initDeviceProperties(DeviceProp* device_prop, int device) {
75*da0073e9SAndroid Build Coastguard Worker using namespace sycl::info;
76*da0073e9SAndroid Build Coastguard Worker using namespace sycl::ext;
77*da0073e9SAndroid Build Coastguard Worker // Get raw sycl device associated with device index.
78*da0073e9SAndroid Build Coastguard Worker auto& raw_device = *gDevicePool.devices[device];
79*da0073e9SAndroid Build Coastguard Worker
80*da0073e9SAndroid Build Coastguard Worker // Initialize the device properties associated with the specific device.
81*da0073e9SAndroid Build Coastguard Worker #define ASSIGN_DEVICE_PROP(property) \
82*da0073e9SAndroid Build Coastguard Worker device_prop->property = raw_device.get_info<device::property>();
83*da0073e9SAndroid Build Coastguard Worker
84*da0073e9SAndroid Build Coastguard Worker #define ASSIGN_EXT_DEVICE_PROP(property, default_value) \
85*da0073e9SAndroid Build Coastguard Worker device_prop->property = raw_device.has(sycl::aspect::ext_intel_##property) \
86*da0073e9SAndroid Build Coastguard Worker ? raw_device.get_info<intel::info::device::property>() \
87*da0073e9SAndroid Build Coastguard Worker : default_value;
88*da0073e9SAndroid Build Coastguard Worker
89*da0073e9SAndroid Build Coastguard Worker #define ASSIGN_DEVICE_ASPECT(member) \
90*da0073e9SAndroid Build Coastguard Worker device_prop->has_##member = raw_device.has(sycl::aspect::member);
91*da0073e9SAndroid Build Coastguard Worker
92*da0073e9SAndroid Build Coastguard Worker #define ASSIGN_EXP_CL_ASPECT(member) \
93*da0073e9SAndroid Build Coastguard Worker device_prop->has_##member = raw_device.ext_oneapi_supports_cl_extension( \
94*da0073e9SAndroid Build Coastguard Worker "cl_intel_" #member, &cl_version);
95*da0073e9SAndroid Build Coastguard Worker
96*da0073e9SAndroid Build Coastguard Worker AT_FORALL_XPU_DEVICE_PROPERTIES(ASSIGN_DEVICE_PROP);
97*da0073e9SAndroid Build Coastguard Worker
98*da0073e9SAndroid Build Coastguard Worker device_prop->platform_name =
99*da0073e9SAndroid Build Coastguard Worker raw_device.get_info<device::platform>().get_info<platform::name>();
100*da0073e9SAndroid Build Coastguard Worker
101*da0073e9SAndroid Build Coastguard Worker AT_FORALL_XPU_EXT_DEVICE_PROPERTIES(ASSIGN_EXT_DEVICE_PROP);
102*da0073e9SAndroid Build Coastguard Worker
103*da0073e9SAndroid Build Coastguard Worker AT_FORALL_XPU_DEVICE_ASPECT(ASSIGN_DEVICE_ASPECT);
104*da0073e9SAndroid Build Coastguard Worker
105*da0073e9SAndroid Build Coastguard Worker // TODO: Remove cl_version since it is unnecessary.
106*da0073e9SAndroid Build Coastguard Worker sycl::ext::oneapi::experimental::cl_version cl_version;
107*da0073e9SAndroid Build Coastguard Worker AT_FORALL_XPU_EXP_CL_ASPECT(ASSIGN_EXP_CL_ASPECT);
108*da0073e9SAndroid Build Coastguard Worker return;
109*da0073e9SAndroid Build Coastguard Worker }
110*da0073e9SAndroid Build Coastguard Worker
check_device(DeviceIndex device)111*da0073e9SAndroid Build Coastguard Worker inline void check_device(DeviceIndex device) {
112*da0073e9SAndroid Build Coastguard Worker // TODO: Use c10::Device::MAX_NUM_DEVICES directly. DeviceIndex is a int8_t
113*da0073e9SAndroid Build Coastguard Worker // value, and the maximum number of GPUs that PyTorch recognizes is 64. So, we
114*da0073e9SAndroid Build Coastguard Worker // have to check if there is an overflow happen. When DeviceIndex changes to
115*da0073e9SAndroid Build Coastguard Worker // int16_t and c10::Device::MAX_NUM_DEVICES is provided, we should use it
116*da0073e9SAndroid Build Coastguard Worker // directly to check if too many XPU devices are detected.
117*da0073e9SAndroid Build Coastguard Worker TORCH_CHECK(
118*da0073e9SAndroid Build Coastguard Worker gDevicePool.devices.size() <= std::numeric_limits<DeviceIndex>::max(),
119*da0073e9SAndroid Build Coastguard Worker "Too many XPU devices, DeviceIndex overflowed");
120*da0073e9SAndroid Build Coastguard Worker auto total = static_cast<DeviceIndex>(gDevicePool.devices.size());
121*da0073e9SAndroid Build Coastguard Worker TORCH_CHECK(
122*da0073e9SAndroid Build Coastguard Worker device >= 0 && device < total,
123*da0073e9SAndroid Build Coastguard Worker "device is out of range, device is ",
124*da0073e9SAndroid Build Coastguard Worker device,
125*da0073e9SAndroid Build Coastguard Worker ", total number of device is ",
126*da0073e9SAndroid Build Coastguard Worker total,
127*da0073e9SAndroid Build Coastguard Worker ".");
128*da0073e9SAndroid Build Coastguard Worker }
129*da0073e9SAndroid Build Coastguard Worker
130*da0073e9SAndroid Build Coastguard Worker } // anonymous namespace
131*da0073e9SAndroid Build Coastguard Worker
get_raw_device(DeviceIndex device)132*da0073e9SAndroid Build Coastguard Worker sycl::device& get_raw_device(DeviceIndex device) {
133*da0073e9SAndroid Build Coastguard Worker initDevicePoolCallOnce();
134*da0073e9SAndroid Build Coastguard Worker check_device(device);
135*da0073e9SAndroid Build Coastguard Worker return *gDevicePool.devices[device];
136*da0073e9SAndroid Build Coastguard Worker }
137*da0073e9SAndroid Build Coastguard Worker
get_device_context()138*da0073e9SAndroid Build Coastguard Worker sycl::context& get_device_context() {
139*da0073e9SAndroid Build Coastguard Worker initDevicePoolCallOnce();
140*da0073e9SAndroid Build Coastguard Worker TORCH_CHECK(
141*da0073e9SAndroid Build Coastguard Worker gDevicePool.context,
142*da0073e9SAndroid Build Coastguard Worker "Device pool initialization failed, you might not have an XPU device.")
143*da0073e9SAndroid Build Coastguard Worker return *gDevicePool.context;
144*da0073e9SAndroid Build Coastguard Worker }
145*da0073e9SAndroid Build Coastguard Worker
get_device_properties(DeviceProp * device_prop,DeviceIndex device)146*da0073e9SAndroid Build Coastguard Worker void get_device_properties(DeviceProp* device_prop, DeviceIndex device) {
147*da0073e9SAndroid Build Coastguard Worker initDevicePoolCallOnce();
148*da0073e9SAndroid Build Coastguard Worker TORCH_CHECK(device_prop, "device_prop is an invalid pointer.");
149*da0073e9SAndroid Build Coastguard Worker check_device(device);
150*da0073e9SAndroid Build Coastguard Worker initDeviceProperties(device_prop, device);
151*da0073e9SAndroid Build Coastguard Worker }
152*da0073e9SAndroid Build Coastguard Worker
get_device_idx_from_pointer(void * ptr)153*da0073e9SAndroid Build Coastguard Worker DeviceIndex get_device_idx_from_pointer(void* ptr) {
154*da0073e9SAndroid Build Coastguard Worker initDevicePoolCallOnce();
155*da0073e9SAndroid Build Coastguard Worker TORCH_CHECK(ptr, "ptr is an invalid pointer.");
156*da0073e9SAndroid Build Coastguard Worker auto type = sycl::get_pointer_type(ptr, get_device_context());
157*da0073e9SAndroid Build Coastguard Worker TORCH_CHECK(
158*da0073e9SAndroid Build Coastguard Worker type == sycl::usm::alloc::device, "ptr is not a device type pointer.");
159*da0073e9SAndroid Build Coastguard Worker
160*da0073e9SAndroid Build Coastguard Worker sycl::device raw_device = sycl::get_pointer_device(ptr, get_device_context());
161*da0073e9SAndroid Build Coastguard Worker auto match_device = [raw_device](const auto& device) -> bool {
162*da0073e9SAndroid Build Coastguard Worker return raw_device == *device;
163*da0073e9SAndroid Build Coastguard Worker };
164*da0073e9SAndroid Build Coastguard Worker auto it = std::find_if(
165*da0073e9SAndroid Build Coastguard Worker gDevicePool.devices.begin(), gDevicePool.devices.end(), match_device);
166*da0073e9SAndroid Build Coastguard Worker TORCH_CHECK(
167*da0073e9SAndroid Build Coastguard Worker it != gDevicePool.devices.end(),
168*da0073e9SAndroid Build Coastguard Worker "Can't find the pointer from XPU devices.");
169*da0073e9SAndroid Build Coastguard Worker return static_cast<DeviceIndex>(
170*da0073e9SAndroid Build Coastguard Worker std::distance(gDevicePool.devices.begin(), it));
171*da0073e9SAndroid Build Coastguard Worker }
172*da0073e9SAndroid Build Coastguard Worker
device_count()173*da0073e9SAndroid Build Coastguard Worker DeviceIndex device_count() {
174*da0073e9SAndroid Build Coastguard Worker initDevicePoolCallOnce();
175*da0073e9SAndroid Build Coastguard Worker return static_cast<DeviceIndex>(gDevicePool.devices.size());
176*da0073e9SAndroid Build Coastguard Worker }
177*da0073e9SAndroid Build Coastguard Worker
device_count_ensure_non_zero()178*da0073e9SAndroid Build Coastguard Worker DeviceIndex device_count_ensure_non_zero() {
179*da0073e9SAndroid Build Coastguard Worker auto count = device_count();
180*da0073e9SAndroid Build Coastguard Worker // Zero gpus could produce a warning in `device_count` but we fail here.
181*da0073e9SAndroid Build Coastguard Worker TORCH_CHECK(count, "No XPU devices are available.");
182*da0073e9SAndroid Build Coastguard Worker return count;
183*da0073e9SAndroid Build Coastguard Worker }
184*da0073e9SAndroid Build Coastguard Worker
current_device()185*da0073e9SAndroid Build Coastguard Worker DeviceIndex current_device() {
186*da0073e9SAndroid Build Coastguard Worker initDevicePoolCallOnce();
187*da0073e9SAndroid Build Coastguard Worker return curDeviceIndex;
188*da0073e9SAndroid Build Coastguard Worker }
189*da0073e9SAndroid Build Coastguard Worker
set_device(DeviceIndex device)190*da0073e9SAndroid Build Coastguard Worker void set_device(DeviceIndex device) {
191*da0073e9SAndroid Build Coastguard Worker initDevicePoolCallOnce();
192*da0073e9SAndroid Build Coastguard Worker check_device(device);
193*da0073e9SAndroid Build Coastguard Worker curDeviceIndex = device;
194*da0073e9SAndroid Build Coastguard Worker }
195*da0073e9SAndroid Build Coastguard Worker
exchange_device(c10::DeviceIndex to_device)196*da0073e9SAndroid Build Coastguard Worker c10::DeviceIndex exchange_device(c10::DeviceIndex to_device) {
197*da0073e9SAndroid Build Coastguard Worker auto cur_device = current_device();
198*da0073e9SAndroid Build Coastguard Worker if (to_device == cur_device) {
199*da0073e9SAndroid Build Coastguard Worker return cur_device;
200*da0073e9SAndroid Build Coastguard Worker }
201*da0073e9SAndroid Build Coastguard Worker set_device(to_device);
202*da0073e9SAndroid Build Coastguard Worker return cur_device;
203*da0073e9SAndroid Build Coastguard Worker }
204*da0073e9SAndroid Build Coastguard Worker
maybe_exchange_device(c10::DeviceIndex to_device)205*da0073e9SAndroid Build Coastguard Worker c10::DeviceIndex maybe_exchange_device(c10::DeviceIndex to_device) {
206*da0073e9SAndroid Build Coastguard Worker return exchange_device(to_device);
207*da0073e9SAndroid Build Coastguard Worker }
208*da0073e9SAndroid Build Coastguard Worker
209*da0073e9SAndroid Build Coastguard Worker } // namespace c10::xpu
210