xref: /aosp_15_r20/external/tensorflow/tensorflow/core/distributed_runtime/rpc/grpc_util.h (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #ifndef TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_RPC_GRPC_UTIL_H_
17 #define TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_RPC_GRPC_UTIL_H_
18 
19 #include <memory>
20 #include <string>
21 
22 #include "grpcpp/grpcpp.h"
23 #include "grpcpp/impl/codegen/proto_utils.h"
24 #include "grpcpp/support/byte_buffer.h"
25 #include "tensorflow/core/distributed_runtime/error_payloads.h"
26 #include "tensorflow/core/distributed_runtime/tensor_coding.h"
27 #include "tensorflow/core/lib/core/status.h"
28 #include "tensorflow/core/lib/strings/stringprintf.h"
29 #include "tensorflow/core/platform/mutex.h"
30 #include "tensorflow/core/platform/protobuf.h"
31 #include "tensorflow/core/platform/stringpiece.h"
32 #include "tensorflow/core/protobuf/distributed_runtime_payloads.pb.h"
33 
34 namespace tensorflow {
35 
36 // Given the total number of RPC retries attempted, return a randomized
37 // amount of time to delay before retrying the request.
38 //
39 // The average computed backoff increases with the number of RPCs attempted.
40 // See implementation for details on the calculations.
41 int64_t ComputeBackoffMicroseconds(int current_retry_attempt,
42                                    int64_t min_delay = 1000,
43                                    int64_t max_delay = 10000000);
44 
45 // Thin wrapper around ::grpc::ProtoBufferReader to give TensorResponse an
46 // efficient byte reader from which to decode a RecvTensorResponse.
47 class GrpcByteSource : public TensorResponse::Source {
48  public:
GrpcByteSource(::grpc::ByteBuffer * buffer)49   explicit GrpcByteSource(::grpc::ByteBuffer* buffer) : buffer_(buffer) {}
~GrpcByteSource()50   ~GrpcByteSource() override { DeleteStream(); }
51 
52   typedef ::grpc::ProtoBufferReader Reader;
53 
contents()54   protobuf::io::ZeroCopyInputStream* contents() override {
55     DeleteStream();
56     stream_ = new (&space_) Reader(buffer_);
57     return stream_;
58   }
59 
60  private:
DeleteStream()61   void DeleteStream() {
62     if (stream_) {
63       stream_->~Reader();
64     }
65   }
66 
67   ::grpc::ByteBuffer* buffer_;  // Not owned
68   Reader* stream_ = nullptr;    // Points into space_ if non-nullptr
69   char space_[sizeof(Reader)];
70 };
71 
72 constexpr char kStreamRemovedMessage[] = "Stream removed";
73 
74 // Identify if the given grpc::Status corresponds to an HTTP stream removed
75 // error (see chttp2_transport.cc).
76 //
77 // When auto-reconnecting to a remote TensorFlow worker after it restarts, gRPC
78 // can return an UNKNOWN error code with a "Stream removed" error message.
79 // This should not be treated as an unrecoverable error.
80 //
81 // N.B. This is dependent on the error message from grpc remaining consistent.
IsStreamRemovedError(const::grpc::Status & s)82 inline bool IsStreamRemovedError(const ::grpc::Status& s) {
83   return !s.ok() && s.error_code() == ::grpc::StatusCode::UNKNOWN &&
84          s.error_message() == kStreamRemovedMessage;
85 }
86 
SerializePayloads(const::tensorflow::Status & s)87 inline std::string SerializePayloads(const ::tensorflow::Status& s) {
88   distributed_runtime::GrpcPayloadContainer container;
89   s.ForEachPayload(
90       [&container](tensorflow::StringPiece key, tensorflow::StringPiece value) {
91         (*container.mutable_payloads())[std::string(key)] = std::string(value);
92       });
93   return container.SerializeAsString();
94 }
95 
InsertSerializedPayloads(::tensorflow::Status & s,std::string payloads)96 inline void InsertSerializedPayloads(::tensorflow::Status& s,
97                                      std::string payloads) {
98   distributed_runtime::GrpcPayloadContainer container;
99   if (container.ParseFromString(payloads)) {
100     for (const auto& key_val : container.payloads()) {
101       s.SetPayload(key_val.first, key_val.second);
102     }
103   } else {
104     s.SetPayload(kGrpcPayloadsLost,
105                  distributed_runtime::GrpcPayloadsLost().SerializeAsString());
106   }
107 }
108 
FromGrpcStatus(const::grpc::Status & s)109 inline ::tensorflow::Status FromGrpcStatus(const ::grpc::Status& s) {
110   if (s.ok()) {
111     return OkStatus();
112   } else {
113     ::tensorflow::Status converted;
114     // Convert "UNKNOWN" stream removed errors into unavailable, to allow
115     // for retry upstream.
116     if (IsStreamRemovedError(s)) {
117       converted = Status(tensorflow::error::UNAVAILABLE, s.error_message());
118     }
119     converted = Status(static_cast<tensorflow::error::Code>(s.error_code()),
120                        s.error_message());
121     InsertSerializedPayloads(converted, s.error_details());
122     return converted;
123   }
124 }
125 
ToGrpcStatus(const::tensorflow::Status & s)126 inline ::grpc::Status ToGrpcStatus(const ::tensorflow::Status& s) {
127   if (s.ok()) {
128     return ::grpc::Status::OK;
129   } else {
130     if (s.error_message().size() > 3072 /* 3k bytes */) {
131       // TODO(b/62947679): Remove truncation once the gRPC issue is resolved.
132       string scratch =
133           strings::Printf("%.3072s ... [truncated]", s.error_message().c_str());
134       LOG(ERROR) << "Truncated error message: " << s;
135       return ::grpc::Status(static_cast<::grpc::StatusCode>(s.code()), scratch,
136                             SerializePayloads(s));
137     }
138     return ::grpc::Status(static_cast<::grpc::StatusCode>(s.code()),
139                           s.error_message(), SerializePayloads(s));
140   }
141 }
142 
143 typedef std::shared_ptr<::grpc::Channel> SharedGrpcChannelPtr;
144 
GrpcIdKey()145 inline string GrpcIdKey() { return "tf-rpc"; }
146 
147 // Serialize src and store in *dst.
148 ::grpc::Status GrpcMaybeUnparseProto(const protobuf::Message& src,
149                                      ::grpc::ByteBuffer* dst);
150 
151 // Parse contents of src and initialize *dst with them.
152 bool GrpcMaybeParseProto(::grpc::ByteBuffer* src, protobuf::Message* dst);
153 
154 // Specialization for TensorResponse
155 bool GrpcMaybeParseProto(::grpc::ByteBuffer* src, TensorResponse* dst);
156 
157 // Copy string src to grpc buffer *dst.
158 ::grpc::Status GrpcMaybeUnparseProto(const string& src,
159                                      ::grpc::ByteBuffer* dst);
160 
161 // Copy grpc buffer src to string *dst.
162 bool GrpcMaybeParseProto(::grpc::ByteBuffer* src, string* dst);
163 
164 // Copy grpc buffer src to tstring *dst.
165 bool GrpcMaybeParseProto(::grpc::ByteBuffer* src, tstring* dst);
166 
167 }  // namespace tensorflow
168 
169 #endif  // TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_RPC_GRPC_UTIL_H_
170