xref: /aosp_15_r20/external/tensorflow/tensorflow/compiler/xla/pjrt/distributed/service.cc (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/xla/pjrt/distributed/service.h"
17 
18 #include <memory>
19 #include <string>
20 
21 #include "absl/memory/memory.h"
22 #include "absl/time/time.h"
23 #include "grpcpp/server_builder.h"
24 #include "tensorflow/compiler/xla/pjrt/distributed/protocol.h"
25 #include "tensorflow/compiler/xla/pjrt/distributed/util.h"
26 #include "tensorflow/compiler/xla/status.h"
27 #include "tensorflow/compiler/xla/util.h"
28 #include "tensorflow/core/distributed_runtime/coordination/coordination_service.h"
29 #include "tensorflow/core/distributed_runtime/rpc/async_service_interface.h"
30 #include "tensorflow/core/distributed_runtime/rpc/coordination/grpc_coordination_service_impl.h"
31 #include "tensorflow/core/platform/env.h"
32 #include "tensorflow/core/platform/errors.h"
33 #include "tensorflow/core/platform/random.h"
34 #include "tensorflow/core/platform/threadpool.h"
35 #include "tensorflow/core/protobuf/cluster.pb.h"
36 #include "tensorflow/core/protobuf/config.pb.h"
37 #include "tensorflow/core/protobuf/coordination_config.pb.h"
38 #include "tensorflow/core/protobuf/tensorflow_server.pb.h"
39 
40 namespace {
41 constexpr int kBarrierTimedOut = -1000;
42 
43 std::unique_ptr<tensorflow::CoordinationServiceInterface>
EnableCoordinationService(const xla::DistributedRuntimeServiceImpl::Options & options)44 EnableCoordinationService(
45     const xla::DistributedRuntimeServiceImpl::Options& options) {
46   const std::string& job_name = "jax_worker";
47   // TODO(b/205307544): Remove TensorFlow server def references once it is no
48   // longer needed.
49   tensorflow::ServerDef server_def;
50   server_def.set_protocol("grpc");
51   server_def.set_job_name(job_name);
52   server_def.set_task_index(0);
53   auto job_def = server_def.mutable_cluster()->add_job();
54   job_def->set_name(job_name);
55   for (int32_t i = 0; i < options.num_nodes; ++i) {
56     job_def->mutable_tasks()->insert({i, "UNKNOWN_SERVER_ADDRESS"});
57   }
58 
59   // Convert options to coordination service config.
60   auto coordination_config = server_def.mutable_default_session_config()
61                                  ->mutable_experimental()
62                                  ->mutable_coordination_config();
63   coordination_config->set_service_type("standalone");
64   coordination_config->set_service_leader(
65       absl::StrCat("/job:", job_name, "/task:0"));
66   coordination_config->set_cluster_register_timeout_in_ms(
67       absl::ToInt64Milliseconds(options.enumerate_devices_timeout));
68   coordination_config->set_heartbeat_timeout_in_ms(absl::ToInt64Milliseconds(
69       options.heartbeat_interval * options.max_missing_heartbeats));
70   coordination_config->set_shutdown_barrier_timeout_in_ms(
71       absl::ToInt64Milliseconds(options.shutdown_timeout));
72   return tensorflow::CoordinationServiceInterface::EnableCoordinationService(
73       "standalone", options.env, server_def, /*cache=*/nullptr);
74 }
75 }  // namespace
76 
77 namespace xla {
78 
DistributedRuntimeServiceImpl(const Options & options)79 DistributedRuntimeServiceImpl::DistributedRuntimeServiceImpl(
80     const Options& options)
81     : options_(options), session_id_(tensorflow::random::New64()) {
82   nodes_.resize(options.num_nodes);
83   local_topologies_.resize(options.num_nodes);
84 }
85 
~DistributedRuntimeServiceImpl()86 DistributedRuntimeServiceImpl::~DistributedRuntimeServiceImpl() {
87   {
88     absl::MutexLock lock(&mu_);
89     state_ = State::kClosed;
90     service_status_ =
91         tensorflow::errors::FailedPrecondition("Service shutting down.");
92     if (!stop_heartbeat_thread_.HasBeenNotified()) {
93       stop_heartbeat_thread_.Notify();
94     }
95   }
96 }
97 
98 // Steals the contents of `local_topologies`.
BuildGlobalTopology(absl::Span<LocalTopologyProto> local_topologies,GlobalTopologyProto * global_topology)99 void BuildGlobalTopology(absl::Span<LocalTopologyProto> local_topologies,
100                          GlobalTopologyProto* global_topology) {
101   int next_global_device_id = 0;
102   for (LocalTopologyProto& local : local_topologies) {
103     for (DeviceProto& device : *local.mutable_devices()) {
104       device.set_global_device_id(next_global_device_id++);
105     }
106     global_topology->add_nodes()->Swap(&local);
107   }
108 }
109 
ValidateNodeId(int node_id)110 xla::Status DistributedRuntimeServiceImpl::ValidateNodeId(int node_id) {
111   if (node_id < 0) {
112     return xla::InvalidArgument("Invalid node ID %d, must be non-negative",
113                                 node_id);
114   }
115   if (node_id >= options_.num_nodes) {
116     return xla::FailedPrecondition(
117         "Invalid node ID %d, must be in the range [0, %d)", node_id,
118         options_.num_nodes);
119   }
120   return xla::OkStatus();
121 }
122 
ValidateSessionId(uint64_t session_id)123 xla::Status DistributedRuntimeServiceImpl::ValidateSessionId(
124     uint64_t session_id) {
125   if (session_id != session_id_) {
126     return xla::FailedPrecondition(
127         "Session ID of request %llu does not match active session ID %llu",
128         session_id, session_id_);
129   }
130   return xla::OkStatus();
131 }
132 
Connect(::grpc::ServerContext * context,const ConnectRequest * request,ConnectResponse * response)133 ::grpc::Status DistributedRuntimeServiceImpl::Connect(
134     ::grpc::ServerContext* context, const ConnectRequest* request,
135     ConnectResponse* response) {
136   VLOG(10) << "Connect " << request->DebugString();
137   if (request->protocol_version() != DistributedRuntimeProtocolVersion()) {
138     return xla::ToGrpcStatus(xla::InvalidArgument("Invalid protocol version %d",
139                                                   request->protocol_version()));
140   }
141   absl::MutexLock lock(&mu_);
142   if (state_ != State::kInitializing) {
143     // This most likely indicates that a client task was restarted but the
144     // old master is still up. Clients should retry on failure.
145     return xla::ToGrpcStatus(tensorflow::errors::Aborted(
146         "Connect() called when system is not initializing."));
147   }
148   int node_id = request->node_id();
149   xla::Status status = ValidateNodeId(node_id);
150   if (!status.ok()) {
151     return xla::ToGrpcStatus(status);
152   }
153   if (!nodes_[node_id].present) {
154     nodes_[node_id].present = true;
155     ++num_nodes_present_;
156   }
157   nodes_[node_id].client_id = request->client_id();
158 
159   auto all_nodes_present_or_duplicate_request = [&]() {
160     mu_.AssertHeld();
161     return num_nodes_present_ == nodes_.size() ||
162            nodes_[node_id].client_id != request->client_id();
163   };
164   auto connect_timeout = absl::Milliseconds(request->timeout_milliseconds());
165   if (!mu_.AwaitWithTimeout(
166           absl::Condition(&all_nodes_present_or_duplicate_request),
167           connect_timeout)) {
168     nodes_[node_id].present = false;
169     --num_nodes_present_;
170     return xla::ToGrpcStatus(tensorflow::errors::DeadlineExceeded(
171         "Timed out after ", absl::FormatDuration(connect_timeout),
172         " waiting for all nodes to call Connect()"));
173   }
174 
175   if (nodes_[node_id].client_id != request->client_id()) {
176     // This might happen either if two nodes are erroneously configured with the
177     // same ID number, or it might happen if a task fails and is restarted
178     // while we are waiting for nodes to connect. To elaborate on the second
179     // scenario, it would look like this:
180     // * a task calls Connect() with a particular node_id and client_id.
181     // * the task is killed and restarted, or alternatively the client's RPC
182     //   times out and it decides to retry.
183     // * the task calls Connect() again with the same node_id and a different
184     //   client_id.
185     // In this scenario we take whichever client showed up most recently and
186     // evict the client with an out-of-date client ID.
187     return xla::ToGrpcStatus(
188         tensorflow::errors::Aborted("Duplicate node ID ", node_id));
189   }
190 
191   if (node_id == 0) {
192     state_ = State::kRunning;
193     heartbeat_thread_.reset(options_.env->StartThread(
194         tensorflow::ThreadOptions(), "pjrt_service_heartbeat",
195         [this]() { HeartbeatLoop(); }));
196   } else {
197     auto running = [&]() {
198       mu_.AssertHeld();
199       return state_ == State::kRunning;
200     };
201     mu_.Await(absl::Condition(&running));
202   }
203   nodes_[node_id].last_heartbeat = absl::Now();
204   response->set_session_id(session_id_);
205   return ::grpc::Status::OK;
206 }
207 
Shutdown(::grpc::ServerContext * context,const ShutdownRequest * request,ShutdownResponse * response)208 ::grpc::Status DistributedRuntimeServiceImpl::Shutdown(
209     ::grpc::ServerContext* context, const ShutdownRequest* request,
210     ShutdownResponse* response) {
211   VLOG(10) << "Shutdown " << request->DebugString();
212   xla::Status status = ValidateSessionId(request->session_id());
213   if (!status.ok()) {
214     return xla::ToGrpcStatus(status);
215   }
216   absl::MutexLock lock(&mu_);
217   if (state_ != State::kRunning) {
218     if (!service_status_.ok()) {
219       return xla::ToGrpcStatus(service_status_);
220     }
221     return xla::ToGrpcStatus(xla::FailedPrecondition(
222         "Shutdown() called when system is not running."));
223   }
224   int node_id = request->node_id();
225   status = ValidateNodeId(node_id);
226   if (!status.ok()) {
227     return xla::ToGrpcStatus(status);
228   }
229   ++num_nodes_shutting_down_;
230 
231   auto all_nodes_shutting_down = [&]() {
232     mu_.AssertHeld();
233     return num_nodes_shutting_down_ == nodes_.size() || !service_status_.ok();
234   };
235   if (!mu_.AwaitWithTimeout(absl::Condition(&all_nodes_shutting_down),
236                             options_.shutdown_timeout)) {
237     state_ = State::kClosed;
238     return xla::ToGrpcStatus(tensorflow::errors::DeadlineExceeded(
239         "Timed out after ", absl::FormatDuration(options_.shutdown_timeout),
240         " waiting for all nodes to call Shutdown()"));
241   }
242   state_ = State::kClosed;
243   if (!stop_heartbeat_thread_.HasBeenNotified()) {
244     stop_heartbeat_thread_.Notify();
245   }
246   if (!service_status_.ok()) {
247     return xla::ToGrpcStatus(service_status_);
248   }
249   return ::grpc::Status::OK;
250 }
251 
EnumerateDevices(::grpc::ServerContext * context,const EnumerateDevicesRequest * request,EnumerateDevicesResponse * response)252 ::grpc::Status DistributedRuntimeServiceImpl::EnumerateDevices(
253     ::grpc::ServerContext* context, const EnumerateDevicesRequest* request,
254     EnumerateDevicesResponse* response) {
255   VLOG(10) << "EnumerateDevices " << request->DebugString();
256   xla::Status status = ValidateSessionId(request->session_id());
257   if (!status.ok()) {
258     return xla::ToGrpcStatus(status);
259   }
260   absl::MutexLock lock(&mu_);
261   if (state_ != State::kRunning) {
262     if (!service_status_.ok()) {
263       return xla::ToGrpcStatus(service_status_);
264     }
265     return xla::ToGrpcStatus(xla::FailedPrecondition(
266         "EnumerateDevices() called when system is not running."));
267   }
268   int node_id = request->local_topology().node_id();
269   status = ValidateNodeId(node_id);
270   if (!status.ok()) {
271     return xla::ToGrpcStatus(status);
272   }
273   local_topologies_[node_id] = request->local_topology();
274   ++num_topologies_present_;
275 
276   auto all_topologies_present = [&]() {
277     mu_.AssertHeld();
278     return num_topologies_present_ == nodes_.size() || !service_status_.ok();
279   };
280   if (!mu_.AwaitWithTimeout(absl::Condition(&all_topologies_present),
281                             options_.enumerate_devices_timeout)) {
282     return xla::ToGrpcStatus(tensorflow::errors::DeadlineExceeded(
283         "Timed out after ",
284         absl::FormatDuration(options_.enumerate_devices_timeout),
285         " waiting for all nodes to call EnumerateDevices()"));
286   }
287   if (!service_status_.ok()) {
288     return xla::ToGrpcStatus(service_status_);
289   }
290 
291   if (node_id == 0) {
292     topology_.emplace();
293     BuildGlobalTopology(absl::Span<LocalTopologyProto>(local_topologies_),
294                         &*topology_);
295     local_topologies_.clear();
296   } else {
297     auto topology_ready = [&]() -> bool {
298       mu_.AssertHeld();
299       return topology_.has_value();
300     };
301     mu_.Await(absl::Condition(&topology_ready));
302   }
303   *response->mutable_global_topology() = *topology_;
304   return ::grpc::Status::OK;
305 }
306 
Heartbeat(::grpc::ServerContext * context,const HeartbeatRequest * request,HeartbeatResponse * response)307 ::grpc::Status DistributedRuntimeServiceImpl::Heartbeat(
308     ::grpc::ServerContext* context, const HeartbeatRequest* request,
309     HeartbeatResponse* response) {
310   VLOG(10) << "Heartbeat " << request->DebugString();
311   xla::Status status = ValidateSessionId(request->session_id());
312   if (!status.ok()) {
313     return xla::ToGrpcStatus(status);
314   }
315   absl::MutexLock lock(&mu_);
316   if (state_ != State::kRunning) {
317     if (!service_status_.ok()) {
318       return xla::ToGrpcStatus(service_status_);
319     }
320     return xla::ToGrpcStatus(xla::FailedPrecondition(
321         "Heartbeat() called when system is not running."));
322   }
323   int node_id = request->node_id();
324   status = ValidateNodeId(node_id);
325   if (!status.ok()) {
326     return xla::ToGrpcStatus(status);
327   }
328   nodes_[node_id].last_heartbeat = absl::Now();
329   return ::grpc::Status::OK;
330 }
331 
HeartbeatLoop()332 void DistributedRuntimeServiceImpl::HeartbeatLoop() {
333   while (true) {
334     stop_heartbeat_thread_.WaitForNotificationWithTimeout(
335         options_.heartbeat_interval);
336     VLOG(10) << "Checking heartbeats";
337     if (stop_heartbeat_thread_.HasBeenNotified()) {
338       VLOG(10) << "Heartbeat checking stopped.";
339       return;
340     }
341     absl::Time now = absl::Now();
342     absl::MutexLock lock(&mu_);
343     for (size_t i = 0; i < nodes_.size(); ++i) {
344       // If we haven't heard from the node for a number of heartbeat intervals,
345       // declare that we are unhealthy.
346       VLOG(10) << "Node " << i
347                << " last heartbeat: " << nodes_[i].last_heartbeat;
348       if (nodes_[i].last_heartbeat +
349               options_.max_missing_heartbeats * options_.heartbeat_interval <
350           now) {
351         LOG(INFO) << "Missed heartbeats from node " << i << ". Shutting down.";
352         state_ = State::kClosed;
353         service_status_ = tensorflow::errors::Aborted(
354             "Shutting down due to missed heartbeat from task ", i);
355         return;
356       }
357     }
358   }
359 }
360 
KeyValueGet(::grpc::ServerContext * context,const KeyValueGetRequest * request,KeyValueGetResponse * response)361 ::grpc::Status DistributedRuntimeServiceImpl::KeyValueGet(
362     ::grpc::ServerContext* context, const KeyValueGetRequest* request,
363     KeyValueGetResponse* response) {
364   VLOG(10) << "KeyValueGet " << request->DebugString();
365   xla::Status status = ValidateSessionId(request->session_id());
366   if (!status.ok()) {
367     return xla::ToGrpcStatus(status);
368   }
369   {
370     absl::MutexLock lock(&mu_);
371     if (state_ != State::kRunning) {
372       if (!service_status_.ok()) {
373         return xla::ToGrpcStatus(service_status_);
374       }
375       return xla::ToGrpcStatus(xla::FailedPrecondition(
376           "KeyValueGet() called when system is not running."));
377     }
378   }
379   return key_value_store_.Get(
380       request->key(), absl::Milliseconds(request->timeout_milliseconds()),
381       response->mutable_value());
382 }
383 
KeyValueSet(::grpc::ServerContext * context,const KeyValueSetRequest * request,KeyValueSetResponse * response)384 ::grpc::Status DistributedRuntimeServiceImpl::KeyValueSet(
385     ::grpc::ServerContext* context, const KeyValueSetRequest* request,
386     KeyValueSetResponse* response) {
387   VLOG(10) << "KeyValueSet " << request->DebugString();
388   xla::Status status = ValidateSessionId(request->session_id());
389   if (!status.ok()) {
390     return xla::ToGrpcStatus(status);
391   }
392   {
393     absl::MutexLock lock(&mu_);
394     if (state_ != State::kRunning) {
395       if (!service_status_.ok()) {
396         return xla::ToGrpcStatus(service_status_);
397       }
398       return xla::ToGrpcStatus(xla::FailedPrecondition(
399           "KeyValueSet() called when system is not running; clients must call "
400           "Connect() first"));
401     }
402   }
403   return key_value_store_.Set(request->key(), request->value());
404 }
405 
WaitAtBarrier(::grpc::ServerContext * context,const WaitAtBarrierRequest * request,WaitAtBarrierResponse * response)406 ::grpc::Status DistributedRuntimeServiceImpl::WaitAtBarrier(
407     ::grpc::ServerContext* context, const WaitAtBarrierRequest* request,
408     WaitAtBarrierResponse* response) {
409   VLOG(10) << "WaitAtBarrier " << request->DebugString();
410   xla::Status status = ValidateSessionId(request->session_id());
411   if (!status.ok()) {
412     return xla::ToGrpcStatus(status);
413   }
414   absl::MutexLock lock(&mu_);
415   if (state_ != State::kRunning) {
416     if (!service_status_.ok()) {
417       return xla::ToGrpcStatus(service_status_);
418     }
419     return xla::ToGrpcStatus(xla::FailedPrecondition(
420         "WaitAtBarrier() called when system is not running."));
421   }
422   int node_id = request->node_id();
423   status = ValidateNodeId(node_id);
424   if (!status.ok()) {
425     return xla::ToGrpcStatus(status);
426   }
427 
428   std::string barrier_id = request->barrier_id();
429 
430   if (barrier_id_to_num_nodes_[barrier_id] == nodes_.size()) {
431     return xla::ToGrpcStatus(
432         xla::FailedPrecondition("Calling WaitAtBarrier with the same id "
433                                 "across barriers is not allowed. Please use "
434                                 "unique barrier ids across barriers."));
435   }
436 
437   if (barrier_id_to_num_nodes_[barrier_id] == kBarrierTimedOut) {
438     return xla::ToGrpcStatus(xla::FailedPrecondition(
439         "A process timed out waiting at the barrier. Exiting early because the "
440         "current process will also timeout."));
441   }
442 
443   ++barrier_id_to_num_nodes_[barrier_id];
444 
445   absl::Duration timeout = absl::Milliseconds(request->timeout_milliseconds());
446   auto all_nodes_at_barrier = [&]() {
447     mu_.AssertHeld();
448     return barrier_id_to_num_nodes_[barrier_id] == nodes_.size() ||
449            !service_status_.ok();
450   };
451   // TODO(yashkatariya,hanyangtay): Do something similar to the coordination
452   // service here.
453   if (!mu_.AwaitWithTimeout(absl::Condition(&all_nodes_at_barrier), timeout)) {
454     barrier_id_to_num_nodes_[barrier_id] = kBarrierTimedOut;
455     return xla::ToGrpcStatus(tensorflow::errors::DeadlineExceeded(
456         "Timed out after ", timeout,
457         " waiting for all nodes to be at WaitAtBarrier()"));
458   }
459 
460   if (!service_status_.ok()) {
461     return xla::ToGrpcStatus(service_status_);
462   }
463   return ::grpc::Status::OK;
464 }
465 
CoordinationServiceImpl(const DistributedRuntimeServiceImpl::Options & options,::grpc::ServerBuilder * builder)466 CoordinationServiceImpl::CoordinationServiceImpl(
467     const DistributedRuntimeServiceImpl::Options& options,
468     ::grpc::ServerBuilder* builder)
469     : env_(options.env) {
470   coord_service_ = EnableCoordinationService(options);
471   coord_compute_pool_ = std::make_unique<tensorflow::thread::ThreadPool>(
472       options.env, "CoordinationServiceRpcHandler",
473       /*num_threads=*/4);
474   coord_rpc_service_ =
475       std::make_unique<tensorflow::GrpcCoordinationServiceImpl>(
476           coord_compute_pool_.get(), builder);
477   LOG(INFO) << "Experimental coordination service is enabled.";
478 }
479 
~CoordinationServiceImpl()480 CoordinationServiceImpl::~CoordinationServiceImpl() {
481   // Service object must be destroyed to clear all pending RPCs before shutting
482   // down the RPC service.
483   coord_service_ = nullptr;
484   coord_rpc_service_->Shutdown();
485 }
486 
StartRpcThread()487 void CoordinationServiceImpl::StartRpcThread() {
488   coord_rpc_thread_.reset(env_->StartThread(
489       tensorflow::ThreadOptions(), "CoordinationServiceHandleRPCsLoop",
490       [service = coord_rpc_service_.get()] { service->HandleRPCsLoop(); }));
491 }
492 
493 xla::StatusOr<std::unique_ptr<DistributedRuntimeService>>
Get(const std::string & address,std::shared_ptr<::grpc::ServerCredentials> credentials,const DistributedRuntimeServiceImpl::Options & options,bool use_coordination_service)494 DistributedRuntimeService::Get(
495     const std::string& address,
496     std::shared_ptr<::grpc::ServerCredentials> credentials,
497     const DistributedRuntimeServiceImpl::Options& options,
498     bool use_coordination_service) {
499   ::grpc::ServerBuilder builder;
500   builder.AddListeningPort(address, credentials);
501   VLOG(1) << "Distributed runtime service address " << address;
502   auto service = std::make_unique<DistributedRuntimeService>(
503       options, &builder, use_coordination_service);
504   if (!service->server_) {
505     return xla::Unknown("Failed to start RPC server");
506   }
507   LOG(INFO) << "Jax service listening on " << address;
508   return service;
509 }
510 
DistributedRuntimeService(const DistributedRuntimeServiceImpl::Options & options,::grpc::ServerBuilder * builder,bool use_coordination_service)511 DistributedRuntimeService::DistributedRuntimeService(
512     const DistributedRuntimeServiceImpl::Options& options,
513     ::grpc::ServerBuilder* builder, bool use_coordination_service) {
514   if (use_coordination_service) {
515     coord_impl_ = std::make_unique<CoordinationServiceImpl>(options, builder);
516     server_ = builder->BuildAndStart();
517     coord_impl_->StartRpcThread();
518   } else {
519     impl_ = std::make_unique<DistributedRuntimeServiceImpl>(options);
520     builder->RegisterService(impl_.get());
521     server_ = builder->BuildAndStart();
522   }
523 }
524 
~DistributedRuntimeService()525 DistributedRuntimeService::~DistributedRuntimeService() { Shutdown(); }
526 
Shutdown()527 void DistributedRuntimeService::Shutdown() {
528   if (server_) {
529     LOG(INFO) << "Jax service shutting down";
530     server_->Shutdown();
531     server_->Wait();
532   }
533 
534   // Explicitly destroy coordination service before the gRPC server. This clears
535   // all pending RPCs before the gRPC server is destroyed.
536   coord_impl_ = nullptr;
537   server_ = nullptr;
538 }
539 
540 }  // namespace xla
541