xref: /aosp_15_r20/external/tensorflow/tensorflow/lite/testing/tf_driver.cc (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include "tensorflow/lite/testing/tf_driver.h"
16 
17 #include <fstream>
18 #include <iostream>
19 #include <string>
20 
21 #include "absl/strings/escaping.h"
22 #include "tensorflow/core/lib/gtl/array_slice.h"
23 #include "tensorflow/lite/string_util.h"
24 #include "tensorflow/lite/testing/join.h"
25 #include "tensorflow/lite/testing/split.h"
26 
27 namespace tflite {
28 namespace testing {
29 
30 namespace {
31 
CreateTensor(const tensorflow::DataType type,const std::vector<int64_t> & dim)32 tensorflow::Tensor CreateTensor(const tensorflow::DataType type,
33                                 const std::vector<int64_t>& dim) {
34   tensorflow::TensorShape shape{tensorflow::gtl::ArraySlice<int64_t>{
35       reinterpret_cast<const int64_t*>(dim.data()), dim.size()}};
36   return {type, shape};
37 }
38 
39 template <typename T>
FillTensorWithData(tensorflow::Tensor * tensor,const string & values_as_string)40 int FillTensorWithData(tensorflow::Tensor* tensor,
41                        const string& values_as_string) {
42   const auto& values = testing::Split<T>(values_as_string, ",");
43 
44   if (values.size() == tensor->NumElements()) {
45     auto data = tensor->flat<T>();
46     for (int i = 0; i < values.size(); i++) {
47       data(i) = values[i];
48     }
49   }
50 
51   return values.size();
52 }
53 
54 // Assumes 'values_as_string' is a hex string that gets converted into a
55 // TF Lite DynamicBuffer. Strings are then extracted and copied into the
56 // TensorFlow tensor.
FillTensorWithTfLiteHexString(tensorflow::Tensor * tensor,const string & values_as_string)57 int FillTensorWithTfLiteHexString(tensorflow::Tensor* tensor,
58                                   const string& values_as_string) {
59   string s = absl::HexStringToBytes(values_as_string);
60 
61   int num_strings = values_as_string.empty() ? 0 : GetStringCount(s.data());
62 
63   if (num_strings == tensor->NumElements()) {
64     auto data = tensor->flat<tensorflow::tstring>();
65     for (size_t i = 0; i < num_strings; ++i) {
66       auto ref = GetString(s.data(), i);
67       data(i).assign(ref.str, ref.len);
68     }
69   }
70 
71   return num_strings;
72 }
73 
74 template <typename T>
FillTensorWithZeros(tensorflow::Tensor * tensor)75 void FillTensorWithZeros(tensorflow::Tensor* tensor) {
76   auto data = tensor->flat<T>();
77   for (int i = 0; i < tensor->NumElements(); i++) {
78     data(i) = 0;
79   }
80 }
81 
82 template <typename T>
TensorDataToCsvString(const tensorflow::Tensor & tensor)83 string TensorDataToCsvString(const tensorflow::Tensor& tensor) {
84   const auto& data = tensor.flat<T>();
85   return Join(data.data(), data.size(), ",");
86 }
87 
TensorDataToTfLiteHexString(const tensorflow::Tensor & tensor)88 string TensorDataToTfLiteHexString(const tensorflow::Tensor& tensor) {
89   DynamicBuffer dynamic_buffer;
90 
91   auto data = tensor.flat<tensorflow::tstring>();
92   for (int i = 0; i < tensor.NumElements(); ++i) {
93     dynamic_buffer.AddString(data(i).data(), data(i).size());
94   }
95 
96   char* char_buffer = nullptr;
97   size_t size = dynamic_buffer.WriteToBuffer(&char_buffer);
98   string s = absl::BytesToHexString({char_buffer, size});
99   free(char_buffer);
100 
101   return s;
102 }
103 
104 }  // namespace
105 
TfDriver(const std::vector<string> & input_layer,const std::vector<string> & input_layer_type,const std::vector<string> & input_layer_shape,const std::vector<string> & output_layer)106 TfDriver::TfDriver(const std::vector<string>& input_layer,
107                    const std::vector<string>& input_layer_type,
108                    const std::vector<string>& input_layer_shape,
109                    const std::vector<string>& output_layer)
110     : input_names_(input_layer), output_names_(output_layer) {
111   CHECK_EQ(input_layer.size(), input_layer_type.size());
112   CHECK_EQ(input_layer.size(), input_layer_shape.size());
113 
114   input_ids_.resize(input_layer.size());
115   input_tensors_.reserve(input_layer.size());
116   input_types_.resize(input_layer.size());
117   input_shapes_.resize(input_layer.size());
118   for (int i = 0; i < input_layer.size(); i++) {
119     input_ids_[i] = i;
120     input_tensors_[input_layer[i]] = {};
121     CHECK(DataTypeFromString(input_layer_type[i], &input_types_[i]));
122     input_shapes_[i] = Split<int64_t>(input_layer_shape[i], ",");
123     input_name_to_id_[input_layer[i]] = i;
124   }
125 
126   output_ids_.resize(output_layer.size());
127   output_tensors_.reserve(output_layer.size());
128   for (int i = 0; i < output_layer.size(); i++) {
129     output_ids_[i] = i;
130     output_name_to_id_[output_layer[i]] = i;
131   }
132 }
133 
LoadModel(const string & bin_file_path)134 void TfDriver::LoadModel(const string& bin_file_path) {
135   if (!IsValid()) return;
136   std::ifstream model(bin_file_path);
137   if (model.fail()) {
138     Invalidate("Failed to find the model " + bin_file_path);
139     return;
140   }
141 
142   tensorflow::GraphDef graphdef;
143   if (!graphdef.ParseFromIstream(&model)) {
144     Invalidate("Failed to parse tensorflow graphdef");
145     return;
146   }
147 
148   tensorflow::SessionOptions options;
149   session_.reset(tensorflow::NewSession(options));
150   auto status = session_->Create(graphdef);
151   if (!status.ok()) {
152     Invalidate("Failed to create session. " + status.error_message());
153   }
154 }
155 
ReshapeTensor(const string & name,const string & csv_values)156 void TfDriver::ReshapeTensor(const string& name, const string& csv_values) {
157   if (!IsValid()) return;
158   int id = input_name_to_id_[name];
159   input_shapes_[id] = Split<int64_t>(csv_values, ",");
160   input_tensors_[input_names_[id]] =
161       CreateTensor(input_types_[id], input_shapes_[id]);
162   ResetTensor(name);
163 }
164 
ResetTensor(const std::string & name)165 void TfDriver::ResetTensor(const std::string& name) {
166   if (!IsValid()) return;
167   int id = input_name_to_id_[name];
168   auto tensor = input_tensors_[input_names_[id]];
169   switch (input_types_[id]) {
170     case tensorflow::DT_FLOAT: {
171       FillTensorWithZeros<float>(&tensor);
172       break;
173     }
174     case tensorflow::DT_INT32: {
175       FillTensorWithZeros<int32_t>(&tensor);
176       break;
177     }
178     default:
179       Invalidate(absl::StrCat("Unsupported tensor type ", input_types_[id],
180                               tensorflow::DataType_Name(input_types_[id]),
181                               " in ResetInput"));
182       return;
183   }
184 }
ReadOutput(const string & name)185 string TfDriver::ReadOutput(const string& name) {
186   if (!IsValid()) return "";
187   return ReadOutput(output_tensors_[output_name_to_id_[name]]);
188 }
Invoke(const std::vector<std::pair<string,string>> & inputs)189 void TfDriver::Invoke(const std::vector<std::pair<string, string>>& inputs) {
190   if (!IsValid()) return;
191   for (const auto& input : inputs) {
192     auto id = input_name_to_id_[input.first];
193     auto tensor = CreateTensor(input_types_[id], input_shapes_[id]);
194     SetInput(input.second, &tensor);
195     input_tensors_[input_names_[id]] = tensor;
196   }
197   auto status = session_->Run({input_tensors_.begin(), input_tensors_.end()},
198                               output_names_, {}, &output_tensors_);
199   if (!status.ok()) {
200     Invalidate(absl::StrCat("TensorFlow failed to run graph:",
201                             status.error_message()));
202   }
203 }
204 
SetInput(const string & values_as_string,tensorflow::Tensor * tensor)205 void TfDriver::SetInput(const string& values_as_string,
206                         tensorflow::Tensor* tensor) {
207   int num_values_available = 0;
208   switch (tensor->dtype()) {
209     case tensorflow::DT_FLOAT:
210       num_values_available =
211           FillTensorWithData<float>(tensor, values_as_string);
212       break;
213     case tensorflow::DT_INT32:
214       num_values_available =
215           FillTensorWithData<int32_t>(tensor, values_as_string);
216       break;
217     case tensorflow::DT_UINT32:
218       num_values_available =
219           FillTensorWithData<uint32_t>(tensor, values_as_string);
220       break;
221     case tensorflow::DT_UINT8:
222       num_values_available =
223           FillTensorWithData<uint8_t>(tensor, values_as_string);
224       break;
225     case tensorflow::DT_STRING:
226       num_values_available =
227           FillTensorWithTfLiteHexString(tensor, values_as_string);
228       break;
229     default:
230       Invalidate(absl::StrCat("Unsupported tensor type ",
231                               tensorflow::DataType_Name(tensor->dtype()),
232                               " in SetInput"));
233       return;
234   }
235 
236   if (tensor->NumElements() != num_values_available) {
237     Invalidate(absl::StrCat("Needed ", tensor->NumElements(),
238                             " values for input tensor, but was given ",
239                             num_values_available, " instead."));
240   }
241 }
242 
ReadOutput(const tensorflow::Tensor & tensor)243 string TfDriver::ReadOutput(const tensorflow::Tensor& tensor) {
244   switch (tensor.dtype()) {
245     case tensorflow::DT_FLOAT:
246       return TensorDataToCsvString<float>(tensor);
247     case tensorflow::DT_INT32:
248       return TensorDataToCsvString<int32_t>(tensor);
249     case tensorflow::DT_UINT32:
250       return TensorDataToCsvString<uint32_t>(tensor);
251     case tensorflow::DT_INT64:
252       return TensorDataToCsvString<int64_t>(tensor);
253     case tensorflow::DT_UINT8:
254       return TensorDataToCsvString<uint8_t>(tensor);
255     case tensorflow::DT_STRING:
256       return TensorDataToTfLiteHexString(tensor);
257     case tensorflow::DT_BOOL:
258       return TensorDataToCsvString<bool>(tensor);
259     default:
260       Invalidate(absl::StrCat("Unsupported tensor type ",
261                               tensorflow::DataType_Name(tensor.dtype()),
262                               " in ReadOutput"));
263       return "";
264   }
265 }
266 
267 }  // namespace testing
268 }  // namespace tflite
269