1 /* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/core/distributed_runtime/rpc/grpc_remote_master.h"
17
18 #include <utility>
19
20 #include "absl/time/clock.h"
21 #include "absl/time/time.h"
22 #include "tensorflow/core/distributed_runtime/call_options.h"
23 #include "tensorflow/core/distributed_runtime/master_interface.h"
24 #include "tensorflow/core/distributed_runtime/rpc/grpc_master_service_impl.h"
25 #include "tensorflow/core/distributed_runtime/rpc/grpc_util.h"
26 #include "tensorflow/core/lib/core/errors.h"
27 #include "tensorflow/core/lib/core/status.h"
28 #include "tensorflow/core/lib/strings/strcat.h"
29 #include "tensorflow/core/platform/env.h"
30 #include "tensorflow/core/platform/tracing.h"
31 #include "tensorflow/core/profiler/lib/traceme.h"
32 #include "tensorflow/core/protobuf/master.pb.h"
33
34 namespace tensorflow {
35
36 // GrpcRemoteMaster is an implementation of the MasterInterface
37 // that uses gRPC to talk to the Master service.
38 class GrpcRemoteMaster : public MasterInterface {
39 using MasterServiceStub = grpc::MasterService::Stub;
40
41 public:
GrpcRemoteMaster(const SharedGrpcChannelPtr & client_channel)42 explicit GrpcRemoteMaster(const SharedGrpcChannelPtr& client_channel)
43 : stub_(grpc::MasterService::NewStub(client_channel)) {}
44
~GrpcRemoteMaster()45 ~GrpcRemoteMaster() override {}
46
CreateSession(CallOptions * call_options,const CreateSessionRequest * request,CreateSessionResponse * response)47 Status CreateSession(CallOptions* call_options,
48 const CreateSessionRequest* request,
49 CreateSessionResponse* response) override {
50 return CallWithRetry(call_options, request, response,
51 &MasterServiceStub::CreateSession);
52 }
53
ExtendSession(CallOptions * call_options,const ExtendSessionRequest * request,ExtendSessionResponse * response)54 Status ExtendSession(CallOptions* call_options,
55 const ExtendSessionRequest* request,
56 ExtendSessionResponse* response) override {
57 return CallWithRetry(call_options, request, response,
58 &MasterServiceStub::ExtendSession);
59 }
60
PartialRunSetup(CallOptions * call_options,const PartialRunSetupRequest * request,PartialRunSetupResponse * response)61 Status PartialRunSetup(CallOptions* call_options,
62 const PartialRunSetupRequest* request,
63 PartialRunSetupResponse* response) override {
64 return CallWithRetry(call_options, request, response,
65 &MasterServiceStub::PartialRunSetup);
66 }
67
RunStep(CallOptions * call_options,RunStepRequestWrapper * request,MutableRunStepResponseWrapper * response)68 Status RunStep(CallOptions* call_options, RunStepRequestWrapper* request,
69 MutableRunStepResponseWrapper* response) override {
70 return CallWithRetry(call_options, &request->ToProto(),
71 get_proto_from_wrapper(response),
72 &MasterServiceStub::RunStep, "RunStep/Client");
73 }
74
CloseSession(CallOptions * call_options,const CloseSessionRequest * request,CloseSessionResponse * response)75 Status CloseSession(CallOptions* call_options,
76 const CloseSessionRequest* request,
77 CloseSessionResponse* response) override {
78 return CallWithRetry(call_options, request, response,
79 &MasterServiceStub::CloseSession);
80 }
81
ListDevices(CallOptions * call_options,const ListDevicesRequest * request,ListDevicesResponse * response)82 Status ListDevices(CallOptions* call_options,
83 const ListDevicesRequest* request,
84 ListDevicesResponse* response) override {
85 return CallWithRetry(call_options, request, response,
86 &MasterServiceStub::ListDevices);
87 }
88
Reset(CallOptions * call_options,const ResetRequest * request,ResetResponse * response)89 Status Reset(CallOptions* call_options, const ResetRequest* request,
90 ResetResponse* response) override {
91 return CallWithRetry(call_options, request, response,
92 &MasterServiceStub::Reset);
93 }
94
MakeCallable(CallOptions * call_options,const MakeCallableRequest * request,MakeCallableResponse * response)95 Status MakeCallable(CallOptions* call_options,
96 const MakeCallableRequest* request,
97 MakeCallableResponse* response) override {
98 return CallWithRetry(call_options, request, response,
99 &MasterServiceStub::MakeCallable);
100 }
RunCallable(CallOptions * call_options,const RunCallableRequest * request,RunCallableResponse * response)101 Status RunCallable(CallOptions* call_options,
102 const RunCallableRequest* request,
103 RunCallableResponse* response) override {
104 return CallWithRetry(call_options, request, response,
105 &MasterServiceStub::RunCallable);
106 }
ReleaseCallable(CallOptions * call_options,const ReleaseCallableRequest * request,ReleaseCallableResponse * response)107 Status ReleaseCallable(CallOptions* call_options,
108 const ReleaseCallableRequest* request,
109 ReleaseCallableResponse* response) override {
110 return CallWithRetry(call_options, request, response,
111 &MasterServiceStub::ReleaseCallable);
112 }
113
114 private:
115 // Start tracing, attaching a unique ID to both the trace and the RPC.
NewTraceRpc(StringPiece name,::grpc::ClientContext * ctx)116 profiler::TraceMe* NewTraceRpc(StringPiece name, ::grpc::ClientContext* ctx) {
117 string trace_id = strings::StrCat(tracing::GetUniqueArg());
118 ctx->AddMetadata(GrpcIdKey(), trace_id);
119 return new profiler::TraceMe(
120 [&] { return strings::StrCat(name, ":", trace_id); },
121 profiler::TraceMeLevel::kInfo);
122 }
123
124 template <typename Request, typename Response>
CallWithRetry(CallOptions * call_options,const Request * request,Response * response,::grpc::Status (MasterServiceStub::* pfunc)(::grpc::ClientContext *,const Request &,Response *),string trace_string={})125 Status CallWithRetry(CallOptions* call_options, const Request* request,
126 Response* response,
127 ::grpc::Status (MasterServiceStub::*pfunc)(
128 ::grpc::ClientContext*, const Request&, Response*),
129 string trace_string = {}) {
130 absl::Duration timeout = absl::Milliseconds(call_options->GetTimeout());
131 absl::Time expired_time = absl::FromUnixMicros(Env::Default()->NowMicros());
132 if (timeout > absl::ZeroDuration()) {
133 expired_time += timeout;
134 }
135 Status s;
136 for (int num_retries = 0;; ++num_retries) {
137 ::grpc::ClientContext ctx;
138 std::unique_ptr<profiler::TraceMe> trace;
139 if (!trace_string.empty()) {
140 trace.reset(NewTraceRpc(trace_string, &ctx));
141 }
142 ctx.set_fail_fast(false);
143 if (timeout > absl::ZeroDuration()) {
144 // We do not modify the timeout here to match legacy behavior. However,
145 // this could violate the contract of tensorflow::Session. If we retry
146 // an RPC just before the deadline is exceeded, we will still set the
147 // timeout to the original value. This leads to the overall timeout
148 // being double what was expected.
149 ctx.set_deadline(absl::ToChronoTime(absl::Now() + timeout));
150 }
151 s = FromGrpcStatus((stub_.get()->*pfunc)(&ctx, *request, response));
152 if (!errors::IsUnavailable(s)) {
153 return s;
154 }
155 // TODO(b/117162170): we may want to make this configurable.
156 constexpr int kMaxRetries = 10;
157 LOG(WARNING) << "RPC failed with status = \"" << s
158 << "\" and grpc_error_string = \""
159 << ctx.debug_error_string() << "\", maybe retrying the RPC";
160 if (num_retries >= kMaxRetries) {
161 LOG(WARNING) << "Too many retries, returning last status: " << s;
162 return s;
163 }
164 absl::Time now = absl::FromUnixMicros(Env::Default()->NowMicros());
165 const absl::Time deadline_with_backoff =
166 now + absl::Microseconds(ComputeBackoffMicroseconds(num_retries));
167 // Wait for a short period of time before retrying the RPC. If our
168 // backoff would put us past the RPC deadline, we truncate it to ensure
169 // our RPC starts before the deadline.
170 const auto backoff_until = (timeout <= absl::ZeroDuration() ||
171 expired_time > deadline_with_backoff)
172 ? deadline_with_backoff
173 : expired_time;
174 Env::Default()->SleepForMicroseconds(
175 absl::ToInt64Microseconds(backoff_until - now));
176 now = absl::FromUnixMicros(Env::Default()->NowMicros());
177 if (now > expired_time && timeout > absl::ZeroDuration()) {
178 // If timeout_in_ms is set, exit the retry loop on timeout.
179 return errors::DeadlineExceeded(ctx.debug_error_string());
180 }
181 }
182 }
183
184 std::unique_ptr<MasterServiceStub> stub_;
185 };
186
NewGrpcMaster(const SharedGrpcChannelPtr & channel)187 MasterInterface* NewGrpcMaster(const SharedGrpcChannelPtr& channel) {
188 return new GrpcRemoteMaster(channel);
189 }
190
191 } // namespace tensorflow
192