xref: /aosp_15_r20/external/executorch/backends/arm/test/runner_utils.py (revision 523fa7a60841cd1ecfb9cc4201f1ca8b03ed023a)
1# Copyright 2024 Arm Limited and/or its affiliates.
2#
3# This source code is licensed under the BSD-style license found in the
4# LICENSE file in the root directory of this source tree.
5
6import json
7import logging
8import os
9import re
10import shutil
11import subprocess
12import tempfile
13
14from pathlib import Path
15from typing import Dict, List, Optional, Tuple
16
17import numpy as np
18import torch
19
20from executorch.backends.arm.test.common import arm_test_options, is_option_enabled
21
22from torch.export import ExportedProgram
23from torch.fx.node import Node
24
25logger = logging.getLogger(__name__)
26logger.setLevel(logging.WARNING)
27
28
29class QuantizationParams:
30    __slots__ = ["node_name", "zp", "scale", "qmin", "qmax", "dtype"]
31
32    # todo: zps and scales can be per tensors or per channel => a list??
33    def __init__(
34        self,
35        node_name: str,
36        zp: int,
37        scale: float,
38        qmin: int,
39        qmax: int,
40        dtype: torch.dtype,
41    ):
42        self.node_name = node_name  # not need I think, but good for error check
43        self.zp = zp
44        self.scale = scale
45        self.qmin = qmin
46        self.qmax = qmax
47        self.dtype = dtype
48
49
50def _get_input_names(program: ExportedProgram) -> list[str]:
51    """
52    Get a list[str] with the names of the inputs to this model.
53
54    Args:
55        program (ExportedProgram): The program to get input names from.
56    Returns:
57        A list of strings with the names of the model input.
58    """
59    input_names = []
60
61    # E.g. bias and weights are 'placeholders' as well. This is used to
62    # get only the use inputs.
63    usr_inputs = program.graph_signature.user_inputs
64    for node in program.graph.nodes:
65        if node.op == "placeholder" and node.name in usr_inputs:
66            input_names.append(node.name)
67
68    return input_names
69
70
71def _get_input_quantization_params(
72    program: ExportedProgram,
73) -> list[QuantizationParams]:
74    """
75    Get input QuantizationParams in a program, maximum one per input to the program.
76    Args:
77        program (ExportedProgram): The program to get input quantization parameters from.
78    Returns:
79        list[QuantizationParams]: The found quantization parameters.
80    Raises:
81        RuntimeError if no quantization parameters are found.
82    """
83
84    quant_params = []
85    input_names = _get_input_names(program)
86    num_inputs = len(input_names)
87    for node in program.graph.nodes:
88        if (
89            node.target == torch.ops.quantized_decomposed.quantize_per_tensor.default
90            and node.args[0].name in input_names
91        ):
92            qp = QuantizationParams(
93                node_name=node.args[0].name,
94                scale=node.args[1],
95                zp=node.args[2],
96                qmin=node.args[3],
97                qmax=node.args[4],
98                dtype=node.args[5],
99            )
100            quant_params.append(qp)
101            if (
102                len(quant_params) == num_inputs
103            ):  # break early if we have all the inputs quantized parameters
104                break
105    if len(quant_params) == 0:
106        raise RuntimeError("No Quantization parameters found in exported model.")
107    return quant_params
108
109
110def _get_output_node(program: ExportedProgram) -> Node:
111    """
112    Get output node to this model.
113
114    Args:
115        program (ExportedProgram): The program to get output node from.
116    Returns:
117        The node that is the output of 'program'.
118    """
119
120    for node in program.graph.nodes:
121        if node.op == "output":
122            return node
123    raise RuntimeError("No output node found.")
124
125
126def _get_output_quantization_params(
127    program: ExportedProgram, output_node: Node
128) -> QuantizationParams:
129    """
130    Get output QuantizationParams from a program.
131    Args:
132        program (ExportedProgram): The program to get output quantization parameters from.
133    Returns:
134        QuantizationParams: The found quantization parameters.
135    Raises:
136        RuntimeError if no output quantization parameters are found.
137    """
138
139    quant_params = None
140    for node in program.graph.nodes:
141        if (
142            node.target == torch.ops.quantized_decomposed.dequantize_per_tensor.default
143            and node == output_node.args[0][0]
144        ):
145            quant_params = QuantizationParams(
146                node_name=node.args[0].name,
147                scale=node.args[1],
148                zp=node.args[2],
149                qmin=node.args[3],
150                qmax=node.args[4],
151                dtype=node.args[5],
152            )
153            break  # break early, there's only one output node
154    if quant_params is None:
155        raise RuntimeError("No Quantization parameters not found in exported model.")
156    return quant_params
157
158
159"""
160A class to store parameters needed for running programs, either in tosa or .pte format.
161"""
162
163
164class RunnerUtil:
165    def __init__(
166        self,
167        intermediate_path: str,
168        tosa_ref_model_path: Optional[str] = None,
169    ):
170        self.intermediate_path = intermediate_path
171        self.tosa_ref_model_path = tosa_ref_model_path or "tosa_reference_model"
172        assert os.path.exists(
173            self.intermediate_path
174        ), f"TOSA artifact path don't exist! Path: {self.intermediate_path}"
175
176        self.is_quantized: bool = False
177        self.input_names: list[str] = None
178        self.output_name: str = None
179        self.qp_input: list[QuantizationParams] = None
180        self.qp_output: QuantizationParams = None
181        self.timeout = 120
182        self.target_board: str = None
183
184        self._has_init_run = False
185
186    def init_run(
187        self,
188        exported_program: ExportedProgram,
189        edge_program: ExportedProgram,
190        is_quantized: bool,
191        target_board: str,
192    ):
193
194        if target_board not in ["corstone-300", "corstone-320"]:
195            raise RuntimeError(f"Unknown target board: {target_board}")
196
197        self.input_names = _get_input_names(edge_program)
198        self.output_node = _get_output_node(exported_program)
199        self.output_name = self.output_node.name
200        self.is_quantized = is_quantized
201        self.target_board = target_board
202
203        if is_quantized:
204            self.qp_input = _get_input_quantization_params(exported_program)
205            self.qp_output = _get_output_quantization_params(
206                exported_program, self.output_node
207            )
208        else:
209            self.qp_input = [None] * len(self.input_names)
210            self.qp_output = None
211
212        self._has_init_run = True
213
214    def set_timeout(self, timeout: int):
215        self.timeout = timeout
216
217    def run_corstone(
218        self,
219        inputs: Tuple[torch.Tensor],
220    ) -> list[torch.Tensor]:
221
222        assert (
223            self._has_init_run
224        ), "RunnerUtil needs to be initialized using init_run() before running Corstone300."
225
226        pte_path = os.path.join(self.intermediate_path, "program.pte")
227        assert os.path.exists(pte_path), f"Pte path '{pte_path}' not found."
228
229        for input_name, quant_param, data in zip(
230            self.input_names, self.qp_input, inputs
231        ):
232            save_bytes(self.intermediate_path, data, False, input_name, quant_param)
233
234        out_path = os.path.join(self.intermediate_path, "out")
235        out_path_with_suffix = out_path + "-0.bin"
236        input_paths = []
237        for name in self.input_names:
238            input_paths.append(
239                os.path.join(self.intermediate_path, f"{name}.bin"),
240            )
241        elf_path = os.path.join(
242            "cmake-out",
243            f"arm_semihosting_executor_runner_{self.target_board}",
244            "arm_executor_runner",
245        )
246        assert os.path.exists(
247            elf_path
248        ), f"Did not find build arm_executor_runner in path {elf_path}, run setup_testing.sh?"
249
250        cmd_line = f"executor_runner -m {pte_path} -o {out_path}"
251        for input_path in input_paths:
252            cmd_line += f" -i {input_path}"
253
254        ethos_u_extra_args = ""
255        if is_option_enabled(arm_test_options.fast_fvp):
256            ethos_u_extra_args = ethos_u_extra_args + "--fast"
257
258        command_args = {
259            "corstone-300": [
260                "FVP_Corstone_SSE-300_Ethos-U55",
261                "-C",
262                "ethosu.num_macs=128",
263                "-C",
264                "mps3_board.visualisation.disable-visualisation=1",
265                "-C",
266                "mps3_board.telnetterminal0.start_telnet=0",
267                "-C",
268                "mps3_board.uart0.out_file='-'",
269                "-C",
270                "cpu0.CFGITCMSZ=11",
271                "-C",
272                "cpu0.semihosting-enable=1",
273                "-C",
274                "cpu0.semihosting-stack_base=0",
275                "-C",
276                f"ethosu.extra_args='{ethos_u_extra_args}'",
277                "-C",
278                "cpu0.semihosting-heap_limit=0",
279                "-C",
280                f"cpu0.semihosting-cmd_line='{cmd_line}'",
281                "-a",
282                elf_path,
283                "--timelimit",
284                f"{self.timeout}",
285            ],
286            "corstone-320": [
287                "FVP_Corstone_SSE-320",
288                "-C",
289                "mps4_board.subsystem.ethosu.num_macs=128",
290                "-C",
291                "mps4_board.visualisation.disable-visualisation=1",
292                "-C",
293                "vis_hdlcd.disable_visualisation=1",
294                "-C",
295                "mps4_board.telnetterminal0.start_telnet=0",
296                "-C",
297                "mps4_board.uart0.out_file='-'",
298                "-C",
299                "mps4_board.uart0.unbuffered_output=1",
300                "-C",
301                "mps4_board.uart0.shutdown_on_eot=1",
302                "-C",
303                "mps4_board.subsystem.cpu0.semihosting-enable=1",
304                "-C",
305                "mps4_board.subsystem.cpu0.semihosting-stack_base=0",
306                "-C",
307                "mps4_board.subsystem.cpu0.semihosting-heap_limit=0",
308                "-C",
309                f"mps4_board.subsystem.ethosu.extra_args='{ethos_u_extra_args}'",
310                "-C",
311                f"mps4_board.subsystem.cpu0.semihosting-cmd_line='{cmd_line}'",
312                "-a",
313                elf_path,
314                "--timelimit",
315                f"{self.timeout}",
316            ],
317        }
318
319        result = _run_cmd(command_args[self.target_board], check=False)
320        if result.returncode != 0:
321            raise RuntimeError(
322                f"Failed to run {command_args[self.target_board]}\nError: {result.stderr.decode()}"
323            )
324        result_stdout = result.stdout.decode()
325
326        error_regex = r"(^[EF][: ].*$)|(^.*Hard fault.*$)|(^.*Assertion.*$)"
327
328        # Check for errors in the output
329        # regex to check for error or fault messages in stdout from FVP
330        if re.compile(error_regex, re.MULTILINE).search(result_stdout):
331            raise RuntimeError(
332                f"Corstone simulation failed:\ncmd: {command_args[self.target_board]}\n, log: \n {result_stdout}\n{result.stderr.decode()}"
333            )
334
335        tosa_ref_output = np.fromfile(out_path_with_suffix, dtype=np.float32)
336        output_shape = self.output_node.args[0][0].meta["val"].shape
337        tosa_ref_output = torch.from_numpy(tosa_ref_output).reshape(output_shape)
338        return [tosa_ref_output]
339
340    def run_tosa_ref_model(
341        self,
342        inputs: Tuple[torch.Tensor],
343    ) -> list[torch.Tensor]:
344        """
345        Run TOSA reference model using the tosa_reference_model program.
346
347        In order to do that we need:
348        1. desc.json, which points to files needed by tosa_reference_model.
349        2. output.tosa, which is the TOSA buffer that describes the model we're
350           trying to run.
351
352        These two files are created by arm_backend.py as part of partition stage
353
354        All these files are saved on disk in self.intermediate_path.
355
356        Args:
357            inputs (Tuple[torch.Tensor]): The input data to run the TOSA
358
359        Returns:
360            torch.Tensor: The output of the TOSA reference model, as a torch
361                tensor.
362
363        Here's a sample desc.json file:
364        {
365            "tosa_file": "output.tosa",
366            "ifm_name": [
367                "arg0_1"
368            ],
369            "ifm_file": [
370                "arg0_1.npy"
371            ],
372            "ofm_name": [
373                "quantized_decomposed_dequantize_per_tensor_default_1"
374            ],
375            "ofm_file": [
376                "ref-quantized_decomposed_dequantize_per_tensor_default_1.npy"
377            ],
378            "expected_return_code": 0,
379            "expected_failure": false
380        }
381
382        Todo:
383            * It would be nice to not rely on files on disk. Should be possible
384              as a next step. See:
385              https://review.mlplatform.org/plugins/gitiles/tosa/reference_model/#executable-usage
386        """
387
388        assert (
389            self._has_init_run
390        ), "RunnerUtil needs to be initialized using init_run() before running tosa reference."
391
392        all_desc_file_paths = [
393            str(path) for path in Path(self.intermediate_path).glob("desc*.json")
394        ]
395        assert (
396            all_desc_file_paths
397        ), f"No TOSA description file found in '{self.intermediate_path}'."
398        if len(all_desc_file_paths) != 1:
399            raise NotImplementedError(
400                "Graphs with more than one partition are currently not supported."
401            )
402
403        desc_file_path = all_desc_file_paths[0]
404        assert os.path.exists(
405            desc_file_path
406        ), f"desc_file_path: {desc_file_path} does not exist"
407
408        # Save the input data to disk as a .npy file, since that's what the TOSA
409        # reference model expects. Name of the file must match the name in
410        # desc.json, which is the tensor name from the graph + .npy
411        for input_name, quant_param, data in zip(
412            self.input_names, self.qp_input, inputs, strict=True
413        ):
414            save_npy(
415                self.intermediate_path, data, self.is_quantized, input_name, quant_param
416            )
417
418        # Run the TOSA reference model via command line, this will produce a
419        # .npy file with the result (aka OFM).
420        assert (
421            shutil.which(self.tosa_ref_model_path) is not None
422        ), f"tosa_reference_model tool not found, did you run examples/arm/setup.sh? Path: {self.tosa_ref_model_path}"
423        loglevel_map = {
424            logging.INFO: "INFO",
425            logging.CRITICAL: "LOW",
426            logging.ERROR: "LOW",
427            logging.WARNING: "MED",
428            logging.DEBUG: "HIGH",
429            logging.NOTSET: "MED",
430        }
431        clamped_logging_level = max(min(logger.level // 10 * 10, 50), 0)
432        cmd_ref_model = [
433            self.tosa_ref_model_path,
434            "--test_desc",
435            desc_file_path,
436            "-l",
437            loglevel_map[clamped_logging_level],
438        ]
439        _run_cmd(cmd_ref_model)
440
441        # Load desc.json, just to get the name of the output file above
442        with open(desc_file_path) as f:
443            desc_json = json.load(f)
444
445        tosa_ref_outputs = []
446        for ofm_file in desc_json["ofm_file"]:
447            ofm_file_npy = os.path.join(self.intermediate_path, ofm_file)
448
449            # Load the output file (OFM) and return it as a numpy array
450            tosa_ref_output = np.load(ofm_file_npy)
451
452            if self.is_quantized:
453                # Need to dequant back to FP32 for comparison with torch output
454                # Convert to int32 prior to dequantize the output
455                if tosa_ref_output.dtype == np.int8:
456                    tosa_ref_output = tosa_ref_output.astype(np.int32)
457                quant_param = self.qp_output
458                assert (
459                    quant_param is not None
460                ), "There are no quantization parameters, check output parameters"
461                tosa_ref_output = (tosa_ref_output - quant_param.zp) * quant_param.scale
462
463            if tosa_ref_output.dtype == np.double:
464                tosa_ref_output = tosa_ref_output.astype("float32")
465
466            # tosa_output is a numpy array, convert to torch tensor for comparison
467            tosa_ref_outputs.append(torch.from_numpy(tosa_ref_output))
468
469        return tosa_ref_outputs
470
471
472def prep_data_for_save(
473    data, is_quantized: bool, input_name: str, quant_param: QuantizationParams
474):
475    data_np = np.array(data.detach(), order="C").astype(
476        f"{data.dtype}".replace("torch.", "")
477    )
478
479    if is_quantized:
480        assert quant_param.node_name in input_name, (
481            f"The quantization params name '{quant_param.node_name}' does not "
482            f"match the input tensor name '{input_name}'."
483        )
484        data_np = (
485            ((data_np / np.float32(quant_param.scale)) + quant_param.zp)
486            .round()
487            .clip(quant_param.qmin, quant_param.qmax)
488            .astype(
489                f"{quant_param.dtype}".replace("torch.", "")
490            )  # Use string format of dtype to convert to numpy dtype
491        )
492    return data_np
493
494
495def save_npy(
496    path: str,
497    data,
498    is_quantized: bool,
499    input_name: str,
500    quant_param: QuantizationParams,
501) -> str:
502    """Serializes and saves 'data' as a .npy file, possibly quantizing it before.
503
504    Parameters:
505        path: the directory where to save the data.
506        data: the data to save.
507        is_quantized: whether to quantize the data before saving it.
508        input_name: the name of the file, without file-ending.
509        quant_param: the parameters to use for quantization.
510    Returns:
511        the full file path of the output.
512    """
513    data_np = prep_data_for_save(data, is_quantized, input_name, quant_param)
514    file_path = os.path.join(path, input_name + ".npy")
515    np.save(file_path, data_np, allow_pickle=False)
516
517    return file_path
518
519
520def save_bytes(
521    path: str,
522    data,
523    is_quantized: bool,
524    input_name: str,
525    quant_param: QuantizationParams,
526) -> str:
527    """Serializes and saves 'data' in byte format, possibly quantizing it before.
528
529    Parameters:
530        path: the directory where to save the data.
531        data: the data to save.
532        is_quantized: whether to quantize the data before saving it.
533        input_name: the name of the file, without file-ending.
534        quant_param: the parameters to use for quantization.
535    Returns:
536        the full file path of the output.
537    """
538    data_np = prep_data_for_save(data, is_quantized, input_name, quant_param)
539    file_path = os.path.join(path, input_name + ".bin")
540    with open(file_path, "w+b") as f:
541        data_np_bytes = data_np.tobytes()
542        f.write(data_np_bytes)
543
544    return file_path
545
546
547def _run_cmd(cmd: List[str], check=True) -> subprocess.CompletedProcess[bytes]:
548    """
549    Run a command and check for errors.
550
551    Args:
552    cmd (List[str]): The command to run as a list.
553    """
554    try:
555        result = subprocess.run(cmd, check=check, capture_output=True)
556        return result
557    except subprocess.CalledProcessError as e:
558        arg_string = " ".join(cmd)
559        raise RuntimeError(
560            f"Failed running command {arg_string}\nStderr: {e.stderr.decode()}\nStdout: {e.stdout.decode()}"
561        )
562
563
564def dbg_tosa_fb_to_json(tosa_fb: bytes) -> Dict:
565    """
566    This function is used to dump the TOSA flatbuffer to a human readable
567    format, using flatc. It is used for debugging purposes.
568    """
569
570    tmp = tempfile.mkdtemp()
571    tosa_input_file = os.path.join(tmp, "output.tosa")
572    with open(tosa_input_file, "wb") as f:
573        f.write(tosa_fb)
574
575    arm_backend_path = os.path.realpath(os.path.dirname(__file__) + "/..")
576    tosa_schema_file = os.path.join(
577        arm_backend_path, "third-party/serialization_lib/schema/tosa.fbs"
578    )
579    assert os.path.exists(
580        tosa_schema_file
581    ), f"tosa_schema_file: {tosa_schema_file} does not exist"
582
583    assert shutil.which("flatc") is not None
584    cmd_flatc = [
585        "flatc",
586        "--json",
587        "--strict-json",
588        "-o",
589        tmp,
590        "--raw-binary",
591        "-t",
592        tosa_schema_file,
593        "--",
594        tosa_input_file,
595    ]
596    _run_cmd(cmd_flatc)
597    with open(os.path.join(tmp, "output.json"), "r") as f:
598        json_out = json.load(f)
599
600    # Cast float tensors to proper dtype.
601    try:
602        for region in json_out["regions"]:
603            for block in region["blocks"]:
604                for tensor in block["tensors"]:
605                    if "data" in tensor:
606                        if tensor["type"] == "FP32":
607                            data = np.array(tensor["data"])
608                            data = data.astype(np.int8)
609                            data = np.frombuffer(data, dtype=np.float32)
610                        data = data.reshape(tensor["shape"])
611                        tensor["data"] = data
612    except Exception:
613        # This is just nice-to-have if it works, don't care if it fails.
614        pass
615
616    return json_out
617