1 /* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/core/distributed_runtime/coordination/coordination_service.h"
17 
18 #include <algorithm>
19 #include <iterator>
20 #include <string>
21 #include <utility>
22 #include <vector>
23 
24 #include "absl/container/flat_hash_map.h"
25 #include "absl/container/flat_hash_set.h"
26 #include "absl/strings/str_cat.h"
27 #include "absl/strings/string_view.h"
28 #include "absl/synchronization/notification.h"
29 #include "absl/time/time.h"
30 #include "tensorflow/compiler/xla/pjrt/distributed/protocol.pb.h"
31 #include "tensorflow/core/distributed_runtime/call_options.h"
32 #include "tensorflow/core/distributed_runtime/coordination/coordination_client.h"
33 #include "tensorflow/core/distributed_runtime/coordination/coordination_service_error_util.h"
34 #include "tensorflow/core/platform/env.h"
35 #include "tensorflow/core/platform/errors.h"
36 #include "tensorflow/core/platform/macros.h"
37 #include "tensorflow/core/platform/mutex.h"
38 #include "tensorflow/core/platform/random.h"
39 #include "tensorflow/core/platform/status.h"
40 #include "tensorflow/core/platform/strcat.h"
41 #include "tensorflow/core/platform/thread_annotations.h"
42 #include "tensorflow/core/protobuf/cluster.pb.h"
43 #include "tensorflow/core/protobuf/config.pb.h"
44 #include "tensorflow/core/protobuf/coordination_config.pb.h"
45 #include "tensorflow/core/protobuf/coordination_service.pb.h"
46 #include "tensorflow/core/protobuf/tensorflow_server.pb.h"
47 #include "tensorflow/core/util/device_name_utils.h"
48 
49 namespace tensorflow {
50 namespace {
51 
52 constexpr absl::Duration kDevicePropagationTimeout = absl::Hours(1);
53 constexpr int kDefaultHeartbeatTimeoutMs = 10 * 1000;  // 10 seconds
54 constexpr int kServiceToClientTimeoutMs = 10 * 1000;   // 10 seconds
55 constexpr size_t kOngoingBarriersSoftLimit = 20;
56 constexpr char kHealthCheckThread[] = "CoordinationServiceHealthCheck";
57 
GetTaskName(absl::string_view job_name,int task_id)58 std::string GetTaskName(absl::string_view job_name, int task_id) {
59   return strings::StrCat("/job:", job_name, "/replica:", 0, "/task:", task_id);
60 }
61 
GetTaskName(const CoordinatedTask & task)62 std::string GetTaskName(const CoordinatedTask& task) {
63   return GetTaskName(task.job_name(), task.task_id());
64 }
65 
GetTaskFromName(absl::string_view task_name)66 CoordinatedTask GetTaskFromName(absl::string_view task_name) {
67   DeviceNameUtils::ParsedName parsed;
68   DeviceNameUtils::ParseFullName(task_name, &parsed);
69   CoordinatedTask task;
70   task.set_job_name(parsed.job);
71   task.set_task_id(parsed.task);
72   return task;
73 }
74 
is_multi_client_leader(const ServerDef & server_def)75 bool is_multi_client_leader(const ServerDef& server_def) {
76   const auto& config = server_def.default_session_config();
77   const std::string& leader =
78       config.experimental().coordination_config().service_leader();
79   const std::string& collective_leader =
80       config.experimental().collective_group_leader();
81   DeviceNameUtils::ParsedName leader_pn;
82   if (!leader.empty()) {
83     DeviceNameUtils::ParseFullName(leader, &leader_pn);
84   } else if (!collective_leader.empty()) {
85     LOG(INFO) << "No coordination leader is set, using the collective leader "
86               << collective_leader;
87     DeviceNameUtils::ParseFullName(collective_leader, &leader_pn);
88   } else {
89     LOG(INFO) << "No coordination leader is set, using the default /job:"
90               << server_def.job_name() << "/replica:0/task:0";
91     return server_def.task_index() == 0;
92   }
93   return server_def.job_name() == leader_pn.job &&
94          server_def.task_index() == leader_pn.task;
95 }
96 
97 // Convenience structs to allow using CoordinatedTask as container keys.
98 struct CoordinatedTaskHash {
operator ()tensorflow::__anonf25d665d0111::CoordinatedTaskHash99   uint64_t operator()(const CoordinatedTask& task) const {
100     return absl::HashOf(task.job_name(), task.task_id());
101   }
102 };
103 struct CoordinatedTaskEqual {
operator ()tensorflow::__anonf25d665d0111::CoordinatedTaskEqual104   bool operator()(const CoordinatedTask& lhs,
105                   const CoordinatedTask& rhs) const {
106     return lhs.job_name() == rhs.job_name() && lhs.task_id() == rhs.task_id();
107   }
108 };
109 
110 // Standalone implementation of the coordination service.
111 class CoordinationServiceStandaloneImpl : public CoordinationServiceInterface {
112  public:
113   CoordinationServiceStandaloneImpl(
114       std::unique_ptr<CoordinationClientCache> client_cache, Env* env,
115       const ServerDef& server_def);
~CoordinationServiceStandaloneImpl()116   ~CoordinationServiceStandaloneImpl() override { Stop(); }
117 
118   Status RegisterTask(const CoordinatedTask& task,
119                       uint64_t incarnation) override;
120   void WaitForAllTasks(const CoordinatedTask& task,
121                        const CoordinationServiceDeviceInfo& devices,
122                        StatusCallback done) override;
123   void ShutdownTaskAsync(const CoordinatedTask& task,
124                          StatusCallback done) override;
125   Status ResetTask(const CoordinatedTask& task) override;
126   Status RecordHeartbeat(const CoordinatedTask& task,
127                          uint64_t incarnation) override;
128   Status ReportTaskError(const CoordinatedTask& task, Status error) override;
129   Status InsertKeyValue(const std::string& key,
130                         const std::string& value) override;
131   void GetKeyValueAsync(const std::string& key,
132                         StatusOrValueCallback done) override;
133   StatusOr<std::string> TryGetKeyValue(const std::string& key) override;
134   std::vector<KeyValueEntry> GetKeyValueDir(
135       absl::string_view directory_key) override;
136   Status DeleteKeyValue(const std::string& key) override;
137   void BarrierAsync(const std::string& barrier_id, absl::Duration timeout,
138                     const CoordinatedTask& task,
139                     const std::vector<CoordinatedTask>& participating_tasks,
140                     StatusCallback done) override;
141   Status CancelBarrier(const std::string& barrier_id,
142                        const CoordinatedTask& task) override;
143 
144  private:
145   const CoordinationServiceDeviceInfo& ListClusterDevices() override
146       TF_EXCLUSIVE_LOCKS_REQUIRED(state_mu_);
147   uint64_t GetServiceIncarnation() override;
148   void StartCheckStaleness();  // Checks both heartbeat and barrier timeouts.
149   void Stop(bool shut_staleness_thread = true);
150   // Report service error to a specified task.
151   void ReportServiceErrorToTaskAsync(const CoordinatedTask& destination_task,
152                                      Status error);
153   // Report error from a task to all other connected tasks if the task is not
154   // recoverable.
155   // Note: SetTaskError() must be called before propagating its error.
156   void PropagateError(const CoordinatedTask& source_task,
157                       bool is_reported_by_task = false)
158       TF_LOCKS_EXCLUDED(state_mu_);
159   void SetTaskError(absl::string_view task_name, Status error)
160       TF_EXCLUSIVE_LOCKS_REQUIRED(state_mu_);
161   void SetXlaGlobalDeviceIds() TF_EXCLUSIVE_LOCKS_REQUIRED(state_mu_);
162   Status DisconnectTask(const CoordinatedTask& task)
163       TF_EXCLUSIVE_LOCKS_REQUIRED(state_mu_);
164 
165   struct BarrierState {
166     bool passed = false;
167     Status result = errors::Unknown(
168         "Invalid barrier result.");  // Only valid if `passed` is true.
169     uint64_t deadline_in_micros = 0;
170     int num_pending_tasks = 0;
171     // Specifies which tasks have called the barrier so far.
172     absl::flat_hash_map<CoordinatedTask, bool, CoordinatedTaskHash,
173                         CoordinatedTaskEqual>
174         tasks_at_barrier;
175     std::vector<StatusCallback> done_callbacks;
176   };
177   void PassBarrier(absl::string_view barrier_id, Status result,
178                    BarrierState* barrier)
179       TF_EXCLUSIVE_LOCKS_REQUIRED(state_mu_);
180   // Check if participating tasks are specified correctly across barrier calls.
181   bool ValidateTaskArgs(
182       const std::vector<CoordinatedTask>& tasks_args,
183       const absl::flat_hash_map<CoordinatedTask, bool, CoordinatedTaskHash,
184                                 CoordinatedTaskEqual>& tasks_at_barrier,
185       int64_t cluster_size);
186   bool isRecoverableJob(const absl::string_view task_name) const;
187 
188   class TaskState {
189    public:
190     // Task state maintained on the coordination service side.
191     // State transition:
192     //                Register           Heartbeat
193     //   DISCONNECTED -------> CONNECTED --------> ERROR (timeout)
194     //                              |   ReportError
195     //                              +--------------> ERROR
196     //
197     // When task state becomes ERROR, propagate this status to other CONNECTED
198     // tasks in the cluster.
199 
GetState()200     CoordinatedTaskState GetState() { return state_; }
GetStatus()201     Status GetStatus() { return status_; }
202     void SetConnected(uint64_t task_incarnation);
203     void Disconnect(uint64_t grace_period_duration_us);
204     Status RecordHeartbeat(uint64_t task_incarnation);
205     int64_t TimeSinceLastHeartbeatMs();
206     // This denotes the deadline after which we stop accepting heartbeats from a
207     // disconnected task. This grace period accounts for the lag time between
208     // the service recording the state change and the agent stopping heartbeats.
209     uint64_t GetDisconnectedGracePeriodMicros();
210     void SetError(Status status);
GetDeviceInfoCollected()211     bool GetDeviceInfoCollected() { return device_info_collected_; }
MarkDeviceInfoCollected()212     void MarkDeviceInfoCollected() { device_info_collected_ = true; }
213     absl::flat_hash_set<std::string> GetOngoingBarriers();
214     void JoinBarrier(absl::string_view barrier_id);
215     void ExitBarrier(absl::string_view barrier_id);
216 
217    private:
218     // Incarnation ID for CPU:0 on remote task.
219     uint64_t task_incarnation_ = 0;
220 
221     CoordinatedTaskState state_ = CoordinatedTaskState::TASKSTATE_DISCONNECTED;
222     Status status_;
223     mutex last_heartbeat_mu_;
224     uint64_t last_heartbeat_us_ TF_GUARDED_BY(last_heartbeat_mu_);
225     // This denotes the deadline after which we stop accepting heartbeats from a
226     // disconnected task. This grace period accounts for the lag time between
227     // the service recording the state change and the agent stopping heartbeats.
228     uint64_t disconnect_grace_period_us_ = 0;
229     // Checks if task has called WaitForAllTasks() previously, which gathers the
230     // local device info.
231     bool device_info_collected_ = false;
232     // For now, we assume there won't be many simultaneous barriers so we simply
233     // use a set.
234     absl::flat_hash_set<std::string> ongoing_barriers_for_task_;
235   };
236 
237   std::unique_ptr<CoordinationClientCache> client_cache_;
238   Env& env_;
239   const uint64_t service_incarnation_ = random::New64();
240   const uint64_t heartbeat_timeout_ms_;
241   const absl::Duration shutdown_barrier_timeout_;
242 
243   const std::string device_propagation_barrier_id_ =
244       absl::StrCat("WaitForAllTasks::", std::to_string(service_incarnation_));
245   const std::string shutdown_barrier_id_ =
246       absl::StrCat("Shutdown::", std::to_string(service_incarnation_));
247 
248   mutex state_mu_;
249   absl::flat_hash_map<std::string, std::unique_ptr<TaskState>> cluster_state_
250       TF_GUARDED_BY(state_mu_);
251   CoordinationServiceDeviceInfo cluster_devices_ TF_GUARDED_BY(state_mu_);
252 
253   mutex kv_mu_;
254   // Ordered map to store config key-values
255   std::map<std::string, std::string> kv_store_ TF_GUARDED_BY(kv_mu_);
256   absl::flat_hash_map<std::string, std::vector<StatusOrValueCallback>> get_cb_
257       TF_GUARDED_BY(kv_mu_);
258 
259   mutex check_staleness_thread_shutdown_mu_;
260   condition_variable check_staleness_thread_cv_;
261   bool shutting_down_ TF_GUARDED_BY(check_staleness_thread_shutdown_mu_) =
262       false;
263   std::unique_ptr<Thread> check_staleness_thread_;
264 
265   absl::flat_hash_map<std::string, BarrierState> barriers_
266       TF_GUARDED_BY(state_mu_);
267   // For now, we assume there won't be many simultaneous barriers so we simply
268   // use a set.
269   absl::flat_hash_set<std::string> ongoing_barriers_ TF_GUARDED_BY(state_mu_);
270 
271   absl::flat_hash_set<std::string> recoverable_jobs_;
272 
273   TF_DISALLOW_COPY_AND_ASSIGN(CoordinationServiceStandaloneImpl);
274 };
275 
SetConnected(uint64_t task_incarnation)276 void CoordinationServiceStandaloneImpl::TaskState::SetConnected(
277     uint64_t task_incarnation) {
278   state_ = CoordinatedTaskState::TASKSTATE_CONNECTED;
279   status_ = OkStatus();
280   task_incarnation_ = task_incarnation;
281   mutex_lock l(last_heartbeat_mu_);
282   last_heartbeat_us_ = Env::Default()->NowMicros();
283 }
284 
Disconnect(uint64_t grace_period_duration_us)285 void CoordinationServiceStandaloneImpl::TaskState::Disconnect(
286     uint64_t grace_period_duration_us) {
287   disconnect_grace_period_us_ =
288       Env::Default()->NowMicros() + grace_period_duration_us;
289   state_ = CoordinatedTaskState::TASKSTATE_DISCONNECTED;
290   status_ = OkStatus();
291 }
292 
SetError(const Status status)293 void CoordinationServiceStandaloneImpl::TaskState::SetError(
294     const Status status) {
295   if (state_ == CoordinatedTaskState::TASKSTATE_ERROR) return;
296   state_ = CoordinatedTaskState::TASKSTATE_ERROR;
297   status_ = status;
298 }
299 
RecordHeartbeat(uint64_t task_incarnation)300 Status CoordinationServiceStandaloneImpl::TaskState::RecordHeartbeat(
301     uint64_t task_incarnation) {
302   if (!status_.ok()) return status_;
303   if (task_incarnation != task_incarnation_) {
304     return MakeCoordinationError(errors::Aborted(
305         "Incarnation ID mismatch: expecting ", task_incarnation_, " but got ",
306         task_incarnation, ". This means the remote task has restarted."));
307   }
308   mutex_lock l(last_heartbeat_mu_);
309   last_heartbeat_us_ = Env::Default()->NowMicros();
310   return OkStatus();
311 }
312 
313 int64_t
TimeSinceLastHeartbeatMs()314 CoordinationServiceStandaloneImpl::TaskState::TimeSinceLastHeartbeatMs() {
315   mutex_lock l(last_heartbeat_mu_);
316   return (Env::Default()->NowMicros() - last_heartbeat_us_) / 1000;
317 }
318 
319 uint64_t CoordinationServiceStandaloneImpl::TaskState::
GetDisconnectedGracePeriodMicros()320     GetDisconnectedGracePeriodMicros() {
321   return disconnect_grace_period_us_;
322 }
323 
324 absl::flat_hash_set<std::string>
GetOngoingBarriers()325 CoordinationServiceStandaloneImpl::TaskState::GetOngoingBarriers() {
326   return ongoing_barriers_for_task_;
327 }
328 
JoinBarrier(absl::string_view barrier_id)329 void CoordinationServiceStandaloneImpl::TaskState::JoinBarrier(
330     absl::string_view barrier_id) {
331   ongoing_barriers_for_task_.emplace(barrier_id);
332 }
333 
ExitBarrier(absl::string_view barrier_id)334 void CoordinationServiceStandaloneImpl::TaskState::ExitBarrier(
335     absl::string_view barrier_id) {
336   ongoing_barriers_for_task_.erase(barrier_id);
337 }
CoordinationServiceStandaloneImpl(std::unique_ptr<CoordinationClientCache> client_cache,Env * env,const ServerDef & server_def)338 CoordinationServiceStandaloneImpl::CoordinationServiceStandaloneImpl(
339     std::unique_ptr<CoordinationClientCache> client_cache, Env* env,
340     const ServerDef& server_def)
341     : client_cache_(std::move(client_cache)),
342       env_(*env),
343       heartbeat_timeout_ms_([&server_def]() -> uint64_t {
344         const auto& configs = server_def.default_session_config()
345                                   .experimental()
346                                   .coordination_config();
347         return configs.heartbeat_timeout_in_ms() > 0
348                    ? configs.heartbeat_timeout_in_ms()
349                    : kDefaultHeartbeatTimeoutMs;
350       }()),
351       shutdown_barrier_timeout_(
352           absl::Milliseconds(server_def.default_session_config()
353                                  .experimental()
354                                  .coordination_config()
355                                  .shutdown_barrier_timeout_in_ms())) {
356   const auto& configs =
357       server_def.default_session_config().experimental().coordination_config();
358   const std::unordered_set<std::string> coordinated_jobs(
359       configs.coordinated_jobs().cbegin(), configs.coordinated_jobs().cend());
360   recoverable_jobs_ = absl::flat_hash_set<std::string>(
361       configs.recoverable_jobs().cbegin(), configs.recoverable_jobs().cend());
362   const auto& cluster_def = server_def.cluster();
363   for (const auto& job : cluster_def.job()) {
364     // If `coordinated_jobs` is specified, skip jobs that are not included there
365     if (!coordinated_jobs.empty() &&
366         coordinated_jobs.find(job.name()) == coordinated_jobs.end()) {
367       continue;
368     }
369     for (const auto& task : job.tasks()) {
370       const std::string& task_name = GetTaskName(job.name(), task.first);
371       cluster_state_.emplace(task_name, std::make_unique<TaskState>());
372     }
373   }
374   StartCheckStaleness();
375 }
376 
377 // Checks both heartbeat and barrier timeouts in the same thread, since threads
378 // are a constrained resource.
StartCheckStaleness()379 void CoordinationServiceStandaloneImpl::StartCheckStaleness() {
380   check_staleness_thread_.reset(
381       env_.StartThread({}, kHealthCheckThread, [this]() {
382         const bool has_service_to_client_connection = client_cache_ != nullptr;
383         // Used to store stale tasks and barriers.
384         std::vector<absl::string_view> stale_task_names;
385         absl::flat_hash_map<std::string, BarrierState*> expired_barriers;
386         while (true) {
387           {
388             mutex_lock l(check_staleness_thread_shutdown_mu_);
389             check_staleness_thread_cv_.wait_for(l, std::chrono::seconds(1));
390             if (shutting_down_) {
391               return;
392             }
393           }
394           // Heartbeat check.
395           Status status = OkStatus();
396           {
397             mutex_lock l(state_mu_);
398             for (const auto& [task_name, task_state] : cluster_state_) {
399               // Skip tasks that are not registered or in error state
400               if (task_state->GetState() !=
401                   CoordinatedTaskState::TASKSTATE_CONNECTED) {
402                 continue;
403               }
404               const bool is_stale = task_state->TimeSinceLastHeartbeatMs() >
405                                     heartbeat_timeout_ms_;
406               VLOG(1) << "Checking staleness for " << task_name
407                       << " stale?=" << is_stale;
408               if (is_stale) {
409                 stale_task_names.push_back(task_name);
410                 status = MakeCoordinationError(errors::Unavailable(
411                     "Task ", task_name,
412                     " heartbeat timeout. This indicates that the remote task "
413                     "has failed, got preempted, or crashed unexpectedly."));
414                 SetTaskError(task_name, status);
415               }
416             }
417           }
418           // Propagate heartbeat timeout errors to other connected tasks.
419           if (!stale_task_names.empty()) {
420             if (!has_service_to_client_connection) {
421               // Error cannot be propagated since there is no service-to-client
422               // connection, so shut down service instead. Note: we cannot
423               // destroy the thread within its own function. However, this
424               // thread will be destroyed once the function returns.
425               LOG(ERROR) << "Stopping coordination service as heartbeat has "
426                             "timed out for "
427                          << stale_task_names[0]
428                          << " and there is no service-to-client connection";
429               Stop(/*shut_staleness_thread=*/false);
430               return;
431             }
432             for (const auto& stale_task_name : stale_task_names) {
433               PropagateError(GetTaskFromName(stale_task_name));
434             }
435             stale_task_names.clear();
436           }
437 
438           // Barrier timeout check.
439           uint64_t current_time_micros = Env::Default()->NowMicros();
440           {
441             mutex_lock l(state_mu_);
442             // Gather barriers which have timed out.
443             for (const std::string& barrier_id : ongoing_barriers_) {
444               auto* barrier = &barriers_[barrier_id];
445               if (current_time_micros > barrier->deadline_in_micros) {
446                 expired_barriers[barrier_id] = barrier;
447               }
448             }
449             // Pass these barriers with the time out error.
450             for (const auto& [barrier_id, barrier] : expired_barriers) {
451               const Status error =
452                   MakeCoordinationError(errors::DeadlineExceeded(absl::StrCat(
453                       "Barrier timed out. Barrier_id: ", barrier_id)));
454               PassBarrier(barrier_id, error, barrier);
455             }
456           }
457           if (!has_service_to_client_connection &&
458               expired_barriers.contains(shutdown_barrier_id_)) {
459             // Error cannot be propagated since there is no service-to-client
460             // connection, so shut down service instead. Note: we cannot
461             // destroy the thread within its own function. However, this
462             // thread will be destroyed once the function returns.
463             LOG(ERROR)
464                 << "Stopping coordination service as shutdown barrier "
465                    "timed out and there is no service-to-client connection.";
466             Stop(/*shut_staleness_thread=*/false);
467           }
468           // Reset this for the next barrier check.
469           expired_barriers.clear();
470         }
471       }));
472 }
473 
Stop(bool shut_staleness_thread)474 void CoordinationServiceStandaloneImpl::Stop(bool shut_staleness_thread) {
475   {
476     mutex_lock l(kv_mu_);
477     for (const auto& [key, get_kv_callbacks] : get_cb_) {
478       for (const auto& get_kv_callback : get_kv_callbacks) {
479         get_kv_callback(errors::Cancelled(
480             absl::StrCat("Coordination service is shutting down. Cancelling "
481                          "GetKeyValue() for key: ",
482                          key)));
483       }
484     }
485     get_cb_.clear();
486   }
487   {
488     mutex_lock l(state_mu_);
489     cluster_state_.clear();
490     for (auto& [barrier_id, barrier] : barriers_) {
491       if (!barrier.passed) {
492         Status error = MakeCoordinationError(errors::Aborted(absl::StrCat(
493             "Barrier failed because service is shutting down. Barrier_id: ",
494             barrier_id)));
495         PassBarrier(barrier_id, error, &barrier);
496       }
497     }
498     barriers_.clear();
499   }
500   {
501     mutex_lock l(check_staleness_thread_shutdown_mu_);
502     shutting_down_ = true;
503     check_staleness_thread_cv_.notify_all();
504   }
505   if (shut_staleness_thread) {
506     check_staleness_thread_.reset();
507   }
508 }
509 
RegisterTask(const CoordinatedTask & task,uint64_t incarnation)510 Status CoordinationServiceStandaloneImpl::RegisterTask(
511     const CoordinatedTask& task, uint64_t incarnation) {
512   const std::string& task_name = GetTaskName(task);
513 
514   Status status;
515   {
516     mutex_lock l(state_mu_);
517     if (!cluster_state_.contains(task_name)) {
518       // Note: return early here as unexpected task register errors should not
519       // be propagated to other tasks.
520       return MakeCoordinationError(errors::InvalidArgument(
521           "Unexpected task registered with task_name=", task_name));
522     }
523     if (cluster_state_[task_name]->GetState() ==
524         CoordinatedTaskState::TASKSTATE_DISCONNECTED) {
525       // This task is currently disconnected (registering for the first time or
526       // has called ResetTask() previously).
527       cluster_state_[task_name]->SetConnected(incarnation);
528       LOG(INFO) << task_name
529                 << " has connected to coordination service. Incarnation: "
530                 << incarnation;
531     } else {
532       // This task is connected or already in error, which implies it has
533       // registered previously.
534       status = MakeCoordinationError(
535           errors::Aborted("Duplicate task registration with task_name=",
536                           task_name),
537           task);
538       SetTaskError(task_name, status);
539     }
540   }
541   if (!status.ok()) {
542     PropagateError(task);
543   }
544   return status;
545 }
546 
WaitForAllTasks(const CoordinatedTask & task,const CoordinationServiceDeviceInfo & devices,StatusCallback done)547 void CoordinationServiceStandaloneImpl::WaitForAllTasks(
548     const CoordinatedTask& task, const CoordinationServiceDeviceInfo& devices,
549     StatusCallback done) {
550   {
551     mutex_lock l(state_mu_);
552     const auto& task_state = cluster_state_.find(GetTaskName(task));
553     // Add task device info to global device state for the first time that task
554     // has called WaitForAllTasks().
555     if (task_state != cluster_state_.end() &&
556         !task_state->second->GetDeviceInfoCollected()) {
557       cluster_devices_.MergeFrom(devices);
558       task_state->second->MarkDeviceInfoCollected();
559     }
560   }
561   BarrierAsync(device_propagation_barrier_id_, kDevicePropagationTimeout, task,
562                {}, std::move(done));
563 }
564 
ShutdownTaskAsync(const CoordinatedTask & task,StatusCallback done)565 void CoordinationServiceStandaloneImpl::ShutdownTaskAsync(
566     const CoordinatedTask& task, StatusCallback done) {
567   if (shutdown_barrier_timeout_ > absl::ZeroDuration()) {
568     // Impose shutdown barrier so that all tasks can disconnect together.
569     BarrierAsync(shutdown_barrier_id_, shutdown_barrier_timeout_, task, {},
570                  done);
571   } else {
572     Status status;
573     {
574       mutex_lock l(state_mu_);
575       // Disconnect task from service individually.
576       status = DisconnectTask(task);
577     }
578     done(status);
579   }
580 }
581 
ResetTask(const CoordinatedTask & task)582 Status CoordinationServiceStandaloneImpl::ResetTask(
583     const CoordinatedTask& task) {
584   mutex_lock l(state_mu_);
585   return DisconnectTask(task);
586 }
587 
DisconnectTask(const CoordinatedTask & task)588 Status CoordinationServiceStandaloneImpl::DisconnectTask(
589     const CoordinatedTask& task) {
590   const std::string task_name = GetTaskName(task);
591   // Check if task is valid and not already disconnected.
592   if (!cluster_state_.contains(task_name)) {
593     return MakeCoordinationError(errors::InvalidArgument(
594         "Unexpected disconnect request with task_name=", task_name));
595   } else if (cluster_state_[task_name]->GetState() ==
596              CoordinatedTaskState::TASKSTATE_DISCONNECTED) {
597     return MakeCoordinationError(errors::FailedPrecondition(
598         "The task is already disconnected: ", task_name));
599   }
600 
601   // Disconnect task and fail any ongoing barriers.
602   cluster_state_[task_name]->Disconnect(
603       /*grace_period_duration_us=*/heartbeat_timeout_ms_ * 1000);
604   for (const auto& barrier_id :
605        cluster_state_[task_name]->GetOngoingBarriers()) {
606     Status error = MakeCoordinationError(errors::Internal(absl::StrCat(
607         "Barrier failed from a disconnected task. Barrier Id: ", barrier_id,
608         ", Task: ", task_name)));
609     PassBarrier(barrier_id, error, &barriers_[barrier_id]);
610   }
611 
612   LOG(INFO) << task_name << " has disconnected from coordination service.";
613   return OkStatus();
614 }
615 
616 const CoordinationServiceDeviceInfo&
ListClusterDevices()617 CoordinationServiceStandaloneImpl::ListClusterDevices() {
618   return cluster_devices_;
619 }
620 
GetServiceIncarnation()621 uint64_t CoordinationServiceStandaloneImpl::GetServiceIncarnation() {
622   return service_incarnation_;
623 }
624 
ReportTaskError(const CoordinatedTask & task,Status error)625 Status CoordinationServiceStandaloneImpl::ReportTaskError(
626     const CoordinatedTask& task, Status error) {
627   const std::string& task_name = GetTaskName(task);
628   {
629     mutex_lock l(state_mu_);
630     if (!cluster_state_.contains(task_name)) {
631       return MakeCoordinationError(
632           errors::InvalidArgument("Unexpected request from task ", task_name));
633     } else if (cluster_state_[task_name]->GetState() !=
634                CoordinatedTaskState::TASKSTATE_CONNECTED) {
635       return MakeCoordinationError(errors::FailedPrecondition(
636           "The task is not connected or already has an error."));
637     } else {
638       SetTaskError(task_name, error);
639     }
640   }
641   PropagateError(task, /*is_reported_by_task=*/true);
642   return OkStatus();
643 }
644 
RecordHeartbeat(const CoordinatedTask & task,uint64_t incarnation)645 Status CoordinationServiceStandaloneImpl::RecordHeartbeat(
646     const CoordinatedTask& task, uint64_t incarnation) {
647   const std::string& task_name = GetTaskName(task);
648   Status s = OkStatus();
649   {
650     mutex_lock l(state_mu_);
651     if (!cluster_state_.contains(task_name)) {
652       return MakeCoordinationError(errors::InvalidArgument(
653           "Unexpected task request with task_name=", task_name));
654     }
655     if (!cluster_state_[task_name]->GetStatus().ok()) {
656       return cluster_state_[task_name]->GetStatus();
657     } else if (cluster_state_[task_name]->GetState() ==
658                    CoordinatedTaskState::TASKSTATE_DISCONNECTED &&
659                // We accept heartbeats for a short grace period to account for
660                // the lag time between the service recording the state change
661                // and the agent stopping heartbeats.
662                Env::Default()->NowMicros() >
663                    cluster_state_[task_name]
664                        ->GetDisconnectedGracePeriodMicros()) {
665       return MakeCoordinationError(errors::InvalidArgument(
666           "Task with task_name=", task_name,
667           " must be registered before sending heartbeat messages"));
668     }
669     s = cluster_state_[task_name]->RecordHeartbeat(incarnation);
670   }
671 
672   // Set and propagate any heartbeat errors.
673   if (!s.ok()) {
674     {
675       mutex_lock l(state_mu_);
676       SetTaskError(task_name, s);
677     }
678     PropagateError(task);
679   }
680 
681   return s;
682 }
683 
ReportServiceErrorToTaskAsync(const CoordinatedTask & destination_task,Status error)684 void CoordinationServiceStandaloneImpl::ReportServiceErrorToTaskAsync(
685     const CoordinatedTask& destination_task, Status error) {
686   assert(!error.ok());
687 
688   // Don't report error if there is no service-to-client connection.
689   if (client_cache_ == nullptr) {
690     LOG(ERROR) << error;
691     return;
692   }
693 
694   auto request = std::make_shared<ReportErrorToTaskRequest>();
695   auto response = std::make_shared<ReportErrorToTaskResponse>();
696   request->set_error_code(error.code());
697   request->set_error_message(error.error_message());
698   CoordinatedTask* error_source =
699       request->mutable_error_payload()->mutable_source_task();
700   error_source->set_job_name("coordination_service");
701   auto call_opts = std::make_shared<CallOptions>();
702   call_opts->SetTimeout(kServiceToClientTimeoutMs);
703 
704   const std::string task_name = GetTaskName(destination_task);
705   CoordinationClient* client = client_cache_->GetClient(task_name);
706   client->ReportErrorToTaskAsync(
707       call_opts.get(), request.get(), response.get(),
708       [request, response, task_name, call_opts](Status s) {
709         if (!s.ok()) {
710           LOG(ERROR) << "Encountered another error while reporting to "
711                      << task_name << ": " << s;
712         }
713       });
714 }
715 
PropagateError(const CoordinatedTask & source_task,bool is_reported_by_task)716 void CoordinationServiceStandaloneImpl::PropagateError(
717     const CoordinatedTask& source_task, bool is_reported_by_task) {
718   // If the error task is recoverable, do not propagate the error to other
719   // connected tasks.
720   if (isRecoverableJob(source_task.job_name())) return;
721   Status error;
722   {
723     mutex_lock l(state_mu_);
724     error = cluster_state_[GetTaskName(source_task)]->GetStatus();
725   }
726   assert(!error.ok());
727   ReportErrorToTaskRequest request;
728   request.set_error_code(error.code());
729   request.set_error_message(error.error_message());
730   CoordinationServiceError* payload = request.mutable_error_payload();
731   *payload->mutable_source_task() = source_task;
732   payload->set_is_reported_error(is_reported_by_task);
733   CallOptions call_opts;
734   call_opts.SetTimeout(kServiceToClientTimeoutMs);
735   std::vector<std::shared_ptr<absl::Notification>> notifications;
736 
737   std::vector<absl::string_view> task_names;
738   {
739     tf_shared_lock l(state_mu_);
740     task_names.reserve(cluster_state_.size());
741     for (const auto& pair : cluster_state_) {
742       task_names.emplace_back(pair.first);
743     }
744   }
745   for (absl::string_view task : task_names) {
746     {
747       mutex_lock l(state_mu_);
748       // Propagate error only to tasks that are connected
749       if (cluster_state_[task]->GetState() !=
750           CoordinatedTaskState::TASKSTATE_CONNECTED)
751         continue;
752     }
753 
754     // Don't propagate error if there is no service-to-client connection.
755     if (client_cache_ == nullptr) {
756       LOG(ERROR)
757           << "Stopping coordination service as there is no "
758              "service-to-client connection, but we encountered an error: "
759           << error;
760       Stop(/*shut_staleness_thread=*/false);
761       return;
762     }
763     CoordinationClient* client = client_cache_->GetClient(std::string(task));
764     auto response = std::make_shared<ReportErrorToTaskResponse>();
765     auto n = std::make_shared<absl::Notification>();
766     client->ReportErrorToTaskAsync(
767         &call_opts, &request, response.get(), [response, n, task](Status s) {
768           if (!s.ok()) {
769             LOG(ERROR) << "Encountered another error while reporting to "
770                        << task << ": " << s;
771           }
772           n->Notify();
773         });
774     notifications.push_back(n);
775   }
776   for (auto& n : notifications) {
777     n->WaitForNotification();
778   }
779 }
780 
781 // Utility for normalizing structured config key string.
782 // The normalized key will not have leading or trailing slashes, and all parts
783 // in the key path are separated by exactly one slack ('/').
784 // E.g., ///a//b/c// --> a/b/c
NormalizeKey(const StringPiece orig_key)785 std::string NormalizeKey(const StringPiece orig_key) {
786   std::string norm_key = std::string(orig_key);
787   const char* src = norm_key.c_str();
788   std::string::iterator dst = norm_key.begin();
789 
790   // Parse all characters
791   while (*src) {
792     // Skip leading slashes
793     while (*src == '/') src++;
794     // Copy over all non-slash characters
795     while (*src && *src != '/') {
796       *dst++ = *src++;
797     }
798     // Allow one slash at the end of current directory
799     if (*src) {
800       *dst++ = *src++;
801     }
802   }
803   // If ending with slash, remove the trailing slash
804   if (dst > norm_key.begin() && *(dst - 1) == '/') dst--;
805   norm_key.resize(dst - norm_key.begin());
806   return norm_key;
807 }
808 
InsertKeyValue(const std::string & key,const std::string & value)809 Status CoordinationServiceStandaloneImpl::InsertKeyValue(
810     const std::string& key, const std::string& value) {
811   const std::string& norm_key = NormalizeKey(key);
812   mutex_lock l(kv_mu_);
813   if (kv_store_.find(norm_key) != kv_store_.end()) {
814     return MakeCoordinationError(
815         errors::AlreadyExists("Config key ", key, " already exists."));
816   }
817   kv_store_.emplace(norm_key, value);
818   auto iter = get_cb_.find(norm_key);
819   if (iter != get_cb_.end()) {
820     for (const auto& cb : iter->second) {
821       cb(value);
822     }
823     get_cb_.erase(iter);
824   }
825   return OkStatus();
826 }
827 
GetKeyValueAsync(const std::string & key,StatusOrValueCallback done)828 void CoordinationServiceStandaloneImpl::GetKeyValueAsync(
829     const std::string& key, StatusOrValueCallback done) {
830   const std::string& norm_key = NormalizeKey(key);
831   mutex_lock l(kv_mu_);
832   const auto& iter = kv_store_.find(norm_key);
833   if (iter != kv_store_.end()) {
834     done(iter->second);
835     return;
836   }
837   auto cb_iter = get_cb_.find(norm_key);
838   if (cb_iter == get_cb_.end()) {
839     cb_iter =
840         get_cb_.emplace(norm_key, std::vector<StatusOrValueCallback>()).first;
841   }
842   cb_iter->second.emplace_back(std::move(done));
843 }
844 
TryGetKeyValue(const std::string & key)845 StatusOr<std::string> CoordinationServiceStandaloneImpl::TryGetKeyValue(
846     const std::string& key) {
847   const std::string& norm_key = NormalizeKey(key);
848   mutex_lock l(kv_mu_);
849   const auto& iter = kv_store_.find(norm_key);
850   if (iter == kv_store_.end()) {
851     return errors::NotFound("Config key ", key, " not found.");
852   }
853   return iter->second;
854 }
855 
GetKeyValueDir(absl::string_view directory_key)856 std::vector<KeyValueEntry> CoordinationServiceStandaloneImpl::GetKeyValueDir(
857     absl::string_view directory_key) {
858   std::vector<KeyValueEntry> kvs_in_directory;
859   const std::string norm_key = NormalizeKey(directory_key);
860   const std::string dir = absl::StrCat(norm_key, "/");
861 
862   mutex_lock l(kv_mu_);
863   // Find first key in ordered map that has the directory prefix.
864   auto begin = kv_store_.lower_bound(dir);
865   std::map<std::string, std::string>::iterator it;
866   // Iterate through key range that match directory prefix.
867   for (it = begin; it != kv_store_.end(); ++it) {
868     // Stop once the next key does not have the directory prefix. Since keys are
869     // ordered, none of the other keys would have a matching prefix.
870     if (std::mismatch(dir.begin(), dir.end(), it->first.begin()).first !=
871         dir.end()) {
872       break;
873     }
874     KeyValueEntry kv;
875     kv.set_key(it->first);
876     kv.set_value(it->second);
877     kvs_in_directory.push_back(kv);
878   }
879 
880   return kvs_in_directory;
881 }
882 
DeleteKeyValue(const std::string & key)883 Status CoordinationServiceStandaloneImpl::DeleteKeyValue(
884     const std::string& key) {
885   const std::string& norm_key = NormalizeKey(key);
886   mutex_lock l(kv_mu_);
887   // Delete directory: find key range that match directory prefix
888   const std::string& dir = strings::StrCat(norm_key, "/");
889   auto begin = kv_store_.lower_bound(dir);
890   std::map<std::string, std::string>::iterator end;
891   for (end = begin; end != kv_store_.end(); end++) {
892     if (std::mismatch(dir.begin(), dir.end(), end->first.begin()).first !=
893         dir.end())
894       break;
895   }
896   kv_store_.erase(begin, end);
897   auto iter = kv_store_.find(norm_key);
898   if (iter != kv_store_.end()) {
899     kv_store_.erase(iter);
900   }
901   return OkStatus();
902 }
903 
SetTaskError(absl::string_view task_name,Status error)904 void CoordinationServiceStandaloneImpl::SetTaskError(
905     absl::string_view task_name, Status error) {
906   cluster_state_[task_name]->SetError(error);
907   for (const auto& barrier_id :
908        cluster_state_[task_name]->GetOngoingBarriers()) {
909     Status error = MakeCoordinationError(errors::Internal(absl::StrCat(
910         "Barrier failed from a task error. Barrier Id: ", barrier_id,
911         ", Task: ", task_name)));
912     PassBarrier(barrier_id, error, &barriers_[barrier_id]);
913   }
914 
915   LOG(ERROR) << task_name
916              << " has been set to ERROR in coordination service: " << error;
917 }
918 
BarrierAsync(const std::string & barrier_id,absl::Duration timeout,const CoordinatedTask & task,const std::vector<CoordinatedTask> & participating_tasks,StatusCallback done)919 void CoordinationServiceStandaloneImpl::BarrierAsync(
920     const std::string& barrier_id, absl::Duration timeout,
921     const CoordinatedTask& task,
922     const std::vector<CoordinatedTask>& participating_tasks,
923     StatusCallback done) {
924   mutex_lock l(state_mu_);
925   auto pair = barriers_.try_emplace(barrier_id);
926   auto it = pair.first;
927   bool inserted = pair.second;
928   auto* barrier = &it->second;
929   // Create barrier for the first time.
930   if (inserted) {
931     // Initialize barrier state.
932     barrier->passed = false;
933     // Assume barrier is for entire cluster if no tasks are specified.
934     if (participating_tasks.empty()) {
935       for (const auto& task_state : cluster_state_) {
936         absl::string_view task_name = task_state.first;
937         barrier->tasks_at_barrier[GetTaskFromName(task_name)] = false;
938       }
939     } else {
940       for (const auto& task : participating_tasks) {
941         // Fail the barrier immediately if unexpected task is included in the
942         // barrier.
943         const std::string task_name = GetTaskName(task);
944         if (!cluster_state_.contains(task_name)) {
945           Status error = MakeCoordinationError(errors::InvalidArgument(
946               absl::StrCat("Unexpected task (", task_name,
947                            ") that is not in the cluster called the barrier. "
948                            "Barrier Id: ",
949                            barrier_id)));
950           PassBarrier(barrier_id, error, barrier);
951           done(error);
952           return;
953         }
954         barrier->tasks_at_barrier[task] = false;
955       }
956     }
957     barrier->num_pending_tasks = barrier->tasks_at_barrier.size();
958 
959     // Fail the barrier immediately if any tasks are already in error.
960     for (const auto& pending_task : barrier->tasks_at_barrier) {
961       const std::string task_name = GetTaskName(pending_task.first);
962       if (cluster_state_[task_name]->GetState() ==
963           CoordinatedTaskState::TASKSTATE_ERROR) {
964         Status error = MakeCoordinationError(errors::Internal(
965             absl::StrCat("Task (", task_name,
966                          ") is already in error before the barrier "
967                          "was called. Barrier Id: ",
968                          barrier_id)));
969         PassBarrier(barrier_id, error, barrier);
970         done(error);
971         return;
972       }
973     }
974     barrier->deadline_in_micros =
975         Env::Default()->NowMicros() + (timeout / absl::Microseconds(1));
976 
977     // Add ongoing barrier to cluster state.
978     ongoing_barriers_.emplace(barrier_id);
979     const size_t num_ongoing_barriers = ongoing_barriers_.size();
980     if (num_ongoing_barriers > kOngoingBarriersSoftLimit) {
981       LOG(WARNING) << "There is a high number of ongoing barriers in "
982                       "coordination service: "
983                    << num_ongoing_barriers;
984     }
985     for (const auto& pending_task : barrier->tasks_at_barrier) {
986       const CoordinatedTask& task = pending_task.first;
987       cluster_state_[GetTaskName(task)]->JoinBarrier(barrier_id);
988     }
989   }
990 
991   // Barrier has already been passed, return previous result immediately.
992   if (barrier->passed) {
993     // Special hook for shutdown barrier to disconnect task.
994     if (barrier_id == shutdown_barrier_id_) {
995       Status s = DisconnectTask(task);
996       // Return any errors from the disconnect attempt, otherwise return the
997       // barrier status outside of this hook.
998       if (!s.ok()) {
999         done(s);
1000         return;
1001       }
1002     }
1003 
1004     done(barrier->result);
1005     return;
1006   }
1007 
1008   // Add pending callbacks.
1009   barrier->done_callbacks.push_back(done);
1010 
1011   // Check if caller task is participating in the barrier.
1012   if (!barrier->tasks_at_barrier.contains(task)) {
1013     // Unexpected barrier call from a task not participating in the barrier.
1014     Status error = MakeCoordinationError(errors::InvalidArgument(
1015         absl::StrCat("A non-participating task (", GetTaskName(task),
1016                      ") called the barrier: ", barrier_id)));
1017     PassBarrier(barrier_id, error, barrier);
1018     return;
1019   }
1020 
1021   // Check if task args are specified consistently across barrier calls.
1022   if (!ValidateTaskArgs(participating_tasks, barrier->tasks_at_barrier,
1023                         cluster_state_.size())) {
1024     Status error = MakeCoordinationError(errors::InvalidArgument(absl::StrCat(
1025         "Conflicting tasks specified for the same barrier: ", barrier_id)));
1026     PassBarrier(barrier_id, error, barrier);
1027     return;
1028   }
1029 
1030   // Remove pending task.
1031   // We need to check if task made a repeated call after reaching the barrier.
1032   if (!barrier->tasks_at_barrier[task]) {
1033     barrier->tasks_at_barrier[task] = true;
1034     --barrier->num_pending_tasks;
1035 
1036     if (barrier->num_pending_tasks == 0) {
1037       PassBarrier(barrier_id, OkStatus(), barrier);
1038       return;
1039     }
1040   }
1041 }
1042 
CancelBarrier(const std::string & barrier_id,const CoordinatedTask & task)1043 Status CoordinationServiceStandaloneImpl::CancelBarrier(
1044     const std::string& barrier_id, const CoordinatedTask& task) {
1045   mutex_lock l(state_mu_);
1046   auto [it, inserted] = barriers_.try_emplace(barrier_id);
1047   auto* barrier = &it->second;
1048   if (inserted) {
1049     LOG(WARNING) << "Barrier (" << barrier_id
1050                  << ") is cancelled before being created by task: "
1051                  << GetTaskName(task);
1052   }
1053   // Barrier has already been passed.
1054   if (barrier->passed) {
1055     return MakeCoordinationError(errors::FailedPrecondition(absl::StrCat(
1056         "Barrier (", barrier_id, ") has already been passed with status code: ",
1057         barrier->result.code())));
1058   }
1059 
1060   // Cancel barrier.
1061   Status cancelled = MakeCoordinationError(errors::Cancelled(absl::StrCat(
1062       "Barrier (", barrier_id, ") is cancelled by task: ", GetTaskName(task))));
1063   PassBarrier(barrier_id, cancelled, barrier);
1064 
1065   return OkStatus();
1066 }
1067 
1068 // Mark barrier as passed.
PassBarrier(absl::string_view barrier_id,Status result,BarrierState * barrier)1069 void CoordinationServiceStandaloneImpl::PassBarrier(
1070     absl::string_view barrier_id, Status result, BarrierState* barrier) {
1071   barrier->passed = true;
1072   barrier->result = result;
1073   // Special hook for device propagation barrier to set global device ids.
1074   if (barrier_id == device_propagation_barrier_id_) {
1075     SetXlaGlobalDeviceIds();
1076   }
1077   for (const auto& task_at_barrier : barrier->tasks_at_barrier) {
1078     // Clean up task state (used as error hooks).
1079     const CoordinatedTask& task = task_at_barrier.first;
1080     cluster_state_[GetTaskName(task)]->ExitBarrier(barrier_id);
1081   }
1082 
1083   // Special hook for shutdown barrier to disconnect tasks at the barrier.
1084   if (barrier_id == shutdown_barrier_id_) {
1085     if (result.ok()) {
1086       LOG(INFO) << "Shutdown barrier in coordination service has passed.";
1087     } else {
1088       LOG(ERROR) << "Shutdown barrier in coordination service has failed: "
1089                  << result
1090                  << ". This suggests that at least one worker did not complete "
1091                     "its job, or was too slow/hanging in its execution.";
1092     }
1093     Status shutdown_error = MakeCoordinationError(errors::Internal(
1094         absl::StrCat("Shutdown barrier has been passed with status: '",
1095                      barrier->result.ToString(),
1096                      "', but this task is not at the barrier yet.")));
1097     for (const auto& [task, at_barrier] : barrier->tasks_at_barrier) {
1098       if (at_barrier) {
1099         // Disconnect tasks that reached the barrier.
1100         Status disconnect_status = DisconnectTask(task);
1101         if (!disconnect_status.ok()) {
1102           LOG(ERROR) << disconnect_status;
1103         }
1104       } else {
1105         // Propagate errors to straggling tasks that have not reached the
1106         // barrier. The barrier must have failed if any task did not reach the
1107         // barrier.
1108         ReportServiceErrorToTaskAsync(task, shutdown_error);
1109       }
1110     }
1111   }
1112   barrier->tasks_at_barrier.clear();
1113   ongoing_barriers_.erase(barrier_id);
1114   // Note: barrier_id shouldn't be referenced after this line as its lifetime
1115   // may be tied to one of the callbacks.
1116   // Propagate results to participating tasks.
1117   for (const auto& callback : barrier->done_callbacks) {
1118     callback(result);
1119   }
1120   barrier->done_callbacks.clear();
1121 }
1122 
ValidateTaskArgs(const std::vector<CoordinatedTask> & tasks_args,const absl::flat_hash_map<CoordinatedTask,bool,CoordinatedTaskHash,CoordinatedTaskEqual> & tasks_at_barrier,int64_t cluster_size)1123 bool CoordinationServiceStandaloneImpl::ValidateTaskArgs(
1124 
1125     const std::vector<CoordinatedTask>& tasks_args,
1126     const absl::flat_hash_map<CoordinatedTask, bool, CoordinatedTaskHash,
1127                               CoordinatedTaskEqual>& tasks_at_barrier,
1128     int64_t cluster_size) {
1129   if (tasks_args.empty()) {
1130     return tasks_at_barrier.size() == cluster_size;
1131   } else if (tasks_at_barrier.size() != tasks_args.size()) {
1132     return false;
1133   } else {
1134     for (const auto& task : tasks_args) {
1135       if (!tasks_at_barrier.contains(task)) {
1136         return false;
1137       }
1138     }
1139   }
1140   return true;
1141 }
1142 
SetXlaGlobalDeviceIds()1143 void CoordinationServiceStandaloneImpl::SetXlaGlobalDeviceIds() {
1144   // No-op if TF devices are specified.
1145   if (cluster_devices_.has_xla()) {
1146     int global_id = 0;
1147     for (xla::LocalTopologyProto& local_topology :
1148          *cluster_devices_.mutable_xla()->mutable_devices()->mutable_nodes()) {
1149       for (xla::DeviceProto& device : *local_topology.mutable_devices()) {
1150         device.set_global_device_id(global_id);
1151         ++global_id;
1152       }
1153     }
1154   }
1155 }
1156 }  // namespace
1157 
EnableCoordinationService(Env * env,const ServerDef & server_def,std::unique_ptr<CoordinationClientCache> cache)1158 std::unique_ptr<CoordinationServiceInterface> EnableCoordinationService(
1159     Env* env, const ServerDef& server_def,
1160     std::unique_ptr<CoordinationClientCache> cache) {
1161   std::unique_ptr<CoordinationServiceInterface> coord_service;
1162   if (is_multi_client_leader(server_def)) {
1163     coord_service = std::make_unique<CoordinationServiceStandaloneImpl>(
1164         std::move(cache), env, server_def);
1165   }
1166   return coord_service;
1167 }
1168 
isRecoverableJob(const absl::string_view task_name) const1169 bool CoordinationServiceStandaloneImpl::isRecoverableJob(
1170     const absl::string_view task_name) const {
1171   return recoverable_jobs_.find(task_name) != recoverable_jobs_.end();
1172 }
1173 
1174 // Register standalone coordination service implementation.
1175 REGISTER_COORDINATION_SERVICE("standalone", EnableCoordinationService);
1176 
1177 }  // namespace tensorflow
1178