1 /* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/core/distributed_runtime/coordination/coordination_service_rpc_handler.h"
17 
18 #include <string>
19 #include <utility>
20 
21 #include "absl/time/time.h"
22 #include "tensorflow/core/distributed_runtime/coordination/coordination_service.h"
23 #include "tensorflow/core/distributed_runtime/coordination/coordination_service_agent.h"
24 #include "tensorflow/core/distributed_runtime/coordination/coordination_service_error_util.h"
25 #include "tensorflow/core/platform/casts.h"
26 #include "tensorflow/core/platform/errors.h"
27 #include "tensorflow/core/platform/mutex.h"
28 #include "tensorflow/core/protobuf/coordination_service.pb.h"
29 
30 namespace tensorflow {
31 
SetAgentInstance(CoordinationServiceAgent * agent)32 void CoordinationServiceRpcHandler::SetAgentInstance(
33     CoordinationServiceAgent* agent) {
34   mutex_lock l(agent_mu_);
35   agent_ = agent;
36 }
37 
RegisterTaskAsync(const RegisterTaskRequest * request,RegisterTaskResponse * response,StatusCallback done)38 void CoordinationServiceRpcHandler::RegisterTaskAsync(
39     const RegisterTaskRequest* request, RegisterTaskResponse* response,
40     StatusCallback done) {
41   CoordinationServiceInterface* service =
42       CoordinationServiceInterface::GetCoordinationServiceInstance();
43   if (service == nullptr) {
44     done(MakeCoordinationError(
45         errors::Internal("Coordination service is not enabled.")));
46     return;
47   }
48   const CoordinatedTask& task = request->source_task();
49   const uint64_t incarnation = request->incarnation();
50   const uint64_t leader_incarnation = service->GetServiceIncarnation();
51   response->set_leader_incarnation(leader_incarnation);
52   done(service->RegisterTask(task, incarnation));
53 }
54 
HeartbeatAsync(const HeartbeatRequest * request,HeartbeatResponse * response,StatusCallback done)55 void CoordinationServiceRpcHandler::HeartbeatAsync(
56     const HeartbeatRequest* request, HeartbeatResponse* response,
57     StatusCallback done) {
58   CoordinationServiceInterface* service =
59       CoordinationServiceInterface::GetCoordinationServiceInstance();
60   if (service == nullptr) {
61     done(MakeCoordinationError(
62         errors::Internal("Coordination service is not enabled.")));
63     return;
64   }
65   const CoordinatedTask& task = request->source_task();
66   const uint64_t incarnation = request->incarnation();
67   const uint64_t leader_incarnation = service->GetServiceIncarnation();
68   Status s = service->RecordHeartbeat(task, incarnation);
69   if (!s.ok()) {
70     done(s);
71     return;
72   }
73   response->set_leader_incarnation(leader_incarnation);
74   done(OkStatus());
75 }
76 
WaitForAllTasksAsync(const WaitForAllTasksRequest * request,WaitForAllTasksResponse * response,StatusCallback done)77 void CoordinationServiceRpcHandler::WaitForAllTasksAsync(
78     const WaitForAllTasksRequest* request, WaitForAllTasksResponse* response,
79     StatusCallback done) {
80   CoordinationServiceInterface* service =
81       CoordinationServiceInterface::GetCoordinationServiceInstance();
82   if (service == nullptr) {
83     done(MakeCoordinationError(
84         errors::Internal("Coordination service is not enabled.")));
85     return;
86   }
87   service->WaitForAllTasks(
88       request->source_task(), request->local_device_info(),
89       [response, service, done = std::move(done)](Status s) {
90         if (s.ok()) {
91           *response->mutable_cluster_device_info() =
92               service->ListClusterDevices();
93         }
94         done(s);
95       });
96 }
97 
ShutdownTaskAsync(const ShutdownTaskRequest * request,ShutdownTaskResponse * response,StatusCallback done)98 void CoordinationServiceRpcHandler::ShutdownTaskAsync(
99     const ShutdownTaskRequest* request, ShutdownTaskResponse* response,
100     StatusCallback done) {
101   CoordinationServiceInterface* service =
102       CoordinationServiceInterface::GetCoordinationServiceInstance();
103   if (service == nullptr) {
104     done(MakeCoordinationError(
105         errors::Internal("Coordination service is not enabled.")));
106     return;
107   }
108   service->ShutdownTaskAsync(request->source_task(),
109                              [done](Status s) { done(s); });
110 }
111 
ResetTaskAsync(const ResetTaskRequest * request,ResetTaskResponse * response,StatusCallback done)112 void CoordinationServiceRpcHandler::ResetTaskAsync(
113     const ResetTaskRequest* request, ResetTaskResponse* response,
114     StatusCallback done) {
115   CoordinationServiceInterface* service =
116       CoordinationServiceInterface::GetCoordinationServiceInstance();
117   if (service == nullptr) {
118     done(MakeCoordinationError(
119         errors::Internal("Coordination service is not enabled.")));
120     return;
121   }
122   done(service->ResetTask(request->source_task()));
123 }
124 
ReportErrorToTaskAsync(const ReportErrorToTaskRequest * request,ReportErrorToTaskResponse * response,StatusCallback done)125 void CoordinationServiceRpcHandler::ReportErrorToTaskAsync(
126     const ReportErrorToTaskRequest* request,
127     ReportErrorToTaskResponse* response, StatusCallback done) {
128   tf_shared_lock l(agent_mu_);
129   if (agent_ == nullptr) {
130     done(MakeCoordinationError(errors::Internal(
131         "CoordinationServiceAgent is uninitialized or has already shutdown.")));
132     return;
133   }
134   const CoordinationServiceError& error_payload = request->error_payload();
135   Status error(static_cast<error::Code>(request->error_code()),
136                strings::StrCat("Error reported from /job:",
137                                error_payload.source_task().job_name(),
138                                "/task:", error_payload.source_task().task_id(),
139                                ": ", request->error_message()));
140   error = MakeCoordinationError(error, error_payload);
141   agent_->SetError(error);
142   done(OkStatus());
143 }
144 
ReportErrorToServiceAsync(const ReportErrorToServiceRequest * request,ReportErrorToServiceResponse * response,StatusCallback done)145 void CoordinationServiceRpcHandler::ReportErrorToServiceAsync(
146     const ReportErrorToServiceRequest* request,
147     ReportErrorToServiceResponse* response, StatusCallback done) {
148   CoordinationServiceInterface* service =
149       CoordinationServiceInterface::GetCoordinationServiceInstance();
150   if (service == nullptr) {
151     done(MakeCoordinationError(
152         errors::Internal("Coordination service is not enabled.")));
153     return;
154   }
155   done(service->ReportTaskError(
156       request->error_origin(),
157       MakeCoordinationError(
158           Status{static_cast<error::Code>(request->error_code()),
159                  request->error_message()},
160           request->error_origin(),
161           /*is_reported_error=*/true)));
162 }
163 
InsertKeyValueAsync(const InsertKeyValueRequest * request,InsertKeyValueResponse * response,StatusCallback done)164 void CoordinationServiceRpcHandler::InsertKeyValueAsync(
165     const InsertKeyValueRequest* request, InsertKeyValueResponse* response,
166     StatusCallback done) {
167   CoordinationServiceInterface* service =
168       CoordinationServiceInterface::GetCoordinationServiceInstance();
169   if (service == nullptr) {
170     done(MakeCoordinationError(
171         errors::Internal("Coordination service is not enabled.")));
172     return;
173   }
174   done(service->InsertKeyValue(request->kv().key(), request->kv().value()));
175 }
176 
GetKeyValueAsync(const GetKeyValueRequest * request,GetKeyValueResponse * response,StatusCallback done)177 void CoordinationServiceRpcHandler::GetKeyValueAsync(
178     const GetKeyValueRequest* request, GetKeyValueResponse* response,
179     StatusCallback done) {
180   CoordinationServiceInterface* service =
181       CoordinationServiceInterface::GetCoordinationServiceInstance();
182   if (service == nullptr) {
183     done(MakeCoordinationError(
184         errors::Internal("Coordination service is not enabled.")));
185     return;
186   }
187   response->mutable_kv()->set_key(request->key());
188   service->GetKeyValueAsync(
189       request->key(), [response, done = std::move(done)](
190                           const StatusOr<std::string>& status_or_value) {
191         if (status_or_value.ok()) {
192           response->mutable_kv()->set_value(status_or_value.ValueOrDie());
193         }
194         done(status_or_value.status());
195       });
196 }
197 
TryGetKeyValueAsync(const TryGetKeyValueRequest * request,TryGetKeyValueResponse * response,StatusCallback done)198 void CoordinationServiceRpcHandler::TryGetKeyValueAsync(
199     const TryGetKeyValueRequest* request, TryGetKeyValueResponse* response,
200     StatusCallback done) {
201   CoordinationServiceInterface* service =
202       CoordinationServiceInterface::GetCoordinationServiceInstance();
203   if (service == nullptr) {
204     done(MakeCoordinationError(
205         errors::Internal("Coordination service is not enabled.")));
206     return;
207   }
208   auto result = service->TryGetKeyValue(request->key());
209   if (!result.ok()) {
210     done(MakeCoordinationError(result.status()));
211     return;
212   }
213   response->mutable_kv()->set_key(request->key());
214   response->mutable_kv()->set_value(result.ValueOrDie());
215   done(Status::OK());
216 }
217 
GetKeyValueDirAsync(const GetKeyValueDirRequest * request,GetKeyValueDirResponse * response,StatusCallback done)218 void CoordinationServiceRpcHandler::GetKeyValueDirAsync(
219     const GetKeyValueDirRequest* request, GetKeyValueDirResponse* response,
220     StatusCallback done) {
221   CoordinationServiceInterface* service =
222       CoordinationServiceInterface::GetCoordinationServiceInstance();
223   if (service == nullptr) {
224     done(MakeCoordinationError(
225         errors::Internal("Coordination service is not enabled.")));
226     return;
227   }
228   std::vector<KeyValueEntry> results =
229       service->GetKeyValueDir(request->directory_key());
230   *response->mutable_kv() = {std::make_move_iterator(results.begin()),
231                              std::make_move_iterator(results.end())};
232   done(OkStatus());
233 }
234 
DeleteKeyValueAsync(const DeleteKeyValueRequest * request,DeleteKeyValueResponse * response,StatusCallback done)235 void CoordinationServiceRpcHandler::DeleteKeyValueAsync(
236     const DeleteKeyValueRequest* request, DeleteKeyValueResponse* response,
237     StatusCallback done) {
238   CoordinationServiceInterface* service =
239       CoordinationServiceInterface::GetCoordinationServiceInstance();
240   if (service == nullptr) {
241     done(MakeCoordinationError(
242         errors::Internal("Coordination service is not enabled.")));
243     return;
244   }
245   done(service->DeleteKeyValue(request->key()));
246 }
247 
BarrierAsync(const BarrierRequest * request,BarrierResponse * response,StatusCallback done)248 void CoordinationServiceRpcHandler::BarrierAsync(const BarrierRequest* request,
249                                                  BarrierResponse* response,
250                                                  StatusCallback done) {
251   CoordinationServiceInterface* service =
252       CoordinationServiceInterface::GetCoordinationServiceInstance();
253   if (service == nullptr) {
254     done(MakeCoordinationError(
255         errors::Internal("Coordination service is not enabled.")));
256     return;
257   }
258   std::vector<CoordinatedTask> tasks = {request->tasks().begin(),
259                                         request->tasks().end()};
260   service->BarrierAsync(
261       request->barrier_id(),
262       absl::Milliseconds(request->barrier_timeout_in_ms()),
263       request->source_task(), tasks,
264       [done = std::move(done)](const Status& status) { done(status); });
265 }
266 
CancelBarrierAsync(const CancelBarrierRequest * request,CancelBarrierResponse * response,StatusCallback done)267 void CoordinationServiceRpcHandler::CancelBarrierAsync(
268     const CancelBarrierRequest* request, CancelBarrierResponse* response,
269     StatusCallback done) {
270   CoordinationServiceInterface* service =
271       CoordinationServiceInterface::GetCoordinationServiceInstance();
272   if (service == nullptr) {
273     done(MakeCoordinationError(
274         errors::Internal("Coordination service is not enabled.")));
275     return;
276   }
277   done(service->CancelBarrier(request->barrier_id(), request->source_task()));
278 }
279 
280 }  // namespace tensorflow
281