xref: /aosp_15_r20/external/tensorflow/tensorflow/core/distributed_runtime/tensor_coding.cc (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/core/distributed_runtime/tensor_coding.h"
17 
18 #include "google/protobuf/any.pb.h"
19 
20 #include "tensorflow/core/common_runtime/device.h"
21 #include "tensorflow/core/framework/tensor.pb.h"
22 #include "tensorflow/core/framework/tensor_shape.pb.h"
23 
24 namespace tensorflow {
25 
~Source()26 TensorResponse::Source::~Source() {}
27 
Clear()28 void TensorResponse::Clear() {
29   on_host_ = false;
30   device_ = nullptr;
31   alloc_attrs_ = AllocatorAttributes();
32   allocator_ = nullptr;
33   already_used_ = false;
34   ClearTensor();
35 }
36 
ClearTensor()37 void TensorResponse::ClearTensor() {
38   meta_.Clear();
39   tensor_ = Tensor();
40 }
41 
InitAlloc(DeviceBase * d,const AllocatorAttributes & aa)42 void TensorResponse::InitAlloc(DeviceBase* d, const AllocatorAttributes& aa) {
43   Clear();
44   device_ = d;
45   alloc_attrs_ = aa;
46   const DeviceAttributes& da = d->attributes();
47   if (alloc_attrs_.on_host() || da.device_type() == "CPU") {
48     on_host_ = true;
49   }
50   allocator_ = device_->GetAllocator(alloc_attrs_);
51 }
52 
InitFrom(RecvTensorResponse * response)53 Status TensorResponse::InitFrom(RecvTensorResponse* response) {
54   Status s;
55   meta_.Swap(response);
56   if (on_host_) {
57     if (!tensor_.FromProto(allocator_, meta_.tensor())) {
58       s = errors::InvalidArgument("Cannot parse tensor from response");
59     }
60   } else {
61     s = device_->MakeTensorFromProto(meta_.tensor(), alloc_attrs_, &tensor_);
62   }
63   {
64     TensorProto empty;
65     meta_.mutable_tensor()->Swap(&empty);
66   }
67   meta_.clear_tensor();
68   return s;
69 }
70 
InitPartial(const RecvTensorResponse & response,const AllocationAttributes & allocation_attr)71 void TensorResponse::InitPartial(const RecvTensorResponse& response,
72                                  const AllocationAttributes& allocation_attr) {
73   // Everything except content is present in *response.  Content will
74   // arrive later; allocate a Tensor with appropriate storage for that
75   // content.
76   meta_ = response;
77   TensorShape shape(meta_.tensor().tensor_shape());
78   Tensor t(allocator_, meta_.tensor().dtype(), shape, allocation_attr);
79   tensor_ = std::move(t);
80 }
81 
ParseFrom(Source * source)82 Status TensorResponse::ParseFrom(Source* source) {
83   if (!on_host_) {
84     protobuf::io::CodedInputStream input(source->contents());
85 
86     // Pre-parse into local storage, then delegate to device.
87     if (!meta_.ParseFromCodedStream(&input) || !input.ConsumedEntireMessage()) {
88       return errors::InvalidArgument("Cannot parse tensor from response");
89     }
90     Status s =
91         device_->MakeTensorFromProto(meta_.tensor(), alloc_attrs_, &tensor_);
92     // Reduce memory usage for big tensors.
93     {
94       TensorProto empty;
95       meta_.mutable_tensor()->Swap(&empty);
96     }
97     meta_.clear_tensor();
98     return s;
99   }
100   if (already_used_) {
101     ClearTensor();
102   }
103   already_used_ = true;
104   if (ParseFast(source)) return OkStatus();
105   meta_.Clear();
106   if (ParseSlow(source)) return OkStatus();
107   return errors::InvalidArgument("Cannot parse tensor from response");
108 }
109 
110 // Define some helper routines for decoding protocol buffer wire format data
111 namespace {
112 // We only need some of the wiretype values for this code
113 enum WireType {
114   WIRETYPE_VARINT = 0,
115   WIRETYPE_LENGTH_DELIMITED = 2,
116 };
GetTagFieldNumber(uint32 tag)117 inline int GetTagFieldNumber(uint32 tag) { return tag >> 3; }
GetTagWireType(uint32 tag)118 inline WireType GetTagWireType(uint32 tag) {
119   return static_cast<WireType>(tag & 0x7);
120 }
121 
ReadVarintSizeAsInt(protobuf::io::CodedInputStream * input,int * result)122 bool ReadVarintSizeAsInt(protobuf::io::CodedInputStream* input, int* result) {
123   protobuf_uint64 v;
124   if (input->ReadVarint64(&v) && v <= static_cast<uint64>(INT_MAX)) {
125     *result = static_cast<int>(v);
126     return true;
127   } else {
128     return false;
129   }
130 }
131 
ReadNestedMessage(protobuf::io::CodedInputStream * input,protobuf::Message * value)132 bool ReadNestedMessage(protobuf::io::CodedInputStream* input,
133                        protobuf::Message* value) {
134   int length;
135   if (!ReadVarintSizeAsInt(input, &length)) return false;
136   std::pair<protobuf::io::CodedInputStream::Limit, int> p =
137       input->IncrementRecursionDepthAndPushLimit(length);
138   if (p.second < 0 || !value->MergePartialFromCodedStream(input)) return false;
139   // Make sure that parsing stopped when the limit was hit, not at an endgroup
140   // tag.
141   return input->DecrementRecursionDepthAndPopLimit(p.first);
142 }
143 
144 }  // namespace
145 
ParseTensorSubmessage(protobuf::io::CodedInputStream * input,TensorProto * tensor_meta)146 bool TensorResponse::ParseTensorSubmessage(
147     protobuf::io::CodedInputStream* input, TensorProto* tensor_meta) {
148   bool seen_tensor_content = false;
149   while (true) {
150     auto p = input->ReadTagWithCutoff(127);
151     int tag = GetTagFieldNumber(p.first);
152     WireType wt = GetTagWireType(p.first);
153     if (!p.second) {
154       bool ok = (tag == 0);
155       if (ok && !seen_tensor_content) {
156         // No tensor content: could be because it's a zero-length tensor
157         TensorShape shape(tensor_meta->tensor_shape());
158         Tensor t(allocator_, tensor_meta->dtype(), shape);
159         tensor_ = std::move(t);
160       }
161       return ok;
162     }
163     switch (tag) {
164       case TensorProto::kDtypeFieldNumber: {
165         uint32 v;
166         if ((wt != WIRETYPE_VARINT) || !input->ReadVarint32(&v)) return false;
167         if (seen_tensor_content) return false;
168         tensor_meta->set_dtype(static_cast<DataType>(static_cast<int>(v)));
169         if (!DataTypeCanUseMemcpy(tensor_meta->dtype())) return false;
170         break;
171       }
172       case TensorProto::kTensorShapeFieldNumber: {
173         if ((wt != WIRETYPE_LENGTH_DELIMITED) ||
174             !ReadNestedMessage(input, tensor_meta->mutable_tensor_shape()))
175           return false;
176         if (seen_tensor_content) return false;
177         break;
178       }
179       case TensorProto::kVersionNumberFieldNumber: {
180         uint32 v;
181         if ((wt != WIRETYPE_VARINT) || !input->ReadVarint32(&v)) return false;
182         if (seen_tensor_content) return false;
183         tensor_meta->set_version_number(static_cast<int32>(v));
184         break;
185       }
186       case TensorProto::kTensorContentFieldNumber: {
187         // If we haven't seen the dtype and tensor_shape data first, we can't
188         // deal with this in the fast path.
189         if (seen_tensor_content) return false;
190         if (wt != WIRETYPE_LENGTH_DELIMITED ||
191             !tensor_meta->has_tensor_shape()) {
192           return false;
193         }
194         int num_bytes;
195         if (!ReadVarintSizeAsInt(input, &num_bytes)) return false;
196         seen_tensor_content = true;
197         TensorShape shape(tensor_meta->tensor_shape());
198         Tensor t(allocator_, tensor_meta->dtype(), shape);
199         StringPiece buf = t.tensor_data();
200         if (static_cast<size_t>(num_bytes) != buf.size()) return false;
201         // TODO(jeff,sanjay): Figure out a way to avoid this copy if
202         // the underlying ZeroCopyInputStream data is properly aligned
203         // and compatible with what allocator_ wants.
204         if (!input->ReadRaw(const_cast<char*>(buf.data()), num_bytes))
205           return false;
206         tensor_ = std::move(t);
207         break;
208       }
209       default: {
210         // Some other tag our fast path code is not prepared to handle.
211         // return false.
212         return false;
213       }
214     }
215   }
216 }
217 
ParseFast(Source * source)218 bool TensorResponse::ParseFast(Source* source) {
219   protobuf::io::CodedInputStream input(source->contents());
220   while (true) {
221     auto p = input.ReadTagWithCutoff(127);
222     int tag = GetTagFieldNumber(p.first);
223     WireType wt = GetTagWireType(p.first);
224     if (!p.second) {
225       return (tag == 0);
226     }
227     switch (tag) {
228       case RecvTensorResponse::kTensorFieldNumber: {
229         if (wt != WIRETYPE_LENGTH_DELIMITED) return false;
230 
231         int length;
232         if (!ReadVarintSizeAsInt(&input, &length)) return false;
233         std::pair<protobuf::io::CodedInputStream::Limit, int> p =
234             input.IncrementRecursionDepthAndPushLimit(length);
235         if (p.second < 0 ||
236             !ParseTensorSubmessage(&input, meta_.mutable_tensor())) {
237           return false;
238         }
239         if (!input.DecrementRecursionDepthAndPopLimit(p.first)) {
240           return false;
241         }
242         break;
243       }
244       case RecvTensorResponse::kIsDeadFieldNumber: {
245         uint32 v;
246         if ((wt != WIRETYPE_VARINT) || !input.ReadVarint32(&v)) return false;
247         meta_.set_is_dead(v != 0);
248         break;
249       }
250       case RecvTensorResponse::kSendStartMicrosFieldNumber: {
251         protobuf_uint64 v;
252         if ((wt != WIRETYPE_VARINT) || !input.ReadVarint64(&v)) return false;
253         meta_.set_send_start_micros(static_cast<int64_t>(v));
254         break;
255       }
256       case RecvTensorResponse::kTransportOptionsFieldNumber: {
257         if ((wt != WIRETYPE_LENGTH_DELIMITED) ||
258             !ReadNestedMessage(&input, meta_.mutable_transport_options()))
259           return false;
260         break;
261       }
262       case RecvTensorResponse::kRequireAckFieldNumber: {
263         uint32 v;
264         if ((wt != WIRETYPE_VARINT) || !input.ReadVarint32(&v)) return false;
265         meta_.set_require_ack(v != 0);
266         break;
267       }
268       default: {
269         // Unknown tag, so don't handle we can't handle on the fast path
270         return false;
271       }
272     }
273   }
274 
275   return false;
276 }
277 
ParseSlow(Source * source)278 bool TensorResponse::ParseSlow(Source* source) {
279   if (!meta_.ParseFromZeroCopyStream(source->contents())) {
280     return false;
281   }
282 
283   Tensor parsed(meta_.tensor().dtype());
284   if (!parsed.FromProto(allocator_, meta_.tensor())) {
285     return false;
286   }
287   tensor_ = std::move(parsed);
288 
289   // Reduce memory usage for big tensors.
290   {
291     TensorProto empty;
292     meta_.mutable_tensor()->Swap(&empty);
293   }
294   meta_.clear_tensor();
295 
296   return true;
297 }
298 
299 }  // namespace tensorflow
300