1import csv 2from collections import defaultdict 3 4import yaml 5 6import torch 7 8 9def get_ops_for_key(key): 10 # Needs modified PyTorch C++ code to work 11 if key is None: 12 ops = torch._C._dispatch_get_registrations_for_dispatch_key() 13 else: 14 ops = torch._C._dispatch_get_registrations_for_dispatch_key(key) 15 cleaned_ops = [] 16 for i in ops: 17 if "aten::" not in i: 18 continue 19 cleaned_ops.append(i[6:].strip()) 20 return set(cleaned_ops) 21 22 23def gen_data(special_op_lists, analysis_name): 24 all_ops = get_ops_for_key(None) 25 composite_ops = get_ops_for_key("CompositeImplicitAutograd") 26 noncomposite_ops = all_ops - composite_ops 27 28 ops = yaml.load( 29 open("../../aten/src/ATen/native/native_functions.yaml").read(), 30 Loader=yaml.CLoader, 31 ) 32 33 annotated_ops = { 34 a.strip(): b.strip() for a, b in list(csv.reader(open("annotated_ops"))) 35 } 36 from collections import defaultdict 37 38 uniq_ops = [] 39 uniq_names = set() 40 overload_types = defaultdict(list) 41 cnt = 0 42 for op in ops: 43 func_str = op["func"] 44 name = func_str[: func_str.index("(")] 45 if "." in name: 46 uniq_name = name[: name.index(".")] 47 overload_types[name[name.index(".") + 1 :]].append(name) 48 else: 49 uniq_name = name 50 op["name"] = uniq_name 51 full_name = func_str[: func_str.index("(")] 52 op["full_name"] = full_name 53 ret_type = func_str[func_str.index("->") + 3 :] 54 op["ret_type"] = ret_type 55 cnt += 1 56 if uniq_name in uniq_names: 57 continue 58 uniq_names.add(uniq_name) 59 uniq_ops.append(op) 60 61 def annotate_ops(ops, is_unique): 62 categorization = defaultdict(int) 63 for op in ops: 64 if op["name"][-1] == "_": 65 categorization["inplace"] += 1 66 op["meta"] = "inplace" 67 continue 68 if not is_unique and "a!" in op["func"].lower(): 69 categorization["out"] += 1 70 op["meta"] = "out" 71 continue 72 if "conv" in op["name"]: 73 categorization["conv"] += 1 74 op["meta"] = "conv" 75 continue 76 if "pool" in op["name"]: 77 categorization["pool"] += 1 78 op["meta"] = "pool" 79 continue 80 if "backward" in op["name"]: 81 categorization["backward"] += 1 82 op["meta"] = "backward" 83 continue 84 if op["name"][0] == "_" and op["name"][1] != "_": 85 categorization["private"] += 1 86 op["meta"] = "private" 87 continue 88 if "batch_norm" in op["name"]: 89 categorization["batch_norm"] += 1 90 op["meta"] = "batch_norm" 91 continue 92 if "Tensor" not in op["func"] or "Tensor" not in op["ret_type"]: 93 categorization["non_tensor"] += 1 94 op["meta"] = "non_tensor" 95 continue 96 if ( 97 "cudnn" in op["name"] 98 or "mkldnn" in op["name"] 99 or "miopen" in op["name"] 100 or "native" in op["name"] 101 or "thnn" in op["name"] 102 or "slow" in op["name"] 103 ): 104 categorization["backend"] += 1 105 op["meta"] = "backend" 106 continue 107 if op["name"] in annotated_ops: 108 categorization["core"] += 1 109 op["meta"] = "core " + annotated_ops[op["name"]] 110 continue 111 categorization["core"] += 1 112 op["meta"] = "core unknown" 113 return categorization 114 115 annotate_ops(ops, is_unique=False) 116 with open(f"{analysis_name}", "w") as f: 117 for op in ops: 118 info = [ 119 op["full_name"], 120 op["meta"], 121 op["full_name"] not in noncomposite_ops, 122 ] + [check(op) for check in special_op_lists] 123 f.write(",".join([str(i) for i in info]) + "\n") 124 125 126def name_check(lst): 127 return lambda x: x["name"] in lst 128 129 130def full_name_check(lst): 131 return lambda x: x["full_name"] in lst 132 133 134# Generates batching rule data 135gen_data([full_name_check(get_ops_for_key("FuncTorchBatched"))], "vmap.txt") 136 137 138def remove_suffix(input_string, suffix): 139 if suffix and input_string.endswith(suffix): 140 return input_string[: -len(suffix)] 141 return input_string 142 143 144def remove_prefix(input_string, prefix): 145 if prefix and input_string.startswith(prefix): 146 return input_string[len(prefix) :] 147 return input_string 148 149 150if True: 151 with open("run_ops.txt") as f: 152 opinfo_ops = [remove_suffix(i.strip(), ".default") for i in f] 153 with open("count_ops.txt") as f: 154 opinfo_counts = [i.strip() for i in f] 155 opinfo_counts = defaultdict(int, dict(zip(opinfo_ops, opinfo_counts))) 156 157 def count_fn(x): 158 return opinfo_counts[x["full_name"]] 159 160 with open("run_decompositions.txt") as f: 161 decomposed_ops = [remove_suffix(i.strip(), ".default") for i in f] 162 163 with open("public_api") as f: 164 ref_api = [i.strip() for i in f] 165 166 def has_ref_impl(x): 167 name = x["name"] 168 for prefix in ["linalg_", "special_"]: 169 name = remove_prefix(name, prefix) 170 prefixes = ["nn.functional", "fft", "special", "linalg"] 171 return ( 172 any(f"{prefix}.{name}" in ref_api for prefix in prefixes) or name in ref_api 173 ) 174 175 gen_data( 176 [ 177 full_name_check(opinfo_ops), 178 full_name_check(decomposed_ops), 179 count_fn, 180 has_ref_impl, 181 ], 182 "decompositions.txt", 183 ) 184