xref: /aosp_15_r20/external/tensorflow/tensorflow/core/distributed_runtime/rpc/eager/grpc_eager_client.cc (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/core/distributed_runtime/rpc/eager/grpc_eager_client.h"
17 
18 #include <string>
19 
20 #include "grpcpp/generic/generic_stub.h"
21 #include "tensorflow/core/distributed_runtime/call_options.h"
22 #include "tensorflow/core/distributed_runtime/rpc/eager/grpc_eager_service.h"
23 #include "tensorflow/core/distributed_runtime/rpc/grpc_client_cq_tag.h"
24 #include "tensorflow/core/distributed_runtime/rpc/grpc_state.h"
25 #include "tensorflow/core/distributed_runtime/rpc/grpc_util.h"
26 #include "tensorflow/core/framework/metrics.h"
27 #include "tensorflow/core/lib/core/refcount.h"
28 #include "tensorflow/core/lib/core/status.h"
29 #include "tensorflow/core/platform/env.h"
30 #include "tensorflow/core/platform/error_payloads.h"
31 #include "tensorflow/core/platform/macros.h"
32 #include "tensorflow/core/protobuf/core_platform_payloads.pb.h"
33 #include "tensorflow/core/protobuf/eager_service.pb.h"
34 #include "tensorflow/core/util/env_var.h"
35 
36 namespace tensorflow {
37 namespace eager {
38 namespace {
39 
40 /* Retrieve the global env variable.
41  * Setting environment variable "TF_ENABLE_EAGER_CLIENT_STREAMING_ENQUEUE" to
42  * true will turn on asynchronous execution of remote op. It means that when
43  * executing an op on a remote worker, client will not block on waiting
44  * for the response anymore. Using follow code as example:
45  *
46  * with tf.device('worker:0'):
47  *   a = tf.matmul(...)
48  *   b = tf.matmul(...)
49  * logging.into('Requests sent')    # Probably not executed yet
50  * logging.info('b: %s', b.numpy()) # Block until 'b' finished.
51  *
52  * Streaming RPC will preserve order as well. So 'a' must be executed before
53  * 'b' on 'worker:0'.
54  *
55  * When turning on this feature, you should explicitly wait for some result
56  * from remote workers at the end of you python program. Otherwise, client may
57  * shutdown remote workers without waiting all pending ops.
58  *
59  * Note that the caller could still disable streaming enqueue, even though
60  * EnableStreaminh() returns true, if the caller's executor is set to disable
61  * streaming enqueue when the executor was created. EnableStreaming() is
62  * determined based on the global env variable, which by default is turned on
63  * for the main executor.
64  *
65  * TODO(fishx): When exiting client, make sure all pending ops on remote workers
66  * are finished.
67  */
EnableStreaming()68 bool EnableStreaming() {
69   bool result;
70   TF_CHECK_OK(ReadBoolFromEnvVar("TF_ENABLE_EAGER_CLIENT_STREAMING_ENQUEUE",
71                                  true, &result));
72   return result;
73 }
74 
75 // Ref-counted thread to handle callbacks for completed requests a GRPC
76 // completion queue. The thread might be shared by multiple eager clients, and
77 // each one of them should hold a reference count to ensure that the thread
78 // outlives the clients.
79 // To ensure that every tag in completion queue is processed, this thread also
80 // holds a reference to itself and always wait until ref count is one to exit.
81 class GrpcEagerClientThread : public core::RefCounted {
82  public:
GrpcEagerClientThread()83   GrpcEagerClientThread() {
84     // Hold a reference to ensure every completion tag gets processed.
85     Ref();
86     thread_.reset(Env::Default()->StartThread(
87         ThreadOptions(), "eager_client_thread", [this]() {
88           void* tag;
89           bool ok;
90           while (completion_queue_.Next(&tag, &ok)) {
91             VLOG(4) << "GrpcEagerClientThread got next tag";
92             GrpcClientCQTag* callback_tag = static_cast<GrpcClientCQTag*>(tag);
93             callback_tag->OnCompleted(ok);
94             VLOG(4) << "GrpcEagerClientThread blocking for next tag";
95             if (RefCountIsOne()) {
96               break;
97             }
98           }
99           VLOG(4) << "GrpcEagerClientThread exiting";
100           completion_queue_.Shutdown();
101           // `this` holds the final reference so cannot directly Unref here.
102           // Instead, schedule a separate thread to clean it up.
103           Env::Default()->SchedClosure([this]() { this->Unref(); });
104         }));
105   }
106 
~GrpcEagerClientThread()107   ~GrpcEagerClientThread() override {}
108 
completion_queue()109   ::grpc::CompletionQueue* completion_queue() { return &completion_queue_; }
110 
111  private:
112   ::grpc::CompletionQueue completion_queue_;
113   std::unique_ptr<Thread> thread_;
114 };
115 
116 class GrpcEagerClient : public EagerClient {
117  public:
GrpcEagerClient(const tensorflow::SharedGrpcChannelPtr & channel,GrpcEagerClientThread * thread,const string & target)118   GrpcEagerClient(const tensorflow::SharedGrpcChannelPtr& channel,
119                   GrpcEagerClientThread* thread, const string& target)
120       : stub_(channel), thread_(thread), target_(target) {
121     // Hold a reference to make sure the corresponding EagerClientThread
122     // outlives the client.
123     thread_->Ref();
124     cq_ = thread->completion_queue();
125   }
~GrpcEagerClient()126   ~GrpcEagerClient() override { thread_->Unref(); }
127 
allow_multiple_pending_requests() const128   bool allow_multiple_pending_requests() const override {
129     return EnableStreaming();
130   }
131 
132 #define CLIENT_METHOD(method)                                             \
133   void method##Async(const method##Request* request,                      \
134                      method##Response* response, StatusCallback done)     \
135       override {                                                          \
136     StatusCallback done_wrapped = callback_wrapper(std::move(done));      \
137     new RPCState<protobuf::Message>(                                      \
138         &stub_, cq_, "/tensorflow.eager.EagerService/" #method, *request, \
139         response, std::move(done_wrapped), /*call_opts=*/nullptr,         \
140         /*threadpool=*/nullptr, /*max_retries=*/0, /*fail_fast=*/true,    \
141         &target_);                                                        \
142   }
143 
144   CLIENT_METHOD(CreateContext);
145   CLIENT_METHOD(UpdateContext);
146   CLIENT_METHOD(WaitQueueDone);
147   CLIENT_METHOD(KeepAlive);
148 
149 #undef CLIENT_METHOD
150 
151 #define CLIENT_CANCELABLE_METHOD(method)                                      \
152   void method##Async(CallOptions* call_opts, const method##Request* request,  \
153                      method##Response* response, StatusCallback done)         \
154       override {                                                              \
155     StatusCallback done_wrapped = callback_wrapper(std::move(done));          \
156     new RPCState<protobuf::Message>(                                          \
157         &stub_, cq_, "/tensorflow.eager.EagerService/" #method, *request,     \
158         response, std::move(done_wrapped), call_opts, /*threadpool=*/nullptr, \
159         /*max_retries=*/0, /*fail_fast=*/true, &target_);                     \
160   }
161 
162   CLIENT_CANCELABLE_METHOD(Enqueue);
163   CLIENT_CANCELABLE_METHOD(RunComponentFunction);
164 
165 #undef CLIENT_CANCELABLE_METHOD
166 
CloseContextAsync(const CloseContextRequest * request,CloseContextResponse * response,StatusCallback done)167   void CloseContextAsync(const CloseContextRequest* request,
168                          CloseContextResponse* response,
169                          StatusCallback done) override {
170     StatusCallback done_wrapped = callback_wrapper(std::move(done));
171     new RPCState<protobuf::Message>(
172         &stub_, cq_, "/tensorflow.eager.EagerService/CloseContext", *request,
173         response, std::move(done_wrapped), /*call_opts=*/nullptr,
174         /*threadpool=*/nullptr, /*max_retries=*/0, /*fail_fast=*/true,
175         &target_);
176 
177     VLOG(1) << "Sending RPC to close remote eager context "
178             << request->DebugString();
179 
180     mutex_lock l(mu_);
181     const auto& it = enqueue_dispatchers_.find(request->context_id());
182     if (it != enqueue_dispatchers_.end()) {
183       it->second.CancelCall();
184       enqueue_dispatchers_.erase(it);
185     } else if (EnableStreaming()) {
186       LOG(ERROR) << "Remote EagerContext with id " << request->context_id()
187                  << " does not seem to exist.";
188     }
189   }
190 
StreamingEnqueueAsync(bool enable_streaming_enqueue,CallOptions * call_opts,const EnqueueRequest * request,EnqueueResponse * response,StatusCallback done)191   void StreamingEnqueueAsync(bool enable_streaming_enqueue,
192                              CallOptions* call_opts,
193                              const EnqueueRequest* request,
194                              EnqueueResponse* response,
195                              StatusCallback done) override {
196     StatusCallback done_wrapped = callback_wrapper(std::move(done));
197     // Whether streaming enqueue is used is determined based on 2 factors:
198     // 1. The global env variable, as checked in EnableStreaming().
199     // 2. The flag set in the eager executor.
200     // Streaming enqueue is allowed only when the both are enabled.
201     if (EnableStreaming() && enable_streaming_enqueue) {
202       mutex_lock l(mu_);
203       auto it = enqueue_dispatchers_.find(request->context_id());
204       if (it == enqueue_dispatchers_.end()) {
205         auto it_and_bool = enqueue_dispatchers_.emplace(
206             std::piecewise_construct,
207             std::forward_as_tuple(request->context_id()),
208             std::forward_as_tuple(
209                 &stub_, cq_,
210                 "/tensorflow.eager.EagerService/StreamingEnqueue"));
211         it = it_and_bool.first;
212       }
213       // TODO(haoyuzhang): Consider supporting cancellation for streaming RPC?
214       it->second.SendNextRequest(*request, response, std::move(done_wrapped));
215     } else {
216       Notification n;
217       Status status;
218       EnqueueAsync(call_opts, request, response,
219                    [&n, &status](const Status& s) {
220                      status.Update(s);
221                      n.Notify();
222                    });
223       n.WaitForNotification();
224       done_wrapped(status);
225     }
226   }
227 
228  private:
229   ::grpc::GenericStub stub_;
230   const GrpcEagerClientThread* thread_;
231   const string target_;
232 
233   ::grpc::CompletionQueue* cq_;
234 
235   mutable mutex mu_;
236 
237   std::unordered_map<uint64, StreamingRPCDispatcher<EnqueueResponse>>
238       enqueue_dispatchers_ TF_GUARDED_BY(mu_);
239 
callback_wrapper(StatusCallback done)240   StatusCallback callback_wrapper(StatusCallback done) {
241     Ref();
242     return [this, done = std::move(done)](const Status& status) {
243       done(status);
244       this->Unref();
245       if (TF_PREDICT_FALSE(!status.ok())) {
246         // Retrieve the location where the error was produced.
247         auto error_source_payload = status.GetPayload(kErrorSource);
248 
249         if (error_source_payload.has_value()) {
250           tensorflow::core::platform::ErrorSourceProto error_source_proto;
251           error_source_proto.ParseFromString(
252               std::string(*error_source_payload));  // NOLINT
253           metrics::UpdateEagerClientErrorCounter(
254               error_source_proto.ErrorSource_Name(
255                   error_source_proto.error_source()),
256               error_name(status.code()));
257         } else {
258           metrics::UpdateEagerClientErrorCounter("unknown",
259                                                  error_name(status.code()));
260         }
261       }
262     };
263   }
264 };
265 
266 class GrpcEagerClientCache : public EagerClientCache {
267  public:
GrpcEagerClientCache(std::shared_ptr<tensorflow::GrpcChannelCache> cache)268   explicit GrpcEagerClientCache(
269       std::shared_ptr<tensorflow::GrpcChannelCache> cache)
270       : next_round_robin_assignment_(0), cache_(cache), threads_(4) {
271     for (int i = 0, end = threads_.size(); i < end; i++) {
272       threads_[i].reset(new GrpcEagerClientThread());
273     }
274   }
275 
~GrpcEagerClientCache()276   ~GrpcEagerClientCache() override { threads_.clear(); }
277 
GetClient(const string & target,core::RefCountPtr<EagerClient> * client)278   Status GetClient(const string& target,
279                    core::RefCountPtr<EagerClient>* client) override {
280     mutex_lock l(clients_mu_);
281     auto it = clients_.find(target);
282     if (it == clients_.end()) {
283       tensorflow::SharedGrpcChannelPtr shared =
284           cache_->FindWorkerChannel(target);
285       if (shared == nullptr) {
286         return errors::InvalidArgument("Client for target ", target,
287                                        " not found.");
288       }
289       int assigned_index = AssignClientToThread(target);
290       GrpcEagerClientThread* thread = threads_[assigned_index].get();
291       core::RefCountPtr<EagerClient> worker(
292           new GrpcEagerClient(shared, thread, target));
293       it = clients_.emplace(target, std::move(worker)).first;
294     }
295 
296     it->second->Ref();
297     client->reset(it->second.get());
298     return OkStatus();
299   }
300 
301  private:
302   mutex assignment_mu_;
303   std::unordered_map<std::string, size_t> target_assignments_
304       TF_GUARDED_BY(assignment_mu_);
305   size_t next_round_robin_assignment_ TF_GUARDED_BY(assignment_mu_);
306 
AssignClientToThread(const string & target)307   size_t AssignClientToThread(const string& target) {
308     // Round-robin target assignment, but keeps the same target on the same
309     // polling thread always, as this is important for gRPC performance
310     mutex_lock lock(assignment_mu_);
311     auto it = target_assignments_.find(target);
312     if (it == target_assignments_.end()) {
313       it = target_assignments_
314                .insert(std::make_pair(
315                    target, (next_round_robin_assignment_++) % threads_.size()))
316                .first;
317     }
318     return it->second;
319   }
320 
321   std::shared_ptr<tensorflow::GrpcChannelCache> cache_;
322   mutable mutex clients_mu_;
323   std::unordered_map<string, core::RefCountPtr<EagerClient>> clients_
324       TF_GUARDED_BY(clients_mu_);
325   std::vector<core::RefCountPtr<GrpcEagerClientThread>> threads_;
326 };
327 
328 }  // namespace
329 
NewGrpcEagerClientCache(std::shared_ptr<tensorflow::GrpcChannelCache> channel)330 EagerClientCache* NewGrpcEagerClientCache(
331     std::shared_ptr<tensorflow::GrpcChannelCache> channel) {
332   return new GrpcEagerClientCache(channel);
333 }
334 
335 }  // namespace eager
336 }  // namespace tensorflow
337