1# Generates Python bindings for ATen functions 2# 3# The bindings are generated as methods on python_variable or functions on the 4# torch._C._nn. torch._C._fft, torch._C._linalg, torch._C._nested, torch._C._sparse 5# or torch._C._special objects. 6# 7 8# Code tries to stick to the following rules: 9# 10# - templates should be colocated with the functions that use them. 11# no templates are currently shared between functions, but if that 12# happens, maybe put the template with the first one 13# 14# - don't use environment dictionaries when calling template.substitute(). 15# pass named arguments directly for everything, otherwise it's much too 16# hard to track what's actually being used and by who 17# 18# - colocate any new hacks/adjustments with existing ones of the same kind. 19# ideally in a data structure rather than code if possible. See e.g. 20# SCHEMA_DEFAULT_CONVERSION_HACKS, etc. 21# 22# - similarly, conversions from one format to another should ideally happen 23# all at once in a single place. 24# 25# - no nontrivial nested functions. couple-liners are ok but please no more. 26# especially avoid functions that read/write outer variables defined far away. 27# 28# - raise RuntimeError instead of asserting, and put as much 29# information as is available into the message. I.e. no need to 30# plumb in new params whose only purpose is to fill out an error 31# message, but use what's there 32# 33 34from __future__ import annotations 35 36import itertools 37import re 38from collections import defaultdict 39from typing import Callable, Iterable, Sequence 40 41import yaml 42 43from torchgen.api import cpp 44from torchgen.api.python import ( 45 arg_parser_output_exprs, 46 cpp_dispatch_exprs, 47 cpp_dispatch_target, 48 dispatch_lambda_args, 49 dispatch_lambda_exprs, 50 dispatch_lambda_return_str, 51 has_tensor_options, 52 PythonSignature, 53 PythonSignatureDeprecated, 54 PythonSignatureGroup, 55 PythonSignatureNativeFunctionPair, 56 signature, 57 signature_from_schema, 58 structseq_fieldnames, 59) 60from torchgen.code_template import CodeTemplate 61from torchgen.context import with_native_function 62from torchgen.gen import cpp_string, parse_native_yaml, parse_tags_yaml 63from torchgen.model import ( 64 Argument, 65 BaseOperatorName, 66 FunctionSchema, 67 NativeFunction, 68 SchemaKind, 69 Type, 70 Variant, 71) 72from torchgen.utils import FileManager, split_name_params 73from torchgen.yaml_utils import YamlLoader 74 75from .gen_inplace_or_view_type import is_tensor_list_type 76from .gen_trace_type import should_trace 77 78 79# 80# declarations blocklist 81# We skip codegen for these functions, for various reasons. 82# Future PRs will categorize this list and eliminate or hoist 83# them out of eager-only codegen. 84# See https://github.com/pytorch/pytorch/issues/30788 85# 86 87# These functions require manual Python bindings or are not exposed to Python 88_SKIP_PYTHON_BINDINGS = [ 89 "alias", 90 "contiguous", 91 "is_cuda", 92 "is_sparse", 93 "is_sparse_csr", 94 "size", 95 "stride", 96 "sym_size", 97 "sym_stride", 98 "sym_storage_offset", 99 "sym_numel", 100 ".*_backward", 101 ".*_backward_(out|input|weight|bias)", 102 ".*_forward", 103 ".*_forward_out", 104 ".*_jvp", 105 "_unsafe_view", 106 "tensor", 107 "_?sparse_(coo|compressed|csr|csc|bsr|bsc)_tensor.*", 108 "_range.*", 109 "_sparse_add_out", 110 "_sparse_div.*", 111 "_sparse_mul.*", 112 "_sparse_sub.*", 113 "_sparse_dense_add_out", 114 "index", 115 "index_out", 116 "unique_dim_consecutive", 117 "_cumsum.*", 118 "_cumprod.*", 119 "_sum.*", 120 "_prod.*", 121 "_th_.*", 122 "_thnn_.*", 123 "range.*", 124 "_solve.*", 125 "_inverse.*", 126 "_cholesky.*", 127 "_triangular_solve.*", 128 "_qr.*", 129 "_svd.*", 130 "slice", 131 "item", 132 "_local_scalar_dense", 133 "to", 134 "_to_copy", 135 "_to_copy_out", 136 "_reshape_copy", 137 "_reshape_copy_out", 138 "copy_sparse_to_sparse_", 139 "copy_", 140 "_foreach_copy", 141 "numpy_T", 142 "matrix_H", 143 "mT", 144 "mH", # these need to be an attributes in Python, not functions 145 "nonzero(_(out|numpy))?", 146 "set_data", 147 ".*_overrideable", # overrideable functions for backend extension 148 "data", 149 "is_leaf", 150 "output_nr", 151 "_version", 152 "requires_grad_", 153 "retains_grad", 154 "set_", 155 "_fw_primal", 156 "fake_quantize_per_tensor_affine_cachemask", 157 "fake_quantize_per_channel_affine_cachemask", 158 "_new_zeros_with_same_feature_meta", 159 "_has_same_storage_numel", # used for forward AD internals 160 "_reshape_alias", 161 "replace_", # only used by the functionalization pass, doesn't need to be exposed to python 162 "copy", # only used by the functionalization pass 163 "fill.Tensor", # only used by the functionalization pass 164 "fill.Scalar", # only used by the functionalization pass 165 "lift.*", 166 "normal_functional", # only used by the functionalization pass 167 "nbytes", 168 "itemsize", 169 "_batch_norm_with_update", 170 "_batch_norm_with_update_out", 171 "_batch_norm_no_update", 172] 173 174SKIP_PYTHON_BINDINGS = [ 175 re.compile(rf"^{pattern}$") for pattern in _SKIP_PYTHON_BINDINGS 176] 177 178# These function signatures are not exposed to Python. Note that this signature 179# list does not support regex. 180SKIP_PYTHON_BINDINGS_SIGNATURES = [ 181 "add.Scalar(Tensor self, Scalar other, Scalar alpha=1) -> Tensor", 182 "add_.Scalar(Tensor(a!) self, Scalar other, Scalar alpha=1) -> Tensor(a!)", 183 "sub.Scalar(Tensor self, Scalar other, Scalar alpha=1) -> Tensor", 184 "sub_.Scalar(Tensor(a!) self, Scalar other, Scalar alpha=1) -> Tensor(a!)", 185 "mul.Scalar(Tensor self, Scalar other) -> Tensor", 186 "mul_.Scalar(Tensor(a!) self, Scalar other) -> Tensor(a!)", 187 "div.Scalar(Tensor self, Scalar other) -> Tensor", 188 "div_.Scalar(Tensor(a!) self, Scalar other) -> Tensor(a!)", 189] 190 191 192@with_native_function 193def should_generate_py_binding(f: NativeFunction) -> bool: 194 # NativeFunctions that are entirely code-generated should not get python bindings 195 # because these codegen implementations are often inefficient. A handful of 196 # view_copy style ops were exposed accidentally when they were handwritten and now 197 # that we are moving them to codegen for bc reasons we need to keep them exposed in 198 # python. 199 if "generated" in f.tags and "view_copy" not in f.tags: 200 return False 201 202 name = cpp.name(f.func) 203 for skip_regex in SKIP_PYTHON_BINDINGS: 204 if skip_regex.match(name): 205 return False 206 207 signature = str(f.func) 208 for pattern in SKIP_PYTHON_BINDINGS_SIGNATURES: 209 if pattern == signature: 210 return False 211 return True 212 213 214def get_pycname(name: BaseOperatorName) -> str: 215 return f"THPVariable_{name}" 216 217 218def is_noarg(overloads: Sequence[PythonSignatureNativeFunctionPair]) -> bool: 219 return len(overloads) == 1 and overloads[0].signature.arguments_count() == 0 220 221 222def is_py_variable_method(f: NativeFunction) -> bool: 223 return f.python_module is None and Variant.method in f.variants 224 225 226def is_py_torch_function(f: NativeFunction) -> bool: 227 return f.python_module is None and Variant.function in f.variants 228 229 230def is_py_nn_function(f: NativeFunction) -> bool: 231 return f.python_module == "nn" 232 233 234def is_py_fft_function(f: NativeFunction) -> bool: 235 return f.python_module == "fft" 236 237 238def is_py_linalg_function(f: NativeFunction) -> bool: 239 return f.python_module == "linalg" 240 241 242def is_py_nested_function(f: NativeFunction) -> bool: 243 return f.python_module == "nested" 244 245 246def is_py_sparse_function(f: NativeFunction) -> bool: 247 return f.python_module == "sparse" 248 249 250def is_py_special_function(f: NativeFunction) -> bool: 251 return f.python_module == "special" 252 253 254# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ # 255# 256# Main Function 257# 258# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ # 259 260 261def gen( 262 out: str, 263 native_yaml_path: str, 264 tags_yaml_path: str, 265 deprecated_yaml_path: str, 266 template_path: str, 267 *, 268 symint: bool = True, 269) -> None: 270 fm = FileManager(install_dir=out, template_dir=template_path, dry_run=False) 271 native_functions = parse_native_yaml( 272 native_yaml_path, tags_yaml_path 273 ).native_functions 274 native_functions = list(filter(should_generate_py_binding, native_functions)) 275 276 methods = load_signatures(native_functions, deprecated_yaml_path, method=True) 277 create_python_bindings( 278 fm, 279 methods, 280 is_py_variable_method, 281 None, 282 "python_variable_methods.cpp", 283 method=True, 284 symint=symint, 285 ) 286 287 # NOTE: num_shards here must be synced with gatherTorchFunctions in 288 # torch/csrc/autograd/python_torch_functions_manual.cpp 289 functions = load_signatures(native_functions, deprecated_yaml_path, method=False) 290 create_python_bindings_sharded( 291 fm, 292 functions, 293 is_py_torch_function, 294 "torch", 295 "python_torch_functions.cpp", 296 method=False, 297 num_shards=3, 298 symint=symint, 299 ) 300 301 create_python_bindings( 302 fm, 303 functions, 304 is_py_nn_function, 305 "torch.nn", 306 "python_nn_functions.cpp", 307 method=False, 308 symint=symint, 309 ) 310 311 create_python_bindings( 312 fm, 313 functions, 314 is_py_fft_function, 315 "torch.fft", 316 "python_fft_functions.cpp", 317 method=False, 318 symint=symint, 319 ) 320 321 create_python_bindings( 322 fm, 323 functions, 324 is_py_linalg_function, 325 "torch.linalg", 326 "python_linalg_functions.cpp", 327 method=False, 328 symint=symint, 329 ) 330 331 create_python_bindings( 332 fm, 333 functions, 334 is_py_nested_function, 335 "torch.nested", 336 "python_nested_functions.cpp", 337 method=False, 338 ) 339 340 create_python_bindings( 341 fm, 342 functions, 343 is_py_sparse_function, 344 "torch.sparse", 345 "python_sparse_functions.cpp", 346 method=False, 347 symint=symint, 348 ) 349 350 create_python_bindings( 351 fm, 352 functions, 353 is_py_special_function, 354 "torch.special", 355 "python_special_functions.cpp", 356 method=False, 357 symint=symint, 358 ) 359 360 # Currently, we only use `functions` to generate `return_types` bindings. 361 # All methods which return structseq have function variant at this point. 362 # If any method only operator with structseq is added in the future, 363 # we will have to address that. 364 create_python_return_type_bindings( 365 fm, functions, lambda fn: True, "python_return_types.cpp" 366 ) 367 create_python_return_type_bindings_header( 368 fm, functions, lambda fn: True, "python_return_types.h" 369 ) 370 371 valid_tags = parse_tags_yaml(tags_yaml_path) 372 373 def gen_tags_enum() -> dict[str, str]: 374 return { 375 "enum_of_valid_tags": ( 376 "".join( 377 [f'\n.value("{tag}", at::Tag::{tag})' for tag in sorted(valid_tags)] 378 ) 379 ) 380 } 381 382 fm.write("python_enum_tag.cpp", gen_tags_enum) 383 384 385def group_filter_overloads( 386 pairs: Sequence[PythonSignatureNativeFunctionPair], 387 pred: Callable[[NativeFunction], bool], 388) -> dict[BaseOperatorName, list[PythonSignatureNativeFunctionPair]]: 389 grouped: dict[ 390 BaseOperatorName, list[PythonSignatureNativeFunctionPair] 391 ] = defaultdict(list) 392 for pair in pairs: 393 if pred(pair.function): 394 grouped[pair.function.func.name.name].append(pair) 395 return grouped 396 397 398def create_python_bindings( 399 fm: FileManager, 400 pairs: Sequence[PythonSignatureNativeFunctionPair], 401 pred: Callable[[NativeFunction], bool], 402 module: str | None, 403 filename: str, 404 *, 405 method: bool, 406 symint: bool = True, 407) -> None: 408 """Generates Python bindings to ATen functions""" 409 py_methods: list[str] = [] 410 ops_headers: list[str] = [] 411 py_method_defs: list[str] = [] 412 py_forwards: list[str] = [] 413 414 grouped = group_filter_overloads(pairs, pred) 415 416 for name in sorted(grouped.keys(), key=str): 417 overloads = grouped[name] 418 py_methods.append( 419 method_impl(name, module, overloads, method=method, symint=symint) 420 ) 421 py_method_defs.append(method_def(name, module, overloads, method=method)) 422 py_forwards.extend(forward_decls(name, overloads, method=method)) 423 ops_headers.append(f"#include <ATen/ops/{name.base}.h>") 424 425 fm.write_with_template( 426 filename, 427 filename, 428 lambda: { 429 "generated_comment": "@" 430 + f"generated from {fm.template_dir_for_comments()}/{filename}", 431 "ops_headers": ops_headers, 432 "py_forwards": py_forwards, 433 "py_methods": py_methods, 434 "py_method_defs": py_method_defs, 435 }, 436 ) 437 438 439def create_python_return_type_bindings( 440 fm: FileManager, 441 pairs: Sequence[PythonSignatureNativeFunctionPair], 442 pred: Callable[[NativeFunction], bool], 443 filename: str, 444) -> None: 445 """ 446 Generate function to initialize and return named tuple for native functions 447 which returns named tuple and registration invocations in `python_return_types.cpp`. 448 """ 449 py_return_types_definition: list[str] = [] 450 py_return_types_registrations: list[str] = [] 451 452 grouped = group_filter_overloads(pairs, pred) 453 454 for name in sorted(grouped.keys(), key=str): 455 overloads = grouped[name] 456 definitions, registrations = generate_return_type_definition_and_registrations( 457 overloads 458 ) 459 py_return_types_definition.append( 460 "" if not definitions else "\n".join(definitions) 461 ) 462 py_return_types_registrations.append( 463 "" if not registrations else "\n".join(registrations) 464 ) 465 466 fm.write_with_template( 467 filename, 468 filename, 469 lambda: { 470 "generated_comment": "@" 471 + f"generated from {fm.template_dir_for_comments()}/{filename}", 472 "py_return_types": py_return_types_definition, 473 "py_return_types_registrations": py_return_types_registrations, 474 }, 475 ) 476 477 478def create_python_return_type_bindings_header( 479 fm: FileManager, 480 pairs: Sequence[PythonSignatureNativeFunctionPair], 481 pred: Callable[[NativeFunction], bool], 482 filename: str, 483) -> None: 484 """ 485 Generate function to initialize and return named tuple for native functions 486 which returns named tuple and relevant entry for the map in `python_return_types.cpp`. 487 """ 488 py_return_types_declarations: list[str] = [] 489 490 grouped = group_filter_overloads(pairs, pred) 491 492 for name in sorted(grouped.keys(), key=str): 493 overloads = grouped[name] 494 declarations = generate_return_type_declarations(overloads) 495 py_return_types_declarations.append( 496 "" if not declarations else "\n".join(declarations) 497 ) 498 499 fm.write_with_template( 500 filename, 501 filename, 502 lambda: { 503 "generated_comment": "@" 504 + f"generated from {fm.template_dir_for_comments()}/{filename}", 505 "py_return_types_declarations": py_return_types_declarations, 506 }, 507 ) 508 509 510def create_python_bindings_sharded( 511 fm: FileManager, 512 pairs: Sequence[PythonSignatureNativeFunctionPair], 513 pred: Callable[[NativeFunction], bool], 514 module: str | None, 515 filename: str, 516 *, 517 method: bool, 518 num_shards: int, 519 symint: bool = True, 520) -> None: 521 """Generates Python bindings to ATen functions""" 522 grouped = group_filter_overloads(pairs, pred) 523 524 def key_func( 525 kv: tuple[BaseOperatorName, list[PythonSignatureNativeFunctionPair]] 526 ) -> str: 527 return kv[0].base 528 529 def env_func( 530 kv: tuple[BaseOperatorName, list[PythonSignatureNativeFunctionPair]] 531 ) -> dict[str, list[str]]: 532 name, fn_pairs = kv 533 return { 534 "ops_headers": [f"#include <ATen/ops/{name.base}.h>"], 535 "py_forwards": list(forward_decls(name, fn_pairs, method=method)), 536 "py_methods": [ 537 method_impl(name, module, fn_pairs, method=method, symint=symint) 538 ], 539 "py_method_defs": [method_def(name, module, fn_pairs, method=method)], 540 } 541 542 fm.write_sharded( 543 filename, 544 grouped.items(), 545 base_env={ 546 "generated_comment": "@" 547 + f"generated from {fm.template_dir_for_comments()}/{filename}", 548 }, 549 key_fn=key_func, 550 env_callable=env_func, 551 num_shards=num_shards, 552 sharded_keys={"ops_headers", "py_forwards", "py_methods", "py_method_defs"}, 553 ) 554 555 556def load_signatures( 557 native_functions: list[NativeFunction], 558 deprecated_yaml_path: str, 559 *, 560 method: bool, 561 skip_deprecated: bool = False, 562 pyi: bool = False, 563) -> Sequence[PythonSignatureNativeFunctionPair]: 564 @with_native_function 565 def gen_signature_pairs(f: NativeFunction) -> PythonSignatureNativeFunctionPair: 566 return PythonSignatureNativeFunctionPair( 567 signature=signature(f, method=method, pyi=pyi), 568 function=f, 569 ) 570 571 pairs = list(map(gen_signature_pairs, native_functions)) 572 deprecated = load_deprecated_signatures( 573 pairs, deprecated_yaml_path, method=method, pyi=pyi 574 ) 575 return pairs if skip_deprecated else pairs + deprecated 576 577 578def load_deprecated_signatures( 579 pairs: Sequence[PythonSignatureNativeFunctionPair], 580 deprecated_yaml_path: str, 581 *, 582 method: bool, 583 pyi: bool, 584) -> list[PythonSignatureNativeFunctionPair]: 585 # The deprecated.yaml doesn't have complete type information, we need 586 # find and leverage the original ATen signature (to which it delegates 587 # the call) to generate the full python signature. 588 # We join the deprecated and the original signatures using type-only form. 589 590 # group the original ATen signatures by name 591 grouped: dict[str, list[PythonSignatureNativeFunctionPair]] = defaultdict(list) 592 for pair in pairs: 593 grouped[pair.signature.name].append(pair) 594 595 # find matching original signatures for each deprecated signature 596 results: list[PythonSignatureNativeFunctionPair] = [] 597 598 with open(deprecated_yaml_path) as f: 599 deprecated_defs = yaml.load(f, Loader=YamlLoader) 600 601 for deprecated in deprecated_defs: 602 schema = FunctionSchema.parse(deprecated["name"]) 603 aten_name, call_args = split_name_params(deprecated["aten"]) 604 is_out = aten_name.endswith("_out") 605 if is_out: 606 aten_name = aten_name.replace("_out", "") 607 608 # HACK: these are fixed constants used to pass the aten function. 609 # The type must be known ahead of time 610 known_constants = { 611 "1": Type.parse("Scalar"), 612 } 613 schema_args_by_name = {a.name: a for a in schema.arguments.flat_all} 614 for name in call_args: 615 assert ( 616 name in schema_args_by_name or name in known_constants 617 ), f"deprecation definiton: Unrecognized value {name}" 618 619 # Map deprecated signature arguments to their aten signature and test 620 # if the types and alias annotation match. 621 def is_schema_compatible( 622 aten_schema: FunctionSchema, 623 ) -> bool: 624 arguments: Iterable[Argument] 625 if is_out: 626 arguments = itertools.chain( 627 aten_schema.arguments.out, aten_schema.arguments.flat_non_out 628 ) 629 else: 630 arguments = aten_schema.arguments.flat_all 631 632 for i, arg in enumerate(arguments): 633 if i < len(call_args): 634 arg_name = call_args[i] 635 if arg_name in known_constants: 636 schema_type = known_constants[arg_name] 637 schema_annotation = None 638 else: 639 schema_arg = schema_args_by_name[arg_name] 640 schema_type = schema_arg.type 641 schema_annotation = schema_arg.annotation 642 643 if schema_type != arg.type or schema_annotation != arg.annotation: 644 return False 645 else: 646 if arg.default is None: 647 return False 648 649 return len(schema.returns) == len(aten_schema.returns) and all( 650 a == b for a, b in zip(schema.returns, aten_schema.returns) 651 ) 652 653 any_schema_found = False 654 for pair in grouped[aten_name]: 655 if not is_schema_compatible(pair.function.func): 656 continue 657 any_schema_found = True 658 659 python_sig = signature_from_schema( 660 schema, 661 category_override=pair.function.category_override, 662 method=method, 663 pyi=pyi, 664 ) 665 666 results.append( 667 PythonSignatureNativeFunctionPair( 668 signature=PythonSignatureDeprecated( 669 name=python_sig.name, 670 input_args=python_sig.input_args, 671 input_kwargs=python_sig.input_kwargs, 672 output_args=python_sig.output_args, 673 tensor_options_args=python_sig.tensor_options_args, 674 method=python_sig.method, 675 deprecated_schema=schema, 676 deprecated_args_exprs=tuple(call_args), 677 returns=python_sig.returns, 678 ), 679 function=pair.function, 680 ) 681 ) 682 assert ( 683 any_schema_found 684 ), f"No native function with name {aten_name} matched signature:\n {str(schema)}" 685 686 return results 687 688 689# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ # 690# 691# Named Tuple Codegen 692# 693# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ # 694 695 696@with_native_function 697def gen_structseq_typename_key(f: NativeFunction) -> str: 698 name = cpp.name(f.func) 699 fieldnames = structseq_fieldnames(f.func.returns) 700 return "_".join([name] + fieldnames) 701 702 703def emit_structseq_call( 704 overloads: Sequence[PythonSignatureNativeFunctionPair], 705) -> tuple[list[str], dict[str, str]]: 706 """ 707 Generate block of named tuple type def inits, and add typeref snippets 708 to declarations that use them 709 """ 710 typenames: dict[ 711 str, str 712 ] = {} # map from unique name + field name lists to typedef name 713 typedefs: list[str] = [] # typedef declarations and init code 714 715 for overload in overloads: 716 fieldnames = structseq_fieldnames(overload.function.func.returns) 717 if not fieldnames: 718 continue 719 720 name = cpp.name(overload.function.func) # use @with_native_function? 721 tn_key = gen_structseq_typename_key(overload.function) 722 typename = typenames.get(tn_key) 723 if typename is None: 724 typename = f'NamedTuple{"" if not typedefs else len(typedefs)}' 725 typenames[tn_key] = typename 726 typedefs.append( 727 f"""\ 728static PyTypeObject* {typename} = generated::get_{name}_structseq();""" 729 ) 730 731 return typedefs, typenames 732 733 734def generate_return_type_definition_and_registrations( 735 overloads: Sequence[PythonSignatureNativeFunctionPair], 736) -> tuple[list[str], list[str]]: 737 """ 738 Generate block of function in `python_return_types.cpp` to initialize 739 and return named tuple for a native function which returns named tuple 740 and registration invocations in same file. 741 """ 742 typenames: dict[ 743 str, str 744 ] = {} # map from unique name + field name lists to typedef name 745 definitions: list[str] = [] # function definition to register the typedef 746 registrations: list[str] = [] # register call for the typedef 747 748 for overload in overloads: 749 fieldnames = structseq_fieldnames(overload.function.func.returns) 750 if not fieldnames: 751 continue 752 753 fields = ", ".join(f'{{"{fn}", ""}}' for fn in fieldnames) 754 755 name = cpp.name(overload.function.func) # use @with_native_function? 756 tn_key = gen_structseq_typename_key(overload.function) 757 typename = typenames.get(tn_key) 758 759 if typename is None: 760 typename = f'{name}NamedTuple{"" if not definitions else len(definitions)}' 761 typenames[tn_key] = typename 762 definitions.append( 763 f"""\ 764PyTypeObject* get_{name}_structseq() {{ 765 static PyStructSequence_Field NamedTuple_fields[] = {{ {fields}, {{nullptr}} }}; 766 static PyTypeObject {typename}; 767 static bool is_initialized = false; 768 static PyStructSequence_Desc desc = {{ "torch.return_types.{name}", nullptr, NamedTuple_fields, {len(fieldnames)} }}; 769 if (!is_initialized) {{ 770 PyStructSequence_InitType(&{typename}, &desc); 771 {typename}.tp_repr = (reprfunc)torch::utils::returned_structseq_repr; 772 is_initialized = true; 773 }} 774 return &{typename}; 775}} 776""" 777 ) 778 registrations.append( 779 f'addReturnType(return_types_module, "{name}", generated::get_{name}_structseq());' 780 ) 781 782 return definitions, registrations 783 784 785def generate_return_type_declarations( 786 overloads: Sequence[PythonSignatureNativeFunctionPair], 787) -> list[str]: 788 """ 789 Generate block of function declarations in `python_return_types.h` to initialize 790 and return named tuple for a native function. 791 """ 792 typenames: dict[ 793 str, str 794 ] = {} # map from unique name + field name lists to typedef name 795 declarations: list[str] = [] # function declaration to register the typedef 796 797 for overload in overloads: 798 fieldnames = structseq_fieldnames(overload.function.func.returns) 799 if not fieldnames: 800 continue 801 802 name = cpp.name(overload.function.func) # use @with_native_function? 803 tn_key = gen_structseq_typename_key(overload.function) 804 typename = typenames.get(tn_key) 805 806 if typename is None: 807 typename = ( 808 f'{name}NamedTuple{"" if not declarations else len(declarations)}' 809 ) 810 typenames[tn_key] = typename 811 declarations.append(f"PyTypeObject* get_{name}_structseq();") 812 813 return declarations 814 815 816# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ # 817# 818# Method Impl Codegen 819# 820# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ # 821 822# python binding for all overloads of a particular function/method 823PY_VARIABLE_METHOD_VARARGS = CodeTemplate( 824 r"""\ 825// ${name} 826static PyObject * ${pycname}(PyObject* self_, PyObject* args, PyObject* kwargs) 827{ 828 ${method_header} 829 static PythonArgParser parser({ 830 ${signatures} 831 }, /*traceable=*/${traceable}); 832 833 ParsedArgs<${max_args}> parsed_args; 834 auto _r = parser.parse(${self_}, args, kwargs, parsed_args); 835 ${check_has_torch_function} 836 switch (_r.idx) { 837 ${dispatch} 838 } 839 ${method_footer} 840} 841 842""" 843) 844 845# handler for a single parsed signature - may be a single overload or 846# a pair of overloads that whose signatures only differ in output params 847# (plugged into PY_VARIABLE_METHOD_VARARGS as an item in ${dispatch}) 848PY_VARIABLE_CASE = CodeTemplate( 849 """\ 850case ${overload_index}: { 851 ${body} 852} 853""" 854) 855 856# python binding for single-overload function/method 857PY_VARIABLE_METHOD_VARARGS_SINGLETON = CodeTemplate( 858 """\ 859// ${name} 860static PyObject * ${pycname}(PyObject* self_, PyObject* args, PyObject* kwargs) 861{ 862 ${method_header} 863 static PythonArgParser parser({ 864 ${signatures} 865 }, /*traceable=*/${traceable}); 866 867 ParsedArgs<${max_args}> parsed_args; 868 auto _r = parser.parse(${self_}, args, kwargs, parsed_args); 869 ${check_has_torch_function} 870 ${dispatch} 871 ${method_footer} 872} 873 874""" 875) 876 877# python binding for a method with no args, shortcuts parsing 878PY_VARIABLE_METHOD_NOARGS = CodeTemplate( 879 """\ 880// ${name} 881static PyObject * ${pycname}(PyObject* self_, PyObject* args) 882{ 883 ${method_header} 884 ${check_has_torch_function} 885 ${dispatch} 886 ${method_footer} 887} 888 889""" 890) 891 892 893def method_impl( 894 name: BaseOperatorName, 895 module: str | None, 896 overloads: Sequence[PythonSignatureNativeFunctionPair], 897 *, 898 method: bool, 899 symint: bool = True, 900) -> str: 901 """ 902 Generate a python binding for all overloads of an op. 903 """ 904 pycname = get_pycname(name) 905 noarg = is_noarg(overloads) 906 structseq_inits, structseq_typenames = emit_structseq_call(overloads) 907 908 method_header = ["HANDLE_TH_ERRORS"] 909 method_header += structseq_inits 910 method_header += ( 911 ["const Tensor& self = THPVariable_Unpack(self_);"] if method else [] 912 ) 913 914 method_footer = ([] if noarg else ["Py_RETURN_NONE;"]) + ["END_HANDLE_TH_ERRORS"] 915 916 traceable = "true" if all(should_trace(o.function) for o in overloads) else "false" 917 918 grouped_overloads: Sequence[PythonSignatureGroup] = group_overloads( 919 overloads, symint=symint 920 ) 921 is_singleton = len(grouped_overloads) == 1 922 signatures: list[str] = [] 923 dispatch: list[str] = [] 924 for overload_index, overload in enumerate(grouped_overloads): 925 signature = overload.signature.signature_str(symint=symint) 926 signatures.append(f"{cpp_string(str(signature))},") 927 dispatch_body = emit_dispatch_case(overload, structseq_typenames, symint=symint) 928 dispatch.append( 929 PY_VARIABLE_CASE.substitute( 930 overload_index=overload_index, body=dispatch_body 931 ) 932 if not is_singleton 933 else dispatch_body 934 ) 935 936 if noarg: 937 template = PY_VARIABLE_METHOD_NOARGS 938 elif is_singleton: 939 template = PY_VARIABLE_METHOD_VARARGS_SINGLETON 940 else: 941 template = PY_VARIABLE_METHOD_VARARGS 942 943 return template.substitute( 944 name=name, 945 pycname=pycname, 946 method_header=method_header, 947 max_args=max(o.signature.arguments_count() for o in overloads), 948 signatures=signatures, 949 traceable=traceable, 950 check_has_torch_function=gen_has_torch_function_check( 951 name=name, 952 module=module, 953 noarg=noarg, 954 method=method, 955 ), 956 dispatch=dispatch, 957 method_footer=method_footer, 958 self_="self_" if method else "nullptr", 959 ) 960 961 962def gen_has_torch_function_check( 963 name: BaseOperatorName, module: str | None, *, noarg: bool, method: bool 964) -> str: 965 if noarg: 966 if method: 967 return f"""\ 968if(check_has_torch_function(self_)) {{ 969 return handle_torch_function(self_, "{name}"); 970}} 971""" 972 else: 973 return "" 974 975 self_ = "self_" if method else "nullptr" 976 namespace = ( 977 { 978 "torch": "THPVariableFunctionsModule", 979 "torch.nn": "THPNNVariableFunctionsModule", 980 "torch.fft": "THPFFTVariableFunctionsModule", 981 "torch.linalg": "THPLinalgVariableFunctionsModule", 982 "torch.nested": "THPNestedVariableFunctionsModule", 983 "torch.sparse": "THPSparseVariableFunctionsModule", 984 "torch.special": "THPSpecialVariableFunctionsModule", 985 }[module] 986 if module 987 else "THPVariableClass" 988 ) 989 990 return f"""\ 991if(_r.has_torch_function()) {{ 992 return handle_torch_function(_r, {self_}, args, kwargs, {namespace}, "{module or "torch.Tensor"}"); 993}} 994""" 995 996 997# handler for output/no-output overload pair 998PY_VARIABLE_OUT = CodeTemplate( 999 """\ 1000if (_r.isNone(${out_idx})) { 1001 ${call_dispatch} 1002} else { 1003 ${call_dispatch_out} 1004} 1005""" 1006) 1007 1008 1009def emit_dispatch_case( 1010 overload: PythonSignatureGroup, 1011 structseq_typenames: dict[str, str], 1012 *, 1013 symint: bool = True, 1014) -> str: 1015 """ 1016 Emit dispatch code for a single parsed signature. This corresponds to either 1017 a single native function, or a pair that differ only in output params. In the 1018 latter case, a single python signature is used for both and dispatching 1019 switches on the presence/absence of passed output args. 1020 """ 1021 if overload.outplace is not None: 1022 # dispatch output and no-output variants, branch on _r.isNone(<out_idx>) 1023 return PY_VARIABLE_OUT.substitute( 1024 out_idx=overload.signature.output_idx(), 1025 call_dispatch=emit_single_dispatch( 1026 overload.signature, overload.base, structseq_typenames, symint=symint 1027 ), 1028 call_dispatch_out=emit_single_dispatch( 1029 overload.signature, 1030 overload.outplace, 1031 structseq_typenames, 1032 symint=symint, 1033 ), 1034 ) 1035 else: 1036 # no-output version only 1037 return emit_single_dispatch( 1038 overload.signature, overload.base, structseq_typenames, symint=symint 1039 ) 1040 1041 1042# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ # 1043# 1044# Forward Declarations Codegen 1045# 1046# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ # 1047 1048 1049def forward_decls( 1050 name: BaseOperatorName, 1051 overloads: Sequence[PythonSignatureNativeFunctionPair], 1052 *, 1053 method: bool, 1054) -> tuple[str, ...]: 1055 if method: 1056 return () 1057 1058 pycname = get_pycname(name) 1059 if is_noarg(overloads): 1060 return ( 1061 f"""\ 1062static PyObject * {pycname}(PyObject* self_, PyObject* args); 1063""", 1064 ) 1065 else: 1066 return ( 1067 f"""\ 1068static PyObject * {pycname}(PyObject* self_, PyObject* args, PyObject* kwargs); 1069""", 1070 ) 1071 1072 1073# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ # 1074# 1075# Method Def (Binding Table Entry) Codegen 1076# 1077# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ # 1078 1079 1080def method_def( 1081 name: BaseOperatorName, 1082 module: str | None, 1083 overloads: Sequence[PythonSignatureNativeFunctionPair], 1084 *, 1085 method: bool, 1086) -> str: 1087 """ 1088 Generate method def entry. 1089 """ 1090 pycname = get_pycname(name) 1091 1092 if name.dunder_method: 1093 # PyMethodDef entry for binary op, throws not implemented error 1094 pycname = f"TypeError_to_NotImplemented_<{pycname}>" 1095 1096 if is_noarg(overloads): 1097 flags = "METH_NOARGS" if method else "METH_VARARGS | METH_KEYWORDS" 1098 else: 1099 pycname = f"castPyCFunctionWithKeywords({pycname})" 1100 flags = "METH_VARARGS | METH_KEYWORDS" 1101 1102 if module == "torch": 1103 flags += " | METH_STATIC" 1104 1105 return f'{{"{name}", {pycname}, {flags}, NULL}},' 1106 1107 1108# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ # 1109# 1110# Overload Sorting and Grouping 1111# 1112# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ # 1113 1114 1115def group_overloads( 1116 overloads: Sequence[PythonSignatureNativeFunctionPair], *, symint: bool = True 1117) -> Sequence[PythonSignatureGroup]: 1118 bases: dict[str, PythonSignatureNativeFunctionPair] = {} 1119 outplaces: dict[str, PythonSignatureNativeFunctionPair] = {} 1120 1121 # first group by signature ignoring out arguments 1122 for overload in overloads: 1123 sig = overload.signature.signature_str(skip_outputs=True, symint=symint) 1124 if overload.function.func.is_out_fn(): 1125 if sig in outplaces: 1126 raise RuntimeError( 1127 f"Found duplicated function definition:\n- {overload.function.func}.\n" 1128 f"Existing definition:\n- {outplaces[sig].function.func}." 1129 ) 1130 outplaces[sig] = overload 1131 else: 1132 if sig in bases: 1133 raise RuntimeError( 1134 f"Found duplicated function definition:\n- {overload.function.func}.\n" 1135 f"Existing definition:\n- {bases[sig].function.func}." 1136 ) 1137 bases[sig] = overload 1138 1139 for sig, out in outplaces.items(): 1140 if sig not in bases: 1141 candidates: list[str] = [] 1142 for overload in overloads: 1143 if ( 1144 str(overload.function.func.name.name) 1145 == str(out.function.func.name.name) 1146 and not overload.function.func.is_out_fn() 1147 and not overload.signature.deprecated 1148 ): 1149 candidates.append( 1150 overload.signature.signature_str( 1151 skip_outputs=True, symint=symint 1152 ) 1153 ) 1154 out_sig = out.signature.signature_str(symint=symint) 1155 raise RuntimeError( 1156 f"While identifying overloads, we found an out schema {out_sig} without a corresponding non-out variant. " 1157 f"We expected the non-out variant to have schema: \n- {sig}\nPlease check that you spelled the schema " 1158 "correctly in native_functions.yaml. We discovered the following candidate(s): \n" 1159 + "\n".join(f"- {candidate}" for candidate in candidates) 1160 ) 1161 1162 grouped = [ 1163 PythonSignatureGroup.from_pairs( 1164 functional=base, 1165 out=outplaces.get(sig), 1166 ) 1167 for sig, base in bases.items() 1168 ] 1169 return sort_overloads(grouped, symint=symint) 1170 1171 1172# This function declares a partial order on declarations, and sorts them according 1173# to its linear extension. This is necessary, because there's some ambiguity in the 1174# choice of overload, and we want a different order. 1175# 1176# See Note[Order of overloads matters] 1177# 1178# A few examples of ambiguous python signature pairs. 1179# 1180# All parameters have the same type, except one taking Tensor the other taking 1181# Scalar. A numeric PyObject can be casted into Tensor, and a zero-dim Tensor 1182# object can be accepted as Scalar type parameter (see python_arg_parser.cpp). 1183# Therefore, same input arguments might be accepted by either python signature. 1184# We want to always parse the one taking Tensor first. 1185# 1186# bitwise_and(Tensor input, Tensor other, *, Tensor out=None) 1187# bitwise_and(Tensor input, Scalar other, *, Tensor out=None) 1188# 1189# If they have different number of parameters then they are not ambiguous - but 1190# the difference on output param can be ignored as it's optional. 1191# 1192# multiply(Tensor input, Tensor other, *, Tensor out=None) 1193# multiply(Tensor input, Scalar other) 1194# 1195# Both positional args and keyword-only args are considered together. 1196# 1197# subtract(Tensor other, *, Scalar alpha=1) 1198# subtract(Scalar other, Scalar alpha=1) 1199# 1200# A few ambiguous cases which it does NOT handle yet. 1201# 1202# If there is any difference in other parameters besides the Tensor/Scalar 1203# difference, then they are not considered ambiguous by this method anymore. 1204# However, the difference could be too trivial to disambiguate. 1205# 1206# foo(Tensor input, Scalar other, Scalar bar) 1207# foo(Tensor input, Tensor other, double bar) 1208# 1209# If they are taking different number of parameters then they are not considered 1210# ambiguous anymore, even if the difference is only on optional kwargs. 1211# 1212# foo(Scalar other, Scalar alpha=1) 1213# foo(Tensor other, *, Scalar alpha=1, Scalar beta=1) 1214# 1215 1216 1217def sort_overloads( 1218 grouped_overloads: Sequence[PythonSignatureGroup], *, symint: bool = True 1219) -> Sequence[PythonSignatureGroup]: 1220 # NB: Smaller here means lower priority 1221 1222 def is_arg_smaller(t1: Type, t2: Type) -> bool: 1223 return ( 1224 str(t1) == "Scalar" 1225 and str(t2) == "Tensor" 1226 or str(t1) == "Scalar?" 1227 and str(t2) == "Tensor?" 1228 or "Dimname" in str(t1) 1229 and "Dimname" not in str(t2) 1230 or 1231 # In the discussion https://github.com/pytorch/pytorch/issues/54555 it has been 1232 # discussed why it is important to prioritize int/int? over int[] 1233 str(t1) == "int[]" 1234 and (str(t2) == "int" or str(t2) == "int?") 1235 or 1236 # TensorList currently throws an error during argument parsing, that's why it needs to be 1237 # last in signature ordering. See discussion: https://github.com/pytorch/pytorch/issues/58087 1238 str(t1) == "Tensor[]" 1239 and str(t2).find("[]") != -1 1240 or 1241 # Prioritize IntArrayRef overload over SymIntArrayRef 1242 str(t1) == "SymInt[]" 1243 and str(t2) == "int[]" 1244 or 1245 # Make sure both in, SymInt are sorted consistently w.r.t. Tensor since Tensor can be implicitly 1246 # converted to either int or SymInt. Prioritize the Tensor overload since it otherwise gets shadowed. 1247 (str(t1) == "SymInt" or str(t1) == "int") 1248 and str(t2) == "Tensor" 1249 ) 1250 1251 def is_smaller(s1: PythonSignature, s2: PythonSignature) -> bool: 1252 """Returns True if s1 < s2 in the partial order.""" 1253 args1, args2 = s1.arguments(skip_outputs=True), s2.arguments(skip_outputs=True) 1254 if len(args1) != len(args2): 1255 return False 1256 # TODO: should use some canonical form instead of 'str(arg.type)' - see comments 1257 # above. The old codegen used the deprecated 'dynamic_type(arg.type)', which 1258 # ignores the optional annotation, i.e. 'Scalar' and 'Scalar?'. 1259 equal = all(arg1.type == arg2.type for arg1, arg2 in zip(args1, args2)) 1260 smaller_or_equal = all( 1261 str(arg1.type) == str(arg2.type) or is_arg_smaller(arg1.type, arg2.type) 1262 for arg1, arg2 in zip(args1, args2) 1263 ) 1264 return smaller_or_equal and not equal 1265 1266 # First sort by signature 1267 grouped_overloads = sorted( 1268 grouped_overloads, key=lambda x: x.signature.signature_str(symint=symint) 1269 ) 1270 1271 # Construct the relation graph 1272 larger_than: dict[int, set[int]] = defaultdict(set) 1273 for i1, overload1 in enumerate(grouped_overloads): 1274 for i2, overload2 in enumerate(grouped_overloads): 1275 if is_smaller(overload1.signature, overload2.signature): 1276 larger_than[i1].add(i2) 1277 1278 if not larger_than: 1279 return list(grouped_overloads) 1280 1281 # Use a topological sort to sort overloads according to the partial order. 1282 N = len(grouped_overloads) 1283 sorted_ids: list[int] = list(filter(lambda x: x not in larger_than, range(N))) 1284 1285 for idx in range(N): 1286 # The size of sorted_ids will grow to N eventually. 1287 i = sorted_ids[idx] 1288 for j in sorted(larger_than.keys()): 1289 larger = larger_than[j] 1290 larger.discard(i) 1291 if not larger: 1292 del larger_than[j] 1293 sorted_ids.append(j) 1294 1295 return [grouped_overloads[x] for x in sorted_ids] 1296 1297 1298# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ # 1299# 1300# Codegen API Integration 1301# 1302# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ # 1303 1304 1305def emit_single_dispatch( 1306 ps: PythonSignature, 1307 f: NativeFunction, 1308 structseq_typenames: dict[str, str], 1309 *, 1310 symint: bool = True, 1311) -> str: 1312 """ 1313 Emit dispatch code for a single native function. 1314 """ 1315 1316 @with_native_function 1317 def go(f: NativeFunction) -> str: 1318 # header comments 1319 if isinstance(ps, PythonSignatureDeprecated): 1320 schema_comment = f"// [deprecated] aten::{ps.deprecated_schema}" 1321 else: 1322 schema_comment = f"// aten::{f.func}" 1323 1324 deprecated = "[deprecated] " if ps.deprecated else "" 1325 1326 # dispatch lambda signature 1327 name = cpp.name(f.func) 1328 lambda_formals = ", ".join( 1329 f"{a.type_str} {a.name}" for a in dispatch_lambda_args(ps, f, symint=symint) 1330 ) 1331 lambda_return = dispatch_lambda_return_str(f) 1332 1333 # dispatch lambda body 1334 dispatch_callee = cpp_dispatch_target(f) 1335 dispatch_args = ", ".join(cpp_dispatch_exprs(f, python_signature=ps)) 1336 1337 # from arg parser outputs to dispatch lambda arguments 1338 parser_outputs = arg_parser_output_exprs(ps, f, symint=symint) 1339 lambda_arg_exprs = dispatch_lambda_exprs(ps, f, symint=symint) 1340 inits = "\n".join(lambda_arg_exprs.inits) 1341 lambda_args = ", ".join(lambda_arg_exprs.exprs) 1342 1343 # scatter fields 1344 # TODO: Checking `ps.method and ('requires_grad' in parser_outputs)` is a hacky 1345 # solution for enabling the 'requires_grad' argument for tensor methods 1346 # new_full, new_empty, and new_zeros. A much better but more difficult to 1347 # implement solution involves refactoring according to Ed's description here: 1348 # https://github.com/pytorch/pytorch/issues/36455#issuecomment-614767589 1349 need_set_requires_grad = ps.tensor_options_args and ( 1350 not has_tensor_options(f) 1351 or (ps.method and ("requires_grad" in parser_outputs)) 1352 ) 1353 set_requires_grad = ( 1354 f'.set_requires_grad({parser_outputs["requires_grad"].expr})' 1355 if need_set_requires_grad 1356 else "" 1357 ) 1358 1359 if lambda_return == "void": 1360 # Make in-place foreach return `self` at python-binding level. 1361 # ref: https://github.com/pytorch/pytorch/pull/118622#pullrequestreview-1904804954 1362 self_arg = f.func.arguments.self_arg 1363 return_stmt: str 1364 if ( 1365 str(f.func.name).startswith("_foreach_") 1366 and f.func.kind() == SchemaKind.inplace 1367 ): 1368 # note(crcrpar): `_foreach_pow.ScalarAndTensor` does NOT have its in-place 1369 # variant and it unlikely to have it in the future. Thus it's safe to have the following assert. 1370 assert self_arg is not None and is_tensor_list_type( 1371 self_arg.argument.type 1372 ) 1373 return_stmt = """PyObject* self_tensorlist = _r.args[0]; 1374Py_INCREF(self_tensorlist); 1375return self_tensorlist; 1376""" 1377 else: 1378 return_stmt = "Py_RETURN_NONE;" 1379 return f"""\ 1380{schema_comment} 1381{inits} 1382auto dispatch_{name} = []({lambda_formals}) -> {lambda_return} {{ 1383 pybind11::gil_scoped_release no_gil; 1384 {dispatch_callee}({dispatch_args}); 1385}}; 1386dispatch_{name}({lambda_args}){set_requires_grad}; 1387{return_stmt} 1388""" 1389 else: 1390 typename = structseq_typenames.get(gen_structseq_typename_key(f)) 1391 structseq_typeref = f"{typename}, " if typename is not None else "" 1392 return f"""\ 1393{schema_comment} 1394{inits} 1395auto dispatch_{name} = []({lambda_formals}) -> {lambda_return} {{ 1396 pybind11::gil_scoped_release no_gil; 1397 return {dispatch_callee}({dispatch_args}); 1398}}; 1399return wrap({structseq_typeref}dispatch_{name}({lambda_args}){set_requires_grad}); 1400""" 1401 1402 return go(f) 1403