1 #include <gtest/gtest.h>
2
3 #include <c10/core/Device.h>
4 #include <c10/core/DeviceType.h>
5 #include <c10/util/Exception.h>
6
7 // -- Device -------------------------------------------------------
8
9 struct ExpectedDeviceTestResult {
10 std::string device_string;
11 c10::DeviceType device_type;
12 c10::DeviceIndex device_index;
13 };
14
TEST(DeviceTest,BasicConstruction)15 TEST(DeviceTest, BasicConstruction) {
16 std::vector<ExpectedDeviceTestResult> valid_devices = {
17 {"cpu", c10::DeviceType::CPU, -1},
18 {"cuda", c10::DeviceType::CUDA, -1},
19 {"cpu:0", c10::DeviceType::CPU, 0},
20 {"cuda:0", c10::DeviceType::CUDA, 0},
21 {"cuda:1", c10::DeviceType::CUDA, 1},
22 };
23 std::vector<std::string> invalid_device_strings = {
24 "cpu:x",
25 "cpu:foo",
26 "cuda:cuda",
27 "cuda:",
28 "cpu:0:0",
29 "cpu:0:",
30 "cpu:-1",
31 "::",
32 ":",
33 "cpu:00",
34 "cpu:01"};
35
36 for (auto& ds : valid_devices) {
37 c10::Device d(ds.device_string);
38 ASSERT_EQ(d.type(), ds.device_type)
39 << "Device String: " << ds.device_string;
40 ASSERT_EQ(d.index(), ds.device_index)
41 << "Device String: " << ds.device_string;
42 }
43
44 auto make_device = [](const std::string& ds) { return c10::Device(ds); };
45
46 for (auto& ds : invalid_device_strings) {
47 EXPECT_THROW(make_device(ds), c10::Error) << "Device String: " << ds;
48 }
49 }
50
TEST(DeviceTypeTest,PrivateUseOneDeviceType)51 TEST(DeviceTypeTest, PrivateUseOneDeviceType) {
52 c10::register_privateuse1_backend("my_privateuse1_backend");
53 ASSERT_TRUE(c10::is_privateuse1_backend_registered());
54 ASSERT_EQ(c10::get_privateuse1_backend(true), "my_privateuse1_backend");
55 ASSERT_EQ(c10::get_privateuse1_backend(false), "MY_PRIVATEUSE1_BACKEND");
56 }
57
TEST(DeviceTypeTest,PrivateUseOneRegister)58 TEST(DeviceTypeTest, PrivateUseOneRegister) {
59 ASSERT_THROW(c10::register_privateuse1_backend("cpu"), c10::Error);
60 ASSERT_THROW(c10::register_privateuse1_backend("cuda"), c10::Error);
61 ASSERT_THROW(c10::register_privateuse1_backend("hip"), c10::Error);
62 ASSERT_THROW(c10::register_privateuse1_backend("mps"), c10::Error);
63 ASSERT_THROW(c10::register_privateuse1_backend("xpu"), c10::Error);
64 ASSERT_THROW(c10::register_privateuse1_backend("mtia"), c10::Error);
65 }
66