1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/core/common_runtime/device_set.h"
17
18 #include <vector>
19
20 #include "tensorflow/core/common_runtime/device_factory.h"
21 #include "tensorflow/core/lib/core/status.h"
22 #include "tensorflow/core/platform/test.h"
23
24 namespace tensorflow {
25 namespace {
26
27 // Return a fake device with the specified type and name.
Dev(const char * type,const char * name)28 static Device* Dev(const char* type, const char* name) {
29 class FakeDevice : public Device {
30 public:
31 explicit FakeDevice(const DeviceAttributes& attr) : Device(nullptr, attr) {}
32 Status Sync() override { return OkStatus(); }
33 Allocator* GetAllocator(AllocatorAttributes) override { return nullptr; }
34 };
35 DeviceAttributes attr;
36 attr.set_name(name);
37 attr.set_device_type(type);
38 return new FakeDevice(attr);
39 }
40
41 class DeviceSetTest : public ::testing::Test {
42 public:
AddDevice(const char * type,const char * name)43 Device* AddDevice(const char* type, const char* name) {
44 Device* d = Dev(type, name);
45 owned_.emplace_back(d);
46 devices_.AddDevice(d);
47 return d;
48 }
49
device_set() const50 const DeviceSet& device_set() const { return devices_; }
51
types() const52 std::vector<DeviceType> types() const {
53 return devices_.PrioritizedDeviceTypeList();
54 }
55
56 private:
57 DeviceSet devices_;
58 std::vector<std::unique_ptr<Device>> owned_;
59 };
60
61 class DummyFactory : public DeviceFactory {
62 public:
ListPhysicalDevices(std::vector<string> * devices)63 Status ListPhysicalDevices(std::vector<string>* devices) override {
64 return OkStatus();
65 }
CreateDevices(const SessionOptions & options,const string & name_prefix,std::vector<std::unique_ptr<Device>> * devices)66 Status CreateDevices(const SessionOptions& options, const string& name_prefix,
67 std::vector<std::unique_ptr<Device>>* devices) override {
68 return OkStatus();
69 }
70 };
71
72 // Assumes the default priority is '50'.
73 REGISTER_LOCAL_DEVICE_FACTORY("d1", DummyFactory);
74 REGISTER_LOCAL_DEVICE_FACTORY("d2", DummyFactory, 51);
75 REGISTER_LOCAL_DEVICE_FACTORY("d3", DummyFactory, 49);
76
TEST_F(DeviceSetTest,PrioritizedDeviceTypeList)77 TEST_F(DeviceSetTest, PrioritizedDeviceTypeList) {
78 EXPECT_EQ(50, DeviceSet::DeviceTypeOrder(DeviceType("d1")));
79 EXPECT_EQ(51, DeviceSet::DeviceTypeOrder(DeviceType("d2")));
80 EXPECT_EQ(49, DeviceSet::DeviceTypeOrder(DeviceType("d3")));
81
82 EXPECT_EQ(std::vector<DeviceType>{}, types());
83
84 AddDevice("d1", "/job:a/replica:0/task:0/device:d1:0");
85 EXPECT_EQ(std::vector<DeviceType>{DeviceType("d1")}, types());
86
87 AddDevice("d1", "/job:a/replica:0/task:0/device:d1:1");
88 EXPECT_EQ(std::vector<DeviceType>{DeviceType("d1")}, types());
89
90 // D2 is prioritized higher than D1.
91 AddDevice("d2", "/job:a/replica:0/task:0/device:d2:0");
92 EXPECT_EQ((std::vector<DeviceType>{DeviceType("d2"), DeviceType("d1")}),
93 types());
94
95 // D3 is prioritized below D1.
96 AddDevice("d3", "/job:a/replica:0/task:0/device:d3:0");
97 EXPECT_EQ((std::vector<DeviceType>{
98 DeviceType("d2"),
99 DeviceType("d1"),
100 DeviceType("d3"),
101 }),
102 types());
103 }
104
TEST_F(DeviceSetTest,prioritized_devices)105 TEST_F(DeviceSetTest, prioritized_devices) {
106 Device* d1 = AddDevice("d1", "/job:a/replica:0/task:0/device:d1:0");
107 Device* d2 = AddDevice("d2", "/job:a/replica:0/task:0/device:d2:0");
108 EXPECT_EQ(device_set().prioritized_devices(),
109 (PrioritizedDeviceVector{std::make_pair(d2, 51),
110 std::make_pair(d1, 50)}));
111
112 // Cache is rebuilt when a device is added.
113 Device* d3 = AddDevice("d3", "/job:a/replica:0/task:0/device:d3:0");
114 EXPECT_EQ(
115 device_set().prioritized_devices(),
116 (PrioritizedDeviceVector{std::make_pair(d2, 51), std::make_pair(d1, 50),
117 std::make_pair(d3, 49)}));
118 }
119
TEST_F(DeviceSetTest,prioritized_device_types)120 TEST_F(DeviceSetTest, prioritized_device_types) {
121 AddDevice("d1", "/job:a/replica:0/task:0/device:d1:0");
122 AddDevice("d2", "/job:a/replica:0/task:0/device:d2:0");
123 EXPECT_EQ(
124 device_set().prioritized_device_types(),
125 (PrioritizedDeviceTypeVector{std::make_pair(DeviceType("d2"), 51),
126 std::make_pair(DeviceType("d1"), 50)}));
127
128 // Cache is rebuilt when a device is added.
129 AddDevice("d3", "/job:a/replica:0/task:0/device:d3:0");
130 EXPECT_EQ(
131 device_set().prioritized_device_types(),
132 (PrioritizedDeviceTypeVector{std::make_pair(DeviceType("d2"), 51),
133 std::make_pair(DeviceType("d1"), 50),
134 std::make_pair(DeviceType("d3"), 49)}));
135 }
136
TEST_F(DeviceSetTest,SortPrioritizedDeviceVector)137 TEST_F(DeviceSetTest, SortPrioritizedDeviceVector) {
138 Device* d1_0 = AddDevice("d1", "/job:a/replica:0/task:0/device:d1:0");
139 Device* d2_0 = AddDevice("d2", "/job:a/replica:0/task:0/device:d2:0");
140 Device* d3_0 = AddDevice("d3", "/job:a/replica:0/task:0/device:d3:0");
141 Device* d1_1 = AddDevice("d1", "/job:a/replica:0/task:0/device:d1:1");
142 Device* d2_1 = AddDevice("d2", "/job:a/replica:0/task:0/device:d2:1");
143 Device* d3_1 = AddDevice("d3", "/job:a/replica:0/task:0/device:d3:1");
144
145 PrioritizedDeviceVector sorted{
146 std::make_pair(d3_1, 30), std::make_pair(d1_0, 10),
147 std::make_pair(d2_0, 20), std::make_pair(d3_0, 30),
148 std::make_pair(d1_1, 20), std::make_pair(d2_1, 10)};
149
150 device_set().SortPrioritizedDeviceVector(&sorted);
151
152 EXPECT_EQ(sorted, (PrioritizedDeviceVector{
153 std::make_pair(d3_0, 30), std::make_pair(d3_1, 30),
154 std::make_pair(d2_0, 20), std::make_pair(d1_1, 20),
155 std::make_pair(d2_1, 10), std::make_pair(d1_0, 10)}));
156 }
157
TEST_F(DeviceSetTest,SortPrioritizedDeviceTypeVector)158 TEST_F(DeviceSetTest, SortPrioritizedDeviceTypeVector) {
159 PrioritizedDeviceTypeVector sorted{std::make_pair(DeviceType("d3"), 20),
160 std::make_pair(DeviceType("d1"), 20),
161 std::make_pair(DeviceType("d2"), 30)};
162
163 device_set().SortPrioritizedDeviceTypeVector(&sorted);
164
165 EXPECT_EQ(sorted, (PrioritizedDeviceTypeVector{
166 std::make_pair(DeviceType("d2"), 30),
167 std::make_pair(DeviceType("d1"), 20),
168 std::make_pair(DeviceType("d3"), 20)}));
169 }
170
171 } // namespace
172 } // namespace tensorflow
173