1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 // A minimal but useful C++ example showing how to load an Imagenet-style object
17 // recognition TensorFlow model, prepare input images for it, run them through
18 // the graph, and interpret the results.
19 //
20 // It's designed to have as few dependencies and be as clear as possible, so
21 // it's more verbose than it could be in production code. In particular, using
22 // auto for the types of a lot of the returned values from TensorFlow calls can
23 // remove a lot of boilerplate, but I find the explicit types useful in sample
24 // code to make it simple to look up the classes involved.
25 //
26 // To use it, compile and then run in a working directory with the
27 // learning/brain/tutorials/label_image/data/ folder below it, and you should
28 // see the top five labels for the example Lena image output. You can then
29 // customize it to use your own models or images by changing the file names at
30 // the top of the main() function.
31 //
32 // The googlenet_graph.pb file included by default is created from Inception.
33 //
34 // Note that, for GIF inputs, to reuse existing code, only single-frame ones
35 // are supported.
36
37 #include <fstream>
38 #include <utility>
39 #include <vector>
40
41 #include "tensorflow/cc/ops/const_op.h"
42 #include "tensorflow/cc/ops/image_ops.h"
43 #include "tensorflow/cc/ops/standard_ops.h"
44 #include "tensorflow/core/framework/graph.pb.h"
45 #include "tensorflow/core/framework/tensor.h"
46 #include "tensorflow/core/graph/default_device.h"
47 #include "tensorflow/core/graph/graph_def_builder.h"
48 #include "tensorflow/core/lib/core/errors.h"
49 #include "tensorflow/core/lib/core/stringpiece.h"
50 #include "tensorflow/core/lib/core/threadpool.h"
51 #include "tensorflow/core/lib/io/path.h"
52 #include "tensorflow/core/lib/strings/str_util.h"
53 #include "tensorflow/core/lib/strings/stringprintf.h"
54 #include "tensorflow/core/platform/env.h"
55 #include "tensorflow/core/platform/init_main.h"
56 #include "tensorflow/core/platform/logging.h"
57 #include "tensorflow/core/platform/types.h"
58 #include "tensorflow/core/public/session.h"
59 #include "tensorflow/core/util/command_line_flags.h"
60
61 // These are all common classes it's handy to reference with no namespace.
62 using tensorflow::Flag;
63 using tensorflow::int32;
64 using tensorflow::Status;
65 using tensorflow::string;
66 using tensorflow::Tensor;
67 using tensorflow::tstring;
68
69 // Takes a file name, and loads a list of labels from it, one per line, and
70 // returns a vector of the strings. It pads with empty strings so the length
71 // of the result is a multiple of 16, because our model expects that.
ReadLabelsFile(const string & file_name,std::vector<string> * result,size_t * found_label_count)72 Status ReadLabelsFile(const string& file_name, std::vector<string>* result,
73 size_t* found_label_count) {
74 std::ifstream file(file_name);
75 if (!file) {
76 return tensorflow::errors::NotFound("Labels file ", file_name,
77 " not found.");
78 }
79 result->clear();
80 string line;
81 while (std::getline(file, line)) {
82 result->push_back(line);
83 }
84 *found_label_count = result->size();
85 const int padding = 16;
86 while (result->size() % padding) {
87 result->emplace_back();
88 }
89 return ::tensorflow::OkStatus();
90 }
91
ReadEntireFile(tensorflow::Env * env,const string & filename,Tensor * output)92 static Status ReadEntireFile(tensorflow::Env* env, const string& filename,
93 Tensor* output) {
94 tensorflow::uint64 file_size = 0;
95 TF_RETURN_IF_ERROR(env->GetFileSize(filename, &file_size));
96
97 string contents;
98 contents.resize(file_size);
99
100 std::unique_ptr<tensorflow::RandomAccessFile> file;
101 TF_RETURN_IF_ERROR(env->NewRandomAccessFile(filename, &file));
102
103 tensorflow::StringPiece data;
104 TF_RETURN_IF_ERROR(file->Read(0, file_size, &data, &(contents)[0]));
105 if (data.size() != file_size) {
106 return tensorflow::errors::DataLoss("Truncated read of '", filename,
107 "' expected ", file_size, " got ",
108 data.size());
109 }
110 output->scalar<tstring>()() = tstring(data);
111 return ::tensorflow::OkStatus();
112 }
113
114 // Given an image file name, read in the data, try to decode it as an image,
115 // resize it to the requested size, and then scale the values as desired.
ReadTensorFromImageFile(const string & file_name,const int input_height,const int input_width,const float input_mean,const float input_std,std::vector<Tensor> * out_tensors)116 Status ReadTensorFromImageFile(const string& file_name, const int input_height,
117 const int input_width, const float input_mean,
118 const float input_std,
119 std::vector<Tensor>* out_tensors) {
120 auto root = tensorflow::Scope::NewRootScope();
121 using namespace ::tensorflow::ops; // NOLINT(build/namespaces)
122
123 string input_name = "file_reader";
124 string output_name = "normalized";
125
126 // read file_name into a tensor named input
127 Tensor input(tensorflow::DT_STRING, tensorflow::TensorShape());
128 TF_RETURN_IF_ERROR(
129 ReadEntireFile(tensorflow::Env::Default(), file_name, &input));
130
131 // use a placeholder to read input data
132 auto file_reader =
133 Placeholder(root.WithOpName("input"), tensorflow::DataType::DT_STRING);
134
135 std::vector<std::pair<string, tensorflow::Tensor>> inputs = {
136 {"input", input},
137 };
138
139 // Now try to figure out what kind of file it is and decode it.
140 const int wanted_channels = 3;
141 tensorflow::Output image_reader;
142 if (tensorflow::str_util::EndsWith(file_name, ".png")) {
143 image_reader = DecodePng(root.WithOpName("png_reader"), file_reader,
144 DecodePng::Channels(wanted_channels));
145 } else if (tensorflow::str_util::EndsWith(file_name, ".gif")) {
146 // gif decoder returns 4-D tensor, remove the first dim
147 image_reader =
148 Squeeze(root.WithOpName("squeeze_first_dim"),
149 DecodeGif(root.WithOpName("gif_reader"), file_reader));
150 } else if (tensorflow::str_util::EndsWith(file_name, ".bmp")) {
151 image_reader = DecodeBmp(root.WithOpName("bmp_reader"), file_reader);
152 } else {
153 // Assume if it's neither a PNG nor a GIF then it must be a JPEG.
154 image_reader = DecodeJpeg(root.WithOpName("jpeg_reader"), file_reader,
155 DecodeJpeg::Channels(wanted_channels));
156 }
157 // Now cast the image data to float so we can do normal math on it.
158 auto float_caster =
159 Cast(root.WithOpName("float_caster"), image_reader, tensorflow::DT_FLOAT);
160 // The convention for image ops in TensorFlow is that all images are expected
161 // to be in batches, so that they're four-dimensional arrays with indices of
162 // [batch, height, width, channel]. Because we only have a single image, we
163 // have to add a batch dimension of 1 to the start with ExpandDims().
164 auto dims_expander = ExpandDims(root, float_caster, 0);
165 // Bilinearly resize the image to fit the required dimensions.
166 auto resized = ResizeBilinear(
167 root, dims_expander,
168 Const(root.WithOpName("size"), {input_height, input_width}));
169 // Subtract the mean and divide by the scale.
170 Div(root.WithOpName(output_name), Sub(root, resized, {input_mean}),
171 {input_std});
172
173 // This runs the GraphDef network definition that we've just constructed, and
174 // returns the results in the output tensor.
175 tensorflow::GraphDef graph;
176 TF_RETURN_IF_ERROR(root.ToGraphDef(&graph));
177
178 std::unique_ptr<tensorflow::Session> session(
179 tensorflow::NewSession(tensorflow::SessionOptions()));
180 TF_RETURN_IF_ERROR(session->Create(graph));
181 TF_RETURN_IF_ERROR(session->Run({inputs}, {output_name}, {}, out_tensors));
182 return ::tensorflow::OkStatus();
183 }
184
185 // Reads a model graph definition from disk, and creates a session object you
186 // can use to run it.
LoadGraph(const string & graph_file_name,std::unique_ptr<tensorflow::Session> * session)187 Status LoadGraph(const string& graph_file_name,
188 std::unique_ptr<tensorflow::Session>* session) {
189 tensorflow::GraphDef graph_def;
190 Status load_graph_status =
191 ReadBinaryProto(tensorflow::Env::Default(), graph_file_name, &graph_def);
192 if (!load_graph_status.ok()) {
193 return tensorflow::errors::NotFound("Failed to load compute graph at '",
194 graph_file_name, "'");
195 }
196 session->reset(tensorflow::NewSession(tensorflow::SessionOptions()));
197 Status session_create_status = (*session)->Create(graph_def);
198 if (!session_create_status.ok()) {
199 return session_create_status;
200 }
201 return ::tensorflow::OkStatus();
202 }
203
204 // Analyzes the output of the Inception graph to retrieve the highest scores and
205 // their positions in the tensor, which correspond to categories.
GetTopLabels(const std::vector<Tensor> & outputs,int how_many_labels,Tensor * indices,Tensor * scores)206 Status GetTopLabels(const std::vector<Tensor>& outputs, int how_many_labels,
207 Tensor* indices, Tensor* scores) {
208 auto root = tensorflow::Scope::NewRootScope();
209 using namespace ::tensorflow::ops; // NOLINT(build/namespaces)
210
211 string output_name = "top_k";
212 TopK(root.WithOpName(output_name), outputs[0], how_many_labels);
213 // This runs the GraphDef network definition that we've just constructed, and
214 // returns the results in the output tensors.
215 tensorflow::GraphDef graph;
216 TF_RETURN_IF_ERROR(root.ToGraphDef(&graph));
217
218 std::unique_ptr<tensorflow::Session> session(
219 tensorflow::NewSession(tensorflow::SessionOptions()));
220 TF_RETURN_IF_ERROR(session->Create(graph));
221 // The TopK node returns two outputs, the scores and their original indices,
222 // so we have to append :0 and :1 to specify them both.
223 std::vector<Tensor> out_tensors;
224 TF_RETURN_IF_ERROR(session->Run({}, {output_name + ":0", output_name + ":1"},
225 {}, &out_tensors));
226 *scores = out_tensors[0];
227 *indices = out_tensors[1];
228 return ::tensorflow::OkStatus();
229 }
230
231 // Given the output of a model run, and the name of a file containing the labels
232 // this prints out the top five highest-scoring values.
PrintTopLabels(const std::vector<Tensor> & outputs,const string & labels_file_name)233 Status PrintTopLabels(const std::vector<Tensor>& outputs,
234 const string& labels_file_name) {
235 std::vector<string> labels;
236 size_t label_count;
237 Status read_labels_status =
238 ReadLabelsFile(labels_file_name, &labels, &label_count);
239 if (!read_labels_status.ok()) {
240 LOG(ERROR) << read_labels_status;
241 return read_labels_status;
242 }
243 const int how_many_labels = std::min(5, static_cast<int>(label_count));
244 Tensor indices;
245 Tensor scores;
246 TF_RETURN_IF_ERROR(GetTopLabels(outputs, how_many_labels, &indices, &scores));
247 tensorflow::TTypes<float>::Flat scores_flat = scores.flat<float>();
248 tensorflow::TTypes<int32>::Flat indices_flat = indices.flat<int32>();
249 for (int pos = 0; pos < how_many_labels; ++pos) {
250 const int label_index = indices_flat(pos);
251 const float score = scores_flat(pos);
252 LOG(INFO) << labels[label_index] << " (" << label_index << "): " << score;
253 }
254 return ::tensorflow::OkStatus();
255 }
256
257 // This is a testing function that returns whether the top label index is the
258 // one that's expected.
CheckTopLabel(const std::vector<Tensor> & outputs,int expected,bool * is_expected)259 Status CheckTopLabel(const std::vector<Tensor>& outputs, int expected,
260 bool* is_expected) {
261 *is_expected = false;
262 Tensor indices;
263 Tensor scores;
264 const int how_many_labels = 1;
265 TF_RETURN_IF_ERROR(GetTopLabels(outputs, how_many_labels, &indices, &scores));
266 tensorflow::TTypes<int32>::Flat indices_flat = indices.flat<int32>();
267 if (indices_flat(0) != expected) {
268 LOG(ERROR) << "Expected label #" << expected << " but got #"
269 << indices_flat(0);
270 *is_expected = false;
271 } else {
272 *is_expected = true;
273 }
274 return ::tensorflow::OkStatus();
275 }
276
main(int argc,char * argv[])277 int main(int argc, char* argv[]) {
278 // These are the command-line flags the program can understand.
279 // They define where the graph and input data is located, and what kind of
280 // input the model expects. If you train your own model, or use something
281 // other than inception_v3, then you'll need to update these.
282 string image = "tensorflow/examples/label_image/data/grace_hopper.jpg";
283 string graph =
284 "tensorflow/examples/label_image/data/inception_v3_2016_08_28_frozen.pb";
285 string labels =
286 "tensorflow/examples/label_image/data/imagenet_slim_labels.txt";
287 int32_t input_width = 299;
288 int32_t input_height = 299;
289 float input_mean = 0;
290 float input_std = 255;
291 string input_layer = "input";
292 string output_layer = "InceptionV3/Predictions/Reshape_1";
293 bool self_test = false;
294 string root_dir = "";
295 std::vector<Flag> flag_list = {
296 Flag("image", &image, "image to be processed"),
297 Flag("graph", &graph, "graph to be executed"),
298 Flag("labels", &labels, "name of file containing labels"),
299 Flag("input_width", &input_width, "resize image to this width in pixels"),
300 Flag("input_height", &input_height,
301 "resize image to this height in pixels"),
302 Flag("input_mean", &input_mean, "scale pixel values to this mean"),
303 Flag("input_std", &input_std, "scale pixel values to this std deviation"),
304 Flag("input_layer", &input_layer, "name of input layer"),
305 Flag("output_layer", &output_layer, "name of output layer"),
306 Flag("self_test", &self_test, "run a self test"),
307 Flag("root_dir", &root_dir,
308 "interpret image and graph file names relative to this directory"),
309 };
310 string usage = tensorflow::Flags::Usage(argv[0], flag_list);
311 const bool parse_result = tensorflow::Flags::Parse(&argc, argv, flag_list);
312 if (!parse_result) {
313 LOG(ERROR) << usage;
314 return -1;
315 }
316
317 // We need to call this to set up global state for TensorFlow.
318 tensorflow::port::InitMain(argv[0], &argc, &argv);
319 if (argc > 1) {
320 LOG(ERROR) << "Unknown argument " << argv[1] << "\n" << usage;
321 return -1;
322 }
323
324 // First we load and initialize the model.
325 std::unique_ptr<tensorflow::Session> session;
326 string graph_path = tensorflow::io::JoinPath(root_dir, graph);
327 Status load_graph_status = LoadGraph(graph_path, &session);
328 if (!load_graph_status.ok()) {
329 LOG(ERROR) << load_graph_status;
330 return -1;
331 }
332
333 // Get the image from disk as a float array of numbers, resized and normalized
334 // to the specifications the main graph expects.
335 std::vector<Tensor> resized_tensors;
336 string image_path = tensorflow::io::JoinPath(root_dir, image);
337 Status read_tensor_status =
338 ReadTensorFromImageFile(image_path, input_height, input_width, input_mean,
339 input_std, &resized_tensors);
340 if (!read_tensor_status.ok()) {
341 LOG(ERROR) << read_tensor_status;
342 return -1;
343 }
344 const Tensor& resized_tensor = resized_tensors[0];
345
346 // Actually run the image through the model.
347 std::vector<Tensor> outputs;
348 Status run_status = session->Run({{input_layer, resized_tensor}},
349 {output_layer}, {}, &outputs);
350 if (!run_status.ok()) {
351 LOG(ERROR) << "Running model failed: " << run_status;
352 return -1;
353 }
354
355 // This is for automated testing to make sure we get the expected result with
356 // the default settings. We know that label 653 (military uniform) should be
357 // the top label for the Admiral Hopper image.
358 if (self_test) {
359 bool expected_matches;
360 Status check_status = CheckTopLabel(outputs, 653, &expected_matches);
361 if (!check_status.ok()) {
362 LOG(ERROR) << "Running check failed: " << check_status;
363 return -1;
364 }
365 if (!expected_matches) {
366 LOG(ERROR) << "Self-test failed!";
367 return -1;
368 }
369 }
370
371 // Do something interesting with the results we've generated.
372 Status print_status = PrintTopLabels(outputs, labels);
373 if (!print_status.ok()) {
374 LOG(ERROR) << "Running print failed: " << print_status;
375 return -1;
376 }
377
378 return 0;
379 }
380