xref: /aosp_15_r20/external/executorch/codegen/tools/merge_yaml.py (revision 523fa7a60841cd1ecfb9cc4201f1ca8b03ed023a)
1*523fa7a6SAndroid Build Coastguard Worker# Copyright (c) Meta Platforms, Inc. and affiliates.
2*523fa7a6SAndroid Build Coastguard Worker# All rights reserved.
3*523fa7a6SAndroid Build Coastguard Worker#
4*523fa7a6SAndroid Build Coastguard Worker# This source code is licensed under the BSD-style license found in the
5*523fa7a6SAndroid Build Coastguard Worker# LICENSE file in the root directory of this source tree.
6*523fa7a6SAndroid Build Coastguard Worker
7*523fa7a6SAndroid Build Coastguard Workerimport argparse
8*523fa7a6SAndroid Build Coastguard Workerimport os
9*523fa7a6SAndroid Build Coastguard Workerimport sys
10*523fa7a6SAndroid Build Coastguard Workerfrom collections import defaultdict
11*523fa7a6SAndroid Build Coastguard Workerfrom typing import Any, Dict, List, Optional
12*523fa7a6SAndroid Build Coastguard Worker
13*523fa7a6SAndroid Build Coastguard Workerimport yaml
14*523fa7a6SAndroid Build Coastguard Worker
15*523fa7a6SAndroid Build Coastguard Workertry:
16*523fa7a6SAndroid Build Coastguard Worker    from yaml import CSafeLoader as Loader
17*523fa7a6SAndroid Build Coastguard Workerexcept ImportError:
18*523fa7a6SAndroid Build Coastguard Worker    from yaml import SafeLoader as Loader  # type: ignore[misc]
19*523fa7a6SAndroid Build Coastguard Worker
20*523fa7a6SAndroid Build Coastguard Worker
21*523fa7a6SAndroid Build Coastguard Workerclass BlankLineDumper(yaml.SafeDumper):
22*523fa7a6SAndroid Build Coastguard Worker    def write_line_break(self, data=None):
23*523fa7a6SAndroid Build Coastguard Worker        super().write_line_break(data)
24*523fa7a6SAndroid Build Coastguard Worker        # insert a new line between entries.
25*523fa7a6SAndroid Build Coastguard Worker        if len(self.indents) == 1:
26*523fa7a6SAndroid Build Coastguard Worker            super().write_line_break()
27*523fa7a6SAndroid Build Coastguard Worker
28*523fa7a6SAndroid Build Coastguard Worker
29*523fa7a6SAndroid Build Coastguard Workerdef merge(functions_yaml_path: str, fallback_yaml_path: Optional[str], output_dir: str):
30*523fa7a6SAndroid Build Coastguard Worker    output_file = os.path.join(output_dir, "merged.yaml")
31*523fa7a6SAndroid Build Coastguard Worker
32*523fa7a6SAndroid Build Coastguard Worker    def get_canonical_opname(func: object) -> str:
33*523fa7a6SAndroid Build Coastguard Worker        """get the canonical name of an operator
34*523fa7a6SAndroid Build Coastguard Worker        "op" and "func" are two keywords we are supporting for yaml files.
35*523fa7a6SAndroid Build Coastguard Worker        To give an example:
36*523fa7a6SAndroid Build Coastguard Worker        - op: add.Tensor # mostly used for binding ATen ops to kernels
37*523fa7a6SAndroid Build Coastguard Worker        - func: add.Tensor(Tensor self, Tensor other, Scalar alpha) # mostly used for
38*523fa7a6SAndroid Build Coastguard Worker            defining custom ops.
39*523fa7a6SAndroid Build Coastguard Worker
40*523fa7a6SAndroid Build Coastguard Worker        These two will be supported
41*523fa7a6SAndroid Build Coastguard Worker        Args:
42*523fa7a6SAndroid Build Coastguard Worker            func (object): yaml object
43*523fa7a6SAndroid Build Coastguard Worker
44*523fa7a6SAndroid Build Coastguard Worker        Returns:
45*523fa7a6SAndroid Build Coastguard Worker            str: canonical name of the operator
46*523fa7a6SAndroid Build Coastguard Worker        """
47*523fa7a6SAndroid Build Coastguard Worker        # pyre-ignore
48*523fa7a6SAndroid Build Coastguard Worker        opname = func["op"] if "op" in func else func["func"].split("(")[0]
49*523fa7a6SAndroid Build Coastguard Worker        if "::" not in opname:
50*523fa7a6SAndroid Build Coastguard Worker            opname = "aten::" + opname
51*523fa7a6SAndroid Build Coastguard Worker        return opname
52*523fa7a6SAndroid Build Coastguard Worker
53*523fa7a6SAndroid Build Coastguard Worker    with open(functions_yaml_path) as f:
54*523fa7a6SAndroid Build Coastguard Worker        functions_obj = yaml.load(f, Loader=Loader)
55*523fa7a6SAndroid Build Coastguard Worker        functions_dict: Dict[str, object] = defaultdict(object)
56*523fa7a6SAndroid Build Coastguard Worker        for func in functions_obj:
57*523fa7a6SAndroid Build Coastguard Worker            functions_dict[get_canonical_opname(func)] = func
58*523fa7a6SAndroid Build Coastguard Worker    if fallback_yaml_path is not None and os.path.exists(fallback_yaml_path):
59*523fa7a6SAndroid Build Coastguard Worker        with open(fallback_yaml_path) as f:
60*523fa7a6SAndroid Build Coastguard Worker            fallback_obj = yaml.load(f, Loader=Loader)
61*523fa7a6SAndroid Build Coastguard Worker            for func in fallback_obj:
62*523fa7a6SAndroid Build Coastguard Worker                opname = get_canonical_opname(func)
63*523fa7a6SAndroid Build Coastguard Worker                if opname not in functions_dict:
64*523fa7a6SAndroid Build Coastguard Worker                    functions_dict[opname] = func
65*523fa7a6SAndroid Build Coastguard Worker
66*523fa7a6SAndroid Build Coastguard Worker    with open(output_file, "w") as f:
67*523fa7a6SAndroid Build Coastguard Worker        yaml.dump(
68*523fa7a6SAndroid Build Coastguard Worker            list(functions_dict.values()),
69*523fa7a6SAndroid Build Coastguard Worker            f,
70*523fa7a6SAndroid Build Coastguard Worker            Dumper=BlankLineDumper,
71*523fa7a6SAndroid Build Coastguard Worker            default_flow_style=False,
72*523fa7a6SAndroid Build Coastguard Worker            sort_keys=False,
73*523fa7a6SAndroid Build Coastguard Worker            width=1000,
74*523fa7a6SAndroid Build Coastguard Worker        )
75*523fa7a6SAndroid Build Coastguard Worker
76*523fa7a6SAndroid Build Coastguard Worker
77*523fa7a6SAndroid Build Coastguard Workerdef main(argv: List[Any]) -> None:
78*523fa7a6SAndroid Build Coastguard Worker    """Merge functions.yaml and fallback yaml. The output yaml will be a union of all entries in functions.yaml and fallback yaml, with operator entries in functions.yaml overriding entries with the same op name in fallback yaml.
79*523fa7a6SAndroid Build Coastguard Worker    E.g.,
80*523fa7a6SAndroid Build Coastguard Worker    functions.yaml:
81*523fa7a6SAndroid Build Coastguard Worker    - op: add.Tensor
82*523fa7a6SAndroid Build Coastguard Worker      - kernel: add_impl
83*523fa7a6SAndroid Build Coastguard Worker
84*523fa7a6SAndroid Build Coastguard Worker    fallback yaml:
85*523fa7a6SAndroid Build Coastguard Worker    - op: add.Tensor
86*523fa7a6SAndroid Build Coastguard Worker      - kernel: add_fallback
87*523fa7a6SAndroid Build Coastguard Worker    - op: relu
88*523fa7a6SAndroid Build Coastguard Worker      - kernel: relu_fallback
89*523fa7a6SAndroid Build Coastguard Worker
90*523fa7a6SAndroid Build Coastguard Worker    Merged:
91*523fa7a6SAndroid Build Coastguard Worker    - op: add.Tensor
92*523fa7a6SAndroid Build Coastguard Worker      - kernel: add_impl
93*523fa7a6SAndroid Build Coastguard Worker    - op: relu
94*523fa7a6SAndroid Build Coastguard Worker      - kernel: relu_fallback
95*523fa7a6SAndroid Build Coastguard Worker
96*523fa7a6SAndroid Build Coastguard Worker    """
97*523fa7a6SAndroid Build Coastguard Worker    parser = argparse.ArgumentParser(
98*523fa7a6SAndroid Build Coastguard Worker        description="Merge functions.yaml, custom_ops.yaml with fallback yaml, for codegen to consume."
99*523fa7a6SAndroid Build Coastguard Worker    )
100*523fa7a6SAndroid Build Coastguard Worker    parser.add_argument(
101*523fa7a6SAndroid Build Coastguard Worker        "--functions-yaml-path",
102*523fa7a6SAndroid Build Coastguard Worker        "--functions_yaml_path",
103*523fa7a6SAndroid Build Coastguard Worker        help="path to the functions.yaml file to use.",
104*523fa7a6SAndroid Build Coastguard Worker        required=True,
105*523fa7a6SAndroid Build Coastguard Worker    )
106*523fa7a6SAndroid Build Coastguard Worker    parser.add_argument(
107*523fa7a6SAndroid Build Coastguard Worker        "--fallback-yaml-path",
108*523fa7a6SAndroid Build Coastguard Worker        "--fallback_yaml_path",
109*523fa7a6SAndroid Build Coastguard Worker        help="path to fallback yaml file.",
110*523fa7a6SAndroid Build Coastguard Worker        required=False,
111*523fa7a6SAndroid Build Coastguard Worker    )
112*523fa7a6SAndroid Build Coastguard Worker    parser.add_argument(
113*523fa7a6SAndroid Build Coastguard Worker        "--output_dir",
114*523fa7a6SAndroid Build Coastguard Worker        help=("The directory to store the output yaml file"),
115*523fa7a6SAndroid Build Coastguard Worker        required=True,
116*523fa7a6SAndroid Build Coastguard Worker    )
117*523fa7a6SAndroid Build Coastguard Worker
118*523fa7a6SAndroid Build Coastguard Worker    options = parser.parse_args(argv)
119*523fa7a6SAndroid Build Coastguard Worker    assert options.functions_yaml_path is not None and os.path.exists(
120*523fa7a6SAndroid Build Coastguard Worker        options.functions_yaml_path
121*523fa7a6SAndroid Build Coastguard Worker    )
122*523fa7a6SAndroid Build Coastguard Worker    merge(options.functions_yaml_path, options.fallback_yaml_path, options.output_dir)
123*523fa7a6SAndroid Build Coastguard Worker
124*523fa7a6SAndroid Build Coastguard Worker
125*523fa7a6SAndroid Build Coastguard Workerif __name__ == "__main__":
126*523fa7a6SAndroid Build Coastguard Worker    main(sys.argv[1:])
127