1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include "tensorflow/lite/testing/tf_driver.h"
16
17 #include <fstream>
18 #include <iostream>
19 #include <string>
20
21 #include "absl/strings/escaping.h"
22 #include "tensorflow/core/lib/gtl/array_slice.h"
23 #include "tensorflow/lite/string_util.h"
24 #include "tensorflow/lite/testing/join.h"
25 #include "tensorflow/lite/testing/split.h"
26
27 namespace tflite {
28 namespace testing {
29
30 namespace {
31
CreateTensor(const tensorflow::DataType type,const std::vector<int64_t> & dim)32 tensorflow::Tensor CreateTensor(const tensorflow::DataType type,
33 const std::vector<int64_t>& dim) {
34 tensorflow::TensorShape shape{tensorflow::gtl::ArraySlice<int64_t>{
35 reinterpret_cast<const int64_t*>(dim.data()), dim.size()}};
36 return {type, shape};
37 }
38
39 template <typename T>
FillTensorWithData(tensorflow::Tensor * tensor,const string & values_as_string)40 int FillTensorWithData(tensorflow::Tensor* tensor,
41 const string& values_as_string) {
42 const auto& values = testing::Split<T>(values_as_string, ",");
43
44 if (values.size() == tensor->NumElements()) {
45 auto data = tensor->flat<T>();
46 for (int i = 0; i < values.size(); i++) {
47 data(i) = values[i];
48 }
49 }
50
51 return values.size();
52 }
53
54 // Assumes 'values_as_string' is a hex string that gets converted into a
55 // TF Lite DynamicBuffer. Strings are then extracted and copied into the
56 // TensorFlow tensor.
FillTensorWithTfLiteHexString(tensorflow::Tensor * tensor,const string & values_as_string)57 int FillTensorWithTfLiteHexString(tensorflow::Tensor* tensor,
58 const string& values_as_string) {
59 string s = absl::HexStringToBytes(values_as_string);
60
61 int num_strings = values_as_string.empty() ? 0 : GetStringCount(s.data());
62
63 if (num_strings == tensor->NumElements()) {
64 auto data = tensor->flat<tensorflow::tstring>();
65 for (size_t i = 0; i < num_strings; ++i) {
66 auto ref = GetString(s.data(), i);
67 data(i).assign(ref.str, ref.len);
68 }
69 }
70
71 return num_strings;
72 }
73
74 template <typename T>
FillTensorWithZeros(tensorflow::Tensor * tensor)75 void FillTensorWithZeros(tensorflow::Tensor* tensor) {
76 auto data = tensor->flat<T>();
77 for (int i = 0; i < tensor->NumElements(); i++) {
78 data(i) = 0;
79 }
80 }
81
82 template <typename T>
TensorDataToCsvString(const tensorflow::Tensor & tensor)83 string TensorDataToCsvString(const tensorflow::Tensor& tensor) {
84 const auto& data = tensor.flat<T>();
85 return Join(data.data(), data.size(), ",");
86 }
87
TensorDataToTfLiteHexString(const tensorflow::Tensor & tensor)88 string TensorDataToTfLiteHexString(const tensorflow::Tensor& tensor) {
89 DynamicBuffer dynamic_buffer;
90
91 auto data = tensor.flat<tensorflow::tstring>();
92 for (int i = 0; i < tensor.NumElements(); ++i) {
93 dynamic_buffer.AddString(data(i).data(), data(i).size());
94 }
95
96 char* char_buffer = nullptr;
97 size_t size = dynamic_buffer.WriteToBuffer(&char_buffer);
98 string s = absl::BytesToHexString({char_buffer, size});
99 free(char_buffer);
100
101 return s;
102 }
103
104 } // namespace
105
TfDriver(const std::vector<string> & input_layer,const std::vector<string> & input_layer_type,const std::vector<string> & input_layer_shape,const std::vector<string> & output_layer)106 TfDriver::TfDriver(const std::vector<string>& input_layer,
107 const std::vector<string>& input_layer_type,
108 const std::vector<string>& input_layer_shape,
109 const std::vector<string>& output_layer)
110 : input_names_(input_layer), output_names_(output_layer) {
111 CHECK_EQ(input_layer.size(), input_layer_type.size());
112 CHECK_EQ(input_layer.size(), input_layer_shape.size());
113
114 input_ids_.resize(input_layer.size());
115 input_tensors_.reserve(input_layer.size());
116 input_types_.resize(input_layer.size());
117 input_shapes_.resize(input_layer.size());
118 for (int i = 0; i < input_layer.size(); i++) {
119 input_ids_[i] = i;
120 input_tensors_[input_layer[i]] = {};
121 CHECK(DataTypeFromString(input_layer_type[i], &input_types_[i]));
122 input_shapes_[i] = Split<int64_t>(input_layer_shape[i], ",");
123 input_name_to_id_[input_layer[i]] = i;
124 }
125
126 output_ids_.resize(output_layer.size());
127 output_tensors_.reserve(output_layer.size());
128 for (int i = 0; i < output_layer.size(); i++) {
129 output_ids_[i] = i;
130 output_name_to_id_[output_layer[i]] = i;
131 }
132 }
133
LoadModel(const string & bin_file_path)134 void TfDriver::LoadModel(const string& bin_file_path) {
135 if (!IsValid()) return;
136 std::ifstream model(bin_file_path);
137 if (model.fail()) {
138 Invalidate("Failed to find the model " + bin_file_path);
139 return;
140 }
141
142 tensorflow::GraphDef graphdef;
143 if (!graphdef.ParseFromIstream(&model)) {
144 Invalidate("Failed to parse tensorflow graphdef");
145 return;
146 }
147
148 tensorflow::SessionOptions options;
149 session_.reset(tensorflow::NewSession(options));
150 auto status = session_->Create(graphdef);
151 if (!status.ok()) {
152 Invalidate("Failed to create session. " + status.error_message());
153 }
154 }
155
ReshapeTensor(const string & name,const string & csv_values)156 void TfDriver::ReshapeTensor(const string& name, const string& csv_values) {
157 if (!IsValid()) return;
158 int id = input_name_to_id_[name];
159 input_shapes_[id] = Split<int64_t>(csv_values, ",");
160 input_tensors_[input_names_[id]] =
161 CreateTensor(input_types_[id], input_shapes_[id]);
162 ResetTensor(name);
163 }
164
ResetTensor(const std::string & name)165 void TfDriver::ResetTensor(const std::string& name) {
166 if (!IsValid()) return;
167 int id = input_name_to_id_[name];
168 auto tensor = input_tensors_[input_names_[id]];
169 switch (input_types_[id]) {
170 case tensorflow::DT_FLOAT: {
171 FillTensorWithZeros<float>(&tensor);
172 break;
173 }
174 case tensorflow::DT_INT32: {
175 FillTensorWithZeros<int32_t>(&tensor);
176 break;
177 }
178 default:
179 Invalidate(absl::StrCat("Unsupported tensor type ", input_types_[id],
180 tensorflow::DataType_Name(input_types_[id]),
181 " in ResetInput"));
182 return;
183 }
184 }
ReadOutput(const string & name)185 string TfDriver::ReadOutput(const string& name) {
186 if (!IsValid()) return "";
187 return ReadOutput(output_tensors_[output_name_to_id_[name]]);
188 }
Invoke(const std::vector<std::pair<string,string>> & inputs)189 void TfDriver::Invoke(const std::vector<std::pair<string, string>>& inputs) {
190 if (!IsValid()) return;
191 for (const auto& input : inputs) {
192 auto id = input_name_to_id_[input.first];
193 auto tensor = CreateTensor(input_types_[id], input_shapes_[id]);
194 SetInput(input.second, &tensor);
195 input_tensors_[input_names_[id]] = tensor;
196 }
197 auto status = session_->Run({input_tensors_.begin(), input_tensors_.end()},
198 output_names_, {}, &output_tensors_);
199 if (!status.ok()) {
200 Invalidate(absl::StrCat("TensorFlow failed to run graph:",
201 status.error_message()));
202 }
203 }
204
SetInput(const string & values_as_string,tensorflow::Tensor * tensor)205 void TfDriver::SetInput(const string& values_as_string,
206 tensorflow::Tensor* tensor) {
207 int num_values_available = 0;
208 switch (tensor->dtype()) {
209 case tensorflow::DT_FLOAT:
210 num_values_available =
211 FillTensorWithData<float>(tensor, values_as_string);
212 break;
213 case tensorflow::DT_INT32:
214 num_values_available =
215 FillTensorWithData<int32_t>(tensor, values_as_string);
216 break;
217 case tensorflow::DT_UINT32:
218 num_values_available =
219 FillTensorWithData<uint32_t>(tensor, values_as_string);
220 break;
221 case tensorflow::DT_UINT8:
222 num_values_available =
223 FillTensorWithData<uint8_t>(tensor, values_as_string);
224 break;
225 case tensorflow::DT_STRING:
226 num_values_available =
227 FillTensorWithTfLiteHexString(tensor, values_as_string);
228 break;
229 default:
230 Invalidate(absl::StrCat("Unsupported tensor type ",
231 tensorflow::DataType_Name(tensor->dtype()),
232 " in SetInput"));
233 return;
234 }
235
236 if (tensor->NumElements() != num_values_available) {
237 Invalidate(absl::StrCat("Needed ", tensor->NumElements(),
238 " values for input tensor, but was given ",
239 num_values_available, " instead."));
240 }
241 }
242
ReadOutput(const tensorflow::Tensor & tensor)243 string TfDriver::ReadOutput(const tensorflow::Tensor& tensor) {
244 switch (tensor.dtype()) {
245 case tensorflow::DT_FLOAT:
246 return TensorDataToCsvString<float>(tensor);
247 case tensorflow::DT_INT32:
248 return TensorDataToCsvString<int32_t>(tensor);
249 case tensorflow::DT_UINT32:
250 return TensorDataToCsvString<uint32_t>(tensor);
251 case tensorflow::DT_INT64:
252 return TensorDataToCsvString<int64_t>(tensor);
253 case tensorflow::DT_UINT8:
254 return TensorDataToCsvString<uint8_t>(tensor);
255 case tensorflow::DT_STRING:
256 return TensorDataToTfLiteHexString(tensor);
257 case tensorflow::DT_BOOL:
258 return TensorDataToCsvString<bool>(tensor);
259 default:
260 Invalidate(absl::StrCat("Unsupported tensor type ",
261 tensorflow::DataType_Name(tensor.dtype()),
262 " in ReadOutput"));
263 return "";
264 }
265 }
266
267 } // namespace testing
268 } // namespace tflite
269