xref: /aosp_15_r20/external/tensorflow/tensorflow/core/distributed_runtime/rpc/grpc_worker_service.cc (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/core/distributed_runtime/rpc/grpc_worker_service.h"
17 
18 #include <deque>
19 #include <memory>
20 #include <unordered_map>
21 #include <vector>
22 
23 #include "grpcpp/alarm.h"
24 #include "grpcpp/server_builder.h"
25 #include "absl/container/flat_hash_map.h"
26 #include "tensorflow/core/common_runtime/buf_rendezvous.h"
27 #include "tensorflow/core/common_runtime/copy_tensor.h"
28 #include "tensorflow/core/common_runtime/device.h"
29 #include "tensorflow/core/common_runtime/device_mgr.h"
30 #include "tensorflow/core/common_runtime/dma_helper.h"
31 #include "tensorflow/core/common_runtime/local_device.h"
32 #include "tensorflow/core/common_runtime/process_util.h"
33 #include "tensorflow/core/common_runtime/step_stats_collector.h"
34 #include "tensorflow/core/distributed_runtime/graph_mgr.h"
35 #include "tensorflow/core/distributed_runtime/rendezvous_mgr_interface.h"
36 #include "tensorflow/core/distributed_runtime/rpc/async_service_interface.h"
37 #include "tensorflow/core/distributed_runtime/rpc/grpc_call.h"
38 #include "tensorflow/core/distributed_runtime/rpc/grpc_response_cache.h"
39 #include "tensorflow/core/distributed_runtime/rpc/grpc_tensor_coding.h"
40 #include "tensorflow/core/distributed_runtime/rpc/grpc_util.h"
41 #include "tensorflow/core/distributed_runtime/rpc/grpc_worker_service_impl.h"
42 #include "tensorflow/core/distributed_runtime/worker.h"
43 #include "tensorflow/core/distributed_runtime/worker_cache.h"
44 #include "tensorflow/core/distributed_runtime/worker_session.h"
45 #include "tensorflow/core/framework/cancellation.h"
46 #include "tensorflow/core/framework/collective.h"
47 #include "tensorflow/core/framework/tensor.h"
48 #include "tensorflow/core/lib/core/errors.h"
49 #include "tensorflow/core/lib/core/status.h"
50 #include "tensorflow/core/lib/gtl/map_util.h"
51 #include "tensorflow/core/lib/strings/strcat.h"
52 #include "tensorflow/core/lib/strings/stringprintf.h"
53 #include "tensorflow/core/platform/logging.h"
54 #include "tensorflow/core/platform/mutex.h"
55 #include "tensorflow/core/platform/tracing.h"
56 #include "tensorflow/core/profiler/lib/scoped_memory_debug_annotation.h"
57 #include "tensorflow/core/protobuf/transport_options.pb.h"
58 #include "tensorflow/core/protobuf/worker.pb.h"
59 
60 namespace tensorflow {
61 
62 namespace {
63 
64 // This macro creates a new request for the given RPC method name
65 // (e.g., `ENQUEUE_REQUEST(GetStatus, false);`), and enqueues it on
66 // `this->cq_`.
67 //
68 // This macro is invoked one or more times for each RPC method to
69 // ensure that there are sufficient completion queue entries to
70 // handle incoming requests without blocking.
71 //
72 // The implementation of the request handler for each RPC method
73 // must ensure that it calls ENQUEUE_REQUEST() for that RPC method,
74 // to keep accepting new requests.
75 #define ENQUEUE_REQUEST(method, supports_cancel)                             \
76   do {                                                                       \
77     mutex_lock l(shutdown_mu_);                                              \
78     if (!is_shutdown_) {                                                     \
79       Call<GrpcWorkerServiceThread, grpc::WorkerService::AsyncService,       \
80            method##Request, method##Response>::                              \
81           EnqueueRequestForMethod(                                           \
82               worker_service_, cq_.get(),                                    \
83               static_cast<int>(GrpcWorkerMethod::k##method),                 \
84               &GrpcWorkerServiceThread::method##Handler, (supports_cancel)); \
85     }                                                                        \
86   } while (0)
87 
88 #define SETUP_FOR_REQUEST(method, default_depth, supports_cancel)              \
89   for (int i = 0;                                                              \
90        i < gtl::FindWithDefault(queue_depth_,                                  \
91                                 static_cast<int>(GrpcWorkerMethod::k##method), \
92                                 default_depth);                                \
93        ++i) {                                                                  \
94     ENQUEUE_REQUEST(method, supports_cancel);                                  \
95   }
96 
97 // GrpcWorkerService spawns one or more GrpcWorkerServiceThreads to service
98 // requests.  Each thread operates on an independent completion queue.
99 class GrpcWorkerServiceThread {
100  public:
GrpcWorkerServiceThread(GrpcWorker * worker,::grpc::ServerBuilder * builder,std::unordered_map<int,int> queue_depth,GrpcResponseCache * cache,grpc::WorkerService::AsyncService * worker_service)101   explicit GrpcWorkerServiceThread(
102       GrpcWorker* worker, ::grpc::ServerBuilder* builder,
103       std::unordered_map<int, int> queue_depth, GrpcResponseCache* cache,
104       grpc::WorkerService::AsyncService* worker_service)
105       : worker_(worker),
106         queue_depth_(queue_depth),
107         cache_(cache),
108         worker_service_(worker_service),
109         is_shutdown_(false) {
110     cq_ = builder->AddCompletionQueue();
111   }
112 
Start()113   void Start() {
114     thread_.reset(
115         worker_->env()->env->StartThread(ThreadOptions(), "grpc_worker_service",
116                                          [this]() { HandleRPCsLoop(); }));
117   }
118 
Join()119   void Join() { thread_.reset(); }  // Blocks until thread exits
120 
Shutdown()121   void Shutdown() {
122     {
123       mutex_lock lock(shutdown_mu_);
124       is_shutdown_ = true;
125     }
126     cq_->Shutdown();
127   }
128 
129  private:
130   // Add one or more completion queue entries for each worker method, then
131   // begin servicing requests from the completion queue.
HandleRPCsLoop()132   void HandleRPCsLoop() {
133     // TODO(ncteisen): This may require performance engineering. We can
134     // change the number of threads, the number of handlers per thread,
135     // or even decide to specialize certain threads to certain methods.
136     SETUP_FOR_REQUEST(GetStatus, 1, false);
137     SETUP_FOR_REQUEST(CreateWorkerSession, 1, false);
138     SETUP_FOR_REQUEST(DeleteWorkerSession, 1, false);
139     SETUP_FOR_REQUEST(CleanupAll, 1, false);
140     SETUP_FOR_REQUEST(RegisterGraph, 1, false);
141     SETUP_FOR_REQUEST(DeregisterGraph, 1, false);
142     SETUP_FOR_REQUEST(Logging, 1, false);
143     SETUP_FOR_REQUEST(Tracing, 1, false);
144     SETUP_FOR_REQUEST(CompleteGroup, 10, true);
145     SETUP_FOR_REQUEST(CompleteInstance, 10, true);
146     SETUP_FOR_REQUEST(GetStepSequence, 10, true);
147     SETUP_FOR_REQUEST(RecvBuf, 500, true);
148     SETUP_FOR_REQUEST(RunGraph, 100, true);
149     SETUP_FOR_REQUEST(CleanupGraph, 100, false);
150     SETUP_FOR_REQUEST(MarkRecvFinished, 10, false);
151 
152     // TODO(ncteisen): Determine a better policy for enqueuing the
153     // appropriate number of each request type.
154     for (int i = 0;
155          i < gtl::FindWithDefault(
156                  queue_depth_, static_cast<int>(GrpcWorkerMethod::kRecvTensor),
157                  1000);
158          ++i) {
159       EnqueueRecvTensorRequestRaw();
160     }
161 
162     void* tag;
163     bool ok;
164 
165     while (cq_->Next(&tag, &ok)) {
166       UntypedCall<GrpcWorkerServiceThread>::Tag* callback_tag =
167           static_cast<UntypedCall<GrpcWorkerServiceThread>::Tag*>(tag);
168       CHECK(callback_tag);
169       callback_tag->OnCompleted(this, ok);
170     }
171   }
172 
173  private:
Schedule(std::function<void ()> f)174   void Schedule(std::function<void()> f) {
175     worker_->env()->compute_pool->Schedule(std::move(f));
176   }
177 
178   // The following section contains one request handler method per
179   // RPC. The `FooHandler` method is called (indirectly) by
180   // `HandleRPCsLoop()` when the next Foo RPC is received. Each
181   // `FooHandler` call schedules a closure on `worker_->env()->compute_pool`,
182   // and is responsible for requesting the next Foo call by calling
183   // `ENQUEUE_REQUEST(Foo)`.
184   template <class RequestMessage, class ResponseMessage>
185   using WorkerCall =
186       Call<GrpcWorkerServiceThread, grpc::WorkerService::AsyncService,
187            RequestMessage, ResponseMessage>;
188 
189   // Handle all non-cancellable simple methods with a standard wrapper.
190   // The boolean `may_block_on_compute_pool` indicates whether or not the
191   // operation may block on activities (such as op execution) that run on the
192   // compute pool.
193 #define HANDLE_CALL(method, may_block_on_compute_pool)                        \
194   void method##Handler(WorkerCall<method##Request, method##Response>* call) { \
195     auto closure = [this, call]() {                                           \
196       Status s = worker_->method(&call->request, &call->response);            \
197       if (!s.ok()) {                                                          \
198         VLOG(3) << "Bad response from " << #method << ": " << s;              \
199       }                                                                       \
200       call->SendResponse(ToGrpcStatus(s));                                    \
201     };                                                                        \
202     if ((may_block_on_compute_pool)) {                                        \
203       worker_->env()->env->SchedClosure(std::move(closure));                  \
204     } else {                                                                  \
205       worker_->env()->compute_pool->Schedule(std::move(closure));             \
206     }                                                                         \
207     ENQUEUE_REQUEST(method, false);                                           \
208   }
209 
210   HANDLE_CALL(GetStatus, false);
211   HANDLE_CALL(CreateWorkerSession, false);
212   HANDLE_CALL(DeleteWorkerSession, true);
213   HANDLE_CALL(CleanupAll, false);
214   HANDLE_CALL(RegisterGraph, false);
215   HANDLE_CALL(DeregisterGraph, false);
216   HANDLE_CALL(CleanupGraph, false);
217   HANDLE_CALL(Logging, false);
218   HANDLE_CALL(Tracing, false);
219 
220 #undef HANDLE_CALL
221 
GetStepSequenceHandler(WorkerCall<GetStepSequenceRequest,GetStepSequenceResponse> * call)222   void GetStepSequenceHandler(
223       WorkerCall<GetStepSequenceRequest, GetStepSequenceResponse>* call) {
224     Schedule([this, call]() {
225       worker_->GetStepSequenceAsync(
226           &call->request, &call->response, [call](const Status& s) {
227             VLOG(3) << "Bad response from GetStepSequence:" << s;
228             call->SendResponse(ToGrpcStatus(s));
229           });
230     });
231     ENQUEUE_REQUEST(GetStepSequence, true);
232   }
233 
MarkRecvFinishedHandler(WorkerCall<MarkRecvFinishedRequest,MarkRecvFinishedResponse> * call)234   void MarkRecvFinishedHandler(
235       WorkerCall<MarkRecvFinishedRequest, MarkRecvFinishedResponse>* call) {
236     VLOG(3) << "Clean cache entry for request " << call->request.request_id();
237     worker_->RemoveCacheEntryForId(call->request.request_id());
238     call->SendResponse(::grpc::Status::OK);
239     ENQUEUE_REQUEST(MarkRecvFinished, false);
240   }
241 
RunGraphHandler(WorkerCall<RunGraphRequest,RunGraphResponse> * call)242   void RunGraphHandler(WorkerCall<RunGraphRequest, RunGraphResponse>* call) {
243     Schedule([this, call]() {
244       CallOptions* call_opts = new CallOptions;
245       ProtoRunGraphRequest* wrapped_request =
246           new ProtoRunGraphRequest(&call->request);
247       NonOwnedProtoRunGraphResponse* wrapped_response =
248           new NonOwnedProtoRunGraphResponse(&call->response);
249       call->SetCancelCallback([call_opts]() { call_opts->StartCancel(); });
250       worker_->RunGraphAsync(call_opts, wrapped_request, wrapped_response,
251                              [call, call_opts, wrapped_request,
252                               wrapped_response](const Status& s) {
253                                VLOG(3) << "RunGraph::Done";
254                                if (!s.ok()) {
255                                  VLOG(3) << "Bad response from RunGraph:" << s;
256                                }
257                                call->ClearCancelCallback();
258                                delete call_opts;
259                                delete wrapped_request;
260                                delete wrapped_response;
261                                call->SendResponse(ToGrpcStatus(s));
262                              });
263     });
264     ENQUEUE_REQUEST(RunGraph, true);
265   }
266 
RecvTensorHandlerRaw(WorkerCall<RecvTensorRequest,::grpc::ByteBuffer> * call)267   void RecvTensorHandlerRaw(
268       WorkerCall<RecvTensorRequest, ::grpc::ByteBuffer>* call) {
269     Schedule([this, call]() {
270       CallOptions* call_opts = new CallOptions;
271       call->SetCancelCallback([call_opts]() { call_opts->StartCancel(); });
272 
273       worker_->GrpcRecvTensorAsync(
274           call_opts, &call->request, &call->response,
275           [call, call_opts](const Status& s) {
276             call->ClearCancelCallback();
277             delete call_opts;
278             if (!s.ok()) {
279               VLOG(3) << "Bad response from RecvTensor:" << s;
280             }
281             call->SendResponse(ToGrpcStatus(s));
282           });
283     });
284     EnqueueRecvTensorRequestRaw();
285   }
286 
RecvBufHandler(WorkerCall<RecvBufRequest,RecvBufResponse> * call)287   void RecvBufHandler(WorkerCall<RecvBufRequest, RecvBufResponse>* call) {
288     Schedule([this, call]() {
289       CallOptions* call_opts = new CallOptions;
290       call->SetCancelCallback([call_opts]() { call_opts->StartCancel(); });
291       worker_->RecvBufAsync(call_opts, &call->request, &call->response,
292                             [call, call_opts](const Status& s) {
293                               call->ClearCancelCallback();
294                               delete call_opts;
295                               if (!s.ok()) {
296                                 VLOG(3) << "Bad response from RecvBuf:" << s;
297                               }
298                               call->SendResponse(ToGrpcStatus(s));
299                             });
300     });
301     ENQUEUE_REQUEST(RecvBuf, true);
302   }
303 
CompleteGroupHandler(WorkerCall<CompleteGroupRequest,CompleteGroupResponse> * call)304   void CompleteGroupHandler(
305       WorkerCall<CompleteGroupRequest, CompleteGroupResponse>* call) {
306     Schedule([this, call]() {
307       CallOptions* call_opts = new CallOptions;
308       call->SetCancelCallback([call_opts]() { call_opts->StartCancel(); });
309       worker_->CompleteGroupAsync(
310           call_opts, &call->request, &call->response,
311           [call, call_opts](const Status& s) {
312             call->ClearCancelCallback();
313             delete call_opts;
314             if (!s.ok()) {
315               VLOG(3) << "Bad response from CompleteGroup:" << s;
316             }
317             call->SendResponse(ToGrpcStatus(s));
318           });
319     });
320     ENQUEUE_REQUEST(CompleteGroup, true);
321   }
322 
CompleteInstanceHandler(WorkerCall<CompleteInstanceRequest,CompleteInstanceResponse> * call)323   void CompleteInstanceHandler(
324       WorkerCall<CompleteInstanceRequest, CompleteInstanceResponse>* call) {
325     Schedule([this, call]() {
326       CallOptions* call_opts = new CallOptions;
327       call->SetCancelCallback([call_opts]() { call_opts->StartCancel(); });
328       worker_->CompleteInstanceAsync(
329           call_opts, &call->request, &call->response,
330           [call, call_opts](const Status& s) {
331             call->ClearCancelCallback();
332             delete call_opts;
333             if (!s.ok()) {
334               VLOG(3) << "Bad response from CompleteInstance:" << s;
335             }
336             call->SendResponse(ToGrpcStatus(s));
337           });
338     });
339     ENQUEUE_REQUEST(CompleteInstance, false);
340   }
341 #undef ENQUEUE_REQUEST
342 
EnqueueRecvTensorRequestRaw()343   void EnqueueRecvTensorRequestRaw() {
344     mutex_lock l(shutdown_mu_);
345     if (!is_shutdown_) {
346       Call<GrpcWorkerServiceThread, grpc::WorkerService::AsyncService,
347            RecvTensorRequest, ::grpc::ByteBuffer>::
348           EnqueueRequestForMethod(
349               worker_service_, cq_.get(),
350               static_cast<int>(GrpcWorkerMethod::kRecvTensor),
351               &GrpcWorkerServiceThread::RecvTensorHandlerRaw,
352               true /* supports cancel*/);
353     }
354   }
355 
356   GrpcWorker* const worker_ = nullptr;  // Not owned.
357   std::unique_ptr<::grpc::ServerCompletionQueue> cq_;
358   std::unique_ptr<Thread> thread_;
359   std::unordered_map<int, int> queue_depth_;
360   GrpcResponseCache* cache_;
361   grpc::WorkerService::AsyncService* const worker_service_;
362 
363   mutex shutdown_mu_;
364   bool is_shutdown_ TF_GUARDED_BY(shutdown_mu_);
365   TF_DISALLOW_COPY_AND_ASSIGN(GrpcWorkerServiceThread);
366 };
367 
368 class GrpcWorkerService : public AsyncServiceInterface {
369  public:
GrpcWorkerService(GrpcWorker * worker,::grpc::ServerBuilder * builder,GrpcWorkerServiceOptions options)370   GrpcWorkerService(GrpcWorker* worker, ::grpc::ServerBuilder* builder,
371                     GrpcWorkerServiceOptions options)
372       : is_shutdown_(false) {
373     builder->RegisterService(&worker_service_);
374 
375     for (int i = 0; i < options.num_serving_threads; i++) {
376       threads_.emplace_back(
377           new GrpcWorkerServiceThread(worker, builder, options.queue_depth,
378                                       cache_.get(), &worker_service_));
379     }
380   }
381 
Shutdown()382   void Shutdown() override {
383     bool did_shutdown = false;
384     {
385       mutex_lock l(service_shutdown_mu_);
386       if (!is_shutdown_) {
387         LOG(INFO) << "Shutting down GrpcWorkerService.";
388         is_shutdown_ = true;
389         did_shutdown = true;
390       }
391     }
392     if (did_shutdown) {
393       for (auto& worker_thread : threads_) {
394         worker_thread->Shutdown();
395       }
396     }
397   }
398 
399   // This method blocks forever handling requests from the completion queue.
HandleRPCsLoop()400   void HandleRPCsLoop() override {
401     for (auto& worker_thread : threads_) {
402       worker_thread->Start();
403     }
404     for (auto& worker_thread : threads_) {
405       worker_thread->Join();
406     }
407   }
408 
409  private:
410   grpc::WorkerService::AsyncService worker_service_;
411   std::vector<std::unique_ptr<GrpcWorkerServiceThread>> threads_;
412 
413   std::unique_ptr<GrpcResponseCache> cache_;
414   mutex service_shutdown_mu_;
415   bool is_shutdown_ TF_GUARDED_BY(service_shutdown_mu_);
416 
417   TF_DISALLOW_COPY_AND_ASSIGN(GrpcWorkerService);
418 };
419 
420 }  // namespace
421 
GrpcWorker(WorkerEnv * worker_env,const ConfigProto & config)422 GrpcWorker::GrpcWorker(WorkerEnv* worker_env, const ConfigProto& config)
423     : Worker(worker_env),
424       recv_buf_max_chunk_(
425           config.experimental().recv_buf_max_chunk() > 0
426               ? config.experimental().recv_buf_max_chunk()
427               : (config.experimental().recv_buf_max_chunk() < 0 ? 0 : 4096)) {
428   if (config.rpc_options().cache_rpc_response()) {
429     EnableResponseCache();
430   }
431 }
432 
EnableResponseCache()433 void GrpcWorker::EnableResponseCache() {
434   VLOG(3) << "Enabling gRPC tensor response cache.";
435   response_cache_ = std::make_unique<GrpcResponseCache>();
436 }
437 
438 // GrpcRecvTensorAsync: unlike the other Worker methods, which use protocol
439 // buffers for a response object, to avoid extra protocol buffer serialization
440 // overhead we generate our response directly into a ::grpc::ByteBuffer object
GrpcRecvTensorAsync(CallOptions * opts,const RecvTensorRequest * request,::grpc::ByteBuffer * response,StatusCallback done)441 void GrpcWorker::GrpcRecvTensorAsync(CallOptions* opts,
442                                      const RecvTensorRequest* request,
443                                      ::grpc::ByteBuffer* response,
444                                      StatusCallback done) {
445   VLOG(3) << "GrpcRecvTensorAsync req: " << request->DebugString();
446   const int64_t request_id = request->request_id();
447   const int64_t step_id = request->step_id();
448 
449   bool cache_enabled = (response_cache_ != nullptr && request_id != 0);
450 
451   auto do_response = [response, done, cache_enabled](const Tensor& tensor,
452                                                      bool is_dead,
453                                                      const Status& status) {
454     if (status.ok()) {
455       grpc::EncodeTensorToByteBuffer(is_dead, tensor, cache_enabled, response);
456     }
457     done(status);
458   };
459 
460   // If response cache is enabled and the response cache already contains the
461   // request, we delegate this retry request to the response cache. Otherwise,
462   // we add the request to the response cache and start the computation to
463   // retrieve the requested data.
464   if (cache_enabled &&
465       response_cache_->QueueRequest(request_id, step_id, do_response)) {
466     return;
467   }
468 
469   auto rendezvous_done = [this, request_id, do_response, cache_enabled](
470                              const Tensor& tensor, bool is_dead,
471                              const Status& status) {
472     if (cache_enabled) {
473       // Data is ready. Process all pending requests in the response cache.
474       response_cache_->OnRequestFinished(request_id, tensor, is_dead, status);
475     } else {
476       do_response(tensor, is_dead, status);
477     }
478   };
479 
480   auto fail = [&rendezvous_done](const Status& status) {
481     rendezvous_done(Tensor(), false, status);
482   };
483 
484   Status s = recent_request_ids_.TrackUnique(
485       request_id, "RecvTensor (GrpcWorker)", *request);
486   if (!s.ok()) {
487     fail(s);
488     return;
489   }
490 
491   const string& key = request->rendezvous_key();
492   TRACEPRINTF("RecvTensor: %lld %s", step_id, key.c_str());
493   Rendezvous::ParsedKey parsed;
494   s = Rendezvous::ParseKey(key, &parsed);
495   Device* src_dev = nullptr;
496   if (s.ok()) {
497     s = PrepareRecvTensor(parsed, &src_dev);
498   }
499   if (!s.ok()) {
500     fail(s);
501     return;
502   }
503 
504   // Request the tensor associated with the rendezvous key.
505   // Any time while waiting for the tensor to be produced, up until the start of
506   // execution of the callback lambda body below, an RPC cancellation should
507   // abort the rendezvous.
508   // Note that gRPC can generate cancellations in response to transient network
509   // failures, and the client might not observe any errors or cancellations but
510   // simply waits for the responses. Aborting the step would report an error to
511   // the client, and avoid permanent hanging in distributed function execution.
512   opts->SetCancelCallback([this, step_id]() {
513     LOG(WARNING) << "RecvTensor cancelled for " << step_id;
514     AbortStep(step_id);
515   });
516   env_->rendezvous_mgr->RecvLocalAsync(
517       step_id, parsed,
518       [opts, rendezvous_done, src_dev, request](
519           const Status& status, const Rendezvous::Args& send_args,
520           const Rendezvous::Args& recv_args, const Tensor& val,
521           const bool is_dead) {
522         opts->ClearCancelCallback();
523         if (status.ok()) {
524           // DMA can only be used for Tensors that do not fall into
525           // the following three odd edge cases: 1) a zero-size
526           // buffer, 2) a dead tensor which has an uninit value, and
527           // 3) the tensor has the on_host allocation attribute,
528           // i.e. it's in CPU RAM *independent of its assigned
529           // device type*.
530           const bool on_host = send_args.alloc_attrs.on_host();
531           {
532             // Non-DMA cases.
533             if (src_dev->tensorflow_accelerator_device_info() && (!on_host)) {
534               DeviceContext* send_dev_context = send_args.device_context;
535               AllocatorAttributes alloc_attrs;
536               alloc_attrs.set_gpu_compatible(true);
537               alloc_attrs.set_on_host(true);
538               Allocator* alloc = src_dev->GetAllocator(alloc_attrs);
539               Tensor* copy = new Tensor(alloc, val.dtype(), val.shape());
540               CHECK(send_dev_context)
541                   << "send dev name: " << src_dev->name() << " gpu_info: "
542                   << src_dev->tensorflow_accelerator_device_info();
543               // "val" is on an accelerator device. Uses the device_context to
544               // fill the copy on host.
545               StatusCallback copy_ready = [rendezvous_done, copy,
546                                            is_dead](const Status& s) {
547                 // The value is now ready to be returned on the wire.
548                 rendezvous_done(*copy, is_dead, s);
549                 delete copy;
550               };
551 
552               CopyDeviceToHost(&val, alloc, alloc, request->rendezvous_key(),
553                                src_dev, copy, send_dev_context, copy_ready);
554               return;
555             }
556           }
557         }
558 
559         rendezvous_done(val, is_dead, status);
560       });
561 }
562 
563 namespace {
564 // If RecvBufRespExtra.tensor_content is a single large string, then gRPC
565 // can stall on the recv side when the string buffer needs to be enlarged,
566 // since the size is not sent in advance.  Changing this field to a sequence
567 // of small strings costs some extra time on the send side, since we do
568 // some otherwise unnecessary copies, but it improves runtime overall by
569 // improving flow control.  Best performance is likely achieved with a
570 // max_chunk_bytes equal to the memory page size.
571 //
572 // TODO(tucker): When proto3 supports [ctype=CORD] then change
573 // RecvBufRespExtra.tensor_content to a cord instead of a repeated string,
574 // and remove this function.
SetTensorInRecvBufResp(int64_t max_chunk_bytes,const Tensor * tensor,RecvBufResponse * response)575 void SetTensorInRecvBufResp(int64_t max_chunk_bytes, const Tensor* tensor,
576                             RecvBufResponse* response) {
577   RecvBufRespExtra extra;
578   int64_t num_bytes = tensor->TotalBytes();
579   const char* head = reinterpret_cast<const char*>(DMAHelper::base(tensor));
580   while (num_bytes > 0) {
581     int64_t bytes =
582         max_chunk_bytes > 0 ? std::min(num_bytes, max_chunk_bytes) : num_bytes;
583     extra.add_tensor_content(std::string(head, bytes));
584     head += bytes;
585     num_bytes -= bytes;
586   }
587   response->mutable_transport_options()->PackFrom(extra);
588 }
589 }  // namespace
590 
RecvBufAsync(CallOptions * opts,const RecvBufRequest * request,RecvBufResponse * response,StatusCallback done)591 void GrpcWorker::RecvBufAsync(CallOptions* opts, const RecvBufRequest* request,
592                               RecvBufResponse* response, StatusCallback done) {
593   const int64_t request_id = request->request_id();
594   const int64_t step_id = request->step_id();
595   bool cache_enabled = (response_cache_ != nullptr && request_id != 0);
596 
597   auto do_response = [this, response, done, cache_enabled](
598                          const Tensor& tensor, bool is_dead,
599                          const Status& status) {
600     if (status.ok()) {
601       SetTensorInRecvBufResp(recv_buf_max_chunk_, &tensor, response);
602     }
603     response->set_send_start_micros(env_->env->NowMicros());
604     response->set_require_ack(cache_enabled);
605     done(status);
606   };
607 
608   // If response cache is enabled and the response cache already contains the
609   // request, we delegate this retry request to the response cache. Otherwise,
610   // we add the request to the response cache and start the computation to
611   // retrieve the requested data.
612   if (cache_enabled &&
613       response_cache_->QueueRequest(request_id, step_id, do_response)) {
614     return;
615   }
616 
617   auto rendezvous_done = [this, request_id, do_response, cache_enabled](
618                              const Tensor& tensor, const Status& status) {
619     if (cache_enabled) {
620       // Data is ready. Process all pending requests in the response cache.
621       response_cache_->OnRequestFinished(request_id, tensor, false, status);
622     } else {
623       do_response(tensor, false, status);
624     }
625   };
626 
627   auto fail = [&rendezvous_done](const Status& status) {
628     rendezvous_done(Tensor(), status);
629   };
630 
631   // This is a generic, low performance implementation appropriate for grpc.
632   Status s = recent_request_ids_.TrackUnique(request_id, "RecvBuf (GrpcWorker)",
633                                              *request);
634   if (!s.ok()) {
635     fail(s);
636     return;
637   }
638 
639   CollectiveExecutor::Handle ce_handle(
640       env_->collective_executor_mgr->FindOrCreate(step_id), true);
641   CollectiveRemoteAccess* rma = ce_handle.get()->remote_access();
642   auto consumer_callback = [this, request, rendezvous_done](
643                                const Status& status,
644                                BufRendezvous::Hook* hook) {
645     Status s = status;
646     if (s.ok()) {
647       if (hook == nullptr) {
648         s = errors::Internal("Invalid null hook for key ",
649                              request->buf_rendezvous_key());
650       }
651       if (!DMAHelper::CanUseDMA(hook->prod_value)) {
652         s = errors::Internal("Tensor value for key ",
653                              request->buf_rendezvous_key(),
654                              " is not of a type supported by RecvBuf");
655       }
656     } else {
657       if (hook != nullptr) {
658         LOG(ERROR) << "Got hook " << hook << " with status " << s
659                    << " from ConsumeBuf";
660       }
661     }
662 
663     if (s.ok()) {
664       // The RPC source tensor needs to be in CPU RAM.  If not already
665       // there make a copy using memory appropriate to the purpose.
666       const size_t num_bytes = hook->prod_value->TotalBytes();
667       const bool on_host =
668           hook->prod_dev->attributes().device_type() == "CPU" ||
669           hook->prod_attr.on_host();
670       if ((!on_host) && (num_bytes > 0)) {
671         Device* cpu_dev = nullptr;
672         s = env_->device_mgr->LookupDevice("CPU:0", &cpu_dev);
673         if (s.ok()) {
674           AllocatorAttributes cpu_attr;
675           cpu_attr.set_gpu_compatible(true);
676           cpu_attr.set_nic_compatible(true);
677           profiler::ScopedMemoryDebugAnnotation op_annotation(
678               "GrpcWorker::RecvBufAsync::consumer_callback", request->step_id(),
679               "dynamic", hook->prod_value->dtype(),
680               [hook]() { return hook->prod_value->shape().DebugString(); });
681           Tensor* cpu_tensor =
682               new Tensor(cpu_dev->GetAllocator(cpu_attr),
683                          hook->prod_value->dtype(), hook->prod_value->shape());
684           hook->prod_ctx->CopyDeviceTensorToCPU(
685               hook->prod_value, "empty_name", hook->prod_dev, cpu_tensor,
686               [hook, cpu_tensor, rendezvous_done](const Status& s) {
687                 rendezvous_done(*cpu_tensor, s);
688                 BufRendezvous::DoneWithHook(hook);
689                 delete cpu_tensor;
690               });
691           return;
692         }
693       }
694     }
695 
696     if (hook == nullptr) {
697       rendezvous_done(Tensor(), s);
698     } else {
699       rendezvous_done(*hook->prod_value, s);
700       BufRendezvous::DoneWithHook(hook);
701     }
702   };
703   rma->buf_rendezvous()->ConsumeBuf(
704       request->buf_rendezvous_key(), request->src_device(),
705       request->src_incarnation(), consumer_callback,
706       /*cancellation_manager=*/nullptr);
707 }
708 
LoggingAsync(const LoggingRequest * request,LoggingResponse * response,StatusCallback done)709 void GrpcWorker::LoggingAsync(const LoggingRequest* request,
710                               LoggingResponse* response, StatusCallback done) {
711   auto env = this->env();
712   if (env) {
713     auto session_mgr = env->session_mgr;
714     if (session_mgr) {
715       if (request->enable_rpc_logging()) {
716         session_mgr->SetLogging(true);
717       }
718       // NOTE(mrry): Handle old masters that disable RPC logging by setting
719       // `request->enable_rpc_logging` to `false`.
720       if (request->disable_rpc_logging() ||
721           (!request->enable_rpc_logging() &&
722            request->fetch_step_id_size() == 0)) {
723         session_mgr->SetLogging(false);
724       }
725       for (const auto& step_id : request->fetch_step_id()) {
726         session_mgr->RetrieveLogs(step_id, response);
727       }
728       if (request->clear()) {
729         session_mgr->ClearLogs();
730       }
731     }
732   }
733   done(OkStatus());
734 }
735 
CleanupGraphAsync(const CleanupGraphRequest * request,CleanupGraphResponse * response,StatusCallback done)736 void GrpcWorker::CleanupGraphAsync(const CleanupGraphRequest* request,
737                                    CleanupGraphResponse* response,
738                                    StatusCallback done) {
739   if (response_cache_) {
740     // Cleanup any stale response cache entries for this step. This can occur if
741     // a worker crashes before acking a request.
742     response_cache_->CleanEntriesForStep(request->step_id());
743   }
744   Worker::CleanupGraphAsync(request, response, done);
745 }
746 
env()747 WorkerEnv* GrpcWorker::env() { return env_; }
748 
RemoveCacheEntryForId(int64_t request_id)749 void GrpcWorker::RemoveCacheEntryForId(int64_t request_id) {
750   if (response_cache_) {
751     response_cache_->EraseRequestId(request_id);
752   }
753 }
754 
NewGrpcWorker(WorkerEnv * env,const ConfigProto & config)755 std::unique_ptr<GrpcWorker> NewGrpcWorker(WorkerEnv* env,
756                                           const ConfigProto& config) {
757   return std::unique_ptr<GrpcWorker>(new GrpcWorker(env, config));
758 }
759 
NewGrpcWorkerService(GrpcWorker * worker,::grpc::ServerBuilder * builder,GrpcWorkerServiceOptions options)760 std::unique_ptr<AsyncServiceInterface> NewGrpcWorkerService(
761     GrpcWorker* worker, ::grpc::ServerBuilder* builder,
762     GrpcWorkerServiceOptions options) {
763   return std::unique_ptr<AsyncServiceInterface>(
764       new GrpcWorkerService(worker, builder, options));
765 }
766 
767 }  // namespace tensorflow
768