xref: /aosp_15_r20/external/pytorch/tools/code_analyzer/gen_op_registration_allowlist.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1"""
2This util is invoked from cmake to produce the op registration allowlist param
3for `ATen/gen.py` for custom mobile build.
4For custom build with dynamic dispatch, it takes the op dependency graph of ATen
5and the list of root ops, and outputs all transitive dependencies of the root
6ops as the allowlist.
7For custom build with static dispatch, the op dependency graph will be omitted,
8and it will directly output root ops as the allowlist.
9"""
10
11from __future__ import annotations
12
13import argparse
14from collections import defaultdict
15from typing import Dict, Set
16
17import yaml
18
19
20DepGraph = Dict[str, Set[str]]
21
22
23def canonical_name(opname: str) -> str:
24    # Skip the overload name part as it's not supported by code analyzer yet.
25    return opname.split(".", 1)[0]
26
27
28def load_op_dep_graph(fname: str) -> DepGraph:
29    with open(fname) as stream:
30        result = defaultdict(set)
31        for op in yaml.safe_load(stream):
32            op_name = canonical_name(op["name"])
33            for dep in op.get("depends", []):
34                dep_name = canonical_name(dep["name"])
35                result[op_name].add(dep_name)
36        return dict(result)
37
38
39def load_root_ops(fname: str) -> list[str]:
40    result = []
41    with open(fname) as stream:
42        for op in yaml.safe_load(stream):
43            result.append(canonical_name(op))
44    return result
45
46
47def gen_transitive_closure(
48    dep_graph: DepGraph,
49    root_ops: list[str],
50    train: bool = False,
51) -> list[str]:
52    result = set(root_ops)
53    queue = root_ops.copy()
54
55    # The dependency graph might contain a special entry with key = `__BASE__`
56    # and value = (set of `base` ops to always include in custom build).
57    queue.append("__BASE__")
58
59    # The dependency graph might contain a special entry with key = `__ROOT__`
60    # and value = (set of ops reachable from C++ functions). Insert the special
61    # `__ROOT__` key to include ops which can be called from C++ code directly,
62    # in addition to ops that are called from TorchScript model.
63    # '__ROOT__' is only needed for full-jit. Keep it only for training.
64    # TODO: when FL is migrated from full-jit to lite trainer, remove '__ROOT__'
65    if train:
66        queue.append("__ROOT__")
67
68    while queue:
69        cur = queue.pop()
70        for dep in dep_graph.get(cur, []):
71            if dep not in result:
72                result.add(dep)
73                queue.append(dep)
74
75    return sorted(result)
76
77
78def gen_transitive_closure_str(dep_graph: DepGraph, root_ops: list[str]) -> str:
79    return " ".join(gen_transitive_closure(dep_graph, root_ops))
80
81
82if __name__ == "__main__":
83    parser = argparse.ArgumentParser(
84        description="Util to produce transitive dependencies for custom build"
85    )
86    parser.add_argument(
87        "--op-dependency",
88        help="input yaml file of op dependency graph "
89        "- can be omitted for custom build with static dispatch",
90    )
91    parser.add_argument(
92        "--root-ops",
93        required=True,
94        help="input yaml file of root (directly used) operators",
95    )
96    args = parser.parse_args()
97
98    deps = load_op_dep_graph(args.op_dependency) if args.op_dependency else {}
99    root_ops = load_root_ops(args.root_ops)
100    print(gen_transitive_closure_str(deps, root_ops))
101