xref: /aosp_15_r20/external/executorch/backends/arm/test/tester/arm_tester.py (revision 523fa7a60841cd1ecfb9cc4201f1ca8b03ed023a)
1# Copyright 2024 Arm Limited and/or its affiliates.
2#
3# This source code is licensed under the BSD-style license found in the
4# LICENSE file in the root directory of this source tree.
5
6import logging
7
8from collections import Counter
9from pprint import pformat
10from typing import Any, Iterable, List, Literal, Optional, Tuple, Union
11
12import executorch.backends.xnnpack.test.tester.tester as tester
13
14import numpy as np
15import serializer.tosa_serializer as ts
16
17import torch.fx
18
19from executorch.backends.arm.arm_backend import get_intermediate_path, is_permute_memory
20from executorch.backends.arm.arm_partitioner import ArmPartitioner
21from executorch.backends.arm.quantizer.arm_quantizer import (
22    ArmQuantizer,
23    get_symmetric_quantization_config,
24)
25from executorch.backends.arm.test.common import (
26    arm_test_options,
27    current_time_formated,
28    get_option,
29)
30
31from executorch.backends.arm.test.runner_utils import (
32    _get_input_quantization_params,
33    _get_output_node,
34    _get_output_quantization_params,
35    dbg_tosa_fb_to_json,
36    RunnerUtil,
37)
38from executorch.backends.arm.tosa_mapping import extract_tensor_meta
39
40from executorch.backends.xnnpack.test.tester import Tester
41from executorch.devtools.backend_debug import get_delegation_info
42from executorch.exir import EdgeCompileConfig, ExecutorchProgramManager
43from executorch.exir.backend.compile_spec_schema import CompileSpec
44from executorch.exir.backend.partitioner import Partitioner
45from executorch.exir.lowered_backend_module import LoweredBackendModule
46
47from tabulate import tabulate
48from torch.export.graph_signature import ExportGraphSignature, InputSpec, OutputSpec
49from torch.fx import Graph
50
51logger = logging.getLogger(__name__)
52
53
54def _dump_lowered_modules_artifact(
55    path_to_dump: Optional[str],
56    artifact: ExecutorchProgramManager,
57    graph_module: torch.fx.GraphModule,
58):
59    output = "Formated Graph Signature:\n"
60    output += _format_export_graph_signature(
61        artifact.exported_program().graph_signature
62    )
63
64    def get_output_format(lowered_module) -> str | None:
65        for spec in lowered_module.compile_specs:
66            if spec.key == "output_format":
67                return spec.value.decode()
68        return None
69
70    for node in graph_module.graph.nodes:
71        if node.op == "get_attr" and node.name.startswith("lowered_module_"):
72            lowered_module = getattr(graph_module, node.name)
73            assert isinstance(
74                lowered_module, LoweredBackendModule
75            ), f"Attribute {node.name} must be of type LoweredBackendModule."
76
77            output_format = get_output_format(lowered_module)
78            if output_format == "tosa":
79                tosa_fb = lowered_module.processed_bytes
80                to_print = dbg_tosa_fb_to_json(tosa_fb)
81                to_print = pformat(to_print, compact=True, indent=1)
82                output += f"\nTOSA deserialized {node.name}: \n{to_print}\n"
83            elif output_format == "vela":
84                vela_cmd_stream = lowered_module.processed_bytes
85                output += f"\nVela command stream {node.name}: \n{vela_cmd_stream}\n"
86            else:
87                logger.warning(
88                    f"No TOSA nor Vela compile spec found in compile specs of {node.name}."
89                )
90                continue
91
92    if not output:
93        logger.warning("No output to print generated from artifact.")
94        return
95
96    _dump_str(output, path_to_dump)
97
98
99class Partition(tester.Partition):
100    def dump_artifact(self, path_to_dump: Optional[str]):
101        super().dump_artifact(path_to_dump)
102        _dump_lowered_modules_artifact(path_to_dump, self.artifact, self.graph_module)
103
104
105class ToEdgeTransformAndLower(tester.ToEdgeTransformAndLower):
106    def dump_artifact(self, path_to_dump: Optional[str]):
107        super().dump_artifact(path_to_dump)
108        _dump_lowered_modules_artifact(path_to_dump, self.artifact, self.graph_module)
109
110
111class Serialize(tester.Serialize):
112    def __init__(self, runner_util: RunnerUtil, timeout: int = 1):
113        super().__init__()
114        self.runner = runner_util
115        self.runner.set_timeout(timeout)
116
117    def run_artifact(self, inputs):
118        return self.runner.run_corstone(inputs)
119
120    def dump_artifact(self, path_to_dump: Optional[str]):
121        if not path_to_dump:
122            path_to_dump = self.path + "/program.pte"
123        super().dump_artifact(path_to_dump)
124
125
126class ToExecutorch(tester.ToExecutorch):
127    def __init__(
128        self,
129        tosa_test_util: RunnerUtil,
130        dynamic_shapes: Optional[Tuple[Any]] = None,
131    ):
132        super().__init__(dynamic_shapes)
133        self.tosa_test_util = tosa_test_util
134
135    def run_artifact(self, inputs):
136        tosa_output = self.tosa_test_util.run_tosa_ref_model(
137            inputs=inputs,
138        )
139        return tosa_output
140
141
142class InitialModel(tester.Stage):
143    def __init__(self, model: torch.nn.Module):
144        self.model = model
145
146    def run(self, artifact, inputs=None) -> None:
147        pass
148
149    @property
150    def artifact(self) -> torch.nn.Module:
151        return self.model
152
153    @property
154    def graph_module(self) -> None:
155        return None
156
157    def artifact_str(self) -> str:
158        return str(self.model)
159
160    def run_artifact(self, inputs):
161        return self.model.forward(*inputs)
162
163
164class ArmTester(Tester):
165    def __init__(
166        self,
167        model: torch.nn.Module,
168        example_inputs: Tuple[torch.Tensor],
169        compile_spec: List[CompileSpec] = None,
170        tosa_ref_model_path: str | None = None,
171    ):
172        """
173        Args:
174            model (torch.nn.Module): The model to test
175            example_inputs (Tuple[torch.Tensor]): Example inputs to the model
176            compile_spec (List[CompileSpec]): The compile spec to use
177        """
178
179        # Initiate runner_util
180        intermediate_path = get_intermediate_path(compile_spec)
181        self.runner_util = RunnerUtil(
182            intermediate_path=intermediate_path,
183            tosa_ref_model_path=tosa_ref_model_path,
184        )
185
186        self.compile_spec = compile_spec
187        super().__init__(model, example_inputs)
188        self.pipeline[self.stage_name(InitialModel)] = [
189            self.stage_name(tester.Quantize),
190            self.stage_name(tester.Export),
191        ]
192
193        # Initial model needs to be set as a *possible* but not yet added Stage, therefore add None entry.
194        self.stages[self.stage_name(InitialModel)] = None
195        self._run_stage(InitialModel(self.original_module))
196
197    def quantize(self, quantize_stage: Optional[tester.Quantize] = None):
198        if quantize_stage is None:
199            quantize_stage = tester.Quantize(
200                ArmQuantizer(),
201                get_symmetric_quantization_config(is_per_channel=False),
202            )
203        return super().quantize(quantize_stage)
204
205    def to_edge(
206        self,
207        to_edge_stage: Optional[tester.ToEdge] = None,
208        config: Optional[EdgeCompileConfig] = None,
209    ):
210        if to_edge_stage is None:
211            to_edge_stage = tester.ToEdge(config)
212        else:
213            if config is not None:
214                to_edge_stage.edge_compile_conf = config
215
216        # TODO(T182928844): Delegate dim order op to backend.
217        to_edge_stage.edge_compile_conf._skip_dim_order = True
218        return super().to_edge(to_edge_stage)
219
220    def partition(self, partition_stage: Optional[Partition] = None):
221        if partition_stage is None:
222            arm_partitioner = ArmPartitioner(compile_spec=self.compile_spec)
223            partition_stage = Partition(arm_partitioner)
224        return super().partition(partition_stage)
225
226    def to_edge_transform_and_lower(
227        self,
228        to_edge_and_lower_stage: Optional[ToEdgeTransformAndLower] = None,
229        partitioners: Optional[List[Partitioner]] = None,
230        edge_compile_config: Optional[EdgeCompileConfig] = None,
231    ):
232        if to_edge_and_lower_stage is None:
233            if partitioners is None:
234                partitioners = [ArmPartitioner(compile_spec=self.compile_spec)]
235            to_edge_and_lower_stage = ToEdgeTransformAndLower(
236                partitioners, edge_compile_config
237            )
238        else:
239            if partitioners is not None:
240                to_edge_and_lower_stage.partitioners = partitioners
241            if edge_compile_config is not None:
242                to_edge_and_lower_stage.edge_compile_conf = edge_compile_config
243        to_edge_and_lower_stage.edge_compile_conf._skip_dim_order = True
244        return super().to_edge_transform_and_lower(to_edge_and_lower_stage)
245
246    def to_executorch(self, to_executorch_stage: Optional[ToExecutorch] | None = None):
247        if to_executorch_stage is None:
248            to_executorch_stage = ToExecutorch(self.runner_util)
249        return super().to_executorch(to_executorch_stage)
250
251    def serialize(
252        self, serialize_stage: Optional[Serialize] = None, timeout: int = 120
253    ):
254        if serialize_stage is None:
255            serialize_stage = Serialize(self.runner_util, timeout=timeout)
256        assert (
257            get_intermediate_path(self.compile_spec) is not None
258        ), "Can't dump serialized file when compile specs do not contain an artifact path."
259
260        return (
261            super()
262            .serialize(serialize_stage)
263            .dump_artifact(get_intermediate_path(self.compile_spec) + "/program.pte")
264        )
265
266    def run_method_and_compare_outputs(
267        self,
268        inputs: Optional[Tuple[torch.Tensor]] = None,
269        stage: Optional[str] = None,
270        target_board: Optional[str] = "corstone-300",
271        num_runs=1,
272        atol=1e-03,
273        rtol=1e-03,
274        qtol=0,
275    ):
276        """
277        Compares the run_artifact output of 'stage' with the output of a reference stage.
278        If the model is quantized, the reference stage is the Quantize stage output.
279        Otherwise, the reference stage is the initial pytorch module.
280
281        Asserts that the outputs are equal (within tolerances).
282        Returns self to allow the function to be run in a test chain.
283
284        Args:
285            stage: (Optional[str]): The name of the stage to compare.
286                The default is the latest run stage.
287            inputs (Optional[Tuple[torch.Tensor]]): Allows you to input custom input data.
288                The default is random data.
289        """
290        edge_stage = self.stages[self.stage_name(tester.ToEdge)]
291        if edge_stage is None:
292            edge_stage = self.stages[self.stage_name(tester.ToEdgeTransformAndLower)]
293        assert (
294            self.runner_util is not None
295        ), "self.tosa_test_util is not initialized, cannot use run_method()"
296        assert (
297            edge_stage is not None
298        ), "To compare outputs, at least the ToEdge or ToEdgeTransformAndLower stage needs to be run."
299
300        stage = stage or self.cur
301        test_stage = self.stages[stage]
302        is_quantized = self.stages[self.stage_name(tester.Quantize)] is not None
303
304        exported_program = self.stages[self.stage_name(tester.Export)].artifact
305        edge_program = edge_stage.artifact.exported_program()
306        self.runner_util.init_run(
307            exported_program,
308            edge_program,
309            is_quantized,
310            target_board,
311        )
312
313        if is_quantized:
314            reference_stage = self.stages[self.stage_name(tester.Quantize)]
315            quantization_scale = self.runner_util.qp_output.scale
316        else:
317            reference_stage = self.stages[self.stage_name(InitialModel)]
318            quantization_scale = None
319
320        logger.info(
321            f"Comparing Stage '{self.stage_name(test_stage)}' with Stage '{self.stage_name(reference_stage)}'"
322        )
323        is_nhwc = is_permute_memory(self.compile_spec)
324
325        # Loop inputs and compare reference stage with the compared stage.
326        for run_iteration in range(num_runs):
327            reference_input = inputs if inputs else next(self.generate_random_inputs())
328
329            # Test parameters can include constants that are used in eager mode but are already set as attributes
330            # in TOSA. Therefore, only accept torch.Tensor inputs.
331            test_input: list[torch.Tensor] = []
332            for arg in reference_input:
333                if isinstance(arg, torch.Tensor):
334                    test_input.append(arg.clone())
335                if isinstance(arg, tuple) and isinstance(arg[0], torch.Tensor):
336                    test_input.extend([tensor.clone() for tensor in arg])
337
338            if (
339                is_nhwc
340                and test_stage == self.stages[self.stage_name(tester.ToExecutorch)]
341            ):
342                test_input = self.transpose_data_format(test_input, "NHWC")
343
344            input_shapes = [
345                generated_input.shape if hasattr(generated_input, "shape") else (1,)
346                for generated_input in reference_input
347            ]
348            input_shape_str = ", ".join([str(list(i)) for i in input_shapes])
349            logger.info(f"Run #{run_iteration}, input shapes: {input_shape_str}")
350
351            reference_output = reference_stage.run_artifact(reference_input)
352            test_output = tuple(test_stage.run_artifact(test_input))
353            if (
354                is_nhwc
355                and test_stage == self.stages[self.stage_name(tester.ToExecutorch)]
356            ):
357                test_output = self.transpose_data_format(test_output, "NCHW")
358
359            self._compare_outputs(
360                reference_output, test_output, quantization_scale, atol, rtol, qtol
361            )
362
363        return self
364
365    def get_graph(self, stage: str | None = None) -> Graph:
366        if stage is None:
367            stage = self.cur
368        artifact = self.get_artifact(stage)
369        if (
370            self.cur == self.stage_name(tester.ToEdge)
371            or self.cur == self.stage_name(Partition)
372            or self.cur == self.stage_name(ToEdgeTransformAndLower)
373        ):
374            graph = artifact.exported_program().graph
375        elif self.cur == self.stage_name(tester.Export) or self.cur == self.stage_name(
376            tester.Quantize
377        ):
378            graph = artifact.graph
379        else:
380            raise RuntimeError(
381                "Can only get a graph from Quantize, ToEdge, Export, and Partition stages."
382            )
383
384        return graph
385
386    def dump_operator_distribution(
387        self, path_to_dump: Optional[str] = None, print_table: bool = True
388    ):
389        """Dump the distribution of operators in the current stage.
390        In the partition stage, additional information is included such as the number of
391        delegates and the distribution of TOSA operators.
392        Set parameter print_table to False to dump in a parseable format.
393
394
395        Returns self for daisy-chaining.
396        """
397        line = "#" * 10
398        to_print = f"{line} {self.cur.capitalize()} Operator Distribution {line}\n"
399
400        if (
401            self.cur
402            in (
403                self.stage_name(tester.Partition),
404                self.stage_name(ToEdgeTransformAndLower),
405            )
406            and print_table
407        ):
408            graph_module = self.get_artifact().exported_program().graph_module
409            if print_table:
410                delegation_info = get_delegation_info(graph_module)
411                op_dist = delegation_info.get_operator_delegation_dataframe()
412            else:
413                op_dist = dict(_get_operator_distribution(graph_module.graph))
414            to_print += _format_dict(op_dist, print_table)
415            to_print += "\n" + _get_tosa_operator_distribution(
416                graph_module, print_table
417            )
418            to_print += "\n"
419            to_print += delegation_info.get_summary()
420        else:
421            graph = self.get_graph(self.cur)
422            op_dist = dict(_get_operator_distribution(graph))
423            if print_table:
424                op_dist = {
425                    "Operator": list(op_dist),
426                    "Count": [op_dist[key] for key in op_dist],
427                }
428            to_print += _format_dict(op_dist, print_table) + "\n"
429
430        _dump_str(to_print, path_to_dump)
431
432        return self
433
434    def dump_dtype_distribution(
435        self, path_to_dump: Optional[str] = None, print_table: bool = True
436    ):
437        """Dump a the distributions of dtypes of nodes and placeholders in the current stage.
438        Set parameter print_table to False to dump in a parseable format.
439
440        Returns self for daisy-chaining.
441        """
442
443        line = "#" * 10
444        to_print = (
445            f"{line} {self.cur.capitalize()} Placeholder Dtype Distribution {line}\n"
446        )
447
448        graph = self.get_graph(self.cur)
449        dtype_dist_placeholders, dtype_dirst_tensors = _get_dtype_distribution(graph)
450        all_dtypes = set(dtype_dist_placeholders.keys()) | set(
451            dtype_dirst_tensors.keys()
452        )
453        if print_table:
454            dtype_dist = {
455                "Dtype": all_dtypes,
456                "Placeholder Count": [
457                    (
458                        dtype_dist_placeholders[key]
459                        if key in dtype_dist_placeholders
460                        else 0
461                    )
462                    for key in all_dtypes
463                ],
464                "Tensor Count": [
465                    (dtype_dirst_tensors[key] if key in dtype_dirst_tensors else 0)
466                    for key in all_dtypes
467                ],
468            }
469        else:
470            dtype_dist = dict(dtype_dist_placeholders + dtype_dirst_tensors)
471        to_print += _format_dict(dtype_dist, print_table) + "\n"
472        _dump_str(to_print, path_to_dump)
473        return self
474
475    @staticmethod
476    def _calculate_reference_output(
477        module: Union[torch.fx.GraphModule, torch.nn.Module], inputs
478    ) -> torch.Tensor:
479        """
480        Note: I'd prefer to use the base class method here, but since it use the
481        exported program, I can't. The partitioner stage clears the state_dict
482        of the exported program, which causes an issue when evaluating the
483        module.
484        """
485
486        return module.forward(*inputs)
487
488    def transpose_data_format(
489        self, data: Tuple[torch.Tensor], to: Literal["NHWC", "NCHW"]
490    ):
491        if to == "NCHW":
492            dim_order = (0, 3, 1, 2)
493        if to == "NHWC":
494            dim_order = (0, 2, 3, 1)
495        inputs_transposed = list(data)
496        for i in range(len(data)):
497            if hasattr(data[i], "shape") and len(data[i].shape) == 4:
498                inputs_transposed[i] = np.transpose(data[i], dim_order)
499        return tuple(inputs_transposed)
500
501    def _compare_outputs(
502        self,
503        reference_output,
504        stage_output,
505        quantization_scale=None,
506        atol=1e-03,
507        rtol=1e-03,
508        qtol=0,
509    ):
510        try:
511            super()._compare_outputs(
512                reference_output, stage_output, quantization_scale, atol, rtol, qtol
513            )
514        except AssertionError as e:
515            # Capture assertion error and print more info
516            banner = "=" * 40 + "TOSA debug info" + "=" * 40
517            logger.error(banner)
518            path_to_tosa_files = self.runner_util.intermediate_path
519
520            export_stage = self.stages.get(self.stage_name(tester.Export), None)
521            quantize_stage = self.stages.get(self.stage_name(tester.Quantize), None)
522            if export_stage is not None and quantize_stage is not None:
523                output_node = _get_output_node(export_stage.artifact)
524                qp_input = _get_input_quantization_params(export_stage.artifact)
525                qp_output = _get_output_quantization_params(
526                    export_stage.artifact, output_node
527                )
528                logger.error(f"{qp_input=}")
529                logger.error(f"{qp_output=}")
530
531            logger.error(f"{path_to_tosa_files=}")
532            import os
533
534            torch.save(
535                stage_output,
536                os.path.join(path_to_tosa_files, "torch_tosa_output.pt"),
537            )
538            torch.save(
539                reference_output,
540                os.path.join(path_to_tosa_files, "torch_ref_output.pt"),
541            )
542            logger.error(f"{atol=}, {rtol=}, {qtol=}")
543            raise e
544
545
546def _get_dtype_distribution(graph: Graph) -> tuple[dict, dict]:
547    """Counts the occurences of placeholder and call_function dtypes in a graph.
548    The result is a tuple of Counters (placeholder_distribution, call_function_distribution)
549    """
550    placeholder_dtypes = []
551    call_function_dtypes = []
552    for node in graph.nodes:
553        if node.op == "placeholder":
554            placeholder_dtypes.append(str(node.meta["val"].dtype))
555        if node.op == "call_function":
556            if "val" in node.meta:
557                dtype, _, _ = extract_tensor_meta(node.meta)
558                call_function_dtypes.append(ts.DTypeNames[dtype])
559    return Counter(placeholder_dtypes), Counter(call_function_dtypes)
560
561
562def _get_operator_distribution(graph: Graph) -> dict[str, int]:
563    """Counts the occurences of operator names in a graph.
564    The result is a dict {'operator name':'number of nodes'}
565    """
566    return Counter(
567        [str(node.target) for node in list(graph.nodes) if node.op == "call_function"]
568    )
569
570
571def _format_export_graph_signature(signature: ExportGraphSignature) -> str:
572    def specs_dict(specs: list[InputSpec | OutputSpec], title: str):
573        _dict: dict[str, list] = {title: [], "arg": [], "kind": [], "target": []}
574        for i, spec in enumerate(specs):
575            _dict[title].append(i)
576            _dict["arg"].append(spec.arg)
577            _dict["kind"].append(spec.kind)
578            _dict["target"].append(spec.target if spec.target else "-")
579        return _dict
580
581    input_dict = specs_dict(signature.input_specs, "Inputs")
582    output_dict = specs_dict(signature.output_specs, "Outputs")
583
584    return f"{_format_dict(input_dict)}\n{_format_dict(output_dict)}"
585
586
587def _get_tosa_operator_distribution(
588    graph_module: torch.fx.GraphModule, print_table=False
589) -> str:
590    """Counts the occurences of operator names of all lowered modules containing
591    a TOSA flatbuffer.
592    The result is a string with the operator distribution or an error message.
593    """
594    op_list = []
595    id = 0
596    while lowered_module := getattr(graph_module, f"lowered_module_{id}", None):
597        for spec in lowered_module.compile_specs:
598            if spec.key != "output_format":
599                continue
600            if spec.value == b"tosa":
601                tosa_fb = lowered_module.processed_bytes
602                tosa_json = dbg_tosa_fb_to_json(tosa_fb)
603                for region in tosa_json["regions"]:
604                    for block in region["blocks"]:
605                        op_list.extend(
606                            [operator["op"] for operator in block["operators"]]
607                        )
608                break
609            elif spec.value == b"vela":
610                return "Can not get operator distribution for Vela command stream."
611            else:
612                return f"Unknown output format '{spec.value}'."
613        id += 1
614    if id == 0:
615        return "No delegate with name 'lowered_module_0 found in graph module."
616    op_dist = dict(Counter(op_list))
617    op_dist = {
618        "Operator": list(op_dist.keys()),
619        "Count": [item[1] for item in op_dist.items()],
620    }
621    return "TOSA operators:\n" + _format_dict(dict(op_dist), print_table)
622
623
624def _dump_str(to_print: str, path_to_dump: Optional[str] = None):
625    default_dump_path = get_option(arm_test_options.dump_path)
626    if not path_to_dump and default_dump_path:
627        path_to_dump = default_dump_path / f"ArmTester_{current_time_formated()}.log"
628    if path_to_dump:
629        with open(path_to_dump, "a") as fp:
630            fp.write(to_print)
631    else:
632        logger.info(to_print)
633
634
635def _format_dict(to_print: dict, print_table: bool = True) -> str:
636    if isinstance(list(to_print.items())[0], Iterable) and print_table:
637        return tabulate(
638            to_print, headers="keys", tablefmt="fancy_grid", maxcolwidths=35
639        )
640    else:
641        return pformat(to_print, compact=True, indent=1)
642