xref: /aosp_15_r20/external/pytorch/functorch/op_analysis/gen_data.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1import csv
2from collections import defaultdict
3
4import yaml
5
6import torch
7
8
9def get_ops_for_key(key):
10    # Needs modified PyTorch C++ code to work
11    if key is None:
12        ops = torch._C._dispatch_get_registrations_for_dispatch_key()
13    else:
14        ops = torch._C._dispatch_get_registrations_for_dispatch_key(key)
15    cleaned_ops = []
16    for i in ops:
17        if "aten::" not in i:
18            continue
19        cleaned_ops.append(i[6:].strip())
20    return set(cleaned_ops)
21
22
23def gen_data(special_op_lists, analysis_name):
24    all_ops = get_ops_for_key(None)
25    composite_ops = get_ops_for_key("CompositeImplicitAutograd")
26    noncomposite_ops = all_ops - composite_ops
27
28    ops = yaml.load(
29        open("../../aten/src/ATen/native/native_functions.yaml").read(),
30        Loader=yaml.CLoader,
31    )
32
33    annotated_ops = {
34        a.strip(): b.strip() for a, b in list(csv.reader(open("annotated_ops")))
35    }
36    from collections import defaultdict
37
38    uniq_ops = []
39    uniq_names = set()
40    overload_types = defaultdict(list)
41    cnt = 0
42    for op in ops:
43        func_str = op["func"]
44        name = func_str[: func_str.index("(")]
45        if "." in name:
46            uniq_name = name[: name.index(".")]
47            overload_types[name[name.index(".") + 1 :]].append(name)
48        else:
49            uniq_name = name
50        op["name"] = uniq_name
51        full_name = func_str[: func_str.index("(")]
52        op["full_name"] = full_name
53        ret_type = func_str[func_str.index("->") + 3 :]
54        op["ret_type"] = ret_type
55        cnt += 1
56        if uniq_name in uniq_names:
57            continue
58        uniq_names.add(uniq_name)
59        uniq_ops.append(op)
60
61    def annotate_ops(ops, is_unique):
62        categorization = defaultdict(int)
63        for op in ops:
64            if op["name"][-1] == "_":
65                categorization["inplace"] += 1
66                op["meta"] = "inplace"
67                continue
68            if not is_unique and "a!" in op["func"].lower():
69                categorization["out"] += 1
70                op["meta"] = "out"
71                continue
72            if "conv" in op["name"]:
73                categorization["conv"] += 1
74                op["meta"] = "conv"
75                continue
76            if "pool" in op["name"]:
77                categorization["pool"] += 1
78                op["meta"] = "pool"
79                continue
80            if "backward" in op["name"]:
81                categorization["backward"] += 1
82                op["meta"] = "backward"
83                continue
84            if op["name"][0] == "_" and op["name"][1] != "_":
85                categorization["private"] += 1
86                op["meta"] = "private"
87                continue
88            if "batch_norm" in op["name"]:
89                categorization["batch_norm"] += 1
90                op["meta"] = "batch_norm"
91                continue
92            if "Tensor" not in op["func"] or "Tensor" not in op["ret_type"]:
93                categorization["non_tensor"] += 1
94                op["meta"] = "non_tensor"
95                continue
96            if (
97                "cudnn" in op["name"]
98                or "mkldnn" in op["name"]
99                or "miopen" in op["name"]
100                or "native" in op["name"]
101                or "thnn" in op["name"]
102                or "slow" in op["name"]
103            ):
104                categorization["backend"] += 1
105                op["meta"] = "backend"
106                continue
107            if op["name"] in annotated_ops:
108                categorization["core"] += 1
109                op["meta"] = "core " + annotated_ops[op["name"]]
110                continue
111            categorization["core"] += 1
112            op["meta"] = "core unknown"
113        return categorization
114
115    annotate_ops(ops, is_unique=False)
116    with open(f"{analysis_name}", "w") as f:
117        for op in ops:
118            info = [
119                op["full_name"],
120                op["meta"],
121                op["full_name"] not in noncomposite_ops,
122            ] + [check(op) for check in special_op_lists]
123            f.write(",".join([str(i) for i in info]) + "\n")
124
125
126def name_check(lst):
127    return lambda x: x["name"] in lst
128
129
130def full_name_check(lst):
131    return lambda x: x["full_name"] in lst
132
133
134# Generates batching rule data
135gen_data([full_name_check(get_ops_for_key("FuncTorchBatched"))], "vmap.txt")
136
137
138def remove_suffix(input_string, suffix):
139    if suffix and input_string.endswith(suffix):
140        return input_string[: -len(suffix)]
141    return input_string
142
143
144def remove_prefix(input_string, prefix):
145    if prefix and input_string.startswith(prefix):
146        return input_string[len(prefix) :]
147    return input_string
148
149
150if True:
151    with open("run_ops.txt") as f:
152        opinfo_ops = [remove_suffix(i.strip(), ".default") for i in f]
153    with open("count_ops.txt") as f:
154        opinfo_counts = [i.strip() for i in f]
155        opinfo_counts = defaultdict(int, dict(zip(opinfo_ops, opinfo_counts)))
156
157    def count_fn(x):
158        return opinfo_counts[x["full_name"]]
159
160    with open("run_decompositions.txt") as f:
161        decomposed_ops = [remove_suffix(i.strip(), ".default") for i in f]
162
163    with open("public_api") as f:
164        ref_api = [i.strip() for i in f]
165
166    def has_ref_impl(x):
167        name = x["name"]
168        for prefix in ["linalg_", "special_"]:
169            name = remove_prefix(name, prefix)
170        prefixes = ["nn.functional", "fft", "special", "linalg"]
171        return (
172            any(f"{prefix}.{name}" in ref_api for prefix in prefixes) or name in ref_api
173        )
174
175    gen_data(
176        [
177            full_name_check(opinfo_ops),
178            full_name_check(decomposed_ops),
179            count_fn,
180            has_ref_impl,
181        ],
182        "decompositions.txt",
183    )
184