1 /* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/core/distributed_runtime/rpc/coordination/grpc_coordination_client.h"
17 
18 #include <string>
19 #include <utility>
20 
21 #include "tensorflow/core/distributed_runtime/coordination/coordination_client.h"
22 #include "tensorflow/core/distributed_runtime/rpc/grpc_channel.h"
23 #include "tensorflow/core/distributed_runtime/rpc/grpc_client_cq_tag.h"
24 #include "tensorflow/core/distributed_runtime/rpc/grpc_state.h"
25 #include "tensorflow/core/distributed_runtime/rpc/grpc_util.h"
26 #include "tensorflow/core/platform/thread_annotations.h"
27 #include "tensorflow/core/protobuf/coordination_service.pb.h"
28 
29 namespace tensorflow {
30 namespace {
31 
32 class GrpcCoordinationClientThread {
33  public:
GrpcCoordinationClientThread()34   GrpcCoordinationClientThread() {
35     thread_.reset(Env::Default()->StartThread(
36         ThreadOptions(), "coordination_client_thread", [this]() {
37           void* tag;
38           bool ok;
39           while (completion_queue_.Next(&tag, &ok)) {
40             VLOG(4) << "GrpcCoordinationClientThread got next tag";
41             GrpcClientCQTag* callback_tag = static_cast<GrpcClientCQTag*>(tag);
42             callback_tag->OnCompleted(ok);
43             VLOG(4) << "GrpcCoordinationClientThread blocking for next tag";
44           }
45           VLOG(4) << "GrpcCoordinationClientThread exiting";
46         }));
47   }
48 
~GrpcCoordinationClientThread()49   ~GrpcCoordinationClientThread() {
50     completion_queue_.Shutdown();
51     thread_.reset();
52   }
53 
completion_queue()54   ::grpc::CompletionQueue* completion_queue() { return &completion_queue_; }
55 
56  private:
57   ::grpc::CompletionQueue completion_queue_;
58   std::unique_ptr<Thread> thread_;
59 };
60 
61 class GrpcCoordinationClient : public CoordinationClient {
62  public:
GrpcCoordinationClient(SharedGrpcChannelPtr channel,::grpc::CompletionQueue * cq,const std::string & target)63   GrpcCoordinationClient(SharedGrpcChannelPtr channel,
64                          ::grpc::CompletionQueue* cq, const std::string& target)
65       : stub_(channel), cq_(cq), target_(target) {}
GrpcCoordinationClient(SharedGrpcChannelPtr channel,const std::string & target)66   GrpcCoordinationClient(SharedGrpcChannelPtr channel,
67                          const std::string& target)
68       : stub_(channel), target_(target) {
69     client_thread_ = std::make_unique<GrpcCoordinationClientThread>();
70     cq_ = client_thread_->completion_queue();
71   }
~GrpcCoordinationClient()72   ~GrpcCoordinationClient() override {}
73 
RegisterTaskAsync(CallOptions * call_opts,const RegisterTaskRequest * request,RegisterTaskResponse * response,StatusCallback done)74   void RegisterTaskAsync(CallOptions* call_opts,
75                          const RegisterTaskRequest* request,
76                          RegisterTaskResponse* response,
77                          StatusCallback done) override {
78     new RPCState<protobuf::Message>(
79         &stub_, cq_, "/tensorflow.CoordinationService/RegisterTask", *request,
80         response, std::move(done), call_opts,
81         /*threadpool=*/nullptr, /*max_retries=*/0, /*fail_fast=*/false,
82         &target_);
83   }
84 
WaitForAllTasksAsync(const WaitForAllTasksRequest * request,WaitForAllTasksResponse * response,StatusCallback done)85   void WaitForAllTasksAsync(const WaitForAllTasksRequest* request,
86                             WaitForAllTasksResponse* response,
87                             StatusCallback done) override {
88     new RPCState<protobuf::Message>(
89         &stub_, cq_, "/tensorflow.CoordinationService/WaitForAllTasks",
90         *request, response, std::move(done), /*call_opts=*/nullptr,
91         /*threadpool=*/nullptr, /*max_retries=*/0, /*fail_fast=*/true,
92         &target_);
93   }
94 
ShutdownTaskAsync(CallOptions * call_opts,const ShutdownTaskRequest * request,ShutdownTaskResponse * response,StatusCallback done)95   void ShutdownTaskAsync(CallOptions* call_opts,
96                          const ShutdownTaskRequest* request,
97                          ShutdownTaskResponse* response,
98                          StatusCallback done) override {
99     new RPCState<protobuf::Message>(
100         &stub_, cq_, "/tensorflow.CoordinationService/ShutdownTask", *request,
101         response, std::move(done), call_opts,
102         /*threadpool=*/nullptr, /*max_retries=*/0, /*fail_fast=*/true,
103         &target_);
104   }
105 
ResetTaskAsync(const ResetTaskRequest * request,ResetTaskResponse * response,StatusCallback done)106   void ResetTaskAsync(const ResetTaskRequest* request,
107                       ResetTaskResponse* response,
108                       StatusCallback done) override {
109     new RPCState<protobuf::Message>(
110         &stub_, cq_, "/tensorflow.CoordinationService/ResetTask", *request,
111         response, std::move(done), /*call_opts=*/nullptr,
112         /*threadpool=*/nullptr, /*max_retries=*/0, /*fail_fast=*/true,
113         &target_);
114   }
115 
HeartbeatAsync(CallOptions * call_opts,const HeartbeatRequest * request,HeartbeatResponse * response,StatusCallback done)116   void HeartbeatAsync(CallOptions* call_opts, const HeartbeatRequest* request,
117                       HeartbeatResponse* response,
118                       StatusCallback done) override {
119     // Different from other RPCs which do not retry by default, the Heartbeat
120     // RPC should retry automatically to tolerate transient network issues.
121     new RPCState<protobuf::Message>(
122         &stub_, cq_, "/tensorflow.CoordinationService/Heartbeat", *request,
123         response, std::move(done), call_opts, /*threadpool=*/nullptr,
124         /*max_retries=*/3,
125         /*fail_fast=*/true, &target_);
126   }
127 
ReportErrorToTaskAsync(CallOptions * call_opts,const ReportErrorToTaskRequest * request,ReportErrorToTaskResponse * response,StatusCallback done)128   void ReportErrorToTaskAsync(CallOptions* call_opts,
129                               const ReportErrorToTaskRequest* request,
130                               ReportErrorToTaskResponse* response,
131                               StatusCallback done) override {
132     new RPCState<protobuf::Message>(
133         &stub_, cq_, "/tensorflow.CoordinationService/ReportErrorToTask",
134         *request, response, std::move(done), call_opts,
135         /*threadpool=*/nullptr, /*max_retries=*/0, /*fail_fast=*/true,
136         &target_);
137   }
138 
ReportErrorToServiceAsync(const ReportErrorToServiceRequest * request,ReportErrorToServiceResponse * response,StatusCallback done)139   void ReportErrorToServiceAsync(const ReportErrorToServiceRequest* request,
140                                  ReportErrorToServiceResponse* response,
141                                  StatusCallback done) override {
142     new RPCState<protobuf::Message>(
143         &stub_, cq_, "/tensorflow.CoordinationService/ReportErrorToService",
144         *request, response, std::move(done), /*call_opts=*/nullptr,
145         /*threadpool=*/nullptr, /*max_retries=*/0, /*fail_fast=*/true,
146         &target_);
147   }
148 
InsertKeyValueAsync(const InsertKeyValueRequest * request,InsertKeyValueResponse * response,StatusCallback done)149   void InsertKeyValueAsync(const InsertKeyValueRequest* request,
150                            InsertKeyValueResponse* response,
151                            StatusCallback done) override {
152     new RPCState<protobuf::Message>(
153         &stub_, cq_, "/tensorflow.CoordinationService/InsertKeyValue", *request,
154         response, std::move(done), /*call_opts=*/nullptr,
155         /*threadpool=*/nullptr, /*max_retries=*/0, /*fail_fast=*/true,
156         &target_);
157   }
158 
GetKeyValueAsync(CallOptions * call_opts,const GetKeyValueRequest * request,GetKeyValueResponse * response,StatusCallback done)159   void GetKeyValueAsync(CallOptions* call_opts,
160                         const GetKeyValueRequest* request,
161                         GetKeyValueResponse* response,
162                         StatusCallback done) override {
163     new RPCState<protobuf::Message>(
164         &stub_, cq_, "/tensorflow.CoordinationService/GetKeyValue", *request,
165         response, std::move(done), call_opts,
166         /*threadpool=*/nullptr, /*max_retries=*/0, /*fail_fast=*/true,
167         &target_);
168   }
169 
TryGetKeyValueAsync(const TryGetKeyValueRequest * request,TryGetKeyValueResponse * response,StatusCallback done)170   void TryGetKeyValueAsync(const TryGetKeyValueRequest* request,
171                            TryGetKeyValueResponse* response,
172                            StatusCallback done) override {
173     new RPCState<protobuf::Message>(
174         &stub_, cq_, "/tensorflow.CoordinationService/TryGetKeyValue", *request,
175         response, std::move(done), /*call_opts=*/nullptr,
176         /*threadpool=*/nullptr, /*max_retries=*/0, /*fail_fast=*/true,
177         &target_);
178   }
179 
GetKeyValueDirAsync(const GetKeyValueDirRequest * request,GetKeyValueDirResponse * response,StatusCallback done)180   void GetKeyValueDirAsync(const GetKeyValueDirRequest* request,
181                            GetKeyValueDirResponse* response,
182                            StatusCallback done) override {
183     new RPCState<protobuf::Message>(
184         &stub_, cq_, "/tensorflow.CoordinationService/GetKeyValueDir", *request,
185         response, std::move(done), /*call_opts=*/nullptr,
186         /*threadpool=*/nullptr, /*max_retries=*/0, /*fail_fast=*/true,
187         &target_);
188   }
189 
DeleteKeyValueAsync(const DeleteKeyValueRequest * request,DeleteKeyValueResponse * response,StatusCallback done)190   void DeleteKeyValueAsync(const DeleteKeyValueRequest* request,
191                            DeleteKeyValueResponse* response,
192                            StatusCallback done) override {
193     new RPCState<protobuf::Message>(
194         &stub_, cq_, "/tensorflow.CoordinationService/DeleteKeyValue", *request,
195         response, std::move(done), /*call_opts=*/nullptr,
196         /*threadpool=*/nullptr, /*max_retries=*/0, /*fail_fast=*/true,
197         &target_);
198   }
199 
BarrierAsync(const BarrierRequest * request,BarrierResponse * response,StatusCallback done)200   void BarrierAsync(const BarrierRequest* request, BarrierResponse* response,
201                     StatusCallback done) override {
202     new RPCState<protobuf::Message>(
203         &stub_, cq_, "/tensorflow.CoordinationService/Barrier", *request,
204         response, std::move(done), /*call_opts=*/nullptr,
205         /*threadpool=*/nullptr, /*max_retries=*/0, /*fail_fast=*/true,
206         &target_);
207   }
208 
CancelBarrierAsync(const CancelBarrierRequest * request,CancelBarrierResponse * response,StatusCallback done)209   void CancelBarrierAsync(const CancelBarrierRequest* request,
210                           CancelBarrierResponse* response,
211                           StatusCallback done) override {
212     new RPCState<protobuf::Message>(
213         &stub_, cq_, "/tensorflow.CoordinationService/CancelBarrier", *request,
214         response, std::move(done), /*call_opts=*/nullptr,
215         /*threadpool=*/nullptr, /*max_retries=*/0, /*fail_fast=*/true,
216         &target_);
217   }
218 
219  private:
220   ::grpc::GenericStub stub_;
221   ::grpc::CompletionQueue* cq_;
222   const string target_;
223   std::unique_ptr<GrpcCoordinationClientThread> client_thread_;
224 };
225 
226 class GrpcCoordinationClientCache : public CoordinationClientCache {
227  public:
GrpcCoordinationClientCache(std::shared_ptr<GrpcChannelCache> channel_cache)228   explicit GrpcCoordinationClientCache(
229       std::shared_ptr<GrpcChannelCache> channel_cache)
230       : next_round_robin_assignment_(0),
231         channel_cache_(channel_cache),
232         threads_(4) {}
233 
~GrpcCoordinationClientCache()234   ~GrpcCoordinationClientCache() override {}
235 
GetClient(const string & target)236   CoordinationClient* GetClient(const string& target) override {
237     mutex_lock l(clients_mu_);
238     auto it = clients_.find(target);
239     if (it == clients_.end()) {
240       SharedGrpcChannelPtr channel = channel_cache_->FindWorkerChannel(target);
241       if (channel == nullptr) {
242         VLOG(2) << "Coordination client for target " << target << " not found.";
243       }
244       int assigned_index = AssignClientToThread(target);
245       auto coord_client = std::make_unique<GrpcCoordinationClient>(
246           channel, threads_[assigned_index].completion_queue(), target);
247       it = clients_.emplace(target, std::move(coord_client)).first;
248     }
249     return it->second.get();
250   }
251 
GetOwnedClient(const string & target)252   std::unique_ptr<CoordinationClient> GetOwnedClient(
253       const string& target) override {
254     SharedGrpcChannelPtr channel = channel_cache_->FindWorkerChannel(target);
255     if (channel == nullptr) {
256       VLOG(2) << "Coordination client for target " << target << " not found.";
257     }
258     return std::make_unique<GrpcCoordinationClient>(channel, target);
259   }
260 
261  private:
262   mutex assignment_mu_;
263   std::unordered_map<std::string, size_t> target_assignments_
264       TF_GUARDED_BY(assignment_mu_);
265   size_t next_round_robin_assignment_ TF_GUARDED_BY(assignment_mu_);
266 
AssignClientToThread(const string & target)267   size_t AssignClientToThread(const string& target) {
268     // Round-robin target assignment, but keeps the same target on the same
269     // polling thread always, as this is important for gRPC performance
270     mutex_lock lock(assignment_mu_);
271     auto it = target_assignments_.find(target);
272     if (it == target_assignments_.end()) {
273       it = target_assignments_
274                .insert(std::make_pair(
275                    target, (next_round_robin_assignment_++) % threads_.size()))
276                .first;
277     }
278     return it->second;
279   }
280 
281   std::shared_ptr<GrpcChannelCache> channel_cache_;
282   mutable mutex clients_mu_;
283   std::unordered_map<std::string, std::unique_ptr<CoordinationClient>> clients_
284       TF_GUARDED_BY(clients_mu_);
285   std::vector<GrpcCoordinationClientThread> threads_;
286 };
287 
288 }  // namespace
289 
NewGrpcCoordinationClientCache(std::shared_ptr<GrpcChannelCache> channel_cache)290 CoordinationClientCache* NewGrpcCoordinationClientCache(
291     std::shared_ptr<GrpcChannelCache> channel_cache) {
292   return new GrpcCoordinationClientCache(channel_cache);
293 }
294 
NewGrpcCoordinationClient(std::shared_ptr<::grpc::Channel> channel)295 CoordinationClient* NewGrpcCoordinationClient(
296     std::shared_ptr<::grpc::Channel> channel) {
297   // TODO(hanyangtay): Pass in the logical task name for better logging.
298   return new GrpcCoordinationClient(
299       channel, /*target=*/"unknown_target_for_coordination_leader");
300 }
301 
302 }  // namespace tensorflow
303