1# Copyright (c) Meta Platforms, Inc. and affiliates. 2# All rights reserved. 3# 4# This source code is licensed under the BSD-style license found in the 5# LICENSE file in the root directory of this source tree. 6 7import argparse 8from typing import Any, Dict, IO, List, Optional, Set, Tuple 9 10import ruamel.yaml 11 12import torch 13from executorch.exir.dialects.edge.dtype.runner import DtypeRunner 14from executorch.exir.dialects.edge.dtype.supported import regular_tensor_dtypes_to_str 15from executorch.exir.dialects.edge.op.api import get_callable 16from executorch.exir.dialects.edge.op.sample_input import SAMPLE_INPUT 17from executorch.exir.dialects.edge.spec.utils import ( 18 get_names_for_args_with_dtype, 19 type_aggregrate, 20) 21 22from torch.testing._internal.common_methods_invocations import op_db 23from torch.testing._internal.opinfo.core import ( 24 BinaryUfuncInfo, 25 generate_elementwise_binary_with_scalar_samples, 26) 27 28# Ops taking too long to run 29BLOCKLISTED_OPS = [ 30 "_native_batch_norm_legit_no_training.default", 31] 32 33name_to_opinfo = { 34 op.aten_name if op.aten_name is not None else op.name: op for op in op_db 35} 36 37# pyre-ignore 38yaml = ruamel.yaml.YAML() 39 40dtr = DtypeRunner() 41 42 43class test_case_generator: 44 def __init__( 45 self, 46 preset_types: Dict[torch.dtype, List[torch.dtype]], 47 test_case_size: List[List[int]], 48 *args, 49 **kwargs, 50 ): 51 self.preset_types = preset_types 52 self.test_case_size = test_case_size 53 54 for preset_type in self.preset_types.values(): 55 if len(preset_type) != len(self.test_case_size): 56 raise Exception( 57 "Preset type size does not match test case size, get {} and {}".format( 58 len(preset_type), len(self.test_case_size) 59 ) 60 ) 61 self.args = args 62 self.kwargs = kwargs 63 64 def get_sample_input(self, dtype: torch.dtype): 65 if dtype not in self.preset_types: 66 raise Exception(f"Unsupported type {dtype}") 67 68 yield [ 69 torch.randn(tensor_size).to(preset_type) 70 for preset_type, tensor_size in zip( 71 self.preset_types[dtype], self.test_case_size 72 ) 73 ] + list(self.args), self.kwargs 74 75 76preset_test_case_generators: Dict[str, test_case_generator] = { 77 "aten::lift_fresh_copy": test_case_generator( 78 { 79 k: [ 80 k, 81 ] 82 for k in regular_tensor_dtypes_to_str 83 }, 84 [ 85 [3, 4, 10], 86 ], 87 ), 88} 89 90 91def get_func_schema(op_name_may_with_overload: str) -> torch._C.FunctionSchema: 92 """Get the function schema given a op name may or may not have overload name.""" 93 if "." in op_name_may_with_overload: 94 op_name, overload_name = op_name_may_with_overload.rsplit(".", 1) 95 else: 96 op_name, overload_name = ( 97 op_name_may_with_overload, 98 "", 99 ) 100 101 func_schemas = torch._C._jit_get_schemas_for_operator(op_name) 102 found_overload_names = [] 103 for func_schema in func_schemas: 104 found_overload_names.append(func_schema.overload_name) 105 if overload_name == func_schema.overload_name: 106 return func_schema 107 108 raise ValueError( 109 "Cannot find {} with specific overload {}. All overloads we can find are {}".format( 110 op_name, overload_name, found_overload_names 111 ) 112 ) 113 114 115def get_func_name_yaml(func_schema: torch._C.FunctionSchema) -> str: 116 """Return the operation name in yaml file given its function schema. 117 It should consists operator package name plus operator overload name.""" 118 return func_schema.name + ( 119 ".{}".format(func_schema.overload_name) if func_schema.overload_name else "" 120 ) 121 122 123def get_test_gen_key(op_name: str) -> str: 124 """Map the operator name to key of test case generator. 125 126 The test case generator here can be either the preset test case generator at the top of this file, or an entry of opdb. 127 128 Will raise exception if cannot find the corresponding operator in opdb. 129 """ 130 if op_name in preset_test_case_generators: 131 return op_name 132 133 opdb_key = op_name.split("::")[-1].strip("_") 134 if opdb_key.endswith("_copy"): 135 opdb_key = opdb_key[:-5] 136 elif opdb_key == "sym_size": 137 opdb_key = "resize_" 138 elif opdb_key == "sym_numel": 139 opdb_key = "abs" 140 elif opdb_key == "convolution": 141 opdb_key = "conv_transpose2d" 142 elif opdb_key == "embedding": 143 opdb_key = "nn.functional.embedding" 144 145 if opdb_key not in name_to_opinfo: 146 # current function is unsupported: can not find it in opdb 147 raise Exception( 148 "Can not find operator {} in the opdb using key {}".format( 149 op_name, opdb_key 150 ) 151 ) 152 return opdb_key 153 154 155def get_sample_input(key: str, overload_name: str, edge_type: torch.dtype): 156 """Given a key and a specific edge_type, 157 return a set of testcase for this operator in the certain type""" 158 159 if key in preset_test_case_generators: 160 yield next(preset_test_case_generators[key].get_sample_input(edge_type)) 161 else: 162 opdb_key = key 163 op_info = name_to_opinfo[opdb_key] 164 if overload_name == "Scalar" and isinstance(op_info, BinaryUfuncInfo): 165 sample_input = next( 166 generate_elementwise_binary_with_scalar_samples( 167 op_info, 168 device=torch.device("cpu"), 169 dtype=edge_type, 170 requires_grad=False, 171 ) 172 ) 173 else: 174 sample_input = next( 175 op_info.sample_inputs( 176 torch.device("cpu"), edge_type, required_grad=False 177 ) 178 ) 179 sample_args = [sample_input.input] + list(sample_input.args) 180 sample_kwargs = sample_input.kwargs 181 if opdb_key in ["log_softmax", "softmax"]: 182 sample_args.append(False) 183 sample_kwargs = {} 184 elif opdb_key == "resize_": 185 sample_args[-1] = 0 186 elif opdb_key == "to": 187 for dtype in regular_tensor_dtypes_to_str: 188 sample_args = sample_args[:1] 189 sample_kwargs = {"dtype": dtype} 190 yield sample_args, sample_kwargs 191 elif opdb_key == "clamp": 192 sample_args = sample_args[:1] + [1] 193 elif opdb_key == "conv_transpose2d": 194 sample_kwargs = { 195 "stride": (2, 2), 196 "padding": (2, 2), 197 "output_padding": (1, 1), 198 "groups": 1, 199 "dilation": (1, 1), 200 "transposed": True, 201 } 202 elif opdb_key == "split": 203 sample_args[1] = 1 204 elif opdb_key == "scalar_tensor": 205 del sample_kwargs["requires_grad"] 206 yield sample_args, sample_kwargs 207 208 209def in_legal_edge_type(vals: List[Any]) -> bool: 210 """Given a list of object, check the tensors in it are in edge type or not. 211 Return false if any of it not in edge type; true if otherwise""" 212 is_in_legal_type = True 213 for val in vals: 214 is_in_legal_type = is_in_legal_type and ( 215 (not isinstance(val, torch.Tensor)) 216 or val.dtype in regular_tensor_dtypes_to_str 217 ) 218 return is_in_legal_type 219 220 221def seq(*args): 222 """Convert a list into a yaml sequence to make the yaml file more structure.""" 223 s = ruamel.yaml.comments.CommentedSeq(args) 224 s.fa.set_flow_style() 225 return s 226 227 228def print_error_msg(unsupported_funcs: List[str]): 229 """Print unsupported funciton name in current model""" 230 if unsupported_funcs: 231 print("*********************************") 232 print( 233 "Unsupport following functions, please read the error messages above for details:" 234 ) 235 for f in unsupported_funcs: 236 print(f) 237 238 239def is_not_dype_exception(exc: BaseException, dtype_str: str) -> bool: 240 """Check if an exception about unsupported dtype.""" 241 242 # alias dtype means the alias name of dtype str, like "Boolean" is the alias name of "Bool". 243 # Set default alias_dtype as twice of str(exc) to make sure default alias dtype is not part of str(exc) 244 alias_dtype = 2 * str(exc) 245 if dtype_str == "Bool": 246 alias_dtype = "Boolean" 247 248 return not ( 249 ("not supported" in str(exc) or "not implemented" in str(exc)) 250 and ( 251 dtype_str in str(exc) 252 or alias_dtype in str(exc) 253 or dtype_str.lower() in str(exc) 254 ) 255 ) 256 257 258class EdgeOpYamlInfo: 259 def __init__( 260 self, 261 func_name: str, 262 tensor_variable_names: List[str], 263 allowed_types: Set[Tuple[str, ...]], 264 inherits: str = "", 265 custom: str = "", 266 ) -> None: 267 """ 268 Record all information for single function in edge.yaml file 269 func_name: name of current Edge function (e.g add.Tensor) 270 tensor_variable_names: all names for function's variable in tensor type, including inputs and outputs 271 (e.g. self, other, __ret, first two are tensor inputs and the last one is tensor output) 272 inherits/custom: the place the function is implemented; if we want to reuse the existing function, 273 set inherits as the target function (e.g. aten::add.Tensor); otherwise, set custom as the target 274 (e.g. edge::add.Tensor). Noticed that must one and only one of the inherits and custom attribute can be set. 275 allowed_types: all combinations of types tensor variables allowed. The length of each list in allow_types should 276 be same as number of variables, and each element should be one of the allowed types in string, a.k.a one of 277 the values in regular_tensor_dtypes_to_str. 278 """ 279 280 self.func_name = func_name 281 self.tensor_variable_names = tensor_variable_names 282 283 assert bool(inherits) ^ bool( 284 custom 285 ), "Must set one and only one of the inherits and custom attribute." 286 self.inherits = inherits 287 self.custom = custom 288 289 assert all( 290 len(self.tensor_variable_names) == len(type_combination) 291 for type_combination in allowed_types 292 ), "{}'s tensor_variable_names length must be the same as number of allowed types, but got {} vs {}: {}.".format( 293 self.inherits, 294 self.tensor_variable_names, 295 allowed_types, 296 [ 297 len(self.tensor_variable_names) == type_combination 298 for type_combination in allowed_types 299 ], 300 ) 301 302 self.type_alias, self.type_constraint = type_aggregrate(allowed_types) 303 304 def to_yaml(self) -> Dict[str, Any]: 305 """Convert self to a dicitionary for yaml lib to dump""" 306 try: 307 impl_source_key = "inherits" if self.inherits else "custom" 308 impl_source_value = self.inherits if self.inherits else self.custom 309 310 type_alias_yaml = { 311 "T{}".format(i): seq(*sorted(set(ts))) 312 for i, ts in enumerate(self.type_alias) 313 } 314 type_constraint_yaml = [ 315 { 316 self.tensor_variable_names[tensor_idx]: "T{}".format(type_idx) 317 for tensor_idx, type_idx in enumerate(self.type_constraint[j]) 318 } 319 for j in range(len(self.type_constraint)) 320 ] 321 322 yaml_dict: Dict[str, Any] = { 323 "func": self.func_name, 324 "namespace": "edge", 325 impl_source_key: impl_source_value, 326 "type_alias": type_alias_yaml, 327 "type_constraint": type_constraint_yaml, 328 } 329 return yaml_dict 330 except BaseException: 331 print( 332 "Operator {} inherited from {} failed convert to yaml".format( 333 self.func_name, self.inherits 334 ) 335 ) 336 print(self) 337 return {} 338 339 def __str__(self) -> str: 340 my_str: str = "\nop_yaml_info: \n" 341 my_str += "name: {}\n".format(self.func_name) 342 my_str += "tensor_variable_names: {}\n".format(self.tensor_variable_names) 343 my_str += "inherits: {}\n".format(self.inherits) 344 my_str += "custom: {}\n".format(self.custom) 345 my_str += "type_alias: {}\n".format(self.type_alias) 346 my_str += "type_constraint: {}\n".format(self.type_constraint) 347 return my_str 348 349 350class EdgeYamlInfo: 351 def __init__(self): 352 """ 353 All info for a single edge dialect yaml file. 354 """ 355 self.all_op_yaml_info: List[EdgeOpYamlInfo] = [] 356 357 def append(self, op_yaml_info: EdgeOpYamlInfo) -> None: 358 self.all_op_yaml_info.append(op_yaml_info) 359 360 def to_yaml(self, yaml_stream: IO) -> List[str]: 361 tag = "generated" 362 heading = f"# @{tag} by //executorch/exir/dialects/edge/spec/gen.py\n\n" 363 364 yaml_stream.write(heading) 365 yaml_stream.write( 366 "# This yaml file is auto-generated by //executorch/exir/dialects/edge/spec/gen.py\n" 367 ) 368 yaml_stream.write("# Please do not update it manually.\n") 369 yaml_stream.write( 370 "# If anything is not up-to-date, please rerun the binary target. Optional argument: --regenerate.\n" 371 ) 372 373 yaml_list: List[Dict[str, Any]] = [] 374 failed_operator: List[str] = [] 375 for op_yaml_info in self.all_op_yaml_info: 376 op_yaml = op_yaml_info.to_yaml() 377 if op_yaml: 378 yaml_list.append(op_yaml) 379 else: 380 failed_operator.append(op_yaml_info.inherits) 381 382 yaml_list = sorted(yaml_list, key=lambda d: d["func"]) 383 384 for idx, op_yaml in enumerate(yaml_list): 385 yaml.dump( 386 [ 387 op_yaml, 388 ], 389 yaml_stream, 390 ) 391 if idx != len(yaml_list) - 1: 392 yaml_stream.write("\n") 393 394 return failed_operator 395 396 def _str__(self) -> str: 397 return "\n\n".join(list(map(str, self.all_op_yaml_info))) 398 399 400def try_all_dtypes_input_samples( 401 op_name: str, 402) -> Set[Tuple[str]]: 403 """Input samples given test generate key in all possible dtypes on given operation""" 404 valid_type_combinations: Set[Tuple[str, ...]] = set() 405 assert ( 406 op_name in SAMPLE_INPUT 407 ), f"{op_name} does not have a sample input in SAMPLE_INPUT." 408 inputs = SAMPLE_INPUT[op_name] 409 sample_args: List[Any] = [] 410 sample_kwargs: Dict[Any, Any] = {} 411 412 result = dtr.run(op_name, inputs) 413 for success, _, valid_dtypes, _, _ in result: 414 if success and not any(dtype is None for dtype in valid_dtypes): 415 valid_type_combinations.add( 416 tuple(regular_tensor_dtypes_to_str[t] for t in valid_dtypes) 417 ) 418 if not valid_type_combinations: 419 # current function is unsupported: error test case from opdb 420 print( 421 f"{op_name} is unsupported: no legal test case has been found from runner.py" 422 ) 423 if (not sample_args) and (not sample_kwargs): 424 print("Can not get sample input case.") 425 else: 426 print("One of the sample inputs is", sample_args, sample_kwargs) 427 return valid_type_combinations 428 429 430def gen_op_yaml(op_name: str) -> Optional[EdgeOpYamlInfo]: 431 """Generate yaml info for given operator. 432 Arguments: 433 op_name: The name of operator. Needs to conform the convention of "<name>.<overload_name>". 434 If no overload name for the operator, needs to use "default" as overload name. 435 Return the yaml info for given operator if generation succeed. Otherwise return None. 436 """ 437 438 try: 439 func_schema: torch._C.FunctionSchema = get_callable(op_name)._schema 440 except BaseException as e: 441 # Can not find operator schema, or can not find operator based on op_name. 442 # Return None to append it into unsupport_funcs and skip. 443 raise RuntimeError(f"Can not find operator schema for {op_name}") from e 444 445 valid_type_combinations = try_all_dtypes_input_samples(op_name) 446 447 if not valid_type_combinations: 448 return 449 450 func_name_yaml = get_func_name_yaml(func_schema) 451 tensor_variable_names = get_names_for_args_with_dtype(op_name, func_schema) 452 inherits = func_schema.name + ( 453 ".{}".format(func_schema.overload_name) if func_schema.overload_name else "" 454 ) 455 456 try: 457 op_yaml_info = EdgeOpYamlInfo( 458 func_name=func_name_yaml, 459 tensor_variable_names=tensor_variable_names, 460 inherits=inherits, 461 allowed_types=valid_type_combinations, 462 ) 463 except BaseException as e: 464 # Failed to create yaml file for current function. 465 # Append it to unsupported_funcs. 466 print("Failed to create yaml file for current function:", op_name) 467 print("Error msg:", str(e)) 468 return 469 470 return op_yaml_info 471 472 473def gen_edge_yaml(op_names: List[str], yaml_out_stream: IO) -> List[str]: 474 """Generate yaml file of edge dialect operators for target model. 475 476 Given a list of operator names, generate a yaml file edge.yaml that describes all allowed tensor dtypes for those operators. 477 478 Args: 479 op_names: The list of operator names. 480 yaml_out_stream: The place the yaml file will be stored. e.g. a file. 481 482 Returns: 483 A list of incompatible operators that can not be auto-generated. 484 485 """ 486 487 print("************************************************************") 488 print("These are ops used by current model: ") 489 print(op_names) 490 print("************************************************************") 491 492 edge_yaml_info = EdgeYamlInfo() 493 494 # Record all functions in the model whose yaml file can not be auto-generated. 495 unsupported_funcs: List[str] = [] 496 497 for i, op_name in enumerate(op_names): 498 ret = gen_op_yaml(op_name) 499 if ret is None: 500 # Skip this op. Return None means it cannot be auto-generated 501 print(f"Skipping op ({i+1}/{len(op_names)}): {op_name}") 502 unsupported_funcs.append(op_name) 503 else: 504 print( 505 f"Generating dtype constraints for op ({i+1}/{len(op_names)}): {op_name}" 506 ) 507 # Append the generated yaml info for op to edge_yaml_info 508 edge_yaml_info.append(ret) 509 510 unsupported_funcs += edge_yaml_info.to_yaml(yaml_out_stream) 511 return unsupported_funcs 512 513 514def main(): 515 parser = argparse.ArgumentParser( 516 description="Generate allowed tensor dtypes for core ATen ops" 517 ) 518 parser.add_argument( 519 "--regenerate", 520 action="store_true", 521 help="Whether to regenerate edge.yaml, based on all edge ops used in ASR models. By default we reuses operators in existing edge.yaml file.", 522 ) 523 options = parser.parse_args() 524 525 yaml_path = "executorch/exir/dialects/edge/edge.yaml" 526 if options.regenerate: 527 # TODO(larryliu0820): Use all core ATen ops here. 528 op_names = [op for op in SAMPLE_INPUT.keys() if op not in BLOCKLISTED_OPS] 529 else: 530 with open(yaml_path, "r") as f: 531 obj = yaml.load(f) 532 if not obj: 533 raise Exception("YAML file is empty!") 534 op_names = [e["inherits"] for e in obj] 535 536 with open(yaml_path, "w") as stream: 537 unsupported_funcs = gen_edge_yaml(op_names, stream) 538 print_error_msg(unsupported_funcs) 539 540 541if __name__ == "__main__": 542 main() 543