1 /* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/core/distributed_runtime/coordination/coordination_service_rpc_handler.h"
17
18 #include <string>
19 #include <utility>
20
21 #include "absl/time/time.h"
22 #include "tensorflow/core/distributed_runtime/coordination/coordination_service.h"
23 #include "tensorflow/core/distributed_runtime/coordination/coordination_service_agent.h"
24 #include "tensorflow/core/distributed_runtime/coordination/coordination_service_error_util.h"
25 #include "tensorflow/core/platform/casts.h"
26 #include "tensorflow/core/platform/errors.h"
27 #include "tensorflow/core/platform/mutex.h"
28 #include "tensorflow/core/protobuf/coordination_service.pb.h"
29
30 namespace tensorflow {
31
SetAgentInstance(CoordinationServiceAgent * agent)32 void CoordinationServiceRpcHandler::SetAgentInstance(
33 CoordinationServiceAgent* agent) {
34 mutex_lock l(agent_mu_);
35 agent_ = agent;
36 }
37
RegisterTaskAsync(const RegisterTaskRequest * request,RegisterTaskResponse * response,StatusCallback done)38 void CoordinationServiceRpcHandler::RegisterTaskAsync(
39 const RegisterTaskRequest* request, RegisterTaskResponse* response,
40 StatusCallback done) {
41 CoordinationServiceInterface* service =
42 CoordinationServiceInterface::GetCoordinationServiceInstance();
43 if (service == nullptr) {
44 done(MakeCoordinationError(
45 errors::Internal("Coordination service is not enabled.")));
46 return;
47 }
48 const CoordinatedTask& task = request->source_task();
49 const uint64_t incarnation = request->incarnation();
50 const uint64_t leader_incarnation = service->GetServiceIncarnation();
51 response->set_leader_incarnation(leader_incarnation);
52 done(service->RegisterTask(task, incarnation));
53 }
54
HeartbeatAsync(const HeartbeatRequest * request,HeartbeatResponse * response,StatusCallback done)55 void CoordinationServiceRpcHandler::HeartbeatAsync(
56 const HeartbeatRequest* request, HeartbeatResponse* response,
57 StatusCallback done) {
58 CoordinationServiceInterface* service =
59 CoordinationServiceInterface::GetCoordinationServiceInstance();
60 if (service == nullptr) {
61 done(MakeCoordinationError(
62 errors::Internal("Coordination service is not enabled.")));
63 return;
64 }
65 const CoordinatedTask& task = request->source_task();
66 const uint64_t incarnation = request->incarnation();
67 const uint64_t leader_incarnation = service->GetServiceIncarnation();
68 Status s = service->RecordHeartbeat(task, incarnation);
69 if (!s.ok()) {
70 done(s);
71 return;
72 }
73 response->set_leader_incarnation(leader_incarnation);
74 done(OkStatus());
75 }
76
WaitForAllTasksAsync(const WaitForAllTasksRequest * request,WaitForAllTasksResponse * response,StatusCallback done)77 void CoordinationServiceRpcHandler::WaitForAllTasksAsync(
78 const WaitForAllTasksRequest* request, WaitForAllTasksResponse* response,
79 StatusCallback done) {
80 CoordinationServiceInterface* service =
81 CoordinationServiceInterface::GetCoordinationServiceInstance();
82 if (service == nullptr) {
83 done(MakeCoordinationError(
84 errors::Internal("Coordination service is not enabled.")));
85 return;
86 }
87 service->WaitForAllTasks(
88 request->source_task(), request->local_device_info(),
89 [response, service, done = std::move(done)](Status s) {
90 if (s.ok()) {
91 *response->mutable_cluster_device_info() =
92 service->ListClusterDevices();
93 }
94 done(s);
95 });
96 }
97
ShutdownTaskAsync(const ShutdownTaskRequest * request,ShutdownTaskResponse * response,StatusCallback done)98 void CoordinationServiceRpcHandler::ShutdownTaskAsync(
99 const ShutdownTaskRequest* request, ShutdownTaskResponse* response,
100 StatusCallback done) {
101 CoordinationServiceInterface* service =
102 CoordinationServiceInterface::GetCoordinationServiceInstance();
103 if (service == nullptr) {
104 done(MakeCoordinationError(
105 errors::Internal("Coordination service is not enabled.")));
106 return;
107 }
108 service->ShutdownTaskAsync(request->source_task(),
109 [done](Status s) { done(s); });
110 }
111
ResetTaskAsync(const ResetTaskRequest * request,ResetTaskResponse * response,StatusCallback done)112 void CoordinationServiceRpcHandler::ResetTaskAsync(
113 const ResetTaskRequest* request, ResetTaskResponse* response,
114 StatusCallback done) {
115 CoordinationServiceInterface* service =
116 CoordinationServiceInterface::GetCoordinationServiceInstance();
117 if (service == nullptr) {
118 done(MakeCoordinationError(
119 errors::Internal("Coordination service is not enabled.")));
120 return;
121 }
122 done(service->ResetTask(request->source_task()));
123 }
124
ReportErrorToTaskAsync(const ReportErrorToTaskRequest * request,ReportErrorToTaskResponse * response,StatusCallback done)125 void CoordinationServiceRpcHandler::ReportErrorToTaskAsync(
126 const ReportErrorToTaskRequest* request,
127 ReportErrorToTaskResponse* response, StatusCallback done) {
128 tf_shared_lock l(agent_mu_);
129 if (agent_ == nullptr) {
130 done(MakeCoordinationError(errors::Internal(
131 "CoordinationServiceAgent is uninitialized or has already shutdown.")));
132 return;
133 }
134 const CoordinationServiceError& error_payload = request->error_payload();
135 Status error(static_cast<error::Code>(request->error_code()),
136 strings::StrCat("Error reported from /job:",
137 error_payload.source_task().job_name(),
138 "/task:", error_payload.source_task().task_id(),
139 ": ", request->error_message()));
140 error = MakeCoordinationError(error, error_payload);
141 agent_->SetError(error);
142 done(OkStatus());
143 }
144
ReportErrorToServiceAsync(const ReportErrorToServiceRequest * request,ReportErrorToServiceResponse * response,StatusCallback done)145 void CoordinationServiceRpcHandler::ReportErrorToServiceAsync(
146 const ReportErrorToServiceRequest* request,
147 ReportErrorToServiceResponse* response, StatusCallback done) {
148 CoordinationServiceInterface* service =
149 CoordinationServiceInterface::GetCoordinationServiceInstance();
150 if (service == nullptr) {
151 done(MakeCoordinationError(
152 errors::Internal("Coordination service is not enabled.")));
153 return;
154 }
155 done(service->ReportTaskError(
156 request->error_origin(),
157 MakeCoordinationError(
158 Status{static_cast<error::Code>(request->error_code()),
159 request->error_message()},
160 request->error_origin(),
161 /*is_reported_error=*/true)));
162 }
163
InsertKeyValueAsync(const InsertKeyValueRequest * request,InsertKeyValueResponse * response,StatusCallback done)164 void CoordinationServiceRpcHandler::InsertKeyValueAsync(
165 const InsertKeyValueRequest* request, InsertKeyValueResponse* response,
166 StatusCallback done) {
167 CoordinationServiceInterface* service =
168 CoordinationServiceInterface::GetCoordinationServiceInstance();
169 if (service == nullptr) {
170 done(MakeCoordinationError(
171 errors::Internal("Coordination service is not enabled.")));
172 return;
173 }
174 done(service->InsertKeyValue(request->kv().key(), request->kv().value()));
175 }
176
GetKeyValueAsync(const GetKeyValueRequest * request,GetKeyValueResponse * response,StatusCallback done)177 void CoordinationServiceRpcHandler::GetKeyValueAsync(
178 const GetKeyValueRequest* request, GetKeyValueResponse* response,
179 StatusCallback done) {
180 CoordinationServiceInterface* service =
181 CoordinationServiceInterface::GetCoordinationServiceInstance();
182 if (service == nullptr) {
183 done(MakeCoordinationError(
184 errors::Internal("Coordination service is not enabled.")));
185 return;
186 }
187 response->mutable_kv()->set_key(request->key());
188 service->GetKeyValueAsync(
189 request->key(), [response, done = std::move(done)](
190 const StatusOr<std::string>& status_or_value) {
191 if (status_or_value.ok()) {
192 response->mutable_kv()->set_value(status_or_value.ValueOrDie());
193 }
194 done(status_or_value.status());
195 });
196 }
197
TryGetKeyValueAsync(const TryGetKeyValueRequest * request,TryGetKeyValueResponse * response,StatusCallback done)198 void CoordinationServiceRpcHandler::TryGetKeyValueAsync(
199 const TryGetKeyValueRequest* request, TryGetKeyValueResponse* response,
200 StatusCallback done) {
201 CoordinationServiceInterface* service =
202 CoordinationServiceInterface::GetCoordinationServiceInstance();
203 if (service == nullptr) {
204 done(MakeCoordinationError(
205 errors::Internal("Coordination service is not enabled.")));
206 return;
207 }
208 auto result = service->TryGetKeyValue(request->key());
209 if (!result.ok()) {
210 done(MakeCoordinationError(result.status()));
211 return;
212 }
213 response->mutable_kv()->set_key(request->key());
214 response->mutable_kv()->set_value(result.ValueOrDie());
215 done(Status::OK());
216 }
217
GetKeyValueDirAsync(const GetKeyValueDirRequest * request,GetKeyValueDirResponse * response,StatusCallback done)218 void CoordinationServiceRpcHandler::GetKeyValueDirAsync(
219 const GetKeyValueDirRequest* request, GetKeyValueDirResponse* response,
220 StatusCallback done) {
221 CoordinationServiceInterface* service =
222 CoordinationServiceInterface::GetCoordinationServiceInstance();
223 if (service == nullptr) {
224 done(MakeCoordinationError(
225 errors::Internal("Coordination service is not enabled.")));
226 return;
227 }
228 std::vector<KeyValueEntry> results =
229 service->GetKeyValueDir(request->directory_key());
230 *response->mutable_kv() = {std::make_move_iterator(results.begin()),
231 std::make_move_iterator(results.end())};
232 done(OkStatus());
233 }
234
DeleteKeyValueAsync(const DeleteKeyValueRequest * request,DeleteKeyValueResponse * response,StatusCallback done)235 void CoordinationServiceRpcHandler::DeleteKeyValueAsync(
236 const DeleteKeyValueRequest* request, DeleteKeyValueResponse* response,
237 StatusCallback done) {
238 CoordinationServiceInterface* service =
239 CoordinationServiceInterface::GetCoordinationServiceInstance();
240 if (service == nullptr) {
241 done(MakeCoordinationError(
242 errors::Internal("Coordination service is not enabled.")));
243 return;
244 }
245 done(service->DeleteKeyValue(request->key()));
246 }
247
BarrierAsync(const BarrierRequest * request,BarrierResponse * response,StatusCallback done)248 void CoordinationServiceRpcHandler::BarrierAsync(const BarrierRequest* request,
249 BarrierResponse* response,
250 StatusCallback done) {
251 CoordinationServiceInterface* service =
252 CoordinationServiceInterface::GetCoordinationServiceInstance();
253 if (service == nullptr) {
254 done(MakeCoordinationError(
255 errors::Internal("Coordination service is not enabled.")));
256 return;
257 }
258 std::vector<CoordinatedTask> tasks = {request->tasks().begin(),
259 request->tasks().end()};
260 service->BarrierAsync(
261 request->barrier_id(),
262 absl::Milliseconds(request->barrier_timeout_in_ms()),
263 request->source_task(), tasks,
264 [done = std::move(done)](const Status& status) { done(status); });
265 }
266
CancelBarrierAsync(const CancelBarrierRequest * request,CancelBarrierResponse * response,StatusCallback done)267 void CoordinationServiceRpcHandler::CancelBarrierAsync(
268 const CancelBarrierRequest* request, CancelBarrierResponse* response,
269 StatusCallback done) {
270 CoordinationServiceInterface* service =
271 CoordinationServiceInterface::GetCoordinationServiceInstance();
272 if (service == nullptr) {
273 done(MakeCoordinationError(
274 errors::Internal("Coordination service is not enabled.")));
275 return;
276 }
277 done(service->CancelBarrier(request->barrier_id(), request->source_task()));
278 }
279
280 } // namespace tensorflow
281