xref: /aosp_15_r20/external/tensorflow/tensorflow/compiler/xla/pjrt/distributed/service.h (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #ifndef TENSORFLOW_COMPILER_XLA_PJRT_DISTRIBUTED_SERVICE_H_
17 #define TENSORFLOW_COMPILER_XLA_PJRT_DISTRIBUTED_SERVICE_H_
18 
19 #include <string>
20 
21 #include "absl/container/flat_hash_map.h"
22 #include "absl/synchronization/mutex.h"
23 #include "absl/synchronization/notification.h"
24 #include "absl/time/time.h"
25 #include "grpcpp/security/server_credentials.h"
26 #include "tensorflow/compiler/xla/pjrt/distributed/key_value_store.h"
27 #include "tensorflow/compiler/xla/pjrt/distributed/protocol.grpc.pb.h"
28 #include "tensorflow/compiler/xla/statusor.h"
29 #include "tensorflow/compiler/xla/types.h"
30 #include "tensorflow/core/distributed_runtime/coordination/coordination_service.h"
31 #include "tensorflow/core/distributed_runtime/rpc/async_service_interface.h"
32 #include "tensorflow/core/platform/env.h"
33 #include "tensorflow/core/platform/threadpool.h"
34 
35 namespace xla {
36 
37 typedef int NodeId;
38 
39 class DistributedRuntimeServiceImpl final
40     : public grpc::DistributedRuntimeService::Service {
41  public:
42   struct Options {
43     // Number of nodes in the job. Mandatory. Must be non-negative.
44     int num_nodes = -1;
45 
46     tensorflow::Env* env = tensorflow::Env::Default();
47 
48     // Interval at which the service should check for missed heartbeat RPCs
49     // from the clients.
50     absl::Duration heartbeat_interval = absl::Seconds(10);
51 
52     // Number of heartbeats that a client may miss in a row before the
53     // coordinator concludes that a client has vanished.
54     int max_missing_heartbeats = 10;
55 
56     // How long should we wait for all clients to call EnumerateDevices() before
57     // giving up?
58     absl::Duration enumerate_devices_timeout = absl::Seconds(60);
59 
60     // How long should we wait for all clients to call Shutdown() before giving
61     // up and returning a failure?
62     absl::Duration shutdown_timeout = absl::Seconds(60);
63   };
64   explicit DistributedRuntimeServiceImpl(const Options& options);
65   ~DistributedRuntimeServiceImpl() override;
66 
67   DistributedRuntimeServiceImpl(const DistributedRuntimeServiceImpl&) = delete;
68   DistributedRuntimeServiceImpl(DistributedRuntimeServiceImpl&&) = delete;
69   DistributedRuntimeServiceImpl& operator=(
70       const DistributedRuntimeServiceImpl&) = delete;
71   DistributedRuntimeServiceImpl&& operator=(DistributedRuntimeServiceImpl&&) =
72       delete;
73 
74   ::grpc::Status Connect(::grpc::ServerContext* context,
75                          const ConnectRequest* request,
76                          ConnectResponse* response) override;
77 
78   ::grpc::Status Shutdown(::grpc::ServerContext* context,
79                           const ShutdownRequest* request,
80                           ShutdownResponse* response) override;
81 
82   ::grpc::Status Heartbeat(::grpc::ServerContext* context,
83                            const HeartbeatRequest* request,
84                            HeartbeatResponse* response) override;
85 
86   ::grpc::Status EnumerateDevices(::grpc::ServerContext* context,
87                                   const EnumerateDevicesRequest* request,
88                                   EnumerateDevicesResponse* response) override;
89 
90   ::grpc::Status KeyValueGet(::grpc::ServerContext* context,
91                              const KeyValueGetRequest* request,
92                              KeyValueGetResponse* response) override;
93 
94   ::grpc::Status KeyValueSet(::grpc::ServerContext* context,
95                              const KeyValueSetRequest* request,
96                              KeyValueSetResponse* response) override;
97 
98   ::grpc::Status WaitAtBarrier(::grpc::ServerContext* context,
99                                const WaitAtBarrierRequest* request,
100                                WaitAtBarrierResponse* response) override;
101 
102  private:
103   // Entry point for the heartbeat checking thread.
104   void HeartbeatLoop();
105 
106   // Validates a session id number matches the current session id.
107   xla::Status ValidateSessionId(uint64_t session_id);
108 
109   // Validates a node id number.
110   xla::Status ValidateNodeId(int node_id);
111 
112   const Options options_;
113   const uint64_t session_id_;
114 
115   absl::Mutex mu_;
116   enum class State { kInitializing, kRunning, kClosed };
117   State state_ ABSL_GUARDED_BY(mu_) = State::kInitializing;
118   Status service_status_ ABSL_GUARDED_BY(mu_);
119 
120   // State for Connect() and heartbeats.
121   struct Node {
122     // Have we heard from a task with this ID?
123     bool present = false;
124 
125     // A unique ID belonging to the client. Used to identify the client that
126     // most recently called Connect() with a particular task id.
127     uint64_t client_id = 0;
128 
129     // When did we last receive a heartbeat from this task?
130     absl::Time last_heartbeat = absl::InfinitePast();
131   };
132   int num_nodes_present_ ABSL_GUARDED_BY(mu_) = 0;
133   std::vector<Node> nodes_ ABSL_GUARDED_BY(mu_);
134 
135   // State for EnumerateDevices.
136   int num_topologies_present_ ABSL_GUARDED_BY(mu_) = 0;
137   std::vector<LocalTopologyProto> local_topologies_ ABSL_GUARDED_BY(mu_);
138   std::optional<GlobalTopologyProto> topology_ ABSL_GUARDED_BY(mu_);
139 
140   // State for Shutdown(). Counter of how many nodes are blocked at the
141   // Shutdown() barrier.
142   int num_nodes_shutting_down_ ABSL_GUARDED_BY(mu_) = 0;
143 
144   // This dictionary tracks the number of nodes per barrier.
145   absl::flat_hash_map<std::string, int> barrier_id_to_num_nodes_
146       ABSL_GUARDED_BY(mu_);
147 
148   // Key-value store, used by distributed GPU code to share NCCL state.
149   KeyValueStore key_value_store_;
150 
151   // Notification that tells the heartbeat thread to stop.
152   absl::Notification stop_heartbeat_thread_;
153 
154   // Thread that checks for missing hearbeats from the clients periodically.
155   std::unique_ptr<tensorflow::Thread> heartbeat_thread_;
156 };
157 
158 class CoordinationServiceImpl {
159  public:
160   CoordinationServiceImpl(const DistributedRuntimeServiceImpl::Options& options,
161                           ::grpc::ServerBuilder* builder);
162   ~CoordinationServiceImpl();
163 
164   // Must be called after gRPC server has started.
165   void StartRpcThread();
166 
167   CoordinationServiceImpl(const CoordinationServiceImpl&) = delete;
168   CoordinationServiceImpl(CoordinationServiceImpl&&) = delete;
169   CoordinationServiceImpl& operator=(const CoordinationServiceImpl&) = delete;
170   CoordinationServiceImpl&& operator=(CoordinationServiceImpl&&) = delete;
171 
172  private:
173   tensorflow::Env* env_ = nullptr;  // Not owned.
174   std::unique_ptr<tensorflow::CoordinationServiceInterface> coord_service_;
175   std::unique_ptr<tensorflow::thread::ThreadPool> coord_compute_pool_;
176   std::unique_ptr<tensorflow::AsyncServiceInterface> coord_rpc_service_;
177   std::unique_ptr<tensorflow::Thread> coord_rpc_thread_;
178 };
179 
180 class DistributedRuntimeService {
181  public:
182   static xla::StatusOr<std::unique_ptr<DistributedRuntimeService>> Get(
183       const std::string& address,
184       std::shared_ptr<::grpc::ServerCredentials> credentials,
185       const DistributedRuntimeServiceImpl::Options& options,
186       bool use_coordination_service);
187 
188   explicit DistributedRuntimeService(
189       const DistributedRuntimeServiceImpl::Options& options,
190       ::grpc::ServerBuilder* builder, bool use_coordination_service);
191   ~DistributedRuntimeService();
192 
193   DistributedRuntimeService(const DistributedRuntimeService&) = delete;
194   DistributedRuntimeService(DistributedRuntimeService&&) = delete;
195   DistributedRuntimeService& operator=(const DistributedRuntimeService&) =
196       delete;
197   DistributedRuntimeService& operator=(DistributedRuntimeService&&) = delete;
198 
199   void Shutdown();
200 
server()201   ::grpc::Server* server() const { return server_.get(); }
202 
203  private:
204   std::unique_ptr<DistributedRuntimeServiceImpl> impl_;
205   std::unique_ptr<CoordinationServiceImpl> coord_impl_;
206   std::unique_ptr<::grpc::Server> server_;
207 };
208 
209 // Everything below this point is exposed only for tests.
210 
211 // Given a LocalTopologyProto object from each node, builds a
212 // GlobalTopologyProto that describes all nodes.
213 void BuildGlobalTopology(absl::Span<LocalTopologyProto> local_topologies,
214                          GlobalTopologyProto* global_topology);
215 
216 }  // namespace xla
217 
218 #endif  // TENSORFLOW_COMPILER_XLA_PJRT_DISTRIBUTED_SERVICE_H_
219