1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include <setjmp.h>
17 #include <stdio.h>
18 #include <string.h>
19
20 #include <cmath>
21 #include <fstream>
22 #include <vector>
23
24 #include "tensorflow/cc/ops/const_op.h"
25 #include "tensorflow/cc/ops/image_ops.h"
26 #include "tensorflow/cc/ops/standard_ops.h"
27 #include "tensorflow/core/framework/graph.pb.h"
28 #include "tensorflow/core/framework/tensor.h"
29 #include "tensorflow/core/graph/default_device.h"
30 #include "tensorflow/core/graph/graph_def_builder.h"
31 #include "tensorflow/core/lib/core/errors.h"
32 #include "tensorflow/core/lib/core/stringpiece.h"
33 #include "tensorflow/core/lib/core/threadpool.h"
34 #include "tensorflow/core/lib/io/path.h"
35 #include "tensorflow/core/lib/strings/numbers.h"
36 #include "tensorflow/core/lib/strings/str_util.h"
37 #include "tensorflow/core/lib/strings/stringprintf.h"
38 #include "tensorflow/core/platform/init_main.h"
39 #include "tensorflow/core/platform/logging.h"
40 #include "tensorflow/core/platform/types.h"
41 #include "tensorflow/core/public/session.h"
42 #include "tensorflow/core/util/command_line_flags.h"
43
44 // These are all common classes it's handy to reference with no namespace.
45 using tensorflow::Flag;
46 using tensorflow::Tensor;
47 using tensorflow::Status;
48 using tensorflow::string;
49 using tensorflow::int32;
50 using tensorflow::uint8;
51
52 // Takes a file name, and loads a list of comma-separated box priors from it,
53 // one per line, and returns a vector of the values.
ReadLocationsFile(const string & file_name,std::vector<float> * result,size_t * found_label_count)54 Status ReadLocationsFile(const string& file_name, std::vector<float>* result,
55 size_t* found_label_count) {
56 std::ifstream file(file_name);
57 if (!file) {
58 return tensorflow::errors::NotFound("Labels file ", file_name,
59 " not found.");
60 }
61 result->clear();
62 string line;
63 while (std::getline(file, line)) {
64 std::vector<string> string_tokens = tensorflow::str_util::Split(line, ',');
65 result->reserve(string_tokens.size());
66 for (const string& string_token : string_tokens) {
67 float number;
68 CHECK(tensorflow::strings::safe_strtof(string_token, &number));
69 result->push_back(number);
70 }
71 }
72 *found_label_count = result->size();
73 return ::tensorflow::OkStatus();
74 }
75
76 // Given an image file name, read in the data, try to decode it as an image,
77 // resize it to the requested size, and then scale the values as desired.
ReadTensorFromImageFile(const string & file_name,const int input_height,const int input_width,const float input_mean,const float input_std,std::vector<Tensor> * out_tensors)78 Status ReadTensorFromImageFile(const string& file_name, const int input_height,
79 const int input_width, const float input_mean,
80 const float input_std,
81 std::vector<Tensor>* out_tensors) {
82 auto root = tensorflow::Scope::NewRootScope();
83 using namespace ::tensorflow::ops; // NOLINT(build/namespaces)
84
85 string input_name = "file_reader";
86 string original_name = "identity";
87 string output_name = "normalized";
88 auto file_reader =
89 tensorflow::ops::ReadFile(root.WithOpName(input_name), file_name);
90 // Now try to figure out what kind of file it is and decode it.
91 const int wanted_channels = 3;
92 tensorflow::Output image_reader;
93 if (tensorflow::str_util::EndsWith(file_name, ".png")) {
94 image_reader = DecodePng(root.WithOpName("png_reader"), file_reader,
95 DecodePng::Channels(wanted_channels));
96 } else if (tensorflow::str_util::EndsWith(file_name, ".gif")) {
97 image_reader = DecodeGif(root.WithOpName("gif_reader"), file_reader);
98 } else {
99 // Assume if it's neither a PNG nor a GIF then it must be a JPEG.
100 image_reader = DecodeJpeg(root.WithOpName("jpeg_reader"), file_reader,
101 DecodeJpeg::Channels(wanted_channels));
102 }
103
104 // Also return identity so that we can know the original dimensions and
105 // optionally save the image out with bounding boxes overlaid.
106 auto original_image = Identity(root.WithOpName(original_name), image_reader);
107
108 // Now cast the image data to float so we can do normal math on it.
109 auto float_caster = Cast(root.WithOpName("float_caster"), original_image,
110 tensorflow::DT_FLOAT);
111 // The convention for image ops in TensorFlow is that all images are expected
112 // to be in batches, so that they're four-dimensional arrays with indices of
113 // [batch, height, width, channel]. Because we only have a single image, we
114 // have to add a batch dimension of 1 to the start with ExpandDims().
115 auto dims_expander = ExpandDims(root, float_caster, 0);
116
117 // Bilinearly resize the image to fit the required dimensions.
118 auto resized = ResizeBilinear(
119 root, dims_expander,
120 Const(root.WithOpName("size"), {input_height, input_width}));
121 // Subtract the mean and divide by the scale.
122 Div(root.WithOpName(output_name), Sub(root, resized, {input_mean}),
123 {input_std});
124
125 // This runs the GraphDef network definition that we've just constructed, and
126 // returns the results in the output tensor.
127 tensorflow::GraphDef graph;
128 TF_RETURN_IF_ERROR(root.ToGraphDef(&graph));
129
130 std::unique_ptr<tensorflow::Session> session(
131 tensorflow::NewSession(tensorflow::SessionOptions()));
132 TF_RETURN_IF_ERROR(session->Create(graph));
133 TF_RETURN_IF_ERROR(
134 session->Run({}, {output_name, original_name}, {}, out_tensors));
135 return ::tensorflow::OkStatus();
136 }
137
SaveImage(const Tensor & tensor,const string & file_path)138 Status SaveImage(const Tensor& tensor, const string& file_path) {
139 LOG(INFO) << "Saving image to " << file_path;
140 CHECK(tensorflow::str_util::EndsWith(file_path, ".png"))
141 << "Only saving of png files is supported.";
142
143 auto root = tensorflow::Scope::NewRootScope();
144 using namespace ::tensorflow::ops; // NOLINT(build/namespaces)
145
146 string encoder_name = "encode";
147 string output_name = "file_writer";
148
149 tensorflow::Output image_encoder =
150 EncodePng(root.WithOpName(encoder_name), tensor);
151 tensorflow::ops::WriteFile file_saver = tensorflow::ops::WriteFile(
152 root.WithOpName(output_name), file_path, image_encoder);
153
154 tensorflow::GraphDef graph;
155 TF_RETURN_IF_ERROR(root.ToGraphDef(&graph));
156
157 std::unique_ptr<tensorflow::Session> session(
158 tensorflow::NewSession(tensorflow::SessionOptions()));
159 TF_RETURN_IF_ERROR(session->Create(graph));
160 std::vector<Tensor> outputs;
161 TF_RETURN_IF_ERROR(session->Run({}, {}, {output_name}, &outputs));
162
163 return ::tensorflow::OkStatus();
164 }
165
166 // Reads a model graph definition from disk, and creates a session object you
167 // can use to run it.
LoadGraph(const string & graph_file_name,std::unique_ptr<tensorflow::Session> * session)168 Status LoadGraph(const string& graph_file_name,
169 std::unique_ptr<tensorflow::Session>* session) {
170 tensorflow::GraphDef graph_def;
171 Status load_graph_status =
172 ReadBinaryProto(tensorflow::Env::Default(), graph_file_name, &graph_def);
173 if (!load_graph_status.ok()) {
174 return tensorflow::errors::NotFound("Failed to load compute graph at '",
175 graph_file_name, "'");
176 }
177 session->reset(tensorflow::NewSession(tensorflow::SessionOptions()));
178 Status session_create_status = (*session)->Create(graph_def);
179 if (!session_create_status.ok()) {
180 return session_create_status;
181 }
182 return ::tensorflow::OkStatus();
183 }
184
185 // Analyzes the output of the MultiBox graph to retrieve the highest scores and
186 // their positions in the tensor, which correspond to individual box detections.
GetTopDetections(const std::vector<Tensor> & outputs,int how_many_labels,Tensor * indices,Tensor * scores)187 Status GetTopDetections(const std::vector<Tensor>& outputs, int how_many_labels,
188 Tensor* indices, Tensor* scores) {
189 auto root = tensorflow::Scope::NewRootScope();
190 using namespace ::tensorflow::ops; // NOLINT(build/namespaces)
191
192 string output_name = "top_k";
193 TopK(root.WithOpName(output_name), outputs[0], how_many_labels);
194 // This runs the GraphDef network definition that we've just constructed, and
195 // returns the results in the output tensors.
196 tensorflow::GraphDef graph;
197 TF_RETURN_IF_ERROR(root.ToGraphDef(&graph));
198
199 std::unique_ptr<tensorflow::Session> session(
200 tensorflow::NewSession(tensorflow::SessionOptions()));
201 TF_RETURN_IF_ERROR(session->Create(graph));
202 // The TopK node returns two outputs, the scores and their original indices,
203 // so we have to append :0 and :1 to specify them both.
204 std::vector<Tensor> out_tensors;
205 TF_RETURN_IF_ERROR(session->Run({}, {output_name + ":0", output_name + ":1"},
206 {}, &out_tensors));
207 *scores = out_tensors[0];
208 *indices = out_tensors[1];
209 return ::tensorflow::OkStatus();
210 }
211
212 // Converts an encoded location to an actual box placement with the provided
213 // box priors.
DecodeLocation(const float * encoded_location,const float * box_priors,float * decoded_location)214 void DecodeLocation(const float* encoded_location, const float* box_priors,
215 float* decoded_location) {
216 bool non_zero = false;
217 for (int i = 0; i < 4; ++i) {
218 const float curr_encoding = encoded_location[i];
219 non_zero = non_zero || curr_encoding != 0.0f;
220
221 const float mean = box_priors[i * 2];
222 const float std_dev = box_priors[i * 2 + 1];
223
224 float currentLocation = curr_encoding * std_dev + mean;
225
226 currentLocation = std::max(currentLocation, 0.0f);
227 currentLocation = std::min(currentLocation, 1.0f);
228 decoded_location[i] = currentLocation;
229 }
230
231 if (!non_zero) {
232 LOG(WARNING) << "No non-zero encodings; check log for inference errors.";
233 }
234 }
235
DecodeScore(float encoded_score)236 float DecodeScore(float encoded_score) {
237 return 1 / (1 + std::exp(-encoded_score));
238 }
239
DrawBox(const int image_width,const int image_height,int left,int top,int right,int bottom,tensorflow::TTypes<uint8>::Flat * image)240 void DrawBox(const int image_width, const int image_height, int left, int top,
241 int right, int bottom, tensorflow::TTypes<uint8>::Flat* image) {
242 tensorflow::TTypes<uint8>::Flat image_ref = *image;
243
244 top = std::max(0, std::min(image_height - 1, top));
245 bottom = std::max(0, std::min(image_height - 1, bottom));
246
247 left = std::max(0, std::min(image_width - 1, left));
248 right = std::max(0, std::min(image_width - 1, right));
249
250 for (int i = 0; i < 3; ++i) {
251 uint8 val = i == 2 ? 255 : 0;
252 for (int x = left; x <= right; ++x) {
253 image_ref((top * image_width + x) * 3 + i) = val;
254 image_ref((bottom * image_width + x) * 3 + i) = val;
255 }
256 for (int y = top; y <= bottom; ++y) {
257 image_ref((y * image_width + left) * 3 + i) = val;
258 image_ref((y * image_width + right) * 3 + i) = val;
259 }
260 }
261 }
262
263 // Given the output of a model run, and the name of a file containing the labels
264 // this prints out the top five highest-scoring values.
PrintTopDetections(const std::vector<Tensor> & outputs,const string & labels_file_name,const int num_boxes,const int num_detections,const string & image_file_name,Tensor * original_tensor)265 Status PrintTopDetections(const std::vector<Tensor>& outputs,
266 const string& labels_file_name,
267 const int num_boxes,
268 const int num_detections,
269 const string& image_file_name,
270 Tensor* original_tensor) {
271 std::vector<float> locations;
272 size_t label_count;
273 Status read_labels_status =
274 ReadLocationsFile(labels_file_name, &locations, &label_count);
275 if (!read_labels_status.ok()) {
276 LOG(ERROR) << read_labels_status;
277 return read_labels_status;
278 }
279 CHECK_EQ(label_count, num_boxes * 8);
280
281 const int how_many_labels =
282 std::min(num_detections, static_cast<int>(label_count));
283 Tensor indices;
284 Tensor scores;
285 TF_RETURN_IF_ERROR(
286 GetTopDetections(outputs, how_many_labels, &indices, &scores));
287
288 tensorflow::TTypes<float>::Flat scores_flat = scores.flat<float>();
289
290 tensorflow::TTypes<int32>::Flat indices_flat = indices.flat<int32>();
291
292 const Tensor& encoded_locations = outputs[1];
293 auto locations_encoded = encoded_locations.flat<float>();
294
295 LOG(INFO) << original_tensor->DebugString();
296 const int image_width = original_tensor->shape().dim_size(1);
297 const int image_height = original_tensor->shape().dim_size(0);
298
299 tensorflow::TTypes<uint8>::Flat image_flat = original_tensor->flat<uint8>();
300
301 LOG(INFO) << "===== Top " << how_many_labels << " Detections ======";
302 for (int pos = 0; pos < how_many_labels; ++pos) {
303 const int label_index = indices_flat(pos);
304 const float score = scores_flat(pos);
305
306 float decoded_location[4];
307 DecodeLocation(&locations_encoded(label_index * 4),
308 &locations[label_index * 8], decoded_location);
309
310 float left = decoded_location[0] * image_width;
311 float top = decoded_location[1] * image_height;
312 float right = decoded_location[2] * image_width;
313 float bottom = decoded_location[3] * image_height;
314
315 LOG(INFO) << "Detection " << pos << ": "
316 << "L:" << left << " "
317 << "T:" << top << " "
318 << "R:" << right << " "
319 << "B:" << bottom << " "
320 << "(" << label_index << ") score: " << DecodeScore(score);
321
322 DrawBox(image_width, image_height, left, top, right, bottom, &image_flat);
323 }
324
325 if (!image_file_name.empty()) {
326 return SaveImage(*original_tensor, image_file_name);
327 }
328 return ::tensorflow::OkStatus();
329 }
330
main(int argc,char * argv[])331 int main(int argc, char* argv[]) {
332 // These are the command-line flags the program can understand.
333 // They define where the graph and input data is located, and what kind of
334 // input the model expects. If you train your own model, or use something
335 // other than multibox_model you'll need to update these.
336 string image =
337 "tensorflow/examples/multibox_detector/data/surfers.jpg";
338 string graph =
339 "tensorflow/examples/multibox_detector/data/"
340 "multibox_model.pb";
341 string box_priors =
342 "tensorflow/examples/multibox_detector/data/"
343 "multibox_location_priors.txt";
344 int32_t input_width = 224;
345 int32_t input_height = 224;
346 int32_t input_mean = 128;
347 int32_t input_std = 128;
348 int32_t num_detections = 5;
349 int32_t num_boxes = 784;
350 string input_layer = "ResizeBilinear";
351 string output_location_layer = "output_locations/Reshape";
352 string output_score_layer = "output_scores/Reshape";
353 string root_dir = "";
354 string image_out = "";
355
356 std::vector<Flag> flag_list = {
357 Flag("image", &image, "image to be processed"),
358 Flag("image_out", &image_out,
359 "location to save output image, if desired"),
360 Flag("graph", &graph, "graph to be executed"),
361 Flag("box_priors", &box_priors, "name of file containing box priors"),
362 Flag("input_width", &input_width, "resize image to this width in pixels"),
363 Flag("input_height", &input_height,
364 "resize image to this height in pixels"),
365 Flag("input_mean", &input_mean, "scale pixel values to this mean"),
366 Flag("input_std", &input_std, "scale pixel values to this std deviation"),
367 Flag("num_detections", &num_detections,
368 "number of top detections to return"),
369 Flag("num_boxes", &num_boxes,
370 "number of boxes defined by the location file"),
371 Flag("input_layer", &input_layer, "name of input layer"),
372 Flag("output_location_layer", &output_location_layer,
373 "name of location output layer"),
374 Flag("output_score_layer", &output_score_layer,
375 "name of score output layer"),
376 Flag("root_dir", &root_dir,
377 "interpret image and graph file names relative to this directory"),
378 };
379
380 string usage = tensorflow::Flags::Usage(argv[0], flag_list);
381 const bool parse_result = tensorflow::Flags::Parse(&argc, argv, flag_list);
382 if (!parse_result) {
383 LOG(ERROR) << usage;
384 return -1;
385 }
386
387 // We need to call this to set up global state for TensorFlow.
388 tensorflow::port::InitMain(argv[0], &argc, &argv);
389 if (argc > 1) {
390 LOG(ERROR) << "Unknown argument " << argv[1] << "\n" << usage;
391 return -1;
392 }
393
394 // First we load and initialize the model.
395 std::unique_ptr<tensorflow::Session> session;
396 string graph_path = tensorflow::io::JoinPath(root_dir, graph);
397 Status load_graph_status = LoadGraph(graph_path, &session);
398 if (!load_graph_status.ok()) {
399 LOG(ERROR) << load_graph_status;
400 return -1;
401 }
402
403 // Get the image from disk as a float array of numbers, resized and normalized
404 // to the specifications the main graph expects.
405 std::vector<Tensor> image_tensors;
406 string image_path = tensorflow::io::JoinPath(root_dir, image);
407
408 Status read_tensor_status =
409 ReadTensorFromImageFile(image_path, input_height, input_width, input_mean,
410 input_std, &image_tensors);
411 if (!read_tensor_status.ok()) {
412 LOG(ERROR) << read_tensor_status;
413 return -1;
414 }
415 const Tensor& resized_tensor = image_tensors[0];
416
417 // Actually run the image through the model.
418 std::vector<Tensor> outputs;
419 Status run_status =
420 session->Run({{input_layer, resized_tensor}},
421 {output_score_layer, output_location_layer}, {}, &outputs);
422 if (!run_status.ok()) {
423 LOG(ERROR) << "Running model failed: " << run_status;
424 return -1;
425 }
426
427 Status print_status = PrintTopDetections(outputs, box_priors, num_boxes,
428 num_detections, image_out,
429 &image_tensors[1]);
430
431 if (!print_status.ok()) {
432 LOG(ERROR) << "Running print failed: " << print_status;
433 return -1;
434 }
435 return 0;
436 }
437