1 /* Copyright 2020 Google LLC
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/compiler/xla/pjrt/distributed/client.h"
17
18 #include <algorithm>
19 #include <chrono> // NOLINT
20 #include <memory>
21 #include <optional>
22 #include <random>
23 #include <string>
24 #include <utility>
25
26 #include "absl/synchronization/mutex.h"
27 #include "absl/synchronization/notification.h"
28 #include "absl/time/time.h"
29 #include "grpcpp/channel.h"
30 #include "tensorflow/compiler/xla/pjrt/distributed/protocol.h"
31 #include "tensorflow/compiler/xla/pjrt/distributed/util.h"
32 #include "tensorflow/compiler/xla/util.h"
33 #include "tensorflow/core/distributed_runtime/coordination/coordination_client.h"
34 #include "tensorflow/core/distributed_runtime/coordination/coordination_service_agent.h"
35 #include "tensorflow/core/distributed_runtime/coordination/coordination_service_error_util.h"
36 #include "tensorflow/core/distributed_runtime/rpc/coordination/grpc_coordination_client.h"
37 #include "tensorflow/core/platform/errors.h"
38 #include "tensorflow/core/platform/random.h"
39 #include "tensorflow/core/protobuf/coordination_config.pb.h"
40 #include "tensorflow/core/protobuf/coordination_service.pb.h"
41
42 namespace xla {
43 class DistributedRuntimeClientImpl : public DistributedRuntimeClient {
44 public:
45 DistributedRuntimeClientImpl(std::shared_ptr<::grpc::Channel> channel,
46 const Options& options);
DistributedRuntimeClientImpl(std::shared_ptr<::grpc::Channel> channel)47 explicit DistributedRuntimeClientImpl(
48 std::shared_ptr<::grpc::Channel> channel)
49 : DistributedRuntimeClientImpl(channel, Options()) {}
50 ~DistributedRuntimeClientImpl() override;
51
52 xla::Status Connect() override;
53 xla::Status Shutdown() override;
54 xla::Status EnumerateDevices(const LocalTopologyProto& local_topology,
55 GlobalTopologyProto* global_topology) override;
56 xla::StatusOr<std::string> BlockingKeyValueGet(
57 std::string key, absl::Duration timeout) override;
58 xla::Status KeyValueSet(std::string key, std::string value) override;
59 xla::Status WaitAtBarrier(std::string barrier_id,
60 absl::Duration timeout) override;
61 xla::StatusOr<tensorflow::CoordinationServiceAgent*>
62 GetCoordinationServiceAgent() override;
63
64 private:
65 // Entry point for the heartbeat thread.
66 void HeartbeatLoop();
67
68 const std::unique_ptr<grpc::DistributedRuntimeService::Stub> stub_;
69 const DistributedRuntimeClient::Options options_;
70
71 // Possible states of the client.
72 // The only legal transitions are downwards in the order below. i.e., there is
73 // no way to reopen a closed client.
74 enum class State {
75 // The client has not yet connected to the server, i.e., had a Connect()
76 // RPC succeed.
77 kNotConnected,
78
79 // The client is connected to the server and as far as we are aware the
80 // connection is healthy.
81 kConnected,
82
83 // The client is in the process of shutting down, i.e., Shutdown() has been
84 // called.
85 kShuttingDown,
86
87 // The client has shut down its server connection, either due to an error
88 // or due to an explicit shutdown.
89 kClosed,
90 };
91
92 static absl::string_view StateToString(State state);
93
94 // state_ is protected by a mutex because the heartbeat thread needs to look
95 // at it.
96 absl::Mutex mu_;
97 State state_ ABSL_GUARDED_BY(mu_) = State::kNotConnected;
98
99 // A unique session ID, assigned by the server during Connect().
100 uint64_t session_id_;
101
102 // Notification that tells the heartbeat thread to stop running.
103 absl::Notification stop_heartbeats_;
104
105 // Thread responsible for performing heartbeats.
106 std::unique_ptr<tensorflow::Thread> heartbeat_thread_;
107 };
108
109 class DistributedRuntimeCoordinationServiceClient
110 : public DistributedRuntimeClient {
111 public:
112 DistributedRuntimeCoordinationServiceClient(
113 std::shared_ptr<::grpc::Channel> channel, const Options& options);
DistributedRuntimeCoordinationServiceClient(std::shared_ptr<::grpc::Channel> channel)114 explicit DistributedRuntimeCoordinationServiceClient(
115 std::shared_ptr<::grpc::Channel> channel)
116 : DistributedRuntimeCoordinationServiceClient(channel, Options()) {}
117 ~DistributedRuntimeCoordinationServiceClient() override;
118
119 xla::Status Connect() override;
120 xla::Status Shutdown() override;
121 xla::Status EnumerateDevices(const LocalTopologyProto& local_topology,
122 GlobalTopologyProto* global_topology) override;
123 xla::StatusOr<std::string> BlockingKeyValueGet(
124 std::string key, absl::Duration timeout) override;
125 xla::Status KeyValueSet(std::string key, std::string value) override;
126 xla::Status WaitAtBarrier(std::string barrier_id,
127 absl::Duration timeout) override;
128 xla::StatusOr<tensorflow::CoordinationServiceAgent*>
129 GetCoordinationServiceAgent() override;
130
131 private:
132 std::unique_ptr<tensorflow::CoordinationServiceAgent> coord_agent_;
133 tensorflow::CoordinationServiceConfig config_;
134 absl::Duration min_connect_barrier_timeout_;
135 int task_id_;
136 };
137
DistributedRuntimeClientImpl(std::shared_ptr<::grpc::Channel> channel,const Options & options)138 DistributedRuntimeClientImpl::DistributedRuntimeClientImpl(
139 std::shared_ptr<::grpc::Channel> channel, const Options& options)
140 : stub_(grpc::DistributedRuntimeService::NewStub(std::move(channel))),
141 options_(options) {}
142
~DistributedRuntimeClientImpl()143 DistributedRuntimeClientImpl::~DistributedRuntimeClientImpl() {
144 bool connected;
145 {
146 absl::MutexLock lock(&mu_);
147 connected = (state_ == State::kConnected);
148 }
149 if (connected) {
150 if (options_.shutdown_on_destruction) {
151 Status status = Shutdown();
152 if (!status.ok()) {
153 LOG(WARNING) << "PJRT shutdown failed: " << status;
154 }
155 } else {
156 if (!stop_heartbeats_.HasBeenNotified()) {
157 stop_heartbeats_.Notify();
158 }
159 }
160 }
161 }
162
StateToString(State state)163 /*static*/ absl::string_view DistributedRuntimeClientImpl::StateToString(
164 State state) {
165 switch (state) {
166 case State::kNotConnected:
167 return "kNotConnected";
168 case State::kConnected:
169 return "kConnected";
170 case State::kShuttingDown:
171 return "kShuttingDown";
172 case State::kClosed:
173 return "kClosed";
174 }
175 }
176
Connect()177 xla::Status DistributedRuntimeClientImpl::Connect() {
178 {
179 absl::MutexLock lock(&mu_);
180 if (state_ != State::kNotConnected) {
181 return xla::FailedPrecondition("Connect() called when client in state %s",
182 StateToString(state_));
183 }
184 }
185 ConnectRequest request;
186 request.set_protocol_version(DistributedRuntimeProtocolVersion());
187 request.set_timeout_milliseconds(
188 absl::ToInt64Milliseconds(options_.rpc_timeout) / 2);
189 request.set_node_id(options_.node_id);
190 VLOG(10) << "Connect: " << request.DebugString();
191 ConnectResponse response;
192 ::grpc::Status status;
193 absl::Time deadline = absl::Now() + options_.init_timeout;
194 int attempt = 0;
195 std::default_random_engine generator;
196 std::uniform_real_distribution<double> distribution(0.0, 1.0);
197 do {
198 ::grpc::ClientContext ctx;
199 ctx.set_fail_fast(false);
200 ctx.set_deadline(absl::ToChronoTime(absl::Now() + options_.rpc_timeout));
201 request.set_client_id(tensorflow::random::New64());
202 response.Clear();
203 status = stub_->Connect(&ctx, request, &response);
204 if (!status.ok()) {
205 VLOG(1) << "Connect failed() with status: " << FromGrpcStatus(status);
206 if (attempt % 10 == 0) {
207 LOG(INFO) << "Connect failed() with status: " << FromGrpcStatus(status);
208 }
209 // Exponential backoff with jitter. Note we will retry for `init_timeout`
210 // time in total; the `14` here corresponds to an ~16s maximum interval
211 // between connection attempts.
212 int backoff = 1 << std::min(14, attempt);
213 absl::SleepFor(absl::Milliseconds(backoff * distribution(generator)));
214 }
215 ++attempt;
216 } while (!status.ok() && absl::Now() < deadline);
217 if (!status.ok()) {
218 LOG(ERROR) << "Connect() failed after " << attempt << " retries in "
219 << options_.init_timeout
220 << "; most recent failure status: " << FromGrpcStatus(status);
221 return tensorflow::errors::DeadlineExceeded(
222 absl::StrFormat("Connect() timed out after %s with %d attempts. Most "
223 "recent failure was: %s",
224 absl::FormatDuration(options_.init_timeout), attempt,
225 FromGrpcStatus(status).ToString()));
226 }
227 VLOG(10) << "Connect() response: " << response.DebugString();
228 {
229 absl::MutexLock lock(&mu_);
230 state_ = State::kConnected;
231 }
232 session_id_ = response.session_id();
233
234 heartbeat_thread_.reset(options_.env->StartThread(
235 tensorflow::ThreadOptions(), "pjrt_distributed_heartbeat",
236 [this]() { HeartbeatLoop(); }));
237 LOG(INFO) << "Connected to distributed JAX controller";
238 return OkStatus();
239 }
240
EnumerateDevices(const LocalTopologyProto & local_topology,GlobalTopologyProto * global_topology)241 xla::Status DistributedRuntimeClientImpl::EnumerateDevices(
242 const LocalTopologyProto& local_topology,
243 GlobalTopologyProto* global_topology) {
244 {
245 absl::MutexLock lock(&mu_);
246 if (state_ != State::kConnected) {
247 return xla::FailedPrecondition(
248 "EnumerateDevices() called when client not connected.");
249 }
250 }
251 ::grpc::ClientContext ctx;
252 ctx.set_fail_fast(false);
253 ctx.set_deadline(absl::ToChronoTime(absl::Now() + options_.rpc_timeout));
254 EnumerateDevicesRequest request;
255 request.set_session_id(session_id_);
256 *request.mutable_local_topology() = local_topology;
257 request.mutable_local_topology()->set_node_id(options_.node_id);
258
259 VLOG(10) << "EnumerateDevices: " << request.DebugString();
260 EnumerateDevicesResponse response;
261 ::grpc::Status status = stub_->EnumerateDevices(&ctx, request, &response);
262 if (!status.ok()) {
263 return FromGrpcStatus(status);
264 }
265 VLOG(10) << "EnumerateDevices() response: " << response.DebugString();
266 response.mutable_global_topology()->Swap(global_topology);
267 return OkStatus();
268 }
269
Shutdown()270 xla::Status DistributedRuntimeClientImpl::Shutdown() {
271 LOG(INFO) << "Waiting for all distributed JAX tasks to shut down.";
272 ::grpc::ClientContext ctx;
273 {
274 absl::MutexLock lock(&mu_);
275 if (state_ != State::kConnected) {
276 return xla::FailedPrecondition(
277 "Shutdown() called when client not connected.");
278 }
279 state_ = State::kShuttingDown;
280 }
281 ctx.set_fail_fast(false);
282 ctx.set_deadline(absl::ToChronoTime(absl::Now() + options_.shutdown_timeout));
283 ShutdownRequest request;
284 request.set_session_id(session_id_);
285 VLOG(10) << "Shutdown: " << request.DebugString();
286 ShutdownResponse response;
287 ::grpc::Status status = stub_->Shutdown(&ctx, request, &response);
288
289 LOG(INFO) << "Distributed task shutdown result: " << FromGrpcStatus(status);
290 if (!status.ok()) {
291 return FromGrpcStatus(status);
292 }
293 if (!stop_heartbeats_.HasBeenNotified()) {
294 stop_heartbeats_.Notify();
295 }
296 VLOG(10) << "Shutdown() response: " << response.DebugString();
297 absl::MutexLock lock(&mu_);
298 state_ = State::kClosed;
299 return OkStatus();
300 }
301
BlockingKeyValueGet(std::string key,absl::Duration timeout)302 xla::StatusOr<std::string> DistributedRuntimeClientImpl::BlockingKeyValueGet(
303 std::string key, absl::Duration timeout) {
304 {
305 absl::MutexLock lock(&mu_);
306 if (state_ != State::kConnected) {
307 return xla::FailedPrecondition(
308 "BlockingKeyValueGet() called when client not connected.");
309 }
310 }
311 ::grpc::ClientContext ctx;
312 ctx.set_fail_fast(false);
313 ctx.set_deadline(absl::ToChronoTime(absl::Now() + timeout));
314 KeyValueGetRequest request;
315 request.set_session_id(session_id_);
316 request.set_key(std::move(key));
317 timeout = std::min(timeout, absl::Minutes(10)); // Avoid overflow
318 request.set_timeout_milliseconds(absl::ToInt64Milliseconds(timeout));
319 VLOG(10) << "BlockingKeyValueGet: " << request.DebugString();
320 KeyValueGetResponse response;
321 ::grpc::Status status = stub_->KeyValueGet(&ctx, request, &response);
322 if (!status.ok()) {
323 return FromGrpcStatus(status);
324 }
325 return response.value();
326 }
327
KeyValueSet(std::string key,std::string value)328 xla::Status DistributedRuntimeClientImpl::KeyValueSet(std::string key,
329 std::string value) {
330 {
331 absl::MutexLock lock(&mu_);
332 if (state_ != State::kConnected) {
333 return xla::FailedPrecondition(
334 "KeyValueSet() called when client not connected.");
335 }
336 }
337 ::grpc::ClientContext ctx;
338 ctx.set_fail_fast(false);
339 ctx.set_deadline(absl::ToChronoTime(absl::Now() + options_.rpc_timeout));
340 KeyValueSetRequest request;
341 request.set_session_id(session_id_);
342 request.set_key(std::move(key));
343 request.set_value(std::move(value));
344 VLOG(10) << "KeyValueSet: " << request.DebugString();
345 KeyValueSetResponse response;
346 ::grpc::Status status = stub_->KeyValueSet(&ctx, request, &response);
347 return FromGrpcStatus(status);
348 }
349
WaitAtBarrier(std::string barrier_id,absl::Duration timeout)350 xla::Status DistributedRuntimeClientImpl::WaitAtBarrier(
351 std::string barrier_id, absl::Duration timeout) {
352 {
353 absl::MutexLock lock(&mu_);
354 if (state_ != State::kConnected) {
355 return xla::FailedPrecondition(
356 "WaitAtBarrier() called when client not connected.");
357 }
358 }
359 ::grpc::ClientContext ctx;
360 ctx.set_fail_fast(false);
361 // Set timeout to be at least 5 seconds so that there is time for service-side
362 // timeout logic to execute.
363 ctx.set_deadline(
364 absl::ToChronoTime(absl::Now() + std::max(timeout, absl::Seconds(5))));
365 WaitAtBarrierRequest request;
366 request.set_session_id(session_id_);
367 request.set_barrier_id(std::move(barrier_id));
368 request.set_node_id(options_.node_id);
369 // TODO(yashkatariya,hanyuangtay): Change timeout_milliseconds to int64 in
370 // protocol.proto so that we don't need a minimum timeout here.
371 timeout = std::min(timeout, absl::Minutes(10)); // Avoid overflow
372 request.set_timeout_milliseconds(absl::ToInt64Milliseconds(timeout));
373 VLOG(10) << "WaitAtBarrier: " << request.DebugString();
374 WaitAtBarrierResponse response;
375 ::grpc::Status status = stub_->WaitAtBarrier(&ctx, request, &response);
376 return FromGrpcStatus(status);
377 }
378
379 xla::StatusOr<tensorflow::CoordinationServiceAgent*>
GetCoordinationServiceAgent()380 DistributedRuntimeClientImpl::GetCoordinationServiceAgent() {
381 return xla::Internal(
382 "Invoking GetCoordinationServiceAgent() while coordination service is "
383 "not enabled. Enable coordination service via "
384 "--jax_coordination_service.");
385 }
386
HeartbeatLoop()387 void DistributedRuntimeClientImpl::HeartbeatLoop() {
388 int num_missing_heartbeats = 0;
389 while (true) {
390 stop_heartbeats_.WaitForNotificationWithTimeout(
391 options_.heartbeat_interval);
392 if (stop_heartbeats_.HasBeenNotified()) {
393 return;
394 }
395
396 ::grpc::ClientContext ctx;
397 ctx.set_fail_fast(false);
398 ctx.set_deadline(
399 absl::ToChronoTime(absl::Now() + options_.heartbeat_interval));
400 HeartbeatRequest request;
401 request.set_session_id(session_id_);
402 request.set_node_id(options_.node_id);
403 VLOG(10) << "Heartbeat: " << request.DebugString();
404 HeartbeatResponse response;
405 ::grpc::Status status = stub_->Heartbeat(&ctx, request, &response);
406 if (status.ok()) {
407 VLOG(10) << "Heartbeat ok";
408 num_missing_heartbeats = 0;
409 } else {
410 ++num_missing_heartbeats;
411 VLOG(10) << "Heartbeat error, "
412 << options_.max_missing_heartbeats - num_missing_heartbeats
413 << " tries left: " << status.error_message();
414 bool is_transient_error =
415 (status.error_code() == ::grpc::StatusCode::DEADLINE_EXCEEDED ||
416 status.error_code() == ::grpc::StatusCode::UNAVAILABLE);
417 if (!stop_heartbeats_.HasBeenNotified() &&
418 (!is_transient_error ||
419 num_missing_heartbeats >= options_.max_missing_heartbeats)) {
420 // If we are shutting down, missed heartbeats are benign: they may
421 // simply mean that the server has shut down already before it saw
422 // the heartbeat request.
423 absl::MutexLock lock(&mu_);
424 if (state_ != State::kShuttingDown) {
425 options_.missed_heartbeat_callback(FromGrpcStatus(status),
426 !is_transient_error);
427 }
428 return;
429 }
430 }
431 }
432 }
433
434 DistributedRuntimeCoordinationServiceClient::
DistributedRuntimeCoordinationServiceClient(std::shared_ptr<::grpc::Channel> channel,const Options & options)435 DistributedRuntimeCoordinationServiceClient(
436 std::shared_ptr<::grpc::Channel> channel, const Options& options) {
437 // Convert options to coordination config.
438 tensorflow::CoordinationServiceConfig config;
439 config.set_service_type("standalone");
440 config.set_service_leader("/job:jax_worker/task:0");
441 config.set_cluster_register_timeout_in_ms(
442 absl::ToInt64Milliseconds(options.init_timeout));
443 min_connect_barrier_timeout_ = options.rpc_timeout;
444 config.set_heartbeat_timeout_in_ms(absl::ToInt64Milliseconds(
445 options.heartbeat_interval * options.max_missing_heartbeats));
446 config.set_shutdown_barrier_timeout_in_ms(
447 absl::ToInt64Milliseconds(options.shutdown_timeout));
448 config.set_agent_destruction_without_shutdown(
449 !options.shutdown_on_destruction);
450 auto error_fn =
451 [timeout_fn = options.missed_heartbeat_callback](const Status& status) {
452 LOG(ERROR) << "Coordination service agent in error status: " << status;
453 timeout_fn(status, /*coordinator_reported_failure=*/true);
454 };
455
456 std::unique_ptr<tensorflow::CoordinationClient> leader_client;
457 leader_client.reset(tensorflow::NewGrpcCoordinationClient(channel));
458 coord_agent_ = tensorflow::CreateCoordinationServiceAgent();
459 const Status status =
460 coord_agent_->Initialize(options.env, "jax_worker", options.node_id,
461 config, std::move(leader_client), error_fn);
462 if (!status.ok()) {
463 LOG(ERROR) << "Coordination agent failed to initialize: " << status;
464 }
465 task_id_ = options.node_id;
466 config_ = config;
467 }
468
469 DistributedRuntimeCoordinationServiceClient::
~DistributedRuntimeCoordinationServiceClient()470 ~DistributedRuntimeCoordinationServiceClient() {}
471
Connect()472 xla::Status DistributedRuntimeCoordinationServiceClient::Connect() {
473 Status s = tensorflow::errors::Unknown("Connection not attempted yet.");
474 absl::Duration timeout =
475 absl::Milliseconds(config_.cluster_register_timeout_in_ms());
476 absl::Time deadline = absl::Now() + timeout;
477 int attempt = 0;
478 std::default_random_engine generator;
479 std::uniform_real_distribution<double> distribution(0.0, 1.0);
480
481 do {
482 ++attempt;
483 s = coord_agent_->Connect();
484 if (s.ok()) {
485 absl::Duration barrier_timeout = deadline - absl::Now();
486 // Note: `init_timeout` in client options may be set to 0 so that the
487 // client only attempts to connect once. In that case, we provide some
488 // buffer time to wait for all tasks.
489 barrier_timeout = std::max(barrier_timeout, min_connect_barrier_timeout_);
490 s = coord_agent_->WaitAtBarrier("PjRT_Client_Connect", barrier_timeout,
491 /*tasks=*/{});
492 }
493 // Exponential backoff with jitter. Note we will retry for `init_timeout`
494 // time in total; the `14` here corresponds to an ~16s maximum interval
495 // between connection attempts.
496
497 int backoff = 1 << std::min(14, attempt);
498 absl::SleepFor(absl::Milliseconds(backoff * distribution(generator)));
499 } while (!s.ok() && absl::Now() < deadline &&
500 // Retries are only made for RPC errors. If a valid service error is
501 // returned, fail immediately.
502 s.GetPayload(tensorflow::CoordinationErrorPayloadKey()) ==
503 std::nullopt);
504 if (s.ok()) {
505 LOG(INFO) << "Connected to distributed JAX controller";
506 } else {
507 LOG(INFO) << "Failed to connect to distributed JAX controller: " << s;
508 }
509 return s;
510 }
511
Shutdown()512 xla::Status DistributedRuntimeCoordinationServiceClient::Shutdown() {
513 LOG(INFO) << "Distributed task shutdown initiated.";
514 Status s = coord_agent_->Shutdown();
515 LOG(INFO) << "Distributed task shutdown result: " << s;
516 return s;
517 }
518
EnumerateDevices(const LocalTopologyProto & local_topology,GlobalTopologyProto * global_topology)519 xla::Status DistributedRuntimeCoordinationServiceClient::EnumerateDevices(
520 const LocalTopologyProto& local_topology,
521 GlobalTopologyProto* global_topology) {
522 tensorflow::CoordinationServiceDeviceInfo devices;
523 LocalTopologyProto* device =
524 devices.mutable_xla()->mutable_devices()->add_nodes();
525 *device = local_topology;
526 device->set_node_id(task_id_);
527 Status s = coord_agent_->WaitForAllTasks(devices);
528 if (!s.ok()) return s;
529 *global_topology = coord_agent_->GetClusterDeviceInfo().xla().devices();
530 return OkStatus();
531 }
532
533 xla::StatusOr<std::string>
BlockingKeyValueGet(std::string key,absl::Duration timeout)534 DistributedRuntimeCoordinationServiceClient::BlockingKeyValueGet(
535 std::string key, absl::Duration timeout) {
536 return coord_agent_->GetKeyValue(key, timeout);
537 }
538
KeyValueSet(std::string key,std::string value)539 xla::Status DistributedRuntimeCoordinationServiceClient::KeyValueSet(
540 std::string key, std::string value) {
541 return coord_agent_->InsertKeyValue(key, value);
542 }
543
WaitAtBarrier(std::string barrier_id,absl::Duration timeout)544 xla::Status DistributedRuntimeCoordinationServiceClient::WaitAtBarrier(
545 std::string barrier_id, absl::Duration timeout) {
546 return coord_agent_->WaitAtBarrier(barrier_id, timeout, /*tasks=*/{});
547 }
548
549 xla::StatusOr<tensorflow::CoordinationServiceAgent*>
GetCoordinationServiceAgent()550 DistributedRuntimeCoordinationServiceClient::GetCoordinationServiceAgent() {
551 return coord_agent_.get();
552 }
553
GetDistributedRuntimeClient(std::shared_ptr<::grpc::Channel> channel,const DistributedRuntimeClient::Options & options,bool use_coordination_service)554 std::unique_ptr<DistributedRuntimeClient> GetDistributedRuntimeClient(
555 std::shared_ptr<::grpc::Channel> channel,
556 const DistributedRuntimeClient::Options& options,
557 bool use_coordination_service) {
558 if (use_coordination_service) {
559 return std::make_unique<xla::DistributedRuntimeCoordinationServiceClient>(
560 channel, options);
561 }
562 return std::make_unique<xla::DistributedRuntimeClientImpl>(channel, options);
563 }
564 } // namespace xla
565