1from __future__ import annotations 2 3import argparse 4import os 5from collections import defaultdict 6from dataclasses import dataclass 7from pathlib import Path 8from typing import Any, Callable, Sequence, TextIO, TYPE_CHECKING 9 10import yaml 11 12# Parse native_functions.yaml into a sequence of NativeFunctions and Backend Indices. 13from torchgen import dest 14from torchgen.api import cpp as aten_cpp 15from torchgen.api.types import CppSignature, CppSignatureGroup, CType, NamedCType 16from torchgen.context import ( 17 method_with_native_function, 18 method_with_nested_native_function, 19 with_native_function_and_index, 20) 21from torchgen.executorch.api import et_cpp 22from torchgen.executorch.api.custom_ops import ( 23 ComputeNativeFunctionStub, 24 gen_custom_ops_registration, 25) 26from torchgen.executorch.api.types import contextArg, ExecutorchCppSignature 27from torchgen.executorch.api.unboxing import Unboxing 28from torchgen.executorch.model import ETKernelIndex, ETKernelKey, ETParsedYaml 29from torchgen.executorch.parse import ET_FIELDS, parse_et_yaml, parse_et_yaml_struct 30from torchgen.gen import ( 31 get_custom_build_selector, 32 get_native_function_declarations, 33 get_native_function_declarations_from_ns_grouped_kernels, 34 get_native_function_schema_registrations, 35 LineLoader, 36 parse_native_yaml, 37) 38from torchgen.model import ( 39 BackendIndex, 40 BackendMetadata, 41 DEFAULT_KERNEL_NAMESPACE, 42 DispatchKey, 43 FunctionSchema, 44 Location, 45 NativeFunction, 46 NativeFunctionsGroup, 47 OperatorName, 48 Variant, 49) 50from torchgen.utils import ( 51 context, 52 FileManager, 53 make_file_manager, 54 mapMaybe, 55 NamespaceHelper, 56) 57 58 59if TYPE_CHECKING: 60 from torchgen.selective_build.selector import SelectiveBuilder 61 62 63def _sig_decl_wrapper(sig: CppSignature | ExecutorchCppSignature) -> str: 64 """ 65 A wrapper function to basically get `sig.decl(include_context=True)`. 66 For ATen kernel, the codegen has no idea about ET contextArg, so we 67 use this wrapper to add it. 68 """ 69 if isinstance(sig, ExecutorchCppSignature): 70 return sig.decl() 71 72 returns_type = aten_cpp.returns_type(sig.func.returns).cpp_type() 73 cpp_args = [a.decl() for a in sig.arguments()] 74 cpp_args_str = ", ".join([contextArg.decl()] + cpp_args) 75 sig_decl = f"{returns_type} {sig.name()}({cpp_args_str})" 76 return sig_decl 77 78 79def static_dispatch( 80 sig: CppSignature | ExecutorchCppSignature, 81 f: NativeFunction, 82 backend_indices: list[BackendIndex], 83) -> str: 84 """ 85 For a given `NativeFunction`, find out the corresponding native function and dispatch to it. If zero or more than one 86 native function exists, error out. A simplified version of register_dispatch_key.py 87 Arguments: 88 sig: A CppSignature for this native function we want to use. 89 f: NativeFunction to generate static dispatch. 90 backend_indices: All available backends. 91 Return: 92 C++ code to call backend-specific functions, e.g., "return at::native::add(self, other, scale);" 93 """ 94 if len(backend_indices) == 0 or f.manual_kernel_registration: 95 return "" 96 97 backends = [b for b in backend_indices if b.has_kernel(f)] 98 static_block = None 99 if len(backends) == 1: 100 backend_metadata = backends[0].get_kernel(f) 101 if backend_metadata: 102 args = ", ".join(a.name for a in sig.arguments()) 103 # Here we are assuming there's no difference between CppSignature and NativeSignature for Executorch. 104 static_block = f"return ::{backend_metadata.cpp_namespace}::{backend_metadata.kernel}({args});" 105 else: 106 static_block = f""" 107ET_ASSERT_UNREACHABLE_MSG("The number of native function(s) binding to {f.func.name} is {len(backends)}."); 108 """ 109 return f""" 110// {f.namespace}::{f.func} 111TORCH_API inline {_sig_decl_wrapper(sig)} {{ 112 {static_block} 113}} 114""" 115 116 117# Generates Functions.h, which provides the functional public C++ API, 118# and the scaffolding to call into the dispatcher from these functions. 119@dataclass(frozen=True) 120class ComputeFunction: 121 static_dispatch_backend_indices: list[BackendIndex] 122 123 selector: SelectiveBuilder 124 125 use_aten_lib: bool 126 127 is_custom_op: Callable[[NativeFunction], bool] 128 129 @method_with_native_function 130 def __call__(self, f: NativeFunction) -> str | None: 131 is_method_variant = False 132 if not self.selector.is_root_operator(f"{f.namespace}::{f.func.name}"): 133 return None 134 135 if Variant.function not in f.variants and Variant.method in f.variants: 136 is_method_variant = True 137 138 # only valid remaining case is only function is in f.variants 139 elif not (Variant.function in f.variants and Variant.method not in f.variants): 140 raise Exception( # noqa: TRY002 141 f"Can't handle native function {f.func} with the following variant specification {f.variants}." 142 ) 143 144 sig: CppSignature | ExecutorchCppSignature = ( 145 CppSignatureGroup.from_native_function( 146 f, method=False, fallback_binding=f.manual_cpp_binding 147 ).most_faithful_signature() 148 if self.use_aten_lib 149 else ExecutorchCppSignature.from_native_function(f) 150 ) 151 if self.use_aten_lib and not self.is_custom_op(f): 152 comma = ", " 153 154 if is_method_variant: 155 return f""" 156// {f.namespace}::{f.func} 157TORCH_API inline {_sig_decl_wrapper(sig)} {{ 158 return {sig.arguments()[0].name}.{sig.name()}({comma.join(e.name for e in sig.arguments()[1:])}); 159}} 160""" 161 else: 162 return f""" 163// {f.namespace}::{f.func} 164TORCH_API inline {_sig_decl_wrapper(sig)} {{ 165 return at::{sig.name()}({comma.join(e.name for e in sig.arguments())}); 166}} 167""" 168 169 else: 170 return static_dispatch( 171 sig, 172 f, 173 backend_indices=self.static_dispatch_backend_indices, 174 ) 175 176 177# Generates RegisterCodegenUnboxedKernels.cpp. 178@dataclass(frozen=True) 179class ComputeCodegenUnboxedKernels: 180 selector: SelectiveBuilder 181 182 use_aten_lib: bool 183 184 @method_with_nested_native_function 185 def __call__( 186 self, 187 unbox_kernel_entry: tuple[NativeFunction, tuple[ETKernelKey, BackendMetadata]], 188 ) -> str: 189 f: NativeFunction = unbox_kernel_entry[0] 190 kernel_key: ETKernelKey | list[ETKernelKey] = unbox_kernel_entry[1][0] 191 kernel_meta: BackendMetadata = unbox_kernel_entry[1][1] 192 193 op_name = f"{f.namespace}::{f.func.name}" 194 if not self.selector.is_root_operator(op_name): 195 return "" 196 197 if not isinstance(kernel_key, list): 198 kernel_key = [kernel_key] 199 used_kernel_keys = self.selector.et_get_selected_kernels( 200 op_name, [k.to_native_string() for k in kernel_key] 201 ) 202 if not used_kernel_keys: 203 return "" 204 sig: CppSignature | ExecutorchCppSignature 205 argument_type_gen: Callable[..., NamedCType] 206 return_type_gen: Callable[..., CType] 207 if self.use_aten_lib: 208 sig = CppSignatureGroup.from_native_function( 209 f, method=False, fallback_binding=f.manual_cpp_binding 210 ).most_faithful_signature() 211 argument_type_gen = aten_cpp.argumenttype_type 212 return_type_gen = aten_cpp.returns_type 213 arguments = sig.arguments() 214 kernel_call = f"torch::executor::{f.namespace}::{sig.name()}" 215 else: 216 sig = ExecutorchCppSignature.from_native_function(f) 217 argument_type_gen = et_cpp.argumenttype_type 218 return_type_gen = et_cpp.returns_type 219 arguments = sig.arguments(include_context=False) 220 kernel_call = f"{kernel_meta.cpp_namespace}::{kernel_meta.kernel}" 221 # parse arguments into C++ code 222 binding_list, code_list = Unboxing( 223 argument_type_gen=argument_type_gen 224 ).convert_arguments(arguments) 225 226 # for each C++ argument, generate the conversion code 227 code_connector = "\n\t" 228 arg_connector = ", " 229 230 args_str = f"{arg_connector.join(e.name for e in binding_list)}" 231 event_tracer_output_logging = "" 232 output_ids = [] 233 234 if len(f.func.returns) == 0: 235 if len(f.func.arguments.out) == 0: 236 raise Exception( # noqa: TRY002 237 f"Can't handle native function {f.func} with no returns and no out yet." 238 ) 239 out = f.func.arguments.out[0] 240 return_assignment = f"""stack[{len(binding_list)}] = &{out.name};""" 241 ret_prefix = "" 242 output_ids = [len(binding_list)] 243 else: 244 if len(f.func.arguments.out) == 0: 245 return_assignment = ( 246 f"""*stack[{len(binding_list)}] = EValue(result_);""" 247 ) 248 ret_prefix = return_type_gen(f.func.returns).cpp_type() + " result_ = " 249 output_ids = [len(binding_list)] 250 else: 251 return_assignment = "" 252 ret_prefix = "" 253 output_ids = [ 254 len(binding_list) - (i + 1) 255 for i in reversed(range(len(f.func.arguments.out))) 256 ] 257 258 for output_id in output_ids: 259 event_tracer_output_logging += ( 260 f"internal::event_tracer_log_evalue(" 261 f"context.internal_event_tracer(), " 262 f"*stack[{output_id}]);\n" 263 ) 264 265 newline = "\n " 266 return "\n".join( 267 [ 268 f""" 269Kernel( 270 "{f.namespace}::{f.func.name}",{newline + '"' + (k + '",') if k != 'default' else ''} 271 []({contextArg.defn()}, EValue** stack) {{ 272 {code_connector.join(code_list)} 273 274 internal::EventTracerProfileScope event_tracer_scope(context.internal_event_tracer(), "native_call_{f.func.name}"); 275 EXECUTORCH_SCOPE_PROF("native_call_{f.func.name}"); 276 {ret_prefix}{kernel_call}(context, {args_str}); 277 {event_tracer_output_logging} 278 {return_assignment} 279 }} 280), 281""" 282 for k in used_kernel_keys 283 ] 284 ) 285 286 287def gen_unboxing( 288 *, 289 native_functions: Sequence[NativeFunction], 290 cpu_fm: FileManager, 291 selector: SelectiveBuilder, 292 use_aten_lib: bool, 293 kernel_index: ETKernelIndex, 294 manual_registration: bool, 295) -> None: 296 # Iterable type for write_sharded is a Tuple of (native_function, (kernel_key, metadata)) 297 def key_func( 298 item: tuple[NativeFunction, tuple[ETKernelKey, BackendMetadata]] 299 ) -> str: 300 return item[0].root_name + ":" + item[1][0].to_native_string() 301 302 items: list[tuple[NativeFunction, tuple[ETKernelKey, BackendMetadata]]] = [ 303 (native_function, (kernel_key, metadata)) 304 for native_function in native_functions 305 for kernel_key, metadata in kernel_index.get_kernels(native_function).items() 306 ] 307 308 header = ["Functions.h" if use_aten_lib else "NativeFunctions.h"] 309 filename = ( 310 "RegisterKernels.cpp" 311 if manual_registration 312 else "RegisterCodegenUnboxedKernels.cpp" 313 ) 314 cpu_fm.write_sharded( 315 filename, 316 items, 317 key_fn=key_func, 318 env_callable=lambda unbox_kernel_entry: { 319 "unboxed_kernels": [ 320 ComputeCodegenUnboxedKernels(selector, use_aten_lib)(unbox_kernel_entry) 321 ], 322 "fn_header": header 323 if unbox_kernel_entry == items[0] 324 else [], # Only write header once 325 }, 326 num_shards=1, 327 sharded_keys={"unboxed_kernels", "fn_header"}, 328 ) 329 330 331@with_native_function_and_index # type: ignore[arg-type] 332def compute_native_function_declaration( 333 g: NativeFunctionsGroup | NativeFunction, kernel_index: ETKernelIndex 334) -> list[str]: 335 assert isinstance(g, NativeFunction) 336 sig = ExecutorchCppSignature.from_native_function(f=g) 337 metadata_list = kernel_index.get_kernels(g).values() 338 if metadata_list is None: 339 return [] 340 341 # for kernels in lean mode, we declare two versions, one with context and one without. 342 # In the end we will cleanup the unused one. 343 def gen_decl(metadata: BackendMetadata, include_context: bool) -> str: 344 return f"{sig.decl(name=metadata.kernel, include_context=include_context)};" 345 346 return [ 347 gen_decl(metadata, include_context) 348 for include_context in [False, True] 349 for metadata in metadata_list 350 ] 351 352 353def gen_functions_declarations( 354 *, 355 native_functions: Sequence[NativeFunction], 356 kernel_index: ETKernelIndex, 357 selector: SelectiveBuilder, 358 use_aten_lib: bool, 359 custom_ops_native_functions: Sequence[NativeFunction] | None = None, 360) -> str: 361 """ 362 Generates namespace separated C++ function API inline declaration/definitions. 363 Native functions are grouped by namespaces and the generated code is wrapped inside 364 namespace blocks. 365 366 E.g., for `custom_1::foo.out` in yaml file we will generate a C++ API as a symbol 367 in `torch::executor::custom_1::foo_out`. This way we avoid symbol conflict when 368 the other `custom_2::foo.out` is available. 369 """ 370 371 # convert kernel index to BackendIndex. This is because we can't handle ETKernelIndex yet. 372 # TODO larryliu: evaluate if this code is still needed. If yes let it handle ETKernelIndex. 373 374 backend_index = kernel_index._to_backend_index() 375 376 ns_grouped_functions = defaultdict(list) 377 for native_function in native_functions: 378 ns_grouped_functions[native_function.namespace].append(native_function) 379 functions_declarations = "" 380 newline = "\n" 381 for namespace in ns_grouped_functions: 382 ns_helper = NamespaceHelper( 383 namespace_str=namespace, 384 entity_name="", 385 max_level=3, 386 ) 387 declarations = list( 388 mapMaybe( 389 ComputeFunction( 390 static_dispatch_backend_indices=[backend_index], 391 selector=selector, 392 use_aten_lib=use_aten_lib, 393 is_custom_op=lambda f: custom_ops_native_functions is not None 394 and f in custom_ops_native_functions, 395 ), 396 ns_grouped_functions[namespace], 397 ) 398 ) 399 functions_declarations += f""" 400{ns_helper.prologue} 401{newline.join(declarations)} 402{ns_helper.epilogue} 403 """ 404 return functions_declarations 405 406 407def get_ns_grouped_kernels( 408 *, 409 native_functions: Sequence[NativeFunction], 410 kernel_index: ETKernelIndex, 411 native_function_decl_gen: Callable[ 412 [ 413 NativeFunctionsGroup | NativeFunction, 414 ETKernelIndex, 415 ], 416 list[str], 417 ], 418) -> dict[str, list[str]]: 419 ns_grouped_kernels: dict[str, list[str]] = defaultdict(list) 420 for f in native_functions: 421 native_function_namespaces = set() 422 op_kernels = kernel_index.get_kernels(f) 423 for backend_metadata in op_kernels.values(): 424 if backend_metadata: 425 namespace = backend_metadata.cpp_namespace 426 native_function_namespaces.add(namespace) 427 else: 428 namespace = DEFAULT_KERNEL_NAMESPACE 429 assert ( 430 len(native_function_namespaces) <= 1 431 ), f"Codegen only supports one namespace per operator, got {native_function_namespaces}" 432 ns_grouped_kernels[namespace].extend( 433 native_function_decl_gen(f, kernel_index) 434 ) 435 return ns_grouped_kernels 436 437 438def gen_headers( 439 *, 440 native_functions: Sequence[NativeFunction], 441 gen_custom_ops_header: bool, 442 custom_ops_native_functions: Sequence[NativeFunction], 443 selector: SelectiveBuilder, 444 kernel_index: ETKernelIndex, 445 cpu_fm: FileManager, 446 use_aten_lib: bool, 447) -> None: 448 """Generate headers. 449 450 Args: 451 native_functions (Sequence[NativeFunction]): a collection of NativeFunction for ATen ops. 452 gen_custom_ops_header (bool): whether we should generate CustomOpsNativeFunctions.h 453 custom_ops_native_functions (Sequence[NativeFunction]): a collection of NativeFunction for custom ops. 454 kernel_index (ETKernelIndex): kernel collection 455 cpu_fm (FileManager): file manager manages output stream 456 use_aten_lib (bool): whether we are generating for PyTorch types or Executorch types. 457 """ 458 aten_headers = ["#include <ATen/Functions.h>"] 459 backend_indices = {DispatchKey.CPU: kernel_index._to_backend_index()} 460 if gen_custom_ops_header: 461 cpu_fm.write_with_template( 462 "CustomOpsNativeFunctions.h", 463 "NativeFunctions.h", 464 lambda: { 465 "nativeFunctions_declarations": get_native_function_declarations( 466 grouped_native_functions=custom_ops_native_functions, 467 backend_indices=backend_indices, 468 native_function_decl_gen=dest.compute_native_function_declaration, 469 ), 470 "headers": [ 471 "#include <ATen/ATen.h>", 472 "#include <torch/torch.h>", 473 ], 474 }, 475 ) 476 aten_headers.append('#include "CustomOpsNativeFunctions.h"') 477 cpu_fm.write( 478 "Functions.h", 479 lambda: { 480 "static_dispatch_extra_headers": aten_headers 481 if use_aten_lib 482 else ['#include "NativeFunctions.h"'], 483 "Functions_declarations": gen_functions_declarations( 484 native_functions=native_functions, 485 kernel_index=kernel_index, 486 selector=selector, 487 use_aten_lib=use_aten_lib, 488 custom_ops_native_functions=custom_ops_native_functions, 489 ), 490 }, 491 ) 492 cpu_fm.write( 493 "RegisterKernels.h", 494 lambda: { 495 "generated_comment": "@" + "generated by torchgen/gen_executorch.py", 496 }, 497 ) 498 headers = { 499 "headers": [ 500 "#include <executorch/runtime/core/exec_aten/exec_aten.h> // at::Tensor etc.", 501 "#include <executorch/runtime/kernel/kernel_runtime_context.h>", 502 ], 503 } 504 if use_aten_lib: 505 headers["headers"].append("#include <executorch/codegen/macros.h> // TORCH_API") 506 cpu_fm.write( 507 "NativeFunctions.h", 508 lambda: dict( 509 { 510 "nativeFunctions_declarations": get_native_function_declarations( 511 grouped_native_functions=native_functions, 512 backend_indices=backend_indices, 513 native_function_decl_gen=dest.compute_native_function_declaration, 514 ), 515 }, 516 **headers, 517 ), 518 ) 519 else: 520 ns_grouped_kernels = get_ns_grouped_kernels( 521 native_functions=native_functions, 522 kernel_index=kernel_index, 523 native_function_decl_gen=compute_native_function_declaration, # type: ignore[arg-type] 524 ) 525 cpu_fm.write( 526 "NativeFunctions.h", 527 lambda: dict( 528 { 529 "nativeFunctions_declarations": get_native_function_declarations_from_ns_grouped_kernels( 530 ns_grouped_kernels=ns_grouped_kernels, 531 ), 532 }, 533 **headers, 534 ), 535 ) 536 537 538def gen_custom_ops( 539 *, 540 native_functions: Sequence[NativeFunction], 541 selector: SelectiveBuilder, 542 kernel_index: ETKernelIndex, 543 cpu_fm: FileManager, 544 rocm: bool, 545) -> None: 546 dispatch_key = DispatchKey.CPU 547 ( 548 anonymous_definition, 549 static_init_dispatch_registrations, 550 ) = gen_custom_ops_registration( 551 native_functions=native_functions, 552 selector=selector, 553 kernel_index=kernel_index, 554 rocm=rocm, 555 ) 556 cpu_fm.write_with_template( 557 f"Register{dispatch_key}CustomOps.cpp", 558 "RegisterDispatchKeyCustomOps.cpp", 559 lambda: { 560 "ops_headers": '#include "CustomOpsNativeFunctions.h"', 561 "DispatchKey": dispatch_key, 562 "dispatch_namespace": dispatch_key.lower(), 563 "dispatch_namespaced_definitions": "", 564 "dispatch_anonymous_definitions": anonymous_definition, 565 "static_init_dispatch_registrations": static_init_dispatch_registrations, 566 }, 567 ) 568 cpu_fm.write_with_template( 569 f"Register{dispatch_key}Stub.cpp", 570 "RegisterDispatchKeyCustomOps.cpp", 571 lambda: { 572 "ops_headers": "", 573 "DispatchKey": dispatch_key, 574 "dispatch_namespace": dispatch_key.lower(), 575 "dispatch_namespaced_definitions": "", 576 "dispatch_anonymous_definitions": list( 577 mapMaybe(ComputeNativeFunctionStub(), native_functions) 578 ), 579 "static_init_dispatch_registrations": static_init_dispatch_registrations, 580 }, 581 ) 582 583 ( 584 aten_schema_registrations, 585 schema_registrations, 586 ) = get_native_function_schema_registrations( 587 native_functions=native_functions, 588 schema_selector=selector, 589 ) 590 cpu_fm.write( 591 "RegisterSchema.cpp", 592 lambda: { 593 "schema_registrations": schema_registrations, 594 "aten_schema_registrations": aten_schema_registrations, 595 }, 596 ) 597 598 599def translate_native_yaml( 600 tags_yaml_path: str, 601 aten_yaml_path: str, 602 native_yaml_path: str | None, 603 use_aten_lib: bool, 604 out_file: TextIO, 605) -> None: 606 """Translates Executorch DSL dialect to use the same syntax as 607 native_functions.yaml. The major difference is that Executorch DSL dialect 608 supports "op" key, where it refers to the operator name in native_functions.yaml. 609 610 For example, a functions.yaml may have the following entry: 611 612 - op: add.out 613 ... 614 615 It needs to be translated to the following: 616 617 - func: add.out(Tensor self, Tensor other, *, Scalar alpha=1, Tensor(a!) out) -> Tensor(a!) 618 ... 619 620 We go in aten_yaml_path and find the operator schema for "add.out" and add it 621 to the original functions.yaml. We also add required field "variants", where for 622 Executorch it will always be "function". 623 624 For ATen mode we don't have to do the translation because native_yaml_path is 625 the same as native_functions.yaml. 626 627 Args: 628 tags_yaml_path: Path to a tags.yaml file to satisfy codegen parsing. 629 It is not optional. 630 aten_yaml_path: Path to ATen operator yaml file native_functions.yaml. 631 native_yaml_path: Path to a functions.yaml file to parse. 632 If the path does not exist in the filesystem, it is treated as an 633 empty file. If `custom_ops_yaml_path` exists, the contents of that 634 file are appended to the yaml input to be parsed. 635 use_aten_lib: We use this flag to determine if we want to generate native 636 functions. In ATen mode we should generate out= variants. 637 out_file: The IO object that we are writing into. 638 Returns: 639 None 640 """ 641 if use_aten_lib: 642 with open(aten_yaml_path) as aten_yaml: 643 out_file.writelines(aten_yaml.readlines()) 644 return 645 646 native_functions, persisted_fields = parse_et_yaml( 647 aten_yaml_path, 648 tags_yaml_path, 649 None, 650 skip_native_fns_gen=False, 651 ) 652 653 func_to_scoped_name: dict[FunctionSchema, str] = { 654 f.func: f"{f.namespace}::{f.func.name}" for f in native_functions 655 } 656 op_to_scoped_name: dict[OperatorName, str] = { 657 func.name: name for func, name in func_to_scoped_name.items() 658 } 659 660 schema_dict = {name: str(func) for func, name in func_to_scoped_name.items()} 661 kernel_persist_dict: dict[str, dict[str, Any]] = { 662 op_to_scoped_name[op]: v for op, v in persisted_fields.items() 663 } 664 665 if ( 666 not native_yaml_path 667 or not os.path.exists(native_yaml_path) 668 or os.stat(native_yaml_path).st_size == 0 669 ): 670 return 671 with open(native_yaml_path) as native_yaml: 672 native_es = yaml.load(native_yaml, Loader=LineLoader) 673 if not native_es: 674 return 675 for e in native_es: 676 assert isinstance(e.get("__line__"), int), e 677 loc = Location(native_yaml_path, e.pop("__line__")) 678 with context(lambda: f"in {loc}:\n "): 679 if "variants" not in e: 680 e["variants"] = "function" 681 if "func" in e: 682 continue 683 assert isinstance(e.get("op"), str), e 684 opname = e.pop("op") 685 if "::" not in opname: 686 opname = "aten::" + opname 687 assert opname in schema_dict 688 e["func"] = schema_dict.get(opname) 689 690 # Write out persisted kernel information 691 if opname in kernel_persist_dict: 692 for k, v in kernel_persist_dict[opname].items(): 693 e[k] = v 694 695 yaml.dump(native_es, out_file, width=1000) 696 697 698def parse_yaml( 699 path: str | None, 700 tags_yaml_path: str, 701 function_filter: Callable[[NativeFunction], bool], 702 skip_native_fns_gen: bool = False, 703) -> tuple[ 704 list[NativeFunction], 705 dict[DispatchKey, dict[OperatorName, BackendMetadata]] | ETKernelIndex, 706]: 707 if path and os.path.exists(path) and os.stat(path).st_size > 0: 708 with open(path) as f: 709 es = yaml.load(f, Loader=LineLoader) 710 711 # Check for kernel index structure 712 kernel_index = ( 713 parse_et_yaml_struct(es) if any("kernels" in e for e in es) else None 714 ) 715 716 # Remove ET specific fields from entries for BC compatibility 717 for entry in es: 718 for field in ET_FIELDS: 719 entry.pop(field, None) 720 721 parsed_yaml = parse_native_yaml( 722 path, 723 tags_yaml_path, 724 None, 725 skip_native_fns_gen=skip_native_fns_gen, 726 loaded_yaml=es, 727 ) 728 native_functions = list(filter(function_filter, parsed_yaml.native_functions)) 729 op_names = [f.func.name for f in native_functions] 730 731 # (1) Return ETKernelIndex if kernel index is present 732 if kernel_index is not None: 733 filtered_index = { 734 op_name: kernel_mapping 735 for op_name, kernel_mapping in kernel_index.index.items() 736 if op_name in op_names 737 } 738 return native_functions, ETKernelIndex(index=filtered_index) 739 740 # (2) Return BackendIndices if kernel index is absent 741 def map_index( 742 m: dict[OperatorName, BackendMetadata] 743 ) -> dict[OperatorName, BackendMetadata]: 744 return {op: m[op] for op in m if op in op_names} 745 746 backend_indices = { 747 k: map_index(b.index) for (k, b) in parsed_yaml.backend_indices.items() 748 } 749 750 return native_functions, backend_indices 751 else: 752 return [], {} 753 754 755def parse_yaml_files( 756 tags_yaml_path: str, 757 aten_yaml_path: str, 758 native_yaml_path: str | None, 759 custom_ops_yaml_path: str | None, 760 selector: SelectiveBuilder, 761 use_aten_lib: bool, 762) -> tuple[ETParsedYaml, ETParsedYaml | None]: 763 """Parses functions.yaml and custom_ops.yaml files. 764 765 Args: 766 tags_yaml_path: Path to a tags.yaml file to satisfy codegen parsing. 767 It is not optional. 768 aten_yaml_path: Path to ATen operator yaml file native_functions.yaml. 769 native_yaml_path: Path to a functions.yaml file to parse. 770 If the path does not exist in the filesystem, it is treated as an 771 empty file. If `custom_ops_yaml_path` exists, the contents of that 772 file are appended to the yaml input to be parsed. 773 custom_ops_yaml_path: Path to a custom_ops.yaml file to parse. If 774 the path does not exist in the filesystem, it is ignored. 775 selector: For selective build. 776 use_aten_lib: We use this flag to determine if we want to generate native 777 functions. In ATen mode we should generate out= variants. 778 Returns: 779 A tuple with two elements: 780 [0]: The parsed results of concatenating the contents of 781 `native_yaml_path` and `custom_ops_yaml_path`. 782 [1]: The parsed results of the contents of `custom_ops_yaml_path`, if 783 present. If not present, None. 784 """ 785 import tempfile 786 787 # only include selected ops, this is because we want to avoid 788 def function_filter(f: NativeFunction) -> bool: 789 return selector.is_native_function_selected(f) 790 791 with tempfile.TemporaryDirectory() as tmpdirname: 792 translated_yaml_path = os.path.join(tmpdirname, "translated.yaml") 793 with open(translated_yaml_path, "w") as translated: 794 translate_native_yaml( 795 tags_yaml_path, 796 aten_yaml_path, 797 native_yaml_path, 798 use_aten_lib, 799 translated, 800 ) 801 802 translated_functions, translated_indices = parse_yaml( 803 translated_yaml_path, tags_yaml_path, function_filter, not use_aten_lib 804 ) 805 custom_ops_functions, custom_ops_indices = parse_yaml( 806 custom_ops_yaml_path, tags_yaml_path, function_filter, True 807 ) 808 809 # Convert BackendIndices to ETKernelIndex 810 if not isinstance(translated_indices, ETKernelIndex): 811 translated_indices = ETKernelIndex.from_backend_indices(translated_indices) 812 if not isinstance(custom_ops_indices, ETKernelIndex): 813 custom_ops_indices = ETKernelIndex.from_backend_indices(custom_ops_indices) 814 815 combined_functions = translated_functions + custom_ops_functions 816 combined_kernel_index = ETKernelIndex.merge_indices( 817 translated_indices, custom_ops_indices 818 ) 819 combined_yaml = ETParsedYaml(combined_functions, combined_kernel_index) 820 custom_ops_parsed_yaml = ETParsedYaml(custom_ops_functions, custom_ops_indices) 821 822 return combined_yaml, custom_ops_parsed_yaml 823 824 825def main() -> None: 826 parser = argparse.ArgumentParser(description="Generate operator source files") 827 # Although we don't refer to --source-path directly, make_file_manager() 828 # expects it to point to a directory that contains a templates/ subdirectory 829 # containing the file templates. 830 parser.add_argument( 831 "-s", 832 "--source-path", 833 help="path to source directory for kernel templates", 834 ) 835 parser.add_argument( 836 "--functions-yaml-path", 837 "--functions_yaml_path", 838 help="path to the functions.yaml file to use. Optional, but at least " 839 "one of --functions-yaml-path and --custom-ops-yaml-path must be " 840 "specified.", 841 ) 842 parser.add_argument( 843 "--custom-ops-yaml-path", 844 "--custom_ops_yaml_path", 845 help="path to the custom_ops.yaml file to use. Optional, but at least " 846 "one of --functions-yaml-path and --custom-ops-yaml-path must be " 847 "specified.", 848 ) 849 parser.add_argument( 850 "--aten-yaml-path", 851 "--aten_yaml_path", 852 help="path to native_functions.yaml file.", 853 ) 854 # Note that make_file_manager() also looks at --install-dir. 855 parser.add_argument( 856 "-d", 857 "--install-dir", 858 "--install_dir", 859 help="output directory", 860 default="build/generated", 861 ) 862 parser.add_argument( 863 "-o", 864 "--output-dependencies", 865 help="output a list of dependencies into the given file and exit", 866 ) 867 # Although we don't refer to --dry-run directly, make_file_manager() looks 868 # for it. 869 parser.add_argument( 870 "--dry-run", 871 action="store_true", 872 help="run without writing any files (still updates outputs)", 873 ) 874 parser.add_argument( 875 "--static-dispatch-backend", 876 "--static_dispatch_backend", 877 nargs="*", 878 help="generate static dispatch code for the specific backend (if set)", 879 ) 880 parser.add_argument( 881 "--op-registration-whitelist", 882 "--op_registration_whitelist", 883 nargs="*", 884 help="filter op registrations by the whitelist (if set); " 885 "each item is `namespace`::`operator name` without overload name; " 886 "e.g.: aten::empty aten::conv2d ...", 887 ) 888 parser.add_argument( 889 "--op-selection-yaml-path", 890 "--op_selection_yaml_path", 891 help="Provide a path to the operator selection (for custom build) YAML " 892 "that contains the information about the set of selected operators " 893 "and their categories (training, ...). Each operator is either a " 894 "full operator name with overload or just a bare operator name. " 895 "The operator names also contain the namespace prefix (e.g. aten::)", 896 ) 897 parser.add_argument( 898 "--tags-path", 899 help="Path to tags.yaml. Required by yaml parsing in codegen system.", 900 ) 901 parser.add_argument( 902 "--rocm", 903 action="store_true", 904 help="reinterpret CUDA as ROCm/HIP and adjust filepaths accordingly", 905 ) 906 parser.add_argument( 907 "--use-aten-lib", 908 "--use_aten_lib", 909 action="store_true", 910 help="a boolean flag to indicate whether we use ATen kernels or not, in the future this flag will be per " 911 "operator", 912 ) 913 parser.add_argument( 914 "--manual_registration", 915 "--manual-registration", 916 action="store_true", 917 help="a boolean flag to indicate whether we want to manually call" 918 "register_kernels() or rely on static init. ", 919 ) 920 parser.add_argument( 921 "--generate", 922 type=str, 923 nargs="*", 924 choices=["headers", "sources"], 925 default=["headers", "sources"], 926 help="Generate only a subset of files", 927 ) 928 options = parser.parse_args() 929 assert options.tags_path, "tags.yaml is required by codegen yaml parsing." 930 931 selector = get_custom_build_selector( 932 options.op_registration_whitelist, 933 options.op_selection_yaml_path, 934 ) 935 936 parsed_yaml, custom_ops_parsed_yaml = parse_yaml_files( 937 aten_yaml_path=options.aten_yaml_path, 938 tags_yaml_path=options.tags_path, 939 native_yaml_path=options.functions_yaml_path, 940 custom_ops_yaml_path=options.custom_ops_yaml_path, 941 selector=selector, 942 use_aten_lib=options.use_aten_lib, 943 ) 944 native_functions, kernel_index = ( 945 parsed_yaml.native_functions, 946 parsed_yaml.kernel_index, 947 ) 948 custom_ops_native_functions = ( 949 custom_ops_parsed_yaml.native_functions if custom_ops_parsed_yaml else [] 950 ) 951 952 cpu_fm = make_file_manager(options=options) 953 954 if "headers" in options.generate: 955 # generate CustomOpsNativeFunctions.h when custom_ops.yaml is present, to match the build system. 956 gen_headers( 957 native_functions=native_functions, 958 gen_custom_ops_header=options.custom_ops_yaml_path, 959 custom_ops_native_functions=custom_ops_native_functions, 960 selector=selector, 961 kernel_index=kernel_index, 962 cpu_fm=cpu_fm, 963 use_aten_lib=options.use_aten_lib, 964 ) 965 966 if "sources" in options.generate: 967 gen_unboxing( 968 native_functions=native_functions, 969 cpu_fm=cpu_fm, 970 selector=selector, 971 use_aten_lib=options.use_aten_lib, 972 kernel_index=kernel_index, 973 manual_registration=options.manual_registration, 974 ) 975 if custom_ops_native_functions: 976 gen_custom_ops( 977 native_functions=custom_ops_native_functions, 978 selector=selector, 979 kernel_index=kernel_index, 980 cpu_fm=cpu_fm, 981 rocm=options.rocm, 982 ) 983 984 if options.output_dependencies: 985 depfile_path = Path(options.output_dependencies).resolve() 986 depfile_name = depfile_path.name 987 depfile_stem = depfile_path.stem 988 989 for fm, prefix in [ 990 (cpu_fm, ""), 991 ]: 992 varname = prefix + depfile_stem 993 path = depfile_path.parent / (prefix + depfile_name) 994 fm.write_outputs(varname, str(path)) 995 996 997if __name__ == "__main__": 998 main() 999