1 /* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/core/distributed_runtime/rpc/coordination/grpc_coordination_client.h"
17
18 #include <string>
19 #include <utility>
20
21 #include "tensorflow/core/distributed_runtime/coordination/coordination_client.h"
22 #include "tensorflow/core/distributed_runtime/rpc/grpc_channel.h"
23 #include "tensorflow/core/distributed_runtime/rpc/grpc_client_cq_tag.h"
24 #include "tensorflow/core/distributed_runtime/rpc/grpc_state.h"
25 #include "tensorflow/core/distributed_runtime/rpc/grpc_util.h"
26 #include "tensorflow/core/platform/thread_annotations.h"
27 #include "tensorflow/core/protobuf/coordination_service.pb.h"
28
29 namespace tensorflow {
30 namespace {
31
32 class GrpcCoordinationClientThread {
33 public:
GrpcCoordinationClientThread()34 GrpcCoordinationClientThread() {
35 thread_.reset(Env::Default()->StartThread(
36 ThreadOptions(), "coordination_client_thread", [this]() {
37 void* tag;
38 bool ok;
39 while (completion_queue_.Next(&tag, &ok)) {
40 VLOG(4) << "GrpcCoordinationClientThread got next tag";
41 GrpcClientCQTag* callback_tag = static_cast<GrpcClientCQTag*>(tag);
42 callback_tag->OnCompleted(ok);
43 VLOG(4) << "GrpcCoordinationClientThread blocking for next tag";
44 }
45 VLOG(4) << "GrpcCoordinationClientThread exiting";
46 }));
47 }
48
~GrpcCoordinationClientThread()49 ~GrpcCoordinationClientThread() {
50 completion_queue_.Shutdown();
51 thread_.reset();
52 }
53
completion_queue()54 ::grpc::CompletionQueue* completion_queue() { return &completion_queue_; }
55
56 private:
57 ::grpc::CompletionQueue completion_queue_;
58 std::unique_ptr<Thread> thread_;
59 };
60
61 class GrpcCoordinationClient : public CoordinationClient {
62 public:
GrpcCoordinationClient(SharedGrpcChannelPtr channel,::grpc::CompletionQueue * cq,const std::string & target)63 GrpcCoordinationClient(SharedGrpcChannelPtr channel,
64 ::grpc::CompletionQueue* cq, const std::string& target)
65 : stub_(channel), cq_(cq), target_(target) {}
GrpcCoordinationClient(SharedGrpcChannelPtr channel,const std::string & target)66 GrpcCoordinationClient(SharedGrpcChannelPtr channel,
67 const std::string& target)
68 : stub_(channel), target_(target) {
69 client_thread_ = std::make_unique<GrpcCoordinationClientThread>();
70 cq_ = client_thread_->completion_queue();
71 }
~GrpcCoordinationClient()72 ~GrpcCoordinationClient() override {}
73
RegisterTaskAsync(CallOptions * call_opts,const RegisterTaskRequest * request,RegisterTaskResponse * response,StatusCallback done)74 void RegisterTaskAsync(CallOptions* call_opts,
75 const RegisterTaskRequest* request,
76 RegisterTaskResponse* response,
77 StatusCallback done) override {
78 new RPCState<protobuf::Message>(
79 &stub_, cq_, "/tensorflow.CoordinationService/RegisterTask", *request,
80 response, std::move(done), call_opts,
81 /*threadpool=*/nullptr, /*max_retries=*/0, /*fail_fast=*/false,
82 &target_);
83 }
84
WaitForAllTasksAsync(const WaitForAllTasksRequest * request,WaitForAllTasksResponse * response,StatusCallback done)85 void WaitForAllTasksAsync(const WaitForAllTasksRequest* request,
86 WaitForAllTasksResponse* response,
87 StatusCallback done) override {
88 new RPCState<protobuf::Message>(
89 &stub_, cq_, "/tensorflow.CoordinationService/WaitForAllTasks",
90 *request, response, std::move(done), /*call_opts=*/nullptr,
91 /*threadpool=*/nullptr, /*max_retries=*/0, /*fail_fast=*/true,
92 &target_);
93 }
94
ShutdownTaskAsync(CallOptions * call_opts,const ShutdownTaskRequest * request,ShutdownTaskResponse * response,StatusCallback done)95 void ShutdownTaskAsync(CallOptions* call_opts,
96 const ShutdownTaskRequest* request,
97 ShutdownTaskResponse* response,
98 StatusCallback done) override {
99 new RPCState<protobuf::Message>(
100 &stub_, cq_, "/tensorflow.CoordinationService/ShutdownTask", *request,
101 response, std::move(done), call_opts,
102 /*threadpool=*/nullptr, /*max_retries=*/0, /*fail_fast=*/true,
103 &target_);
104 }
105
ResetTaskAsync(const ResetTaskRequest * request,ResetTaskResponse * response,StatusCallback done)106 void ResetTaskAsync(const ResetTaskRequest* request,
107 ResetTaskResponse* response,
108 StatusCallback done) override {
109 new RPCState<protobuf::Message>(
110 &stub_, cq_, "/tensorflow.CoordinationService/ResetTask", *request,
111 response, std::move(done), /*call_opts=*/nullptr,
112 /*threadpool=*/nullptr, /*max_retries=*/0, /*fail_fast=*/true,
113 &target_);
114 }
115
HeartbeatAsync(CallOptions * call_opts,const HeartbeatRequest * request,HeartbeatResponse * response,StatusCallback done)116 void HeartbeatAsync(CallOptions* call_opts, const HeartbeatRequest* request,
117 HeartbeatResponse* response,
118 StatusCallback done) override {
119 // Different from other RPCs which do not retry by default, the Heartbeat
120 // RPC should retry automatically to tolerate transient network issues.
121 new RPCState<protobuf::Message>(
122 &stub_, cq_, "/tensorflow.CoordinationService/Heartbeat", *request,
123 response, std::move(done), call_opts, /*threadpool=*/nullptr,
124 /*max_retries=*/3,
125 /*fail_fast=*/true, &target_);
126 }
127
ReportErrorToTaskAsync(CallOptions * call_opts,const ReportErrorToTaskRequest * request,ReportErrorToTaskResponse * response,StatusCallback done)128 void ReportErrorToTaskAsync(CallOptions* call_opts,
129 const ReportErrorToTaskRequest* request,
130 ReportErrorToTaskResponse* response,
131 StatusCallback done) override {
132 new RPCState<protobuf::Message>(
133 &stub_, cq_, "/tensorflow.CoordinationService/ReportErrorToTask",
134 *request, response, std::move(done), call_opts,
135 /*threadpool=*/nullptr, /*max_retries=*/0, /*fail_fast=*/true,
136 &target_);
137 }
138
ReportErrorToServiceAsync(const ReportErrorToServiceRequest * request,ReportErrorToServiceResponse * response,StatusCallback done)139 void ReportErrorToServiceAsync(const ReportErrorToServiceRequest* request,
140 ReportErrorToServiceResponse* response,
141 StatusCallback done) override {
142 new RPCState<protobuf::Message>(
143 &stub_, cq_, "/tensorflow.CoordinationService/ReportErrorToService",
144 *request, response, std::move(done), /*call_opts=*/nullptr,
145 /*threadpool=*/nullptr, /*max_retries=*/0, /*fail_fast=*/true,
146 &target_);
147 }
148
InsertKeyValueAsync(const InsertKeyValueRequest * request,InsertKeyValueResponse * response,StatusCallback done)149 void InsertKeyValueAsync(const InsertKeyValueRequest* request,
150 InsertKeyValueResponse* response,
151 StatusCallback done) override {
152 new RPCState<protobuf::Message>(
153 &stub_, cq_, "/tensorflow.CoordinationService/InsertKeyValue", *request,
154 response, std::move(done), /*call_opts=*/nullptr,
155 /*threadpool=*/nullptr, /*max_retries=*/0, /*fail_fast=*/true,
156 &target_);
157 }
158
GetKeyValueAsync(CallOptions * call_opts,const GetKeyValueRequest * request,GetKeyValueResponse * response,StatusCallback done)159 void GetKeyValueAsync(CallOptions* call_opts,
160 const GetKeyValueRequest* request,
161 GetKeyValueResponse* response,
162 StatusCallback done) override {
163 new RPCState<protobuf::Message>(
164 &stub_, cq_, "/tensorflow.CoordinationService/GetKeyValue", *request,
165 response, std::move(done), call_opts,
166 /*threadpool=*/nullptr, /*max_retries=*/0, /*fail_fast=*/true,
167 &target_);
168 }
169
TryGetKeyValueAsync(const TryGetKeyValueRequest * request,TryGetKeyValueResponse * response,StatusCallback done)170 void TryGetKeyValueAsync(const TryGetKeyValueRequest* request,
171 TryGetKeyValueResponse* response,
172 StatusCallback done) override {
173 new RPCState<protobuf::Message>(
174 &stub_, cq_, "/tensorflow.CoordinationService/TryGetKeyValue", *request,
175 response, std::move(done), /*call_opts=*/nullptr,
176 /*threadpool=*/nullptr, /*max_retries=*/0, /*fail_fast=*/true,
177 &target_);
178 }
179
GetKeyValueDirAsync(const GetKeyValueDirRequest * request,GetKeyValueDirResponse * response,StatusCallback done)180 void GetKeyValueDirAsync(const GetKeyValueDirRequest* request,
181 GetKeyValueDirResponse* response,
182 StatusCallback done) override {
183 new RPCState<protobuf::Message>(
184 &stub_, cq_, "/tensorflow.CoordinationService/GetKeyValueDir", *request,
185 response, std::move(done), /*call_opts=*/nullptr,
186 /*threadpool=*/nullptr, /*max_retries=*/0, /*fail_fast=*/true,
187 &target_);
188 }
189
DeleteKeyValueAsync(const DeleteKeyValueRequest * request,DeleteKeyValueResponse * response,StatusCallback done)190 void DeleteKeyValueAsync(const DeleteKeyValueRequest* request,
191 DeleteKeyValueResponse* response,
192 StatusCallback done) override {
193 new RPCState<protobuf::Message>(
194 &stub_, cq_, "/tensorflow.CoordinationService/DeleteKeyValue", *request,
195 response, std::move(done), /*call_opts=*/nullptr,
196 /*threadpool=*/nullptr, /*max_retries=*/0, /*fail_fast=*/true,
197 &target_);
198 }
199
BarrierAsync(const BarrierRequest * request,BarrierResponse * response,StatusCallback done)200 void BarrierAsync(const BarrierRequest* request, BarrierResponse* response,
201 StatusCallback done) override {
202 new RPCState<protobuf::Message>(
203 &stub_, cq_, "/tensorflow.CoordinationService/Barrier", *request,
204 response, std::move(done), /*call_opts=*/nullptr,
205 /*threadpool=*/nullptr, /*max_retries=*/0, /*fail_fast=*/true,
206 &target_);
207 }
208
CancelBarrierAsync(const CancelBarrierRequest * request,CancelBarrierResponse * response,StatusCallback done)209 void CancelBarrierAsync(const CancelBarrierRequest* request,
210 CancelBarrierResponse* response,
211 StatusCallback done) override {
212 new RPCState<protobuf::Message>(
213 &stub_, cq_, "/tensorflow.CoordinationService/CancelBarrier", *request,
214 response, std::move(done), /*call_opts=*/nullptr,
215 /*threadpool=*/nullptr, /*max_retries=*/0, /*fail_fast=*/true,
216 &target_);
217 }
218
219 private:
220 ::grpc::GenericStub stub_;
221 ::grpc::CompletionQueue* cq_;
222 const string target_;
223 std::unique_ptr<GrpcCoordinationClientThread> client_thread_;
224 };
225
226 class GrpcCoordinationClientCache : public CoordinationClientCache {
227 public:
GrpcCoordinationClientCache(std::shared_ptr<GrpcChannelCache> channel_cache)228 explicit GrpcCoordinationClientCache(
229 std::shared_ptr<GrpcChannelCache> channel_cache)
230 : next_round_robin_assignment_(0),
231 channel_cache_(channel_cache),
232 threads_(4) {}
233
~GrpcCoordinationClientCache()234 ~GrpcCoordinationClientCache() override {}
235
GetClient(const string & target)236 CoordinationClient* GetClient(const string& target) override {
237 mutex_lock l(clients_mu_);
238 auto it = clients_.find(target);
239 if (it == clients_.end()) {
240 SharedGrpcChannelPtr channel = channel_cache_->FindWorkerChannel(target);
241 if (channel == nullptr) {
242 VLOG(2) << "Coordination client for target " << target << " not found.";
243 }
244 int assigned_index = AssignClientToThread(target);
245 auto coord_client = std::make_unique<GrpcCoordinationClient>(
246 channel, threads_[assigned_index].completion_queue(), target);
247 it = clients_.emplace(target, std::move(coord_client)).first;
248 }
249 return it->second.get();
250 }
251
GetOwnedClient(const string & target)252 std::unique_ptr<CoordinationClient> GetOwnedClient(
253 const string& target) override {
254 SharedGrpcChannelPtr channel = channel_cache_->FindWorkerChannel(target);
255 if (channel == nullptr) {
256 VLOG(2) << "Coordination client for target " << target << " not found.";
257 }
258 return std::make_unique<GrpcCoordinationClient>(channel, target);
259 }
260
261 private:
262 mutex assignment_mu_;
263 std::unordered_map<std::string, size_t> target_assignments_
264 TF_GUARDED_BY(assignment_mu_);
265 size_t next_round_robin_assignment_ TF_GUARDED_BY(assignment_mu_);
266
AssignClientToThread(const string & target)267 size_t AssignClientToThread(const string& target) {
268 // Round-robin target assignment, but keeps the same target on the same
269 // polling thread always, as this is important for gRPC performance
270 mutex_lock lock(assignment_mu_);
271 auto it = target_assignments_.find(target);
272 if (it == target_assignments_.end()) {
273 it = target_assignments_
274 .insert(std::make_pair(
275 target, (next_round_robin_assignment_++) % threads_.size()))
276 .first;
277 }
278 return it->second;
279 }
280
281 std::shared_ptr<GrpcChannelCache> channel_cache_;
282 mutable mutex clients_mu_;
283 std::unordered_map<std::string, std::unique_ptr<CoordinationClient>> clients_
284 TF_GUARDED_BY(clients_mu_);
285 std::vector<GrpcCoordinationClientThread> threads_;
286 };
287
288 } // namespace
289
NewGrpcCoordinationClientCache(std::shared_ptr<GrpcChannelCache> channel_cache)290 CoordinationClientCache* NewGrpcCoordinationClientCache(
291 std::shared_ptr<GrpcChannelCache> channel_cache) {
292 return new GrpcCoordinationClientCache(channel_cache);
293 }
294
NewGrpcCoordinationClient(std::shared_ptr<::grpc::Channel> channel)295 CoordinationClient* NewGrpcCoordinationClient(
296 std::shared_ptr<::grpc::Channel> channel) {
297 // TODO(hanyangtay): Pass in the logical task name for better logging.
298 return new GrpcCoordinationClient(
299 channel, /*target=*/"unknown_target_for_coordination_leader");
300 }
301
302 } // namespace tensorflow
303