xref: /aosp_15_r20/external/pytorch/torchgen/executorch/api/custom_ops.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1from __future__ import annotations
2
3from collections import defaultdict
4from dataclasses import dataclass
5from typing import Sequence, TYPE_CHECKING
6
7from torchgen import dest
8
9
10# disable import sorting to avoid circular dependency.
11from torchgen.api.types import DispatcherSignature  # usort: skip
12from torchgen.context import method_with_native_function
13from torchgen.model import BaseTy, BaseType, DispatchKey, NativeFunction, Variant
14from torchgen.utils import concatMap, Target
15
16
17if TYPE_CHECKING:
18    from torchgen.executorch.model import ETKernelIndex
19    from torchgen.selective_build.selector import SelectiveBuilder
20
21
22# Generates RegisterKernelStub.cpp, which provides placeholder kernels for custom operators. This will be used at
23# model authoring side.
24@dataclass(frozen=True)
25class ComputeNativeFunctionStub:
26    @method_with_native_function
27    def __call__(self, f: NativeFunction) -> str | None:
28        if Variant.function not in f.variants:
29            return None
30
31        sig = DispatcherSignature.from_schema(
32            f.func, prefix=f"wrapper_CPU_{f.func.name.overload_name}_", symint=False
33        )
34        assert sig is not None
35        if len(f.func.returns) == 0:
36            ret_name = ""
37        elif len(f.func.returns) == 1:
38            if f.func.arguments.out:
39                ret_name = f.func.arguments.out[0].name
40            else:
41                ret_name = next(
42                    (
43                        a.name
44                        for a in f.func.arguments.flat_non_out
45                        if a.type == f.func.returns[0].type
46                    ),
47                    "",
48                )
49            if not ret_name:
50                # if return type is tensor
51                if f.func.returns[0].type == BaseType(BaseTy.Tensor):
52                    # Returns an empty tensor
53                    ret_name = "at::Tensor()"
54                else:
55                    raise Exception(  # noqa: TRY002
56                        f"Can't handle this return type {f.func}"
57                    )  # noqa: TRY002
58        elif len(f.func.arguments.out) == len(f.func.returns):
59            # Returns a tuple of out arguments
60            tensor_type = "at::Tensor &"
61            comma = ", "
62            ret_name = f"""::std::tuple<{comma.join([tensor_type] * len(f.func.returns))}>(
63                {comma.join([r.name for r in f.func.arguments.out])}
64            )"""
65        else:
66            assert all(
67                a.type == BaseType(BaseTy.Tensor) for a in f.func.returns
68            ), f"Only support tensor returns but got {f.func.returns}"
69            # Returns a tuple of empty tensors
70            tensor_type = "at::Tensor"
71            comma = ", "
72            ret_name = f"""::std::tuple<{comma.join([tensor_type] * len(f.func.returns))}>(
73                {comma.join(["at::Tensor()" for _ in f.func.returns])}
74            )"""
75        ret_str = f"return {ret_name};" if len(f.func.returns) > 0 else ""
76        return f"""
77{sig.defn()} {{
78    {ret_str}
79}}
80    """
81
82
83def gen_custom_ops_registration(
84    *,
85    native_functions: Sequence[NativeFunction],
86    selector: SelectiveBuilder,
87    kernel_index: ETKernelIndex,
88    rocm: bool,
89) -> tuple[str, str]:
90    """
91    Generate custom ops registration code for dest.RegisterDispatchKey.
92
93    :param native_functions: a sequence of `NativeFunction`
94    :param selector: for selective build.
95    :param kernel_index: kernels for all the ops.
96    :param rocm: bool for dest.RegisterDispatchKey.
97    :return: generated C++ code to register custom operators into PyTorch
98    """
99
100    # convert kernel index to BackendIndex. This is because we can't handle ETKernelIndex yet.
101    # TODO larryliu: evaluate if this code is still needed. If yes let it handle ETKernelIndex.
102
103    dispatch_key = DispatchKey.CPU
104    backend_index = kernel_index._to_backend_index()
105    static_init_dispatch_registrations = ""
106    ns_grouped_native_functions: dict[str, list[NativeFunction]] = defaultdict(list)
107    for native_function in native_functions:
108        ns_grouped_native_functions[native_function.namespace].append(native_function)
109
110    for namespace, functions in ns_grouped_native_functions.items():
111        if len(functions) == 0:
112            continue
113        dispatch_registrations_body = "\n".join(
114            list(
115                concatMap(
116                    dest.RegisterDispatchKey(
117                        backend_index,
118                        Target.REGISTRATION,
119                        selector,
120                        rocm=rocm,
121                        symint=False,
122                        class_method_name=None,
123                        skip_dispatcher_op_registration=False,
124                    ),
125                    functions,
126                )
127            )
128        )
129        static_init_dispatch_registrations += f"""
130TORCH_LIBRARY_IMPL({namespace}, {dispatch_key}, m) {{
131{dispatch_registrations_body}
132}};"""
133    anonymous_definition = "\n".join(
134        list(
135            concatMap(
136                dest.RegisterDispatchKey(
137                    backend_index,
138                    Target.ANONYMOUS_DEFINITION,
139                    selector,
140                    rocm=rocm,
141                    symint=False,
142                    class_method_name=None,
143                    skip_dispatcher_op_registration=False,
144                ),
145                native_functions,
146            )
147        )
148    )
149    return anonymous_definition, static_init_dispatch_registrations
150