1""" 2This util is invoked from cmake to produce the op registration allowlist param 3for `ATen/gen.py` for custom mobile build. 4For custom build with dynamic dispatch, it takes the op dependency graph of ATen 5and the list of root ops, and outputs all transitive dependencies of the root 6ops as the allowlist. 7For custom build with static dispatch, the op dependency graph will be omitted, 8and it will directly output root ops as the allowlist. 9""" 10 11from __future__ import annotations 12 13import argparse 14from collections import defaultdict 15from typing import Dict, Set 16 17import yaml 18 19 20DepGraph = Dict[str, Set[str]] 21 22 23def canonical_name(opname: str) -> str: 24 # Skip the overload name part as it's not supported by code analyzer yet. 25 return opname.split(".", 1)[0] 26 27 28def load_op_dep_graph(fname: str) -> DepGraph: 29 with open(fname) as stream: 30 result = defaultdict(set) 31 for op in yaml.safe_load(stream): 32 op_name = canonical_name(op["name"]) 33 for dep in op.get("depends", []): 34 dep_name = canonical_name(dep["name"]) 35 result[op_name].add(dep_name) 36 return dict(result) 37 38 39def load_root_ops(fname: str) -> list[str]: 40 result = [] 41 with open(fname) as stream: 42 for op in yaml.safe_load(stream): 43 result.append(canonical_name(op)) 44 return result 45 46 47def gen_transitive_closure( 48 dep_graph: DepGraph, 49 root_ops: list[str], 50 train: bool = False, 51) -> list[str]: 52 result = set(root_ops) 53 queue = root_ops.copy() 54 55 # The dependency graph might contain a special entry with key = `__BASE__` 56 # and value = (set of `base` ops to always include in custom build). 57 queue.append("__BASE__") 58 59 # The dependency graph might contain a special entry with key = `__ROOT__` 60 # and value = (set of ops reachable from C++ functions). Insert the special 61 # `__ROOT__` key to include ops which can be called from C++ code directly, 62 # in addition to ops that are called from TorchScript model. 63 # '__ROOT__' is only needed for full-jit. Keep it only for training. 64 # TODO: when FL is migrated from full-jit to lite trainer, remove '__ROOT__' 65 if train: 66 queue.append("__ROOT__") 67 68 while queue: 69 cur = queue.pop() 70 for dep in dep_graph.get(cur, []): 71 if dep not in result: 72 result.add(dep) 73 queue.append(dep) 74 75 return sorted(result) 76 77 78def gen_transitive_closure_str(dep_graph: DepGraph, root_ops: list[str]) -> str: 79 return " ".join(gen_transitive_closure(dep_graph, root_ops)) 80 81 82if __name__ == "__main__": 83 parser = argparse.ArgumentParser( 84 description="Util to produce transitive dependencies for custom build" 85 ) 86 parser.add_argument( 87 "--op-dependency", 88 help="input yaml file of op dependency graph " 89 "- can be omitted for custom build with static dispatch", 90 ) 91 parser.add_argument( 92 "--root-ops", 93 required=True, 94 help="input yaml file of root (directly used) operators", 95 ) 96 args = parser.parse_args() 97 98 deps = load_op_dep_graph(args.op_dependency) if args.op_dependency else {} 99 root_ops = load_root_ops(args.root_ops) 100 print(gen_transitive_closure_str(deps, root_ops)) 101