xref: /aosp_15_r20/external/executorch/examples/mediatek/aot_utils/llm_utils/sanity_checks.py (revision 523fa7a60841cd1ecfb9cc4201f1ca8b03ed023a)
1import os
2import sys
3
4if os.getcwd() not in sys.path:
5    sys.path.append(os.getcwd())
6import warnings
7
8from models.llm_models.configuration_base import BaseConfig
9
10
11# flake8: noqa: E721
12
13
14def warning_on_one_line(message, category, filename, lineno, file=None, line=None):
15    return f"{category.__name__}: {message}\n"
16
17
18warnings.formatwarning = warning_on_one_line
19
20
21def check_all_chunks_same_num_layer(num_blocks_per_chunk):
22    for i in range(1, len(num_blocks_per_chunk)):
23        if num_blocks_per_chunk[i] != num_blocks_per_chunk[0]:
24            print("num_blocks_per_chunk:", num_blocks_per_chunk)
25            raise RuntimeError(
26                "This version of the sdk doesn't support different number of "
27                "decoder layers per chunk, as shape fixer stage will fail. If you require this support,"
28                " please contact Mediatek sdk owner."
29            )
30
31
32def check_between_exclusive(num, min_, max_, message=None):
33    if not (type(num) == type(min_) == type(max_)):
34        raise TypeError(
35            f"Got different types for num ({type(num)}), min ({type(min_)}), and max ({type(max_)})"
36        )
37    if not (min_ < num < max_):
38        if message is None:
39            raise ValueError(
40                f"Expected number between {min_} and {max_} exclusive, but got: {num}"
41            )
42        else:
43            raise ValueError(
44                f"{message} must be between {min_} and {max_} exclusive, but got: {num}"
45            )
46
47
48def check_between_inclusive(num, min_, max_, message=None):
49    if not (type(num) == type(min_) == type(max_)):
50        raise TypeError(
51            f"Got different types for num ({type(num)}), min ({type(min_)}), and max ({type(max_)})"
52        )
53    if not (min_ <= num <= max_):
54        if message is None:
55            raise ValueError(
56                f"Expected number between {min_} and {max_} inclusive, but got: {num}"
57            )
58        else:
59            raise ValueError(
60                f"{message} must be between {min_} and {max_} inclusive, but got: {num}"
61            )
62
63
64def check_exist(file_or_folder, message=None):
65    if not os.path.exists(file_or_folder):
66        if message is None:
67            raise FileNotFoundError(f"{file_or_folder} does not exist.")
68        else:
69            raise FileNotFoundError(f"{message} does not exist: {file_or_folder}")
70
71
72def check_ext(file, ext, message=None):
73    if not file.endswith(ext):
74        if message is None:
75            raise RuntimeError(f"Expected {ext} file, but got: {file}")
76        else:
77            raise RuntimeError(f"Expected {ext} file for {message}, but got: {file}")
78
79
80def check_isdir(folder, message=None):
81    if not os.path.isdir(folder):
82        if message is None:
83            raise FileNotFoundError(f"{folder} is not a directory.")
84        else:
85            raise RuntimeError(f"Expected directory for {message}, but got: {folder}")
86
87
88def check_old_arg(path):
89    if os.path.isdir(path):
90        raise RuntimeError(
91            "This package's main usage has changed starting from v0.8.0. Please use"
92            " model's config.json as main argument instead of weight directory."
93        )
94
95
96def check_shapes(shapes):
97    if not isinstance(shapes, list):
98        raise TypeError(f"Expected shapes to be a list, but got {type(shapes)} instead")
99    for shape in shapes:
100        if shape.count("t") != 1 or shape.count("c") != 1:
101            raise RuntimeError(
102                f"Shape {shape} is in the wrong format. Every shape needs to be of"
103                "the format: xtyc where x and y are integers. (e.g. 32t512c)"
104            )
105        try:
106            _ = int(shape.split("t")[0])
107        except ValueError:
108            raise RuntimeError(
109                f"Shape {shape} is in the wrong format. Every shape needs to be of"
110                "the format: xtyc where x and y are integers. (e.g. 32t512c)"
111            )
112
113        try:
114            _ = int(shape.split("t")[1].split("c")[0])
115        except ValueError:
116            raise RuntimeError(
117                f"Shape {shape} is in the wrong format. Every shape needs to be of"
118                "the format: xtyc where x and y are integers. (e.g. 32t512c)"
119            )
120
121
122def check_supported_model(config):
123    SUPPORTED_MODELS = [
124        "llama",
125        "bloom",
126        "baichuan",
127        "qwen",
128        "qwen1.5",
129        "qwen2",
130        "milm",
131    ]
132    if not isinstance(config, BaseConfig):
133        raise RuntimeError(
134            f"Unsupported config class: {type(config)}. "
135            "config needs to be subclassed from BaseConfig"
136        )
137
138    if config.model_type not in SUPPORTED_MODELS:
139        raise RuntimeError(
140            f"Unsupported model: {config.model_type}. Supported models: "
141            f"{SUPPORTED_MODELS}"
142        )
143
144
145def check_supported_tokenizer(config):
146    SUPPORTED_TOKENIZERS = [
147        "default",
148        "bloom",
149        "baichuan",
150        "gpt2",
151        "gpt2_fast",
152        "qwen",
153        "qwen2",
154        "qwen2_fast",
155        "llama",
156        "pretrained_fast",
157    ]
158    if not isinstance(config, BaseConfig):
159        raise RuntimeError(
160            f"Unsupported config class: {type(config)}. "
161            "config needs to be subclassed from BaseConfig"
162        )
163
164    if config.tokenizer not in SUPPORTED_TOKENIZERS:
165        raise RuntimeError(
166            f"Unsupported tokenizer: {config.tokenizer}. Supported tokenizers: "
167            f"{SUPPORTED_TOKENIZERS}"
168        )
169
170
171def check_tokenizer_exist(folder):
172    model = config = False
173    for f in os.listdir(folder):
174        if f == "tokenizer.model" or f == "tokenizer.json" or f.endswith(".tiktoken"):
175            model = True
176        if f == "tokenizer_config.json":
177            config = True
178    if not model:
179        raise FileNotFoundError(
180            f"Tokenizer not found in {folder}. Expected tokenizer.model, "
181            "tokenizer.json, or tokenizer.tiktoken"
182        )
183    if not config:
184        raise FileNotFoundError(
185            f"Tokenizer config not found in {folder}. Expected " "tokenizer_config.json"
186        )
187
188
189def check_weights_exist(weight_dir):
190    if (
191        len(
192            [
193                f
194                for f in os.listdir(weight_dir)
195                if (
196                    (f.startswith("pytorch_model") and f.endswith(".bin"))
197                    or (f.startswith("model") and f.endswith(".safetensors"))
198                )
199            ]
200        )
201        == 0
202    ):
203        raise FileNotFoundError(
204            f"No weight files found in {weight_dir}! Weight files should be either .bin or .safetensors file types."
205        )
206    safetensors_l = [f for f in os.listdir(weight_dir) if f.endswith(".safetensors")]
207    bin_l = [
208        f for f in os.listdir(weight_dir) if f.endswith(".bin") and "embedding" not in f
209    ]
210    if len(safetensors_l) & len(bin_l):
211        raise RuntimeError(
212            "Weights should only be in either .bin or .safetensors format, not both."
213        )
214