1from __future__ import annotations 2 3import argparse 4import collections 5import importlib 6import sys 7from pprint import pformat 8from typing import Sequence 9from unittest.mock import Mock, patch 10from warnings import warn 11 12from tools.autograd.gen_python_functions import ( 13 group_overloads, 14 load_signatures, 15 should_generate_py_binding, 16) 17 18from torchgen.api.python import ( 19 PythonSignatureGroup, 20 PythonSignatureNativeFunctionPair, 21 returns_structseq_pyi, 22) 23from torchgen.gen import parse_native_yaml, parse_tags_yaml 24from torchgen.model import _TorchDispatchModeKey, DispatchKey, Variant 25from torchgen.utils import FileManager 26 27 28""" 29This module implements generation of type stubs for PyTorch, 30enabling use of autocomplete in IDEs like PyCharm, which otherwise 31don't understand C extension modules. 32 33At the moment, this module only handles type stubs for torch and 34torch.Tensor. It should eventually be expanded to cover all functions 35which come are autogenerated. 36 37Here's our general strategy: 38 39- We start off with a hand-written __init__.pyi.in file. This 40 file contains type definitions for everything we cannot automatically 41 generate, including pure Python definitions directly in __init__.py 42 (the latter case should be pretty rare). 43 44- We go through automatically bound functions based on the 45 type information recorded in native_functions.yaml and 46 generate type hints for them (generate_type_hints) 47 48There are a number of type hints which we've special-cased; 49read gen_pyi for the gory details. 50""" 51 52 53def get_py_torch_functions( 54 python_funcs: Sequence[PythonSignatureNativeFunctionPair], 55 method: bool = False, 56) -> Sequence[PythonSignatureGroup]: 57 """ 58 Get declarations (grouped by name) which should be generated 59 as either functions in the "torch" module or methods on Tensor. 60 """ 61 62 def should_bind_function(python_func: PythonSignatureNativeFunctionPair) -> bool: 63 return ( 64 should_generate_py_binding(python_func.function) 65 and not python_func.function.python_module 66 and Variant.function in python_func.function.variants 67 ) 68 69 def should_bind_method(python_func: PythonSignatureNativeFunctionPair) -> bool: 70 return ( 71 should_generate_py_binding(python_func.function) 72 and not python_func.function.python_module 73 and Variant.method in python_func.function.variants 74 ) 75 76 should_bind = should_bind_method if method else should_bind_function 77 return group_overloads([f for f in python_funcs if should_bind(f)]) 78 79 80# TODO: Consider defining some aliases for our Union[...] types, to make 81# the stubs to read on the human eye. 82 83DEVICE_PARAM = "device: Optional[DeviceLikeType] = None" 84FACTORY_PARAMS = f"dtype: Optional[_dtype] = None, {DEVICE_PARAM}, requires_grad: _bool = False, pin_memory: _bool = False" 85 86# NOTE: specifying indices for Tensor.__getitem__ 87# We can imitate numpy's definition of ndarray.__getitem__ found in numpy/__init__.pyi: 88# 89# key: ( 90# None 91# | slice 92# | ellipsis 93# | SupportsIndex 94# | _ArrayLikeInt_co 95# | tuple[None | slice | ellipsis | _ArrayLikeInt_co | SupportsIndex, ...] 96# ) 97# 98# where: 99# 100# _ArrayLikeInt_co = _DualArrayLike[ 101# dtype[Union[bool_, integer[Any]]], 102# Union[bool, int], 103# ] 104# 105# and 106# 107# _DualArrayLike = Union[ 108# _SupportsArray[_DType], 109# _NestedSequence[_SupportsArray[_DType]], 110# _T, 111# _NestedSequence[_T], 112# ] 113# 114# Moreover, _NestedSequence is a Protocol that matches arbitrary nesting of list/tuple. 115# We can substitute and simplify: 116# _SupportsArray -> Tensor 117# _ArrayLikeInt_co -> [bool | int | | Tensor | NestedSequence[bool | int] | NestedSequence[Tensor]] 118# which leaves us with key: T | tuple[T, ...], where T is: 119# T = ( 120# None | bool | int | slice | ellipsis | SupportsIndex 121# | Tensor | _NestedSequence[Tensor] | _NestedSequence[bool | int] 122# ) 123 124# NOTE: ellipsis is equal to type[Ellipsis] in stub files. 125_leaf_types = "Union[None, _bool, _int, slice, ellipsis, Tensor]" # not SupportsIndex! 126_index = f"Union[SupportsIndex, {_leaf_types}, _NestedSequence[{_leaf_types}]]" 127INDICES = f"indices: Union[{_index}, tuple[{_index}, ...]]" 128 129blocklist = [ 130 "__init_subclass__", 131 "__new__", 132 "__subclasshook__", 133 "cdist", 134 "device", 135 "grad", 136 "requires_grad", 137 "range", 138 # defined in functional 139 "einsum", 140 # Somehow, these are defined in both _C and in functional. Ick! 141 "broadcast_tensors", 142 # Manually define named tensor type stubs in __init__.pyi.in 143 "align_tensors", 144 "meshgrid", 145 "cartesian_prod", 146 "block_diag", 147 "norm", 148 "chain_matmul", 149 "stft", 150 "tensordot", 151 "split", 152 "unique_consecutive", 153 "atleast_1d", 154 "atleast_2d", 155 "atleast_3d", 156 # These are handled specially by python_arg_parser.cpp 157 "add", 158 "add_", 159 "add_out", 160 "sub", 161 "sub_", 162 "sub_out", 163 "mul", 164 "mul_", 165 "mul_out", 166 "div", 167 "div_", 168 "div_out", 169 "true_divide", 170 "true_divide_", 171 "true_divide_out", 172 "floor_divide", 173 "floor_divide_", 174 "floor_divide_out", 175 "to", 176 "_to_copy", 177 "copy_", 178] 179 180binary_ops = ( 181 "add", 182 "sub", 183 "mul", 184 "div", 185 "pow", 186 "lshift", 187 "rshift", 188 "mod", 189 "truediv", 190 "matmul", 191 "floordiv", 192 "radd", 193 "rsub", 194 "rmul", 195 "rtruediv", 196 "rfloordiv", 197 "rpow", # reverse arithmetic 198 "and", 199 "or", 200 "xor", 201 "rand", 202 "ror", 203 "rxor", # logic 204 "iadd", 205 "iand", 206 "idiv", 207 "ilshift", 208 "imul", 209 "ior", 210 "irshift", 211 "isub", 212 "ixor", 213 "ifloordiv", 214 "imod", # inplace ops 215) 216symmetric_comparison_ops = ("eq", "ne") 217asymmetric_comparison_ops = ("ge", "gt", "lt", "le") 218comparison_ops = symmetric_comparison_ops + asymmetric_comparison_ops 219 220unary_ops = ("neg", "abs", "invert") 221to_py_type_ops = ("bool", "float", "complex", "long", "index", "int", "nonzero") 222all_ops = binary_ops + comparison_ops + unary_ops + to_py_type_ops 223 224 225def sig_for_ops(opname: str) -> list[str]: 226 """sig_for_ops(opname : str) -> List[str] 227 228 Returns signatures for operator special functions (__add__ etc.)""" 229 230 # we have to do this by hand, because they are hand-bound in Python 231 232 assert opname.endswith("__") and opname.startswith("__"), f"Unexpected op {opname}" 233 234 name = opname[2:-2] 235 if name in binary_ops: 236 return [f"def {opname}(self, other: Any) -> Tensor: ..."] 237 elif name in comparison_ops: 238 sig = f"def {opname}(self, other: Any) -> Tensor: ..." 239 if name in symmetric_comparison_ops: 240 # unsafe override https://github.com/python/mypy/issues/5704 241 sig += " # type: ignore[override]" 242 return [sig] 243 elif name in unary_ops: 244 return [f"def {opname}(self) -> Tensor: ..."] 245 elif name in to_py_type_ops: 246 if name in {"bool", "float", "complex"}: 247 tname = name 248 elif name == "nonzero": 249 tname = "bool" 250 else: 251 tname = "int" 252 if tname in {"float", "int", "bool", "complex"}: 253 tname = "builtins." + tname 254 return [f"def {opname}(self) -> {tname}: ..."] 255 else: 256 raise Exception("unknown op", opname) # noqa: TRY002 257 258 259def generate_type_hints(sig_group: PythonSignatureGroup) -> list[str]: 260 type_hints: list[str] = [] 261 262 # Some deprecated ops that are on the blocklist are still included in pyi 263 if sig_group.signature.name in blocklist and not sig_group.signature.deprecated: 264 return type_hints 265 266 # deprecated signatures have separate entries for their functional and out variants 267 # (as opposed to the native ops, which fuse the two into a single signature). 268 # generate the functional variant here, if an out variant exists. 269 if sig_group.signature.deprecated and sig_group.outplace is not None: 270 type_hint = sig_group.signature.signature_str_pyi(skip_outputs=True) 271 type_hints.append(type_hint) 272 273 # PythonSignatureGroups that have both a functional + out variant get a single signature, with an optional out argument 274 # Generates the out variant if one exists. Otherwise, generate the functional variant 275 type_hint = sig_group.signature.signature_str_pyi( 276 skip_outputs=sig_group.outplace is None 277 ) 278 type_hints.append(type_hint) 279 280 # Some operators also additionally have a vararg variant of their signature 281 type_hint_vararg = sig_group.signature.signature_str_pyi_vararg( 282 skip_outputs=sig_group.outplace is None 283 ) 284 if type_hint_vararg: 285 type_hints.append(type_hint_vararg) 286 287 return type_hints 288 289 290def get_max_pool_dispatch(name: str, arg_list: list[str]) -> dict[str, list[str]]: 291 flag_pos = arg_list.index("{return_indices}") 292 # If return_indices is positional arg, everything before should have no default 293 arg_list_positional = ( 294 [ 295 ", ".join(single_arg.split(" = ")[0] for single_arg in arg.split(", ")) 296 for arg in arg_list[: flag_pos + 1] 297 ] 298 + ["/"] 299 + arg_list[flag_pos + 1 :] 300 ) 301 # Otherwise force return_indices to be kwarg 302 arg_list_keyword = arg_list.copy() 303 arg_list_keyword.insert(flag_pos, "*") 304 tmpl = "def {name}({args}) -> {{return_type}}: ..." 305 return { 306 name: [ 307 tmpl.format(name=name, args=", ".join(arg_list)).format( 308 return_indices="return_indices: Literal[False] = False", 309 return_type="Tensor", 310 ), 311 tmpl.format(name=name, args=", ".join(arg_list_positional)).format( 312 return_indices="return_indices: Literal[True]", 313 return_type="Tuple[Tensor, Tensor]", 314 ), 315 tmpl.format(name=name, args=", ".join(arg_list_keyword)).format( 316 return_indices="return_indices: Literal[True]", 317 return_type="Tuple[Tensor, Tensor]", 318 ), 319 ] 320 } 321 322 323def gen_nn_functional(fm: FileManager) -> None: 324 INPUT = "input: Tensor" 325 KERNEL_SIZE = "kernel_size: Union[_int, _size]" 326 STRIDE_PADDING = ", ".join( 327 [ 328 "stride: Optional[Union[_int, _size]] = None", 329 "padding: Union[_int, _size] = 0", 330 ] 331 ) 332 333 # TODO the list for `torch._C._nn` is nonexhaustive 334 unsorted_c_nn_function_hints: dict[str, list[str]] = {} 335 336 for d in (2, 3): 337 unsorted_c_nn_function_hints.update( 338 { 339 f"avg_pool{d}d": [ 340 f"def avg_pool{d}d({{}}) -> Tensor: ...".format( 341 ", ".join( 342 [ 343 f"{INPUT}", 344 f"{KERNEL_SIZE}", 345 f"{STRIDE_PADDING}", 346 "ceil_mode: bool = False", 347 "count_include_pad: bool = True", 348 "divisor_override: Optional[int] = None", 349 ] 350 ) 351 ) 352 ], 353 f"fractional_max_pool{d}d": [ 354 f"def fractional_max_pool{d}d({{}}) -> {{}}: ...".format( 355 ", ".join( 356 [ 357 f"{INPUT}", 358 f"{KERNEL_SIZE}", 359 "output_size: Union[_int, _size]", 360 "_random_samples: Tensor", 361 ] 362 ), 363 "Tuple[Tensor, Tensor]", 364 ) 365 ], 366 f"adaptive_max_pool{d}d": [ 367 f"def adaptive_max_pool{d}d({{}}) -> {{}}: ...".format( 368 ", ".join([f"{INPUT}", "output_size: Union[_int, _size]"]), 369 "Tuple[Tensor, Tensor]", 370 ) 371 ], 372 } 373 ) 374 375 unsorted_c_nn_function_hints.update( 376 { 377 "hardtanh": [ 378 "def hardtanh({}) -> Tensor: ...".format( 379 ", ".join( 380 [ 381 "input: Tensor", 382 "min_val: float = ...", 383 "max_val: float = ...", 384 "*", 385 "out: Optional[Tensor] = None", 386 ] 387 ) 388 ) 389 ], 390 "hardtanh_": [ 391 "def hardtanh_({}) -> Tensor: ...".format( 392 ", ".join( 393 [ 394 "input: Tensor", 395 "min_val: float = ...", 396 "max_val: float = ...", 397 ] 398 ) 399 ) 400 ], 401 "elu_": ["def elu_(input: Tensor, alpha: float = ...) -> Tensor: ..."], 402 "leaky_relu": [ 403 "def leaky_relu({}) -> Tensor: ...".format( 404 ", ".join( 405 [ 406 "input: Tensor", 407 "negative_slope: float = ...", 408 "*", 409 "out: Optional[Tensor] = None", 410 ] 411 ) 412 ) 413 ], 414 "leaky_relu_": [ 415 f"def leaky_relu_({', '.join(['input: Tensor', 'negative_slope: float = ...'])}) -> Tensor: ..." 416 ], 417 "log_sigmoid": ["def log_sigmoid(input: Tensor) -> Tensor: ..."], 418 "gelu": ["def gelu(input: Tensor, approximate: str = ...) -> Tensor: ..."], 419 "softplus": [ 420 "def softplus({}) -> Tensor: ...".format( 421 ", ".join( 422 ["input: Tensor", "beta: float = ...", "threshold: float = ..."] 423 ) 424 ) 425 ], 426 "softshrink": [ 427 "def softshrink(input: Tensor, lambd: float = ...) -> Tensor: ..." 428 ], 429 "hardsigmoid": [ 430 f"def hardsigmoid({', '.join(['input: Tensor', '*', 'out: Optional[Tensor] = None'])}) -> Tensor: ..." 431 ], 432 "linear": [ 433 "def linear({}) -> Tensor: ...".format( 434 ", ".join( 435 [ 436 "input: Tensor", 437 "weight: Tensor", 438 "bias: Optional[Tensor] = None", 439 ] 440 ) 441 ) 442 ], 443 "pad": [ 444 "def pad({}) -> Tensor: ...".format( 445 ", ".join( 446 [ 447 "input: Tensor", 448 "pad: Sequence[int]", 449 "mode: str = ...", 450 "value: Optional[float] = None", 451 ] 452 ) 453 ) 454 ], 455 "one_hot": [ 456 "def one_hot(tensor: Tensor, num_classes: int = ...) -> Tensor: ..." 457 ], 458 "scaled_dot_product_attention": [ 459 "def scaled_dot_product_attention({}) -> Tensor: ...".format( 460 ", ".join( 461 [ 462 "query: Tensor", 463 "key: Tensor", 464 "value: Tensor", 465 "attn_mask: Optional[Tensor] = None", 466 "dropout_p: float = 0.0", 467 "is_causal: bool = False", 468 "scale: Optional[float] = None", 469 "enable_gqa: bool = False", 470 ] 471 ) 472 ) 473 ], 474 } 475 ) 476 477 c_nn_function_hints: list[str] = [] 478 for _, hints in sorted(unsorted_c_nn_function_hints.items()): 479 if len(hints) > 1: 480 hints = ["@overload\n" + h for h in hints] 481 c_nn_function_hints += hints 482 483 # Functions imported into `torch.nn.functional` from `torch`, perhaps being filtered 484 # through an `_add_docstr` call 485 torch_imports = [ 486 "conv1d", 487 "conv2d", 488 "conv3d", 489 "conv_transpose1d", 490 "conv_transpose2d", 491 "conv_transpose3d", 492 "conv_tbc", 493 "avg_pool1d", 494 "adaptive_avg_pool1d", 495 "relu_", 496 "selu_", 497 "celu_", 498 "prelu", 499 "rrelu_", 500 "hardshrink", 501 "bilinear", 502 "pixel_shuffle", 503 "pixel_unshuffle", 504 "channel_shuffle", 505 "native_channel_shuffle", 506 "pairwise_distance", 507 "pdist", 508 "cosine_similarity", 509 ] 510 imported_hints = [f"from torch import {_} as {_}" for _ in torch_imports] 511 512 # Functions imported into `torch.nn.functional` from `torch._C._nn` 513 c_nn_imports = [ 514 "avg_pool2d", 515 "avg_pool3d", 516 "hardtanh_", 517 "elu_", 518 "leaky_relu_", 519 "gelu", 520 "softplus", 521 "softshrink", 522 "linear", 523 "pad", 524 "one_hot", 525 "scaled_dot_product_attention", 526 ] 527 imported_hints += [f"from torch._C._nn import {_} as {_}" for _ in c_nn_imports] 528 # This is from `torch._C._nn` but renamed 529 imported_hints.append( 530 "from torch._C._nn import log_sigmoid\nlogsigmoid = log_sigmoid" 531 ) 532 533 # Functions generated by `torch._jit_internal.boolean_dispatch` in `nn.functional` 534 unsorted_dispatched_hints: dict[str, list[str]] = {} 535 536 for d in (1, 2, 3): 537 unsorted_dispatched_hints.update( 538 **get_max_pool_dispatch( 539 f"max_pool{d}d", 540 [ 541 f"{INPUT}", 542 f"{KERNEL_SIZE}", 543 f"{STRIDE_PADDING}", 544 "dilation: Union[_int, _size] = 1", 545 "ceil_mode: bool = False", 546 "{return_indices}", 547 ], 548 ), 549 **get_max_pool_dispatch( 550 f"fractional_max_pool{d}d", 551 [ 552 f"{INPUT}", 553 f"{KERNEL_SIZE}", 554 "output_size: Optional[Union[_int, _size]] = None", 555 "output_ratio: Optional[_ratio_any_t] = None", 556 "{return_indices}", 557 "_random_samples: Optional[Tensor] = None", 558 ], 559 ), 560 **get_max_pool_dispatch( 561 f"adaptive_max_pool{d}d", 562 [f"{INPUT}", "output_size: Union[_int, _size]", "{return_indices}"], 563 ), 564 ) 565 566 # There's no fractional_max_pool1d 567 del unsorted_dispatched_hints["fractional_max_pool1d"] 568 569 dispatched_hints: list[str] = [] 570 for _, hints in sorted(unsorted_dispatched_hints.items()): 571 if len(hints) > 1: 572 hints = ["@overload\n" + h for h in hints] 573 dispatched_hints += hints 574 575 fm.write_with_template( 576 "torch/nn/functional.pyi", 577 "torch/nn/functional.pyi.in", 578 lambda: { 579 "imported_hints": imported_hints, 580 "dispatched_hints": dispatched_hints, 581 }, 582 ) 583 fm.write_with_template( 584 "torch/_C/_nn.pyi", 585 "torch/_C/_nn.pyi.in", 586 lambda: { 587 "c_nn_function_hints": c_nn_function_hints, 588 }, 589 ) 590 591 592""" 593We gather the docstrings for torch with the following steps: 5941. Mock torch and torch._C, which are the only dependencies of the docs files 5952. Mock the _add_docstr function to save the docstrings 5963. Import the docs files to trigger mocked _add_docstr and collect docstrings 597""" 598 599 600def gather_docstrs() -> dict[str, str]: 601 docstrs = {} 602 603 def mock_add_docstr(func: Mock, docstr: str) -> None: 604 docstrs[func._extract_mock_name()] = docstr.strip() 605 606 # sys.modules and sys.path are restored after the context manager exits 607 with patch.dict(sys.modules), patch.object(sys, "path", sys.path + ["torch"]): 608 # mock the torch module and torch._C._add_docstr 609 sys.modules["torch"] = Mock(name="torch") 610 sys.modules["torch._C"] = Mock(_add_docstr=mock_add_docstr) 611 612 try: 613 # manually import torch._torch_docs and torch._tensor_docs to trigger 614 # the mocked _add_docstr and collect docstrings 615 sys.modules["torch._torch_docs"] = importlib.import_module("_torch_docs") 616 sys.modules["torch._tensor_docs"] = importlib.import_module("_tensor_docs") 617 except ModuleNotFoundError: 618 # Gracefully fail if these modules are not importable 619 warn( 620 "Failed to import _torch_docs/_tensor_docs, skipping docstring in pyi files." 621 ) 622 623 return docstrs 624 625 626def add_docstr_to_hint(docstr: str, hint: str) -> str: 627 if "..." in hint: # function or method 628 assert hint.endswith("..."), f"Hint `{hint}` does not end with '...'" 629 hint = hint[:-3] # remove "..." 630 return "\n ".join([hint, 'r"""'] + docstr.split("\n") + ['"""', "..."]) 631 else: # attribute or property 632 return f'{hint}\nr"""{docstr}"""\n' 633 634 635def gen_pyi( 636 native_yaml_path: str, 637 tags_yaml_path: str, 638 deprecated_yaml_path: str, 639 fm: FileManager, 640) -> None: 641 """gen_pyi() 642 643 This function generates a pyi file for torch. 644 """ 645 646 # Some of this logic overlaps with generate_python_signature in 647 # tools/autograd/gen_python_functions.py; however, this 648 # function is all about generating mypy type signatures, whereas 649 # the other function generates are custom format for argument 650 # checking. If you are update this, consider if your change 651 # also needs to update the other file. 652 653 # Dictionary for NamedTuple definitions 654 structseqs: dict[str, str] = {} 655 656 # Generate type signatures for top-level functions 657 # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ 658 659 unsorted_function_hints: dict[str, list[str]] = collections.defaultdict(list) 660 661 for n, n1, n2 in [ 662 ("csr", "crow", "col"), 663 ("csc", "ccol", "row"), 664 ("bsr", "crow", "col"), 665 ("bsc", "ccol", "row"), 666 ]: 667 unsorted_function_hints.update( 668 { 669 f"sparse_{n}_tensor": [ 670 f"def sparse_{n}_tensor({{}}) -> Tensor: ...".format( 671 ", ".join( 672 [ 673 f"{n1}_indices: Union[Tensor, List]", 674 f"{n2}_indices: Union[Tensor, List]", 675 "values: Union[Tensor, List]", 676 "size: Optional[_size] = None", 677 "*", 678 "dtype: Optional[_dtype] = None", 679 "device: Optional[DeviceLikeType] = None", 680 "requires_grad: _bool = False", 681 "check_invariants: Optional[_bool] = None", 682 ] 683 ), 684 ) 685 ], 686 } 687 ) 688 689 unsorted_function_hints.update( 690 { 691 "set_flush_denormal": ["def set_flush_denormal(mode: _bool) -> _bool: ..."], 692 "get_default_dtype": ["def get_default_dtype() -> _dtype: ..."], 693 "asarray": [ 694 "def asarray({}) -> Tensor: ...".format( 695 ", ".join( 696 [ 697 "obj: Any", 698 "*", 699 "dtype: Optional[_dtype] = None", 700 "device: Optional[DeviceLikeType] = None", 701 "copy: Optional[_bool] = None", 702 "requires_grad: _bool = False", 703 ] 704 ) 705 ) 706 ], 707 "from_numpy": ["def from_numpy(ndarray) -> Tensor: ..."], 708 "frombuffer": [ 709 "def frombuffer({}) -> Tensor: ...".format( 710 ", ".join( 711 [ 712 "buffer: Any", 713 "*", 714 "dtype: _dtype", 715 "count: int = -1", 716 "offset: int = 0", 717 "requires_grad: _bool = False", 718 ] 719 ) 720 ) 721 ], 722 "numel": ["def numel(self: Tensor) -> _int: ..."], 723 "as_tensor": [ 724 "def as_tensor({}) -> Tensor: ...".format( 725 ", ".join( 726 [ 727 "data: Any", 728 "dtype: Optional[_dtype] = None", 729 DEVICE_PARAM, 730 ] 731 ) 732 ) 733 ], 734 "get_num_threads": ["def get_num_threads() -> _int: ..."], 735 "set_num_threads": ["def set_num_threads(num: _int) -> None: ..."], 736 "init_num_threads": ["def init_num_threads() -> None: ..."], 737 "get_num_interop_threads": ["def get_num_interop_threads() -> _int: ..."], 738 "set_num_interop_threads": [ 739 "def set_num_interop_threads(num: _int) -> None: ..." 740 ], 741 # These functions are explicitly disabled by 742 # SKIP_PYTHON_BINDINGS because they are hand bound. 743 # Correspondingly, we must hand-write their signatures. 744 "tensor": [f"def tensor(data: Any, {FACTORY_PARAMS}) -> Tensor: ..."], 745 "sparse_coo_tensor": [ 746 "def sparse_coo_tensor({}) -> Tensor: ...".format( 747 ", ".join( 748 [ 749 "indices: Tensor", 750 "values: Union[Tensor, List]", 751 "size: Optional[_size] = None", 752 "*", 753 "dtype: Optional[_dtype] = None", 754 "device: Optional[DeviceLikeType] = None", 755 "requires_grad: _bool = False", 756 "check_invariants: Optional[_bool] = None", 757 "is_coalesced: Optional[_bool] = None", 758 ] 759 ) 760 ) 761 ], 762 "sparse_compressed_tensor": [ 763 "def sparse_compressed_tensor({}) -> Tensor: ...".format( 764 ", ".join( 765 [ 766 "compressed_indices: Union[Tensor, List]", 767 "plain_indices: Union[Tensor, List]", 768 "values: Union[Tensor, List]", 769 "size: Optional[_size] = None", 770 "*", 771 "dtype: Optional[_dtype] = None", 772 "layout: Optional[_layout] = None", 773 "device: Optional[DeviceLikeType] = None", 774 "requires_grad: _bool = False", 775 "check_invariants: Optional[_bool] = None", 776 ] 777 ) 778 ) 779 ], 780 "_sync": ["def _sync(t: Tensor) -> None: ..."], 781 "_is_functional_tensor": [ 782 "def _is_functional_tensor(t: Tensor) -> _bool: ..." 783 ], 784 "_is_functional_tensor_base": [ 785 "def _is_functional_tensor_base(t: Tensor) -> _bool: ..." 786 ], 787 "_from_functional_tensor": [ 788 "def _from_functional_tensor(t: Tensor) -> Tensor: ..." 789 ], 790 "_to_functional_tensor": [ 791 "def _to_functional_tensor(t: Tensor) -> Tensor: ..." 792 ], 793 "_functionalize_replace": [ 794 "def _functionalize_replace(self_: Tensor, other: Tensor) -> None: ..." 795 ], 796 "_functionalize_commit_update": [ 797 "def _functionalize_commit_update(t: Tensor) -> None: ..." 798 ], 799 "_functionalize_unsafe_set": [ 800 "def _functionalize_unsafe_set(dst: Tensor, src: Tensor) -> None: ..." 801 ], 802 "_functionalize_mark_mutation_hidden_from_autograd": [ 803 "def _functionalize_mark_mutation_hidden_from_autograd(t: Tensor) -> None: ..." 804 ], 805 "_functionalize_are_all_mutations_hidden_from_autograd": [ 806 "def _functionalize_are_all_mutations_hidden_from_autograd(t: Tensor) -> _bool: ..." 807 ], 808 "_functionalize_are_all_mutations_under_no_grad_or_inference_mode": [ 809 "def _functionalize_are_all_mutations_under_no_grad_or_inference_mode(t: Tensor) -> _bool: ..." 810 ], 811 "_functionalize_was_inductor_storage_resized": [ 812 "def _functionalize_was_inductor_storage_resized(t: Tensor) -> _bool: ..." 813 ], 814 "_functionalize_sync": ["def _functionalize_sync(t: Tensor) -> None: ..."], 815 "_functionalize_was_storage_changed": [ 816 "def _functionalize_was_storage_changed(tensor: Tensor) -> _bool: ..." 817 ], 818 "_functionalize_set_storage_changed": [ 819 "def _functionalize_set_storage_changed(tensor: Tensor) -> _bool: ..." 820 ], 821 "_functionalize_has_metadata_mutation": [ 822 "def _functionalize_has_metadata_mutation(tensor: Tensor) -> _bool: ..." 823 ], 824 "_functionalize_apply_view_metas": [ 825 "def _functionalize_apply_view_metas(tensor: Tensor, base: Tensor) -> Tensor: ..." 826 ], 827 "_functionalize_is_symbolic": [ 828 "def _functionalize_is_symbolic(tensor: Tensor) -> _bool: ..." 829 ], 830 "_enable_functionalization": [ 831 "def _enable_functionalization(*, reapply_views: _bool = False): ..." 832 ], 833 "_disable_functionalization": ["def _disable_functionalization(): ..."], 834 "range": [ 835 "def range({}) -> Tensor: ...".format( 836 ", ".join( 837 [ 838 "start: Number", 839 "end: Number", 840 "step: Number = 1", 841 "*", 842 "out: Optional[Tensor] = None", 843 FACTORY_PARAMS, 844 ] 845 ) 846 ) 847 ], 848 "arange": [ 849 "def arange({}) -> Tensor: ...".format( 850 ", ".join( 851 [ 852 "start: Number", 853 "end: Number", 854 "step: Number", 855 "*", 856 "out: Optional[Tensor] = None", 857 FACTORY_PARAMS, 858 ] 859 ) 860 ), 861 "def arange({}) -> Tensor: ...".format( 862 ", ".join( 863 [ 864 "start: Number", 865 "end: Number", 866 "*", 867 "out: Optional[Tensor] = None", 868 FACTORY_PARAMS, 869 ] 870 ) 871 ), 872 "def arange({}) -> Tensor: ...".format( 873 ", ".join( 874 [ 875 "end: Number", 876 "*", 877 "out: Optional[Tensor] = None", 878 FACTORY_PARAMS, 879 ] 880 ) 881 ), 882 ], 883 "linspace": [ 884 "def linspace({}) -> Tensor: ...".format( 885 ", ".join( 886 [ 887 "start: Number", 888 "end: Number", 889 "steps: Optional[_int] = None", 890 "*", 891 "out: Optional[Tensor] = None", 892 FACTORY_PARAMS, 893 ] 894 ) 895 ) 896 ], 897 "logspace": [ 898 "def logspace({}) -> Tensor: ...".format( 899 ", ".join( 900 [ 901 "start: Number", 902 "end: Number", 903 "steps: Optional[_int] = None", 904 "base: _float = 10.0", 905 "*", 906 "out: Optional[Tensor] = None", 907 FACTORY_PARAMS, 908 ] 909 ) 910 ) 911 ], 912 "randint": [ 913 "def randint({}) -> Tensor: ...".format( 914 ", ".join( 915 [ 916 "low: _int", 917 "high: _int", 918 "size: _size", 919 "*", 920 "generator: Optional[Generator] = None", 921 FACTORY_PARAMS, 922 ] 923 ) 924 ), 925 "def randint({}) -> Tensor: ...".format( 926 ", ".join( 927 [ 928 "high: _int", 929 "size: _size", 930 "*", 931 "generator: Optional[Generator] = None", 932 FACTORY_PARAMS, 933 ] 934 ) 935 ), 936 ], 937 "full": [ 938 "def full({}) -> Tensor: ...".format( 939 ", ".join( 940 [ 941 "size: _size", 942 "fill_value: Union[Number, _complex]", 943 "*", 944 "out: Optional[Tensor] = None", 945 "layout: _layout = strided", 946 FACTORY_PARAMS, 947 ] 948 ) 949 ), 950 "def full({}) -> Tensor: ...".format( 951 ", ".join( 952 [ 953 "size: _size", 954 "fill_value: Union[Number, _complex]", 955 "*", 956 "names: List[Union[str, None]]", 957 "layout: _layout = strided", 958 FACTORY_PARAMS, 959 ] 960 ) 961 ), 962 ], 963 "is_grad_enabled": ["def is_grad_enabled() -> _bool: ..."], 964 "is_inference_mode_enabled": [ 965 "def is_inference_mode_enabled() -> _bool: ..." 966 ], 967 "nonzero": [ 968 "def nonzero(input: Tensor, *, as_tuple: Literal[False] = False, out: Optional[Tensor] = None) -> Tensor: ...", 969 "def nonzero(input: Tensor, *, as_tuple: Literal[True]) -> Tuple[Tensor, ...]: ...", 970 ], 971 "dsmm": ["def dsmm(input: Tensor, mat2: Tensor) -> Tensor: ..."], 972 "hsmm": ["def hsmm(input: Tensor, mat2: Tensor) -> Tensor: ..."], 973 "saddmm": [ 974 "def saddmm({}) -> Tensor: ...".format( 975 ", ".join( 976 [ 977 "input: Tensor", 978 "mat1: Tensor", 979 "mat2: Tensor", 980 "*", 981 "beta: Number = 1", 982 "alpha: Number = 1", 983 "out: Optional[Tensor] = None", 984 ] 985 ) 986 ) 987 ], 988 "spmm": ["def spmm(input: Tensor, mat2: Tensor) -> Tensor: ..."], 989 "div": [ 990 "def div({}) -> Tensor: ...".format( 991 ", ".join( 992 [ 993 "input: Union[Tensor, Number]", 994 "other: Union[Tensor, Number]", 995 "*", 996 "rounding_mode: Optional[str] = None", 997 "out: Optional[Tensor] = None", 998 ] 999 ) 1000 ) 1001 ], 1002 } 1003 ) 1004 for binop in ["true_divide", "floor_divide"]: 1005 unsorted_function_hints[binop].append( 1006 f"def {binop}(input: Union[Tensor, Number], other: Union[Tensor, Number], " 1007 "*, out: Optional[Tensor] = None) -> Tensor: ..." 1008 ) 1009 for binop in ["mul"]: 1010 unsorted_function_hints[binop].append( 1011 f"def {binop}(input: Union[Tensor, Number, _complex], other: Union[Tensor, Number, _complex], " 1012 "*, out: Optional[Tensor] = None) -> Tensor: ..." 1013 ) 1014 for binop in ["add", "sub"]: 1015 unsorted_function_hints[binop].append( 1016 f"def {binop}(input: Union[Tensor, Number, _complex], other: Union[Tensor, Number, _complex], " 1017 "*, alpha: Optional[Union[Number, _complex]] = 1, out: Optional[Tensor] = None) -> Tensor: ..." 1018 ) 1019 1020 native_functions = parse_native_yaml( 1021 native_yaml_path, tags_yaml_path 1022 ).native_functions 1023 native_functions = list(filter(should_generate_py_binding, native_functions)) 1024 1025 function_signatures = load_signatures( 1026 native_functions, deprecated_yaml_path, method=False, pyi=True 1027 ) 1028 sig_groups = get_py_torch_functions(function_signatures) 1029 for group in sorted(sig_groups, key=lambda g: g.signature.name): 1030 name = group.signature.name 1031 unsorted_function_hints[name] += generate_type_hints(group) 1032 1033 structseq = returns_structseq_pyi(group.signature) 1034 if structseq is not None and not group.signature.deprecated: 1035 # deprecated structseqs are currently not included for torch functions 1036 tuple_name, tuple_def = structseq 1037 if tuple_name in structseqs: 1038 assert structseqs[tuple_name] == tuple_def 1039 else: 1040 structseqs[tuple_name] = tuple_def 1041 1042 def replace_special_case(hint: str) -> str: 1043 # NB: Keep this in sync with enum in aten/src/ATen/core/Reduction.h 1044 hint = hint.replace("at::Reduction::Mean", "1") 1045 hint = hint.replace(": Tensor = None", ": Optional[Tensor] = None") 1046 # Match both: 1047 # ": Union[Tensor, Tuple[Tensor, ...], List[Tensor]] = None" 1048 # ": Union[Tuple[Tensor, ...], List[Tensor]] = None" 1049 hint = hint.replace( 1050 "Tuple[Tensor, ...], List[Tensor]] = None", 1051 "Tuple[Tensor, ...], List[Tensor], None] = None", 1052 ) 1053 return hint 1054 1055 docstrs = gather_docstrs() 1056 function_hints = [] 1057 for name, hints in sorted(unsorted_function_hints.items()): 1058 hints = [replace_special_case(h) for h in hints] 1059 if len(hints) > 1: 1060 hints = ["@overload\n" + h for h in hints] 1061 docstr = docstrs.get(f"torch.{name}") 1062 if docstr is not None: 1063 hints = [add_docstr_to_hint(docstr, h) for h in hints] 1064 function_hints += hints 1065 1066 # Generate type signatures for Tensor methods 1067 # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ 1068 1069 unsorted_tensor_method_hints: dict[str, list[str]] = collections.defaultdict(list) 1070 unsorted_tensor_method_hints.update( 1071 { 1072 "size": [ 1073 "def size(self, dim: None = None) -> Size: ...", 1074 "def size(self, dim: _int) -> _int: ...", 1075 ], 1076 "stride": [ 1077 "def stride(self, dim: None = None) -> Tuple[_int, ...]: ...", 1078 "def stride(self, dim: _int) -> _int: ...", 1079 ], 1080 "new_ones": [ 1081 f"def new_ones(self, size: _size, {FACTORY_PARAMS}) -> Tensor: ..." 1082 ], 1083 "new_tensor": [ 1084 f"def new_tensor(self, data: Any, {FACTORY_PARAMS}) -> Tensor: ..." 1085 ], 1086 "__new__": ["def __new__(cls, *args, **kwargs) -> Self: ..."], 1087 # new and __init__ have the same signatures differ only in return type 1088 # Adapted from legacy_tensor_ctor and legacy_tensor_new 1089 "new": [ 1090 f"def new(cls, *args: Any, {DEVICE_PARAM}) -> Self: ...", 1091 "def new(cls, storage: Storage) -> Self: ...", 1092 "def new(cls, other: Tensor) -> Self: ...", 1093 f"def new(cls, size: _size, *, {DEVICE_PARAM}) -> Self: ...", 1094 ], 1095 "__init__": [ 1096 f"def __init__(self, *args: Any, {DEVICE_PARAM}) -> None: ...", 1097 "def __init__(self, storage: Storage) -> None: ...", 1098 "def __init__(self, other: Tensor) -> None: ...", 1099 f"def __init__(self, size: _size, *, {DEVICE_PARAM}) -> None: ...", 1100 ], 1101 "as_subclass": ["def as_subclass(self, cls: _Type[S]) -> S: ..."], 1102 "_make_subclass": [ 1103 "@staticmethod \ndef _make_subclass({}) -> S: ...".format( 1104 ", ".join( 1105 [ 1106 "cls: _Type[S]", 1107 "data: Tensor", 1108 "require_grad: _bool = False", 1109 "dispatch_strides: _bool = False", 1110 "dispatch_device: _bool = False", 1111 "device_for_backend_keys: Optional[_device] = None", 1112 ] 1113 ) 1114 ) 1115 ], 1116 "__contains__": ["def __contains__(self, other: Any, /) -> _bool: ..."], 1117 "__getitem__": [f"def __getitem__(self, {INDICES}) -> Tensor: ..."], 1118 "__setitem__": [ 1119 f"def __setitem__(self, {INDICES}, val: Union[Tensor, Number]) -> None: ..." 1120 ], 1121 "tolist": ["def tolist(self) -> List: ..."], 1122 "requires_grad_": [ 1123 "def requires_grad_(self, mode: _bool = True) -> Tensor: ..." 1124 ], 1125 "element_size": ["def element_size(self) -> _int: ..."], 1126 "data_ptr": ["def data_ptr(self) -> _int: ..."], 1127 "dim": ["def dim(self) -> _int: ..."], 1128 "nonzero": [ 1129 "def nonzero(self, *, as_tuple: Literal[False] = False) -> Tensor: ...", 1130 "def nonzero(self, *, as_tuple: Literal[True]) -> Tuple[Tensor, ...]: ...", 1131 ], 1132 "numel": ["def numel(self) -> _int: ..."], 1133 "ndimension": ["def ndimension(self) -> _int: ..."], 1134 "nelement": ["def nelement(self) -> _int: ..."], 1135 "cuda": [ 1136 "def cuda({}) -> Tensor: ...".format( 1137 ", ".join( 1138 [ 1139 "self", 1140 "device: Optional[Union[_device, _int, str]] = None", 1141 "non_blocking: _bool = False", 1142 "memory_format: torch.memory_format = torch.preserve_format", 1143 ] 1144 ) 1145 ) 1146 ], 1147 "xpu": [ 1148 "def xpu({}) -> Tensor: ...".format( 1149 ", ".join( 1150 [ 1151 "self", 1152 "device: Optional[Union[_device, _int, str]] = None", 1153 "non_blocking: _bool = False", 1154 "memory_format: torch.memory_format = torch.preserve_format", 1155 ] 1156 ) 1157 ) 1158 ], 1159 "cpu": [ 1160 "def cpu(self, memory_format: torch.memory_format = torch.preserve_format) -> Tensor: ..." 1161 ], 1162 "numpy": ["def numpy(self, *, force: _bool = False) -> numpy.ndarray: ..."], 1163 "apply_": ["def apply_(self, callable: Callable) -> Tensor: ..."], 1164 "map_": [ 1165 "def map_(self, tensor: Tensor, callable: Callable) -> Tensor: ..." 1166 ], 1167 "map2_": [ 1168 "def map2_(self, x: Tensor, y: Tensor, callable: Callable) -> Tensor: ..." 1169 ], 1170 "storage": ["def untyped_storage(self) -> UntypedStorage: ..."], 1171 "storage_type": ["def storage_type(self) -> Storage: ..."], 1172 "type": [ 1173 "def type(self, dtype: None = None, non_blocking: _bool = False) -> str: ...", 1174 "def type(self, dtype: Union[str, _dtype], non_blocking: _bool = False) -> Tensor: ...", 1175 ], 1176 "get_device": ["def get_device(self) -> _int: ..."], 1177 "contiguous": [ 1178 "def contiguous(self, memory_format=torch.contiguous_format) -> Tensor: ..." 1179 ], 1180 "has_names": ["def has_names(self) -> _bool: ..."], 1181 "is_contiguous": [ 1182 "def is_contiguous(self, memory_format=torch.contiguous_format) -> _bool: ..." 1183 ], 1184 "_is_view": ["def _is_view(self) -> _bool: ..."], 1185 "is_cpu": ["is_cpu: _bool"], 1186 "is_cuda": ["is_cuda: _bool"], 1187 "is_leaf": ["is_leaf: _bool"], 1188 "is_nested": ["is_nested: _bool"], 1189 "is_sparse": ["is_sparse: _bool"], 1190 "is_sparse_csr": ["is_sparse_csr: _bool"], 1191 "is_quantized": ["is_quantized: _bool"], 1192 "is_meta": ["is_meta: _bool"], 1193 "is_mps": ["is_mps: _bool"], 1194 "is_mtia": ["is_mtia: _bool"], 1195 "is_maia": ["is_maia: _bool"], 1196 "is_mkldnn": ["is_mkldnn: _bool"], 1197 "is_vulkan": ["is_vulkan: _bool"], 1198 "is_ipu": ["is_ipu: _bool"], 1199 "storage_offset": ["def storage_offset(self) -> Union[_int, SymInt]: ..."], 1200 "to": [ 1201 ( 1202 f"def to(self, {args}, non_blocking: _bool = False, copy: _bool = False, *, " 1203 "memory_format: Optional[torch.memory_format] = None) -> Tensor: ..." 1204 ) 1205 for args in [ 1206 "dtype: _dtype", 1207 "device: Optional[DeviceLikeType] = None, dtype: Optional[_dtype] = None", 1208 "other: Tensor", 1209 ] 1210 ], 1211 "item": ["def item(self) -> Number: ..."], 1212 "copy_": [ 1213 "def copy_(self, src: Tensor, non_blocking: _bool = False) -> Tensor: ..." 1214 ], 1215 "set_": [ 1216 "def set_(self, storage: Union[Storage, TypedStorage, UntypedStorage], " 1217 "offset: IntLikeType, size: _symsize, stride: _symsize) -> Tensor: ...", 1218 "def set_(self, storage: Union[Storage, TypedStorage, UntypedStorage]) -> Tensor: ...", 1219 ], 1220 "split": [ 1221 "def split(self, split_size: _int, dim: _int = 0) -> Sequence[Tensor]: ...", 1222 "def split(self, split_size: Tuple[_int, ...], dim: _int = 0) -> Sequence[Tensor]: ...", 1223 ], 1224 "div": [ 1225 "def div(self, other: Union[Tensor, Number], *, rounding_mode: Optional[str] = None) -> Tensor: ..." 1226 ], 1227 "div_": [ 1228 "def div_(self, other: Union[Tensor, Number], *, rounding_mode: Optional[str] = None) -> Tensor: ..." 1229 ], 1230 } 1231 ) 1232 for binop in ["true_divide", "floor_divide"]: 1233 for inplace in [False, True]: 1234 out_suffix = ", *, out: Optional[Tensor] = None" 1235 if inplace: 1236 binop += "_" 1237 out_suffix = "" 1238 unsorted_tensor_method_hints[binop].append( 1239 f"def {binop}(self, other: Union[Tensor, Number, torch.SymInt, torch.SymFloat]{out_suffix})" 1240 " -> Tensor: ..." 1241 ) 1242 for binop in ["mul"]: 1243 for inplace in [False, True]: 1244 out_suffix = ", *, out: Optional[Tensor] = None" 1245 if inplace: 1246 binop += "_" 1247 out_suffix = "" 1248 unsorted_tensor_method_hints[binop].append( 1249 f"def {binop}(self, other: Union[Tensor, Number, _complex, torch.SymInt, torch.SymFloat]{out_suffix})" 1250 " -> Tensor: ..." 1251 ) 1252 for binop in ["add", "sub"]: 1253 for inplace in [False, True]: 1254 out_suffix = ", out: Optional[Tensor] = None" 1255 if inplace: 1256 binop += "_" 1257 out_suffix = "" 1258 unsorted_tensor_method_hints[binop].append( 1259 f"def {binop}(self, other: Union[Tensor, Number, _complex, torch.SymInt, torch.SymFloat], " 1260 f"*, alpha: Optional[Union[Number, _complex]] = 1{out_suffix})" 1261 " -> Tensor: ..." 1262 ) 1263 simple_conversions = [ 1264 "byte", 1265 "char", 1266 "double", 1267 "float", 1268 "half", 1269 "int", 1270 "long", 1271 "short", 1272 "bool", 1273 "bfloat16", 1274 ] 1275 for name in simple_conversions: 1276 unsorted_tensor_method_hints[name].append(f"def {name}(self) -> Tensor: ...") 1277 1278 # pyi tensor methods don't currently include deprecated signatures for some reason 1279 # TODO: we should probably add them in 1280 tensor_method_signatures = load_signatures( 1281 native_functions, 1282 deprecated_yaml_path, 1283 method=True, 1284 skip_deprecated=True, 1285 pyi=True, 1286 ) 1287 tensor_method_sig_groups = get_py_torch_functions( 1288 tensor_method_signatures, method=True 1289 ) 1290 1291 for group in sorted(tensor_method_sig_groups, key=lambda g: g.signature.name): 1292 name = group.signature.name 1293 unsorted_tensor_method_hints[name] += generate_type_hints(group) 1294 1295 structseq = returns_structseq_pyi(group.signature) 1296 if structseq is not None and not group.signature.deprecated: 1297 # deprecated structseqs are currently not included for torch functions 1298 tuple_name, tuple_def = structseq 1299 if tuple_name in structseqs: 1300 assert structseqs[tuple_name] == tuple_def 1301 else: 1302 structseqs[tuple_name] = tuple_def 1303 1304 for op in all_ops: 1305 name = f"__{op}__" 1306 unsorted_tensor_method_hints[name] += sig_for_ops(name) 1307 1308 tensor_method_hints = [] 1309 for name, hints in sorted(unsorted_tensor_method_hints.items()): 1310 if len(hints) > 1: 1311 hints = ["@overload\n" + h for h in hints] 1312 docstr = docstrs.get(f"torch._C.TensorBase.{name}") 1313 if docstr is not None: 1314 hints = [add_docstr_to_hint(docstr, h) for h in hints] 1315 tensor_method_hints += hints 1316 1317 # TODO: Missing type hints for nn 1318 1319 # Generate structseq definitions 1320 # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ 1321 1322 structseq_defs = [f"{defn}\n" for defn in structseqs.values()] 1323 1324 # Generate type signatures for legacy classes 1325 # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ 1326 1327 legacy_storage_base_hints = ["class StorageBase(object): ..."] 1328 1329 legacy_class_hints = [] 1330 for c in ( 1331 "DoubleTensor", 1332 "FloatTensor", 1333 "BFloat16Tensor", 1334 "LongTensor", 1335 "IntTensor", 1336 "ShortTensor", 1337 "HalfTensor", 1338 "CharTensor", 1339 "ByteTensor", 1340 "BoolTensor", 1341 ): 1342 legacy_class_hints.append(f"class {c}(Tensor): ...") 1343 1344 # Generate type signatures for dtype classes 1345 # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ 1346 1347 # TODO: don't explicitly list dtypes here; get it from canonical 1348 # source 1349 dtype_class_hints = [ 1350 f"{n}: dtype = ..." 1351 for n in [ 1352 "float32", 1353 "float", 1354 "float64", 1355 "double", 1356 "float16", 1357 "bfloat16", 1358 "float8_e4m3fn", 1359 "float8_e4m3fnuz", 1360 "float8_e5m2", 1361 "float8_e5m2fnuz", 1362 "half", 1363 "uint8", 1364 "uint16", 1365 "uint32", 1366 "uint64", 1367 "int8", 1368 "int16", 1369 "short", 1370 "int32", 1371 "int", 1372 "int64", 1373 "long", 1374 "complex32", 1375 "complex64", 1376 "chalf", 1377 "cfloat", 1378 "complex128", 1379 "cdouble", 1380 "quint8", 1381 "qint8", 1382 "qint32", 1383 "bool", 1384 "quint4x2", 1385 "quint2x4", 1386 "bits1x8", 1387 "bits2x4", 1388 "bits4x2", 1389 "bits8", 1390 "bits16", 1391 ] 1392 ] 1393 1394 # Generate __all__ directive 1395 # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ 1396 1397 # Include only the functions that contain hints, to prevent undefined 1398 # symbols to be included in the `__all__` directive. 1399 hinted_function_names = [ 1400 name for name, hint in unsorted_function_hints.items() if hint 1401 ] 1402 all_symbols = sorted(list(structseqs.keys()) + hinted_function_names) 1403 all_directive = pformat(all_symbols, width=100, compact=True).split("\n") 1404 all_directive[0] = f"__all__ = {all_directive[0]}" 1405 1406 # Dispatch key hints 1407 # ~~~~~~~~~~~~~~~~~~ 1408 dispatch_key_hints = [f"{d.name}: DispatchKey = ..." for d in DispatchKey] 1409 torch_dispatch_mode_key_hints = [ 1410 f"{k.name}: _TorchDispatchModeKey = ..." for k in _TorchDispatchModeKey 1411 ] 1412 1413 # Tags Enum type hints 1414 # ~~~~~~~~~~~~~~~~~~~~ 1415 1416 tag_names = sorted(parse_tags_yaml(tags_yaml_path)) 1417 tag_attributes = "\n".join( 1418 f"{name}: _int = {index}" for index, name in enumerate(tag_names) 1419 ) 1420 1421 # Write out the stub 1422 # ~~~~~~~~~~~~~~~~~~ 1423 1424 env = { 1425 "structseq_defs": structseq_defs, 1426 "function_hints": function_hints, 1427 "tensor_method_hints": tensor_method_hints, 1428 "legacy_class_hints": legacy_class_hints, 1429 "legacy_storage_base_hints": legacy_storage_base_hints, 1430 "dtype_class_hints": dtype_class_hints, 1431 "dispatch_key_hints": dispatch_key_hints, 1432 "torch_dispatch_mode_key_hints": torch_dispatch_mode_key_hints, 1433 "all_directive": all_directive, 1434 "tag_attributes": tag_attributes, 1435 } 1436 fm.write_with_template( 1437 "torch/_C/__init__.pyi", 1438 "torch/_C/__init__.pyi.in", 1439 lambda: env, 1440 ) 1441 fm.write_with_template( 1442 "torch/_C/_VariableFunctions.pyi", 1443 "torch/_C/_VariableFunctions.pyi.in", 1444 lambda: env, 1445 ) 1446 fm.write_with_template( 1447 "torch/_VF.pyi", 1448 "torch/_C/_VariableFunctions.pyi.in", 1449 lambda: env, 1450 ) 1451 fm.write_with_template( 1452 "torch/return_types.pyi", 1453 "torch/_C/return_types.pyi.in", 1454 lambda: env, 1455 ) 1456 gen_nn_functional(fm) 1457 1458 1459def main() -> None: 1460 parser = argparse.ArgumentParser(description="Generate type stubs for PyTorch") 1461 parser.add_argument( 1462 "--native-functions-path", 1463 metavar="NATIVE", 1464 default="aten/src/ATen/native/native_functions.yaml", 1465 help="path to native_functions.yaml", 1466 ) 1467 parser.add_argument( 1468 "--tags-path", 1469 metavar="TAGS", 1470 default="aten/src/ATen/native/tags.yaml", 1471 help="path to tags.yaml", 1472 ) 1473 parser.add_argument( 1474 "--deprecated-functions-path", 1475 metavar="DEPRECATED", 1476 default="tools/autograd/deprecated.yaml", 1477 help="path to deprecated.yaml", 1478 ) 1479 parser.add_argument( 1480 "--out", metavar="OUT", default=".", help="path to output directory" 1481 ) 1482 args = parser.parse_args() 1483 fm = FileManager(install_dir=args.out, template_dir=".", dry_run=False) 1484 gen_pyi( 1485 args.native_functions_path, args.tags_path, args.deprecated_functions_path, fm 1486 ) 1487 1488 1489if __name__ == "__main__": 1490 main() 1491