xref: /aosp_15_r20/external/tensorflow/tensorflow/compiler/xla/pjrt/distributed/client.cc (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2020 Google LLC
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/xla/pjrt/distributed/client.h"
17 
18 #include <algorithm>
19 #include <chrono>  // NOLINT
20 #include <memory>
21 #include <optional>
22 #include <random>
23 #include <string>
24 #include <utility>
25 
26 #include "absl/synchronization/mutex.h"
27 #include "absl/synchronization/notification.h"
28 #include "absl/time/time.h"
29 #include "grpcpp/channel.h"
30 #include "tensorflow/compiler/xla/pjrt/distributed/protocol.h"
31 #include "tensorflow/compiler/xla/pjrt/distributed/util.h"
32 #include "tensorflow/compiler/xla/util.h"
33 #include "tensorflow/core/distributed_runtime/coordination/coordination_client.h"
34 #include "tensorflow/core/distributed_runtime/coordination/coordination_service_agent.h"
35 #include "tensorflow/core/distributed_runtime/coordination/coordination_service_error_util.h"
36 #include "tensorflow/core/distributed_runtime/rpc/coordination/grpc_coordination_client.h"
37 #include "tensorflow/core/platform/errors.h"
38 #include "tensorflow/core/platform/random.h"
39 #include "tensorflow/core/protobuf/coordination_config.pb.h"
40 #include "tensorflow/core/protobuf/coordination_service.pb.h"
41 
42 namespace xla {
43 class DistributedRuntimeClientImpl : public DistributedRuntimeClient {
44  public:
45   DistributedRuntimeClientImpl(std::shared_ptr<::grpc::Channel> channel,
46                                const Options& options);
DistributedRuntimeClientImpl(std::shared_ptr<::grpc::Channel> channel)47   explicit DistributedRuntimeClientImpl(
48       std::shared_ptr<::grpc::Channel> channel)
49       : DistributedRuntimeClientImpl(channel, Options()) {}
50   ~DistributedRuntimeClientImpl() override;
51 
52   xla::Status Connect() override;
53   xla::Status Shutdown() override;
54   xla::Status EnumerateDevices(const LocalTopologyProto& local_topology,
55                                GlobalTopologyProto* global_topology) override;
56   xla::StatusOr<std::string> BlockingKeyValueGet(
57       std::string key, absl::Duration timeout) override;
58   xla::Status KeyValueSet(std::string key, std::string value) override;
59   xla::Status WaitAtBarrier(std::string barrier_id,
60                             absl::Duration timeout) override;
61   xla::StatusOr<tensorflow::CoordinationServiceAgent*>
62   GetCoordinationServiceAgent() override;
63 
64  private:
65   // Entry point for the heartbeat thread.
66   void HeartbeatLoop();
67 
68   const std::unique_ptr<grpc::DistributedRuntimeService::Stub> stub_;
69   const DistributedRuntimeClient::Options options_;
70 
71   // Possible states of the client.
72   // The only legal transitions are downwards in the order below. i.e., there is
73   // no way to reopen a closed client.
74   enum class State {
75     // The client has not yet connected to the server, i.e., had a Connect()
76     // RPC succeed.
77     kNotConnected,
78 
79     // The client is connected to the server and as far as we are aware the
80     // connection is healthy.
81     kConnected,
82 
83     // The client is in the process of shutting down, i.e., Shutdown() has been
84     // called.
85     kShuttingDown,
86 
87     // The client has shut down its server connection, either due to an error
88     // or due to an explicit shutdown.
89     kClosed,
90   };
91 
92   static absl::string_view StateToString(State state);
93 
94   // state_ is protected by a mutex because the heartbeat thread needs to look
95   // at it.
96   absl::Mutex mu_;
97   State state_ ABSL_GUARDED_BY(mu_) = State::kNotConnected;
98 
99   // A unique session ID, assigned by the server during Connect().
100   uint64_t session_id_;
101 
102   // Notification that tells the heartbeat thread to stop running.
103   absl::Notification stop_heartbeats_;
104 
105   // Thread responsible for performing heartbeats.
106   std::unique_ptr<tensorflow::Thread> heartbeat_thread_;
107 };
108 
109 class DistributedRuntimeCoordinationServiceClient
110     : public DistributedRuntimeClient {
111  public:
112   DistributedRuntimeCoordinationServiceClient(
113       std::shared_ptr<::grpc::Channel> channel, const Options& options);
DistributedRuntimeCoordinationServiceClient(std::shared_ptr<::grpc::Channel> channel)114   explicit DistributedRuntimeCoordinationServiceClient(
115       std::shared_ptr<::grpc::Channel> channel)
116       : DistributedRuntimeCoordinationServiceClient(channel, Options()) {}
117   ~DistributedRuntimeCoordinationServiceClient() override;
118 
119   xla::Status Connect() override;
120   xla::Status Shutdown() override;
121   xla::Status EnumerateDevices(const LocalTopologyProto& local_topology,
122                                GlobalTopologyProto* global_topology) override;
123   xla::StatusOr<std::string> BlockingKeyValueGet(
124       std::string key, absl::Duration timeout) override;
125   xla::Status KeyValueSet(std::string key, std::string value) override;
126   xla::Status WaitAtBarrier(std::string barrier_id,
127                             absl::Duration timeout) override;
128   xla::StatusOr<tensorflow::CoordinationServiceAgent*>
129   GetCoordinationServiceAgent() override;
130 
131  private:
132   std::unique_ptr<tensorflow::CoordinationServiceAgent> coord_agent_;
133   tensorflow::CoordinationServiceConfig config_;
134   absl::Duration min_connect_barrier_timeout_;
135   int task_id_;
136 };
137 
DistributedRuntimeClientImpl(std::shared_ptr<::grpc::Channel> channel,const Options & options)138 DistributedRuntimeClientImpl::DistributedRuntimeClientImpl(
139     std::shared_ptr<::grpc::Channel> channel, const Options& options)
140     : stub_(grpc::DistributedRuntimeService::NewStub(std::move(channel))),
141       options_(options) {}
142 
~DistributedRuntimeClientImpl()143 DistributedRuntimeClientImpl::~DistributedRuntimeClientImpl() {
144   bool connected;
145   {
146     absl::MutexLock lock(&mu_);
147     connected = (state_ == State::kConnected);
148   }
149   if (connected) {
150     if (options_.shutdown_on_destruction) {
151       Status status = Shutdown();
152       if (!status.ok()) {
153         LOG(WARNING) << "PJRT shutdown failed: " << status;
154       }
155     } else {
156       if (!stop_heartbeats_.HasBeenNotified()) {
157         stop_heartbeats_.Notify();
158       }
159     }
160   }
161 }
162 
StateToString(State state)163 /*static*/ absl::string_view DistributedRuntimeClientImpl::StateToString(
164     State state) {
165   switch (state) {
166     case State::kNotConnected:
167       return "kNotConnected";
168     case State::kConnected:
169       return "kConnected";
170     case State::kShuttingDown:
171       return "kShuttingDown";
172     case State::kClosed:
173       return "kClosed";
174   }
175 }
176 
Connect()177 xla::Status DistributedRuntimeClientImpl::Connect() {
178   {
179     absl::MutexLock lock(&mu_);
180     if (state_ != State::kNotConnected) {
181       return xla::FailedPrecondition("Connect() called when client in state %s",
182                                      StateToString(state_));
183     }
184   }
185   ConnectRequest request;
186   request.set_protocol_version(DistributedRuntimeProtocolVersion());
187   request.set_timeout_milliseconds(
188       absl::ToInt64Milliseconds(options_.rpc_timeout) / 2);
189   request.set_node_id(options_.node_id);
190   VLOG(10) << "Connect: " << request.DebugString();
191   ConnectResponse response;
192   ::grpc::Status status;
193   absl::Time deadline = absl::Now() + options_.init_timeout;
194   int attempt = 0;
195   std::default_random_engine generator;
196   std::uniform_real_distribution<double> distribution(0.0, 1.0);
197   do {
198     ::grpc::ClientContext ctx;
199     ctx.set_fail_fast(false);
200     ctx.set_deadline(absl::ToChronoTime(absl::Now() + options_.rpc_timeout));
201     request.set_client_id(tensorflow::random::New64());
202     response.Clear();
203     status = stub_->Connect(&ctx, request, &response);
204     if (!status.ok()) {
205       VLOG(1) << "Connect failed() with status: " << FromGrpcStatus(status);
206       if (attempt % 10 == 0) {
207         LOG(INFO) << "Connect failed() with status: " << FromGrpcStatus(status);
208       }
209       // Exponential backoff with jitter. Note we will retry for `init_timeout`
210       // time in total; the `14` here corresponds to an ~16s maximum interval
211       // between connection attempts.
212       int backoff = 1 << std::min(14, attempt);
213       absl::SleepFor(absl::Milliseconds(backoff * distribution(generator)));
214     }
215     ++attempt;
216   } while (!status.ok() && absl::Now() < deadline);
217   if (!status.ok()) {
218     LOG(ERROR) << "Connect() failed after " << attempt << " retries in "
219                << options_.init_timeout
220                << "; most recent failure status: " << FromGrpcStatus(status);
221     return tensorflow::errors::DeadlineExceeded(
222         absl::StrFormat("Connect() timed out after %s with %d attempts. Most "
223                         "recent failure was: %s",
224                         absl::FormatDuration(options_.init_timeout), attempt,
225                         FromGrpcStatus(status).ToString()));
226   }
227   VLOG(10) << "Connect() response: " << response.DebugString();
228   {
229     absl::MutexLock lock(&mu_);
230     state_ = State::kConnected;
231   }
232   session_id_ = response.session_id();
233 
234   heartbeat_thread_.reset(options_.env->StartThread(
235       tensorflow::ThreadOptions(), "pjrt_distributed_heartbeat",
236       [this]() { HeartbeatLoop(); }));
237   LOG(INFO) << "Connected to distributed JAX controller";
238   return OkStatus();
239 }
240 
EnumerateDevices(const LocalTopologyProto & local_topology,GlobalTopologyProto * global_topology)241 xla::Status DistributedRuntimeClientImpl::EnumerateDevices(
242     const LocalTopologyProto& local_topology,
243     GlobalTopologyProto* global_topology) {
244   {
245     absl::MutexLock lock(&mu_);
246     if (state_ != State::kConnected) {
247       return xla::FailedPrecondition(
248           "EnumerateDevices() called when client not connected.");
249     }
250   }
251   ::grpc::ClientContext ctx;
252   ctx.set_fail_fast(false);
253   ctx.set_deadline(absl::ToChronoTime(absl::Now() + options_.rpc_timeout));
254   EnumerateDevicesRequest request;
255   request.set_session_id(session_id_);
256   *request.mutable_local_topology() = local_topology;
257   request.mutable_local_topology()->set_node_id(options_.node_id);
258 
259   VLOG(10) << "EnumerateDevices: " << request.DebugString();
260   EnumerateDevicesResponse response;
261   ::grpc::Status status = stub_->EnumerateDevices(&ctx, request, &response);
262   if (!status.ok()) {
263     return FromGrpcStatus(status);
264   }
265   VLOG(10) << "EnumerateDevices() response: " << response.DebugString();
266   response.mutable_global_topology()->Swap(global_topology);
267   return OkStatus();
268 }
269 
Shutdown()270 xla::Status DistributedRuntimeClientImpl::Shutdown() {
271   LOG(INFO) << "Waiting for all distributed JAX tasks to shut down.";
272   ::grpc::ClientContext ctx;
273   {
274     absl::MutexLock lock(&mu_);
275     if (state_ != State::kConnected) {
276       return xla::FailedPrecondition(
277           "Shutdown() called when client not connected.");
278     }
279     state_ = State::kShuttingDown;
280   }
281   ctx.set_fail_fast(false);
282   ctx.set_deadline(absl::ToChronoTime(absl::Now() + options_.shutdown_timeout));
283   ShutdownRequest request;
284   request.set_session_id(session_id_);
285   VLOG(10) << "Shutdown: " << request.DebugString();
286   ShutdownResponse response;
287   ::grpc::Status status = stub_->Shutdown(&ctx, request, &response);
288 
289   LOG(INFO) << "Distributed task shutdown result: " << FromGrpcStatus(status);
290   if (!status.ok()) {
291     return FromGrpcStatus(status);
292   }
293   if (!stop_heartbeats_.HasBeenNotified()) {
294     stop_heartbeats_.Notify();
295   }
296   VLOG(10) << "Shutdown() response: " << response.DebugString();
297   absl::MutexLock lock(&mu_);
298   state_ = State::kClosed;
299   return OkStatus();
300 }
301 
BlockingKeyValueGet(std::string key,absl::Duration timeout)302 xla::StatusOr<std::string> DistributedRuntimeClientImpl::BlockingKeyValueGet(
303     std::string key, absl::Duration timeout) {
304   {
305     absl::MutexLock lock(&mu_);
306     if (state_ != State::kConnected) {
307       return xla::FailedPrecondition(
308           "BlockingKeyValueGet() called when client not connected.");
309     }
310   }
311   ::grpc::ClientContext ctx;
312   ctx.set_fail_fast(false);
313   ctx.set_deadline(absl::ToChronoTime(absl::Now() + timeout));
314   KeyValueGetRequest request;
315   request.set_session_id(session_id_);
316   request.set_key(std::move(key));
317   timeout = std::min(timeout, absl::Minutes(10));  // Avoid overflow
318   request.set_timeout_milliseconds(absl::ToInt64Milliseconds(timeout));
319   VLOG(10) << "BlockingKeyValueGet: " << request.DebugString();
320   KeyValueGetResponse response;
321   ::grpc::Status status = stub_->KeyValueGet(&ctx, request, &response);
322   if (!status.ok()) {
323     return FromGrpcStatus(status);
324   }
325   return response.value();
326 }
327 
KeyValueSet(std::string key,std::string value)328 xla::Status DistributedRuntimeClientImpl::KeyValueSet(std::string key,
329                                                       std::string value) {
330   {
331     absl::MutexLock lock(&mu_);
332     if (state_ != State::kConnected) {
333       return xla::FailedPrecondition(
334           "KeyValueSet() called when client not connected.");
335     }
336   }
337   ::grpc::ClientContext ctx;
338   ctx.set_fail_fast(false);
339   ctx.set_deadline(absl::ToChronoTime(absl::Now() + options_.rpc_timeout));
340   KeyValueSetRequest request;
341   request.set_session_id(session_id_);
342   request.set_key(std::move(key));
343   request.set_value(std::move(value));
344   VLOG(10) << "KeyValueSet: " << request.DebugString();
345   KeyValueSetResponse response;
346   ::grpc::Status status = stub_->KeyValueSet(&ctx, request, &response);
347   return FromGrpcStatus(status);
348 }
349 
WaitAtBarrier(std::string barrier_id,absl::Duration timeout)350 xla::Status DistributedRuntimeClientImpl::WaitAtBarrier(
351     std::string barrier_id, absl::Duration timeout) {
352   {
353     absl::MutexLock lock(&mu_);
354     if (state_ != State::kConnected) {
355       return xla::FailedPrecondition(
356           "WaitAtBarrier() called when client not connected.");
357     }
358   }
359   ::grpc::ClientContext ctx;
360   ctx.set_fail_fast(false);
361   // Set timeout to be at least 5 seconds so that there is time for service-side
362   // timeout logic to execute.
363   ctx.set_deadline(
364       absl::ToChronoTime(absl::Now() + std::max(timeout, absl::Seconds(5))));
365   WaitAtBarrierRequest request;
366   request.set_session_id(session_id_);
367   request.set_barrier_id(std::move(barrier_id));
368   request.set_node_id(options_.node_id);
369   // TODO(yashkatariya,hanyuangtay): Change timeout_milliseconds to int64 in
370   // protocol.proto so that we don't need a minimum timeout here.
371   timeout = std::min(timeout, absl::Minutes(10));  // Avoid overflow
372   request.set_timeout_milliseconds(absl::ToInt64Milliseconds(timeout));
373   VLOG(10) << "WaitAtBarrier: " << request.DebugString();
374   WaitAtBarrierResponse response;
375   ::grpc::Status status = stub_->WaitAtBarrier(&ctx, request, &response);
376   return FromGrpcStatus(status);
377 }
378 
379 xla::StatusOr<tensorflow::CoordinationServiceAgent*>
GetCoordinationServiceAgent()380 DistributedRuntimeClientImpl::GetCoordinationServiceAgent() {
381   return xla::Internal(
382       "Invoking GetCoordinationServiceAgent() while coordination service is "
383       "not enabled. Enable coordination service via "
384       "--jax_coordination_service.");
385 }
386 
HeartbeatLoop()387 void DistributedRuntimeClientImpl::HeartbeatLoop() {
388   int num_missing_heartbeats = 0;
389   while (true) {
390     stop_heartbeats_.WaitForNotificationWithTimeout(
391         options_.heartbeat_interval);
392     if (stop_heartbeats_.HasBeenNotified()) {
393       return;
394     }
395 
396     ::grpc::ClientContext ctx;
397     ctx.set_fail_fast(false);
398     ctx.set_deadline(
399         absl::ToChronoTime(absl::Now() + options_.heartbeat_interval));
400     HeartbeatRequest request;
401     request.set_session_id(session_id_);
402     request.set_node_id(options_.node_id);
403     VLOG(10) << "Heartbeat: " << request.DebugString();
404     HeartbeatResponse response;
405     ::grpc::Status status = stub_->Heartbeat(&ctx, request, &response);
406     if (status.ok()) {
407       VLOG(10) << "Heartbeat ok";
408       num_missing_heartbeats = 0;
409     } else {
410       ++num_missing_heartbeats;
411       VLOG(10) << "Heartbeat error, "
412                << options_.max_missing_heartbeats - num_missing_heartbeats
413                << " tries left: " << status.error_message();
414       bool is_transient_error =
415           (status.error_code() == ::grpc::StatusCode::DEADLINE_EXCEEDED ||
416            status.error_code() == ::grpc::StatusCode::UNAVAILABLE);
417       if (!stop_heartbeats_.HasBeenNotified() &&
418           (!is_transient_error ||
419            num_missing_heartbeats >= options_.max_missing_heartbeats)) {
420         // If we are shutting down, missed heartbeats are benign: they may
421         // simply mean that the server has shut down already before it saw
422         // the heartbeat request.
423         absl::MutexLock lock(&mu_);
424         if (state_ != State::kShuttingDown) {
425           options_.missed_heartbeat_callback(FromGrpcStatus(status),
426                                              !is_transient_error);
427         }
428         return;
429       }
430     }
431   }
432 }
433 
434 DistributedRuntimeCoordinationServiceClient::
DistributedRuntimeCoordinationServiceClient(std::shared_ptr<::grpc::Channel> channel,const Options & options)435     DistributedRuntimeCoordinationServiceClient(
436         std::shared_ptr<::grpc::Channel> channel, const Options& options) {
437   // Convert options to coordination config.
438   tensorflow::CoordinationServiceConfig config;
439   config.set_service_type("standalone");
440   config.set_service_leader("/job:jax_worker/task:0");
441   config.set_cluster_register_timeout_in_ms(
442       absl::ToInt64Milliseconds(options.init_timeout));
443   min_connect_barrier_timeout_ = options.rpc_timeout;
444   config.set_heartbeat_timeout_in_ms(absl::ToInt64Milliseconds(
445       options.heartbeat_interval * options.max_missing_heartbeats));
446   config.set_shutdown_barrier_timeout_in_ms(
447       absl::ToInt64Milliseconds(options.shutdown_timeout));
448   config.set_agent_destruction_without_shutdown(
449       !options.shutdown_on_destruction);
450   auto error_fn =
451       [timeout_fn = options.missed_heartbeat_callback](const Status& status) {
452         LOG(ERROR) << "Coordination service agent in error status: " << status;
453         timeout_fn(status, /*coordinator_reported_failure=*/true);
454       };
455 
456   std::unique_ptr<tensorflow::CoordinationClient> leader_client;
457   leader_client.reset(tensorflow::NewGrpcCoordinationClient(channel));
458   coord_agent_ = tensorflow::CreateCoordinationServiceAgent();
459   const Status status =
460       coord_agent_->Initialize(options.env, "jax_worker", options.node_id,
461                                config, std::move(leader_client), error_fn);
462   if (!status.ok()) {
463     LOG(ERROR) << "Coordination agent failed to initialize: " << status;
464   }
465   task_id_ = options.node_id;
466   config_ = config;
467 }
468 
469 DistributedRuntimeCoordinationServiceClient::
~DistributedRuntimeCoordinationServiceClient()470     ~DistributedRuntimeCoordinationServiceClient() {}
471 
Connect()472 xla::Status DistributedRuntimeCoordinationServiceClient::Connect() {
473   Status s = tensorflow::errors::Unknown("Connection not attempted yet.");
474   absl::Duration timeout =
475       absl::Milliseconds(config_.cluster_register_timeout_in_ms());
476   absl::Time deadline = absl::Now() + timeout;
477   int attempt = 0;
478   std::default_random_engine generator;
479   std::uniform_real_distribution<double> distribution(0.0, 1.0);
480 
481   do {
482     ++attempt;
483     s = coord_agent_->Connect();
484     if (s.ok()) {
485       absl::Duration barrier_timeout = deadline - absl::Now();
486       // Note: `init_timeout` in client options may be set to 0 so that the
487       // client only attempts to connect once. In that case, we provide some
488       // buffer time to wait for all tasks.
489       barrier_timeout = std::max(barrier_timeout, min_connect_barrier_timeout_);
490       s = coord_agent_->WaitAtBarrier("PjRT_Client_Connect", barrier_timeout,
491                                       /*tasks=*/{});
492     }
493     // Exponential backoff with jitter. Note we will retry for `init_timeout`
494     // time in total; the `14` here corresponds to an ~16s maximum interval
495     // between connection attempts.
496 
497     int backoff = 1 << std::min(14, attempt);
498     absl::SleepFor(absl::Milliseconds(backoff * distribution(generator)));
499   } while (!s.ok() && absl::Now() < deadline &&
500            // Retries are only made for RPC errors. If a valid service error is
501            // returned, fail immediately.
502            s.GetPayload(tensorflow::CoordinationErrorPayloadKey()) ==
503                std::nullopt);
504   if (s.ok()) {
505     LOG(INFO) << "Connected to distributed JAX controller";
506   } else {
507     LOG(INFO) << "Failed to connect to distributed JAX controller: " << s;
508   }
509   return s;
510 }
511 
Shutdown()512 xla::Status DistributedRuntimeCoordinationServiceClient::Shutdown() {
513   LOG(INFO) << "Distributed task shutdown initiated.";
514   Status s = coord_agent_->Shutdown();
515   LOG(INFO) << "Distributed task shutdown result: " << s;
516   return s;
517 }
518 
EnumerateDevices(const LocalTopologyProto & local_topology,GlobalTopologyProto * global_topology)519 xla::Status DistributedRuntimeCoordinationServiceClient::EnumerateDevices(
520     const LocalTopologyProto& local_topology,
521     GlobalTopologyProto* global_topology) {
522   tensorflow::CoordinationServiceDeviceInfo devices;
523   LocalTopologyProto* device =
524       devices.mutable_xla()->mutable_devices()->add_nodes();
525   *device = local_topology;
526   device->set_node_id(task_id_);
527   Status s = coord_agent_->WaitForAllTasks(devices);
528   if (!s.ok()) return s;
529   *global_topology = coord_agent_->GetClusterDeviceInfo().xla().devices();
530   return OkStatus();
531 }
532 
533 xla::StatusOr<std::string>
BlockingKeyValueGet(std::string key,absl::Duration timeout)534 DistributedRuntimeCoordinationServiceClient::BlockingKeyValueGet(
535     std::string key, absl::Duration timeout) {
536   return coord_agent_->GetKeyValue(key, timeout);
537 }
538 
KeyValueSet(std::string key,std::string value)539 xla::Status DistributedRuntimeCoordinationServiceClient::KeyValueSet(
540     std::string key, std::string value) {
541   return coord_agent_->InsertKeyValue(key, value);
542 }
543 
WaitAtBarrier(std::string barrier_id,absl::Duration timeout)544 xla::Status DistributedRuntimeCoordinationServiceClient::WaitAtBarrier(
545     std::string barrier_id, absl::Duration timeout) {
546   return coord_agent_->WaitAtBarrier(barrier_id, timeout, /*tasks=*/{});
547 }
548 
549 xla::StatusOr<tensorflow::CoordinationServiceAgent*>
GetCoordinationServiceAgent()550 DistributedRuntimeCoordinationServiceClient::GetCoordinationServiceAgent() {
551   return coord_agent_.get();
552 }
553 
GetDistributedRuntimeClient(std::shared_ptr<::grpc::Channel> channel,const DistributedRuntimeClient::Options & options,bool use_coordination_service)554 std::unique_ptr<DistributedRuntimeClient> GetDistributedRuntimeClient(
555     std::shared_ptr<::grpc::Channel> channel,
556     const DistributedRuntimeClient::Options& options,
557     bool use_coordination_service) {
558   if (use_coordination_service) {
559     return std::make_unique<xla::DistributedRuntimeCoordinationServiceClient>(
560         channel, options);
561   }
562   return std::make_unique<xla::DistributedRuntimeClientImpl>(channel, options);
563 }
564 }  // namespace xla
565