1 /* Copyright 2020 Google LLC 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #ifndef TENSORFLOW_COMPILER_XLA_PJRT_DISTRIBUTED_CLIENT_H_ 17 #define TENSORFLOW_COMPILER_XLA_PJRT_DISTRIBUTED_CLIENT_H_ 18 19 #include <functional> 20 #include <memory> 21 #include <string> 22 23 #include "absl/time/time.h" 24 #include "grpcpp/channel.h" 25 #include "tensorflow/compiler/xla/pjrt/distributed/protocol.grpc.pb.h" 26 #include "tensorflow/compiler/xla/statusor.h" 27 #include "tensorflow/compiler/xla/types.h" 28 #include "tensorflow/core/platform/env.h" 29 30 namespace tensorflow { 31 class CoordinationServiceAgent; 32 } // namespace tensorflow 33 34 namespace xla { 35 36 class DistributedRuntimeClient { 37 public: 38 struct Options { 39 // This node's global ID. Required. 40 int32_t node_id = -1; 41 42 // Environment used for starting threads. 43 tensorflow::Env* env = tensorflow::Env::Default(); 44 45 // RPC timeout used for RPC that don't have their own timeouts. 46 absl::Duration rpc_timeout = absl::Seconds(120); 47 48 // Time period for which Connect() should be retried. The client will keep 49 // trying to open the initial connection for this period, even if any 50 // individual Connect() RPC fails. May be zero, in which case Connect() will 51 // only be attempted once. 52 absl::Duration init_timeout = absl::ZeroDuration(); 53 54 // How long to wait for all nodes to call Shutdown(). If the timeout 55 // expires, then shutdown() reports an error and returns control. 56 absl::Duration shutdown_timeout = absl::Seconds(60); 57 58 // Interval at which the client should send heartbeat RPCs to the 59 // coordinator. 60 absl::Duration heartbeat_interval = absl::Seconds(10); 61 62 // How many failed heartbeat RPCs may fail due to a possibly-ephemeral 63 // reason before we decide the coordinator has vanished and that we should 64 // shut down. 65 int max_missing_heartbeats = 10; 66 67 // Callback invoked by the client when notification of a missing heartbeat 68 // is reported by the coordinator, or we have not heard from the coordinator 69 // recently. `coordinator_reported_failure` is true in the former case. 70 // Exposed so tests can override this behavior to something non-fatal. 71 std::function<void(xla::Status, bool coordinator_reported_failure)> 72 missed_heartbeat_callback = 73 [](xla::Status status, bool coordinator_reported_failure) { 74 if (coordinator_reported_failure) { 75 LOG(QFATAL) 76 << "Terminating process because the coordinator detected " 77 "missing heartbeats. This most likely indicates that " 78 "another task died; see the other task logs for more " 79 "details. Status: " 80 << status; 81 } else { 82 LOG(QFATAL) 83 << "Terminating process because of missing heartbeat " 84 "response from the coordinator. This most likely " 85 "indicates that the coordinator task died; see the " 86 "coordinator's task logs for more details. Status: " 87 << status; 88 } 89 }; 90 91 // For testing. Should the client explicitly Shutdown() on destruction? 92 bool shutdown_on_destruction = true; 93 }; 94 ~DistributedRuntimeClient()95 virtual ~DistributedRuntimeClient() {} 96 97 // Connects to the master, and blocks until all clients have successfully 98 // connected. 99 // Not thread-safe, i.e., calls to Connect()/Shutdown()/EnumerateDevices() 100 // must be serialized by some other means. 101 virtual xla::Status Connect() = 0; 102 103 // Reports to the master that the client is ready to shutdown, and blocks 104 // until all clients are ready to shutdown or the shutdown timeout expires. 105 // Not thread-safe. 106 virtual xla::Status Shutdown() = 0; 107 108 // Blocking enumeration of global devices. Used by the GPU platform. 109 // Not thread-safe. 110 virtual xla::Status EnumerateDevices( 111 const LocalTopologyProto& local_topology, 112 GlobalTopologyProto* global_topology) = 0; 113 114 // The following APIs are thread-safe. 115 virtual xla::StatusOr<std::string> BlockingKeyValueGet( 116 std::string key, absl::Duration timeout) = 0; 117 118 virtual xla::Status KeyValueSet(std::string key, std::string value) = 0; 119 120 // Blocks until all nodes are at the barrier or the barrier times out. 121 // `barrier_id` should be unique across barriers. 122 virtual xla::Status WaitAtBarrier(std::string barrier_id, 123 absl::Duration timeout) = 0; 124 125 // Returns pointer to coordination service agent, or InternalError if the 126 // client does not use coordination service. 127 virtual StatusOr<tensorflow::CoordinationServiceAgent*> 128 GetCoordinationServiceAgent() = 0; 129 }; 130 131 // Creates a distributed runtime client. 132 std::unique_ptr<DistributedRuntimeClient> GetDistributedRuntimeClient( 133 std::shared_ptr<::grpc::Channel> channel, 134 const DistributedRuntimeClient::Options& options, 135 bool use_coordination_service); 136 137 } // namespace xla 138 139 #endif // TENSORFLOW_COMPILER_XLA_PJRT_DISTRIBUTED_CLIENT_H_ 140