xref: /aosp_15_r20/external/tensorflow/tensorflow/core/common_runtime/device_set_test.cc (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/core/common_runtime/device_set.h"
17 
18 #include <vector>
19 
20 #include "tensorflow/core/common_runtime/device_factory.h"
21 #include "tensorflow/core/lib/core/status.h"
22 #include "tensorflow/core/platform/test.h"
23 
24 namespace tensorflow {
25 namespace {
26 
27 // Return a fake device with the specified type and name.
Dev(const char * type,const char * name)28 static Device* Dev(const char* type, const char* name) {
29   class FakeDevice : public Device {
30    public:
31     explicit FakeDevice(const DeviceAttributes& attr) : Device(nullptr, attr) {}
32     Status Sync() override { return OkStatus(); }
33     Allocator* GetAllocator(AllocatorAttributes) override { return nullptr; }
34   };
35   DeviceAttributes attr;
36   attr.set_name(name);
37   attr.set_device_type(type);
38   return new FakeDevice(attr);
39 }
40 
41 class DeviceSetTest : public ::testing::Test {
42  public:
AddDevice(const char * type,const char * name)43   Device* AddDevice(const char* type, const char* name) {
44     Device* d = Dev(type, name);
45     owned_.emplace_back(d);
46     devices_.AddDevice(d);
47     return d;
48   }
49 
device_set() const50   const DeviceSet& device_set() const { return devices_; }
51 
types() const52   std::vector<DeviceType> types() const {
53     return devices_.PrioritizedDeviceTypeList();
54   }
55 
56  private:
57   DeviceSet devices_;
58   std::vector<std::unique_ptr<Device>> owned_;
59 };
60 
61 class DummyFactory : public DeviceFactory {
62  public:
ListPhysicalDevices(std::vector<string> * devices)63   Status ListPhysicalDevices(std::vector<string>* devices) override {
64     return OkStatus();
65   }
CreateDevices(const SessionOptions & options,const string & name_prefix,std::vector<std::unique_ptr<Device>> * devices)66   Status CreateDevices(const SessionOptions& options, const string& name_prefix,
67                        std::vector<std::unique_ptr<Device>>* devices) override {
68     return OkStatus();
69   }
70 };
71 
72 // Assumes the default priority is '50'.
73 REGISTER_LOCAL_DEVICE_FACTORY("d1", DummyFactory);
74 REGISTER_LOCAL_DEVICE_FACTORY("d2", DummyFactory, 51);
75 REGISTER_LOCAL_DEVICE_FACTORY("d3", DummyFactory, 49);
76 
TEST_F(DeviceSetTest,PrioritizedDeviceTypeList)77 TEST_F(DeviceSetTest, PrioritizedDeviceTypeList) {
78   EXPECT_EQ(50, DeviceSet::DeviceTypeOrder(DeviceType("d1")));
79   EXPECT_EQ(51, DeviceSet::DeviceTypeOrder(DeviceType("d2")));
80   EXPECT_EQ(49, DeviceSet::DeviceTypeOrder(DeviceType("d3")));
81 
82   EXPECT_EQ(std::vector<DeviceType>{}, types());
83 
84   AddDevice("d1", "/job:a/replica:0/task:0/device:d1:0");
85   EXPECT_EQ(std::vector<DeviceType>{DeviceType("d1")}, types());
86 
87   AddDevice("d1", "/job:a/replica:0/task:0/device:d1:1");
88   EXPECT_EQ(std::vector<DeviceType>{DeviceType("d1")}, types());
89 
90   // D2 is prioritized higher than D1.
91   AddDevice("d2", "/job:a/replica:0/task:0/device:d2:0");
92   EXPECT_EQ((std::vector<DeviceType>{DeviceType("d2"), DeviceType("d1")}),
93             types());
94 
95   // D3 is prioritized below D1.
96   AddDevice("d3", "/job:a/replica:0/task:0/device:d3:0");
97   EXPECT_EQ((std::vector<DeviceType>{
98                 DeviceType("d2"),
99                 DeviceType("d1"),
100                 DeviceType("d3"),
101             }),
102             types());
103 }
104 
TEST_F(DeviceSetTest,prioritized_devices)105 TEST_F(DeviceSetTest, prioritized_devices) {
106   Device* d1 = AddDevice("d1", "/job:a/replica:0/task:0/device:d1:0");
107   Device* d2 = AddDevice("d2", "/job:a/replica:0/task:0/device:d2:0");
108   EXPECT_EQ(device_set().prioritized_devices(),
109             (PrioritizedDeviceVector{std::make_pair(d2, 51),
110                                      std::make_pair(d1, 50)}));
111 
112   // Cache is rebuilt when a device is added.
113   Device* d3 = AddDevice("d3", "/job:a/replica:0/task:0/device:d3:0");
114   EXPECT_EQ(
115       device_set().prioritized_devices(),
116       (PrioritizedDeviceVector{std::make_pair(d2, 51), std::make_pair(d1, 50),
117                                std::make_pair(d3, 49)}));
118 }
119 
TEST_F(DeviceSetTest,prioritized_device_types)120 TEST_F(DeviceSetTest, prioritized_device_types) {
121   AddDevice("d1", "/job:a/replica:0/task:0/device:d1:0");
122   AddDevice("d2", "/job:a/replica:0/task:0/device:d2:0");
123   EXPECT_EQ(
124       device_set().prioritized_device_types(),
125       (PrioritizedDeviceTypeVector{std::make_pair(DeviceType("d2"), 51),
126                                    std::make_pair(DeviceType("d1"), 50)}));
127 
128   // Cache is rebuilt when a device is added.
129   AddDevice("d3", "/job:a/replica:0/task:0/device:d3:0");
130   EXPECT_EQ(
131       device_set().prioritized_device_types(),
132       (PrioritizedDeviceTypeVector{std::make_pair(DeviceType("d2"), 51),
133                                    std::make_pair(DeviceType("d1"), 50),
134                                    std::make_pair(DeviceType("d3"), 49)}));
135 }
136 
TEST_F(DeviceSetTest,SortPrioritizedDeviceVector)137 TEST_F(DeviceSetTest, SortPrioritizedDeviceVector) {
138   Device* d1_0 = AddDevice("d1", "/job:a/replica:0/task:0/device:d1:0");
139   Device* d2_0 = AddDevice("d2", "/job:a/replica:0/task:0/device:d2:0");
140   Device* d3_0 = AddDevice("d3", "/job:a/replica:0/task:0/device:d3:0");
141   Device* d1_1 = AddDevice("d1", "/job:a/replica:0/task:0/device:d1:1");
142   Device* d2_1 = AddDevice("d2", "/job:a/replica:0/task:0/device:d2:1");
143   Device* d3_1 = AddDevice("d3", "/job:a/replica:0/task:0/device:d3:1");
144 
145   PrioritizedDeviceVector sorted{
146       std::make_pair(d3_1, 30), std::make_pair(d1_0, 10),
147       std::make_pair(d2_0, 20), std::make_pair(d3_0, 30),
148       std::make_pair(d1_1, 20), std::make_pair(d2_1, 10)};
149 
150   device_set().SortPrioritizedDeviceVector(&sorted);
151 
152   EXPECT_EQ(sorted, (PrioritizedDeviceVector{
153                         std::make_pair(d3_0, 30), std::make_pair(d3_1, 30),
154                         std::make_pair(d2_0, 20), std::make_pair(d1_1, 20),
155                         std::make_pair(d2_1, 10), std::make_pair(d1_0, 10)}));
156 }
157 
TEST_F(DeviceSetTest,SortPrioritizedDeviceTypeVector)158 TEST_F(DeviceSetTest, SortPrioritizedDeviceTypeVector) {
159   PrioritizedDeviceTypeVector sorted{std::make_pair(DeviceType("d3"), 20),
160                                      std::make_pair(DeviceType("d1"), 20),
161                                      std::make_pair(DeviceType("d2"), 30)};
162 
163   device_set().SortPrioritizedDeviceTypeVector(&sorted);
164 
165   EXPECT_EQ(sorted, (PrioritizedDeviceTypeVector{
166                         std::make_pair(DeviceType("d2"), 30),
167                         std::make_pair(DeviceType("d1"), 20),
168                         std::make_pair(DeviceType("d3"), 20)}));
169 }
170 
171 }  // namespace
172 }  // namespace tensorflow
173