1from __future__ import annotations 2 3import textwrap 4from dataclasses import dataclass 5from typing import Sequence 6 7from torchgen.api.types import DispatcherSignature 8from torchgen.api.types.signatures import CppSignature, CppSignatureGroup 9from torchgen.context import method_with_native_function 10from torchgen.model import ( 11 Argument, 12 BackendIndex, 13 BaseTy, 14 BaseType, 15 DispatchKey, 16 FunctionSchema, 17 ListType, 18 NativeFunction, 19 NativeFunctionsGroup, 20 OperatorName, 21 OptionalType, 22 Type, 23) 24from torchgen.utils import mapMaybe 25 26 27base_type_to_c_type = { 28 BaseTy.Tensor: "AtenTensorHandle", 29 BaseTy.bool: "int32_t", # Use int to pass bool 30 BaseTy.int: "int64_t", 31 BaseTy.SymInt: "int64_t", # Inductor-generated code won't see a SymInt 32 BaseTy.Scalar: "double", # Use double to pass both integer and floating point 33 BaseTy.float: "double", # TODO: how about other floating point types? 34 BaseTy.str: "const char*", 35 BaseTy.DeviceIndex: "int32_t", 36 BaseTy.Layout: "int32_t", # Represent enum as int 37 BaseTy.MemoryFormat: "int32_t", # Represent enum as int 38 BaseTy.ScalarType: "int32_t", # Represent enum as int 39 BaseTy.Generator: "AtenGeneratorHandle", 40} 41 42base_type_to_aten_type = { 43 BaseTy.Tensor: "at::Tensor", 44 BaseTy.bool: "bool", 45 BaseTy.int: "int64_t", 46 BaseTy.SymInt: "c10::SymInt", 47 BaseTy.Scalar: "c10::Scalar", 48 BaseTy.float: "double", 49 BaseTy.str: "c10::string_view", 50 BaseTy.DeviceIndex: "c10::DeviceIndex", 51 BaseTy.Layout: "c10::Layout", 52 BaseTy.MemoryFormat: "c10::MemoryFormat", 53 BaseTy.ScalarType: "c10::ScalarType", 54 BaseTy.Generator: "at::Generator", 55} 56 57base_type_to_callsite_expr = { 58 BaseTy.Tensor: "*tensor_handle_to_tensor_pointer", 59 BaseTy.bool: "", 60 BaseTy.int: "", 61 BaseTy.SymInt: "", 62 BaseTy.Scalar: "", 63 BaseTy.float: "", 64 BaseTy.str: "", 65 BaseTy.DeviceIndex: "static_cast<c10::DeviceIndex>", 66 BaseTy.Layout: "static_cast<c10::Layout>", 67 BaseTy.MemoryFormat: "static_cast<c10::MemoryFormat>", 68 BaseTy.ScalarType: "static_cast<c10::ScalarType>", 69 BaseTy.Generator: "*generator_handle_to_generator_pointer", 70} 71 72 73# convert args to C types, names in declarations, and expressions in function bodies 74def convert_arg_type_and_name(typ: Type, name: str) -> tuple[list[str], list[str], list[str], list[str]]: # type: ignore[return] 75 if isinstance(typ, BaseType): 76 if typ.name in base_type_to_c_type: 77 return ( 78 [base_type_to_c_type[typ.name]], 79 [name], 80 [base_type_to_aten_type[typ.name]], 81 [ 82 f"{base_type_to_callsite_expr[typ.name]}({name})" 83 if base_type_to_callsite_expr[typ.name] 84 else name 85 ], 86 ) 87 elif typ.name == BaseTy.Device: 88 return ( 89 ["int32_t", "int32_t"], 90 [name, name + "_index_"], 91 ["c10::Device"], 92 [ 93 f"c10::Device(static_cast<c10::DeviceType>({name}), static_cast<c10::DeviceIndex>({name}_index_))" 94 ], 95 ) 96 else: 97 # TODO: BaseTy.Dimname, etc. 98 raise NotImplementedError(f"TODO: add support for arg type {repr(typ)}") 99 elif isinstance(typ, OptionalType): 100 c_types, names, aten_types, callsite_exprs = convert_arg_type_and_name( 101 typ.elem, name 102 ) 103 j = 0 # index for names 104 new_aten_types = [] 105 new_callsite_exprs = [] 106 for aten_type in aten_types: 107 # Use pointer to denote optional type 108 c_types[j] = c_types[j] + "*" 109 if aten_type.startswith("c10::ArrayRef<"): 110 # ArrayRef is passed as pointer + size, but no need to add "*" to the size argument 111 new_aten_types.append(f"::std::optional<{aten_type}>") 112 base_type = aten_type[len("c10::ArrayRef<") : -1] 113 new_callsite_exprs.append( 114 f"pointer_to_optional_list<{base_type}>({names[j]}, {names[j+1]})" 115 ) 116 j += 2 117 elif aten_type == "c10::Device": 118 # Device is passed as device_type + device_index 119 new_aten_types.append("::std::optional<c10::Device>") 120 new_callsite_exprs.append( 121 f"pointer_to_optional_device({names[j]}, {names[j+1]})" 122 ) 123 j += 2 124 else: 125 new_aten_types.append(f"::std::optional<{aten_type}>") 126 new_callsite_exprs.append( 127 f"pointer_to_optional<{aten_type}>({names[j]})" 128 ) 129 j += 1 130 131 return ( 132 c_types, 133 names, 134 new_aten_types, 135 new_callsite_exprs, 136 ) 137 elif isinstance(typ, ListType): 138 # Need to explictly pass the list as pointer + length 139 c_types, names, aten_types, _ = convert_arg_type_and_name(typ.elem, name) 140 assert len(c_types) == 1, "ListType with unsupported element type " + repr(typ) 141 142 # The list content should never be modified 143 c_types[0] = f"const {c_types[0]}*" 144 c_types.append("int64_t") 145 name = names[0] 146 names.append(name + "_len_") 147 148 atype = aten_types[0] 149 callsite_exprs = [] 150 if atype == "bool": 151 # no converter from std::vector<bool> to c10::ArrayRef<bool> 152 # construct std::array<bool, N> instead 153 assert typ.size is not None 154 callsite_exprs.append(f"pointer_to_list<{typ.size}>({name})") 155 elif atype == "::std::optional<at::Tensor>": 156 # convert from std::vector<::std::optional<at::Tensor>> to c10::List<::std::optional<at::Tensor>> 157 callsite_exprs.append( 158 f"c10::List<{atype}>(c10::ArrayRef<{atype}>(pointer_to_list<{atype}>({name}, {name}_len_)))" 159 ) 160 else: 161 callsite_exprs.append(f"pointer_to_list<{atype}>({name}, {name}_len_)") 162 163 aten_types = [f"c10::ArrayRef<{t}>" for t in aten_types] 164 return ( 165 c_types, 166 names, 167 aten_types, 168 callsite_exprs, 169 ) 170 171 172def zip_type_and_name(types: list[str], names: list[str]) -> list[str]: 173 return [typ + " " + name for typ, name in zip(types, names)] 174 175 176# Generate argument declarations and callsite expressions 177def gen_arguments(flat_arguments: Sequence[Argument]) -> tuple[list[str], list[str]]: 178 types = [] 179 new_names = [] 180 callsite_exprs = [] 181 for arg in flat_arguments: 182 new_types, names, _, new_callsite_exprs = convert_arg_type_and_name( 183 arg.type, arg.name 184 ) 185 types.extend(new_types) 186 new_names.extend(names) 187 callsite_exprs.extend(new_callsite_exprs) 188 return zip_type_and_name(types, new_names), callsite_exprs 189 190 191# Return values are passed out as pointer arguments because all the C shim functions 192# are expected to return AOTITorchError. 193# Generate returns as declarations and callsite expressions 194def gen_returns(schema: FunctionSchema) -> tuple[list[str], list[str]]: 195 types = [] 196 names = [] 197 for idx, ret in enumerate(schema.returns): 198 names.append(f"ret{idx}") 199 if isinstance(ret.type, BaseType) and ret.type.name in base_type_to_c_type: 200 types.append(base_type_to_c_type[ret.type.name] + "*") 201 else: 202 raise NotImplementedError( 203 f"TODO: add support for return type {repr(ret.type)}" 204 ) 205 206 def convert_return(typ: BaseType, val: str) -> str: 207 if typ.name == BaseTy.Tensor: 208 return f"new_tensor_handle(std::move({val}));" 209 elif typ.name == BaseTy.SymInt: 210 return f"{val}.expect_int()" 211 elif typ.name == BaseTy.Scalar: 212 return f"{val}.toDouble()" 213 else: 214 return val 215 216 ret_pointer_can_be_null = False 217 unambiguous_name = schema.name.unambiguous_name() 218 for name in [ 219 "_scaled_dot_product_flash_attention", 220 "_scaled_dot_product_efficient_attention", 221 "_scaled_dot_product_cudnn_attention", 222 "convolution_backward", 223 ]: 224 if name in unambiguous_name: 225 ret_pointer_can_be_null = True 226 break 227 228 callsite_exprs: list[str] = [] 229 for idx, ret in enumerate(schema.returns): 230 tmp = "tmp_result" if len(names) == 1 else f"std::get<{idx}>(tmp_result)" 231 assert isinstance(ret.type, BaseType) 232 rval = convert_return(ret.type, tmp) 233 if ret_pointer_can_be_null: 234 callsite_exprs.append(f"if ({names[idx]}) {{ *{names[idx]} = {rval}; }}") 235 else: 236 callsite_exprs.append(f"*{names[idx]} = {rval};") 237 238 return zip_type_and_name(types, names), callsite_exprs 239 240 241# gen.py generates header first and then src, so caching the result here to avoid duplicate work 242declaration_definition_cache: dict[tuple[str, str, str], tuple[str, str]] = {} 243 244 245def gen_declaration_and_definition( 246 schema: FunctionSchema, device: str, backend_call: str 247) -> tuple[str, str]: 248 func_name = schema.name.unambiguous_name() 249 250 global declaration_definition_cache 251 if (func_name, device, backend_call) in declaration_definition_cache: 252 return declaration_definition_cache[(func_name, device, backend_call)] 253 254 if schema.is_out_fn(): 255 # out_variant has out arguments in the front, and it's ok to ignore return values 256 # because C shim functions only return AOTITorchError 257 args, callsite_exprs = gen_arguments( 258 [*schema.arguments.out, *schema.arguments.flat_non_out] 259 ) 260 ret_assignments: list[str] = [] 261 else: 262 args, callsite_exprs = gen_arguments(schema.arguments.flat_all) 263 # ignore return values for inplace ops 264 ret_declarations, ret_assignments = ( 265 ([], []) if schema.name.name.inplace else gen_returns(schema) 266 ) 267 args.extend(ret_declarations) 268 269 declaration = f"AOTITorchError aoti_torch_{device}_{func_name}({', '.join(args)})" 270 271 tmp_result = "auto tmp_result = " if ret_assignments else "" 272 ret_assignments_str = "\n" + "\n".join(ret_assignments) if ret_assignments else "" 273 definition = f""" 274{declaration} {{ 275 AOTI_TORCH_CONVERT_EXCEPTION_TO_ERROR_CODE({{ 276 {tmp_result}{backend_call}( 277{textwrap.indent(', '.join(callsite_exprs), " ")} 278 );{textwrap.indent(ret_assignments_str, " ")} 279 }}); 280}} 281""" 282 declaration_definition_cache[(func_name, device, backend_call)] = ( 283 declaration, 284 definition, 285 ) 286 return declaration, definition 287 288 289def gen_static_dispatch_backend_call_signature( 290 sig: CppSignature | DispatcherSignature, 291 f: NativeFunction, 292) -> CppSignature: 293 sig = DispatcherSignature.from_schema(f.func) 294 cpp_sigs = CppSignatureGroup.from_native_function( 295 f, method=False, fallback_binding=False 296 ) 297 if sig.symint and f.func.has_symint(): 298 cpp_sig = cpp_sigs.symint_signature 299 else: 300 cpp_sig = cpp_sigs.signature 301 assert cpp_sig is not None 302 return cpp_sig 303 304 305def gen_static_dispatch_backend_call( 306 f: NativeFunction, 307 backend_index: BackendIndex, 308) -> str: 309 sig = DispatcherSignature.from_schema(f.func) 310 cpp_sig = gen_static_dispatch_backend_call_signature(sig, f) 311 return f"at::{backend_index.dispatch_key.lower()}::{cpp_sig.name()}" 312 313 314def get_backend_index_for_aoti( 315 func: NativeFunction, 316 func_group_mapping: dict[OperatorName, NativeFunctionsGroup], 317 dispatch_key: DispatchKey, 318 backend_indices: dict[DispatchKey, BackendIndex], 319) -> BackendIndex | None: 320 backend_index = None 321 if backend_indices[dispatch_key].has_kernel(func) or ( 322 func.structured_delegate is not None 323 and func.structured_delegate in func_group_mapping 324 and backend_indices[dispatch_key].has_kernel( 325 func_group_mapping[func.structured_delegate] 326 ) 327 ): 328 backend_index = backend_indices[dispatch_key] 329 elif backend_indices[DispatchKey.CompositeExplicitAutograd].has_kernel(func): 330 # We need to create C shim wrappers for CompositeExplicitAutograd kernels 331 backend_index = backend_indices[DispatchKey.CompositeExplicitAutograd] 332 elif backend_indices[DispatchKey.CompositeExplicitAutogradNonFunctional].has_kernel( 333 func 334 ): 335 # We need to create C shim wrappers for CompositeExplicitAutogradNonFunctional kernels 336 backend_index = backend_indices[ 337 DispatchKey.CompositeExplicitAutogradNonFunctional 338 ] 339 elif backend_indices[DispatchKey.CompositeImplicitAutograd].has_kernel(func): 340 backend_index = backend_indices[DispatchKey.CompositeImplicitAutograd] 341 342 return backend_index 343 344 345def get_header_for_aoti( 346 func: NativeFunction, 347 func_group_mapping: dict[OperatorName, NativeFunctionsGroup], 348 dispatch_key: DispatchKey, 349 backend_indices: dict[DispatchKey, BackendIndex], 350) -> str | None: 351 backend_index = get_backend_index_for_aoti( 352 func, func_group_mapping, dispatch_key, backend_indices 353 ) 354 return ( 355 None 356 if backend_index is None 357 else f"#include <ATen/ops/{func.root_name}_{backend_index.dispatch_key.lower()}_dispatch.h>" 358 ) 359 360 361def get_fallback_op_name(func: NativeFunction) -> str: 362 return ( 363 f"{func.namespace}.{func.func.name.name}.{func.func.name.overload_name}" 364 if func.func.name.overload_name 365 else f"{func.namespace}.{func.func.name.name}.default" 366 ) 367 368 369def gen_c_shim( 370 func: NativeFunction, 371 func_group_mapping: dict[OperatorName, NativeFunctionsGroup], 372 dispatch_key: DispatchKey, 373 backend_indices: dict[DispatchKey, BackendIndex], 374 header: bool, 375) -> str | None: 376 backend_index = get_backend_index_for_aoti( 377 func, func_group_mapping, dispatch_key, backend_indices 378 ) 379 if backend_index is None: 380 return None 381 382 schema = func.func 383 device = dispatch_key.lower() 384 backend_call = gen_static_dispatch_backend_call( 385 func, 386 backend_index, 387 ) 388 389 try: 390 if header: 391 declaration, _ = gen_declaration_and_definition( 392 schema, device, backend_call 393 ) 394 return f"AOTI_TORCH_EXPORT {declaration};" 395 else: 396 _, definition = gen_declaration_and_definition(schema, device, backend_call) 397 return definition 398 399 except NotImplementedError: 400 return None 401 402 403@dataclass(frozen=True) 404class ShimGenerator: 405 func_group_mapping: dict[OperatorName, NativeFunctionsGroup] 406 dispatch_key: DispatchKey 407 backend_indices: dict[DispatchKey, BackendIndex] 408 header: bool # True to generate .h and False to generate .cpp 409 410 @method_with_native_function 411 def __call__( 412 self, 413 func: NativeFunction, 414 ) -> str | None: 415 result = gen_c_shim( 416 func, 417 self.func_group_mapping, 418 self.dispatch_key, 419 self.backend_indices, 420 self.header, 421 ) 422 return result 423 424 425def gen_aoti_c_shim( 426 native_functions: Sequence[NativeFunction], 427 func_group_mapping: dict[OperatorName, NativeFunctionsGroup], 428 dispatch_key: DispatchKey, 429 backend_indices: dict[DispatchKey, BackendIndex], 430 header: bool, 431 includes: str = "", 432) -> str: 433 body = "\n".join( 434 list( 435 mapMaybe( 436 ShimGenerator( 437 func_group_mapping, dispatch_key, backend_indices, header 438 ), 439 native_functions, 440 ) 441 ) 442 ) 443 device = dispatch_key.lower() 444 445 warning = """ 446// WARNING: THIS FILE IS AUTOGENERATED BY torchgen. DO NOT MODIFY BY HAND. 447// See https://github.com/pytorch/pytorch/blob/7e86a7c0155295539996e0cf422883571126073e/torchgen/gen.py#L2424-L2436 for details""" 448 449 if header: 450 return f""" 451{warning} 452 453#pragma once 454 455#include <torch/csrc/inductor/aoti_torch/c/shim.h> 456 457#ifdef __cplusplus 458extern "C" {{ 459#endif 460 461{body} 462 463#ifdef __cplusplus 464}} // extern "C" 465#endif 466""" 467 468 else: 469 return f""" 470{warning} 471 472#include <torch/csrc/inductor/aoti_torch/generated/c_shim_{device}.h> 473#include <torch/csrc/inductor/aoti_torch/utils.h> 474 475#ifndef AT_PER_OPERATOR_HEADERS 476#include <ATen/{str(dispatch_key)}Functions.h> 477#include <ATen/CompositeExplicitAutogradFunctions.h> 478#include <ATen/CompositeExplicitAutogradNonFunctionalFunctions.h> 479#include <ATen/CompositeImplicitAutogradFunctions.h> 480#else 481{includes} 482#endif 483 484using namespace torch::aot_inductor; 485 486{body}""" 487