xref: /aosp_15_r20/external/executorch/backends/qualcomm/tests/utils.py (revision 523fa7a60841cd1ecfb9cc4201f1ca8b03ed023a)
1# Copyright (c) Qualcomm Innovation Center, Inc.
2# All rights reserved
3#
4# This source code is licensed under the BSD-style license found in the
5# LICENSE file in the root directory of this source tree.
6import collections
7import copy
8import os
9import subprocess
10import tempfile
11import unittest
12from typing import Callable, Dict, List, Optional, Tuple
13
14import numpy as np
15import torch
16
17from executorch import exir
18from executorch.backends.qualcomm.partition.qnn_partitioner import QnnPartitioner
19from executorch.backends.qualcomm.qnn_preprocess import QnnBackend
20from executorch.backends.qualcomm.quantizer.quantizer import QnnQuantizer, QuantDtype
21from executorch.backends.qualcomm.serialization.qc_schema import QcomChipset
22from executorch.backends.qualcomm.utils.utils import (
23    capture_program,
24    get_soc_to_chipset_map,
25)
26from executorch.devtools import generate_etrecord, Inspector
27from executorch.examples.qualcomm.utils import (
28    generate_inputs,
29    make_output_dir,
30    SimpleADB,
31)
32
33from executorch.exir.backend.backend_api import to_backend
34from executorch.exir.backend.compile_spec_schema import CompileSpec
35from executorch.exir.dialects._ops import ops as exir_ops
36from executorch.exir.pass_base import ExportPass
37from executorch.exir.passes.memory_planning_pass import MemoryPlanningPass
38from executorch.exir.program import ExecutorchProgram, ExecutorchProgramManager
39from torch.ao.quantization.quantize_pt2e import (
40    convert_pt2e,
41    prepare_pt2e,
42    prepare_qat_pt2e,
43)
44
45
46def generate_context_binary(
47    module: torch.nn.Module,
48    inputs: Dict[str, torch.Tensor],
49    quantized: bool,
50    artifact_dir: str,
51):
52    # we also expect clang showing in PATH or context may fail to generate
53    qnn_sdk = os.environ.get("QNN_SDK_ROOT", None)
54    ndk = os.environ.get("ANDROID_NDK_ROOT", None)
55    assert qnn_sdk, "QNN_SDK_ROOT was not found in environment variable"
56    assert ndk, "ANDROID_NDK_ROOT was not found in environment variable"
57
58    inputs_tup = tuple(inputs.values())
59    jit_module = torch.jit.trace(module, inputs_tup)
60    torch.jit.save(jit_module, f"{artifact_dir}/jit_module.pt")
61
62    # input data
63    if quantized:
64        input_list = []
65        for name, data in inputs.items():
66            file_name = f"{artifact_dir}/{name}.raw"
67            data.detach().numpy().tofile(file_name)
68            input_list.append(file_name)
69
70        with open(f"{artifact_dir}/input_list.txt", "w") as f:
71            f.write(" ".join(input_list))
72
73    # flow of qnn tools
74    target = "x86_64-linux-clang"
75    inputs_str = [
76        f"-d '{k}' {str(tuple(v.shape)).replace(' ', '')[1:-1]}"
77        for k, v in inputs.items()
78    ]
79    cmds = [
80        # setup qnn env
81        f"source {qnn_sdk}/bin/envsetup.sh;"
82        # qnn-pytorch-converter
83        f"{qnn_sdk}/bin/{target}/qnn-pytorch-converter",
84        f"-i {artifact_dir}/jit_module.pt",
85        *inputs_str,
86        f"--input_list {artifact_dir}/input_list.txt" if quantized else "",
87        "--preserve_io",
88        f"-o {artifact_dir}/model.cpp;",
89        # qnn-model-lib-generator
90        f"{qnn_sdk}/bin/{target}/qnn-model-lib-generator",
91        f"-c {artifact_dir}/model.cpp",
92        f"-t {target}",
93        "-l model",
94        f"-o {artifact_dir}/model_libs;",
95        # qnn-context-binary-generator
96        f"{qnn_sdk}/bin/{target}/qnn-context-binary-generator",
97        f"--model {artifact_dir}/model_libs/{target}/libmodel.so",
98        f"--backend {qnn_sdk}/lib/{target}/libQnnHtp.so",
99        "--binary_file model_ctx",
100        f"--output_dir {artifact_dir};",
101    ]
102    result = subprocess.run(
103        " ".join(cmds),
104        shell=True,
105        executable="/bin/bash",
106        capture_output=True,
107    )
108    assert os.path.isfile(f"{artifact_dir}/model_ctx.bin"), print(result.stderr)
109
110
111class TestQNN(unittest.TestCase):
112    rtol: float = 0
113    atol: float = 0
114    host: str = ""
115    device: str = ""
116    build_folder: str = ""
117    model: QcomChipset = None
118    compiler_specs: List[CompileSpec] = None
119    chipset_table = get_soc_to_chipset_map()
120    error_only = False
121    ip = "localhost"
122    port = 8080
123    executorch_root: str = ""
124    artifact_dir: str = ""
125    image_dataset: str = ""
126    pretrained_weight: str = ""
127    enable_profile: bool = False
128    online_prepare: bool = False
129    use_8a8w: str = "8a8w"
130    use_16a16w: str = "16a16w"
131    use_16a4w: str = "16a4w"
132    shared_buffer: bool = False
133    enable_x86_64: bool = False
134
135    def _assert_outputs_equal(self, model_output, ref_output):
136        self.assertTrue(len(ref_output) == len(model_output))
137        for i in range(len(ref_output)):
138            self.assertTrue(
139                torch.allclose(
140                    model_output[i], ref_output[i], atol=self.atol, rtol=self.rtol
141                ),
142                msg=f"ref_output:\n{ref_output[i]}\n\nmodel_output:\n{model_output[i]}",
143            )
144
145    def _save_model_and_expected_output(
146        self,
147        module: torch.nn.Module,
148        buffer: exir.ExirExportedProgram,
149        inputs: Tuple[torch.Tensor],
150        dir_name: str,
151    ) -> None:
152        # Save the input data list to be executed
153        input_list = ""
154        for idx, _ in enumerate(inputs):
155            input_name = f"input_0_{idx}.raw"
156            input_list += input_name + " "
157        input_list = input_list.strip() + "\n"
158
159        ref_output = module(*inputs)
160
161        # Save the expected output data to be verified
162        ref_outputs = []
163        if isinstance(ref_output, collections.OrderedDict):
164            ref_outputs.append(ref_output["out"].detach())
165        elif isinstance(ref_output, (list, tuple)):
166            for output in ref_output:
167                ref_outputs.append(output.detach())
168        else:
169            ref_outputs.append(ref_output.detach())
170
171        pte_fname = f"{dir_name}/qnn_executorch_test.pte"
172        with open(pte_fname, "wb") as file:
173            file.write(buffer)
174
175        return input_list, ref_outputs, pte_fname
176
177    def verify_output(  # noqa: C901
178        self,
179        module: torch.nn.Module,
180        sample_inputs: Tuple[torch.Tensor],
181        executorch_prog: ExecutorchProgram | ExecutorchProgramManager,
182        etrecord_path: str = "etrecord.bin",
183        expected_profile_events: int = -1,
184        expected_intermediate_events: int = -1,
185        method_index: int = 0,
186    ):
187        with tempfile.TemporaryDirectory() as tmp_dir:
188            (
189                input_list,
190                ref_outputs,
191                pte_fname,
192            ) = self._save_model_and_expected_output(
193                module,
194                executorch_prog.buffer,
195                sample_inputs,
196                tmp_dir,
197            )
198
199            output_dir = f"{tmp_dir}/outputs"
200            outputs = []
201            etdump_path = f"{tmp_dir}/etdump.etdp"
202            debug_output_path = f"{tmp_dir}/debug_output.bin"
203
204            def post_process():
205                for i, f in enumerate(sorted(os.listdir(output_dir))):
206                    filename = os.path.join(output_dir, f)
207                    output = np.fromfile(filename, dtype=ref_outputs[i].numpy().dtype)
208                    output = torch.from_numpy(output).reshape(ref_outputs[i].shape)
209                    outputs.append(output)
210
211            def validate_profile():
212                inspector = Inspector(etdump_path=etdump_path, etrecord=etrecord_path)
213                self.assertTrue(
214                    len(inspector.to_dataframe().index) == expected_profile_events
215                )
216
217            def validate_intermediate_tensor():
218                inspector = Inspector(
219                    etdump_path=etdump_path, debug_buffer_path=debug_output_path
220                )
221                for event_block in inspector.event_blocks:
222                    if event_block.name == "Execute":
223                        self.assertTrue(
224                            len(event_block.events) == expected_intermediate_events
225                        )
226
227            if self.enable_x86_64:
228                generate_inputs(tmp_dir, "input_list.txt", [sample_inputs], input_list)
229                make_output_dir(output_dir)
230
231                target = "x86_64-linux-clang"
232                qnn_sdk = os.environ.get("QNN_SDK_ROOT", None)
233                assert qnn_sdk, "QNN_SDK_ROOT was not found in environment variable"
234
235                build_folder = self.build_folder
236                if os.path.isabs(self.build_folder):
237                    # obey user's opinion
238                    pass
239                else:
240                    # ok, assuming the user give a relative path to cwd
241                    build_folder = os.path.join(os.getcwd(), self.build_folder)
242
243                cmd = [
244                    # qnn_executor_runner
245                    f"{build_folder}/examples/qualcomm/executor_runner/qnn_executor_runner",
246                    "--model_path",
247                    pte_fname,
248                    "--input_list_path",
249                    f"{tmp_dir}/input_list.txt",
250                    "--output_folder_path",
251                    output_dir,
252                    "--method_index",
253                    str(method_index),
254                ]
255                if expected_intermediate_events != -1:
256                    cmd.append("--dump_intermediate_outputs")
257
258                env = dict(os.environ)
259                env["LD_LIBRARY_PATH"] = f"{qnn_sdk}/lib/{target}/:{build_folder}/lib"
260                proc = subprocess.run(
261                    cmd,
262                    stdout=subprocess.PIPE,
263                    stderr=subprocess.STDOUT,
264                    env=env,
265                    cwd=tmp_dir,
266                )
267
268                self.assertEqual(
269                    proc.returncode,
270                    0,
271                    f"The process running qnn_executorch_runner return {proc.returncode}, "
272                    "STDOUT=\n"
273                    f"{proc.stdout.decode('utf-8')}",
274                )
275
276                # Verify the outputs
277                post_process()
278                self._assert_outputs_equal(outputs, ref_outputs)
279
280                # Verify the etdump
281                if expected_profile_events != -1:
282                    validate_profile()
283
284                if expected_intermediate_events != -1:
285                    validate_intermediate_tensor()
286            else:
287                adb = SimpleADB(
288                    qnn_sdk=os.getenv("QNN_SDK_ROOT"),
289                    build_path=self.build_folder,
290                    pte_path=pte_fname,
291                    workspace="/data/local/tmp/qnn_executorch_test",
292                    device_id=self.device,
293                    host_id=self.host,
294                    soc_model=self.model,
295                    error_only=self.error_only,
296                    dump_intermediate_outputs=(
297                        True if expected_intermediate_events != -1 else False
298                    ),
299                )
300                adb.push(inputs=[sample_inputs], input_list=input_list)
301                adb.execute(method_index=method_index)
302                adb.pull(output_path=tmp_dir, callback=post_process)
303                self._assert_outputs_equal(outputs, ref_outputs)
304
305                if expected_profile_events != -1:
306                    adb.pull_etdump(etdump_path, callback=validate_profile)
307
308                if expected_intermediate_events != -1:
309                    adb.pull_debug_output(
310                        etdump_path,
311                        debug_output_path,
312                        callback=validate_intermediate_tensor,
313                    )
314
315    def lower_module_and_test_output(
316        self,
317        module: torch.nn.Module,
318        sample_inputs: Tuple[torch.Tensor],
319        expected_partitions: int = 1,
320        expected_profile_events: int = -1,
321        expected_intermediate_events: int = -1,
322        assert_output_equal: bool = True,
323        skip_node_id_set: set = None,
324        skip_node_op_set: set = None,
325    ):
326        qnn_partitioner = QnnPartitioner(
327            self.compiler_specs, skip_node_id_set, skip_node_op_set
328        )
329        delegated_program = capture_program(module, sample_inputs)
330
331        # this is needed for the ETRecord as lowering modifies the graph in-place
332        edge_copy = copy.deepcopy(delegated_program)
333
334        delegated_program.exported_program = to_backend(
335            delegated_program.exported_program, qnn_partitioner
336        )
337        exec_prog = delegated_program.to_executorch(
338            exir.ExecutorchBackendConfig(
339                # For shared buffer, user must pass the memory address
340                # which is allocated by RPC memory to executor runner.
341                # Therefore, won't want to pre-allocate
342                # by memory manager in runtime.
343                memory_planning_pass=MemoryPlanningPass(
344                    alloc_graph_input=not self.shared_buffer,
345                    alloc_graph_output=not self.shared_buffer,
346                ),
347            )
348        )
349
350        # Assert the backend name is qnn
351        self.assertEqual(
352            len(exec_prog.program.execution_plan[0].delegates),
353            expected_partitions,
354        )
355        for i in range(expected_partitions):
356            self.assertEqual(
357                exec_prog.program.execution_plan[0].delegates[i].id,
358                QnnBackend.__name__,
359            )
360
361        etrecord_path = "etrecord.bin"
362        if self.enable_profile:
363            generate_etrecord(etrecord_path, edge_copy, exec_prog)
364        # Check numerics
365        if (
366            assert_output_equal
367            or expected_profile_events != -1
368            or expected_intermediate_events != -1
369        ):
370            self.verify_output(
371                module,
372                sample_inputs,
373                exec_prog,
374                etrecord_path,
375                expected_profile_events,
376                expected_intermediate_events,
377            )
378
379    def get_qdq_module(
380        self,
381        module: torch.nn.Module,
382        inputs: Tuple[torch.Tensor],
383        is_conv_per_channel: Optional[bool] = True,
384        is_linear_per_channel: Optional[bool] = False,
385        custom_quant_annotations: Tuple[Callable] = (),
386        quant_dtype: QuantDtype = QuantDtype.use_8a8w,
387    ) -> torch.fx.GraphModule:
388        m = torch.export.export(module, inputs).module()
389
390        quantizer = QnnQuantizer()
391        quantizer.add_custom_quant_annotations(custom_quant_annotations)
392        quantizer.set_per_channel_conv_quant(is_conv_per_channel)
393        quantizer.set_per_channel_linear_quant(is_linear_per_channel)
394        quantizer.set_quant_config(quant_dtype)
395
396        prepared = prepare_pt2e(m, quantizer)
397        prepared(*inputs)
398        quantized_module = convert_pt2e(prepared)
399        nodes = {node.target for node in quantized_module.graph.nodes}
400        q_and_dq = {
401            torch.ops.quantized_decomposed.quantize_per_tensor.default,
402            torch.ops.quantized_decomposed.dequantize_per_tensor.default,
403            torch.ops.quantized_decomposed.quantize_per_channel.default,
404            torch.ops.quantized_decomposed.dequantize_per_channel.default,
405        }
406        self.assertTrue(nodes.intersection(q_and_dq))
407        return quantized_module
408
409    def get_prepared_qat_module(
410        self,
411        module: torch.nn.Module,
412        inputs: Tuple[torch.Tensor],
413        is_conv_per_channel: Optional[bool] = True,
414        is_linear_per_channel: Optional[bool] = False,
415        custom_quant_annotations: Tuple[Callable] = (),
416        quant_dtype: QuantDtype = QuantDtype.use_8a8w,
417    ) -> torch.fx.GraphModule:
418        m = torch.export.export_for_training(module, inputs).module()
419
420        quantizer = QnnQuantizer()
421        quantizer.add_custom_quant_annotations(custom_quant_annotations)
422        quantizer.set_per_channel_conv_quant(is_conv_per_channel)
423        quantizer.set_per_channel_linear_quant(is_linear_per_channel)
424
425        if quant_dtype == QuantDtype.use_8a8w:
426            quantizer.set_quant_config(quant_dtype, is_qat=True)
427        else:
428            raise RuntimeError("Shuld not be here")
429
430        prepared = prepare_qat_pt2e(m, quantizer)
431        return torch.ao.quantization.move_exported_model_to_train(prepared)
432
433    def get_converted_sgd_trained_module(
434        self,
435        ori_module: torch.nn.Module,
436        prepared: torch.nn.Module,
437        inputs: Tuple[torch.Tensor],
438    ) -> torch.fx.GraphModule:
439        optimizer = torch.optim.SGD(prepared.parameters(), lr=0.0001)
440        criterion = torch.nn.CrossEntropyLoss()
441        output = prepared(*inputs)
442        loss = criterion(output, ori_module(*inputs))
443        optimizer.zero_grad()
444        loss.backward()
445        optimizer.step()
446        return torch.ao.quantization.quantize_pt2e.convert_pt2e(prepared)
447
448    def split_graph(self, graph_module: torch.fx.GraphModule, division: int):
449        class SplitGraph(ExportPass):
450            """
451            Split graph based on number of nodes.
452            """
453
454            def __init__(self, shares):
455                super().__init__()
456                self.shares = shares
457
458            def _insert_clone(
459                self, graph_module: torch.fx.GraphModule
460            ) -> torch.fx.GraphModule:
461                num_graph_nodes = 0
462                for node in graph_module.graph.nodes:
463                    num_graph_nodes += 1 if node.op == "call_function" else 0
464
465                    if num_graph_nodes % self.shares != 0 or node.op != "call_function":
466                        continue
467
468                    with graph_module.graph.inserting_after(node):
469                        users = list(node.users.keys())
470                        inserted_node = graph_module.graph.create_node(
471                            "call_function",
472                            exir_ops.edge.aten.clone.default,
473                            (node,),
474                        )
475                        inserted_node.meta["val"] = node.meta["val"]
476                        if "quant_attrs" in node.meta:
477                            inserted_node.meta["quant_attrs"] = node.meta["quant_attrs"]
478                        for user in users:
479                            user.replace_input_with(node, inserted_node)
480
481            def call(self, graph_module: torch.fx.GraphModule):
482                self._insert_clone(graph_module)
483                graph_module.recompile()
484
485        num_graph_nodes = 0
486        for node in graph_module.graph.nodes:
487            num_graph_nodes += 1 if node.op == "call_function" else 0
488
489        SplitGraph(-(num_graph_nodes // -division))(graph_module)
490