1from __future__ import annotations 2 3import argparse 4import os 5import re 6from collections import Counter, defaultdict, namedtuple 7from pathlib import Path 8from typing import Sequence 9 10import yaml 11 12import torchgen.api.dispatcher as dispatcher 13import torchgen.dest as dest 14from torchgen.api.types import DispatcherSignature 15from torchgen.code_template import CodeTemplate 16from torchgen.context import native_function_manager 17from torchgen.gen import get_grouped_native_functions, parse_native_yaml 18from torchgen.model import ( 19 BackendIndex, 20 BackendMetadata, 21 DispatchKey, 22 NativeFunction, 23 NativeFunctionsGroup, 24 OperatorName, 25) 26from torchgen.selective_build.selector import SelectiveBuilder 27from torchgen.utils import concatMap, context, FileManager, NamespaceHelper, Target 28from torchgen.yaml_utils import YamlLoader 29 30 31# Parses the external backend's yaml, and adds a new BackendIndex for the backend's dispatch key. 32# Returns a Tuple of (backend_key, autograd_key, cpp_namespace, updated BackendIndex mapping) 33ParsedExternalYaml = namedtuple( 34 "ParsedExternalYaml", 35 ["backend_key", "autograd_key", "class_name", "cpp_namespace", "backend_indices"], 36) 37 38 39def parse_backend_yaml( 40 backend_yaml_path: str, 41 grouped_native_functions: Sequence[NativeFunction | NativeFunctionsGroup], 42 backend_indices: dict[DispatchKey, BackendIndex], 43) -> ParsedExternalYaml: 44 native_functions_map: dict[OperatorName, NativeFunction] = { 45 f.func.name: f 46 for f in concatMap( 47 lambda f: [f] if isinstance(f, NativeFunction) else list(f.functions()), 48 grouped_native_functions, 49 ) 50 } 51 52 with open(backend_yaml_path) as f: 53 yaml_values = yaml.load(f, Loader=YamlLoader) 54 assert isinstance(yaml_values, dict) 55 56 valid_keys = [ 57 "backend", 58 "class_name", 59 "cpp_namespace", 60 "extra_headers", 61 "supported", 62 "autograd", 63 "full_codegen", 64 "non_native", 65 "ir_gen", 66 "symint", 67 ] 68 69 backend = yaml_values.pop("backend", None) 70 assert backend is not None, 'You must provide a value for "backend"' 71 72 class_name = yaml_values.pop("class_name", None) 73 74 cpp_namespace = yaml_values.pop("cpp_namespace", None) 75 assert cpp_namespace is not None, 'You must provide a value for "cpp_namespace"' 76 77 # Mostly just defaulting to false to stick with LazyTensor convention. 78 use_out_as_primary = yaml_values.pop("use_out_as_primary", False) 79 assert isinstance( 80 use_out_as_primary, bool 81 ), f"You must provide either True or False for use_out_as_primary. Provided: {use_out_as_primary}" 82 83 use_device_guard = yaml_values.pop("device_guard", False) 84 assert isinstance( 85 use_device_guard, bool 86 ), f"You must provide either True or False for device_guard. Provided: {use_device_guard}" 87 88 supported = yaml_values.pop("supported", []) 89 if supported is None: 90 supported = [] # Allow an empty list of supported ops 91 assert isinstance( 92 supported, list 93 ), f'expected "supported" to be a list, but got: {supported} (of type {type(supported)})' 94 95 symint = yaml_values.pop("symint", []) 96 if symint is None: 97 symint = [] # Allow an empty list of symint ops 98 assert isinstance( 99 symint, list 100 ), f'expected "symint" to be a list, but got: {supported} (of type {type(supported)})' 101 symint_set = set(symint) 102 103 supported_autograd = yaml_values.pop("autograd", []) 104 assert isinstance( 105 supported_autograd, list 106 ), f'expected "autograd" to be a list, but got: {supported_autograd}' 107 108 # full_codegen is ignored by parse_backend_yaml, and re-parsed in gen_lazy_tensor.py 109 full_codegen = yaml_values.pop("full_codegen", []) 110 supported.extend(full_codegen) 111 112 # non_native is ignored by parse_backend_yaml, and re-parsed in gen_lazy_tensor.py 113 yaml_values.pop("non_native", {}) 114 115 # ir_gen is ignored by parse_backend_yaml, and re-parsed in gen_lazy_tensor.py 116 yaml_values.pop("ir_gen", {}) 117 118 assert ( 119 len(yaml_values.keys()) == 0 120 ), f'{backend_yaml_path} contains unexpected keys: {", ".join(yaml_values.keys())}. \ 121Only the following keys are supported: {", ".join(valid_keys)}' 122 123 def create_backend_index( 124 backend_ops: list[str], 125 symint_ops: set[str], 126 dispatch_key: DispatchKey, 127 *, 128 use_out_as_primary: bool, 129 use_device_guard: bool, 130 ) -> BackendIndex: 131 metadata: dict[OperatorName, BackendMetadata] = {} 132 for op in backend_ops: 133 op_name = OperatorName.parse(op) 134 assert ( 135 op_name in native_functions_map 136 ), f"Found an invalid operator name: {op_name}" 137 # See Note [External Backends Follow Dispatcher API] 138 kernel_name = dispatcher.name(native_functions_map[op_name].func) 139 if op in symint_ops: 140 kernel_name += "_symint" 141 # TODO: allow structured external backends later. 142 m = BackendMetadata( 143 kernel=kernel_name, structured=False, cpp_namespace=cpp_namespace 144 ) 145 metadata[op_name] = m 146 return BackendIndex( 147 dispatch_key=dispatch_key, 148 use_out_as_primary=use_out_as_primary, 149 external=True, 150 device_guard=use_device_guard, 151 index=metadata, 152 ) 153 154 backend_key: DispatchKey | None = None 155 if len(supported) > 0: 156 with context( 157 lambda: f'The provided value for "backend" must be a valid DispatchKey, but got {backend}.' 158 ): 159 backend_key = DispatchKey.parse(backend) 160 161 backend_idx = create_backend_index( 162 supported, 163 symint_set, 164 backend_key, 165 use_out_as_primary=use_out_as_primary, 166 use_device_guard=use_device_guard, 167 ) 168 assert backend_key not in backend_indices 169 backend_indices[backend_key] = backend_idx 170 171 autograd_key: DispatchKey | None = None 172 if len(supported_autograd) > 0: 173 with context( 174 lambda: f'The "autograd" key was specified, which indicates that you would like to override \ 175the behavior of autograd for some operators on your backend. However "Autograd{backend}" is not a valid DispatchKey.' 176 ): 177 autograd_key = DispatchKey.parse(f"Autograd{backend}") 178 179 autograd_idx = create_backend_index( 180 supported_autograd, 181 symint_set, 182 autograd_key, 183 use_out_as_primary=use_out_as_primary, 184 use_device_guard=use_device_guard, 185 ) 186 assert autograd_key not in backend_indices 187 backend_indices[autograd_key] = autograd_idx 188 189 for g in grouped_native_functions: 190 if isinstance(g, NativeFunction): 191 forward_kernels = ( 192 [] 193 if backend_key is None 194 else [ 195 m 196 for m in [backend_indices[backend_key].get_kernel(g)] 197 if m is not None 198 ] 199 ) 200 backward_kernels = ( 201 [] 202 if autograd_key is None 203 else [ 204 m 205 for m in [backend_indices[autograd_key].get_kernel(g)] 206 if m is not None 207 ] 208 ) 209 else: 210 forward_kernels = ( 211 [] 212 if backend_key is None 213 else [ 214 m 215 for m in [ 216 backend_indices[backend_key].get_kernel(f) 217 for f in g.functions() 218 ] 219 if m is not None 220 ] 221 ) 222 backward_kernels = ( 223 [] 224 if autograd_key is None 225 else [ 226 m 227 for m in [ 228 backend_indices[autograd_key].get_kernel(f) 229 for f in g.functions() 230 ] 231 if m is not None 232 ] 233 ) 234 235 forward_kernels = [f for f in forward_kernels if f is not None] 236 backward_kernels = [f for f in backward_kernels if f is not None] 237 assert ( 238 len(forward_kernels) == 0 or len(backward_kernels) == 0 239 ), f'Currently, all variants of an op must either be registered to a backend key, or to a backend\'s \ 240autograd key. They cannot be mix and matched. If this is something you need, feel free to create an issue! \ 241{forward_kernels[0].kernel} is listed under "supported", but {backward_kernels[0].kernel} is listed under "autograd".' 242 243 return ParsedExternalYaml( 244 backend_key, autograd_key, class_name, cpp_namespace, backend_indices 245 ) 246 247 248def error_on_missing_kernels( 249 native_functions: Sequence[NativeFunction], 250 backend_indices: dict[DispatchKey, BackendIndex], 251 backend_key: DispatchKey, 252 autograd_key: DispatchKey | None, 253 class_name: str, 254 kernel_defn_file_path: str, 255 full_codegen: list[OperatorName] | None = None, 256) -> None: 257 try: 258 with open(kernel_defn_file_path) as f: 259 backend_defns = f.read() 260 except OSError as e: 261 raise AssertionError( 262 f"Unable to read from the specified impl_path file: {kernel_defn_file_path}" 263 ) from e 264 265 if full_codegen is None: 266 full_codegen = [] 267 268 indices = [backend_indices[backend_key].index] + ( 269 [] if autograd_key is None else [backend_indices[autograd_key].index] 270 ) 271 # Quick mapping from each OperatorName used by the external backend 272 # to its backend kernel name 273 expected_backend_op_names: dict[OperatorName, str] = dict( 274 list( 275 concatMap( 276 lambda index: [ 277 (op_name, metadata.kernel) for op_name, metadata in index.items() 278 ], 279 indices, 280 ) 281 ) 282 ) 283 expected_backend_native_funcs: list[NativeFunction] = [ 284 f 285 for f in native_functions 286 if f.func.name in expected_backend_op_names.keys() 287 and f.func.name not in full_codegen 288 ] 289 expected_backend_kernel_name_counts: dict[str, list[NativeFunction]] = defaultdict( 290 list 291 ) 292 for native_f in expected_backend_native_funcs: 293 expected_backend_kernel_name_counts[ 294 expected_backend_op_names[native_f.func.name] 295 ].append(native_f) 296 297 # This just looks for lines containing "foo(", and assumes that the kernel foo has been implemented. 298 # It might cause false negatives (we won't catch all cases), but that's ok - if we catch a missing kernel 299 # here, then we get a nicer error message. If we miss it, you get a linker error. 300 kernel_defn_regex = rf"(.*){class_name}::\s*([\w\d]*)\(" 301 actual_backend_kernel_name_counts = Counter( 302 # A bit unwieldy (this could probably be moved into regex), 303 # but we don't want to include kernel names that come from function calls, 304 # like "return torch_xla::XLANativeFunctions::empty_strided_symint(...)". 305 # Easy check is to ignore any lines with colons before the class name. 306 [ 307 y 308 for (x, y) in re.findall(kernel_defn_regex, backend_defns) 309 if not x.endswith(":") 310 ] 311 ) 312 313 missing_kernels_err_msg = "" 314 for expected_name, funcs in expected_backend_kernel_name_counts.items(): 315 expected_overload_count = len(funcs) 316 actual_overload_count = actual_backend_kernel_name_counts[expected_name] 317 if expected_overload_count != actual_overload_count: 318 319 def create_decl(f: NativeFunction) -> str: 320 with native_function_manager(f): 321 return DispatcherSignature.from_schema(f.func).decl() 322 323 expected_schemas_str = "\n".join([create_decl(f) for f in funcs]) 324 missing_kernels_err_msg += f""" 325{class_name} is missing a kernel definition for {expected_name}. We found {actual_overload_count} kernel(s) with that name, 326but expected {expected_overload_count} kernel(s). The expected function schemas for the missing operator are: 327{expected_schemas_str} 328 329""" 330 assert missing_kernels_err_msg == "", missing_kernels_err_msg 331 332 333def main() -> None: 334 parser = argparse.ArgumentParser(description="Generate backend stub files") 335 parser.add_argument( 336 "-s", 337 "--source-yaml", 338 "--source_yaml", 339 help="path to source yaml file containing operator external definitions", 340 ) 341 parser.add_argument("-o", "--output-dir", "--output_dir", help="output directory") 342 parser.add_argument( 343 "--dry-run", "--dry_run", type=bool, default=False, help="output directory" 344 ) 345 parser.add_argument( 346 "--impl-path", 347 "--impl_path", 348 type=str, 349 default=None, 350 help="path to the source C++ file containing kernel definitions", 351 ) 352 options = parser.parse_args() 353 354 run(options.source_yaml, options.output_dir, options.dry_run, options.impl_path) 355 356 357def gen_dispatchkey_nativefunc_headers( 358 fm: FileManager, 359 class_name: str, 360 cpp_namespace: str, 361 backend_indices: dict[DispatchKey, BackendIndex], 362 grouped_native_functions: Sequence[NativeFunction | NativeFunctionsGroup], 363 backend_dispatch_key: DispatchKey, 364 autograd_dispatch_key: DispatchKey | None, 365 backend_name: str = "", 366) -> None: 367 assert class_name is not None 368 generated_comment = ( 369 "Autogenerated file by gen_backend_stubs.py. Do not edit directly!" 370 ) 371 372 # Convert to a set first to remove duplicate kernel names. 373 # Backends are allowed to repeat kernel names; only generate the declaration once! 374 # Sort for deterministic output. 375 backend_declarations = sorted( 376 set( 377 concatMap( 378 lambda f: dest.compute_native_function_declaration( 379 f, backend_indices[backend_dispatch_key] 380 ), 381 grouped_native_functions, 382 ) 383 ) 384 ) 385 autograd_declarations = sorted( 386 set( 387 concatMap( 388 lambda f: [] 389 if autograd_dispatch_key is None 390 else dest.compute_native_function_declaration( 391 f, backend_indices[autograd_dispatch_key] 392 ), 393 grouped_native_functions, 394 ) 395 ) 396 ) 397 398 ns_helper = NamespaceHelper(cpp_namespace) 399 fm.write_with_template( 400 f"{backend_dispatch_key}NativeFunctions.h", 401 "DispatchKeyNativeFunctions.h", 402 lambda: { 403 "generated_comment": generated_comment, 404 "namespace_prologue": ns_helper.prologue, 405 "class_name": class_name, 406 "namespace_epilogue": ns_helper.epilogue, 407 "dispatch_declarations": backend_declarations + autograd_declarations, 408 "BackendName": backend_name, 409 "DispatchKey": backend_dispatch_key, 410 }, 411 ) 412 413 414def gen_dispatcher_registrations( 415 fm: FileManager, 416 output_dir: str, 417 class_name: str, 418 backend_indices: dict[DispatchKey, BackendIndex], 419 grouped_native_functions: Sequence[NativeFunction | NativeFunctionsGroup], 420 backend_dispatch_key: DispatchKey, 421 dispatch_key: DispatchKey, 422 selector: SelectiveBuilder, 423 # build_in_tree is true for lazy TS backend and affects include paths, not used for external backends 424 build_in_tree: bool = False, 425 per_operator_headers: bool = False, 426 backend_name: str = "", 427 eager_registration: bool = True, 428) -> None: 429 headers = [ 430 f"{output_dir}/{backend_dispatch_key}NativeFunctions.h", 431 ] 432 if build_in_tree: 433 external_backend_headers_str = "\n".join(f"#include <{h}>" for h in headers) 434 else: 435 external_backend_headers_str = "\n".join(f'#include "{h}"' for h in headers) 436 437 assert class_name is not None 438 backend_index = backend_indices[dispatch_key] 439 440 dispatch_registrations_body = list( 441 concatMap( 442 dest.RegisterDispatchKey( 443 backend_index, 444 Target.REGISTRATION, 445 selector, 446 rocm=False, 447 symint=True, 448 class_method_name=f"{class_name}", 449 skip_dispatcher_op_registration=False, 450 ), 451 grouped_native_functions, 452 ) 453 ) 454 newline = "\n" 455 ns_helper = NamespaceHelper(namespace_str="at") 456 deferred_dispatch_registrations = "" 457 static_init_dispatch_registrations = "" 458 if eager_registration: 459 static_template = CodeTemplate( 460 """\ 461TORCH_LIBRARY_IMPL(aten, $dispatch_key, m) { 462 $dispatch_registrations_body 463};""" 464 ) 465 static_init_dispatch_registrations = static_template.substitute( 466 dispatch_key=dispatch_key, 467 dispatch_registrations_body=dispatch_registrations_body, 468 ) 469 else: 470 deferred_template = CodeTemplate( 471 """\ 472TORCH_API void Register${backend_name}${dispatch_key}NativeFunctions(); 473TORCH_API void Register${backend_name}${dispatch_key}NativeFunctions() { 474 static auto m = MAKE_TORCH_LIBRARY_IMPL(aten, $dispatch_key); 475 $dispatch_registrations_body 476}""" 477 ) 478 deferred_dispatch_registrations = deferred_template.substitute( 479 backend_name=backend_name, 480 dispatch_key=dispatch_key, 481 dispatch_registrations_body=dispatch_registrations_body, 482 ) 483 484 fm.write_with_template( 485 f"Register{dispatch_key}.cpp", 486 "RegisterDispatchKey.cpp", 487 lambda: { 488 "extra_cuda_headers": "", 489 "external_backend_headers": external_backend_headers_str, 490 "ops_headers": "#include <ATen/Functions.h>" 491 if not per_operator_headers 492 else "", 493 "DispatchKey": dispatch_key, 494 "dispatch_namespace": dispatch_key.lower(), 495 "dispatch_headers": dest.gen_registration_headers( 496 backend_index, per_operator_headers=per_operator_headers, rocm=False 497 ), 498 "dispatch_definitions": fm.substitute_with_template( 499 "RegisterDispatchDefinitions.ini", 500 lambda: { 501 "ns_prologue": ns_helper.prologue, 502 "ns_epilogue": ns_helper.epilogue, 503 "static_init_dispatch_registrations": static_init_dispatch_registrations, 504 "deferred_dispatch_registrations": deferred_dispatch_registrations, 505 "dispatch_helpers": dest.gen_registration_helpers(backend_index), 506 "dispatch_namespace": dispatch_key.lower(), 507 "dispatch_namespaced_definitions": "", 508 "dispatch_anonymous_definitions": list( 509 concatMap( 510 dest.RegisterDispatchKey( 511 backend_index, 512 Target.ANONYMOUS_DEFINITION, 513 selector, 514 rocm=False, 515 symint=True, 516 class_method_name=f"{class_name}", 517 skip_dispatcher_op_registration=False, 518 ), 519 grouped_native_functions, 520 ) 521 ), 522 }, 523 ).split(newline), 524 }, 525 ) 526 527 528def run( 529 source_yaml: str, output_dir: str, dry_run: bool, impl_path: str | None = None 530) -> None: 531 # Assumes that this file lives at PYTORCH_ROOT/torchgen/gen_backend_stubs.py 532 pytorch_root = Path(__file__).parent.parent.absolute() 533 template_dir = os.path.join(pytorch_root, "aten/src/ATen/templates") 534 535 def make_file_manager(install_dir: str) -> FileManager: 536 return FileManager( 537 install_dir=install_dir, template_dir=template_dir, dry_run=dry_run 538 ) 539 540 fm = make_file_manager(output_dir) 541 542 native_yaml_path = os.path.join( 543 pytorch_root, "aten/src/ATen/native/native_functions.yaml" 544 ) 545 tags_yaml_path = os.path.join(pytorch_root, "aten/src/ATen/native/tags.yaml") 546 parsed_yaml = parse_native_yaml(native_yaml_path, tags_yaml_path) 547 native_functions, backend_indices = ( 548 parsed_yaml.native_functions, 549 parsed_yaml.backend_indices, 550 ) 551 grouped_native_functions = get_grouped_native_functions(native_functions) 552 parsed_backend_yaml = parse_backend_yaml( 553 source_yaml, grouped_native_functions, backend_indices 554 ) 555 backend_key = parsed_backend_yaml.backend_key 556 autograd_key = parsed_backend_yaml.autograd_key 557 cpp_namespace = parsed_backend_yaml.cpp_namespace 558 class_name = parsed_backend_yaml.class_name 559 backend_indices = parsed_backend_yaml.backend_indices 560 561 selector = SelectiveBuilder.get_nop_selector() 562 563 if backend_key is None: 564 # This could be useful if a backend wants to quickly set up a noop yaml file but doesn't have any kernels ready yet. 565 return 566 567 if class_name is None: 568 # class_name is an optional argument to backend yaml file. 569 # if specified it allows an external backend to override 570 # the name of the class that all generated kernel definitions live under. 571 # if not specified, its value is given as native_function_class_name. 572 class_name = backend_indices[backend_key].native_function_class_name() 573 assert class_name is not None 574 575 if impl_path is not None: 576 error_on_missing_kernels( 577 native_functions, 578 backend_indices, 579 backend_key, 580 autograd_key, 581 class_name, 582 impl_path, 583 ) 584 585 gen_dispatchkey_nativefunc_headers( 586 fm, 587 class_name, 588 cpp_namespace, 589 backend_indices, 590 grouped_native_functions, 591 backend_key, 592 autograd_key, 593 ) 594 595 for dispatch_key in ( 596 [backend_key] if autograd_key is None else [backend_key, autograd_key] 597 ): 598 gen_dispatcher_registrations( 599 fm, 600 output_dir, 601 class_name, 602 backend_indices, 603 grouped_native_functions, 604 backend_key, 605 dispatch_key, 606 selector, 607 ) 608 609 610if __name__ == "__main__": 611 main() 612