xref: /aosp_15_r20/external/tensorflow/tensorflow/core/distributed_runtime/rpc/grpc_tensor_coding.cc (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/core/distributed_runtime/rpc/grpc_tensor_coding.h"
17 
18 #include "grpcpp/support/byte_buffer.h"
19 #include "grpcpp/support/slice.h"
20 #include "tensorflow/core/common_runtime/dma_helper.h"
21 #include "tensorflow/core/framework/tensor.h"
22 #include "tensorflow/core/framework/tensor.pb.h"
23 #include "tensorflow/core/framework/tensor_reference.h"
24 #include "tensorflow/core/framework/tensor_shape.pb.h"
25 #include "tensorflow/core/lib/gtl/inlined_vector.h"
26 #include "tensorflow/core/lib/io/proto_encode_helper.h"
27 #include "tensorflow/core/platform/env.h"
28 #include "tensorflow/core/protobuf/worker.pb.h"
29 
30 namespace tensorflow {
31 namespace grpc {
32 
EncodeRecvTensorResponseToByteBuffer(const RecvTensorResponse & proto,::grpc::ByteBuffer * result)33 void EncodeRecvTensorResponseToByteBuffer(const RecvTensorResponse& proto,
34                                           ::grpc::ByteBuffer* result) {
35   ::grpc::Slice slice(proto.ByteSizeLong());
36   proto.SerializeWithCachedSizesToArray(
37       const_cast<uint8*>(reinterpret_cast<const uint8*>(slice.begin())));
38   ::grpc::ByteBuffer tmp(&slice, 1);
39   result->Swap(&tmp);
40 }
41 
42 // We generate a RecvTensorResponse protocol buffer encoding into "*result",
43 // but where possible, we share the underlying Tensor buffer for "val", to
44 // avoid an extra copy.
45 //
46 // We hand-encode the protocol buffer data in the following order, as follows:
47 //
48 // Let R be a RecvTensorResponse object we want to encode, logically
49 // constructed by filling in data from "is_dead" and "val" and filling
50 // in a few other fields as well.
51 //
52 // (Letters here are used in the code to refer back to which part of the
53 //  encoding the code is generating).
54 //
55 // A:   <protocol buffer encoding of fields except R.tensor()>
56 // B1:  <tag encoding for RecvTensorResponse::tensor>
57 // B2:  <varint32 length of R.tensor() sub message>
58 // C:   <protocol buffer encoding of R.tensor() except for
59 //          R.tensor().tensor_content()>
60 // D1:  <tag encoding for TensorProto::tensor_content>
61 // D2:  <varint32 length of R.tensor().tensor_content() data>
62 // E:   <actual data for val's representation>
63 //
64 // If the tensor data is up to "kLargeTensorBytes", then A
65 // through E will all be encoded into "*result" in a single grpc::Slice.
66 //
67 // If the tensor data is larger than "kLargeTensorBytes", then A through
68 // D2 will be encoded in one grpc::Slice, and E will be encoded in a second
69 // grpc::Slice that points to the backing store for the tensor data, to avoid
70 // copying the tensor data (and the grpc::Slice setup will be arrange so as
71 // to dereference the underlying tensor data buffer when it is no longer
72 // needed in the "*result" ByteBuffer).
VarLengthEncodingSize(uint32 tag,size_t bytes)73 static int VarLengthEncodingSize(uint32 tag, size_t bytes) {
74   return core::VarintLength(tag << 3) + core::VarintLength(bytes) + bytes;
75 }
76 
77 // Returns an upper bound in bytes of the protocol buffer encoding of
78 // the "skeleton" of "val" (all the data needed for dtype and the shape,
79 // but not the actual contents of "val").
SkeletonEncodingSizeUpperBound(const Tensor & val)80 static int SkeletonEncodingSizeUpperBound(const Tensor& val) {
81   static const int kVarintMax64 = 10;  // Max length of varint64 encoding
82   const int ndims = val.shape().dims();
83   return (2 * kVarintMax64) +           // dtype
84          (ndims * (4 * kVarintMax64));  // Shape: 4 varints per dim
85 }
86 
87 // Encode the skeleton for "val" (the encoded TensorProto contents
88 // (dtype and shape, but not the actual data) into "*e".  The backing
89 // store for "*e" must be of appropriate size to hold this encoding.
EncodeSkeleton(const Tensor & val,io::ProtoEncodeHelper * e)90 static void EncodeSkeleton(const Tensor& val, io::ProtoEncodeHelper* e) {
91   // Encode val.dtype()
92   e->WriteUint64(TensorProto::kDtypeFieldNumber, val.dtype());
93 
94   // Compute length of val.shape() proto encoding
95   const int ndims = val.shape().dims();
96   int tensor_shape_bytes = 0;
97   for (int d = 0; d < ndims; d++) {
98     int64_t dim_size = val.shape().dim_size(d);
99     tensor_shape_bytes +=
100         2 +  // TensorShapeProto dim tag + varintlength of submessage
101         1 +  // TensorShapeProto_Dim::kSizeFieldNumber
102         core::VarintLength(dim_size);
103   }
104 
105   if (tensor_shape_bytes > 0) {
106     e->WriteVarlengthBeginning(TensorProto::kTensorShapeFieldNumber,
107                                tensor_shape_bytes);
108     // Encode val.shape()
109     for (int d = 0; d < ndims; d++) {
110       int64_t dim_size = val.shape().dim_size(d);
111       int64_t dim_varlen = 1 +  // TensorShapeProto_Dim::kSizeFieldNumber
112                            core::VarintLength(dim_size);
113       e->WriteVarlengthBeginning(TensorShapeProto::kDimFieldNumber, dim_varlen);
114       e->WriteUint64(TensorShapeProto_Dim::kSizeFieldNumber, dim_size);
115     }
116   }
117 
118 #ifndef NDEBUG
119   {
120     // Debug-mode only check to make sure the encoding above is
121     // identical to the auto-generated protocol buffer encoding.
122     TensorProto skeleton;
123     skeleton.set_dtype(val.dtype());
124     val.shape().AsProto(skeleton.mutable_tensor_shape());
125     string tensor_except_contents;  // tensor() field except contents
126     skeleton.AppendToString(&tensor_except_contents);
127     TensorProto skeleton2;
128     skeleton2.ParseFromString(string(e->data(), e->size()));
129     string out;
130     skeleton.AppendToString(&out);
131     DCHECK_EQ(tensor_except_contents, out) << skeleton.DebugString() << " vs\n"
132                                            << skeleton2.DebugString();
133   }
134 #endif
135 }
136 
EncodeTensorToByteBuffer(bool is_dead,const Tensor & val,bool require_ack,::grpc::ByteBuffer * result)137 void EncodeTensorToByteBuffer(bool is_dead, const Tensor& val, bool require_ack,
138                               ::grpc::ByteBuffer* result) {
139   const int kLargeTensorBytes = 1024;
140   const int64_t kProtoBufLimitBytes = 1LL << 31;
141 
142   if (val.TotalBytes() > kProtoBufLimitBytes) {
143     size_t exceeded_bytes = val.TotalBytes() - kProtoBufLimitBytes;
144     LOG(FATAL) << "Cannot encode a Tensor that exceeds the 2GB protobuf limit. "
145                   "Exceeded bytes: "
146                << exceeded_bytes;
147   }
148 
149   RecvTensorResponse response;
150   if (is_dead) {
151     response.set_is_dead(is_dead);
152   }
153   response.set_require_ack(require_ack);
154   response.set_send_start_micros(Env::Default()->NowMicros());
155   if (!DataTypeCanUseMemcpy(val.dtype())) {
156     // Straightforward but slow path for complicated kinds of tensor data
157     // TODO(jeff,sanjay): If this becomes an issue, we could
158     // go directly from val -> ByteBuffer, with some effort.
159     val.AsProtoTensorContent(response.mutable_tensor());
160 
161     // Encode full protocol buffer to a ByteBuffer
162     EncodeRecvTensorResponseToByteBuffer(response, result);
163   } else {
164     // skeleton is the encoded TensorProto contents (dtype and shape), but
165     // not the actual data
166     gtl::InlinedVector<char, 128> skeleton(SkeletonEncodingSizeUpperBound(val));
167     io::ProtoEncodeHelper e_skeleton(skeleton.data(), skeleton.size());
168     EncodeSkeleton(val, &e_skeleton);
169 
170     StringPiece tdata = val.tensor_data();
171     uint32 overall_tensor_proto_bytesize =
172         (e_skeleton.size() +
173          VarLengthEncodingSize(TensorProto::kTensorContentFieldNumber,
174                                tdata.size()));
175     string header;  // All of RecvTensorResponse except the tensor() field
176     response.AppendToString(&header);
177 
178     size_t expected_size =
179         (header.size() +
180          VarLengthEncodingSize(RecvTensorResponse::kTensorFieldNumber,
181                                overall_tensor_proto_bytesize));
182     // If "share_tensor_slice_memory == false", we copy the tensor data to
183     // the end of the buffer we are preparing that holds the rest of the
184     // RecvTensorResponse protocol buffer.
185     //
186     // If "share_tensor_slice_memory == true", we arrange to share the
187     // backing store of the data by creating a slice that also points to the
188     // backing store, with appropriate reference counts to keep the
189     // backing store alive as needed.
190     //
191     // We enable this behavior if the tensor is large.
192     bool share_tensor_slice_memory = (tdata.size() > kLargeTensorBytes);
193 
194     size_t encoder_size = expected_size - tdata.size();
195 
196     // Encode all but the actual "tdata", but including the tag and
197     // varlength header for the "tdata"
198     gtl::InlinedVector<char, 1024> space(encoder_size);
199     io::ProtoEncodeHelper e(space.data(), space.size());
200     // (A)
201     e.WriteRawBytes(header);
202 
203     // (B1) & (B2)
204     e.WriteVarlengthBeginning(RecvTensorResponse::kTensorFieldNumber,
205                               overall_tensor_proto_bytesize);
206     // (C)
207     e.WriteRawBytes(StringPiece(e_skeleton.data(), e_skeleton.size()));
208     // (D1) & (D2)
209     e.WriteVarlengthBeginning(TensorProto::kTensorContentFieldNumber,
210                               tdata.size());
211 
212     // All but the tensor backing store are serialized now
213 
214     // Now allocate memory and put into the ByteBuffer
215     ::grpc::Slice slices[2];
216     int num_slices = 0;
217     {
218       size_t slice_len =
219           e.size() + (share_tensor_slice_memory ? 0 : tdata.size());
220       slices[0] = ::grpc::Slice(slice_len);
221       memcpy(const_cast<uint8_t*>(slices[0].begin()), e.data(), e.size());
222       if (!share_tensor_slice_memory) {
223         // (E)
224         memcpy(const_cast<uint8_t*>(slices[0].begin()) + e.size(), tdata.data(),
225                tdata.size());
226       }
227       num_slices += 1;
228     }
229 
230     if (share_tensor_slice_memory) {
231       // (E) Encode tensor data, but by sharing backing store
232       const TensorBuffer* buf = DMAHelper::buffer(&val);
233       buf->Ref();
234       slices[1] = ::grpc::Slice(
235           const_cast<void*>(static_cast<const void*>(tdata.data())),
236           tdata.size(),
237           [](void* backing) { static_cast<TensorBuffer*>(backing)->Unref(); },
238           const_cast<TensorBuffer*>(buf));
239       num_slices += 1;
240     }
241     size_t total_bytes = 0;
242     for (int i = 0; i < num_slices; i++) {
243       total_bytes += slices[i].size();
244     }
245     CHECK_EQ(total_bytes, expected_size);
246 
247     ::grpc::ByteBuffer tmp(&slices[0], num_slices);
248     result->Swap(&tmp);
249   }
250 }
251 
252 }  // namespace grpc
253 }  // namespace tensorflow
254