1# mypy: allow-untyped-defs 2import operator 3from enum import Enum 4from functools import partial, reduce 5from typing import Callable, List, Optional, Sequence, Tuple, Type, Union 6 7import torch 8import torch._prims_common as utils 9import torch.library 10from torch import sym_float, Tensor 11from torch._C import _get_default_device 12from torch._higher_order_ops.effects import new_token_tensor 13from torch._library.utils import is_functional_schema 14from torch._prims.debug_prims import register_debug_prims 15from torch._prims.rng_prims import register_rng_prims 16from torch._prims_common import ( 17 Dim, 18 DimsSequenceType, 19 DimsType, 20 IntLike, 21 Number, 22 NumberType, 23 RETURN_TYPE, 24 ShapeType, 25 StrideType, 26 TensorLike, 27 TensorLikeType, 28 type_to_dtype, 29) 30from torch._prims_common.wrappers import backwards_not_supported 31from torch._subclasses.fake_tensor import FakeTensor, FakeTensorMode 32from torch.overrides import handle_torch_function, has_torch_function 33from torch.utils._pytree import tree_flatten, tree_map, tree_unflatten 34 35 36prim = torch.library.Library("prims", "DEF") 37prim_impl = torch.library.Library("prims", "IMPL", "CompositeExplicitAutograd") 38prim_backend_select_impl = torch.library.Library("prims", "IMPL", "BackendSelect") 39prim_autograd_impl = torch.library.Library("prims", "IMPL", "Autograd") 40prim_meta_impl = torch.library.Library("prims", "IMPL", "Meta") 41 42# Experimental module containing prototype "primitive" operations. 43 44__all__ = [ 45 # 46 # Common datastructures and helpers 47 # 48 "RETURN_TYPE", 49 # 50 # Elementwise unary prims 51 # 52 "abs", 53 "acos", 54 "acosh", 55 "asin", 56 "asinh", 57 "atan", 58 "atanh", 59 "cos", 60 "cosh", 61 "bessel_i0", 62 "bessel_i0e", 63 "bessel_i1", 64 "bessel_i1e", 65 "bessel_j0", 66 "bessel_j1", 67 "bitwise_not", 68 "cbrt", 69 "ceil", 70 "conj_physical", 71 "digamma", 72 "erf", 73 "erf_inv", 74 "erfc", 75 "erfcx", 76 "exp", 77 "expm1", 78 "exp2", 79 "fill", 80 "floor", 81 "imag", 82 "isfinite", 83 "lgamma", 84 "log", 85 "log1p", 86 "log2", 87 "log10", 88 "ndtri", 89 "neg", 90 "real", 91 "reciprocal", 92 "round", 93 "sign", 94 "signbit", 95 "sin", 96 "sinh", 97 "spherical_bessel_j0", 98 "sqrt", 99 "tan", 100 "tanh", 101 "trunc", 102 # 103 # Elementwise binary prims 104 # 105 "add", 106 "atan2", 107 "bitwise_and", 108 "bitwise_or", 109 "bitwise_xor", 110 # 'complex', # needs custom meta 111 "div", 112 "eq", 113 "fmax", 114 "fmin", 115 "fmod", 116 "frexp", 117 "gcd", 118 "ge", 119 "gt", 120 "hypot", 121 "igamma", 122 "igammac", 123 "le", 124 "lt", 125 "maximum", 126 "minimum", 127 "mul", 128 "ne", 129 "nextafter", 130 "pow", 131 "remainder", 132 "rsqrt", 133 "shift_left", 134 "shift_right_arithmetic", 135 "shift_right_logical", # not implemented 136 "sub", 137 "zeta", 138 # 139 # View prims 140 # 141 "as_strided", 142 "broadcast_in_dim", 143 "collapse_view", 144 "conj", 145 "expand_dims", 146 "slice", 147 "slice_in_dim", # implemented using slice -- make this a ref? 148 "split_dim", 149 "squeeze", 150 "transpose", 151 "view_of", 152 "view_element_type", 153 # 154 # Functionalized view mutations 155 # 156 "as_strided_scatter", 157 # 158 # Shape prims 159 # 160 "collapse", 161 "cat", 162 "reshape", 163 "rev", 164 # 165 # Conditional prims 166 # 167 "where", 168 # 169 # Data conversion and movement prims 170 # 171 "clone", 172 "convert_element_type", 173 "device_put", 174 "item", 175 "maximum_value", 176 "minimum_value", 177 "copy_strided", 178 # 179 # Inplace prims 180 # 181 "copy_to", 182 "resize", 183 # "_set", # Commented out, see note below 184 # 185 # Reduction prims 186 # 187 "amax", 188 "amin", 189 "prod", 190 "sum", 191 "xor_sum", 192 "var", 193 # 194 # Tensor Creation Prims 195 # 196 "empty_strided", 197 "empty_permuted", 198 "scalar_tensor", 199 "iota", 200 # 201 # Linear algebra (linalg) Prims 202 # 203 "svd", 204 # 205 # Randomness Prims 206 # 207 "normal", 208 "_uniform_helper", 209 # 210 # FFT prims 211 # 212 "fft_r2c", 213 "fft_c2c", 214 "fft_c2r", 215 # 216 # prims for making/sinking tokens 217 # 218 "_make_token", 219 "_sink_tokens", 220] 221 222 223def TensorMeta( 224 tensorlike: Optional[Union[NumberType, torch.Tensor]] = None, 225 *, 226 shape: Optional[ShapeType] = None, 227 strides: Optional[StrideType] = None, 228 dtype: Optional[torch.dtype] = None, 229 device: Optional[Union[torch.device, str]] = None, 230): 231 if isinstance(tensorlike, Number): 232 assert not shape and (shape is None or isinstance(shape, Sequence)) 233 assert not strides and (strides is None or isinstance(strides, Sequence)) 234 inferred_shape: Tuple[int, ...] = () 235 inferred_strides: Tuple[int, ...] = () 236 inferred_dtype = type_to_dtype(type(tensorlike)) 237 inferred_device = torch.device("cpu") 238 # TODO: This looks wrong, a number that is wrapped into a tensor 239 # needs to behave differently than a scalar tensor for type 240 # promotion purposes 241 elif tensorlike is not None: 242 assert isinstance(tensorlike, torch.Tensor) 243 inferred_shape = tuple(tensorlike.shape) 244 inferred_strides = tuple(tensorlike.stride()) 245 inferred_dtype = tensorlike.dtype 246 inferred_device = tensorlike.device 247 else: 248 # If no tensorlike "example" is given then all metadata 249 # must be provided explicitly 250 assert shape is not None 251 assert strides is not None 252 assert dtype is not None 253 assert device is not None 254 255 shape = inferred_shape if shape is None else tuple(shape) # type: ignore[possibly-undefined] 256 strides = inferred_strides if strides is None else tuple(strides) # type: ignore[possibly-undefined] 257 dtype = inferred_dtype if dtype is None else dtype # type: ignore[possibly-undefined] 258 device = inferred_device if device is None else device # type: ignore[possibly-undefined] 259 260 if isinstance(device, str): 261 device = torch.device(device) 262 263 return torch.empty_strided(shape, strides, dtype=dtype, device=device) 264 265 266def _make_prim( 267 *, 268 schema: str, 269 return_type: Union[RETURN_TYPE, Tuple[RETURN_TYPE, ...]], 270 meta: Callable, 271 impl_aten: Callable, 272 doc: str, 273 tags: Optional[Sequence[torch.Tag]] = None, 274 use_old_custom_ops_api: bool = False, 275 register_conj_neg_fallthrough: bool = False, 276): 277 """ 278 Creates a primitive operation. 279 280 """ 281 282 def _prim_impl(*args, **kwargs): 283 # always run the meta function because aten implementation will 284 # typically accept more inputs (e.g., it will do promotion and 285 # broadcasting) which we want to reject 286 meta(*args, **kwargs) 287 return impl_aten(*args, **kwargs) 288 289 # Right now prims don't support autograd (we can and should add an 290 # argument that provides an implementation for backward here.) Because we 291 # don't have derivative formulas, we must setup a custom autograd function 292 # that raises an error if backwards is invoked 293 def _autograd_impl(*args, **kwargs): 294 return backwards_not_supported(_prim)(*args, **kwargs) 295 296 def _backend_select_impl(*args, **kwargs): 297 if kwargs.get("device") and kwargs["device"].type == "meta": 298 return meta(*args, **kwargs) 299 if any(isinstance(x, torch.device) and x.type == "meta" for x in args): 300 return meta(*args, **kwargs) 301 else: 302 return _prim_impl(*args, **kwargs) 303 304 name = schema.split("(")[0] 305 schema = schema[len(name) :] 306 307 # register non-functional ops with old custom ops API 308 cpp_schema = torch._C.parse_schema(name + schema) 309 if use_old_custom_ops_api or not is_functional_schema(cpp_schema): 310 prim.define(name + schema, tags=torch.Tag.pt2_compliant_tag) 311 prim_impl.impl(name, _prim_impl) 312 prim_autograd_impl.impl(name, _autograd_impl) 313 prim_meta_impl.impl(name, meta) 314 else: 315 mutates_args = [] 316 for arg in cpp_schema.arguments: 317 if arg.alias_info is not None and arg.alias_info.is_write: 318 mutates_args.append(arg.name) 319 prim_def = torch.library.custom_op( 320 "prims::" + name, 321 _prim_impl, 322 mutates_args=tuple(mutates_args), 323 schema=schema, 324 ) 325 prim_def.register_fake(meta) 326 327 # all view ops get conj/neg fallthroughs 328 if return_type == RETURN_TYPE.VIEW or register_conj_neg_fallthrough: 329 prim_def._lib.impl(name, torch.library.fallthrough_kernel, "Conjugate") 330 prim_def._lib.impl(name, torch.library.fallthrough_kernel, "Negative") 331 332 _prim_packet = getattr(torch._ops.ops.prims, name) 333 _prim = _prim_packet.default 334 if tags: 335 _prim._tags = tags 336 elif aten_packet := getattr(torch.ops.aten, name, None): 337 overload_tags = [ 338 getattr(aten_packet, overload).tags for overload in aten_packet.overloads() 339 ] 340 tags_intersection = set(overload_tags[0]) 341 tags_intersection.intersection_update(*overload_tags[1:]) 342 343 # dont inadvertently add to prim ops 344 tags_intersection.discard(torch.Tag.core) 345 # causes errors with python ref executor tests, none of the 346 # data dependent pytorch ops actually decompose to prims 347 tags_intersection.discard(torch.Tag.data_dependent_output) 348 349 # iter over first tags for determinism 350 _prim._tags = tuple(t for t in overload_tags[0] if t in tags_intersection) 351 352 from torch._subclasses.fake_tensor import contains_tensor_types 353 354 if not any(contains_tensor_types(a.type) for a in _prim._schema.arguments) or str( 355 _prim 356 ) in [ 357 # See https://github.com/pytorch/pytorch/issues/103532 358 "prims.device_put.default" 359 ]: 360 prim_backend_select_impl.impl(name, _backend_select_impl) 361 362 for p in (_prim_packet, _prim): 363 p.__doc__ = doc 364 p.return_type = return_type # type: ignore[attr-defined] 365 366 p.schema = schema 367 p.prim_impl = _prim_impl 368 p.prim_meta_impl = meta 369 p.impl_aten = impl_aten 370 371 return _prim 372 373 374class ELEMENTWISE_PRIM_TYPE_PROMOTION_KIND(Enum): 375 DEFAULT = (0,) 376 INT_TO_FLOAT = (2,) 377 ALWAYS_BOOL = (3,) 378 COMPLEX_TO_FLOAT = (4,) 379 380 381# TODO: implement dtype validation here, too, or on the corresponding refs 382def _prim_elementwise_meta( 383 *args, 384 type_promotion: ELEMENTWISE_PRIM_TYPE_PROMOTION_KIND, 385 args_with_fixed_dtypes: Optional[Tuple[TensorLikeType, ...]] = None, 386) -> FakeTensor: 387 """ 388 Meta function for elementwise operations that produce outputs in the same dtype 389 as their inputs. 390 391 Stride logic is currently incorrect. 392 """ 393 394 assert len(args) > 0 395 396 utils.check_same_dtype(*args) 397 398 args_ = list(args) 399 if args_with_fixed_dtypes is not None: 400 args_ = list(args_with_fixed_dtypes) + args_ 401 402 utils.check_same_device(*args_, allow_cpu_scalar_tensors=True) 403 utils.check_same_shape(*args_, allow_cpu_scalar_tensors=True) 404 405 l2p_perm = utils.compute_elementwise_output_logical_to_physical_perm(*args_) 406 shape = utils.extract_shape(*args_, allow_cpu_scalar_tensors=True) 407 408 # Acquires the dtype 409 dtype = None 410 scalar_type = None 411 for arg in args: 412 if isinstance(arg, TensorLike): 413 if not utils.is_cpu_scalar_tensor(arg): 414 dtype = arg.dtype 415 break 416 else: 417 dtype = arg.dtype 418 elif isinstance(arg, Number): 419 scalar_type = type(arg) 420 421 if dtype is None and scalar_type is not None: 422 dtype = utils.type_to_dtype(scalar_type) 423 424 # Acquires the device (if it exists) or number 425 device = None 426 number = None 427 for arg in args_: 428 if isinstance(arg, TensorLike): 429 if utils.is_cpu_scalar_tensor(arg): 430 if device is None: 431 device = arg.device 432 # keep going, in case there is a cuda tensor later 433 else: 434 device = arg.device 435 break 436 437 elif isinstance(arg, Number): 438 if number is None: 439 number = arg 440 441 # NOTE: type promotion behavior here is mostly hidden from tests because 442 # references will typically handle the type promotion properly even if this doesn't 443 # (but getting it wrong will cause too many casts to be inserted in traces!) 444 if device is not None: 445 assert dtype is not None 446 if type_promotion == ELEMENTWISE_PRIM_TYPE_PROMOTION_KIND.DEFAULT: 447 dtype = dtype 448 elif type_promotion == ELEMENTWISE_PRIM_TYPE_PROMOTION_KIND.ALWAYS_BOOL: 449 dtype = torch.bool 450 elif type_promotion == ELEMENTWISE_PRIM_TYPE_PROMOTION_KIND.INT_TO_FLOAT: 451 if utils.is_integer_dtype(dtype) or utils.is_boolean_dtype(dtype): 452 dtype = torch.get_default_dtype() 453 elif type_promotion == ELEMENTWISE_PRIM_TYPE_PROMOTION_KIND.COMPLEX_TO_FLOAT: 454 if utils.is_complex_dtype(dtype): 455 dtype = utils.corresponding_real_dtype(dtype) 456 else: 457 dtype = dtype 458 459 assert shape is not None 460 return torch.empty_permuted(shape, l2p_perm, device=device, dtype=dtype) # type: ignore[return-value] 461 462 # Number case 463 # TODO: fix number type promotion (bool, complex->float) 464 465 # For now for symint/float, just implementing the common / simple cases of (int,float,symint,symfloat) 466 seen_float = False 467 if isinstance(number, (torch.SymInt, torch.SymFloat)): 468 for a in args: 469 assert isinstance(a, (int, float, torch.SymInt, torch.SymFloat)), "NYI" 470 seen_float = seen_float or isinstance(a, (float, torch.SymFloat)) 471 if seen_float: 472 number = sym_float(number) 473 474 return TensorMeta(number) # type: ignore[arg-type] 475 476 477def _complex_only_elementwise_meta(*args, **kwargs): 478 torch._check( 479 utils.is_complex_dtype(args[0].dtype), lambda: "Only complex dtype is supported" 480 ) 481 return _prim_elementwise_meta(*args, **kwargs) 482 483 484def _make_elementwise_unary_prim( 485 name: str, *, type_promotion: ELEMENTWISE_PRIM_TYPE_PROMOTION_KIND, **kwargs 486): 487 """ 488 Creates an elementwise unary prim. 489 """ 490 491 return _make_prim( 492 schema=f"{name}(Tensor self) -> Tensor", 493 meta=partial(_prim_elementwise_meta, type_promotion=type_promotion), 494 return_type=RETURN_TYPE.NEW, 495 **kwargs, 496 ) 497 498 499def _make_elementwise_binary_prim( 500 name: str, *, type_promotion: ELEMENTWISE_PRIM_TYPE_PROMOTION_KIND, **kwargs 501): 502 """ 503 Creates an elementwise binary prim. 504 """ 505 506 return _make_prim( 507 schema=f"{name}(Tensor self, Tensor other) -> Tensor", 508 meta=partial(_prim_elementwise_meta, type_promotion=type_promotion), 509 return_type=RETURN_TYPE.NEW, 510 **kwargs, 511 ) 512 513 514def _not_impl(*args, **kwargs): 515 raise NotImplementedError 516 517 518# 519# Elementwise unary operations 520# 521 522 523abs = _make_elementwise_unary_prim( 524 "abs", 525 impl_aten=torch.abs, 526 doc="", 527 type_promotion=ELEMENTWISE_PRIM_TYPE_PROMOTION_KIND.COMPLEX_TO_FLOAT, 528) 529 530acos = _make_elementwise_unary_prim( 531 "acos", 532 impl_aten=torch.acos, 533 doc="", 534 type_promotion=ELEMENTWISE_PRIM_TYPE_PROMOTION_KIND.DEFAULT, 535) 536 537acosh = _make_elementwise_unary_prim( 538 "acosh", 539 impl_aten=torch.acosh, 540 doc="", 541 type_promotion=ELEMENTWISE_PRIM_TYPE_PROMOTION_KIND.DEFAULT, 542) 543 544asin = _make_elementwise_unary_prim( 545 "asin", 546 impl_aten=torch.asin, 547 doc="", 548 type_promotion=ELEMENTWISE_PRIM_TYPE_PROMOTION_KIND.DEFAULT, 549) 550 551asinh = _make_elementwise_unary_prim( 552 "asinh", 553 impl_aten=torch.asinh, 554 doc="", 555 type_promotion=ELEMENTWISE_PRIM_TYPE_PROMOTION_KIND.DEFAULT, 556) 557 558atan = _make_elementwise_unary_prim( 559 "atan", 560 impl_aten=torch.atan, 561 doc="", 562 type_promotion=ELEMENTWISE_PRIM_TYPE_PROMOTION_KIND.DEFAULT, 563) 564 565atanh = _make_elementwise_unary_prim( 566 "atanh", 567 impl_aten=torch.atanh, 568 doc="", 569 type_promotion=ELEMENTWISE_PRIM_TYPE_PROMOTION_KIND.DEFAULT, 570) 571 572cos = _make_elementwise_unary_prim( 573 "cos", 574 impl_aten=torch.cos, 575 doc="", 576 type_promotion=ELEMENTWISE_PRIM_TYPE_PROMOTION_KIND.DEFAULT, 577) 578 579cosh = _make_elementwise_unary_prim( 580 "cosh", 581 impl_aten=torch.cosh, 582 doc="", 583 type_promotion=ELEMENTWISE_PRIM_TYPE_PROMOTION_KIND.DEFAULT, 584) 585 586bessel_j0 = _make_elementwise_unary_prim( 587 "bessel_j0", 588 impl_aten=torch.special.bessel_j0, 589 doc="", 590 type_promotion=ELEMENTWISE_PRIM_TYPE_PROMOTION_KIND.DEFAULT, 591) 592 593bessel_j1 = _make_elementwise_unary_prim( 594 "bessel_j1", 595 impl_aten=torch.special.bessel_j1, 596 doc="", 597 type_promotion=ELEMENTWISE_PRIM_TYPE_PROMOTION_KIND.DEFAULT, 598) 599 600bessel_i0 = _make_elementwise_unary_prim( 601 "bessel_i0", 602 impl_aten=torch.i0, 603 doc="", 604 type_promotion=ELEMENTWISE_PRIM_TYPE_PROMOTION_KIND.DEFAULT, 605) 606 607bessel_i0e = _make_elementwise_unary_prim( 608 "bessel_i0e", 609 impl_aten=torch.special.i0e, 610 doc="", 611 type_promotion=ELEMENTWISE_PRIM_TYPE_PROMOTION_KIND.DEFAULT, 612) 613 614bessel_i1 = _make_elementwise_unary_prim( 615 "bessel_i1", 616 impl_aten=torch.special.i1, 617 doc="", 618 type_promotion=ELEMENTWISE_PRIM_TYPE_PROMOTION_KIND.DEFAULT, 619) 620 621bessel_i1e = _make_elementwise_unary_prim( 622 "bessel_i1e", 623 impl_aten=torch.special.i1e, 624 doc="", 625 type_promotion=ELEMENTWISE_PRIM_TYPE_PROMOTION_KIND.DEFAULT, 626) 627 628bitwise_not = _make_elementwise_unary_prim( 629 "bitwise_not", 630 impl_aten=torch.bitwise_not, 631 doc="", 632 type_promotion=ELEMENTWISE_PRIM_TYPE_PROMOTION_KIND.DEFAULT, 633) 634 635 636def _cbrt_aten(a: torch.Tensor) -> Tensor: 637 torch._check( 638 not a.is_complex(), 639 lambda: "cbrt: Complex inputs not supported. Consider calling torch.pow(a, 1.0/3.0)", 640 ) 641 # Returns the real cubic root of the number. 642 # Note that if a < 0, pow(a, (1. / 3.)) returns th complex number 643 # exp(1/3 * log(a)) = exp(1/3 * (log(abs(a)) + pi*i)) = cbrt(abs(a)) * e^{pi/3*i} 644 # which is a complex number. 645 # For more info see the section Note in 646 # https://en.cppreference.com/w/cpp/numeric/math/cbrt 647 return torch.copysign(torch.pow(a.abs(), 1 / 3), a) 648 649 650cbrt = _make_elementwise_unary_prim( 651 "cbrt", 652 impl_aten=_cbrt_aten, 653 doc="", 654 type_promotion=ELEMENTWISE_PRIM_TYPE_PROMOTION_KIND.DEFAULT, 655) 656 657ceil = _make_elementwise_unary_prim( 658 "ceil", 659 impl_aten=torch.ceil, 660 doc="", 661 type_promotion=ELEMENTWISE_PRIM_TYPE_PROMOTION_KIND.DEFAULT, 662) 663 664 665def _conj_physical_meta(input: TensorLikeType) -> TensorLikeType: 666 if not input.dtype.is_complex: 667 raise RuntimeError("prims.conj_physical is only defined for complex dtypes") 668 669 strides = utils.compute_elementwise_output_strides(input) 670 return TensorMeta(input, strides=strides) 671 672 673conj_physical = _make_prim( 674 schema="conj_physical(Tensor self) -> Tensor", 675 meta=_conj_physical_meta, 676 impl_aten=torch._conj_physical, 677 doc="Returns the physical conjugation of a complex tensor", 678 return_type=RETURN_TYPE.NEW, 679) 680 681 682def _clone_meta( 683 input: TensorLikeType, *, memory_format: torch.memory_format = torch.preserve_format 684) -> TensorLikeType: 685 if memory_format != torch.preserve_format: 686 return torch.empty( 687 input.shape, 688 dtype=input.dtype, 689 layout=input.layout, 690 device=input.device, 691 memory_format=memory_format, 692 ) 693 694 # memory_format == torch.preserve_format 695 strides = utils.compute_elementwise_output_strides(input) 696 return torch.empty_strided( 697 input.shape, 698 strides, 699 dtype=input.dtype, 700 layout=input.layout, 701 device=input.device, 702 ) 703 704 705clone = _make_prim( 706 schema="clone(Tensor self, *, MemoryFormat? memory_format=None) -> Tensor", 707 meta=_clone_meta, 708 impl_aten=torch.clone, 709 doc="Returns the copy of a tensor", 710 return_type=RETURN_TYPE.NEW, 711 register_conj_neg_fallthrough=True, 712) 713 714digamma = _make_elementwise_unary_prim( 715 "digamma", 716 impl_aten=torch.digamma, 717 doc="", 718 type_promotion=ELEMENTWISE_PRIM_TYPE_PROMOTION_KIND.DEFAULT, 719) 720 721erf = _make_elementwise_unary_prim( 722 "erf", 723 impl_aten=torch.erf, 724 doc="", 725 type_promotion=ELEMENTWISE_PRIM_TYPE_PROMOTION_KIND.DEFAULT, 726) 727 728erf_inv = _make_elementwise_unary_prim( 729 "erf_inv", 730 impl_aten=torch.special.erfinv, 731 doc="", 732 type_promotion=ELEMENTWISE_PRIM_TYPE_PROMOTION_KIND.DEFAULT, 733) 734 735erfc = _make_elementwise_unary_prim( 736 "erfc", 737 impl_aten=torch.special.erfc, 738 doc="", 739 type_promotion=ELEMENTWISE_PRIM_TYPE_PROMOTION_KIND.DEFAULT, 740) 741 742erfcx = _make_elementwise_unary_prim( 743 "erfcx", 744 impl_aten=torch.special.erfcx, 745 doc="", 746 type_promotion=ELEMENTWISE_PRIM_TYPE_PROMOTION_KIND.DEFAULT, 747) 748 749exp = _make_elementwise_unary_prim( 750 "exp", 751 impl_aten=torch.exp, 752 doc="", 753 type_promotion=ELEMENTWISE_PRIM_TYPE_PROMOTION_KIND.DEFAULT, 754) 755 756expm1 = _make_elementwise_unary_prim( 757 "expm1", 758 impl_aten=torch.special.expm1, 759 doc="", 760 type_promotion=ELEMENTWISE_PRIM_TYPE_PROMOTION_KIND.DEFAULT, 761) 762 763exp2 = _make_elementwise_unary_prim( 764 "exp2", 765 impl_aten=torch.special.exp2, 766 doc="", 767 type_promotion=ELEMENTWISE_PRIM_TYPE_PROMOTION_KIND.DEFAULT, 768) 769 770 771def _fill_meta(a: TensorLikeType, value: NumberType) -> TensorLikeType: 772 return _prim_elementwise_meta( 773 a, type_promotion=ELEMENTWISE_PRIM_TYPE_PROMOTION_KIND.DEFAULT 774 ) 775 776 777# NOTE: fill uses _make_prim directly because it has a value parameter 778fill = _make_prim( 779 schema="fill(Tensor self, Scalar value) -> Tensor", 780 return_type=RETURN_TYPE.NEW, 781 meta=_fill_meta, 782 impl_aten=torch.fill, 783 doc="", 784) 785 786floor = _make_elementwise_unary_prim( 787 "floor", 788 impl_aten=torch.floor, 789 doc="", 790 type_promotion=ELEMENTWISE_PRIM_TYPE_PROMOTION_KIND.DEFAULT, 791) 792 793imag = _make_prim( 794 schema="imag(Tensor(a) self) -> Tensor(a)", 795 meta=partial( 796 _complex_only_elementwise_meta, 797 type_promotion=ELEMENTWISE_PRIM_TYPE_PROMOTION_KIND.COMPLEX_TO_FLOAT, 798 ), 799 return_type=RETURN_TYPE.VIEW, 800 impl_aten=torch.imag, 801 doc="", 802) 803 804isfinite = _make_elementwise_unary_prim( 805 "isfinite", 806 impl_aten=torch.isfinite, 807 doc="", 808 type_promotion=ELEMENTWISE_PRIM_TYPE_PROMOTION_KIND.ALWAYS_BOOL, 809) 810 811lgamma = _make_elementwise_unary_prim( 812 "lgamma", 813 impl_aten=torch.lgamma, 814 doc="", 815 type_promotion=ELEMENTWISE_PRIM_TYPE_PROMOTION_KIND.DEFAULT, 816) 817 818log = _make_elementwise_unary_prim( 819 "log", 820 impl_aten=torch.log, 821 doc="", 822 type_promotion=ELEMENTWISE_PRIM_TYPE_PROMOTION_KIND.DEFAULT, 823) 824 825log1p = _make_elementwise_unary_prim( 826 "log1p", 827 impl_aten=torch.log1p, 828 doc="", 829 type_promotion=ELEMENTWISE_PRIM_TYPE_PROMOTION_KIND.DEFAULT, 830) 831 832log2 = _make_elementwise_unary_prim( 833 "log2", 834 impl_aten=torch.log2, 835 doc="", 836 type_promotion=ELEMENTWISE_PRIM_TYPE_PROMOTION_KIND.DEFAULT, 837) 838 839log10 = _make_elementwise_unary_prim( 840 "log10", 841 impl_aten=torch.log10, 842 doc="", 843 type_promotion=ELEMENTWISE_PRIM_TYPE_PROMOTION_KIND.DEFAULT, 844) 845 846real = _make_prim( 847 schema="real(Tensor(a) self) -> Tensor(a)", 848 meta=partial( 849 _complex_only_elementwise_meta, 850 type_promotion=ELEMENTWISE_PRIM_TYPE_PROMOTION_KIND.COMPLEX_TO_FLOAT, 851 ), 852 return_type=RETURN_TYPE.VIEW, 853 impl_aten=torch.real, 854 doc="", 855) 856 857reciprocal = _make_elementwise_unary_prim( 858 "reciprocal", 859 impl_aten=torch.reciprocal, 860 doc="", 861 type_promotion=ELEMENTWISE_PRIM_TYPE_PROMOTION_KIND.DEFAULT, 862) 863 864ndtri = _make_elementwise_unary_prim( 865 "ndtri", 866 impl_aten=torch.special.ndtri, 867 doc="", 868 type_promotion=ELEMENTWISE_PRIM_TYPE_PROMOTION_KIND.DEFAULT, 869) 870 871neg = _make_elementwise_unary_prim( 872 "neg", 873 impl_aten=torch.neg, 874 doc="", 875 type_promotion=ELEMENTWISE_PRIM_TYPE_PROMOTION_KIND.DEFAULT, 876) 877 878round = _make_elementwise_unary_prim( 879 "round", 880 impl_aten=torch.round, 881 doc="", 882 type_promotion=ELEMENTWISE_PRIM_TYPE_PROMOTION_KIND.DEFAULT, 883) 884 885rsqrt = _make_elementwise_unary_prim( 886 "rsqrt", 887 impl_aten=torch.rsqrt, 888 doc="", 889 type_promotion=ELEMENTWISE_PRIM_TYPE_PROMOTION_KIND.DEFAULT, 890) 891 892sign = _make_elementwise_unary_prim( 893 "sign", 894 impl_aten=torch.sign, 895 doc="", 896 type_promotion=ELEMENTWISE_PRIM_TYPE_PROMOTION_KIND.DEFAULT, 897) 898 899signbit = _make_elementwise_unary_prim( 900 "signbit", 901 impl_aten=torch.signbit, 902 doc="", 903 type_promotion=ELEMENTWISE_PRIM_TYPE_PROMOTION_KIND.DEFAULT, 904) 905 906sin = _make_elementwise_unary_prim( 907 "sin", 908 impl_aten=torch.sin, 909 doc="", 910 type_promotion=ELEMENTWISE_PRIM_TYPE_PROMOTION_KIND.DEFAULT, 911) 912 913sinh = _make_elementwise_unary_prim( 914 "sinh", 915 impl_aten=torch.sinh, 916 doc="", 917 type_promotion=ELEMENTWISE_PRIM_TYPE_PROMOTION_KIND.DEFAULT, 918) 919 920spherical_bessel_j0 = _make_elementwise_unary_prim( 921 "spherical_bessel_j0", 922 impl_aten=torch.special.spherical_bessel_j0, 923 doc="", 924 type_promotion=ELEMENTWISE_PRIM_TYPE_PROMOTION_KIND.DEFAULT, 925) 926 927sqrt = _make_elementwise_unary_prim( 928 "sqrt", 929 impl_aten=torch.sqrt, 930 doc="", 931 type_promotion=ELEMENTWISE_PRIM_TYPE_PROMOTION_KIND.DEFAULT, 932) 933 934tan = _make_elementwise_unary_prim( 935 "tan", 936 impl_aten=torch.tan, 937 doc="", 938 type_promotion=ELEMENTWISE_PRIM_TYPE_PROMOTION_KIND.DEFAULT, 939) 940 941tanh = _make_elementwise_unary_prim( 942 "tanh", 943 impl_aten=torch.tanh, 944 doc="", 945 type_promotion=ELEMENTWISE_PRIM_TYPE_PROMOTION_KIND.DEFAULT, 946) 947 948trunc = _make_elementwise_unary_prim( 949 "trunc", 950 impl_aten=torch.trunc, 951 doc="", 952 type_promotion=ELEMENTWISE_PRIM_TYPE_PROMOTION_KIND.DEFAULT, 953) 954 955# 956# Elementwise binary operations 957# 958 959add = _make_elementwise_binary_prim( 960 name="add", 961 impl_aten=torch.add, 962 doc="", 963 type_promotion=ELEMENTWISE_PRIM_TYPE_PROMOTION_KIND.DEFAULT, 964) 965 966atan2 = _make_elementwise_binary_prim( 967 name="atan2", 968 impl_aten=torch.atan2, 969 doc="", 970 type_promotion=ELEMENTWISE_PRIM_TYPE_PROMOTION_KIND.DEFAULT, 971) 972 973bitwise_and = _make_elementwise_binary_prim( 974 "bitwise_and", 975 impl_aten=torch.bitwise_and, 976 doc="", 977 type_promotion=ELEMENTWISE_PRIM_TYPE_PROMOTION_KIND.DEFAULT, 978) 979 980bitwise_or = _make_elementwise_binary_prim( 981 "bitwise_or", 982 impl_aten=torch.bitwise_or, 983 doc="", 984 type_promotion=ELEMENTWISE_PRIM_TYPE_PROMOTION_KIND.DEFAULT, 985) 986 987bitwise_xor = _make_elementwise_binary_prim( 988 "bitwise_xor", 989 impl_aten=torch.bitwise_xor, 990 doc="", 991 type_promotion=ELEMENTWISE_PRIM_TYPE_PROMOTION_KIND.DEFAULT, 992) 993 994# TODO: complex needs a special meta to account for its float -> complex behavior 995# complex = _make_elementwise_binary_prim( 996# impl_aten=torch.complex, 997# doc="", 998# ) 999 1000 1001# div prim performs truncation division on integer inputs 1002# and true division for floating and complex inputs 1003def _div_aten(a, b): 1004 is_integral = isinstance(a, (bool, int, torch.SymInt)) or ( 1005 isinstance(a, torch.Tensor) and utils.is_integer_dtype(a.dtype) 1006 ) 1007 1008 if is_integral: 1009 return torch.div(a, b, rounding_mode="trunc") 1010 else: 1011 return torch.true_divide(a, b) 1012 1013 1014div = _make_elementwise_binary_prim( 1015 "div", 1016 impl_aten=_div_aten, 1017 doc="", 1018 type_promotion=ELEMENTWISE_PRIM_TYPE_PROMOTION_KIND.DEFAULT, 1019) 1020 1021eq = _make_elementwise_binary_prim( 1022 "eq", 1023 impl_aten=torch.eq, 1024 doc="", 1025 type_promotion=ELEMENTWISE_PRIM_TYPE_PROMOTION_KIND.ALWAYS_BOOL, 1026) 1027 1028fmax = _make_elementwise_binary_prim( 1029 "fmax", 1030 impl_aten=torch.fmax, 1031 doc="", 1032 type_promotion=ELEMENTWISE_PRIM_TYPE_PROMOTION_KIND.DEFAULT, 1033) 1034 1035fmin = _make_elementwise_binary_prim( 1036 "fmin", 1037 impl_aten=torch.fmin, 1038 doc="", 1039 type_promotion=ELEMENTWISE_PRIM_TYPE_PROMOTION_KIND.DEFAULT, 1040) 1041 1042fmod = _make_elementwise_binary_prim( 1043 "fmod", 1044 impl_aten=torch.fmod, 1045 doc="", 1046 type_promotion=ELEMENTWISE_PRIM_TYPE_PROMOTION_KIND.DEFAULT, 1047) 1048 1049 1050gcd = _make_elementwise_binary_prim( 1051 "gcd", 1052 impl_aten=torch.gcd, 1053 doc="", 1054 type_promotion=ELEMENTWISE_PRIM_TYPE_PROMOTION_KIND.DEFAULT, 1055) 1056 1057 1058ge = _make_elementwise_binary_prim( 1059 "ge", 1060 impl_aten=torch.ge, 1061 doc="", 1062 type_promotion=ELEMENTWISE_PRIM_TYPE_PROMOTION_KIND.ALWAYS_BOOL, 1063) 1064 1065gt = _make_elementwise_binary_prim( 1066 "gt", 1067 impl_aten=torch.gt, 1068 doc="", 1069 type_promotion=ELEMENTWISE_PRIM_TYPE_PROMOTION_KIND.ALWAYS_BOOL, 1070) 1071 1072hypot = _make_elementwise_binary_prim( 1073 "hypot", 1074 impl_aten=torch.hypot, 1075 doc="", 1076 type_promotion=ELEMENTWISE_PRIM_TYPE_PROMOTION_KIND.DEFAULT, 1077) 1078 1079igamma = _make_elementwise_binary_prim( 1080 "igamma", 1081 impl_aten=torch.special.gammainc, 1082 doc="", 1083 type_promotion=ELEMENTWISE_PRIM_TYPE_PROMOTION_KIND.DEFAULT, 1084) 1085 1086igammac = _make_elementwise_binary_prim( 1087 "igammac", 1088 impl_aten=torch.special.gammaincc, 1089 doc="", 1090 type_promotion=ELEMENTWISE_PRIM_TYPE_PROMOTION_KIND.DEFAULT, 1091) 1092 1093le = _make_elementwise_binary_prim( 1094 "le", 1095 impl_aten=torch.le, 1096 doc="", 1097 type_promotion=ELEMENTWISE_PRIM_TYPE_PROMOTION_KIND.ALWAYS_BOOL, 1098) 1099 1100lt = _make_elementwise_binary_prim( 1101 "lt", 1102 impl_aten=torch.lt, 1103 doc="", 1104 type_promotion=ELEMENTWISE_PRIM_TYPE_PROMOTION_KIND.ALWAYS_BOOL, 1105) 1106 1107 1108# Note: the following impls are because torch.maximum and torch.minimum do not support scalar inputs 1109def _maximum_aten( 1110 a: Union[TensorLikeType, NumberType], b: Union[TensorLikeType, NumberType] 1111) -> TensorLikeType: 1112 if isinstance(a, TensorLike) and isinstance(b, Number): 1113 b = scalar_tensor(b, dtype=a.dtype, device=a.device) 1114 elif isinstance(b, TensorLike) and isinstance(a, Number): 1115 a = scalar_tensor(a, dtype=b.dtype, device=b.device) 1116 1117 return torch.maximum(a, b) # type: ignore[arg-type] 1118 1119 1120maximum = _make_elementwise_binary_prim( 1121 "maximum", 1122 impl_aten=_maximum_aten, 1123 doc="", 1124 type_promotion=ELEMENTWISE_PRIM_TYPE_PROMOTION_KIND.DEFAULT, 1125) 1126 1127 1128def _minimum_aten( 1129 a: Union[TensorLikeType, NumberType], b: Union[TensorLikeType, NumberType] 1130) -> TensorLikeType: 1131 if isinstance(a, TensorLike) and isinstance(b, Number): 1132 b = scalar_tensor(b, dtype=a.dtype, device=a.device) 1133 elif isinstance(b, TensorLike) and isinstance(a, Number): 1134 a = scalar_tensor(a, dtype=b.dtype, device=b.device) 1135 1136 return torch.minimum(a, b) # type: ignore[arg-type] 1137 1138 1139minimum = _make_elementwise_binary_prim( 1140 "minimum", 1141 impl_aten=_minimum_aten, 1142 doc="", 1143 type_promotion=ELEMENTWISE_PRIM_TYPE_PROMOTION_KIND.DEFAULT, 1144) 1145 1146mul = _make_elementwise_binary_prim( 1147 "mul", 1148 impl_aten=torch.mul, 1149 doc="", 1150 type_promotion=ELEMENTWISE_PRIM_TYPE_PROMOTION_KIND.DEFAULT, 1151) 1152 1153ne = _make_elementwise_binary_prim( 1154 "ne", 1155 impl_aten=torch.ne, 1156 doc="", 1157 type_promotion=ELEMENTWISE_PRIM_TYPE_PROMOTION_KIND.ALWAYS_BOOL, 1158) 1159 1160nextafter = _make_elementwise_binary_prim( 1161 "nextafter", 1162 impl_aten=torch.nextafter, 1163 doc="", 1164 type_promotion=ELEMENTWISE_PRIM_TYPE_PROMOTION_KIND.DEFAULT, 1165) 1166 1167pow = _make_elementwise_binary_prim( 1168 "pow", 1169 impl_aten=torch.pow, 1170 doc="", 1171 type_promotion=ELEMENTWISE_PRIM_TYPE_PROMOTION_KIND.DEFAULT, 1172) 1173 1174remainder = _make_elementwise_binary_prim( 1175 "remainder", 1176 impl_aten=torch.remainder, 1177 doc="", 1178 type_promotion=ELEMENTWISE_PRIM_TYPE_PROMOTION_KIND.DEFAULT, 1179) 1180 1181 1182shift_left = _make_elementwise_binary_prim( 1183 "shift_left", 1184 impl_aten=torch.bitwise_left_shift, 1185 doc="", 1186 type_promotion=ELEMENTWISE_PRIM_TYPE_PROMOTION_KIND.DEFAULT, 1187) 1188 1189shift_right_arithmetic = _make_elementwise_binary_prim( 1190 "shift_right_arithmetic", 1191 impl_aten=torch.bitwise_right_shift, 1192 doc="", 1193 type_promotion=ELEMENTWISE_PRIM_TYPE_PROMOTION_KIND.DEFAULT, 1194) 1195 1196shift_right_logical = _not_impl 1197 1198sub = _make_elementwise_binary_prim( 1199 "sub", 1200 impl_aten=torch.sub, 1201 doc="", 1202 type_promotion=ELEMENTWISE_PRIM_TYPE_PROMOTION_KIND.DEFAULT, 1203) 1204 1205zeta = _make_elementwise_binary_prim( 1206 "zeta", 1207 impl_aten=torch.special.zeta, 1208 doc="", 1209 type_promotion=ELEMENTWISE_PRIM_TYPE_PROMOTION_KIND.DEFAULT, 1210) 1211 1212 1213# 1214# View operations 1215def _as_strided_meta( 1216 a: TensorLikeType, size: ShapeType, stride: StrideType, storage_offset: int 1217) -> TensorLikeType: 1218 assert len(size) == len(stride) 1219 assert storage_offset >= 0 1220 utils.validate_strides(stride) 1221 utils.validate_shape(size) 1222 1223 if reduce(operator.mul, size) == 0: 1224 # NOTE: This special case is to avoid having to acquire the storage below 1225 # as_strided to shapes with no elements are trivially valid, so it's OK 1226 pass 1227 elif isinstance(a, torch.Tensor): 1228 utils.check_in_bounds_for_storage( 1229 a._typed_storage(), size, stride, storage_offset 1230 ) 1231 1232 return torch.as_strided(a, size, stride, storage_offset) 1233 1234 1235def _as_strided_aten( 1236 a: Tensor, size: ShapeType, stride: StrideType, storage_offset: int 1237) -> Tensor: 1238 return torch.as_strided(a, size, stride, storage_offset) 1239 1240 1241_as_strided_doc = """ 1242 Creates a view of the tensor with the given shape (size), strides (stride) and 1243 storage offset (storage_offset). 1244""" 1245 1246as_strided = _make_prim( 1247 schema="as_strided(Tensor(a!) a, SymInt[] size, SymInt[] stride, SymInt storage_offset) -> Tensor(a!)", 1248 meta=_as_strided_meta, 1249 impl_aten=_as_strided_aten, 1250 return_type=RETURN_TYPE.VIEW, 1251 doc=_as_strided_doc, 1252) 1253 1254 1255def _broadcast_in_dim_meta( 1256 a: TensorLikeType, shape: ShapeType, broadcast_dimensions: Sequence[int] 1257): 1258 from torch.fx.experimental.symbolic_shapes import guard_size_oblivious 1259 1260 # Type checks 1261 assert isinstance(a, TensorLike) 1262 assert isinstance(shape, Sequence) 1263 assert isinstance(broadcast_dimensions, Sequence) 1264 1265 # every dimension must be accounted for 1266 assert a.ndim == len(broadcast_dimensions) 1267 1268 # broadcast shape must have weakly more dimensions 1269 assert len(shape) >= a.ndim 1270 1271 # broadcast_dimensions must be an ascending sequence 1272 # (no relative reordering of dims) of integers and 1273 # each dimension must be within the new shape 1274 def _greater_than_reduce(acc, x): 1275 assert isinstance(x, Dim) 1276 assert x > acc 1277 assert x < len(shape) 1278 1279 return x 1280 1281 reduce(_greater_than_reduce, broadcast_dimensions, -1) 1282 1283 # shape must be broadcastable to 1284 for idx, new_idx in enumerate(broadcast_dimensions): 1285 if not guard_size_oblivious(a.shape[idx] == 1): 1286 torch._check( 1287 a.shape[idx] == shape[new_idx], 1288 lambda: f"{a.shape[idx]} must be broadcastable to {shape[new_idx]}", 1289 ) 1290 1291 new_strides = [] 1292 original_idx = 0 1293 for idx in range(len(shape)): 1294 if idx in broadcast_dimensions: 1295 # Assigns a stride of zero to dimensions 1296 # which were actually broadcast 1297 if guard_size_oblivious(a.shape[original_idx] != shape[idx]): 1298 new_strides.append(0) 1299 else: 1300 new_strides.append(a.stride()[original_idx]) 1301 original_idx = original_idx + 1 1302 else: 1303 if guard_size_oblivious(shape[idx] != 1): 1304 new_strides.append(0) 1305 elif original_idx == a.ndim: 1306 new_strides.append(1) 1307 else: 1308 new_strides.append(a.stride()[original_idx] * a.size()[original_idx]) 1309 1310 return a.as_strided(shape, new_strides, a.storage_offset()) 1311 1312 1313def _broadcast_in_dim_aten(a, shape, broadcast_dimensions): 1314 s = list(shape) 1315 for broadcast_dimension in broadcast_dimensions: 1316 s[broadcast_dimension] = -1 1317 1318 v = a 1319 for idx, x in enumerate(s): 1320 if x != -1: 1321 v = v.unsqueeze(idx) 1322 1323 return v.expand(shape) 1324 1325 1326_broadcast_in_dim_doc = """ 1327 Creates a view of a with the specified shape. 1328 1329 Allows adding dimensions of any length and broadcasting 1330 dimensions of length one in a to any length. 1331 1332 The location of the broadcast dimensions must be specified 1333 using the broadcast_dimensions argument. Changing the 1334 relative order of dimensions is not supported. 1335 """ 1336 1337broadcast_in_dim = _make_prim( 1338 schema="broadcast_in_dim(Tensor(a) a, SymInt[] shape, int[] broadcast_dimensions) -> Tensor(a)", 1339 meta=_broadcast_in_dim_meta, 1340 impl_aten=_broadcast_in_dim_aten, 1341 return_type=RETURN_TYPE.VIEW, 1342 doc=_broadcast_in_dim_doc, 1343) 1344 1345 1346def _validate_collapse_args(a: Tensor, start: int, end: int) -> None: 1347 # Special-case for zero dimensional tensors 1348 ndim = max(1, a.dim()) 1349 utils.validate_idx(ndim, start) 1350 utils.validate_idx(ndim, end) 1351 1352 # Verifies end is strictly greater than start 1353 # (Collapse requires a non-empty interval) 1354 torch._check_value( 1355 end >= start, 1356 lambda: f"Attempting to collapse but end, {end}, is less than start, {start}!", 1357 ) 1358 1359 1360def _collapsed_shape(shape: ShapeType, start: int, end: int) -> Tuple[int, ...]: 1361 """ 1362 Returns the shape of a with dims in [start, end) merged into a single dimension. 1363 """ 1364 # Special-case for zero dimensional tensors 1365 shape = (1,) if len(shape) == 0 else tuple(shape) 1366 1367 dim_length = 1 1368 for s in shape[start : end + 1]: 1369 dim_length = dim_length * s 1370 1371 return shape[0:start] + (dim_length,) + shape[end + 1 :] 1372 1373 1374def _collapse_view_helper( 1375 a: TensorLikeType, start: int, end: int 1376) -> Tuple[Optional[ShapeType], Optional[StrideType]]: 1377 assert isinstance(a, TensorLike) 1378 1379 from torch.fx.experimental.symbolic_shapes import guard_size_oblivious 1380 1381 _validate_collapse_args(a, start, end) 1382 1383 # Special-case for zero dimensional tensors 1384 if a.ndim == 0: 1385 shape = (1,) 1386 strides = (1,) 1387 else: 1388 shape = a.shape # type: ignore[assignment] 1389 strides = a.stride() # type: ignore[assignment] 1390 1391 if a.ndim == 0 or (end == start): 1392 return shape, strides 1393 1394 length = shape[end] 1395 stride = strides[end] 1396 for idx in range(end - 1, start - 1, -1): 1397 if guard_size_oblivious(shape[idx] == 0) or guard_size_oblivious( 1398 shape[idx + 1] == 0 1399 ): 1400 length = 0 1401 stride = 0 1402 break 1403 1404 if guard_size_oblivious(shape[idx] == 1): 1405 continue 1406 1407 length = length * shape[idx] 1408 if guard_size_oblivious(stride < strides[idx]): 1409 stride = stride 1410 else: 1411 stride = strides[idx] 1412 1413 if ( 1414 guard_size_oblivious(a.numel() > 0) 1415 and guard_size_oblivious(shape[idx + 1] != 1) 1416 and not guard_size_oblivious( 1417 strides[idx] == strides[idx + 1] * shape[idx + 1] 1418 ) 1419 ): 1420 return None, None 1421 1422 new_shape = shape[:start] + (length,) + shape[end + 1 :] 1423 new_strides = strides[:start] + (stride,) + strides[end + 1 :] 1424 1425 # NOTE: when the input has no elements it's restrided as if it were contiguous 1426 if guard_size_oblivious(a.numel() == 0): 1427 new_strides = utils.make_contiguous_strides_for(new_shape) 1428 1429 return new_shape, new_strides 1430 1431 1432def _collapse_view_meta(a: TensorLikeType, start: int, end: int) -> TensorLikeType: 1433 new_shape, new_strides = _collapse_view_helper(a, start, end) 1434 1435 if new_shape is None: 1436 msg = "Attempting to view a collapsed tensor, but no such view exists!" 1437 raise ValueError(msg) 1438 1439 assert new_strides is not None 1440 return a.as_strided(new_shape, new_strides, a.storage_offset()) 1441 1442 1443def _collapse_view_aten(a: Tensor, start: int, end: int) -> Tensor: 1444 new_shape = _collapsed_shape(a.shape, start, end) 1445 return a.view(new_shape) 1446 1447 1448_collapse_view_doc = """ 1449 Creates a view of a with the dimensions between 1450 start (inclusive) and end (exclusive) merged into a 1451 single dimension. 1452 1453 If it's not possible to take such a view then an error 1454 is thrown. See collapse instead. 1455 1456 The dimensions can be merged if and only if 1457 they are all "nested" with each other. That is, they all 1458 have the property that 1459 1460 stride[i] = stride[i+1] * shape[i+1] 1461 1462 for all i in [start, end - 1). 1463 """ 1464 1465collapse_view = _make_prim( 1466 schema="collapse_view(Tensor(a) a, int start, int end) -> Tensor(a)", 1467 meta=_collapse_view_meta, 1468 impl_aten=_collapse_view_aten, 1469 return_type=RETURN_TYPE.VIEW, 1470 doc=_collapse_view_doc, 1471) 1472 1473 1474def _conj_meta(a: TensorLikeType) -> TensorLikeType: 1475 if not a.dtype.is_complex: 1476 raise RuntimeError("Expected complex dtype in prims.conj") 1477 out = a.as_strided(a.shape, a.stride(), a.storage_offset()) 1478 torch._C._set_conj(out, not a.is_conj()) 1479 return out 1480 1481 1482_conj_doc = """ 1483Returns a conjugated view of the original tensor 1484""" 1485 1486conj = _make_prim( 1487 schema="conj(Tensor(a) a) -> Tensor(a)", 1488 meta=_conj_meta, 1489 impl_aten=torch.conj, 1490 return_type=RETURN_TYPE.VIEW, 1491 doc=_conj_doc, 1492) 1493 1494 1495def expand_dims( 1496 a: TensorLikeType, dimensions: DimsSequenceType, ndim=None 1497) -> TensorLikeType: 1498 """ 1499 Creates a view of a with a.ndim + len(dimensions) dimensions, with new 1500 dimensions of length one at the dimensions specified by dimensions. 1501 """ 1502 if ndim is not None: 1503 # TODO: this is only here to support the unsqueeze ref 1504 dims = sorted(utils.canonicalize_dims(ndim, dimensions)) # type: ignore[arg-type] 1505 else: 1506 dims = sorted(utils.canonicalize_dims(a.ndim, dimensions)) # type: ignore[arg-type] 1507 if len(set(dims)) != len(dims): 1508 msg = f"Received duplicate dimensions to expand in {str(dimensions)}" 1509 raise ValueError(msg) 1510 1511 new_shape = list(a.shape) 1512 for idx in dims: 1513 new_shape.insert(idx, 1) 1514 1515 broadcast_dimensions = [ 1516 idx for idx in range(len(new_shape)) if idx not in dimensions 1517 ] 1518 return broadcast_in_dim(a, new_shape, broadcast_dimensions) 1519 1520 1521# Note: saves the Python slice object because we're about to clobber its name with the slice prim 1522pyslice: Type[slice] = slice # type: ignore[has-type] 1523 1524 1525def _slice_meta( 1526 a: TensorLikeType, 1527 start_indices: DimsSequenceType, 1528 limit_indices: DimsSequenceType, 1529 strides: Optional[StrideType] = None, 1530) -> TensorLikeType: 1531 _strides = strides if strides is not None else [1] * len(start_indices) 1532 1533 if a.ndim != len(start_indices): 1534 msg = f"Attempting to slice tensor of rank {a.ndim} with start_indices of length {len(start_indices)}!" 1535 raise ValueError(msg) 1536 1537 if a.ndim != len(limit_indices): 1538 msg = f"Attempting to slice tensor of rank {a.ndim} with limit_indices of length {len(limit_indices)}!" 1539 raise ValueError(msg) 1540 1541 if a.ndim != len(_strides): 1542 msg = f"Attempting to slice tensor of rank {a.ndim} with strides of length {len(limit_indices)}!" 1543 raise ValueError(msg) 1544 1545 for x, y in zip(start_indices, a.shape): 1546 if x < 0: 1547 msg = f"Attempting to slice a tensor with a negative start index of {x}!" 1548 raise ValueError(msg) 1549 if x > y: 1550 msg = ( 1551 f"Attempting to slice a tensor but a start index in {start_indices} is greater than" 1552 f" the length of its corresponding dimension in shape {a.shape}" 1553 ) 1554 raise ValueError(msg) 1555 1556 for x, y, z in zip(limit_indices, a.shape, start_indices): 1557 if x < 0: 1558 msg = f"Attempting to slice a tensor with a negative stop index of {x}!" 1559 raise ValueError(msg) 1560 if x > y: 1561 msg = ( 1562 f"Attempting to slice a tensor but a stop index in {limit_indices} is greater than the length of " 1563 f" its corresponding dimension in shape {a.shape}" 1564 ) 1565 raise ValueError(msg) 1566 if x < z: 1567 msg = ( 1568 f"Attempting to slice a tensor but a start index in {x} is greater than " 1569 f" its corresponding stop index {z}" 1570 ) 1571 1572 for x in _strides: 1573 if x <= 0: 1574 msg = f"Attempting to slice a tensor with a non-positive step of {x}!" 1575 raise ValueError(msg) 1576 1577 new_shape = [] 1578 for x, y, z in zip(start_indices, limit_indices, _strides): 1579 new_shape.append(1 + (y - x - 1) // z) 1580 1581 new_strides = [] 1582 for x, y in zip(a.stride(), _strides): 1583 new_strides.append(x * y) 1584 1585 return a.as_strided(new_shape, new_strides, a.storage_offset()) 1586 1587 1588def _slice_aten( 1589 a: Tensor, 1590 start_indices: DimsSequenceType, 1591 limit_indices: DimsSequenceType, 1592 strides: Optional[StrideType] = None, 1593) -> Tensor: 1594 _strides = strides if strides is not None else [1] * len(start_indices) 1595 1596 slices = [] 1597 for start, stop, step in zip(start_indices, limit_indices, _strides): 1598 slices.append(pyslice(start, stop, step)) 1599 1600 return operator.getitem(a, slices) # type: ignore[call-overload] 1601 1602 1603_slice_doc = """ 1604 Creates a view of a "bounding box" within the tensor. 1605 1606 The bounding box is specified independently in each of the tensor's dimensions. 1607 start_indices and limit_indices describe the box's boundaries for their corresponding 1608 dimensions. If strides is specified then they specify the step size between elements 1609 in their corresponding dimension. 1610 1611 This operation is analogous to slicing in NumPy, but does not permit slices where 1612 the stop indices are less than the start indices. 1613 """ 1614 1615slice = _make_prim( 1616 schema="slice(Tensor(a) a, SymInt[] start_indices, SymInt[] limit_indices, SymInt[]? strides=None) -> Tensor(a)", 1617 meta=_slice_meta, 1618 impl_aten=_slice_aten, 1619 return_type=RETURN_TYPE.VIEW, 1620 doc=_slice_doc, 1621) 1622 1623 1624def _slice_in_dim_meta( 1625 a: TensorLikeType, 1626 start_index: int, 1627 limit_index: int, 1628 stride: int = 1, 1629 axis: int = 0, 1630) -> TensorLikeType: 1631 if axis < 0: 1632 msg = f"slice_in_dim: received a negative axis {axis}" 1633 raise ValueError(msg) 1634 if axis >= a.ndim: 1635 msg = f"slice_in_dim: axis {axis} is greater or equal to the rank {a.ndim} of the tensor" 1636 raise ValueError(msg) 1637 1638 if start_index < 0: 1639 msg = f"slice_in_dim: received a negative start_index {start_index}" 1640 raise ValueError(msg) 1641 1642 if start_index > a.shape[axis]: 1643 msg = f"slice_in_dim: start_index is greater than the length {start_index} of dimension {axis}" 1644 raise ValueError(msg) 1645 1646 if limit_index > a.shape[axis]: 1647 msg = f"slice_in_dim: limit_index is greater than the length {limit_index} of dimension {axis}" 1648 raise ValueError(msg) 1649 1650 if limit_index < start_index: 1651 msg = f"slice_in_dim: received a limit_index {limit_index} less than the start_index {start_index}" 1652 raise ValueError(msg) 1653 1654 if stride < 0: 1655 msg = f"slice_in_dim: received a non-positive stride of {stride}!" 1656 raise ValueError(msg) 1657 1658 start_indices = [0] * a.ndim 1659 limit_indices = list(a.shape) 1660 strides = [1] * a.ndim 1661 1662 start_indices[axis] = start_index 1663 limit_indices[axis] = limit_index 1664 strides[axis] = stride 1665 1666 return _slice_meta(a, start_indices, limit_indices, strides) 1667 1668 1669def _slice_in_dim_aten( 1670 a: Tensor, 1671 start_index: int, 1672 limit_index: int, 1673 stride: int = 1, 1674 axis: int = 0, 1675) -> Tensor: 1676 start_indices = [0] * a.ndim 1677 limit_indices = list(a.shape) 1678 strides = [1] * a.ndim 1679 1680 start_indices[axis] = start_index 1681 limit_indices[axis] = limit_index 1682 strides[axis] = stride 1683 1684 return slice(a, start_indices, limit_indices, strides) 1685 1686 1687_slice_in_dim_doc = """ 1688 Convenience wrapper for slicing just one dimension using slice. 1689 """ 1690 1691# TODO: make stride SymInt 1692slice_in_dim = _make_prim( 1693 schema="slice_in_dim(Tensor(a) a, SymInt start_index, SymInt limit_index, int stride=1, int axis=0) -> Tensor(a)", 1694 meta=_slice_in_dim_meta, 1695 impl_aten=_slice_in_dim_aten, 1696 return_type=RETURN_TYPE.VIEW, 1697 doc=_slice_in_dim_doc, 1698) 1699 1700 1701def _split_dim_meta(a: TensorLikeType, dim: int, outer_length: int) -> TensorLikeType: 1702 assert isinstance(a, TensorLike) 1703 utils.validate_idx(a.ndim, dim) 1704 utils.validate_dim_length(outer_length) 1705 1706 # Verifies the dim can be split with the specified lhs_length 1707 inner_length = a.shape[dim] // outer_length 1708 1709 if (a.shape[dim] % outer_length) != 0: 1710 msg = ( 1711 f"Attempting to split dimension of length {a.shape[dim]}, " 1712 f"but outer length of {outer_length} divides it with a remainder!" 1713 ) 1714 raise ValueError(msg) 1715 1716 new_shape: List[int] = [] 1717 new_strides: List[int] = [] 1718 for idx in range(a.ndim): 1719 if idx == dim: 1720 new_shape.extend((outer_length, inner_length)) 1721 new_strides.extend((a.stride()[idx] * inner_length, a.stride()[idx])) 1722 else: 1723 new_shape.append(a.shape[idx]) 1724 new_strides.append(a.stride()[idx]) 1725 1726 return a.as_strided(new_shape, new_strides, a.storage_offset()) 1727 1728 1729def _split_dim_aten(a: Tensor, dim: int, outer_length: int) -> Tensor: 1730 inner_length = a.shape[dim] // outer_length 1731 new_shape = a.shape[0:dim] + (outer_length, inner_length) + a.shape[dim + 1 :] 1732 1733 return a.view(new_shape) 1734 1735 1736_split_dim_doc = """ 1737 Creates a view of a with the given dimension (of length l) split 1738 into two dimensions, with the outer of the two having 1739 length outer_length and the inner of the two having computed 1740 length inner_length such outer_length * inner_length = l. 1741 """ 1742 1743# TODO: consider renaming split_dim_view 1744split_dim = _make_prim( 1745 schema="split_dim(Tensor(a) a, int dim, SymInt outer_length) -> Tensor(a)", 1746 meta=_split_dim_meta, 1747 impl_aten=_split_dim_aten, 1748 return_type=RETURN_TYPE.VIEW, 1749 doc=_split_dim_doc, 1750) 1751 1752 1753# Note: allows dimensions to be specified redundantly 1754def _squeeze_meta(a: TensorLikeType, dimensions: Sequence) -> TensorLikeType: 1755 assert isinstance(a, TensorLike) 1756 1757 for idx in dimensions: 1758 utils.validate_idx(a.ndim, idx) 1759 assert a.shape[idx] == 1 1760 1761 new_shape = [] 1762 new_strides = [] 1763 for idx in range(len(a.shape)): 1764 if idx in dimensions: 1765 continue 1766 1767 new_shape.append(a.shape[idx]) 1768 new_strides.append(a.stride()[idx]) 1769 1770 return a.as_strided(new_shape, new_strides, a.storage_offset()) 1771 1772 1773_squeeze_doc = """ 1774 Creates a view of the tensor with the specified dimensions removed. 1775 1776 The removed dimensions must each have length one. 1777 """ 1778 1779squeeze = _make_prim( 1780 schema="squeeze(Tensor(a) a, int[] dimensions) -> Tensor(a)", 1781 meta=_squeeze_meta, 1782 impl_aten=torch.squeeze, 1783 return_type=RETURN_TYPE.VIEW, 1784 doc=_squeeze_doc, 1785) 1786 1787 1788def _transpose_meta(a: TensorLikeType, permutation: DimsSequenceType) -> TensorLikeType: 1789 if a.ndim != len(permutation): 1790 msg = f"Attempting to permute a tensor of rank {a.ndim}, but received a permutation of length {len(permutation)}!" 1791 raise ValueError(msg) 1792 1793 if not utils.is_valid_permutation(a.ndim, permutation): 1794 msg = f"Received an invalid permutation, {permutation}!" 1795 raise ValueError(msg) 1796 1797 new_shape = [0] * a.ndim 1798 new_strides = [0] * a.ndim 1799 for idx, dim in enumerate(permutation): 1800 new_shape[idx] = a.shape[dim] 1801 new_strides[idx] = a.stride()[dim] 1802 1803 return a.as_strided(tuple(new_shape), tuple(new_strides), a.storage_offset()) 1804 1805 1806def _transpose_aten(a: Tensor, permutation: DimsSequenceType) -> Tensor: 1807 return torch.permute(a, permutation) 1808 1809 1810_transpose_doc = """ 1811 Creates a view of the tensor with its dimensions permuted. 1812 1813 The length of the permutation must be the rank of the tensor, 1814 and each element of the permutation specifies the new order 1815 for the corresponding dimension. 1816 """ 1817 1818transpose = _make_prim( 1819 schema="transpose(Tensor(a) a, int[] permutation) -> Tensor(a)", 1820 meta=_transpose_meta, 1821 impl_aten=_transpose_aten, 1822 return_type=RETURN_TYPE.VIEW, 1823 doc=_transpose_doc, 1824) 1825 1826 1827def _view_of_meta(a: TensorLikeType) -> TensorLikeType: 1828 return a.as_strided(a.shape, a.stride(), a.storage_offset()) 1829 1830 1831def _view_of_aten(a: Tensor) -> Tensor: 1832 return a.view(a.shape) 1833 1834 1835_view_of_doc = """ 1836 Creates a view of the tensor. 1837 """ 1838 1839view_of = _make_prim( 1840 schema="view_of(Tensor(a) a) -> Tensor(a)", 1841 meta=_view_of_meta, 1842 impl_aten=_view_of_aten, 1843 return_type=RETURN_TYPE.VIEW, 1844 doc=_view_of_doc, 1845) 1846 1847 1848def _view_element_type_meta(a: TensorLikeType, dtype: torch.dtype) -> TensorLikeType: 1849 return a.view(dtype) 1850 1851 1852def _view_element_type_aten(a: Tensor, dtype: torch.dtype) -> Tensor: 1853 return a.view(dtype) 1854 1855 1856_view_element_type_doc = """ 1857 Creates a view of the tensor with a different dtype. 1858 """ 1859 1860view_element_type = _make_prim( 1861 schema="view_of_dtype(Tensor(a) a, ScalarType dtype) -> Tensor(a)", 1862 meta=_view_element_type_meta, 1863 impl_aten=_view_element_type_aten, 1864 return_type=RETURN_TYPE.VIEW, 1865 doc=_view_element_type_doc, 1866) 1867 1868# 1869# Functionalized view mutations 1870# 1871 1872 1873def _as_strided_scatter_meta( 1874 input: TensorLikeType, 1875 src: TensorLikeType, 1876 size: ShapeType, 1877 stride: StrideType, 1878 storage_offset: int, 1879) -> TensorLikeType: 1880 utils.validate_shape(size) 1881 utils.validate_strides(stride) 1882 1883 required_size = utils.compute_required_storage_length(size, stride, storage_offset) 1884 torch._check( 1885 input.numel() >= required_size, 1886 lambda: ( 1887 f"as_strided_scatter: sizes {size}, strides {stride}, storage offset {storage_offset} " 1888 f" and itemsize {input.element_size()} requiring a storage size of " 1889 f"{required_size * input.element_size()} are out of bounds " 1890 f"for storage of size {input.numel() * input.element_size()}" 1891 ), 1892 ) 1893 torch._check( 1894 utils.is_same_shape(src.shape, size), 1895 lambda: f"expected src to have a size equal to the slice of self. src size = {src.shape}, slice size = {size}", 1896 ) 1897 1898 return utils.clone_preserve_strides(input) 1899 1900 1901_as_strided_scatter_doc = """ 1902 Creates a new tensor equivalent to ``out = input.clone()`` after mutation by 1903 ``out.as_strided(size, stride, storage_offset).copy_(src)``. 1904""" 1905 1906as_strided_scatter = _make_prim( 1907 schema="as_strided_scatter(Tensor self, Tensor src, SymInt[] size, SymInt[] stride, SymInt storage_offset) -> Tensor", 1908 meta=_as_strided_scatter_meta, 1909 impl_aten=torch.as_strided_scatter, 1910 return_type=RETURN_TYPE.NEW, 1911 doc=_as_strided_scatter_doc, 1912) 1913 1914 1915# 1916# Shape operations 1917# 1918 1919 1920def _collapse_meta(a: Tensor, start: int, end: int) -> Tensor: 1921 # Special-case for zero dimensional tensors 1922 _validate_collapse_args(a, start, end) 1923 new_shape = _collapsed_shape(a.shape, start, end) 1924 return a.new_empty(new_shape) 1925 1926 1927def _collapse_aten(a: Tensor, start: int, end: int) -> Tensor: 1928 new_shape = _collapsed_shape(a.shape, start, end) 1929 out = a.new_empty(new_shape) 1930 with torch.no_grad(): 1931 out.view_as(a).copy_(a) 1932 return out 1933 1934 1935_collapse_doc = """ 1936Collapse a span of neighboring dimensions into one. 1937 1938See collapse_view for the corresponding view operation. 1939""" 1940collapse = _make_prim( 1941 schema="collapse(Tensor a, int start, int end) -> Tensor", 1942 meta=_collapse_meta, 1943 impl_aten=_collapse_aten, 1944 return_type=RETURN_TYPE.NEW, 1945 doc=_collapse_doc, 1946) 1947 1948 1949# TODO: review stride logic 1950# NB: unlike torch.cat, this is more strict about empty tensors and dim is 1951# never negative 1952def _cat_meta(tensors: Sequence[TensorLikeType], dim: int) -> TensorLikeType: 1953 # Verifies same shape (except in the concat dimension) 1954 assert dim >= 0 1955 shape = tensors[0].shape 1956 concat_length = 0 1957 for tensor_idx, tensor in enumerate(tensors): 1958 assert len(shape) == len(tensor.shape) 1959 for idx, (common_length, length) in enumerate(zip(shape, tensor.shape)): 1960 if idx == dim: 1961 concat_length = concat_length + length 1962 else: 1963 torch._check( 1964 length == common_length, 1965 lambda: f"Sizes of tensors must match except in dimension {dim}. " 1966 f"Expected {common_length} but got {length} for tensor number " 1967 f"{tensor_idx} in the list", 1968 ) 1969 1970 new_shape = list(tensors[0].shape).copy() 1971 new_shape[dim] = concat_length 1972 return TensorMeta( 1973 tensors[0], 1974 shape=new_shape, 1975 strides=utils.make_contiguous_strides_for(new_shape), 1976 ) 1977 1978 1979def _cat_aten(tensors: Union[Tuple[Tensor, ...], List[Tensor]], dim: int) -> Tensor: 1980 return torch.cat(tensors, dim) 1981 1982 1983_cat_doc = """ 1984 Concatenates tensors along the specified dimension. 1985 1986 The tensors' shapes must have the same rank and same length for other dimensions. 1987 """ 1988 1989cat = _make_prim( 1990 schema="cat(Tensor[] tensors, int dim) -> Tensor", 1991 meta=_cat_meta, 1992 impl_aten=_cat_aten, 1993 return_type=RETURN_TYPE.NEW, 1994 doc=_cat_doc, 1995) 1996 1997 1998def _reshape_meta(a: TensorLikeType, shape: ShapeType): 1999 assert isinstance(a, TensorLike) 2000 utils.validate_shape(shape) 2001 2002 # Validates the tensor and the requested shape have the 2003 # same number of elements 2004 numel = reduce(operator.mul, shape) 2005 if numel != a.numel(): 2006 msg = f"Attempting to reshape a tensor with {a.numel()} elements to a shape with {numel} elements!" 2007 raise ValueError(msg) 2008 2009 return TensorMeta(a, shape=shape, strides=utils.make_contiguous_strides_for(shape)) 2010 2011 2012def _reshape_aten(a: Tensor, shape: ShapeType) -> Tensor: 2013 return a.reshape(shape).contiguous().clone() 2014 2015 2016_reshape_doc = """ 2017 Creates a contiguous tensor with the specified shape 2018 containing a copy of the data in a. 2019 """ 2020reshape = _make_prim( 2021 schema="reshape(Tensor a, SymInt[] shape) -> Tensor", 2022 meta=_reshape_meta, 2023 impl_aten=_reshape_aten, 2024 return_type=RETURN_TYPE.NEW, 2025 doc=_reshape_doc, 2026) 2027 2028 2029def _rev_meta(a: TensorLikeType, dims: DimsSequenceType) -> TensorLikeType: 2030 utils.validate_dimension_indices(a.ndim, dims) 2031 return torch.empty_like(a, memory_format=torch.preserve_format) 2032 2033 2034_rev_doc = """ 2035 Reverses the order of elements along the given dimensions. 2036 """ 2037 2038rev = _make_prim( 2039 schema="rev(Tensor a, int[] dims) -> Tensor", 2040 meta=_rev_meta, 2041 impl_aten=torch.flip, 2042 return_type=RETURN_TYPE.NEW, 2043 doc=_rev_doc, 2044) 2045 2046# 2047# Conditional prims 2048# 2049 2050 2051def _where_meta( 2052 pred: TensorLikeType, a: TensorLikeType, b: TensorLikeType 2053) -> TensorLikeType: 2054 return _prim_elementwise_meta( 2055 a, 2056 b, 2057 type_promotion=ELEMENTWISE_PRIM_TYPE_PROMOTION_KIND.DEFAULT, 2058 args_with_fixed_dtypes=(pred,), 2059 ) 2060 2061 2062_where_doc = """ 2063 Selects elements from a and b according to pred. 2064 2065 Where pred is true the result contains the element from a, and 2066 where pred is false the result contains the element from b. 2067 """ 2068 2069where = _make_prim( 2070 schema="where(Tensor pred, Tensor a, Tensor b) -> Tensor", 2071 meta=_where_meta, 2072 impl_aten=torch.where, 2073 return_type=RETURN_TYPE.NEW, 2074 doc=_where_doc, 2075) 2076 2077 2078# 2079# Type conversions 2080# 2081def _convert_element_type_meta(a: TensorLikeType, dtype: torch.dtype) -> TensorLikeType: 2082 # Type checks 2083 assert isinstance(a, TensorLike) 2084 assert isinstance(dtype, torch.dtype) 2085 2086 # dtype conversion preserves dense strides 2087 if torch._prims_common.is_non_overlapping_and_dense(a): 2088 strides = a.stride() 2089 else: 2090 strides = utils.compute_elementwise_output_strides(a) 2091 2092 return TensorMeta(a, strides=strides, dtype=dtype) 2093 2094 2095def _convert_element_type_aten(a: Tensor, dtype: torch.dtype) -> Tensor: 2096 # Propagates requires grad when possible 2097 if not utils.is_grad_dtype(dtype): 2098 requires_grad = False 2099 else: 2100 # TODO: update meta objects so this can be acquired directly 2101 try: 2102 requires_grad = a.requires_grad 2103 except Exception as e: 2104 requires_grad = False 2105 2106 result = torch.empty_like( 2107 a, device=a.device, dtype=dtype, requires_grad=requires_grad 2108 ) 2109 with torch.no_grad(): 2110 return copy_to(result, a) 2111 2112 2113_convert_element_type_doc = """ 2114 Creates a copy of a tensor with the given dtype. 2115 """ 2116 2117convert_element_type = _make_prim( 2118 schema="convert_element_type(Tensor a, ScalarType dtype) -> Tensor", 2119 meta=_convert_element_type_meta, 2120 impl_aten=_convert_element_type_aten, 2121 return_type=RETURN_TYPE.NEW, 2122 doc=_convert_element_type_doc, 2123 tags=(torch.Tag.pointwise,), 2124) 2125 2126 2127def _device_put_meta( 2128 a: TensorLikeType, device: Union[str, torch.device] 2129) -> TensorLikeType: 2130 assert isinstance(a, TensorLike) 2131 assert isinstance(device, (str, torch.device)) 2132 2133 return TensorMeta(a, device=utils.canonicalize_device(device)) 2134 2135 2136def _device_put_aten(a: Tensor, device: Union[str, torch.device]) -> Tensor: 2137 return a.to(device) 2138 2139 2140_device_put_doc = """ 2141 Creates a copy of a tensor on the given device. 2142 """ 2143 2144device_put = _make_prim( 2145 schema="device_put(Tensor a, Device device) -> Tensor", 2146 meta=_device_put_meta, 2147 impl_aten=_device_put_aten, 2148 return_type=RETURN_TYPE.NEW, 2149 doc=_device_put_doc, 2150) 2151 2152 2153# NOTE: need to model meta scalars 2154# See https://github.com/pytorch/pytorch/issues/78070 2155def _item_meta(a: TensorLikeType) -> FakeTensor: 2156 number_type = utils.dtype_to_type(a.dtype) 2157 return TensorMeta(number_type(-1)) 2158 2159 2160_item_doc = """ 2161 Converts a tensor with one element to a Python number. 2162""" 2163 2164# TODO: create a new return type for scalars? 2165# FIXME: currently returns integers for boolean tensors 2166# https://github.com/pytorch/pytorch/issues/78071 2167item = _make_prim( 2168 schema="item(Tensor a) -> Scalar", 2169 meta=_item_meta, 2170 impl_aten=torch.Tensor.item, 2171 return_type=RETURN_TYPE.NEW, 2172 doc=_item_doc, 2173) 2174 2175 2176# NOTE: need to model meta scalars 2177# See https://github.com/pytorch/pytorch/issues/78070 2178def _maximum_value_meta(dtype: torch.dtype) -> FakeTensor: 2179 number_type = utils.dtype_to_type(dtype) 2180 return TensorMeta(number_type(-1)) 2181 2182 2183def _maximum_value_aten(dtype: torch.dtype): 2184 if dtype == torch.bool: 2185 return True 2186 elif dtype.is_complex or dtype.is_floating_point: 2187 return torch.finfo(dtype).max 2188 else: 2189 return torch.iinfo(dtype).max 2190 2191 2192_maximum_value_doc = """ 2193 Return the maximum finite value for a dtype. 2194""" 2195 2196# TODO: create a new return type for scalars? 2197# FIXME: currently returns integers for boolean tensors 2198# https://github.com/pytorch/pytorch/issues/78071 2199maximum_value = _make_prim( 2200 schema="maximum_value(ScalarType dtype) -> Scalar", 2201 meta=_maximum_value_meta, 2202 impl_aten=_maximum_value_aten, 2203 return_type=RETURN_TYPE.NEW, 2204 doc=_maximum_value_doc, 2205) 2206 2207 2208# NOTE: need to model meta scalars 2209# See https://github.com/pytorch/pytorch/issues/78070 2210def _minimum_value_meta(dtype: torch.dtype) -> FakeTensor: 2211 number_type = utils.dtype_to_type(dtype) 2212 return TensorMeta(number_type(-1)) 2213 2214 2215def _minimum_value_aten(dtype: torch.dtype): 2216 if dtype == torch.bool: 2217 return False 2218 elif dtype.is_complex or dtype.is_floating_point: 2219 return torch.finfo(dtype).min 2220 else: 2221 return torch.iinfo(dtype).min 2222 2223 2224_minimum_value_doc = """ 2225 Return the minimum finite value for a dtype. 2226""" 2227 2228# TODO: create a new return type for scalars? 2229# FIXME: currently returns integers for boolean tensors 2230# https://github.com/pytorch/pytorch/issues/78071 2231minimum_value = _make_prim( 2232 schema="minimum_value(ScalarType dtype) -> Scalar", 2233 meta=_minimum_value_meta, 2234 impl_aten=_minimum_value_aten, 2235 return_type=RETURN_TYPE.NEW, 2236 doc=_minimum_value_doc, 2237) 2238 2239# 2240# Inplace operators 2241# 2242 2243 2244def _copy_to_meta(a: TensorLikeType, b: TensorLikeType): 2245 assert isinstance(a, TensorLike) 2246 assert isinstance(b, TensorLike) 2247 2248 # Validates the cast is safe 2249 # TODO: move this as an option on the reference 2250 # a_typ = utils.dtype_to_type(a.dtype) 2251 # b_typ = utils.dtype_to_type(b.dtype) 2252 # if a_typ is not utils.get_higher_type(a_typ, b_typ): 2253 # raise RuntimeError(str(b.dtype), " can't be cast safely to ", str(a.dtype), "!") 2254 2255 # Validates the tensors have the same number of elements 2256 if a.numel() != b.numel(): 2257 msg = f"Attempting to copy {b.numel()} elements to a tensor with {a.numel()} elements!" 2258 raise RuntimeError(msg) 2259 2260 return a 2261 2262 2263def _copy_to_aten(a: Tensor, b: Tensor) -> Tensor: 2264 return a.copy_(b) 2265 2266 2267_copy_to_doc = """ 2268 Copies the data in b to a and returns the modified a. 2269 """ 2270 2271# TODO: Remove safe casting and implement on reference instead 2272copy_to = _make_prim( 2273 schema="copy_to(Tensor(a!) a, Tensor b) -> Tensor(a!)", 2274 meta=_copy_to_meta, 2275 impl_aten=_copy_to_aten, 2276 return_type=RETURN_TYPE.INPLACE, 2277 doc=_copy_to_doc, 2278 register_conj_neg_fallthrough=True, 2279) 2280 2281 2282def _copy_strided_meta(a: TensorLikeType, stride: ShapeType): 2283 assert isinstance(a, TensorLike) 2284 return torch.empty_strided( 2285 a.shape, 2286 stride, 2287 dtype=a.dtype, 2288 layout=a.layout, 2289 device=a.device, 2290 requires_grad=a.requires_grad, 2291 ) 2292 2293 2294def _copy_strided_aten(a: Tensor, stride: ShapeType) -> Tensor: 2295 out = torch.empty_strided( 2296 a.size(), 2297 stride=stride, 2298 dtype=a.dtype, 2299 layout=a.layout, 2300 device=a.device, 2301 requires_grad=a.requires_grad, 2302 ) 2303 out.copy_(a) 2304 return out 2305 2306 2307_copy_strided_doc = """ 2308 Copies the data in a to a new tensor, the new tensor has same shape with a size, but has different stride. 2309 """ 2310 2311 2312copy_strided = _make_prim( 2313 schema="copy_strided(Tensor a, SymInt[] stride) -> Tensor", 2314 meta=_copy_strided_meta, 2315 impl_aten=_copy_strided_aten, 2316 return_type=RETURN_TYPE.NEW, 2317 doc=_copy_strided_doc, 2318) 2319 2320 2321def _resize_meta(a: TensorLikeType, shape: ShapeType): 2322 return a.resize_(shape) 2323 2324 2325def _resize_aten(a: Tensor, shape: ShapeType) -> Tensor: 2326 return a.resize_(shape) 2327 2328 2329_resize_doc = """ 2330 Gives a tensor with no elements a new shape, returning the modified tensor. 2331 2332 The tensor's strides are contiguous and its values are unitialized. 2333 """ 2334 2335# TODO: review support arbitrary resizes 2336resize = _make_prim( 2337 schema="resize(Tensor(a!) a, SymInt[] shape) -> Tensor(a!)", 2338 meta=_resize_meta, 2339 impl_aten=_resize_aten, 2340 return_type=RETURN_TYPE.INPLACE, 2341 doc=_resize_doc, 2342) 2343 2344 2345def _reduction_meta(inp, dims, *, output_dtype=None): 2346 """ 2347 Meta function for single output reduction operations 2348 Stride logic is incorrect 2349 """ 2350 assert isinstance(inp, TensorLike) 2351 if output_dtype is None: 2352 output_dtype = inp.dtype 2353 output_shape = utils.compute_reduction_output_shape(inp.shape, dims) 2354 return TensorMeta( 2355 shape=output_shape, 2356 strides=utils.make_contiguous_strides_for(output_shape), 2357 dtype=output_dtype, 2358 device=inp.device, 2359 ) 2360 2361 2362def _var_reduction_meta(inp, dims, correction): 2363 if utils.is_complex_dtype(inp.dtype): 2364 output_dtype = utils.corresponding_real_dtype(inp.dtype) 2365 else: 2366 output_dtype = inp.dtype 2367 return _reduction_meta(inp, dims, output_dtype=output_dtype) 2368 2369 2370_sum_doc = """ 2371 Computes the sum of elements in the input tensor over the list of dimensions 2372 specified in the dim argument 2373 """ 2374_xor_sum_doc = """ 2375 Computes the xor sum of elements in the input tensor over the list of dimensions 2376 specified in the dim argument 2377 """ 2378_prod_doc = """ 2379 Computes the product of elements in the input tensor over the list of dimensions 2380 specified in the dim argument 2381 """ 2382_amax_doc = """ 2383 Computes the maximum value of elements in the input tensor over the list of dimensions 2384 specified in the dim argument 2385 """ 2386_amin_doc = """ 2387 Computes the minimum value of elements in the input tensor over the list of dimensions 2388 specified in the dim argument 2389 """ 2390_var_doc = """ 2391 Computes the biased variance of x over the list of dimensions specified in the dim argument 2392 """ 2393 2394 2395def _make_reduction_prim(name: str, impl_aten, doc): 2396 """Creates a reduction prim.""" 2397 return _make_prim( 2398 schema=f"{name}(Tensor inp, int[]? dims, *, ScalarType? output_dtype=None) -> Tensor", 2399 meta=_reduction_meta, 2400 impl_aten=impl_aten, 2401 return_type=RETURN_TYPE.NEW, 2402 doc=doc, 2403 ) 2404 2405 2406def _make_var_reduction_prim(name: str, impl_aten, doc): 2407 """Creates a reduction prim.""" 2408 return _make_prim( 2409 schema=f"{name}(Tensor inp, int[]? dims, float? correction=1, *, ScalarType? output_dtype=None) -> Tensor", 2410 meta=_var_reduction_meta, 2411 impl_aten=impl_aten, 2412 return_type=RETURN_TYPE.NEW, 2413 doc=doc, 2414 ) 2415 2416 2417sum = _make_reduction_prim( 2418 name="sum", 2419 impl_aten=torch.sum, 2420 doc=_sum_doc, 2421) 2422 2423 2424def _xor_sum_aten( 2425 inp: TensorLikeType, 2426 dims: Optional[DimsSequenceType], 2427 *, 2428 dtype: Optional[torch.dtype] = None, 2429) -> Tensor: 2430 raise NotImplementedError("xor_sum only implemented with inductor") 2431 2432 2433xor_sum = _make_reduction_prim( 2434 name="xor_sum", 2435 impl_aten=_xor_sum_aten, 2436 doc=_xor_sum_doc, 2437) 2438 2439 2440def _prod_aten( 2441 inp: TensorLikeType, 2442 dims: Optional[DimsSequenceType], 2443 *, 2444 dtype: Optional[torch.dtype] = None, 2445) -> Tensor: 2446 if dims is not None: 2447 if len(dims) == 0: 2448 return inp.clone() 2449 for d in sorted(dims, reverse=True): 2450 assert d >= 0 2451 inp = torch.prod(inp, d, dtype=dtype) 2452 return inp 2453 else: 2454 return torch.prod(inp, dims, dtype=dtype) 2455 2456 2457prod = _make_reduction_prim( 2458 name="prod", 2459 impl_aten=_prod_aten, 2460 doc=_prod_doc, 2461) 2462 2463 2464# torch.var, but correction is not kwarg-only 2465def torch_var(input, dim=None, correction=1, **kwargs): 2466 return torch.var(input, dim=dim, correction=correction, **kwargs) 2467 2468 2469var = _make_var_reduction_prim( 2470 name="var", 2471 impl_aten=torch_var, 2472 doc=_var_doc, 2473) 2474 2475amax = _make_reduction_prim( 2476 name="amax", 2477 impl_aten=torch.amax, 2478 doc=_amax_doc, 2479) 2480 2481amin = _make_reduction_prim( 2482 name="amin", 2483 impl_aten=torch.amin, 2484 doc=_amin_doc, 2485) 2486 2487 2488_iota_doc = """ 2489 Constructs a 1-D tensor t where ``t[i] == start + i * step``. 2490""" 2491 2492 2493# TODO: layout, pin_memory, memory_format 2494# TODO: model requires_grad on TensorMeta 2495def _iota_meta( 2496 length: int, 2497 *, 2498 start: int, 2499 step: int, 2500 dtype: torch.dtype, 2501 device: torch.device, 2502 requires_grad: bool, 2503) -> TensorLikeType: 2504 torch._check( 2505 utils.is_integer_dtype(dtype), 2506 lambda: "prims.iota only supports integer dtypes", 2507 ) 2508 torch._check(step != 0, lambda: "step must be nonzero") 2509 return torch.empty( 2510 length, 2511 dtype=dtype, 2512 device=device, 2513 requires_grad=requires_grad, 2514 ) 2515 2516 2517def _iota_aten( 2518 length: int, 2519 *, 2520 start: int, 2521 step: int, 2522 dtype: torch.dtype, 2523 device: torch.device, 2524 requires_grad: bool, 2525) -> TensorLikeType: 2526 end = start + length * step 2527 return torch.arange( 2528 start, end, step, dtype=dtype, device=device, requires_grad=requires_grad 2529 ) 2530 2531 2532iota = _make_prim( 2533 schema="iota(SymInt length, *, SymInt start, SymInt step, ScalarType dtype, Device device, bool requires_grad) -> Tensor", # noqa: B950 2534 return_type=RETURN_TYPE.NEW, 2535 meta=_iota_meta, 2536 impl_aten=_iota_aten, 2537 doc=_iota_doc, 2538) 2539 2540 2541# TODO: layout, pin_memory, memory_format 2542# TODO: model requires_grad on TensorMeta 2543def _empty_meta( 2544 shape: ShapeType, *, dtype: torch.dtype, device: torch.device, requires_grad: bool 2545) -> TensorLikeType: 2546 strides = utils.make_contiguous_strides_for(shape) 2547 return TensorMeta(shape=shape, strides=strides, dtype=dtype, device=device) 2548 2549 2550def _empty_aten( 2551 shape: ShapeType, *, dtype: torch.dtype, device: torch.device, requires_grad: bool 2552) -> Tensor: 2553 return torch.empty(shape, dtype=dtype, device=device, requires_grad=requires_grad) 2554 2555 2556_empty_doc = """ 2557 Creates a tensor with uninitialized values and the specified shape, dtype, and device. 2558""" 2559 2560empty = _make_prim( 2561 schema="empty(SymInt[] shape, *, ScalarType dtype, Device device, bool requires_grad) -> Tensor", 2562 meta=_empty_meta, 2563 impl_aten=_empty_aten, 2564 return_type=RETURN_TYPE.NEW, 2565 doc=_empty_doc, 2566) 2567 2568 2569def _empty_strided_meta( 2570 shape: ShapeType, 2571 strides: StrideType, 2572 *, 2573 dtype: torch.dtype, 2574 device: torch.device, 2575 requires_grad: bool, 2576) -> TensorLikeType: 2577 return TensorMeta(shape=shape, strides=strides, dtype=dtype, device=device) 2578 2579 2580_empty_strided_doc = """ 2581 Creates a tensor with uninitialized values. 2582""" 2583 2584# TODO: add layout, pin_memory 2585empty_strided = _make_prim( 2586 schema="empty_strided(SymInt[] shape, SymInt[] strides, *, ScalarType dtype, Device device, bool requires_grad) -> Tensor", 2587 return_type=RETURN_TYPE.NEW, 2588 meta=_empty_strided_meta, 2589 impl_aten=torch.empty_strided, 2590 doc=_empty_strided_doc, 2591) 2592 2593 2594def _empty_permuted_meta( 2595 shape: ShapeType, 2596 physical_layout: DimsSequenceType, 2597 *, 2598 dtype: torch.dtype, 2599 device: torch.device, 2600 requires_grad: bool, 2601) -> TensorLikeType: 2602 p_strides = utils.make_contiguous_strides_for([shape[l] for l in physical_layout]) 2603 dim = len(shape) 2604 torch._check( 2605 len(physical_layout) == dim, 2606 lambda: ( 2607 "Number of dimensions in the tensor input does not match the " 2608 f"length of the physical layout; i.e. len(size) = {dim} " 2609 f"is not equal to len(physical_layout) = {len(physical_layout)}" 2610 ), 2611 ) 2612 strides = [0] * len(shape) 2613 seen_dims = set() 2614 for p, l in enumerate(physical_layout): 2615 torch._check( 2616 0 <= l < dim, 2617 lambda: ( 2618 f"Dimension out of range (expected to be between 0 and {dim - 1}, but got " 2619 f"{l} at index {p}). NB: negative dims " 2620 "not currently supported; file an issue if you want it." 2621 ), 2622 ) 2623 torch._check(l not in seen_dims, lambda: "Duplicate dim not allowed") 2624 strides[l] = p_strides[p] 2625 seen_dims.add(l) 2626 return TensorMeta( 2627 shape=shape, 2628 strides=strides, 2629 dtype=dtype, 2630 device=device, 2631 ) 2632 2633 2634_empty_permuted_doc = """ 2635 Creates a tensor with uninitialized values according to some physical layout, 2636 that is guaranteed to be non-overlapping and dense. 2637""" 2638 2639# TODO: add layout, pin_memory 2640empty_permuted = _make_prim( 2641 schema="empty_permuted(SymInt[] shape, int[] physical_layout, *, ScalarType dtype, Device device, bool requires_grad) -> Tensor", # noqa: B950 2642 return_type=RETURN_TYPE.NEW, 2643 meta=_empty_permuted_meta, 2644 impl_aten=torch.empty_permuted, 2645 doc=_empty_permuted_doc, 2646) 2647 2648 2649def _full_meta( 2650 shape: ShapeType, 2651 fill_value: NumberType, 2652 *, 2653 dtype: torch.dtype, 2654 device: torch.device, 2655 requires_grad: bool, 2656) -> TensorLikeType: 2657 strides = utils.make_contiguous_strides_for(shape) 2658 return TensorMeta(shape=shape, strides=strides, dtype=dtype, device=device) 2659 2660 2661def _full_aten( 2662 shape: ShapeType, 2663 fill_value: NumberType, 2664 *, 2665 dtype: torch.dtype, 2666 device: torch.device, 2667 requires_grad: bool, 2668) -> Tensor: 2669 # Note that Mypy thinks torch.full can't accept a complex fill_value 2670 return torch.full( 2671 shape, fill_value, dtype=dtype, device=device, requires_grad=requires_grad # type: ignore[arg-type] 2672 ) 2673 2674 2675_full_doc = """ 2676 Creates a tensor filled with the given fill value, and with the specified shape, dtype, and device. 2677""" 2678 2679# TODO: add layout 2680full = _make_prim( 2681 schema="full(SymInt[] shape, Scalar fill_value, *, ScalarType dtype, Device device, bool requires_grad) -> Tensor", 2682 meta=_full_meta, 2683 impl_aten=_full_aten, 2684 return_type=RETURN_TYPE.NEW, 2685 doc=_full_doc, 2686) 2687 2688 2689def _full_like_meta( 2690 a: TensorLikeType, 2691 fill_value: NumberType, 2692 *, 2693 dtype: torch.dtype, 2694 device: torch.device, 2695 requires_grad: bool, 2696) -> TensorLikeType: 2697 strides = utils.compute_elementwise_output_strides(a) 2698 if a.numel() == 0: 2699 strides = a.stride() 2700 2701 return TensorMeta(a, strides=strides, dtype=dtype, device=device) 2702 2703 2704def _full_like_aten( 2705 a: Tensor, 2706 fill_value: NumberType, 2707 *, 2708 dtype: torch.dtype, 2709 device: torch.device, 2710 requires_grad: bool, 2711) -> Tensor: 2712 # Note that Mypy thinks torch.full can't accept a complex fill_value 2713 return torch.full_like( 2714 a, fill_value, dtype=dtype, device=device, requires_grad=requires_grad # type: ignore[arg-type] 2715 ) 2716 2717 2718_full_like_doc = """ 2719 Creates a tensor filled with the given fill value, and the same shape, dtype, and device as the 2720 given tensor by default. The dtype and device settings can be overridden 2721 by specifying them explicitly. 2722""" 2723 2724full_like = _make_prim( 2725 schema="full_like(Tensor a, Scalar fill_value, *, ScalarType dtype, Device device, bool requires_grad) -> Tensor", 2726 meta=_full_like_meta, 2727 impl_aten=_full_like_aten, 2728 return_type=RETURN_TYPE.NEW, 2729 doc=_full_like_doc, 2730) 2731 2732 2733def _scalar_tensor_meta( 2734 scalar: NumberType, 2735 *, 2736 dtype: torch.dtype, 2737 device: torch.device, 2738) -> TensorLikeType: 2739 shape: ShapeType = [] 2740 strides = utils.make_contiguous_strides_for(shape) 2741 return TensorMeta(scalar, shape=shape, strides=strides, dtype=dtype, device=device) 2742 2743 2744def _scalar_tensor_aten( 2745 scalar: NumberType, 2746 *, 2747 dtype: torch.dtype, 2748 device: torch.device, 2749) -> Tensor: 2750 if isinstance(scalar, complex) and ( 2751 dtype is None or not utils.is_complex_dtype(dtype) 2752 ): 2753 raise TypeError("Complex scalar requires complex tensor dtype.") 2754 # Note that Mypy thinks torch.scalar can't accept a complex scalar 2755 return torch.scalar_tensor(scalar, dtype=dtype, device=device) # type: ignore[arg-type] 2756 2757 2758_scalar_tensor_doc = """ 2759 Wraps a Number into a Tensor with the specified dtype and device. 2760""" 2761 2762# TODO: add layout and pin_memory support 2763scalar_tensor = _make_prim( 2764 schema="scalar_tensor(Scalar s, *, ScalarType? dtype=None, Device? device=None) -> Tensor", 2765 meta=_scalar_tensor_meta, 2766 impl_aten=_scalar_tensor_aten, 2767 return_type=RETURN_TYPE.NEW, 2768 doc=_scalar_tensor_doc, 2769) 2770 2771 2772# 2773# Linear algebra (linalg) prims 2774# 2775 2776 2777def _svd_meta( 2778 A: TensorLikeType, *, full_matrices: bool 2779) -> Tuple[TensorLikeType, TensorLikeType, TensorLikeType]: 2780 utils.check_is_matrix(A, "linalg.svd") 2781 utils.check_fp_or_complex(A.dtype, "linalg.svd", allow_low_precision_dtypes=False) 2782 2783 A_shape = A.shape 2784 batch = A_shape[:-2] 2785 m, n = A_shape[-2:] 2786 k = min(m, n) 2787 2788 shape_U = batch + (m, m if full_matrices else k) 2789 strides_U = utils.make_contiguous_strides_for(shape_U, row_major=False) 2790 U = TensorMeta(shape=shape_U, strides=strides_U, dtype=A.dtype, device=A.device) 2791 2792 shape_S = batch + (k,) 2793 strides_S = utils.make_contiguous_strides_for(shape_S) 2794 S = TensorMeta( 2795 shape=shape_S, 2796 strides=strides_S, 2797 dtype=utils.corresponding_real_dtype(A.dtype) if A.is_complex() else A.dtype, 2798 device=A.device, 2799 ) 2800 2801 shape_Vh = batch + (n if full_matrices else k, n) 2802 # The CPU backend returns V, but the cuSolver backend returns V^H 2803 # TODO The MAGMA backend returns V, so this is wrong if used with the MAGMA backend 2804 is_cuda = A.device.type == "cuda" 2805 strides_Vh = utils.make_contiguous_strides_for(shape_Vh, row_major=is_cuda) 2806 Vh = TensorMeta(shape=shape_Vh, strides=strides_Vh, dtype=A.dtype, device=A.device) 2807 # Also makes sure this is CUDA or HIP: 2808 # https://pytorch.org/docs/stable/notes/hip.html#checking-for-hip 2809 if A.numel() != 0 and Vh.is_complex() and torch.cuda.is_available(): 2810 Vh = Vh.conj() 2811 return U, S, Vh 2812 2813 2814def _svd_aten( 2815 A: TensorLikeType, *, full_matrices: bool 2816) -> Tuple[Tensor, Tensor, Tensor]: 2817 return torch.linalg.svd(A, full_matrices=full_matrices) 2818 2819 2820_svd_doc = """ 2821 Returns the SVD of a matrix or batch of matrices. 2822 2823 The `full_matrices` flag controls whether the full or reduced SVD decomposition is returned. 2824""" 2825 2826svd = _make_prim( 2827 schema="svd(Tensor A, *, bool full_matrices) -> (Tensor U, Tensor S, Tensor Vh)", 2828 meta=_svd_meta, 2829 impl_aten=_svd_aten, 2830 return_type=(RETURN_TYPE.NEW, RETURN_TYPE.NEW, RETURN_TYPE.NEW), 2831 doc=_svd_doc, 2832) 2833 2834 2835# 2836# Randomness Prims 2837# 2838 2839 2840def _normal_meta( 2841 shape: ShapeType, 2842 *, 2843 mean: Union[float, complex], 2844 std: float, 2845 dtype: torch.dtype, 2846 device: torch.device, 2847 requires_grad: bool, 2848 generator: Optional[torch.Generator] = None, 2849) -> TensorLikeType: 2850 torch._check( 2851 std >= 0.0, 2852 lambda: f"expected non-negative standard deviation, but got std={std}", 2853 ) 2854 2855 torch._check( 2856 utils.is_float_dtype(dtype) or utils.is_complex_dtype(dtype), 2857 lambda: f"expected a floating-point or complex dtype, but got dtype={dtype}", 2858 ) 2859 2860 strides = utils.make_contiguous_strides_for(shape) 2861 return TensorMeta(shape=shape, strides=strides, dtype=dtype, device=device) 2862 2863 2864def _normal_aten( 2865 shape: ShapeType, 2866 *, 2867 mean: Union[float, complex], 2868 std: float, 2869 dtype: torch.dtype, 2870 device: torch.device, 2871 requires_grad: bool, 2872 generator: Optional[torch.Generator] = None, 2873) -> Tensor: 2874 a = torch.empty(shape, dtype=dtype, device=device, requires_grad=requires_grad) 2875 with torch.no_grad(): 2876 # NOTE: normal_ is incorrectly annotated to expect mean to be a float 2877 a.normal_(mean, std, generator=generator) # type: ignore[arg-type] 2878 return a 2879 2880 2881_normal_doc = """ 2882 Constructs a tensor filled with values drawn from a normal distribution with the specified mean 2883 and standard deviation. 2884 2885 Only supports floating-point types. 2886""" 2887 2888normal = _make_prim( 2889 schema=( 2890 "normal(SymInt[] shape, *, Scalar mean, Scalar std, ScalarType dtype, Device device, bool requires_grad, Generator? generator=None) -> Tensor" # noqa: B950 2891 ), 2892 return_type=RETURN_TYPE.NEW, 2893 meta=_normal_meta, 2894 impl_aten=_normal_aten, 2895 doc=_normal_doc, 2896) 2897 2898 2899def _uniform_meta( 2900 shape: ShapeType, 2901 *, 2902 low: float, 2903 high: float, 2904 dtype: torch.dtype, 2905 device: torch.device, 2906 generator: Optional[torch.Generator] = None, 2907) -> TensorLikeType: 2908 strides = utils.make_contiguous_strides_for(shape) 2909 return TensorMeta(shape=shape, strides=strides, dtype=dtype, device=device) 2910 2911 2912def _uniform_aten( 2913 shape: ShapeType, 2914 *, 2915 low: float, 2916 high: float, 2917 dtype: torch.dtype, 2918 device: torch.device, 2919 generator: Optional[torch.Generator] = None, 2920) -> Tensor: 2921 a = torch.empty(shape, dtype=dtype, device=device) 2922 a.uniform_(low, high, generator=generator) 2923 return a 2924 2925 2926_uniform_doc = """ 2927 Constructs a tensor filled with values drawn uniformly from low to high. 2928""" 2929 2930# TODO: we should more seriously review randomness modeling and prims 2931_uniform_helper = _make_prim( 2932 schema=( 2933 "uniform(SymInt[] shape, *, Scalar low, Scalar high, ScalarType dtype, Device device, Generator? generator=None) -> Tensor" 2934 ), 2935 return_type=RETURN_TYPE.NEW, 2936 meta=_uniform_meta, 2937 impl_aten=_uniform_aten, 2938 doc=_uniform_doc, 2939) 2940 2941# 2942# FFT prims 2943# 2944 2945 2946def _fft_r2c_meta( 2947 input: TensorLike, 2948 *, 2949 dim: DimsSequenceType, 2950 onesided: bool, 2951) -> TensorLikeType: 2952 dim = utils.canonicalize_dims(input.ndim, dim) 2953 utils.validate_no_repeating_dims(dim) 2954 2955 shape = list(input.shape) 2956 if onesided: 2957 last_dim = dim[-1] 2958 shape[last_dim] = shape[last_dim] // 2 + 1 2959 2960 dtype = utils.corresponding_complex_dtype(input.dtype) 2961 strides = utils.make_contiguous_strides_for(shape) 2962 return TensorMeta(shape=shape, strides=strides, dtype=dtype, device=input.device) 2963 2964 2965def _fft_r2c_aten( 2966 input: TensorLike, 2967 *, 2968 dim: DimsSequenceType, 2969 onesided: bool, 2970) -> TensorLikeType: 2971 normalization = 0 # No normalization 2972 return torch._fft_r2c(input, dim, normalization, onesided) 2973 2974 2975_fft_r2c_doc = """ 2976 Performs a real to complex Fast Fourier Transform 2977""" 2978 2979 2980fft_r2c = _make_prim( 2981 schema="fft_r2c(Tensor self, *, int[] dim, bool onesided) -> Tensor", 2982 meta=_fft_r2c_meta, 2983 impl_aten=_fft_r2c_aten, 2984 return_type=RETURN_TYPE.NEW, 2985 doc=_fft_r2c_doc, 2986) 2987 2988 2989def _fft_c2c_meta( 2990 input: TensorLike, 2991 *, 2992 dim: DimsSequenceType, 2993 forward: bool, 2994) -> TensorLikeType: 2995 dim = utils.canonicalize_dims(input.ndim, dim) 2996 utils.validate_no_repeating_dims(dim) 2997 2998 shape = input.shape 2999 strides = utils.make_contiguous_strides_for(shape) 3000 return TensorMeta( 3001 shape=shape, strides=strides, dtype=input.dtype, device=input.device 3002 ) 3003 3004 3005def _fft_c2c_aten( 3006 input: TensorLike, 3007 *, 3008 dim: DimsSequenceType, 3009 forward: bool, 3010) -> TensorLikeType: 3011 normalization = 0 # No normalization 3012 return torch._fft_c2c(input, dim, normalization, forward) 3013 3014 3015_fft_c2c_doc = """ 3016 Performs either a Fast Fourier Transform, or its inverse 3017""" 3018 3019 3020fft_c2c = _make_prim( 3021 schema="fft_c2c(Tensor self, *, int[] dim, bool forward) -> Tensor", 3022 meta=_fft_c2c_meta, 3023 impl_aten=_fft_c2c_aten, 3024 return_type=RETURN_TYPE.NEW, 3025 doc=_fft_c2c_doc, 3026) 3027 3028 3029def _fft_c2r_meta( 3030 input: TensorLike, 3031 *, 3032 dim: DimsSequenceType, 3033 last_dim_size: int, 3034) -> TensorLikeType: 3035 dim = utils.canonicalize_dims(input.ndim, dim) 3036 utils.validate_no_repeating_dims(dim) 3037 3038 shape = list(input.shape) 3039 shape[dim[-1]] = last_dim_size 3040 dtype = utils.corresponding_real_dtype(input.dtype) 3041 strides = utils.make_contiguous_strides_for(shape) 3042 return TensorMeta(shape=shape, strides=strides, dtype=dtype, device=input.device) 3043 3044 3045def _fft_c2r_aten( 3046 input: TensorLike, 3047 *, 3048 dim: DimsSequenceType, 3049 last_dim_size: int, 3050) -> TensorLikeType: 3051 normalization = 0 # No normalization 3052 return torch._fft_c2r(input, dim, normalization, last_dim_size) 3053 3054 3055_fft_c2r_doc = """ 3056 Performs a complex to real Inverse Fast Fourier Transform 3057""" 3058 3059 3060fft_c2r = _make_prim( 3061 schema="fft_c2r(Tensor self, *, int[] dim, SymInt last_dim_size) -> Tensor", 3062 meta=_fft_c2r_meta, 3063 impl_aten=_fft_c2r_aten, 3064 return_type=RETURN_TYPE.NEW, 3065 doc=_fft_c2r_doc, 3066) 3067 3068 3069def _frexp_meta(self: TensorLikeType) -> Tuple[TensorLikeType, TensorLikeType]: 3070 torch._check( 3071 self.dtype.is_floating_point, 3072 lambda: "torch.frexp() only supports floating-point dtypes", 3073 ) 3074 return torch.empty_like(self), torch.empty_like(self, dtype=torch.int32) 3075 3076 3077frexp = _make_prim( 3078 schema="frexp(Tensor self) -> (Tensor mantissa, Tensor exponent)", 3079 meta=_frexp_meta, 3080 return_type=(RETURN_TYPE.NEW, RETURN_TYPE.NEW), 3081 impl_aten=torch.frexp, 3082 doc="", 3083) 3084 3085 3086def _make_token_aten() -> TensorLikeType: 3087 return new_token_tensor() 3088 3089 3090_make_token = _make_prim( 3091 schema="_make_token() -> Tensor", 3092 meta=_make_token_aten, 3093 return_type=RETURN_TYPE.NEW, 3094 impl_aten=_make_token_aten, 3095 doc="Creates a token used for keeping track of side effects.", 3096) 3097 3098 3099def _sink_tokens_aten(tokens) -> None: 3100 pass 3101 3102 3103_sink_tokens = _make_prim( 3104 schema="_sink_tokens(Tensor[] tokens) -> ()", 3105 meta=_sink_tokens_aten, 3106 return_type=RETURN_TYPE.NONE, 3107 impl_aten=_sink_tokens_aten, 3108 doc="Sink all of the tokens which were previously used for keeping track of side effects.", 3109) 3110 3111 3112register_rng_prims() 3113register_debug_prims() 3114