1from __future__ import annotations 2 3from collections import defaultdict 4from collections.abc import Iterable 5from dataclasses import dataclass 6from typing import TYPE_CHECKING 7 8import yaml 9 10from torchgen.selective_build.operator import ( 11 merge_debug_info, 12 merge_operator_dicts, 13 SelectiveBuildOperator, 14 strip_operator_overload_name, 15) 16 17 18if TYPE_CHECKING: 19 from torchgen.model import NativeFunction 20 21 22# A SelectiveBuilder holds information extracted from the selective build 23# YAML specification. 24# 25# It includes information about the build's selectivity, the debug_info 26# associated with this selective build (opaque string), and the set of 27# operators that should be included in the build. 28# 29@dataclass(frozen=True) 30class SelectiveBuilder: 31 # If true, then the build is not selective, and includes all 32 # operators. 33 include_all_operators: bool 34 35 # Debug Information at the selective/custom build level. 36 _debug_info: tuple[str, ...] | None 37 38 # A dictionary of operator -> operator metadata. 39 operators: dict[str, SelectiveBuildOperator] 40 41 # A dictionary of selected kernel tags and dtypes. Typically a 42 # PyTorch Operator Kernel (function) may have many code paths 43 # that are specialized for many many Tensor dtypes, so it's not 44 # one per kernel function, but there could be many per kernel 45 # function. The tag isn't a kernel function name, but some fragment 46 # of the kernel function implementation itself. 47 kernel_metadata: dict[str, list[str]] 48 49 # ExecuTorch only. A dictionary of kernel tag -> list of (list of input 50 # dtypes for tensor-like input args). 51 # This is from selective.yaml 52 et_kernel_metadata: dict[str, list[str]] 53 54 # A set of all the custom torch bind classes used by the selected models 55 # Stored as a set internally to remove duplicates proactively, but written 56 # as a list to yamls 57 custom_classes: set[str] 58 59 # A set of all the build features used by the selected models 60 # Stored as a set internally to remove duplicates proactively, but written 61 # as a list to yamls 62 build_features: set[str] 63 64 # If true, then fragments for all dtypes for all kernel functions 65 # are included as well as all custom classes. This is typically set when any one of the 66 # operator lists is generated from a mechanism other than 67 # tracing based selective build. 68 include_all_non_op_selectives: bool 69 70 @staticmethod 71 def get_nop_selector() -> SelectiveBuilder: 72 return SelectiveBuilder.from_yaml_dict({"include_all_operators": True}) 73 74 @staticmethod 75 def from_yaml_dict(data: dict[str, object]) -> SelectiveBuilder: 76 valid_top_level_keys = { 77 "include_all_non_op_selectives", 78 "include_all_operators", 79 "debug_info", 80 "operators", 81 "kernel_metadata", 82 "et_kernel_metadata", 83 "custom_classes", 84 "build_features", 85 } 86 top_level_keys = set(data.keys()) 87 if len(top_level_keys - valid_top_level_keys) > 0: 88 raise Exception( # noqa: TRY002 89 "Got unexpected top level keys: {}".format( 90 ",".join(top_level_keys - valid_top_level_keys), 91 ) 92 ) 93 include_all_operators = data.get("include_all_operators", False) 94 assert isinstance(include_all_operators, bool) 95 96 debug_info = None 97 if "debug_info" in data: 98 di_list = data["debug_info"] 99 assert isinstance(di_list, list) 100 101 debug_info = tuple(str(x) for x in di_list) 102 103 operators = {} 104 operators_dict = data.get("operators", {}) 105 assert isinstance(operators_dict, dict) 106 107 for k, v in operators_dict.items(): 108 operators[k] = SelectiveBuildOperator.from_yaml_dict(k, v) 109 110 kernel_metadata = {} 111 kernel_metadata_dict = data.get("kernel_metadata", {}) 112 assert isinstance(kernel_metadata_dict, dict) 113 114 for k, v in kernel_metadata_dict.items(): 115 kernel_metadata[str(k)] = [str(dtype) for dtype in v] 116 117 et_kernel_metadata = data.get("et_kernel_metadata", {}) 118 assert isinstance(et_kernel_metadata, dict) 119 120 custom_classes = data.get("custom_classes", []) 121 assert isinstance(custom_classes, Iterable) 122 custom_classes = set(custom_classes) 123 124 build_features = data.get("build_features", []) 125 assert isinstance(build_features, Iterable) 126 build_features = set(build_features) 127 128 include_all_non_op_selectives = data.get("include_all_non_op_selectives", False) 129 assert isinstance(include_all_non_op_selectives, bool) 130 131 return SelectiveBuilder( 132 include_all_operators, 133 debug_info, 134 operators, 135 kernel_metadata, 136 et_kernel_metadata, 137 custom_classes, # type: ignore[arg-type] 138 build_features, # type: ignore[arg-type] 139 include_all_non_op_selectives, 140 ) 141 142 @staticmethod 143 def from_yaml_str(config_contents: str) -> SelectiveBuilder: 144 contents = yaml.safe_load(config_contents) 145 return SelectiveBuilder.from_yaml_dict(contents) 146 147 @staticmethod 148 def from_yaml_path(config_path: str) -> SelectiveBuilder: 149 with open(config_path) as f: 150 contents = yaml.safe_load(f) 151 return SelectiveBuilder.from_yaml_dict(contents) 152 153 @staticmethod 154 def from_legacy_op_registration_allow_list( 155 allow_list: set[str], is_root_operator: bool, is_used_for_training: bool 156 ) -> SelectiveBuilder: 157 operators = {} 158 for op in allow_list: 159 operators[op] = { 160 "name": op, 161 "is_root_operator": is_root_operator, 162 "is_used_for_training": is_used_for_training, 163 "include_all_overloads": True, 164 } 165 return SelectiveBuilder.from_yaml_dict( 166 { 167 "operators": operators, 168 "include_all_non_op_selectives": True, 169 } 170 ) 171 172 def is_operator_selected(self, name: str) -> bool: 173 if self.include_all_operators: 174 return True 175 176 if name in self.operators: 177 return True 178 name = strip_operator_overload_name(name) 179 return name in self.operators and self.operators[name].include_all_overloads 180 181 def is_native_function_selected(self, func: NativeFunction) -> bool: 182 op_name = op_name_from_native_function(func) 183 return self.is_operator_selected(op_name) 184 185 def is_operator_selected_for_training(self, name: str) -> bool: 186 if not self.is_operator_selected(name): 187 return False 188 if self.include_all_operators: 189 return True 190 191 not_training_op = SelectiveBuildOperator( 192 name="", 193 is_root_operator=False, 194 is_used_for_training=False, 195 include_all_overloads=False, 196 _debug_info=None, 197 ) 198 op = not_training_op 199 if name in self.operators: 200 op = self.operators[name] 201 202 name = strip_operator_overload_name(name) 203 base_op = not_training_op 204 if name in self.operators: 205 base_op = self.operators[name] 206 207 return op.is_used_for_training or ( 208 base_op.include_all_overloads and base_op.is_used_for_training 209 ) 210 211 def is_native_function_selected_for_training(self, func: NativeFunction) -> bool: 212 op_name = op_name_from_native_function(func) 213 return self.is_operator_selected_for_training(op_name) 214 215 def is_root_operator(self, name: str) -> bool: 216 if not self.is_operator_selected(name): 217 return False 218 if self.include_all_operators: 219 return True 220 221 if name in self.operators: 222 op: SelectiveBuildOperator = self.operators[name] 223 return op.is_root_operator 224 name = strip_operator_overload_name(name) 225 if name not in self.operators: 226 return False 227 base_op: SelectiveBuildOperator = self.operators[name] 228 return base_op.include_all_overloads and base_op.is_root_operator 229 230 def is_kernel_dtype_selected(self, kernel_tag: str, dtype: str) -> bool: 231 if self.include_all_operators or self.include_all_non_op_selectives: 232 return True 233 234 return ( 235 kernel_tag in self.kernel_metadata 236 and dtype in self.kernel_metadata[kernel_tag] 237 ) 238 239 def et_get_selected_kernels(self, op_name: str, kernel_key: list[str]) -> list[str]: 240 """ 241 Return a list of kernel keys that cover the used ops 242 """ 243 # If no kernel metadata, either it's implied by include_all_operators=True or the op is not used. 244 if op_name not in self.et_kernel_metadata: 245 return kernel_key if self.include_all_operators else [] 246 # Otherwise, only return the specific kernel keys. 247 248 result_set = set() 249 250 for model_kernel_keys in self.et_kernel_metadata[op_name]: 251 key_found = False 252 for key in kernel_key: 253 # Don't compare the version for now 254 if ( 255 key != "default" 256 and key.split("/")[1] == model_kernel_keys.split("/")[1] 257 ): 258 result_set.add(key) 259 key_found = True 260 break 261 if not key_found: 262 if "default" not in kernel_key: 263 raise Exception("Missing kernel for the model") # noqa: TRY002 264 else: 265 result_set.add("default") 266 267 return list(result_set) 268 269 def to_dict(self) -> dict[str, object]: 270 ret: dict[str, object] = { 271 "include_all_non_op_selectives": self.include_all_non_op_selectives, 272 "include_all_operators": self.include_all_operators, 273 } 274 operators = {} 275 for op_name, op in self.operators.items(): 276 operators[op_name] = op.to_dict() 277 ret["operators"] = operators 278 279 if self._debug_info is not None: 280 ret["debug_info"] = sorted(self._debug_info) 281 282 ret["kernel_metadata"] = { 283 k: sorted(v) for (k, v) in self.kernel_metadata.items() 284 } 285 286 ret["et_kernel_metadata"] = self.et_kernel_metadata 287 288 ret["custom_classes"] = sorted(self.custom_classes) 289 290 ret["build_features"] = sorted(self.build_features) 291 292 return ret 293 294 295def merge_kernel_metadata( 296 lhs: dict[str, list[str]], 297 rhs: dict[str, list[str]], 298) -> dict[str, list[str]]: 299 kernel_metadata: dict[str, list[str]] = {} 300 for tag_name, dtypes in list(lhs.items()) + list(rhs.items()): 301 dtypes_copy = set(dtypes) 302 if tag_name in kernel_metadata: 303 dtypes_copy |= set(kernel_metadata[tag_name]) 304 305 kernel_metadata[tag_name] = list(dtypes_copy) 306 307 return kernel_metadata 308 309 310def merge_et_kernel_metadata( 311 lhs: dict[str, list[str]], 312 rhs: dict[str, list[str]], 313) -> dict[str, list[str]]: 314 merge_et_kernel_metadata: dict[str, set[str]] = defaultdict(set) 315 for op in list(lhs.keys()) + list(rhs.keys()): 316 merge_et_kernel_metadata[op].update(lhs.get(op, [])) 317 merge_et_kernel_metadata[op].update(rhs.get(op, [])) 318 319 return {op: sorted(val) for op, val in merge_et_kernel_metadata.items()} 320 321 322def combine_selective_builders( 323 lhs: SelectiveBuilder, rhs: SelectiveBuilder 324) -> SelectiveBuilder: 325 include_all_operators = lhs.include_all_operators or rhs.include_all_operators 326 debug_info = merge_debug_info(lhs._debug_info, rhs._debug_info) 327 operators = merge_operator_dicts(lhs.operators, rhs.operators) 328 kernel_metadata = merge_kernel_metadata(lhs.kernel_metadata, rhs.kernel_metadata) 329 et_kernel_metadata = merge_et_kernel_metadata( 330 lhs.et_kernel_metadata, rhs.et_kernel_metadata 331 ) 332 include_all_non_op_selectives = ( 333 lhs.include_all_non_op_selectives or rhs.include_all_non_op_selectives 334 ) 335 custom_classes = lhs.custom_classes.union(rhs.custom_classes) 336 build_features = lhs.build_features.union(rhs.build_features) 337 return SelectiveBuilder( 338 include_all_operators, 339 debug_info, 340 operators, 341 kernel_metadata, 342 et_kernel_metadata, 343 custom_classes, 344 build_features, 345 include_all_non_op_selectives, 346 ) 347 348 349def op_name_from_native_function(f: NativeFunction) -> str: 350 # This was originally read from the 'operator_name_with_overload' field in the 351 # declaration dict, which was the part before the first '(' in 'schema_string'. 352 return f"{f.namespace}::{f.func.name}" 353