1*523fa7a6SAndroid Build Coastguard Worker# Copyright (c) Qualcomm Innovation Center, Inc. 2*523fa7a6SAndroid Build Coastguard Worker# All rights reserved 3*523fa7a6SAndroid Build Coastguard Worker# 4*523fa7a6SAndroid Build Coastguard Worker# This source code is licensed under the BSD-style license found in the 5*523fa7a6SAndroid Build Coastguard Worker# LICENSE file in the root directory of this source tree. 6*523fa7a6SAndroid Build Coastguard Workerimport collections 7*523fa7a6SAndroid Build Coastguard Workerimport copy 8*523fa7a6SAndroid Build Coastguard Workerimport os 9*523fa7a6SAndroid Build Coastguard Workerimport subprocess 10*523fa7a6SAndroid Build Coastguard Workerimport tempfile 11*523fa7a6SAndroid Build Coastguard Workerimport unittest 12*523fa7a6SAndroid Build Coastguard Workerfrom typing import Callable, Dict, List, Optional, Tuple 13*523fa7a6SAndroid Build Coastguard Worker 14*523fa7a6SAndroid Build Coastguard Workerimport numpy as np 15*523fa7a6SAndroid Build Coastguard Workerimport torch 16*523fa7a6SAndroid Build Coastguard Worker 17*523fa7a6SAndroid Build Coastguard Workerfrom executorch import exir 18*523fa7a6SAndroid Build Coastguard Workerfrom executorch.backends.qualcomm.partition.qnn_partitioner import QnnPartitioner 19*523fa7a6SAndroid Build Coastguard Workerfrom executorch.backends.qualcomm.qnn_preprocess import QnnBackend 20*523fa7a6SAndroid Build Coastguard Workerfrom executorch.backends.qualcomm.quantizer.quantizer import QnnQuantizer, QuantDtype 21*523fa7a6SAndroid Build Coastguard Workerfrom executorch.backends.qualcomm.serialization.qc_schema import QcomChipset 22*523fa7a6SAndroid Build Coastguard Workerfrom executorch.backends.qualcomm.utils.utils import ( 23*523fa7a6SAndroid Build Coastguard Worker capture_program, 24*523fa7a6SAndroid Build Coastguard Worker get_soc_to_chipset_map, 25*523fa7a6SAndroid Build Coastguard Worker) 26*523fa7a6SAndroid Build Coastguard Workerfrom executorch.devtools import generate_etrecord, Inspector 27*523fa7a6SAndroid Build Coastguard Workerfrom executorch.examples.qualcomm.utils import ( 28*523fa7a6SAndroid Build Coastguard Worker generate_inputs, 29*523fa7a6SAndroid Build Coastguard Worker make_output_dir, 30*523fa7a6SAndroid Build Coastguard Worker SimpleADB, 31*523fa7a6SAndroid Build Coastguard Worker) 32*523fa7a6SAndroid Build Coastguard Worker 33*523fa7a6SAndroid Build Coastguard Workerfrom executorch.exir.backend.backend_api import to_backend 34*523fa7a6SAndroid Build Coastguard Workerfrom executorch.exir.backend.compile_spec_schema import CompileSpec 35*523fa7a6SAndroid Build Coastguard Workerfrom executorch.exir.dialects._ops import ops as exir_ops 36*523fa7a6SAndroid Build Coastguard Workerfrom executorch.exir.pass_base import ExportPass 37*523fa7a6SAndroid Build Coastguard Workerfrom executorch.exir.passes.memory_planning_pass import MemoryPlanningPass 38*523fa7a6SAndroid Build Coastguard Workerfrom executorch.exir.program import ExecutorchProgram, ExecutorchProgramManager 39*523fa7a6SAndroid Build Coastguard Workerfrom torch.ao.quantization.quantize_pt2e import ( 40*523fa7a6SAndroid Build Coastguard Worker convert_pt2e, 41*523fa7a6SAndroid Build Coastguard Worker prepare_pt2e, 42*523fa7a6SAndroid Build Coastguard Worker prepare_qat_pt2e, 43*523fa7a6SAndroid Build Coastguard Worker) 44*523fa7a6SAndroid Build Coastguard Worker 45*523fa7a6SAndroid Build Coastguard Worker 46*523fa7a6SAndroid Build Coastguard Workerdef generate_context_binary( 47*523fa7a6SAndroid Build Coastguard Worker module: torch.nn.Module, 48*523fa7a6SAndroid Build Coastguard Worker inputs: Dict[str, torch.Tensor], 49*523fa7a6SAndroid Build Coastguard Worker quantized: bool, 50*523fa7a6SAndroid Build Coastguard Worker artifact_dir: str, 51*523fa7a6SAndroid Build Coastguard Worker): 52*523fa7a6SAndroid Build Coastguard Worker # we also expect clang showing in PATH or context may fail to generate 53*523fa7a6SAndroid Build Coastguard Worker qnn_sdk = os.environ.get("QNN_SDK_ROOT", None) 54*523fa7a6SAndroid Build Coastguard Worker ndk = os.environ.get("ANDROID_NDK_ROOT", None) 55*523fa7a6SAndroid Build Coastguard Worker assert qnn_sdk, "QNN_SDK_ROOT was not found in environment variable" 56*523fa7a6SAndroid Build Coastguard Worker assert ndk, "ANDROID_NDK_ROOT was not found in environment variable" 57*523fa7a6SAndroid Build Coastguard Worker 58*523fa7a6SAndroid Build Coastguard Worker inputs_tup = tuple(inputs.values()) 59*523fa7a6SAndroid Build Coastguard Worker jit_module = torch.jit.trace(module, inputs_tup) 60*523fa7a6SAndroid Build Coastguard Worker torch.jit.save(jit_module, f"{artifact_dir}/jit_module.pt") 61*523fa7a6SAndroid Build Coastguard Worker 62*523fa7a6SAndroid Build Coastguard Worker # input data 63*523fa7a6SAndroid Build Coastguard Worker if quantized: 64*523fa7a6SAndroid Build Coastguard Worker input_list = [] 65*523fa7a6SAndroid Build Coastguard Worker for name, data in inputs.items(): 66*523fa7a6SAndroid Build Coastguard Worker file_name = f"{artifact_dir}/{name}.raw" 67*523fa7a6SAndroid Build Coastguard Worker data.detach().numpy().tofile(file_name) 68*523fa7a6SAndroid Build Coastguard Worker input_list.append(file_name) 69*523fa7a6SAndroid Build Coastguard Worker 70*523fa7a6SAndroid Build Coastguard Worker with open(f"{artifact_dir}/input_list.txt", "w") as f: 71*523fa7a6SAndroid Build Coastguard Worker f.write(" ".join(input_list)) 72*523fa7a6SAndroid Build Coastguard Worker 73*523fa7a6SAndroid Build Coastguard Worker # flow of qnn tools 74*523fa7a6SAndroid Build Coastguard Worker target = "x86_64-linux-clang" 75*523fa7a6SAndroid Build Coastguard Worker inputs_str = [ 76*523fa7a6SAndroid Build Coastguard Worker f"-d '{k}' {str(tuple(v.shape)).replace(' ', '')[1:-1]}" 77*523fa7a6SAndroid Build Coastguard Worker for k, v in inputs.items() 78*523fa7a6SAndroid Build Coastguard Worker ] 79*523fa7a6SAndroid Build Coastguard Worker cmds = [ 80*523fa7a6SAndroid Build Coastguard Worker # setup qnn env 81*523fa7a6SAndroid Build Coastguard Worker f"source {qnn_sdk}/bin/envsetup.sh;" 82*523fa7a6SAndroid Build Coastguard Worker # qnn-pytorch-converter 83*523fa7a6SAndroid Build Coastguard Worker f"{qnn_sdk}/bin/{target}/qnn-pytorch-converter", 84*523fa7a6SAndroid Build Coastguard Worker f"-i {artifact_dir}/jit_module.pt", 85*523fa7a6SAndroid Build Coastguard Worker *inputs_str, 86*523fa7a6SAndroid Build Coastguard Worker f"--input_list {artifact_dir}/input_list.txt" if quantized else "", 87*523fa7a6SAndroid Build Coastguard Worker "--preserve_io", 88*523fa7a6SAndroid Build Coastguard Worker f"-o {artifact_dir}/model.cpp;", 89*523fa7a6SAndroid Build Coastguard Worker # qnn-model-lib-generator 90*523fa7a6SAndroid Build Coastguard Worker f"{qnn_sdk}/bin/{target}/qnn-model-lib-generator", 91*523fa7a6SAndroid Build Coastguard Worker f"-c {artifact_dir}/model.cpp", 92*523fa7a6SAndroid Build Coastguard Worker f"-t {target}", 93*523fa7a6SAndroid Build Coastguard Worker "-l model", 94*523fa7a6SAndroid Build Coastguard Worker f"-o {artifact_dir}/model_libs;", 95*523fa7a6SAndroid Build Coastguard Worker # qnn-context-binary-generator 96*523fa7a6SAndroid Build Coastguard Worker f"{qnn_sdk}/bin/{target}/qnn-context-binary-generator", 97*523fa7a6SAndroid Build Coastguard Worker f"--model {artifact_dir}/model_libs/{target}/libmodel.so", 98*523fa7a6SAndroid Build Coastguard Worker f"--backend {qnn_sdk}/lib/{target}/libQnnHtp.so", 99*523fa7a6SAndroid Build Coastguard Worker "--binary_file model_ctx", 100*523fa7a6SAndroid Build Coastguard Worker f"--output_dir {artifact_dir};", 101*523fa7a6SAndroid Build Coastguard Worker ] 102*523fa7a6SAndroid Build Coastguard Worker result = subprocess.run( 103*523fa7a6SAndroid Build Coastguard Worker " ".join(cmds), 104*523fa7a6SAndroid Build Coastguard Worker shell=True, 105*523fa7a6SAndroid Build Coastguard Worker executable="/bin/bash", 106*523fa7a6SAndroid Build Coastguard Worker capture_output=True, 107*523fa7a6SAndroid Build Coastguard Worker ) 108*523fa7a6SAndroid Build Coastguard Worker assert os.path.isfile(f"{artifact_dir}/model_ctx.bin"), print(result.stderr) 109*523fa7a6SAndroid Build Coastguard Worker 110*523fa7a6SAndroid Build Coastguard Worker 111*523fa7a6SAndroid Build Coastguard Workerclass TestQNN(unittest.TestCase): 112*523fa7a6SAndroid Build Coastguard Worker rtol: float = 0 113*523fa7a6SAndroid Build Coastguard Worker atol: float = 0 114*523fa7a6SAndroid Build Coastguard Worker host: str = "" 115*523fa7a6SAndroid Build Coastguard Worker device: str = "" 116*523fa7a6SAndroid Build Coastguard Worker build_folder: str = "" 117*523fa7a6SAndroid Build Coastguard Worker model: QcomChipset = None 118*523fa7a6SAndroid Build Coastguard Worker compiler_specs: List[CompileSpec] = None 119*523fa7a6SAndroid Build Coastguard Worker chipset_table = get_soc_to_chipset_map() 120*523fa7a6SAndroid Build Coastguard Worker error_only = False 121*523fa7a6SAndroid Build Coastguard Worker ip = "localhost" 122*523fa7a6SAndroid Build Coastguard Worker port = 8080 123*523fa7a6SAndroid Build Coastguard Worker executorch_root: str = "" 124*523fa7a6SAndroid Build Coastguard Worker artifact_dir: str = "" 125*523fa7a6SAndroid Build Coastguard Worker image_dataset: str = "" 126*523fa7a6SAndroid Build Coastguard Worker pretrained_weight: str = "" 127*523fa7a6SAndroid Build Coastguard Worker enable_profile: bool = False 128*523fa7a6SAndroid Build Coastguard Worker online_prepare: bool = False 129*523fa7a6SAndroid Build Coastguard Worker use_8a8w: str = "8a8w" 130*523fa7a6SAndroid Build Coastguard Worker use_16a16w: str = "16a16w" 131*523fa7a6SAndroid Build Coastguard Worker use_16a4w: str = "16a4w" 132*523fa7a6SAndroid Build Coastguard Worker shared_buffer: bool = False 133*523fa7a6SAndroid Build Coastguard Worker enable_x86_64: bool = False 134*523fa7a6SAndroid Build Coastguard Worker 135*523fa7a6SAndroid Build Coastguard Worker def _assert_outputs_equal(self, model_output, ref_output): 136*523fa7a6SAndroid Build Coastguard Worker self.assertTrue(len(ref_output) == len(model_output)) 137*523fa7a6SAndroid Build Coastguard Worker for i in range(len(ref_output)): 138*523fa7a6SAndroid Build Coastguard Worker self.assertTrue( 139*523fa7a6SAndroid Build Coastguard Worker torch.allclose( 140*523fa7a6SAndroid Build Coastguard Worker model_output[i], ref_output[i], atol=self.atol, rtol=self.rtol 141*523fa7a6SAndroid Build Coastguard Worker ), 142*523fa7a6SAndroid Build Coastguard Worker msg=f"ref_output:\n{ref_output[i]}\n\nmodel_output:\n{model_output[i]}", 143*523fa7a6SAndroid Build Coastguard Worker ) 144*523fa7a6SAndroid Build Coastguard Worker 145*523fa7a6SAndroid Build Coastguard Worker def _save_model_and_expected_output( 146*523fa7a6SAndroid Build Coastguard Worker self, 147*523fa7a6SAndroid Build Coastguard Worker module: torch.nn.Module, 148*523fa7a6SAndroid Build Coastguard Worker buffer: exir.ExirExportedProgram, 149*523fa7a6SAndroid Build Coastguard Worker inputs: Tuple[torch.Tensor], 150*523fa7a6SAndroid Build Coastguard Worker dir_name: str, 151*523fa7a6SAndroid Build Coastguard Worker ) -> None: 152*523fa7a6SAndroid Build Coastguard Worker # Save the input data list to be executed 153*523fa7a6SAndroid Build Coastguard Worker input_list = "" 154*523fa7a6SAndroid Build Coastguard Worker for idx, _ in enumerate(inputs): 155*523fa7a6SAndroid Build Coastguard Worker input_name = f"input_0_{idx}.raw" 156*523fa7a6SAndroid Build Coastguard Worker input_list += input_name + " " 157*523fa7a6SAndroid Build Coastguard Worker input_list = input_list.strip() + "\n" 158*523fa7a6SAndroid Build Coastguard Worker 159*523fa7a6SAndroid Build Coastguard Worker ref_output = module(*inputs) 160*523fa7a6SAndroid Build Coastguard Worker 161*523fa7a6SAndroid Build Coastguard Worker # Save the expected output data to be verified 162*523fa7a6SAndroid Build Coastguard Worker ref_outputs = [] 163*523fa7a6SAndroid Build Coastguard Worker if isinstance(ref_output, collections.OrderedDict): 164*523fa7a6SAndroid Build Coastguard Worker ref_outputs.append(ref_output["out"].detach()) 165*523fa7a6SAndroid Build Coastguard Worker elif isinstance(ref_output, (list, tuple)): 166*523fa7a6SAndroid Build Coastguard Worker for output in ref_output: 167*523fa7a6SAndroid Build Coastguard Worker ref_outputs.append(output.detach()) 168*523fa7a6SAndroid Build Coastguard Worker else: 169*523fa7a6SAndroid Build Coastguard Worker ref_outputs.append(ref_output.detach()) 170*523fa7a6SAndroid Build Coastguard Worker 171*523fa7a6SAndroid Build Coastguard Worker pte_fname = f"{dir_name}/qnn_executorch_test.pte" 172*523fa7a6SAndroid Build Coastguard Worker with open(pte_fname, "wb") as file: 173*523fa7a6SAndroid Build Coastguard Worker file.write(buffer) 174*523fa7a6SAndroid Build Coastguard Worker 175*523fa7a6SAndroid Build Coastguard Worker return input_list, ref_outputs, pte_fname 176*523fa7a6SAndroid Build Coastguard Worker 177*523fa7a6SAndroid Build Coastguard Worker def verify_output( # noqa: C901 178*523fa7a6SAndroid Build Coastguard Worker self, 179*523fa7a6SAndroid Build Coastguard Worker module: torch.nn.Module, 180*523fa7a6SAndroid Build Coastguard Worker sample_inputs: Tuple[torch.Tensor], 181*523fa7a6SAndroid Build Coastguard Worker executorch_prog: ExecutorchProgram | ExecutorchProgramManager, 182*523fa7a6SAndroid Build Coastguard Worker etrecord_path: str = "etrecord.bin", 183*523fa7a6SAndroid Build Coastguard Worker expected_profile_events: int = -1, 184*523fa7a6SAndroid Build Coastguard Worker expected_intermediate_events: int = -1, 185*523fa7a6SAndroid Build Coastguard Worker method_index: int = 0, 186*523fa7a6SAndroid Build Coastguard Worker ): 187*523fa7a6SAndroid Build Coastguard Worker with tempfile.TemporaryDirectory() as tmp_dir: 188*523fa7a6SAndroid Build Coastguard Worker ( 189*523fa7a6SAndroid Build Coastguard Worker input_list, 190*523fa7a6SAndroid Build Coastguard Worker ref_outputs, 191*523fa7a6SAndroid Build Coastguard Worker pte_fname, 192*523fa7a6SAndroid Build Coastguard Worker ) = self._save_model_and_expected_output( 193*523fa7a6SAndroid Build Coastguard Worker module, 194*523fa7a6SAndroid Build Coastguard Worker executorch_prog.buffer, 195*523fa7a6SAndroid Build Coastguard Worker sample_inputs, 196*523fa7a6SAndroid Build Coastguard Worker tmp_dir, 197*523fa7a6SAndroid Build Coastguard Worker ) 198*523fa7a6SAndroid Build Coastguard Worker 199*523fa7a6SAndroid Build Coastguard Worker output_dir = f"{tmp_dir}/outputs" 200*523fa7a6SAndroid Build Coastguard Worker outputs = [] 201*523fa7a6SAndroid Build Coastguard Worker etdump_path = f"{tmp_dir}/etdump.etdp" 202*523fa7a6SAndroid Build Coastguard Worker debug_output_path = f"{tmp_dir}/debug_output.bin" 203*523fa7a6SAndroid Build Coastguard Worker 204*523fa7a6SAndroid Build Coastguard Worker def post_process(): 205*523fa7a6SAndroid Build Coastguard Worker for i, f in enumerate(sorted(os.listdir(output_dir))): 206*523fa7a6SAndroid Build Coastguard Worker filename = os.path.join(output_dir, f) 207*523fa7a6SAndroid Build Coastguard Worker output = np.fromfile(filename, dtype=ref_outputs[i].numpy().dtype) 208*523fa7a6SAndroid Build Coastguard Worker output = torch.from_numpy(output).reshape(ref_outputs[i].shape) 209*523fa7a6SAndroid Build Coastguard Worker outputs.append(output) 210*523fa7a6SAndroid Build Coastguard Worker 211*523fa7a6SAndroid Build Coastguard Worker def validate_profile(): 212*523fa7a6SAndroid Build Coastguard Worker inspector = Inspector(etdump_path=etdump_path, etrecord=etrecord_path) 213*523fa7a6SAndroid Build Coastguard Worker self.assertTrue( 214*523fa7a6SAndroid Build Coastguard Worker len(inspector.to_dataframe().index) == expected_profile_events 215*523fa7a6SAndroid Build Coastguard Worker ) 216*523fa7a6SAndroid Build Coastguard Worker 217*523fa7a6SAndroid Build Coastguard Worker def validate_intermediate_tensor(): 218*523fa7a6SAndroid Build Coastguard Worker inspector = Inspector( 219*523fa7a6SAndroid Build Coastguard Worker etdump_path=etdump_path, debug_buffer_path=debug_output_path 220*523fa7a6SAndroid Build Coastguard Worker ) 221*523fa7a6SAndroid Build Coastguard Worker for event_block in inspector.event_blocks: 222*523fa7a6SAndroid Build Coastguard Worker if event_block.name == "Execute": 223*523fa7a6SAndroid Build Coastguard Worker self.assertTrue( 224*523fa7a6SAndroid Build Coastguard Worker len(event_block.events) == expected_intermediate_events 225*523fa7a6SAndroid Build Coastguard Worker ) 226*523fa7a6SAndroid Build Coastguard Worker 227*523fa7a6SAndroid Build Coastguard Worker if self.enable_x86_64: 228*523fa7a6SAndroid Build Coastguard Worker generate_inputs(tmp_dir, "input_list.txt", [sample_inputs], input_list) 229*523fa7a6SAndroid Build Coastguard Worker make_output_dir(output_dir) 230*523fa7a6SAndroid Build Coastguard Worker 231*523fa7a6SAndroid Build Coastguard Worker target = "x86_64-linux-clang" 232*523fa7a6SAndroid Build Coastguard Worker qnn_sdk = os.environ.get("QNN_SDK_ROOT", None) 233*523fa7a6SAndroid Build Coastguard Worker assert qnn_sdk, "QNN_SDK_ROOT was not found in environment variable" 234*523fa7a6SAndroid Build Coastguard Worker 235*523fa7a6SAndroid Build Coastguard Worker build_folder = self.build_folder 236*523fa7a6SAndroid Build Coastguard Worker if os.path.isabs(self.build_folder): 237*523fa7a6SAndroid Build Coastguard Worker # obey user's opinion 238*523fa7a6SAndroid Build Coastguard Worker pass 239*523fa7a6SAndroid Build Coastguard Worker else: 240*523fa7a6SAndroid Build Coastguard Worker # ok, assuming the user give a relative path to cwd 241*523fa7a6SAndroid Build Coastguard Worker build_folder = os.path.join(os.getcwd(), self.build_folder) 242*523fa7a6SAndroid Build Coastguard Worker 243*523fa7a6SAndroid Build Coastguard Worker cmd = [ 244*523fa7a6SAndroid Build Coastguard Worker # qnn_executor_runner 245*523fa7a6SAndroid Build Coastguard Worker f"{build_folder}/examples/qualcomm/executor_runner/qnn_executor_runner", 246*523fa7a6SAndroid Build Coastguard Worker "--model_path", 247*523fa7a6SAndroid Build Coastguard Worker pte_fname, 248*523fa7a6SAndroid Build Coastguard Worker "--input_list_path", 249*523fa7a6SAndroid Build Coastguard Worker f"{tmp_dir}/input_list.txt", 250*523fa7a6SAndroid Build Coastguard Worker "--output_folder_path", 251*523fa7a6SAndroid Build Coastguard Worker output_dir, 252*523fa7a6SAndroid Build Coastguard Worker "--method_index", 253*523fa7a6SAndroid Build Coastguard Worker str(method_index), 254*523fa7a6SAndroid Build Coastguard Worker ] 255*523fa7a6SAndroid Build Coastguard Worker if expected_intermediate_events != -1: 256*523fa7a6SAndroid Build Coastguard Worker cmd.append("--dump_intermediate_outputs") 257*523fa7a6SAndroid Build Coastguard Worker 258*523fa7a6SAndroid Build Coastguard Worker env = dict(os.environ) 259*523fa7a6SAndroid Build Coastguard Worker env["LD_LIBRARY_PATH"] = f"{qnn_sdk}/lib/{target}/:{build_folder}/lib" 260*523fa7a6SAndroid Build Coastguard Worker proc = subprocess.run( 261*523fa7a6SAndroid Build Coastguard Worker cmd, 262*523fa7a6SAndroid Build Coastguard Worker stdout=subprocess.PIPE, 263*523fa7a6SAndroid Build Coastguard Worker stderr=subprocess.STDOUT, 264*523fa7a6SAndroid Build Coastguard Worker env=env, 265*523fa7a6SAndroid Build Coastguard Worker cwd=tmp_dir, 266*523fa7a6SAndroid Build Coastguard Worker ) 267*523fa7a6SAndroid Build Coastguard Worker 268*523fa7a6SAndroid Build Coastguard Worker self.assertEqual( 269*523fa7a6SAndroid Build Coastguard Worker proc.returncode, 270*523fa7a6SAndroid Build Coastguard Worker 0, 271*523fa7a6SAndroid Build Coastguard Worker f"The process running qnn_executorch_runner return {proc.returncode}, " 272*523fa7a6SAndroid Build Coastguard Worker "STDOUT=\n" 273*523fa7a6SAndroid Build Coastguard Worker f"{proc.stdout.decode('utf-8')}", 274*523fa7a6SAndroid Build Coastguard Worker ) 275*523fa7a6SAndroid Build Coastguard Worker 276*523fa7a6SAndroid Build Coastguard Worker # Verify the outputs 277*523fa7a6SAndroid Build Coastguard Worker post_process() 278*523fa7a6SAndroid Build Coastguard Worker self._assert_outputs_equal(outputs, ref_outputs) 279*523fa7a6SAndroid Build Coastguard Worker 280*523fa7a6SAndroid Build Coastguard Worker # Verify the etdump 281*523fa7a6SAndroid Build Coastguard Worker if expected_profile_events != -1: 282*523fa7a6SAndroid Build Coastguard Worker validate_profile() 283*523fa7a6SAndroid Build Coastguard Worker 284*523fa7a6SAndroid Build Coastguard Worker if expected_intermediate_events != -1: 285*523fa7a6SAndroid Build Coastguard Worker validate_intermediate_tensor() 286*523fa7a6SAndroid Build Coastguard Worker else: 287*523fa7a6SAndroid Build Coastguard Worker adb = SimpleADB( 288*523fa7a6SAndroid Build Coastguard Worker qnn_sdk=os.getenv("QNN_SDK_ROOT"), 289*523fa7a6SAndroid Build Coastguard Worker build_path=self.build_folder, 290*523fa7a6SAndroid Build Coastguard Worker pte_path=pte_fname, 291*523fa7a6SAndroid Build Coastguard Worker workspace="/data/local/tmp/qnn_executorch_test", 292*523fa7a6SAndroid Build Coastguard Worker device_id=self.device, 293*523fa7a6SAndroid Build Coastguard Worker host_id=self.host, 294*523fa7a6SAndroid Build Coastguard Worker soc_model=self.model, 295*523fa7a6SAndroid Build Coastguard Worker error_only=self.error_only, 296*523fa7a6SAndroid Build Coastguard Worker dump_intermediate_outputs=( 297*523fa7a6SAndroid Build Coastguard Worker True if expected_intermediate_events != -1 else False 298*523fa7a6SAndroid Build Coastguard Worker ), 299*523fa7a6SAndroid Build Coastguard Worker ) 300*523fa7a6SAndroid Build Coastguard Worker adb.push(inputs=[sample_inputs], input_list=input_list) 301*523fa7a6SAndroid Build Coastguard Worker adb.execute(method_index=method_index) 302*523fa7a6SAndroid Build Coastguard Worker adb.pull(output_path=tmp_dir, callback=post_process) 303*523fa7a6SAndroid Build Coastguard Worker self._assert_outputs_equal(outputs, ref_outputs) 304*523fa7a6SAndroid Build Coastguard Worker 305*523fa7a6SAndroid Build Coastguard Worker if expected_profile_events != -1: 306*523fa7a6SAndroid Build Coastguard Worker adb.pull_etdump(etdump_path, callback=validate_profile) 307*523fa7a6SAndroid Build Coastguard Worker 308*523fa7a6SAndroid Build Coastguard Worker if expected_intermediate_events != -1: 309*523fa7a6SAndroid Build Coastguard Worker adb.pull_debug_output( 310*523fa7a6SAndroid Build Coastguard Worker etdump_path, 311*523fa7a6SAndroid Build Coastguard Worker debug_output_path, 312*523fa7a6SAndroid Build Coastguard Worker callback=validate_intermediate_tensor, 313*523fa7a6SAndroid Build Coastguard Worker ) 314*523fa7a6SAndroid Build Coastguard Worker 315*523fa7a6SAndroid Build Coastguard Worker def lower_module_and_test_output( 316*523fa7a6SAndroid Build Coastguard Worker self, 317*523fa7a6SAndroid Build Coastguard Worker module: torch.nn.Module, 318*523fa7a6SAndroid Build Coastguard Worker sample_inputs: Tuple[torch.Tensor], 319*523fa7a6SAndroid Build Coastguard Worker expected_partitions: int = 1, 320*523fa7a6SAndroid Build Coastguard Worker expected_profile_events: int = -1, 321*523fa7a6SAndroid Build Coastguard Worker expected_intermediate_events: int = -1, 322*523fa7a6SAndroid Build Coastguard Worker assert_output_equal: bool = True, 323*523fa7a6SAndroid Build Coastguard Worker skip_node_id_set: set = None, 324*523fa7a6SAndroid Build Coastguard Worker skip_node_op_set: set = None, 325*523fa7a6SAndroid Build Coastguard Worker ): 326*523fa7a6SAndroid Build Coastguard Worker qnn_partitioner = QnnPartitioner( 327*523fa7a6SAndroid Build Coastguard Worker self.compiler_specs, skip_node_id_set, skip_node_op_set 328*523fa7a6SAndroid Build Coastguard Worker ) 329*523fa7a6SAndroid Build Coastguard Worker delegated_program = capture_program(module, sample_inputs) 330*523fa7a6SAndroid Build Coastguard Worker 331*523fa7a6SAndroid Build Coastguard Worker # this is needed for the ETRecord as lowering modifies the graph in-place 332*523fa7a6SAndroid Build Coastguard Worker edge_copy = copy.deepcopy(delegated_program) 333*523fa7a6SAndroid Build Coastguard Worker 334*523fa7a6SAndroid Build Coastguard Worker delegated_program.exported_program = to_backend( 335*523fa7a6SAndroid Build Coastguard Worker delegated_program.exported_program, qnn_partitioner 336*523fa7a6SAndroid Build Coastguard Worker ) 337*523fa7a6SAndroid Build Coastguard Worker exec_prog = delegated_program.to_executorch( 338*523fa7a6SAndroid Build Coastguard Worker exir.ExecutorchBackendConfig( 339*523fa7a6SAndroid Build Coastguard Worker # For shared buffer, user must pass the memory address 340*523fa7a6SAndroid Build Coastguard Worker # which is allocated by RPC memory to executor runner. 341*523fa7a6SAndroid Build Coastguard Worker # Therefore, won't want to pre-allocate 342*523fa7a6SAndroid Build Coastguard Worker # by memory manager in runtime. 343*523fa7a6SAndroid Build Coastguard Worker memory_planning_pass=MemoryPlanningPass( 344*523fa7a6SAndroid Build Coastguard Worker alloc_graph_input=not self.shared_buffer, 345*523fa7a6SAndroid Build Coastguard Worker alloc_graph_output=not self.shared_buffer, 346*523fa7a6SAndroid Build Coastguard Worker ), 347*523fa7a6SAndroid Build Coastguard Worker ) 348*523fa7a6SAndroid Build Coastguard Worker ) 349*523fa7a6SAndroid Build Coastguard Worker 350*523fa7a6SAndroid Build Coastguard Worker # Assert the backend name is qnn 351*523fa7a6SAndroid Build Coastguard Worker self.assertEqual( 352*523fa7a6SAndroid Build Coastguard Worker len(exec_prog.program.execution_plan[0].delegates), 353*523fa7a6SAndroid Build Coastguard Worker expected_partitions, 354*523fa7a6SAndroid Build Coastguard Worker ) 355*523fa7a6SAndroid Build Coastguard Worker for i in range(expected_partitions): 356*523fa7a6SAndroid Build Coastguard Worker self.assertEqual( 357*523fa7a6SAndroid Build Coastguard Worker exec_prog.program.execution_plan[0].delegates[i].id, 358*523fa7a6SAndroid Build Coastguard Worker QnnBackend.__name__, 359*523fa7a6SAndroid Build Coastguard Worker ) 360*523fa7a6SAndroid Build Coastguard Worker 361*523fa7a6SAndroid Build Coastguard Worker etrecord_path = "etrecord.bin" 362*523fa7a6SAndroid Build Coastguard Worker if self.enable_profile: 363*523fa7a6SAndroid Build Coastguard Worker generate_etrecord(etrecord_path, edge_copy, exec_prog) 364*523fa7a6SAndroid Build Coastguard Worker # Check numerics 365*523fa7a6SAndroid Build Coastguard Worker if ( 366*523fa7a6SAndroid Build Coastguard Worker assert_output_equal 367*523fa7a6SAndroid Build Coastguard Worker or expected_profile_events != -1 368*523fa7a6SAndroid Build Coastguard Worker or expected_intermediate_events != -1 369*523fa7a6SAndroid Build Coastguard Worker ): 370*523fa7a6SAndroid Build Coastguard Worker self.verify_output( 371*523fa7a6SAndroid Build Coastguard Worker module, 372*523fa7a6SAndroid Build Coastguard Worker sample_inputs, 373*523fa7a6SAndroid Build Coastguard Worker exec_prog, 374*523fa7a6SAndroid Build Coastguard Worker etrecord_path, 375*523fa7a6SAndroid Build Coastguard Worker expected_profile_events, 376*523fa7a6SAndroid Build Coastguard Worker expected_intermediate_events, 377*523fa7a6SAndroid Build Coastguard Worker ) 378*523fa7a6SAndroid Build Coastguard Worker 379*523fa7a6SAndroid Build Coastguard Worker def get_qdq_module( 380*523fa7a6SAndroid Build Coastguard Worker self, 381*523fa7a6SAndroid Build Coastguard Worker module: torch.nn.Module, 382*523fa7a6SAndroid Build Coastguard Worker inputs: Tuple[torch.Tensor], 383*523fa7a6SAndroid Build Coastguard Worker is_conv_per_channel: Optional[bool] = True, 384*523fa7a6SAndroid Build Coastguard Worker is_linear_per_channel: Optional[bool] = False, 385*523fa7a6SAndroid Build Coastguard Worker custom_quant_annotations: Tuple[Callable] = (), 386*523fa7a6SAndroid Build Coastguard Worker quant_dtype: QuantDtype = QuantDtype.use_8a8w, 387*523fa7a6SAndroid Build Coastguard Worker ) -> torch.fx.GraphModule: 388*523fa7a6SAndroid Build Coastguard Worker m = torch.export.export(module, inputs).module() 389*523fa7a6SAndroid Build Coastguard Worker 390*523fa7a6SAndroid Build Coastguard Worker quantizer = QnnQuantizer() 391*523fa7a6SAndroid Build Coastguard Worker quantizer.add_custom_quant_annotations(custom_quant_annotations) 392*523fa7a6SAndroid Build Coastguard Worker quantizer.set_per_channel_conv_quant(is_conv_per_channel) 393*523fa7a6SAndroid Build Coastguard Worker quantizer.set_per_channel_linear_quant(is_linear_per_channel) 394*523fa7a6SAndroid Build Coastguard Worker quantizer.set_quant_config(quant_dtype) 395*523fa7a6SAndroid Build Coastguard Worker 396*523fa7a6SAndroid Build Coastguard Worker prepared = prepare_pt2e(m, quantizer) 397*523fa7a6SAndroid Build Coastguard Worker prepared(*inputs) 398*523fa7a6SAndroid Build Coastguard Worker quantized_module = convert_pt2e(prepared) 399*523fa7a6SAndroid Build Coastguard Worker nodes = {node.target for node in quantized_module.graph.nodes} 400*523fa7a6SAndroid Build Coastguard Worker q_and_dq = { 401*523fa7a6SAndroid Build Coastguard Worker torch.ops.quantized_decomposed.quantize_per_tensor.default, 402*523fa7a6SAndroid Build Coastguard Worker torch.ops.quantized_decomposed.dequantize_per_tensor.default, 403*523fa7a6SAndroid Build Coastguard Worker torch.ops.quantized_decomposed.quantize_per_channel.default, 404*523fa7a6SAndroid Build Coastguard Worker torch.ops.quantized_decomposed.dequantize_per_channel.default, 405*523fa7a6SAndroid Build Coastguard Worker } 406*523fa7a6SAndroid Build Coastguard Worker self.assertTrue(nodes.intersection(q_and_dq)) 407*523fa7a6SAndroid Build Coastguard Worker return quantized_module 408*523fa7a6SAndroid Build Coastguard Worker 409*523fa7a6SAndroid Build Coastguard Worker def get_prepared_qat_module( 410*523fa7a6SAndroid Build Coastguard Worker self, 411*523fa7a6SAndroid Build Coastguard Worker module: torch.nn.Module, 412*523fa7a6SAndroid Build Coastguard Worker inputs: Tuple[torch.Tensor], 413*523fa7a6SAndroid Build Coastguard Worker is_conv_per_channel: Optional[bool] = True, 414*523fa7a6SAndroid Build Coastguard Worker is_linear_per_channel: Optional[bool] = False, 415*523fa7a6SAndroid Build Coastguard Worker custom_quant_annotations: Tuple[Callable] = (), 416*523fa7a6SAndroid Build Coastguard Worker quant_dtype: QuantDtype = QuantDtype.use_8a8w, 417*523fa7a6SAndroid Build Coastguard Worker ) -> torch.fx.GraphModule: 418*523fa7a6SAndroid Build Coastguard Worker m = torch.export.export_for_training(module, inputs).module() 419*523fa7a6SAndroid Build Coastguard Worker 420*523fa7a6SAndroid Build Coastguard Worker quantizer = QnnQuantizer() 421*523fa7a6SAndroid Build Coastguard Worker quantizer.add_custom_quant_annotations(custom_quant_annotations) 422*523fa7a6SAndroid Build Coastguard Worker quantizer.set_per_channel_conv_quant(is_conv_per_channel) 423*523fa7a6SAndroid Build Coastguard Worker quantizer.set_per_channel_linear_quant(is_linear_per_channel) 424*523fa7a6SAndroid Build Coastguard Worker 425*523fa7a6SAndroid Build Coastguard Worker if quant_dtype == QuantDtype.use_8a8w: 426*523fa7a6SAndroid Build Coastguard Worker quantizer.set_quant_config(quant_dtype, is_qat=True) 427*523fa7a6SAndroid Build Coastguard Worker else: 428*523fa7a6SAndroid Build Coastguard Worker raise RuntimeError("Shuld not be here") 429*523fa7a6SAndroid Build Coastguard Worker 430*523fa7a6SAndroid Build Coastguard Worker prepared = prepare_qat_pt2e(m, quantizer) 431*523fa7a6SAndroid Build Coastguard Worker return torch.ao.quantization.move_exported_model_to_train(prepared) 432*523fa7a6SAndroid Build Coastguard Worker 433*523fa7a6SAndroid Build Coastguard Worker def get_converted_sgd_trained_module( 434*523fa7a6SAndroid Build Coastguard Worker self, 435*523fa7a6SAndroid Build Coastguard Worker ori_module: torch.nn.Module, 436*523fa7a6SAndroid Build Coastguard Worker prepared: torch.nn.Module, 437*523fa7a6SAndroid Build Coastguard Worker inputs: Tuple[torch.Tensor], 438*523fa7a6SAndroid Build Coastguard Worker ) -> torch.fx.GraphModule: 439*523fa7a6SAndroid Build Coastguard Worker optimizer = torch.optim.SGD(prepared.parameters(), lr=0.0001) 440*523fa7a6SAndroid Build Coastguard Worker criterion = torch.nn.CrossEntropyLoss() 441*523fa7a6SAndroid Build Coastguard Worker output = prepared(*inputs) 442*523fa7a6SAndroid Build Coastguard Worker loss = criterion(output, ori_module(*inputs)) 443*523fa7a6SAndroid Build Coastguard Worker optimizer.zero_grad() 444*523fa7a6SAndroid Build Coastguard Worker loss.backward() 445*523fa7a6SAndroid Build Coastguard Worker optimizer.step() 446*523fa7a6SAndroid Build Coastguard Worker return torch.ao.quantization.quantize_pt2e.convert_pt2e(prepared) 447*523fa7a6SAndroid Build Coastguard Worker 448*523fa7a6SAndroid Build Coastguard Worker def split_graph(self, graph_module: torch.fx.GraphModule, division: int): 449*523fa7a6SAndroid Build Coastguard Worker class SplitGraph(ExportPass): 450*523fa7a6SAndroid Build Coastguard Worker """ 451*523fa7a6SAndroid Build Coastguard Worker Split graph based on number of nodes. 452*523fa7a6SAndroid Build Coastguard Worker """ 453*523fa7a6SAndroid Build Coastguard Worker 454*523fa7a6SAndroid Build Coastguard Worker def __init__(self, shares): 455*523fa7a6SAndroid Build Coastguard Worker super().__init__() 456*523fa7a6SAndroid Build Coastguard Worker self.shares = shares 457*523fa7a6SAndroid Build Coastguard Worker 458*523fa7a6SAndroid Build Coastguard Worker def _insert_clone( 459*523fa7a6SAndroid Build Coastguard Worker self, graph_module: torch.fx.GraphModule 460*523fa7a6SAndroid Build Coastguard Worker ) -> torch.fx.GraphModule: 461*523fa7a6SAndroid Build Coastguard Worker num_graph_nodes = 0 462*523fa7a6SAndroid Build Coastguard Worker for node in graph_module.graph.nodes: 463*523fa7a6SAndroid Build Coastguard Worker num_graph_nodes += 1 if node.op == "call_function" else 0 464*523fa7a6SAndroid Build Coastguard Worker 465*523fa7a6SAndroid Build Coastguard Worker if num_graph_nodes % self.shares != 0 or node.op != "call_function": 466*523fa7a6SAndroid Build Coastguard Worker continue 467*523fa7a6SAndroid Build Coastguard Worker 468*523fa7a6SAndroid Build Coastguard Worker with graph_module.graph.inserting_after(node): 469*523fa7a6SAndroid Build Coastguard Worker users = list(node.users.keys()) 470*523fa7a6SAndroid Build Coastguard Worker inserted_node = graph_module.graph.create_node( 471*523fa7a6SAndroid Build Coastguard Worker "call_function", 472*523fa7a6SAndroid Build Coastguard Worker exir_ops.edge.aten.clone.default, 473*523fa7a6SAndroid Build Coastguard Worker (node,), 474*523fa7a6SAndroid Build Coastguard Worker ) 475*523fa7a6SAndroid Build Coastguard Worker inserted_node.meta["val"] = node.meta["val"] 476*523fa7a6SAndroid Build Coastguard Worker if "quant_attrs" in node.meta: 477*523fa7a6SAndroid Build Coastguard Worker inserted_node.meta["quant_attrs"] = node.meta["quant_attrs"] 478*523fa7a6SAndroid Build Coastguard Worker for user in users: 479*523fa7a6SAndroid Build Coastguard Worker user.replace_input_with(node, inserted_node) 480*523fa7a6SAndroid Build Coastguard Worker 481*523fa7a6SAndroid Build Coastguard Worker def call(self, graph_module: torch.fx.GraphModule): 482*523fa7a6SAndroid Build Coastguard Worker self._insert_clone(graph_module) 483*523fa7a6SAndroid Build Coastguard Worker graph_module.recompile() 484*523fa7a6SAndroid Build Coastguard Worker 485*523fa7a6SAndroid Build Coastguard Worker num_graph_nodes = 0 486*523fa7a6SAndroid Build Coastguard Worker for node in graph_module.graph.nodes: 487*523fa7a6SAndroid Build Coastguard Worker num_graph_nodes += 1 if node.op == "call_function" else 0 488*523fa7a6SAndroid Build Coastguard Worker 489*523fa7a6SAndroid Build Coastguard Worker SplitGraph(-(num_graph_nodes // -division))(graph_module) 490