1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/compiler/xla/pjrt/distributed/service.h"
17
18 #include <memory>
19 #include <string>
20
21 #include "absl/memory/memory.h"
22 #include "absl/time/time.h"
23 #include "grpcpp/server_builder.h"
24 #include "tensorflow/compiler/xla/pjrt/distributed/protocol.h"
25 #include "tensorflow/compiler/xla/pjrt/distributed/util.h"
26 #include "tensorflow/compiler/xla/status.h"
27 #include "tensorflow/compiler/xla/util.h"
28 #include "tensorflow/core/distributed_runtime/coordination/coordination_service.h"
29 #include "tensorflow/core/distributed_runtime/rpc/async_service_interface.h"
30 #include "tensorflow/core/distributed_runtime/rpc/coordination/grpc_coordination_service_impl.h"
31 #include "tensorflow/core/platform/env.h"
32 #include "tensorflow/core/platform/errors.h"
33 #include "tensorflow/core/platform/random.h"
34 #include "tensorflow/core/platform/threadpool.h"
35 #include "tensorflow/core/protobuf/cluster.pb.h"
36 #include "tensorflow/core/protobuf/config.pb.h"
37 #include "tensorflow/core/protobuf/coordination_config.pb.h"
38 #include "tensorflow/core/protobuf/tensorflow_server.pb.h"
39
40 namespace {
41 constexpr int kBarrierTimedOut = -1000;
42
43 std::unique_ptr<tensorflow::CoordinationServiceInterface>
EnableCoordinationService(const xla::DistributedRuntimeServiceImpl::Options & options)44 EnableCoordinationService(
45 const xla::DistributedRuntimeServiceImpl::Options& options) {
46 const std::string& job_name = "jax_worker";
47 // TODO(b/205307544): Remove TensorFlow server def references once it is no
48 // longer needed.
49 tensorflow::ServerDef server_def;
50 server_def.set_protocol("grpc");
51 server_def.set_job_name(job_name);
52 server_def.set_task_index(0);
53 auto job_def = server_def.mutable_cluster()->add_job();
54 job_def->set_name(job_name);
55 for (int32_t i = 0; i < options.num_nodes; ++i) {
56 job_def->mutable_tasks()->insert({i, "UNKNOWN_SERVER_ADDRESS"});
57 }
58
59 // Convert options to coordination service config.
60 auto coordination_config = server_def.mutable_default_session_config()
61 ->mutable_experimental()
62 ->mutable_coordination_config();
63 coordination_config->set_service_type("standalone");
64 coordination_config->set_service_leader(
65 absl::StrCat("/job:", job_name, "/task:0"));
66 coordination_config->set_cluster_register_timeout_in_ms(
67 absl::ToInt64Milliseconds(options.enumerate_devices_timeout));
68 coordination_config->set_heartbeat_timeout_in_ms(absl::ToInt64Milliseconds(
69 options.heartbeat_interval * options.max_missing_heartbeats));
70 coordination_config->set_shutdown_barrier_timeout_in_ms(
71 absl::ToInt64Milliseconds(options.shutdown_timeout));
72 return tensorflow::CoordinationServiceInterface::EnableCoordinationService(
73 "standalone", options.env, server_def, /*cache=*/nullptr);
74 }
75 } // namespace
76
77 namespace xla {
78
DistributedRuntimeServiceImpl(const Options & options)79 DistributedRuntimeServiceImpl::DistributedRuntimeServiceImpl(
80 const Options& options)
81 : options_(options), session_id_(tensorflow::random::New64()) {
82 nodes_.resize(options.num_nodes);
83 local_topologies_.resize(options.num_nodes);
84 }
85
~DistributedRuntimeServiceImpl()86 DistributedRuntimeServiceImpl::~DistributedRuntimeServiceImpl() {
87 {
88 absl::MutexLock lock(&mu_);
89 state_ = State::kClosed;
90 service_status_ =
91 tensorflow::errors::FailedPrecondition("Service shutting down.");
92 if (!stop_heartbeat_thread_.HasBeenNotified()) {
93 stop_heartbeat_thread_.Notify();
94 }
95 }
96 }
97
98 // Steals the contents of `local_topologies`.
BuildGlobalTopology(absl::Span<LocalTopologyProto> local_topologies,GlobalTopologyProto * global_topology)99 void BuildGlobalTopology(absl::Span<LocalTopologyProto> local_topologies,
100 GlobalTopologyProto* global_topology) {
101 int next_global_device_id = 0;
102 for (LocalTopologyProto& local : local_topologies) {
103 for (DeviceProto& device : *local.mutable_devices()) {
104 device.set_global_device_id(next_global_device_id++);
105 }
106 global_topology->add_nodes()->Swap(&local);
107 }
108 }
109
ValidateNodeId(int node_id)110 xla::Status DistributedRuntimeServiceImpl::ValidateNodeId(int node_id) {
111 if (node_id < 0) {
112 return xla::InvalidArgument("Invalid node ID %d, must be non-negative",
113 node_id);
114 }
115 if (node_id >= options_.num_nodes) {
116 return xla::FailedPrecondition(
117 "Invalid node ID %d, must be in the range [0, %d)", node_id,
118 options_.num_nodes);
119 }
120 return xla::OkStatus();
121 }
122
ValidateSessionId(uint64_t session_id)123 xla::Status DistributedRuntimeServiceImpl::ValidateSessionId(
124 uint64_t session_id) {
125 if (session_id != session_id_) {
126 return xla::FailedPrecondition(
127 "Session ID of request %llu does not match active session ID %llu",
128 session_id, session_id_);
129 }
130 return xla::OkStatus();
131 }
132
Connect(::grpc::ServerContext * context,const ConnectRequest * request,ConnectResponse * response)133 ::grpc::Status DistributedRuntimeServiceImpl::Connect(
134 ::grpc::ServerContext* context, const ConnectRequest* request,
135 ConnectResponse* response) {
136 VLOG(10) << "Connect " << request->DebugString();
137 if (request->protocol_version() != DistributedRuntimeProtocolVersion()) {
138 return xla::ToGrpcStatus(xla::InvalidArgument("Invalid protocol version %d",
139 request->protocol_version()));
140 }
141 absl::MutexLock lock(&mu_);
142 if (state_ != State::kInitializing) {
143 // This most likely indicates that a client task was restarted but the
144 // old master is still up. Clients should retry on failure.
145 return xla::ToGrpcStatus(tensorflow::errors::Aborted(
146 "Connect() called when system is not initializing."));
147 }
148 int node_id = request->node_id();
149 xla::Status status = ValidateNodeId(node_id);
150 if (!status.ok()) {
151 return xla::ToGrpcStatus(status);
152 }
153 if (!nodes_[node_id].present) {
154 nodes_[node_id].present = true;
155 ++num_nodes_present_;
156 }
157 nodes_[node_id].client_id = request->client_id();
158
159 auto all_nodes_present_or_duplicate_request = [&]() {
160 mu_.AssertHeld();
161 return num_nodes_present_ == nodes_.size() ||
162 nodes_[node_id].client_id != request->client_id();
163 };
164 auto connect_timeout = absl::Milliseconds(request->timeout_milliseconds());
165 if (!mu_.AwaitWithTimeout(
166 absl::Condition(&all_nodes_present_or_duplicate_request),
167 connect_timeout)) {
168 nodes_[node_id].present = false;
169 --num_nodes_present_;
170 return xla::ToGrpcStatus(tensorflow::errors::DeadlineExceeded(
171 "Timed out after ", absl::FormatDuration(connect_timeout),
172 " waiting for all nodes to call Connect()"));
173 }
174
175 if (nodes_[node_id].client_id != request->client_id()) {
176 // This might happen either if two nodes are erroneously configured with the
177 // same ID number, or it might happen if a task fails and is restarted
178 // while we are waiting for nodes to connect. To elaborate on the second
179 // scenario, it would look like this:
180 // * a task calls Connect() with a particular node_id and client_id.
181 // * the task is killed and restarted, or alternatively the client's RPC
182 // times out and it decides to retry.
183 // * the task calls Connect() again with the same node_id and a different
184 // client_id.
185 // In this scenario we take whichever client showed up most recently and
186 // evict the client with an out-of-date client ID.
187 return xla::ToGrpcStatus(
188 tensorflow::errors::Aborted("Duplicate node ID ", node_id));
189 }
190
191 if (node_id == 0) {
192 state_ = State::kRunning;
193 heartbeat_thread_.reset(options_.env->StartThread(
194 tensorflow::ThreadOptions(), "pjrt_service_heartbeat",
195 [this]() { HeartbeatLoop(); }));
196 } else {
197 auto running = [&]() {
198 mu_.AssertHeld();
199 return state_ == State::kRunning;
200 };
201 mu_.Await(absl::Condition(&running));
202 }
203 nodes_[node_id].last_heartbeat = absl::Now();
204 response->set_session_id(session_id_);
205 return ::grpc::Status::OK;
206 }
207
Shutdown(::grpc::ServerContext * context,const ShutdownRequest * request,ShutdownResponse * response)208 ::grpc::Status DistributedRuntimeServiceImpl::Shutdown(
209 ::grpc::ServerContext* context, const ShutdownRequest* request,
210 ShutdownResponse* response) {
211 VLOG(10) << "Shutdown " << request->DebugString();
212 xla::Status status = ValidateSessionId(request->session_id());
213 if (!status.ok()) {
214 return xla::ToGrpcStatus(status);
215 }
216 absl::MutexLock lock(&mu_);
217 if (state_ != State::kRunning) {
218 if (!service_status_.ok()) {
219 return xla::ToGrpcStatus(service_status_);
220 }
221 return xla::ToGrpcStatus(xla::FailedPrecondition(
222 "Shutdown() called when system is not running."));
223 }
224 int node_id = request->node_id();
225 status = ValidateNodeId(node_id);
226 if (!status.ok()) {
227 return xla::ToGrpcStatus(status);
228 }
229 ++num_nodes_shutting_down_;
230
231 auto all_nodes_shutting_down = [&]() {
232 mu_.AssertHeld();
233 return num_nodes_shutting_down_ == nodes_.size() || !service_status_.ok();
234 };
235 if (!mu_.AwaitWithTimeout(absl::Condition(&all_nodes_shutting_down),
236 options_.shutdown_timeout)) {
237 state_ = State::kClosed;
238 return xla::ToGrpcStatus(tensorflow::errors::DeadlineExceeded(
239 "Timed out after ", absl::FormatDuration(options_.shutdown_timeout),
240 " waiting for all nodes to call Shutdown()"));
241 }
242 state_ = State::kClosed;
243 if (!stop_heartbeat_thread_.HasBeenNotified()) {
244 stop_heartbeat_thread_.Notify();
245 }
246 if (!service_status_.ok()) {
247 return xla::ToGrpcStatus(service_status_);
248 }
249 return ::grpc::Status::OK;
250 }
251
EnumerateDevices(::grpc::ServerContext * context,const EnumerateDevicesRequest * request,EnumerateDevicesResponse * response)252 ::grpc::Status DistributedRuntimeServiceImpl::EnumerateDevices(
253 ::grpc::ServerContext* context, const EnumerateDevicesRequest* request,
254 EnumerateDevicesResponse* response) {
255 VLOG(10) << "EnumerateDevices " << request->DebugString();
256 xla::Status status = ValidateSessionId(request->session_id());
257 if (!status.ok()) {
258 return xla::ToGrpcStatus(status);
259 }
260 absl::MutexLock lock(&mu_);
261 if (state_ != State::kRunning) {
262 if (!service_status_.ok()) {
263 return xla::ToGrpcStatus(service_status_);
264 }
265 return xla::ToGrpcStatus(xla::FailedPrecondition(
266 "EnumerateDevices() called when system is not running."));
267 }
268 int node_id = request->local_topology().node_id();
269 status = ValidateNodeId(node_id);
270 if (!status.ok()) {
271 return xla::ToGrpcStatus(status);
272 }
273 local_topologies_[node_id] = request->local_topology();
274 ++num_topologies_present_;
275
276 auto all_topologies_present = [&]() {
277 mu_.AssertHeld();
278 return num_topologies_present_ == nodes_.size() || !service_status_.ok();
279 };
280 if (!mu_.AwaitWithTimeout(absl::Condition(&all_topologies_present),
281 options_.enumerate_devices_timeout)) {
282 return xla::ToGrpcStatus(tensorflow::errors::DeadlineExceeded(
283 "Timed out after ",
284 absl::FormatDuration(options_.enumerate_devices_timeout),
285 " waiting for all nodes to call EnumerateDevices()"));
286 }
287 if (!service_status_.ok()) {
288 return xla::ToGrpcStatus(service_status_);
289 }
290
291 if (node_id == 0) {
292 topology_.emplace();
293 BuildGlobalTopology(absl::Span<LocalTopologyProto>(local_topologies_),
294 &*topology_);
295 local_topologies_.clear();
296 } else {
297 auto topology_ready = [&]() -> bool {
298 mu_.AssertHeld();
299 return topology_.has_value();
300 };
301 mu_.Await(absl::Condition(&topology_ready));
302 }
303 *response->mutable_global_topology() = *topology_;
304 return ::grpc::Status::OK;
305 }
306
Heartbeat(::grpc::ServerContext * context,const HeartbeatRequest * request,HeartbeatResponse * response)307 ::grpc::Status DistributedRuntimeServiceImpl::Heartbeat(
308 ::grpc::ServerContext* context, const HeartbeatRequest* request,
309 HeartbeatResponse* response) {
310 VLOG(10) << "Heartbeat " << request->DebugString();
311 xla::Status status = ValidateSessionId(request->session_id());
312 if (!status.ok()) {
313 return xla::ToGrpcStatus(status);
314 }
315 absl::MutexLock lock(&mu_);
316 if (state_ != State::kRunning) {
317 if (!service_status_.ok()) {
318 return xla::ToGrpcStatus(service_status_);
319 }
320 return xla::ToGrpcStatus(xla::FailedPrecondition(
321 "Heartbeat() called when system is not running."));
322 }
323 int node_id = request->node_id();
324 status = ValidateNodeId(node_id);
325 if (!status.ok()) {
326 return xla::ToGrpcStatus(status);
327 }
328 nodes_[node_id].last_heartbeat = absl::Now();
329 return ::grpc::Status::OK;
330 }
331
HeartbeatLoop()332 void DistributedRuntimeServiceImpl::HeartbeatLoop() {
333 while (true) {
334 stop_heartbeat_thread_.WaitForNotificationWithTimeout(
335 options_.heartbeat_interval);
336 VLOG(10) << "Checking heartbeats";
337 if (stop_heartbeat_thread_.HasBeenNotified()) {
338 VLOG(10) << "Heartbeat checking stopped.";
339 return;
340 }
341 absl::Time now = absl::Now();
342 absl::MutexLock lock(&mu_);
343 for (size_t i = 0; i < nodes_.size(); ++i) {
344 // If we haven't heard from the node for a number of heartbeat intervals,
345 // declare that we are unhealthy.
346 VLOG(10) << "Node " << i
347 << " last heartbeat: " << nodes_[i].last_heartbeat;
348 if (nodes_[i].last_heartbeat +
349 options_.max_missing_heartbeats * options_.heartbeat_interval <
350 now) {
351 LOG(INFO) << "Missed heartbeats from node " << i << ". Shutting down.";
352 state_ = State::kClosed;
353 service_status_ = tensorflow::errors::Aborted(
354 "Shutting down due to missed heartbeat from task ", i);
355 return;
356 }
357 }
358 }
359 }
360
KeyValueGet(::grpc::ServerContext * context,const KeyValueGetRequest * request,KeyValueGetResponse * response)361 ::grpc::Status DistributedRuntimeServiceImpl::KeyValueGet(
362 ::grpc::ServerContext* context, const KeyValueGetRequest* request,
363 KeyValueGetResponse* response) {
364 VLOG(10) << "KeyValueGet " << request->DebugString();
365 xla::Status status = ValidateSessionId(request->session_id());
366 if (!status.ok()) {
367 return xla::ToGrpcStatus(status);
368 }
369 {
370 absl::MutexLock lock(&mu_);
371 if (state_ != State::kRunning) {
372 if (!service_status_.ok()) {
373 return xla::ToGrpcStatus(service_status_);
374 }
375 return xla::ToGrpcStatus(xla::FailedPrecondition(
376 "KeyValueGet() called when system is not running."));
377 }
378 }
379 return key_value_store_.Get(
380 request->key(), absl::Milliseconds(request->timeout_milliseconds()),
381 response->mutable_value());
382 }
383
KeyValueSet(::grpc::ServerContext * context,const KeyValueSetRequest * request,KeyValueSetResponse * response)384 ::grpc::Status DistributedRuntimeServiceImpl::KeyValueSet(
385 ::grpc::ServerContext* context, const KeyValueSetRequest* request,
386 KeyValueSetResponse* response) {
387 VLOG(10) << "KeyValueSet " << request->DebugString();
388 xla::Status status = ValidateSessionId(request->session_id());
389 if (!status.ok()) {
390 return xla::ToGrpcStatus(status);
391 }
392 {
393 absl::MutexLock lock(&mu_);
394 if (state_ != State::kRunning) {
395 if (!service_status_.ok()) {
396 return xla::ToGrpcStatus(service_status_);
397 }
398 return xla::ToGrpcStatus(xla::FailedPrecondition(
399 "KeyValueSet() called when system is not running; clients must call "
400 "Connect() first"));
401 }
402 }
403 return key_value_store_.Set(request->key(), request->value());
404 }
405
WaitAtBarrier(::grpc::ServerContext * context,const WaitAtBarrierRequest * request,WaitAtBarrierResponse * response)406 ::grpc::Status DistributedRuntimeServiceImpl::WaitAtBarrier(
407 ::grpc::ServerContext* context, const WaitAtBarrierRequest* request,
408 WaitAtBarrierResponse* response) {
409 VLOG(10) << "WaitAtBarrier " << request->DebugString();
410 xla::Status status = ValidateSessionId(request->session_id());
411 if (!status.ok()) {
412 return xla::ToGrpcStatus(status);
413 }
414 absl::MutexLock lock(&mu_);
415 if (state_ != State::kRunning) {
416 if (!service_status_.ok()) {
417 return xla::ToGrpcStatus(service_status_);
418 }
419 return xla::ToGrpcStatus(xla::FailedPrecondition(
420 "WaitAtBarrier() called when system is not running."));
421 }
422 int node_id = request->node_id();
423 status = ValidateNodeId(node_id);
424 if (!status.ok()) {
425 return xla::ToGrpcStatus(status);
426 }
427
428 std::string barrier_id = request->barrier_id();
429
430 if (barrier_id_to_num_nodes_[barrier_id] == nodes_.size()) {
431 return xla::ToGrpcStatus(
432 xla::FailedPrecondition("Calling WaitAtBarrier with the same id "
433 "across barriers is not allowed. Please use "
434 "unique barrier ids across barriers."));
435 }
436
437 if (barrier_id_to_num_nodes_[barrier_id] == kBarrierTimedOut) {
438 return xla::ToGrpcStatus(xla::FailedPrecondition(
439 "A process timed out waiting at the barrier. Exiting early because the "
440 "current process will also timeout."));
441 }
442
443 ++barrier_id_to_num_nodes_[barrier_id];
444
445 absl::Duration timeout = absl::Milliseconds(request->timeout_milliseconds());
446 auto all_nodes_at_barrier = [&]() {
447 mu_.AssertHeld();
448 return barrier_id_to_num_nodes_[barrier_id] == nodes_.size() ||
449 !service_status_.ok();
450 };
451 // TODO(yashkatariya,hanyangtay): Do something similar to the coordination
452 // service here.
453 if (!mu_.AwaitWithTimeout(absl::Condition(&all_nodes_at_barrier), timeout)) {
454 barrier_id_to_num_nodes_[barrier_id] = kBarrierTimedOut;
455 return xla::ToGrpcStatus(tensorflow::errors::DeadlineExceeded(
456 "Timed out after ", timeout,
457 " waiting for all nodes to be at WaitAtBarrier()"));
458 }
459
460 if (!service_status_.ok()) {
461 return xla::ToGrpcStatus(service_status_);
462 }
463 return ::grpc::Status::OK;
464 }
465
CoordinationServiceImpl(const DistributedRuntimeServiceImpl::Options & options,::grpc::ServerBuilder * builder)466 CoordinationServiceImpl::CoordinationServiceImpl(
467 const DistributedRuntimeServiceImpl::Options& options,
468 ::grpc::ServerBuilder* builder)
469 : env_(options.env) {
470 coord_service_ = EnableCoordinationService(options);
471 coord_compute_pool_ = std::make_unique<tensorflow::thread::ThreadPool>(
472 options.env, "CoordinationServiceRpcHandler",
473 /*num_threads=*/4);
474 coord_rpc_service_ =
475 std::make_unique<tensorflow::GrpcCoordinationServiceImpl>(
476 coord_compute_pool_.get(), builder);
477 LOG(INFO) << "Experimental coordination service is enabled.";
478 }
479
~CoordinationServiceImpl()480 CoordinationServiceImpl::~CoordinationServiceImpl() {
481 // Service object must be destroyed to clear all pending RPCs before shutting
482 // down the RPC service.
483 coord_service_ = nullptr;
484 coord_rpc_service_->Shutdown();
485 }
486
StartRpcThread()487 void CoordinationServiceImpl::StartRpcThread() {
488 coord_rpc_thread_.reset(env_->StartThread(
489 tensorflow::ThreadOptions(), "CoordinationServiceHandleRPCsLoop",
490 [service = coord_rpc_service_.get()] { service->HandleRPCsLoop(); }));
491 }
492
493 xla::StatusOr<std::unique_ptr<DistributedRuntimeService>>
Get(const std::string & address,std::shared_ptr<::grpc::ServerCredentials> credentials,const DistributedRuntimeServiceImpl::Options & options,bool use_coordination_service)494 DistributedRuntimeService::Get(
495 const std::string& address,
496 std::shared_ptr<::grpc::ServerCredentials> credentials,
497 const DistributedRuntimeServiceImpl::Options& options,
498 bool use_coordination_service) {
499 ::grpc::ServerBuilder builder;
500 builder.AddListeningPort(address, credentials);
501 VLOG(1) << "Distributed runtime service address " << address;
502 auto service = std::make_unique<DistributedRuntimeService>(
503 options, &builder, use_coordination_service);
504 if (!service->server_) {
505 return xla::Unknown("Failed to start RPC server");
506 }
507 LOG(INFO) << "Jax service listening on " << address;
508 return service;
509 }
510
DistributedRuntimeService(const DistributedRuntimeServiceImpl::Options & options,::grpc::ServerBuilder * builder,bool use_coordination_service)511 DistributedRuntimeService::DistributedRuntimeService(
512 const DistributedRuntimeServiceImpl::Options& options,
513 ::grpc::ServerBuilder* builder, bool use_coordination_service) {
514 if (use_coordination_service) {
515 coord_impl_ = std::make_unique<CoordinationServiceImpl>(options, builder);
516 server_ = builder->BuildAndStart();
517 coord_impl_->StartRpcThread();
518 } else {
519 impl_ = std::make_unique<DistributedRuntimeServiceImpl>(options);
520 builder->RegisterService(impl_.get());
521 server_ = builder->BuildAndStart();
522 }
523 }
524
~DistributedRuntimeService()525 DistributedRuntimeService::~DistributedRuntimeService() { Shutdown(); }
526
Shutdown()527 void DistributedRuntimeService::Shutdown() {
528 if (server_) {
529 LOG(INFO) << "Jax service shutting down";
530 server_->Shutdown();
531 server_->Wait();
532 }
533
534 // Explicitly destroy coordination service before the gRPC server. This clears
535 // all pending RPCs before the gRPC server is destroyed.
536 coord_impl_ = nullptr;
537 server_ = nullptr;
538 }
539
540 } // namespace xla
541