xref: /aosp_15_r20/external/executorch/exir/dialects/edge/spec/gen.py (revision 523fa7a60841cd1ecfb9cc4201f1ca8b03ed023a)
1# Copyright (c) Meta Platforms, Inc. and affiliates.
2# All rights reserved.
3#
4# This source code is licensed under the BSD-style license found in the
5# LICENSE file in the root directory of this source tree.
6
7import argparse
8from typing import Any, Dict, IO, List, Optional, Set, Tuple
9
10import ruamel.yaml
11
12import torch
13from executorch.exir.dialects.edge.dtype.runner import DtypeRunner
14from executorch.exir.dialects.edge.dtype.supported import regular_tensor_dtypes_to_str
15from executorch.exir.dialects.edge.op.api import get_callable
16from executorch.exir.dialects.edge.op.sample_input import SAMPLE_INPUT
17from executorch.exir.dialects.edge.spec.utils import (
18    get_names_for_args_with_dtype,
19    type_aggregrate,
20)
21
22from torch.testing._internal.common_methods_invocations import op_db
23from torch.testing._internal.opinfo.core import (
24    BinaryUfuncInfo,
25    generate_elementwise_binary_with_scalar_samples,
26)
27
28# Ops taking too long to run
29BLOCKLISTED_OPS = [
30    "_native_batch_norm_legit_no_training.default",
31]
32
33name_to_opinfo = {
34    op.aten_name if op.aten_name is not None else op.name: op for op in op_db
35}
36
37# pyre-ignore
38yaml = ruamel.yaml.YAML()
39
40dtr = DtypeRunner()
41
42
43class test_case_generator:
44    def __init__(
45        self,
46        preset_types: Dict[torch.dtype, List[torch.dtype]],
47        test_case_size: List[List[int]],
48        *args,
49        **kwargs,
50    ):
51        self.preset_types = preset_types
52        self.test_case_size = test_case_size
53
54        for preset_type in self.preset_types.values():
55            if len(preset_type) != len(self.test_case_size):
56                raise Exception(
57                    "Preset type size does not match test case size, get {} and {}".format(
58                        len(preset_type), len(self.test_case_size)
59                    )
60                )
61        self.args = args
62        self.kwargs = kwargs
63
64    def get_sample_input(self, dtype: torch.dtype):
65        if dtype not in self.preset_types:
66            raise Exception(f"Unsupported type {dtype}")
67
68        yield [
69            torch.randn(tensor_size).to(preset_type)
70            for preset_type, tensor_size in zip(
71                self.preset_types[dtype], self.test_case_size
72            )
73        ] + list(self.args), self.kwargs
74
75
76preset_test_case_generators: Dict[str, test_case_generator] = {
77    "aten::lift_fresh_copy": test_case_generator(
78        {
79            k: [
80                k,
81            ]
82            for k in regular_tensor_dtypes_to_str
83        },
84        [
85            [3, 4, 10],
86        ],
87    ),
88}
89
90
91def get_func_schema(op_name_may_with_overload: str) -> torch._C.FunctionSchema:
92    """Get the function schema given a op name may or may not have overload name."""
93    if "." in op_name_may_with_overload:
94        op_name, overload_name = op_name_may_with_overload.rsplit(".", 1)
95    else:
96        op_name, overload_name = (
97            op_name_may_with_overload,
98            "",
99        )
100
101    func_schemas = torch._C._jit_get_schemas_for_operator(op_name)
102    found_overload_names = []
103    for func_schema in func_schemas:
104        found_overload_names.append(func_schema.overload_name)
105        if overload_name == func_schema.overload_name:
106            return func_schema
107
108    raise ValueError(
109        "Cannot find {} with specific overload {}. All overloads we can find are {}".format(
110            op_name, overload_name, found_overload_names
111        )
112    )
113
114
115def get_func_name_yaml(func_schema: torch._C.FunctionSchema) -> str:
116    """Return the operation name in yaml file given its function schema.
117    It should consists operator package name plus operator overload name."""
118    return func_schema.name + (
119        ".{}".format(func_schema.overload_name) if func_schema.overload_name else ""
120    )
121
122
123def get_test_gen_key(op_name: str) -> str:
124    """Map the operator name to key of test case generator.
125
126    The test case generator here can be either the preset test case generator at the top of this file, or an entry of opdb.
127
128    Will raise exception if cannot find the corresponding operator in opdb.
129    """
130    if op_name in preset_test_case_generators:
131        return op_name
132
133    opdb_key = op_name.split("::")[-1].strip("_")
134    if opdb_key.endswith("_copy"):
135        opdb_key = opdb_key[:-5]
136    elif opdb_key == "sym_size":
137        opdb_key = "resize_"
138    elif opdb_key == "sym_numel":
139        opdb_key = "abs"
140    elif opdb_key == "convolution":
141        opdb_key = "conv_transpose2d"
142    elif opdb_key == "embedding":
143        opdb_key = "nn.functional.embedding"
144
145    if opdb_key not in name_to_opinfo:
146        # current function is unsupported: can not find it in opdb
147        raise Exception(
148            "Can not find operator {} in the opdb using key {}".format(
149                op_name, opdb_key
150            )
151        )
152    return opdb_key
153
154
155def get_sample_input(key: str, overload_name: str, edge_type: torch.dtype):
156    """Given a key and a specific edge_type,
157    return a set of testcase for this operator in the certain type"""
158
159    if key in preset_test_case_generators:
160        yield next(preset_test_case_generators[key].get_sample_input(edge_type))
161    else:
162        opdb_key = key
163        op_info = name_to_opinfo[opdb_key]
164        if overload_name == "Scalar" and isinstance(op_info, BinaryUfuncInfo):
165            sample_input = next(
166                generate_elementwise_binary_with_scalar_samples(
167                    op_info,
168                    device=torch.device("cpu"),
169                    dtype=edge_type,
170                    requires_grad=False,
171                )
172            )
173        else:
174            sample_input = next(
175                op_info.sample_inputs(
176                    torch.device("cpu"), edge_type, required_grad=False
177                )
178            )
179        sample_args = [sample_input.input] + list(sample_input.args)
180        sample_kwargs = sample_input.kwargs
181        if opdb_key in ["log_softmax", "softmax"]:
182            sample_args.append(False)
183            sample_kwargs = {}
184        elif opdb_key == "resize_":
185            sample_args[-1] = 0
186        elif opdb_key == "to":
187            for dtype in regular_tensor_dtypes_to_str:
188                sample_args = sample_args[:1]
189                sample_kwargs = {"dtype": dtype}
190                yield sample_args, sample_kwargs
191        elif opdb_key == "clamp":
192            sample_args = sample_args[:1] + [1]
193        elif opdb_key == "conv_transpose2d":
194            sample_kwargs = {
195                "stride": (2, 2),
196                "padding": (2, 2),
197                "output_padding": (1, 1),
198                "groups": 1,
199                "dilation": (1, 1),
200                "transposed": True,
201            }
202        elif opdb_key == "split":
203            sample_args[1] = 1
204        elif opdb_key == "scalar_tensor":
205            del sample_kwargs["requires_grad"]
206        yield sample_args, sample_kwargs
207
208
209def in_legal_edge_type(vals: List[Any]) -> bool:
210    """Given a list of object, check the tensors in it are in edge type or not.
211    Return false if any of it not in edge type; true if otherwise"""
212    is_in_legal_type = True
213    for val in vals:
214        is_in_legal_type = is_in_legal_type and (
215            (not isinstance(val, torch.Tensor))
216            or val.dtype in regular_tensor_dtypes_to_str
217        )
218    return is_in_legal_type
219
220
221def seq(*args):
222    """Convert a list into a yaml sequence to make the yaml file more structure."""
223    s = ruamel.yaml.comments.CommentedSeq(args)
224    s.fa.set_flow_style()
225    return s
226
227
228def print_error_msg(unsupported_funcs: List[str]):
229    """Print unsupported funciton name in current model"""
230    if unsupported_funcs:
231        print("*********************************")
232        print(
233            "Unsupport following functions, please read the error messages above for details:"
234        )
235        for f in unsupported_funcs:
236            print(f)
237
238
239def is_not_dype_exception(exc: BaseException, dtype_str: str) -> bool:
240    """Check if an exception about unsupported dtype."""
241
242    # alias dtype means the alias name of dtype str, like "Boolean" is the alias name of "Bool".
243    # Set default alias_dtype as twice of str(exc) to make sure default alias dtype is not part of str(exc)
244    alias_dtype = 2 * str(exc)
245    if dtype_str == "Bool":
246        alias_dtype = "Boolean"
247
248    return not (
249        ("not supported" in str(exc) or "not implemented" in str(exc))
250        and (
251            dtype_str in str(exc)
252            or alias_dtype in str(exc)
253            or dtype_str.lower() in str(exc)
254        )
255    )
256
257
258class EdgeOpYamlInfo:
259    def __init__(
260        self,
261        func_name: str,
262        tensor_variable_names: List[str],
263        allowed_types: Set[Tuple[str, ...]],
264        inherits: str = "",
265        custom: str = "",
266    ) -> None:
267        """
268        Record all information for single function in edge.yaml file
269        func_name: name of current Edge function (e.g add.Tensor)
270        tensor_variable_names: all names for function's variable in tensor type, including inputs and outputs
271            (e.g. self, other, __ret, first two are tensor inputs and the last one is tensor output)
272        inherits/custom: the place the function is implemented; if we want to reuse the existing function,
273            set inherits as the target function (e.g. aten::add.Tensor); otherwise, set custom as the target
274            (e.g. edge::add.Tensor). Noticed that must one and only one of the inherits and custom attribute can be set.
275        allowed_types: all combinations of types tensor variables allowed. The length of each list in allow_types should
276            be same as number of variables, and each element should be one of the allowed types in string, a.k.a one of
277            the values in regular_tensor_dtypes_to_str.
278        """
279
280        self.func_name = func_name
281        self.tensor_variable_names = tensor_variable_names
282
283        assert bool(inherits) ^ bool(
284            custom
285        ), "Must set one and only one of the inherits and custom attribute."
286        self.inherits = inherits
287        self.custom = custom
288
289        assert all(
290            len(self.tensor_variable_names) == len(type_combination)
291            for type_combination in allowed_types
292        ), "{}'s tensor_variable_names length must be the same as number of allowed types, but got {} vs {}: {}.".format(
293            self.inherits,
294            self.tensor_variable_names,
295            allowed_types,
296            [
297                len(self.tensor_variable_names) == type_combination
298                for type_combination in allowed_types
299            ],
300        )
301
302        self.type_alias, self.type_constraint = type_aggregrate(allowed_types)
303
304    def to_yaml(self) -> Dict[str, Any]:
305        """Convert self to a dicitionary for yaml lib to dump"""
306        try:
307            impl_source_key = "inherits" if self.inherits else "custom"
308            impl_source_value = self.inherits if self.inherits else self.custom
309
310            type_alias_yaml = {
311                "T{}".format(i): seq(*sorted(set(ts)))
312                for i, ts in enumerate(self.type_alias)
313            }
314            type_constraint_yaml = [
315                {
316                    self.tensor_variable_names[tensor_idx]: "T{}".format(type_idx)
317                    for tensor_idx, type_idx in enumerate(self.type_constraint[j])
318                }
319                for j in range(len(self.type_constraint))
320            ]
321
322            yaml_dict: Dict[str, Any] = {
323                "func": self.func_name,
324                "namespace": "edge",
325                impl_source_key: impl_source_value,
326                "type_alias": type_alias_yaml,
327                "type_constraint": type_constraint_yaml,
328            }
329            return yaml_dict
330        except BaseException:
331            print(
332                "Operator {} inherited from {} failed convert to yaml".format(
333                    self.func_name, self.inherits
334                )
335            )
336            print(self)
337            return {}
338
339    def __str__(self) -> str:
340        my_str: str = "\nop_yaml_info: \n"
341        my_str += "name: {}\n".format(self.func_name)
342        my_str += "tensor_variable_names: {}\n".format(self.tensor_variable_names)
343        my_str += "inherits: {}\n".format(self.inherits)
344        my_str += "custom: {}\n".format(self.custom)
345        my_str += "type_alias: {}\n".format(self.type_alias)
346        my_str += "type_constraint: {}\n".format(self.type_constraint)
347        return my_str
348
349
350class EdgeYamlInfo:
351    def __init__(self):
352        """
353        All info for a single edge dialect yaml file.
354        """
355        self.all_op_yaml_info: List[EdgeOpYamlInfo] = []
356
357    def append(self, op_yaml_info: EdgeOpYamlInfo) -> None:
358        self.all_op_yaml_info.append(op_yaml_info)
359
360    def to_yaml(self, yaml_stream: IO) -> List[str]:
361        tag = "generated"
362        heading = f"# @{tag} by //executorch/exir/dialects/edge/spec/gen.py\n\n"
363
364        yaml_stream.write(heading)
365        yaml_stream.write(
366            "# This yaml file is auto-generated by //executorch/exir/dialects/edge/spec/gen.py\n"
367        )
368        yaml_stream.write("# Please do not update it manually.\n")
369        yaml_stream.write(
370            "# If anything is not up-to-date, please rerun the binary target. Optional argument: --regenerate.\n"
371        )
372
373        yaml_list: List[Dict[str, Any]] = []
374        failed_operator: List[str] = []
375        for op_yaml_info in self.all_op_yaml_info:
376            op_yaml = op_yaml_info.to_yaml()
377            if op_yaml:
378                yaml_list.append(op_yaml)
379            else:
380                failed_operator.append(op_yaml_info.inherits)
381
382        yaml_list = sorted(yaml_list, key=lambda d: d["func"])
383
384        for idx, op_yaml in enumerate(yaml_list):
385            yaml.dump(
386                [
387                    op_yaml,
388                ],
389                yaml_stream,
390            )
391            if idx != len(yaml_list) - 1:
392                yaml_stream.write("\n")
393
394        return failed_operator
395
396    def _str__(self) -> str:
397        return "\n\n".join(list(map(str, self.all_op_yaml_info)))
398
399
400def try_all_dtypes_input_samples(
401    op_name: str,
402) -> Set[Tuple[str]]:
403    """Input samples given test generate key in all possible dtypes on given operation"""
404    valid_type_combinations: Set[Tuple[str, ...]] = set()
405    assert (
406        op_name in SAMPLE_INPUT
407    ), f"{op_name} does not have a sample input in SAMPLE_INPUT."
408    inputs = SAMPLE_INPUT[op_name]
409    sample_args: List[Any] = []
410    sample_kwargs: Dict[Any, Any] = {}
411
412    result = dtr.run(op_name, inputs)
413    for success, _, valid_dtypes, _, _ in result:
414        if success and not any(dtype is None for dtype in valid_dtypes):
415            valid_type_combinations.add(
416                tuple(regular_tensor_dtypes_to_str[t] for t in valid_dtypes)
417            )
418    if not valid_type_combinations:
419        # current function is unsupported: error test case from opdb
420        print(
421            f"{op_name} is unsupported: no legal test case has been found from runner.py"
422        )
423        if (not sample_args) and (not sample_kwargs):
424            print("Can not get sample input case.")
425        else:
426            print("One of the sample inputs is", sample_args, sample_kwargs)
427    return valid_type_combinations
428
429
430def gen_op_yaml(op_name: str) -> Optional[EdgeOpYamlInfo]:
431    """Generate yaml info for given operator.
432    Arguments:
433        op_name: The name of operator. Needs to conform the convention of "<name>.<overload_name>".
434                If no overload name for the operator, needs to use "default" as overload name.
435    Return the yaml info for given operator if generation succeed. Otherwise return None.
436    """
437
438    try:
439        func_schema: torch._C.FunctionSchema = get_callable(op_name)._schema
440    except BaseException as e:
441        # Can not find operator schema, or can not find operator based on op_name.
442        # Return None to append it into unsupport_funcs and skip.
443        raise RuntimeError(f"Can not find operator schema for {op_name}") from e
444
445    valid_type_combinations = try_all_dtypes_input_samples(op_name)
446
447    if not valid_type_combinations:
448        return
449
450    func_name_yaml = get_func_name_yaml(func_schema)
451    tensor_variable_names = get_names_for_args_with_dtype(op_name, func_schema)
452    inherits = func_schema.name + (
453        ".{}".format(func_schema.overload_name) if func_schema.overload_name else ""
454    )
455
456    try:
457        op_yaml_info = EdgeOpYamlInfo(
458            func_name=func_name_yaml,
459            tensor_variable_names=tensor_variable_names,
460            inherits=inherits,
461            allowed_types=valid_type_combinations,
462        )
463    except BaseException as e:
464        # Failed to create yaml file for current function.
465        # Append it to unsupported_funcs.
466        print("Failed to create yaml file for current function:", op_name)
467        print("Error msg:", str(e))
468        return
469
470    return op_yaml_info
471
472
473def gen_edge_yaml(op_names: List[str], yaml_out_stream: IO) -> List[str]:
474    """Generate yaml file of edge dialect operators for target model.
475
476    Given a list of operator names, generate a yaml file edge.yaml that describes all allowed tensor dtypes for those operators.
477
478    Args:
479        op_names: The list of operator names.
480        yaml_out_stream: The place the yaml file will be stored. e.g. a file.
481
482    Returns:
483        A list of incompatible operators that can not be auto-generated.
484
485    """
486
487    print("************************************************************")
488    print("These are ops used by current model: ")
489    print(op_names)
490    print("************************************************************")
491
492    edge_yaml_info = EdgeYamlInfo()
493
494    # Record all functions in the model whose yaml file can not be auto-generated.
495    unsupported_funcs: List[str] = []
496
497    for i, op_name in enumerate(op_names):
498        ret = gen_op_yaml(op_name)
499        if ret is None:
500            # Skip this op. Return None means it cannot be auto-generated
501            print(f"Skipping op ({i+1}/{len(op_names)}): {op_name}")
502            unsupported_funcs.append(op_name)
503        else:
504            print(
505                f"Generating dtype constraints for op ({i+1}/{len(op_names)}): {op_name}"
506            )
507            # Append the generated yaml info for op to edge_yaml_info
508            edge_yaml_info.append(ret)
509
510    unsupported_funcs += edge_yaml_info.to_yaml(yaml_out_stream)
511    return unsupported_funcs
512
513
514def main():
515    parser = argparse.ArgumentParser(
516        description="Generate allowed tensor dtypes for core ATen ops"
517    )
518    parser.add_argument(
519        "--regenerate",
520        action="store_true",
521        help="Whether to regenerate edge.yaml, based on all edge ops used in ASR models. By default we reuses operators in existing edge.yaml file.",
522    )
523    options = parser.parse_args()
524
525    yaml_path = "executorch/exir/dialects/edge/edge.yaml"
526    if options.regenerate:
527        # TODO(larryliu0820): Use all core ATen ops here.
528        op_names = [op for op in SAMPLE_INPUT.keys() if op not in BLOCKLISTED_OPS]
529    else:
530        with open(yaml_path, "r") as f:
531            obj = yaml.load(f)
532            if not obj:
533                raise Exception("YAML file is empty!")
534            op_names = [e["inherits"] for e in obj]
535
536    with open(yaml_path, "w") as stream:
537        unsupported_funcs = gen_edge_yaml(op_names, stream)
538    print_error_msg(unsupported_funcs)
539
540
541if __name__ == "__main__":
542    main()
543