xref: /aosp_15_r20/external/executorch/examples/qualcomm/scripts/deeplab_v3.py (revision 523fa7a60841cd1ecfb9cc4201f1ca8b03ed023a)
1# Copyright (c) Qualcomm Innovation Center, Inc.
2# All rights reserved
3#
4# This source code is licensed under the BSD-style license found in the
5# LICENSE file in the root directory of this source tree.
6
7import json
8import os
9import random
10import re
11from multiprocessing.connection import Client
12
13import numpy as np
14import torch
15
16from executorch.backends.qualcomm.quantizer.quantizer import QuantDtype
17from executorch.examples.models.deeplab_v3 import DeepLabV3ResNet101Model
18from executorch.examples.qualcomm.utils import (
19    build_executorch_binary,
20    make_output_dir,
21    parse_skip_delegation_node,
22    segmentation_metrics,
23    setup_common_args_and_variables,
24    SimpleADB,
25)
26
27
28def get_dataset(data_size, dataset_dir, download):
29    import numpy as np
30    from torchvision import datasets, transforms
31
32    input_size = (224, 224)
33    preprocess = transforms.Compose(
34        [
35            transforms.Resize(input_size),
36            transforms.ToTensor(),
37            transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
38        ]
39    )
40    dataset = list(
41        datasets.VOCSegmentation(
42            root=os.path.join(dataset_dir, "voc_image"),
43            year="2012",
44            image_set="val",
45            transform=preprocess,
46            download=download,
47        )
48    )
49
50    # prepare input data
51    random.shuffle(dataset)
52    inputs, targets, input_list = [], [], ""
53    for index, data in enumerate(dataset):
54        if index >= data_size:
55            break
56        image, target = data
57        inputs.append((image.unsqueeze(0),))
58        targets.append(np.array(target.resize(input_size)))
59        input_list += f"input_{index}_0.raw\n"
60
61    return inputs, targets, input_list
62
63
64def main(args):
65    skip_node_id_set, skip_node_op_set = parse_skip_delegation_node(args)
66
67    # ensure the working directory exist.
68    os.makedirs(args.artifact, exist_ok=True)
69
70    if not args.compile_only and args.device is None:
71        raise RuntimeError(
72            "device serial is required if not compile only. "
73            "Please specify a device serial by -s/--device argument."
74        )
75
76    data_num = 100
77    if args.compile_only:
78        inputs = [(torch.rand(1, 3, 224, 224),)]
79    else:
80        inputs, targets, input_list = get_dataset(
81            data_size=data_num, dataset_dir=args.artifact, download=args.download
82        )
83
84    pte_filename = "dl3_qnn_q8"
85    instance = DeepLabV3ResNet101Model()
86
87    build_executorch_binary(
88        instance.get_eager_model().eval(),
89        instance.get_example_inputs(),
90        args.model,
91        f"{args.artifact}/{pte_filename}",
92        inputs,
93        skip_node_id_set=skip_node_id_set,
94        skip_node_op_set=skip_node_op_set,
95        quant_dtype=QuantDtype.use_8a8w,
96        shared_buffer=args.shared_buffer,
97    )
98
99    if args.compile_only:
100        return
101
102    adb = SimpleADB(
103        qnn_sdk=os.getenv("QNN_SDK_ROOT"),
104        build_path=f"{args.build_folder}",
105        pte_path=f"{args.artifact}/{pte_filename}.pte",
106        workspace=f"/data/local/tmp/executorch/{pte_filename}",
107        device_id=args.device,
108        host_id=args.host,
109        soc_model=args.model,
110        shared_buffer=args.shared_buffer,
111    )
112    adb.push(inputs=inputs, input_list=input_list)
113    adb.execute()
114
115    # collect output data
116    output_data_folder = f"{args.artifact}/outputs"
117    make_output_dir(output_data_folder)
118
119    # remove the auxiliary output and data processing
120    classes = [
121        "Backround",
122        "Aeroplane",
123        "Bicycle",
124        "Bird",
125        "Boat",
126        "Bottle",
127        "Bus",
128        "Car",
129        "Cat",
130        "Chair",
131        "Cow",
132        "DiningTable",
133        "Dog",
134        "Horse",
135        "MotorBike",
136        "Person",
137        "PottedPlant",
138        "Sheep",
139        "Sofa",
140        "Train",
141        "TvMonitor",
142    ]
143
144    def post_process():
145        for f in os.listdir(output_data_folder):
146            filename = os.path.join(output_data_folder, f)
147            if re.match(r"^output_[0-9]+_[1-9].raw$", f):
148                os.remove(filename)
149            else:
150                output = np.fromfile(filename, dtype=np.float32)
151                output_shape = [len(classes), 224, 224]
152                output = output.reshape(output_shape)
153                output.argmax(0).astype(np.uint8).tofile(filename)
154
155    adb.pull(output_path=args.artifact, callback=post_process)
156
157    # segmentation metrics
158    predictions = []
159    for i in range(data_num):
160        predictions.append(
161            np.fromfile(
162                os.path.join(output_data_folder, f"output_{i}_0.raw"), dtype=np.uint8
163            )
164        )
165
166    pa, mpa, miou, cls_iou = segmentation_metrics(predictions, targets, classes)
167    if args.ip and args.port != -1:
168        with Client((args.ip, args.port)) as conn:
169            conn.send(
170                json.dumps({"PA": float(pa), "MPA": float(mpa), "MIoU": float(miou)})
171            )
172    else:
173        print(f"PA   : {pa}%")
174        print(f"MPA  : {mpa}%")
175        print(f"MIoU : {miou}%")
176        print(f"CIoU : \n{json.dumps(cls_iou, indent=2)}")
177
178
179if __name__ == "__main__":
180    parser = setup_common_args_and_variables()
181
182    parser.add_argument(
183        "-a",
184        "--artifact",
185        help="path for storing generated artifacts by this example. Default ./deeplab_v3",
186        default="./deeplab_v3",
187        type=str,
188    )
189
190    parser.add_argument(
191        "-d",
192        "--download",
193        help="If specified, download VOCSegmentation dataset by torchvision API",
194        action="store_true",
195        default=False,
196    )
197
198    args = parser.parse_args()
199    try:
200        main(args)
201    except Exception as e:
202        if args.ip and args.port != -1:
203            with Client((args.ip, args.port)) as conn:
204                conn.send(json.dumps({"Error": str(e)}))
205        else:
206            raise Exception(e)
207