1""" 2For procedural tests needed for __torch_function__, we use this function 3to export method names and signatures as needed by the tests in 4test/test_overrides.py. 5 6python -m tools.autograd.gen_annotated_fn_args \ 7 aten/src/ATen/native/native_functions.yaml \ 8 aten/src/ATen/native/tags.yaml \ 9 $OUTPUT_DIR \ 10 tools/autograd 11 12Where $OUTPUT_DIR is where you would like the files to be 13generated. In the full build system, OUTPUT_DIR is 14torch/testing/_internal/generated 15""" 16 17from __future__ import annotations 18 19import argparse 20import os 21import textwrap 22from collections import defaultdict 23from typing import Any, Sequence, TYPE_CHECKING 24 25import torchgen.api.python as python 26from torchgen.context import with_native_function 27from torchgen.gen import parse_native_yaml 28from torchgen.utils import FileManager 29 30from .gen_python_functions import ( 31 is_py_fft_function, 32 is_py_linalg_function, 33 is_py_nn_function, 34 is_py_special_function, 35 is_py_torch_function, 36 is_py_variable_method, 37 should_generate_py_binding, 38) 39 40 41if TYPE_CHECKING: 42 from torchgen.model import Argument, BaseOperatorName, NativeFunction 43 44 45def gen_annotated( 46 native_yaml_path: str, tags_yaml_path: str, out: str, autograd_dir: str 47) -> None: 48 native_functions = parse_native_yaml( 49 native_yaml_path, tags_yaml_path 50 ).native_functions 51 mappings = ( 52 (is_py_torch_function, "torch._C._VariableFunctions"), 53 (is_py_nn_function, "torch._C._nn"), 54 (is_py_linalg_function, "torch._C._linalg"), 55 (is_py_special_function, "torch._C._special"), 56 (is_py_fft_function, "torch._C._fft"), 57 (is_py_variable_method, "torch.Tensor"), 58 ) 59 annotated_args: list[str] = [] 60 for pred, namespace in mappings: 61 groups: dict[BaseOperatorName, list[NativeFunction]] = defaultdict(list) 62 for f in native_functions: 63 if not should_generate_py_binding(f) or not pred(f): 64 continue 65 groups[f.func.name.name].append(f) 66 for group in groups.values(): 67 for f in group: 68 annotated_args.append(f"{namespace}.{gen_annotated_args(f)}") 69 70 template_path = os.path.join(autograd_dir, "templates") 71 fm = FileManager(install_dir=out, template_dir=template_path, dry_run=False) 72 fm.write_with_template( 73 "annotated_fn_args.py", 74 "annotated_fn_args.py.in", 75 lambda: { 76 "annotated_args": textwrap.indent("\n".join(annotated_args), " "), 77 }, 78 ) 79 80 81@with_native_function 82def gen_annotated_args(f: NativeFunction) -> str: 83 def _get_kwargs_func_exclusion_list() -> list[str]: 84 # functions that currently don't work with kwargs in test_overrides.py 85 return [ 86 "diagonal", 87 "round_", 88 "round", 89 "scatter_", 90 ] 91 92 def _add_out_arg( 93 out_args: list[dict[str, Any]], args: Sequence[Argument], *, is_kwarg_only: bool 94 ) -> None: 95 for arg in args: 96 if arg.default is not None: 97 continue 98 out_arg: dict[str, Any] = {} 99 out_arg["is_kwarg_only"] = str(is_kwarg_only) 100 out_arg["name"] = arg.name 101 out_arg["simple_type"] = python.argument_type_str( 102 arg.type, simple_type=True 103 ) 104 size_t = python.argument_type_size(arg.type) 105 if size_t: 106 out_arg["size"] = size_t 107 out_args.append(out_arg) 108 109 out_args: list[dict[str, Any]] = [] 110 _add_out_arg(out_args, f.func.arguments.flat_positional, is_kwarg_only=False) 111 if f"{f.func.name.name}" not in _get_kwargs_func_exclusion_list(): 112 _add_out_arg(out_args, f.func.arguments.flat_kwarg_only, is_kwarg_only=True) 113 114 return f"{f.func.name.name}: {repr(out_args)}," 115 116 117def main() -> None: 118 parser = argparse.ArgumentParser(description="Generate annotated_fn_args script") 119 parser.add_argument( 120 "native_functions", metavar="NATIVE", help="path to native_functions.yaml" 121 ) 122 parser.add_argument("tags", metavar="TAGS", help="path to tags.yaml") 123 parser.add_argument("out", metavar="OUT", help="path to output directory") 124 parser.add_argument( 125 "autograd", metavar="AUTOGRAD", help="path to template directory" 126 ) 127 args = parser.parse_args() 128 gen_annotated(args.native_functions, args.tags, args.out, args.autograd) 129 130 131if __name__ == "__main__": 132 main() 133