xref: /aosp_15_r20/external/tensorflow/tensorflow/core/distributed_runtime/rpc/grpc_remote_master.cc (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/core/distributed_runtime/rpc/grpc_remote_master.h"
17 
18 #include <utility>
19 
20 #include "absl/time/clock.h"
21 #include "absl/time/time.h"
22 #include "tensorflow/core/distributed_runtime/call_options.h"
23 #include "tensorflow/core/distributed_runtime/master_interface.h"
24 #include "tensorflow/core/distributed_runtime/rpc/grpc_master_service_impl.h"
25 #include "tensorflow/core/distributed_runtime/rpc/grpc_util.h"
26 #include "tensorflow/core/lib/core/errors.h"
27 #include "tensorflow/core/lib/core/status.h"
28 #include "tensorflow/core/lib/strings/strcat.h"
29 #include "tensorflow/core/platform/env.h"
30 #include "tensorflow/core/platform/tracing.h"
31 #include "tensorflow/core/profiler/lib/traceme.h"
32 #include "tensorflow/core/protobuf/master.pb.h"
33 
34 namespace tensorflow {
35 
36 // GrpcRemoteMaster is an implementation of the MasterInterface
37 // that uses gRPC to talk to the Master service.
38 class GrpcRemoteMaster : public MasterInterface {
39   using MasterServiceStub = grpc::MasterService::Stub;
40 
41  public:
GrpcRemoteMaster(const SharedGrpcChannelPtr & client_channel)42   explicit GrpcRemoteMaster(const SharedGrpcChannelPtr& client_channel)
43       : stub_(grpc::MasterService::NewStub(client_channel)) {}
44 
~GrpcRemoteMaster()45   ~GrpcRemoteMaster() override {}
46 
CreateSession(CallOptions * call_options,const CreateSessionRequest * request,CreateSessionResponse * response)47   Status CreateSession(CallOptions* call_options,
48                        const CreateSessionRequest* request,
49                        CreateSessionResponse* response) override {
50     return CallWithRetry(call_options, request, response,
51                          &MasterServiceStub::CreateSession);
52   }
53 
ExtendSession(CallOptions * call_options,const ExtendSessionRequest * request,ExtendSessionResponse * response)54   Status ExtendSession(CallOptions* call_options,
55                        const ExtendSessionRequest* request,
56                        ExtendSessionResponse* response) override {
57     return CallWithRetry(call_options, request, response,
58                          &MasterServiceStub::ExtendSession);
59   }
60 
PartialRunSetup(CallOptions * call_options,const PartialRunSetupRequest * request,PartialRunSetupResponse * response)61   Status PartialRunSetup(CallOptions* call_options,
62                          const PartialRunSetupRequest* request,
63                          PartialRunSetupResponse* response) override {
64     return CallWithRetry(call_options, request, response,
65                          &MasterServiceStub::PartialRunSetup);
66   }
67 
RunStep(CallOptions * call_options,RunStepRequestWrapper * request,MutableRunStepResponseWrapper * response)68   Status RunStep(CallOptions* call_options, RunStepRequestWrapper* request,
69                  MutableRunStepResponseWrapper* response) override {
70     return CallWithRetry(call_options, &request->ToProto(),
71                          get_proto_from_wrapper(response),
72                          &MasterServiceStub::RunStep, "RunStep/Client");
73   }
74 
CloseSession(CallOptions * call_options,const CloseSessionRequest * request,CloseSessionResponse * response)75   Status CloseSession(CallOptions* call_options,
76                       const CloseSessionRequest* request,
77                       CloseSessionResponse* response) override {
78     return CallWithRetry(call_options, request, response,
79                          &MasterServiceStub::CloseSession);
80   }
81 
ListDevices(CallOptions * call_options,const ListDevicesRequest * request,ListDevicesResponse * response)82   Status ListDevices(CallOptions* call_options,
83                      const ListDevicesRequest* request,
84                      ListDevicesResponse* response) override {
85     return CallWithRetry(call_options, request, response,
86                          &MasterServiceStub::ListDevices);
87   }
88 
Reset(CallOptions * call_options,const ResetRequest * request,ResetResponse * response)89   Status Reset(CallOptions* call_options, const ResetRequest* request,
90                ResetResponse* response) override {
91     return CallWithRetry(call_options, request, response,
92                          &MasterServiceStub::Reset);
93   }
94 
MakeCallable(CallOptions * call_options,const MakeCallableRequest * request,MakeCallableResponse * response)95   Status MakeCallable(CallOptions* call_options,
96                       const MakeCallableRequest* request,
97                       MakeCallableResponse* response) override {
98     return CallWithRetry(call_options, request, response,
99                          &MasterServiceStub::MakeCallable);
100   }
RunCallable(CallOptions * call_options,const RunCallableRequest * request,RunCallableResponse * response)101   Status RunCallable(CallOptions* call_options,
102                      const RunCallableRequest* request,
103                      RunCallableResponse* response) override {
104     return CallWithRetry(call_options, request, response,
105                          &MasterServiceStub::RunCallable);
106   }
ReleaseCallable(CallOptions * call_options,const ReleaseCallableRequest * request,ReleaseCallableResponse * response)107   Status ReleaseCallable(CallOptions* call_options,
108                          const ReleaseCallableRequest* request,
109                          ReleaseCallableResponse* response) override {
110     return CallWithRetry(call_options, request, response,
111                          &MasterServiceStub::ReleaseCallable);
112   }
113 
114  private:
115   // Start tracing, attaching a unique ID to both the trace and the RPC.
NewTraceRpc(StringPiece name,::grpc::ClientContext * ctx)116   profiler::TraceMe* NewTraceRpc(StringPiece name, ::grpc::ClientContext* ctx) {
117     string trace_id = strings::StrCat(tracing::GetUniqueArg());
118     ctx->AddMetadata(GrpcIdKey(), trace_id);
119     return new profiler::TraceMe(
120         [&] { return strings::StrCat(name, ":", trace_id); },
121         profiler::TraceMeLevel::kInfo);
122   }
123 
124   template <typename Request, typename Response>
CallWithRetry(CallOptions * call_options,const Request * request,Response * response,::grpc::Status (MasterServiceStub::* pfunc)(::grpc::ClientContext *,const Request &,Response *),string trace_string={})125   Status CallWithRetry(CallOptions* call_options, const Request* request,
126                        Response* response,
127                        ::grpc::Status (MasterServiceStub::*pfunc)(
128                            ::grpc::ClientContext*, const Request&, Response*),
129                        string trace_string = {}) {
130     absl::Duration timeout = absl::Milliseconds(call_options->GetTimeout());
131     absl::Time expired_time = absl::FromUnixMicros(Env::Default()->NowMicros());
132     if (timeout > absl::ZeroDuration()) {
133       expired_time += timeout;
134     }
135     Status s;
136     for (int num_retries = 0;; ++num_retries) {
137       ::grpc::ClientContext ctx;
138       std::unique_ptr<profiler::TraceMe> trace;
139       if (!trace_string.empty()) {
140         trace.reset(NewTraceRpc(trace_string, &ctx));
141       }
142       ctx.set_fail_fast(false);
143       if (timeout > absl::ZeroDuration()) {
144         // We do not modify the timeout here to match legacy behavior. However,
145         // this could violate the contract of tensorflow::Session. If we retry
146         // an RPC just before the deadline is exceeded, we will still set the
147         // timeout to the original value. This leads to the overall timeout
148         // being double what was expected.
149         ctx.set_deadline(absl::ToChronoTime(absl::Now() + timeout));
150       }
151       s = FromGrpcStatus((stub_.get()->*pfunc)(&ctx, *request, response));
152       if (!errors::IsUnavailable(s)) {
153         return s;
154       }
155       // TODO(b/117162170): we may want to make this configurable.
156       constexpr int kMaxRetries = 10;
157       LOG(WARNING) << "RPC failed with status = \"" << s
158                    << "\" and grpc_error_string = \""
159                    << ctx.debug_error_string() << "\", maybe retrying the RPC";
160       if (num_retries >= kMaxRetries) {
161         LOG(WARNING) << "Too many retries, returning last status: " << s;
162         return s;
163       }
164       absl::Time now = absl::FromUnixMicros(Env::Default()->NowMicros());
165       const absl::Time deadline_with_backoff =
166           now + absl::Microseconds(ComputeBackoffMicroseconds(num_retries));
167       // Wait for a short period of time before retrying the RPC.  If our
168       // backoff would put us past the RPC deadline, we truncate it to ensure
169       // our RPC starts before the deadline.
170       const auto backoff_until = (timeout <= absl::ZeroDuration() ||
171                                   expired_time > deadline_with_backoff)
172                                      ? deadline_with_backoff
173                                      : expired_time;
174       Env::Default()->SleepForMicroseconds(
175           absl::ToInt64Microseconds(backoff_until - now));
176       now = absl::FromUnixMicros(Env::Default()->NowMicros());
177       if (now > expired_time && timeout > absl::ZeroDuration()) {
178         // If timeout_in_ms is set, exit the retry loop on timeout.
179         return errors::DeadlineExceeded(ctx.debug_error_string());
180       }
181     }
182   }
183 
184   std::unique_ptr<MasterServiceStub> stub_;
185 };
186 
NewGrpcMaster(const SharedGrpcChannelPtr & channel)187 MasterInterface* NewGrpcMaster(const SharedGrpcChannelPtr& channel) {
188   return new GrpcRemoteMaster(channel);
189 }
190 
191 }  // namespace tensorflow
192