xref: /aosp_15_r20/external/pytorch/torchgen/gen_executorch.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1from __future__ import annotations
2
3import argparse
4import os
5from collections import defaultdict
6from dataclasses import dataclass
7from pathlib import Path
8from typing import Any, Callable, Sequence, TextIO, TYPE_CHECKING
9
10import yaml
11
12# Parse native_functions.yaml into a sequence of NativeFunctions and Backend Indices.
13from torchgen import dest
14from torchgen.api import cpp as aten_cpp
15from torchgen.api.types import CppSignature, CppSignatureGroup, CType, NamedCType
16from torchgen.context import (
17    method_with_native_function,
18    method_with_nested_native_function,
19    with_native_function_and_index,
20)
21from torchgen.executorch.api import et_cpp
22from torchgen.executorch.api.custom_ops import (
23    ComputeNativeFunctionStub,
24    gen_custom_ops_registration,
25)
26from torchgen.executorch.api.types import contextArg, ExecutorchCppSignature
27from torchgen.executorch.api.unboxing import Unboxing
28from torchgen.executorch.model import ETKernelIndex, ETKernelKey, ETParsedYaml
29from torchgen.executorch.parse import ET_FIELDS, parse_et_yaml, parse_et_yaml_struct
30from torchgen.gen import (
31    get_custom_build_selector,
32    get_native_function_declarations,
33    get_native_function_declarations_from_ns_grouped_kernels,
34    get_native_function_schema_registrations,
35    LineLoader,
36    parse_native_yaml,
37)
38from torchgen.model import (
39    BackendIndex,
40    BackendMetadata,
41    DEFAULT_KERNEL_NAMESPACE,
42    DispatchKey,
43    FunctionSchema,
44    Location,
45    NativeFunction,
46    NativeFunctionsGroup,
47    OperatorName,
48    Variant,
49)
50from torchgen.utils import (
51    context,
52    FileManager,
53    make_file_manager,
54    mapMaybe,
55    NamespaceHelper,
56)
57
58
59if TYPE_CHECKING:
60    from torchgen.selective_build.selector import SelectiveBuilder
61
62
63def _sig_decl_wrapper(sig: CppSignature | ExecutorchCppSignature) -> str:
64    """
65    A wrapper function to basically get `sig.decl(include_context=True)`.
66    For ATen kernel, the codegen has no idea about ET contextArg, so we
67    use this wrapper to add it.
68    """
69    if isinstance(sig, ExecutorchCppSignature):
70        return sig.decl()
71
72    returns_type = aten_cpp.returns_type(sig.func.returns).cpp_type()
73    cpp_args = [a.decl() for a in sig.arguments()]
74    cpp_args_str = ", ".join([contextArg.decl()] + cpp_args)
75    sig_decl = f"{returns_type} {sig.name()}({cpp_args_str})"
76    return sig_decl
77
78
79def static_dispatch(
80    sig: CppSignature | ExecutorchCppSignature,
81    f: NativeFunction,
82    backend_indices: list[BackendIndex],
83) -> str:
84    """
85    For a given `NativeFunction`, find out the corresponding native function and dispatch to it. If zero or more than one
86    native function exists, error out. A simplified version of register_dispatch_key.py
87    Arguments:
88        sig: A CppSignature for this native function we want to use.
89        f: NativeFunction to generate static dispatch.
90        backend_indices: All available backends.
91    Return:
92        C++ code to call backend-specific functions, e.g., "return at::native::add(self, other, scale);"
93    """
94    if len(backend_indices) == 0 or f.manual_kernel_registration:
95        return ""
96
97    backends = [b for b in backend_indices if b.has_kernel(f)]
98    static_block = None
99    if len(backends) == 1:
100        backend_metadata = backends[0].get_kernel(f)
101        if backend_metadata:
102            args = ", ".join(a.name for a in sig.arguments())
103            # Here we are assuming there's no difference between CppSignature and NativeSignature for Executorch.
104            static_block = f"return ::{backend_metadata.cpp_namespace}::{backend_metadata.kernel}({args});"
105    else:
106        static_block = f"""
107ET_ASSERT_UNREACHABLE_MSG("The number of native function(s) binding to {f.func.name} is {len(backends)}.");
108    """
109    return f"""
110// {f.namespace}::{f.func}
111TORCH_API inline {_sig_decl_wrapper(sig)} {{
112    {static_block}
113}}
114"""
115
116
117# Generates Functions.h, which provides the functional public C++ API,
118# and the scaffolding to call into the dispatcher from these functions.
119@dataclass(frozen=True)
120class ComputeFunction:
121    static_dispatch_backend_indices: list[BackendIndex]
122
123    selector: SelectiveBuilder
124
125    use_aten_lib: bool
126
127    is_custom_op: Callable[[NativeFunction], bool]
128
129    @method_with_native_function
130    def __call__(self, f: NativeFunction) -> str | None:
131        is_method_variant = False
132        if not self.selector.is_root_operator(f"{f.namespace}::{f.func.name}"):
133            return None
134
135        if Variant.function not in f.variants and Variant.method in f.variants:
136            is_method_variant = True
137
138        # only valid remaining case is only function is in f.variants
139        elif not (Variant.function in f.variants and Variant.method not in f.variants):
140            raise Exception(  # noqa: TRY002
141                f"Can't handle native function {f.func} with the following variant specification {f.variants}."
142            )
143
144        sig: CppSignature | ExecutorchCppSignature = (
145            CppSignatureGroup.from_native_function(
146                f, method=False, fallback_binding=f.manual_cpp_binding
147            ).most_faithful_signature()
148            if self.use_aten_lib
149            else ExecutorchCppSignature.from_native_function(f)
150        )
151        if self.use_aten_lib and not self.is_custom_op(f):
152            comma = ", "
153
154            if is_method_variant:
155                return f"""
156// {f.namespace}::{f.func}
157TORCH_API inline {_sig_decl_wrapper(sig)} {{
158    return {sig.arguments()[0].name}.{sig.name()}({comma.join(e.name for e in sig.arguments()[1:])});
159}}
160"""
161            else:
162                return f"""
163// {f.namespace}::{f.func}
164TORCH_API inline {_sig_decl_wrapper(sig)} {{
165    return at::{sig.name()}({comma.join(e.name for e in sig.arguments())});
166}}
167"""
168
169        else:
170            return static_dispatch(
171                sig,
172                f,
173                backend_indices=self.static_dispatch_backend_indices,
174            )
175
176
177# Generates RegisterCodegenUnboxedKernels.cpp.
178@dataclass(frozen=True)
179class ComputeCodegenUnboxedKernels:
180    selector: SelectiveBuilder
181
182    use_aten_lib: bool
183
184    @method_with_nested_native_function
185    def __call__(
186        self,
187        unbox_kernel_entry: tuple[NativeFunction, tuple[ETKernelKey, BackendMetadata]],
188    ) -> str:
189        f: NativeFunction = unbox_kernel_entry[0]
190        kernel_key: ETKernelKey | list[ETKernelKey] = unbox_kernel_entry[1][0]
191        kernel_meta: BackendMetadata = unbox_kernel_entry[1][1]
192
193        op_name = f"{f.namespace}::{f.func.name}"
194        if not self.selector.is_root_operator(op_name):
195            return ""
196
197        if not isinstance(kernel_key, list):
198            kernel_key = [kernel_key]
199        used_kernel_keys = self.selector.et_get_selected_kernels(
200            op_name, [k.to_native_string() for k in kernel_key]
201        )
202        if not used_kernel_keys:
203            return ""
204        sig: CppSignature | ExecutorchCppSignature
205        argument_type_gen: Callable[..., NamedCType]
206        return_type_gen: Callable[..., CType]
207        if self.use_aten_lib:
208            sig = CppSignatureGroup.from_native_function(
209                f, method=False, fallback_binding=f.manual_cpp_binding
210            ).most_faithful_signature()
211            argument_type_gen = aten_cpp.argumenttype_type
212            return_type_gen = aten_cpp.returns_type
213            arguments = sig.arguments()
214            kernel_call = f"torch::executor::{f.namespace}::{sig.name()}"
215        else:
216            sig = ExecutorchCppSignature.from_native_function(f)
217            argument_type_gen = et_cpp.argumenttype_type
218            return_type_gen = et_cpp.returns_type
219            arguments = sig.arguments(include_context=False)
220            kernel_call = f"{kernel_meta.cpp_namespace}::{kernel_meta.kernel}"
221        # parse arguments into C++ code
222        binding_list, code_list = Unboxing(
223            argument_type_gen=argument_type_gen
224        ).convert_arguments(arguments)
225
226        # for each C++ argument, generate the conversion code
227        code_connector = "\n\t"
228        arg_connector = ", "
229
230        args_str = f"{arg_connector.join(e.name for e in binding_list)}"
231        event_tracer_output_logging = ""
232        output_ids = []
233
234        if len(f.func.returns) == 0:
235            if len(f.func.arguments.out) == 0:
236                raise Exception(  # noqa: TRY002
237                    f"Can't handle native function {f.func} with no returns and no out yet."
238                )
239            out = f.func.arguments.out[0]
240            return_assignment = f"""stack[{len(binding_list)}] = &{out.name};"""
241            ret_prefix = ""
242            output_ids = [len(binding_list)]
243        else:
244            if len(f.func.arguments.out) == 0:
245                return_assignment = (
246                    f"""*stack[{len(binding_list)}] = EValue(result_);"""
247                )
248                ret_prefix = return_type_gen(f.func.returns).cpp_type() + " result_ = "
249                output_ids = [len(binding_list)]
250            else:
251                return_assignment = ""
252                ret_prefix = ""
253                output_ids = [
254                    len(binding_list) - (i + 1)
255                    for i in reversed(range(len(f.func.arguments.out)))
256                ]
257
258        for output_id in output_ids:
259            event_tracer_output_logging += (
260                f"internal::event_tracer_log_evalue("
261                f"context.internal_event_tracer(), "
262                f"*stack[{output_id}]);\n"
263            )
264
265        newline = "\n    "
266        return "\n".join(
267            [
268                f"""
269Kernel(
270    "{f.namespace}::{f.func.name}",{newline + '"' + (k + '",') if k != 'default' else ''}
271    []({contextArg.defn()}, EValue** stack) {{
272        {code_connector.join(code_list)}
273
274        internal::EventTracerProfileScope event_tracer_scope(context.internal_event_tracer(), "native_call_{f.func.name}");
275        EXECUTORCH_SCOPE_PROF("native_call_{f.func.name}");
276        {ret_prefix}{kernel_call}(context, {args_str});
277        {event_tracer_output_logging}
278        {return_assignment}
279    }}
280),
281"""
282                for k in used_kernel_keys
283            ]
284        )
285
286
287def gen_unboxing(
288    *,
289    native_functions: Sequence[NativeFunction],
290    cpu_fm: FileManager,
291    selector: SelectiveBuilder,
292    use_aten_lib: bool,
293    kernel_index: ETKernelIndex,
294    manual_registration: bool,
295) -> None:
296    # Iterable type for write_sharded is a Tuple of (native_function, (kernel_key, metadata))
297    def key_func(
298        item: tuple[NativeFunction, tuple[ETKernelKey, BackendMetadata]]
299    ) -> str:
300        return item[0].root_name + ":" + item[1][0].to_native_string()
301
302    items: list[tuple[NativeFunction, tuple[ETKernelKey, BackendMetadata]]] = [
303        (native_function, (kernel_key, metadata))
304        for native_function in native_functions
305        for kernel_key, metadata in kernel_index.get_kernels(native_function).items()
306    ]
307
308    header = ["Functions.h" if use_aten_lib else "NativeFunctions.h"]
309    filename = (
310        "RegisterKernels.cpp"
311        if manual_registration
312        else "RegisterCodegenUnboxedKernels.cpp"
313    )
314    cpu_fm.write_sharded(
315        filename,
316        items,
317        key_fn=key_func,
318        env_callable=lambda unbox_kernel_entry: {
319            "unboxed_kernels": [
320                ComputeCodegenUnboxedKernels(selector, use_aten_lib)(unbox_kernel_entry)
321            ],
322            "fn_header": header
323            if unbox_kernel_entry == items[0]
324            else [],  # Only write header once
325        },
326        num_shards=1,
327        sharded_keys={"unboxed_kernels", "fn_header"},
328    )
329
330
331@with_native_function_and_index  # type: ignore[arg-type]
332def compute_native_function_declaration(
333    g: NativeFunctionsGroup | NativeFunction, kernel_index: ETKernelIndex
334) -> list[str]:
335    assert isinstance(g, NativeFunction)
336    sig = ExecutorchCppSignature.from_native_function(f=g)
337    metadata_list = kernel_index.get_kernels(g).values()
338    if metadata_list is None:
339        return []
340
341    # for kernels in lean mode, we declare two versions, one with context and one without.
342    # In the end we will cleanup the unused one.
343    def gen_decl(metadata: BackendMetadata, include_context: bool) -> str:
344        return f"{sig.decl(name=metadata.kernel, include_context=include_context)};"
345
346    return [
347        gen_decl(metadata, include_context)
348        for include_context in [False, True]
349        for metadata in metadata_list
350    ]
351
352
353def gen_functions_declarations(
354    *,
355    native_functions: Sequence[NativeFunction],
356    kernel_index: ETKernelIndex,
357    selector: SelectiveBuilder,
358    use_aten_lib: bool,
359    custom_ops_native_functions: Sequence[NativeFunction] | None = None,
360) -> str:
361    """
362    Generates namespace separated C++ function API inline declaration/definitions.
363    Native functions are grouped by namespaces and the generated code is wrapped inside
364    namespace blocks.
365
366    E.g., for `custom_1::foo.out` in yaml file we will generate a C++ API as a symbol
367    in `torch::executor::custom_1::foo_out`. This way we avoid symbol conflict when
368    the other `custom_2::foo.out` is available.
369    """
370
371    # convert kernel index to BackendIndex. This is because we can't handle ETKernelIndex yet.
372    # TODO larryliu: evaluate if this code is still needed. If yes let it handle ETKernelIndex.
373
374    backend_index = kernel_index._to_backend_index()
375
376    ns_grouped_functions = defaultdict(list)
377    for native_function in native_functions:
378        ns_grouped_functions[native_function.namespace].append(native_function)
379    functions_declarations = ""
380    newline = "\n"
381    for namespace in ns_grouped_functions:
382        ns_helper = NamespaceHelper(
383            namespace_str=namespace,
384            entity_name="",
385            max_level=3,
386        )
387        declarations = list(
388            mapMaybe(
389                ComputeFunction(
390                    static_dispatch_backend_indices=[backend_index],
391                    selector=selector,
392                    use_aten_lib=use_aten_lib,
393                    is_custom_op=lambda f: custom_ops_native_functions is not None
394                    and f in custom_ops_native_functions,
395                ),
396                ns_grouped_functions[namespace],
397            )
398        )
399        functions_declarations += f"""
400{ns_helper.prologue}
401{newline.join(declarations)}
402{ns_helper.epilogue}
403        """
404    return functions_declarations
405
406
407def get_ns_grouped_kernels(
408    *,
409    native_functions: Sequence[NativeFunction],
410    kernel_index: ETKernelIndex,
411    native_function_decl_gen: Callable[
412        [
413            NativeFunctionsGroup | NativeFunction,
414            ETKernelIndex,
415        ],
416        list[str],
417    ],
418) -> dict[str, list[str]]:
419    ns_grouped_kernels: dict[str, list[str]] = defaultdict(list)
420    for f in native_functions:
421        native_function_namespaces = set()
422        op_kernels = kernel_index.get_kernels(f)
423        for backend_metadata in op_kernels.values():
424            if backend_metadata:
425                namespace = backend_metadata.cpp_namespace
426                native_function_namespaces.add(namespace)
427            else:
428                namespace = DEFAULT_KERNEL_NAMESPACE
429            assert (
430                len(native_function_namespaces) <= 1
431            ), f"Codegen only supports one namespace per operator, got {native_function_namespaces}"
432            ns_grouped_kernels[namespace].extend(
433                native_function_decl_gen(f, kernel_index)
434            )
435    return ns_grouped_kernels
436
437
438def gen_headers(
439    *,
440    native_functions: Sequence[NativeFunction],
441    gen_custom_ops_header: bool,
442    custom_ops_native_functions: Sequence[NativeFunction],
443    selector: SelectiveBuilder,
444    kernel_index: ETKernelIndex,
445    cpu_fm: FileManager,
446    use_aten_lib: bool,
447) -> None:
448    """Generate headers.
449
450    Args:
451        native_functions (Sequence[NativeFunction]): a collection of NativeFunction for ATen ops.
452        gen_custom_ops_header (bool): whether we should generate CustomOpsNativeFunctions.h
453        custom_ops_native_functions (Sequence[NativeFunction]): a collection of NativeFunction for custom ops.
454        kernel_index (ETKernelIndex): kernel collection
455        cpu_fm (FileManager): file manager manages output stream
456        use_aten_lib (bool): whether we are generating for PyTorch types or Executorch types.
457    """
458    aten_headers = ["#include <ATen/Functions.h>"]
459    backend_indices = {DispatchKey.CPU: kernel_index._to_backend_index()}
460    if gen_custom_ops_header:
461        cpu_fm.write_with_template(
462            "CustomOpsNativeFunctions.h",
463            "NativeFunctions.h",
464            lambda: {
465                "nativeFunctions_declarations": get_native_function_declarations(
466                    grouped_native_functions=custom_ops_native_functions,
467                    backend_indices=backend_indices,
468                    native_function_decl_gen=dest.compute_native_function_declaration,
469                ),
470                "headers": [
471                    "#include <ATen/ATen.h>",
472                    "#include <torch/torch.h>",
473                ],
474            },
475        )
476        aten_headers.append('#include "CustomOpsNativeFunctions.h"')
477    cpu_fm.write(
478        "Functions.h",
479        lambda: {
480            "static_dispatch_extra_headers": aten_headers
481            if use_aten_lib
482            else ['#include "NativeFunctions.h"'],
483            "Functions_declarations": gen_functions_declarations(
484                native_functions=native_functions,
485                kernel_index=kernel_index,
486                selector=selector,
487                use_aten_lib=use_aten_lib,
488                custom_ops_native_functions=custom_ops_native_functions,
489            ),
490        },
491    )
492    cpu_fm.write(
493        "RegisterKernels.h",
494        lambda: {
495            "generated_comment": "@" + "generated by torchgen/gen_executorch.py",
496        },
497    )
498    headers = {
499        "headers": [
500            "#include <executorch/runtime/core/exec_aten/exec_aten.h> // at::Tensor etc.",
501            "#include <executorch/runtime/kernel/kernel_runtime_context.h>",
502        ],
503    }
504    if use_aten_lib:
505        headers["headers"].append("#include <executorch/codegen/macros.h> // TORCH_API")
506        cpu_fm.write(
507            "NativeFunctions.h",
508            lambda: dict(
509                {
510                    "nativeFunctions_declarations": get_native_function_declarations(
511                        grouped_native_functions=native_functions,
512                        backend_indices=backend_indices,
513                        native_function_decl_gen=dest.compute_native_function_declaration,
514                    ),
515                },
516                **headers,
517            ),
518        )
519    else:
520        ns_grouped_kernels = get_ns_grouped_kernels(
521            native_functions=native_functions,
522            kernel_index=kernel_index,
523            native_function_decl_gen=compute_native_function_declaration,  # type: ignore[arg-type]
524        )
525        cpu_fm.write(
526            "NativeFunctions.h",
527            lambda: dict(
528                {
529                    "nativeFunctions_declarations": get_native_function_declarations_from_ns_grouped_kernels(
530                        ns_grouped_kernels=ns_grouped_kernels,
531                    ),
532                },
533                **headers,
534            ),
535        )
536
537
538def gen_custom_ops(
539    *,
540    native_functions: Sequence[NativeFunction],
541    selector: SelectiveBuilder,
542    kernel_index: ETKernelIndex,
543    cpu_fm: FileManager,
544    rocm: bool,
545) -> None:
546    dispatch_key = DispatchKey.CPU
547    (
548        anonymous_definition,
549        static_init_dispatch_registrations,
550    ) = gen_custom_ops_registration(
551        native_functions=native_functions,
552        selector=selector,
553        kernel_index=kernel_index,
554        rocm=rocm,
555    )
556    cpu_fm.write_with_template(
557        f"Register{dispatch_key}CustomOps.cpp",
558        "RegisterDispatchKeyCustomOps.cpp",
559        lambda: {
560            "ops_headers": '#include "CustomOpsNativeFunctions.h"',
561            "DispatchKey": dispatch_key,
562            "dispatch_namespace": dispatch_key.lower(),
563            "dispatch_namespaced_definitions": "",
564            "dispatch_anonymous_definitions": anonymous_definition,
565            "static_init_dispatch_registrations": static_init_dispatch_registrations,
566        },
567    )
568    cpu_fm.write_with_template(
569        f"Register{dispatch_key}Stub.cpp",
570        "RegisterDispatchKeyCustomOps.cpp",
571        lambda: {
572            "ops_headers": "",
573            "DispatchKey": dispatch_key,
574            "dispatch_namespace": dispatch_key.lower(),
575            "dispatch_namespaced_definitions": "",
576            "dispatch_anonymous_definitions": list(
577                mapMaybe(ComputeNativeFunctionStub(), native_functions)
578            ),
579            "static_init_dispatch_registrations": static_init_dispatch_registrations,
580        },
581    )
582
583    (
584        aten_schema_registrations,
585        schema_registrations,
586    ) = get_native_function_schema_registrations(
587        native_functions=native_functions,
588        schema_selector=selector,
589    )
590    cpu_fm.write(
591        "RegisterSchema.cpp",
592        lambda: {
593            "schema_registrations": schema_registrations,
594            "aten_schema_registrations": aten_schema_registrations,
595        },
596    )
597
598
599def translate_native_yaml(
600    tags_yaml_path: str,
601    aten_yaml_path: str,
602    native_yaml_path: str | None,
603    use_aten_lib: bool,
604    out_file: TextIO,
605) -> None:
606    """Translates Executorch DSL dialect to use the same syntax as
607    native_functions.yaml. The major difference is that Executorch DSL dialect
608    supports "op" key, where it refers to the operator name in native_functions.yaml.
609
610    For example, a functions.yaml may have the following entry:
611
612    - op: add.out
613      ...
614
615    It needs to be translated to the following:
616
617    - func: add.out(Tensor self, Tensor other, *, Scalar alpha=1, Tensor(a!) out) -> Tensor(a!)
618      ...
619
620    We go in aten_yaml_path and find the operator schema for "add.out" and add it
621    to the original functions.yaml. We also add required field "variants", where for
622    Executorch it will always be "function".
623
624    For ATen mode we don't have to do the translation because native_yaml_path is
625    the same as native_functions.yaml.
626
627    Args:
628        tags_yaml_path: Path to a tags.yaml file to satisfy codegen parsing.
629            It is not optional.
630        aten_yaml_path: Path to ATen operator yaml file native_functions.yaml.
631        native_yaml_path: Path to a functions.yaml file to parse.
632            If the path does not exist in the filesystem, it is treated as an
633            empty file. If `custom_ops_yaml_path` exists, the contents of that
634            file are appended to the yaml input to be parsed.
635        use_aten_lib: We use this flag to determine if we want to generate native
636            functions. In ATen mode we should generate out= variants.
637        out_file: The IO object that we are writing into.
638    Returns:
639        None
640    """
641    if use_aten_lib:
642        with open(aten_yaml_path) as aten_yaml:
643            out_file.writelines(aten_yaml.readlines())
644        return
645
646    native_functions, persisted_fields = parse_et_yaml(
647        aten_yaml_path,
648        tags_yaml_path,
649        None,
650        skip_native_fns_gen=False,
651    )
652
653    func_to_scoped_name: dict[FunctionSchema, str] = {
654        f.func: f"{f.namespace}::{f.func.name}" for f in native_functions
655    }
656    op_to_scoped_name: dict[OperatorName, str] = {
657        func.name: name for func, name in func_to_scoped_name.items()
658    }
659
660    schema_dict = {name: str(func) for func, name in func_to_scoped_name.items()}
661    kernel_persist_dict: dict[str, dict[str, Any]] = {
662        op_to_scoped_name[op]: v for op, v in persisted_fields.items()
663    }
664
665    if (
666        not native_yaml_path
667        or not os.path.exists(native_yaml_path)
668        or os.stat(native_yaml_path).st_size == 0
669    ):
670        return
671    with open(native_yaml_path) as native_yaml:
672        native_es = yaml.load(native_yaml, Loader=LineLoader)
673        if not native_es:
674            return
675        for e in native_es:
676            assert isinstance(e.get("__line__"), int), e
677            loc = Location(native_yaml_path, e.pop("__line__"))
678            with context(lambda: f"in {loc}:\n  "):
679                if "variants" not in e:
680                    e["variants"] = "function"
681                if "func" in e:
682                    continue
683                assert isinstance(e.get("op"), str), e
684                opname = e.pop("op")
685                if "::" not in opname:
686                    opname = "aten::" + opname
687                assert opname in schema_dict
688                e["func"] = schema_dict.get(opname)
689
690                # Write out persisted kernel information
691                if opname in kernel_persist_dict:
692                    for k, v in kernel_persist_dict[opname].items():
693                        e[k] = v
694
695        yaml.dump(native_es, out_file, width=1000)
696
697
698def parse_yaml(
699    path: str | None,
700    tags_yaml_path: str,
701    function_filter: Callable[[NativeFunction], bool],
702    skip_native_fns_gen: bool = False,
703) -> tuple[
704    list[NativeFunction],
705    dict[DispatchKey, dict[OperatorName, BackendMetadata]] | ETKernelIndex,
706]:
707    if path and os.path.exists(path) and os.stat(path).st_size > 0:
708        with open(path) as f:
709            es = yaml.load(f, Loader=LineLoader)
710
711        # Check for kernel index structure
712        kernel_index = (
713            parse_et_yaml_struct(es) if any("kernels" in e for e in es) else None
714        )
715
716        # Remove ET specific fields from entries for BC compatibility
717        for entry in es:
718            for field in ET_FIELDS:
719                entry.pop(field, None)
720
721        parsed_yaml = parse_native_yaml(
722            path,
723            tags_yaml_path,
724            None,
725            skip_native_fns_gen=skip_native_fns_gen,
726            loaded_yaml=es,
727        )
728        native_functions = list(filter(function_filter, parsed_yaml.native_functions))
729        op_names = [f.func.name for f in native_functions]
730
731        # (1) Return ETKernelIndex if kernel index is present
732        if kernel_index is not None:
733            filtered_index = {
734                op_name: kernel_mapping
735                for op_name, kernel_mapping in kernel_index.index.items()
736                if op_name in op_names
737            }
738            return native_functions, ETKernelIndex(index=filtered_index)
739
740        # (2) Return BackendIndices if kernel index is absent
741        def map_index(
742            m: dict[OperatorName, BackendMetadata]
743        ) -> dict[OperatorName, BackendMetadata]:
744            return {op: m[op] for op in m if op in op_names}
745
746        backend_indices = {
747            k: map_index(b.index) for (k, b) in parsed_yaml.backend_indices.items()
748        }
749
750        return native_functions, backend_indices
751    else:
752        return [], {}
753
754
755def parse_yaml_files(
756    tags_yaml_path: str,
757    aten_yaml_path: str,
758    native_yaml_path: str | None,
759    custom_ops_yaml_path: str | None,
760    selector: SelectiveBuilder,
761    use_aten_lib: bool,
762) -> tuple[ETParsedYaml, ETParsedYaml | None]:
763    """Parses functions.yaml and custom_ops.yaml files.
764
765    Args:
766        tags_yaml_path: Path to a tags.yaml file to satisfy codegen parsing.
767            It is not optional.
768        aten_yaml_path: Path to ATen operator yaml file native_functions.yaml.
769        native_yaml_path: Path to a functions.yaml file to parse.
770            If the path does not exist in the filesystem, it is treated as an
771            empty file. If `custom_ops_yaml_path` exists, the contents of that
772            file are appended to the yaml input to be parsed.
773        custom_ops_yaml_path: Path to a custom_ops.yaml file to parse. If
774            the path does not exist in the filesystem, it is ignored.
775        selector: For selective build.
776        use_aten_lib: We use this flag to determine if we want to generate native
777            functions. In ATen mode we should generate out= variants.
778    Returns:
779        A tuple with two elements:
780        [0]: The parsed results of concatenating the contents of
781             `native_yaml_path` and `custom_ops_yaml_path`.
782        [1]: The parsed results of the contents of `custom_ops_yaml_path`, if
783             present. If not present, None.
784    """
785    import tempfile
786
787    # only include selected ops, this is because we want to avoid
788    def function_filter(f: NativeFunction) -> bool:
789        return selector.is_native_function_selected(f)
790
791    with tempfile.TemporaryDirectory() as tmpdirname:
792        translated_yaml_path = os.path.join(tmpdirname, "translated.yaml")
793        with open(translated_yaml_path, "w") as translated:
794            translate_native_yaml(
795                tags_yaml_path,
796                aten_yaml_path,
797                native_yaml_path,
798                use_aten_lib,
799                translated,
800            )
801
802        translated_functions, translated_indices = parse_yaml(
803            translated_yaml_path, tags_yaml_path, function_filter, not use_aten_lib
804        )
805        custom_ops_functions, custom_ops_indices = parse_yaml(
806            custom_ops_yaml_path, tags_yaml_path, function_filter, True
807        )
808
809        # Convert BackendIndices to ETKernelIndex
810        if not isinstance(translated_indices, ETKernelIndex):
811            translated_indices = ETKernelIndex.from_backend_indices(translated_indices)
812        if not isinstance(custom_ops_indices, ETKernelIndex):
813            custom_ops_indices = ETKernelIndex.from_backend_indices(custom_ops_indices)
814
815        combined_functions = translated_functions + custom_ops_functions
816        combined_kernel_index = ETKernelIndex.merge_indices(
817            translated_indices, custom_ops_indices
818        )
819        combined_yaml = ETParsedYaml(combined_functions, combined_kernel_index)
820        custom_ops_parsed_yaml = ETParsedYaml(custom_ops_functions, custom_ops_indices)
821
822    return combined_yaml, custom_ops_parsed_yaml
823
824
825def main() -> None:
826    parser = argparse.ArgumentParser(description="Generate operator source files")
827    # Although we don't refer to --source-path directly, make_file_manager()
828    # expects it to point to a directory that contains a templates/ subdirectory
829    # containing the file templates.
830    parser.add_argument(
831        "-s",
832        "--source-path",
833        help="path to source directory for kernel templates",
834    )
835    parser.add_argument(
836        "--functions-yaml-path",
837        "--functions_yaml_path",
838        help="path to the functions.yaml file to use. Optional, but at least "
839        "one of --functions-yaml-path and --custom-ops-yaml-path must be "
840        "specified.",
841    )
842    parser.add_argument(
843        "--custom-ops-yaml-path",
844        "--custom_ops_yaml_path",
845        help="path to the custom_ops.yaml file to use. Optional, but at least "
846        "one of --functions-yaml-path and --custom-ops-yaml-path must be "
847        "specified.",
848    )
849    parser.add_argument(
850        "--aten-yaml-path",
851        "--aten_yaml_path",
852        help="path to native_functions.yaml file.",
853    )
854    # Note that make_file_manager() also looks at --install-dir.
855    parser.add_argument(
856        "-d",
857        "--install-dir",
858        "--install_dir",
859        help="output directory",
860        default="build/generated",
861    )
862    parser.add_argument(
863        "-o",
864        "--output-dependencies",
865        help="output a list of dependencies into the given file and exit",
866    )
867    # Although we don't refer to --dry-run directly, make_file_manager() looks
868    # for it.
869    parser.add_argument(
870        "--dry-run",
871        action="store_true",
872        help="run without writing any files (still updates outputs)",
873    )
874    parser.add_argument(
875        "--static-dispatch-backend",
876        "--static_dispatch_backend",
877        nargs="*",
878        help="generate static dispatch code for the specific backend (if set)",
879    )
880    parser.add_argument(
881        "--op-registration-whitelist",
882        "--op_registration_whitelist",
883        nargs="*",
884        help="filter op registrations by the whitelist (if set); "
885        "each item is `namespace`::`operator name` without overload name; "
886        "e.g.: aten::empty aten::conv2d ...",
887    )
888    parser.add_argument(
889        "--op-selection-yaml-path",
890        "--op_selection_yaml_path",
891        help="Provide a path to the operator selection (for custom build) YAML "
892        "that contains the information about the set of selected operators "
893        "and their categories (training, ...). Each operator is either a "
894        "full operator name with overload or just a bare operator name. "
895        "The operator names also contain the namespace prefix (e.g. aten::)",
896    )
897    parser.add_argument(
898        "--tags-path",
899        help="Path to tags.yaml. Required by yaml parsing in codegen system.",
900    )
901    parser.add_argument(
902        "--rocm",
903        action="store_true",
904        help="reinterpret CUDA as ROCm/HIP and adjust filepaths accordingly",
905    )
906    parser.add_argument(
907        "--use-aten-lib",
908        "--use_aten_lib",
909        action="store_true",
910        help="a boolean flag to indicate whether we use ATen kernels or not, in the future this flag will be per "
911        "operator",
912    )
913    parser.add_argument(
914        "--manual_registration",
915        "--manual-registration",
916        action="store_true",
917        help="a boolean flag to indicate whether we want to manually call"
918        "register_kernels() or rely on static init. ",
919    )
920    parser.add_argument(
921        "--generate",
922        type=str,
923        nargs="*",
924        choices=["headers", "sources"],
925        default=["headers", "sources"],
926        help="Generate only a subset of files",
927    )
928    options = parser.parse_args()
929    assert options.tags_path, "tags.yaml is required by codegen yaml parsing."
930
931    selector = get_custom_build_selector(
932        options.op_registration_whitelist,
933        options.op_selection_yaml_path,
934    )
935
936    parsed_yaml, custom_ops_parsed_yaml = parse_yaml_files(
937        aten_yaml_path=options.aten_yaml_path,
938        tags_yaml_path=options.tags_path,
939        native_yaml_path=options.functions_yaml_path,
940        custom_ops_yaml_path=options.custom_ops_yaml_path,
941        selector=selector,
942        use_aten_lib=options.use_aten_lib,
943    )
944    native_functions, kernel_index = (
945        parsed_yaml.native_functions,
946        parsed_yaml.kernel_index,
947    )
948    custom_ops_native_functions = (
949        custom_ops_parsed_yaml.native_functions if custom_ops_parsed_yaml else []
950    )
951
952    cpu_fm = make_file_manager(options=options)
953
954    if "headers" in options.generate:
955        # generate CustomOpsNativeFunctions.h when custom_ops.yaml is present, to match the build system.
956        gen_headers(
957            native_functions=native_functions,
958            gen_custom_ops_header=options.custom_ops_yaml_path,
959            custom_ops_native_functions=custom_ops_native_functions,
960            selector=selector,
961            kernel_index=kernel_index,
962            cpu_fm=cpu_fm,
963            use_aten_lib=options.use_aten_lib,
964        )
965
966    if "sources" in options.generate:
967        gen_unboxing(
968            native_functions=native_functions,
969            cpu_fm=cpu_fm,
970            selector=selector,
971            use_aten_lib=options.use_aten_lib,
972            kernel_index=kernel_index,
973            manual_registration=options.manual_registration,
974        )
975        if custom_ops_native_functions:
976            gen_custom_ops(
977                native_functions=custom_ops_native_functions,
978                selector=selector,
979                kernel_index=kernel_index,
980                cpu_fm=cpu_fm,
981                rocm=options.rocm,
982            )
983
984    if options.output_dependencies:
985        depfile_path = Path(options.output_dependencies).resolve()
986        depfile_name = depfile_path.name
987        depfile_stem = depfile_path.stem
988
989        for fm, prefix in [
990            (cpu_fm, ""),
991        ]:
992            varname = prefix + depfile_stem
993            path = depfile_path.parent / (prefix + depfile_name)
994            fm.write_outputs(varname, str(path))
995
996
997if __name__ == "__main__":
998    main()
999