1*523fa7a6SAndroid Build Coastguard Worker# Copyright (c) Meta Platforms, Inc. and affiliates. 2*523fa7a6SAndroid Build Coastguard Worker# All rights reserved. 3*523fa7a6SAndroid Build Coastguard Worker# 4*523fa7a6SAndroid Build Coastguard Worker# This source code is licensed under the BSD-style license found in the 5*523fa7a6SAndroid Build Coastguard Worker# LICENSE file in the root directory of this source tree. 6*523fa7a6SAndroid Build Coastguard Worker 7*523fa7a6SAndroid Build Coastguard Workerimport argparse 8*523fa7a6SAndroid Build Coastguard Workerimport os 9*523fa7a6SAndroid Build Coastguard Workerimport sys 10*523fa7a6SAndroid Build Coastguard Workerfrom collections import defaultdict 11*523fa7a6SAndroid Build Coastguard Workerfrom typing import Any, Dict, List, Optional 12*523fa7a6SAndroid Build Coastguard Worker 13*523fa7a6SAndroid Build Coastguard Workerimport yaml 14*523fa7a6SAndroid Build Coastguard Worker 15*523fa7a6SAndroid Build Coastguard Workertry: 16*523fa7a6SAndroid Build Coastguard Worker from yaml import CSafeLoader as Loader 17*523fa7a6SAndroid Build Coastguard Workerexcept ImportError: 18*523fa7a6SAndroid Build Coastguard Worker from yaml import SafeLoader as Loader # type: ignore[misc] 19*523fa7a6SAndroid Build Coastguard Worker 20*523fa7a6SAndroid Build Coastguard Worker 21*523fa7a6SAndroid Build Coastguard Workerclass BlankLineDumper(yaml.SafeDumper): 22*523fa7a6SAndroid Build Coastguard Worker def write_line_break(self, data=None): 23*523fa7a6SAndroid Build Coastguard Worker super().write_line_break(data) 24*523fa7a6SAndroid Build Coastguard Worker # insert a new line between entries. 25*523fa7a6SAndroid Build Coastguard Worker if len(self.indents) == 1: 26*523fa7a6SAndroid Build Coastguard Worker super().write_line_break() 27*523fa7a6SAndroid Build Coastguard Worker 28*523fa7a6SAndroid Build Coastguard Worker 29*523fa7a6SAndroid Build Coastguard Workerdef merge(functions_yaml_path: str, fallback_yaml_path: Optional[str], output_dir: str): 30*523fa7a6SAndroid Build Coastguard Worker output_file = os.path.join(output_dir, "merged.yaml") 31*523fa7a6SAndroid Build Coastguard Worker 32*523fa7a6SAndroid Build Coastguard Worker def get_canonical_opname(func: object) -> str: 33*523fa7a6SAndroid Build Coastguard Worker """get the canonical name of an operator 34*523fa7a6SAndroid Build Coastguard Worker "op" and "func" are two keywords we are supporting for yaml files. 35*523fa7a6SAndroid Build Coastguard Worker To give an example: 36*523fa7a6SAndroid Build Coastguard Worker - op: add.Tensor # mostly used for binding ATen ops to kernels 37*523fa7a6SAndroid Build Coastguard Worker - func: add.Tensor(Tensor self, Tensor other, Scalar alpha) # mostly used for 38*523fa7a6SAndroid Build Coastguard Worker defining custom ops. 39*523fa7a6SAndroid Build Coastguard Worker 40*523fa7a6SAndroid Build Coastguard Worker These two will be supported 41*523fa7a6SAndroid Build Coastguard Worker Args: 42*523fa7a6SAndroid Build Coastguard Worker func (object): yaml object 43*523fa7a6SAndroid Build Coastguard Worker 44*523fa7a6SAndroid Build Coastguard Worker Returns: 45*523fa7a6SAndroid Build Coastguard Worker str: canonical name of the operator 46*523fa7a6SAndroid Build Coastguard Worker """ 47*523fa7a6SAndroid Build Coastguard Worker # pyre-ignore 48*523fa7a6SAndroid Build Coastguard Worker opname = func["op"] if "op" in func else func["func"].split("(")[0] 49*523fa7a6SAndroid Build Coastguard Worker if "::" not in opname: 50*523fa7a6SAndroid Build Coastguard Worker opname = "aten::" + opname 51*523fa7a6SAndroid Build Coastguard Worker return opname 52*523fa7a6SAndroid Build Coastguard Worker 53*523fa7a6SAndroid Build Coastguard Worker with open(functions_yaml_path) as f: 54*523fa7a6SAndroid Build Coastguard Worker functions_obj = yaml.load(f, Loader=Loader) 55*523fa7a6SAndroid Build Coastguard Worker functions_dict: Dict[str, object] = defaultdict(object) 56*523fa7a6SAndroid Build Coastguard Worker for func in functions_obj: 57*523fa7a6SAndroid Build Coastguard Worker functions_dict[get_canonical_opname(func)] = func 58*523fa7a6SAndroid Build Coastguard Worker if fallback_yaml_path is not None and os.path.exists(fallback_yaml_path): 59*523fa7a6SAndroid Build Coastguard Worker with open(fallback_yaml_path) as f: 60*523fa7a6SAndroid Build Coastguard Worker fallback_obj = yaml.load(f, Loader=Loader) 61*523fa7a6SAndroid Build Coastguard Worker for func in fallback_obj: 62*523fa7a6SAndroid Build Coastguard Worker opname = get_canonical_opname(func) 63*523fa7a6SAndroid Build Coastguard Worker if opname not in functions_dict: 64*523fa7a6SAndroid Build Coastguard Worker functions_dict[opname] = func 65*523fa7a6SAndroid Build Coastguard Worker 66*523fa7a6SAndroid Build Coastguard Worker with open(output_file, "w") as f: 67*523fa7a6SAndroid Build Coastguard Worker yaml.dump( 68*523fa7a6SAndroid Build Coastguard Worker list(functions_dict.values()), 69*523fa7a6SAndroid Build Coastguard Worker f, 70*523fa7a6SAndroid Build Coastguard Worker Dumper=BlankLineDumper, 71*523fa7a6SAndroid Build Coastguard Worker default_flow_style=False, 72*523fa7a6SAndroid Build Coastguard Worker sort_keys=False, 73*523fa7a6SAndroid Build Coastguard Worker width=1000, 74*523fa7a6SAndroid Build Coastguard Worker ) 75*523fa7a6SAndroid Build Coastguard Worker 76*523fa7a6SAndroid Build Coastguard Worker 77*523fa7a6SAndroid Build Coastguard Workerdef main(argv: List[Any]) -> None: 78*523fa7a6SAndroid Build Coastguard Worker """Merge functions.yaml and fallback yaml. The output yaml will be a union of all entries in functions.yaml and fallback yaml, with operator entries in functions.yaml overriding entries with the same op name in fallback yaml. 79*523fa7a6SAndroid Build Coastguard Worker E.g., 80*523fa7a6SAndroid Build Coastguard Worker functions.yaml: 81*523fa7a6SAndroid Build Coastguard Worker - op: add.Tensor 82*523fa7a6SAndroid Build Coastguard Worker - kernel: add_impl 83*523fa7a6SAndroid Build Coastguard Worker 84*523fa7a6SAndroid Build Coastguard Worker fallback yaml: 85*523fa7a6SAndroid Build Coastguard Worker - op: add.Tensor 86*523fa7a6SAndroid Build Coastguard Worker - kernel: add_fallback 87*523fa7a6SAndroid Build Coastguard Worker - op: relu 88*523fa7a6SAndroid Build Coastguard Worker - kernel: relu_fallback 89*523fa7a6SAndroid Build Coastguard Worker 90*523fa7a6SAndroid Build Coastguard Worker Merged: 91*523fa7a6SAndroid Build Coastguard Worker - op: add.Tensor 92*523fa7a6SAndroid Build Coastguard Worker - kernel: add_impl 93*523fa7a6SAndroid Build Coastguard Worker - op: relu 94*523fa7a6SAndroid Build Coastguard Worker - kernel: relu_fallback 95*523fa7a6SAndroid Build Coastguard Worker 96*523fa7a6SAndroid Build Coastguard Worker """ 97*523fa7a6SAndroid Build Coastguard Worker parser = argparse.ArgumentParser( 98*523fa7a6SAndroid Build Coastguard Worker description="Merge functions.yaml, custom_ops.yaml with fallback yaml, for codegen to consume." 99*523fa7a6SAndroid Build Coastguard Worker ) 100*523fa7a6SAndroid Build Coastguard Worker parser.add_argument( 101*523fa7a6SAndroid Build Coastguard Worker "--functions-yaml-path", 102*523fa7a6SAndroid Build Coastguard Worker "--functions_yaml_path", 103*523fa7a6SAndroid Build Coastguard Worker help="path to the functions.yaml file to use.", 104*523fa7a6SAndroid Build Coastguard Worker required=True, 105*523fa7a6SAndroid Build Coastguard Worker ) 106*523fa7a6SAndroid Build Coastguard Worker parser.add_argument( 107*523fa7a6SAndroid Build Coastguard Worker "--fallback-yaml-path", 108*523fa7a6SAndroid Build Coastguard Worker "--fallback_yaml_path", 109*523fa7a6SAndroid Build Coastguard Worker help="path to fallback yaml file.", 110*523fa7a6SAndroid Build Coastguard Worker required=False, 111*523fa7a6SAndroid Build Coastguard Worker ) 112*523fa7a6SAndroid Build Coastguard Worker parser.add_argument( 113*523fa7a6SAndroid Build Coastguard Worker "--output_dir", 114*523fa7a6SAndroid Build Coastguard Worker help=("The directory to store the output yaml file"), 115*523fa7a6SAndroid Build Coastguard Worker required=True, 116*523fa7a6SAndroid Build Coastguard Worker ) 117*523fa7a6SAndroid Build Coastguard Worker 118*523fa7a6SAndroid Build Coastguard Worker options = parser.parse_args(argv) 119*523fa7a6SAndroid Build Coastguard Worker assert options.functions_yaml_path is not None and os.path.exists( 120*523fa7a6SAndroid Build Coastguard Worker options.functions_yaml_path 121*523fa7a6SAndroid Build Coastguard Worker ) 122*523fa7a6SAndroid Build Coastguard Worker merge(options.functions_yaml_path, options.fallback_yaml_path, options.output_dir) 123*523fa7a6SAndroid Build Coastguard Worker 124*523fa7a6SAndroid Build Coastguard Worker 125*523fa7a6SAndroid Build Coastguard Workerif __name__ == "__main__": 126*523fa7a6SAndroid Build Coastguard Worker main(sys.argv[1:]) 127