1# Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15 16import argparse 17 18import numpy as np 19import tensorflow as tf 20 21 22def load_graph(model_file): 23 graph = tf.Graph() 24 graph_def = tf.GraphDef() 25 26 with open(model_file, "rb") as f: 27 graph_def.ParseFromString(f.read()) 28 with graph.as_default(): 29 tf.import_graph_def(graph_def) 30 31 return graph 32 33 34def read_tensor_from_image_file(file_name, 35 input_height=299, 36 input_width=299, 37 input_mean=0, 38 input_std=255): 39 input_name = "file_reader" 40 output_name = "normalized" 41 file_reader = tf.read_file(file_name, input_name) 42 if file_name.endswith(".png"): 43 image_reader = tf.io.decode_png(file_reader, channels=3, name="png_reader") 44 elif file_name.endswith(".gif"): 45 image_reader = tf.squeeze(tf.io.decode_gif(file_reader, name="gif_reader")) 46 elif file_name.endswith(".bmp"): 47 image_reader = tf.io.decode_bmp(file_reader, name="bmp_reader") 48 else: 49 image_reader = tf.io.decode_jpeg( 50 file_reader, channels=3, name="jpeg_reader") 51 float_caster = tf.cast(image_reader, tf.float32) 52 dims_expander = tf.expand_dims(float_caster, 0) 53 resized = tf.image.resize_bilinear(dims_expander, [input_height, input_width]) 54 normalized = tf.divide(tf.subtract(resized, [input_mean]), [input_std]) 55 sess = tf.compat.v1.Session() 56 return sess.run(normalized) 57 58 59def load_labels(label_file): 60 proto_as_ascii_lines = tf.gfile.GFile(label_file).readlines() 61 return [l.rstrip() for l in proto_as_ascii_lines] 62 63 64if __name__ == "__main__": 65 file_name = "tensorflow/examples/label_image/data/grace_hopper.jpg" 66 model_file = \ 67 "tensorflow/examples/label_image/data/inception_v3_2016_08_28_frozen.pb" 68 label_file = "tensorflow/examples/label_image/data/imagenet_slim_labels.txt" 69 input_height = 299 70 input_width = 299 71 input_mean = 0 72 input_std = 255 73 input_layer = "input" 74 output_layer = "InceptionV3/Predictions/Reshape_1" 75 76 parser = argparse.ArgumentParser() 77 parser.add_argument("--image", help="image to be processed") 78 parser.add_argument("--graph", help="graph/model to be executed") 79 parser.add_argument("--labels", help="name of file containing labels") 80 parser.add_argument("--input_height", type=int, help="input height") 81 parser.add_argument("--input_width", type=int, help="input width") 82 parser.add_argument("--input_mean", type=int, help="input mean") 83 parser.add_argument("--input_std", type=int, help="input std") 84 parser.add_argument("--input_layer", help="name of input layer") 85 parser.add_argument("--output_layer", help="name of output layer") 86 args = parser.parse_args() 87 88 if args.graph: 89 model_file = args.graph 90 if args.image: 91 file_name = args.image 92 if args.labels: 93 label_file = args.labels 94 if args.input_height: 95 input_height = args.input_height 96 if args.input_width: 97 input_width = args.input_width 98 if args.input_mean: 99 input_mean = args.input_mean 100 if args.input_std: 101 input_std = args.input_std 102 if args.input_layer: 103 input_layer = args.input_layer 104 if args.output_layer: 105 output_layer = args.output_layer 106 107 graph = load_graph(model_file) 108 t = read_tensor_from_image_file( 109 file_name, 110 input_height=input_height, 111 input_width=input_width, 112 input_mean=input_mean, 113 input_std=input_std) 114 115 input_name = "import/" + input_layer 116 output_name = "import/" + output_layer 117 input_operation = graph.get_operation_by_name(input_name) 118 output_operation = graph.get_operation_by_name(output_name) 119 120 with tf.compat.v1.Session(graph=graph) as sess: 121 results = sess.run(output_operation.outputs[0], { 122 input_operation.outputs[0]: t 123 }) 124 results = np.squeeze(results) 125 126 top_k = results.argsort()[-5:][::-1] 127 labels = load_labels(label_file) 128 for i in top_k: 129 print(labels[i], results[i]) 130