1from __future__ import annotations 2 3from dataclasses import dataclass 4from typing import Iterator, Sequence, TYPE_CHECKING 5 6from torchgen.api.types.types_base import Binding, CType, Expr 7 8 9if TYPE_CHECKING: 10 from torchgen.model import ( 11 BackendIndex, 12 FunctionSchema, 13 NativeFunction, 14 NativeFunctionsGroup, 15 NativeFunctionsViewGroup, 16 ) 17 18 19@dataclass(frozen=True) 20class CppSignature: 21 """ 22 A CppSignature represents a single overload in the C++ API. For 23 any given function schema, there may be multiple CppSignatures 24 corresponding to it, based on how we desugar to C++. See also 25 CppSignatureGroup. 26 """ 27 28 # The schema this signature is derived from 29 func: FunctionSchema 30 31 # Is this a C++ signature for a method, i.e. Tensor::my_op(...)? 32 method: bool 33 34 # Is this a faithful C++ signature (i.e. following the JIT schema) or a convenience API 35 # (i.e. with a potential TensorOptions argument and out arguments in the front) 36 faithful: bool 37 38 # Is this a symint C++ signature. For BC reasons, functions that take 39 # SymInts still present as int64_t in C++, and the SymInt variant is 40 # offered at a different overload name 41 # 42 # NB: If a function RETURNS a SymInt, this is ALWAYS false 43 symint: bool 44 45 # The set of C++ arguments which should not have defaults applied to them 46 cpp_no_default_args: set[str] 47 48 # Is this a fallback C++ binding? Fallback bindings are enabled by 49 # manual_cpp_binding: True and are alternate, non-public API that 50 # lets manual C++ binding implementors access the binding that would 51 # have been automatically generated 52 fallback_binding: bool = False 53 54 # Return the unpacked argument structure of this signature, 55 # discarding information about which arguments are semantically 56 # related to each other. 57 def arguments(self) -> Sequence[Binding]: 58 return cpp.arguments( 59 self.func.arguments, 60 faithful=self.faithful, 61 symint=self.symint, 62 method=self.method, 63 cpp_no_default_args=self.cpp_no_default_args, 64 ) 65 66 def name(self, *, suppress_symint_suffix: bool = False) -> str: 67 n = cpp.name( 68 self.func, 69 faithful_name_for_out_overloads=self.faithful, 70 symint_overload=False if suppress_symint_suffix else self.symint, 71 ) 72 if self.fallback_binding: 73 n = f"__dispatch_{n}" 74 return n 75 76 # Render the C++ declaration for this signature 77 def decl( 78 self, 79 *, 80 name: str | None = None, 81 prefix: str = "", 82 is_redispatching_fn: bool = False, 83 suppress_symint_suffix: bool = False, 84 ) -> str: 85 returns_type = cpp.returns_type( 86 self.func.returns, symint=self.symint 87 ).cpp_type() 88 cpp_args = [a.decl() for a in self.arguments()] 89 if is_redispatching_fn: 90 cpp_args = ["c10::DispatchKeySet dispatchKeySet"] + cpp_args 91 cpp_args_str = ", ".join(cpp_args) 92 if name is None: 93 name = prefix + self.name(suppress_symint_suffix=suppress_symint_suffix) 94 return f"{returns_type} {name}({cpp_args_str})" 95 96 # Render the C++ definition for this signature, not including 97 # the body (with curly braces) 98 def defn( 99 self, 100 *, 101 name: str | None = None, 102 prefix: str = "", 103 is_redispatching_fn: bool = False, 104 ) -> str: 105 returns_type = cpp.returns_type( 106 self.func.returns, symint=self.symint 107 ).cpp_type() 108 cpp_args = [a.defn() for a in self.arguments()] 109 if is_redispatching_fn: 110 cpp_args = ["c10::DispatchKeySet dispatchKeySet"] + cpp_args 111 cpp_args_str = ", ".join(cpp_args) 112 if name is None: 113 name = prefix + self.name() 114 return f"{returns_type} {name}({cpp_args_str})" 115 116 def ptr_type(self) -> str: 117 args_types_str = ", ".join(a.type for a in self.arguments()) 118 return f"{cpp.returns_type(self.func.returns, symint=self.symint).cpp_type()} (*)({args_types_str})" 119 120 # Return the C++ function type, e.g., something like int(bool) 121 def type(self) -> str: 122 args_types_str = ", ".join(a.type for a in self.arguments()) 123 return f"{cpp.returns_type(self.func.returns, symint=self.symint).cpp_type()} ({args_types_str})" 124 125 126# Represents group of all CppSignatures associated with a 127# FunctionSchema. Right now, that's the regular, user-visible 128# signature, as well as a "faithful" signature which doesn't 129# have grouping. 130@dataclass(frozen=True) 131class CppSignatureGroup: 132 func: FunctionSchema 133 signature: CppSignature 134 faithful_signature: CppSignature | None 135 symint_signature: CppSignature | None 136 symint_faithful_signature: CppSignature | None 137 138 def most_faithful_signature(self) -> CppSignature: 139 if self.faithful_signature: 140 return self.faithful_signature 141 else: 142 return self.signature 143 144 def signatures(self, *, symint: bool = True) -> Iterator[CppSignature]: 145 yield self.signature 146 if self.faithful_signature: 147 yield self.faithful_signature 148 if symint: 149 if self.symint_signature: 150 yield self.symint_signature 151 if self.symint_faithful_signature: 152 yield self.symint_faithful_signature 153 154 @staticmethod 155 def from_native_function( 156 f: NativeFunction, *, method: bool, fallback_binding: bool = False 157 ) -> CppSignatureGroup: 158 func = f.func 159 160 def make_sig(*, faithful: bool, symint: bool) -> CppSignature: 161 return CppSignature( 162 func=func, 163 faithful=faithful, 164 symint=symint, 165 method=method, 166 fallback_binding=fallback_binding, 167 cpp_no_default_args=f.cpp_no_default_args, 168 ) 169 170 def make_sigs(*, symint: bool) -> tuple[CppSignature, CppSignature | None]: 171 faithful_signature: CppSignature | None = None 172 if func.arguments.tensor_options is not None or len(func.arguments.out) > 0: 173 faithful_signature = make_sig(faithful=True, symint=symint) 174 signature = make_sig(faithful=False, symint=symint) 175 return signature, faithful_signature 176 177 signature, faithful_signature = make_sigs(symint=False) 178 symint_signature: CppSignature | None = None 179 symint_faithful_signature: CppSignature | None = None 180 if func.has_symint(): 181 symint_signature, symint_faithful_signature = make_sigs(symint=True) 182 183 return CppSignatureGroup( 184 func=func, 185 signature=signature, 186 faithful_signature=faithful_signature, 187 symint_signature=symint_signature, 188 symint_faithful_signature=symint_faithful_signature, 189 ) 190 191 192@dataclass(frozen=True) 193class DispatcherSignature: 194 # The schema this signature is derived from 195 func: FunctionSchema 196 197 # Allows you to prepend an arbitrary prefix to the signature name. 198 # This is useful for parts of the codegen that generate wrappers around kernels, 199 # and need to avoid naming collisions. 200 prefix: str = "" 201 202 symint: bool = True 203 204 def arguments(self) -> list[Binding]: 205 return dispatcher.arguments(self.func, symint=self.symint) 206 207 def name(self) -> str: 208 return self.prefix + dispatcher.name(self.func) 209 210 def decl(self, name: str | None = None) -> str: 211 args_str = ", ".join(a.decl() for a in self.arguments()) 212 if name is None: 213 name = self.name() 214 return f"{self.returns_type().cpp_type()} {name}({args_str})" 215 216 def defn( 217 self, name: str | None = None, *, is_redispatching_fn: bool = False 218 ) -> str: 219 args = [a.defn() for a in self.arguments()] 220 if is_redispatching_fn: 221 args = ["c10::DispatchKeySet dispatchKeySet"] + args 222 args_str = ", ".join(args) 223 if name is None: 224 name = self.name() 225 return f"{self.returns_type().cpp_type()} {name}({args_str})" 226 227 def exprs(self) -> list[Expr]: 228 return [Expr(a.name, a.nctype) for a in self.arguments()] 229 230 def returns_type(self) -> CType: 231 return dispatcher.returns_type(self.func.returns, symint=self.symint) 232 233 def ptr_type(self) -> str: 234 dispatcher_args_types_str = ", ".join(a.type for a in self.arguments()) 235 return f"{self.returns_type().cpp_type()} (*)({dispatcher_args_types_str})" 236 237 # Return the C++ function type, e.g., something like int(bool) 238 def type(self) -> str: 239 dispatcher_args_types_str = ", ".join(a.type for a in self.arguments()) 240 return f"{self.returns_type().cpp_type()} ({dispatcher_args_types_str})" 241 242 @staticmethod 243 def from_schema( 244 func: FunctionSchema, *, prefix: str = "", symint: bool = True 245 ) -> DispatcherSignature: 246 return DispatcherSignature(func, prefix, symint) 247 248 249@dataclass(frozen=True) 250class NativeSignature: 251 # The schema this signature is derived from 252 func: FunctionSchema 253 254 symint: bool 255 256 prefix: str = "" 257 258 def name(self) -> str: 259 return self.prefix + native.name(self.func) 260 261 def decl(self, name: str | None = None) -> str: 262 args_str = ", ".join(a.decl() for a in self.arguments()) 263 if name is None: 264 name = self.name() 265 return f"{native.returns_type(self.func.returns, symint=self.symint).cpp_type()} {name}({args_str})" 266 267 def defn(self, name: str | None = None) -> str: 268 args_str = ", ".join(a.defn() for a in self.arguments()) 269 if name is None: 270 name = self.name() 271 return f"{native.returns_type(self.func.returns, symint=self.symint).cpp_type()} {name}({args_str})" 272 273 def ptr_type(self) -> str: 274 # don't include defaults in type signature! 275 args_str = ", ".join(a.defn() for a in self.arguments()) 276 return f"{native.returns_type(self.func.returns, symint=self.symint).cpp_type()} (*)({args_str})" 277 278 def arguments(self) -> list[Binding]: 279 return native.arguments(self.func, symint=self.symint) 280 281 def returns_type(self) -> CType: 282 return native.returns_type(self.func.returns, symint=self.symint) 283 284 def dispatcher_exprs(self) -> list[Expr]: 285 return translate.translate( 286 self.arguments(), dispatcher.arguments(self.func), method=False 287 ) 288 289 290@dataclass(frozen=True) 291class ViewInverseSignature: 292 g: NativeFunctionsViewGroup 293 294 def name(self) -> str: 295 return functionalization.reverse_name(self.g.view, include_namespace=False) 296 297 def decl(self) -> str: 298 return_type = functionalization.returns_type(self.g.view.func) 299 decls = [ 300 a.decl() 301 for a in functionalization.inner_arguments( 302 self.g.view.func, is_reverse=True 303 ) 304 ] 305 return f"static {return_type.cpp_type()} {self.name()}({', '.join(decls)});" 306 307 308@dataclass(frozen=True) 309class FunctionalizationLambda: 310 g: NativeFunctionsViewGroup 311 312 # are we generating the forward lambda or the reverse lambda? 313 is_reverse: bool 314 315 def captures(self) -> list[Expr]: 316 # The lambda lives inside of a kernel following the dispatcher API, so its outer context is the dispatcher arguments 317 # We also need to read the "reapply views" TLS at the time that the functionalization kernel was executed, 318 # and plumb it into the lambda. 319 outer_ctx = dispatcher.arguments(self.g.view.func) + [ 320 functionalization.reapply_views_binding, 321 functionalization.inverse_return_mode_binding, 322 ] 323 capture_bindings = functionalization.capture_arguments( 324 self.g.view.func, is_reverse=self.is_reverse 325 ) 326 # allow_expensive_conversions is set because we want to convert 327 # some reference types (IntArrayRef) to value types (vector<int64_t>). 328 capture_exprs = translate.translate( 329 outer_ctx, capture_bindings, method=False, allow_expensive_conversions=True 330 ) 331 return capture_exprs 332 333 def decl(self) -> str: 334 return_type = functionalization.returns_type(self.g.view.func) 335 capture_str = ", ".join( 336 f"{val.type.name} = {val.expr}" for val in self.captures() 337 ) 338 decls = [ 339 a.decl() 340 for a in functionalization.outer_arguments(is_reverse=self.is_reverse) 341 ] 342 return f"[{capture_str}]({', '.join(decls)}) -> {return_type.cpp_type()}" 343 344 def inner_call(self, *, reapply_views: bool | None = None) -> str: 345 inner_call_name = functionalization.name( 346 self.g, 347 is_reverse=self.is_reverse, 348 include_namespace=True, 349 reapply_views=reapply_views, 350 ) 351 352 arg_ctx = functionalization.outer_arguments(is_reverse=self.is_reverse) 353 capture_ctx = functionalization.capture_arguments( 354 self.g.view.func, is_reverse=self.is_reverse 355 ) 356 full_ctx = arg_ctx + capture_ctx 357 358 assert self.g.view_copy is not None 359 call_bindings = functionalization.inner_arguments( 360 self.g.view_copy.func, is_reverse=self.is_reverse 361 ) 362 maybe_index = functionalization.inner_call_index(self.g.view_copy.func) 363 call_exprs = [ 364 e.expr for e in translate.translate(full_ctx, call_bindings, method=False) 365 ] 366 if not self.is_reverse and maybe_index is not None: 367 return f'{inner_call_name}({", ".join(call_exprs)})[{maybe_index.name}];' 368 else: 369 return f'{inner_call_name}({", ".join(call_exprs)});' 370 371 @staticmethod 372 def from_func( 373 g: NativeFunctionsViewGroup, *, is_reverse: bool 374 ) -> FunctionalizationLambda: 375 return FunctionalizationLambda(g, is_reverse) 376 377 378@dataclass(frozen=True) 379class StructuredImplSignature: 380 g: NativeFunctionsGroup 381 name: str 382 383 def defn(self, name: str | None = None) -> str: 384 args_str = ", ".join(a.defn() for a in self.arguments()) 385 return f"TORCH_IMPL_FUNC({self.name})({args_str})" 386 387 def arguments(self) -> list[Binding]: 388 return structured.impl_arguments(self.g) 389 390 391# Helper functions 392 393 394def kernel_signature( 395 f: NativeFunction, backend_index: BackendIndex, *, prefix: str = "" 396) -> NativeSignature | DispatcherSignature: 397 # Note [External Backends Follow Dispatcher API] 398 # Kernel signatures for in-tree backends follow the "native" API, 399 # while kernels for out-of-tree backends follow the dispatcher API. 400 # See the comments in `native.py` for details, but historically there have been 401 # some small differences in schema convention between them and the Dispatcher API. 402 # Any differences that require translating between the two will results in a runtime cost, 403 # so we'd like to keep the differences as small as possible. 404 # With external backends, we'd like to enforce that they write their kernels with schemas 405 # that match the Dispatcher API directly, if they can. 406 meta = backend_index.get_kernel(f) 407 symint = meta is not None and meta.supports_symint() 408 if symint: 409 assert ( 410 f.func.has_symint() 411 ), f"attempted to define symint kernel for {backend_index.dispatch_key} without SymInt in schema" 412 if backend_index.external: 413 return DispatcherSignature.from_schema(f.func, prefix=prefix, symint=symint) 414 else: 415 return NativeSignature(f.func, prefix=prefix, symint=symint) 416 417 418# Functions only, no types 419from torchgen.api import ( 420 cpp, 421 dispatcher, 422 functionalization, 423 native, 424 structured, 425 translate, 426) 427