1 #include <gtest/gtest.h>
2
3 #include <c10/core/DeviceGuard.h>
4 #include <c10/core/Event.h>
5 #include <c10/xpu/XPUStream.h>
6 #include <c10/xpu/test/impl/XPUTest.h>
7
has_xpu()8 bool has_xpu() {
9 return c10::xpu::device_count() > 0;
10 }
11
TEST(XPUGuardTest,GuardBehavior)12 TEST(XPUGuardTest, GuardBehavior) {
13 if (!has_xpu()) {
14 return;
15 }
16
17 {
18 auto device = c10::Device(c10::kXPU);
19 const c10::DeviceGuard device_guard(device);
20 EXPECT_EQ(c10::xpu::current_device(), 0);
21 }
22
23 std::vector<c10::xpu::XPUStream> streams0 = {
24 c10::xpu::getStreamFromPool(), c10::xpu::getStreamFromPool(true)};
25 EXPECT_EQ(streams0[0].device_index(), 0);
26 EXPECT_EQ(streams0[1].device_index(), 0);
27 c10::xpu::setCurrentXPUStream(streams0[0]);
28 EXPECT_EQ(c10::xpu::getCurrentXPUStream(), streams0[0]);
29
30 if (c10::xpu::device_count() <= 1) {
31 return;
32 }
33
34 // Test DeviceGuard for XPU.
35 std::vector<c10::xpu::XPUStream> streams1;
36 {
37 auto device = c10::Device(c10::kXPU, 1);
38 const c10::DeviceGuard device_guard(device);
39 streams1.push_back(c10::xpu::getStreamFromPool());
40 streams1.push_back(c10::xpu::getStreamFromPool());
41 }
42
43 EXPECT_EQ(streams1[0].device_index(), 1);
44 EXPECT_EQ(streams1[1].device_index(), 1);
45 EXPECT_EQ(c10::xpu::current_device(), 0);
46 }
47
TEST(XPUGuardTest,EventBehavior)48 TEST(XPUGuardTest, EventBehavior) {
49 if (!has_xpu()) {
50 return;
51 }
52
53 auto device = c10::Device(c10::kXPU, c10::xpu::current_device());
54 c10::impl::VirtualGuardImpl impl(device.type());
55 c10::Stream stream1 = impl.getStream(device);
56 c10::Stream stream2 = impl.getStream(device);
57 c10::Event event1(device.type());
58 // event is lazily created.
59 EXPECT_FALSE(event1.eventId());
60
61 constexpr int numel = 1024;
62 int hostData1[numel];
63 initHostData(hostData1, numel);
64 int hostData2[numel];
65 clearHostData(hostData2, numel);
66
67 auto xpu_stream1 = c10::xpu::XPUStream(stream1);
68 int* deviceData1 = sycl::malloc_device<int>(numel, xpu_stream1);
69
70 // Copy hostData1 to deviceData1 via stream1, and then copy deviceData1 to
71 // hostData2 via stream2.
72 xpu_stream1.queue().memcpy(deviceData1, hostData1, sizeof(int) * numel);
73 // stream2 wait on stream1's completion.
74 event1.record(stream1);
75 event1.block(stream2);
76 auto xpu_stream2 = c10::xpu::XPUStream(stream2);
77 xpu_stream2.queue().memcpy(hostData2, deviceData1, sizeof(int) * numel);
78 xpu_stream2.synchronize();
79
80 EXPECT_TRUE(event1.query());
81 validateHostData(hostData2, numel);
82 event1.record(stream2);
83 event1.synchronize();
84 EXPECT_TRUE(event1.query());
85
86 clearHostData(hostData2, numel);
87 xpu_stream1.queue().memcpy(deviceData1, hostData1, sizeof(int) * numel);
88 // stream2 wait on stream1's completion.
89 event1.record(stream1);
90 event1.block(stream2);
91 // event1 will overwrite the previously captured state.
92 event1.record(stream2);
93 xpu_stream2.queue().memcpy(hostData2, deviceData1, sizeof(int) * numel);
94 xpu_stream2.synchronize();
95 EXPECT_TRUE(event1.query());
96 validateHostData(hostData2, numel);
97
98 clearHostData(hostData2, numel);
99 // ensure deviceData1 and deviceData2 are different buffers.
100 int* deviceData2 = sycl::malloc_device<int>(numel, xpu_stream1);
101 sycl::free(deviceData1, c10::xpu::get_device_context());
102 c10::Event event2(device.type());
103
104 // Copy hostData1 to deviceData2 via stream1, and then copy deviceData2 to
105 // hostData1 via stream1.
106 xpu_stream1.queue().memcpy(deviceData2, hostData1, sizeof(int) * numel);
107 event2.record(xpu_stream1);
108 event2.synchronize();
109 EXPECT_TRUE(event2.query());
110 clearHostData(hostData1, numel);
111 xpu_stream1.queue().memcpy(hostData1, deviceData2, sizeof(int) * numel);
112 event2.record(xpu_stream1);
113 event2.synchronize();
114 EXPECT_TRUE(event2.query());
115 EXPECT_NE(event1.eventId(), event2.eventId());
116 ASSERT_THROW(event1.elapsedTime(event2), c10::Error);
117 sycl::free(deviceData2, c10::xpu::get_device_context());
118 }
119