xref: /aosp_15_r20/external/pytorch/tools/autograd/gen_python_functions.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# Generates Python bindings for ATen functions
2#
3# The bindings are generated as methods on python_variable or functions on the
4# torch._C._nn. torch._C._fft, torch._C._linalg, torch._C._nested, torch._C._sparse
5# or torch._C._special objects.
6#
7
8# Code tries to stick to the following rules:
9#
10# - templates should be colocated with the functions that use them.
11#   no templates are currently shared between functions, but if that
12#   happens, maybe put the template with the first one
13#
14# - don't use environment dictionaries when calling template.substitute().
15#   pass named arguments directly for everything, otherwise it's much too
16#   hard to track what's actually being used and by who
17#
18# - colocate any new hacks/adjustments with existing ones of the same kind.
19#   ideally in a data structure rather than code if possible. See e.g.
20#   SCHEMA_DEFAULT_CONVERSION_HACKS, etc.
21#
22# - similarly, conversions from one format to another should ideally happen
23#   all at once in a single place.
24#
25# - no nontrivial nested functions. couple-liners are ok but please no more.
26#   especially avoid functions that read/write outer variables defined far away.
27#
28# - raise RuntimeError instead of asserting, and put as much
29#   information as is available into the message. I.e. no need to
30#   plumb in new params whose only purpose is to fill out an error
31#   message, but use what's there
32#
33
34from __future__ import annotations
35
36import itertools
37import re
38from collections import defaultdict
39from typing import Callable, Iterable, Sequence
40
41import yaml
42
43from torchgen.api import cpp
44from torchgen.api.python import (
45    arg_parser_output_exprs,
46    cpp_dispatch_exprs,
47    cpp_dispatch_target,
48    dispatch_lambda_args,
49    dispatch_lambda_exprs,
50    dispatch_lambda_return_str,
51    has_tensor_options,
52    PythonSignature,
53    PythonSignatureDeprecated,
54    PythonSignatureGroup,
55    PythonSignatureNativeFunctionPair,
56    signature,
57    signature_from_schema,
58    structseq_fieldnames,
59)
60from torchgen.code_template import CodeTemplate
61from torchgen.context import with_native_function
62from torchgen.gen import cpp_string, parse_native_yaml, parse_tags_yaml
63from torchgen.model import (
64    Argument,
65    BaseOperatorName,
66    FunctionSchema,
67    NativeFunction,
68    SchemaKind,
69    Type,
70    Variant,
71)
72from torchgen.utils import FileManager, split_name_params
73from torchgen.yaml_utils import YamlLoader
74
75from .gen_inplace_or_view_type import is_tensor_list_type
76from .gen_trace_type import should_trace
77
78
79#
80# declarations blocklist
81# We skip codegen for these functions, for various reasons.
82# Future PRs will categorize this list and eliminate or hoist
83# them out of eager-only codegen.
84# See https://github.com/pytorch/pytorch/issues/30788
85#
86
87# These functions require manual Python bindings or are not exposed to Python
88_SKIP_PYTHON_BINDINGS = [
89    "alias",
90    "contiguous",
91    "is_cuda",
92    "is_sparse",
93    "is_sparse_csr",
94    "size",
95    "stride",
96    "sym_size",
97    "sym_stride",
98    "sym_storage_offset",
99    "sym_numel",
100    ".*_backward",
101    ".*_backward_(out|input|weight|bias)",
102    ".*_forward",
103    ".*_forward_out",
104    ".*_jvp",
105    "_unsafe_view",
106    "tensor",
107    "_?sparse_(coo|compressed|csr|csc|bsr|bsc)_tensor.*",
108    "_range.*",
109    "_sparse_add_out",
110    "_sparse_div.*",
111    "_sparse_mul.*",
112    "_sparse_sub.*",
113    "_sparse_dense_add_out",
114    "index",
115    "index_out",
116    "unique_dim_consecutive",
117    "_cumsum.*",
118    "_cumprod.*",
119    "_sum.*",
120    "_prod.*",
121    "_th_.*",
122    "_thnn_.*",
123    "range.*",
124    "_solve.*",
125    "_inverse.*",
126    "_cholesky.*",
127    "_triangular_solve.*",
128    "_qr.*",
129    "_svd.*",
130    "slice",
131    "item",
132    "_local_scalar_dense",
133    "to",
134    "_to_copy",
135    "_to_copy_out",
136    "_reshape_copy",
137    "_reshape_copy_out",
138    "copy_sparse_to_sparse_",
139    "copy_",
140    "_foreach_copy",
141    "numpy_T",
142    "matrix_H",
143    "mT",
144    "mH",  # these need to be an attributes in Python, not functions
145    "nonzero(_(out|numpy))?",
146    "set_data",
147    ".*_overrideable",  # overrideable functions for backend extension
148    "data",
149    "is_leaf",
150    "output_nr",
151    "_version",
152    "requires_grad_",
153    "retains_grad",
154    "set_",
155    "_fw_primal",
156    "fake_quantize_per_tensor_affine_cachemask",
157    "fake_quantize_per_channel_affine_cachemask",
158    "_new_zeros_with_same_feature_meta",
159    "_has_same_storage_numel",  # used for forward AD internals
160    "_reshape_alias",
161    "replace_",  # only used by the functionalization pass, doesn't need to be exposed to python
162    "copy",  # only used by the functionalization pass
163    "fill.Tensor",  # only used by the functionalization pass
164    "fill.Scalar",  # only used by the functionalization pass
165    "lift.*",
166    "normal_functional",  # only used by the functionalization pass
167    "nbytes",
168    "itemsize",
169    "_batch_norm_with_update",
170    "_batch_norm_with_update_out",
171    "_batch_norm_no_update",
172]
173
174SKIP_PYTHON_BINDINGS = [
175    re.compile(rf"^{pattern}$") for pattern in _SKIP_PYTHON_BINDINGS
176]
177
178# These function signatures are not exposed to Python. Note that this signature
179# list does not support regex.
180SKIP_PYTHON_BINDINGS_SIGNATURES = [
181    "add.Scalar(Tensor self, Scalar other, Scalar alpha=1) -> Tensor",
182    "add_.Scalar(Tensor(a!) self, Scalar other, Scalar alpha=1) -> Tensor(a!)",
183    "sub.Scalar(Tensor self, Scalar other, Scalar alpha=1) -> Tensor",
184    "sub_.Scalar(Tensor(a!) self, Scalar other, Scalar alpha=1) -> Tensor(a!)",
185    "mul.Scalar(Tensor self, Scalar other) -> Tensor",
186    "mul_.Scalar(Tensor(a!) self, Scalar other) -> Tensor(a!)",
187    "div.Scalar(Tensor self, Scalar other) -> Tensor",
188    "div_.Scalar(Tensor(a!) self, Scalar other) -> Tensor(a!)",
189]
190
191
192@with_native_function
193def should_generate_py_binding(f: NativeFunction) -> bool:
194    # NativeFunctions that are entirely code-generated should not get python bindings
195    # because these codegen implementations are often inefficient. A handful of
196    # view_copy style ops were exposed accidentally when they were handwritten and now
197    # that we are moving them to codegen for bc reasons we need to keep them exposed in
198    # python.
199    if "generated" in f.tags and "view_copy" not in f.tags:
200        return False
201
202    name = cpp.name(f.func)
203    for skip_regex in SKIP_PYTHON_BINDINGS:
204        if skip_regex.match(name):
205            return False
206
207    signature = str(f.func)
208    for pattern in SKIP_PYTHON_BINDINGS_SIGNATURES:
209        if pattern == signature:
210            return False
211    return True
212
213
214def get_pycname(name: BaseOperatorName) -> str:
215    return f"THPVariable_{name}"
216
217
218def is_noarg(overloads: Sequence[PythonSignatureNativeFunctionPair]) -> bool:
219    return len(overloads) == 1 and overloads[0].signature.arguments_count() == 0
220
221
222def is_py_variable_method(f: NativeFunction) -> bool:
223    return f.python_module is None and Variant.method in f.variants
224
225
226def is_py_torch_function(f: NativeFunction) -> bool:
227    return f.python_module is None and Variant.function in f.variants
228
229
230def is_py_nn_function(f: NativeFunction) -> bool:
231    return f.python_module == "nn"
232
233
234def is_py_fft_function(f: NativeFunction) -> bool:
235    return f.python_module == "fft"
236
237
238def is_py_linalg_function(f: NativeFunction) -> bool:
239    return f.python_module == "linalg"
240
241
242def is_py_nested_function(f: NativeFunction) -> bool:
243    return f.python_module == "nested"
244
245
246def is_py_sparse_function(f: NativeFunction) -> bool:
247    return f.python_module == "sparse"
248
249
250def is_py_special_function(f: NativeFunction) -> bool:
251    return f.python_module == "special"
252
253
254# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ #
255#
256#                            Main Function
257#
258# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ #
259
260
261def gen(
262    out: str,
263    native_yaml_path: str,
264    tags_yaml_path: str,
265    deprecated_yaml_path: str,
266    template_path: str,
267    *,
268    symint: bool = True,
269) -> None:
270    fm = FileManager(install_dir=out, template_dir=template_path, dry_run=False)
271    native_functions = parse_native_yaml(
272        native_yaml_path, tags_yaml_path
273    ).native_functions
274    native_functions = list(filter(should_generate_py_binding, native_functions))
275
276    methods = load_signatures(native_functions, deprecated_yaml_path, method=True)
277    create_python_bindings(
278        fm,
279        methods,
280        is_py_variable_method,
281        None,
282        "python_variable_methods.cpp",
283        method=True,
284        symint=symint,
285    )
286
287    # NOTE: num_shards here must be synced with gatherTorchFunctions in
288    #       torch/csrc/autograd/python_torch_functions_manual.cpp
289    functions = load_signatures(native_functions, deprecated_yaml_path, method=False)
290    create_python_bindings_sharded(
291        fm,
292        functions,
293        is_py_torch_function,
294        "torch",
295        "python_torch_functions.cpp",
296        method=False,
297        num_shards=3,
298        symint=symint,
299    )
300
301    create_python_bindings(
302        fm,
303        functions,
304        is_py_nn_function,
305        "torch.nn",
306        "python_nn_functions.cpp",
307        method=False,
308        symint=symint,
309    )
310
311    create_python_bindings(
312        fm,
313        functions,
314        is_py_fft_function,
315        "torch.fft",
316        "python_fft_functions.cpp",
317        method=False,
318        symint=symint,
319    )
320
321    create_python_bindings(
322        fm,
323        functions,
324        is_py_linalg_function,
325        "torch.linalg",
326        "python_linalg_functions.cpp",
327        method=False,
328        symint=symint,
329    )
330
331    create_python_bindings(
332        fm,
333        functions,
334        is_py_nested_function,
335        "torch.nested",
336        "python_nested_functions.cpp",
337        method=False,
338    )
339
340    create_python_bindings(
341        fm,
342        functions,
343        is_py_sparse_function,
344        "torch.sparse",
345        "python_sparse_functions.cpp",
346        method=False,
347        symint=symint,
348    )
349
350    create_python_bindings(
351        fm,
352        functions,
353        is_py_special_function,
354        "torch.special",
355        "python_special_functions.cpp",
356        method=False,
357        symint=symint,
358    )
359
360    # Currently, we only use `functions` to generate `return_types` bindings.
361    # All methods which return structseq have function variant at this point.
362    # If any method only operator with structseq is added in the future,
363    # we will have to address that.
364    create_python_return_type_bindings(
365        fm, functions, lambda fn: True, "python_return_types.cpp"
366    )
367    create_python_return_type_bindings_header(
368        fm, functions, lambda fn: True, "python_return_types.h"
369    )
370
371    valid_tags = parse_tags_yaml(tags_yaml_path)
372
373    def gen_tags_enum() -> dict[str, str]:
374        return {
375            "enum_of_valid_tags": (
376                "".join(
377                    [f'\n.value("{tag}", at::Tag::{tag})' for tag in sorted(valid_tags)]
378                )
379            )
380        }
381
382    fm.write("python_enum_tag.cpp", gen_tags_enum)
383
384
385def group_filter_overloads(
386    pairs: Sequence[PythonSignatureNativeFunctionPair],
387    pred: Callable[[NativeFunction], bool],
388) -> dict[BaseOperatorName, list[PythonSignatureNativeFunctionPair]]:
389    grouped: dict[
390        BaseOperatorName, list[PythonSignatureNativeFunctionPair]
391    ] = defaultdict(list)
392    for pair in pairs:
393        if pred(pair.function):
394            grouped[pair.function.func.name.name].append(pair)
395    return grouped
396
397
398def create_python_bindings(
399    fm: FileManager,
400    pairs: Sequence[PythonSignatureNativeFunctionPair],
401    pred: Callable[[NativeFunction], bool],
402    module: str | None,
403    filename: str,
404    *,
405    method: bool,
406    symint: bool = True,
407) -> None:
408    """Generates Python bindings to ATen functions"""
409    py_methods: list[str] = []
410    ops_headers: list[str] = []
411    py_method_defs: list[str] = []
412    py_forwards: list[str] = []
413
414    grouped = group_filter_overloads(pairs, pred)
415
416    for name in sorted(grouped.keys(), key=str):
417        overloads = grouped[name]
418        py_methods.append(
419            method_impl(name, module, overloads, method=method, symint=symint)
420        )
421        py_method_defs.append(method_def(name, module, overloads, method=method))
422        py_forwards.extend(forward_decls(name, overloads, method=method))
423        ops_headers.append(f"#include <ATen/ops/{name.base}.h>")
424
425    fm.write_with_template(
426        filename,
427        filename,
428        lambda: {
429            "generated_comment": "@"
430            + f"generated from {fm.template_dir_for_comments()}/{filename}",
431            "ops_headers": ops_headers,
432            "py_forwards": py_forwards,
433            "py_methods": py_methods,
434            "py_method_defs": py_method_defs,
435        },
436    )
437
438
439def create_python_return_type_bindings(
440    fm: FileManager,
441    pairs: Sequence[PythonSignatureNativeFunctionPair],
442    pred: Callable[[NativeFunction], bool],
443    filename: str,
444) -> None:
445    """
446    Generate function to initialize and return named tuple for native functions
447    which returns named tuple and registration invocations in `python_return_types.cpp`.
448    """
449    py_return_types_definition: list[str] = []
450    py_return_types_registrations: list[str] = []
451
452    grouped = group_filter_overloads(pairs, pred)
453
454    for name in sorted(grouped.keys(), key=str):
455        overloads = grouped[name]
456        definitions, registrations = generate_return_type_definition_and_registrations(
457            overloads
458        )
459        py_return_types_definition.append(
460            "" if not definitions else "\n".join(definitions)
461        )
462        py_return_types_registrations.append(
463            "" if not registrations else "\n".join(registrations)
464        )
465
466    fm.write_with_template(
467        filename,
468        filename,
469        lambda: {
470            "generated_comment": "@"
471            + f"generated from {fm.template_dir_for_comments()}/{filename}",
472            "py_return_types": py_return_types_definition,
473            "py_return_types_registrations": py_return_types_registrations,
474        },
475    )
476
477
478def create_python_return_type_bindings_header(
479    fm: FileManager,
480    pairs: Sequence[PythonSignatureNativeFunctionPair],
481    pred: Callable[[NativeFunction], bool],
482    filename: str,
483) -> None:
484    """
485    Generate function to initialize and return named tuple for native functions
486    which returns named tuple and relevant entry for the map in `python_return_types.cpp`.
487    """
488    py_return_types_declarations: list[str] = []
489
490    grouped = group_filter_overloads(pairs, pred)
491
492    for name in sorted(grouped.keys(), key=str):
493        overloads = grouped[name]
494        declarations = generate_return_type_declarations(overloads)
495        py_return_types_declarations.append(
496            "" if not declarations else "\n".join(declarations)
497        )
498
499    fm.write_with_template(
500        filename,
501        filename,
502        lambda: {
503            "generated_comment": "@"
504            + f"generated from {fm.template_dir_for_comments()}/{filename}",
505            "py_return_types_declarations": py_return_types_declarations,
506        },
507    )
508
509
510def create_python_bindings_sharded(
511    fm: FileManager,
512    pairs: Sequence[PythonSignatureNativeFunctionPair],
513    pred: Callable[[NativeFunction], bool],
514    module: str | None,
515    filename: str,
516    *,
517    method: bool,
518    num_shards: int,
519    symint: bool = True,
520) -> None:
521    """Generates Python bindings to ATen functions"""
522    grouped = group_filter_overloads(pairs, pred)
523
524    def key_func(
525        kv: tuple[BaseOperatorName, list[PythonSignatureNativeFunctionPair]]
526    ) -> str:
527        return kv[0].base
528
529    def env_func(
530        kv: tuple[BaseOperatorName, list[PythonSignatureNativeFunctionPair]]
531    ) -> dict[str, list[str]]:
532        name, fn_pairs = kv
533        return {
534            "ops_headers": [f"#include <ATen/ops/{name.base}.h>"],
535            "py_forwards": list(forward_decls(name, fn_pairs, method=method)),
536            "py_methods": [
537                method_impl(name, module, fn_pairs, method=method, symint=symint)
538            ],
539            "py_method_defs": [method_def(name, module, fn_pairs, method=method)],
540        }
541
542    fm.write_sharded(
543        filename,
544        grouped.items(),
545        base_env={
546            "generated_comment": "@"
547            + f"generated from {fm.template_dir_for_comments()}/{filename}",
548        },
549        key_fn=key_func,
550        env_callable=env_func,
551        num_shards=num_shards,
552        sharded_keys={"ops_headers", "py_forwards", "py_methods", "py_method_defs"},
553    )
554
555
556def load_signatures(
557    native_functions: list[NativeFunction],
558    deprecated_yaml_path: str,
559    *,
560    method: bool,
561    skip_deprecated: bool = False,
562    pyi: bool = False,
563) -> Sequence[PythonSignatureNativeFunctionPair]:
564    @with_native_function
565    def gen_signature_pairs(f: NativeFunction) -> PythonSignatureNativeFunctionPair:
566        return PythonSignatureNativeFunctionPair(
567            signature=signature(f, method=method, pyi=pyi),
568            function=f,
569        )
570
571    pairs = list(map(gen_signature_pairs, native_functions))
572    deprecated = load_deprecated_signatures(
573        pairs, deprecated_yaml_path, method=method, pyi=pyi
574    )
575    return pairs if skip_deprecated else pairs + deprecated
576
577
578def load_deprecated_signatures(
579    pairs: Sequence[PythonSignatureNativeFunctionPair],
580    deprecated_yaml_path: str,
581    *,
582    method: bool,
583    pyi: bool,
584) -> list[PythonSignatureNativeFunctionPair]:
585    # The deprecated.yaml doesn't have complete type information, we need
586    # find and leverage the original ATen signature (to which it delegates
587    # the call) to generate the full python signature.
588    # We join the deprecated and the original signatures using type-only form.
589
590    # group the original ATen signatures by name
591    grouped: dict[str, list[PythonSignatureNativeFunctionPair]] = defaultdict(list)
592    for pair in pairs:
593        grouped[pair.signature.name].append(pair)
594
595    # find matching original signatures for each deprecated signature
596    results: list[PythonSignatureNativeFunctionPair] = []
597
598    with open(deprecated_yaml_path) as f:
599        deprecated_defs = yaml.load(f, Loader=YamlLoader)
600
601    for deprecated in deprecated_defs:
602        schema = FunctionSchema.parse(deprecated["name"])
603        aten_name, call_args = split_name_params(deprecated["aten"])
604        is_out = aten_name.endswith("_out")
605        if is_out:
606            aten_name = aten_name.replace("_out", "")
607
608        # HACK: these are fixed constants used to pass the aten function.
609        # The type must be known ahead of time
610        known_constants = {
611            "1": Type.parse("Scalar"),
612        }
613        schema_args_by_name = {a.name: a for a in schema.arguments.flat_all}
614        for name in call_args:
615            assert (
616                name in schema_args_by_name or name in known_constants
617            ), f"deprecation definiton: Unrecognized value {name}"
618
619        # Map deprecated signature arguments to their aten signature and test
620        # if the types and alias annotation match.
621        def is_schema_compatible(
622            aten_schema: FunctionSchema,
623        ) -> bool:
624            arguments: Iterable[Argument]
625            if is_out:
626                arguments = itertools.chain(
627                    aten_schema.arguments.out, aten_schema.arguments.flat_non_out
628                )
629            else:
630                arguments = aten_schema.arguments.flat_all
631
632            for i, arg in enumerate(arguments):
633                if i < len(call_args):
634                    arg_name = call_args[i]
635                    if arg_name in known_constants:
636                        schema_type = known_constants[arg_name]
637                        schema_annotation = None
638                    else:
639                        schema_arg = schema_args_by_name[arg_name]
640                        schema_type = schema_arg.type
641                        schema_annotation = schema_arg.annotation
642
643                    if schema_type != arg.type or schema_annotation != arg.annotation:
644                        return False
645                else:
646                    if arg.default is None:
647                        return False
648
649            return len(schema.returns) == len(aten_schema.returns) and all(
650                a == b for a, b in zip(schema.returns, aten_schema.returns)
651            )
652
653        any_schema_found = False
654        for pair in grouped[aten_name]:
655            if not is_schema_compatible(pair.function.func):
656                continue
657            any_schema_found = True
658
659            python_sig = signature_from_schema(
660                schema,
661                category_override=pair.function.category_override,
662                method=method,
663                pyi=pyi,
664            )
665
666            results.append(
667                PythonSignatureNativeFunctionPair(
668                    signature=PythonSignatureDeprecated(
669                        name=python_sig.name,
670                        input_args=python_sig.input_args,
671                        input_kwargs=python_sig.input_kwargs,
672                        output_args=python_sig.output_args,
673                        tensor_options_args=python_sig.tensor_options_args,
674                        method=python_sig.method,
675                        deprecated_schema=schema,
676                        deprecated_args_exprs=tuple(call_args),
677                        returns=python_sig.returns,
678                    ),
679                    function=pair.function,
680                )
681            )
682        assert (
683            any_schema_found
684        ), f"No native function with name {aten_name} matched signature:\n  {str(schema)}"
685
686    return results
687
688
689# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ #
690#
691#                         Named Tuple Codegen
692#
693# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ #
694
695
696@with_native_function
697def gen_structseq_typename_key(f: NativeFunction) -> str:
698    name = cpp.name(f.func)
699    fieldnames = structseq_fieldnames(f.func.returns)
700    return "_".join([name] + fieldnames)
701
702
703def emit_structseq_call(
704    overloads: Sequence[PythonSignatureNativeFunctionPair],
705) -> tuple[list[str], dict[str, str]]:
706    """
707    Generate block of named tuple type def inits, and add typeref snippets
708    to declarations that use them
709    """
710    typenames: dict[
711        str, str
712    ] = {}  # map from unique name + field name lists to typedef name
713    typedefs: list[str] = []  # typedef declarations and init code
714
715    for overload in overloads:
716        fieldnames = structseq_fieldnames(overload.function.func.returns)
717        if not fieldnames:
718            continue
719
720        name = cpp.name(overload.function.func)  # use @with_native_function?
721        tn_key = gen_structseq_typename_key(overload.function)
722        typename = typenames.get(tn_key)
723        if typename is None:
724            typename = f'NamedTuple{"" if not typedefs else len(typedefs)}'
725            typenames[tn_key] = typename
726            typedefs.append(
727                f"""\
728static PyTypeObject* {typename} = generated::get_{name}_structseq();"""
729            )
730
731    return typedefs, typenames
732
733
734def generate_return_type_definition_and_registrations(
735    overloads: Sequence[PythonSignatureNativeFunctionPair],
736) -> tuple[list[str], list[str]]:
737    """
738    Generate block of function in `python_return_types.cpp` to initialize
739    and return named tuple for a native function which returns named tuple
740    and registration invocations in same file.
741    """
742    typenames: dict[
743        str, str
744    ] = {}  # map from unique name + field name lists to typedef name
745    definitions: list[str] = []  # function definition to register the typedef
746    registrations: list[str] = []  # register call for the typedef
747
748    for overload in overloads:
749        fieldnames = structseq_fieldnames(overload.function.func.returns)
750        if not fieldnames:
751            continue
752
753        fields = ", ".join(f'{{"{fn}", ""}}' for fn in fieldnames)
754
755        name = cpp.name(overload.function.func)  # use @with_native_function?
756        tn_key = gen_structseq_typename_key(overload.function)
757        typename = typenames.get(tn_key)
758
759        if typename is None:
760            typename = f'{name}NamedTuple{"" if not definitions else len(definitions)}'
761            typenames[tn_key] = typename
762            definitions.append(
763                f"""\
764PyTypeObject* get_{name}_structseq() {{
765    static PyStructSequence_Field NamedTuple_fields[] = {{ {fields},  {{nullptr}} }};
766    static PyTypeObject {typename};
767    static bool is_initialized = false;
768    static PyStructSequence_Desc desc = {{ "torch.return_types.{name}", nullptr, NamedTuple_fields, {len(fieldnames)} }};
769    if (!is_initialized) {{
770        PyStructSequence_InitType(&{typename}, &desc);
771        {typename}.tp_repr = (reprfunc)torch::utils::returned_structseq_repr;
772        is_initialized = true;
773    }}
774    return &{typename};
775}}
776"""
777            )
778            registrations.append(
779                f'addReturnType(return_types_module, "{name}", generated::get_{name}_structseq());'
780            )
781
782    return definitions, registrations
783
784
785def generate_return_type_declarations(
786    overloads: Sequence[PythonSignatureNativeFunctionPair],
787) -> list[str]:
788    """
789    Generate block of function declarations in `python_return_types.h` to initialize
790    and return named tuple for a native function.
791    """
792    typenames: dict[
793        str, str
794    ] = {}  # map from unique name + field name lists to typedef name
795    declarations: list[str] = []  # function declaration to register the typedef
796
797    for overload in overloads:
798        fieldnames = structseq_fieldnames(overload.function.func.returns)
799        if not fieldnames:
800            continue
801
802        name = cpp.name(overload.function.func)  # use @with_native_function?
803        tn_key = gen_structseq_typename_key(overload.function)
804        typename = typenames.get(tn_key)
805
806        if typename is None:
807            typename = (
808                f'{name}NamedTuple{"" if not declarations else len(declarations)}'
809            )
810            typenames[tn_key] = typename
811            declarations.append(f"PyTypeObject* get_{name}_structseq();")
812
813    return declarations
814
815
816# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ #
817#
818#                         Method Impl Codegen
819#
820# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ #
821
822# python binding for all overloads of a particular function/method
823PY_VARIABLE_METHOD_VARARGS = CodeTemplate(
824    r"""\
825// ${name}
826static PyObject * ${pycname}(PyObject* self_, PyObject* args, PyObject* kwargs)
827{
828  ${method_header}
829  static PythonArgParser parser({
830    ${signatures}
831  }, /*traceable=*/${traceable});
832
833  ParsedArgs<${max_args}> parsed_args;
834  auto _r = parser.parse(${self_}, args, kwargs, parsed_args);
835  ${check_has_torch_function}
836  switch (_r.idx) {
837    ${dispatch}
838  }
839  ${method_footer}
840}
841
842"""
843)
844
845# handler for a single parsed signature - may be a single overload or
846# a pair of overloads that whose signatures only differ in output params
847# (plugged into PY_VARIABLE_METHOD_VARARGS as an item in ${dispatch})
848PY_VARIABLE_CASE = CodeTemplate(
849    """\
850case ${overload_index}: {
851  ${body}
852}
853"""
854)
855
856# python binding for single-overload function/method
857PY_VARIABLE_METHOD_VARARGS_SINGLETON = CodeTemplate(
858    """\
859// ${name}
860static PyObject * ${pycname}(PyObject* self_, PyObject* args, PyObject* kwargs)
861{
862  ${method_header}
863  static PythonArgParser parser({
864    ${signatures}
865  }, /*traceable=*/${traceable});
866
867  ParsedArgs<${max_args}> parsed_args;
868  auto _r = parser.parse(${self_}, args, kwargs, parsed_args);
869  ${check_has_torch_function}
870  ${dispatch}
871  ${method_footer}
872}
873
874"""
875)
876
877# python binding for a method with no args, shortcuts parsing
878PY_VARIABLE_METHOD_NOARGS = CodeTemplate(
879    """\
880// ${name}
881static PyObject * ${pycname}(PyObject* self_, PyObject* args)
882{
883  ${method_header}
884  ${check_has_torch_function}
885  ${dispatch}
886  ${method_footer}
887}
888
889"""
890)
891
892
893def method_impl(
894    name: BaseOperatorName,
895    module: str | None,
896    overloads: Sequence[PythonSignatureNativeFunctionPair],
897    *,
898    method: bool,
899    symint: bool = True,
900) -> str:
901    """
902    Generate a python binding for all overloads of an op.
903    """
904    pycname = get_pycname(name)
905    noarg = is_noarg(overloads)
906    structseq_inits, structseq_typenames = emit_structseq_call(overloads)
907
908    method_header = ["HANDLE_TH_ERRORS"]
909    method_header += structseq_inits
910    method_header += (
911        ["const Tensor& self = THPVariable_Unpack(self_);"] if method else []
912    )
913
914    method_footer = ([] if noarg else ["Py_RETURN_NONE;"]) + ["END_HANDLE_TH_ERRORS"]
915
916    traceable = "true" if all(should_trace(o.function) for o in overloads) else "false"
917
918    grouped_overloads: Sequence[PythonSignatureGroup] = group_overloads(
919        overloads, symint=symint
920    )
921    is_singleton = len(grouped_overloads) == 1
922    signatures: list[str] = []
923    dispatch: list[str] = []
924    for overload_index, overload in enumerate(grouped_overloads):
925        signature = overload.signature.signature_str(symint=symint)
926        signatures.append(f"{cpp_string(str(signature))},")
927        dispatch_body = emit_dispatch_case(overload, structseq_typenames, symint=symint)
928        dispatch.append(
929            PY_VARIABLE_CASE.substitute(
930                overload_index=overload_index, body=dispatch_body
931            )
932            if not is_singleton
933            else dispatch_body
934        )
935
936    if noarg:
937        template = PY_VARIABLE_METHOD_NOARGS
938    elif is_singleton:
939        template = PY_VARIABLE_METHOD_VARARGS_SINGLETON
940    else:
941        template = PY_VARIABLE_METHOD_VARARGS
942
943    return template.substitute(
944        name=name,
945        pycname=pycname,
946        method_header=method_header,
947        max_args=max(o.signature.arguments_count() for o in overloads),
948        signatures=signatures,
949        traceable=traceable,
950        check_has_torch_function=gen_has_torch_function_check(
951            name=name,
952            module=module,
953            noarg=noarg,
954            method=method,
955        ),
956        dispatch=dispatch,
957        method_footer=method_footer,
958        self_="self_" if method else "nullptr",
959    )
960
961
962def gen_has_torch_function_check(
963    name: BaseOperatorName, module: str | None, *, noarg: bool, method: bool
964) -> str:
965    if noarg:
966        if method:
967            return f"""\
968if(check_has_torch_function(self_)) {{
969  return handle_torch_function(self_, "{name}");
970}}
971"""
972        else:
973            return ""
974
975    self_ = "self_" if method else "nullptr"
976    namespace = (
977        {
978            "torch": "THPVariableFunctionsModule",
979            "torch.nn": "THPNNVariableFunctionsModule",
980            "torch.fft": "THPFFTVariableFunctionsModule",
981            "torch.linalg": "THPLinalgVariableFunctionsModule",
982            "torch.nested": "THPNestedVariableFunctionsModule",
983            "torch.sparse": "THPSparseVariableFunctionsModule",
984            "torch.special": "THPSpecialVariableFunctionsModule",
985        }[module]
986        if module
987        else "THPVariableClass"
988    )
989
990    return f"""\
991if(_r.has_torch_function()) {{
992  return handle_torch_function(_r, {self_}, args, kwargs, {namespace}, "{module or "torch.Tensor"}");
993}}
994"""
995
996
997# handler for output/no-output overload pair
998PY_VARIABLE_OUT = CodeTemplate(
999    """\
1000if (_r.isNone(${out_idx})) {
1001  ${call_dispatch}
1002} else {
1003  ${call_dispatch_out}
1004}
1005"""
1006)
1007
1008
1009def emit_dispatch_case(
1010    overload: PythonSignatureGroup,
1011    structseq_typenames: dict[str, str],
1012    *,
1013    symint: bool = True,
1014) -> str:
1015    """
1016    Emit dispatch code for a single parsed signature. This corresponds to either
1017    a single native function, or a pair that differ only in output params. In the
1018    latter case, a single python signature is used for both and dispatching
1019    switches on the presence/absence of passed output args.
1020    """
1021    if overload.outplace is not None:
1022        # dispatch output and no-output variants, branch on _r.isNone(<out_idx>)
1023        return PY_VARIABLE_OUT.substitute(
1024            out_idx=overload.signature.output_idx(),
1025            call_dispatch=emit_single_dispatch(
1026                overload.signature, overload.base, structseq_typenames, symint=symint
1027            ),
1028            call_dispatch_out=emit_single_dispatch(
1029                overload.signature,
1030                overload.outplace,
1031                structseq_typenames,
1032                symint=symint,
1033            ),
1034        )
1035    else:
1036        # no-output version only
1037        return emit_single_dispatch(
1038            overload.signature, overload.base, structseq_typenames, symint=symint
1039        )
1040
1041
1042# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ #
1043#
1044#                    Forward Declarations Codegen
1045#
1046# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ #
1047
1048
1049def forward_decls(
1050    name: BaseOperatorName,
1051    overloads: Sequence[PythonSignatureNativeFunctionPair],
1052    *,
1053    method: bool,
1054) -> tuple[str, ...]:
1055    if method:
1056        return ()
1057
1058    pycname = get_pycname(name)
1059    if is_noarg(overloads):
1060        return (
1061            f"""\
1062static PyObject * {pycname}(PyObject* self_, PyObject* args);
1063""",
1064        )
1065    else:
1066        return (
1067            f"""\
1068static PyObject * {pycname}(PyObject* self_, PyObject* args, PyObject* kwargs);
1069""",
1070        )
1071
1072
1073# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ #
1074#
1075#              Method Def (Binding Table Entry) Codegen
1076#
1077# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ #
1078
1079
1080def method_def(
1081    name: BaseOperatorName,
1082    module: str | None,
1083    overloads: Sequence[PythonSignatureNativeFunctionPair],
1084    *,
1085    method: bool,
1086) -> str:
1087    """
1088    Generate method def entry.
1089    """
1090    pycname = get_pycname(name)
1091
1092    if name.dunder_method:
1093        # PyMethodDef entry for binary op, throws not implemented error
1094        pycname = f"TypeError_to_NotImplemented_<{pycname}>"
1095
1096    if is_noarg(overloads):
1097        flags = "METH_NOARGS" if method else "METH_VARARGS | METH_KEYWORDS"
1098    else:
1099        pycname = f"castPyCFunctionWithKeywords({pycname})"
1100        flags = "METH_VARARGS | METH_KEYWORDS"
1101
1102    if module == "torch":
1103        flags += " | METH_STATIC"
1104
1105    return f'{{"{name}", {pycname}, {flags}, NULL}},'
1106
1107
1108# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ #
1109#
1110#                   Overload Sorting and Grouping
1111#
1112# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ #
1113
1114
1115def group_overloads(
1116    overloads: Sequence[PythonSignatureNativeFunctionPair], *, symint: bool = True
1117) -> Sequence[PythonSignatureGroup]:
1118    bases: dict[str, PythonSignatureNativeFunctionPair] = {}
1119    outplaces: dict[str, PythonSignatureNativeFunctionPair] = {}
1120
1121    # first group by signature ignoring out arguments
1122    for overload in overloads:
1123        sig = overload.signature.signature_str(skip_outputs=True, symint=symint)
1124        if overload.function.func.is_out_fn():
1125            if sig in outplaces:
1126                raise RuntimeError(
1127                    f"Found duplicated function definition:\n- {overload.function.func}.\n"
1128                    f"Existing definition:\n- {outplaces[sig].function.func}."
1129                )
1130            outplaces[sig] = overload
1131        else:
1132            if sig in bases:
1133                raise RuntimeError(
1134                    f"Found duplicated function definition:\n- {overload.function.func}.\n"
1135                    f"Existing definition:\n- {bases[sig].function.func}."
1136                )
1137            bases[sig] = overload
1138
1139    for sig, out in outplaces.items():
1140        if sig not in bases:
1141            candidates: list[str] = []
1142            for overload in overloads:
1143                if (
1144                    str(overload.function.func.name.name)
1145                    == str(out.function.func.name.name)
1146                    and not overload.function.func.is_out_fn()
1147                    and not overload.signature.deprecated
1148                ):
1149                    candidates.append(
1150                        overload.signature.signature_str(
1151                            skip_outputs=True, symint=symint
1152                        )
1153                    )
1154            out_sig = out.signature.signature_str(symint=symint)
1155            raise RuntimeError(
1156                f"While identifying overloads, we found an out schema {out_sig} without a corresponding non-out variant. "
1157                f"We expected the non-out variant to have schema: \n- {sig}\nPlease check that you spelled the schema "
1158                "correctly in native_functions.yaml. We discovered the following candidate(s): \n"
1159                + "\n".join(f"- {candidate}" for candidate in candidates)
1160            )
1161
1162    grouped = [
1163        PythonSignatureGroup.from_pairs(
1164            functional=base,
1165            out=outplaces.get(sig),
1166        )
1167        for sig, base in bases.items()
1168    ]
1169    return sort_overloads(grouped, symint=symint)
1170
1171
1172# This function declares a partial order on declarations, and sorts them according
1173# to its linear extension. This is necessary, because there's some ambiguity in the
1174# choice of overload, and we want a different order.
1175#
1176# See Note[Order of overloads matters]
1177#
1178# A few examples of ambiguous python signature pairs.
1179#
1180#   All parameters have the same type, except one taking Tensor the other taking
1181#   Scalar. A numeric PyObject can be casted into Tensor, and a zero-dim Tensor
1182#   object can be accepted as Scalar type parameter (see python_arg_parser.cpp).
1183#   Therefore, same input arguments might be accepted by either python signature.
1184#   We want to always parse the one taking Tensor first.
1185#
1186#     bitwise_and(Tensor input, Tensor other, *, Tensor out=None)
1187#     bitwise_and(Tensor input, Scalar other, *, Tensor out=None)
1188#
1189#   If they have different number of parameters then they are not ambiguous - but
1190#   the difference on output param can be ignored as it's optional.
1191#
1192#     multiply(Tensor input, Tensor other, *, Tensor out=None)
1193#     multiply(Tensor input, Scalar other)
1194#
1195#   Both positional args and keyword-only args are considered together.
1196#
1197#     subtract(Tensor other, *, Scalar alpha=1)
1198#     subtract(Scalar other, Scalar alpha=1)
1199#
1200# A few ambiguous cases which it does NOT handle yet.
1201#
1202#   If there is any difference in other parameters besides the Tensor/Scalar
1203#   difference, then they are not considered ambiguous by this method anymore.
1204#   However, the difference could be too trivial to disambiguate.
1205#
1206#     foo(Tensor input, Scalar other, Scalar bar)
1207#     foo(Tensor input, Tensor other, double bar)
1208#
1209#   If they are taking different number of parameters then they are not considered
1210#   ambiguous anymore, even if the difference is only on optional kwargs.
1211#
1212#     foo(Scalar other, Scalar alpha=1)
1213#     foo(Tensor other, *, Scalar alpha=1, Scalar beta=1)
1214#
1215
1216
1217def sort_overloads(
1218    grouped_overloads: Sequence[PythonSignatureGroup], *, symint: bool = True
1219) -> Sequence[PythonSignatureGroup]:
1220    # NB: Smaller here means lower priority
1221
1222    def is_arg_smaller(t1: Type, t2: Type) -> bool:
1223        return (
1224            str(t1) == "Scalar"
1225            and str(t2) == "Tensor"
1226            or str(t1) == "Scalar?"
1227            and str(t2) == "Tensor?"
1228            or "Dimname" in str(t1)
1229            and "Dimname" not in str(t2)
1230            or
1231            # In the discussion https://github.com/pytorch/pytorch/issues/54555 it has been
1232            # discussed why it is important to prioritize int/int? over int[]
1233            str(t1) == "int[]"
1234            and (str(t2) == "int" or str(t2) == "int?")
1235            or
1236            # TensorList currently throws an error during argument parsing, that's why it needs to be
1237            # last in signature ordering. See discussion: https://github.com/pytorch/pytorch/issues/58087
1238            str(t1) == "Tensor[]"
1239            and str(t2).find("[]") != -1
1240            or
1241            # Prioritize IntArrayRef overload over SymIntArrayRef
1242            str(t1) == "SymInt[]"
1243            and str(t2) == "int[]"
1244            or
1245            # Make sure both in, SymInt are sorted consistently w.r.t. Tensor since Tensor can be implicitly
1246            # converted to either int or SymInt.  Prioritize the Tensor overload since it otherwise gets shadowed.
1247            (str(t1) == "SymInt" or str(t1) == "int")
1248            and str(t2) == "Tensor"
1249        )
1250
1251    def is_smaller(s1: PythonSignature, s2: PythonSignature) -> bool:
1252        """Returns True if s1 < s2 in the partial order."""
1253        args1, args2 = s1.arguments(skip_outputs=True), s2.arguments(skip_outputs=True)
1254        if len(args1) != len(args2):
1255            return False
1256        # TODO: should use some canonical form instead of 'str(arg.type)' - see comments
1257        # above. The old codegen used the deprecated 'dynamic_type(arg.type)', which
1258        # ignores the optional annotation, i.e. 'Scalar' and 'Scalar?'.
1259        equal = all(arg1.type == arg2.type for arg1, arg2 in zip(args1, args2))
1260        smaller_or_equal = all(
1261            str(arg1.type) == str(arg2.type) or is_arg_smaller(arg1.type, arg2.type)
1262            for arg1, arg2 in zip(args1, args2)
1263        )
1264        return smaller_or_equal and not equal
1265
1266    # First sort by signature
1267    grouped_overloads = sorted(
1268        grouped_overloads, key=lambda x: x.signature.signature_str(symint=symint)
1269    )
1270
1271    # Construct the relation graph
1272    larger_than: dict[int, set[int]] = defaultdict(set)
1273    for i1, overload1 in enumerate(grouped_overloads):
1274        for i2, overload2 in enumerate(grouped_overloads):
1275            if is_smaller(overload1.signature, overload2.signature):
1276                larger_than[i1].add(i2)
1277
1278    if not larger_than:
1279        return list(grouped_overloads)
1280
1281    # Use a topological sort to sort overloads according to the partial order.
1282    N = len(grouped_overloads)
1283    sorted_ids: list[int] = list(filter(lambda x: x not in larger_than, range(N)))
1284
1285    for idx in range(N):
1286        # The size of sorted_ids will grow to N eventually.
1287        i = sorted_ids[idx]
1288        for j in sorted(larger_than.keys()):
1289            larger = larger_than[j]
1290            larger.discard(i)
1291            if not larger:
1292                del larger_than[j]
1293                sorted_ids.append(j)
1294
1295    return [grouped_overloads[x] for x in sorted_ids]
1296
1297
1298# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ #
1299#
1300#                       Codegen API Integration
1301#
1302# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ #
1303
1304
1305def emit_single_dispatch(
1306    ps: PythonSignature,
1307    f: NativeFunction,
1308    structseq_typenames: dict[str, str],
1309    *,
1310    symint: bool = True,
1311) -> str:
1312    """
1313    Emit dispatch code for a single native function.
1314    """
1315
1316    @with_native_function
1317    def go(f: NativeFunction) -> str:
1318        # header comments
1319        if isinstance(ps, PythonSignatureDeprecated):
1320            schema_comment = f"// [deprecated] aten::{ps.deprecated_schema}"
1321        else:
1322            schema_comment = f"// aten::{f.func}"
1323
1324        deprecated = "[deprecated] " if ps.deprecated else ""
1325
1326        # dispatch lambda signature
1327        name = cpp.name(f.func)
1328        lambda_formals = ", ".join(
1329            f"{a.type_str} {a.name}" for a in dispatch_lambda_args(ps, f, symint=symint)
1330        )
1331        lambda_return = dispatch_lambda_return_str(f)
1332
1333        # dispatch lambda body
1334        dispatch_callee = cpp_dispatch_target(f)
1335        dispatch_args = ", ".join(cpp_dispatch_exprs(f, python_signature=ps))
1336
1337        # from arg parser outputs to dispatch lambda arguments
1338        parser_outputs = arg_parser_output_exprs(ps, f, symint=symint)
1339        lambda_arg_exprs = dispatch_lambda_exprs(ps, f, symint=symint)
1340        inits = "\n".join(lambda_arg_exprs.inits)
1341        lambda_args = ", ".join(lambda_arg_exprs.exprs)
1342
1343        # scatter fields
1344        # TODO: Checking `ps.method and ('requires_grad' in parser_outputs)` is a hacky
1345        #       solution for enabling the 'requires_grad' argument for tensor methods
1346        #       new_full, new_empty, and new_zeros. A much better but more difficult to
1347        #       implement solution involves refactoring according to Ed's description here:
1348        #       https://github.com/pytorch/pytorch/issues/36455#issuecomment-614767589
1349        need_set_requires_grad = ps.tensor_options_args and (
1350            not has_tensor_options(f)
1351            or (ps.method and ("requires_grad" in parser_outputs))
1352        )
1353        set_requires_grad = (
1354            f'.set_requires_grad({parser_outputs["requires_grad"].expr})'
1355            if need_set_requires_grad
1356            else ""
1357        )
1358
1359        if lambda_return == "void":
1360            # Make in-place foreach return `self` at python-binding level.
1361            # ref: https://github.com/pytorch/pytorch/pull/118622#pullrequestreview-1904804954
1362            self_arg = f.func.arguments.self_arg
1363            return_stmt: str
1364            if (
1365                str(f.func.name).startswith("_foreach_")
1366                and f.func.kind() == SchemaKind.inplace
1367            ):
1368                # note(crcrpar): `_foreach_pow.ScalarAndTensor` does NOT have its in-place
1369                # variant and it unlikely to have it in the future. Thus it's safe to have the following assert.
1370                assert self_arg is not None and is_tensor_list_type(
1371                    self_arg.argument.type
1372                )
1373                return_stmt = """PyObject* self_tensorlist = _r.args[0];
1374Py_INCREF(self_tensorlist);
1375return self_tensorlist;
1376"""
1377            else:
1378                return_stmt = "Py_RETURN_NONE;"
1379            return f"""\
1380{schema_comment}
1381{inits}
1382auto dispatch_{name} = []({lambda_formals}) -> {lambda_return} {{
1383  pybind11::gil_scoped_release no_gil;
1384  {dispatch_callee}({dispatch_args});
1385}};
1386dispatch_{name}({lambda_args}){set_requires_grad};
1387{return_stmt}
1388"""
1389        else:
1390            typename = structseq_typenames.get(gen_structseq_typename_key(f))
1391            structseq_typeref = f"{typename}, " if typename is not None else ""
1392            return f"""\
1393{schema_comment}
1394{inits}
1395auto dispatch_{name} = []({lambda_formals}) -> {lambda_return} {{
1396  pybind11::gil_scoped_release no_gil;
1397  return {dispatch_callee}({dispatch_args});
1398}};
1399return wrap({structseq_typeref}dispatch_{name}({lambda_args}){set_requires_grad});
1400"""
1401
1402    return go(f)
1403