xref: /aosp_15_r20/external/executorch/backends/qualcomm/tests/utils.py (revision 523fa7a60841cd1ecfb9cc4201f1ca8b03ed023a)
1*523fa7a6SAndroid Build Coastguard Worker# Copyright (c) Qualcomm Innovation Center, Inc.
2*523fa7a6SAndroid Build Coastguard Worker# All rights reserved
3*523fa7a6SAndroid Build Coastguard Worker#
4*523fa7a6SAndroid Build Coastguard Worker# This source code is licensed under the BSD-style license found in the
5*523fa7a6SAndroid Build Coastguard Worker# LICENSE file in the root directory of this source tree.
6*523fa7a6SAndroid Build Coastguard Workerimport collections
7*523fa7a6SAndroid Build Coastguard Workerimport copy
8*523fa7a6SAndroid Build Coastguard Workerimport os
9*523fa7a6SAndroid Build Coastguard Workerimport subprocess
10*523fa7a6SAndroid Build Coastguard Workerimport tempfile
11*523fa7a6SAndroid Build Coastguard Workerimport unittest
12*523fa7a6SAndroid Build Coastguard Workerfrom typing import Callable, Dict, List, Optional, Tuple
13*523fa7a6SAndroid Build Coastguard Worker
14*523fa7a6SAndroid Build Coastguard Workerimport numpy as np
15*523fa7a6SAndroid Build Coastguard Workerimport torch
16*523fa7a6SAndroid Build Coastguard Worker
17*523fa7a6SAndroid Build Coastguard Workerfrom executorch import exir
18*523fa7a6SAndroid Build Coastguard Workerfrom executorch.backends.qualcomm.partition.qnn_partitioner import QnnPartitioner
19*523fa7a6SAndroid Build Coastguard Workerfrom executorch.backends.qualcomm.qnn_preprocess import QnnBackend
20*523fa7a6SAndroid Build Coastguard Workerfrom executorch.backends.qualcomm.quantizer.quantizer import QnnQuantizer, QuantDtype
21*523fa7a6SAndroid Build Coastguard Workerfrom executorch.backends.qualcomm.serialization.qc_schema import QcomChipset
22*523fa7a6SAndroid Build Coastguard Workerfrom executorch.backends.qualcomm.utils.utils import (
23*523fa7a6SAndroid Build Coastguard Worker    capture_program,
24*523fa7a6SAndroid Build Coastguard Worker    get_soc_to_chipset_map,
25*523fa7a6SAndroid Build Coastguard Worker)
26*523fa7a6SAndroid Build Coastguard Workerfrom executorch.devtools import generate_etrecord, Inspector
27*523fa7a6SAndroid Build Coastguard Workerfrom executorch.examples.qualcomm.utils import (
28*523fa7a6SAndroid Build Coastguard Worker    generate_inputs,
29*523fa7a6SAndroid Build Coastguard Worker    make_output_dir,
30*523fa7a6SAndroid Build Coastguard Worker    SimpleADB,
31*523fa7a6SAndroid Build Coastguard Worker)
32*523fa7a6SAndroid Build Coastguard Worker
33*523fa7a6SAndroid Build Coastguard Workerfrom executorch.exir.backend.backend_api import to_backend
34*523fa7a6SAndroid Build Coastguard Workerfrom executorch.exir.backend.compile_spec_schema import CompileSpec
35*523fa7a6SAndroid Build Coastguard Workerfrom executorch.exir.dialects._ops import ops as exir_ops
36*523fa7a6SAndroid Build Coastguard Workerfrom executorch.exir.pass_base import ExportPass
37*523fa7a6SAndroid Build Coastguard Workerfrom executorch.exir.passes.memory_planning_pass import MemoryPlanningPass
38*523fa7a6SAndroid Build Coastguard Workerfrom executorch.exir.program import ExecutorchProgram, ExecutorchProgramManager
39*523fa7a6SAndroid Build Coastguard Workerfrom torch.ao.quantization.quantize_pt2e import (
40*523fa7a6SAndroid Build Coastguard Worker    convert_pt2e,
41*523fa7a6SAndroid Build Coastguard Worker    prepare_pt2e,
42*523fa7a6SAndroid Build Coastguard Worker    prepare_qat_pt2e,
43*523fa7a6SAndroid Build Coastguard Worker)
44*523fa7a6SAndroid Build Coastguard Worker
45*523fa7a6SAndroid Build Coastguard Worker
46*523fa7a6SAndroid Build Coastguard Workerdef generate_context_binary(
47*523fa7a6SAndroid Build Coastguard Worker    module: torch.nn.Module,
48*523fa7a6SAndroid Build Coastguard Worker    inputs: Dict[str, torch.Tensor],
49*523fa7a6SAndroid Build Coastguard Worker    quantized: bool,
50*523fa7a6SAndroid Build Coastguard Worker    artifact_dir: str,
51*523fa7a6SAndroid Build Coastguard Worker):
52*523fa7a6SAndroid Build Coastguard Worker    # we also expect clang showing in PATH or context may fail to generate
53*523fa7a6SAndroid Build Coastguard Worker    qnn_sdk = os.environ.get("QNN_SDK_ROOT", None)
54*523fa7a6SAndroid Build Coastguard Worker    ndk = os.environ.get("ANDROID_NDK_ROOT", None)
55*523fa7a6SAndroid Build Coastguard Worker    assert qnn_sdk, "QNN_SDK_ROOT was not found in environment variable"
56*523fa7a6SAndroid Build Coastguard Worker    assert ndk, "ANDROID_NDK_ROOT was not found in environment variable"
57*523fa7a6SAndroid Build Coastguard Worker
58*523fa7a6SAndroid Build Coastguard Worker    inputs_tup = tuple(inputs.values())
59*523fa7a6SAndroid Build Coastguard Worker    jit_module = torch.jit.trace(module, inputs_tup)
60*523fa7a6SAndroid Build Coastguard Worker    torch.jit.save(jit_module, f"{artifact_dir}/jit_module.pt")
61*523fa7a6SAndroid Build Coastguard Worker
62*523fa7a6SAndroid Build Coastguard Worker    # input data
63*523fa7a6SAndroid Build Coastguard Worker    if quantized:
64*523fa7a6SAndroid Build Coastguard Worker        input_list = []
65*523fa7a6SAndroid Build Coastguard Worker        for name, data in inputs.items():
66*523fa7a6SAndroid Build Coastguard Worker            file_name = f"{artifact_dir}/{name}.raw"
67*523fa7a6SAndroid Build Coastguard Worker            data.detach().numpy().tofile(file_name)
68*523fa7a6SAndroid Build Coastguard Worker            input_list.append(file_name)
69*523fa7a6SAndroid Build Coastguard Worker
70*523fa7a6SAndroid Build Coastguard Worker        with open(f"{artifact_dir}/input_list.txt", "w") as f:
71*523fa7a6SAndroid Build Coastguard Worker            f.write(" ".join(input_list))
72*523fa7a6SAndroid Build Coastguard Worker
73*523fa7a6SAndroid Build Coastguard Worker    # flow of qnn tools
74*523fa7a6SAndroid Build Coastguard Worker    target = "x86_64-linux-clang"
75*523fa7a6SAndroid Build Coastguard Worker    inputs_str = [
76*523fa7a6SAndroid Build Coastguard Worker        f"-d '{k}' {str(tuple(v.shape)).replace(' ', '')[1:-1]}"
77*523fa7a6SAndroid Build Coastguard Worker        for k, v in inputs.items()
78*523fa7a6SAndroid Build Coastguard Worker    ]
79*523fa7a6SAndroid Build Coastguard Worker    cmds = [
80*523fa7a6SAndroid Build Coastguard Worker        # setup qnn env
81*523fa7a6SAndroid Build Coastguard Worker        f"source {qnn_sdk}/bin/envsetup.sh;"
82*523fa7a6SAndroid Build Coastguard Worker        # qnn-pytorch-converter
83*523fa7a6SAndroid Build Coastguard Worker        f"{qnn_sdk}/bin/{target}/qnn-pytorch-converter",
84*523fa7a6SAndroid Build Coastguard Worker        f"-i {artifact_dir}/jit_module.pt",
85*523fa7a6SAndroid Build Coastguard Worker        *inputs_str,
86*523fa7a6SAndroid Build Coastguard Worker        f"--input_list {artifact_dir}/input_list.txt" if quantized else "",
87*523fa7a6SAndroid Build Coastguard Worker        "--preserve_io",
88*523fa7a6SAndroid Build Coastguard Worker        f"-o {artifact_dir}/model.cpp;",
89*523fa7a6SAndroid Build Coastguard Worker        # qnn-model-lib-generator
90*523fa7a6SAndroid Build Coastguard Worker        f"{qnn_sdk}/bin/{target}/qnn-model-lib-generator",
91*523fa7a6SAndroid Build Coastguard Worker        f"-c {artifact_dir}/model.cpp",
92*523fa7a6SAndroid Build Coastguard Worker        f"-t {target}",
93*523fa7a6SAndroid Build Coastguard Worker        "-l model",
94*523fa7a6SAndroid Build Coastguard Worker        f"-o {artifact_dir}/model_libs;",
95*523fa7a6SAndroid Build Coastguard Worker        # qnn-context-binary-generator
96*523fa7a6SAndroid Build Coastguard Worker        f"{qnn_sdk}/bin/{target}/qnn-context-binary-generator",
97*523fa7a6SAndroid Build Coastguard Worker        f"--model {artifact_dir}/model_libs/{target}/libmodel.so",
98*523fa7a6SAndroid Build Coastguard Worker        f"--backend {qnn_sdk}/lib/{target}/libQnnHtp.so",
99*523fa7a6SAndroid Build Coastguard Worker        "--binary_file model_ctx",
100*523fa7a6SAndroid Build Coastguard Worker        f"--output_dir {artifact_dir};",
101*523fa7a6SAndroid Build Coastguard Worker    ]
102*523fa7a6SAndroid Build Coastguard Worker    result = subprocess.run(
103*523fa7a6SAndroid Build Coastguard Worker        " ".join(cmds),
104*523fa7a6SAndroid Build Coastguard Worker        shell=True,
105*523fa7a6SAndroid Build Coastguard Worker        executable="/bin/bash",
106*523fa7a6SAndroid Build Coastguard Worker        capture_output=True,
107*523fa7a6SAndroid Build Coastguard Worker    )
108*523fa7a6SAndroid Build Coastguard Worker    assert os.path.isfile(f"{artifact_dir}/model_ctx.bin"), print(result.stderr)
109*523fa7a6SAndroid Build Coastguard Worker
110*523fa7a6SAndroid Build Coastguard Worker
111*523fa7a6SAndroid Build Coastguard Workerclass TestQNN(unittest.TestCase):
112*523fa7a6SAndroid Build Coastguard Worker    rtol: float = 0
113*523fa7a6SAndroid Build Coastguard Worker    atol: float = 0
114*523fa7a6SAndroid Build Coastguard Worker    host: str = ""
115*523fa7a6SAndroid Build Coastguard Worker    device: str = ""
116*523fa7a6SAndroid Build Coastguard Worker    build_folder: str = ""
117*523fa7a6SAndroid Build Coastguard Worker    model: QcomChipset = None
118*523fa7a6SAndroid Build Coastguard Worker    compiler_specs: List[CompileSpec] = None
119*523fa7a6SAndroid Build Coastguard Worker    chipset_table = get_soc_to_chipset_map()
120*523fa7a6SAndroid Build Coastguard Worker    error_only = False
121*523fa7a6SAndroid Build Coastguard Worker    ip = "localhost"
122*523fa7a6SAndroid Build Coastguard Worker    port = 8080
123*523fa7a6SAndroid Build Coastguard Worker    executorch_root: str = ""
124*523fa7a6SAndroid Build Coastguard Worker    artifact_dir: str = ""
125*523fa7a6SAndroid Build Coastguard Worker    image_dataset: str = ""
126*523fa7a6SAndroid Build Coastguard Worker    pretrained_weight: str = ""
127*523fa7a6SAndroid Build Coastguard Worker    enable_profile: bool = False
128*523fa7a6SAndroid Build Coastguard Worker    online_prepare: bool = False
129*523fa7a6SAndroid Build Coastguard Worker    use_8a8w: str = "8a8w"
130*523fa7a6SAndroid Build Coastguard Worker    use_16a16w: str = "16a16w"
131*523fa7a6SAndroid Build Coastguard Worker    use_16a4w: str = "16a4w"
132*523fa7a6SAndroid Build Coastguard Worker    shared_buffer: bool = False
133*523fa7a6SAndroid Build Coastguard Worker    enable_x86_64: bool = False
134*523fa7a6SAndroid Build Coastguard Worker
135*523fa7a6SAndroid Build Coastguard Worker    def _assert_outputs_equal(self, model_output, ref_output):
136*523fa7a6SAndroid Build Coastguard Worker        self.assertTrue(len(ref_output) == len(model_output))
137*523fa7a6SAndroid Build Coastguard Worker        for i in range(len(ref_output)):
138*523fa7a6SAndroid Build Coastguard Worker            self.assertTrue(
139*523fa7a6SAndroid Build Coastguard Worker                torch.allclose(
140*523fa7a6SAndroid Build Coastguard Worker                    model_output[i], ref_output[i], atol=self.atol, rtol=self.rtol
141*523fa7a6SAndroid Build Coastguard Worker                ),
142*523fa7a6SAndroid Build Coastguard Worker                msg=f"ref_output:\n{ref_output[i]}\n\nmodel_output:\n{model_output[i]}",
143*523fa7a6SAndroid Build Coastguard Worker            )
144*523fa7a6SAndroid Build Coastguard Worker
145*523fa7a6SAndroid Build Coastguard Worker    def _save_model_and_expected_output(
146*523fa7a6SAndroid Build Coastguard Worker        self,
147*523fa7a6SAndroid Build Coastguard Worker        module: torch.nn.Module,
148*523fa7a6SAndroid Build Coastguard Worker        buffer: exir.ExirExportedProgram,
149*523fa7a6SAndroid Build Coastguard Worker        inputs: Tuple[torch.Tensor],
150*523fa7a6SAndroid Build Coastguard Worker        dir_name: str,
151*523fa7a6SAndroid Build Coastguard Worker    ) -> None:
152*523fa7a6SAndroid Build Coastguard Worker        # Save the input data list to be executed
153*523fa7a6SAndroid Build Coastguard Worker        input_list = ""
154*523fa7a6SAndroid Build Coastguard Worker        for idx, _ in enumerate(inputs):
155*523fa7a6SAndroid Build Coastguard Worker            input_name = f"input_0_{idx}.raw"
156*523fa7a6SAndroid Build Coastguard Worker            input_list += input_name + " "
157*523fa7a6SAndroid Build Coastguard Worker        input_list = input_list.strip() + "\n"
158*523fa7a6SAndroid Build Coastguard Worker
159*523fa7a6SAndroid Build Coastguard Worker        ref_output = module(*inputs)
160*523fa7a6SAndroid Build Coastguard Worker
161*523fa7a6SAndroid Build Coastguard Worker        # Save the expected output data to be verified
162*523fa7a6SAndroid Build Coastguard Worker        ref_outputs = []
163*523fa7a6SAndroid Build Coastguard Worker        if isinstance(ref_output, collections.OrderedDict):
164*523fa7a6SAndroid Build Coastguard Worker            ref_outputs.append(ref_output["out"].detach())
165*523fa7a6SAndroid Build Coastguard Worker        elif isinstance(ref_output, (list, tuple)):
166*523fa7a6SAndroid Build Coastguard Worker            for output in ref_output:
167*523fa7a6SAndroid Build Coastguard Worker                ref_outputs.append(output.detach())
168*523fa7a6SAndroid Build Coastguard Worker        else:
169*523fa7a6SAndroid Build Coastguard Worker            ref_outputs.append(ref_output.detach())
170*523fa7a6SAndroid Build Coastguard Worker
171*523fa7a6SAndroid Build Coastguard Worker        pte_fname = f"{dir_name}/qnn_executorch_test.pte"
172*523fa7a6SAndroid Build Coastguard Worker        with open(pte_fname, "wb") as file:
173*523fa7a6SAndroid Build Coastguard Worker            file.write(buffer)
174*523fa7a6SAndroid Build Coastguard Worker
175*523fa7a6SAndroid Build Coastguard Worker        return input_list, ref_outputs, pte_fname
176*523fa7a6SAndroid Build Coastguard Worker
177*523fa7a6SAndroid Build Coastguard Worker    def verify_output(  # noqa: C901
178*523fa7a6SAndroid Build Coastguard Worker        self,
179*523fa7a6SAndroid Build Coastguard Worker        module: torch.nn.Module,
180*523fa7a6SAndroid Build Coastguard Worker        sample_inputs: Tuple[torch.Tensor],
181*523fa7a6SAndroid Build Coastguard Worker        executorch_prog: ExecutorchProgram | ExecutorchProgramManager,
182*523fa7a6SAndroid Build Coastguard Worker        etrecord_path: str = "etrecord.bin",
183*523fa7a6SAndroid Build Coastguard Worker        expected_profile_events: int = -1,
184*523fa7a6SAndroid Build Coastguard Worker        expected_intermediate_events: int = -1,
185*523fa7a6SAndroid Build Coastguard Worker        method_index: int = 0,
186*523fa7a6SAndroid Build Coastguard Worker    ):
187*523fa7a6SAndroid Build Coastguard Worker        with tempfile.TemporaryDirectory() as tmp_dir:
188*523fa7a6SAndroid Build Coastguard Worker            (
189*523fa7a6SAndroid Build Coastguard Worker                input_list,
190*523fa7a6SAndroid Build Coastguard Worker                ref_outputs,
191*523fa7a6SAndroid Build Coastguard Worker                pte_fname,
192*523fa7a6SAndroid Build Coastguard Worker            ) = self._save_model_and_expected_output(
193*523fa7a6SAndroid Build Coastguard Worker                module,
194*523fa7a6SAndroid Build Coastguard Worker                executorch_prog.buffer,
195*523fa7a6SAndroid Build Coastguard Worker                sample_inputs,
196*523fa7a6SAndroid Build Coastguard Worker                tmp_dir,
197*523fa7a6SAndroid Build Coastguard Worker            )
198*523fa7a6SAndroid Build Coastguard Worker
199*523fa7a6SAndroid Build Coastguard Worker            output_dir = f"{tmp_dir}/outputs"
200*523fa7a6SAndroid Build Coastguard Worker            outputs = []
201*523fa7a6SAndroid Build Coastguard Worker            etdump_path = f"{tmp_dir}/etdump.etdp"
202*523fa7a6SAndroid Build Coastguard Worker            debug_output_path = f"{tmp_dir}/debug_output.bin"
203*523fa7a6SAndroid Build Coastguard Worker
204*523fa7a6SAndroid Build Coastguard Worker            def post_process():
205*523fa7a6SAndroid Build Coastguard Worker                for i, f in enumerate(sorted(os.listdir(output_dir))):
206*523fa7a6SAndroid Build Coastguard Worker                    filename = os.path.join(output_dir, f)
207*523fa7a6SAndroid Build Coastguard Worker                    output = np.fromfile(filename, dtype=ref_outputs[i].numpy().dtype)
208*523fa7a6SAndroid Build Coastguard Worker                    output = torch.from_numpy(output).reshape(ref_outputs[i].shape)
209*523fa7a6SAndroid Build Coastguard Worker                    outputs.append(output)
210*523fa7a6SAndroid Build Coastguard Worker
211*523fa7a6SAndroid Build Coastguard Worker            def validate_profile():
212*523fa7a6SAndroid Build Coastguard Worker                inspector = Inspector(etdump_path=etdump_path, etrecord=etrecord_path)
213*523fa7a6SAndroid Build Coastguard Worker                self.assertTrue(
214*523fa7a6SAndroid Build Coastguard Worker                    len(inspector.to_dataframe().index) == expected_profile_events
215*523fa7a6SAndroid Build Coastguard Worker                )
216*523fa7a6SAndroid Build Coastguard Worker
217*523fa7a6SAndroid Build Coastguard Worker            def validate_intermediate_tensor():
218*523fa7a6SAndroid Build Coastguard Worker                inspector = Inspector(
219*523fa7a6SAndroid Build Coastguard Worker                    etdump_path=etdump_path, debug_buffer_path=debug_output_path
220*523fa7a6SAndroid Build Coastguard Worker                )
221*523fa7a6SAndroid Build Coastguard Worker                for event_block in inspector.event_blocks:
222*523fa7a6SAndroid Build Coastguard Worker                    if event_block.name == "Execute":
223*523fa7a6SAndroid Build Coastguard Worker                        self.assertTrue(
224*523fa7a6SAndroid Build Coastguard Worker                            len(event_block.events) == expected_intermediate_events
225*523fa7a6SAndroid Build Coastguard Worker                        )
226*523fa7a6SAndroid Build Coastguard Worker
227*523fa7a6SAndroid Build Coastguard Worker            if self.enable_x86_64:
228*523fa7a6SAndroid Build Coastguard Worker                generate_inputs(tmp_dir, "input_list.txt", [sample_inputs], input_list)
229*523fa7a6SAndroid Build Coastguard Worker                make_output_dir(output_dir)
230*523fa7a6SAndroid Build Coastguard Worker
231*523fa7a6SAndroid Build Coastguard Worker                target = "x86_64-linux-clang"
232*523fa7a6SAndroid Build Coastguard Worker                qnn_sdk = os.environ.get("QNN_SDK_ROOT", None)
233*523fa7a6SAndroid Build Coastguard Worker                assert qnn_sdk, "QNN_SDK_ROOT was not found in environment variable"
234*523fa7a6SAndroid Build Coastguard Worker
235*523fa7a6SAndroid Build Coastguard Worker                build_folder = self.build_folder
236*523fa7a6SAndroid Build Coastguard Worker                if os.path.isabs(self.build_folder):
237*523fa7a6SAndroid Build Coastguard Worker                    # obey user's opinion
238*523fa7a6SAndroid Build Coastguard Worker                    pass
239*523fa7a6SAndroid Build Coastguard Worker                else:
240*523fa7a6SAndroid Build Coastguard Worker                    # ok, assuming the user give a relative path to cwd
241*523fa7a6SAndroid Build Coastguard Worker                    build_folder = os.path.join(os.getcwd(), self.build_folder)
242*523fa7a6SAndroid Build Coastguard Worker
243*523fa7a6SAndroid Build Coastguard Worker                cmd = [
244*523fa7a6SAndroid Build Coastguard Worker                    # qnn_executor_runner
245*523fa7a6SAndroid Build Coastguard Worker                    f"{build_folder}/examples/qualcomm/executor_runner/qnn_executor_runner",
246*523fa7a6SAndroid Build Coastguard Worker                    "--model_path",
247*523fa7a6SAndroid Build Coastguard Worker                    pte_fname,
248*523fa7a6SAndroid Build Coastguard Worker                    "--input_list_path",
249*523fa7a6SAndroid Build Coastguard Worker                    f"{tmp_dir}/input_list.txt",
250*523fa7a6SAndroid Build Coastguard Worker                    "--output_folder_path",
251*523fa7a6SAndroid Build Coastguard Worker                    output_dir,
252*523fa7a6SAndroid Build Coastguard Worker                    "--method_index",
253*523fa7a6SAndroid Build Coastguard Worker                    str(method_index),
254*523fa7a6SAndroid Build Coastguard Worker                ]
255*523fa7a6SAndroid Build Coastguard Worker                if expected_intermediate_events != -1:
256*523fa7a6SAndroid Build Coastguard Worker                    cmd.append("--dump_intermediate_outputs")
257*523fa7a6SAndroid Build Coastguard Worker
258*523fa7a6SAndroid Build Coastguard Worker                env = dict(os.environ)
259*523fa7a6SAndroid Build Coastguard Worker                env["LD_LIBRARY_PATH"] = f"{qnn_sdk}/lib/{target}/:{build_folder}/lib"
260*523fa7a6SAndroid Build Coastguard Worker                proc = subprocess.run(
261*523fa7a6SAndroid Build Coastguard Worker                    cmd,
262*523fa7a6SAndroid Build Coastguard Worker                    stdout=subprocess.PIPE,
263*523fa7a6SAndroid Build Coastguard Worker                    stderr=subprocess.STDOUT,
264*523fa7a6SAndroid Build Coastguard Worker                    env=env,
265*523fa7a6SAndroid Build Coastguard Worker                    cwd=tmp_dir,
266*523fa7a6SAndroid Build Coastguard Worker                )
267*523fa7a6SAndroid Build Coastguard Worker
268*523fa7a6SAndroid Build Coastguard Worker                self.assertEqual(
269*523fa7a6SAndroid Build Coastguard Worker                    proc.returncode,
270*523fa7a6SAndroid Build Coastguard Worker                    0,
271*523fa7a6SAndroid Build Coastguard Worker                    f"The process running qnn_executorch_runner return {proc.returncode}, "
272*523fa7a6SAndroid Build Coastguard Worker                    "STDOUT=\n"
273*523fa7a6SAndroid Build Coastguard Worker                    f"{proc.stdout.decode('utf-8')}",
274*523fa7a6SAndroid Build Coastguard Worker                )
275*523fa7a6SAndroid Build Coastguard Worker
276*523fa7a6SAndroid Build Coastguard Worker                # Verify the outputs
277*523fa7a6SAndroid Build Coastguard Worker                post_process()
278*523fa7a6SAndroid Build Coastguard Worker                self._assert_outputs_equal(outputs, ref_outputs)
279*523fa7a6SAndroid Build Coastguard Worker
280*523fa7a6SAndroid Build Coastguard Worker                # Verify the etdump
281*523fa7a6SAndroid Build Coastguard Worker                if expected_profile_events != -1:
282*523fa7a6SAndroid Build Coastguard Worker                    validate_profile()
283*523fa7a6SAndroid Build Coastguard Worker
284*523fa7a6SAndroid Build Coastguard Worker                if expected_intermediate_events != -1:
285*523fa7a6SAndroid Build Coastguard Worker                    validate_intermediate_tensor()
286*523fa7a6SAndroid Build Coastguard Worker            else:
287*523fa7a6SAndroid Build Coastguard Worker                adb = SimpleADB(
288*523fa7a6SAndroid Build Coastguard Worker                    qnn_sdk=os.getenv("QNN_SDK_ROOT"),
289*523fa7a6SAndroid Build Coastguard Worker                    build_path=self.build_folder,
290*523fa7a6SAndroid Build Coastguard Worker                    pte_path=pte_fname,
291*523fa7a6SAndroid Build Coastguard Worker                    workspace="/data/local/tmp/qnn_executorch_test",
292*523fa7a6SAndroid Build Coastguard Worker                    device_id=self.device,
293*523fa7a6SAndroid Build Coastguard Worker                    host_id=self.host,
294*523fa7a6SAndroid Build Coastguard Worker                    soc_model=self.model,
295*523fa7a6SAndroid Build Coastguard Worker                    error_only=self.error_only,
296*523fa7a6SAndroid Build Coastguard Worker                    dump_intermediate_outputs=(
297*523fa7a6SAndroid Build Coastguard Worker                        True if expected_intermediate_events != -1 else False
298*523fa7a6SAndroid Build Coastguard Worker                    ),
299*523fa7a6SAndroid Build Coastguard Worker                )
300*523fa7a6SAndroid Build Coastguard Worker                adb.push(inputs=[sample_inputs], input_list=input_list)
301*523fa7a6SAndroid Build Coastguard Worker                adb.execute(method_index=method_index)
302*523fa7a6SAndroid Build Coastguard Worker                adb.pull(output_path=tmp_dir, callback=post_process)
303*523fa7a6SAndroid Build Coastguard Worker                self._assert_outputs_equal(outputs, ref_outputs)
304*523fa7a6SAndroid Build Coastguard Worker
305*523fa7a6SAndroid Build Coastguard Worker                if expected_profile_events != -1:
306*523fa7a6SAndroid Build Coastguard Worker                    adb.pull_etdump(etdump_path, callback=validate_profile)
307*523fa7a6SAndroid Build Coastguard Worker
308*523fa7a6SAndroid Build Coastguard Worker                if expected_intermediate_events != -1:
309*523fa7a6SAndroid Build Coastguard Worker                    adb.pull_debug_output(
310*523fa7a6SAndroid Build Coastguard Worker                        etdump_path,
311*523fa7a6SAndroid Build Coastguard Worker                        debug_output_path,
312*523fa7a6SAndroid Build Coastguard Worker                        callback=validate_intermediate_tensor,
313*523fa7a6SAndroid Build Coastguard Worker                    )
314*523fa7a6SAndroid Build Coastguard Worker
315*523fa7a6SAndroid Build Coastguard Worker    def lower_module_and_test_output(
316*523fa7a6SAndroid Build Coastguard Worker        self,
317*523fa7a6SAndroid Build Coastguard Worker        module: torch.nn.Module,
318*523fa7a6SAndroid Build Coastguard Worker        sample_inputs: Tuple[torch.Tensor],
319*523fa7a6SAndroid Build Coastguard Worker        expected_partitions: int = 1,
320*523fa7a6SAndroid Build Coastguard Worker        expected_profile_events: int = -1,
321*523fa7a6SAndroid Build Coastguard Worker        expected_intermediate_events: int = -1,
322*523fa7a6SAndroid Build Coastguard Worker        assert_output_equal: bool = True,
323*523fa7a6SAndroid Build Coastguard Worker        skip_node_id_set: set = None,
324*523fa7a6SAndroid Build Coastguard Worker        skip_node_op_set: set = None,
325*523fa7a6SAndroid Build Coastguard Worker    ):
326*523fa7a6SAndroid Build Coastguard Worker        qnn_partitioner = QnnPartitioner(
327*523fa7a6SAndroid Build Coastguard Worker            self.compiler_specs, skip_node_id_set, skip_node_op_set
328*523fa7a6SAndroid Build Coastguard Worker        )
329*523fa7a6SAndroid Build Coastguard Worker        delegated_program = capture_program(module, sample_inputs)
330*523fa7a6SAndroid Build Coastguard Worker
331*523fa7a6SAndroid Build Coastguard Worker        # this is needed for the ETRecord as lowering modifies the graph in-place
332*523fa7a6SAndroid Build Coastguard Worker        edge_copy = copy.deepcopy(delegated_program)
333*523fa7a6SAndroid Build Coastguard Worker
334*523fa7a6SAndroid Build Coastguard Worker        delegated_program.exported_program = to_backend(
335*523fa7a6SAndroid Build Coastguard Worker            delegated_program.exported_program, qnn_partitioner
336*523fa7a6SAndroid Build Coastguard Worker        )
337*523fa7a6SAndroid Build Coastguard Worker        exec_prog = delegated_program.to_executorch(
338*523fa7a6SAndroid Build Coastguard Worker            exir.ExecutorchBackendConfig(
339*523fa7a6SAndroid Build Coastguard Worker                # For shared buffer, user must pass the memory address
340*523fa7a6SAndroid Build Coastguard Worker                # which is allocated by RPC memory to executor runner.
341*523fa7a6SAndroid Build Coastguard Worker                # Therefore, won't want to pre-allocate
342*523fa7a6SAndroid Build Coastguard Worker                # by memory manager in runtime.
343*523fa7a6SAndroid Build Coastguard Worker                memory_planning_pass=MemoryPlanningPass(
344*523fa7a6SAndroid Build Coastguard Worker                    alloc_graph_input=not self.shared_buffer,
345*523fa7a6SAndroid Build Coastguard Worker                    alloc_graph_output=not self.shared_buffer,
346*523fa7a6SAndroid Build Coastguard Worker                ),
347*523fa7a6SAndroid Build Coastguard Worker            )
348*523fa7a6SAndroid Build Coastguard Worker        )
349*523fa7a6SAndroid Build Coastguard Worker
350*523fa7a6SAndroid Build Coastguard Worker        # Assert the backend name is qnn
351*523fa7a6SAndroid Build Coastguard Worker        self.assertEqual(
352*523fa7a6SAndroid Build Coastguard Worker            len(exec_prog.program.execution_plan[0].delegates),
353*523fa7a6SAndroid Build Coastguard Worker            expected_partitions,
354*523fa7a6SAndroid Build Coastguard Worker        )
355*523fa7a6SAndroid Build Coastguard Worker        for i in range(expected_partitions):
356*523fa7a6SAndroid Build Coastguard Worker            self.assertEqual(
357*523fa7a6SAndroid Build Coastguard Worker                exec_prog.program.execution_plan[0].delegates[i].id,
358*523fa7a6SAndroid Build Coastguard Worker                QnnBackend.__name__,
359*523fa7a6SAndroid Build Coastguard Worker            )
360*523fa7a6SAndroid Build Coastguard Worker
361*523fa7a6SAndroid Build Coastguard Worker        etrecord_path = "etrecord.bin"
362*523fa7a6SAndroid Build Coastguard Worker        if self.enable_profile:
363*523fa7a6SAndroid Build Coastguard Worker            generate_etrecord(etrecord_path, edge_copy, exec_prog)
364*523fa7a6SAndroid Build Coastguard Worker        # Check numerics
365*523fa7a6SAndroid Build Coastguard Worker        if (
366*523fa7a6SAndroid Build Coastguard Worker            assert_output_equal
367*523fa7a6SAndroid Build Coastguard Worker            or expected_profile_events != -1
368*523fa7a6SAndroid Build Coastguard Worker            or expected_intermediate_events != -1
369*523fa7a6SAndroid Build Coastguard Worker        ):
370*523fa7a6SAndroid Build Coastguard Worker            self.verify_output(
371*523fa7a6SAndroid Build Coastguard Worker                module,
372*523fa7a6SAndroid Build Coastguard Worker                sample_inputs,
373*523fa7a6SAndroid Build Coastguard Worker                exec_prog,
374*523fa7a6SAndroid Build Coastguard Worker                etrecord_path,
375*523fa7a6SAndroid Build Coastguard Worker                expected_profile_events,
376*523fa7a6SAndroid Build Coastguard Worker                expected_intermediate_events,
377*523fa7a6SAndroid Build Coastguard Worker            )
378*523fa7a6SAndroid Build Coastguard Worker
379*523fa7a6SAndroid Build Coastguard Worker    def get_qdq_module(
380*523fa7a6SAndroid Build Coastguard Worker        self,
381*523fa7a6SAndroid Build Coastguard Worker        module: torch.nn.Module,
382*523fa7a6SAndroid Build Coastguard Worker        inputs: Tuple[torch.Tensor],
383*523fa7a6SAndroid Build Coastguard Worker        is_conv_per_channel: Optional[bool] = True,
384*523fa7a6SAndroid Build Coastguard Worker        is_linear_per_channel: Optional[bool] = False,
385*523fa7a6SAndroid Build Coastguard Worker        custom_quant_annotations: Tuple[Callable] = (),
386*523fa7a6SAndroid Build Coastguard Worker        quant_dtype: QuantDtype = QuantDtype.use_8a8w,
387*523fa7a6SAndroid Build Coastguard Worker    ) -> torch.fx.GraphModule:
388*523fa7a6SAndroid Build Coastguard Worker        m = torch.export.export(module, inputs).module()
389*523fa7a6SAndroid Build Coastguard Worker
390*523fa7a6SAndroid Build Coastguard Worker        quantizer = QnnQuantizer()
391*523fa7a6SAndroid Build Coastguard Worker        quantizer.add_custom_quant_annotations(custom_quant_annotations)
392*523fa7a6SAndroid Build Coastguard Worker        quantizer.set_per_channel_conv_quant(is_conv_per_channel)
393*523fa7a6SAndroid Build Coastguard Worker        quantizer.set_per_channel_linear_quant(is_linear_per_channel)
394*523fa7a6SAndroid Build Coastguard Worker        quantizer.set_quant_config(quant_dtype)
395*523fa7a6SAndroid Build Coastguard Worker
396*523fa7a6SAndroid Build Coastguard Worker        prepared = prepare_pt2e(m, quantizer)
397*523fa7a6SAndroid Build Coastguard Worker        prepared(*inputs)
398*523fa7a6SAndroid Build Coastguard Worker        quantized_module = convert_pt2e(prepared)
399*523fa7a6SAndroid Build Coastguard Worker        nodes = {node.target for node in quantized_module.graph.nodes}
400*523fa7a6SAndroid Build Coastguard Worker        q_and_dq = {
401*523fa7a6SAndroid Build Coastguard Worker            torch.ops.quantized_decomposed.quantize_per_tensor.default,
402*523fa7a6SAndroid Build Coastguard Worker            torch.ops.quantized_decomposed.dequantize_per_tensor.default,
403*523fa7a6SAndroid Build Coastguard Worker            torch.ops.quantized_decomposed.quantize_per_channel.default,
404*523fa7a6SAndroid Build Coastguard Worker            torch.ops.quantized_decomposed.dequantize_per_channel.default,
405*523fa7a6SAndroid Build Coastguard Worker        }
406*523fa7a6SAndroid Build Coastguard Worker        self.assertTrue(nodes.intersection(q_and_dq))
407*523fa7a6SAndroid Build Coastguard Worker        return quantized_module
408*523fa7a6SAndroid Build Coastguard Worker
409*523fa7a6SAndroid Build Coastguard Worker    def get_prepared_qat_module(
410*523fa7a6SAndroid Build Coastguard Worker        self,
411*523fa7a6SAndroid Build Coastguard Worker        module: torch.nn.Module,
412*523fa7a6SAndroid Build Coastguard Worker        inputs: Tuple[torch.Tensor],
413*523fa7a6SAndroid Build Coastguard Worker        is_conv_per_channel: Optional[bool] = True,
414*523fa7a6SAndroid Build Coastguard Worker        is_linear_per_channel: Optional[bool] = False,
415*523fa7a6SAndroid Build Coastguard Worker        custom_quant_annotations: Tuple[Callable] = (),
416*523fa7a6SAndroid Build Coastguard Worker        quant_dtype: QuantDtype = QuantDtype.use_8a8w,
417*523fa7a6SAndroid Build Coastguard Worker    ) -> torch.fx.GraphModule:
418*523fa7a6SAndroid Build Coastguard Worker        m = torch.export.export_for_training(module, inputs).module()
419*523fa7a6SAndroid Build Coastguard Worker
420*523fa7a6SAndroid Build Coastguard Worker        quantizer = QnnQuantizer()
421*523fa7a6SAndroid Build Coastguard Worker        quantizer.add_custom_quant_annotations(custom_quant_annotations)
422*523fa7a6SAndroid Build Coastguard Worker        quantizer.set_per_channel_conv_quant(is_conv_per_channel)
423*523fa7a6SAndroid Build Coastguard Worker        quantizer.set_per_channel_linear_quant(is_linear_per_channel)
424*523fa7a6SAndroid Build Coastguard Worker
425*523fa7a6SAndroid Build Coastguard Worker        if quant_dtype == QuantDtype.use_8a8w:
426*523fa7a6SAndroid Build Coastguard Worker            quantizer.set_quant_config(quant_dtype, is_qat=True)
427*523fa7a6SAndroid Build Coastguard Worker        else:
428*523fa7a6SAndroid Build Coastguard Worker            raise RuntimeError("Shuld not be here")
429*523fa7a6SAndroid Build Coastguard Worker
430*523fa7a6SAndroid Build Coastguard Worker        prepared = prepare_qat_pt2e(m, quantizer)
431*523fa7a6SAndroid Build Coastguard Worker        return torch.ao.quantization.move_exported_model_to_train(prepared)
432*523fa7a6SAndroid Build Coastguard Worker
433*523fa7a6SAndroid Build Coastguard Worker    def get_converted_sgd_trained_module(
434*523fa7a6SAndroid Build Coastguard Worker        self,
435*523fa7a6SAndroid Build Coastguard Worker        ori_module: torch.nn.Module,
436*523fa7a6SAndroid Build Coastguard Worker        prepared: torch.nn.Module,
437*523fa7a6SAndroid Build Coastguard Worker        inputs: Tuple[torch.Tensor],
438*523fa7a6SAndroid Build Coastguard Worker    ) -> torch.fx.GraphModule:
439*523fa7a6SAndroid Build Coastguard Worker        optimizer = torch.optim.SGD(prepared.parameters(), lr=0.0001)
440*523fa7a6SAndroid Build Coastguard Worker        criterion = torch.nn.CrossEntropyLoss()
441*523fa7a6SAndroid Build Coastguard Worker        output = prepared(*inputs)
442*523fa7a6SAndroid Build Coastguard Worker        loss = criterion(output, ori_module(*inputs))
443*523fa7a6SAndroid Build Coastguard Worker        optimizer.zero_grad()
444*523fa7a6SAndroid Build Coastguard Worker        loss.backward()
445*523fa7a6SAndroid Build Coastguard Worker        optimizer.step()
446*523fa7a6SAndroid Build Coastguard Worker        return torch.ao.quantization.quantize_pt2e.convert_pt2e(prepared)
447*523fa7a6SAndroid Build Coastguard Worker
448*523fa7a6SAndroid Build Coastguard Worker    def split_graph(self, graph_module: torch.fx.GraphModule, division: int):
449*523fa7a6SAndroid Build Coastguard Worker        class SplitGraph(ExportPass):
450*523fa7a6SAndroid Build Coastguard Worker            """
451*523fa7a6SAndroid Build Coastguard Worker            Split graph based on number of nodes.
452*523fa7a6SAndroid Build Coastguard Worker            """
453*523fa7a6SAndroid Build Coastguard Worker
454*523fa7a6SAndroid Build Coastguard Worker            def __init__(self, shares):
455*523fa7a6SAndroid Build Coastguard Worker                super().__init__()
456*523fa7a6SAndroid Build Coastguard Worker                self.shares = shares
457*523fa7a6SAndroid Build Coastguard Worker
458*523fa7a6SAndroid Build Coastguard Worker            def _insert_clone(
459*523fa7a6SAndroid Build Coastguard Worker                self, graph_module: torch.fx.GraphModule
460*523fa7a6SAndroid Build Coastguard Worker            ) -> torch.fx.GraphModule:
461*523fa7a6SAndroid Build Coastguard Worker                num_graph_nodes = 0
462*523fa7a6SAndroid Build Coastguard Worker                for node in graph_module.graph.nodes:
463*523fa7a6SAndroid Build Coastguard Worker                    num_graph_nodes += 1 if node.op == "call_function" else 0
464*523fa7a6SAndroid Build Coastguard Worker
465*523fa7a6SAndroid Build Coastguard Worker                    if num_graph_nodes % self.shares != 0 or node.op != "call_function":
466*523fa7a6SAndroid Build Coastguard Worker                        continue
467*523fa7a6SAndroid Build Coastguard Worker
468*523fa7a6SAndroid Build Coastguard Worker                    with graph_module.graph.inserting_after(node):
469*523fa7a6SAndroid Build Coastguard Worker                        users = list(node.users.keys())
470*523fa7a6SAndroid Build Coastguard Worker                        inserted_node = graph_module.graph.create_node(
471*523fa7a6SAndroid Build Coastguard Worker                            "call_function",
472*523fa7a6SAndroid Build Coastguard Worker                            exir_ops.edge.aten.clone.default,
473*523fa7a6SAndroid Build Coastguard Worker                            (node,),
474*523fa7a6SAndroid Build Coastguard Worker                        )
475*523fa7a6SAndroid Build Coastguard Worker                        inserted_node.meta["val"] = node.meta["val"]
476*523fa7a6SAndroid Build Coastguard Worker                        if "quant_attrs" in node.meta:
477*523fa7a6SAndroid Build Coastguard Worker                            inserted_node.meta["quant_attrs"] = node.meta["quant_attrs"]
478*523fa7a6SAndroid Build Coastguard Worker                        for user in users:
479*523fa7a6SAndroid Build Coastguard Worker                            user.replace_input_with(node, inserted_node)
480*523fa7a6SAndroid Build Coastguard Worker
481*523fa7a6SAndroid Build Coastguard Worker            def call(self, graph_module: torch.fx.GraphModule):
482*523fa7a6SAndroid Build Coastguard Worker                self._insert_clone(graph_module)
483*523fa7a6SAndroid Build Coastguard Worker                graph_module.recompile()
484*523fa7a6SAndroid Build Coastguard Worker
485*523fa7a6SAndroid Build Coastguard Worker        num_graph_nodes = 0
486*523fa7a6SAndroid Build Coastguard Worker        for node in graph_module.graph.nodes:
487*523fa7a6SAndroid Build Coastguard Worker            num_graph_nodes += 1 if node.op == "call_function" else 0
488*523fa7a6SAndroid Build Coastguard Worker
489*523fa7a6SAndroid Build Coastguard Worker        SplitGraph(-(num_graph_nodes // -division))(graph_module)
490