1from __future__ import annotations 2 3from collections import defaultdict 4from dataclasses import dataclass 5from typing import Sequence, TYPE_CHECKING 6 7from torchgen import dest 8 9 10# disable import sorting to avoid circular dependency. 11from torchgen.api.types import DispatcherSignature # usort: skip 12from torchgen.context import method_with_native_function 13from torchgen.model import BaseTy, BaseType, DispatchKey, NativeFunction, Variant 14from torchgen.utils import concatMap, Target 15 16 17if TYPE_CHECKING: 18 from torchgen.executorch.model import ETKernelIndex 19 from torchgen.selective_build.selector import SelectiveBuilder 20 21 22# Generates RegisterKernelStub.cpp, which provides placeholder kernels for custom operators. This will be used at 23# model authoring side. 24@dataclass(frozen=True) 25class ComputeNativeFunctionStub: 26 @method_with_native_function 27 def __call__(self, f: NativeFunction) -> str | None: 28 if Variant.function not in f.variants: 29 return None 30 31 sig = DispatcherSignature.from_schema( 32 f.func, prefix=f"wrapper_CPU_{f.func.name.overload_name}_", symint=False 33 ) 34 assert sig is not None 35 if len(f.func.returns) == 0: 36 ret_name = "" 37 elif len(f.func.returns) == 1: 38 if f.func.arguments.out: 39 ret_name = f.func.arguments.out[0].name 40 else: 41 ret_name = next( 42 ( 43 a.name 44 for a in f.func.arguments.flat_non_out 45 if a.type == f.func.returns[0].type 46 ), 47 "", 48 ) 49 if not ret_name: 50 # if return type is tensor 51 if f.func.returns[0].type == BaseType(BaseTy.Tensor): 52 # Returns an empty tensor 53 ret_name = "at::Tensor()" 54 else: 55 raise Exception( # noqa: TRY002 56 f"Can't handle this return type {f.func}" 57 ) # noqa: TRY002 58 elif len(f.func.arguments.out) == len(f.func.returns): 59 # Returns a tuple of out arguments 60 tensor_type = "at::Tensor &" 61 comma = ", " 62 ret_name = f"""::std::tuple<{comma.join([tensor_type] * len(f.func.returns))}>( 63 {comma.join([r.name for r in f.func.arguments.out])} 64 )""" 65 else: 66 assert all( 67 a.type == BaseType(BaseTy.Tensor) for a in f.func.returns 68 ), f"Only support tensor returns but got {f.func.returns}" 69 # Returns a tuple of empty tensors 70 tensor_type = "at::Tensor" 71 comma = ", " 72 ret_name = f"""::std::tuple<{comma.join([tensor_type] * len(f.func.returns))}>( 73 {comma.join(["at::Tensor()" for _ in f.func.returns])} 74 )""" 75 ret_str = f"return {ret_name};" if len(f.func.returns) > 0 else "" 76 return f""" 77{sig.defn()} {{ 78 {ret_str} 79}} 80 """ 81 82 83def gen_custom_ops_registration( 84 *, 85 native_functions: Sequence[NativeFunction], 86 selector: SelectiveBuilder, 87 kernel_index: ETKernelIndex, 88 rocm: bool, 89) -> tuple[str, str]: 90 """ 91 Generate custom ops registration code for dest.RegisterDispatchKey. 92 93 :param native_functions: a sequence of `NativeFunction` 94 :param selector: for selective build. 95 :param kernel_index: kernels for all the ops. 96 :param rocm: bool for dest.RegisterDispatchKey. 97 :return: generated C++ code to register custom operators into PyTorch 98 """ 99 100 # convert kernel index to BackendIndex. This is because we can't handle ETKernelIndex yet. 101 # TODO larryliu: evaluate if this code is still needed. If yes let it handle ETKernelIndex. 102 103 dispatch_key = DispatchKey.CPU 104 backend_index = kernel_index._to_backend_index() 105 static_init_dispatch_registrations = "" 106 ns_grouped_native_functions: dict[str, list[NativeFunction]] = defaultdict(list) 107 for native_function in native_functions: 108 ns_grouped_native_functions[native_function.namespace].append(native_function) 109 110 for namespace, functions in ns_grouped_native_functions.items(): 111 if len(functions) == 0: 112 continue 113 dispatch_registrations_body = "\n".join( 114 list( 115 concatMap( 116 dest.RegisterDispatchKey( 117 backend_index, 118 Target.REGISTRATION, 119 selector, 120 rocm=rocm, 121 symint=False, 122 class_method_name=None, 123 skip_dispatcher_op_registration=False, 124 ), 125 functions, 126 ) 127 ) 128 ) 129 static_init_dispatch_registrations += f""" 130TORCH_LIBRARY_IMPL({namespace}, {dispatch_key}, m) {{ 131{dispatch_registrations_body} 132}};""" 133 anonymous_definition = "\n".join( 134 list( 135 concatMap( 136 dest.RegisterDispatchKey( 137 backend_index, 138 Target.ANONYMOUS_DEFINITION, 139 selector, 140 rocm=rocm, 141 symint=False, 142 class_method_name=None, 143 skip_dispatcher_op_registration=False, 144 ), 145 native_functions, 146 ) 147 ) 148 ) 149 return anonymous_definition, static_init_dispatch_registrations 150