xref: /aosp_15_r20/external/tensorflow/tensorflow/examples/label_image/label_image.py (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15
16import argparse
17
18import numpy as np
19import tensorflow as tf
20
21
22def load_graph(model_file):
23  graph = tf.Graph()
24  graph_def = tf.GraphDef()
25
26  with open(model_file, "rb") as f:
27    graph_def.ParseFromString(f.read())
28  with graph.as_default():
29    tf.import_graph_def(graph_def)
30
31  return graph
32
33
34def read_tensor_from_image_file(file_name,
35                                input_height=299,
36                                input_width=299,
37                                input_mean=0,
38                                input_std=255):
39  input_name = "file_reader"
40  output_name = "normalized"
41  file_reader = tf.read_file(file_name, input_name)
42  if file_name.endswith(".png"):
43    image_reader = tf.io.decode_png(file_reader, channels=3, name="png_reader")
44  elif file_name.endswith(".gif"):
45    image_reader = tf.squeeze(tf.io.decode_gif(file_reader, name="gif_reader"))
46  elif file_name.endswith(".bmp"):
47    image_reader = tf.io.decode_bmp(file_reader, name="bmp_reader")
48  else:
49    image_reader = tf.io.decode_jpeg(
50        file_reader, channels=3, name="jpeg_reader")
51  float_caster = tf.cast(image_reader, tf.float32)
52  dims_expander = tf.expand_dims(float_caster, 0)
53  resized = tf.image.resize_bilinear(dims_expander, [input_height, input_width])
54  normalized = tf.divide(tf.subtract(resized, [input_mean]), [input_std])
55  sess = tf.compat.v1.Session()
56  return sess.run(normalized)
57
58
59def load_labels(label_file):
60  proto_as_ascii_lines = tf.gfile.GFile(label_file).readlines()
61  return [l.rstrip() for l in proto_as_ascii_lines]
62
63
64if __name__ == "__main__":
65  file_name = "tensorflow/examples/label_image/data/grace_hopper.jpg"
66  model_file = \
67    "tensorflow/examples/label_image/data/inception_v3_2016_08_28_frozen.pb"
68  label_file = "tensorflow/examples/label_image/data/imagenet_slim_labels.txt"
69  input_height = 299
70  input_width = 299
71  input_mean = 0
72  input_std = 255
73  input_layer = "input"
74  output_layer = "InceptionV3/Predictions/Reshape_1"
75
76  parser = argparse.ArgumentParser()
77  parser.add_argument("--image", help="image to be processed")
78  parser.add_argument("--graph", help="graph/model to be executed")
79  parser.add_argument("--labels", help="name of file containing labels")
80  parser.add_argument("--input_height", type=int, help="input height")
81  parser.add_argument("--input_width", type=int, help="input width")
82  parser.add_argument("--input_mean", type=int, help="input mean")
83  parser.add_argument("--input_std", type=int, help="input std")
84  parser.add_argument("--input_layer", help="name of input layer")
85  parser.add_argument("--output_layer", help="name of output layer")
86  args = parser.parse_args()
87
88  if args.graph:
89    model_file = args.graph
90  if args.image:
91    file_name = args.image
92  if args.labels:
93    label_file = args.labels
94  if args.input_height:
95    input_height = args.input_height
96  if args.input_width:
97    input_width = args.input_width
98  if args.input_mean:
99    input_mean = args.input_mean
100  if args.input_std:
101    input_std = args.input_std
102  if args.input_layer:
103    input_layer = args.input_layer
104  if args.output_layer:
105    output_layer = args.output_layer
106
107  graph = load_graph(model_file)
108  t = read_tensor_from_image_file(
109      file_name,
110      input_height=input_height,
111      input_width=input_width,
112      input_mean=input_mean,
113      input_std=input_std)
114
115  input_name = "import/" + input_layer
116  output_name = "import/" + output_layer
117  input_operation = graph.get_operation_by_name(input_name)
118  output_operation = graph.get_operation_by_name(output_name)
119
120  with tf.compat.v1.Session(graph=graph) as sess:
121    results = sess.run(output_operation.outputs[0], {
122        input_operation.outputs[0]: t
123    })
124  results = np.squeeze(results)
125
126  top_k = results.argsort()[-5:][::-1]
127  labels = load_labels(label_file)
128  for i in top_k:
129    print(labels[i], results[i])
130