xref: /aosp_15_r20/external/tensorflow/tensorflow/core/data/service/auto_shard_rewriter_test.cc (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include "tensorflow/core/data/service/auto_shard_rewriter.h"
16 
17 #include <string>
18 
19 #include "absl/strings/string_view.h"
20 #include "absl/strings/substitute.h"
21 #include "tensorflow/core/data/service/common.pb.h"
22 #include "tensorflow/core/data/service/test_util.h"
23 #include "tensorflow/core/framework/dataset_options.pb.h"
24 #include "tensorflow/core/framework/function_testlib.h"
25 #include "tensorflow/core/framework/graph.pb.h"
26 #include "tensorflow/core/framework/node_def.pb.h"
27 #include "tensorflow/core/framework/tensor.pb.h"
28 #include "tensorflow/core/framework/tensor_testutil.h"
29 #include "tensorflow/core/framework/types.pb.h"
30 #include "tensorflow/core/lib/core/status_test_util.h"
31 #include "tensorflow/core/platform/errors.h"
32 #include "tensorflow/core/platform/status_matchers.h"
33 #include "tensorflow/core/platform/statusor.h"
34 #include "tensorflow/core/platform/test.h"
35 #include "tensorflow/core/platform/tstring.h"
36 #include "tensorflow/core/platform/types.h"
37 #include "tensorflow/core/protobuf/data_service.pb.h"
38 #include "tensorflow/core/protobuf/error_codes.pb.h"
39 
40 namespace tensorflow {
41 namespace data {
42 namespace {
43 
44 using ::tensorflow::data::testing::EqualsProto;
45 using ::tensorflow::data::testing::RangeDatasetWithShardHint;
46 using ::tensorflow::data::testing::RangeSquareDataset;
47 using ::tensorflow::testing::IsOkAndHolds;
48 using ::tensorflow::testing::StatusIs;
49 using ::testing::HasSubstr;
50 using ::testing::SizeIs;
51 
GetNode(const GraphDef & graph_def,absl::string_view name)52 StatusOr<NodeDef> GetNode(const GraphDef& graph_def, absl::string_view name) {
53   for (const NodeDef& node : graph_def.node()) {
54     if (node.name() == name) {
55       return node;
56     }
57   }
58   return errors::NotFound(absl::Substitute("Node $0 not found in graph $1.",
59                                            name, graph_def.ShortDebugString()));
60 }
61 
GetValue(const GraphDef & graph_def,absl::string_view name)62 StatusOr<int64_t> GetValue(const GraphDef& graph_def, absl::string_view name) {
63   for (const NodeDef& node : graph_def.node()) {
64     if (node.name() == name) {
65       return node.attr().at("value").tensor().int64_val()[0];
66     }
67   }
68   return errors::NotFound(absl::Substitute("Node $0 not found in graph $1.",
69                                            name, graph_def.ShortDebugString()));
70 }
71 
GetTaskDef(const ProcessingModeDef::ShardingPolicy sharding_policy,const int64_t num_workers,const int64_t worker_index)72 TaskDef GetTaskDef(const ProcessingModeDef::ShardingPolicy sharding_policy,
73                    const int64_t num_workers, const int64_t worker_index) {
74   TaskDef task_def;
75   task_def.mutable_processing_mode_def()->set_sharding_policy(sharding_policy);
76   task_def.set_num_workers(num_workers);
77   task_def.set_worker_index(worker_index);
78   return task_def;
79 }
80 
TEST(AutoShardRewriterTest,AutoShard)81 TEST(AutoShardRewriterTest, AutoShard) {
82   TaskDef task_def = GetTaskDef(ProcessingModeDef::FILE_OR_DATA,
83                                 /*num_workers=*/3, /*worker_index=*/1);
84   TF_ASSERT_OK_AND_ASSIGN(AutoShardRewriter rewriter,
85                           AutoShardRewriter::Create(task_def));
86 
87   DatasetDef dataset = RangeSquareDataset(10);
88   TF_ASSERT_OK_AND_ASSIGN(GraphDef rewritten_graph,
89                           rewriter.ApplyAutoShardRewrite(dataset.graph()));
90   TF_ASSERT_OK_AND_ASSIGN(NodeDef shard_node,
91                           GetNode(rewritten_graph, "ShardDataset"));
92   ASSERT_THAT(shard_node.input(), SizeIs(3));
93   EXPECT_THAT(GetValue(rewritten_graph, shard_node.input(1)), IsOkAndHolds(3));
94   EXPECT_THAT(GetValue(rewritten_graph, shard_node.input(2)), IsOkAndHolds(1));
95 }
96 
TEST(AutoShardRewriterTest,ShardByData)97 TEST(AutoShardRewriterTest, ShardByData) {
98   TaskDef task_def = GetTaskDef(ProcessingModeDef::DATA, /*num_workers=*/3,
99                                 /*worker_index=*/1);
100   TF_ASSERT_OK_AND_ASSIGN(AutoShardRewriter rewriter,
101                           AutoShardRewriter::Create(task_def));
102 
103   DatasetDef dataset = RangeSquareDataset(10);
104   TF_ASSERT_OK_AND_ASSIGN(GraphDef rewritten_graph,
105                           rewriter.ApplyAutoShardRewrite(dataset.graph()));
106   TF_ASSERT_OK_AND_ASSIGN(NodeDef shard_node,
107                           GetNode(rewritten_graph, "ShardDataset"));
108   ASSERT_THAT(shard_node.input(), SizeIs(3));
109   EXPECT_THAT(GetValue(rewritten_graph, shard_node.input(1)), IsOkAndHolds(3));
110   EXPECT_THAT(GetValue(rewritten_graph, shard_node.input(2)), IsOkAndHolds(1));
111 }
112 
TEST(AutoShardRewriterTest,ShardByFile)113 TEST(AutoShardRewriterTest, ShardByFile) {
114   TaskDef task_def = GetTaskDef(ProcessingModeDef::FILE, /*num_workers=*/3,
115                                 /*worker_index=*/1);
116   TF_ASSERT_OK_AND_ASSIGN(AutoShardRewriter rewriter,
117                           AutoShardRewriter::Create(task_def));
118 
119   DatasetDef dataset = RangeSquareDataset(10);
120   EXPECT_THAT(rewriter.ApplyAutoShardRewrite(dataset.graph()),
121               StatusIs(error::NOT_FOUND,
122                        HasSubstr("Found an unshardable source dataset")));
123 }
124 
TEST(AutoShardRewriterTest,ShardByHint)125 TEST(AutoShardRewriterTest, ShardByHint) {
126   TaskDef task_def = GetTaskDef(ProcessingModeDef::HINT, /*num_workers=*/3,
127                                 /*worker_index=*/1);
128   TF_ASSERT_OK_AND_ASSIGN(AutoShardRewriter rewriter,
129                           AutoShardRewriter::Create(task_def));
130 
131   DatasetDef dataset = RangeDatasetWithShardHint(10);
132   TF_ASSERT_OK_AND_ASSIGN(GraphDef rewritten_graph,
133                           rewriter.ApplyAutoShardRewrite(dataset.graph()));
134   TF_ASSERT_OK_AND_ASSIGN(NodeDef shard_node,
135                           GetNode(rewritten_graph, "ShardDataset"));
136   ASSERT_THAT(shard_node.input(), SizeIs(3));
137   EXPECT_THAT(GetValue(rewritten_graph, shard_node.input(1)), IsOkAndHolds(3));
138   EXPECT_THAT(GetValue(rewritten_graph, shard_node.input(2)), IsOkAndHolds(1));
139 }
140 
TEST(AutoShardRewriterTest,NoShard)141 TEST(AutoShardRewriterTest, NoShard) {
142   TaskDef task_def =
143       GetTaskDef(ProcessingModeDef::OFF, /*num_workers=*/3, /*worker_index=*/1);
144   TF_ASSERT_OK_AND_ASSIGN(AutoShardRewriter rewriter,
145                           AutoShardRewriter::Create(task_def));
146 
147   DatasetDef dataset = RangeSquareDataset(10);
148   EXPECT_THAT(rewriter.ApplyAutoShardRewrite(dataset.graph()),
149               IsOkAndHolds(EqualsProto(dataset.graph())));
150 }
151 
TEST(AutoShardRewriterTest,EmptyDataset)152 TEST(AutoShardRewriterTest, EmptyDataset) {
153   TaskDef task_def =
154       GetTaskDef(ProcessingModeDef::FILE_OR_DATA, /*num_workers=*/3,
155                  /*worker_index=*/1);
156   TF_ASSERT_OK_AND_ASSIGN(AutoShardRewriter rewriter,
157                           AutoShardRewriter::Create(task_def));
158 
159   DatasetDef dataset = RangeSquareDataset(0);
160   TF_ASSERT_OK_AND_ASSIGN(GraphDef rewritten_graph,
161                           rewriter.ApplyAutoShardRewrite(dataset.graph()));
162   TF_ASSERT_OK_AND_ASSIGN(NodeDef shard_node,
163                           GetNode(rewritten_graph, "ShardDataset"));
164   ASSERT_THAT(shard_node.input(), SizeIs(3));
165   EXPECT_THAT(GetValue(rewritten_graph, shard_node.input(1)), IsOkAndHolds(3));
166   EXPECT_THAT(GetValue(rewritten_graph, shard_node.input(2)), IsOkAndHolds(1));
167 }
168 
TEST(AutoShardRewriterTest,NoWorkers)169 TEST(AutoShardRewriterTest, NoWorkers) {
170   TaskDef task_def =
171       GetTaskDef(ProcessingModeDef::FILE_OR_DATA, /*num_workers=*/0,
172                  /*worker_index=*/0);
173   TF_ASSERT_OK_AND_ASSIGN(AutoShardRewriter rewriter,
174                           AutoShardRewriter::Create(task_def));
175 
176   DatasetDef dataset = RangeSquareDataset(10);
177   EXPECT_THAT(rewriter.ApplyAutoShardRewrite(dataset.graph()),
178               StatusIs(error::INVALID_ARGUMENT,
179                        "num_workers should be >= 1, currently 0"));
180 }
181 
TEST(AutoShardRewriterTest,NoWorkersWhenShardIsOff)182 TEST(AutoShardRewriterTest, NoWorkersWhenShardIsOff) {
183   TaskDef task_def =
184       GetTaskDef(ProcessingModeDef::OFF, /*num_workers=*/0, /*worker_index=*/0);
185   TF_ASSERT_OK_AND_ASSIGN(AutoShardRewriter rewriter,
186                           AutoShardRewriter::Create(task_def));
187 
188   DatasetDef dataset = RangeSquareDataset(10);
189   EXPECT_THAT(rewriter.ApplyAutoShardRewrite(dataset.graph()),
190               IsOkAndHolds(EqualsProto(dataset.graph())));
191 }
192 
TEST(AutoShardRewriterTest,WorkerIndexOutOfRange)193 TEST(AutoShardRewriterTest, WorkerIndexOutOfRange) {
194   TaskDef task_def =
195       GetTaskDef(ProcessingModeDef::FILE_OR_DATA, /*num_workers=*/2,
196                  /*worker_index=*/5);
197   TF_ASSERT_OK_AND_ASSIGN(AutoShardRewriter rewriter,
198                           AutoShardRewriter::Create(task_def));
199 
200   DatasetDef dataset = RangeSquareDataset(10);
201   EXPECT_THAT(rewriter.ApplyAutoShardRewrite(dataset.graph()),
202               StatusIs(error::INVALID_ARGUMENT,
203                        "index should be >= 0 and < 2, currently 5"));
204 }
205 
TEST(WorkerIndexResolverTest,AddOneWorker)206 TEST(WorkerIndexResolverTest, AddOneWorker) {
207   WorkerIndexResolver resolver(std::vector<std::string>{"localhost"});
208   EXPECT_THAT(resolver.GetWorkerIndex("localhost:12345"),
209               StatusIs(error::NOT_FOUND));
210 
211   TF_EXPECT_OK(resolver.ValidateWorker("localhost:12345"));
212   resolver.AddWorker("localhost:12345");
213   EXPECT_THAT(resolver.GetWorkerIndex("localhost:12345"), IsOkAndHolds(0));
214 }
215 
TEST(WorkerIndexResolverTest,AddMultipleWorkers)216 TEST(WorkerIndexResolverTest, AddMultipleWorkers) {
217   WorkerIndexResolver resolver(std::vector<std::string>{
218       "/worker/task/0", "/worker/task/1", "/worker/task/2"});
219   TF_EXPECT_OK(resolver.ValidateWorker("/worker/task/2:12345"));
220   TF_EXPECT_OK(resolver.ValidateWorker("/worker/task/1:23456"));
221   TF_EXPECT_OK(resolver.ValidateWorker("/worker/task/0:34567"));
222   resolver.AddWorker("/worker/task/2:12345");
223   resolver.AddWorker("/worker/task/1:23456");
224   resolver.AddWorker("/worker/task/0:34567");
225   EXPECT_THAT(resolver.GetWorkerIndex("/worker/task/0:34567"), IsOkAndHolds(0));
226   EXPECT_THAT(resolver.GetWorkerIndex("/worker/task/1:23456"), IsOkAndHolds(1));
227   EXPECT_THAT(resolver.GetWorkerIndex("/worker/task/2:12345"), IsOkAndHolds(2));
228 }
229 
TEST(WorkerIndexResolverTest,NamedPorts)230 TEST(WorkerIndexResolverTest, NamedPorts) {
231   WorkerIndexResolver resolver(
232       std::vector<std::string>{"/worker/task/0:worker", "/worker/task/1:worker",
233                                "/worker/task/2:worker"});
234   TF_EXPECT_OK(resolver.ValidateWorker("/worker/task/2:worker"));
235   TF_EXPECT_OK(resolver.ValidateWorker("/worker/task/1:worker"));
236   TF_EXPECT_OK(resolver.ValidateWorker("/worker/task/0:worker"));
237   resolver.AddWorker("/worker/task/2:worker");
238   resolver.AddWorker("/worker/task/1:worker");
239   resolver.AddWorker("/worker/task/0:worker");
240   EXPECT_THAT(resolver.GetWorkerIndex("/worker/task/0:worker"),
241               IsOkAndHolds(0));
242   EXPECT_THAT(resolver.GetWorkerIndex("/worker/task/1:worker"),
243               IsOkAndHolds(1));
244   EXPECT_THAT(resolver.GetWorkerIndex("/worker/task/2:worker"),
245               IsOkAndHolds(2));
246 }
247 
TEST(WorkerIndexResolverTest,DynamicPorts)248 TEST(WorkerIndexResolverTest, DynamicPorts) {
249   WorkerIndexResolver resolver(std::vector<std::string>{
250       "/worker/task/0:%port_worker%", "/worker/task/1:%port_worker%",
251       "/worker/task/2:%port_worker%"});
252   TF_EXPECT_OK(resolver.ValidateWorker("/worker/task/2:worker"));
253   TF_EXPECT_OK(resolver.ValidateWorker("/worker/task/1:worker"));
254   TF_EXPECT_OK(resolver.ValidateWorker("/worker/task/0:worker"));
255   resolver.AddWorker("/worker/task/2:worker");
256   resolver.AddWorker("/worker/task/1:worker");
257   resolver.AddWorker("/worker/task/0:worker");
258   EXPECT_THAT(resolver.GetWorkerIndex("/worker/task/0:worker"),
259               IsOkAndHolds(0));
260   EXPECT_THAT(resolver.GetWorkerIndex("/worker/task/1:worker"),
261               IsOkAndHolds(1));
262   EXPECT_THAT(resolver.GetWorkerIndex("/worker/task/2:worker"),
263               IsOkAndHolds(2));
264 }
265 
TEST(WorkerIndexResolverTest,AnonymousPorts)266 TEST(WorkerIndexResolverTest, AnonymousPorts) {
267   WorkerIndexResolver resolver(
268       std::vector<std::string>{"/worker/task/0:%port%", "/worker/task/1:%port%",
269                                "/worker/task/2:%port%"});
270   TF_EXPECT_OK(resolver.ValidateWorker("/worker/task/2:10000"));
271   TF_EXPECT_OK(resolver.ValidateWorker("/worker/task/1:10001"));
272   TF_EXPECT_OK(resolver.ValidateWorker("/worker/task/0:10002"));
273   resolver.AddWorker("/worker/task/2:10000");
274   resolver.AddWorker("/worker/task/1:10001");
275   resolver.AddWorker("/worker/task/0:10002");
276   EXPECT_THAT(resolver.GetWorkerIndex("/worker/task/0:10002"), IsOkAndHolds(0));
277   EXPECT_THAT(resolver.GetWorkerIndex("/worker/task/1:10001"), IsOkAndHolds(1));
278   EXPECT_THAT(resolver.GetWorkerIndex("/worker/task/2:10000"), IsOkAndHolds(2));
279 }
280 
TEST(WorkerIndexResolverTest,NumericPorts)281 TEST(WorkerIndexResolverTest, NumericPorts) {
282   WorkerIndexResolver resolver(std::vector<std::string>{
283       "/worker/task/0:12345", "/worker/task/1:23456", "/worker/task/2:34567"});
284   EXPECT_THAT(resolver.GetWorkerIndex("/worker/task/0:12345"), IsOkAndHolds(0));
285   EXPECT_THAT(resolver.GetWorkerIndex("/worker/task/1:23456"), IsOkAndHolds(1));
286   EXPECT_THAT(resolver.GetWorkerIndex("/worker/task/2:34567"), IsOkAndHolds(2));
287 
288   // Adding duplicate workers is a no-op.
289   TF_EXPECT_OK(resolver.ValidateWorker("/worker/task/2:34567"));
290   TF_EXPECT_OK(resolver.ValidateWorker("/worker/task/1:23456"));
291   TF_EXPECT_OK(resolver.ValidateWorker("/worker/task/0:12345"));
292   resolver.AddWorker("/worker/task/2:34567");
293   resolver.AddWorker("/worker/task/1:23456");
294   resolver.AddWorker("/worker/task/0:12345");
295   EXPECT_THAT(resolver.GetWorkerIndex("/worker/task/0:12345"), IsOkAndHolds(0));
296   EXPECT_THAT(resolver.GetWorkerIndex("/worker/task/1:23456"), IsOkAndHolds(1));
297   EXPECT_THAT(resolver.GetWorkerIndex("/worker/task/2:34567"), IsOkAndHolds(2));
298 }
299 
TEST(WorkerIndexResolverTest,IPv6Addresses)300 TEST(WorkerIndexResolverTest, IPv6Addresses) {
301   WorkerIndexResolver resolver(std::vector<std::string>{
302       "[1080:0:0:0:8:800:200C:417A]", "[1080:0:0:0:8:800:200C:417B]",
303       "[1080:0:0:0:8:800:200C:417C]"});
304   TF_EXPECT_OK(resolver.ValidateWorker("[1080:0:0:0:8:800:200C:417A]:12345"));
305   TF_EXPECT_OK(resolver.ValidateWorker("[1080:0:0:0:8:800:200C:417B]:23456"));
306   TF_EXPECT_OK(resolver.ValidateWorker("[1080:0:0:0:8:800:200C:417C]:34567"));
307   resolver.AddWorker("[1080:0:0:0:8:800:200C:417A]:12345");
308   resolver.AddWorker("[1080:0:0:0:8:800:200C:417B]:23456");
309   resolver.AddWorker("[1080:0:0:0:8:800:200C:417C]:34567");
310   EXPECT_THAT(resolver.GetWorkerIndex("[1080:0:0:0:8:800:200C:417A]:12345"),
311               IsOkAndHolds(0));
312   EXPECT_THAT(resolver.GetWorkerIndex("[1080:0:0:0:8:800:200C:417B]:23456"),
313               IsOkAndHolds(1));
314   EXPECT_THAT(resolver.GetWorkerIndex("[1080:0:0:0:8:800:200C:417C]:34567"),
315               IsOkAndHolds(2));
316 }
317 
TEST(WorkerIndexResolverTest,IPv6AddressesWithDynamicPort)318 TEST(WorkerIndexResolverTest, IPv6AddressesWithDynamicPort) {
319   WorkerIndexResolver resolver(
320       std::vector<std::string>{"[1080:0:0:0:8:800:200C:417A]:%port%",
321                                "[1080:0:0:0:8:800:200C:417B]:%port%",
322                                "[1080:0:0:0:8:800:200C:417C]:%port%"});
323   TF_EXPECT_OK(resolver.ValidateWorker("[1080:0:0:0:8:800:200C:417A]:12345"));
324   TF_EXPECT_OK(resolver.ValidateWorker("[1080:0:0:0:8:800:200C:417B]:23456"));
325   TF_EXPECT_OK(resolver.ValidateWorker("[1080:0:0:0:8:800:200C:417C]:34567"));
326   resolver.AddWorker("[1080:0:0:0:8:800:200C:417A]:12345");
327   resolver.AddWorker("[1080:0:0:0:8:800:200C:417B]:23456");
328   resolver.AddWorker("[1080:0:0:0:8:800:200C:417C]:34567");
329   EXPECT_THAT(resolver.GetWorkerIndex("[1080:0:0:0:8:800:200C:417A]:12345"),
330               IsOkAndHolds(0));
331   EXPECT_THAT(resolver.GetWorkerIndex("[1080:0:0:0:8:800:200C:417B]:23456"),
332               IsOkAndHolds(1));
333   EXPECT_THAT(resolver.GetWorkerIndex("[1080:0:0:0:8:800:200C:417C]:34567"),
334               IsOkAndHolds(2));
335 }
336 
TEST(WorkerIndexResolverTest,AddressesWithProtocols)337 TEST(WorkerIndexResolverTest, AddressesWithProtocols) {
338   WorkerIndexResolver resolver(std::vector<std::string>{
339       "http://127.0.0.1", "http://127.0.0.1", "http://127.0.0.1"});
340   TF_EXPECT_OK(resolver.ValidateWorker("http://127.0.0.1:12345"));
341   TF_EXPECT_OK(resolver.ValidateWorker("http://127.0.0.1:23456"));
342   TF_EXPECT_OK(resolver.ValidateWorker("http://127.0.0.1:34567"));
343   resolver.AddWorker("http://127.0.0.1:12345");
344   resolver.AddWorker("http://127.0.0.1:23456");
345   resolver.AddWorker("http://127.0.0.1:34567");
346   EXPECT_THAT(resolver.GetWorkerIndex("http://127.0.0.1:12345"),
347               IsOkAndHolds(0));
348   EXPECT_THAT(resolver.GetWorkerIndex("http://127.0.0.1:23456"),
349               IsOkAndHolds(1));
350   EXPECT_THAT(resolver.GetWorkerIndex("http://127.0.0.1:34567"),
351               IsOkAndHolds(2));
352 }
353 
TEST(WorkerIndexResolverTest,AddressesWithProtocolsAndDynamicPorts)354 TEST(WorkerIndexResolverTest, AddressesWithProtocolsAndDynamicPorts) {
355   WorkerIndexResolver resolver(std::vector<std::string>{
356       "http://127.0.0.1:%port_name%", "http://127.0.0.1:%port_name%",
357       "http://127.0.0.1:%port_name%"});
358   TF_EXPECT_OK(resolver.ValidateWorker("http://127.0.0.1:12345"));
359   TF_EXPECT_OK(resolver.ValidateWorker("http://127.0.0.1:23456"));
360   TF_EXPECT_OK(resolver.ValidateWorker("http://127.0.0.1:34567"));
361   resolver.AddWorker("http://127.0.0.1:12345");
362   resolver.AddWorker("http://127.0.0.1:23456");
363   resolver.AddWorker("http://127.0.0.1:34567");
364   EXPECT_THAT(resolver.GetWorkerIndex("http://127.0.0.1:12345"),
365               IsOkAndHolds(0));
366   EXPECT_THAT(resolver.GetWorkerIndex("http://127.0.0.1:23456"),
367               IsOkAndHolds(1));
368   EXPECT_THAT(resolver.GetWorkerIndex("http://127.0.0.1:34567"),
369               IsOkAndHolds(2));
370 }
371 
TEST(WorkerIndexResolverTest,HostNameHasColons)372 TEST(WorkerIndexResolverTest, HostNameHasColons) {
373   WorkerIndexResolver resolver(
374       std::vector<std::string>{":worker:task:0:%port%", ":worker:task:1:%port%",
375                                ":worker:task:2:34567"});
376   TF_EXPECT_OK(resolver.ValidateWorker(":worker:task:0:12345"));
377   TF_EXPECT_OK(resolver.ValidateWorker(":worker:task:1:23456"));
378   TF_EXPECT_OK(resolver.ValidateWorker(":worker:task:2:34567"));
379   resolver.AddWorker(":worker:task:0:12345");
380   resolver.AddWorker(":worker:task:1:23456");
381   resolver.AddWorker(":worker:task:2:34567");
382   EXPECT_THAT(resolver.GetWorkerIndex(":worker:task:0:12345"), IsOkAndHolds(0));
383   EXPECT_THAT(resolver.GetWorkerIndex(":worker:task:1:23456"), IsOkAndHolds(1));
384   EXPECT_THAT(resolver.GetWorkerIndex(":worker:task:2:34567"), IsOkAndHolds(2));
385 }
386 
TEST(WorkerIndexResolverTest,ChangeWorkerPort)387 TEST(WorkerIndexResolverTest, ChangeWorkerPort) {
388   WorkerIndexResolver resolver(std::vector<std::string>{
389       "/worker/task/0", "/worker/task/1", "/worker/task/2"});
390   TF_EXPECT_OK(resolver.ValidateWorker("/worker/task/2:12345"));
391   TF_EXPECT_OK(resolver.ValidateWorker("/worker/task/1:23456"));
392   TF_EXPECT_OK(resolver.ValidateWorker("/worker/task/0:34567"));
393   resolver.AddWorker("/worker/task/2:12345");
394   resolver.AddWorker("/worker/task/1:23456");
395   resolver.AddWorker("/worker/task/0:34567");
396   EXPECT_THAT(resolver.ValidateWorker("/worker/task/0:99999"),
397               StatusIs(error::FAILED_PRECONDITION,
398                        HasSubstr("already running at the configured host")));
399   EXPECT_THAT(resolver.ValidateWorker("/worker/task/1:99999"),
400               StatusIs(error::FAILED_PRECONDITION,
401                        HasSubstr("already running at the configured host")));
402   EXPECT_THAT(resolver.ValidateWorker("/worker/task/2:99999"),
403               StatusIs(error::FAILED_PRECONDITION,
404                        HasSubstr("already running at the configured host")));
405 }
406 
TEST(WorkerIndexResolverTest,WorkerNotFound)407 TEST(WorkerIndexResolverTest, WorkerNotFound) {
408   WorkerIndexResolver resolver(std::vector<std::string>{
409       "/worker/task/0", "/worker/task/1", "/worker/task/2"});
410   EXPECT_THAT(resolver.GetWorkerIndex("/worker/task/0:34567"),
411               StatusIs(error::NOT_FOUND));
412   EXPECT_THAT(resolver.GetWorkerIndex("/worker/task/1:23456"),
413               StatusIs(error::NOT_FOUND));
414   EXPECT_THAT(resolver.GetWorkerIndex("/worker/task/2:12345"),
415               StatusIs(error::NOT_FOUND));
416   EXPECT_THAT(resolver.GetWorkerIndex("/worker/task/3:45678"),
417               StatusIs(error::NOT_FOUND));
418 
419   TF_EXPECT_OK(resolver.ValidateWorker("/worker/task/2:12345"));
420   TF_EXPECT_OK(resolver.ValidateWorker("/worker/task/1:23456"));
421   TF_EXPECT_OK(resolver.ValidateWorker("/worker/task/0:34567"));
422   EXPECT_THAT(resolver.ValidateWorker("/worker/task/3:45678"),
423               StatusIs(error::FAILED_PRECONDITION,
424                        HasSubstr("The worker's address is not configured")));
425   resolver.AddWorker("/worker/task/3:45678");
426   resolver.AddWorker("/worker/task/2:12345");
427   resolver.AddWorker("/worker/task/1:23456");
428   resolver.AddWorker("/worker/task/0:34567");
429 
430   EXPECT_THAT(resolver.GetWorkerIndex("/worker/task/0:34567"), IsOkAndHolds(0));
431   EXPECT_THAT(resolver.GetWorkerIndex("/worker/task/1:23456"), IsOkAndHolds(1));
432   EXPECT_THAT(resolver.GetWorkerIndex("/worker/task/2:12345"), IsOkAndHolds(2));
433   EXPECT_THAT(
434       resolver.GetWorkerIndex("/worker/task/3:45678"),
435       StatusIs(error::NOT_FOUND,
436                HasSubstr(
437                    "Worker /worker/task/3:45678 is not in the workers list.")));
438 }
439 
TEST(WorkerIndexResolverTest,MultipleWorkersInOneHost)440 TEST(WorkerIndexResolverTest, MultipleWorkersInOneHost) {
441   WorkerIndexResolver resolver(
442       std::vector<std::string>{"localhost", "localhost", "localhost"});
443   TF_EXPECT_OK(resolver.ValidateWorker("localhost:12345"));
444   resolver.AddWorker("localhost:12345");
445   TF_EXPECT_OK(resolver.ValidateWorker("localhost:23456"));
446   resolver.AddWorker("localhost:23456");
447   TF_EXPECT_OK(resolver.ValidateWorker("localhost:34567"));
448   resolver.AddWorker("localhost:34567");
449   EXPECT_THAT(resolver.GetWorkerIndex("localhost:12345"), IsOkAndHolds(0));
450   EXPECT_THAT(resolver.GetWorkerIndex("localhost:23456"), IsOkAndHolds(1));
451   EXPECT_THAT(resolver.GetWorkerIndex("localhost:34567"), IsOkAndHolds(2));
452 }
453 
TEST(WorkerIndexResolverTest,MoreWorkersThanConfigured)454 TEST(WorkerIndexResolverTest, MoreWorkersThanConfigured) {
455   WorkerIndexResolver resolver(std::vector<std::string>{
456       "localhost:%port%", "localhost:%port%", "localhost:%port%"});
457   TF_EXPECT_OK(resolver.ValidateWorker("localhost:12345"));
458   resolver.AddWorker("localhost:12345");
459   TF_EXPECT_OK(resolver.ValidateWorker("localhost:23456"));
460   resolver.AddWorker("localhost:23456");
461   TF_EXPECT_OK(resolver.ValidateWorker("localhost:34567"));
462   resolver.AddWorker("localhost:34567");
463   TF_EXPECT_OK(resolver.ValidateWorker("localhost:12345"));
464   resolver.AddWorker("localhost:12345");
465   TF_EXPECT_OK(resolver.ValidateWorker("localhost:23456"));
466   resolver.AddWorker("localhost:23456");
467   TF_EXPECT_OK(resolver.ValidateWorker("localhost:34567"));
468   resolver.AddWorker("localhost:34567");
469   EXPECT_THAT(resolver.ValidateWorker("localhost:45678"),
470               StatusIs(error::FAILED_PRECONDITION,
471                        HasSubstr("already running at the configured host")));
472   EXPECT_THAT(resolver.ValidateWorker("localhost:56789"),
473               StatusIs(error::FAILED_PRECONDITION,
474                        HasSubstr("already running at the configured host")));
475 }
476 
TEST(WorkerIndexResolverTest,WorkerNotConfigured)477 TEST(WorkerIndexResolverTest, WorkerNotConfigured) {
478   WorkerIndexResolver resolver(std::vector<std::string>{""});
479   EXPECT_THAT(resolver.GetWorkerIndex("localhost:12345"),
480               StatusIs(error::NOT_FOUND));
481   EXPECT_THAT(resolver.ValidateWorker("localhost:12345"),
482               StatusIs(error::FAILED_PRECONDITION,
483                        HasSubstr("The worker's address is not configured")));
484   resolver.AddWorker("localhost:12345");
485   EXPECT_THAT(resolver.GetWorkerIndex("localhost:12345"),
486               StatusIs(error::NOT_FOUND));
487 }
488 }  // namespace
489 }  // namespace data
490 }  // namespace tensorflow
491