1 /* Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #ifndef TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_RPC_GRPC_STATE_H_ 17 #define TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_RPC_GRPC_STATE_H_ 18 19 #include <queue> 20 #include <utility> 21 22 #include "grpcpp/generic/generic_stub.h" 23 #include "grpcpp/grpcpp.h" 24 #include "tensorflow/core/distributed_runtime/call_options.h" 25 #include "tensorflow/core/distributed_runtime/rpc/grpc_client_cq_tag.h" 26 #include "tensorflow/core/distributed_runtime/rpc/grpc_util.h" 27 #include "tensorflow/core/distributed_runtime/tensor_coding.h" 28 #include "tensorflow/core/lib/core/refcount.h" 29 #include "tensorflow/core/lib/core/status.h" 30 #include "tensorflow/core/lib/core/threadpool.h" 31 #include "tensorflow/core/lib/strings/strcat.h" 32 #include "tensorflow/core/platform/errors.h" 33 #include "tensorflow/core/platform/mutex.h" 34 #include "tensorflow/core/platform/notification.h" 35 #include "tensorflow/core/util/env_var.h" 36 37 namespace tensorflow { 38 39 // Object allocated per active RPC. 40 // Manage the state of a single asynchronous RPC request. If `max_retries` 41 // is greater than 0, the request will be retried for any transient failures. 42 template <class Response> 43 class RPCState : public GrpcClientCQTag { 44 public: 45 RPCState(::grpc::GenericStub* stub, ::grpc::CompletionQueue* cq, 46 const ::grpc::string& method, const protobuf::Message& request, 47 Response* response, StatusCallback done, CallOptions* call_opts, 48 thread::ThreadPool* threadpool, int32_t max_retries = 0, 49 bool fail_fast = true, const string* target = nullptr) 50 : RPCState( 51 stub, cq, method, request, response, std::move(done), call_opts, 52 threadpool, 53 // 1) If GRPC_FAIL_FAST is set to 'true' or 'false', 54 // fail_fast=$GRPC_FAIL_FAST. See b/141948186. 55 // 2) Otherwise if GRPC_FAIL_FAST is set to 'use_caller', use the 56 // fail_fast from the caller. See b/140260119. 57 // 58 // Current default: use caller's fail_fast argument. 59 // 60 // NOTE: Callers mostly set fail_fast=true to prevent job hanging 61 // on worker task failures, except a few cases such as GetStatus 62 // in cluster initialization and collective param resolution. 63 [fail_fast, &done]() -> bool { 64 string fail_fast_env; 65 TF_CHECK_OK(ReadStringFromEnvVar("GRPC_FAIL_FAST", "use_caller", 66 &fail_fast_env)); 67 string fail_fast_env_lower = absl::AsciiStrToLower(fail_fast_env); 68 if (fail_fast_env_lower == "true") { 69 return true; 70 } else if (fail_fast_env_lower == "use_caller") { 71 return fail_fast; 72 } else if (fail_fast_env_lower == "false") { 73 return false; 74 } else { 75 string error_message = strings::StrCat( 76 "Invalid GRPC_FAIL_FAST config: ", fail_fast_env); 77 LOG(WARNING) << error_message; 78 done(errors::InvalidArgument(error_message)); 79 return false; 80 } 81 }(), 82 (call_opts != nullptr ? call_opts->GetTimeout() : 0), max_retries, 83 target) {} 84 85 template <typename Request> RPCState(::grpc::GenericStub * stub,::grpc::CompletionQueue * cq,const::grpc::string & method,const Request & request,Response * response,StatusCallback done,CallOptions * call_opts,thread::ThreadPool * threadpool,bool fail_fast,int64_t timeout_in_ms,int32_t max_retries,const string * target)86 RPCState(::grpc::GenericStub* stub, ::grpc::CompletionQueue* cq, 87 const ::grpc::string& method, const Request& request, 88 Response* response, StatusCallback done, CallOptions* call_opts, 89 thread::ThreadPool* threadpool, bool fail_fast, 90 int64_t timeout_in_ms, int32_t max_retries, const string* target) 91 : call_opts_(call_opts), 92 threadpool_(threadpool), 93 done_(std::move(done)), 94 timeout_in_ms_(timeout_in_ms), 95 max_retries_(max_retries), 96 cq_(cq), 97 stub_(stub), 98 method_(method), 99 fail_fast_(fail_fast), 100 target_(target) { 101 response_ = response; 102 ::grpc::Status s = GrpcMaybeUnparseProto(request, &request_buf_); 103 if (!s.ok()) { 104 LOG(ERROR) << "GrpcMaybeUnparseProto returned with non-ok status: " 105 << s.error_message(); 106 // Skip retry logic if we fail to parse our request. 107 done_(FromGrpcStatus(s)); 108 delete this; 109 return; 110 } 111 StartCall(); 112 } 113 StartCall()114 void StartCall() { 115 context_.reset(new ::grpc::ClientContext()); 116 context_->set_wait_for_ready(!fail_fast_); 117 if (timeout_in_ms_ > 0) { 118 context_->set_deadline( 119 gpr_time_from_millis(timeout_in_ms_, GPR_TIMESPAN)); 120 } 121 if (call_opts_) { 122 call_opts_->SetCancelCallback([this]() { context_->TryCancel(); }); 123 } 124 125 VLOG(2) << "Starting call: " << method_; 126 127 call_ = stub_->PrepareUnaryCall(context_.get(), method_, request_buf_, cq_); 128 call_->StartCall(); 129 call_->Finish(&response_buf_, &status_, this); 130 } 131 OnCompleted(bool ok)132 void OnCompleted(bool ok) override { 133 if (call_opts_) { 134 call_opts_->ClearCancelCallback(); 135 } 136 137 VLOG(2) << "Completed call: " << method_; 138 139 Status s = FromGrpcStatus(status_); 140 if (s.ok() && !ok) { 141 // Since this function is only being used for processing the response 142 // to Finish for client-side unary calls, ok should never be false 143 s.Update( 144 errors::Internal("GRPC status is okay but CompletionQueueStatus is " 145 "not. This should never happen.")); 146 } 147 148 if (s.ok()) { 149 if (threadpool_) { 150 // Run parse and callback in another thread, returning this 151 // one to service more RPCs. 152 threadpool_->Schedule([this]() { ParseAndCallDone(); }); 153 } else { 154 ParseAndCallDone(); 155 } 156 return; 157 } 158 159 VLOG(1) << method_ << " returned with non-ok status: " << s 160 << " Retries: " << num_retries_ << " Max: " << max_retries_ << "\n" 161 << context_->debug_error_string(); 162 // Retry if we have any attempts left 163 if (++num_retries_ <= max_retries_ && 164 (errors::IsUnavailable(s) || errors::IsUnknown(s))) { 165 response_buf_.Clear(); 166 VLOG(1) << "Retrying call for " << method_ << "Retry: " << num_retries_ 167 << " of " << max_retries_; 168 169 ComputeRetryBackoffMs(/*min_backoff_ms=*/1, /*max_backoff_ms=*/10000); 170 int64_t backoff_us = retry_backoff_ms_ * 1000; 171 Env::Default()->SchedClosureAfter(/*micros=*/backoff_us, 172 [this]() { StartCall(); }); 173 } else { 174 // Attach additional GRPC error information if any to the final status 175 string error_msg = s.error_message(); 176 strings::StrAppend(&error_msg, "\nAdditional GRPC error information"); 177 if (target_) { 178 strings::StrAppend(&error_msg, " from remote target ", *target_); 179 } 180 strings::StrAppend(&error_msg, ":\n:", context_->debug_error_string()); 181 s = errors::CreateWithUpdatedMessage(s, error_msg); 182 // Always treat gRPC cancellation as a derived error. This ensures that 183 // other error types are preferred during status aggregation. (gRPC 184 // cancellation messages do not contain the original status message). 185 if (s.code() == tensorflow::error::Code::CANCELLED) { 186 s = StatusGroup::MakeDerived(s); 187 } 188 189 done_(s); 190 delete this; 191 } 192 } 193 ParseAndCallDone()194 void ParseAndCallDone() { 195 Status s; 196 if (!GrpcMaybeParseProto(&response_buf_, response_)) { 197 s.Update(errors::Internal("could not parse rpc response")); 198 } 199 done_(s); 200 delete this; 201 } 202 203 private: ComputeRetryBackoffMs(int min_backoff_ms,int max_backoff_ms)204 void ComputeRetryBackoffMs(int min_backoff_ms, int max_backoff_ms) { 205 constexpr float kBackoffBase = 1.3; 206 if (retry_backoff_ms_ < 0) { 207 retry_backoff_ms_ = min_backoff_ms; 208 } else { 209 retry_backoff_ms_ *= kBackoffBase; 210 if (retry_backoff_ms_ > max_backoff_ms) { 211 retry_backoff_ms_ = max_backoff_ms; 212 } 213 } 214 } 215 216 CallOptions* call_opts_; 217 std::unique_ptr<::grpc::ClientContext> context_; 218 thread::ThreadPool* threadpool_; 219 std::unique_ptr<::grpc::GenericClientAsyncResponseReader> call_; 220 Response* response_; 221 ::grpc::ByteBuffer request_buf_; 222 ::grpc::ByteBuffer response_buf_; 223 ::grpc::Status status_; 224 StatusCallback done_; 225 int64_t timeout_in_ms_; 226 227 size_t num_retries_ = 0; 228 size_t max_retries_; 229 double retry_backoff_ms_ = -1; 230 231 ::grpc::CompletionQueue* cq_; 232 ::grpc::GenericStub* stub_; 233 ::grpc::string method_; 234 bool fail_fast_; 235 const string* target_; 236 }; 237 238 // Represents state associated with one streaming RPC call. 239 // Similarly to above, we extract the methods of StreamingRPCState that don't 240 // need to be templated into this abstract class. 241 // Currently, *StreamingRPCState does not support client closing the call as 242 // there is no use case for it - current clients keep the streaming call open 243 // as long as possible. If/when the need arises, support can be added 244 // by calling GenericClientAsyncReaderWriter::WritesDone with a new tag 245 // TagType::kClientFinished and handling the completion in a new callback. 246 class UntypedStreamingRPCState : public core::RefCounted { 247 public: 248 virtual void CallStarted(bool ok) = 0; 249 virtual void RequestWriteCompleted(bool ok) = 0; 250 virtual void ResponseReadCompleted(bool ok) = 0; 251 virtual void CallFinished(bool ok) = 0; 252 253 virtual string DebugString() const = 0; 254 255 class Tag : public GrpcClientCQTag { 256 public: 257 // One enum value per supported callback. 258 enum class TagType { 259 kCallStarted, 260 kRequestWriteCompleted, 261 kResponseReadCompleted, 262 kCallFinished, 263 }; 264 265 Tag(UntypedStreamingRPCState* streaming_state, Tag::TagType type); 266 267 // Calls the callback associated with this tag and Unrefs 268 // `this->streaming_state_`. 269 void OnCompleted(bool ok) override; 270 271 private: 272 // OnCompleted() consumes on reference each time it is called. 273 UntypedStreamingRPCState* const streaming_state_; 274 const Tag::TagType type_; 275 }; 276 }; 277 278 const char* ToString(UntypedStreamingRPCState::Tag::TagType tag_type); 279 280 // Represents a single request/response exchange between client and the server. 281 // A single streaming call contains a sequence of exchanges. Besides the 282 // messages, exchange contains: 283 // - the user callback to invoke when exchange completes (response is received 284 // or an error occurs). 285 // - The current state of the exchange. 286 class Exchange { 287 public: 288 enum class State { 289 kExchangeCreated, 290 kRequestWriteIssued, 291 kRequestWriteCompleted, 292 kResponseReadIssued, 293 }; 294 Exchange(const::grpc::ByteBuffer & request_buf,protobuf::Message * response,StatusCallback cb,string debug_string)295 Exchange(const ::grpc::ByteBuffer& request_buf, protobuf::Message* response, 296 StatusCallback cb, string debug_string) 297 : state_(State::kExchangeCreated), 298 request_buf_(request_buf), 299 response_(response), 300 cb_(std::move(cb)), 301 debug_string_(std::move(debug_string)) {} 302 request_buf()303 const ::grpc::ByteBuffer& request_buf() { return request_buf_; } response_buf()304 ::grpc::ByteBuffer* response_buf() { return &response_buf_; } 305 MarkRequestWriteIssued()306 void MarkRequestWriteIssued() { 307 DCHECK(state_ == State::kExchangeCreated); 308 state_ = State::kRequestWriteIssued; 309 } MarkRequestWriteCompleted()310 void MarkRequestWriteCompleted() { 311 DCHECK(state_ == State::kRequestWriteIssued); 312 state_ = State::kRequestWriteCompleted; 313 } MarkResponseReadIssued()314 void MarkResponseReadIssued() { 315 DCHECK(state_ == State::kRequestWriteCompleted); 316 state_ = State::kResponseReadIssued; 317 } 318 319 // If `status` is success, completes this exchange by parsing the 320 // response_buf_ and invoking cb_ with Status::OK(). Else, invokes the 321 // callback with `status`. 322 void Complete(Status status); 323 state()324 const State& state() const { return state_; } 325 326 string DebugString() const; 327 328 private: 329 State state_; 330 ::grpc::ByteBuffer request_buf_; 331 ::grpc::ByteBuffer response_buf_; 332 protobuf::Message* response_; 333 StatusCallback cb_; 334 string debug_string_; 335 }; 336 337 const char* ToString(Exchange::State s); 338 339 std::ostream& operator<<(std::ostream& os, const Exchange::State& state); 340 341 // Represents a queue of exchanges. 342 // When a client sends a new request a new exchange is created and added to the 343 // end of the queue. Completed exchanges are popped from the front of the queue. 344 // An explicit exchange queue is needed to brdige the client, which can send new 345 // requests at any time, with gRPC infrastructure, which can handle a single 346 // read and a single write request at a time. 347 // 348 // As the exchange progresses (request sending initiated, request sending 349 // completed, response reading initiated) the queue helps to make sure that the 350 // right operation is issued on the right exchange at the right time. 351 // 352 // To satisfy gRPC constraints, the states of exchanges must be as follows 353 // starting from the front of the queue: 354 // - 0 or 1 exchange in kResponseReadIssued state 355 // - 0 or more exchanges in kRequestWriteCompleted state 356 // - 0 or 1 exchange in kRequestWriteIssued state 357 // - 0 or more exchanges in kExchangeCreated state 358 // 359 // Thread-compatible. 360 class ExchangeQueue { 361 public: 362 // Creates a new exchange and adds it to the end of the queue. 363 void Emplace(const ::grpc::ByteBuffer& request_buf, 364 protobuf::Message* response, StatusCallback cb, 365 std::string debug_string); 366 367 // Returns an exchange for which we can initiate request writing, if any. 368 // Returns nullptr if there is no such exchange. 369 Exchange* GetReadyForRequestWriting(); 370 371 // Returns an exchange for which we can initiate response reading, if any. 372 // Returns nullptr if there is no such exchange. 373 Exchange* GetReadyForResponseReading(); 374 375 // Changes the state of the exchange that is current in kRequestWriteIssued 376 // state to kRequestWriteCompleted state. 377 // REQUIRES: There is an exchange in kRequestWriteIssued state. 378 void MarkRequestWriteCompleted(); 379 380 // Returns the exchange at the front of the queue. 381 // REQUIRES: ExchangeQueue is not empty. 382 Exchange& GetFront(); 383 384 // Removes the exchange at the front of the queue. 385 // REQUIRES: ExchangeQueue is not empty. 386 void PopFront(); 387 388 // Returns a string containing addresses and states of all exchanges in this 389 // queue. 390 string DebugString() const; 391 392 // Swaps the contents of this and `other`. 393 void Swap(ExchangeQueue* other); 394 395 // Completes all exchanges in this with `status`. 396 void CompleteAll(Status status); 397 CallStarted()398 void CallStarted() { call_started_ = true; } 399 400 private: 401 // Does nothing by default. Turn on VLOG(5) to enable. 402 // Checks that this ExchangeQueue is in a valid state. 403 // Kills the process if not. 404 void CheckInvariants(); 405 406 // We can't process any exchanges until the call has started. 407 bool call_started_ = false; 408 409 // std::queue is based on std::deque by default. std::deque provides 410 // fairly strong iterator stability. 411 std::deque<Exchange> exchanges_; 412 }; // namespace tensorflow 413 414 // Represents state associated with one streaming RPC call. 415 // Thread-safe 416 template <class Response> 417 class StreamingRPCState : public UntypedStreamingRPCState { 418 public: 419 // Default behavior is to set fail_fast = False and handle timeouts 420 // manually. StreamingRPCState(std::unique_ptr<::grpc::GenericClientAsyncReaderWriter> call,const std::shared_ptr<::grpc::ClientContext> & context)421 StreamingRPCState( 422 std::unique_ptr<::grpc::GenericClientAsyncReaderWriter> call, 423 const std::shared_ptr<::grpc::ClientContext>& context) 424 : context_(context), call_(std::move(call)), call_state_(State::kActive) { 425 Ref(); 426 VLOG(3) << "Created new StreamingRPCState " << this; 427 VLOG(3) << "StreamingRPCState(" << this << ") calling grpc::StartCall"; 428 call_->StartCall(&call_started_tag_); 429 } 430 ~StreamingRPCState()431 ~StreamingRPCState() override { 432 VLOG(3) << "Destructing StreamingRPCState " << this; 433 } 434 435 // Attempts to send the next request. `done` is invoked when 436 // `response` has been filled with the data from the server, or if there 437 // is an error. `done` can be invoked before SendNextRequest returns. 438 // Return `true` if the call is alive and the `done` callback has or 439 // will be invoked. If the call is dead, returns `false`. `done` callback 440 // will not be invoked in this case. 441 // REQUIRES: The call has been started, i.e. WaitForCallStarted() has 442 // returned. SendNextRequest(const protobuf::Message & request,Response * response,const StatusCallback & done)443 bool SendNextRequest(const protobuf::Message& request, Response* response, 444 const StatusCallback& done) { 445 ::grpc::ByteBuffer request_buf; 446 ::grpc::Status s = GrpcMaybeUnparseProto(request, &request_buf); 447 if (!s.ok()) { 448 Status status = FromGrpcStatus(s); 449 LOG(ERROR) << "GrpcMaybeUnparseProto returned with non-ok status: " 450 << status.ToString(); 451 done(status); 452 return true; 453 } 454 455 mutex_lock l(mu_); 456 if (call_state_ != State::kActive) { 457 // `done` is not invoked intentionally. 458 return false; 459 } 460 if (VLOG_IS_ON(3)) { 461 // If vlog 3 is enabled, include first 100 chars of request as debug 462 // string. 463 exchanges_.Emplace(request_buf, response, done, 464 request.ShortDebugString().substr(0, 100)); 465 } else { 466 exchanges_.Emplace(request_buf, response, done, ""); 467 } 468 MaybeIssueRequestWriteLocked(); 469 return true; 470 } 471 CallStarted(bool ok)472 void CallStarted(bool ok) override { 473 VLOG(3) << "StreamingRPCState(" << this << ")::CallStarted(ok=" << ok 474 << ")"; 475 mutex_lock l(mu_); 476 if (!ok) { 477 call_state_ = State::kDone; 478 return; 479 } 480 exchanges_.CallStarted(); 481 // Now that the call has started, we can write our first request, if any. 482 MaybeIssueRequestWriteLocked(); 483 } 484 RequestWriteCompleted(bool ok)485 void RequestWriteCompleted(bool ok) override { 486 VLOG(3) << "StreamingRPCState(" << this 487 << ")::RequestWriteCompleted(ok=" << ok << ")"; 488 mu_.lock(); 489 if (call_state_ != State::kActive) { 490 mu_.unlock(); 491 return; 492 } 493 exchanges_.MarkRequestWriteCompleted(); 494 // Issue ResponseRead regardless of OK status on completing RequestWrite. 495 // If the underlying completion queue is in Not-OK status due to previous 496 // request failuress (i.e., `ok` from `Next` call on completion queue is 497 // False), delay the error in ResponseRead so we can get the remote error 498 // message from response buffer. 499 MaybeIssueResponseReadLocked(); 500 501 if (ok) { 502 MaybeIssueRequestWriteLocked(); 503 } 504 mu_.unlock(); 505 } 506 ResponseReadCompleted(bool ok)507 void ResponseReadCompleted(bool ok) override { 508 VLOG(3) << "StreamingRPCState(" << this 509 << ")::ResponseReadCompleted(ok=" << ok << ")"; 510 mu_.lock(); 511 if (call_state_ != State::kActive) { 512 mu_.unlock(); 513 return; 514 } 515 if (!ok) { 516 IssueCallFinishLocked(); 517 mu_.unlock(); 518 return; 519 } 520 521 // Complete the exchange without holding the lock because user's 522 // callback can call back into this RPC code resulting in a deadlock. 523 // No other thread can pop this exchange while we release the lock because 524 // this is the only method that pops exchanges and it is called from a 525 // single thread that waits on completion queue events. 526 Exchange* e; 527 e = &exchanges_.GetFront(); 528 mu_.unlock(); 529 530 e->Complete(OkStatus()); 531 532 { 533 mutex_lock l(mu_); 534 exchanges_.PopFront(); 535 MaybeIssueResponseReadLocked(); 536 } 537 } 538 CallFinished(bool ok)539 void CallFinished(bool ok) override { 540 VLOG(3) << "StreamingRPCState(" << this << ")::CallFinished(ok=" << ok 541 << ")"; 542 mu_.lock(); 543 DCHECK(call_state_ != State::kActive); 544 if (call_state_ != State::kFinishing) { 545 mu_.unlock(); 546 return; 547 } 548 549 Status s = FromGrpcStatus(call_status_); 550 if (s.ok() && !ok) { 551 s.Update( 552 errors::Internal("GRPC status is okay but CompletionQueueStatus is " 553 "not. This should never happen.", 554 context_->debug_error_string())); 555 } 556 // unlocks mu_ 557 MarkDoneAndCompleteExchanges(s); 558 } 559 DebugString()560 string DebugString() const override { 561 mutex_lock l(mu_); 562 return exchanges_.DebugString(); 563 } 564 565 private: 566 enum class State { 567 kActive, 568 kFinishing, 569 kDone, 570 }; 571 MarkDoneAndCompleteExchanges(Status status)572 void MarkDoneAndCompleteExchanges(Status status) 573 TF_EXCLUSIVE_LOCKS_REQUIRED(mu_) TF_UNLOCK_FUNCTION(mu_) { 574 call_state_ = State::kDone; 575 VLOG(2) << "Ending gRPC streaming call on the client side due to " 576 << status.ToString(); 577 // Swap the exchanges_ into a temporary ExchangeQueue so that we can 578 // complete all exchanges without holding mu_ in case user callback 579 // reach back into this. This should be impossible now, but safer for 580 // the future. 581 ExchangeQueue queue; 582 exchanges_.Swap(&queue); 583 mu_.unlock(); 584 queue.CompleteAll(status); 585 } 586 MaybeIssueRequestWriteLocked()587 void MaybeIssueRequestWriteLocked() TF_EXCLUSIVE_LOCKS_REQUIRED(mu_) { 588 Exchange* exchange = exchanges_.GetReadyForRequestWriting(); 589 if (exchange == nullptr) { 590 // There are no queued exchanges, there is already an outstanding write, 591 // or there are no just created exchanges. 592 return; 593 } 594 exchange->MarkRequestWriteIssued(); 595 Ref(); 596 VLOG(3) << "StreamingRPCState(" << this << ") calling grpc::Write"; 597 call_->Write(exchange->request_buf(), &request_write_completed_tag_); 598 } 599 MaybeIssueResponseReadLocked()600 void MaybeIssueResponseReadLocked() TF_EXCLUSIVE_LOCKS_REQUIRED(mu_) { 601 Exchange* exchange = exchanges_.GetReadyForResponseReading(); 602 if (exchange == nullptr) { 603 return; 604 } 605 exchange->MarkResponseReadIssued(); 606 Ref(); 607 VLOG(3) << "StreamingRPCState(" << this << ") calling grpc::Read"; 608 call_->Read(exchange->response_buf(), &response_read_completed_tag_); 609 } 610 IssueCallFinishLocked()611 void IssueCallFinishLocked() TF_EXCLUSIVE_LOCKS_REQUIRED(mu_) { 612 call_state_ = State::kFinishing; 613 Ref(); 614 VLOG(3) << "StreamingRPCState(" << this << ") calling grpc::Finish"; 615 // We call finish in response to completed (with error) response reading tag 616 // on some exchange. We let this exchange hang in ResponseReadIssued state. 617 // ExchangeQueue makes sure that there is at most one exchange in this 618 // state. So, no new reads will be issued. 619 call_->Finish(&call_status_, &finished_tag_); 620 } 621 622 // Holds state for a single request/response exchange between the client 623 // and the server. 624 typedef typename UntypedStreamingRPCState::Tag Tag; 625 626 // Order of context_ and call_ is important because context_ must outlive 627 // call_. 628 const std::shared_ptr<const ::grpc::ClientContext> context_; 629 std::unique_ptr<::grpc::GenericClientAsyncReaderWriter> call_; 630 631 mutable mutex mu_; 632 ExchangeQueue exchanges_ TF_GUARDED_BY(mu_); 633 State call_state_ TF_GUARDED_BY(mu_); 634 ::grpc::Status call_status_ TF_GUARDED_BY(mu_); 635 636 // We can get away with having single instances of these tags per 637 // StreamingRPCState because we make sure (as gRPC requires) that 638 // there is at most one outstanding Read and at most one outstanding Write 639 // in the completion queue. 640 // Tags are immutable. No need to guard them. 641 Tag call_started_tag_{this, Tag::TagType::kCallStarted}; 642 Tag request_write_completed_tag_{this, Tag::TagType::kRequestWriteCompleted}; 643 Tag response_read_completed_tag_{this, Tag::TagType::kResponseReadCompleted}; 644 Tag finished_tag_{this, Tag::TagType::kCallFinished}; 645 }; 646 647 // Creates streaming calls and dispatches requests to them. 648 // In the common case, the client would create a StreamingRPCDispatcher for 649 // each bidirectional streaming RPC it might want to make. The first time, it 650 // calls SendNextRequest, a streaming call is initiated and the request is 651 // sent within this call. Initiation of the call blocks the client. If there are 652 // no errors, subsequent calls to SendNextRequest would use the already active 653 // call. If there was an error, the call object will be destroyed after all 654 // the callbacks for outstanding requests have been invoked. The next call to 655 // SendNextRequest will initiate a new call. 656 // 657 // Callbacks that are part of the same call, are invoked in the order they were 658 // provided, but callbacks across calls (a failed and a new one) can be invoked 659 // in any order. 660 // 661 // Thread-safe. 662 template <class Response> 663 class StreamingRPCDispatcher { 664 public: StreamingRPCDispatcher(::grpc::GenericStub * stub,::grpc::CompletionQueue * cq,const::grpc::string & method)665 StreamingRPCDispatcher(::grpc::GenericStub* stub, ::grpc::CompletionQueue* cq, 666 const ::grpc::string& method) 667 : stub_(stub), cq_(cq), method_(method) {} 668 669 // Attempts to send the next request. If there is no active streaming call, 670 // starts one and sends the request on top of it. `done` is invoked when 671 // `response` has been filled with the data from the server, or if there 672 // is an error. `done` can be invoked before SendNextRequest returns. SendNextRequest(const protobuf::Message & request,Response * response,StatusCallback done)673 void SendNextRequest(const protobuf::Message& request, Response* response, 674 StatusCallback done) { 675 mutex_lock l(mu_); 676 if (state_ == nullptr) { 677 CreateStreamingState(); 678 } 679 680 bool is_call_alive = state_->SendNextRequest(request, response, done); 681 if (is_call_alive) { 682 return; 683 } 684 685 // The attempt to send failed because the call was dead, create a new 686 // call and try again. When the call is dead SendNextRequest does not call 687 // `done`. 688 CreateStreamingState(); 689 690 is_call_alive = state_->SendNextRequest(request, response, done); 691 if (!is_call_alive) { 692 // Consider retrying to create and start a call few more times. 693 done(errors::Unknown("gRPC call failed right after it was created")); 694 } 695 } 696 697 // Request to cancel the current streaming call. Non-blocking. CancelCall()698 void CancelCall() { 699 mutex_lock l(mu_); 700 if (state_ == nullptr) { 701 return; 702 } 703 context_->TryCancel(); 704 state_ = nullptr; 705 } 706 707 private: CreateStreamingState()708 void CreateStreamingState() TF_EXCLUSIVE_LOCKS_REQUIRED(mu_) { 709 // ClientContext cannot be reused across calls. 710 context_ = std::make_shared<::grpc::ClientContext>(); 711 // Don't immediately fail StartCall if the channel is not ready. Wait for 712 // the channel to become ready. 713 context_->set_wait_for_ready(true); 714 715 std::unique_ptr<::grpc::GenericClientAsyncReaderWriter> call = 716 stub_->PrepareCall(context_.get(), method_, cq_); 717 718 state_.reset(new StreamingRPCState<Response>(std::move(call), context_)); 719 } 720 721 mutable mutex mu_; 722 723 // Both are thread-safe 724 ::grpc::GenericStub* const stub_; 725 ::grpc::CompletionQueue* const cq_; 726 727 // Does not need synchronization since it is constant. 728 const ::grpc::string method_; 729 730 std::shared_ptr<::grpc::ClientContext> context_ TF_GUARDED_BY(mu_); 731 core::RefCountPtr<StreamingRPCState<Response>> state_ TF_GUARDED_BY(mu_); 732 }; 733 734 } // namespace tensorflow 735 736 #endif // TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_RPC_GRPC_STATE_H_ 737