xref: /aosp_15_r20/external/executorch/examples/qualcomm/utils.py (revision 523fa7a60841cd1ecfb9cc4201f1ca8b03ed023a)
1# Copyright (c) Qualcomm Innovation Center, Inc.
2# All rights reserved
3#
4# This source code is licensed under the BSD-style license found in the
5# LICENSE file in the root directory of this source tree.
6
7import argparse
8import os
9import subprocess
10import sys
11from pathlib import Path
12
13from typing import Callable, List, Optional
14
15import numpy as np
16
17import torch
18from executorch.backends.qualcomm.partition.qnn_partitioner import QnnPartitioner
19from executorch.backends.qualcomm.quantizer.quantizer import QnnQuantizer, QuantDtype
20from executorch.backends.qualcomm.serialization.qc_schema import QcomChipset
21from executorch.backends.qualcomm.utils.utils import (
22    capture_program,
23    generate_htp_compiler_spec,
24    generate_qnn_executorch_compiler_spec,
25    get_soc_to_arch_map,
26)
27from executorch.exir import EdgeCompileConfig, EdgeProgramManager, to_edge
28from executorch.exir.backend.backend_api import to_backend
29from executorch.exir.capture._config import ExecutorchBackendConfig
30from executorch.exir.passes.memory_planning_pass import MemoryPlanningPass
31from torch.ao.quantization.observer import MovingAverageMinMaxObserver
32from torch.ao.quantization.quantize_pt2e import (
33    convert_pt2e,
34    prepare_pt2e,
35    prepare_qat_pt2e,
36)
37
38
39class SimpleADB:
40    """
41    A wrapper class for communicating with Android device
42
43    Attributes:
44        qnn_sdk (str): QNN SDK path setup in environment variable
45        build_path (str): Path where artifacts were built
46        pte_path (str): Path where executorch binary was stored
47        workspace (str): Folder for storing artifacts on android device
48        device_id (str): Serial number of android device
49        soc_model (str): Chipset of device
50        host_id (str): Hostname of machine where device connects
51        error_only (bool): Redirect stdio and leave error messages only
52        shared_buffer (bool): Apply zero-copy mechanism in runtime
53        runner (str): Runtime executor binary
54    """
55
56    def __init__(
57        self,
58        qnn_sdk,
59        build_path,
60        pte_path,
61        workspace,
62        device_id,
63        soc_model,
64        host_id=None,
65        error_only=False,
66        shared_buffer=False,
67        dump_intermediate_outputs=False,
68        runner="examples/qualcomm/executor_runner/qnn_executor_runner",
69    ):
70        self.qnn_sdk = qnn_sdk
71        self.build_path = build_path
72        self.pte_path = pte_path if isinstance(pte_path, list) else [pte_path]
73        self.workspace = workspace
74        self.device_id = device_id
75        self.host_id = host_id
76        self.working_dir = Path(self.pte_path[0]).parent.absolute()
77        self.input_list_filename = "input_list.txt"
78        self.etdump_path = f"{self.workspace}/etdump.etdp"
79        self.dump_intermediate_outputs = dump_intermediate_outputs
80        self.debug_output_path = f"{self.workspace}/debug_output.bin"
81        self.output_folder = f"{self.workspace}/outputs"
82        self.htp_arch = get_soc_to_arch_map()[soc_model]
83        self.error_only = error_only
84        self.shared_buffer = shared_buffer
85        self.runner = runner
86
87    def _adb(self, cmd):
88        if not self.host_id:
89            cmds = ["adb", "-s", self.device_id]
90        else:
91            cmds = ["adb", "-H", self.host_id, "-s", self.device_id]
92        cmds.extend(cmd)
93
94        subprocess.run(
95            cmds, stdout=subprocess.DEVNULL if self.error_only else sys.stdout
96        )
97
98    def push(self, inputs=None, input_list=None, files=None):
99        self._adb(["shell", f"rm -rf {self.workspace}"])
100        self._adb(["shell", f"mkdir -p {self.workspace}"])
101
102        # necessary artifacts
103        artifacts = [
104            *self.pte_path,
105            f"{self.qnn_sdk}/lib/aarch64-android/libQnnHtp.so",
106            (
107                f"{self.qnn_sdk}/lib/hexagon-v{self.htp_arch}/"
108                f"unsigned/libQnnHtpV{self.htp_arch}Skel.so"
109            ),
110            (
111                f"{self.qnn_sdk}/lib/aarch64-android/"
112                f"libQnnHtpV{self.htp_arch}Stub.so"
113            ),
114            f"{self.qnn_sdk}/lib/aarch64-android/libQnnHtpPrepare.so",
115            f"{self.qnn_sdk}/lib/aarch64-android/libQnnSystem.so",
116            f"{self.build_path}/{self.runner}",
117            f"{self.build_path}/backends/qualcomm/libqnn_executorch_backend.so",
118        ]
119        input_list_file, input_files = generate_inputs(
120            self.working_dir, self.input_list_filename, inputs, input_list
121        )
122
123        if input_list_file is not None:
124            # prepare input list
125            artifacts.append(input_list_file)
126
127        for artifact in artifacts:
128            self._adb(["push", artifact, self.workspace])
129
130        # input data
131        for file_name in input_files:
132            self._adb(["push", file_name, self.workspace])
133
134        # custom files
135        if files is not None:
136            for file_name in files:
137                self._adb(["push", file_name, self.workspace])
138
139    def execute(self, custom_runner_cmd=None, method_index=0):
140        self._adb(["shell", f"mkdir -p {self.output_folder}"])
141        # run the delegation
142        if custom_runner_cmd is None:
143            qnn_executor_runner_args = " ".join(
144                [
145                    f"--model_path {os.path.basename(self.pte_path[0])}",
146                    f"--output_folder_path {self.output_folder}",
147                    f"--input_list_path {self.input_list_filename}",
148                    f"--etdump_path {self.etdump_path}",
149                    "--shared_buffer" if self.shared_buffer else "",
150                    f"--debug_output_path {self.debug_output_path}",
151                    (
152                        "--dump_intermediate_outputs"
153                        if self.dump_intermediate_outputs
154                        else ""
155                    ),
156                    f"--method_index {method_index}",
157                ]
158            )
159            qnn_executor_runner_cmds = " ".join(
160                [
161                    f"cd {self.workspace} &&",
162                    f"./qnn_executor_runner {qnn_executor_runner_args}",
163                ]
164            )
165        else:
166            qnn_executor_runner_cmds = custom_runner_cmd
167
168        self._adb(["shell", f"{qnn_executor_runner_cmds}"])
169
170    def pull(self, output_path, callback=None):
171        self._adb(["pull", "-a", self.output_folder, output_path])
172        if callback:
173            callback()
174
175    def pull_etdump(self, output_path, callback=None):
176        self._adb(["pull", self.etdump_path, output_path])
177        if callback:
178            callback()
179
180    def pull_debug_output(self, etdump_path, debug_ouput_path, callback=None):
181        self._adb(["pull", self.etdump_path, etdump_path])
182        self._adb(["pull", self.debug_output_path, debug_ouput_path])
183        if callback:
184            callback()
185
186
187def ptq_calibrate(captured_model, quantizer, dataset):
188    annotated_model = prepare_pt2e(captured_model, quantizer)
189    print("Quantizing(PTQ) the model...")
190    # calibration
191    if callable(dataset):
192        dataset(annotated_model)
193    else:
194        for data in dataset:
195            annotated_model(*data)
196    return annotated_model
197
198
199def qat_train(ori_model, captured_model, quantizer, dataset):
200    data, targets = dataset
201    annotated_model = torch.ao.quantization.move_exported_model_to_train(
202        prepare_qat_pt2e(captured_model, quantizer)
203    )
204    optimizer = torch.optim.SGD(annotated_model.parameters(), lr=0.00001)
205    criterion = torch.nn.CrossEntropyLoss()
206    for i, d in enumerate(data):
207        print(f"Epoch {i}")
208        if i > 3:
209            # Freeze quantizer parameters
210            annotated_model.apply(torch.ao.quantization.disable_observer)
211        if i > 2:
212            # Freeze batch norm mean and variance estimates
213            annotated_model.apply(torch.nn.intrinsic.qat.freeze_bn_stats)
214
215        output = annotated_model(*d)
216        loss = criterion(output, targets[i])
217        optimizer.zero_grad()
218        loss.backward()
219        optimizer.step()
220
221    return torch.ao.quantization.quantize_pt2e.convert_pt2e(
222        torch.ao.quantization.move_exported_model_to_eval(annotated_model)
223    )
224
225
226def make_quantizer(
227    quant_dtype: Optional[QuantDtype] = QuantDtype.use_8a8w,
228    custom_annotations=(),
229    per_channel_conv=True,
230    per_channel_linear=False,
231    act_observer=MovingAverageMinMaxObserver,
232    is_qat=False,
233):
234    quantizer = QnnQuantizer()
235    quantizer.add_custom_quant_annotations(custom_annotations)
236    quantizer.set_per_channel_conv_quant(per_channel_conv)
237    quantizer.set_per_channel_linear_quant(per_channel_linear)
238    quantizer.set_quant_config(quant_dtype, is_qat, act_observer)
239    return quantizer
240
241
242# TODO: refactor to support different backends
243def build_executorch_binary(
244    model,  # noqa: B006
245    inputs,  # noqa: B006
246    soc_model,
247    file_name,
248    dataset: List[torch.Tensor] | Callable[[torch.fx.GraphModule], None],
249    skip_node_id_set=None,
250    skip_node_op_set=None,
251    quant_dtype: Optional[QuantDtype] = None,
252    custom_quantizer=None,
253    shared_buffer=False,
254    metadata=None,
255    dump_intermediate_outputs=False,
256    custom_pass_config=frozenset(),
257    qat_training_data=None,
258):
259    if quant_dtype is not None:
260        captured_model = torch.export.export(model, inputs).module()
261        if qat_training_data:
262            quantizer = custom_quantizer or make_quantizer(
263                quant_dtype=quant_dtype, is_qat=True
264            )
265            # qat training
266            annotated_model = qat_train(
267                model, captured_model, quantizer, qat_training_data
268            )
269        else:
270            quantizer = custom_quantizer or make_quantizer(quant_dtype=quant_dtype)
271            # ptq calibration
272            annotated_model = ptq_calibrate(captured_model, quantizer, dataset)
273
274        quantized_model = convert_pt2e(annotated_model)
275        edge_prog = capture_program(quantized_model, inputs, custom_pass_config)
276    else:
277        edge_prog = capture_program(model, inputs, custom_pass_config)
278
279    backend_options = generate_htp_compiler_spec(
280        use_fp16=False if quant_dtype else True
281    )
282    qnn_partitioner = QnnPartitioner(
283        generate_qnn_executorch_compiler_spec(
284            soc_model=getattr(QcomChipset, soc_model),
285            backend_options=backend_options,
286            shared_buffer=shared_buffer,
287            dump_intermediate_outputs=dump_intermediate_outputs,
288        ),
289        skip_node_id_set,
290        skip_node_op_set,
291    )
292
293    executorch_config = ExecutorchBackendConfig(
294        # For shared buffer, user must pass the memory address
295        # which is allocated by RPC memory to executor runner.
296        # Therefore, won't want to pre-allocate
297        # by memory manager in runtime.
298        memory_planning_pass=MemoryPlanningPass(
299            alloc_graph_input=not shared_buffer,
300            alloc_graph_output=not shared_buffer,
301        ),
302    )
303
304    if metadata is None:
305        exported_program = to_backend(edge_prog.exported_program, qnn_partitioner)
306        exported_program.graph_module.graph.print_tabular()
307        exec_prog = to_edge(exported_program).to_executorch(config=executorch_config)
308        with open(f"{file_name}.pte", "wb") as file:
309            file.write(exec_prog.buffer)
310    else:
311        edge_prog_mgr = EdgeProgramManager(
312            edge_programs={"forward": edge_prog.exported_program},
313            constant_methods=metadata,
314            compile_config=EdgeCompileConfig(_check_ir_validity=False),
315        )
316
317        edge_prog_mgr = edge_prog_mgr.to_backend(qnn_partitioner)
318        exec_prog_mgr = edge_prog_mgr.to_executorch(config=executorch_config)
319        with open(f"{file_name}.pte", "wb") as file:
320            file.write(exec_prog_mgr.buffer)
321
322
323def make_output_dir(path: str):
324    if os.path.exists(path):
325        for f in os.listdir(path):
326            os.remove(os.path.join(path, f))
327        os.removedirs(path)
328    os.makedirs(path)
329
330
331def topk_accuracy(predictions, targets, k):
332    def solve(prob, target, k):
333        _, indices = torch.topk(prob, k=k, sorted=True)
334        golden = torch.reshape(target, [-1, 1])
335        correct = (golden == indices) * 1.0
336        top_k_accuracy = torch.mean(correct) * k
337        return top_k_accuracy
338
339    cnt = 0
340    for index, pred in enumerate(predictions):
341        cnt += solve(torch.from_numpy(pred), targets[index], k)
342
343    return cnt * 100.0 / len(predictions)
344
345
346def segmentation_metrics(predictions, targets, classes):
347    def make_confusion(goldens, predictions, num_classes):
348        def histogram(golden, predict):
349            mask = golden < num_classes
350            hist = np.bincount(
351                num_classes * golden[mask].astype(int) + predict[mask],
352                minlength=num_classes**2,
353            ).reshape(num_classes, num_classes)
354            return hist
355
356        confusion = np.zeros((num_classes, num_classes))
357        for g, p in zip(goldens, predictions):
358            confusion += histogram(g.flatten(), p.flatten())
359
360        return confusion
361
362    eps = 1e-6
363    confusion = make_confusion(targets, predictions, len(classes))
364    pa = np.diag(confusion).sum() / (confusion.sum() + eps)
365    mpa = np.mean(np.diag(confusion) / (confusion.sum(axis=1) + eps))
366    iou = np.diag(confusion) / (
367        confusion.sum(axis=1) + confusion.sum(axis=0) - np.diag(confusion) + eps
368    )
369    miou = np.mean(iou)
370    cls_iou = dict(zip(classes, iou))
371    return (pa, mpa, miou, cls_iou)
372
373
374def get_imagenet_dataset(
375    dataset_path, data_size, image_shape, crop_size=None, shuffle=True
376):
377    from torchvision import datasets, transforms
378
379    def get_data_loader():
380        preprocess = transforms.Compose(
381            [
382                transforms.Resize(image_shape),
383                transforms.CenterCrop(crop_size or image_shape[0]),
384                transforms.ToTensor(),
385                transforms.Normalize(
386                    mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]
387                ),
388            ]
389        )
390        imagenet_data = datasets.ImageFolder(dataset_path, transform=preprocess)
391        return torch.utils.data.DataLoader(
392            imagenet_data,
393            shuffle=shuffle,
394        )
395
396    # prepare input data
397    inputs, targets, input_list = [], [], ""
398    data_loader = get_data_loader()
399    for index, data in enumerate(data_loader):
400        if index >= data_size:
401            break
402        feature, target = data
403        inputs.append((feature,))
404        targets.append(target)
405        input_list += f"input_{index}_0.raw\n"
406
407    return inputs, targets, input_list
408
409
410def setup_common_args_and_variables():
411    parser = argparse.ArgumentParser()
412
413    parser.add_argument(
414        "-m",
415        "--model",
416        help="SoC model of current device. e.g. 'SM8550' for Snapdragon 8 Gen 2.",
417        type=str,
418        required=True,
419    )
420
421    parser.add_argument(
422        "-b",
423        "--build_folder",
424        help="path to cmake binary directory for android, e.g., /path/to/build-android",
425        type=str,
426        required=True,
427    )
428
429    parser.add_argument(
430        "-H",
431        "--host",
432        help="hostname where android device is connected.",
433        default=None,
434        type=str,
435    )
436
437    parser.add_argument(
438        "--ip",
439        help="IPC address for delivering execution result",
440        default="",
441        type=str,
442    )
443
444    parser.add_argument(
445        "--port",
446        help="IPC port for delivering execution result",
447        default=-1,
448        type=int,
449    )
450
451    parser.add_argument(
452        "-S",
453        "--skip_delegate_node_ids",
454        help="If specified, skip delegation for the specified node based on node ids. Node ids should be seperated by comma. e.g., aten_relu_default_10,aten_relu_default_2",
455        default=None,
456        type=str,
457    )
458
459    parser.add_argument(
460        "-f",
461        "--skip_delegate_node_ops",
462        help="If specified, skip delegation for the specified op. Node ops should be seperated by comma. e.g., aten.add.Tensor,aten.relu.default",
463        default=None,
464        type=str,
465    )
466
467    parser.add_argument(
468        "-c",
469        "--compile_only",
470        help="If specified, only compile the model.",
471        action="store_true",
472        default=False,
473    )
474
475    parser.add_argument(
476        "-s",
477        "--device",
478        help="serial number for android device communicated via ADB.",
479        type=str,
480    )
481
482    parser.add_argument(
483        "-z",
484        "--shared_buffer",
485        help="Enables usage of shared buffer between application and backend for graph I/O.",
486        action="store_true",
487    )
488
489    parser.add_argument(
490        "--skip_push",
491        help="If specified, skip pushing files to device.",
492        action="store_true",
493        default=False,
494    )
495
496    parser.add_argument(
497        "--dump_intermediate_outputs",
498        help="If specified, enable dump intermediate outputs",
499        action="store_true",
500        default=False,
501    )
502
503    # QNN_SDK_ROOT might also be an argument, but it is used in various places.
504    # So maybe it's fine to just use the environment.
505    if "QNN_SDK_ROOT" not in os.environ:
506        raise RuntimeError("Environment variable QNN_SDK_ROOT must be set")
507    print(f"QNN_SDK_ROOT={os.getenv('QNN_SDK_ROOT')}")
508
509    return parser
510
511
512def parse_skip_delegation_node(args):
513    skip_node_id_set = set()
514    skip_node_op_set = set()
515
516    if args.skip_delegate_node_ids is not None:
517        skip_node_id_set = set(map(str, args.skip_delegate_node_ids.split(",")))
518        print("Skipping following node ids: ", skip_node_id_set)
519
520    if args.skip_delegate_node_ops is not None:
521        skip_node_op_set = set(map(str, args.skip_delegate_node_ops.split(",")))
522        print("Skipping following node ops: ", skip_node_op_set)
523
524    return skip_node_id_set, skip_node_op_set
525
526
527def generate_inputs(dest_path: str, file_name: str, inputs=None, input_list=None):
528    input_list_file = None
529    input_files = []
530
531    # Prepare input list
532    if input_list is not None:
533        input_list_file = f"{dest_path}/{file_name}"
534        with open(input_list_file, "w") as f:
535            f.write(input_list)
536            f.flush()
537
538    # Prepare input data
539    if inputs is not None:
540        for idx, data in enumerate(inputs):
541            for i, d in enumerate(data):
542                file_name = f"{dest_path}/input_{idx}_{i}.raw"
543                d.detach().numpy().tofile(file_name)
544                input_files.append(file_name)
545
546    return input_list_file, input_files
547