xref: /aosp_15_r20/external/pytorch/c10/test/core/Device_test.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #include <gtest/gtest.h>
2 
3 #include <c10/core/Device.h>
4 #include <c10/core/DeviceType.h>
5 #include <c10/util/Exception.h>
6 
7 // -- Device -------------------------------------------------------
8 
9 struct ExpectedDeviceTestResult {
10   std::string device_string;
11   c10::DeviceType device_type;
12   c10::DeviceIndex device_index;
13 };
14 
TEST(DeviceTest,BasicConstruction)15 TEST(DeviceTest, BasicConstruction) {
16   std::vector<ExpectedDeviceTestResult> valid_devices = {
17       {"cpu", c10::DeviceType::CPU, -1},
18       {"cuda", c10::DeviceType::CUDA, -1},
19       {"cpu:0", c10::DeviceType::CPU, 0},
20       {"cuda:0", c10::DeviceType::CUDA, 0},
21       {"cuda:1", c10::DeviceType::CUDA, 1},
22   };
23   std::vector<std::string> invalid_device_strings = {
24       "cpu:x",
25       "cpu:foo",
26       "cuda:cuda",
27       "cuda:",
28       "cpu:0:0",
29       "cpu:0:",
30       "cpu:-1",
31       "::",
32       ":",
33       "cpu:00",
34       "cpu:01"};
35 
36   for (auto& ds : valid_devices) {
37     c10::Device d(ds.device_string);
38     ASSERT_EQ(d.type(), ds.device_type)
39         << "Device String: " << ds.device_string;
40     ASSERT_EQ(d.index(), ds.device_index)
41         << "Device String: " << ds.device_string;
42   }
43 
44   auto make_device = [](const std::string& ds) { return c10::Device(ds); };
45 
46   for (auto& ds : invalid_device_strings) {
47     EXPECT_THROW(make_device(ds), c10::Error) << "Device String: " << ds;
48   }
49 }
50 
TEST(DeviceTypeTest,PrivateUseOneDeviceType)51 TEST(DeviceTypeTest, PrivateUseOneDeviceType) {
52   c10::register_privateuse1_backend("my_privateuse1_backend");
53   ASSERT_TRUE(c10::is_privateuse1_backend_registered());
54   ASSERT_EQ(c10::get_privateuse1_backend(true), "my_privateuse1_backend");
55   ASSERT_EQ(c10::get_privateuse1_backend(false), "MY_PRIVATEUSE1_BACKEND");
56 }
57 
TEST(DeviceTypeTest,PrivateUseOneRegister)58 TEST(DeviceTypeTest, PrivateUseOneRegister) {
59   ASSERT_THROW(c10::register_privateuse1_backend("cpu"), c10::Error);
60   ASSERT_THROW(c10::register_privateuse1_backend("cuda"), c10::Error);
61   ASSERT_THROW(c10::register_privateuse1_backend("hip"), c10::Error);
62   ASSERT_THROW(c10::register_privateuse1_backend("mps"), c10::Error);
63   ASSERT_THROW(c10::register_privateuse1_backend("xpu"), c10::Error);
64   ASSERT_THROW(c10::register_privateuse1_backend("mtia"), c10::Error);
65 }
66