1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #ifndef TENSORFLOW_COMPILER_XLA_PJRT_DISTRIBUTED_SERVICE_H_ 17 #define TENSORFLOW_COMPILER_XLA_PJRT_DISTRIBUTED_SERVICE_H_ 18 19 #include <string> 20 21 #include "absl/container/flat_hash_map.h" 22 #include "absl/synchronization/mutex.h" 23 #include "absl/synchronization/notification.h" 24 #include "absl/time/time.h" 25 #include "grpcpp/security/server_credentials.h" 26 #include "tensorflow/compiler/xla/pjrt/distributed/key_value_store.h" 27 #include "tensorflow/compiler/xla/pjrt/distributed/protocol.grpc.pb.h" 28 #include "tensorflow/compiler/xla/statusor.h" 29 #include "tensorflow/compiler/xla/types.h" 30 #include "tensorflow/core/distributed_runtime/coordination/coordination_service.h" 31 #include "tensorflow/core/distributed_runtime/rpc/async_service_interface.h" 32 #include "tensorflow/core/platform/env.h" 33 #include "tensorflow/core/platform/threadpool.h" 34 35 namespace xla { 36 37 typedef int NodeId; 38 39 class DistributedRuntimeServiceImpl final 40 : public grpc::DistributedRuntimeService::Service { 41 public: 42 struct Options { 43 // Number of nodes in the job. Mandatory. Must be non-negative. 44 int num_nodes = -1; 45 46 tensorflow::Env* env = tensorflow::Env::Default(); 47 48 // Interval at which the service should check for missed heartbeat RPCs 49 // from the clients. 50 absl::Duration heartbeat_interval = absl::Seconds(10); 51 52 // Number of heartbeats that a client may miss in a row before the 53 // coordinator concludes that a client has vanished. 54 int max_missing_heartbeats = 10; 55 56 // How long should we wait for all clients to call EnumerateDevices() before 57 // giving up? 58 absl::Duration enumerate_devices_timeout = absl::Seconds(60); 59 60 // How long should we wait for all clients to call Shutdown() before giving 61 // up and returning a failure? 62 absl::Duration shutdown_timeout = absl::Seconds(60); 63 }; 64 explicit DistributedRuntimeServiceImpl(const Options& options); 65 ~DistributedRuntimeServiceImpl() override; 66 67 DistributedRuntimeServiceImpl(const DistributedRuntimeServiceImpl&) = delete; 68 DistributedRuntimeServiceImpl(DistributedRuntimeServiceImpl&&) = delete; 69 DistributedRuntimeServiceImpl& operator=( 70 const DistributedRuntimeServiceImpl&) = delete; 71 DistributedRuntimeServiceImpl&& operator=(DistributedRuntimeServiceImpl&&) = 72 delete; 73 74 ::grpc::Status Connect(::grpc::ServerContext* context, 75 const ConnectRequest* request, 76 ConnectResponse* response) override; 77 78 ::grpc::Status Shutdown(::grpc::ServerContext* context, 79 const ShutdownRequest* request, 80 ShutdownResponse* response) override; 81 82 ::grpc::Status Heartbeat(::grpc::ServerContext* context, 83 const HeartbeatRequest* request, 84 HeartbeatResponse* response) override; 85 86 ::grpc::Status EnumerateDevices(::grpc::ServerContext* context, 87 const EnumerateDevicesRequest* request, 88 EnumerateDevicesResponse* response) override; 89 90 ::grpc::Status KeyValueGet(::grpc::ServerContext* context, 91 const KeyValueGetRequest* request, 92 KeyValueGetResponse* response) override; 93 94 ::grpc::Status KeyValueSet(::grpc::ServerContext* context, 95 const KeyValueSetRequest* request, 96 KeyValueSetResponse* response) override; 97 98 ::grpc::Status WaitAtBarrier(::grpc::ServerContext* context, 99 const WaitAtBarrierRequest* request, 100 WaitAtBarrierResponse* response) override; 101 102 private: 103 // Entry point for the heartbeat checking thread. 104 void HeartbeatLoop(); 105 106 // Validates a session id number matches the current session id. 107 xla::Status ValidateSessionId(uint64_t session_id); 108 109 // Validates a node id number. 110 xla::Status ValidateNodeId(int node_id); 111 112 const Options options_; 113 const uint64_t session_id_; 114 115 absl::Mutex mu_; 116 enum class State { kInitializing, kRunning, kClosed }; 117 State state_ ABSL_GUARDED_BY(mu_) = State::kInitializing; 118 Status service_status_ ABSL_GUARDED_BY(mu_); 119 120 // State for Connect() and heartbeats. 121 struct Node { 122 // Have we heard from a task with this ID? 123 bool present = false; 124 125 // A unique ID belonging to the client. Used to identify the client that 126 // most recently called Connect() with a particular task id. 127 uint64_t client_id = 0; 128 129 // When did we last receive a heartbeat from this task? 130 absl::Time last_heartbeat = absl::InfinitePast(); 131 }; 132 int num_nodes_present_ ABSL_GUARDED_BY(mu_) = 0; 133 std::vector<Node> nodes_ ABSL_GUARDED_BY(mu_); 134 135 // State for EnumerateDevices. 136 int num_topologies_present_ ABSL_GUARDED_BY(mu_) = 0; 137 std::vector<LocalTopologyProto> local_topologies_ ABSL_GUARDED_BY(mu_); 138 std::optional<GlobalTopologyProto> topology_ ABSL_GUARDED_BY(mu_); 139 140 // State for Shutdown(). Counter of how many nodes are blocked at the 141 // Shutdown() barrier. 142 int num_nodes_shutting_down_ ABSL_GUARDED_BY(mu_) = 0; 143 144 // This dictionary tracks the number of nodes per barrier. 145 absl::flat_hash_map<std::string, int> barrier_id_to_num_nodes_ 146 ABSL_GUARDED_BY(mu_); 147 148 // Key-value store, used by distributed GPU code to share NCCL state. 149 KeyValueStore key_value_store_; 150 151 // Notification that tells the heartbeat thread to stop. 152 absl::Notification stop_heartbeat_thread_; 153 154 // Thread that checks for missing hearbeats from the clients periodically. 155 std::unique_ptr<tensorflow::Thread> heartbeat_thread_; 156 }; 157 158 class CoordinationServiceImpl { 159 public: 160 CoordinationServiceImpl(const DistributedRuntimeServiceImpl::Options& options, 161 ::grpc::ServerBuilder* builder); 162 ~CoordinationServiceImpl(); 163 164 // Must be called after gRPC server has started. 165 void StartRpcThread(); 166 167 CoordinationServiceImpl(const CoordinationServiceImpl&) = delete; 168 CoordinationServiceImpl(CoordinationServiceImpl&&) = delete; 169 CoordinationServiceImpl& operator=(const CoordinationServiceImpl&) = delete; 170 CoordinationServiceImpl&& operator=(CoordinationServiceImpl&&) = delete; 171 172 private: 173 tensorflow::Env* env_ = nullptr; // Not owned. 174 std::unique_ptr<tensorflow::CoordinationServiceInterface> coord_service_; 175 std::unique_ptr<tensorflow::thread::ThreadPool> coord_compute_pool_; 176 std::unique_ptr<tensorflow::AsyncServiceInterface> coord_rpc_service_; 177 std::unique_ptr<tensorflow::Thread> coord_rpc_thread_; 178 }; 179 180 class DistributedRuntimeService { 181 public: 182 static xla::StatusOr<std::unique_ptr<DistributedRuntimeService>> Get( 183 const std::string& address, 184 std::shared_ptr<::grpc::ServerCredentials> credentials, 185 const DistributedRuntimeServiceImpl::Options& options, 186 bool use_coordination_service); 187 188 explicit DistributedRuntimeService( 189 const DistributedRuntimeServiceImpl::Options& options, 190 ::grpc::ServerBuilder* builder, bool use_coordination_service); 191 ~DistributedRuntimeService(); 192 193 DistributedRuntimeService(const DistributedRuntimeService&) = delete; 194 DistributedRuntimeService(DistributedRuntimeService&&) = delete; 195 DistributedRuntimeService& operator=(const DistributedRuntimeService&) = 196 delete; 197 DistributedRuntimeService& operator=(DistributedRuntimeService&&) = delete; 198 199 void Shutdown(); 200 server()201 ::grpc::Server* server() const { return server_.get(); } 202 203 private: 204 std::unique_ptr<DistributedRuntimeServiceImpl> impl_; 205 std::unique_ptr<CoordinationServiceImpl> coord_impl_; 206 std::unique_ptr<::grpc::Server> server_; 207 }; 208 209 // Everything below this point is exposed only for tests. 210 211 // Given a LocalTopologyProto object from each node, builds a 212 // GlobalTopologyProto that describes all nodes. 213 void BuildGlobalTopology(absl::Span<LocalTopologyProto> local_topologies, 214 GlobalTopologyProto* global_topology); 215 216 } // namespace xla 217 218 #endif // TENSORFLOW_COMPILER_XLA_PJRT_DISTRIBUTED_SERVICE_H_ 219