xref: /aosp_15_r20/external/pytorch/aten/src/ATen/xpu/detail/XPUHooks.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #include <ATen/xpu/PinnedMemoryAllocator.h>
2 #include <ATen/xpu/XPUContext.h>
3 #include <ATen/xpu/XPUDevice.h>
4 #include <ATen/xpu/XPUGeneratorImpl.h>
5 #include <ATen/xpu/detail/XPUHooks.h>
6 #include <c10/util/CallOnce.h>
7 #include <c10/util/Logging.h>
8 #include <c10/xpu/XPUCachingAllocator.h>
9 
10 namespace at::xpu::detail {
11 
initXPU() const12 void XPUHooks::initXPU() const {
13   C10_LOG_API_USAGE_ONCE("aten.init.xpu");
14   const auto device_count = c10::xpu::device_count_ensure_non_zero();
15   c10::xpu::XPUCachingAllocator::init(device_count);
16 }
17 
hasXPU() const18 bool XPUHooks::hasXPU() const {
19   return true;
20 }
21 
showConfig() const22 std::string XPUHooks::showConfig() const {
23   return "XPU backend";
24 }
25 
getGlobalIdxFromDevice(const at::Device & device) const26 int32_t XPUHooks::getGlobalIdxFromDevice(const at::Device& device) const {
27   TORCH_CHECK(device.is_xpu(), "Only the XPU device type is expected.");
28 #ifdef _WIN32
29   TORCH_CHECK(
30       false,
31       "Default context is not supported on XPU on Windows. So we can NOT find its global index of the ATen device.");
32 #else
33   return at::xpu::getGlobalIdxFromDevice(device.index());
34 #endif
35 }
36 
getXPUGenerator(DeviceIndex device_index) const37 Generator XPUHooks::getXPUGenerator(DeviceIndex device_index) const {
38   return make_generator<at::XPUGeneratorImpl>(device_index);
39 }
40 
getDefaultXPUGenerator(DeviceIndex device_index) const41 const Generator& XPUHooks::getDefaultXPUGenerator(
42     DeviceIndex device_index) const {
43   return at::xpu::detail::getDefaultXPUGenerator(device_index);
44 }
45 
getDeviceFromPtr(void * data) const46 Device XPUHooks::getDeviceFromPtr(void* data) const {
47 #ifdef _WIN32
48   TORCH_CHECK(
49       false,
50       "Default context is not supported on XPU on Windows. So we can NOT find the ATen device of a pointer.");
51 #else
52   return at::xpu::getDeviceFromPtr(data);
53 #endif
54 }
55 
getNumGPUs() const56 c10::DeviceIndex XPUHooks::getNumGPUs() const {
57   return at::xpu::device_count();
58 }
59 
current_device() const60 DeviceIndex XPUHooks::current_device() const {
61   return c10::xpu::current_device();
62 }
63 
deviceSynchronize(DeviceIndex device_index) const64 void XPUHooks::deviceSynchronize(DeviceIndex device_index) const {
65   // Only the SYCL queues we have reserved will be synchronized, see Note
66   // [Synchronize Streams on Device].
67   c10::xpu::syncStreamsOnDevice(device_index);
68 }
69 
getPinnedMemoryAllocator() const70 Allocator* XPUHooks::getPinnedMemoryAllocator() const {
71   return at::xpu::getPinnedMemoryAllocator();
72 }
73 
isPinnedPtr(const void * data) const74 bool XPUHooks::isPinnedPtr(const void* data) const {
75   if (!at::xpu::is_available()) {
76     return false;
77   }
78 
79   return sycl::usm::alloc::host ==
80       sycl::get_pointer_type(data, c10::xpu::get_device_context());
81 }
82 
hasPrimaryContext(DeviceIndex device_index) const83 bool XPUHooks::hasPrimaryContext(DeviceIndex device_index) const {
84   // The default context is utilized for each device. So it always returns true.
85   return true;
86 }
87 
88 REGISTER_XPU_HOOKS(XPUHooks);
89 
90 } // namespace at::xpu::detail
91