xref: /aosp_15_r20/external/pytorch/c10/xpu/test/impl/XPUGuardTest.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #include <gtest/gtest.h>
2 
3 #include <c10/core/DeviceGuard.h>
4 #include <c10/core/Event.h>
5 #include <c10/xpu/XPUStream.h>
6 #include <c10/xpu/test/impl/XPUTest.h>
7 
has_xpu()8 bool has_xpu() {
9   return c10::xpu::device_count() > 0;
10 }
11 
TEST(XPUGuardTest,GuardBehavior)12 TEST(XPUGuardTest, GuardBehavior) {
13   if (!has_xpu()) {
14     return;
15   }
16 
17   {
18     auto device = c10::Device(c10::kXPU);
19     const c10::DeviceGuard device_guard(device);
20     EXPECT_EQ(c10::xpu::current_device(), 0);
21   }
22 
23   std::vector<c10::xpu::XPUStream> streams0 = {
24       c10::xpu::getStreamFromPool(), c10::xpu::getStreamFromPool(true)};
25   EXPECT_EQ(streams0[0].device_index(), 0);
26   EXPECT_EQ(streams0[1].device_index(), 0);
27   c10::xpu::setCurrentXPUStream(streams0[0]);
28   EXPECT_EQ(c10::xpu::getCurrentXPUStream(), streams0[0]);
29 
30   if (c10::xpu::device_count() <= 1) {
31     return;
32   }
33 
34   // Test DeviceGuard for XPU.
35   std::vector<c10::xpu::XPUStream> streams1;
36   {
37     auto device = c10::Device(c10::kXPU, 1);
38     const c10::DeviceGuard device_guard(device);
39     streams1.push_back(c10::xpu::getStreamFromPool());
40     streams1.push_back(c10::xpu::getStreamFromPool());
41   }
42 
43   EXPECT_EQ(streams1[0].device_index(), 1);
44   EXPECT_EQ(streams1[1].device_index(), 1);
45   EXPECT_EQ(c10::xpu::current_device(), 0);
46 }
47 
TEST(XPUGuardTest,EventBehavior)48 TEST(XPUGuardTest, EventBehavior) {
49   if (!has_xpu()) {
50     return;
51   }
52 
53   auto device = c10::Device(c10::kXPU, c10::xpu::current_device());
54   c10::impl::VirtualGuardImpl impl(device.type());
55   c10::Stream stream1 = impl.getStream(device);
56   c10::Stream stream2 = impl.getStream(device);
57   c10::Event event1(device.type());
58   // event is lazily created.
59   EXPECT_FALSE(event1.eventId());
60 
61   constexpr int numel = 1024;
62   int hostData1[numel];
63   initHostData(hostData1, numel);
64   int hostData2[numel];
65   clearHostData(hostData2, numel);
66 
67   auto xpu_stream1 = c10::xpu::XPUStream(stream1);
68   int* deviceData1 = sycl::malloc_device<int>(numel, xpu_stream1);
69 
70   // Copy hostData1 to deviceData1 via stream1, and then copy deviceData1 to
71   // hostData2 via stream2.
72   xpu_stream1.queue().memcpy(deviceData1, hostData1, sizeof(int) * numel);
73   // stream2 wait on stream1's completion.
74   event1.record(stream1);
75   event1.block(stream2);
76   auto xpu_stream2 = c10::xpu::XPUStream(stream2);
77   xpu_stream2.queue().memcpy(hostData2, deviceData1, sizeof(int) * numel);
78   xpu_stream2.synchronize();
79 
80   EXPECT_TRUE(event1.query());
81   validateHostData(hostData2, numel);
82   event1.record(stream2);
83   event1.synchronize();
84   EXPECT_TRUE(event1.query());
85 
86   clearHostData(hostData2, numel);
87   xpu_stream1.queue().memcpy(deviceData1, hostData1, sizeof(int) * numel);
88   // stream2 wait on stream1's completion.
89   event1.record(stream1);
90   event1.block(stream2);
91   // event1 will overwrite the previously captured state.
92   event1.record(stream2);
93   xpu_stream2.queue().memcpy(hostData2, deviceData1, sizeof(int) * numel);
94   xpu_stream2.synchronize();
95   EXPECT_TRUE(event1.query());
96   validateHostData(hostData2, numel);
97 
98   clearHostData(hostData2, numel);
99   // ensure deviceData1 and deviceData2 are different buffers.
100   int* deviceData2 = sycl::malloc_device<int>(numel, xpu_stream1);
101   sycl::free(deviceData1, c10::xpu::get_device_context());
102   c10::Event event2(device.type());
103 
104   // Copy hostData1 to deviceData2 via stream1, and then copy deviceData2 to
105   // hostData1 via stream1.
106   xpu_stream1.queue().memcpy(deviceData2, hostData1, sizeof(int) * numel);
107   event2.record(xpu_stream1);
108   event2.synchronize();
109   EXPECT_TRUE(event2.query());
110   clearHostData(hostData1, numel);
111   xpu_stream1.queue().memcpy(hostData1, deviceData2, sizeof(int) * numel);
112   event2.record(xpu_stream1);
113   event2.synchronize();
114   EXPECT_TRUE(event2.query());
115   EXPECT_NE(event1.eventId(), event2.eventId());
116   ASSERT_THROW(event1.elapsedTime(event2), c10::Error);
117   sycl::free(deviceData2, c10::xpu::get_device_context());
118 }
119