1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/core/distributed_runtime/rpc/eager/grpc_eager_client.h"
17
18 #include <string>
19
20 #include "grpcpp/generic/generic_stub.h"
21 #include "tensorflow/core/distributed_runtime/call_options.h"
22 #include "tensorflow/core/distributed_runtime/rpc/eager/grpc_eager_service.h"
23 #include "tensorflow/core/distributed_runtime/rpc/grpc_client_cq_tag.h"
24 #include "tensorflow/core/distributed_runtime/rpc/grpc_state.h"
25 #include "tensorflow/core/distributed_runtime/rpc/grpc_util.h"
26 #include "tensorflow/core/framework/metrics.h"
27 #include "tensorflow/core/lib/core/refcount.h"
28 #include "tensorflow/core/lib/core/status.h"
29 #include "tensorflow/core/platform/env.h"
30 #include "tensorflow/core/platform/error_payloads.h"
31 #include "tensorflow/core/platform/macros.h"
32 #include "tensorflow/core/protobuf/core_platform_payloads.pb.h"
33 #include "tensorflow/core/protobuf/eager_service.pb.h"
34 #include "tensorflow/core/util/env_var.h"
35
36 namespace tensorflow {
37 namespace eager {
38 namespace {
39
40 /* Retrieve the global env variable.
41 * Setting environment variable "TF_ENABLE_EAGER_CLIENT_STREAMING_ENQUEUE" to
42 * true will turn on asynchronous execution of remote op. It means that when
43 * executing an op on a remote worker, client will not block on waiting
44 * for the response anymore. Using follow code as example:
45 *
46 * with tf.device('worker:0'):
47 * a = tf.matmul(...)
48 * b = tf.matmul(...)
49 * logging.into('Requests sent') # Probably not executed yet
50 * logging.info('b: %s', b.numpy()) # Block until 'b' finished.
51 *
52 * Streaming RPC will preserve order as well. So 'a' must be executed before
53 * 'b' on 'worker:0'.
54 *
55 * When turning on this feature, you should explicitly wait for some result
56 * from remote workers at the end of you python program. Otherwise, client may
57 * shutdown remote workers without waiting all pending ops.
58 *
59 * Note that the caller could still disable streaming enqueue, even though
60 * EnableStreaminh() returns true, if the caller's executor is set to disable
61 * streaming enqueue when the executor was created. EnableStreaming() is
62 * determined based on the global env variable, which by default is turned on
63 * for the main executor.
64 *
65 * TODO(fishx): When exiting client, make sure all pending ops on remote workers
66 * are finished.
67 */
EnableStreaming()68 bool EnableStreaming() {
69 bool result;
70 TF_CHECK_OK(ReadBoolFromEnvVar("TF_ENABLE_EAGER_CLIENT_STREAMING_ENQUEUE",
71 true, &result));
72 return result;
73 }
74
75 // Ref-counted thread to handle callbacks for completed requests a GRPC
76 // completion queue. The thread might be shared by multiple eager clients, and
77 // each one of them should hold a reference count to ensure that the thread
78 // outlives the clients.
79 // To ensure that every tag in completion queue is processed, this thread also
80 // holds a reference to itself and always wait until ref count is one to exit.
81 class GrpcEagerClientThread : public core::RefCounted {
82 public:
GrpcEagerClientThread()83 GrpcEagerClientThread() {
84 // Hold a reference to ensure every completion tag gets processed.
85 Ref();
86 thread_.reset(Env::Default()->StartThread(
87 ThreadOptions(), "eager_client_thread", [this]() {
88 void* tag;
89 bool ok;
90 while (completion_queue_.Next(&tag, &ok)) {
91 VLOG(4) << "GrpcEagerClientThread got next tag";
92 GrpcClientCQTag* callback_tag = static_cast<GrpcClientCQTag*>(tag);
93 callback_tag->OnCompleted(ok);
94 VLOG(4) << "GrpcEagerClientThread blocking for next tag";
95 if (RefCountIsOne()) {
96 break;
97 }
98 }
99 VLOG(4) << "GrpcEagerClientThread exiting";
100 completion_queue_.Shutdown();
101 // `this` holds the final reference so cannot directly Unref here.
102 // Instead, schedule a separate thread to clean it up.
103 Env::Default()->SchedClosure([this]() { this->Unref(); });
104 }));
105 }
106
~GrpcEagerClientThread()107 ~GrpcEagerClientThread() override {}
108
completion_queue()109 ::grpc::CompletionQueue* completion_queue() { return &completion_queue_; }
110
111 private:
112 ::grpc::CompletionQueue completion_queue_;
113 std::unique_ptr<Thread> thread_;
114 };
115
116 class GrpcEagerClient : public EagerClient {
117 public:
GrpcEagerClient(const tensorflow::SharedGrpcChannelPtr & channel,GrpcEagerClientThread * thread,const string & target)118 GrpcEagerClient(const tensorflow::SharedGrpcChannelPtr& channel,
119 GrpcEagerClientThread* thread, const string& target)
120 : stub_(channel), thread_(thread), target_(target) {
121 // Hold a reference to make sure the corresponding EagerClientThread
122 // outlives the client.
123 thread_->Ref();
124 cq_ = thread->completion_queue();
125 }
~GrpcEagerClient()126 ~GrpcEagerClient() override { thread_->Unref(); }
127
allow_multiple_pending_requests() const128 bool allow_multiple_pending_requests() const override {
129 return EnableStreaming();
130 }
131
132 #define CLIENT_METHOD(method) \
133 void method##Async(const method##Request* request, \
134 method##Response* response, StatusCallback done) \
135 override { \
136 StatusCallback done_wrapped = callback_wrapper(std::move(done)); \
137 new RPCState<protobuf::Message>( \
138 &stub_, cq_, "/tensorflow.eager.EagerService/" #method, *request, \
139 response, std::move(done_wrapped), /*call_opts=*/nullptr, \
140 /*threadpool=*/nullptr, /*max_retries=*/0, /*fail_fast=*/true, \
141 &target_); \
142 }
143
144 CLIENT_METHOD(CreateContext);
145 CLIENT_METHOD(UpdateContext);
146 CLIENT_METHOD(WaitQueueDone);
147 CLIENT_METHOD(KeepAlive);
148
149 #undef CLIENT_METHOD
150
151 #define CLIENT_CANCELABLE_METHOD(method) \
152 void method##Async(CallOptions* call_opts, const method##Request* request, \
153 method##Response* response, StatusCallback done) \
154 override { \
155 StatusCallback done_wrapped = callback_wrapper(std::move(done)); \
156 new RPCState<protobuf::Message>( \
157 &stub_, cq_, "/tensorflow.eager.EagerService/" #method, *request, \
158 response, std::move(done_wrapped), call_opts, /*threadpool=*/nullptr, \
159 /*max_retries=*/0, /*fail_fast=*/true, &target_); \
160 }
161
162 CLIENT_CANCELABLE_METHOD(Enqueue);
163 CLIENT_CANCELABLE_METHOD(RunComponentFunction);
164
165 #undef CLIENT_CANCELABLE_METHOD
166
CloseContextAsync(const CloseContextRequest * request,CloseContextResponse * response,StatusCallback done)167 void CloseContextAsync(const CloseContextRequest* request,
168 CloseContextResponse* response,
169 StatusCallback done) override {
170 StatusCallback done_wrapped = callback_wrapper(std::move(done));
171 new RPCState<protobuf::Message>(
172 &stub_, cq_, "/tensorflow.eager.EagerService/CloseContext", *request,
173 response, std::move(done_wrapped), /*call_opts=*/nullptr,
174 /*threadpool=*/nullptr, /*max_retries=*/0, /*fail_fast=*/true,
175 &target_);
176
177 VLOG(1) << "Sending RPC to close remote eager context "
178 << request->DebugString();
179
180 mutex_lock l(mu_);
181 const auto& it = enqueue_dispatchers_.find(request->context_id());
182 if (it != enqueue_dispatchers_.end()) {
183 it->second.CancelCall();
184 enqueue_dispatchers_.erase(it);
185 } else if (EnableStreaming()) {
186 LOG(ERROR) << "Remote EagerContext with id " << request->context_id()
187 << " does not seem to exist.";
188 }
189 }
190
StreamingEnqueueAsync(bool enable_streaming_enqueue,CallOptions * call_opts,const EnqueueRequest * request,EnqueueResponse * response,StatusCallback done)191 void StreamingEnqueueAsync(bool enable_streaming_enqueue,
192 CallOptions* call_opts,
193 const EnqueueRequest* request,
194 EnqueueResponse* response,
195 StatusCallback done) override {
196 StatusCallback done_wrapped = callback_wrapper(std::move(done));
197 // Whether streaming enqueue is used is determined based on 2 factors:
198 // 1. The global env variable, as checked in EnableStreaming().
199 // 2. The flag set in the eager executor.
200 // Streaming enqueue is allowed only when the both are enabled.
201 if (EnableStreaming() && enable_streaming_enqueue) {
202 mutex_lock l(mu_);
203 auto it = enqueue_dispatchers_.find(request->context_id());
204 if (it == enqueue_dispatchers_.end()) {
205 auto it_and_bool = enqueue_dispatchers_.emplace(
206 std::piecewise_construct,
207 std::forward_as_tuple(request->context_id()),
208 std::forward_as_tuple(
209 &stub_, cq_,
210 "/tensorflow.eager.EagerService/StreamingEnqueue"));
211 it = it_and_bool.first;
212 }
213 // TODO(haoyuzhang): Consider supporting cancellation for streaming RPC?
214 it->second.SendNextRequest(*request, response, std::move(done_wrapped));
215 } else {
216 Notification n;
217 Status status;
218 EnqueueAsync(call_opts, request, response,
219 [&n, &status](const Status& s) {
220 status.Update(s);
221 n.Notify();
222 });
223 n.WaitForNotification();
224 done_wrapped(status);
225 }
226 }
227
228 private:
229 ::grpc::GenericStub stub_;
230 const GrpcEagerClientThread* thread_;
231 const string target_;
232
233 ::grpc::CompletionQueue* cq_;
234
235 mutable mutex mu_;
236
237 std::unordered_map<uint64, StreamingRPCDispatcher<EnqueueResponse>>
238 enqueue_dispatchers_ TF_GUARDED_BY(mu_);
239
callback_wrapper(StatusCallback done)240 StatusCallback callback_wrapper(StatusCallback done) {
241 Ref();
242 return [this, done = std::move(done)](const Status& status) {
243 done(status);
244 this->Unref();
245 if (TF_PREDICT_FALSE(!status.ok())) {
246 // Retrieve the location where the error was produced.
247 auto error_source_payload = status.GetPayload(kErrorSource);
248
249 if (error_source_payload.has_value()) {
250 tensorflow::core::platform::ErrorSourceProto error_source_proto;
251 error_source_proto.ParseFromString(
252 std::string(*error_source_payload)); // NOLINT
253 metrics::UpdateEagerClientErrorCounter(
254 error_source_proto.ErrorSource_Name(
255 error_source_proto.error_source()),
256 error_name(status.code()));
257 } else {
258 metrics::UpdateEagerClientErrorCounter("unknown",
259 error_name(status.code()));
260 }
261 }
262 };
263 }
264 };
265
266 class GrpcEagerClientCache : public EagerClientCache {
267 public:
GrpcEagerClientCache(std::shared_ptr<tensorflow::GrpcChannelCache> cache)268 explicit GrpcEagerClientCache(
269 std::shared_ptr<tensorflow::GrpcChannelCache> cache)
270 : next_round_robin_assignment_(0), cache_(cache), threads_(4) {
271 for (int i = 0, end = threads_.size(); i < end; i++) {
272 threads_[i].reset(new GrpcEagerClientThread());
273 }
274 }
275
~GrpcEagerClientCache()276 ~GrpcEagerClientCache() override { threads_.clear(); }
277
GetClient(const string & target,core::RefCountPtr<EagerClient> * client)278 Status GetClient(const string& target,
279 core::RefCountPtr<EagerClient>* client) override {
280 mutex_lock l(clients_mu_);
281 auto it = clients_.find(target);
282 if (it == clients_.end()) {
283 tensorflow::SharedGrpcChannelPtr shared =
284 cache_->FindWorkerChannel(target);
285 if (shared == nullptr) {
286 return errors::InvalidArgument("Client for target ", target,
287 " not found.");
288 }
289 int assigned_index = AssignClientToThread(target);
290 GrpcEagerClientThread* thread = threads_[assigned_index].get();
291 core::RefCountPtr<EagerClient> worker(
292 new GrpcEagerClient(shared, thread, target));
293 it = clients_.emplace(target, std::move(worker)).first;
294 }
295
296 it->second->Ref();
297 client->reset(it->second.get());
298 return OkStatus();
299 }
300
301 private:
302 mutex assignment_mu_;
303 std::unordered_map<std::string, size_t> target_assignments_
304 TF_GUARDED_BY(assignment_mu_);
305 size_t next_round_robin_assignment_ TF_GUARDED_BY(assignment_mu_);
306
AssignClientToThread(const string & target)307 size_t AssignClientToThread(const string& target) {
308 // Round-robin target assignment, but keeps the same target on the same
309 // polling thread always, as this is important for gRPC performance
310 mutex_lock lock(assignment_mu_);
311 auto it = target_assignments_.find(target);
312 if (it == target_assignments_.end()) {
313 it = target_assignments_
314 .insert(std::make_pair(
315 target, (next_round_robin_assignment_++) % threads_.size()))
316 .first;
317 }
318 return it->second;
319 }
320
321 std::shared_ptr<tensorflow::GrpcChannelCache> cache_;
322 mutable mutex clients_mu_;
323 std::unordered_map<string, core::RefCountPtr<EagerClient>> clients_
324 TF_GUARDED_BY(clients_mu_);
325 std::vector<core::RefCountPtr<GrpcEagerClientThread>> threads_;
326 };
327
328 } // namespace
329
NewGrpcEagerClientCache(std::shared_ptr<tensorflow::GrpcChannelCache> channel)330 EagerClientCache* NewGrpcEagerClientCache(
331 std::shared_ptr<tensorflow::GrpcChannelCache> channel) {
332 return new GrpcEagerClientCache(channel);
333 }
334
335 } // namespace eager
336 } // namespace tensorflow
337