xref: /aosp_15_r20/external/executorch/backends/arm/test/common.py (revision 523fa7a60841cd1ecfb9cc4201f1ca8b03ed023a)
1# Copyright 2024 Arm Limited and/or its affiliates.
2# All rights reserved.
3#
4# This source code is licensed under the BSD-style license found in the
5# LICENSE file in the root directory of this source tree.
6
7import logging
8import os
9import platform
10import shutil
11import subprocess
12import sys
13import tempfile
14from datetime import datetime
15from enum import auto, Enum
16from pathlib import Path
17from typing import Any
18
19import pytest
20
21import torch
22
23from executorch.backends.arm.arm_backend import ArmCompileSpecBuilder
24from executorch.exir.backend.compile_spec_schema import CompileSpec
25
26
27class arm_test_options(Enum):
28    quantize_io = auto()
29    corstone300 = auto()
30    dump_path = auto()
31    date_format = auto()
32    fast_fvp = auto()
33
34
35_test_options: dict[arm_test_options, Any] = {}
36
37# ==== Pytest hooks ====
38
39
40def pytest_addoption(parser):
41    parser.addoption("--arm_quantize_io", action="store_true")
42    parser.addoption("--arm_run_corstone300", action="store_true")
43    parser.addoption("--default_dump_path", default=None)
44    parser.addoption("--date_format", default="%d-%b-%H:%M:%S")
45    parser.addoption("--fast_fvp", action="store_true")
46
47
48def pytest_configure(config):
49    if config.option.arm_quantize_io:
50        load_libquantized_ops_aot_lib()
51        _test_options[arm_test_options.quantize_io] = True
52    if config.option.arm_run_corstone300:
53        corstone300_exists = shutil.which("FVP_Corstone_SSE-300_Ethos-U55")
54        if not corstone300_exists:
55            raise RuntimeError(
56                "Tests are run with --arm_run_corstone300 but corstone300 FVP is not installed."
57            )
58        _test_options[arm_test_options.corstone300] = True
59    if config.option.default_dump_path:
60        dump_path = Path(config.option.default_dump_path).expanduser()
61        if dump_path.exists() and os.path.isdir(dump_path):
62            _test_options[arm_test_options.dump_path] = dump_path
63        else:
64            raise RuntimeError(
65                f"Supplied argument 'default_dump_path={dump_path}' that does not exist or is not a directory."
66            )
67    _test_options[arm_test_options.date_format] = config.option.date_format
68    _test_options[arm_test_options.fast_fvp] = config.option.fast_fvp
69    logging.basicConfig(level=logging.INFO, stream=sys.stdout)
70
71
72def pytest_collection_modifyitems(config, items):
73    if not config.option.arm_quantize_io:
74        skip_if_aot_lib_not_loaded = pytest.mark.skip(
75            "u55 tests can only run with quantize_io=True."
76        )
77
78        for item in items:
79            if "u55" in item.name:
80                item.add_marker(skip_if_aot_lib_not_loaded)
81
82
83def pytest_sessionstart(session):
84    pass
85
86
87def pytest_sessionfinish(session, exitstatus):
88    if get_option(arm_test_options.dump_path):
89        _clean_dir(
90            get_option(arm_test_options.dump_path),
91            f"ArmTester_{get_option(arm_test_options.date_format)}.log",
92        )
93
94
95# ==== End of Pytest hooks =====
96
97# ==== Custom Pytest decorators =====
98
99
100def expectedFailureOnFVP(test_item):
101    if is_option_enabled("corstone300"):
102        test_item.__unittest_expecting_failure__ = True
103    return test_item
104
105
106# ==== End of Custom Pytest decorators =====
107
108
109def load_libquantized_ops_aot_lib():
110    so_ext = {
111        "Darwin": "dylib",
112        "Linux": "so",
113        "Windows": "dll",
114    }.get(platform.system(), None)
115
116    find_lib_cmd = [
117        "find",
118        "cmake-out-aot-lib",
119        "-name",
120        f"libquantized_ops_aot_lib.{so_ext}",
121    ]
122    res = subprocess.run(find_lib_cmd, capture_output=True)
123    if res.returncode == 0:
124        library_path = res.stdout.decode().strip()
125        torch.ops.load_library(library_path)
126
127
128def is_option_enabled(
129    option: str | arm_test_options, fail_if_not_enabled: bool = False
130) -> bool:
131    """
132    Returns whether an option is successfully enabled, i.e. if the flag was
133    given to pytest and the necessary requirements are available.
134    Implemented options are:
135        - corstone300.
136        - quantize_io.
137
138    The optional parameter 'fail_if_not_enabled' makes the function raise
139      a RuntimeError instead of returning False.
140    """
141    if isinstance(option, str):
142        option = arm_test_options[option.lower()]
143
144    if option in _test_options and _test_options[option]:
145        return True
146    else:
147        if fail_if_not_enabled:
148            raise RuntimeError(f"Required option '{option}' for test is not enabled")
149        else:
150            return False
151
152
153def get_option(option: arm_test_options) -> Any | None:
154    if option in _test_options:
155        return _test_options[option]
156    return None
157
158
159def maybe_get_tosa_collate_path() -> str | None:
160    """
161    Checks the environment variable TOSA_TESTCASES_BASE_PATH and returns the
162    path to the where to store the current tests if it is set.
163    """
164    tosa_test_base = os.environ.get("TOSA_TESTCASES_BASE_PATH")
165    if tosa_test_base:
166        current_test = os.environ.get("PYTEST_CURRENT_TEST")
167        #'backends/arm/test/ops/test_mean_dim.py::TestMeanDim::test_meandim_tosa_BI_0_zeros (call)'
168        test_class = current_test.split("::")[1]
169        test_name = current_test.split("::")[-1].split(" ")[0]
170        if "BI" in test_name:
171            tosa_test_base = os.path.join(tosa_test_base, "tosa-bi")
172        elif "MI" in test_name:
173            tosa_test_base = os.path.join(tosa_test_base, "tosa-mi")
174        else:
175            tosa_test_base = os.path.join(tosa_test_base, "other")
176
177        return os.path.join(tosa_test_base, test_class, test_name)
178
179    return None
180
181
182def get_tosa_compile_spec(
183    tosa_version: str, permute_memory_to_nhwc=True, custom_path=None
184) -> list[CompileSpec]:
185    """
186    Default compile spec for TOSA tests.
187    """
188    return get_tosa_compile_spec_unbuilt(
189        tosa_version, permute_memory_to_nhwc, custom_path
190    ).build()
191
192
193def get_tosa_compile_spec_unbuilt(
194    tosa_version: str, permute_memory_to_nhwc=False, custom_path=None
195) -> ArmCompileSpecBuilder:
196    """Get the ArmCompileSpecBuilder for the default TOSA tests, to modify
197    the compile spec before calling .build() to finalize it.
198    """
199    if not custom_path:
200        intermediate_path = maybe_get_tosa_collate_path() or tempfile.mkdtemp(
201            prefix="arm_tosa_"
202        )
203    else:
204        intermediate_path = custom_path
205
206    if not os.path.exists(intermediate_path):
207        os.makedirs(intermediate_path, exist_ok=True)
208    compile_spec_builder = (
209        ArmCompileSpecBuilder()
210        .tosa_compile_spec(tosa_version)
211        .set_permute_memory_format(permute_memory_to_nhwc)
212        .dump_intermediate_artifacts_to(intermediate_path)
213    )
214
215    return compile_spec_builder
216
217
218def get_u55_compile_spec(
219    permute_memory_to_nhwc=True, quantize_io=False, custom_path=None
220) -> list[CompileSpec]:
221    """
222    Default compile spec for Ethos-U55 tests.
223    """
224    return get_u55_compile_spec_unbuilt(
225        permute_memory_to_nhwc, quantize_io=quantize_io, custom_path=custom_path
226    ).build()
227
228
229def get_u85_compile_spec(
230    permute_memory_to_nhwc=True, quantize_io=False, custom_path=None
231) -> list[CompileSpec]:
232    """
233    Default compile spec for Ethos-U85 tests.
234    """
235    return get_u85_compile_spec_unbuilt(
236        permute_memory_to_nhwc, quantize_io=quantize_io, custom_path=custom_path
237    ).build()
238
239
240def get_u55_compile_spec_unbuilt(
241    permute_memory_to_nhwc=True, quantize_io=False, custom_path=None
242) -> ArmCompileSpecBuilder:
243    """Get the ArmCompileSpecBuilder for the Ethos-U55 tests, to modify
244    the compile spec before calling .build() to finalize it.
245    """
246    artifact_path = custom_path or tempfile.mkdtemp(prefix="arm_u55_")
247    if not os.path.exists(artifact_path):
248        os.makedirs(artifact_path, exist_ok=True)
249    compile_spec = (
250        ArmCompileSpecBuilder()
251        .ethosu_compile_spec(
252            "ethos-u55-128",
253            system_config="Ethos_U55_High_End_Embedded",
254            memory_mode="Shared_Sram",
255            extra_flags="--debug-force-regor --output-format=raw",
256        )
257        .set_quantize_io(is_option_enabled("quantize_io") or quantize_io)
258        .set_permute_memory_format(permute_memory_to_nhwc)
259        .dump_intermediate_artifacts_to(artifact_path)
260    )
261    return compile_spec
262
263
264def get_u85_compile_spec_unbuilt(
265    permute_memory_to_nhwc=True, quantize_io=False, custom_path=None
266) -> list[CompileSpec]:
267    """Get the ArmCompileSpecBuilder for the Ethos-U85 tests, to modify
268    the compile spec before calling .build() to finalize it.
269    """
270    artifact_path = custom_path or tempfile.mkdtemp(prefix="arm_u85_")
271    compile_spec = (
272        ArmCompileSpecBuilder()
273        .ethosu_compile_spec(
274            "ethos-u85-128",
275            system_config="Ethos_U85_SYS_DRAM_Mid",
276            memory_mode="Shared_Sram",
277            extra_flags="--output-format=raw",
278        )
279        .set_quantize_io(is_option_enabled("quantize_io") or quantize_io)
280        .set_permute_memory_format(permute_memory_to_nhwc)
281        .dump_intermediate_artifacts_to(artifact_path)
282    )
283    return compile_spec
284
285
286def current_time_formated() -> str:
287    """Return current time as a formated string"""
288    return datetime.now().strftime(get_option(arm_test_options.date_format))
289
290
291def _clean_dir(dir: Path, filter: str, num_save=10):
292    sorted_files: list[tuple[datetime, Path]] = []
293    for file in dir.iterdir():
294        try:
295            creation_time = datetime.strptime(file.name, filter)
296            insert_index = -1
297            for i, to_compare in enumerate(sorted_files):
298                compare_time = to_compare[0]
299                if creation_time < compare_time:
300                    insert_index = i
301                    break
302            if insert_index == -1 and len(sorted_files) < num_save:
303                sorted_files.append((creation_time, file))
304            else:
305                sorted_files.insert(insert_index, (creation_time, file))
306        except ValueError:
307            continue
308
309    if len(sorted_files) > num_save:
310        for remove in sorted_files[0 : len(sorted_files) - num_save]:
311            file = remove[1]
312            file.unlink()
313