xref: /aosp_15_r20/external/tensorflow/tensorflow/core/distributed_runtime/eager/eager_service_impl.cc (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/core/distributed_runtime/eager/eager_service_impl.h"
17 
18 #include <functional>
19 #include <memory>
20 #include <optional>
21 #include <string>
22 #include <utility>
23 #include <vector>
24 
25 #include "absl/container/fixed_array.h"
26 #include "absl/strings/str_cat.h"
27 #include "absl/types/optional.h"
28 #include "tensorflow/c/eager/immediate_execution_distributed_manager.h"
29 #include "tensorflow/core/common_runtime/device_mgr.h"
30 #include "tensorflow/core/common_runtime/eager/context.h"
31 #include "tensorflow/core/common_runtime/eager/context_distributed_manager.h"
32 #include "tensorflow/core/common_runtime/eager/eager_operation.h"
33 #include "tensorflow/core/common_runtime/eager/execute.h"
34 #include "tensorflow/core/distributed_runtime/eager/cluster_function_library_runtime.h"
35 #include "tensorflow/core/distributed_runtime/eager/remote_mgr.h"
36 #include "tensorflow/core/distributed_runtime/eager/remote_tensor_handle.h"
37 #include "tensorflow/core/distributed_runtime/message_wrappers.h"
38 #include "tensorflow/core/distributed_runtime/preemption/preemption_notifier.h"
39 #include "tensorflow/core/distributed_runtime/session_mgr.h"
40 #include "tensorflow/core/distributed_runtime/worker_cache.h"
41 #include "tensorflow/core/distributed_runtime/worker_env.h"
42 #include "tensorflow/core/framework/rendezvous.h"
43 #include "tensorflow/core/platform/errors.h"
44 #include "tensorflow/core/platform/host_info.h"
45 #include "tensorflow/core/platform/mutex.h"
46 #include "tensorflow/core/platform/refcount.h"
47 #include "tensorflow/core/platform/status.h"
48 #include "tensorflow/core/platform/stringprintf.h"
49 #include "tensorflow/core/profiler/lib/traceme.h"
50 #include "tensorflow/core/protobuf/coordination_config.pb.h"
51 
52 namespace tensorflow {
53 namespace eager {
54 
55 namespace {
GetNumRetvals(tensorflow::EagerContext * context,const string & op_name,const google::protobuf::Map<string,tensorflow::AttrValue> & attrs,int * num_retvals)56 Status GetNumRetvals(tensorflow::EagerContext* context, const string& op_name,
57                      const google::protobuf::Map<string, tensorflow::AttrValue>& attrs,
58                      int* num_retvals) {
59   const tensorflow::OpRegistrationData* op_reg_data = nullptr;
60   auto status = tensorflow::OpRegistry::Global()->LookUp(op_name, &op_reg_data);
61   if (errors::IsNotFound(status)) {
62     status = context->FindFunctionOpData(op_name, &op_reg_data);
63   }
64   TF_RETURN_IF_ERROR(status);
65 
66   const tensorflow::OpDef& op_def = op_reg_data->op_def;
67 
68   for (const auto& output_arg : op_def.output_arg()) {
69     if (!output_arg.number_attr().empty()) {
70       auto iter = attrs.find(output_arg.number_attr());
71       if (iter == attrs.end()) {
72         return errors::InvalidArgument("Unable to find number_attr ",
73                                        output_arg.number_attr(),
74                                        " for Op: ", op_name);
75       }
76       *num_retvals += iter->second.i();
77     } else if (!output_arg.type_list_attr().empty()) {
78       auto iter = attrs.find(output_arg.type_list_attr());
79       if (iter == attrs.end()) {
80         return errors::InvalidArgument("Unable to find type_list_attr ",
81                                        output_arg.type_list_attr(),
82                                        " for Op: ", op_name);
83       }
84       *num_retvals += iter->second.list().type_size();
85     } else {
86       *num_retvals += 1;
87     }
88   }
89 
90   return OkStatus();
91 }
92 
GetEagerOperationAndNumRetvals(const Operation & operation,EagerContext * eager_context,EagerExecutor * eager_executor,EagerOperation * eager_op,int * num_retvals)93 Status GetEagerOperationAndNumRetvals(const Operation& operation,
94                                       EagerContext* eager_context,
95                                       EagerExecutor* eager_executor,
96                                       EagerOperation* eager_op,
97                                       int* num_retvals) {
98   const char* name = operation.name().c_str();  // Shorthand
99   std::optional<tensorflow::EagerFunctionParams> remote_func_params =
100       std::nullopt;
101   if (operation.is_function()) {
102     if (operation.is_component_function()) {
103       remote_func_params = {operation.id(), /*is_component_function=*/true,
104                             operation.func_step_id()};
105     } else {
106       remote_func_params = {operation.id(), /*is_component_function=*/false,
107                             std::nullopt};
108     }
109   }
110   TF_RETURN_IF_ERROR(eager_op->Reset(name, operation.device().c_str(), false,
111                                      eager_executor, remote_func_params));
112 
113   {
114     profiler::TraceMe activity("EagerService:RemoteTensorHandleInternal",
115                                profiler::TraceMeLevel::kVerbose);
116     for (const auto& input : operation.op_inputs()) {
117       tensorflow::TensorHandle* handle;
118       if (input.has_remote_handle()) {
119         TF_RETURN_IF_ERROR(
120             eager_context->RemoteMgr()->DeserializeRemoteTensorHandle(
121                 input.remote_handle(), &handle));
122         TF_RETURN_IF_ERROR(eager_op->AddInput(handle));
123       } else {
124         Tensor tensor;
125         if (!ParseTensorProtoToTensor(input.tensor(), &tensor)) {
126           return errors::InvalidArgument("Invalid TensorProto: ",
127                                          input.tensor().DebugString());
128         } else {
129           handle = TensorHandle::CreateLocalHandle(std::move(tensor), nullptr,
130                                                    nullptr, eager_context);
131           TF_RETURN_IF_ERROR(eager_op->AddInput(handle));
132         }
133       }
134       // Unref handle since it has a ref as an input now.
135       handle->Unref();
136     }
137   }
138 
139   for (const auto& attr : operation.attrs()) {
140     eager_op->MutableAttrs()->Set(attr.first, attr.second);
141   }
142 
143   // TODO(nareshmodi): Consider caching this.
144   return GetNumRetvals(eager_context, operation.name(), operation.attrs(),
145                        num_retvals);
146 }
147 
TensorHandleProto(TensorHandle * handle,TensorProto * proto)148 Status TensorHandleProto(TensorHandle* handle, TensorProto* proto) {
149   const tensorflow::Tensor* t = nullptr;
150   TF_RETURN_IF_ERROR(handle->Tensor(&t));
151   t->AsProtoTensorContent(proto);
152   return OkStatus();
153 }
154 
TensorHandleShape(TensorHandle * handle,TensorShapeProto * proto)155 Status TensorHandleShape(TensorHandle* handle, TensorShapeProto* proto) {
156   const tensorflow::Tensor* t = nullptr;
157 
158   // TODO(nareshmodi): This call makes async calls sync calls. Fix this.
159   if (handle->Type() == TensorHandle::LOCAL) {
160     TF_RETURN_IF_ERROR(handle->Tensor(&t));
161 
162     t->shape().AsProto(proto);
163   } else {
164     TensorShape shape;
165     TF_RETURN_IF_ERROR(handle->Shape(&shape));
166     shape.AsProto(proto);
167   }
168 
169   return OkStatus();
170 }
171 
AddOpRetvalsToResponse(EagerContext * eager_context,int op_id,int num_retvals,const std::vector<int32> & output_nums,TensorHandle ** retvals,std::function<TensorProto * ()> add_tensor_proto_fn,std::function<TensorShapeProto * ()> add_shape_proto_fn,std::function<string * ()> add_device_fn=nullptr)172 Status AddOpRetvalsToResponse(
173     EagerContext* eager_context, int op_id, int num_retvals,
174     const std::vector<int32>& output_nums, TensorHandle** retvals,
175     std::function<TensorProto*()> add_tensor_proto_fn,
176     std::function<TensorShapeProto*()> add_shape_proto_fn,
177     std::function<string*()> add_device_fn = nullptr) {
178   // retvals hold references to the allocated output tensor handles. If errors
179   // happen with adding some results to the response, aggregate the status in sg
180   // instead of directly returning the error, to make sure unref or ownership
181   // transfer completes for the rest of output tensor handles.
182   StatusGroup sg;
183   if (op_id == kInvalidOpId) {
184     // Copy the output tensors back along with the response, since the op id
185     // is invalid which cannot be added to RemoteMgr.
186     for (int i = 0; i < num_retvals; i++) {
187       sg.Update(TensorHandleProto(retvals[i], add_tensor_proto_fn()));
188       retvals[i]->Unref();
189     }
190   } else {
191     for (int i = 0; i < num_retvals; i++) {
192       sg.Update(TensorHandleShape(retvals[i], add_shape_proto_fn()));
193       if (add_device_fn) {
194         Device* device = retvals[i]->device();
195         *add_device_fn() = device ? device->name() : "";
196       }
197       if (retvals[i]->Type() == TensorHandle::REMOTE) {
198         retvals[i]->Unref();
199       } else {
200         const int output_num = output_nums.empty() ? i : output_nums.at(i);
201         eager_context->RemoteMgr()->AddOperationOutput(retvals[i], op_id,
202                                                        output_num);
203       }
204     }
205   }
206   return sg.as_summary_status();
207 }
208 }  // namespace
209 
CreateContext(const CreateContextRequest * request,CreateContextResponse * response)210 Status EagerServiceImpl::CreateContext(const CreateContextRequest* request,
211                                        CreateContextResponse* response) {
212   {
213     mutex_lock l(contexts_mu_);
214     auto context_it = contexts_.find(request->context_id());
215     if (context_it != contexts_.end()) {
216       if (request->context_view_id() <
217           context_it->second->Context()->GetContextViewId()) {
218         return errors::InvalidArgument("EagerService:CreateContext failed. ",
219                                        "Context id: <", request->context_id(),
220                                        "> already exists.");
221       } else {
222         // For existing context with a stale context_view_id, close the old one
223         // and recreate with new view id. This is likely due to the worker
224         // disconnected and then reconnected after one or more cluster updates.
225         context_it->second->Unref();
226         contexts_.erase(context_it);
227       }
228     }
229   }
230   // make sure env_ , env_->rendezvous_mgr available
231   if (env_ == nullptr || env_->rendezvous_mgr == nullptr) {
232     return tensorflow::errors::Internal(
233         "invalid eager env_ or env_->rendezvous_mgr.");
234   }
235 
236   auto* r = env_->rendezvous_mgr->Find(request->context_id());
237   auto session_name =
238       tensorflow::strings::StrCat("eager_", request->context_id());
239   if (VLOG_IS_ON(2)) {
240     VLOG(2) << "Creating context on /job:" << request->server_def().job_name()
241             << "/task:" << request->server_def().task_index();
242     for (const auto& da : request->cluster_device_attributes()) {
243       VLOG(2) << "    " << da.name();
244     }
245   }
246   TF_RETURN_IF_ERROR(env_->session_mgr->CreateSession(
247       session_name, request->server_def(), request->cluster_device_attributes(),
248       request->server_def().default_session_config().isolate_session_state()));
249   int64_t context_id = request->context_id();
250   std::function<void()> session_destroyer = [this, context_id, session_name]() {
251     env_->rendezvous_mgr->Cleanup(context_id);
252     auto s = env_->session_mgr->DeleteSession(session_name);
253     if (!s.ok()) {
254       LOG(WARNING) << "Failed to destroy worker session '" << session_name
255                    << "' due to " << s.error_message();
256     }
257   };
258 
259   std::shared_ptr<WorkerSession> worker_session;
260   TF_RETURN_IF_ERROR(env_->session_mgr->WorkerSessionForSession(
261       session_name, &worker_session));
262 
263   tensorflow::DeviceMgr* device_mgr = worker_session->device_mgr();
264 
265   // Initialize remote tensor communication based on worker session.
266   TF_RETURN_IF_ERROR(r->Initialize(worker_session.get()));
267   // Set the rendezvous as context-global instance for eager op-by-op execution.
268   r->SetRemoteEagerContextDefault();
269 
270   std::function<Rendezvous*(const int64_t)> rendezvous_creator =
271       [worker_session, this](const int64_t step_id) {
272         auto* r = env_->rendezvous_mgr->Find(step_id);
273         r->Initialize(worker_session.get()).IgnoreError();
274         return r;
275       };
276 
277   LOG(INFO) << "Creating " << (request->async() ? "async" : "sync")
278             << " eager service context with rendezvous_id on host "
279             << port::Hostname() << " " << worker_session->worker_name();
280   SessionOptions opts;
281   opts.config = request->server_def().default_session_config();
282   tensorflow::EagerContext* ctx = new tensorflow::EagerContext(
283       opts, tensorflow::ContextDevicePlacementPolicy::DEVICE_PLACEMENT_SILENT,
284       request->async(), device_mgr, false, r, worker_session->cluster_flr(),
285       env_->collective_executor_mgr.get());
286   // Ownership will be transferred to the ServerContext, or else in an error
287   // case ctx will be deleted by this unref.
288   core::ScopedUnref unref_ctx(ctx);
289 
290   std::vector<string> remote_workers;
291   worker_session->worker_cache()->ListWorkers(&remote_workers);
292   remote_workers.erase(std::remove(remote_workers.begin(), remote_workers.end(),
293                                    worker_session->worker_name()),
294                        remote_workers.end());
295 
296   std::unique_ptr<tensorflow::eager::EagerClientCache> remote_eager_workers;
297   TF_RETURN_IF_ERROR(worker_session->worker_cache()->GetEagerClientCache(
298       &remote_eager_workers));
299   DistributedFunctionLibraryRuntime* cluster_flr =
300       eager::CreateClusterFLR(request->context_id(), ctx, worker_session.get());
301 
302   auto remote_mgr =
303       std::make_unique<tensorflow::eager::RemoteMgr>(/*is_master=*/false, ctx);
304   Status s = ctx->InitializeRemoteWorker(
305       std::move(remote_eager_workers), worker_session->remote_device_mgr(),
306       remote_workers, request->context_id(), request->context_view_id(),
307       std::move(rendezvous_creator), cluster_flr, std::move(remote_mgr),
308       std::move(session_destroyer));
309   if (!s.ok()) {
310     VLOG(1) << "EagerContext::InitializeRemoteWorker failed with "
311             << s.ToString();
312     return s;
313   }
314 
315 #if !defined(IS_MOBILE_PLATFORM)
316   const auto& config = request->server_def().default_session_config();
317   const bool enable_coordination =
318       !config.experimental().coordination_config().service_type().empty();
319   if (enable_coordination) {
320     auto dist_mgr = std::make_unique<EagerContextDistributedManager>(ctx);
321     auto coord_agent = env_->session_mgr->GetCoordinationServiceAgent();
322     dist_mgr->SetCoordinationServiceAgent(coord_agent);
323     auto preemption_notifier =
324         PreemptionNotifier::CreatePreemptionNotifier("sigterm", Env::Default());
325     preemption_notifier->WillBePreemptedAtAsync(
326         [coord_agent](StatusOr<absl::Time> time_or_status) {
327           if (time_or_status.ok()) {
328             const auto& coord_task = coord_agent->GetOwnTask().ValueOrDie();
329             Status s = coord_agent->InsertKeyValue(
330                 "TF_DEFAULT_PREEMPTION_NOTICE_KEY",
331                 absl::StrCat("/job:", coord_task.job_name(),
332                              "/task:", coord_task.task_id()));
333             if (!s.ok()) {
334               LOG(INFO) << "Preemption not exported to coordination service: "
335                         << s;
336             }
337           }
338         });
339     ctx->SetDistributedManager(std::move(dist_mgr));
340   }
341 #endif  // !IS_MOBILE_PLATFORM
342 
343   std::vector<DeviceAttributes> device_attributes;
344   device_mgr->ListDeviceAttributes(&device_attributes);
345 
346   for (const auto& da : device_attributes) {
347     *response->add_device_attributes() = da;
348   }
349   {
350     mutex_lock l(contexts_mu_);
351     auto context_it = contexts_.find(request->context_id());
352     if (context_it != contexts_.end()) {
353       return errors::InvalidArgument("EagerService:CreateContext failed. ",
354                                      "Context id: <", request->context_id(),
355                                      "> already exists.");
356     }
357     contexts_.emplace(request->context_id(),
358                       new ServerContext(ctx, request->keep_alive_secs(), env_));
359   }
360 
361   return OkStatus();
362 }
363 
UpdateContext(const UpdateContextRequest * request,UpdateContextResponse * response)364 Status EagerServiceImpl::UpdateContext(const UpdateContextRequest* request,
365                                        UpdateContextResponse* response) {
366   // make sure env_ , env_->rendezvous_mgr available
367   if (env_ == nullptr || env_->rendezvous_mgr == nullptr) {
368     return tensorflow::errors::Internal(
369         "invalid eager env_ or env_->rendezvous_mgr.");
370   }
371 
372   // Find the context to update by the requested context_id
373   ServerContext* server_context = nullptr;
374   TF_RETURN_IF_ERROR(GetServerContext(request->context_id(), &server_context));
375   core::ScopedUnref context_unref(server_context);
376 
377   tensorflow::EagerContext* ctx = server_context->Context();
378   if (request->context_view_id() != ctx->GetContextViewId() + 1) {
379     return errors::InvalidArgument(
380         "EagerService:UpdateContext failed. Context id: <",
381         request->context_id(), "> currently at view #", ctx->GetContextViewId(),
382         " but received update request at view #", request->context_view_id(),
383         ". View id should only be continuously incremented.");
384   }
385   if (request->cluster_device_attributes_size() == 0) {
386     // In this case, the client indicates that the updated `server_def` and
387     // device info is irrelevant to this worker, since it is not connected to
388     // the updated ones (likely due to device filter settings). The worker
389     // simply needs to update view ID and does not update other internal state.
390     ctx->IncrementContextViewId();
391     VLOG(1) << "Processing simplified UpdateContextRequest on "
392             << ctx->HostCPU()->name();
393     return OkStatus();
394   }
395 
396   auto session_name =
397       tensorflow::strings::StrCat("eager_", request->context_id());
398 
399   TF_RETURN_IF_ERROR(
400       env_->session_mgr->UpdateSession(session_name, request->server_def(),
401                                        request->cluster_device_attributes()));
402 
403   std::shared_ptr<WorkerSession> worker_session;
404   TF_RETURN_IF_ERROR(env_->session_mgr->WorkerSessionForSession(
405       session_name, &worker_session));
406 
407   const tensorflow::DeviceMgr* device_mgr = worker_session->device_mgr();
408 
409   std::vector<string> remote_workers;
410   worker_session->worker_cache()->ListWorkers(&remote_workers);
411   remote_workers.erase(std::remove(remote_workers.begin(), remote_workers.end(),
412                                    worker_session->worker_name()),
413                        remote_workers.end());
414   VLOG(1) << "On existing server " << worker_session->worker_name()
415           << " updating remote workers";
416   if (VLOG_IS_ON(2)) {
417     for (const string& rw : remote_workers) {
418       VLOG(2) << "Remote worker " << rw;
419     }
420   }
421 
422   std::unique_ptr<tensorflow::eager::EagerClientCache> remote_eager_workers;
423   TF_RETURN_IF_ERROR(worker_session->worker_cache()->GetEagerClientCache(
424       &remote_eager_workers));
425 
426   ctx->ClearCachesAndThreadExecutors();
427   Status s = ctx->UpdateRemoteWorker(std::move(remote_eager_workers),
428                                      remote_workers, request->context_id());
429   if (!s.ok()) {
430     VLOG(1) << "EagerContext::UpdateRemoteWorker failed with " << s.ToString();
431     return s;
432   }
433 
434   std::vector<DeviceAttributes> device_attributes;
435   device_mgr->ListDeviceAttributes(&device_attributes);
436 
437   for (const auto& da : device_attributes) {
438     *response->add_device_attributes() = da;
439   }
440 
441   return OkStatus();
442 }
443 
CreateMasterContext(const tensorflow::uint64 context_id,EagerContext * context)444 Status EagerServiceImpl::CreateMasterContext(
445     const tensorflow::uint64 context_id, EagerContext* context) {
446   {
447     mutex_lock l(contexts_mu_);
448     auto iter = contexts_.find(context_id);
449     if (iter != contexts_.end()) {
450       return errors::InvalidArgument(
451           "EagerService:CreateMasterContext failed. ", "Context id: <",
452           context_id, "> already exists.");
453     }
454   }
455   ServerContext* server_context =
456       ServerContext::CreateMasterContext(context, env_);
457   mutex_lock l(contexts_mu_);
458   contexts_.emplace(context_id, server_context);
459   return OkStatus();
460 }
461 
RunComponentFunction(CallOptions * call_opts,const RunComponentFunctionRequest * request,RunComponentFunctionResponse * response,StatusCallback done)462 void EagerServiceImpl::RunComponentFunction(
463     CallOptions* call_opts, const RunComponentFunctionRequest* request,
464     RunComponentFunctionResponse* response, StatusCallback done) {
465   ServerContext* context = nullptr;
466   Status s = GetServerContext(request->context_id(), &context);
467   if (!s.ok()) {
468     done(s);
469     return;
470   }
471   core::ScopedUnref context_unref(context);
472 
473   auto& operation = request->operation();
474   // This codepath should only be triggered for executing component function
475   if (!operation.is_function() || !operation.is_component_function()) {
476     done(errors::Internal(
477         "RunComponentFunction request can only be used to execute "
478         "component functions."));
479     return;
480   }
481 
482   EagerContext* eager_context = context->Context();
483   EagerExecutor* eager_executor = &eager_context->Executor();
484 
485   EagerOperation* op = new EagerOperation(eager_context);
486   int* num_retvals = new int(0);
487   s = GetEagerOperationAndNumRetvals(operation, eager_context, eager_executor,
488                                      op, num_retvals);
489   if (!s.ok()) {
490     delete num_retvals;
491     delete op;
492     done(s);
493     return;
494   }
495   if (!op->IsLocal()) {
496     delete num_retvals;
497     delete op;
498     done(errors::Internal(
499         "Received RunComponentFunction request with remote function device. "));
500     return;
501   }
502   s = op->SetAttrBool("is_component_function", true);
503   if (!s.ok()) {
504     delete num_retvals;
505     delete op;
506     done(errors::Internal("Error setting is_component_function attribute: ",
507                           s.error_message()));
508     return;
509   }
510 
511   auto* retvals = new absl::FixedArray<TensorHandle*>(*num_retvals);
512   VLOG(3) << "ServerContext: Calling EagerLocalExecuteAsync for op "
513           << operation.id();
514   std::vector<int32> output_nums;
515   for (const int32_t output_num : request->output_num()) {
516     output_nums.push_back(output_num);
517   }
518 
519   auto cm = std::make_shared<CancellationManager>();
520   op->SetCancellationManager(cm.get());
521   call_opts->SetCancelCallback([cm] { cm->StartCancel(); });
522 
523   context->Ref();
524   EagerLocalExecuteAsync(
525       op, retvals->data(), num_retvals,
526       [op, op_id = operation.id(), num_retvals, retvals, output_nums, cm,
527        call_opts, response, eager_context, context,
528        done = std::move(done)](const Status& status) {
529         call_opts->ClearCancelCallback();
530         auto wrapped_done = [&](const Status& status) {
531           context->Unref();
532           done(status);
533           delete op;
534           delete num_retvals;
535           delete retvals;
536         };
537         if (!status.ok()) {
538           wrapped_done(status);
539           return;
540         }
541         // The output device of a component function is the component device
542         // which is known on the default device of it's parent function.
543         wrapped_done(AddOpRetvalsToResponse(
544             eager_context, op_id, *num_retvals, output_nums, retvals->data(),
545             [response] { return response->add_tensor(); },
546             [response] { return response->add_shape(); }));
547       });
548 }
549 
ExecuteOp(CallOptions * call_opts,const Operation & operation,EagerContext * eager_context,EagerExecutor * eager_executor,QueueResponse * queue_response)550 Status EagerServiceImpl::ExecuteOp(CallOptions* call_opts,
551                                    const Operation& operation,
552                                    EagerContext* eager_context,
553                                    EagerExecutor* eager_executor,
554                                    QueueResponse* queue_response) {
555   tensorflow::EagerOperation op(eager_context);
556   int num_retvals = 0;
557   TF_RETURN_IF_ERROR(GetEagerOperationAndNumRetvals(
558       operation, eager_context, eager_executor, &op, &num_retvals));
559 
560   auto cm = std::make_shared<CancellationManager>();
561   if (call_opts) {
562     op.SetCancellationManager(cm.get());
563     call_opts->SetCancelCallback([cm] { cm->StartCancel(); });
564   }
565 
566   absl::FixedArray<tensorflow::TensorHandle*> retvals(num_retvals);
567   VLOG(3) << "ServerContext: Calling EagerExecute for op " << operation.id();
568   TF_RETURN_IF_ERROR(op.Execute(
569       absl::MakeSpan(
570           reinterpret_cast<tensorflow::AbstractTensorHandle**>(retvals.data()),
571           num_retvals),
572       &num_retvals));
573 
574   std::function<string*()> add_device_fn = nullptr;
575   // Send the output devices of a function back to let a client know where the
576   // outputs are. For a primitive op, an output devics is the op device which is
577   // known on a client.
578   if (op.is_function()) {
579     add_device_fn = [queue_response] { return queue_response->add_device(); };
580   }
581 
582   return AddOpRetvalsToResponse(
583       eager_context, operation.id(), num_retvals, /*output_nums=*/{},
584       retvals.data(), [queue_response] { return queue_response->add_tensor(); },
585       [queue_response] { return queue_response->add_shape(); },
586       std::move(add_device_fn));
587 }
588 
Enqueue(CallOptions * call_opts,const EnqueueRequest * request,EnqueueResponse * response,uint64 stream_id)589 Status EagerServiceImpl::Enqueue(CallOptions* call_opts,
590                                  const EnqueueRequest* request,
591                                  EnqueueResponse* response, uint64 stream_id) {
592   profiler::TraceMe activity(
593       [&] {
594         return absl::StrCat(
595             "EagerService:Enqueue#debug_str=", request->DebugString(), "#");
596       },
597       profiler::TraceMeLevel::kInfo);
598   ServerContext* context = nullptr;
599   TF_RETURN_IF_ERROR(GetServerContext(request->context_id(), &context));
600   core::ScopedUnref context_unref(context);
601 
602   EagerExecutor& executor =
603       stream_id == kInvalidStreamId
604           ? context->Context()->Executor()
605           : context->Context()->RemoteMgr()->GetOrCreateExecutorForStream(
606                 stream_id);
607   Status s;
608   for (const auto& item : request->queue()) {
609     auto* queue_response = response->add_queue_response();
610     if (item.has_operation()) {
611       s = ExecuteOp(call_opts, item.operation(), context->Context(), &executor,
612                     queue_response);
613     } else if (item.has_handle_to_decref()) {
614       auto handle_to_decref = std::make_unique<RemoteTensorHandleInternal>(
615           item.handle_to_decref());
616       auto node = std::make_unique<ClientTensorHandleDeleteNode>(
617           context, std::move(handle_to_decref));
618       s = context->Context()->Executor().AddOrExecute(std::move(node));
619     } else if (item.has_send_tensor()) {
620       s = SendTensor(item.send_tensor(), context->Context());
621     } else if (item.has_send_packed_handle()) {
622       s = SendPackedHandle(item.send_packed_handle(), context->Context());
623     } else if (item.has_register_function()) {
624       s = RegisterFunction(item.register_function(), context->Context());
625     } else if (item.has_cleanup_function()) {
626       s = CleanupFunction(item.cleanup_function());
627     } else {
628       DCHECK(item.has_sync_remote_executor_for_stream());
629       s = executor.WaitForAllPendingNodes();
630     }
631 
632     if (!s.ok()) {
633       if (stream_id != kInvalidStreamId) {
634         context->Context()->RemoteMgr()->DeleteExecutorForStream(stream_id);
635       }
636       return s;
637     }
638   }
639 
640   return OkStatus();
641 }
642 
WaitQueueDone(const WaitQueueDoneRequest * request,WaitQueueDoneResponse * response)643 Status EagerServiceImpl::WaitQueueDone(const WaitQueueDoneRequest* request,
644                                        WaitQueueDoneResponse* response) {
645   ServerContext* context = nullptr;
646   TF_RETURN_IF_ERROR(GetServerContext(request->context_id(), &context));
647   core::ScopedUnref context_unref(context);
648 
649   if (request->op_id_size() > 0) {
650     return errors::Unimplemented(
651         "EagerServiceImpl::WaitQueueDone is not "
652         "implemented for particular op IDs.");
653   }
654   return context->Context()->Executor().WaitForAllPendingNodes();
655 }
656 
KeepAlive(const KeepAliveRequest * request,KeepAliveResponse * response)657 Status EagerServiceImpl::KeepAlive(const KeepAliveRequest* request,
658                                    KeepAliveResponse* response) {
659   ServerContext* context = nullptr;
660   TF_RETURN_IF_ERROR(GetServerContext(request->context_id(), &context));
661   core::ScopedUnref context_unref(context);
662 
663   tensorflow::EagerContext* ctx = context->Context();
664   response->set_context_view_id(ctx->GetContextViewId());
665   return OkStatus();
666 }
667 
CloseContext(const CloseContextRequest * request,CloseContextResponse * response)668 Status EagerServiceImpl::CloseContext(const CloseContextRequest* request,
669                                       CloseContextResponse* response) {
670   VLOG(1) << "Executing EagerService::CloseContext for context "
671           << request->context_id();
672   ServerContext* context = nullptr;
673   if (!GetServerContext(request->context_id(), &context).ok()) {
674     // Swallow the error here.
675     return OkStatus();
676   }
677   core::ScopedUnref context_unref(context);
678 
679   if (request->context_view_id() < context->Context()->GetContextViewId()) {
680     // Swallow the error here.
681     LOG(INFO) << "Ignoring CloseContext request with a stale context_view_id "
682               << request->context_view_id() << "  for context_id "
683               << request->context_id() << ". The current context_view_id is "
684               << context->Context()->GetContextViewId() << ".";
685     return OkStatus();
686   }
687 
688   mutex_lock l(contexts_mu_);
689   contexts_.erase(request->context_id());
690 
691   // GetServerContext returns a newly Reffed copy of ServerContext, which is
692   // unreffed by context_unref. Additionally, we need to unref it one time since
693   // we are releasing it from the map.
694   context->Unref();
695 
696   return OkStatus();
697 }
698 
RegisterFunction(const RegisterFunctionOp & register_function,EagerContext * eager_context)699 Status EagerServiceImpl::RegisterFunction(
700     const RegisterFunctionOp& register_function, EagerContext* eager_context) {
701   // If the function is a component of a multi-device function, we only need to
702   // register it locally.
703   return eager_context->AddFunctionDef(
704       register_function.function_def(), register_function.library(),
705       register_function.is_component_function());
706 }
707 
CleanupFunction(const CleanupFunctionOp & cleanup_function)708 Status EagerServiceImpl::CleanupFunction(
709     const CleanupFunctionOp& cleanup_function) {
710   env_->rendezvous_mgr->Cleanup(cleanup_function.step_id());
711   return OkStatus();
712 }
713 
SendTensor(const SendTensorOp & send_tensor,EagerContext * eager_context)714 Status EagerServiceImpl::SendTensor(const SendTensorOp& send_tensor,
715                                     EagerContext* eager_context) {
716   tensorflow::gtl::InlinedVector<tensorflow::TensorHandle*, 2> tensors;
717   for (const auto& tensor_proto : send_tensor.tensors()) {
718     Tensor tensor;
719     if (!tensor.FromProto(tensor_proto)) {
720       return errors::InvalidArgument("Unable to parse tensor proto");
721     }
722 
723     TensorHandle* tensor_handle = TensorHandle::CreateLocalHandle(
724         std::move(tensor), nullptr, nullptr, eager_context);
725     TensorHandle* copied_handle = nullptr;
726     Device* device;
727     TF_RETURN_IF_ERROR(eager_context->FindDeviceFromName(
728         send_tensor.device_name().c_str(), &device));
729     TF_RETURN_IF_ERROR(EagerCopyToDevice(tensor_handle, eager_context,
730                                          &eager_context->Executor(), device,
731                                          false, &copied_handle));
732     tensors.push_back(copied_handle);
733     tensor_handle->Unref();
734   }
735 
736   eager_context->RemoteMgr()->AddOperationOutputs(tensors, send_tensor.op_id());
737 
738   return OkStatus();
739 }
740 
SendPackedHandle(const SendPackedHandleOp & send_packed_handle,EagerContext * eager_context)741 Status EagerServiceImpl::SendPackedHandle(
742     const SendPackedHandleOp& send_packed_handle, EagerContext* eager_context) {
743   if (send_packed_handle.handles().empty()) {
744     return errors::InvalidArgument("Handles should not be empty.");
745   }
746 
747   std::vector<tensorflow::TensorHandle*> handles;
748   handles.resize(send_packed_handle.handles_size());
749   for (int i = 0; i < send_packed_handle.handles_size(); ++i) {
750     const auto& item = send_packed_handle.handles(i);
751     if (item.has_local_handle()) {
752       Tensor tensor;
753       if (!ParseTensorProtoToTensor(item.local_handle().tensor(), &tensor)) {
754         return errors::InvalidArgument(
755             "Invalid TensorProto: ",
756             item.local_handle().tensor().DebugString());
757       }
758       Device* op_device = nullptr;
759       TF_RETURN_IF_ERROR(eager_context->FindDeviceFromName(
760           item.local_handle().device().c_str(), &op_device));
761       handles[i] = TensorHandle::CreateLocalHandle(
762           std::move(tensor), /*d=*/nullptr, op_device, eager_context);
763     } else {
764       TF_RETURN_IF_ERROR(
765           eager_context->RemoteMgr()->DeserializeRemoteTensorHandle(
766               item.remote_handle(), &handles[i]));
767     }
768   }
769 
770   tensorflow::TensorHandle* packed_handle = nullptr;
771   std::vector<tensorflow::TensorHandle*> handles_to_pack = handles;
772   // Create a unshaped packed TensorHandle.
773   TF_RETURN_IF_ERROR(TensorHandle::CreatePackedHandle(
774       std::move(handles_to_pack), handles.at(0)->dtype, TensorShape(),
775       send_packed_handle.device_name(), eager_context, &packed_handle));
776 
777   for (auto* h : handles) {
778     // Unref handle since it has a ref in the packed handle now.
779     h->Unref();
780   }
781 
782   eager_context->RemoteMgr()->AddOperationOutputs({packed_handle},
783                                                   send_packed_handle.op_id());
784   return OkStatus();
785 }
786 
GetServerContext(uint64 context_id,ServerContext ** server_context)787 tensorflow::Status EagerServiceImpl::GetServerContext(
788     uint64 context_id, ServerContext** server_context) {
789   tf_shared_lock l(contexts_mu_);
790   auto iter = contexts_.find(context_id);
791   if (iter == contexts_.end()) {
792     *server_context = nullptr;
793     return errors::Aborted(strings::Printf(
794         "Unable to find a context_id matching the specified one "
795         "(%llu). Perhaps the worker was restarted, or the context was GC'd?",
796         static_cast<unsigned long long>(context_id)));
797   }
798 
799   *server_context = iter->second;
800   (*server_context)->Ref();
801 
802   (*server_context)->RecordAccess();
803 
804   return OkStatus();
805 }
806 
807 }  // namespace eager
808 }  // namespace tensorflow
809