xref: /aosp_15_r20/external/tensorflow/tensorflow/examples/multibox_detector/main.cc (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include <setjmp.h>
17 #include <stdio.h>
18 #include <string.h>
19 
20 #include <cmath>
21 #include <fstream>
22 #include <vector>
23 
24 #include "tensorflow/cc/ops/const_op.h"
25 #include "tensorflow/cc/ops/image_ops.h"
26 #include "tensorflow/cc/ops/standard_ops.h"
27 #include "tensorflow/core/framework/graph.pb.h"
28 #include "tensorflow/core/framework/tensor.h"
29 #include "tensorflow/core/graph/default_device.h"
30 #include "tensorflow/core/graph/graph_def_builder.h"
31 #include "tensorflow/core/lib/core/errors.h"
32 #include "tensorflow/core/lib/core/stringpiece.h"
33 #include "tensorflow/core/lib/core/threadpool.h"
34 #include "tensorflow/core/lib/io/path.h"
35 #include "tensorflow/core/lib/strings/numbers.h"
36 #include "tensorflow/core/lib/strings/str_util.h"
37 #include "tensorflow/core/lib/strings/stringprintf.h"
38 #include "tensorflow/core/platform/init_main.h"
39 #include "tensorflow/core/platform/logging.h"
40 #include "tensorflow/core/platform/types.h"
41 #include "tensorflow/core/public/session.h"
42 #include "tensorflow/core/util/command_line_flags.h"
43 
44 // These are all common classes it's handy to reference with no namespace.
45 using tensorflow::Flag;
46 using tensorflow::Tensor;
47 using tensorflow::Status;
48 using tensorflow::string;
49 using tensorflow::int32;
50 using tensorflow::uint8;
51 
52 // Takes a file name, and loads a list of comma-separated box priors from it,
53 // one per line, and returns a vector of the values.
ReadLocationsFile(const string & file_name,std::vector<float> * result,size_t * found_label_count)54 Status ReadLocationsFile(const string& file_name, std::vector<float>* result,
55                          size_t* found_label_count) {
56   std::ifstream file(file_name);
57   if (!file) {
58     return tensorflow::errors::NotFound("Labels file ", file_name,
59                                         " not found.");
60   }
61   result->clear();
62   string line;
63   while (std::getline(file, line)) {
64     std::vector<string> string_tokens = tensorflow::str_util::Split(line, ',');
65     result->reserve(string_tokens.size());
66     for (const string& string_token : string_tokens) {
67       float number;
68       CHECK(tensorflow::strings::safe_strtof(string_token, &number));
69       result->push_back(number);
70     }
71   }
72   *found_label_count = result->size();
73   return ::tensorflow::OkStatus();
74 }
75 
76 // Given an image file name, read in the data, try to decode it as an image,
77 // resize it to the requested size, and then scale the values as desired.
ReadTensorFromImageFile(const string & file_name,const int input_height,const int input_width,const float input_mean,const float input_std,std::vector<Tensor> * out_tensors)78 Status ReadTensorFromImageFile(const string& file_name, const int input_height,
79                                const int input_width, const float input_mean,
80                                const float input_std,
81                                std::vector<Tensor>* out_tensors) {
82   auto root = tensorflow::Scope::NewRootScope();
83   using namespace ::tensorflow::ops;  // NOLINT(build/namespaces)
84 
85   string input_name = "file_reader";
86   string original_name = "identity";
87   string output_name = "normalized";
88   auto file_reader =
89       tensorflow::ops::ReadFile(root.WithOpName(input_name), file_name);
90   // Now try to figure out what kind of file it is and decode it.
91   const int wanted_channels = 3;
92   tensorflow::Output image_reader;
93   if (tensorflow::str_util::EndsWith(file_name, ".png")) {
94     image_reader = DecodePng(root.WithOpName("png_reader"), file_reader,
95                              DecodePng::Channels(wanted_channels));
96   } else if (tensorflow::str_util::EndsWith(file_name, ".gif")) {
97     image_reader = DecodeGif(root.WithOpName("gif_reader"), file_reader);
98   } else {
99     // Assume if it's neither a PNG nor a GIF then it must be a JPEG.
100     image_reader = DecodeJpeg(root.WithOpName("jpeg_reader"), file_reader,
101                               DecodeJpeg::Channels(wanted_channels));
102   }
103 
104   // Also return identity so that we can know the original dimensions and
105   // optionally save the image out with bounding boxes overlaid.
106   auto original_image = Identity(root.WithOpName(original_name), image_reader);
107 
108   // Now cast the image data to float so we can do normal math on it.
109   auto float_caster = Cast(root.WithOpName("float_caster"), original_image,
110                            tensorflow::DT_FLOAT);
111   // The convention for image ops in TensorFlow is that all images are expected
112   // to be in batches, so that they're four-dimensional arrays with indices of
113   // [batch, height, width, channel]. Because we only have a single image, we
114   // have to add a batch dimension of 1 to the start with ExpandDims().
115   auto dims_expander = ExpandDims(root, float_caster, 0);
116 
117   // Bilinearly resize the image to fit the required dimensions.
118   auto resized = ResizeBilinear(
119       root, dims_expander,
120       Const(root.WithOpName("size"), {input_height, input_width}));
121   // Subtract the mean and divide by the scale.
122   Div(root.WithOpName(output_name), Sub(root, resized, {input_mean}),
123       {input_std});
124 
125   // This runs the GraphDef network definition that we've just constructed, and
126   // returns the results in the output tensor.
127   tensorflow::GraphDef graph;
128   TF_RETURN_IF_ERROR(root.ToGraphDef(&graph));
129 
130   std::unique_ptr<tensorflow::Session> session(
131       tensorflow::NewSession(tensorflow::SessionOptions()));
132   TF_RETURN_IF_ERROR(session->Create(graph));
133   TF_RETURN_IF_ERROR(
134       session->Run({}, {output_name, original_name}, {}, out_tensors));
135   return ::tensorflow::OkStatus();
136 }
137 
SaveImage(const Tensor & tensor,const string & file_path)138 Status SaveImage(const Tensor& tensor, const string& file_path) {
139   LOG(INFO) << "Saving image to " << file_path;
140   CHECK(tensorflow::str_util::EndsWith(file_path, ".png"))
141       << "Only saving of png files is supported.";
142 
143   auto root = tensorflow::Scope::NewRootScope();
144   using namespace ::tensorflow::ops;  // NOLINT(build/namespaces)
145 
146   string encoder_name = "encode";
147   string output_name = "file_writer";
148 
149   tensorflow::Output image_encoder =
150       EncodePng(root.WithOpName(encoder_name), tensor);
151   tensorflow::ops::WriteFile file_saver = tensorflow::ops::WriteFile(
152       root.WithOpName(output_name), file_path, image_encoder);
153 
154   tensorflow::GraphDef graph;
155   TF_RETURN_IF_ERROR(root.ToGraphDef(&graph));
156 
157   std::unique_ptr<tensorflow::Session> session(
158       tensorflow::NewSession(tensorflow::SessionOptions()));
159   TF_RETURN_IF_ERROR(session->Create(graph));
160   std::vector<Tensor> outputs;
161   TF_RETURN_IF_ERROR(session->Run({}, {}, {output_name}, &outputs));
162 
163   return ::tensorflow::OkStatus();
164 }
165 
166 // Reads a model graph definition from disk, and creates a session object you
167 // can use to run it.
LoadGraph(const string & graph_file_name,std::unique_ptr<tensorflow::Session> * session)168 Status LoadGraph(const string& graph_file_name,
169                  std::unique_ptr<tensorflow::Session>* session) {
170   tensorflow::GraphDef graph_def;
171   Status load_graph_status =
172       ReadBinaryProto(tensorflow::Env::Default(), graph_file_name, &graph_def);
173   if (!load_graph_status.ok()) {
174     return tensorflow::errors::NotFound("Failed to load compute graph at '",
175                                         graph_file_name, "'");
176   }
177   session->reset(tensorflow::NewSession(tensorflow::SessionOptions()));
178   Status session_create_status = (*session)->Create(graph_def);
179   if (!session_create_status.ok()) {
180     return session_create_status;
181   }
182   return ::tensorflow::OkStatus();
183 }
184 
185 // Analyzes the output of the MultiBox graph to retrieve the highest scores and
186 // their positions in the tensor, which correspond to individual box detections.
GetTopDetections(const std::vector<Tensor> & outputs,int how_many_labels,Tensor * indices,Tensor * scores)187 Status GetTopDetections(const std::vector<Tensor>& outputs, int how_many_labels,
188                         Tensor* indices, Tensor* scores) {
189   auto root = tensorflow::Scope::NewRootScope();
190   using namespace ::tensorflow::ops;  // NOLINT(build/namespaces)
191 
192   string output_name = "top_k";
193   TopK(root.WithOpName(output_name), outputs[0], how_many_labels);
194   // This runs the GraphDef network definition that we've just constructed, and
195   // returns the results in the output tensors.
196   tensorflow::GraphDef graph;
197   TF_RETURN_IF_ERROR(root.ToGraphDef(&graph));
198 
199   std::unique_ptr<tensorflow::Session> session(
200       tensorflow::NewSession(tensorflow::SessionOptions()));
201   TF_RETURN_IF_ERROR(session->Create(graph));
202   // The TopK node returns two outputs, the scores and their original indices,
203   // so we have to append :0 and :1 to specify them both.
204   std::vector<Tensor> out_tensors;
205   TF_RETURN_IF_ERROR(session->Run({}, {output_name + ":0", output_name + ":1"},
206                                   {}, &out_tensors));
207   *scores = out_tensors[0];
208   *indices = out_tensors[1];
209   return ::tensorflow::OkStatus();
210 }
211 
212 // Converts an encoded location to an actual box placement with the provided
213 // box priors.
DecodeLocation(const float * encoded_location,const float * box_priors,float * decoded_location)214 void DecodeLocation(const float* encoded_location, const float* box_priors,
215                     float* decoded_location) {
216   bool non_zero = false;
217   for (int i = 0; i < 4; ++i) {
218     const float curr_encoding = encoded_location[i];
219     non_zero = non_zero || curr_encoding != 0.0f;
220 
221     const float mean = box_priors[i * 2];
222     const float std_dev = box_priors[i * 2 + 1];
223 
224     float currentLocation = curr_encoding * std_dev + mean;
225 
226     currentLocation = std::max(currentLocation, 0.0f);
227     currentLocation = std::min(currentLocation, 1.0f);
228     decoded_location[i] = currentLocation;
229   }
230 
231   if (!non_zero) {
232     LOG(WARNING) << "No non-zero encodings; check log for inference errors.";
233   }
234 }
235 
DecodeScore(float encoded_score)236 float DecodeScore(float encoded_score) {
237   return 1 / (1 + std::exp(-encoded_score));
238 }
239 
DrawBox(const int image_width,const int image_height,int left,int top,int right,int bottom,tensorflow::TTypes<uint8>::Flat * image)240 void DrawBox(const int image_width, const int image_height, int left, int top,
241              int right, int bottom, tensorflow::TTypes<uint8>::Flat* image) {
242   tensorflow::TTypes<uint8>::Flat image_ref = *image;
243 
244   top = std::max(0, std::min(image_height - 1, top));
245   bottom = std::max(0, std::min(image_height - 1, bottom));
246 
247   left = std::max(0, std::min(image_width - 1, left));
248   right = std::max(0, std::min(image_width - 1, right));
249 
250   for (int i = 0; i < 3; ++i) {
251     uint8 val = i == 2 ? 255 : 0;
252     for (int x = left; x <= right; ++x) {
253       image_ref((top * image_width + x) * 3 + i) = val;
254       image_ref((bottom * image_width + x) * 3 + i) = val;
255     }
256     for (int y = top; y <= bottom; ++y) {
257       image_ref((y * image_width + left) * 3 + i) = val;
258       image_ref((y * image_width + right) * 3 + i) = val;
259     }
260   }
261 }
262 
263 // Given the output of a model run, and the name of a file containing the labels
264 // this prints out the top five highest-scoring values.
PrintTopDetections(const std::vector<Tensor> & outputs,const string & labels_file_name,const int num_boxes,const int num_detections,const string & image_file_name,Tensor * original_tensor)265 Status PrintTopDetections(const std::vector<Tensor>& outputs,
266                           const string& labels_file_name,
267                           const int num_boxes,
268                           const int num_detections,
269                           const string& image_file_name,
270                           Tensor* original_tensor) {
271   std::vector<float> locations;
272   size_t label_count;
273   Status read_labels_status =
274       ReadLocationsFile(labels_file_name, &locations, &label_count);
275   if (!read_labels_status.ok()) {
276     LOG(ERROR) << read_labels_status;
277     return read_labels_status;
278   }
279   CHECK_EQ(label_count, num_boxes * 8);
280 
281   const int how_many_labels =
282       std::min(num_detections, static_cast<int>(label_count));
283   Tensor indices;
284   Tensor scores;
285   TF_RETURN_IF_ERROR(
286       GetTopDetections(outputs, how_many_labels, &indices, &scores));
287 
288   tensorflow::TTypes<float>::Flat scores_flat = scores.flat<float>();
289 
290   tensorflow::TTypes<int32>::Flat indices_flat = indices.flat<int32>();
291 
292   const Tensor& encoded_locations = outputs[1];
293   auto locations_encoded = encoded_locations.flat<float>();
294 
295   LOG(INFO) << original_tensor->DebugString();
296   const int image_width = original_tensor->shape().dim_size(1);
297   const int image_height = original_tensor->shape().dim_size(0);
298 
299   tensorflow::TTypes<uint8>::Flat image_flat = original_tensor->flat<uint8>();
300 
301   LOG(INFO) << "===== Top " << how_many_labels << " Detections ======";
302   for (int pos = 0; pos < how_many_labels; ++pos) {
303     const int label_index = indices_flat(pos);
304     const float score = scores_flat(pos);
305 
306     float decoded_location[4];
307     DecodeLocation(&locations_encoded(label_index * 4),
308                    &locations[label_index * 8], decoded_location);
309 
310     float left = decoded_location[0] * image_width;
311     float top = decoded_location[1] * image_height;
312     float right = decoded_location[2] * image_width;
313     float bottom = decoded_location[3] * image_height;
314 
315     LOG(INFO) << "Detection " << pos << ": "
316               << "L:" << left << " "
317               << "T:" << top << " "
318               << "R:" << right << " "
319               << "B:" << bottom << " "
320               << "(" << label_index << ") score: " << DecodeScore(score);
321 
322     DrawBox(image_width, image_height, left, top, right, bottom, &image_flat);
323   }
324 
325   if (!image_file_name.empty()) {
326     return SaveImage(*original_tensor, image_file_name);
327   }
328   return ::tensorflow::OkStatus();
329 }
330 
main(int argc,char * argv[])331 int main(int argc, char* argv[]) {
332   // These are the command-line flags the program can understand.
333   // They define where the graph and input data is located, and what kind of
334   // input the model expects. If you train your own model, or use something
335   // other than multibox_model you'll need to update these.
336   string image =
337       "tensorflow/examples/multibox_detector/data/surfers.jpg";
338   string graph =
339       "tensorflow/examples/multibox_detector/data/"
340       "multibox_model.pb";
341   string box_priors =
342       "tensorflow/examples/multibox_detector/data/"
343       "multibox_location_priors.txt";
344   int32_t input_width = 224;
345   int32_t input_height = 224;
346   int32_t input_mean = 128;
347   int32_t input_std = 128;
348   int32_t num_detections = 5;
349   int32_t num_boxes = 784;
350   string input_layer = "ResizeBilinear";
351   string output_location_layer = "output_locations/Reshape";
352   string output_score_layer = "output_scores/Reshape";
353   string root_dir = "";
354   string image_out = "";
355 
356   std::vector<Flag> flag_list = {
357       Flag("image", &image, "image to be processed"),
358       Flag("image_out", &image_out,
359            "location to save output image, if desired"),
360       Flag("graph", &graph, "graph to be executed"),
361       Flag("box_priors", &box_priors, "name of file containing box priors"),
362       Flag("input_width", &input_width, "resize image to this width in pixels"),
363       Flag("input_height", &input_height,
364            "resize image to this height in pixels"),
365       Flag("input_mean", &input_mean, "scale pixel values to this mean"),
366       Flag("input_std", &input_std, "scale pixel values to this std deviation"),
367       Flag("num_detections", &num_detections,
368            "number of top detections to return"),
369       Flag("num_boxes", &num_boxes,
370            "number of boxes defined by the location file"),
371       Flag("input_layer", &input_layer, "name of input layer"),
372       Flag("output_location_layer", &output_location_layer,
373            "name of location output layer"),
374       Flag("output_score_layer", &output_score_layer,
375            "name of score output layer"),
376       Flag("root_dir", &root_dir,
377            "interpret image and graph file names relative to this directory"),
378   };
379 
380   string usage = tensorflow::Flags::Usage(argv[0], flag_list);
381   const bool parse_result = tensorflow::Flags::Parse(&argc, argv, flag_list);
382   if (!parse_result) {
383     LOG(ERROR) << usage;
384     return -1;
385   }
386 
387   // We need to call this to set up global state for TensorFlow.
388   tensorflow::port::InitMain(argv[0], &argc, &argv);
389   if (argc > 1) {
390     LOG(ERROR) << "Unknown argument " << argv[1] << "\n" << usage;
391     return -1;
392   }
393 
394   // First we load and initialize the model.
395   std::unique_ptr<tensorflow::Session> session;
396   string graph_path = tensorflow::io::JoinPath(root_dir, graph);
397   Status load_graph_status = LoadGraph(graph_path, &session);
398   if (!load_graph_status.ok()) {
399     LOG(ERROR) << load_graph_status;
400     return -1;
401   }
402 
403   // Get the image from disk as a float array of numbers, resized and normalized
404   // to the specifications the main graph expects.
405   std::vector<Tensor> image_tensors;
406   string image_path = tensorflow::io::JoinPath(root_dir, image);
407 
408   Status read_tensor_status =
409       ReadTensorFromImageFile(image_path, input_height, input_width, input_mean,
410                               input_std, &image_tensors);
411   if (!read_tensor_status.ok()) {
412     LOG(ERROR) << read_tensor_status;
413     return -1;
414   }
415   const Tensor& resized_tensor = image_tensors[0];
416 
417   // Actually run the image through the model.
418   std::vector<Tensor> outputs;
419   Status run_status =
420       session->Run({{input_layer, resized_tensor}},
421                    {output_score_layer, output_location_layer}, {}, &outputs);
422   if (!run_status.ok()) {
423     LOG(ERROR) << "Running model failed: " << run_status;
424     return -1;
425   }
426 
427   Status print_status = PrintTopDetections(outputs, box_priors, num_boxes,
428                                            num_detections, image_out,
429                                            &image_tensors[1]);
430 
431   if (!print_status.ok()) {
432     LOG(ERROR) << "Running print failed: " << print_status;
433     return -1;
434   }
435   return 0;
436 }
437