xref: /aosp_15_r20/external/tensorflow/tensorflow/lite/examples/python/label_image.py (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""label_image for tflite."""
16
17import argparse
18import time
19
20import numpy as np
21from PIL import Image
22import tensorflow as tf
23
24
25def load_labels(filename):
26  with open(filename, 'r') as f:
27    return [line.strip() for line in f.readlines()]
28
29
30if __name__ == '__main__':
31  parser = argparse.ArgumentParser()
32  parser.add_argument(
33      '-i',
34      '--image',
35      default='/tmp/grace_hopper.bmp',
36      help='image to be classified')
37  parser.add_argument(
38      '-m',
39      '--model_file',
40      default='/tmp/mobilenet_v1_1.0_224_quant.tflite',
41      help='.tflite model to be executed')
42  parser.add_argument(
43      '-l',
44      '--label_file',
45      default='/tmp/labels.txt',
46      help='name of file containing labels')
47  parser.add_argument(
48      '--input_mean',
49      default=127.5, type=float,
50      help='input_mean')
51  parser.add_argument(
52      '--input_std',
53      default=127.5, type=float,
54      help='input standard deviation')
55  parser.add_argument(
56      '--num_threads', default=None, type=int, help='number of threads')
57  parser.add_argument(
58      '-e', '--ext_delegate', help='external_delegate_library path')
59  parser.add_argument(
60      '-o',
61      '--ext_delegate_options',
62      help='external delegate options, \
63            format: "option1: value1; option2: value2"')
64
65  args = parser.parse_args()
66
67  ext_delegate = None
68  ext_delegate_options = {}
69
70  # parse extenal delegate options
71  if args.ext_delegate_options is not None:
72    options = args.ext_delegate_options.split(';')
73    for o in options:
74      kv = o.split(':')
75      if (len(kv) == 2):
76        ext_delegate_options[kv[0].strip()] = kv[1].strip()
77      else:
78        raise RuntimeError('Error parsing delegate option: ' + o)
79
80  # load external delegate
81  if args.ext_delegate is not None:
82    print('Loading external delegate from {} with args: {}'.format(
83        args.ext_delegate, ext_delegate_options))
84    ext_delegate = [
85        tflite.load_delegate(args.ext_delegate, ext_delegate_options)
86    ]
87
88  interpreter = tf.lite.Interpreter(
89      model_path=args.model_file,
90      experimental_delegates=ext_delegate,
91      num_threads=args.num_threads)
92  interpreter.allocate_tensors()
93
94  input_details = interpreter.get_input_details()
95  output_details = interpreter.get_output_details()
96
97  # check the type of the input tensor
98  floating_model = input_details[0]['dtype'] == np.float32
99
100  # NxHxWxC, H:1, W:2
101  height = input_details[0]['shape'][1]
102  width = input_details[0]['shape'][2]
103  img = Image.open(args.image).resize((width, height))
104
105  # add N dim
106  input_data = np.expand_dims(img, axis=0)
107
108  if floating_model:
109    input_data = (np.float32(input_data) - args.input_mean) / args.input_std
110
111  interpreter.set_tensor(input_details[0]['index'], input_data)
112
113  start_time = time.time()
114  interpreter.invoke()
115  stop_time = time.time()
116
117  output_data = interpreter.get_tensor(output_details[0]['index'])
118  results = np.squeeze(output_data)
119
120  top_k = results.argsort()[-5:][::-1]
121  labels = load_labels(args.label_file)
122  for i in top_k:
123    if floating_model:
124      print('{:08.6f}: {}'.format(float(results[i]), labels[i]))
125    else:
126      print('{:08.6f}: {}'.format(float(results[i] / 255.0), labels[i]))
127
128  print('time: {:.3f}ms'.format((stop_time - start_time) * 1000))
129