1 /* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include "tensorflow/core/data/service/auto_shard_rewriter.h"
16
17 #include <string>
18
19 #include "absl/strings/string_view.h"
20 #include "absl/strings/substitute.h"
21 #include "tensorflow/core/data/service/common.pb.h"
22 #include "tensorflow/core/data/service/test_util.h"
23 #include "tensorflow/core/framework/dataset_options.pb.h"
24 #include "tensorflow/core/framework/function_testlib.h"
25 #include "tensorflow/core/framework/graph.pb.h"
26 #include "tensorflow/core/framework/node_def.pb.h"
27 #include "tensorflow/core/framework/tensor.pb.h"
28 #include "tensorflow/core/framework/tensor_testutil.h"
29 #include "tensorflow/core/framework/types.pb.h"
30 #include "tensorflow/core/lib/core/status_test_util.h"
31 #include "tensorflow/core/platform/errors.h"
32 #include "tensorflow/core/platform/status_matchers.h"
33 #include "tensorflow/core/platform/statusor.h"
34 #include "tensorflow/core/platform/test.h"
35 #include "tensorflow/core/platform/tstring.h"
36 #include "tensorflow/core/platform/types.h"
37 #include "tensorflow/core/protobuf/data_service.pb.h"
38 #include "tensorflow/core/protobuf/error_codes.pb.h"
39
40 namespace tensorflow {
41 namespace data {
42 namespace {
43
44 using ::tensorflow::data::testing::EqualsProto;
45 using ::tensorflow::data::testing::RangeDatasetWithShardHint;
46 using ::tensorflow::data::testing::RangeSquareDataset;
47 using ::tensorflow::testing::IsOkAndHolds;
48 using ::tensorflow::testing::StatusIs;
49 using ::testing::HasSubstr;
50 using ::testing::SizeIs;
51
GetNode(const GraphDef & graph_def,absl::string_view name)52 StatusOr<NodeDef> GetNode(const GraphDef& graph_def, absl::string_view name) {
53 for (const NodeDef& node : graph_def.node()) {
54 if (node.name() == name) {
55 return node;
56 }
57 }
58 return errors::NotFound(absl::Substitute("Node $0 not found in graph $1.",
59 name, graph_def.ShortDebugString()));
60 }
61
GetValue(const GraphDef & graph_def,absl::string_view name)62 StatusOr<int64_t> GetValue(const GraphDef& graph_def, absl::string_view name) {
63 for (const NodeDef& node : graph_def.node()) {
64 if (node.name() == name) {
65 return node.attr().at("value").tensor().int64_val()[0];
66 }
67 }
68 return errors::NotFound(absl::Substitute("Node $0 not found in graph $1.",
69 name, graph_def.ShortDebugString()));
70 }
71
GetTaskDef(const ProcessingModeDef::ShardingPolicy sharding_policy,const int64_t num_workers,const int64_t worker_index)72 TaskDef GetTaskDef(const ProcessingModeDef::ShardingPolicy sharding_policy,
73 const int64_t num_workers, const int64_t worker_index) {
74 TaskDef task_def;
75 task_def.mutable_processing_mode_def()->set_sharding_policy(sharding_policy);
76 task_def.set_num_workers(num_workers);
77 task_def.set_worker_index(worker_index);
78 return task_def;
79 }
80
TEST(AutoShardRewriterTest,AutoShard)81 TEST(AutoShardRewriterTest, AutoShard) {
82 TaskDef task_def = GetTaskDef(ProcessingModeDef::FILE_OR_DATA,
83 /*num_workers=*/3, /*worker_index=*/1);
84 TF_ASSERT_OK_AND_ASSIGN(AutoShardRewriter rewriter,
85 AutoShardRewriter::Create(task_def));
86
87 DatasetDef dataset = RangeSquareDataset(10);
88 TF_ASSERT_OK_AND_ASSIGN(GraphDef rewritten_graph,
89 rewriter.ApplyAutoShardRewrite(dataset.graph()));
90 TF_ASSERT_OK_AND_ASSIGN(NodeDef shard_node,
91 GetNode(rewritten_graph, "ShardDataset"));
92 ASSERT_THAT(shard_node.input(), SizeIs(3));
93 EXPECT_THAT(GetValue(rewritten_graph, shard_node.input(1)), IsOkAndHolds(3));
94 EXPECT_THAT(GetValue(rewritten_graph, shard_node.input(2)), IsOkAndHolds(1));
95 }
96
TEST(AutoShardRewriterTest,ShardByData)97 TEST(AutoShardRewriterTest, ShardByData) {
98 TaskDef task_def = GetTaskDef(ProcessingModeDef::DATA, /*num_workers=*/3,
99 /*worker_index=*/1);
100 TF_ASSERT_OK_AND_ASSIGN(AutoShardRewriter rewriter,
101 AutoShardRewriter::Create(task_def));
102
103 DatasetDef dataset = RangeSquareDataset(10);
104 TF_ASSERT_OK_AND_ASSIGN(GraphDef rewritten_graph,
105 rewriter.ApplyAutoShardRewrite(dataset.graph()));
106 TF_ASSERT_OK_AND_ASSIGN(NodeDef shard_node,
107 GetNode(rewritten_graph, "ShardDataset"));
108 ASSERT_THAT(shard_node.input(), SizeIs(3));
109 EXPECT_THAT(GetValue(rewritten_graph, shard_node.input(1)), IsOkAndHolds(3));
110 EXPECT_THAT(GetValue(rewritten_graph, shard_node.input(2)), IsOkAndHolds(1));
111 }
112
TEST(AutoShardRewriterTest,ShardByFile)113 TEST(AutoShardRewriterTest, ShardByFile) {
114 TaskDef task_def = GetTaskDef(ProcessingModeDef::FILE, /*num_workers=*/3,
115 /*worker_index=*/1);
116 TF_ASSERT_OK_AND_ASSIGN(AutoShardRewriter rewriter,
117 AutoShardRewriter::Create(task_def));
118
119 DatasetDef dataset = RangeSquareDataset(10);
120 EXPECT_THAT(rewriter.ApplyAutoShardRewrite(dataset.graph()),
121 StatusIs(error::NOT_FOUND,
122 HasSubstr("Found an unshardable source dataset")));
123 }
124
TEST(AutoShardRewriterTest,ShardByHint)125 TEST(AutoShardRewriterTest, ShardByHint) {
126 TaskDef task_def = GetTaskDef(ProcessingModeDef::HINT, /*num_workers=*/3,
127 /*worker_index=*/1);
128 TF_ASSERT_OK_AND_ASSIGN(AutoShardRewriter rewriter,
129 AutoShardRewriter::Create(task_def));
130
131 DatasetDef dataset = RangeDatasetWithShardHint(10);
132 TF_ASSERT_OK_AND_ASSIGN(GraphDef rewritten_graph,
133 rewriter.ApplyAutoShardRewrite(dataset.graph()));
134 TF_ASSERT_OK_AND_ASSIGN(NodeDef shard_node,
135 GetNode(rewritten_graph, "ShardDataset"));
136 ASSERT_THAT(shard_node.input(), SizeIs(3));
137 EXPECT_THAT(GetValue(rewritten_graph, shard_node.input(1)), IsOkAndHolds(3));
138 EXPECT_THAT(GetValue(rewritten_graph, shard_node.input(2)), IsOkAndHolds(1));
139 }
140
TEST(AutoShardRewriterTest,NoShard)141 TEST(AutoShardRewriterTest, NoShard) {
142 TaskDef task_def =
143 GetTaskDef(ProcessingModeDef::OFF, /*num_workers=*/3, /*worker_index=*/1);
144 TF_ASSERT_OK_AND_ASSIGN(AutoShardRewriter rewriter,
145 AutoShardRewriter::Create(task_def));
146
147 DatasetDef dataset = RangeSquareDataset(10);
148 EXPECT_THAT(rewriter.ApplyAutoShardRewrite(dataset.graph()),
149 IsOkAndHolds(EqualsProto(dataset.graph())));
150 }
151
TEST(AutoShardRewriterTest,EmptyDataset)152 TEST(AutoShardRewriterTest, EmptyDataset) {
153 TaskDef task_def =
154 GetTaskDef(ProcessingModeDef::FILE_OR_DATA, /*num_workers=*/3,
155 /*worker_index=*/1);
156 TF_ASSERT_OK_AND_ASSIGN(AutoShardRewriter rewriter,
157 AutoShardRewriter::Create(task_def));
158
159 DatasetDef dataset = RangeSquareDataset(0);
160 TF_ASSERT_OK_AND_ASSIGN(GraphDef rewritten_graph,
161 rewriter.ApplyAutoShardRewrite(dataset.graph()));
162 TF_ASSERT_OK_AND_ASSIGN(NodeDef shard_node,
163 GetNode(rewritten_graph, "ShardDataset"));
164 ASSERT_THAT(shard_node.input(), SizeIs(3));
165 EXPECT_THAT(GetValue(rewritten_graph, shard_node.input(1)), IsOkAndHolds(3));
166 EXPECT_THAT(GetValue(rewritten_graph, shard_node.input(2)), IsOkAndHolds(1));
167 }
168
TEST(AutoShardRewriterTest,NoWorkers)169 TEST(AutoShardRewriterTest, NoWorkers) {
170 TaskDef task_def =
171 GetTaskDef(ProcessingModeDef::FILE_OR_DATA, /*num_workers=*/0,
172 /*worker_index=*/0);
173 TF_ASSERT_OK_AND_ASSIGN(AutoShardRewriter rewriter,
174 AutoShardRewriter::Create(task_def));
175
176 DatasetDef dataset = RangeSquareDataset(10);
177 EXPECT_THAT(rewriter.ApplyAutoShardRewrite(dataset.graph()),
178 StatusIs(error::INVALID_ARGUMENT,
179 "num_workers should be >= 1, currently 0"));
180 }
181
TEST(AutoShardRewriterTest,NoWorkersWhenShardIsOff)182 TEST(AutoShardRewriterTest, NoWorkersWhenShardIsOff) {
183 TaskDef task_def =
184 GetTaskDef(ProcessingModeDef::OFF, /*num_workers=*/0, /*worker_index=*/0);
185 TF_ASSERT_OK_AND_ASSIGN(AutoShardRewriter rewriter,
186 AutoShardRewriter::Create(task_def));
187
188 DatasetDef dataset = RangeSquareDataset(10);
189 EXPECT_THAT(rewriter.ApplyAutoShardRewrite(dataset.graph()),
190 IsOkAndHolds(EqualsProto(dataset.graph())));
191 }
192
TEST(AutoShardRewriterTest,WorkerIndexOutOfRange)193 TEST(AutoShardRewriterTest, WorkerIndexOutOfRange) {
194 TaskDef task_def =
195 GetTaskDef(ProcessingModeDef::FILE_OR_DATA, /*num_workers=*/2,
196 /*worker_index=*/5);
197 TF_ASSERT_OK_AND_ASSIGN(AutoShardRewriter rewriter,
198 AutoShardRewriter::Create(task_def));
199
200 DatasetDef dataset = RangeSquareDataset(10);
201 EXPECT_THAT(rewriter.ApplyAutoShardRewrite(dataset.graph()),
202 StatusIs(error::INVALID_ARGUMENT,
203 "index should be >= 0 and < 2, currently 5"));
204 }
205
TEST(WorkerIndexResolverTest,AddOneWorker)206 TEST(WorkerIndexResolverTest, AddOneWorker) {
207 WorkerIndexResolver resolver(std::vector<std::string>{"localhost"});
208 EXPECT_THAT(resolver.GetWorkerIndex("localhost:12345"),
209 StatusIs(error::NOT_FOUND));
210
211 TF_EXPECT_OK(resolver.ValidateWorker("localhost:12345"));
212 resolver.AddWorker("localhost:12345");
213 EXPECT_THAT(resolver.GetWorkerIndex("localhost:12345"), IsOkAndHolds(0));
214 }
215
TEST(WorkerIndexResolverTest,AddMultipleWorkers)216 TEST(WorkerIndexResolverTest, AddMultipleWorkers) {
217 WorkerIndexResolver resolver(std::vector<std::string>{
218 "/worker/task/0", "/worker/task/1", "/worker/task/2"});
219 TF_EXPECT_OK(resolver.ValidateWorker("/worker/task/2:12345"));
220 TF_EXPECT_OK(resolver.ValidateWorker("/worker/task/1:23456"));
221 TF_EXPECT_OK(resolver.ValidateWorker("/worker/task/0:34567"));
222 resolver.AddWorker("/worker/task/2:12345");
223 resolver.AddWorker("/worker/task/1:23456");
224 resolver.AddWorker("/worker/task/0:34567");
225 EXPECT_THAT(resolver.GetWorkerIndex("/worker/task/0:34567"), IsOkAndHolds(0));
226 EXPECT_THAT(resolver.GetWorkerIndex("/worker/task/1:23456"), IsOkAndHolds(1));
227 EXPECT_THAT(resolver.GetWorkerIndex("/worker/task/2:12345"), IsOkAndHolds(2));
228 }
229
TEST(WorkerIndexResolverTest,NamedPorts)230 TEST(WorkerIndexResolverTest, NamedPorts) {
231 WorkerIndexResolver resolver(
232 std::vector<std::string>{"/worker/task/0:worker", "/worker/task/1:worker",
233 "/worker/task/2:worker"});
234 TF_EXPECT_OK(resolver.ValidateWorker("/worker/task/2:worker"));
235 TF_EXPECT_OK(resolver.ValidateWorker("/worker/task/1:worker"));
236 TF_EXPECT_OK(resolver.ValidateWorker("/worker/task/0:worker"));
237 resolver.AddWorker("/worker/task/2:worker");
238 resolver.AddWorker("/worker/task/1:worker");
239 resolver.AddWorker("/worker/task/0:worker");
240 EXPECT_THAT(resolver.GetWorkerIndex("/worker/task/0:worker"),
241 IsOkAndHolds(0));
242 EXPECT_THAT(resolver.GetWorkerIndex("/worker/task/1:worker"),
243 IsOkAndHolds(1));
244 EXPECT_THAT(resolver.GetWorkerIndex("/worker/task/2:worker"),
245 IsOkAndHolds(2));
246 }
247
TEST(WorkerIndexResolverTest,DynamicPorts)248 TEST(WorkerIndexResolverTest, DynamicPorts) {
249 WorkerIndexResolver resolver(std::vector<std::string>{
250 "/worker/task/0:%port_worker%", "/worker/task/1:%port_worker%",
251 "/worker/task/2:%port_worker%"});
252 TF_EXPECT_OK(resolver.ValidateWorker("/worker/task/2:worker"));
253 TF_EXPECT_OK(resolver.ValidateWorker("/worker/task/1:worker"));
254 TF_EXPECT_OK(resolver.ValidateWorker("/worker/task/0:worker"));
255 resolver.AddWorker("/worker/task/2:worker");
256 resolver.AddWorker("/worker/task/1:worker");
257 resolver.AddWorker("/worker/task/0:worker");
258 EXPECT_THAT(resolver.GetWorkerIndex("/worker/task/0:worker"),
259 IsOkAndHolds(0));
260 EXPECT_THAT(resolver.GetWorkerIndex("/worker/task/1:worker"),
261 IsOkAndHolds(1));
262 EXPECT_THAT(resolver.GetWorkerIndex("/worker/task/2:worker"),
263 IsOkAndHolds(2));
264 }
265
TEST(WorkerIndexResolverTest,AnonymousPorts)266 TEST(WorkerIndexResolverTest, AnonymousPorts) {
267 WorkerIndexResolver resolver(
268 std::vector<std::string>{"/worker/task/0:%port%", "/worker/task/1:%port%",
269 "/worker/task/2:%port%"});
270 TF_EXPECT_OK(resolver.ValidateWorker("/worker/task/2:10000"));
271 TF_EXPECT_OK(resolver.ValidateWorker("/worker/task/1:10001"));
272 TF_EXPECT_OK(resolver.ValidateWorker("/worker/task/0:10002"));
273 resolver.AddWorker("/worker/task/2:10000");
274 resolver.AddWorker("/worker/task/1:10001");
275 resolver.AddWorker("/worker/task/0:10002");
276 EXPECT_THAT(resolver.GetWorkerIndex("/worker/task/0:10002"), IsOkAndHolds(0));
277 EXPECT_THAT(resolver.GetWorkerIndex("/worker/task/1:10001"), IsOkAndHolds(1));
278 EXPECT_THAT(resolver.GetWorkerIndex("/worker/task/2:10000"), IsOkAndHolds(2));
279 }
280
TEST(WorkerIndexResolverTest,NumericPorts)281 TEST(WorkerIndexResolverTest, NumericPorts) {
282 WorkerIndexResolver resolver(std::vector<std::string>{
283 "/worker/task/0:12345", "/worker/task/1:23456", "/worker/task/2:34567"});
284 EXPECT_THAT(resolver.GetWorkerIndex("/worker/task/0:12345"), IsOkAndHolds(0));
285 EXPECT_THAT(resolver.GetWorkerIndex("/worker/task/1:23456"), IsOkAndHolds(1));
286 EXPECT_THAT(resolver.GetWorkerIndex("/worker/task/2:34567"), IsOkAndHolds(2));
287
288 // Adding duplicate workers is a no-op.
289 TF_EXPECT_OK(resolver.ValidateWorker("/worker/task/2:34567"));
290 TF_EXPECT_OK(resolver.ValidateWorker("/worker/task/1:23456"));
291 TF_EXPECT_OK(resolver.ValidateWorker("/worker/task/0:12345"));
292 resolver.AddWorker("/worker/task/2:34567");
293 resolver.AddWorker("/worker/task/1:23456");
294 resolver.AddWorker("/worker/task/0:12345");
295 EXPECT_THAT(resolver.GetWorkerIndex("/worker/task/0:12345"), IsOkAndHolds(0));
296 EXPECT_THAT(resolver.GetWorkerIndex("/worker/task/1:23456"), IsOkAndHolds(1));
297 EXPECT_THAT(resolver.GetWorkerIndex("/worker/task/2:34567"), IsOkAndHolds(2));
298 }
299
TEST(WorkerIndexResolverTest,IPv6Addresses)300 TEST(WorkerIndexResolverTest, IPv6Addresses) {
301 WorkerIndexResolver resolver(std::vector<std::string>{
302 "[1080:0:0:0:8:800:200C:417A]", "[1080:0:0:0:8:800:200C:417B]",
303 "[1080:0:0:0:8:800:200C:417C]"});
304 TF_EXPECT_OK(resolver.ValidateWorker("[1080:0:0:0:8:800:200C:417A]:12345"));
305 TF_EXPECT_OK(resolver.ValidateWorker("[1080:0:0:0:8:800:200C:417B]:23456"));
306 TF_EXPECT_OK(resolver.ValidateWorker("[1080:0:0:0:8:800:200C:417C]:34567"));
307 resolver.AddWorker("[1080:0:0:0:8:800:200C:417A]:12345");
308 resolver.AddWorker("[1080:0:0:0:8:800:200C:417B]:23456");
309 resolver.AddWorker("[1080:0:0:0:8:800:200C:417C]:34567");
310 EXPECT_THAT(resolver.GetWorkerIndex("[1080:0:0:0:8:800:200C:417A]:12345"),
311 IsOkAndHolds(0));
312 EXPECT_THAT(resolver.GetWorkerIndex("[1080:0:0:0:8:800:200C:417B]:23456"),
313 IsOkAndHolds(1));
314 EXPECT_THAT(resolver.GetWorkerIndex("[1080:0:0:0:8:800:200C:417C]:34567"),
315 IsOkAndHolds(2));
316 }
317
TEST(WorkerIndexResolverTest,IPv6AddressesWithDynamicPort)318 TEST(WorkerIndexResolverTest, IPv6AddressesWithDynamicPort) {
319 WorkerIndexResolver resolver(
320 std::vector<std::string>{"[1080:0:0:0:8:800:200C:417A]:%port%",
321 "[1080:0:0:0:8:800:200C:417B]:%port%",
322 "[1080:0:0:0:8:800:200C:417C]:%port%"});
323 TF_EXPECT_OK(resolver.ValidateWorker("[1080:0:0:0:8:800:200C:417A]:12345"));
324 TF_EXPECT_OK(resolver.ValidateWorker("[1080:0:0:0:8:800:200C:417B]:23456"));
325 TF_EXPECT_OK(resolver.ValidateWorker("[1080:0:0:0:8:800:200C:417C]:34567"));
326 resolver.AddWorker("[1080:0:0:0:8:800:200C:417A]:12345");
327 resolver.AddWorker("[1080:0:0:0:8:800:200C:417B]:23456");
328 resolver.AddWorker("[1080:0:0:0:8:800:200C:417C]:34567");
329 EXPECT_THAT(resolver.GetWorkerIndex("[1080:0:0:0:8:800:200C:417A]:12345"),
330 IsOkAndHolds(0));
331 EXPECT_THAT(resolver.GetWorkerIndex("[1080:0:0:0:8:800:200C:417B]:23456"),
332 IsOkAndHolds(1));
333 EXPECT_THAT(resolver.GetWorkerIndex("[1080:0:0:0:8:800:200C:417C]:34567"),
334 IsOkAndHolds(2));
335 }
336
TEST(WorkerIndexResolverTest,AddressesWithProtocols)337 TEST(WorkerIndexResolverTest, AddressesWithProtocols) {
338 WorkerIndexResolver resolver(std::vector<std::string>{
339 "http://127.0.0.1", "http://127.0.0.1", "http://127.0.0.1"});
340 TF_EXPECT_OK(resolver.ValidateWorker("http://127.0.0.1:12345"));
341 TF_EXPECT_OK(resolver.ValidateWorker("http://127.0.0.1:23456"));
342 TF_EXPECT_OK(resolver.ValidateWorker("http://127.0.0.1:34567"));
343 resolver.AddWorker("http://127.0.0.1:12345");
344 resolver.AddWorker("http://127.0.0.1:23456");
345 resolver.AddWorker("http://127.0.0.1:34567");
346 EXPECT_THAT(resolver.GetWorkerIndex("http://127.0.0.1:12345"),
347 IsOkAndHolds(0));
348 EXPECT_THAT(resolver.GetWorkerIndex("http://127.0.0.1:23456"),
349 IsOkAndHolds(1));
350 EXPECT_THAT(resolver.GetWorkerIndex("http://127.0.0.1:34567"),
351 IsOkAndHolds(2));
352 }
353
TEST(WorkerIndexResolverTest,AddressesWithProtocolsAndDynamicPorts)354 TEST(WorkerIndexResolverTest, AddressesWithProtocolsAndDynamicPorts) {
355 WorkerIndexResolver resolver(std::vector<std::string>{
356 "http://127.0.0.1:%port_name%", "http://127.0.0.1:%port_name%",
357 "http://127.0.0.1:%port_name%"});
358 TF_EXPECT_OK(resolver.ValidateWorker("http://127.0.0.1:12345"));
359 TF_EXPECT_OK(resolver.ValidateWorker("http://127.0.0.1:23456"));
360 TF_EXPECT_OK(resolver.ValidateWorker("http://127.0.0.1:34567"));
361 resolver.AddWorker("http://127.0.0.1:12345");
362 resolver.AddWorker("http://127.0.0.1:23456");
363 resolver.AddWorker("http://127.0.0.1:34567");
364 EXPECT_THAT(resolver.GetWorkerIndex("http://127.0.0.1:12345"),
365 IsOkAndHolds(0));
366 EXPECT_THAT(resolver.GetWorkerIndex("http://127.0.0.1:23456"),
367 IsOkAndHolds(1));
368 EXPECT_THAT(resolver.GetWorkerIndex("http://127.0.0.1:34567"),
369 IsOkAndHolds(2));
370 }
371
TEST(WorkerIndexResolverTest,HostNameHasColons)372 TEST(WorkerIndexResolverTest, HostNameHasColons) {
373 WorkerIndexResolver resolver(
374 std::vector<std::string>{":worker:task:0:%port%", ":worker:task:1:%port%",
375 ":worker:task:2:34567"});
376 TF_EXPECT_OK(resolver.ValidateWorker(":worker:task:0:12345"));
377 TF_EXPECT_OK(resolver.ValidateWorker(":worker:task:1:23456"));
378 TF_EXPECT_OK(resolver.ValidateWorker(":worker:task:2:34567"));
379 resolver.AddWorker(":worker:task:0:12345");
380 resolver.AddWorker(":worker:task:1:23456");
381 resolver.AddWorker(":worker:task:2:34567");
382 EXPECT_THAT(resolver.GetWorkerIndex(":worker:task:0:12345"), IsOkAndHolds(0));
383 EXPECT_THAT(resolver.GetWorkerIndex(":worker:task:1:23456"), IsOkAndHolds(1));
384 EXPECT_THAT(resolver.GetWorkerIndex(":worker:task:2:34567"), IsOkAndHolds(2));
385 }
386
TEST(WorkerIndexResolverTest,ChangeWorkerPort)387 TEST(WorkerIndexResolverTest, ChangeWorkerPort) {
388 WorkerIndexResolver resolver(std::vector<std::string>{
389 "/worker/task/0", "/worker/task/1", "/worker/task/2"});
390 TF_EXPECT_OK(resolver.ValidateWorker("/worker/task/2:12345"));
391 TF_EXPECT_OK(resolver.ValidateWorker("/worker/task/1:23456"));
392 TF_EXPECT_OK(resolver.ValidateWorker("/worker/task/0:34567"));
393 resolver.AddWorker("/worker/task/2:12345");
394 resolver.AddWorker("/worker/task/1:23456");
395 resolver.AddWorker("/worker/task/0:34567");
396 EXPECT_THAT(resolver.ValidateWorker("/worker/task/0:99999"),
397 StatusIs(error::FAILED_PRECONDITION,
398 HasSubstr("already running at the configured host")));
399 EXPECT_THAT(resolver.ValidateWorker("/worker/task/1:99999"),
400 StatusIs(error::FAILED_PRECONDITION,
401 HasSubstr("already running at the configured host")));
402 EXPECT_THAT(resolver.ValidateWorker("/worker/task/2:99999"),
403 StatusIs(error::FAILED_PRECONDITION,
404 HasSubstr("already running at the configured host")));
405 }
406
TEST(WorkerIndexResolverTest,WorkerNotFound)407 TEST(WorkerIndexResolverTest, WorkerNotFound) {
408 WorkerIndexResolver resolver(std::vector<std::string>{
409 "/worker/task/0", "/worker/task/1", "/worker/task/2"});
410 EXPECT_THAT(resolver.GetWorkerIndex("/worker/task/0:34567"),
411 StatusIs(error::NOT_FOUND));
412 EXPECT_THAT(resolver.GetWorkerIndex("/worker/task/1:23456"),
413 StatusIs(error::NOT_FOUND));
414 EXPECT_THAT(resolver.GetWorkerIndex("/worker/task/2:12345"),
415 StatusIs(error::NOT_FOUND));
416 EXPECT_THAT(resolver.GetWorkerIndex("/worker/task/3:45678"),
417 StatusIs(error::NOT_FOUND));
418
419 TF_EXPECT_OK(resolver.ValidateWorker("/worker/task/2:12345"));
420 TF_EXPECT_OK(resolver.ValidateWorker("/worker/task/1:23456"));
421 TF_EXPECT_OK(resolver.ValidateWorker("/worker/task/0:34567"));
422 EXPECT_THAT(resolver.ValidateWorker("/worker/task/3:45678"),
423 StatusIs(error::FAILED_PRECONDITION,
424 HasSubstr("The worker's address is not configured")));
425 resolver.AddWorker("/worker/task/3:45678");
426 resolver.AddWorker("/worker/task/2:12345");
427 resolver.AddWorker("/worker/task/1:23456");
428 resolver.AddWorker("/worker/task/0:34567");
429
430 EXPECT_THAT(resolver.GetWorkerIndex("/worker/task/0:34567"), IsOkAndHolds(0));
431 EXPECT_THAT(resolver.GetWorkerIndex("/worker/task/1:23456"), IsOkAndHolds(1));
432 EXPECT_THAT(resolver.GetWorkerIndex("/worker/task/2:12345"), IsOkAndHolds(2));
433 EXPECT_THAT(
434 resolver.GetWorkerIndex("/worker/task/3:45678"),
435 StatusIs(error::NOT_FOUND,
436 HasSubstr(
437 "Worker /worker/task/3:45678 is not in the workers list.")));
438 }
439
TEST(WorkerIndexResolverTest,MultipleWorkersInOneHost)440 TEST(WorkerIndexResolverTest, MultipleWorkersInOneHost) {
441 WorkerIndexResolver resolver(
442 std::vector<std::string>{"localhost", "localhost", "localhost"});
443 TF_EXPECT_OK(resolver.ValidateWorker("localhost:12345"));
444 resolver.AddWorker("localhost:12345");
445 TF_EXPECT_OK(resolver.ValidateWorker("localhost:23456"));
446 resolver.AddWorker("localhost:23456");
447 TF_EXPECT_OK(resolver.ValidateWorker("localhost:34567"));
448 resolver.AddWorker("localhost:34567");
449 EXPECT_THAT(resolver.GetWorkerIndex("localhost:12345"), IsOkAndHolds(0));
450 EXPECT_THAT(resolver.GetWorkerIndex("localhost:23456"), IsOkAndHolds(1));
451 EXPECT_THAT(resolver.GetWorkerIndex("localhost:34567"), IsOkAndHolds(2));
452 }
453
TEST(WorkerIndexResolverTest,MoreWorkersThanConfigured)454 TEST(WorkerIndexResolverTest, MoreWorkersThanConfigured) {
455 WorkerIndexResolver resolver(std::vector<std::string>{
456 "localhost:%port%", "localhost:%port%", "localhost:%port%"});
457 TF_EXPECT_OK(resolver.ValidateWorker("localhost:12345"));
458 resolver.AddWorker("localhost:12345");
459 TF_EXPECT_OK(resolver.ValidateWorker("localhost:23456"));
460 resolver.AddWorker("localhost:23456");
461 TF_EXPECT_OK(resolver.ValidateWorker("localhost:34567"));
462 resolver.AddWorker("localhost:34567");
463 TF_EXPECT_OK(resolver.ValidateWorker("localhost:12345"));
464 resolver.AddWorker("localhost:12345");
465 TF_EXPECT_OK(resolver.ValidateWorker("localhost:23456"));
466 resolver.AddWorker("localhost:23456");
467 TF_EXPECT_OK(resolver.ValidateWorker("localhost:34567"));
468 resolver.AddWorker("localhost:34567");
469 EXPECT_THAT(resolver.ValidateWorker("localhost:45678"),
470 StatusIs(error::FAILED_PRECONDITION,
471 HasSubstr("already running at the configured host")));
472 EXPECT_THAT(resolver.ValidateWorker("localhost:56789"),
473 StatusIs(error::FAILED_PRECONDITION,
474 HasSubstr("already running at the configured host")));
475 }
476
TEST(WorkerIndexResolverTest,WorkerNotConfigured)477 TEST(WorkerIndexResolverTest, WorkerNotConfigured) {
478 WorkerIndexResolver resolver(std::vector<std::string>{""});
479 EXPECT_THAT(resolver.GetWorkerIndex("localhost:12345"),
480 StatusIs(error::NOT_FOUND));
481 EXPECT_THAT(resolver.ValidateWorker("localhost:12345"),
482 StatusIs(error::FAILED_PRECONDITION,
483 HasSubstr("The worker's address is not configured")));
484 resolver.AddWorker("localhost:12345");
485 EXPECT_THAT(resolver.GetWorkerIndex("localhost:12345"),
486 StatusIs(error::NOT_FOUND));
487 }
488 } // namespace
489 } // namespace data
490 } // namespace tensorflow
491