1 /* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/core/distributed_runtime/coordination/coordination_service.h"
17
18 #include <algorithm>
19 #include <iterator>
20 #include <string>
21 #include <utility>
22 #include <vector>
23
24 #include "absl/container/flat_hash_map.h"
25 #include "absl/container/flat_hash_set.h"
26 #include "absl/strings/str_cat.h"
27 #include "absl/strings/string_view.h"
28 #include "absl/synchronization/notification.h"
29 #include "absl/time/time.h"
30 #include "tensorflow/compiler/xla/pjrt/distributed/protocol.pb.h"
31 #include "tensorflow/core/distributed_runtime/call_options.h"
32 #include "tensorflow/core/distributed_runtime/coordination/coordination_client.h"
33 #include "tensorflow/core/distributed_runtime/coordination/coordination_service_error_util.h"
34 #include "tensorflow/core/platform/env.h"
35 #include "tensorflow/core/platform/errors.h"
36 #include "tensorflow/core/platform/macros.h"
37 #include "tensorflow/core/platform/mutex.h"
38 #include "tensorflow/core/platform/random.h"
39 #include "tensorflow/core/platform/status.h"
40 #include "tensorflow/core/platform/strcat.h"
41 #include "tensorflow/core/platform/thread_annotations.h"
42 #include "tensorflow/core/protobuf/cluster.pb.h"
43 #include "tensorflow/core/protobuf/config.pb.h"
44 #include "tensorflow/core/protobuf/coordination_config.pb.h"
45 #include "tensorflow/core/protobuf/coordination_service.pb.h"
46 #include "tensorflow/core/protobuf/tensorflow_server.pb.h"
47 #include "tensorflow/core/util/device_name_utils.h"
48
49 namespace tensorflow {
50 namespace {
51
52 constexpr absl::Duration kDevicePropagationTimeout = absl::Hours(1);
53 constexpr int kDefaultHeartbeatTimeoutMs = 10 * 1000; // 10 seconds
54 constexpr int kServiceToClientTimeoutMs = 10 * 1000; // 10 seconds
55 constexpr size_t kOngoingBarriersSoftLimit = 20;
56 constexpr char kHealthCheckThread[] = "CoordinationServiceHealthCheck";
57
GetTaskName(absl::string_view job_name,int task_id)58 std::string GetTaskName(absl::string_view job_name, int task_id) {
59 return strings::StrCat("/job:", job_name, "/replica:", 0, "/task:", task_id);
60 }
61
GetTaskName(const CoordinatedTask & task)62 std::string GetTaskName(const CoordinatedTask& task) {
63 return GetTaskName(task.job_name(), task.task_id());
64 }
65
GetTaskFromName(absl::string_view task_name)66 CoordinatedTask GetTaskFromName(absl::string_view task_name) {
67 DeviceNameUtils::ParsedName parsed;
68 DeviceNameUtils::ParseFullName(task_name, &parsed);
69 CoordinatedTask task;
70 task.set_job_name(parsed.job);
71 task.set_task_id(parsed.task);
72 return task;
73 }
74
is_multi_client_leader(const ServerDef & server_def)75 bool is_multi_client_leader(const ServerDef& server_def) {
76 const auto& config = server_def.default_session_config();
77 const std::string& leader =
78 config.experimental().coordination_config().service_leader();
79 const std::string& collective_leader =
80 config.experimental().collective_group_leader();
81 DeviceNameUtils::ParsedName leader_pn;
82 if (!leader.empty()) {
83 DeviceNameUtils::ParseFullName(leader, &leader_pn);
84 } else if (!collective_leader.empty()) {
85 LOG(INFO) << "No coordination leader is set, using the collective leader "
86 << collective_leader;
87 DeviceNameUtils::ParseFullName(collective_leader, &leader_pn);
88 } else {
89 LOG(INFO) << "No coordination leader is set, using the default /job:"
90 << server_def.job_name() << "/replica:0/task:0";
91 return server_def.task_index() == 0;
92 }
93 return server_def.job_name() == leader_pn.job &&
94 server_def.task_index() == leader_pn.task;
95 }
96
97 // Convenience structs to allow using CoordinatedTask as container keys.
98 struct CoordinatedTaskHash {
operator ()tensorflow::__anonf25d665d0111::CoordinatedTaskHash99 uint64_t operator()(const CoordinatedTask& task) const {
100 return absl::HashOf(task.job_name(), task.task_id());
101 }
102 };
103 struct CoordinatedTaskEqual {
operator ()tensorflow::__anonf25d665d0111::CoordinatedTaskEqual104 bool operator()(const CoordinatedTask& lhs,
105 const CoordinatedTask& rhs) const {
106 return lhs.job_name() == rhs.job_name() && lhs.task_id() == rhs.task_id();
107 }
108 };
109
110 // Standalone implementation of the coordination service.
111 class CoordinationServiceStandaloneImpl : public CoordinationServiceInterface {
112 public:
113 CoordinationServiceStandaloneImpl(
114 std::unique_ptr<CoordinationClientCache> client_cache, Env* env,
115 const ServerDef& server_def);
~CoordinationServiceStandaloneImpl()116 ~CoordinationServiceStandaloneImpl() override { Stop(); }
117
118 Status RegisterTask(const CoordinatedTask& task,
119 uint64_t incarnation) override;
120 void WaitForAllTasks(const CoordinatedTask& task,
121 const CoordinationServiceDeviceInfo& devices,
122 StatusCallback done) override;
123 void ShutdownTaskAsync(const CoordinatedTask& task,
124 StatusCallback done) override;
125 Status ResetTask(const CoordinatedTask& task) override;
126 Status RecordHeartbeat(const CoordinatedTask& task,
127 uint64_t incarnation) override;
128 Status ReportTaskError(const CoordinatedTask& task, Status error) override;
129 Status InsertKeyValue(const std::string& key,
130 const std::string& value) override;
131 void GetKeyValueAsync(const std::string& key,
132 StatusOrValueCallback done) override;
133 StatusOr<std::string> TryGetKeyValue(const std::string& key) override;
134 std::vector<KeyValueEntry> GetKeyValueDir(
135 absl::string_view directory_key) override;
136 Status DeleteKeyValue(const std::string& key) override;
137 void BarrierAsync(const std::string& barrier_id, absl::Duration timeout,
138 const CoordinatedTask& task,
139 const std::vector<CoordinatedTask>& participating_tasks,
140 StatusCallback done) override;
141 Status CancelBarrier(const std::string& barrier_id,
142 const CoordinatedTask& task) override;
143
144 private:
145 const CoordinationServiceDeviceInfo& ListClusterDevices() override
146 TF_EXCLUSIVE_LOCKS_REQUIRED(state_mu_);
147 uint64_t GetServiceIncarnation() override;
148 void StartCheckStaleness(); // Checks both heartbeat and barrier timeouts.
149 void Stop(bool shut_staleness_thread = true);
150 // Report service error to a specified task.
151 void ReportServiceErrorToTaskAsync(const CoordinatedTask& destination_task,
152 Status error);
153 // Report error from a task to all other connected tasks if the task is not
154 // recoverable.
155 // Note: SetTaskError() must be called before propagating its error.
156 void PropagateError(const CoordinatedTask& source_task,
157 bool is_reported_by_task = false)
158 TF_LOCKS_EXCLUDED(state_mu_);
159 void SetTaskError(absl::string_view task_name, Status error)
160 TF_EXCLUSIVE_LOCKS_REQUIRED(state_mu_);
161 void SetXlaGlobalDeviceIds() TF_EXCLUSIVE_LOCKS_REQUIRED(state_mu_);
162 Status DisconnectTask(const CoordinatedTask& task)
163 TF_EXCLUSIVE_LOCKS_REQUIRED(state_mu_);
164
165 struct BarrierState {
166 bool passed = false;
167 Status result = errors::Unknown(
168 "Invalid barrier result."); // Only valid if `passed` is true.
169 uint64_t deadline_in_micros = 0;
170 int num_pending_tasks = 0;
171 // Specifies which tasks have called the barrier so far.
172 absl::flat_hash_map<CoordinatedTask, bool, CoordinatedTaskHash,
173 CoordinatedTaskEqual>
174 tasks_at_barrier;
175 std::vector<StatusCallback> done_callbacks;
176 };
177 void PassBarrier(absl::string_view barrier_id, Status result,
178 BarrierState* barrier)
179 TF_EXCLUSIVE_LOCKS_REQUIRED(state_mu_);
180 // Check if participating tasks are specified correctly across barrier calls.
181 bool ValidateTaskArgs(
182 const std::vector<CoordinatedTask>& tasks_args,
183 const absl::flat_hash_map<CoordinatedTask, bool, CoordinatedTaskHash,
184 CoordinatedTaskEqual>& tasks_at_barrier,
185 int64_t cluster_size);
186 bool isRecoverableJob(const absl::string_view task_name) const;
187
188 class TaskState {
189 public:
190 // Task state maintained on the coordination service side.
191 // State transition:
192 // Register Heartbeat
193 // DISCONNECTED -------> CONNECTED --------> ERROR (timeout)
194 // | ReportError
195 // +--------------> ERROR
196 //
197 // When task state becomes ERROR, propagate this status to other CONNECTED
198 // tasks in the cluster.
199
GetState()200 CoordinatedTaskState GetState() { return state_; }
GetStatus()201 Status GetStatus() { return status_; }
202 void SetConnected(uint64_t task_incarnation);
203 void Disconnect(uint64_t grace_period_duration_us);
204 Status RecordHeartbeat(uint64_t task_incarnation);
205 int64_t TimeSinceLastHeartbeatMs();
206 // This denotes the deadline after which we stop accepting heartbeats from a
207 // disconnected task. This grace period accounts for the lag time between
208 // the service recording the state change and the agent stopping heartbeats.
209 uint64_t GetDisconnectedGracePeriodMicros();
210 void SetError(Status status);
GetDeviceInfoCollected()211 bool GetDeviceInfoCollected() { return device_info_collected_; }
MarkDeviceInfoCollected()212 void MarkDeviceInfoCollected() { device_info_collected_ = true; }
213 absl::flat_hash_set<std::string> GetOngoingBarriers();
214 void JoinBarrier(absl::string_view barrier_id);
215 void ExitBarrier(absl::string_view barrier_id);
216
217 private:
218 // Incarnation ID for CPU:0 on remote task.
219 uint64_t task_incarnation_ = 0;
220
221 CoordinatedTaskState state_ = CoordinatedTaskState::TASKSTATE_DISCONNECTED;
222 Status status_;
223 mutex last_heartbeat_mu_;
224 uint64_t last_heartbeat_us_ TF_GUARDED_BY(last_heartbeat_mu_);
225 // This denotes the deadline after which we stop accepting heartbeats from a
226 // disconnected task. This grace period accounts for the lag time between
227 // the service recording the state change and the agent stopping heartbeats.
228 uint64_t disconnect_grace_period_us_ = 0;
229 // Checks if task has called WaitForAllTasks() previously, which gathers the
230 // local device info.
231 bool device_info_collected_ = false;
232 // For now, we assume there won't be many simultaneous barriers so we simply
233 // use a set.
234 absl::flat_hash_set<std::string> ongoing_barriers_for_task_;
235 };
236
237 std::unique_ptr<CoordinationClientCache> client_cache_;
238 Env& env_;
239 const uint64_t service_incarnation_ = random::New64();
240 const uint64_t heartbeat_timeout_ms_;
241 const absl::Duration shutdown_barrier_timeout_;
242
243 const std::string device_propagation_barrier_id_ =
244 absl::StrCat("WaitForAllTasks::", std::to_string(service_incarnation_));
245 const std::string shutdown_barrier_id_ =
246 absl::StrCat("Shutdown::", std::to_string(service_incarnation_));
247
248 mutex state_mu_;
249 absl::flat_hash_map<std::string, std::unique_ptr<TaskState>> cluster_state_
250 TF_GUARDED_BY(state_mu_);
251 CoordinationServiceDeviceInfo cluster_devices_ TF_GUARDED_BY(state_mu_);
252
253 mutex kv_mu_;
254 // Ordered map to store config key-values
255 std::map<std::string, std::string> kv_store_ TF_GUARDED_BY(kv_mu_);
256 absl::flat_hash_map<std::string, std::vector<StatusOrValueCallback>> get_cb_
257 TF_GUARDED_BY(kv_mu_);
258
259 mutex check_staleness_thread_shutdown_mu_;
260 condition_variable check_staleness_thread_cv_;
261 bool shutting_down_ TF_GUARDED_BY(check_staleness_thread_shutdown_mu_) =
262 false;
263 std::unique_ptr<Thread> check_staleness_thread_;
264
265 absl::flat_hash_map<std::string, BarrierState> barriers_
266 TF_GUARDED_BY(state_mu_);
267 // For now, we assume there won't be many simultaneous barriers so we simply
268 // use a set.
269 absl::flat_hash_set<std::string> ongoing_barriers_ TF_GUARDED_BY(state_mu_);
270
271 absl::flat_hash_set<std::string> recoverable_jobs_;
272
273 TF_DISALLOW_COPY_AND_ASSIGN(CoordinationServiceStandaloneImpl);
274 };
275
SetConnected(uint64_t task_incarnation)276 void CoordinationServiceStandaloneImpl::TaskState::SetConnected(
277 uint64_t task_incarnation) {
278 state_ = CoordinatedTaskState::TASKSTATE_CONNECTED;
279 status_ = OkStatus();
280 task_incarnation_ = task_incarnation;
281 mutex_lock l(last_heartbeat_mu_);
282 last_heartbeat_us_ = Env::Default()->NowMicros();
283 }
284
Disconnect(uint64_t grace_period_duration_us)285 void CoordinationServiceStandaloneImpl::TaskState::Disconnect(
286 uint64_t grace_period_duration_us) {
287 disconnect_grace_period_us_ =
288 Env::Default()->NowMicros() + grace_period_duration_us;
289 state_ = CoordinatedTaskState::TASKSTATE_DISCONNECTED;
290 status_ = OkStatus();
291 }
292
SetError(const Status status)293 void CoordinationServiceStandaloneImpl::TaskState::SetError(
294 const Status status) {
295 if (state_ == CoordinatedTaskState::TASKSTATE_ERROR) return;
296 state_ = CoordinatedTaskState::TASKSTATE_ERROR;
297 status_ = status;
298 }
299
RecordHeartbeat(uint64_t task_incarnation)300 Status CoordinationServiceStandaloneImpl::TaskState::RecordHeartbeat(
301 uint64_t task_incarnation) {
302 if (!status_.ok()) return status_;
303 if (task_incarnation != task_incarnation_) {
304 return MakeCoordinationError(errors::Aborted(
305 "Incarnation ID mismatch: expecting ", task_incarnation_, " but got ",
306 task_incarnation, ". This means the remote task has restarted."));
307 }
308 mutex_lock l(last_heartbeat_mu_);
309 last_heartbeat_us_ = Env::Default()->NowMicros();
310 return OkStatus();
311 }
312
313 int64_t
TimeSinceLastHeartbeatMs()314 CoordinationServiceStandaloneImpl::TaskState::TimeSinceLastHeartbeatMs() {
315 mutex_lock l(last_heartbeat_mu_);
316 return (Env::Default()->NowMicros() - last_heartbeat_us_) / 1000;
317 }
318
319 uint64_t CoordinationServiceStandaloneImpl::TaskState::
GetDisconnectedGracePeriodMicros()320 GetDisconnectedGracePeriodMicros() {
321 return disconnect_grace_period_us_;
322 }
323
324 absl::flat_hash_set<std::string>
GetOngoingBarriers()325 CoordinationServiceStandaloneImpl::TaskState::GetOngoingBarriers() {
326 return ongoing_barriers_for_task_;
327 }
328
JoinBarrier(absl::string_view barrier_id)329 void CoordinationServiceStandaloneImpl::TaskState::JoinBarrier(
330 absl::string_view barrier_id) {
331 ongoing_barriers_for_task_.emplace(barrier_id);
332 }
333
ExitBarrier(absl::string_view barrier_id)334 void CoordinationServiceStandaloneImpl::TaskState::ExitBarrier(
335 absl::string_view barrier_id) {
336 ongoing_barriers_for_task_.erase(barrier_id);
337 }
CoordinationServiceStandaloneImpl(std::unique_ptr<CoordinationClientCache> client_cache,Env * env,const ServerDef & server_def)338 CoordinationServiceStandaloneImpl::CoordinationServiceStandaloneImpl(
339 std::unique_ptr<CoordinationClientCache> client_cache, Env* env,
340 const ServerDef& server_def)
341 : client_cache_(std::move(client_cache)),
342 env_(*env),
343 heartbeat_timeout_ms_([&server_def]() -> uint64_t {
344 const auto& configs = server_def.default_session_config()
345 .experimental()
346 .coordination_config();
347 return configs.heartbeat_timeout_in_ms() > 0
348 ? configs.heartbeat_timeout_in_ms()
349 : kDefaultHeartbeatTimeoutMs;
350 }()),
351 shutdown_barrier_timeout_(
352 absl::Milliseconds(server_def.default_session_config()
353 .experimental()
354 .coordination_config()
355 .shutdown_barrier_timeout_in_ms())) {
356 const auto& configs =
357 server_def.default_session_config().experimental().coordination_config();
358 const std::unordered_set<std::string> coordinated_jobs(
359 configs.coordinated_jobs().cbegin(), configs.coordinated_jobs().cend());
360 recoverable_jobs_ = absl::flat_hash_set<std::string>(
361 configs.recoverable_jobs().cbegin(), configs.recoverable_jobs().cend());
362 const auto& cluster_def = server_def.cluster();
363 for (const auto& job : cluster_def.job()) {
364 // If `coordinated_jobs` is specified, skip jobs that are not included there
365 if (!coordinated_jobs.empty() &&
366 coordinated_jobs.find(job.name()) == coordinated_jobs.end()) {
367 continue;
368 }
369 for (const auto& task : job.tasks()) {
370 const std::string& task_name = GetTaskName(job.name(), task.first);
371 cluster_state_.emplace(task_name, std::make_unique<TaskState>());
372 }
373 }
374 StartCheckStaleness();
375 }
376
377 // Checks both heartbeat and barrier timeouts in the same thread, since threads
378 // are a constrained resource.
StartCheckStaleness()379 void CoordinationServiceStandaloneImpl::StartCheckStaleness() {
380 check_staleness_thread_.reset(
381 env_.StartThread({}, kHealthCheckThread, [this]() {
382 const bool has_service_to_client_connection = client_cache_ != nullptr;
383 // Used to store stale tasks and barriers.
384 std::vector<absl::string_view> stale_task_names;
385 absl::flat_hash_map<std::string, BarrierState*> expired_barriers;
386 while (true) {
387 {
388 mutex_lock l(check_staleness_thread_shutdown_mu_);
389 check_staleness_thread_cv_.wait_for(l, std::chrono::seconds(1));
390 if (shutting_down_) {
391 return;
392 }
393 }
394 // Heartbeat check.
395 Status status = OkStatus();
396 {
397 mutex_lock l(state_mu_);
398 for (const auto& [task_name, task_state] : cluster_state_) {
399 // Skip tasks that are not registered or in error state
400 if (task_state->GetState() !=
401 CoordinatedTaskState::TASKSTATE_CONNECTED) {
402 continue;
403 }
404 const bool is_stale = task_state->TimeSinceLastHeartbeatMs() >
405 heartbeat_timeout_ms_;
406 VLOG(1) << "Checking staleness for " << task_name
407 << " stale?=" << is_stale;
408 if (is_stale) {
409 stale_task_names.push_back(task_name);
410 status = MakeCoordinationError(errors::Unavailable(
411 "Task ", task_name,
412 " heartbeat timeout. This indicates that the remote task "
413 "has failed, got preempted, or crashed unexpectedly."));
414 SetTaskError(task_name, status);
415 }
416 }
417 }
418 // Propagate heartbeat timeout errors to other connected tasks.
419 if (!stale_task_names.empty()) {
420 if (!has_service_to_client_connection) {
421 // Error cannot be propagated since there is no service-to-client
422 // connection, so shut down service instead. Note: we cannot
423 // destroy the thread within its own function. However, this
424 // thread will be destroyed once the function returns.
425 LOG(ERROR) << "Stopping coordination service as heartbeat has "
426 "timed out for "
427 << stale_task_names[0]
428 << " and there is no service-to-client connection";
429 Stop(/*shut_staleness_thread=*/false);
430 return;
431 }
432 for (const auto& stale_task_name : stale_task_names) {
433 PropagateError(GetTaskFromName(stale_task_name));
434 }
435 stale_task_names.clear();
436 }
437
438 // Barrier timeout check.
439 uint64_t current_time_micros = Env::Default()->NowMicros();
440 {
441 mutex_lock l(state_mu_);
442 // Gather barriers which have timed out.
443 for (const std::string& barrier_id : ongoing_barriers_) {
444 auto* barrier = &barriers_[barrier_id];
445 if (current_time_micros > barrier->deadline_in_micros) {
446 expired_barriers[barrier_id] = barrier;
447 }
448 }
449 // Pass these barriers with the time out error.
450 for (const auto& [barrier_id, barrier] : expired_barriers) {
451 const Status error =
452 MakeCoordinationError(errors::DeadlineExceeded(absl::StrCat(
453 "Barrier timed out. Barrier_id: ", barrier_id)));
454 PassBarrier(barrier_id, error, barrier);
455 }
456 }
457 if (!has_service_to_client_connection &&
458 expired_barriers.contains(shutdown_barrier_id_)) {
459 // Error cannot be propagated since there is no service-to-client
460 // connection, so shut down service instead. Note: we cannot
461 // destroy the thread within its own function. However, this
462 // thread will be destroyed once the function returns.
463 LOG(ERROR)
464 << "Stopping coordination service as shutdown barrier "
465 "timed out and there is no service-to-client connection.";
466 Stop(/*shut_staleness_thread=*/false);
467 }
468 // Reset this for the next barrier check.
469 expired_barriers.clear();
470 }
471 }));
472 }
473
Stop(bool shut_staleness_thread)474 void CoordinationServiceStandaloneImpl::Stop(bool shut_staleness_thread) {
475 {
476 mutex_lock l(kv_mu_);
477 for (const auto& [key, get_kv_callbacks] : get_cb_) {
478 for (const auto& get_kv_callback : get_kv_callbacks) {
479 get_kv_callback(errors::Cancelled(
480 absl::StrCat("Coordination service is shutting down. Cancelling "
481 "GetKeyValue() for key: ",
482 key)));
483 }
484 }
485 get_cb_.clear();
486 }
487 {
488 mutex_lock l(state_mu_);
489 cluster_state_.clear();
490 for (auto& [barrier_id, barrier] : barriers_) {
491 if (!barrier.passed) {
492 Status error = MakeCoordinationError(errors::Aborted(absl::StrCat(
493 "Barrier failed because service is shutting down. Barrier_id: ",
494 barrier_id)));
495 PassBarrier(barrier_id, error, &barrier);
496 }
497 }
498 barriers_.clear();
499 }
500 {
501 mutex_lock l(check_staleness_thread_shutdown_mu_);
502 shutting_down_ = true;
503 check_staleness_thread_cv_.notify_all();
504 }
505 if (shut_staleness_thread) {
506 check_staleness_thread_.reset();
507 }
508 }
509
RegisterTask(const CoordinatedTask & task,uint64_t incarnation)510 Status CoordinationServiceStandaloneImpl::RegisterTask(
511 const CoordinatedTask& task, uint64_t incarnation) {
512 const std::string& task_name = GetTaskName(task);
513
514 Status status;
515 {
516 mutex_lock l(state_mu_);
517 if (!cluster_state_.contains(task_name)) {
518 // Note: return early here as unexpected task register errors should not
519 // be propagated to other tasks.
520 return MakeCoordinationError(errors::InvalidArgument(
521 "Unexpected task registered with task_name=", task_name));
522 }
523 if (cluster_state_[task_name]->GetState() ==
524 CoordinatedTaskState::TASKSTATE_DISCONNECTED) {
525 // This task is currently disconnected (registering for the first time or
526 // has called ResetTask() previously).
527 cluster_state_[task_name]->SetConnected(incarnation);
528 LOG(INFO) << task_name
529 << " has connected to coordination service. Incarnation: "
530 << incarnation;
531 } else {
532 // This task is connected or already in error, which implies it has
533 // registered previously.
534 status = MakeCoordinationError(
535 errors::Aborted("Duplicate task registration with task_name=",
536 task_name),
537 task);
538 SetTaskError(task_name, status);
539 }
540 }
541 if (!status.ok()) {
542 PropagateError(task);
543 }
544 return status;
545 }
546
WaitForAllTasks(const CoordinatedTask & task,const CoordinationServiceDeviceInfo & devices,StatusCallback done)547 void CoordinationServiceStandaloneImpl::WaitForAllTasks(
548 const CoordinatedTask& task, const CoordinationServiceDeviceInfo& devices,
549 StatusCallback done) {
550 {
551 mutex_lock l(state_mu_);
552 const auto& task_state = cluster_state_.find(GetTaskName(task));
553 // Add task device info to global device state for the first time that task
554 // has called WaitForAllTasks().
555 if (task_state != cluster_state_.end() &&
556 !task_state->second->GetDeviceInfoCollected()) {
557 cluster_devices_.MergeFrom(devices);
558 task_state->second->MarkDeviceInfoCollected();
559 }
560 }
561 BarrierAsync(device_propagation_barrier_id_, kDevicePropagationTimeout, task,
562 {}, std::move(done));
563 }
564
ShutdownTaskAsync(const CoordinatedTask & task,StatusCallback done)565 void CoordinationServiceStandaloneImpl::ShutdownTaskAsync(
566 const CoordinatedTask& task, StatusCallback done) {
567 if (shutdown_barrier_timeout_ > absl::ZeroDuration()) {
568 // Impose shutdown barrier so that all tasks can disconnect together.
569 BarrierAsync(shutdown_barrier_id_, shutdown_barrier_timeout_, task, {},
570 done);
571 } else {
572 Status status;
573 {
574 mutex_lock l(state_mu_);
575 // Disconnect task from service individually.
576 status = DisconnectTask(task);
577 }
578 done(status);
579 }
580 }
581
ResetTask(const CoordinatedTask & task)582 Status CoordinationServiceStandaloneImpl::ResetTask(
583 const CoordinatedTask& task) {
584 mutex_lock l(state_mu_);
585 return DisconnectTask(task);
586 }
587
DisconnectTask(const CoordinatedTask & task)588 Status CoordinationServiceStandaloneImpl::DisconnectTask(
589 const CoordinatedTask& task) {
590 const std::string task_name = GetTaskName(task);
591 // Check if task is valid and not already disconnected.
592 if (!cluster_state_.contains(task_name)) {
593 return MakeCoordinationError(errors::InvalidArgument(
594 "Unexpected disconnect request with task_name=", task_name));
595 } else if (cluster_state_[task_name]->GetState() ==
596 CoordinatedTaskState::TASKSTATE_DISCONNECTED) {
597 return MakeCoordinationError(errors::FailedPrecondition(
598 "The task is already disconnected: ", task_name));
599 }
600
601 // Disconnect task and fail any ongoing barriers.
602 cluster_state_[task_name]->Disconnect(
603 /*grace_period_duration_us=*/heartbeat_timeout_ms_ * 1000);
604 for (const auto& barrier_id :
605 cluster_state_[task_name]->GetOngoingBarriers()) {
606 Status error = MakeCoordinationError(errors::Internal(absl::StrCat(
607 "Barrier failed from a disconnected task. Barrier Id: ", barrier_id,
608 ", Task: ", task_name)));
609 PassBarrier(barrier_id, error, &barriers_[barrier_id]);
610 }
611
612 LOG(INFO) << task_name << " has disconnected from coordination service.";
613 return OkStatus();
614 }
615
616 const CoordinationServiceDeviceInfo&
ListClusterDevices()617 CoordinationServiceStandaloneImpl::ListClusterDevices() {
618 return cluster_devices_;
619 }
620
GetServiceIncarnation()621 uint64_t CoordinationServiceStandaloneImpl::GetServiceIncarnation() {
622 return service_incarnation_;
623 }
624
ReportTaskError(const CoordinatedTask & task,Status error)625 Status CoordinationServiceStandaloneImpl::ReportTaskError(
626 const CoordinatedTask& task, Status error) {
627 const std::string& task_name = GetTaskName(task);
628 {
629 mutex_lock l(state_mu_);
630 if (!cluster_state_.contains(task_name)) {
631 return MakeCoordinationError(
632 errors::InvalidArgument("Unexpected request from task ", task_name));
633 } else if (cluster_state_[task_name]->GetState() !=
634 CoordinatedTaskState::TASKSTATE_CONNECTED) {
635 return MakeCoordinationError(errors::FailedPrecondition(
636 "The task is not connected or already has an error."));
637 } else {
638 SetTaskError(task_name, error);
639 }
640 }
641 PropagateError(task, /*is_reported_by_task=*/true);
642 return OkStatus();
643 }
644
RecordHeartbeat(const CoordinatedTask & task,uint64_t incarnation)645 Status CoordinationServiceStandaloneImpl::RecordHeartbeat(
646 const CoordinatedTask& task, uint64_t incarnation) {
647 const std::string& task_name = GetTaskName(task);
648 Status s = OkStatus();
649 {
650 mutex_lock l(state_mu_);
651 if (!cluster_state_.contains(task_name)) {
652 return MakeCoordinationError(errors::InvalidArgument(
653 "Unexpected task request with task_name=", task_name));
654 }
655 if (!cluster_state_[task_name]->GetStatus().ok()) {
656 return cluster_state_[task_name]->GetStatus();
657 } else if (cluster_state_[task_name]->GetState() ==
658 CoordinatedTaskState::TASKSTATE_DISCONNECTED &&
659 // We accept heartbeats for a short grace period to account for
660 // the lag time between the service recording the state change
661 // and the agent stopping heartbeats.
662 Env::Default()->NowMicros() >
663 cluster_state_[task_name]
664 ->GetDisconnectedGracePeriodMicros()) {
665 return MakeCoordinationError(errors::InvalidArgument(
666 "Task with task_name=", task_name,
667 " must be registered before sending heartbeat messages"));
668 }
669 s = cluster_state_[task_name]->RecordHeartbeat(incarnation);
670 }
671
672 // Set and propagate any heartbeat errors.
673 if (!s.ok()) {
674 {
675 mutex_lock l(state_mu_);
676 SetTaskError(task_name, s);
677 }
678 PropagateError(task);
679 }
680
681 return s;
682 }
683
ReportServiceErrorToTaskAsync(const CoordinatedTask & destination_task,Status error)684 void CoordinationServiceStandaloneImpl::ReportServiceErrorToTaskAsync(
685 const CoordinatedTask& destination_task, Status error) {
686 assert(!error.ok());
687
688 // Don't report error if there is no service-to-client connection.
689 if (client_cache_ == nullptr) {
690 LOG(ERROR) << error;
691 return;
692 }
693
694 auto request = std::make_shared<ReportErrorToTaskRequest>();
695 auto response = std::make_shared<ReportErrorToTaskResponse>();
696 request->set_error_code(error.code());
697 request->set_error_message(error.error_message());
698 CoordinatedTask* error_source =
699 request->mutable_error_payload()->mutable_source_task();
700 error_source->set_job_name("coordination_service");
701 auto call_opts = std::make_shared<CallOptions>();
702 call_opts->SetTimeout(kServiceToClientTimeoutMs);
703
704 const std::string task_name = GetTaskName(destination_task);
705 CoordinationClient* client = client_cache_->GetClient(task_name);
706 client->ReportErrorToTaskAsync(
707 call_opts.get(), request.get(), response.get(),
708 [request, response, task_name, call_opts](Status s) {
709 if (!s.ok()) {
710 LOG(ERROR) << "Encountered another error while reporting to "
711 << task_name << ": " << s;
712 }
713 });
714 }
715
PropagateError(const CoordinatedTask & source_task,bool is_reported_by_task)716 void CoordinationServiceStandaloneImpl::PropagateError(
717 const CoordinatedTask& source_task, bool is_reported_by_task) {
718 // If the error task is recoverable, do not propagate the error to other
719 // connected tasks.
720 if (isRecoverableJob(source_task.job_name())) return;
721 Status error;
722 {
723 mutex_lock l(state_mu_);
724 error = cluster_state_[GetTaskName(source_task)]->GetStatus();
725 }
726 assert(!error.ok());
727 ReportErrorToTaskRequest request;
728 request.set_error_code(error.code());
729 request.set_error_message(error.error_message());
730 CoordinationServiceError* payload = request.mutable_error_payload();
731 *payload->mutable_source_task() = source_task;
732 payload->set_is_reported_error(is_reported_by_task);
733 CallOptions call_opts;
734 call_opts.SetTimeout(kServiceToClientTimeoutMs);
735 std::vector<std::shared_ptr<absl::Notification>> notifications;
736
737 std::vector<absl::string_view> task_names;
738 {
739 tf_shared_lock l(state_mu_);
740 task_names.reserve(cluster_state_.size());
741 for (const auto& pair : cluster_state_) {
742 task_names.emplace_back(pair.first);
743 }
744 }
745 for (absl::string_view task : task_names) {
746 {
747 mutex_lock l(state_mu_);
748 // Propagate error only to tasks that are connected
749 if (cluster_state_[task]->GetState() !=
750 CoordinatedTaskState::TASKSTATE_CONNECTED)
751 continue;
752 }
753
754 // Don't propagate error if there is no service-to-client connection.
755 if (client_cache_ == nullptr) {
756 LOG(ERROR)
757 << "Stopping coordination service as there is no "
758 "service-to-client connection, but we encountered an error: "
759 << error;
760 Stop(/*shut_staleness_thread=*/false);
761 return;
762 }
763 CoordinationClient* client = client_cache_->GetClient(std::string(task));
764 auto response = std::make_shared<ReportErrorToTaskResponse>();
765 auto n = std::make_shared<absl::Notification>();
766 client->ReportErrorToTaskAsync(
767 &call_opts, &request, response.get(), [response, n, task](Status s) {
768 if (!s.ok()) {
769 LOG(ERROR) << "Encountered another error while reporting to "
770 << task << ": " << s;
771 }
772 n->Notify();
773 });
774 notifications.push_back(n);
775 }
776 for (auto& n : notifications) {
777 n->WaitForNotification();
778 }
779 }
780
781 // Utility for normalizing structured config key string.
782 // The normalized key will not have leading or trailing slashes, and all parts
783 // in the key path are separated by exactly one slack ('/').
784 // E.g., ///a//b/c// --> a/b/c
NormalizeKey(const StringPiece orig_key)785 std::string NormalizeKey(const StringPiece orig_key) {
786 std::string norm_key = std::string(orig_key);
787 const char* src = norm_key.c_str();
788 std::string::iterator dst = norm_key.begin();
789
790 // Parse all characters
791 while (*src) {
792 // Skip leading slashes
793 while (*src == '/') src++;
794 // Copy over all non-slash characters
795 while (*src && *src != '/') {
796 *dst++ = *src++;
797 }
798 // Allow one slash at the end of current directory
799 if (*src) {
800 *dst++ = *src++;
801 }
802 }
803 // If ending with slash, remove the trailing slash
804 if (dst > norm_key.begin() && *(dst - 1) == '/') dst--;
805 norm_key.resize(dst - norm_key.begin());
806 return norm_key;
807 }
808
InsertKeyValue(const std::string & key,const std::string & value)809 Status CoordinationServiceStandaloneImpl::InsertKeyValue(
810 const std::string& key, const std::string& value) {
811 const std::string& norm_key = NormalizeKey(key);
812 mutex_lock l(kv_mu_);
813 if (kv_store_.find(norm_key) != kv_store_.end()) {
814 return MakeCoordinationError(
815 errors::AlreadyExists("Config key ", key, " already exists."));
816 }
817 kv_store_.emplace(norm_key, value);
818 auto iter = get_cb_.find(norm_key);
819 if (iter != get_cb_.end()) {
820 for (const auto& cb : iter->second) {
821 cb(value);
822 }
823 get_cb_.erase(iter);
824 }
825 return OkStatus();
826 }
827
GetKeyValueAsync(const std::string & key,StatusOrValueCallback done)828 void CoordinationServiceStandaloneImpl::GetKeyValueAsync(
829 const std::string& key, StatusOrValueCallback done) {
830 const std::string& norm_key = NormalizeKey(key);
831 mutex_lock l(kv_mu_);
832 const auto& iter = kv_store_.find(norm_key);
833 if (iter != kv_store_.end()) {
834 done(iter->second);
835 return;
836 }
837 auto cb_iter = get_cb_.find(norm_key);
838 if (cb_iter == get_cb_.end()) {
839 cb_iter =
840 get_cb_.emplace(norm_key, std::vector<StatusOrValueCallback>()).first;
841 }
842 cb_iter->second.emplace_back(std::move(done));
843 }
844
TryGetKeyValue(const std::string & key)845 StatusOr<std::string> CoordinationServiceStandaloneImpl::TryGetKeyValue(
846 const std::string& key) {
847 const std::string& norm_key = NormalizeKey(key);
848 mutex_lock l(kv_mu_);
849 const auto& iter = kv_store_.find(norm_key);
850 if (iter == kv_store_.end()) {
851 return errors::NotFound("Config key ", key, " not found.");
852 }
853 return iter->second;
854 }
855
GetKeyValueDir(absl::string_view directory_key)856 std::vector<KeyValueEntry> CoordinationServiceStandaloneImpl::GetKeyValueDir(
857 absl::string_view directory_key) {
858 std::vector<KeyValueEntry> kvs_in_directory;
859 const std::string norm_key = NormalizeKey(directory_key);
860 const std::string dir = absl::StrCat(norm_key, "/");
861
862 mutex_lock l(kv_mu_);
863 // Find first key in ordered map that has the directory prefix.
864 auto begin = kv_store_.lower_bound(dir);
865 std::map<std::string, std::string>::iterator it;
866 // Iterate through key range that match directory prefix.
867 for (it = begin; it != kv_store_.end(); ++it) {
868 // Stop once the next key does not have the directory prefix. Since keys are
869 // ordered, none of the other keys would have a matching prefix.
870 if (std::mismatch(dir.begin(), dir.end(), it->first.begin()).first !=
871 dir.end()) {
872 break;
873 }
874 KeyValueEntry kv;
875 kv.set_key(it->first);
876 kv.set_value(it->second);
877 kvs_in_directory.push_back(kv);
878 }
879
880 return kvs_in_directory;
881 }
882
DeleteKeyValue(const std::string & key)883 Status CoordinationServiceStandaloneImpl::DeleteKeyValue(
884 const std::string& key) {
885 const std::string& norm_key = NormalizeKey(key);
886 mutex_lock l(kv_mu_);
887 // Delete directory: find key range that match directory prefix
888 const std::string& dir = strings::StrCat(norm_key, "/");
889 auto begin = kv_store_.lower_bound(dir);
890 std::map<std::string, std::string>::iterator end;
891 for (end = begin; end != kv_store_.end(); end++) {
892 if (std::mismatch(dir.begin(), dir.end(), end->first.begin()).first !=
893 dir.end())
894 break;
895 }
896 kv_store_.erase(begin, end);
897 auto iter = kv_store_.find(norm_key);
898 if (iter != kv_store_.end()) {
899 kv_store_.erase(iter);
900 }
901 return OkStatus();
902 }
903
SetTaskError(absl::string_view task_name,Status error)904 void CoordinationServiceStandaloneImpl::SetTaskError(
905 absl::string_view task_name, Status error) {
906 cluster_state_[task_name]->SetError(error);
907 for (const auto& barrier_id :
908 cluster_state_[task_name]->GetOngoingBarriers()) {
909 Status error = MakeCoordinationError(errors::Internal(absl::StrCat(
910 "Barrier failed from a task error. Barrier Id: ", barrier_id,
911 ", Task: ", task_name)));
912 PassBarrier(barrier_id, error, &barriers_[barrier_id]);
913 }
914
915 LOG(ERROR) << task_name
916 << " has been set to ERROR in coordination service: " << error;
917 }
918
BarrierAsync(const std::string & barrier_id,absl::Duration timeout,const CoordinatedTask & task,const std::vector<CoordinatedTask> & participating_tasks,StatusCallback done)919 void CoordinationServiceStandaloneImpl::BarrierAsync(
920 const std::string& barrier_id, absl::Duration timeout,
921 const CoordinatedTask& task,
922 const std::vector<CoordinatedTask>& participating_tasks,
923 StatusCallback done) {
924 mutex_lock l(state_mu_);
925 auto pair = barriers_.try_emplace(barrier_id);
926 auto it = pair.first;
927 bool inserted = pair.second;
928 auto* barrier = &it->second;
929 // Create barrier for the first time.
930 if (inserted) {
931 // Initialize barrier state.
932 barrier->passed = false;
933 // Assume barrier is for entire cluster if no tasks are specified.
934 if (participating_tasks.empty()) {
935 for (const auto& task_state : cluster_state_) {
936 absl::string_view task_name = task_state.first;
937 barrier->tasks_at_barrier[GetTaskFromName(task_name)] = false;
938 }
939 } else {
940 for (const auto& task : participating_tasks) {
941 // Fail the barrier immediately if unexpected task is included in the
942 // barrier.
943 const std::string task_name = GetTaskName(task);
944 if (!cluster_state_.contains(task_name)) {
945 Status error = MakeCoordinationError(errors::InvalidArgument(
946 absl::StrCat("Unexpected task (", task_name,
947 ") that is not in the cluster called the barrier. "
948 "Barrier Id: ",
949 barrier_id)));
950 PassBarrier(barrier_id, error, barrier);
951 done(error);
952 return;
953 }
954 barrier->tasks_at_barrier[task] = false;
955 }
956 }
957 barrier->num_pending_tasks = barrier->tasks_at_barrier.size();
958
959 // Fail the barrier immediately if any tasks are already in error.
960 for (const auto& pending_task : barrier->tasks_at_barrier) {
961 const std::string task_name = GetTaskName(pending_task.first);
962 if (cluster_state_[task_name]->GetState() ==
963 CoordinatedTaskState::TASKSTATE_ERROR) {
964 Status error = MakeCoordinationError(errors::Internal(
965 absl::StrCat("Task (", task_name,
966 ") is already in error before the barrier "
967 "was called. Barrier Id: ",
968 barrier_id)));
969 PassBarrier(barrier_id, error, barrier);
970 done(error);
971 return;
972 }
973 }
974 barrier->deadline_in_micros =
975 Env::Default()->NowMicros() + (timeout / absl::Microseconds(1));
976
977 // Add ongoing barrier to cluster state.
978 ongoing_barriers_.emplace(barrier_id);
979 const size_t num_ongoing_barriers = ongoing_barriers_.size();
980 if (num_ongoing_barriers > kOngoingBarriersSoftLimit) {
981 LOG(WARNING) << "There is a high number of ongoing barriers in "
982 "coordination service: "
983 << num_ongoing_barriers;
984 }
985 for (const auto& pending_task : barrier->tasks_at_barrier) {
986 const CoordinatedTask& task = pending_task.first;
987 cluster_state_[GetTaskName(task)]->JoinBarrier(barrier_id);
988 }
989 }
990
991 // Barrier has already been passed, return previous result immediately.
992 if (barrier->passed) {
993 // Special hook for shutdown barrier to disconnect task.
994 if (barrier_id == shutdown_barrier_id_) {
995 Status s = DisconnectTask(task);
996 // Return any errors from the disconnect attempt, otherwise return the
997 // barrier status outside of this hook.
998 if (!s.ok()) {
999 done(s);
1000 return;
1001 }
1002 }
1003
1004 done(barrier->result);
1005 return;
1006 }
1007
1008 // Add pending callbacks.
1009 barrier->done_callbacks.push_back(done);
1010
1011 // Check if caller task is participating in the barrier.
1012 if (!barrier->tasks_at_barrier.contains(task)) {
1013 // Unexpected barrier call from a task not participating in the barrier.
1014 Status error = MakeCoordinationError(errors::InvalidArgument(
1015 absl::StrCat("A non-participating task (", GetTaskName(task),
1016 ") called the barrier: ", barrier_id)));
1017 PassBarrier(barrier_id, error, barrier);
1018 return;
1019 }
1020
1021 // Check if task args are specified consistently across barrier calls.
1022 if (!ValidateTaskArgs(participating_tasks, barrier->tasks_at_barrier,
1023 cluster_state_.size())) {
1024 Status error = MakeCoordinationError(errors::InvalidArgument(absl::StrCat(
1025 "Conflicting tasks specified for the same barrier: ", barrier_id)));
1026 PassBarrier(barrier_id, error, barrier);
1027 return;
1028 }
1029
1030 // Remove pending task.
1031 // We need to check if task made a repeated call after reaching the barrier.
1032 if (!barrier->tasks_at_barrier[task]) {
1033 barrier->tasks_at_barrier[task] = true;
1034 --barrier->num_pending_tasks;
1035
1036 if (barrier->num_pending_tasks == 0) {
1037 PassBarrier(barrier_id, OkStatus(), barrier);
1038 return;
1039 }
1040 }
1041 }
1042
CancelBarrier(const std::string & barrier_id,const CoordinatedTask & task)1043 Status CoordinationServiceStandaloneImpl::CancelBarrier(
1044 const std::string& barrier_id, const CoordinatedTask& task) {
1045 mutex_lock l(state_mu_);
1046 auto [it, inserted] = barriers_.try_emplace(barrier_id);
1047 auto* barrier = &it->second;
1048 if (inserted) {
1049 LOG(WARNING) << "Barrier (" << barrier_id
1050 << ") is cancelled before being created by task: "
1051 << GetTaskName(task);
1052 }
1053 // Barrier has already been passed.
1054 if (barrier->passed) {
1055 return MakeCoordinationError(errors::FailedPrecondition(absl::StrCat(
1056 "Barrier (", barrier_id, ") has already been passed with status code: ",
1057 barrier->result.code())));
1058 }
1059
1060 // Cancel barrier.
1061 Status cancelled = MakeCoordinationError(errors::Cancelled(absl::StrCat(
1062 "Barrier (", barrier_id, ") is cancelled by task: ", GetTaskName(task))));
1063 PassBarrier(barrier_id, cancelled, barrier);
1064
1065 return OkStatus();
1066 }
1067
1068 // Mark barrier as passed.
PassBarrier(absl::string_view barrier_id,Status result,BarrierState * barrier)1069 void CoordinationServiceStandaloneImpl::PassBarrier(
1070 absl::string_view barrier_id, Status result, BarrierState* barrier) {
1071 barrier->passed = true;
1072 barrier->result = result;
1073 // Special hook for device propagation barrier to set global device ids.
1074 if (barrier_id == device_propagation_barrier_id_) {
1075 SetXlaGlobalDeviceIds();
1076 }
1077 for (const auto& task_at_barrier : barrier->tasks_at_barrier) {
1078 // Clean up task state (used as error hooks).
1079 const CoordinatedTask& task = task_at_barrier.first;
1080 cluster_state_[GetTaskName(task)]->ExitBarrier(barrier_id);
1081 }
1082
1083 // Special hook for shutdown barrier to disconnect tasks at the barrier.
1084 if (barrier_id == shutdown_barrier_id_) {
1085 if (result.ok()) {
1086 LOG(INFO) << "Shutdown barrier in coordination service has passed.";
1087 } else {
1088 LOG(ERROR) << "Shutdown barrier in coordination service has failed: "
1089 << result
1090 << ". This suggests that at least one worker did not complete "
1091 "its job, or was too slow/hanging in its execution.";
1092 }
1093 Status shutdown_error = MakeCoordinationError(errors::Internal(
1094 absl::StrCat("Shutdown barrier has been passed with status: '",
1095 barrier->result.ToString(),
1096 "', but this task is not at the barrier yet.")));
1097 for (const auto& [task, at_barrier] : barrier->tasks_at_barrier) {
1098 if (at_barrier) {
1099 // Disconnect tasks that reached the barrier.
1100 Status disconnect_status = DisconnectTask(task);
1101 if (!disconnect_status.ok()) {
1102 LOG(ERROR) << disconnect_status;
1103 }
1104 } else {
1105 // Propagate errors to straggling tasks that have not reached the
1106 // barrier. The barrier must have failed if any task did not reach the
1107 // barrier.
1108 ReportServiceErrorToTaskAsync(task, shutdown_error);
1109 }
1110 }
1111 }
1112 barrier->tasks_at_barrier.clear();
1113 ongoing_barriers_.erase(barrier_id);
1114 // Note: barrier_id shouldn't be referenced after this line as its lifetime
1115 // may be tied to one of the callbacks.
1116 // Propagate results to participating tasks.
1117 for (const auto& callback : barrier->done_callbacks) {
1118 callback(result);
1119 }
1120 barrier->done_callbacks.clear();
1121 }
1122
ValidateTaskArgs(const std::vector<CoordinatedTask> & tasks_args,const absl::flat_hash_map<CoordinatedTask,bool,CoordinatedTaskHash,CoordinatedTaskEqual> & tasks_at_barrier,int64_t cluster_size)1123 bool CoordinationServiceStandaloneImpl::ValidateTaskArgs(
1124
1125 const std::vector<CoordinatedTask>& tasks_args,
1126 const absl::flat_hash_map<CoordinatedTask, bool, CoordinatedTaskHash,
1127 CoordinatedTaskEqual>& tasks_at_barrier,
1128 int64_t cluster_size) {
1129 if (tasks_args.empty()) {
1130 return tasks_at_barrier.size() == cluster_size;
1131 } else if (tasks_at_barrier.size() != tasks_args.size()) {
1132 return false;
1133 } else {
1134 for (const auto& task : tasks_args) {
1135 if (!tasks_at_barrier.contains(task)) {
1136 return false;
1137 }
1138 }
1139 }
1140 return true;
1141 }
1142
SetXlaGlobalDeviceIds()1143 void CoordinationServiceStandaloneImpl::SetXlaGlobalDeviceIds() {
1144 // No-op if TF devices are specified.
1145 if (cluster_devices_.has_xla()) {
1146 int global_id = 0;
1147 for (xla::LocalTopologyProto& local_topology :
1148 *cluster_devices_.mutable_xla()->mutable_devices()->mutable_nodes()) {
1149 for (xla::DeviceProto& device : *local_topology.mutable_devices()) {
1150 device.set_global_device_id(global_id);
1151 ++global_id;
1152 }
1153 }
1154 }
1155 }
1156 } // namespace
1157
EnableCoordinationService(Env * env,const ServerDef & server_def,std::unique_ptr<CoordinationClientCache> cache)1158 std::unique_ptr<CoordinationServiceInterface> EnableCoordinationService(
1159 Env* env, const ServerDef& server_def,
1160 std::unique_ptr<CoordinationClientCache> cache) {
1161 std::unique_ptr<CoordinationServiceInterface> coord_service;
1162 if (is_multi_client_leader(server_def)) {
1163 coord_service = std::make_unique<CoordinationServiceStandaloneImpl>(
1164 std::move(cache), env, server_def);
1165 }
1166 return coord_service;
1167 }
1168
isRecoverableJob(const absl::string_view task_name) const1169 bool CoordinationServiceStandaloneImpl::isRecoverableJob(
1170 const absl::string_view task_name) const {
1171 return recoverable_jobs_.find(task_name) != recoverable_jobs_.end();
1172 }
1173
1174 // Register standalone coordination service implementation.
1175 REGISTER_COORDINATION_SERVICE("standalone", EnableCoordinationService);
1176
1177 } // namespace tensorflow
1178