xref: /aosp_15_r20/external/pytorch/test/mobile/model_test/update_production_ops.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1"""
2This is a script to aggregate production ops from xplat/pytorch_models/build/all_mobile_model_configs.yaml.
3Specify the file path in the first argument. The results will be dump to model_ops.yaml
4"""
5
6import sys
7
8import yaml
9
10
11root_operators = {}
12traced_operators = {}
13kernel_metadata = {}
14
15with open(sys.argv[1]) as input_yaml_file:
16    model_infos = yaml.safe_load(input_yaml_file)
17    for info in model_infos:
18        for op in info["root_operators"]:
19            # aggregate occurance per op
20            root_operators[op] = 1 + (root_operators[op] if op in root_operators else 0)
21        for op in info["traced_operators"]:
22            # aggregate occurance per op
23            traced_operators[op] = 1 + (
24                traced_operators[op] if op in traced_operators else 0
25            )
26        # merge dtypes for each kernel
27        for kernal, dtypes in info["kernel_metadata"].items():
28            new_dtypes = dtypes + (
29                kernel_metadata[kernal] if kernal in kernel_metadata else []
30            )
31            kernel_metadata[kernal] = list(set(new_dtypes))
32
33
34# Only test these built-in ops. No custom ops or non-CPU ops.
35namespaces = ["aten", "prepacked", "prim", "quantized"]
36root_operators = {
37    x: root_operators[x] for x in root_operators if x.split("::")[0] in namespaces
38}
39traced_operators = {
40    x: traced_operators[x] for x in traced_operators if x.split("::")[0] in namespaces
41}
42
43out_path = "test/mobile/model_test/model_ops.yaml"
44with open(out_path, "w") as f:
45    yaml.safe_dump({"root_operators": root_operators}, f)
46