1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/core/distributed_runtime/eager/eager_service_impl.h"
17
18 #include <functional>
19 #include <memory>
20 #include <optional>
21 #include <string>
22 #include <utility>
23 #include <vector>
24
25 #include "absl/container/fixed_array.h"
26 #include "absl/strings/str_cat.h"
27 #include "absl/types/optional.h"
28 #include "tensorflow/c/eager/immediate_execution_distributed_manager.h"
29 #include "tensorflow/core/common_runtime/device_mgr.h"
30 #include "tensorflow/core/common_runtime/eager/context.h"
31 #include "tensorflow/core/common_runtime/eager/context_distributed_manager.h"
32 #include "tensorflow/core/common_runtime/eager/eager_operation.h"
33 #include "tensorflow/core/common_runtime/eager/execute.h"
34 #include "tensorflow/core/distributed_runtime/eager/cluster_function_library_runtime.h"
35 #include "tensorflow/core/distributed_runtime/eager/remote_mgr.h"
36 #include "tensorflow/core/distributed_runtime/eager/remote_tensor_handle.h"
37 #include "tensorflow/core/distributed_runtime/message_wrappers.h"
38 #include "tensorflow/core/distributed_runtime/preemption/preemption_notifier.h"
39 #include "tensorflow/core/distributed_runtime/session_mgr.h"
40 #include "tensorflow/core/distributed_runtime/worker_cache.h"
41 #include "tensorflow/core/distributed_runtime/worker_env.h"
42 #include "tensorflow/core/framework/rendezvous.h"
43 #include "tensorflow/core/platform/errors.h"
44 #include "tensorflow/core/platform/host_info.h"
45 #include "tensorflow/core/platform/mutex.h"
46 #include "tensorflow/core/platform/refcount.h"
47 #include "tensorflow/core/platform/status.h"
48 #include "tensorflow/core/platform/stringprintf.h"
49 #include "tensorflow/core/profiler/lib/traceme.h"
50 #include "tensorflow/core/protobuf/coordination_config.pb.h"
51
52 namespace tensorflow {
53 namespace eager {
54
55 namespace {
GetNumRetvals(tensorflow::EagerContext * context,const string & op_name,const google::protobuf::Map<string,tensorflow::AttrValue> & attrs,int * num_retvals)56 Status GetNumRetvals(tensorflow::EagerContext* context, const string& op_name,
57 const google::protobuf::Map<string, tensorflow::AttrValue>& attrs,
58 int* num_retvals) {
59 const tensorflow::OpRegistrationData* op_reg_data = nullptr;
60 auto status = tensorflow::OpRegistry::Global()->LookUp(op_name, &op_reg_data);
61 if (errors::IsNotFound(status)) {
62 status = context->FindFunctionOpData(op_name, &op_reg_data);
63 }
64 TF_RETURN_IF_ERROR(status);
65
66 const tensorflow::OpDef& op_def = op_reg_data->op_def;
67
68 for (const auto& output_arg : op_def.output_arg()) {
69 if (!output_arg.number_attr().empty()) {
70 auto iter = attrs.find(output_arg.number_attr());
71 if (iter == attrs.end()) {
72 return errors::InvalidArgument("Unable to find number_attr ",
73 output_arg.number_attr(),
74 " for Op: ", op_name);
75 }
76 *num_retvals += iter->second.i();
77 } else if (!output_arg.type_list_attr().empty()) {
78 auto iter = attrs.find(output_arg.type_list_attr());
79 if (iter == attrs.end()) {
80 return errors::InvalidArgument("Unable to find type_list_attr ",
81 output_arg.type_list_attr(),
82 " for Op: ", op_name);
83 }
84 *num_retvals += iter->second.list().type_size();
85 } else {
86 *num_retvals += 1;
87 }
88 }
89
90 return OkStatus();
91 }
92
GetEagerOperationAndNumRetvals(const Operation & operation,EagerContext * eager_context,EagerExecutor * eager_executor,EagerOperation * eager_op,int * num_retvals)93 Status GetEagerOperationAndNumRetvals(const Operation& operation,
94 EagerContext* eager_context,
95 EagerExecutor* eager_executor,
96 EagerOperation* eager_op,
97 int* num_retvals) {
98 const char* name = operation.name().c_str(); // Shorthand
99 std::optional<tensorflow::EagerFunctionParams> remote_func_params =
100 std::nullopt;
101 if (operation.is_function()) {
102 if (operation.is_component_function()) {
103 remote_func_params = {operation.id(), /*is_component_function=*/true,
104 operation.func_step_id()};
105 } else {
106 remote_func_params = {operation.id(), /*is_component_function=*/false,
107 std::nullopt};
108 }
109 }
110 TF_RETURN_IF_ERROR(eager_op->Reset(name, operation.device().c_str(), false,
111 eager_executor, remote_func_params));
112
113 {
114 profiler::TraceMe activity("EagerService:RemoteTensorHandleInternal",
115 profiler::TraceMeLevel::kVerbose);
116 for (const auto& input : operation.op_inputs()) {
117 tensorflow::TensorHandle* handle;
118 if (input.has_remote_handle()) {
119 TF_RETURN_IF_ERROR(
120 eager_context->RemoteMgr()->DeserializeRemoteTensorHandle(
121 input.remote_handle(), &handle));
122 TF_RETURN_IF_ERROR(eager_op->AddInput(handle));
123 } else {
124 Tensor tensor;
125 if (!ParseTensorProtoToTensor(input.tensor(), &tensor)) {
126 return errors::InvalidArgument("Invalid TensorProto: ",
127 input.tensor().DebugString());
128 } else {
129 handle = TensorHandle::CreateLocalHandle(std::move(tensor), nullptr,
130 nullptr, eager_context);
131 TF_RETURN_IF_ERROR(eager_op->AddInput(handle));
132 }
133 }
134 // Unref handle since it has a ref as an input now.
135 handle->Unref();
136 }
137 }
138
139 for (const auto& attr : operation.attrs()) {
140 eager_op->MutableAttrs()->Set(attr.first, attr.second);
141 }
142
143 // TODO(nareshmodi): Consider caching this.
144 return GetNumRetvals(eager_context, operation.name(), operation.attrs(),
145 num_retvals);
146 }
147
TensorHandleProto(TensorHandle * handle,TensorProto * proto)148 Status TensorHandleProto(TensorHandle* handle, TensorProto* proto) {
149 const tensorflow::Tensor* t = nullptr;
150 TF_RETURN_IF_ERROR(handle->Tensor(&t));
151 t->AsProtoTensorContent(proto);
152 return OkStatus();
153 }
154
TensorHandleShape(TensorHandle * handle,TensorShapeProto * proto)155 Status TensorHandleShape(TensorHandle* handle, TensorShapeProto* proto) {
156 const tensorflow::Tensor* t = nullptr;
157
158 // TODO(nareshmodi): This call makes async calls sync calls. Fix this.
159 if (handle->Type() == TensorHandle::LOCAL) {
160 TF_RETURN_IF_ERROR(handle->Tensor(&t));
161
162 t->shape().AsProto(proto);
163 } else {
164 TensorShape shape;
165 TF_RETURN_IF_ERROR(handle->Shape(&shape));
166 shape.AsProto(proto);
167 }
168
169 return OkStatus();
170 }
171
AddOpRetvalsToResponse(EagerContext * eager_context,int op_id,int num_retvals,const std::vector<int32> & output_nums,TensorHandle ** retvals,std::function<TensorProto * ()> add_tensor_proto_fn,std::function<TensorShapeProto * ()> add_shape_proto_fn,std::function<string * ()> add_device_fn=nullptr)172 Status AddOpRetvalsToResponse(
173 EagerContext* eager_context, int op_id, int num_retvals,
174 const std::vector<int32>& output_nums, TensorHandle** retvals,
175 std::function<TensorProto*()> add_tensor_proto_fn,
176 std::function<TensorShapeProto*()> add_shape_proto_fn,
177 std::function<string*()> add_device_fn = nullptr) {
178 // retvals hold references to the allocated output tensor handles. If errors
179 // happen with adding some results to the response, aggregate the status in sg
180 // instead of directly returning the error, to make sure unref or ownership
181 // transfer completes for the rest of output tensor handles.
182 StatusGroup sg;
183 if (op_id == kInvalidOpId) {
184 // Copy the output tensors back along with the response, since the op id
185 // is invalid which cannot be added to RemoteMgr.
186 for (int i = 0; i < num_retvals; i++) {
187 sg.Update(TensorHandleProto(retvals[i], add_tensor_proto_fn()));
188 retvals[i]->Unref();
189 }
190 } else {
191 for (int i = 0; i < num_retvals; i++) {
192 sg.Update(TensorHandleShape(retvals[i], add_shape_proto_fn()));
193 if (add_device_fn) {
194 Device* device = retvals[i]->device();
195 *add_device_fn() = device ? device->name() : "";
196 }
197 if (retvals[i]->Type() == TensorHandle::REMOTE) {
198 retvals[i]->Unref();
199 } else {
200 const int output_num = output_nums.empty() ? i : output_nums.at(i);
201 eager_context->RemoteMgr()->AddOperationOutput(retvals[i], op_id,
202 output_num);
203 }
204 }
205 }
206 return sg.as_summary_status();
207 }
208 } // namespace
209
CreateContext(const CreateContextRequest * request,CreateContextResponse * response)210 Status EagerServiceImpl::CreateContext(const CreateContextRequest* request,
211 CreateContextResponse* response) {
212 {
213 mutex_lock l(contexts_mu_);
214 auto context_it = contexts_.find(request->context_id());
215 if (context_it != contexts_.end()) {
216 if (request->context_view_id() <
217 context_it->second->Context()->GetContextViewId()) {
218 return errors::InvalidArgument("EagerService:CreateContext failed. ",
219 "Context id: <", request->context_id(),
220 "> already exists.");
221 } else {
222 // For existing context with a stale context_view_id, close the old one
223 // and recreate with new view id. This is likely due to the worker
224 // disconnected and then reconnected after one or more cluster updates.
225 context_it->second->Unref();
226 contexts_.erase(context_it);
227 }
228 }
229 }
230 // make sure env_ , env_->rendezvous_mgr available
231 if (env_ == nullptr || env_->rendezvous_mgr == nullptr) {
232 return tensorflow::errors::Internal(
233 "invalid eager env_ or env_->rendezvous_mgr.");
234 }
235
236 auto* r = env_->rendezvous_mgr->Find(request->context_id());
237 auto session_name =
238 tensorflow::strings::StrCat("eager_", request->context_id());
239 if (VLOG_IS_ON(2)) {
240 VLOG(2) << "Creating context on /job:" << request->server_def().job_name()
241 << "/task:" << request->server_def().task_index();
242 for (const auto& da : request->cluster_device_attributes()) {
243 VLOG(2) << " " << da.name();
244 }
245 }
246 TF_RETURN_IF_ERROR(env_->session_mgr->CreateSession(
247 session_name, request->server_def(), request->cluster_device_attributes(),
248 request->server_def().default_session_config().isolate_session_state()));
249 int64_t context_id = request->context_id();
250 std::function<void()> session_destroyer = [this, context_id, session_name]() {
251 env_->rendezvous_mgr->Cleanup(context_id);
252 auto s = env_->session_mgr->DeleteSession(session_name);
253 if (!s.ok()) {
254 LOG(WARNING) << "Failed to destroy worker session '" << session_name
255 << "' due to " << s.error_message();
256 }
257 };
258
259 std::shared_ptr<WorkerSession> worker_session;
260 TF_RETURN_IF_ERROR(env_->session_mgr->WorkerSessionForSession(
261 session_name, &worker_session));
262
263 tensorflow::DeviceMgr* device_mgr = worker_session->device_mgr();
264
265 // Initialize remote tensor communication based on worker session.
266 TF_RETURN_IF_ERROR(r->Initialize(worker_session.get()));
267 // Set the rendezvous as context-global instance for eager op-by-op execution.
268 r->SetRemoteEagerContextDefault();
269
270 std::function<Rendezvous*(const int64_t)> rendezvous_creator =
271 [worker_session, this](const int64_t step_id) {
272 auto* r = env_->rendezvous_mgr->Find(step_id);
273 r->Initialize(worker_session.get()).IgnoreError();
274 return r;
275 };
276
277 LOG(INFO) << "Creating " << (request->async() ? "async" : "sync")
278 << " eager service context with rendezvous_id on host "
279 << port::Hostname() << " " << worker_session->worker_name();
280 SessionOptions opts;
281 opts.config = request->server_def().default_session_config();
282 tensorflow::EagerContext* ctx = new tensorflow::EagerContext(
283 opts, tensorflow::ContextDevicePlacementPolicy::DEVICE_PLACEMENT_SILENT,
284 request->async(), device_mgr, false, r, worker_session->cluster_flr(),
285 env_->collective_executor_mgr.get());
286 // Ownership will be transferred to the ServerContext, or else in an error
287 // case ctx will be deleted by this unref.
288 core::ScopedUnref unref_ctx(ctx);
289
290 std::vector<string> remote_workers;
291 worker_session->worker_cache()->ListWorkers(&remote_workers);
292 remote_workers.erase(std::remove(remote_workers.begin(), remote_workers.end(),
293 worker_session->worker_name()),
294 remote_workers.end());
295
296 std::unique_ptr<tensorflow::eager::EagerClientCache> remote_eager_workers;
297 TF_RETURN_IF_ERROR(worker_session->worker_cache()->GetEagerClientCache(
298 &remote_eager_workers));
299 DistributedFunctionLibraryRuntime* cluster_flr =
300 eager::CreateClusterFLR(request->context_id(), ctx, worker_session.get());
301
302 auto remote_mgr =
303 std::make_unique<tensorflow::eager::RemoteMgr>(/*is_master=*/false, ctx);
304 Status s = ctx->InitializeRemoteWorker(
305 std::move(remote_eager_workers), worker_session->remote_device_mgr(),
306 remote_workers, request->context_id(), request->context_view_id(),
307 std::move(rendezvous_creator), cluster_flr, std::move(remote_mgr),
308 std::move(session_destroyer));
309 if (!s.ok()) {
310 VLOG(1) << "EagerContext::InitializeRemoteWorker failed with "
311 << s.ToString();
312 return s;
313 }
314
315 #if !defined(IS_MOBILE_PLATFORM)
316 const auto& config = request->server_def().default_session_config();
317 const bool enable_coordination =
318 !config.experimental().coordination_config().service_type().empty();
319 if (enable_coordination) {
320 auto dist_mgr = std::make_unique<EagerContextDistributedManager>(ctx);
321 auto coord_agent = env_->session_mgr->GetCoordinationServiceAgent();
322 dist_mgr->SetCoordinationServiceAgent(coord_agent);
323 auto preemption_notifier =
324 PreemptionNotifier::CreatePreemptionNotifier("sigterm", Env::Default());
325 preemption_notifier->WillBePreemptedAtAsync(
326 [coord_agent](StatusOr<absl::Time> time_or_status) {
327 if (time_or_status.ok()) {
328 const auto& coord_task = coord_agent->GetOwnTask().ValueOrDie();
329 Status s = coord_agent->InsertKeyValue(
330 "TF_DEFAULT_PREEMPTION_NOTICE_KEY",
331 absl::StrCat("/job:", coord_task.job_name(),
332 "/task:", coord_task.task_id()));
333 if (!s.ok()) {
334 LOG(INFO) << "Preemption not exported to coordination service: "
335 << s;
336 }
337 }
338 });
339 ctx->SetDistributedManager(std::move(dist_mgr));
340 }
341 #endif // !IS_MOBILE_PLATFORM
342
343 std::vector<DeviceAttributes> device_attributes;
344 device_mgr->ListDeviceAttributes(&device_attributes);
345
346 for (const auto& da : device_attributes) {
347 *response->add_device_attributes() = da;
348 }
349 {
350 mutex_lock l(contexts_mu_);
351 auto context_it = contexts_.find(request->context_id());
352 if (context_it != contexts_.end()) {
353 return errors::InvalidArgument("EagerService:CreateContext failed. ",
354 "Context id: <", request->context_id(),
355 "> already exists.");
356 }
357 contexts_.emplace(request->context_id(),
358 new ServerContext(ctx, request->keep_alive_secs(), env_));
359 }
360
361 return OkStatus();
362 }
363
UpdateContext(const UpdateContextRequest * request,UpdateContextResponse * response)364 Status EagerServiceImpl::UpdateContext(const UpdateContextRequest* request,
365 UpdateContextResponse* response) {
366 // make sure env_ , env_->rendezvous_mgr available
367 if (env_ == nullptr || env_->rendezvous_mgr == nullptr) {
368 return tensorflow::errors::Internal(
369 "invalid eager env_ or env_->rendezvous_mgr.");
370 }
371
372 // Find the context to update by the requested context_id
373 ServerContext* server_context = nullptr;
374 TF_RETURN_IF_ERROR(GetServerContext(request->context_id(), &server_context));
375 core::ScopedUnref context_unref(server_context);
376
377 tensorflow::EagerContext* ctx = server_context->Context();
378 if (request->context_view_id() != ctx->GetContextViewId() + 1) {
379 return errors::InvalidArgument(
380 "EagerService:UpdateContext failed. Context id: <",
381 request->context_id(), "> currently at view #", ctx->GetContextViewId(),
382 " but received update request at view #", request->context_view_id(),
383 ". View id should only be continuously incremented.");
384 }
385 if (request->cluster_device_attributes_size() == 0) {
386 // In this case, the client indicates that the updated `server_def` and
387 // device info is irrelevant to this worker, since it is not connected to
388 // the updated ones (likely due to device filter settings). The worker
389 // simply needs to update view ID and does not update other internal state.
390 ctx->IncrementContextViewId();
391 VLOG(1) << "Processing simplified UpdateContextRequest on "
392 << ctx->HostCPU()->name();
393 return OkStatus();
394 }
395
396 auto session_name =
397 tensorflow::strings::StrCat("eager_", request->context_id());
398
399 TF_RETURN_IF_ERROR(
400 env_->session_mgr->UpdateSession(session_name, request->server_def(),
401 request->cluster_device_attributes()));
402
403 std::shared_ptr<WorkerSession> worker_session;
404 TF_RETURN_IF_ERROR(env_->session_mgr->WorkerSessionForSession(
405 session_name, &worker_session));
406
407 const tensorflow::DeviceMgr* device_mgr = worker_session->device_mgr();
408
409 std::vector<string> remote_workers;
410 worker_session->worker_cache()->ListWorkers(&remote_workers);
411 remote_workers.erase(std::remove(remote_workers.begin(), remote_workers.end(),
412 worker_session->worker_name()),
413 remote_workers.end());
414 VLOG(1) << "On existing server " << worker_session->worker_name()
415 << " updating remote workers";
416 if (VLOG_IS_ON(2)) {
417 for (const string& rw : remote_workers) {
418 VLOG(2) << "Remote worker " << rw;
419 }
420 }
421
422 std::unique_ptr<tensorflow::eager::EagerClientCache> remote_eager_workers;
423 TF_RETURN_IF_ERROR(worker_session->worker_cache()->GetEagerClientCache(
424 &remote_eager_workers));
425
426 ctx->ClearCachesAndThreadExecutors();
427 Status s = ctx->UpdateRemoteWorker(std::move(remote_eager_workers),
428 remote_workers, request->context_id());
429 if (!s.ok()) {
430 VLOG(1) << "EagerContext::UpdateRemoteWorker failed with " << s.ToString();
431 return s;
432 }
433
434 std::vector<DeviceAttributes> device_attributes;
435 device_mgr->ListDeviceAttributes(&device_attributes);
436
437 for (const auto& da : device_attributes) {
438 *response->add_device_attributes() = da;
439 }
440
441 return OkStatus();
442 }
443
CreateMasterContext(const tensorflow::uint64 context_id,EagerContext * context)444 Status EagerServiceImpl::CreateMasterContext(
445 const tensorflow::uint64 context_id, EagerContext* context) {
446 {
447 mutex_lock l(contexts_mu_);
448 auto iter = contexts_.find(context_id);
449 if (iter != contexts_.end()) {
450 return errors::InvalidArgument(
451 "EagerService:CreateMasterContext failed. ", "Context id: <",
452 context_id, "> already exists.");
453 }
454 }
455 ServerContext* server_context =
456 ServerContext::CreateMasterContext(context, env_);
457 mutex_lock l(contexts_mu_);
458 contexts_.emplace(context_id, server_context);
459 return OkStatus();
460 }
461
RunComponentFunction(CallOptions * call_opts,const RunComponentFunctionRequest * request,RunComponentFunctionResponse * response,StatusCallback done)462 void EagerServiceImpl::RunComponentFunction(
463 CallOptions* call_opts, const RunComponentFunctionRequest* request,
464 RunComponentFunctionResponse* response, StatusCallback done) {
465 ServerContext* context = nullptr;
466 Status s = GetServerContext(request->context_id(), &context);
467 if (!s.ok()) {
468 done(s);
469 return;
470 }
471 core::ScopedUnref context_unref(context);
472
473 auto& operation = request->operation();
474 // This codepath should only be triggered for executing component function
475 if (!operation.is_function() || !operation.is_component_function()) {
476 done(errors::Internal(
477 "RunComponentFunction request can only be used to execute "
478 "component functions."));
479 return;
480 }
481
482 EagerContext* eager_context = context->Context();
483 EagerExecutor* eager_executor = &eager_context->Executor();
484
485 EagerOperation* op = new EagerOperation(eager_context);
486 int* num_retvals = new int(0);
487 s = GetEagerOperationAndNumRetvals(operation, eager_context, eager_executor,
488 op, num_retvals);
489 if (!s.ok()) {
490 delete num_retvals;
491 delete op;
492 done(s);
493 return;
494 }
495 if (!op->IsLocal()) {
496 delete num_retvals;
497 delete op;
498 done(errors::Internal(
499 "Received RunComponentFunction request with remote function device. "));
500 return;
501 }
502 s = op->SetAttrBool("is_component_function", true);
503 if (!s.ok()) {
504 delete num_retvals;
505 delete op;
506 done(errors::Internal("Error setting is_component_function attribute: ",
507 s.error_message()));
508 return;
509 }
510
511 auto* retvals = new absl::FixedArray<TensorHandle*>(*num_retvals);
512 VLOG(3) << "ServerContext: Calling EagerLocalExecuteAsync for op "
513 << operation.id();
514 std::vector<int32> output_nums;
515 for (const int32_t output_num : request->output_num()) {
516 output_nums.push_back(output_num);
517 }
518
519 auto cm = std::make_shared<CancellationManager>();
520 op->SetCancellationManager(cm.get());
521 call_opts->SetCancelCallback([cm] { cm->StartCancel(); });
522
523 context->Ref();
524 EagerLocalExecuteAsync(
525 op, retvals->data(), num_retvals,
526 [op, op_id = operation.id(), num_retvals, retvals, output_nums, cm,
527 call_opts, response, eager_context, context,
528 done = std::move(done)](const Status& status) {
529 call_opts->ClearCancelCallback();
530 auto wrapped_done = [&](const Status& status) {
531 context->Unref();
532 done(status);
533 delete op;
534 delete num_retvals;
535 delete retvals;
536 };
537 if (!status.ok()) {
538 wrapped_done(status);
539 return;
540 }
541 // The output device of a component function is the component device
542 // which is known on the default device of it's parent function.
543 wrapped_done(AddOpRetvalsToResponse(
544 eager_context, op_id, *num_retvals, output_nums, retvals->data(),
545 [response] { return response->add_tensor(); },
546 [response] { return response->add_shape(); }));
547 });
548 }
549
ExecuteOp(CallOptions * call_opts,const Operation & operation,EagerContext * eager_context,EagerExecutor * eager_executor,QueueResponse * queue_response)550 Status EagerServiceImpl::ExecuteOp(CallOptions* call_opts,
551 const Operation& operation,
552 EagerContext* eager_context,
553 EagerExecutor* eager_executor,
554 QueueResponse* queue_response) {
555 tensorflow::EagerOperation op(eager_context);
556 int num_retvals = 0;
557 TF_RETURN_IF_ERROR(GetEagerOperationAndNumRetvals(
558 operation, eager_context, eager_executor, &op, &num_retvals));
559
560 auto cm = std::make_shared<CancellationManager>();
561 if (call_opts) {
562 op.SetCancellationManager(cm.get());
563 call_opts->SetCancelCallback([cm] { cm->StartCancel(); });
564 }
565
566 absl::FixedArray<tensorflow::TensorHandle*> retvals(num_retvals);
567 VLOG(3) << "ServerContext: Calling EagerExecute for op " << operation.id();
568 TF_RETURN_IF_ERROR(op.Execute(
569 absl::MakeSpan(
570 reinterpret_cast<tensorflow::AbstractTensorHandle**>(retvals.data()),
571 num_retvals),
572 &num_retvals));
573
574 std::function<string*()> add_device_fn = nullptr;
575 // Send the output devices of a function back to let a client know where the
576 // outputs are. For a primitive op, an output devics is the op device which is
577 // known on a client.
578 if (op.is_function()) {
579 add_device_fn = [queue_response] { return queue_response->add_device(); };
580 }
581
582 return AddOpRetvalsToResponse(
583 eager_context, operation.id(), num_retvals, /*output_nums=*/{},
584 retvals.data(), [queue_response] { return queue_response->add_tensor(); },
585 [queue_response] { return queue_response->add_shape(); },
586 std::move(add_device_fn));
587 }
588
Enqueue(CallOptions * call_opts,const EnqueueRequest * request,EnqueueResponse * response,uint64 stream_id)589 Status EagerServiceImpl::Enqueue(CallOptions* call_opts,
590 const EnqueueRequest* request,
591 EnqueueResponse* response, uint64 stream_id) {
592 profiler::TraceMe activity(
593 [&] {
594 return absl::StrCat(
595 "EagerService:Enqueue#debug_str=", request->DebugString(), "#");
596 },
597 profiler::TraceMeLevel::kInfo);
598 ServerContext* context = nullptr;
599 TF_RETURN_IF_ERROR(GetServerContext(request->context_id(), &context));
600 core::ScopedUnref context_unref(context);
601
602 EagerExecutor& executor =
603 stream_id == kInvalidStreamId
604 ? context->Context()->Executor()
605 : context->Context()->RemoteMgr()->GetOrCreateExecutorForStream(
606 stream_id);
607 Status s;
608 for (const auto& item : request->queue()) {
609 auto* queue_response = response->add_queue_response();
610 if (item.has_operation()) {
611 s = ExecuteOp(call_opts, item.operation(), context->Context(), &executor,
612 queue_response);
613 } else if (item.has_handle_to_decref()) {
614 auto handle_to_decref = std::make_unique<RemoteTensorHandleInternal>(
615 item.handle_to_decref());
616 auto node = std::make_unique<ClientTensorHandleDeleteNode>(
617 context, std::move(handle_to_decref));
618 s = context->Context()->Executor().AddOrExecute(std::move(node));
619 } else if (item.has_send_tensor()) {
620 s = SendTensor(item.send_tensor(), context->Context());
621 } else if (item.has_send_packed_handle()) {
622 s = SendPackedHandle(item.send_packed_handle(), context->Context());
623 } else if (item.has_register_function()) {
624 s = RegisterFunction(item.register_function(), context->Context());
625 } else if (item.has_cleanup_function()) {
626 s = CleanupFunction(item.cleanup_function());
627 } else {
628 DCHECK(item.has_sync_remote_executor_for_stream());
629 s = executor.WaitForAllPendingNodes();
630 }
631
632 if (!s.ok()) {
633 if (stream_id != kInvalidStreamId) {
634 context->Context()->RemoteMgr()->DeleteExecutorForStream(stream_id);
635 }
636 return s;
637 }
638 }
639
640 return OkStatus();
641 }
642
WaitQueueDone(const WaitQueueDoneRequest * request,WaitQueueDoneResponse * response)643 Status EagerServiceImpl::WaitQueueDone(const WaitQueueDoneRequest* request,
644 WaitQueueDoneResponse* response) {
645 ServerContext* context = nullptr;
646 TF_RETURN_IF_ERROR(GetServerContext(request->context_id(), &context));
647 core::ScopedUnref context_unref(context);
648
649 if (request->op_id_size() > 0) {
650 return errors::Unimplemented(
651 "EagerServiceImpl::WaitQueueDone is not "
652 "implemented for particular op IDs.");
653 }
654 return context->Context()->Executor().WaitForAllPendingNodes();
655 }
656
KeepAlive(const KeepAliveRequest * request,KeepAliveResponse * response)657 Status EagerServiceImpl::KeepAlive(const KeepAliveRequest* request,
658 KeepAliveResponse* response) {
659 ServerContext* context = nullptr;
660 TF_RETURN_IF_ERROR(GetServerContext(request->context_id(), &context));
661 core::ScopedUnref context_unref(context);
662
663 tensorflow::EagerContext* ctx = context->Context();
664 response->set_context_view_id(ctx->GetContextViewId());
665 return OkStatus();
666 }
667
CloseContext(const CloseContextRequest * request,CloseContextResponse * response)668 Status EagerServiceImpl::CloseContext(const CloseContextRequest* request,
669 CloseContextResponse* response) {
670 VLOG(1) << "Executing EagerService::CloseContext for context "
671 << request->context_id();
672 ServerContext* context = nullptr;
673 if (!GetServerContext(request->context_id(), &context).ok()) {
674 // Swallow the error here.
675 return OkStatus();
676 }
677 core::ScopedUnref context_unref(context);
678
679 if (request->context_view_id() < context->Context()->GetContextViewId()) {
680 // Swallow the error here.
681 LOG(INFO) << "Ignoring CloseContext request with a stale context_view_id "
682 << request->context_view_id() << " for context_id "
683 << request->context_id() << ". The current context_view_id is "
684 << context->Context()->GetContextViewId() << ".";
685 return OkStatus();
686 }
687
688 mutex_lock l(contexts_mu_);
689 contexts_.erase(request->context_id());
690
691 // GetServerContext returns a newly Reffed copy of ServerContext, which is
692 // unreffed by context_unref. Additionally, we need to unref it one time since
693 // we are releasing it from the map.
694 context->Unref();
695
696 return OkStatus();
697 }
698
RegisterFunction(const RegisterFunctionOp & register_function,EagerContext * eager_context)699 Status EagerServiceImpl::RegisterFunction(
700 const RegisterFunctionOp& register_function, EagerContext* eager_context) {
701 // If the function is a component of a multi-device function, we only need to
702 // register it locally.
703 return eager_context->AddFunctionDef(
704 register_function.function_def(), register_function.library(),
705 register_function.is_component_function());
706 }
707
CleanupFunction(const CleanupFunctionOp & cleanup_function)708 Status EagerServiceImpl::CleanupFunction(
709 const CleanupFunctionOp& cleanup_function) {
710 env_->rendezvous_mgr->Cleanup(cleanup_function.step_id());
711 return OkStatus();
712 }
713
SendTensor(const SendTensorOp & send_tensor,EagerContext * eager_context)714 Status EagerServiceImpl::SendTensor(const SendTensorOp& send_tensor,
715 EagerContext* eager_context) {
716 tensorflow::gtl::InlinedVector<tensorflow::TensorHandle*, 2> tensors;
717 for (const auto& tensor_proto : send_tensor.tensors()) {
718 Tensor tensor;
719 if (!tensor.FromProto(tensor_proto)) {
720 return errors::InvalidArgument("Unable to parse tensor proto");
721 }
722
723 TensorHandle* tensor_handle = TensorHandle::CreateLocalHandle(
724 std::move(tensor), nullptr, nullptr, eager_context);
725 TensorHandle* copied_handle = nullptr;
726 Device* device;
727 TF_RETURN_IF_ERROR(eager_context->FindDeviceFromName(
728 send_tensor.device_name().c_str(), &device));
729 TF_RETURN_IF_ERROR(EagerCopyToDevice(tensor_handle, eager_context,
730 &eager_context->Executor(), device,
731 false, &copied_handle));
732 tensors.push_back(copied_handle);
733 tensor_handle->Unref();
734 }
735
736 eager_context->RemoteMgr()->AddOperationOutputs(tensors, send_tensor.op_id());
737
738 return OkStatus();
739 }
740
SendPackedHandle(const SendPackedHandleOp & send_packed_handle,EagerContext * eager_context)741 Status EagerServiceImpl::SendPackedHandle(
742 const SendPackedHandleOp& send_packed_handle, EagerContext* eager_context) {
743 if (send_packed_handle.handles().empty()) {
744 return errors::InvalidArgument("Handles should not be empty.");
745 }
746
747 std::vector<tensorflow::TensorHandle*> handles;
748 handles.resize(send_packed_handle.handles_size());
749 for (int i = 0; i < send_packed_handle.handles_size(); ++i) {
750 const auto& item = send_packed_handle.handles(i);
751 if (item.has_local_handle()) {
752 Tensor tensor;
753 if (!ParseTensorProtoToTensor(item.local_handle().tensor(), &tensor)) {
754 return errors::InvalidArgument(
755 "Invalid TensorProto: ",
756 item.local_handle().tensor().DebugString());
757 }
758 Device* op_device = nullptr;
759 TF_RETURN_IF_ERROR(eager_context->FindDeviceFromName(
760 item.local_handle().device().c_str(), &op_device));
761 handles[i] = TensorHandle::CreateLocalHandle(
762 std::move(tensor), /*d=*/nullptr, op_device, eager_context);
763 } else {
764 TF_RETURN_IF_ERROR(
765 eager_context->RemoteMgr()->DeserializeRemoteTensorHandle(
766 item.remote_handle(), &handles[i]));
767 }
768 }
769
770 tensorflow::TensorHandle* packed_handle = nullptr;
771 std::vector<tensorflow::TensorHandle*> handles_to_pack = handles;
772 // Create a unshaped packed TensorHandle.
773 TF_RETURN_IF_ERROR(TensorHandle::CreatePackedHandle(
774 std::move(handles_to_pack), handles.at(0)->dtype, TensorShape(),
775 send_packed_handle.device_name(), eager_context, &packed_handle));
776
777 for (auto* h : handles) {
778 // Unref handle since it has a ref in the packed handle now.
779 h->Unref();
780 }
781
782 eager_context->RemoteMgr()->AddOperationOutputs({packed_handle},
783 send_packed_handle.op_id());
784 return OkStatus();
785 }
786
GetServerContext(uint64 context_id,ServerContext ** server_context)787 tensorflow::Status EagerServiceImpl::GetServerContext(
788 uint64 context_id, ServerContext** server_context) {
789 tf_shared_lock l(contexts_mu_);
790 auto iter = contexts_.find(context_id);
791 if (iter == contexts_.end()) {
792 *server_context = nullptr;
793 return errors::Aborted(strings::Printf(
794 "Unable to find a context_id matching the specified one "
795 "(%llu). Perhaps the worker was restarted, or the context was GC'd?",
796 static_cast<unsigned long long>(context_id)));
797 }
798
799 *server_context = iter->second;
800 (*server_context)->Ref();
801
802 (*server_context)->RecordAccess();
803
804 return OkStatus();
805 }
806
807 } // namespace eager
808 } // namespace tensorflow
809