xref: /aosp_15_r20/external/executorch/.ci/scripts/gather_test_models.py (revision 523fa7a60841cd1ecfb9cc4201f1ca8b03ed023a)
1#!/usr/bin/env python
2# Copyright (c) Meta Platforms, Inc. and affiliates.
3# All rights reserved.
4#
5# This source code is licensed under the BSD-style license found in the
6# LICENSE file in the root directory of this source tree.
7
8import itertools
9import json
10import os
11from typing import Any
12
13from examples.models import MODEL_NAME_TO_MODEL
14from examples.xnnpack import MODEL_NAME_TO_OPTIONS
15
16DEFAULT_RUNNERS = {
17    "linux": "linux.2xlarge",
18    "macos": "macos-m1-stable",
19}
20CUSTOM_RUNNERS = {
21    "linux": {
22        # This one runs OOM on smaller runner, the root cause is unclear (T163016365)
23        "w2l": "linux.12xlarge",
24        "ic4": "linux.12xlarge",
25        "resnet50": "linux.12xlarge",
26        "llava": "linux.12xlarge",
27        "llama3_2_vision_encoder": "linux.12xlarge",
28        # "llama3_2_text_decoder": "linux.12xlarge",  # TODO: re-enable test when Huy's change is in / model gets smaller.
29        # This one causes timeout on smaller runner, the root cause is unclear (T161064121)
30        "dl3": "linux.12xlarge",
31        "emformer_join": "linux.12xlarge",
32        "emformer_predict": "linux.12xlarge",
33    }
34}
35
36DEFAULT_TIMEOUT = 90
37CUSTOM_TIMEOUT = {
38    # Just some examples on how custom timeout can be set
39    "linux": {
40        "mobilebert": 90,
41        "emformer_predict": 360,
42    },
43    "macos": {
44        "mobilebert": 90,
45        "emformer_predict": 360,
46    },
47}
48
49
50def parse_args() -> Any:
51    from argparse import ArgumentParser
52
53    parser = ArgumentParser("Gather all models to test on CI for the target OS")
54    parser.add_argument(
55        "--target-os",
56        type=str,
57        default="linux",
58        help="the target OS",
59    )
60    parser.add_argument(
61        "-e",
62        "--event",
63        type=str,
64        choices=["pull_request", "push", "schedule"],
65        required=True,
66        help="GitHub CI Event. See https://docs.github.com/en/actions/using-workflows/workflow-syntax-for-github-actions#on",
67    )
68
69    return parser.parse_args()
70
71
72def set_output(name: str, val: Any) -> None:
73    """
74    Set the GitHb output so that it can be accessed by other jobs
75    """
76    print(f"Setting {val} to GitHub output")
77
78    if os.getenv("GITHUB_OUTPUT"):
79        with open(str(os.getenv("GITHUB_OUTPUT")), "a") as env:
80            print(f"{name}={val}", file=env)
81    else:
82        print(f"::set-output name={name}::{val}")
83
84
85def model_should_run_on_event(model: str, event: str) -> bool:
86    """
87    A helper function to decide whether a model should be tested on an event (pull_request/push)
88    We put higher priority and fast models to pull request and rest to push.
89    """
90    if event == "pull_request":
91        return model in ["mv3", "vit"]
92    elif event == "push":
93        # These are super slow. Only run it periodically
94        return model not in ["dl3", "edsr", "emformer_predict"]
95    else:
96        return True
97
98
99def model_should_run_on_target_os(model: str, target_os: str) -> bool:
100    """
101    A helper function to decide whether a model should be tested on a target os (linux/macos).
102    For example, a big model can be disabled in macos due to the limited macos resources.
103    """
104    if target_os == "macos":
105        return model not in ["llava"]
106    return True
107
108
109def export_models_for_ci() -> dict[str, dict]:
110    """
111    This gathers all the example models that we want to test on GitHub OSS CI
112    """
113    args = parse_args()
114    target_os = args.target_os
115    event = args.event
116
117    # This is the JSON syntax for configuration matrix used by GitHub
118    # https://docs.github.com/en/actions/using-jobs/using-a-matrix-for-your-jobs
119    models = {"include": []}
120
121    # Add MobileNet v3 for BUCK2 E2E validation (linux only)
122    if target_os == "linux":
123        for backend in ["portable", "xnnpack-quantization-delegation"]:
124            record = {
125                "build-tool": "buck2",
126                "model": "mv3",
127                "backend": backend,
128                "runner": "linux.2xlarge",
129                "timeout": DEFAULT_TIMEOUT,
130            }
131            models["include"].append(record)
132
133    # Add all models for CMake E2E validation
134    # CMake supports both linux and macos
135    for name, backend in itertools.product(
136        MODEL_NAME_TO_MODEL.keys(), ["portable", "xnnpack"]
137    ):
138        if not model_should_run_on_event(name, event):
139            continue
140
141        if not model_should_run_on_target_os(name, target_os):
142            continue
143
144        if backend == "xnnpack":
145            if name not in MODEL_NAME_TO_OPTIONS:
146                continue
147            if MODEL_NAME_TO_OPTIONS[name].quantization:
148                backend += "-quantization"
149
150            if MODEL_NAME_TO_OPTIONS[name].delegation:
151                backend += "-delegation"
152
153        record = {
154            "build-tool": "cmake",
155            "model": name,
156            "backend": backend,
157            "runner": DEFAULT_RUNNERS.get(target_os, "linux.2xlarge"),
158            "timeout": DEFAULT_TIMEOUT,
159        }
160
161        # Set the custom timeout if needed
162        if target_os in CUSTOM_TIMEOUT and name in CUSTOM_TIMEOUT[target_os]:
163            record["timeout"] = CUSTOM_TIMEOUT[target_os].get(name, DEFAULT_TIMEOUT)
164
165        # NB: Some model requires much bigger Linux runner to avoid
166        # running OOM. The team is investigating the root cause
167        if target_os in CUSTOM_RUNNERS and name in CUSTOM_RUNNERS.get(target_os, {}):
168            record["runner"] = CUSTOM_RUNNERS[target_os][name]
169
170        models["include"].append(record)
171
172    set_output("models", json.dumps(models))
173
174
175if __name__ == "__main__":
176    export_models_for_ci()
177