xref: /aosp_15_r20/external/tensorflow/tensorflow/core/distributed_runtime/rpc/grpc_state.h (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #ifndef TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_RPC_GRPC_STATE_H_
17 #define TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_RPC_GRPC_STATE_H_
18 
19 #include <queue>
20 #include <utility>
21 
22 #include "grpcpp/generic/generic_stub.h"
23 #include "grpcpp/grpcpp.h"
24 #include "tensorflow/core/distributed_runtime/call_options.h"
25 #include "tensorflow/core/distributed_runtime/rpc/grpc_client_cq_tag.h"
26 #include "tensorflow/core/distributed_runtime/rpc/grpc_util.h"
27 #include "tensorflow/core/distributed_runtime/tensor_coding.h"
28 #include "tensorflow/core/lib/core/refcount.h"
29 #include "tensorflow/core/lib/core/status.h"
30 #include "tensorflow/core/lib/core/threadpool.h"
31 #include "tensorflow/core/lib/strings/strcat.h"
32 #include "tensorflow/core/platform/errors.h"
33 #include "tensorflow/core/platform/mutex.h"
34 #include "tensorflow/core/platform/notification.h"
35 #include "tensorflow/core/util/env_var.h"
36 
37 namespace tensorflow {
38 
39 // Object allocated per active RPC.
40 // Manage the state of a single asynchronous RPC request.  If `max_retries`
41 // is greater than 0, the request will be retried for any transient failures.
42 template <class Response>
43 class RPCState : public GrpcClientCQTag {
44  public:
45   RPCState(::grpc::GenericStub* stub, ::grpc::CompletionQueue* cq,
46            const ::grpc::string& method, const protobuf::Message& request,
47            Response* response, StatusCallback done, CallOptions* call_opts,
48            thread::ThreadPool* threadpool, int32_t max_retries = 0,
49            bool fail_fast = true, const string* target = nullptr)
50       : RPCState(
51             stub, cq, method, request, response, std::move(done), call_opts,
52             threadpool,
53             // 1) If GRPC_FAIL_FAST is set to 'true' or 'false',
54             // fail_fast=$GRPC_FAIL_FAST. See b/141948186.
55             // 2) Otherwise if GRPC_FAIL_FAST is set to 'use_caller', use the
56             // fail_fast from the caller. See b/140260119.
57             //
58             // Current default: use caller's fail_fast argument.
59             //
60             // NOTE: Callers mostly set fail_fast=true to prevent job hanging
61             // on worker task failures, except a few cases such as GetStatus
62             // in cluster initialization and collective param resolution.
63             [fail_fast, &done]() -> bool {
64               string fail_fast_env;
65               TF_CHECK_OK(ReadStringFromEnvVar("GRPC_FAIL_FAST", "use_caller",
66                                                &fail_fast_env));
67               string fail_fast_env_lower = absl::AsciiStrToLower(fail_fast_env);
68               if (fail_fast_env_lower == "true") {
69                 return true;
70               } else if (fail_fast_env_lower == "use_caller") {
71                 return fail_fast;
72               } else if (fail_fast_env_lower == "false") {
73                 return false;
74               } else {
75                 string error_message = strings::StrCat(
76                     "Invalid GRPC_FAIL_FAST config: ", fail_fast_env);
77                 LOG(WARNING) << error_message;
78                 done(errors::InvalidArgument(error_message));
79                 return false;
80               }
81             }(),
82             (call_opts != nullptr ? call_opts->GetTimeout() : 0), max_retries,
83             target) {}
84 
85   template <typename Request>
RPCState(::grpc::GenericStub * stub,::grpc::CompletionQueue * cq,const::grpc::string & method,const Request & request,Response * response,StatusCallback done,CallOptions * call_opts,thread::ThreadPool * threadpool,bool fail_fast,int64_t timeout_in_ms,int32_t max_retries,const string * target)86   RPCState(::grpc::GenericStub* stub, ::grpc::CompletionQueue* cq,
87            const ::grpc::string& method, const Request& request,
88            Response* response, StatusCallback done, CallOptions* call_opts,
89            thread::ThreadPool* threadpool, bool fail_fast,
90            int64_t timeout_in_ms, int32_t max_retries, const string* target)
91       : call_opts_(call_opts),
92         threadpool_(threadpool),
93         done_(std::move(done)),
94         timeout_in_ms_(timeout_in_ms),
95         max_retries_(max_retries),
96         cq_(cq),
97         stub_(stub),
98         method_(method),
99         fail_fast_(fail_fast),
100         target_(target) {
101     response_ = response;
102     ::grpc::Status s = GrpcMaybeUnparseProto(request, &request_buf_);
103     if (!s.ok()) {
104       LOG(ERROR) << "GrpcMaybeUnparseProto returned with non-ok status: "
105                  << s.error_message();
106       // Skip retry logic if we fail to parse our request.
107       done_(FromGrpcStatus(s));
108       delete this;
109       return;
110     }
111     StartCall();
112   }
113 
StartCall()114   void StartCall() {
115     context_.reset(new ::grpc::ClientContext());
116     context_->set_wait_for_ready(!fail_fast_);
117     if (timeout_in_ms_ > 0) {
118       context_->set_deadline(
119           gpr_time_from_millis(timeout_in_ms_, GPR_TIMESPAN));
120     }
121     if (call_opts_) {
122       call_opts_->SetCancelCallback([this]() { context_->TryCancel(); });
123     }
124 
125     VLOG(2) << "Starting call: " << method_;
126 
127     call_ = stub_->PrepareUnaryCall(context_.get(), method_, request_buf_, cq_);
128     call_->StartCall();
129     call_->Finish(&response_buf_, &status_, this);
130   }
131 
OnCompleted(bool ok)132   void OnCompleted(bool ok) override {
133     if (call_opts_) {
134       call_opts_->ClearCancelCallback();
135     }
136 
137     VLOG(2) << "Completed call: " << method_;
138 
139     Status s = FromGrpcStatus(status_);
140     if (s.ok() && !ok) {
141       // Since this function is only being used for processing the response
142       // to Finish for client-side unary calls, ok should never be false
143       s.Update(
144           errors::Internal("GRPC status is okay but CompletionQueueStatus is "
145                            "not.  This should never happen."));
146     }
147 
148     if (s.ok()) {
149       if (threadpool_) {
150         // Run parse and callback in another thread, returning this
151         // one to service more RPCs.
152         threadpool_->Schedule([this]() { ParseAndCallDone(); });
153       } else {
154         ParseAndCallDone();
155       }
156       return;
157     }
158 
159     VLOG(1) << method_ << " returned with non-ok status: " << s
160             << " Retries: " << num_retries_ << " Max: " << max_retries_ << "\n"
161             << context_->debug_error_string();
162     // Retry if we have any attempts left
163     if (++num_retries_ <= max_retries_ &&
164         (errors::IsUnavailable(s) || errors::IsUnknown(s))) {
165       response_buf_.Clear();
166       VLOG(1) << "Retrying call for " << method_ << "Retry: " << num_retries_
167               << " of " << max_retries_;
168 
169       ComputeRetryBackoffMs(/*min_backoff_ms=*/1, /*max_backoff_ms=*/10000);
170       int64_t backoff_us = retry_backoff_ms_ * 1000;
171       Env::Default()->SchedClosureAfter(/*micros=*/backoff_us,
172                                         [this]() { StartCall(); });
173     } else {
174       // Attach additional GRPC error information if any to the final status
175       string error_msg = s.error_message();
176       strings::StrAppend(&error_msg, "\nAdditional GRPC error information");
177       if (target_) {
178         strings::StrAppend(&error_msg, " from remote target ", *target_);
179       }
180       strings::StrAppend(&error_msg, ":\n:", context_->debug_error_string());
181       s = errors::CreateWithUpdatedMessage(s, error_msg);
182       // Always treat gRPC cancellation as a derived error. This ensures that
183       // other error types are preferred during status aggregation. (gRPC
184       // cancellation messages do not contain the original status message).
185       if (s.code() == tensorflow::error::Code::CANCELLED) {
186         s = StatusGroup::MakeDerived(s);
187       }
188 
189       done_(s);
190       delete this;
191     }
192   }
193 
ParseAndCallDone()194   void ParseAndCallDone() {
195     Status s;
196     if (!GrpcMaybeParseProto(&response_buf_, response_)) {
197       s.Update(errors::Internal("could not parse rpc response"));
198     }
199     done_(s);
200     delete this;
201   }
202 
203  private:
ComputeRetryBackoffMs(int min_backoff_ms,int max_backoff_ms)204   void ComputeRetryBackoffMs(int min_backoff_ms, int max_backoff_ms) {
205     constexpr float kBackoffBase = 1.3;
206     if (retry_backoff_ms_ < 0) {
207       retry_backoff_ms_ = min_backoff_ms;
208     } else {
209       retry_backoff_ms_ *= kBackoffBase;
210       if (retry_backoff_ms_ > max_backoff_ms) {
211         retry_backoff_ms_ = max_backoff_ms;
212       }
213     }
214   }
215 
216   CallOptions* call_opts_;
217   std::unique_ptr<::grpc::ClientContext> context_;
218   thread::ThreadPool* threadpool_;
219   std::unique_ptr<::grpc::GenericClientAsyncResponseReader> call_;
220   Response* response_;
221   ::grpc::ByteBuffer request_buf_;
222   ::grpc::ByteBuffer response_buf_;
223   ::grpc::Status status_;
224   StatusCallback done_;
225   int64_t timeout_in_ms_;
226 
227   size_t num_retries_ = 0;
228   size_t max_retries_;
229   double retry_backoff_ms_ = -1;
230 
231   ::grpc::CompletionQueue* cq_;
232   ::grpc::GenericStub* stub_;
233   ::grpc::string method_;
234   bool fail_fast_;
235   const string* target_;
236 };
237 
238 // Represents state associated with one streaming RPC call.
239 // Similarly to above, we extract the methods of StreamingRPCState that don't
240 // need to be templated into this abstract class.
241 // Currently, *StreamingRPCState does not support client closing the call as
242 // there is no use case for it - current clients keep the streaming call open
243 // as long as possible. If/when the need arises, support can be added
244 // by calling GenericClientAsyncReaderWriter::WritesDone with a new tag
245 // TagType::kClientFinished and handling the completion in a new callback.
246 class UntypedStreamingRPCState : public core::RefCounted {
247  public:
248   virtual void CallStarted(bool ok) = 0;
249   virtual void RequestWriteCompleted(bool ok) = 0;
250   virtual void ResponseReadCompleted(bool ok) = 0;
251   virtual void CallFinished(bool ok) = 0;
252 
253   virtual string DebugString() const = 0;
254 
255   class Tag : public GrpcClientCQTag {
256    public:
257     // One enum value per supported callback.
258     enum class TagType {
259       kCallStarted,
260       kRequestWriteCompleted,
261       kResponseReadCompleted,
262       kCallFinished,
263     };
264 
265     Tag(UntypedStreamingRPCState* streaming_state, Tag::TagType type);
266 
267     // Calls the callback associated with this tag and Unrefs
268     // `this->streaming_state_`.
269     void OnCompleted(bool ok) override;
270 
271    private:
272     // OnCompleted() consumes on reference each time it is called.
273     UntypedStreamingRPCState* const streaming_state_;
274     const Tag::TagType type_;
275   };
276 };
277 
278 const char* ToString(UntypedStreamingRPCState::Tag::TagType tag_type);
279 
280 // Represents a single request/response exchange between client and the server.
281 // A single streaming call contains a sequence of exchanges. Besides the
282 // messages, exchange contains:
283 //  - the user callback to invoke when exchange completes (response is received
284 //    or an error occurs).
285 //  - The current state of the exchange.
286 class Exchange {
287  public:
288   enum class State {
289     kExchangeCreated,
290     kRequestWriteIssued,
291     kRequestWriteCompleted,
292     kResponseReadIssued,
293   };
294 
Exchange(const::grpc::ByteBuffer & request_buf,protobuf::Message * response,StatusCallback cb,string debug_string)295   Exchange(const ::grpc::ByteBuffer& request_buf, protobuf::Message* response,
296            StatusCallback cb, string debug_string)
297       : state_(State::kExchangeCreated),
298         request_buf_(request_buf),
299         response_(response),
300         cb_(std::move(cb)),
301         debug_string_(std::move(debug_string)) {}
302 
request_buf()303   const ::grpc::ByteBuffer& request_buf() { return request_buf_; }
response_buf()304   ::grpc::ByteBuffer* response_buf() { return &response_buf_; }
305 
MarkRequestWriteIssued()306   void MarkRequestWriteIssued() {
307     DCHECK(state_ == State::kExchangeCreated);
308     state_ = State::kRequestWriteIssued;
309   }
MarkRequestWriteCompleted()310   void MarkRequestWriteCompleted() {
311     DCHECK(state_ == State::kRequestWriteIssued);
312     state_ = State::kRequestWriteCompleted;
313   }
MarkResponseReadIssued()314   void MarkResponseReadIssued() {
315     DCHECK(state_ == State::kRequestWriteCompleted);
316     state_ = State::kResponseReadIssued;
317   }
318 
319   // If `status` is success, completes this exchange by parsing the
320   // response_buf_ and invoking cb_ with Status::OK(). Else, invokes the
321   // callback with `status`.
322   void Complete(Status status);
323 
state()324   const State& state() const { return state_; }
325 
326   string DebugString() const;
327 
328  private:
329   State state_;
330   ::grpc::ByteBuffer request_buf_;
331   ::grpc::ByteBuffer response_buf_;
332   protobuf::Message* response_;
333   StatusCallback cb_;
334   string debug_string_;
335 };
336 
337 const char* ToString(Exchange::State s);
338 
339 std::ostream& operator<<(std::ostream& os, const Exchange::State& state);
340 
341 // Represents a queue of exchanges.
342 // When a client sends a new request a new exchange is created and added to the
343 // end of the queue. Completed exchanges are popped from the front of the queue.
344 // An explicit exchange queue is needed to brdige the client, which can send new
345 // requests at any time, with gRPC infrastructure, which can handle a single
346 // read and a single write request at a time.
347 //
348 // As the exchange progresses (request sending initiated, request sending
349 // completed, response reading initiated) the queue helps to make sure that the
350 // right operation is issued on the right exchange at the right time.
351 //
352 // To satisfy gRPC constraints, the states of exchanges must be as follows
353 // starting from the front of the queue:
354 //  - 0 or 1 exchange in kResponseReadIssued state
355 //  - 0 or more exchanges in kRequestWriteCompleted state
356 //  - 0 or 1 exchange in kRequestWriteIssued state
357 //  - 0 or more exchanges in kExchangeCreated state
358 //
359 // Thread-compatible.
360 class ExchangeQueue {
361  public:
362   // Creates a new exchange and adds it to the end of the queue.
363   void Emplace(const ::grpc::ByteBuffer& request_buf,
364                protobuf::Message* response, StatusCallback cb,
365                std::string debug_string);
366 
367   // Returns an exchange for which we can initiate request writing, if any.
368   // Returns nullptr if there is no such exchange.
369   Exchange* GetReadyForRequestWriting();
370 
371   // Returns an exchange for which we can initiate response reading, if any.
372   // Returns nullptr if there is no such exchange.
373   Exchange* GetReadyForResponseReading();
374 
375   // Changes the state of the exchange that is current in kRequestWriteIssued
376   // state to kRequestWriteCompleted state.
377   // REQUIRES: There is an exchange in kRequestWriteIssued state.
378   void MarkRequestWriteCompleted();
379 
380   // Returns the exchange at the front of the queue.
381   // REQUIRES: ExchangeQueue is not empty.
382   Exchange& GetFront();
383 
384   // Removes the exchange at the front of the queue.
385   // REQUIRES: ExchangeQueue is not empty.
386   void PopFront();
387 
388   // Returns a string containing addresses and states of all exchanges in this
389   // queue.
390   string DebugString() const;
391 
392   // Swaps the contents of this and `other`.
393   void Swap(ExchangeQueue* other);
394 
395   // Completes all exchanges in this with `status`.
396   void CompleteAll(Status status);
397 
CallStarted()398   void CallStarted() { call_started_ = true; }
399 
400  private:
401   // Does nothing by default. Turn on VLOG(5) to enable.
402   // Checks that this ExchangeQueue is in a valid state.
403   // Kills the process if not.
404   void CheckInvariants();
405 
406   // We can't process any exchanges until the call has started.
407   bool call_started_ = false;
408 
409   // std::queue is based on std::deque by default. std::deque provides
410   // fairly strong iterator stability.
411   std::deque<Exchange> exchanges_;
412 };  // namespace tensorflow
413 
414 // Represents state associated with one streaming RPC call.
415 // Thread-safe
416 template <class Response>
417 class StreamingRPCState : public UntypedStreamingRPCState {
418  public:
419   // Default behavior is to set fail_fast = False and handle timeouts
420   // manually.
StreamingRPCState(std::unique_ptr<::grpc::GenericClientAsyncReaderWriter> call,const std::shared_ptr<::grpc::ClientContext> & context)421   StreamingRPCState(
422       std::unique_ptr<::grpc::GenericClientAsyncReaderWriter> call,
423       const std::shared_ptr<::grpc::ClientContext>& context)
424       : context_(context), call_(std::move(call)), call_state_(State::kActive) {
425     Ref();
426     VLOG(3) << "Created new StreamingRPCState " << this;
427     VLOG(3) << "StreamingRPCState(" << this << ") calling grpc::StartCall";
428     call_->StartCall(&call_started_tag_);
429   }
430 
~StreamingRPCState()431   ~StreamingRPCState() override {
432     VLOG(3) << "Destructing StreamingRPCState " << this;
433   }
434 
435   // Attempts to send the next request. `done` is invoked when
436   // `response` has been filled with the data from the server, or if there
437   // is an error. `done` can be invoked before SendNextRequest returns.
438   // Return `true` if the call is alive and the `done` callback has or
439   // will be invoked. If the call is dead, returns `false`. `done` callback
440   // will not be invoked in this case.
441   // REQUIRES: The call has been started, i.e. WaitForCallStarted() has
442   // returned.
SendNextRequest(const protobuf::Message & request,Response * response,const StatusCallback & done)443   bool SendNextRequest(const protobuf::Message& request, Response* response,
444                        const StatusCallback& done) {
445     ::grpc::ByteBuffer request_buf;
446     ::grpc::Status s = GrpcMaybeUnparseProto(request, &request_buf);
447     if (!s.ok()) {
448       Status status = FromGrpcStatus(s);
449       LOG(ERROR) << "GrpcMaybeUnparseProto returned with non-ok status: "
450                  << status.ToString();
451       done(status);
452       return true;
453     }
454 
455     mutex_lock l(mu_);
456     if (call_state_ != State::kActive) {
457       // `done` is not invoked intentionally.
458       return false;
459     }
460     if (VLOG_IS_ON(3)) {
461       // If vlog 3 is enabled, include first 100 chars of request as debug
462       // string.
463       exchanges_.Emplace(request_buf, response, done,
464                          request.ShortDebugString().substr(0, 100));
465     } else {
466       exchanges_.Emplace(request_buf, response, done, "");
467     }
468     MaybeIssueRequestWriteLocked();
469     return true;
470   }
471 
CallStarted(bool ok)472   void CallStarted(bool ok) override {
473     VLOG(3) << "StreamingRPCState(" << this << ")::CallStarted(ok=" << ok
474             << ")";
475     mutex_lock l(mu_);
476     if (!ok) {
477       call_state_ = State::kDone;
478       return;
479     }
480     exchanges_.CallStarted();
481     // Now that the call has started, we can write our first request, if any.
482     MaybeIssueRequestWriteLocked();
483   }
484 
RequestWriteCompleted(bool ok)485   void RequestWriteCompleted(bool ok) override {
486     VLOG(3) << "StreamingRPCState(" << this
487             << ")::RequestWriteCompleted(ok=" << ok << ")";
488     mu_.lock();
489     if (call_state_ != State::kActive) {
490       mu_.unlock();
491       return;
492     }
493     exchanges_.MarkRequestWriteCompleted();
494     // Issue ResponseRead regardless of OK status on completing RequestWrite.
495     // If the underlying completion queue is in Not-OK status due to previous
496     // request failuress (i.e., `ok` from `Next` call on completion queue is
497     // False), delay the error in ResponseRead so we can get the remote error
498     // message from response buffer.
499     MaybeIssueResponseReadLocked();
500 
501     if (ok) {
502       MaybeIssueRequestWriteLocked();
503     }
504     mu_.unlock();
505   }
506 
ResponseReadCompleted(bool ok)507   void ResponseReadCompleted(bool ok) override {
508     VLOG(3) << "StreamingRPCState(" << this
509             << ")::ResponseReadCompleted(ok=" << ok << ")";
510     mu_.lock();
511     if (call_state_ != State::kActive) {
512       mu_.unlock();
513       return;
514     }
515     if (!ok) {
516       IssueCallFinishLocked();
517       mu_.unlock();
518       return;
519     }
520 
521     // Complete the exchange without holding the lock because user's
522     // callback can call back into this RPC code resulting in a deadlock.
523     // No other thread can pop this exchange while we release the lock because
524     // this is the only method that pops exchanges and it is called from a
525     // single thread that waits on completion queue events.
526     Exchange* e;
527     e = &exchanges_.GetFront();
528     mu_.unlock();
529 
530     e->Complete(OkStatus());
531 
532     {
533       mutex_lock l(mu_);
534       exchanges_.PopFront();
535       MaybeIssueResponseReadLocked();
536     }
537   }
538 
CallFinished(bool ok)539   void CallFinished(bool ok) override {
540     VLOG(3) << "StreamingRPCState(" << this << ")::CallFinished(ok=" << ok
541             << ")";
542     mu_.lock();
543     DCHECK(call_state_ != State::kActive);
544     if (call_state_ != State::kFinishing) {
545       mu_.unlock();
546       return;
547     }
548 
549     Status s = FromGrpcStatus(call_status_);
550     if (s.ok() && !ok) {
551       s.Update(
552           errors::Internal("GRPC status is okay but CompletionQueueStatus is "
553                            "not.  This should never happen.",
554                            context_->debug_error_string()));
555     }
556     // unlocks mu_
557     MarkDoneAndCompleteExchanges(s);
558   }
559 
DebugString()560   string DebugString() const override {
561     mutex_lock l(mu_);
562     return exchanges_.DebugString();
563   }
564 
565  private:
566   enum class State {
567     kActive,
568     kFinishing,
569     kDone,
570   };
571 
MarkDoneAndCompleteExchanges(Status status)572   void MarkDoneAndCompleteExchanges(Status status)
573       TF_EXCLUSIVE_LOCKS_REQUIRED(mu_) TF_UNLOCK_FUNCTION(mu_) {
574     call_state_ = State::kDone;
575     VLOG(2) << "Ending gRPC streaming call on the client side due to "
576             << status.ToString();
577     // Swap the exchanges_ into a temporary ExchangeQueue so that we can
578     // complete all exchanges without holding mu_ in case user callback
579     // reach back into this. This should be impossible now, but safer for
580     // the future.
581     ExchangeQueue queue;
582     exchanges_.Swap(&queue);
583     mu_.unlock();
584     queue.CompleteAll(status);
585   }
586 
MaybeIssueRequestWriteLocked()587   void MaybeIssueRequestWriteLocked() TF_EXCLUSIVE_LOCKS_REQUIRED(mu_) {
588     Exchange* exchange = exchanges_.GetReadyForRequestWriting();
589     if (exchange == nullptr) {
590       // There are no queued exchanges, there is already an outstanding write,
591       // or there are no just created exchanges.
592       return;
593     }
594     exchange->MarkRequestWriteIssued();
595     Ref();
596     VLOG(3) << "StreamingRPCState(" << this << ") calling grpc::Write";
597     call_->Write(exchange->request_buf(), &request_write_completed_tag_);
598   }
599 
MaybeIssueResponseReadLocked()600   void MaybeIssueResponseReadLocked() TF_EXCLUSIVE_LOCKS_REQUIRED(mu_) {
601     Exchange* exchange = exchanges_.GetReadyForResponseReading();
602     if (exchange == nullptr) {
603       return;
604     }
605     exchange->MarkResponseReadIssued();
606     Ref();
607     VLOG(3) << "StreamingRPCState(" << this << ") calling grpc::Read";
608     call_->Read(exchange->response_buf(), &response_read_completed_tag_);
609   }
610 
IssueCallFinishLocked()611   void IssueCallFinishLocked() TF_EXCLUSIVE_LOCKS_REQUIRED(mu_) {
612     call_state_ = State::kFinishing;
613     Ref();
614     VLOG(3) << "StreamingRPCState(" << this << ") calling grpc::Finish";
615     // We call finish in response to completed (with error) response reading tag
616     // on some exchange. We let this exchange hang in ResponseReadIssued state.
617     // ExchangeQueue makes sure that there is at most one exchange in this
618     // state. So, no new reads will be issued.
619     call_->Finish(&call_status_, &finished_tag_);
620   }
621 
622   // Holds state for a single request/response exchange between the client
623   // and the server.
624   typedef typename UntypedStreamingRPCState::Tag Tag;
625 
626   // Order of context_ and call_ is important because context_ must outlive
627   // call_.
628   const std::shared_ptr<const ::grpc::ClientContext> context_;
629   std::unique_ptr<::grpc::GenericClientAsyncReaderWriter> call_;
630 
631   mutable mutex mu_;
632   ExchangeQueue exchanges_ TF_GUARDED_BY(mu_);
633   State call_state_ TF_GUARDED_BY(mu_);
634   ::grpc::Status call_status_ TF_GUARDED_BY(mu_);
635 
636   // We can get away with having single instances of these tags per
637   // StreamingRPCState because we make sure (as gRPC requires) that
638   // there is at most one outstanding Read and at most one outstanding Write
639   // in the completion queue.
640   // Tags are immutable. No need to guard them.
641   Tag call_started_tag_{this, Tag::TagType::kCallStarted};
642   Tag request_write_completed_tag_{this, Tag::TagType::kRequestWriteCompleted};
643   Tag response_read_completed_tag_{this, Tag::TagType::kResponseReadCompleted};
644   Tag finished_tag_{this, Tag::TagType::kCallFinished};
645 };
646 
647 // Creates streaming calls and dispatches requests to them.
648 // In the common case, the client would create a StreamingRPCDispatcher for
649 // each bidirectional streaming RPC it might want to make. The first time, it
650 // calls SendNextRequest, a streaming call is initiated and the request is
651 // sent within this call. Initiation of the call blocks the client. If there are
652 // no errors, subsequent calls to SendNextRequest would use the already active
653 // call. If there was an error, the call object will be destroyed after all
654 // the callbacks for outstanding requests have been invoked. The next call to
655 // SendNextRequest will initiate a new call.
656 //
657 // Callbacks that are part of the same call, are invoked in the order they were
658 // provided, but callbacks across calls (a failed and a new one) can be invoked
659 // in any order.
660 //
661 // Thread-safe.
662 template <class Response>
663 class StreamingRPCDispatcher {
664  public:
StreamingRPCDispatcher(::grpc::GenericStub * stub,::grpc::CompletionQueue * cq,const::grpc::string & method)665   StreamingRPCDispatcher(::grpc::GenericStub* stub, ::grpc::CompletionQueue* cq,
666                          const ::grpc::string& method)
667       : stub_(stub), cq_(cq), method_(method) {}
668 
669   // Attempts to send the next request. If there is no active streaming call,
670   // starts one and sends the request on top of it. `done` is invoked when
671   // `response` has been filled with the data from the server, or if there
672   // is an error. `done` can be invoked before SendNextRequest returns.
SendNextRequest(const protobuf::Message & request,Response * response,StatusCallback done)673   void SendNextRequest(const protobuf::Message& request, Response* response,
674                        StatusCallback done) {
675     mutex_lock l(mu_);
676     if (state_ == nullptr) {
677       CreateStreamingState();
678     }
679 
680     bool is_call_alive = state_->SendNextRequest(request, response, done);
681     if (is_call_alive) {
682       return;
683     }
684 
685     // The attempt to send failed because the call was dead, create a new
686     // call and try again. When the call is dead SendNextRequest does not call
687     // `done`.
688     CreateStreamingState();
689 
690     is_call_alive = state_->SendNextRequest(request, response, done);
691     if (!is_call_alive) {
692       // Consider retrying to create and start a call few more times.
693       done(errors::Unknown("gRPC call failed right after it was created"));
694     }
695   }
696 
697   // Request to cancel the current streaming call. Non-blocking.
CancelCall()698   void CancelCall() {
699     mutex_lock l(mu_);
700     if (state_ == nullptr) {
701       return;
702     }
703     context_->TryCancel();
704     state_ = nullptr;
705   }
706 
707  private:
CreateStreamingState()708   void CreateStreamingState() TF_EXCLUSIVE_LOCKS_REQUIRED(mu_) {
709     // ClientContext cannot be reused across calls.
710     context_ = std::make_shared<::grpc::ClientContext>();
711     // Don't immediately fail StartCall if the channel is not ready. Wait for
712     // the channel to become ready.
713     context_->set_wait_for_ready(true);
714 
715     std::unique_ptr<::grpc::GenericClientAsyncReaderWriter> call =
716         stub_->PrepareCall(context_.get(), method_, cq_);
717 
718     state_.reset(new StreamingRPCState<Response>(std::move(call), context_));
719   }
720 
721   mutable mutex mu_;
722 
723   // Both are thread-safe
724   ::grpc::GenericStub* const stub_;
725   ::grpc::CompletionQueue* const cq_;
726 
727   // Does not need synchronization since it is constant.
728   const ::grpc::string method_;
729 
730   std::shared_ptr<::grpc::ClientContext> context_ TF_GUARDED_BY(mu_);
731   core::RefCountPtr<StreamingRPCState<Response>> state_ TF_GUARDED_BY(mu_);
732 };
733 
734 }  // namespace tensorflow
735 
736 #endif  // TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_RPC_GRPC_STATE_H_
737