1 /* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/core/distributed_runtime/rpc/grpc_worker_service.h"
17
18 #include <deque>
19 #include <memory>
20 #include <unordered_map>
21 #include <vector>
22
23 #include "grpcpp/alarm.h"
24 #include "grpcpp/server_builder.h"
25 #include "absl/container/flat_hash_map.h"
26 #include "tensorflow/core/common_runtime/buf_rendezvous.h"
27 #include "tensorflow/core/common_runtime/copy_tensor.h"
28 #include "tensorflow/core/common_runtime/device.h"
29 #include "tensorflow/core/common_runtime/device_mgr.h"
30 #include "tensorflow/core/common_runtime/dma_helper.h"
31 #include "tensorflow/core/common_runtime/local_device.h"
32 #include "tensorflow/core/common_runtime/process_util.h"
33 #include "tensorflow/core/common_runtime/step_stats_collector.h"
34 #include "tensorflow/core/distributed_runtime/graph_mgr.h"
35 #include "tensorflow/core/distributed_runtime/rendezvous_mgr_interface.h"
36 #include "tensorflow/core/distributed_runtime/rpc/async_service_interface.h"
37 #include "tensorflow/core/distributed_runtime/rpc/grpc_call.h"
38 #include "tensorflow/core/distributed_runtime/rpc/grpc_response_cache.h"
39 #include "tensorflow/core/distributed_runtime/rpc/grpc_tensor_coding.h"
40 #include "tensorflow/core/distributed_runtime/rpc/grpc_util.h"
41 #include "tensorflow/core/distributed_runtime/rpc/grpc_worker_service_impl.h"
42 #include "tensorflow/core/distributed_runtime/worker.h"
43 #include "tensorflow/core/distributed_runtime/worker_cache.h"
44 #include "tensorflow/core/distributed_runtime/worker_session.h"
45 #include "tensorflow/core/framework/cancellation.h"
46 #include "tensorflow/core/framework/collective.h"
47 #include "tensorflow/core/framework/tensor.h"
48 #include "tensorflow/core/lib/core/errors.h"
49 #include "tensorflow/core/lib/core/status.h"
50 #include "tensorflow/core/lib/gtl/map_util.h"
51 #include "tensorflow/core/lib/strings/strcat.h"
52 #include "tensorflow/core/lib/strings/stringprintf.h"
53 #include "tensorflow/core/platform/logging.h"
54 #include "tensorflow/core/platform/mutex.h"
55 #include "tensorflow/core/platform/tracing.h"
56 #include "tensorflow/core/profiler/lib/scoped_memory_debug_annotation.h"
57 #include "tensorflow/core/protobuf/transport_options.pb.h"
58 #include "tensorflow/core/protobuf/worker.pb.h"
59
60 namespace tensorflow {
61
62 namespace {
63
64 // This macro creates a new request for the given RPC method name
65 // (e.g., `ENQUEUE_REQUEST(GetStatus, false);`), and enqueues it on
66 // `this->cq_`.
67 //
68 // This macro is invoked one or more times for each RPC method to
69 // ensure that there are sufficient completion queue entries to
70 // handle incoming requests without blocking.
71 //
72 // The implementation of the request handler for each RPC method
73 // must ensure that it calls ENQUEUE_REQUEST() for that RPC method,
74 // to keep accepting new requests.
75 #define ENQUEUE_REQUEST(method, supports_cancel) \
76 do { \
77 mutex_lock l(shutdown_mu_); \
78 if (!is_shutdown_) { \
79 Call<GrpcWorkerServiceThread, grpc::WorkerService::AsyncService, \
80 method##Request, method##Response>:: \
81 EnqueueRequestForMethod( \
82 worker_service_, cq_.get(), \
83 static_cast<int>(GrpcWorkerMethod::k##method), \
84 &GrpcWorkerServiceThread::method##Handler, (supports_cancel)); \
85 } \
86 } while (0)
87
88 #define SETUP_FOR_REQUEST(method, default_depth, supports_cancel) \
89 for (int i = 0; \
90 i < gtl::FindWithDefault(queue_depth_, \
91 static_cast<int>(GrpcWorkerMethod::k##method), \
92 default_depth); \
93 ++i) { \
94 ENQUEUE_REQUEST(method, supports_cancel); \
95 }
96
97 // GrpcWorkerService spawns one or more GrpcWorkerServiceThreads to service
98 // requests. Each thread operates on an independent completion queue.
99 class GrpcWorkerServiceThread {
100 public:
GrpcWorkerServiceThread(GrpcWorker * worker,::grpc::ServerBuilder * builder,std::unordered_map<int,int> queue_depth,GrpcResponseCache * cache,grpc::WorkerService::AsyncService * worker_service)101 explicit GrpcWorkerServiceThread(
102 GrpcWorker* worker, ::grpc::ServerBuilder* builder,
103 std::unordered_map<int, int> queue_depth, GrpcResponseCache* cache,
104 grpc::WorkerService::AsyncService* worker_service)
105 : worker_(worker),
106 queue_depth_(queue_depth),
107 cache_(cache),
108 worker_service_(worker_service),
109 is_shutdown_(false) {
110 cq_ = builder->AddCompletionQueue();
111 }
112
Start()113 void Start() {
114 thread_.reset(
115 worker_->env()->env->StartThread(ThreadOptions(), "grpc_worker_service",
116 [this]() { HandleRPCsLoop(); }));
117 }
118
Join()119 void Join() { thread_.reset(); } // Blocks until thread exits
120
Shutdown()121 void Shutdown() {
122 {
123 mutex_lock lock(shutdown_mu_);
124 is_shutdown_ = true;
125 }
126 cq_->Shutdown();
127 }
128
129 private:
130 // Add one or more completion queue entries for each worker method, then
131 // begin servicing requests from the completion queue.
HandleRPCsLoop()132 void HandleRPCsLoop() {
133 // TODO(ncteisen): This may require performance engineering. We can
134 // change the number of threads, the number of handlers per thread,
135 // or even decide to specialize certain threads to certain methods.
136 SETUP_FOR_REQUEST(GetStatus, 1, false);
137 SETUP_FOR_REQUEST(CreateWorkerSession, 1, false);
138 SETUP_FOR_REQUEST(DeleteWorkerSession, 1, false);
139 SETUP_FOR_REQUEST(CleanupAll, 1, false);
140 SETUP_FOR_REQUEST(RegisterGraph, 1, false);
141 SETUP_FOR_REQUEST(DeregisterGraph, 1, false);
142 SETUP_FOR_REQUEST(Logging, 1, false);
143 SETUP_FOR_REQUEST(Tracing, 1, false);
144 SETUP_FOR_REQUEST(CompleteGroup, 10, true);
145 SETUP_FOR_REQUEST(CompleteInstance, 10, true);
146 SETUP_FOR_REQUEST(GetStepSequence, 10, true);
147 SETUP_FOR_REQUEST(RecvBuf, 500, true);
148 SETUP_FOR_REQUEST(RunGraph, 100, true);
149 SETUP_FOR_REQUEST(CleanupGraph, 100, false);
150 SETUP_FOR_REQUEST(MarkRecvFinished, 10, false);
151
152 // TODO(ncteisen): Determine a better policy for enqueuing the
153 // appropriate number of each request type.
154 for (int i = 0;
155 i < gtl::FindWithDefault(
156 queue_depth_, static_cast<int>(GrpcWorkerMethod::kRecvTensor),
157 1000);
158 ++i) {
159 EnqueueRecvTensorRequestRaw();
160 }
161
162 void* tag;
163 bool ok;
164
165 while (cq_->Next(&tag, &ok)) {
166 UntypedCall<GrpcWorkerServiceThread>::Tag* callback_tag =
167 static_cast<UntypedCall<GrpcWorkerServiceThread>::Tag*>(tag);
168 CHECK(callback_tag);
169 callback_tag->OnCompleted(this, ok);
170 }
171 }
172
173 private:
Schedule(std::function<void ()> f)174 void Schedule(std::function<void()> f) {
175 worker_->env()->compute_pool->Schedule(std::move(f));
176 }
177
178 // The following section contains one request handler method per
179 // RPC. The `FooHandler` method is called (indirectly) by
180 // `HandleRPCsLoop()` when the next Foo RPC is received. Each
181 // `FooHandler` call schedules a closure on `worker_->env()->compute_pool`,
182 // and is responsible for requesting the next Foo call by calling
183 // `ENQUEUE_REQUEST(Foo)`.
184 template <class RequestMessage, class ResponseMessage>
185 using WorkerCall =
186 Call<GrpcWorkerServiceThread, grpc::WorkerService::AsyncService,
187 RequestMessage, ResponseMessage>;
188
189 // Handle all non-cancellable simple methods with a standard wrapper.
190 // The boolean `may_block_on_compute_pool` indicates whether or not the
191 // operation may block on activities (such as op execution) that run on the
192 // compute pool.
193 #define HANDLE_CALL(method, may_block_on_compute_pool) \
194 void method##Handler(WorkerCall<method##Request, method##Response>* call) { \
195 auto closure = [this, call]() { \
196 Status s = worker_->method(&call->request, &call->response); \
197 if (!s.ok()) { \
198 VLOG(3) << "Bad response from " << #method << ": " << s; \
199 } \
200 call->SendResponse(ToGrpcStatus(s)); \
201 }; \
202 if ((may_block_on_compute_pool)) { \
203 worker_->env()->env->SchedClosure(std::move(closure)); \
204 } else { \
205 worker_->env()->compute_pool->Schedule(std::move(closure)); \
206 } \
207 ENQUEUE_REQUEST(method, false); \
208 }
209
210 HANDLE_CALL(GetStatus, false);
211 HANDLE_CALL(CreateWorkerSession, false);
212 HANDLE_CALL(DeleteWorkerSession, true);
213 HANDLE_CALL(CleanupAll, false);
214 HANDLE_CALL(RegisterGraph, false);
215 HANDLE_CALL(DeregisterGraph, false);
216 HANDLE_CALL(CleanupGraph, false);
217 HANDLE_CALL(Logging, false);
218 HANDLE_CALL(Tracing, false);
219
220 #undef HANDLE_CALL
221
GetStepSequenceHandler(WorkerCall<GetStepSequenceRequest,GetStepSequenceResponse> * call)222 void GetStepSequenceHandler(
223 WorkerCall<GetStepSequenceRequest, GetStepSequenceResponse>* call) {
224 Schedule([this, call]() {
225 worker_->GetStepSequenceAsync(
226 &call->request, &call->response, [call](const Status& s) {
227 VLOG(3) << "Bad response from GetStepSequence:" << s;
228 call->SendResponse(ToGrpcStatus(s));
229 });
230 });
231 ENQUEUE_REQUEST(GetStepSequence, true);
232 }
233
MarkRecvFinishedHandler(WorkerCall<MarkRecvFinishedRequest,MarkRecvFinishedResponse> * call)234 void MarkRecvFinishedHandler(
235 WorkerCall<MarkRecvFinishedRequest, MarkRecvFinishedResponse>* call) {
236 VLOG(3) << "Clean cache entry for request " << call->request.request_id();
237 worker_->RemoveCacheEntryForId(call->request.request_id());
238 call->SendResponse(::grpc::Status::OK);
239 ENQUEUE_REQUEST(MarkRecvFinished, false);
240 }
241
RunGraphHandler(WorkerCall<RunGraphRequest,RunGraphResponse> * call)242 void RunGraphHandler(WorkerCall<RunGraphRequest, RunGraphResponse>* call) {
243 Schedule([this, call]() {
244 CallOptions* call_opts = new CallOptions;
245 ProtoRunGraphRequest* wrapped_request =
246 new ProtoRunGraphRequest(&call->request);
247 NonOwnedProtoRunGraphResponse* wrapped_response =
248 new NonOwnedProtoRunGraphResponse(&call->response);
249 call->SetCancelCallback([call_opts]() { call_opts->StartCancel(); });
250 worker_->RunGraphAsync(call_opts, wrapped_request, wrapped_response,
251 [call, call_opts, wrapped_request,
252 wrapped_response](const Status& s) {
253 VLOG(3) << "RunGraph::Done";
254 if (!s.ok()) {
255 VLOG(3) << "Bad response from RunGraph:" << s;
256 }
257 call->ClearCancelCallback();
258 delete call_opts;
259 delete wrapped_request;
260 delete wrapped_response;
261 call->SendResponse(ToGrpcStatus(s));
262 });
263 });
264 ENQUEUE_REQUEST(RunGraph, true);
265 }
266
RecvTensorHandlerRaw(WorkerCall<RecvTensorRequest,::grpc::ByteBuffer> * call)267 void RecvTensorHandlerRaw(
268 WorkerCall<RecvTensorRequest, ::grpc::ByteBuffer>* call) {
269 Schedule([this, call]() {
270 CallOptions* call_opts = new CallOptions;
271 call->SetCancelCallback([call_opts]() { call_opts->StartCancel(); });
272
273 worker_->GrpcRecvTensorAsync(
274 call_opts, &call->request, &call->response,
275 [call, call_opts](const Status& s) {
276 call->ClearCancelCallback();
277 delete call_opts;
278 if (!s.ok()) {
279 VLOG(3) << "Bad response from RecvTensor:" << s;
280 }
281 call->SendResponse(ToGrpcStatus(s));
282 });
283 });
284 EnqueueRecvTensorRequestRaw();
285 }
286
RecvBufHandler(WorkerCall<RecvBufRequest,RecvBufResponse> * call)287 void RecvBufHandler(WorkerCall<RecvBufRequest, RecvBufResponse>* call) {
288 Schedule([this, call]() {
289 CallOptions* call_opts = new CallOptions;
290 call->SetCancelCallback([call_opts]() { call_opts->StartCancel(); });
291 worker_->RecvBufAsync(call_opts, &call->request, &call->response,
292 [call, call_opts](const Status& s) {
293 call->ClearCancelCallback();
294 delete call_opts;
295 if (!s.ok()) {
296 VLOG(3) << "Bad response from RecvBuf:" << s;
297 }
298 call->SendResponse(ToGrpcStatus(s));
299 });
300 });
301 ENQUEUE_REQUEST(RecvBuf, true);
302 }
303
CompleteGroupHandler(WorkerCall<CompleteGroupRequest,CompleteGroupResponse> * call)304 void CompleteGroupHandler(
305 WorkerCall<CompleteGroupRequest, CompleteGroupResponse>* call) {
306 Schedule([this, call]() {
307 CallOptions* call_opts = new CallOptions;
308 call->SetCancelCallback([call_opts]() { call_opts->StartCancel(); });
309 worker_->CompleteGroupAsync(
310 call_opts, &call->request, &call->response,
311 [call, call_opts](const Status& s) {
312 call->ClearCancelCallback();
313 delete call_opts;
314 if (!s.ok()) {
315 VLOG(3) << "Bad response from CompleteGroup:" << s;
316 }
317 call->SendResponse(ToGrpcStatus(s));
318 });
319 });
320 ENQUEUE_REQUEST(CompleteGroup, true);
321 }
322
CompleteInstanceHandler(WorkerCall<CompleteInstanceRequest,CompleteInstanceResponse> * call)323 void CompleteInstanceHandler(
324 WorkerCall<CompleteInstanceRequest, CompleteInstanceResponse>* call) {
325 Schedule([this, call]() {
326 CallOptions* call_opts = new CallOptions;
327 call->SetCancelCallback([call_opts]() { call_opts->StartCancel(); });
328 worker_->CompleteInstanceAsync(
329 call_opts, &call->request, &call->response,
330 [call, call_opts](const Status& s) {
331 call->ClearCancelCallback();
332 delete call_opts;
333 if (!s.ok()) {
334 VLOG(3) << "Bad response from CompleteInstance:" << s;
335 }
336 call->SendResponse(ToGrpcStatus(s));
337 });
338 });
339 ENQUEUE_REQUEST(CompleteInstance, false);
340 }
341 #undef ENQUEUE_REQUEST
342
EnqueueRecvTensorRequestRaw()343 void EnqueueRecvTensorRequestRaw() {
344 mutex_lock l(shutdown_mu_);
345 if (!is_shutdown_) {
346 Call<GrpcWorkerServiceThread, grpc::WorkerService::AsyncService,
347 RecvTensorRequest, ::grpc::ByteBuffer>::
348 EnqueueRequestForMethod(
349 worker_service_, cq_.get(),
350 static_cast<int>(GrpcWorkerMethod::kRecvTensor),
351 &GrpcWorkerServiceThread::RecvTensorHandlerRaw,
352 true /* supports cancel*/);
353 }
354 }
355
356 GrpcWorker* const worker_ = nullptr; // Not owned.
357 std::unique_ptr<::grpc::ServerCompletionQueue> cq_;
358 std::unique_ptr<Thread> thread_;
359 std::unordered_map<int, int> queue_depth_;
360 GrpcResponseCache* cache_;
361 grpc::WorkerService::AsyncService* const worker_service_;
362
363 mutex shutdown_mu_;
364 bool is_shutdown_ TF_GUARDED_BY(shutdown_mu_);
365 TF_DISALLOW_COPY_AND_ASSIGN(GrpcWorkerServiceThread);
366 };
367
368 class GrpcWorkerService : public AsyncServiceInterface {
369 public:
GrpcWorkerService(GrpcWorker * worker,::grpc::ServerBuilder * builder,GrpcWorkerServiceOptions options)370 GrpcWorkerService(GrpcWorker* worker, ::grpc::ServerBuilder* builder,
371 GrpcWorkerServiceOptions options)
372 : is_shutdown_(false) {
373 builder->RegisterService(&worker_service_);
374
375 for (int i = 0; i < options.num_serving_threads; i++) {
376 threads_.emplace_back(
377 new GrpcWorkerServiceThread(worker, builder, options.queue_depth,
378 cache_.get(), &worker_service_));
379 }
380 }
381
Shutdown()382 void Shutdown() override {
383 bool did_shutdown = false;
384 {
385 mutex_lock l(service_shutdown_mu_);
386 if (!is_shutdown_) {
387 LOG(INFO) << "Shutting down GrpcWorkerService.";
388 is_shutdown_ = true;
389 did_shutdown = true;
390 }
391 }
392 if (did_shutdown) {
393 for (auto& worker_thread : threads_) {
394 worker_thread->Shutdown();
395 }
396 }
397 }
398
399 // This method blocks forever handling requests from the completion queue.
HandleRPCsLoop()400 void HandleRPCsLoop() override {
401 for (auto& worker_thread : threads_) {
402 worker_thread->Start();
403 }
404 for (auto& worker_thread : threads_) {
405 worker_thread->Join();
406 }
407 }
408
409 private:
410 grpc::WorkerService::AsyncService worker_service_;
411 std::vector<std::unique_ptr<GrpcWorkerServiceThread>> threads_;
412
413 std::unique_ptr<GrpcResponseCache> cache_;
414 mutex service_shutdown_mu_;
415 bool is_shutdown_ TF_GUARDED_BY(service_shutdown_mu_);
416
417 TF_DISALLOW_COPY_AND_ASSIGN(GrpcWorkerService);
418 };
419
420 } // namespace
421
GrpcWorker(WorkerEnv * worker_env,const ConfigProto & config)422 GrpcWorker::GrpcWorker(WorkerEnv* worker_env, const ConfigProto& config)
423 : Worker(worker_env),
424 recv_buf_max_chunk_(
425 config.experimental().recv_buf_max_chunk() > 0
426 ? config.experimental().recv_buf_max_chunk()
427 : (config.experimental().recv_buf_max_chunk() < 0 ? 0 : 4096)) {
428 if (config.rpc_options().cache_rpc_response()) {
429 EnableResponseCache();
430 }
431 }
432
EnableResponseCache()433 void GrpcWorker::EnableResponseCache() {
434 VLOG(3) << "Enabling gRPC tensor response cache.";
435 response_cache_ = std::make_unique<GrpcResponseCache>();
436 }
437
438 // GrpcRecvTensorAsync: unlike the other Worker methods, which use protocol
439 // buffers for a response object, to avoid extra protocol buffer serialization
440 // overhead we generate our response directly into a ::grpc::ByteBuffer object
GrpcRecvTensorAsync(CallOptions * opts,const RecvTensorRequest * request,::grpc::ByteBuffer * response,StatusCallback done)441 void GrpcWorker::GrpcRecvTensorAsync(CallOptions* opts,
442 const RecvTensorRequest* request,
443 ::grpc::ByteBuffer* response,
444 StatusCallback done) {
445 VLOG(3) << "GrpcRecvTensorAsync req: " << request->DebugString();
446 const int64_t request_id = request->request_id();
447 const int64_t step_id = request->step_id();
448
449 bool cache_enabled = (response_cache_ != nullptr && request_id != 0);
450
451 auto do_response = [response, done, cache_enabled](const Tensor& tensor,
452 bool is_dead,
453 const Status& status) {
454 if (status.ok()) {
455 grpc::EncodeTensorToByteBuffer(is_dead, tensor, cache_enabled, response);
456 }
457 done(status);
458 };
459
460 // If response cache is enabled and the response cache already contains the
461 // request, we delegate this retry request to the response cache. Otherwise,
462 // we add the request to the response cache and start the computation to
463 // retrieve the requested data.
464 if (cache_enabled &&
465 response_cache_->QueueRequest(request_id, step_id, do_response)) {
466 return;
467 }
468
469 auto rendezvous_done = [this, request_id, do_response, cache_enabled](
470 const Tensor& tensor, bool is_dead,
471 const Status& status) {
472 if (cache_enabled) {
473 // Data is ready. Process all pending requests in the response cache.
474 response_cache_->OnRequestFinished(request_id, tensor, is_dead, status);
475 } else {
476 do_response(tensor, is_dead, status);
477 }
478 };
479
480 auto fail = [&rendezvous_done](const Status& status) {
481 rendezvous_done(Tensor(), false, status);
482 };
483
484 Status s = recent_request_ids_.TrackUnique(
485 request_id, "RecvTensor (GrpcWorker)", *request);
486 if (!s.ok()) {
487 fail(s);
488 return;
489 }
490
491 const string& key = request->rendezvous_key();
492 TRACEPRINTF("RecvTensor: %lld %s", step_id, key.c_str());
493 Rendezvous::ParsedKey parsed;
494 s = Rendezvous::ParseKey(key, &parsed);
495 Device* src_dev = nullptr;
496 if (s.ok()) {
497 s = PrepareRecvTensor(parsed, &src_dev);
498 }
499 if (!s.ok()) {
500 fail(s);
501 return;
502 }
503
504 // Request the tensor associated with the rendezvous key.
505 // Any time while waiting for the tensor to be produced, up until the start of
506 // execution of the callback lambda body below, an RPC cancellation should
507 // abort the rendezvous.
508 // Note that gRPC can generate cancellations in response to transient network
509 // failures, and the client might not observe any errors or cancellations but
510 // simply waits for the responses. Aborting the step would report an error to
511 // the client, and avoid permanent hanging in distributed function execution.
512 opts->SetCancelCallback([this, step_id]() {
513 LOG(WARNING) << "RecvTensor cancelled for " << step_id;
514 AbortStep(step_id);
515 });
516 env_->rendezvous_mgr->RecvLocalAsync(
517 step_id, parsed,
518 [opts, rendezvous_done, src_dev, request](
519 const Status& status, const Rendezvous::Args& send_args,
520 const Rendezvous::Args& recv_args, const Tensor& val,
521 const bool is_dead) {
522 opts->ClearCancelCallback();
523 if (status.ok()) {
524 // DMA can only be used for Tensors that do not fall into
525 // the following three odd edge cases: 1) a zero-size
526 // buffer, 2) a dead tensor which has an uninit value, and
527 // 3) the tensor has the on_host allocation attribute,
528 // i.e. it's in CPU RAM *independent of its assigned
529 // device type*.
530 const bool on_host = send_args.alloc_attrs.on_host();
531 {
532 // Non-DMA cases.
533 if (src_dev->tensorflow_accelerator_device_info() && (!on_host)) {
534 DeviceContext* send_dev_context = send_args.device_context;
535 AllocatorAttributes alloc_attrs;
536 alloc_attrs.set_gpu_compatible(true);
537 alloc_attrs.set_on_host(true);
538 Allocator* alloc = src_dev->GetAllocator(alloc_attrs);
539 Tensor* copy = new Tensor(alloc, val.dtype(), val.shape());
540 CHECK(send_dev_context)
541 << "send dev name: " << src_dev->name() << " gpu_info: "
542 << src_dev->tensorflow_accelerator_device_info();
543 // "val" is on an accelerator device. Uses the device_context to
544 // fill the copy on host.
545 StatusCallback copy_ready = [rendezvous_done, copy,
546 is_dead](const Status& s) {
547 // The value is now ready to be returned on the wire.
548 rendezvous_done(*copy, is_dead, s);
549 delete copy;
550 };
551
552 CopyDeviceToHost(&val, alloc, alloc, request->rendezvous_key(),
553 src_dev, copy, send_dev_context, copy_ready);
554 return;
555 }
556 }
557 }
558
559 rendezvous_done(val, is_dead, status);
560 });
561 }
562
563 namespace {
564 // If RecvBufRespExtra.tensor_content is a single large string, then gRPC
565 // can stall on the recv side when the string buffer needs to be enlarged,
566 // since the size is not sent in advance. Changing this field to a sequence
567 // of small strings costs some extra time on the send side, since we do
568 // some otherwise unnecessary copies, but it improves runtime overall by
569 // improving flow control. Best performance is likely achieved with a
570 // max_chunk_bytes equal to the memory page size.
571 //
572 // TODO(tucker): When proto3 supports [ctype=CORD] then change
573 // RecvBufRespExtra.tensor_content to a cord instead of a repeated string,
574 // and remove this function.
SetTensorInRecvBufResp(int64_t max_chunk_bytes,const Tensor * tensor,RecvBufResponse * response)575 void SetTensorInRecvBufResp(int64_t max_chunk_bytes, const Tensor* tensor,
576 RecvBufResponse* response) {
577 RecvBufRespExtra extra;
578 int64_t num_bytes = tensor->TotalBytes();
579 const char* head = reinterpret_cast<const char*>(DMAHelper::base(tensor));
580 while (num_bytes > 0) {
581 int64_t bytes =
582 max_chunk_bytes > 0 ? std::min(num_bytes, max_chunk_bytes) : num_bytes;
583 extra.add_tensor_content(std::string(head, bytes));
584 head += bytes;
585 num_bytes -= bytes;
586 }
587 response->mutable_transport_options()->PackFrom(extra);
588 }
589 } // namespace
590
RecvBufAsync(CallOptions * opts,const RecvBufRequest * request,RecvBufResponse * response,StatusCallback done)591 void GrpcWorker::RecvBufAsync(CallOptions* opts, const RecvBufRequest* request,
592 RecvBufResponse* response, StatusCallback done) {
593 const int64_t request_id = request->request_id();
594 const int64_t step_id = request->step_id();
595 bool cache_enabled = (response_cache_ != nullptr && request_id != 0);
596
597 auto do_response = [this, response, done, cache_enabled](
598 const Tensor& tensor, bool is_dead,
599 const Status& status) {
600 if (status.ok()) {
601 SetTensorInRecvBufResp(recv_buf_max_chunk_, &tensor, response);
602 }
603 response->set_send_start_micros(env_->env->NowMicros());
604 response->set_require_ack(cache_enabled);
605 done(status);
606 };
607
608 // If response cache is enabled and the response cache already contains the
609 // request, we delegate this retry request to the response cache. Otherwise,
610 // we add the request to the response cache and start the computation to
611 // retrieve the requested data.
612 if (cache_enabled &&
613 response_cache_->QueueRequest(request_id, step_id, do_response)) {
614 return;
615 }
616
617 auto rendezvous_done = [this, request_id, do_response, cache_enabled](
618 const Tensor& tensor, const Status& status) {
619 if (cache_enabled) {
620 // Data is ready. Process all pending requests in the response cache.
621 response_cache_->OnRequestFinished(request_id, tensor, false, status);
622 } else {
623 do_response(tensor, false, status);
624 }
625 };
626
627 auto fail = [&rendezvous_done](const Status& status) {
628 rendezvous_done(Tensor(), status);
629 };
630
631 // This is a generic, low performance implementation appropriate for grpc.
632 Status s = recent_request_ids_.TrackUnique(request_id, "RecvBuf (GrpcWorker)",
633 *request);
634 if (!s.ok()) {
635 fail(s);
636 return;
637 }
638
639 CollectiveExecutor::Handle ce_handle(
640 env_->collective_executor_mgr->FindOrCreate(step_id), true);
641 CollectiveRemoteAccess* rma = ce_handle.get()->remote_access();
642 auto consumer_callback = [this, request, rendezvous_done](
643 const Status& status,
644 BufRendezvous::Hook* hook) {
645 Status s = status;
646 if (s.ok()) {
647 if (hook == nullptr) {
648 s = errors::Internal("Invalid null hook for key ",
649 request->buf_rendezvous_key());
650 }
651 if (!DMAHelper::CanUseDMA(hook->prod_value)) {
652 s = errors::Internal("Tensor value for key ",
653 request->buf_rendezvous_key(),
654 " is not of a type supported by RecvBuf");
655 }
656 } else {
657 if (hook != nullptr) {
658 LOG(ERROR) << "Got hook " << hook << " with status " << s
659 << " from ConsumeBuf";
660 }
661 }
662
663 if (s.ok()) {
664 // The RPC source tensor needs to be in CPU RAM. If not already
665 // there make a copy using memory appropriate to the purpose.
666 const size_t num_bytes = hook->prod_value->TotalBytes();
667 const bool on_host =
668 hook->prod_dev->attributes().device_type() == "CPU" ||
669 hook->prod_attr.on_host();
670 if ((!on_host) && (num_bytes > 0)) {
671 Device* cpu_dev = nullptr;
672 s = env_->device_mgr->LookupDevice("CPU:0", &cpu_dev);
673 if (s.ok()) {
674 AllocatorAttributes cpu_attr;
675 cpu_attr.set_gpu_compatible(true);
676 cpu_attr.set_nic_compatible(true);
677 profiler::ScopedMemoryDebugAnnotation op_annotation(
678 "GrpcWorker::RecvBufAsync::consumer_callback", request->step_id(),
679 "dynamic", hook->prod_value->dtype(),
680 [hook]() { return hook->prod_value->shape().DebugString(); });
681 Tensor* cpu_tensor =
682 new Tensor(cpu_dev->GetAllocator(cpu_attr),
683 hook->prod_value->dtype(), hook->prod_value->shape());
684 hook->prod_ctx->CopyDeviceTensorToCPU(
685 hook->prod_value, "empty_name", hook->prod_dev, cpu_tensor,
686 [hook, cpu_tensor, rendezvous_done](const Status& s) {
687 rendezvous_done(*cpu_tensor, s);
688 BufRendezvous::DoneWithHook(hook);
689 delete cpu_tensor;
690 });
691 return;
692 }
693 }
694 }
695
696 if (hook == nullptr) {
697 rendezvous_done(Tensor(), s);
698 } else {
699 rendezvous_done(*hook->prod_value, s);
700 BufRendezvous::DoneWithHook(hook);
701 }
702 };
703 rma->buf_rendezvous()->ConsumeBuf(
704 request->buf_rendezvous_key(), request->src_device(),
705 request->src_incarnation(), consumer_callback,
706 /*cancellation_manager=*/nullptr);
707 }
708
LoggingAsync(const LoggingRequest * request,LoggingResponse * response,StatusCallback done)709 void GrpcWorker::LoggingAsync(const LoggingRequest* request,
710 LoggingResponse* response, StatusCallback done) {
711 auto env = this->env();
712 if (env) {
713 auto session_mgr = env->session_mgr;
714 if (session_mgr) {
715 if (request->enable_rpc_logging()) {
716 session_mgr->SetLogging(true);
717 }
718 // NOTE(mrry): Handle old masters that disable RPC logging by setting
719 // `request->enable_rpc_logging` to `false`.
720 if (request->disable_rpc_logging() ||
721 (!request->enable_rpc_logging() &&
722 request->fetch_step_id_size() == 0)) {
723 session_mgr->SetLogging(false);
724 }
725 for (const auto& step_id : request->fetch_step_id()) {
726 session_mgr->RetrieveLogs(step_id, response);
727 }
728 if (request->clear()) {
729 session_mgr->ClearLogs();
730 }
731 }
732 }
733 done(OkStatus());
734 }
735
CleanupGraphAsync(const CleanupGraphRequest * request,CleanupGraphResponse * response,StatusCallback done)736 void GrpcWorker::CleanupGraphAsync(const CleanupGraphRequest* request,
737 CleanupGraphResponse* response,
738 StatusCallback done) {
739 if (response_cache_) {
740 // Cleanup any stale response cache entries for this step. This can occur if
741 // a worker crashes before acking a request.
742 response_cache_->CleanEntriesForStep(request->step_id());
743 }
744 Worker::CleanupGraphAsync(request, response, done);
745 }
746
env()747 WorkerEnv* GrpcWorker::env() { return env_; }
748
RemoveCacheEntryForId(int64_t request_id)749 void GrpcWorker::RemoveCacheEntryForId(int64_t request_id) {
750 if (response_cache_) {
751 response_cache_->EraseRequestId(request_id);
752 }
753 }
754
NewGrpcWorker(WorkerEnv * env,const ConfigProto & config)755 std::unique_ptr<GrpcWorker> NewGrpcWorker(WorkerEnv* env,
756 const ConfigProto& config) {
757 return std::unique_ptr<GrpcWorker>(new GrpcWorker(env, config));
758 }
759
NewGrpcWorkerService(GrpcWorker * worker,::grpc::ServerBuilder * builder,GrpcWorkerServiceOptions options)760 std::unique_ptr<AsyncServiceInterface> NewGrpcWorkerService(
761 GrpcWorker* worker, ::grpc::ServerBuilder* builder,
762 GrpcWorkerServiceOptions options) {
763 return std::unique_ptr<AsyncServiceInterface>(
764 new GrpcWorkerService(worker, builder, options));
765 }
766
767 } // namespace tensorflow
768