1# Copyright 2018 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""label_image for tflite.""" 16 17import argparse 18import time 19 20import numpy as np 21from PIL import Image 22import tensorflow as tf 23 24 25def load_labels(filename): 26 with open(filename, 'r') as f: 27 return [line.strip() for line in f.readlines()] 28 29 30if __name__ == '__main__': 31 parser = argparse.ArgumentParser() 32 parser.add_argument( 33 '-i', 34 '--image', 35 default='/tmp/grace_hopper.bmp', 36 help='image to be classified') 37 parser.add_argument( 38 '-m', 39 '--model_file', 40 default='/tmp/mobilenet_v1_1.0_224_quant.tflite', 41 help='.tflite model to be executed') 42 parser.add_argument( 43 '-l', 44 '--label_file', 45 default='/tmp/labels.txt', 46 help='name of file containing labels') 47 parser.add_argument( 48 '--input_mean', 49 default=127.5, type=float, 50 help='input_mean') 51 parser.add_argument( 52 '--input_std', 53 default=127.5, type=float, 54 help='input standard deviation') 55 parser.add_argument( 56 '--num_threads', default=None, type=int, help='number of threads') 57 parser.add_argument( 58 '-e', '--ext_delegate', help='external_delegate_library path') 59 parser.add_argument( 60 '-o', 61 '--ext_delegate_options', 62 help='external delegate options, \ 63 format: "option1: value1; option2: value2"') 64 65 args = parser.parse_args() 66 67 ext_delegate = None 68 ext_delegate_options = {} 69 70 # parse extenal delegate options 71 if args.ext_delegate_options is not None: 72 options = args.ext_delegate_options.split(';') 73 for o in options: 74 kv = o.split(':') 75 if (len(kv) == 2): 76 ext_delegate_options[kv[0].strip()] = kv[1].strip() 77 else: 78 raise RuntimeError('Error parsing delegate option: ' + o) 79 80 # load external delegate 81 if args.ext_delegate is not None: 82 print('Loading external delegate from {} with args: {}'.format( 83 args.ext_delegate, ext_delegate_options)) 84 ext_delegate = [ 85 tflite.load_delegate(args.ext_delegate, ext_delegate_options) 86 ] 87 88 interpreter = tf.lite.Interpreter( 89 model_path=args.model_file, 90 experimental_delegates=ext_delegate, 91 num_threads=args.num_threads) 92 interpreter.allocate_tensors() 93 94 input_details = interpreter.get_input_details() 95 output_details = interpreter.get_output_details() 96 97 # check the type of the input tensor 98 floating_model = input_details[0]['dtype'] == np.float32 99 100 # NxHxWxC, H:1, W:2 101 height = input_details[0]['shape'][1] 102 width = input_details[0]['shape'][2] 103 img = Image.open(args.image).resize((width, height)) 104 105 # add N dim 106 input_data = np.expand_dims(img, axis=0) 107 108 if floating_model: 109 input_data = (np.float32(input_data) - args.input_mean) / args.input_std 110 111 interpreter.set_tensor(input_details[0]['index'], input_data) 112 113 start_time = time.time() 114 interpreter.invoke() 115 stop_time = time.time() 116 117 output_data = interpreter.get_tensor(output_details[0]['index']) 118 results = np.squeeze(output_data) 119 120 top_k = results.argsort()[-5:][::-1] 121 labels = load_labels(args.label_file) 122 for i in top_k: 123 if floating_model: 124 print('{:08.6f}: {}'.format(float(results[i]), labels[i])) 125 else: 126 print('{:08.6f}: {}'.format(float(results[i] / 255.0), labels[i])) 127 128 print('time: {:.3f}ms'.format((stop_time - start_time) * 1000)) 129