xref: /aosp_15_r20/external/armnn/samples/ImageClassification/run_classifier.py (revision 89c4ff92f2867872bb9e2354d150bf0c8c502810)
1*89c4ff92SAndroid Build Coastguard Worker#
2*89c4ff92SAndroid Build Coastguard Worker# Copyright © 2021 Arm Ltd and Contributors. All rights reserved.
3*89c4ff92SAndroid Build Coastguard Worker# SPDX-License-Identifier: MIT
4*89c4ff92SAndroid Build Coastguard Worker#
5*89c4ff92SAndroid Build Coastguard Workerimport argparse
6*89c4ff92SAndroid Build Coastguard Workerfrom pathlib import Path
7*89c4ff92SAndroid Build Coastguard Workerfrom typing import Union
8*89c4ff92SAndroid Build Coastguard Worker
9*89c4ff92SAndroid Build Coastguard Workerimport tflite_runtime.interpreter as tflite
10*89c4ff92SAndroid Build Coastguard Workerfrom PIL import Image
11*89c4ff92SAndroid Build Coastguard Workerimport numpy as np
12*89c4ff92SAndroid Build Coastguard Worker
13*89c4ff92SAndroid Build Coastguard Worker
14*89c4ff92SAndroid Build Coastguard Workerdef check_args(args: argparse.Namespace):
15*89c4ff92SAndroid Build Coastguard Worker    """Check the values used in the command-line have acceptable values
16*89c4ff92SAndroid Build Coastguard Worker
17*89c4ff92SAndroid Build Coastguard Worker    args:
18*89c4ff92SAndroid Build Coastguard Worker      - args: argparse.Namespace
19*89c4ff92SAndroid Build Coastguard Worker
20*89c4ff92SAndroid Build Coastguard Worker    returns:
21*89c4ff92SAndroid Build Coastguard Worker      - None
22*89c4ff92SAndroid Build Coastguard Worker
23*89c4ff92SAndroid Build Coastguard Worker    raises:
24*89c4ff92SAndroid Build Coastguard Worker      - FileNotFoundError: if passed files do not exist.
25*89c4ff92SAndroid Build Coastguard Worker      - IOError: if files are of incorrect format.
26*89c4ff92SAndroid Build Coastguard Worker    """
27*89c4ff92SAndroid Build Coastguard Worker    input_image_p = args.input_image
28*89c4ff92SAndroid Build Coastguard Worker    if not input_image_p.suffix in (".png", ".jpg", ".jpeg"):
29*89c4ff92SAndroid Build Coastguard Worker        raise IOError(
30*89c4ff92SAndroid Build Coastguard Worker            "--input_image option should point to an image file of the "
31*89c4ff92SAndroid Build Coastguard Worker            "format .jpg, .jpeg, .png"
32*89c4ff92SAndroid Build Coastguard Worker        )
33*89c4ff92SAndroid Build Coastguard Worker    if not input_image_p.exists():
34*89c4ff92SAndroid Build Coastguard Worker        raise FileNotFoundError("Cannot find ", input_image_p.name)
35*89c4ff92SAndroid Build Coastguard Worker    model_p = args.model_file
36*89c4ff92SAndroid Build Coastguard Worker    if not model_p.suffix == ".tflite":
37*89c4ff92SAndroid Build Coastguard Worker        raise IOError("--model_file should point to a tflite file.")
38*89c4ff92SAndroid Build Coastguard Worker    if not model_p.exists():
39*89c4ff92SAndroid Build Coastguard Worker        raise FileNotFoundError("Cannot find ", model_p.name)
40*89c4ff92SAndroid Build Coastguard Worker    label_mapping_p = args.label_file
41*89c4ff92SAndroid Build Coastguard Worker    if not label_mapping_p.suffix == ".txt":
42*89c4ff92SAndroid Build Coastguard Worker        raise IOError("--label_file expects a .txt file.")
43*89c4ff92SAndroid Build Coastguard Worker    if not label_mapping_p.exists():
44*89c4ff92SAndroid Build Coastguard Worker        raise FileNotFoundError("Cannot find ", label_mapping_p.name)
45*89c4ff92SAndroid Build Coastguard Worker
46*89c4ff92SAndroid Build Coastguard Worker    # check all args given in preferred backends make sense
47*89c4ff92SAndroid Build Coastguard Worker    supported_backends = ["GpuAcc", "CpuAcc", "CpuRef"]
48*89c4ff92SAndroid Build Coastguard Worker    if not all([backend in supported_backends for backend in args.preferred_backends]):
49*89c4ff92SAndroid Build Coastguard Worker        raise ValueError("Incorrect backends given. Please choose from "\
50*89c4ff92SAndroid Build Coastguard Worker            "'GpuAcc', 'CpuAcc', 'CpuRef'.")
51*89c4ff92SAndroid Build Coastguard Worker
52*89c4ff92SAndroid Build Coastguard Worker    return None
53*89c4ff92SAndroid Build Coastguard Worker
54*89c4ff92SAndroid Build Coastguard Worker
55*89c4ff92SAndroid Build Coastguard Workerdef load_image(image_path: Path, model_input_dims: Union[tuple, list], grayscale: bool):
56*89c4ff92SAndroid Build Coastguard Worker    """load an image and put into correct format for the tensorflow lite model
57*89c4ff92SAndroid Build Coastguard Worker
58*89c4ff92SAndroid Build Coastguard Worker    args:
59*89c4ff92SAndroid Build Coastguard Worker      - image_path: pathlib.Path
60*89c4ff92SAndroid Build Coastguard Worker      - model_input_dims: tuple (or array-like). (height,width)
61*89c4ff92SAndroid Build Coastguard Worker
62*89c4ff92SAndroid Build Coastguard Worker    returns:
63*89c4ff92SAndroid Build Coastguard Worker      - image: np.array
64*89c4ff92SAndroid Build Coastguard Worker    """
65*89c4ff92SAndroid Build Coastguard Worker    height, width = model_input_dims
66*89c4ff92SAndroid Build Coastguard Worker    # load and resize image
67*89c4ff92SAndroid Build Coastguard Worker    image = Image.open(image_path).resize((width, height))
68*89c4ff92SAndroid Build Coastguard Worker    # convert to greyscale if expected
69*89c4ff92SAndroid Build Coastguard Worker    if grayscale:
70*89c4ff92SAndroid Build Coastguard Worker        image = image.convert("LA")
71*89c4ff92SAndroid Build Coastguard Worker
72*89c4ff92SAndroid Build Coastguard Worker    image = np.expand_dims(image, axis=0)
73*89c4ff92SAndroid Build Coastguard Worker
74*89c4ff92SAndroid Build Coastguard Worker    return image
75*89c4ff92SAndroid Build Coastguard Worker
76*89c4ff92SAndroid Build Coastguard Worker
77*89c4ff92SAndroid Build Coastguard Workerdef load_delegate(delegate_path: Path, backends: list):
78*89c4ff92SAndroid Build Coastguard Worker    """load the armnn delegate.
79*89c4ff92SAndroid Build Coastguard Worker
80*89c4ff92SAndroid Build Coastguard Worker    args:
81*89c4ff92SAndroid Build Coastguard Worker      - delegate_path: pathlib.Path -> location of you libarmnnDelegate.so
82*89c4ff92SAndroid Build Coastguard Worker      - backends: list -> list of backends you want to use in string format
83*89c4ff92SAndroid Build Coastguard Worker
84*89c4ff92SAndroid Build Coastguard Worker    returns:
85*89c4ff92SAndroid Build Coastguard Worker      - armnn_delegate: tflite.delegate
86*89c4ff92SAndroid Build Coastguard Worker    """
87*89c4ff92SAndroid Build Coastguard Worker    # create a command separated string
88*89c4ff92SAndroid Build Coastguard Worker    backend_string = ",".join(backends)
89*89c4ff92SAndroid Build Coastguard Worker    # load delegate
90*89c4ff92SAndroid Build Coastguard Worker    armnn_delegate = tflite.load_delegate(
91*89c4ff92SAndroid Build Coastguard Worker        library=delegate_path,
92*89c4ff92SAndroid Build Coastguard Worker        options={"backends": backend_string, "logging-severity": "info"},
93*89c4ff92SAndroid Build Coastguard Worker    )
94*89c4ff92SAndroid Build Coastguard Worker
95*89c4ff92SAndroid Build Coastguard Worker    return armnn_delegate
96*89c4ff92SAndroid Build Coastguard Worker
97*89c4ff92SAndroid Build Coastguard Worker
98*89c4ff92SAndroid Build Coastguard Workerdef load_tf_model(model_path: Path, armnn_delegate: tflite.Delegate):
99*89c4ff92SAndroid Build Coastguard Worker    """load a tflite model for use with the armnn delegate.
100*89c4ff92SAndroid Build Coastguard Worker
101*89c4ff92SAndroid Build Coastguard Worker    args:
102*89c4ff92SAndroid Build Coastguard Worker      - model_path: pathlib.Path
103*89c4ff92SAndroid Build Coastguard Worker      - armnn_delegate: tflite.TfLiteDelegate
104*89c4ff92SAndroid Build Coastguard Worker
105*89c4ff92SAndroid Build Coastguard Worker    returns:
106*89c4ff92SAndroid Build Coastguard Worker      - interpreter: tflite.Interpreter
107*89c4ff92SAndroid Build Coastguard Worker    """
108*89c4ff92SAndroid Build Coastguard Worker    interpreter = tflite.Interpreter(
109*89c4ff92SAndroid Build Coastguard Worker        model_path=model_path.as_posix(), experimental_delegates=[armnn_delegate]
110*89c4ff92SAndroid Build Coastguard Worker    )
111*89c4ff92SAndroid Build Coastguard Worker    interpreter.allocate_tensors()
112*89c4ff92SAndroid Build Coastguard Worker
113*89c4ff92SAndroid Build Coastguard Worker    return interpreter
114*89c4ff92SAndroid Build Coastguard Worker
115*89c4ff92SAndroid Build Coastguard Worker
116*89c4ff92SAndroid Build Coastguard Workerdef run_inference(interpreter, input_image):
117*89c4ff92SAndroid Build Coastguard Worker    """Run inference on a processed input image and return the output from
118*89c4ff92SAndroid Build Coastguard Worker    inference.
119*89c4ff92SAndroid Build Coastguard Worker
120*89c4ff92SAndroid Build Coastguard Worker    args:
121*89c4ff92SAndroid Build Coastguard Worker      - interpreter: tflite_runtime.interpreter.Interpreter
122*89c4ff92SAndroid Build Coastguard Worker      - input_image: np.array
123*89c4ff92SAndroid Build Coastguard Worker
124*89c4ff92SAndroid Build Coastguard Worker    returns:
125*89c4ff92SAndroid Build Coastguard Worker      - output_data: np.array
126*89c4ff92SAndroid Build Coastguard Worker    """
127*89c4ff92SAndroid Build Coastguard Worker    # Get input and output tensors.
128*89c4ff92SAndroid Build Coastguard Worker    input_details = interpreter.get_input_details()
129*89c4ff92SAndroid Build Coastguard Worker    output_details = interpreter.get_output_details()
130*89c4ff92SAndroid Build Coastguard Worker    # Test model on random input data.
131*89c4ff92SAndroid Build Coastguard Worker    interpreter.set_tensor(input_details[0]["index"], input_image)
132*89c4ff92SAndroid Build Coastguard Worker    interpreter.invoke()
133*89c4ff92SAndroid Build Coastguard Worker    output_data = interpreter.get_tensor(output_details[0]["index"])
134*89c4ff92SAndroid Build Coastguard Worker
135*89c4ff92SAndroid Build Coastguard Worker    return output_data
136*89c4ff92SAndroid Build Coastguard Worker
137*89c4ff92SAndroid Build Coastguard Worker
138*89c4ff92SAndroid Build Coastguard Workerdef create_mapping(label_mapping_p):
139*89c4ff92SAndroid Build Coastguard Worker    """Creates a Python dictionary mapping an index to a label.
140*89c4ff92SAndroid Build Coastguard Worker
141*89c4ff92SAndroid Build Coastguard Worker    label_mapping[idx] = label
142*89c4ff92SAndroid Build Coastguard Worker
143*89c4ff92SAndroid Build Coastguard Worker    args:
144*89c4ff92SAndroid Build Coastguard Worker      - label_mapping_p: pathlib.Path
145*89c4ff92SAndroid Build Coastguard Worker
146*89c4ff92SAndroid Build Coastguard Worker    returns:
147*89c4ff92SAndroid Build Coastguard Worker      - label_mapping: dict
148*89c4ff92SAndroid Build Coastguard Worker    """
149*89c4ff92SAndroid Build Coastguard Worker    idx = 0
150*89c4ff92SAndroid Build Coastguard Worker    label_mapping = {}
151*89c4ff92SAndroid Build Coastguard Worker    with open(label_mapping_p) as label_mapping_raw:
152*89c4ff92SAndroid Build Coastguard Worker        for line in label_mapping_raw:
153*89c4ff92SAndroid Build Coastguard Worker            label_mapping[idx] = line
154*89c4ff92SAndroid Build Coastguard Worker            idx += 1
155*89c4ff92SAndroid Build Coastguard Worker
156*89c4ff92SAndroid Build Coastguard Worker    return label_mapping
157*89c4ff92SAndroid Build Coastguard Worker
158*89c4ff92SAndroid Build Coastguard Worker
159*89c4ff92SAndroid Build Coastguard Workerdef process_output(output_data, label_mapping):
160*89c4ff92SAndroid Build Coastguard Worker    """Process the output tensor into a label from the labelmapping file. Takes
161*89c4ff92SAndroid Build Coastguard Worker    the index of the maximum valur from the output array.
162*89c4ff92SAndroid Build Coastguard Worker
163*89c4ff92SAndroid Build Coastguard Worker    args:
164*89c4ff92SAndroid Build Coastguard Worker      - output_data: np.array
165*89c4ff92SAndroid Build Coastguard Worker      - label_mapping: dict
166*89c4ff92SAndroid Build Coastguard Worker
167*89c4ff92SAndroid Build Coastguard Worker    returns:
168*89c4ff92SAndroid Build Coastguard Worker      - str: labelmapping for max index.
169*89c4ff92SAndroid Build Coastguard Worker    """
170*89c4ff92SAndroid Build Coastguard Worker    idx = np.argmax(output_data[0])
171*89c4ff92SAndroid Build Coastguard Worker
172*89c4ff92SAndroid Build Coastguard Worker    return label_mapping[idx]
173*89c4ff92SAndroid Build Coastguard Worker
174*89c4ff92SAndroid Build Coastguard Worker
175*89c4ff92SAndroid Build Coastguard Workerdef main(args):
176*89c4ff92SAndroid Build Coastguard Worker    """Run the inference for options passed in the command line.
177*89c4ff92SAndroid Build Coastguard Worker
178*89c4ff92SAndroid Build Coastguard Worker    args:
179*89c4ff92SAndroid Build Coastguard Worker      - args: argparse.Namespace
180*89c4ff92SAndroid Build Coastguard Worker
181*89c4ff92SAndroid Build Coastguard Worker    returns:
182*89c4ff92SAndroid Build Coastguard Worker      - None
183*89c4ff92SAndroid Build Coastguard Worker    """
184*89c4ff92SAndroid Build Coastguard Worker    # sanity check on args
185*89c4ff92SAndroid Build Coastguard Worker    check_args(args)
186*89c4ff92SAndroid Build Coastguard Worker    # load in the armnn delegate
187*89c4ff92SAndroid Build Coastguard Worker    armnn_delegate = load_delegate(args.delegate_path, args.preferred_backends)
188*89c4ff92SAndroid Build Coastguard Worker    # load tflite model
189*89c4ff92SAndroid Build Coastguard Worker    interpreter = load_tf_model(args.model_file, armnn_delegate)
190*89c4ff92SAndroid Build Coastguard Worker    # get input shape for image resizing
191*89c4ff92SAndroid Build Coastguard Worker    input_shape = interpreter.get_input_details()[0]["shape"]
192*89c4ff92SAndroid Build Coastguard Worker    height, width = input_shape[1], input_shape[2]
193*89c4ff92SAndroid Build Coastguard Worker    input_shape = (height, width)
194*89c4ff92SAndroid Build Coastguard Worker    # load input image
195*89c4ff92SAndroid Build Coastguard Worker    input_image = load_image(args.input_image, input_shape, False)
196*89c4ff92SAndroid Build Coastguard Worker    # get label mapping
197*89c4ff92SAndroid Build Coastguard Worker    labelmapping = create_mapping(args.label_file)
198*89c4ff92SAndroid Build Coastguard Worker    output_tensor = run_inference(interpreter, input_image)
199*89c4ff92SAndroid Build Coastguard Worker    output_prediction = process_output(output_tensor, labelmapping)
200*89c4ff92SAndroid Build Coastguard Worker
201*89c4ff92SAndroid Build Coastguard Worker    print("Prediction: ", output_prediction)
202*89c4ff92SAndroid Build Coastguard Worker
203*89c4ff92SAndroid Build Coastguard Worker    return None
204*89c4ff92SAndroid Build Coastguard Worker
205*89c4ff92SAndroid Build Coastguard Worker
206*89c4ff92SAndroid Build Coastguard Workerif __name__ == "__main__":
207*89c4ff92SAndroid Build Coastguard Worker    parser = argparse.ArgumentParser(
208*89c4ff92SAndroid Build Coastguard Worker        formatter_class=argparse.ArgumentDefaultsHelpFormatter
209*89c4ff92SAndroid Build Coastguard Worker    )
210*89c4ff92SAndroid Build Coastguard Worker    parser.add_argument(
211*89c4ff92SAndroid Build Coastguard Worker        "--input_image", help="File path of image file", type=Path, required=True
212*89c4ff92SAndroid Build Coastguard Worker    )
213*89c4ff92SAndroid Build Coastguard Worker    parser.add_argument(
214*89c4ff92SAndroid Build Coastguard Worker        "--model_file",
215*89c4ff92SAndroid Build Coastguard Worker        help="File path of the model tflite file",
216*89c4ff92SAndroid Build Coastguard Worker        type=Path,
217*89c4ff92SAndroid Build Coastguard Worker        required=True,
218*89c4ff92SAndroid Build Coastguard Worker    )
219*89c4ff92SAndroid Build Coastguard Worker    parser.add_argument(
220*89c4ff92SAndroid Build Coastguard Worker        "--label_file",
221*89c4ff92SAndroid Build Coastguard Worker        help="File path of model labelmapping file",
222*89c4ff92SAndroid Build Coastguard Worker        type=Path,
223*89c4ff92SAndroid Build Coastguard Worker        required=True,
224*89c4ff92SAndroid Build Coastguard Worker    )
225*89c4ff92SAndroid Build Coastguard Worker    parser.add_argument(
226*89c4ff92SAndroid Build Coastguard Worker        "--delegate_path",
227*89c4ff92SAndroid Build Coastguard Worker        help="File path of ArmNN delegate file",
228*89c4ff92SAndroid Build Coastguard Worker        type=Path,
229*89c4ff92SAndroid Build Coastguard Worker        required=True,
230*89c4ff92SAndroid Build Coastguard Worker    )
231*89c4ff92SAndroid Build Coastguard Worker    parser.add_argument(
232*89c4ff92SAndroid Build Coastguard Worker        "--preferred_backends",
233*89c4ff92SAndroid Build Coastguard Worker        help="list of backends in order of preference",
234*89c4ff92SAndroid Build Coastguard Worker        type=str,
235*89c4ff92SAndroid Build Coastguard Worker        nargs="+",
236*89c4ff92SAndroid Build Coastguard Worker        required=False,
237*89c4ff92SAndroid Build Coastguard Worker        default=["CpuAcc", "CpuRef"],
238*89c4ff92SAndroid Build Coastguard Worker    )
239*89c4ff92SAndroid Build Coastguard Worker    args = parser.parse_args()
240*89c4ff92SAndroid Build Coastguard Worker
241*89c4ff92SAndroid Build Coastguard Worker    main(args)
242