1 /* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #ifndef TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_RPC_GRPC_UTIL_H_
17 #define TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_RPC_GRPC_UTIL_H_
18
19 #include <memory>
20 #include <string>
21
22 #include "grpcpp/grpcpp.h"
23 #include "grpcpp/impl/codegen/proto_utils.h"
24 #include "grpcpp/support/byte_buffer.h"
25 #include "tensorflow/core/distributed_runtime/error_payloads.h"
26 #include "tensorflow/core/distributed_runtime/tensor_coding.h"
27 #include "tensorflow/core/lib/core/status.h"
28 #include "tensorflow/core/lib/strings/stringprintf.h"
29 #include "tensorflow/core/platform/mutex.h"
30 #include "tensorflow/core/platform/protobuf.h"
31 #include "tensorflow/core/platform/stringpiece.h"
32 #include "tensorflow/core/protobuf/distributed_runtime_payloads.pb.h"
33
34 namespace tensorflow {
35
36 // Given the total number of RPC retries attempted, return a randomized
37 // amount of time to delay before retrying the request.
38 //
39 // The average computed backoff increases with the number of RPCs attempted.
40 // See implementation for details on the calculations.
41 int64_t ComputeBackoffMicroseconds(int current_retry_attempt,
42 int64_t min_delay = 1000,
43 int64_t max_delay = 10000000);
44
45 // Thin wrapper around ::grpc::ProtoBufferReader to give TensorResponse an
46 // efficient byte reader from which to decode a RecvTensorResponse.
47 class GrpcByteSource : public TensorResponse::Source {
48 public:
GrpcByteSource(::grpc::ByteBuffer * buffer)49 explicit GrpcByteSource(::grpc::ByteBuffer* buffer) : buffer_(buffer) {}
~GrpcByteSource()50 ~GrpcByteSource() override { DeleteStream(); }
51
52 typedef ::grpc::ProtoBufferReader Reader;
53
contents()54 protobuf::io::ZeroCopyInputStream* contents() override {
55 DeleteStream();
56 stream_ = new (&space_) Reader(buffer_);
57 return stream_;
58 }
59
60 private:
DeleteStream()61 void DeleteStream() {
62 if (stream_) {
63 stream_->~Reader();
64 }
65 }
66
67 ::grpc::ByteBuffer* buffer_; // Not owned
68 Reader* stream_ = nullptr; // Points into space_ if non-nullptr
69 char space_[sizeof(Reader)];
70 };
71
72 constexpr char kStreamRemovedMessage[] = "Stream removed";
73
74 // Identify if the given grpc::Status corresponds to an HTTP stream removed
75 // error (see chttp2_transport.cc).
76 //
77 // When auto-reconnecting to a remote TensorFlow worker after it restarts, gRPC
78 // can return an UNKNOWN error code with a "Stream removed" error message.
79 // This should not be treated as an unrecoverable error.
80 //
81 // N.B. This is dependent on the error message from grpc remaining consistent.
IsStreamRemovedError(const::grpc::Status & s)82 inline bool IsStreamRemovedError(const ::grpc::Status& s) {
83 return !s.ok() && s.error_code() == ::grpc::StatusCode::UNKNOWN &&
84 s.error_message() == kStreamRemovedMessage;
85 }
86
SerializePayloads(const::tensorflow::Status & s)87 inline std::string SerializePayloads(const ::tensorflow::Status& s) {
88 distributed_runtime::GrpcPayloadContainer container;
89 s.ForEachPayload(
90 [&container](tensorflow::StringPiece key, tensorflow::StringPiece value) {
91 (*container.mutable_payloads())[std::string(key)] = std::string(value);
92 });
93 return container.SerializeAsString();
94 }
95
InsertSerializedPayloads(::tensorflow::Status & s,std::string payloads)96 inline void InsertSerializedPayloads(::tensorflow::Status& s,
97 std::string payloads) {
98 distributed_runtime::GrpcPayloadContainer container;
99 if (container.ParseFromString(payloads)) {
100 for (const auto& key_val : container.payloads()) {
101 s.SetPayload(key_val.first, key_val.second);
102 }
103 } else {
104 s.SetPayload(kGrpcPayloadsLost,
105 distributed_runtime::GrpcPayloadsLost().SerializeAsString());
106 }
107 }
108
FromGrpcStatus(const::grpc::Status & s)109 inline ::tensorflow::Status FromGrpcStatus(const ::grpc::Status& s) {
110 if (s.ok()) {
111 return OkStatus();
112 } else {
113 ::tensorflow::Status converted;
114 // Convert "UNKNOWN" stream removed errors into unavailable, to allow
115 // for retry upstream.
116 if (IsStreamRemovedError(s)) {
117 converted = Status(tensorflow::error::UNAVAILABLE, s.error_message());
118 }
119 converted = Status(static_cast<tensorflow::error::Code>(s.error_code()),
120 s.error_message());
121 InsertSerializedPayloads(converted, s.error_details());
122 return converted;
123 }
124 }
125
ToGrpcStatus(const::tensorflow::Status & s)126 inline ::grpc::Status ToGrpcStatus(const ::tensorflow::Status& s) {
127 if (s.ok()) {
128 return ::grpc::Status::OK;
129 } else {
130 if (s.error_message().size() > 3072 /* 3k bytes */) {
131 // TODO(b/62947679): Remove truncation once the gRPC issue is resolved.
132 string scratch =
133 strings::Printf("%.3072s ... [truncated]", s.error_message().c_str());
134 LOG(ERROR) << "Truncated error message: " << s;
135 return ::grpc::Status(static_cast<::grpc::StatusCode>(s.code()), scratch,
136 SerializePayloads(s));
137 }
138 return ::grpc::Status(static_cast<::grpc::StatusCode>(s.code()),
139 s.error_message(), SerializePayloads(s));
140 }
141 }
142
143 typedef std::shared_ptr<::grpc::Channel> SharedGrpcChannelPtr;
144
GrpcIdKey()145 inline string GrpcIdKey() { return "tf-rpc"; }
146
147 // Serialize src and store in *dst.
148 ::grpc::Status GrpcMaybeUnparseProto(const protobuf::Message& src,
149 ::grpc::ByteBuffer* dst);
150
151 // Parse contents of src and initialize *dst with them.
152 bool GrpcMaybeParseProto(::grpc::ByteBuffer* src, protobuf::Message* dst);
153
154 // Specialization for TensorResponse
155 bool GrpcMaybeParseProto(::grpc::ByteBuffer* src, TensorResponse* dst);
156
157 // Copy string src to grpc buffer *dst.
158 ::grpc::Status GrpcMaybeUnparseProto(const string& src,
159 ::grpc::ByteBuffer* dst);
160
161 // Copy grpc buffer src to string *dst.
162 bool GrpcMaybeParseProto(::grpc::ByteBuffer* src, string* dst);
163
164 // Copy grpc buffer src to tstring *dst.
165 bool GrpcMaybeParseProto(::grpc::ByteBuffer* src, tstring* dst);
166
167 } // namespace tensorflow
168
169 #endif // TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_RPC_GRPC_UTIL_H_
170