1""" 2This is a script to aggregate production ops from xplat/pytorch_models/build/all_mobile_model_configs.yaml. 3Specify the file path in the first argument. The results will be dump to model_ops.yaml 4""" 5 6import sys 7 8import yaml 9 10 11root_operators = {} 12traced_operators = {} 13kernel_metadata = {} 14 15with open(sys.argv[1]) as input_yaml_file: 16 model_infos = yaml.safe_load(input_yaml_file) 17 for info in model_infos: 18 for op in info["root_operators"]: 19 # aggregate occurance per op 20 root_operators[op] = 1 + (root_operators[op] if op in root_operators else 0) 21 for op in info["traced_operators"]: 22 # aggregate occurance per op 23 traced_operators[op] = 1 + ( 24 traced_operators[op] if op in traced_operators else 0 25 ) 26 # merge dtypes for each kernel 27 for kernal, dtypes in info["kernel_metadata"].items(): 28 new_dtypes = dtypes + ( 29 kernel_metadata[kernal] if kernal in kernel_metadata else [] 30 ) 31 kernel_metadata[kernal] = list(set(new_dtypes)) 32 33 34# Only test these built-in ops. No custom ops or non-CPU ops. 35namespaces = ["aten", "prepacked", "prim", "quantized"] 36root_operators = { 37 x: root_operators[x] for x in root_operators if x.split("::")[0] in namespaces 38} 39traced_operators = { 40 x: traced_operators[x] for x in traced_operators if x.split("::")[0] in namespaces 41} 42 43out_path = "test/mobile/model_test/model_ops.yaml" 44with open(out_path, "w") as f: 45 yaml.safe_dump({"root_operators": root_operators}, f) 46