xref: /aosp_15_r20/external/tensorflow/tensorflow/examples/label_image/main.cc (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 // A minimal but useful C++ example showing how to load an Imagenet-style object
17 // recognition TensorFlow model, prepare input images for it, run them through
18 // the graph, and interpret the results.
19 //
20 // It's designed to have as few dependencies and be as clear as possible, so
21 // it's more verbose than it could be in production code. In particular, using
22 // auto for the types of a lot of the returned values from TensorFlow calls can
23 // remove a lot of boilerplate, but I find the explicit types useful in sample
24 // code to make it simple to look up the classes involved.
25 //
26 // To use it, compile and then run in a working directory with the
27 // learning/brain/tutorials/label_image/data/ folder below it, and you should
28 // see the top five labels for the example Lena image output. You can then
29 // customize it to use your own models or images by changing the file names at
30 // the top of the main() function.
31 //
32 // The googlenet_graph.pb file included by default is created from Inception.
33 //
34 // Note that, for GIF inputs, to reuse existing code, only single-frame ones
35 // are supported.
36 
37 #include <fstream>
38 #include <utility>
39 #include <vector>
40 
41 #include "tensorflow/cc/ops/const_op.h"
42 #include "tensorflow/cc/ops/image_ops.h"
43 #include "tensorflow/cc/ops/standard_ops.h"
44 #include "tensorflow/core/framework/graph.pb.h"
45 #include "tensorflow/core/framework/tensor.h"
46 #include "tensorflow/core/graph/default_device.h"
47 #include "tensorflow/core/graph/graph_def_builder.h"
48 #include "tensorflow/core/lib/core/errors.h"
49 #include "tensorflow/core/lib/core/stringpiece.h"
50 #include "tensorflow/core/lib/core/threadpool.h"
51 #include "tensorflow/core/lib/io/path.h"
52 #include "tensorflow/core/lib/strings/str_util.h"
53 #include "tensorflow/core/lib/strings/stringprintf.h"
54 #include "tensorflow/core/platform/env.h"
55 #include "tensorflow/core/platform/init_main.h"
56 #include "tensorflow/core/platform/logging.h"
57 #include "tensorflow/core/platform/types.h"
58 #include "tensorflow/core/public/session.h"
59 #include "tensorflow/core/util/command_line_flags.h"
60 
61 // These are all common classes it's handy to reference with no namespace.
62 using tensorflow::Flag;
63 using tensorflow::int32;
64 using tensorflow::Status;
65 using tensorflow::string;
66 using tensorflow::Tensor;
67 using tensorflow::tstring;
68 
69 // Takes a file name, and loads a list of labels from it, one per line, and
70 // returns a vector of the strings. It pads with empty strings so the length
71 // of the result is a multiple of 16, because our model expects that.
ReadLabelsFile(const string & file_name,std::vector<string> * result,size_t * found_label_count)72 Status ReadLabelsFile(const string& file_name, std::vector<string>* result,
73                       size_t* found_label_count) {
74   std::ifstream file(file_name);
75   if (!file) {
76     return tensorflow::errors::NotFound("Labels file ", file_name,
77                                         " not found.");
78   }
79   result->clear();
80   string line;
81   while (std::getline(file, line)) {
82     result->push_back(line);
83   }
84   *found_label_count = result->size();
85   const int padding = 16;
86   while (result->size() % padding) {
87     result->emplace_back();
88   }
89   return ::tensorflow::OkStatus();
90 }
91 
ReadEntireFile(tensorflow::Env * env,const string & filename,Tensor * output)92 static Status ReadEntireFile(tensorflow::Env* env, const string& filename,
93                              Tensor* output) {
94   tensorflow::uint64 file_size = 0;
95   TF_RETURN_IF_ERROR(env->GetFileSize(filename, &file_size));
96 
97   string contents;
98   contents.resize(file_size);
99 
100   std::unique_ptr<tensorflow::RandomAccessFile> file;
101   TF_RETURN_IF_ERROR(env->NewRandomAccessFile(filename, &file));
102 
103   tensorflow::StringPiece data;
104   TF_RETURN_IF_ERROR(file->Read(0, file_size, &data, &(contents)[0]));
105   if (data.size() != file_size) {
106     return tensorflow::errors::DataLoss("Truncated read of '", filename,
107                                         "' expected ", file_size, " got ",
108                                         data.size());
109   }
110   output->scalar<tstring>()() = tstring(data);
111   return ::tensorflow::OkStatus();
112 }
113 
114 // Given an image file name, read in the data, try to decode it as an image,
115 // resize it to the requested size, and then scale the values as desired.
ReadTensorFromImageFile(const string & file_name,const int input_height,const int input_width,const float input_mean,const float input_std,std::vector<Tensor> * out_tensors)116 Status ReadTensorFromImageFile(const string& file_name, const int input_height,
117                                const int input_width, const float input_mean,
118                                const float input_std,
119                                std::vector<Tensor>* out_tensors) {
120   auto root = tensorflow::Scope::NewRootScope();
121   using namespace ::tensorflow::ops;  // NOLINT(build/namespaces)
122 
123   string input_name = "file_reader";
124   string output_name = "normalized";
125 
126   // read file_name into a tensor named input
127   Tensor input(tensorflow::DT_STRING, tensorflow::TensorShape());
128   TF_RETURN_IF_ERROR(
129       ReadEntireFile(tensorflow::Env::Default(), file_name, &input));
130 
131   // use a placeholder to read input data
132   auto file_reader =
133       Placeholder(root.WithOpName("input"), tensorflow::DataType::DT_STRING);
134 
135   std::vector<std::pair<string, tensorflow::Tensor>> inputs = {
136       {"input", input},
137   };
138 
139   // Now try to figure out what kind of file it is and decode it.
140   const int wanted_channels = 3;
141   tensorflow::Output image_reader;
142   if (tensorflow::str_util::EndsWith(file_name, ".png")) {
143     image_reader = DecodePng(root.WithOpName("png_reader"), file_reader,
144                              DecodePng::Channels(wanted_channels));
145   } else if (tensorflow::str_util::EndsWith(file_name, ".gif")) {
146     // gif decoder returns 4-D tensor, remove the first dim
147     image_reader =
148         Squeeze(root.WithOpName("squeeze_first_dim"),
149                 DecodeGif(root.WithOpName("gif_reader"), file_reader));
150   } else if (tensorflow::str_util::EndsWith(file_name, ".bmp")) {
151     image_reader = DecodeBmp(root.WithOpName("bmp_reader"), file_reader);
152   } else {
153     // Assume if it's neither a PNG nor a GIF then it must be a JPEG.
154     image_reader = DecodeJpeg(root.WithOpName("jpeg_reader"), file_reader,
155                               DecodeJpeg::Channels(wanted_channels));
156   }
157   // Now cast the image data to float so we can do normal math on it.
158   auto float_caster =
159       Cast(root.WithOpName("float_caster"), image_reader, tensorflow::DT_FLOAT);
160   // The convention for image ops in TensorFlow is that all images are expected
161   // to be in batches, so that they're four-dimensional arrays with indices of
162   // [batch, height, width, channel]. Because we only have a single image, we
163   // have to add a batch dimension of 1 to the start with ExpandDims().
164   auto dims_expander = ExpandDims(root, float_caster, 0);
165   // Bilinearly resize the image to fit the required dimensions.
166   auto resized = ResizeBilinear(
167       root, dims_expander,
168       Const(root.WithOpName("size"), {input_height, input_width}));
169   // Subtract the mean and divide by the scale.
170   Div(root.WithOpName(output_name), Sub(root, resized, {input_mean}),
171       {input_std});
172 
173   // This runs the GraphDef network definition that we've just constructed, and
174   // returns the results in the output tensor.
175   tensorflow::GraphDef graph;
176   TF_RETURN_IF_ERROR(root.ToGraphDef(&graph));
177 
178   std::unique_ptr<tensorflow::Session> session(
179       tensorflow::NewSession(tensorflow::SessionOptions()));
180   TF_RETURN_IF_ERROR(session->Create(graph));
181   TF_RETURN_IF_ERROR(session->Run({inputs}, {output_name}, {}, out_tensors));
182   return ::tensorflow::OkStatus();
183 }
184 
185 // Reads a model graph definition from disk, and creates a session object you
186 // can use to run it.
LoadGraph(const string & graph_file_name,std::unique_ptr<tensorflow::Session> * session)187 Status LoadGraph(const string& graph_file_name,
188                  std::unique_ptr<tensorflow::Session>* session) {
189   tensorflow::GraphDef graph_def;
190   Status load_graph_status =
191       ReadBinaryProto(tensorflow::Env::Default(), graph_file_name, &graph_def);
192   if (!load_graph_status.ok()) {
193     return tensorflow::errors::NotFound("Failed to load compute graph at '",
194                                         graph_file_name, "'");
195   }
196   session->reset(tensorflow::NewSession(tensorflow::SessionOptions()));
197   Status session_create_status = (*session)->Create(graph_def);
198   if (!session_create_status.ok()) {
199     return session_create_status;
200   }
201   return ::tensorflow::OkStatus();
202 }
203 
204 // Analyzes the output of the Inception graph to retrieve the highest scores and
205 // their positions in the tensor, which correspond to categories.
GetTopLabels(const std::vector<Tensor> & outputs,int how_many_labels,Tensor * indices,Tensor * scores)206 Status GetTopLabels(const std::vector<Tensor>& outputs, int how_many_labels,
207                     Tensor* indices, Tensor* scores) {
208   auto root = tensorflow::Scope::NewRootScope();
209   using namespace ::tensorflow::ops;  // NOLINT(build/namespaces)
210 
211   string output_name = "top_k";
212   TopK(root.WithOpName(output_name), outputs[0], how_many_labels);
213   // This runs the GraphDef network definition that we've just constructed, and
214   // returns the results in the output tensors.
215   tensorflow::GraphDef graph;
216   TF_RETURN_IF_ERROR(root.ToGraphDef(&graph));
217 
218   std::unique_ptr<tensorflow::Session> session(
219       tensorflow::NewSession(tensorflow::SessionOptions()));
220   TF_RETURN_IF_ERROR(session->Create(graph));
221   // The TopK node returns two outputs, the scores and their original indices,
222   // so we have to append :0 and :1 to specify them both.
223   std::vector<Tensor> out_tensors;
224   TF_RETURN_IF_ERROR(session->Run({}, {output_name + ":0", output_name + ":1"},
225                                   {}, &out_tensors));
226   *scores = out_tensors[0];
227   *indices = out_tensors[1];
228   return ::tensorflow::OkStatus();
229 }
230 
231 // Given the output of a model run, and the name of a file containing the labels
232 // this prints out the top five highest-scoring values.
PrintTopLabels(const std::vector<Tensor> & outputs,const string & labels_file_name)233 Status PrintTopLabels(const std::vector<Tensor>& outputs,
234                       const string& labels_file_name) {
235   std::vector<string> labels;
236   size_t label_count;
237   Status read_labels_status =
238       ReadLabelsFile(labels_file_name, &labels, &label_count);
239   if (!read_labels_status.ok()) {
240     LOG(ERROR) << read_labels_status;
241     return read_labels_status;
242   }
243   const int how_many_labels = std::min(5, static_cast<int>(label_count));
244   Tensor indices;
245   Tensor scores;
246   TF_RETURN_IF_ERROR(GetTopLabels(outputs, how_many_labels, &indices, &scores));
247   tensorflow::TTypes<float>::Flat scores_flat = scores.flat<float>();
248   tensorflow::TTypes<int32>::Flat indices_flat = indices.flat<int32>();
249   for (int pos = 0; pos < how_many_labels; ++pos) {
250     const int label_index = indices_flat(pos);
251     const float score = scores_flat(pos);
252     LOG(INFO) << labels[label_index] << " (" << label_index << "): " << score;
253   }
254   return ::tensorflow::OkStatus();
255 }
256 
257 // This is a testing function that returns whether the top label index is the
258 // one that's expected.
CheckTopLabel(const std::vector<Tensor> & outputs,int expected,bool * is_expected)259 Status CheckTopLabel(const std::vector<Tensor>& outputs, int expected,
260                      bool* is_expected) {
261   *is_expected = false;
262   Tensor indices;
263   Tensor scores;
264   const int how_many_labels = 1;
265   TF_RETURN_IF_ERROR(GetTopLabels(outputs, how_many_labels, &indices, &scores));
266   tensorflow::TTypes<int32>::Flat indices_flat = indices.flat<int32>();
267   if (indices_flat(0) != expected) {
268     LOG(ERROR) << "Expected label #" << expected << " but got #"
269                << indices_flat(0);
270     *is_expected = false;
271   } else {
272     *is_expected = true;
273   }
274   return ::tensorflow::OkStatus();
275 }
276 
main(int argc,char * argv[])277 int main(int argc, char* argv[]) {
278   // These are the command-line flags the program can understand.
279   // They define where the graph and input data is located, and what kind of
280   // input the model expects. If you train your own model, or use something
281   // other than inception_v3, then you'll need to update these.
282   string image = "tensorflow/examples/label_image/data/grace_hopper.jpg";
283   string graph =
284       "tensorflow/examples/label_image/data/inception_v3_2016_08_28_frozen.pb";
285   string labels =
286       "tensorflow/examples/label_image/data/imagenet_slim_labels.txt";
287   int32_t input_width = 299;
288   int32_t input_height = 299;
289   float input_mean = 0;
290   float input_std = 255;
291   string input_layer = "input";
292   string output_layer = "InceptionV3/Predictions/Reshape_1";
293   bool self_test = false;
294   string root_dir = "";
295   std::vector<Flag> flag_list = {
296       Flag("image", &image, "image to be processed"),
297       Flag("graph", &graph, "graph to be executed"),
298       Flag("labels", &labels, "name of file containing labels"),
299       Flag("input_width", &input_width, "resize image to this width in pixels"),
300       Flag("input_height", &input_height,
301            "resize image to this height in pixels"),
302       Flag("input_mean", &input_mean, "scale pixel values to this mean"),
303       Flag("input_std", &input_std, "scale pixel values to this std deviation"),
304       Flag("input_layer", &input_layer, "name of input layer"),
305       Flag("output_layer", &output_layer, "name of output layer"),
306       Flag("self_test", &self_test, "run a self test"),
307       Flag("root_dir", &root_dir,
308            "interpret image and graph file names relative to this directory"),
309   };
310   string usage = tensorflow::Flags::Usage(argv[0], flag_list);
311   const bool parse_result = tensorflow::Flags::Parse(&argc, argv, flag_list);
312   if (!parse_result) {
313     LOG(ERROR) << usage;
314     return -1;
315   }
316 
317   // We need to call this to set up global state for TensorFlow.
318   tensorflow::port::InitMain(argv[0], &argc, &argv);
319   if (argc > 1) {
320     LOG(ERROR) << "Unknown argument " << argv[1] << "\n" << usage;
321     return -1;
322   }
323 
324   // First we load and initialize the model.
325   std::unique_ptr<tensorflow::Session> session;
326   string graph_path = tensorflow::io::JoinPath(root_dir, graph);
327   Status load_graph_status = LoadGraph(graph_path, &session);
328   if (!load_graph_status.ok()) {
329     LOG(ERROR) << load_graph_status;
330     return -1;
331   }
332 
333   // Get the image from disk as a float array of numbers, resized and normalized
334   // to the specifications the main graph expects.
335   std::vector<Tensor> resized_tensors;
336   string image_path = tensorflow::io::JoinPath(root_dir, image);
337   Status read_tensor_status =
338       ReadTensorFromImageFile(image_path, input_height, input_width, input_mean,
339                               input_std, &resized_tensors);
340   if (!read_tensor_status.ok()) {
341     LOG(ERROR) << read_tensor_status;
342     return -1;
343   }
344   const Tensor& resized_tensor = resized_tensors[0];
345 
346   // Actually run the image through the model.
347   std::vector<Tensor> outputs;
348   Status run_status = session->Run({{input_layer, resized_tensor}},
349                                    {output_layer}, {}, &outputs);
350   if (!run_status.ok()) {
351     LOG(ERROR) << "Running model failed: " << run_status;
352     return -1;
353   }
354 
355   // This is for automated testing to make sure we get the expected result with
356   // the default settings. We know that label 653 (military uniform) should be
357   // the top label for the Admiral Hopper image.
358   if (self_test) {
359     bool expected_matches;
360     Status check_status = CheckTopLabel(outputs, 653, &expected_matches);
361     if (!check_status.ok()) {
362       LOG(ERROR) << "Running check failed: " << check_status;
363       return -1;
364     }
365     if (!expected_matches) {
366       LOG(ERROR) << "Self-test failed!";
367       return -1;
368     }
369   }
370 
371   // Do something interesting with the results we've generated.
372   Status print_status = PrintTopLabels(outputs, labels);
373   if (!print_status.ok()) {
374     LOG(ERROR) << "Running print failed: " << print_status;
375     return -1;
376   }
377 
378   return 0;
379 }
380