1import os 2import sys 3 4if os.getcwd() not in sys.path: 5 sys.path.append(os.getcwd()) 6import warnings 7 8from models.llm_models.configuration_base import BaseConfig 9 10 11# flake8: noqa: E721 12 13 14def warning_on_one_line(message, category, filename, lineno, file=None, line=None): 15 return f"{category.__name__}: {message}\n" 16 17 18warnings.formatwarning = warning_on_one_line 19 20 21def check_all_chunks_same_num_layer(num_blocks_per_chunk): 22 for i in range(1, len(num_blocks_per_chunk)): 23 if num_blocks_per_chunk[i] != num_blocks_per_chunk[0]: 24 print("num_blocks_per_chunk:", num_blocks_per_chunk) 25 raise RuntimeError( 26 "This version of the sdk doesn't support different number of " 27 "decoder layers per chunk, as shape fixer stage will fail. If you require this support," 28 " please contact Mediatek sdk owner." 29 ) 30 31 32def check_between_exclusive(num, min_, max_, message=None): 33 if not (type(num) == type(min_) == type(max_)): 34 raise TypeError( 35 f"Got different types for num ({type(num)}), min ({type(min_)}), and max ({type(max_)})" 36 ) 37 if not (min_ < num < max_): 38 if message is None: 39 raise ValueError( 40 f"Expected number between {min_} and {max_} exclusive, but got: {num}" 41 ) 42 else: 43 raise ValueError( 44 f"{message} must be between {min_} and {max_} exclusive, but got: {num}" 45 ) 46 47 48def check_between_inclusive(num, min_, max_, message=None): 49 if not (type(num) == type(min_) == type(max_)): 50 raise TypeError( 51 f"Got different types for num ({type(num)}), min ({type(min_)}), and max ({type(max_)})" 52 ) 53 if not (min_ <= num <= max_): 54 if message is None: 55 raise ValueError( 56 f"Expected number between {min_} and {max_} inclusive, but got: {num}" 57 ) 58 else: 59 raise ValueError( 60 f"{message} must be between {min_} and {max_} inclusive, but got: {num}" 61 ) 62 63 64def check_exist(file_or_folder, message=None): 65 if not os.path.exists(file_or_folder): 66 if message is None: 67 raise FileNotFoundError(f"{file_or_folder} does not exist.") 68 else: 69 raise FileNotFoundError(f"{message} does not exist: {file_or_folder}") 70 71 72def check_ext(file, ext, message=None): 73 if not file.endswith(ext): 74 if message is None: 75 raise RuntimeError(f"Expected {ext} file, but got: {file}") 76 else: 77 raise RuntimeError(f"Expected {ext} file for {message}, but got: {file}") 78 79 80def check_isdir(folder, message=None): 81 if not os.path.isdir(folder): 82 if message is None: 83 raise FileNotFoundError(f"{folder} is not a directory.") 84 else: 85 raise RuntimeError(f"Expected directory for {message}, but got: {folder}") 86 87 88def check_old_arg(path): 89 if os.path.isdir(path): 90 raise RuntimeError( 91 "This package's main usage has changed starting from v0.8.0. Please use" 92 " model's config.json as main argument instead of weight directory." 93 ) 94 95 96def check_shapes(shapes): 97 if not isinstance(shapes, list): 98 raise TypeError(f"Expected shapes to be a list, but got {type(shapes)} instead") 99 for shape in shapes: 100 if shape.count("t") != 1 or shape.count("c") != 1: 101 raise RuntimeError( 102 f"Shape {shape} is in the wrong format. Every shape needs to be of" 103 "the format: xtyc where x and y are integers. (e.g. 32t512c)" 104 ) 105 try: 106 _ = int(shape.split("t")[0]) 107 except ValueError: 108 raise RuntimeError( 109 f"Shape {shape} is in the wrong format. Every shape needs to be of" 110 "the format: xtyc where x and y are integers. (e.g. 32t512c)" 111 ) 112 113 try: 114 _ = int(shape.split("t")[1].split("c")[0]) 115 except ValueError: 116 raise RuntimeError( 117 f"Shape {shape} is in the wrong format. Every shape needs to be of" 118 "the format: xtyc where x and y are integers. (e.g. 32t512c)" 119 ) 120 121 122def check_supported_model(config): 123 SUPPORTED_MODELS = [ 124 "llama", 125 "bloom", 126 "baichuan", 127 "qwen", 128 "qwen1.5", 129 "qwen2", 130 "milm", 131 ] 132 if not isinstance(config, BaseConfig): 133 raise RuntimeError( 134 f"Unsupported config class: {type(config)}. " 135 "config needs to be subclassed from BaseConfig" 136 ) 137 138 if config.model_type not in SUPPORTED_MODELS: 139 raise RuntimeError( 140 f"Unsupported model: {config.model_type}. Supported models: " 141 f"{SUPPORTED_MODELS}" 142 ) 143 144 145def check_supported_tokenizer(config): 146 SUPPORTED_TOKENIZERS = [ 147 "default", 148 "bloom", 149 "baichuan", 150 "gpt2", 151 "gpt2_fast", 152 "qwen", 153 "qwen2", 154 "qwen2_fast", 155 "llama", 156 "pretrained_fast", 157 ] 158 if not isinstance(config, BaseConfig): 159 raise RuntimeError( 160 f"Unsupported config class: {type(config)}. " 161 "config needs to be subclassed from BaseConfig" 162 ) 163 164 if config.tokenizer not in SUPPORTED_TOKENIZERS: 165 raise RuntimeError( 166 f"Unsupported tokenizer: {config.tokenizer}. Supported tokenizers: " 167 f"{SUPPORTED_TOKENIZERS}" 168 ) 169 170 171def check_tokenizer_exist(folder): 172 model = config = False 173 for f in os.listdir(folder): 174 if f == "tokenizer.model" or f == "tokenizer.json" or f.endswith(".tiktoken"): 175 model = True 176 if f == "tokenizer_config.json": 177 config = True 178 if not model: 179 raise FileNotFoundError( 180 f"Tokenizer not found in {folder}. Expected tokenizer.model, " 181 "tokenizer.json, or tokenizer.tiktoken" 182 ) 183 if not config: 184 raise FileNotFoundError( 185 f"Tokenizer config not found in {folder}. Expected " "tokenizer_config.json" 186 ) 187 188 189def check_weights_exist(weight_dir): 190 if ( 191 len( 192 [ 193 f 194 for f in os.listdir(weight_dir) 195 if ( 196 (f.startswith("pytorch_model") and f.endswith(".bin")) 197 or (f.startswith("model") and f.endswith(".safetensors")) 198 ) 199 ] 200 ) 201 == 0 202 ): 203 raise FileNotFoundError( 204 f"No weight files found in {weight_dir}! Weight files should be either .bin or .safetensors file types." 205 ) 206 safetensors_l = [f for f in os.listdir(weight_dir) if f.endswith(".safetensors")] 207 bin_l = [ 208 f for f in os.listdir(weight_dir) if f.endswith(".bin") and "embedding" not in f 209 ] 210 if len(safetensors_l) & len(bin_l): 211 raise RuntimeError( 212 "Weights should only be in either .bin or .safetensors format, not both." 213 ) 214