xref: /aosp_15_r20/external/tensorflow/tensorflow/compiler/xla/pjrt/distributed/client.h (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2020 Google LLC
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #ifndef TENSORFLOW_COMPILER_XLA_PJRT_DISTRIBUTED_CLIENT_H_
17 #define TENSORFLOW_COMPILER_XLA_PJRT_DISTRIBUTED_CLIENT_H_
18 
19 #include <functional>
20 #include <memory>
21 #include <string>
22 
23 #include "absl/time/time.h"
24 #include "grpcpp/channel.h"
25 #include "tensorflow/compiler/xla/pjrt/distributed/protocol.grpc.pb.h"
26 #include "tensorflow/compiler/xla/statusor.h"
27 #include "tensorflow/compiler/xla/types.h"
28 #include "tensorflow/core/platform/env.h"
29 
30 namespace tensorflow {
31 class CoordinationServiceAgent;
32 }  // namespace tensorflow
33 
34 namespace xla {
35 
36 class DistributedRuntimeClient {
37  public:
38   struct Options {
39     // This node's global ID. Required.
40     int32_t node_id = -1;
41 
42     // Environment used for starting threads.
43     tensorflow::Env* env = tensorflow::Env::Default();
44 
45     // RPC timeout used for RPC that don't have their own timeouts.
46     absl::Duration rpc_timeout = absl::Seconds(120);
47 
48     // Time period for which Connect() should be retried. The client will keep
49     // trying to open the initial connection for this period, even if any
50     // individual Connect() RPC fails. May be zero, in which case Connect() will
51     // only be attempted once.
52     absl::Duration init_timeout = absl::ZeroDuration();
53 
54     // How long to wait for all nodes to call Shutdown(). If the timeout
55     // expires, then shutdown() reports an error and returns control.
56     absl::Duration shutdown_timeout = absl::Seconds(60);
57 
58     // Interval at which the client should send heartbeat RPCs to the
59     // coordinator.
60     absl::Duration heartbeat_interval = absl::Seconds(10);
61 
62     // How many failed heartbeat RPCs may fail due to a possibly-ephemeral
63     // reason before we decide the coordinator has vanished and that we should
64     // shut down.
65     int max_missing_heartbeats = 10;
66 
67     // Callback invoked by the client when notification of a missing heartbeat
68     // is reported by the coordinator, or we have not heard from the coordinator
69     // recently. `coordinator_reported_failure` is true in the former case.
70     // Exposed so tests can override this behavior to something non-fatal.
71     std::function<void(xla::Status, bool coordinator_reported_failure)>
72         missed_heartbeat_callback =
73             [](xla::Status status, bool coordinator_reported_failure) {
74               if (coordinator_reported_failure) {
75                 LOG(QFATAL)
76                     << "Terminating process because the coordinator detected "
77                        "missing heartbeats. This most likely indicates that "
78                        "another task died; see the other task logs for more "
79                        "details. Status: "
80                     << status;
81               } else {
82                 LOG(QFATAL)
83                     << "Terminating process because of missing heartbeat "
84                        "response from the coordinator. This most likely "
85                        "indicates that the coordinator task died; see the "
86                        "coordinator's task logs for more details. Status: "
87                     << status;
88               }
89             };
90 
91     // For testing. Should the client explicitly Shutdown() on destruction?
92     bool shutdown_on_destruction = true;
93   };
94 
~DistributedRuntimeClient()95   virtual ~DistributedRuntimeClient() {}
96 
97   // Connects to the master, and blocks until all clients have successfully
98   // connected.
99   // Not thread-safe, i.e., calls to Connect()/Shutdown()/EnumerateDevices()
100   // must be serialized by some other means.
101   virtual xla::Status Connect() = 0;
102 
103   // Reports to the master that the client is ready to shutdown, and blocks
104   // until all clients are ready to shutdown or the shutdown timeout expires.
105   // Not thread-safe.
106   virtual xla::Status Shutdown() = 0;
107 
108   // Blocking enumeration of global devices. Used by the GPU platform.
109   // Not thread-safe.
110   virtual xla::Status EnumerateDevices(
111       const LocalTopologyProto& local_topology,
112       GlobalTopologyProto* global_topology) = 0;
113 
114   // The following APIs are thread-safe.
115   virtual xla::StatusOr<std::string> BlockingKeyValueGet(
116       std::string key, absl::Duration timeout) = 0;
117 
118   virtual xla::Status KeyValueSet(std::string key, std::string value) = 0;
119 
120   // Blocks until all nodes are at the barrier or the barrier times out.
121   // `barrier_id` should be unique across barriers.
122   virtual xla::Status WaitAtBarrier(std::string barrier_id,
123                                     absl::Duration timeout) = 0;
124 
125   // Returns pointer to coordination service agent, or InternalError if the
126   // client does not use coordination service.
127   virtual StatusOr<tensorflow::CoordinationServiceAgent*>
128   GetCoordinationServiceAgent() = 0;
129 };
130 
131 // Creates a distributed runtime client.
132 std::unique_ptr<DistributedRuntimeClient> GetDistributedRuntimeClient(
133     std::shared_ptr<::grpc::Channel> channel,
134     const DistributedRuntimeClient::Options& options,
135     bool use_coordination_service);
136 
137 }  // namespace xla
138 
139 #endif  // TENSORFLOW_COMPILER_XLA_PJRT_DISTRIBUTED_CLIENT_H_
140