1# mypy: allow-untyped-decorators 2# mypy: allow-untyped-defs 3import builtins 4import collections 5import inspect 6import itertools 7import math 8import operator 9import warnings 10from collections.abc import Iterable 11from enum import Enum 12from functools import partial, reduce, singledispatch, wraps 13from typing import Any, Callable, Dict, List, Optional, overload, Sequence, Tuple, Union 14 15import torch 16import torch._prims as prims 17import torch._prims_common as utils 18import torch.utils._pytree as pytree 19from torch import sym_float, sym_int 20from torch._prims_common import ( 21 BoolLike, 22 DeviceLikeType, 23 Dim, 24 DimsSequenceType, 25 DimsType, 26 dtype_to_type, 27 ELEMENTWISE_TYPE_PROMOTION_KIND, 28 FloatLike, 29 FloatWithoutSymFloat, 30 IntLike, 31 is_weakly_lesser_type, 32 Number, 33 NumberType, 34 RealNumberType, 35 REDUCTION_OUTPUT_TYPE_KIND, 36 ShapeType, 37 StrideType, 38 TensorLike, 39 TensorLikeType, 40 TensorOrNumberLikeType, 41 TensorSequenceType, 42) 43from torch._prims_common.wrappers import ( 44 _maybe_convert_to_dtype, 45 _maybe_resize_out, 46 _safe_copy_out, 47 elementwise_type_promotion_wrapper, 48 elementwise_unary_scalar_wrapper, 49 out_wrapper, 50) 51 52 53# Experimental module containing prototype Python references for existing 54# PyTorch operations. 55 56__all__ = [ 57 # 58 # Elementwise Unary References 59 # 60 "abs", 61 "acos", 62 "acosh", 63 "asinh", 64 "asin", 65 "atan", 66 "atanh", 67 "bitwise_not", 68 # "cbrt", # No corresponding torch operation 69 "ceil", 70 "conj_physical", 71 "cos", 72 "cosh", 73 "count_nonzero", 74 "deg2rad", 75 "digamma", 76 "erf", 77 "erfinv", 78 "erfc", 79 "exp", 80 "expm1", 81 "exponential", 82 "exp2", 83 "fill", 84 "fill_", 85 "floor", 86 "frac", 87 "geometric", 88 "index_add", 89 "index_copy", 90 "index_copy_", 91 "index_select", 92 "index_fill", 93 "index_fill_", 94 "isfinite", 95 "isinf", 96 "isposinf", 97 "isneginf", 98 "isnan", 99 "isreal", 100 "i0", 101 "lerp", 102 "lgamma", 103 "log", 104 "log1p", 105 "log2", 106 "log10", 107 "log_normal", 108 "log_softmax", 109 "mvlgamma", 110 "norm", 111 "normal", 112 "nan_to_num", 113 "neg", 114 "positive", 115 "rad2deg", 116 "reciprocal", 117 "round", # TODO: model kwargs 118 "sigmoid", 119 "sgn", 120 "sign", 121 "signbit", 122 "sin", 123 "sinc", 124 "sinh", 125 "softmax", 126 "sqrt", 127 "square", 128 "tan", 129 "tanh", 130 "trace", 131 "trunc", 132 # 133 # Elementwise Binary References 134 # 135 "add", 136 "atan2", 137 "bitwise_and", 138 "bitwise_left_shift", 139 "bitwise_or", 140 "bitwise_right_shift", 141 "bitwise_xor", 142 "clamp_min", 143 "clamp_max", 144 "copysign", 145 "div", 146 "eq", 147 "float_power", 148 "floor_divide", 149 "fmax", 150 "fmin", 151 "fmod", 152 "gcd", 153 "ge", 154 "gt", 155 "heaviside", 156 "hypot", 157 "igamma", 158 "igammac", 159 "imag", 160 "isclose", 161 "lcm", 162 # 'ldexp', 163 "le", 164 "logaddexp", 165 "logaddexp2", 166 "logical_and", 167 "logical_not", 168 "logical_or", 169 "logical_xor", 170 "logsumexp", 171 "lt", 172 # 'max', # implement with reductions 173 "maximum", 174 # 'min', # implement with reductions 175 "minimum", 176 "mul", 177 "ne", 178 "nextafter", 179 # 'polar', # abs, cos, sin 180 "pow", 181 "real", 182 "rpow", 183 "remainder", 184 "rsub", 185 "rtruediv", 186 "rfloordiv", 187 "sub", 188 "true_divide", 189 "trunc_divide", 190 "xlogy", 191 # 192 # Elementwise Ternary References 193 # 194 "addcdiv", 195 "addcmul", 196 "clamp", 197 # 198 # Conditional references 199 # 200 "masked_fill", 201 "masked_fill_", 202 "where", 203 # 204 # Data conversion and movement references 205 # 206 "clone", 207 "copy_to", # TODO: add OpInfo (or implement .to) 208 "item", 209 "to", 210 # 211 # Reduction ops 212 # 213 "all", 214 "amax", 215 "amin", 216 "any", 217 "cumsum", 218 "cumprod", 219 "mean", 220 "dot", 221 "vdot", 222 "std", 223 "std_mean", 224 "sum", 225 "sum_to_size", 226 "prod", 227 "var", 228 "var_mean", 229 # 230 # Linear algebra ops 231 # 232 "addr", 233 # 234 # View & Shape Ops 235 # 236 "alias", 237 "alias_copy", 238 "atleast_1d", 239 "atleast_2d", 240 "atleast_3d", 241 "as_strided", 242 "as_strided_copy", 243 "as_strided_scatter", 244 "block_diag", 245 "broadcast_shapes", 246 "broadcast_tensors", 247 "broadcast_to", 248 "cat", 249 "chunk", 250 "column_stack", 251 "conj", 252 "constant_pad_nd", 253 "contiguous", 254 "diag_embed", 255 "diag", 256 "diagonal", 257 "diagonal_copy", 258 "diagonal_scatter", 259 "dsplit", 260 "dstack", 261 "expand", 262 "expand_as", 263 "expand_copy", 264 "flatten", 265 "flip", 266 "fliplr", 267 "flipud", 268 "hsplit", 269 "hstack", 270 "meshgrid", 271 "movedim", 272 "narrow", 273 "narrow_copy", 274 "native_group_norm", 275 "native_layer_norm", 276 "permute", 277 "ravel", 278 "repeat", 279 "reshape", 280 "reshape_as", 281 "roll", 282 "rot90", 283 "rsqrt", 284 "stack", 285 "swap_axes", # alias for transpose 286 "squeeze", 287 "t", 288 "t_copy", 289 "T", 290 "take_along_dim", 291 "tensor_split", 292 "transpose", 293 "unfold", 294 "unfold_copy", 295 "unsqueeze", 296 "unsqueeze_copy", 297 "view", 298 "view_as", 299 "view_copy", 300 "vsplit", 301 "vstack", 302 "view_as_complex", 303 "unflatten", 304 "unbind", 305 "triu", 306 "tril", 307 "triu_indices", 308 "tril_indices", 309 # 310 # Tensor Creation 311 # 312 "arange", 313 "cauchy", 314 "empty", 315 "empty_like", 316 "empty_permuted", 317 "empty_strided", 318 "eye", 319 "full", 320 "full_like", 321 "linspace", 322 "logspace", 323 "new_empty", 324 "new_empty_strided", 325 "new_full", 326 "new_ones", 327 "new_zeros", 328 "ones", 329 "ones_like", 330 "randn", 331 "scalar_tensor", 332 "zero", 333 "zeros", 334 "zeros_like", 335 # 336 # Test-related functions 337 # 338 "allclose", 339 "equal", 340 # 341 # Statistical operations 342 # 343 "bucketize", 344 # 345 # Misc 346 # 347 "is_complex", 348 "renorm", 349 "stft", 350 "istft", 351] 352 353Tensor = torch.Tensor 354DispatchKey = torch._C.DispatchKey # type: ignore[attr-defined] 355aten = torch._ops.ops.aten 356 357# Note that the docstrings for the public methods from this file are in 358# torch/_torch_docs.py 359 360 361def is_noncontiguous_supported(device): 362 return device is None or device.type != "hpu" 363 364 365def handle_noncontiguous_outputs(input_tlist, output): 366 device = None 367 from torch._subclasses.fake_tensor import FakeTensor 368 369 for t in input_tlist: 370 if isinstance(t, FakeTensor): 371 device = t.fake_device 372 break 373 374 if not is_noncontiguous_supported(device): 375 output = output.contiguous() 376 377 return output 378 379 380def _broadcast_shapes(*_shapes): 381 from torch.fx.experimental.symbolic_shapes import guard_size_oblivious 382 383 shapes = tuple( 384 (x,) if isinstance(x, IntLike) else x 385 for x in filter(lambda x: x is not None, _shapes) 386 ) 387 388 # Short-circuits on no input 389 if len(shapes) == 0: 390 return None 391 392 # Type checking 393 # TODO: make common validations available as utils 394 for shape in shapes: 395 assert isinstance(shape, Sequence) 396 397 # Computes common shape 398 common_shape = [ 399 1, 400 ] * reduce(max, (len(shape) for shape in shapes)) 401 for arg_idx, shape in enumerate(shapes): 402 for idx in range(-1, -1 - len(shape), -1): 403 if guard_size_oblivious(common_shape[idx] == 1): 404 if shape[idx] < 0: 405 raise ValueError( 406 "Attempting to broadcast a dimension with negative length!" 407 ) 408 common_shape[idx] = shape[idx] 409 elif guard_size_oblivious(shape[idx] != 1): 410 if common_shape[idx] != shape[idx]: 411 raise RuntimeError( 412 f"Attempting to broadcast a dimension of length {shape[idx]} at {idx}! " 413 f"Mismatching argument at index {arg_idx} had {shape}; but expected shape " 414 f"should be broadcastable to {common_shape}" 415 ) 416 417 return common_shape 418 419 420def _maybe_broadcast(*args, preserve_cpu_scalar_tensors=True): 421 # Computes common shape 422 common_shape = _broadcast_shapes( 423 *(t.shape if isinstance(t, TensorLike) else None for t in args) 424 ) 425 426 def __maybe_broadcast(x, shape): 427 if x is None: 428 return None 429 elif isinstance(x, Number): 430 return x 431 elif isinstance(x, TensorLike): 432 if preserve_cpu_scalar_tensors and utils.is_cpu_scalar_tensor(x): 433 return x 434 435 if not utils.same_shape(x.shape, common_shape): 436 return x.expand(common_shape) 437 438 return x 439 else: 440 raise RuntimeError( 441 "Unexpected type when broadcasting: " + str(type(x)) + "!" 442 ) 443 444 return tuple(__maybe_broadcast(x, common_shape) for x in args) 445 446 447# Utilities should come BEFORE this import 448from torch._decomp import register_decomposition 449 450 451# 452# Elementwise unary references 453# 454 455infer_aten_op = object() 456 457 458# TODO: add type promotion support 459def _make_elementwise_unary_reference( 460 type_promotion_kind, 461 *, 462 aten_op=infer_aten_op, 463 extra_meta=None, 464) -> Callable: 465 def inner(prim: Callable): 466 nonlocal aten_op 467 468 @wraps(prim) 469 @out_wrapper() 470 @elementwise_unary_scalar_wrapper 471 @elementwise_type_promotion_wrapper( 472 type_promoting_args=("a",), 473 type_promotion_kind=type_promotion_kind, 474 ) 475 def _ref(a: TensorLikeType) -> TensorLikeType: 476 if extra_meta is not None: 477 extra_meta(a) 478 479 output = prim(a) 480 return handle_noncontiguous_outputs([a], output) 481 482 if aten_op is infer_aten_op: 483 aten_op = utils.get_aten_op(prim, prim.__name__) 484 if aten_op is not None: 485 register_decomposition(aten_op)(_ref) 486 487 return _ref 488 489 return inner 490 491 492def _make_alias(fn, name): 493 """ 494 This function defines an alias of another function and sets its __name__ argument. 495 It also sets its __module__ argument to the module of the caller. 496 Note that when naively doing `alias = fn`, we have that `alias.__name__ == "fn"`, and 497 `alias.__module__ == fn.__module__`. 498 """ 499 500 def _fn(*args, **kwargs): 501 return fn(*args, **kwargs) 502 503 _fn.__name__ = name 504 _fn.__module__ = inspect.currentframe().f_back.f_globals["__name__"] # type: ignore[union-attr] 505 return _fn 506 507 508def _make_inplace(fn): 509 """ 510 Given a function with out variant (i.e. using `out_wrapper()), it returns its in-place variant 511 See https://github.com/pytorch/pytorch/wiki/Developer-FAQ#how-do-in-place-operations-work-in-pytorch 512 """ 513 514 # nb. We use the name of the first argument used in the unary references 515 @wraps(fn) 516 def _fn(a, *args, **kwargs): 517 return fn(a, *args, out=a, **kwargs) 518 519 inplace_name = f"{fn.__name__}_" 520 _fn.__name__ = inplace_name 521 _fn = register_decomposition(getattr(aten, inplace_name))(_fn) # type: ignore[assignment] 522 523 # We access the __all__ attribute of the module where fn is defined 524 # There may be a cleaner way of doing this... 525 from inspect import getmodule 526 527 _all = getmodule(fn).__all__ # type: ignore[union-attr] 528 if inplace_name not in _all: 529 _all.append(inplace_name) 530 return _fn 531 532 533@_make_elementwise_unary_reference(ELEMENTWISE_TYPE_PROMOTION_KIND.COMPLEX_TO_FLOAT) 534def abs(a): 535 return prims.abs(a) 536 537 538@_make_elementwise_unary_reference(ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT) 539def acos(a): 540 return prims.acos(a) 541 542 543@_make_elementwise_unary_reference(ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT) 544def acosh(a): 545 return prims.acosh(a) 546 547 548@_make_elementwise_unary_reference(ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT) 549def asin(a): 550 return prims.asin(a) 551 552 553@_make_elementwise_unary_reference(ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT) 554def asinh(a): 555 return prims.asinh(a) 556 557 558@_make_elementwise_unary_reference(ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT) 559def atan(a): 560 return prims.atan(a) 561 562 563@_make_elementwise_unary_reference(ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT) 564def atanh(a): 565 return prims.atanh(a) 566 567 568@_make_elementwise_unary_reference(ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT) 569def bitwise_not(a): 570 return prims.bitwise_not(a) 571 572 573@_make_elementwise_unary_reference(ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT) 574def ceil(a): 575 return prims.ceil(a) 576 577 578@register_decomposition(aten.is_complex) 579def is_complex(input: TensorLikeType): 580 return utils.is_complex_dtype(input.dtype) 581 582 583@register_decomposition(aten.conj_physical) 584@out_wrapper() 585def conj_physical(input: TensorLikeType): 586 if not utils.is_complex_dtype(input.dtype): 587 return input 588 return prims.conj_physical(input) 589 590 591@_make_elementwise_unary_reference(ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT) 592def cos(a): 593 return prims.cos(a) 594 595 596@_make_elementwise_unary_reference(ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT) 597def cosh(a): 598 return prims.cosh(a) 599 600 601@_make_elementwise_unary_reference(ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT) 602def digamma(a): 603 return prims.digamma(a) 604 605 606@_make_elementwise_unary_reference(ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT) 607def erf(a): 608 return prims.erf(a) 609 610 611@_make_elementwise_unary_reference(ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT) 612def erfinv(a): 613 return prims.erf_inv(a) 614 615 616@_make_elementwise_unary_reference(ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT) 617def erfc(a): 618 return prims.erfc(a) 619 620 621@_make_elementwise_unary_reference(ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT) 622def exp(a): 623 return prims.exp(a) 624 625 626@_make_elementwise_unary_reference(ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT) 627def expm1(a): 628 return prims.expm1(a) 629 630 631@_make_elementwise_unary_reference(ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT) 632def exp2(a): 633 return prims.exp2(a) 634 635 636# Fill has its own implementation because it has a value parameter 637# CompositeImplicitAutograd - don't register decomp 638@out_wrapper() 639@elementwise_type_promotion_wrapper( 640 type_promoting_args=("a,"), 641 type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.NO_OPMATH, 642) 643def fill(a: TensorLikeType, value: NumberType) -> TensorLikeType: 644 assert isinstance(a, TensorLike) 645 assert isinstance(value, Number) 646 647 python_type = utils.dtype_to_type(a.dtype) 648 if not utils.is_weakly_lesser_type(type(value), python_type): 649 msg = f"value argument of type {type(value)} cannot be safely cast to type {python_type}!" 650 raise ValueError(msg) 651 652 return prims.fill(a, value) 653 654 655def fill_(a: TensorLikeType, value: NumberType) -> TensorLikeType: 656 r = prims.fill(a, value) 657 prims.copy_to(a, r) 658 return a 659 660 661@register_decomposition(aten.zero) 662@out_wrapper() 663def zero(input: TensorLikeType) -> TensorLikeType: 664 return torch.zeros_like(input) 665 666 667@_make_elementwise_unary_reference(ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT) 668def floor(a): 669 return prims.floor(a) 670 671 672@_make_elementwise_unary_reference(ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT) 673def frac(x: TensorLikeType) -> TensorLikeType: 674 trunc_x = torch.mul(torch.floor(torch.abs(x)), torch.sign(x)) 675 return torch.sub(x, trunc_x) 676 677 678# imag does not use _make_elementwise_unary_reference because it does not support out 679def imag(a: TensorLikeType) -> TensorLikeType: 680 assert isinstance(a, TensorLike) 681 torch._check( 682 utils.is_complex_dtype(a.dtype), lambda: "imag only supports complex tensors." 683 ) 684 return prims.imag(a) 685 686 687@_make_elementwise_unary_reference( 688 ELEMENTWISE_TYPE_PROMOTION_KIND.ALWAYS_BOOL, 689 aten_op=None, # CompositeImplicitAutograd 690) 691def isfinite(a: TensorLikeType) -> TensorLikeType: 692 if utils.is_float_dtype(a.dtype) or utils.is_complex_dtype(a.dtype): 693 return prims.isfinite(a) 694 695 return ones_like(a, dtype=torch.bool) 696 697 698@_make_elementwise_unary_reference(ELEMENTWISE_TYPE_PROMOTION_KIND.ALWAYS_BOOL) 699def isinf(a: TensorLikeType) -> TensorLikeType: 700 if utils.is_complex_dtype(a.dtype): 701 return torch.logical_or(isinf(torch.real(a)), isinf(torch.imag(a))) 702 if utils.is_float_dtype(a.dtype): 703 return torch.abs(a) == float("inf") 704 return torch.zeros_like(a, dtype=torch.bool) 705 706 707@_make_elementwise_unary_reference(ELEMENTWISE_TYPE_PROMOTION_KIND.ALWAYS_BOOL) 708def isposinf(a: TensorLikeType) -> TensorLikeType: 709 torch._check( 710 not utils.is_complex_dtype(a.dtype), 711 lambda: f"Complex dtype is not supported for isposinf, got dtype {a.dtype}", 712 ) 713 if utils.is_float_dtype(a.dtype): 714 return a == float("inf") 715 return torch.zeros_like(a, dtype=torch.bool) 716 717 718@_make_elementwise_unary_reference(ELEMENTWISE_TYPE_PROMOTION_KIND.ALWAYS_BOOL) 719def isneginf(a: TensorLikeType) -> TensorLikeType: 720 torch._check( 721 not utils.is_complex_dtype(a.dtype), 722 lambda: f"Complex dtype is not supported for isneginf, got dtype {a.dtype}", 723 ) 724 if utils.is_float_dtype(a.dtype): 725 return a == float("-inf") 726 return torch.zeros_like(a, dtype=torch.bool) 727 728 729@_make_elementwise_unary_reference(ELEMENTWISE_TYPE_PROMOTION_KIND.ALWAYS_BOOL) 730def isnan(a: TensorLikeType) -> TensorLikeType: 731 return prims.ne(a, a) 732 733 734# alias 735mvlgamma = _make_alias(torch.special.multigammaln, "mvlgamma") # type: ignore[has-type] 736 737 738@_make_elementwise_unary_reference( 739 ELEMENTWISE_TYPE_PROMOTION_KIND.ALWAYS_BOOL, 740 aten_op=None, # CompositeImplicitAutograd 741) 742def isreal(a: TensorLikeType) -> TensorLikeType: 743 if utils.is_complex_dtype(a.dtype): 744 return torch.imag(a) == 0 745 return torch.ones_like(a, dtype=torch.bool) 746 747 748# TODO: if this is special maybe it should be defined there and imported here? 749@_make_elementwise_unary_reference( 750 ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT, aten_op=aten.i0 751) 752def i0(a): 753 return prims.bessel_i0(a) 754 755 756@_make_elementwise_unary_reference(ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT) 757def lgamma(a): 758 return prims.lgamma(a) 759 760 761@_make_elementwise_unary_reference(ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT) 762def log(a): 763 return prims.log(a) 764 765 766@_make_elementwise_unary_reference(ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT) 767def log1p(a): 768 return prims.log1p(a) 769 770 771@_make_elementwise_unary_reference(ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT) 772def log2(a): 773 return prims.log2(a) 774 775 776@_make_elementwise_unary_reference(ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT) 777def log10(a): 778 return prims.log10(a) 779 780 781# CompositeImplicitAutograd - don't register decomp 782@out_wrapper() 783def log_softmax( 784 a: TensorLikeType, 785 dim: int, 786 dtype: Optional[torch.dtype] = None, 787) -> TensorLikeType: 788 result_dtype = dtype or a.dtype 789 computation_dtype = utils.get_computation_dtype(result_dtype) 790 a_ = _maybe_convert_to_dtype(a, computation_dtype) 791 return _maybe_convert_to_dtype(a_ - logsumexp(a_, dim, keepdim=True), result_dtype) # type: ignore[return-value] 792 793 794@register_decomposition(aten.logsumexp) 795@out_wrapper() 796@elementwise_type_promotion_wrapper( 797 type_promoting_args=("self",), 798 type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT, 799) 800def logsumexp( 801 self: TensorLikeType, dim: DimsType, keepdim: bool = False 802) -> TensorLikeType: 803 if not isinstance(dim, Iterable): 804 dim = (dim,) 805 if self.numel() == 0: 806 return torch.sum(torch.exp(self), dim, keepdim).log() 807 maxes = torch.amax(torch.real(self), dim, keepdim=True) 808 maxes = torch.masked_fill(maxes, maxes.abs() == float("inf"), 0) 809 maxes_squeezed = maxes if keepdim else torch.squeeze(maxes, dim) 810 result = torch.sum(torch.exp(self - maxes), dim, keepdim) 811 return result.log().add(maxes_squeezed) 812 813 814@register_decomposition(aten.nan_to_num) 815@out_wrapper() 816def nan_to_num( 817 a: TensorLikeType, 818 nan: Optional[NumberType] = 0.0, 819 posinf: Optional[NumberType] = None, 820 neginf: Optional[NumberType] = None, 821) -> TensorLikeType: 822 assert isinstance(a, TensorLike) 823 824 if utils.is_boolean_dtype(a.dtype) or utils.is_integer_dtype(a.dtype): 825 return a.clone() 826 827 if nan is None: 828 nan = 0.0 829 830 if posinf is None: 831 posinf = torch.finfo(a.dtype).max 832 833 if neginf is None: 834 neginf = torch.finfo(a.dtype).min 835 836 result = torch.where(torch.isnan(a), nan, a) # type: ignore[call-overload] 837 result = torch.where(torch.isneginf(a), neginf, result) # type: ignore[call-overload] 838 result = torch.where(torch.isposinf(a), posinf, result) # type: ignore[call-overload] 839 return result 840 841 842def _neg_meta(a: TensorLikeType): 843 torch._check( 844 a.dtype is not torch.bool, 845 lambda: ( 846 "Negation, the `-` operator, on a bool tensor is not supported. " 847 "If you are trying to invert a mask, use the `~` or `logical_not()` " 848 "operator instead." 849 ), 850 ) 851 852 853@_make_elementwise_unary_reference( 854 ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT, extra_meta=_neg_meta 855) 856def neg(a): 857 return prims.neg(a) 858 859 860# positive does not use _make_elementwise_unary_reference because it does not support out 861# CompositeImplicitAutograd - don't register decomp 862def positive(a: TensorLikeType) -> TensorLikeType: 863 assert isinstance(a, TensorLike) 864 if a.dtype is torch.bool: 865 msg = "positive does not support bool tensors." 866 raise RuntimeError(msg) 867 return a 868 869 870# real does not use _make_elementwise_unary_reference because it does not support out 871def real(a: TensorLikeType) -> TensorLikeType: 872 assert isinstance(a, TensorLike) 873 if utils.is_complex_dtype(a.dtype): 874 return prims.real(a) 875 return a 876 877 878@_make_elementwise_unary_reference(ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT) 879def reciprocal(a): 880 return prims.reciprocal(a) 881 882 883@register_decomposition(aten.round) 884@out_wrapper() 885@elementwise_type_promotion_wrapper( 886 type_promoting_args=("a",), 887 type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT, 888) 889def round(a: TensorLikeType, *, decimals: int = 0) -> TensorLikeType: 890 if decimals == 0: 891 return prims.round(a) 892 else: 893 ten_pow = 10**decimals 894 ten_neg_pow = 10 ** (-decimals) 895 return prims.mul(prims.round(prims.mul(a, ten_pow)), ten_neg_pow) 896 897 898@_make_elementwise_unary_reference(ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT) 899def rsqrt(a): 900 return prims.rsqrt(a) 901 902 903@_make_elementwise_unary_reference(ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT) 904def sigmoid(a: TensorLikeType) -> TensorLikeType: 905 return true_divide(1, add(1, exp(neg(a)))) 906 907 908@_make_elementwise_unary_reference(ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT) 909def sgn(a): 910 if utils.is_complex_dtype(a.dtype): 911 a_abs = a.abs() 912 return torch.where(a_abs == 0, 0, a / a_abs) 913 else: 914 return a.sign() 915 916 917@_make_elementwise_unary_reference(ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT) 918def sign(a): 919 return prims.sign(a) 920 921 922@_make_elementwise_unary_reference(ELEMENTWISE_TYPE_PROMOTION_KIND.ALWAYS_BOOL) 923def signbit(a): 924 return prims.signbit(a) 925 926 927@_make_elementwise_unary_reference(ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT) 928def sin(a): 929 return prims.sin(a) 930 931 932# Autograd note: This will give the right first derivative at zero (by chance), 933# but not the right second derivative 934@_make_elementwise_unary_reference(ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT) 935def sinc(a): 936 a = math.pi * a 937 return torch.where(a == 0, 1, torch.sin(a) / a) 938 939 940@_make_elementwise_unary_reference(ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT) 941def sinh(a): 942 return prims.sinh(a) 943 944 945@_make_elementwise_unary_reference(ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT) 946def sqrt(a): 947 return prims.sqrt(a) 948 949 950@_make_elementwise_unary_reference( 951 ELEMENTWISE_TYPE_PROMOTION_KIND.BOOL_TO_LONG, 952 aten_op=None, # CompositeImplicitAutograd, 953) 954def square(a: TensorLikeType) -> TensorLikeType: 955 return mul(a, a) 956 957 958@_make_elementwise_unary_reference(ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT) 959def tan(a): 960 return prims.tan(a) 961 962 963@_make_elementwise_unary_reference(ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT) 964def tanh(a): 965 return prims.tanh(a) 966 967 968@_make_elementwise_unary_reference(ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT) 969def trunc(a): 970 return prims.trunc(a) 971 972 973# TODO: register this as a real ref/decomposition once TorchInductor supports complex! 974def view_as_complex(self: TensorLikeType) -> TensorLikeType: 975 input_dtype = self.dtype 976 torch._check( 977 utils.is_float_dtype(input_dtype), 978 lambda: f"view_as_complex is only supported for floating point" 979 f"tensors, but got a tensor of scalar type: {input_dtype}", 980 ) 981 sizes = self.size() 982 torch._check( 983 len(sizes) != 0, 984 lambda: "Input tensor must have one or more dimensions", 985 ) 986 torch._check( 987 sizes[-1] == 2, 988 lambda: "Tensor must have a last dimension of size 2", 989 ) 990 991 old_strides = self.stride() 992 torch._check( 993 old_strides[-1] == 1, 994 lambda: "Tensor must have a last dimension with stride 1", 995 ) 996 dims = old_strides[:-1] 997 torch._check( 998 builtins.all(stride % 2 == 0 for stride in dims), 999 lambda: "Tensor must have a stride divisible by 2 for all but last dimension", 1000 ) 1001 torch._check( 1002 self.storage_offset() % 2 == 0, 1003 lambda: "Tensor must have a storage_offset divisible by 2", 1004 ) 1005 return prims.view_element_type( 1006 self, utils.corresponding_complex_dtype(input_dtype) 1007 ).squeeze(-1) 1008 1009 1010def _make_elementwise_binary_reference( 1011 type_promotion_kind, 1012 aten_op=infer_aten_op, 1013 name=None, 1014 has_out=True, 1015 supports_lhs_python_scalar=True, 1016 supports_rhs_python_scalar=True, 1017 supports_two_python_scalars=False, 1018 should_register_decomposition=True, 1019) -> Callable: 1020 def inner(prim: Callable): 1021 nonlocal aten_op, name 1022 if name is None: 1023 name = prim.__name__ 1024 1025 @wraps(prim) 1026 @elementwise_type_promotion_wrapper( 1027 type_promoting_args=("a", "b"), 1028 type_promotion_kind=type_promotion_kind, 1029 ) 1030 def _ref( 1031 a: Union[Tensor, NumberType], 1032 b: Union[Tensor, NumberType], 1033 ) -> Tensor: 1034 torch._check_value( 1035 supports_lhs_python_scalar or not isinstance(a, Number), 1036 lambda: f"{name}: Received a lhs Python scalar to an elementwise binary " 1037 "operation that does not accept lhs scalars!", 1038 ) 1039 torch._check_value( 1040 supports_rhs_python_scalar or not isinstance(b, Number), 1041 lambda: f"{name}: Received a rhs Python scalar to an elementwise binary " 1042 "operation that does not accept rhs scalars!", 1043 ) 1044 torch._check_value( 1045 supports_two_python_scalars 1046 or not (isinstance(a, Number) and isinstance(b, Number)), 1047 lambda: f"{name}: Receive two Number inputs to an elementwise binary operation!", 1048 ) 1049 a, b = _maybe_broadcast(a, b) 1050 output = prim(a, b) 1051 return handle_noncontiguous_outputs([a, b], output) 1052 1053 if has_out: 1054 _ref = out_wrapper()(_ref) # type: ignore[assignment] 1055 1056 _ref.__name__ = name 1057 if aten_op is infer_aten_op: 1058 aten_op = utils.get_aten_op(prim, name) 1059 if aten_op is not None and should_register_decomposition: 1060 register_decomposition(aten_op)(_ref) 1061 1062 return _ref 1063 1064 return inner 1065 1066 1067# Add has its own implementation because it has an alpha argument 1068@register_decomposition(aten.add) 1069@out_wrapper() 1070@elementwise_type_promotion_wrapper( 1071 type_promoting_args=("a", "b"), 1072 type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT, 1073) 1074def add( 1075 a: Union[TensorLikeType, NumberType], 1076 b: Union[TensorLikeType, NumberType], 1077 *, 1078 alpha: Optional[NumberType] = None, 1079): 1080 """ 1081 Reference implementation of torch.add 1082 """ 1083 1084 a, b = _maybe_broadcast(a, b) 1085 1086 if alpha is not None: 1087 dtype = a.dtype if isinstance(a, TensorLike) else b.dtype # type: ignore[union-attr] 1088 python_type = utils.dtype_to_type(dtype) 1089 if python_type != bool and not utils.is_weakly_lesser_type( 1090 type(alpha), python_type 1091 ): 1092 msg = f"alpha argument of type {type(alpha)} cannot be safely cast to type {python_type}!" 1093 raise ValueError(msg) 1094 if isinstance(b, TensorLike): 1095 b = prims.mul(b, alpha) 1096 else: 1097 b = b * alpha 1098 1099 output = prims.add(a, b) 1100 return handle_noncontiguous_outputs([a, b], output) 1101 1102 1103@_make_elementwise_binary_reference( 1104 type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT, 1105 supports_lhs_python_scalar=False, 1106 supports_rhs_python_scalar=False, 1107) 1108def atan2(a, b): 1109 return prims.atan2(a, b) 1110 1111 1112@_make_elementwise_binary_reference( 1113 type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT, 1114) 1115def bitwise_and(a: TensorLikeType, b: TensorLikeType) -> TensorLikeType: 1116 return prims.bitwise_and(a, b) 1117 1118 1119@_make_elementwise_binary_reference( 1120 type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT, 1121) 1122def bitwise_left_shift(a: TensorLikeType, b: TensorLikeType) -> TensorLikeType: 1123 return prims.shift_left(a, b) 1124 1125 1126@_make_elementwise_binary_reference( 1127 type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT, 1128) 1129def bitwise_or(a: TensorLikeType, b: TensorLikeType) -> TensorLikeType: 1130 return prims.bitwise_or(a, b) 1131 1132 1133@_make_elementwise_binary_reference( 1134 type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT, 1135) 1136def bitwise_right_shift(a: TensorLikeType, b: TensorLikeType) -> TensorLikeType: 1137 return prims.shift_right_arithmetic(a, b) 1138 1139 1140@_make_elementwise_binary_reference( 1141 type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT, 1142) 1143def bitwise_xor(a: TensorLikeType, b: TensorLikeType) -> TensorLikeType: 1144 return prims.bitwise_xor(a, b) 1145 1146 1147@_make_elementwise_binary_reference( 1148 type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT, 1149 supports_lhs_python_scalar=False, 1150) 1151def copysign( 1152 a: Union[TensorLikeType, NumberType], b: Union[TensorLikeType, NumberType] 1153): 1154 if isinstance(b, Number) and isinstance(a, Tensor): 1155 b = scalar_tensor(b, dtype=a.dtype, device=a.device) 1156 elif isinstance(a, Tensor) and isinstance(b, Tensor) and a.device != b.device: 1157 msg = f"Expected divisor (b) to be on the same device ({a.device}) as dividend (a), but it is found on {b.device}!" 1158 raise RuntimeError(msg) 1159 return where(signbit(b), neg(abs(a)), abs(a)) 1160 1161 1162# complex = _make_elementwise_binary_reference(prims.complex, type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT) 1163 1164 1165@register_decomposition(aten.div) 1166@out_wrapper() 1167def div( 1168 a: Union[TensorLikeType, NumberType], 1169 b: Union[TensorLikeType, NumberType], 1170 *, 1171 rounding_mode: Optional[str] = None, 1172): 1173 """ 1174 Reference implementation of torch.div 1175 """ 1176 if rounding_mode is None: 1177 return true_divide(a, b) 1178 elif rounding_mode == "trunc": 1179 return trunc_divide(a, b) 1180 elif rounding_mode == "floor": 1181 return floor_divide(a, b) 1182 else: 1183 msg = f"div expected rounding_mode to be one of None, 'trunc', or 'floor' but found {rounding_mode}." 1184 raise ValueError(msg) 1185 1186 1187@_make_elementwise_binary_reference( 1188 type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.ALWAYS_BOOL, 1189 supports_lhs_python_scalar=False, 1190) 1191def eq(a: TensorLikeType, b: TensorLikeType) -> TensorLikeType: 1192 return prims.eq(a, b) 1193 1194 1195@_make_elementwise_binary_reference( 1196 type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.BOOL_TO_LONG, 1197) 1198def pow( 1199 a: Union[TensorLikeType, NumberType], 1200 b: Union[TensorLikeType, NumberType], 1201) -> TensorLikeType: 1202 assert isinstance(a, TensorLikeType) or isinstance(b, TensorLikeType) 1203 1204 if isinstance(b, Number): 1205 if b == 1.0: 1206 return a.clone() # type: ignore[return-value,union-attr] 1207 elif b == 2.0: 1208 return a * a # type: ignore[return-value] 1209 elif b == 0.5: 1210 return torch.sqrt(a) # type: ignore[arg-type] 1211 elif isinstance(a, Number): 1212 if a == 1.0: 1213 return torch.fill(b, True) 1214 if a == 2.0 and ( 1215 utils.is_float_dtype(b.dtype) or utils.is_complex_dtype(b.dtype) 1216 ): 1217 return torch.exp2(b) 1218 1219 return prims.pow(a, b) 1220 1221 1222# Float power has its own implementation because it has unique type promotion. 1223# CompositeImplicitAutograd - don't register decomp 1224@out_wrapper() 1225def float_power( 1226 a: Union[TensorLikeType, NumberType], 1227 b: Union[TensorLikeType, NumberType], 1228) -> Tensor: 1229 if isinstance(a, Number) and isinstance(b, Number): 1230 raise ValueError( 1231 "Receive two Number inputs to an elementwise binary operation!" 1232 ) 1233 1234 # Handles type promotion 1235 dtype = utils.get_higher_dtype(a, b) 1236 assert dtype is not None 1237 if utils.is_complex_dtype(dtype): 1238 dtype = torch.complex128 1239 else: 1240 dtype = torch.float64 1241 1242 # Float power has the following contiguous cast behavior to be 1243 # consistent with its C++ impl 1244 a = _maybe_convert_to_dtype(a, dtype) 1245 b = _maybe_convert_to_dtype(b, dtype) 1246 1247 a, b = _maybe_broadcast(a, b) 1248 return pow(a, b) 1249 1250 1251# >>> a = torch.tensor(-0.2500, dtype=torch.float64) 1252# tensor(-0.250000000000000, dtype=torch.float64) 1253# 1254# >>> b = torch.tensor(-0.0010, dtype=torch.float64) 1255# tensor(-0.001000000000000, dtype=torch.float64) 1256# 1257# Note: In this case, casting float to double will expand the float mantissa with zeros, 1258# while creating a double generates a distinct mantissa. 1259# >>> torch.tensor(-0.001).to(dtype=torch.float64) 1260# tensor(-0.001000000047497, dtype=torch.float64) 1261# 1262# Floor Division 1263# The difference is caused because torch.remainder(a, b) = -0.001. 1264# 1265# >>> torch.floor(torch.true_divide(a, b)) 1266# tensor(250., dtype=torch.float64) 1267# 1268# >>> torch.div(a, b, rounding_mode='floor') 1269# tensor(249., dtype=torch.float64) 1270# 1271# Definition: a // b = (a - remainder(a, b)) / b 1272# >>> torch.true_divide(torch.sub(a, torch.remainder(a, b)), b) 1273# tensor(249., dtype=torch.float64) 1274# 1275# For reference, see CPython's implementation: 1276# https://github.com/python/cpython/blob/ace008c531dd685a30c1dd68f9b5ba35f20171cf/Objects/floatobject.c#L636 1277 1278 1279@_make_elementwise_binary_reference( 1280 type_promotion_kind=utils.ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT, 1281 supports_two_python_scalars=True, 1282 should_register_decomposition=False, 1283) 1284def floor_divide( 1285 a: Union[TensorLikeType, NumberType], b: Union[TensorLikeType, NumberType] 1286): 1287 # Wrap scalars because some references only accept tensor arguments. 1288 if isinstance(a, Number) and isinstance(b, Number): 1289 a = scalar_tensor(a) 1290 b = scalar_tensor(b) 1291 elif isinstance(b, Number) and isinstance(a, Tensor): 1292 b = scalar_tensor(b, dtype=a.dtype, device=a.device) 1293 elif isinstance(a, Number) and isinstance(b, Tensor): 1294 a = scalar_tensor(a, dtype=b.dtype, device=b.device) 1295 elif isinstance(a, Tensor) and isinstance(b, Tensor) and a.device != b.device: 1296 if a.device == torch.device("cpu"): 1297 msg = f"Expected divisor (b) to be on the same device ({a.device}) as dividend (a), but it is found on {b.device}!" 1298 raise RuntimeError(msg) 1299 else: 1300 b = prims.device_put(b, device=a.device) 1301 1302 assert isinstance(a, Tensor) and isinstance(b, Tensor) 1303 dtype = a.dtype 1304 if utils.is_float_dtype(dtype): 1305 return _floor_divide_float(a, b) 1306 elif utils.is_integer_dtype(dtype): 1307 return _floor_divide_integer(a, b) 1308 else: 1309 torch._check(False, lambda: f"{dtype} not supported for floor_divide") 1310 1311 1312def _floor_divide_integer(a: Tensor, b: Tensor) -> Tensor: 1313 a, b = _maybe_broadcast(a, b) 1314 1315 if not a.dtype.is_signed: 1316 return prims.div(a, b) 1317 1318 # Convert truncation to flooring: 1319 offset = (torch.signbit(a) != torch.signbit(b)).logical_and(torch.fmod(a, b) != 0) 1320 return prims.div(a, b) - _maybe_convert_to_dtype(offset, a.dtype) 1321 1322 1323def _floor_divide_float(a: Tensor, b: Tensor) -> Tensor: 1324 mod = fmod(a, b) 1325 div = true_divide(sub(a, mod), b) 1326 1327 # Ensure that the remainder has the same sign as denominator 1328 different_signed_inputs = bitwise_xor(lt(a, 0), lt(b, 0)) 1329 non_zero_remainder = ne(mod, 0) 1330 mask = bitwise_and(non_zero_remainder, different_signed_inputs) 1331 div = where(mask, sub(div, 1), div) 1332 1333 # Map quotient to nearest integer value 1334 floor_div = floor(div) 1335 mask = gt(sub(div, floor_div), 0.5) 1336 floor_div = where(mask, add(floor_div, 1), floor_div) 1337 1338 basic_div = true_divide(a, b) 1339 zero_tensor = scalar_tensor(0, dtype=basic_div.dtype, device=basic_div.device) 1340 1341 # If quotient is zero, copy signbit from true_divide quotient 1342 floor_div = where(ne(div, 0), floor_div, copysign(zero_tensor, basic_div)) 1343 1344 # If denominator is zero, then follow true_divide behavior 1345 return where(ne(b, 0), floor_div, basic_div) 1346 1347 1348@_make_elementwise_binary_reference( 1349 type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT, 1350 supports_lhs_python_scalar=False, 1351 supports_rhs_python_scalar=False, 1352) 1353def fmax(a: TensorLikeType, b: TensorLikeType) -> TensorLikeType: 1354 return prims.fmax(a, b) 1355 1356 1357@_make_elementwise_binary_reference( 1358 type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT, 1359 supports_lhs_python_scalar=False, 1360 supports_rhs_python_scalar=False, 1361) 1362def fmin(a: TensorLikeType, b: TensorLikeType) -> TensorLikeType: 1363 return prims.fmin(a, b) 1364 1365 1366@_make_elementwise_binary_reference( 1367 type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT, 1368 supports_lhs_python_scalar=False, 1369 supports_rhs_python_scalar=True, 1370) 1371def fmod(a: TensorLikeType, b: TensorLikeType) -> TensorLikeType: 1372 return prims.fmod(a, b) 1373 1374 1375@register_decomposition(aten.frexp) 1376@out_wrapper("mantissa", "exponent") 1377def frexp(self: TensorLikeType) -> Tuple[TensorLikeType, TensorLikeType]: 1378 return torch.return_types.frexp(prims.frexp(self)) 1379 1380 1381@_make_elementwise_binary_reference( 1382 type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT, 1383 supports_lhs_python_scalar=False, 1384 supports_rhs_python_scalar=False, 1385) 1386def gcd(a: TensorLikeType, b: TensorLikeType) -> TensorLikeType: 1387 return prims.gcd(a, b) 1388 1389 1390@_make_elementwise_binary_reference( 1391 type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.ALWAYS_BOOL, 1392 supports_lhs_python_scalar=False, 1393) 1394def ge(a: TensorLikeType, b: TensorLikeType) -> TensorLikeType: 1395 return prims.ge(a, b) 1396 1397 1398@_make_elementwise_binary_reference( 1399 type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.ALWAYS_BOOL, 1400 supports_lhs_python_scalar=False, 1401) 1402def gt(a: TensorLikeType, b: TensorLikeType) -> TensorLikeType: 1403 return prims.gt(a, b) 1404 1405 1406@_make_elementwise_binary_reference( 1407 type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT, 1408 supports_lhs_python_scalar=False, 1409 supports_rhs_python_scalar=False, 1410) 1411def heaviside(input: TensorLikeType, values: TensorLikeType) -> TensorLikeType: 1412 input_eq_zero = torch.eq(input, 0) 1413 input_lt_zero = torch.logical_or(torch.lt(input, 0), torch.isnan(input)) 1414 zeros_and_ones = torch.where(input_lt_zero, 0, 1) 1415 output = torch.where(input_eq_zero, values, zeros_and_ones) 1416 return output 1417 1418 1419@_make_elementwise_binary_reference( 1420 type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT, 1421 supports_lhs_python_scalar=False, 1422 supports_rhs_python_scalar=False, 1423) 1424def hypot(a: TensorLikeType, b: TensorLikeType) -> TensorLikeType: 1425 return prims.hypot(a, b) 1426 1427 1428@_make_elementwise_binary_reference( 1429 type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT, 1430 supports_lhs_python_scalar=False, 1431 supports_rhs_python_scalar=False, 1432) 1433def igamma(a: TensorLikeType, b: TensorLikeType) -> TensorLikeType: 1434 return prims.igamma(a, b) 1435 1436 1437@_make_elementwise_binary_reference( 1438 type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT, 1439 supports_lhs_python_scalar=False, 1440 supports_rhs_python_scalar=False, 1441) 1442def igammac(a: TensorLikeType, b: TensorLikeType) -> TensorLikeType: 1443 return prims.igammac(a, b) 1444 1445 1446def _check_close_args( 1447 name: str, 1448 a: TensorLikeType, 1449 b: TensorLikeType, 1450 rtol: float, 1451 atol: float, 1452) -> None: 1453 torch._check_value( 1454 a.dtype == b.dtype, 1455 lambda: f"{name}: Attempting to compare tensors of different dtypes {a.dtype} and {b.dtype}!", 1456 ) 1457 torch._check( 1458 rtol >= 0, 1459 lambda: f"{name}: rtol must be greater than or equal to zero, but got {rtol}!", 1460 ) 1461 torch._check( 1462 atol >= 0, 1463 lambda: f"{name}: atol must be greater than or equal to zero, but got {atol}!", 1464 ) 1465 1466 1467# CompositeImplicitAutograd - don't register decomp 1468def isclose( 1469 a: TensorLikeType, 1470 b: TensorLikeType, 1471 rtol: float = 1e-05, 1472 atol: float = 1e-08, 1473 equal_nan: bool = False, 1474) -> TensorLikeType: 1475 _check_close_args(name="torch.isclose", a=a, b=b, rtol=rtol, atol=atol) 1476 1477 close = eq(a, b) 1478 if equal_nan and (utils.is_float_dtype(a.dtype) or utils.is_complex_dtype(a.dtype)): 1479 close = logical_or(close, logical_and(isnan(a), isnan(b))) 1480 1481 # Note: In case of zero tolerances the closeness inequality degenerates to an equality check. 1482 # In this case, the short-circuit prevents false positives as detailed in the paragraph below. 1483 if atol == 0 and rtol == 0: 1484 return close 1485 1486 # Note [closeness error computation] 1487 # atol and rtol are provided as doubles, so the computation 1488 # rtol * other will produce a float or complex tensor. 1489 # When the difference (self - other) is compared to it then the 1490 # tensor representing the difference will also be cast to float or complex. 1491 # However, since (self - other) in uint8 is very likely to produce a 1492 # negative value, this moves the cast forward so the difference is 1493 # always computed in a float or complex type. 1494 # If the values of the integer tensors cannot be exactly represented 1495 # by the default scalar type then this may cause an incorrect result. 1496 if not utils.is_float_dtype(a.dtype) and not utils.is_complex_dtype(a.dtype): 1497 a = prims.convert_element_type(a, torch.get_default_dtype()) 1498 b = prims.convert_element_type(b, torch.get_default_dtype()) 1499 1500 allowed_error = add(atol, abs(mul(b, rtol))) 1501 actual_error = abs(sub(a, b)) 1502 1503 # Computes finite closeness 1504 result = logical_or( 1505 close, logical_and(isfinite(actual_error), le(actual_error, allowed_error)) 1506 ) 1507 1508 return result 1509 1510 1511@_make_elementwise_binary_reference( 1512 type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT, 1513 supports_lhs_python_scalar=False, 1514 supports_rhs_python_scalar=False, 1515) 1516def lcm(a: TensorLikeType, b: TensorLikeType): 1517 dtype = a.dtype 1518 # promoting to int32 to maintain 100% consistency with C++ and to 1519 # prevent overflow in case of int8 and int16 1520 promote_to_int = dtype in (torch.int8, torch.int16) 1521 if promote_to_int: 1522 a = prims.convert_element_type(a, torch.int32) 1523 b = prims.convert_element_type(b, torch.int32) 1524 1525 g = torch.gcd(a, b) 1526 # Avoid division by zero in case gcd(0, 0) == 0 1527 g = torch.where(g == 0, 1, g) 1528 res = torch.abs(prims.div(a, g) * b) 1529 return res if not promote_to_int else prims.convert_element_type(res, dtype) 1530 1531 1532@_make_elementwise_binary_reference( 1533 type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.ALWAYS_BOOL, 1534 supports_lhs_python_scalar=False, 1535) 1536def le(a: TensorLikeType, b: TensorLikeType) -> TensorLikeType: 1537 return prims.le(a, b) 1538 1539 1540@_make_elementwise_binary_reference( 1541 type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT, 1542 supports_lhs_python_scalar=False, 1543 supports_rhs_python_scalar=False, 1544) 1545def logaddexp(a: TensorLikeType, b: TensorLikeType) -> TensorLikeType: 1546 # Nb. this implementation does not distribute the gradients evenly when a == b 1547 mask = torch.real(a) >= torch.real(b) 1548 max_ = torch.where(mask, a, b) 1549 min_ = torch.where(mask, b, a) 1550 inf_mask = torch.logical_and( 1551 torch.logical_not(torch.isfinite(torch.real(a))), torch.real(a) == torch.real(b) 1552 ) 1553 if utils.is_complex_dtype(a.dtype) or utils.is_complex_dtype(b.dtype): 1554 # are you wondering what this bunch of codes are for? edge cases! 1555 neg_min_mask = torch.real(min_) < 0 1556 inf_vals = torch.where( 1557 neg_min_mask, min_, torch.log(torch.exp(min_) + torch.exp(max_)) 1558 ) 1559 non_nan_vals = torch.where( 1560 inf_mask, inf_vals, max_ + torch.log1p(torch.exp(min_ - max_)) 1561 ) 1562 # the type for full_like does not include tensor yet 1563 nan_mask = torch.isnan(min_) 1564 return torch.where(nan_mask, complex(float("nan"), float("nan")), non_nan_vals) # type: ignore[call-overload] 1565 else: 1566 return torch.where(inf_mask, a, max_ + torch.log1p(torch.exp(min_ - max_))) 1567 1568 1569@_make_elementwise_binary_reference( 1570 type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT, 1571 supports_lhs_python_scalar=False, 1572 supports_rhs_python_scalar=False, 1573) 1574def logaddexp2(a: TensorLikeType, b: TensorLikeType) -> TensorLikeType: 1575 torch._check( 1576 not (utils.is_complex_dtype(a.dtype) or utils.is_complex_dtype(b.dtype)), 1577 lambda: "logaddexp2 doesn't support complex dtypes", 1578 ) 1579 # Nb. this implementation does not distribute the gradients evenly when a == b 1580 mask = a >= b 1581 max_ = torch.where(mask, a, b) 1582 min_ = torch.where(mask, b, a) 1583 inf_mask = torch.logical_and(torch.isinf(a), a == b) 1584 inv_log_2 = 1.0 / math.log(2) 1585 result = max_ + torch.log1p(torch.exp2(min_ - max_)) * inv_log_2 1586 return torch.where(inf_mask, a, result) 1587 1588 1589@_make_elementwise_binary_reference( 1590 type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.ALWAYS_BOOL, 1591) 1592def logical_and(a: TensorLikeType, b: TensorLikeType): 1593 if not utils.is_boolean_dtype(a.dtype): 1594 a = a != 0 1595 if not utils.is_boolean_dtype(b.dtype): 1596 b = b != 0 1597 return a & b 1598 1599 1600@_make_elementwise_unary_reference(ELEMENTWISE_TYPE_PROMOTION_KIND.ALWAYS_BOOL) 1601def logical_not(a: TensorLikeType): 1602 if not utils.is_boolean_dtype(a.dtype): 1603 return a == 0 1604 return ~a 1605 1606 1607@_make_elementwise_binary_reference( 1608 type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.ALWAYS_BOOL, 1609) 1610def logical_or(a: TensorLikeType, b: TensorLikeType): 1611 if not utils.is_boolean_dtype(a.dtype): 1612 a = a != 0 1613 if not utils.is_boolean_dtype(b.dtype): 1614 b = b != 0 1615 return bitwise_or(a, b) 1616 1617 1618# TODO: skip unnecessary conversion of long to float 1619@_make_elementwise_binary_reference( 1620 type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.ALWAYS_BOOL, 1621) 1622def logical_xor(a: TensorLikeType, b: TensorLikeType): 1623 if not utils.is_boolean_dtype(a.dtype): 1624 a = a != 0 1625 if not utils.is_boolean_dtype(b.dtype): 1626 b = b != 0 1627 return a ^ b 1628 1629 1630@_make_elementwise_binary_reference( 1631 type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.ALWAYS_BOOL, 1632 supports_lhs_python_scalar=False, 1633) 1634def lt(a: TensorLikeType, b: TensorLikeType) -> TensorLikeType: 1635 return prims.lt(a, b) 1636 1637 1638@_make_elementwise_binary_reference( 1639 type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT, 1640) 1641def maximum(a: TensorLikeType, b: TensorLikeType) -> TensorLikeType: 1642 return prims.maximum(a, b) 1643 1644 1645@_make_elementwise_binary_reference( 1646 type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT, 1647) 1648def minimum(a: TensorLikeType, b: TensorLikeType) -> TensorLikeType: 1649 return prims.minimum(a, b) 1650 1651 1652@_make_elementwise_binary_reference( 1653 type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT, 1654 supports_two_python_scalars=True, 1655) 1656def mul(a: TensorLikeType, b: TensorLikeType) -> TensorLikeType: 1657 return prims.mul(a, b) 1658 1659 1660@_make_elementwise_binary_reference( 1661 type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.ALWAYS_BOOL, 1662 supports_lhs_python_scalar=False, 1663) 1664def ne(a: TensorLikeType, b: TensorLikeType) -> TensorLikeType: 1665 return prims.ne(a, b) 1666 1667 1668@_make_elementwise_binary_reference( 1669 type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.NO_OPMATH, 1670 supports_lhs_python_scalar=False, 1671 supports_rhs_python_scalar=False, 1672) 1673def nextafter(a: TensorLikeType, b: TensorLikeType) -> TensorLikeType: 1674 return prims.nextafter(a, b) 1675 1676 1677@_make_elementwise_binary_reference( 1678 type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT, 1679) 1680def remainder(a: TensorLikeType, b: TensorLikeType) -> TensorLikeType: 1681 return prims.remainder(a, b) 1682 1683 1684# reverse sub 1685@register_decomposition(aten.rsub) 1686@out_wrapper() 1687def rsub( 1688 a: Union[TensorLikeType, NumberType], 1689 b: Union[TensorLikeType, NumberType], 1690 alpha: NumberType = 1, 1691): 1692 if isinstance(a, Number): 1693 msg = "Received a Number for the first argument, but expected a Tensor" 1694 raise ValueError(msg) 1695 1696 return torch.sub(b, a, alpha=alpha) 1697 1698 1699# TODO: consider refactoring this with add impl 1700# sub has its own implementation because it has an alpha argument 1701@register_decomposition(aten.sub) 1702@out_wrapper() 1703@elementwise_type_promotion_wrapper( 1704 type_promoting_args=("a", "b"), 1705 type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT, 1706) 1707def sub( 1708 a: Union[TensorLikeType, NumberType], 1709 b: Union[TensorLikeType, NumberType], 1710 *, 1711 alpha: NumberType = 1, 1712): 1713 """ 1714 Reference implementation of torch.sub 1715 """ 1716 1717 a, b = _maybe_broadcast(a, b) 1718 1719 if isinstance(a, TensorLike) and isinstance(b, TensorLike): 1720 torch._check( 1721 not utils.is_boolean_dtype(a.dtype) and not utils.is_boolean_dtype(b.dtype), 1722 lambda: ( 1723 "Subtraction, the `-` operator, with two bool tensors is not supported. " 1724 "Use the `^` or `logical_xor()` operator instead." 1725 ), 1726 ) 1727 1728 if alpha != 1: 1729 dtype = a.dtype if isinstance(a, TensorLike) else b.dtype # type: ignore[union-attr] 1730 python_type = utils.dtype_to_type(dtype) 1731 if not utils.is_weakly_lesser_type(type(alpha), python_type): 1732 msg = f"alpha argument of type {type(alpha)} cannot be safely cast to type {python_type}!" 1733 raise ValueError(msg) 1734 if isinstance(b, torch.Tensor): 1735 b = prims.mul(b, alpha) 1736 else: 1737 # Carefully not to use prims.mul if b is a scalar / symint. 1738 # prims.mul always returns a tensor, 1739 # which will mess with type promotion. 1740 b = b * alpha 1741 1742 output = prims.sub(a, b) 1743 return handle_noncontiguous_outputs([a, b], output) 1744 1745 1746@_make_elementwise_binary_reference( 1747 type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT, 1748 name="true_divide", 1749 aten_op=None, # CompositeImplicitAutograd 1750 supports_two_python_scalars=True, 1751) 1752def true_divide(a: TensorLikeType, b: TensorLikeType) -> TensorLikeType: 1753 return prims.div(a, b) 1754 1755 1756@register_decomposition(aten.xlogy) 1757@out_wrapper() 1758@elementwise_type_promotion_wrapper( 1759 type_promoting_args=("a", "b"), 1760 type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT, 1761) 1762def xlogy(a: Union[TensorLikeType, NumberType], b: Union[TensorLikeType, NumberType]): 1763 torch._check( 1764 isinstance(a, TensorLike) or isinstance(b, TensorLike), 1765 lambda: 'Expected either argument a or b to be a Tensor"', 1766 ) 1767 1768 # Operations like eq and log do not handle scalar values, so we convert them to scalar_tensors. 1769 if isinstance(b, TensorLike) and isinstance(a, Number): 1770 a = scalar_tensor(a, dtype=b.dtype, device=b.device) 1771 elif isinstance(a, TensorLike) and isinstance(b, Number): 1772 b = scalar_tensor(b, dtype=a.dtype, device=a.device) 1773 1774 # mypy: expected "Tensor" 1775 assert isinstance(a, TensorLike) 1776 assert isinstance(b, TensorLike) 1777 rhs = torch.where(torch.eq(a, 0), 0, torch.mul(a, torch.log(b))) 1778 return torch.where(torch.isnan(b), float("nan"), rhs) 1779 1780 1781@_make_elementwise_binary_reference( 1782 type_promotion_kind=utils.ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT, 1783 aten_op=None, # CompositeImplicitAutograd 1784 supports_two_python_scalars=True, 1785) 1786def trunc_divide( 1787 a: Union[TensorLikeType, NumberType], b: Union[TensorLikeType, NumberType] 1788): 1789 dtype = utils.get_dtype(a) 1790 if utils.is_integer_dtype(dtype): 1791 return prims.div(a, b) 1792 1793 return trunc(prims.div(a, b)) 1794 1795 1796# 1797# Elementwise Ternary References 1798# 1799 1800 1801@register_decomposition(aten.addcdiv) 1802@out_wrapper() 1803@elementwise_type_promotion_wrapper( 1804 type_promoting_args=("self", "tensor1", "tensor2"), 1805 type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT, 1806) 1807def addcdiv( 1808 self: TensorLikeType, 1809 tensor1: TensorLikeType, 1810 tensor2: TensorLikeType, 1811 *, 1812 value: NumberType = 1, 1813) -> TensorLikeType: 1814 """ 1815 Reference implementation of torch.addcdiv 1816 """ 1817 if value is not None: 1818 dtype = self.dtype # no scalars allowed, see add 1819 python_type = utils.dtype_to_type(dtype) 1820 torch._check_value( 1821 utils.is_weakly_lesser_type(type(value), python_type), 1822 lambda: f"value argument of type {type(value)} cannot be safely cast to type {python_type}!", 1823 ) 1824 1825 return self + value * tensor1 / tensor2 1826 1827 1828@register_decomposition(aten.addcmul) 1829@out_wrapper() 1830@elementwise_type_promotion_wrapper( 1831 type_promoting_args=("self", "tensor1", "tensor2"), 1832 type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT, 1833) 1834def addcmul( 1835 self: TensorLikeType, 1836 tensor1: TensorLikeType, 1837 tensor2: TensorLikeType, 1838 *, 1839 value: NumberType = 1, 1840) -> TensorLikeType: 1841 """ 1842 Reference implementation of torch.addcmul 1843 """ 1844 if value is not None: 1845 dtype = self.dtype # no scalars allowed, see add 1846 python_type = utils.dtype_to_type(dtype) 1847 torch._check_value( 1848 utils.is_weakly_lesser_type(type(value), python_type), 1849 lambda: f"value argument of type {type(value)} cannot be safely cast to type {python_type}!", 1850 ) 1851 1852 return self + value * tensor1 * tensor2 1853 1854 1855@register_decomposition(aten.clamp) 1856@out_wrapper() 1857@elementwise_type_promotion_wrapper( 1858 type_promoting_args=("a", "min", "max"), 1859 type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT, 1860) 1861def clamp( 1862 a: TensorLikeType, 1863 min: Optional[TensorOrNumberLikeType] = None, 1864 max: Optional[TensorOrNumberLikeType] = None, 1865) -> TensorLikeType: 1866 # NOTE: grad behavior with implementation `where` is not consistent on `nan` 1867 if min is None and max is None: 1868 msg = "clamp called but both min and max are none!" 1869 raise ValueError(msg) 1870 if min is not None: 1871 a_isnan = torch.isnan(a) 1872 condition = torch.bitwise_or(torch.ge(a, min), a_isnan) # type: ignore[arg-type] 1873 # we should also propagate `nan` coming from boundaries. However, that's 1874 # not necessary since `ge` would already `False` when either operands has 1875 # a `nan`. So this line below is redundant 1876 # `condition = bitwise_and(condition, bitwise_not(isnan(min)))` 1877 a = torch.where(condition, a, min) # type: ignore[arg-type] 1878 if max is not None: 1879 a_isnan = torch.isnan(a) 1880 # same as above, no need to adjust `nan` from `max` 1881 condition = torch.bitwise_or(torch.le(a, max), a_isnan) # type: ignore[arg-type] 1882 a = torch.where(condition, a, max) # type: ignore[arg-type] 1883 1884 return a 1885 1886 1887@register_decomposition(aten.clamp_min) 1888@out_wrapper() 1889def clamp_min( 1890 self: TensorLikeType, 1891 min: Optional[TensorOrNumberLikeType] = None, 1892) -> TensorLikeType: 1893 return torch.clamp(self, min=min) # type: ignore[arg-type] 1894 1895 1896@register_decomposition(aten.clamp_max) 1897@out_wrapper() 1898def clamp_max( 1899 self: TensorLikeType, 1900 max: Optional[TensorOrNumberLikeType] = None, 1901) -> TensorLikeType: 1902 return torch.clamp(self, max=max) # type: ignore[arg-type] 1903 1904 1905# 1906# Conditional references 1907# 1908 1909 1910# https://pytorch.org/docs/stable/generated/torch.where.html 1911# TODO: implement alternate where 1912@register_decomposition(aten.where) 1913@out_wrapper() 1914@elementwise_type_promotion_wrapper( 1915 type_promoting_args=("a", "b"), 1916 type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.NO_OPMATH, 1917) 1918def where( 1919 pred: Tensor, 1920 a: Optional[TensorOrNumberLikeType] = None, 1921 b: Optional[TensorOrNumberLikeType] = None, 1922): 1923 """ """ 1924 1925 if a is None or b is None: 1926 raise NotImplementedError 1927 1928 utils.check_same_device(pred, a, b, allow_cpu_scalar_tensors=True) 1929 torch._check( 1930 pred.dtype is torch.bool, 1931 lambda: f"expected predicate to be bool, got {pred.dtype}", 1932 ) 1933 1934 pred, a, b = _maybe_broadcast(pred, a, b) 1935 return prims.where(pred, a, b) 1936 1937 1938# 1939# Data Movement References 1940# 1941@register_decomposition(aten.clone) 1942@out_wrapper() 1943def clone( 1944 a: TensorLikeType, *, memory_format: torch.memory_format = torch.preserve_format 1945) -> TensorLikeType: 1946 result = prims.clone(a, memory_format=memory_format) 1947 return result 1948 1949 1950def copy_to(a: Tensor, b: Tensor, *, allow_cross_device=True): 1951 if not allow_cross_device and a.device != b.device: 1952 msg = f"Attempting to copy from device {b.device} to device {a.device}, but cross-device copies are not allowed!" 1953 raise RuntimeError(msg) 1954 1955 return prims.copy_to(a, b) 1956 1957 1958@register_decomposition(aten.item) 1959def item(a: TensorLikeType) -> NumberType: 1960 if a.numel() != 1: 1961 msg = f"Can't convert a tensor with {a.numel()} elements to a number!" 1962 raise ValueError(msg) 1963 1964 # NOTE: explicit conversion is necessary for bool! 1965 # See https://github.com/pytorch/pytorch/issues/78071 1966 number_type = utils.dtype_to_type(a.dtype) 1967 return number_type(prims.item(a)) 1968 1969 1970# fast path when `to` returns an alias to input. This mimics the same function in aten 1971def _to_will_alias( 1972 a: TensorLikeType, 1973 device: Optional[DeviceLikeType] = None, 1974 dtype: Optional[torch.dtype] = None, 1975 copy: Optional[bool] = None, 1976 layout: Optional[torch.layout] = None, 1977 memory_format: Optional[torch.memory_format] = None, 1978 pin_memory: Optional[bool] = False, 1979 non_blocking: bool = False, # not using non_blocking 1980) -> bool: 1981 return ( 1982 not copy 1983 and (device is None or a.device == device) 1984 and (dtype is None or a.dtype == dtype) 1985 and (layout is None or a.layout == layout) 1986 # is_pinned issue #84925 1987 # and (pin_memory is None or pin_memory == a.is_pinned()) 1988 and ( 1989 memory_format is None 1990 or memory_format == torch.preserve_format 1991 or utils.is_contiguous_for_memory_format(a, memory_format=memory_format) 1992 ) 1993 ) 1994 1995 1996@singledispatch 1997def _to_dispatch(*args, **kwargs): 1998 raise NotImplementedError 1999 2000 2001@_to_dispatch.register 2002def _to_device( 2003 device: torch.device, 2004 dtype: torch.dtype, 2005 non_blocking: bool = False, 2006 copy: bool = False, 2007 memory_format: Optional[torch.memory_format] = None, 2008) -> Dict[str, Any]: 2009 kwargs = { 2010 "device": device, 2011 "dtype": dtype, 2012 "non_blocking": non_blocking, 2013 "copy": copy, 2014 "memory_format": memory_format, 2015 } 2016 return kwargs 2017 2018 2019@_to_dispatch.register 2020def _to_device_str( 2021 device: str, 2022 dtype: torch.dtype, 2023 non_blocking: bool = False, 2024 copy: bool = False, 2025 memory_format: Optional[torch.memory_format] = None, 2026) -> Dict[str, Any]: 2027 kwargs = { 2028 "device": torch.device(device), 2029 "dtype": dtype, 2030 "non_blocking": non_blocking, 2031 "copy": copy, 2032 "memory_format": memory_format, 2033 } 2034 return kwargs 2035 2036 2037@_to_dispatch.register 2038def _to_dtype( 2039 dtype: torch.dtype, 2040 non_blocking: bool = False, 2041 copy: bool = False, 2042 memory_format: Optional[torch.memory_format] = None, 2043) -> Dict[str, Any]: 2044 kwargs = { 2045 "dtype": dtype, 2046 "non_blocking": non_blocking, 2047 "copy": copy, 2048 "memory_format": memory_format, 2049 } 2050 return kwargs 2051 2052 2053@_to_dispatch.register 2054def _to_other( 2055 other: Tensor, 2056 non_blocking: bool = False, 2057 copy: bool = False, 2058 memory_format: Optional[torch.memory_format] = None, 2059) -> Dict[str, Any]: 2060 device = other.device 2061 dtype = other.dtype 2062 layout = other.layout 2063 # is_pinned issue #84925 2064 # pin_memory = other.is_pinned() 2065 kwargs = { 2066 "device": device, 2067 "dtype": dtype, 2068 "layout": layout, 2069 "non_blocking": non_blocking, 2070 "copy": copy, 2071 "memory_format": memory_format, 2072 } 2073 return kwargs 2074 2075 2076# remove to_kwargs that is already present in `a` 2077def _canonicalize_to_arguments(a: Tensor, to_kwargs: dict): 2078 options_to_check = ["dtype", "device", "layout", "memory_format"] 2079 # "device" option could be passed a str instead torch.device 2080 if "device" in to_kwargs and isinstance(to_kwargs["device"], str): 2081 to_kwargs["device"] = torch.device(to_kwargs["device"]) 2082 2083 for kw in options_to_check: 2084 if kw in to_kwargs: 2085 if ( 2086 (kw == "memory_format" and to_kwargs[kw] is torch.preserve_format) 2087 or ( 2088 kw == "device" 2089 and to_kwargs[kw].type == a.device.type 2090 and ( 2091 not to_kwargs[kw].index or to_kwargs[kw].index == a.device.index 2092 ) 2093 ) 2094 or ( 2095 getattr(a, kw, None) == to_kwargs[kw] 2096 ) # this also handles {"memory_format": None} 2097 ): 2098 to_kwargs.pop(kw) 2099 2100 2101def to(a: TensorLikeType, *args, **kwargs) -> TensorLikeType: 2102 # handled dispatch via positional arguments 2103 if len(args) != 0: 2104 kwargs = _to_dispatch(*args, **kwargs) 2105 2106 # TODO: is_pinned is not currently supported in refs or fake_tensor 2107 # https://github.com/pytorch/pytorch/issues/84925 2108 assert "pin_memory" not in kwargs 2109 _canonicalize_to_arguments(a, kwargs) 2110 2111 if _to_will_alias(a, **kwargs): 2112 return a 2113 2114 copy = kwargs.pop("copy") if "copy" in kwargs else False 2115 non_blocking = kwargs.pop("non_blocking") if "non_blocking" in kwargs else False 2116 2117 # short-circuit to `prims.convert_element_type` when `to` is just a dtype change 2118 if ( 2119 (copy or (kwargs.get("dtype", a.dtype) != a.dtype)) 2120 and (not non_blocking) 2121 and ("memory_format" not in kwargs) 2122 and ("device" not in kwargs) 2123 and ("layout" not in kwargs) 2124 # is_pinned issue #84925 2125 # and ("pin_memory" not in kwargs) 2126 ): 2127 return prims.convert_element_type(a, kwargs.get("dtype", a.dtype)) 2128 2129 result = torch.empty_like(a, **kwargs) 2130 # TODO: non_blocking should be handled by `copy_to` 2131 copy_to(result, a) 2132 return result 2133 2134 2135# 2136# Reduction references 2137# 2138 2139 2140def _reduction( 2141 a: TensorLikeType, 2142 prim: Callable, 2143 *, 2144 has_identity: bool = True, 2145 accepts_dim_tuple: bool = True, # to handle min/argmin that accept single dim only 2146 dims: Optional[DimsType] = None, 2147 keepdims: bool = False, 2148 dtype: Optional[torch.dtype] = None, # should be specified for ops that support it 2149 out: Optional[Tensor] = None, 2150 output_dtype_kind: REDUCTION_OUTPUT_TYPE_KIND, 2151) -> TensorLikeType: # it is usually SAME, but I want 2152 # ref writers to actually think about what to put here 2153 assert isinstance(a, TensorLike) 2154 if a.ndim > 64: 2155 raise RuntimeError( 2156 f"Received a tensor with {a.ndim} dimensions, but only tensors with up to 64 dims are supported!" 2157 ) 2158 2159 if out is not None: 2160 assert isinstance(out, TensorLike) 2161 if dtype is not None: 2162 # TODO - this is true for eager mode currently, but it's wrong behavior for complex norms 2163 if dtype != out.dtype: 2164 raise RuntimeError( 2165 "dtype argument and out dtype must match in reduction" 2166 ) 2167 if not accepts_dim_tuple: 2168 assert dims is None or isinstance(dims, Dim) 2169 if isinstance(dims, Dim): 2170 dims = (dims,) # type: ignore[assignment] 2171 dims = utils.reduction_dims(a.shape, dims) 2172 if not has_identity: 2173 valid_shape = a.ndim == 0 or builtins.all(a.shape[i] for i in dims) 2174 if not valid_shape: 2175 raise RuntimeError( 2176 "reducing over zero-size dimension for reduction operation without identity" 2177 ) 2178 computation_dtype, result_dtype = utils.reduction_dtypes( 2179 a, output_dtype_kind, dtype 2180 ) 2181 a = _maybe_convert_to_dtype(a, computation_dtype) # type: ignore[method-assign] 2182 result = prim(a, dims) 2183 if keepdims: 2184 output_shape = [a.shape[i] if i not in dims else 1 for i in range(a.ndim)] 2185 broadcast_dims = [i for i in range(a.ndim) if i not in dims] 2186 result = prims.broadcast_in_dim(result, output_shape, broadcast_dims) 2187 2188 if out is not None: 2189 assert result_dtype is not None 2190 if dtype is not None and result_dtype != out.dtype: 2191 raise RuntimeError( 2192 "Expected the dtype of reduction result and out to match" 2193 ) 2194 out = _maybe_resize_out(out, result.shape) 2195 return _safe_copy_out(copy_from=result, copy_to=out) # type: ignore[arg-type] 2196 2197 if result.dtype != result_dtype and result_dtype is not None: 2198 result = prims.convert_element_type(result, result_dtype) 2199 2200 return result 2201 2202 2203def _make_copy_from_view(fn): 2204 """ 2205 Given a view function (e.g. torch.diagonal) generates its copy variant (e.g. torch.diagonal_copy) 2206 """ 2207 aten_fn = getattr(aten, fn.__name__) 2208 annotations = getattr(fn, "__annotations__", {}) 2209 fn = out_wrapper()(aten_fn) 2210 2211 @wraps(fn) 2212 def _fn(*args, out=None, **kwargs): 2213 result = fn(*args, out=out, **kwargs) 2214 if out is not None: 2215 return result 2216 2217 return pytree.tree_map( 2218 lambda x: x.clone(memory_format=torch.contiguous_format), 2219 result, 2220 ) 2221 2222 copy_name = f"{fn.__name__}_copy" 2223 _fn.__name__ = copy_name 2224 _fn.__annotations__.update(annotations) 2225 register_decomposition(getattr(aten, copy_name))(_fn) 2226 return _fn 2227 2228 2229@register_decomposition(aten.all) 2230@out_wrapper() 2231def all( 2232 a: TensorLikeType, 2233 dim: Optional[DimsType] = None, 2234 keepdim: bool = False, 2235) -> TensorLikeType: 2236 result = torch.logical_not(torch.any(torch.logical_not(a), dim, keepdim=keepdim)) 2237 2238 if a.dtype == torch.uint8: 2239 result = result.to(dtype=torch.uint8) 2240 2241 return result 2242 2243 2244@register_decomposition(aten.any) 2245@out_wrapper() 2246def any( 2247 a: TensorLikeType, 2248 dim: Optional[DimsType] = None, 2249 keepdim: bool = False, 2250) -> TensorLikeType: 2251 a_ = _maybe_convert_to_dtype(a, torch.bool) 2252 if isinstance(dim, (list, tuple)) and len(dim) == 0: 2253 result = a_.clone() 2254 else: 2255 result = a_.sum(dim=dim, keepdim=keepdim).ne(False) 2256 2257 # Preserves uint8 -- probably a legacy mask thing 2258 if a.dtype is torch.uint8: 2259 return prims.convert_element_type(result, torch.uint8) 2260 2261 return result 2262 2263 2264@register_decomposition([aten.sum.dim_IntList, aten.sum.IntList_out]) 2265def sum( 2266 a: TensorLikeType, 2267 dim: Union[Optional[int], Optional[List[int]]] = None, 2268 keepdim: bool = False, 2269 *, 2270 dtype: Optional[torch.dtype] = None, 2271 out: Optional[Tensor] = None, 2272) -> TensorLikeType: 2273 if dtype is None: 2274 if out is not None: 2275 dtype = out.dtype 2276 elif utils.is_boolean_dtype(a.dtype) or utils.is_integer_dtype(a.dtype): 2277 dtype = torch.int64 2278 else: 2279 dtype = a.dtype 2280 # reduces over all dimensions if dim=() is passed 2281 if dim == () or dim == []: 2282 dim = None 2283 return _reduction( 2284 a, 2285 prims.sum, 2286 dims=dim, 2287 keepdims=keepdim, 2288 dtype=dtype, 2289 out=out, 2290 output_dtype_kind=REDUCTION_OUTPUT_TYPE_KIND.SAME, 2291 ) 2292 2293 2294def sum_to_size( 2295 a: Tensor, 2296 *shape, 2297) -> Tensor: 2298 shape = utils.extract_shape_from_varargs(shape, validate=False) 2299 torch._check( 2300 utils.is_expandable_to(shape, a.shape), 2301 lambda: f'sum_to_size: size "{shape}" is not expandable to size "{a.shape}"', 2302 ) 2303 # In ATen scalar tensors are sent through sum and the result is returned as 2304 # type promoted 2305 if utils.is_same_shape(shape, a.shape) and len(shape) > 0: 2306 return prims.view_of(a) 2307 leading_dims = a.ndim - len(shape) 2308 reduce_dims = tuple(range(leading_dims)) + tuple( 2309 i 2310 for i in range(leading_dims, len(shape)) 2311 if shape[i - leading_dims] == 1 and a.shape[i] != 1 2312 ) 2313 return torch.sum(a, dim=reduce_dims, keepdim=True, dtype=None) 2314 2315 2316@register_decomposition(aten.prod) 2317def prod( 2318 a: TensorLikeType, 2319 dim: Union[Optional[int], Optional[List[int]]] = None, 2320 keepdim: bool = False, 2321 *, 2322 dtype=None, 2323 out: Optional[Tensor] = None, 2324) -> TensorLikeType: 2325 if dtype is None: 2326 if out is not None: 2327 dtype = out.dtype 2328 elif utils.is_boolean_dtype(a.dtype) or utils.is_integer_dtype(a.dtype): 2329 dtype = torch.int64 2330 else: 2331 dtype = a.dtype 2332 # reduces over all dimensions if dim=() is passed 2333 if dim == () or dim == []: 2334 dim = None 2335 return _reduction( 2336 a, 2337 prims.prod, 2338 dims=dim, 2339 keepdims=keepdim, 2340 dtype=dtype, 2341 out=out, 2342 output_dtype_kind=REDUCTION_OUTPUT_TYPE_KIND.SAME, 2343 ) 2344 2345 2346@register_decomposition(aten.amin) 2347def amin( 2348 a: TensorLikeType, 2349 dim: Optional[DimsType] = None, 2350 keepdim: bool = False, 2351 *, 2352 out: Optional[Tensor] = None, 2353) -> TensorLikeType: 2354 # reduces over all dimensions if dim=() is passed 2355 if dim == () or dim == []: 2356 dim = None 2357 2358 return _reduction( 2359 a, 2360 prims.amin, 2361 dims=dim, 2362 keepdims=keepdim, 2363 dtype=None, 2364 out=out, 2365 has_identity=False, 2366 output_dtype_kind=REDUCTION_OUTPUT_TYPE_KIND.SAME, 2367 ) 2368 2369 2370@register_decomposition(aten.amax) 2371def amax( 2372 a: TensorLikeType, 2373 dim: Optional[DimsType] = None, 2374 keepdim: bool = False, 2375 *, 2376 out: Optional[Tensor] = None, 2377) -> TensorLikeType: 2378 # reduces over all dimensions if dim=() is passed 2379 if dim == () or dim == []: 2380 dim = None 2381 2382 return _reduction( 2383 a, 2384 prims.amax, 2385 dims=dim, 2386 keepdims=keepdim, 2387 dtype=None, 2388 out=out, 2389 has_identity=False, 2390 output_dtype_kind=REDUCTION_OUTPUT_TYPE_KIND.SAME, 2391 ) 2392 2393 2394def _dim_var_dispatch(dim=None, unbiased=None): 2395 # There's the following overload of torch.var: 2396 # var(Tensor self, bool unbiased=True) -> (Tensor, Tensor) 2397 # We need to explicitly convert bool dims to unbiased arg 2398 if unbiased is None and isinstance(dim, bool): 2399 unbiased = dim 2400 dim = None 2401 return dim, unbiased 2402 2403 2404@register_decomposition(aten.var) 2405@out_wrapper() 2406def var( 2407 a: TensorLikeType, 2408 dim: Optional[DimsType] = None, 2409 unbiased: Optional[bool] = None, 2410 keepdim: bool = False, 2411 *, 2412 correction: Optional[NumberType] = None, 2413) -> TensorLikeType: 2414 dim, unbiased = _dim_var_dispatch(dim, unbiased) 2415 correction = utils.set_correction(unbiased, correction) 2416 # reduces over all dimensions if dim=() is passed 2417 if dim == () or dim == []: 2418 dim = None 2419 2420 result = _reduction( 2421 a, 2422 partial(prims.var, correction=correction), 2423 dims=dim, 2424 keepdims=keepdim, 2425 dtype=None, 2426 out=None, 2427 has_identity=True, 2428 output_dtype_kind=REDUCTION_OUTPUT_TYPE_KIND.COMPLEX_TO_FLOAT, 2429 ) 2430 return result 2431 2432 2433@register_decomposition(aten.std) 2434@out_wrapper() 2435def std( 2436 a: TensorLikeType, 2437 dim: Union[Optional[int], Optional[List[int]]] = None, 2438 unbiased: Optional[bool] = None, 2439 keepdim: bool = False, 2440 *, 2441 correction: Optional[NumberType] = None, 2442) -> TensorLikeType: 2443 dim, unbiased = _dim_var_dispatch(dim, unbiased) 2444 correction = utils.set_correction(unbiased, correction) 2445 2446 opmath_dtype, dtype = utils.reduction_dtypes( 2447 a, REDUCTION_OUTPUT_TYPE_KIND.COMPLEX_TO_FLOAT 2448 ) 2449 a = _maybe_convert_to_dtype(a, opmath_dtype) 2450 a_var = torch.var(a, dim, correction=correction, keepdim=keepdim) 2451 a_std = torch.sqrt(a_var) 2452 assert dtype is not None 2453 return _maybe_convert_to_dtype(a_std, dtype) 2454 2455 2456@register_decomposition(aten.mean) 2457def mean( 2458 a: TensorLikeType, 2459 dim: Optional[DimsType] = None, 2460 keepdim: bool = False, 2461 *, 2462 dtype=None, 2463 out=None, 2464) -> TensorLikeType: 2465 # reduces over all dimensions if dim=() is passed 2466 if dim == () or dim == []: 2467 dim = None 2468 orig_dtype = dtype 2469 if dtype is None: 2470 dtype = a.dtype 2471 # can't use out wrapper because of this argument 2472 torch._check( 2473 out is None or out.dtype == dtype, 2474 lambda: f"Expected out tensor to have dtype {dtype}, but got {out.dtype} instead", 2475 ) 2476 result = _reduction( 2477 a, 2478 prims.sum, 2479 dims=dim, 2480 keepdims=keepdim, 2481 dtype=dtype, 2482 out=None, 2483 output_dtype_kind=REDUCTION_OUTPUT_TYPE_KIND.KEEP_PROMOTED_TYPE, 2484 ) 2485 torch._check( 2486 utils.is_float_dtype(dtype) or utils.is_complex_dtype(dtype), 2487 lambda: ( 2488 f"mean(): could not infer output dtype. " 2489 f"{'Input' if orig_dtype is None else 'Optional'} dtype must be either " 2490 f"a floating point or complex dtype. Got: {dtype}" 2491 ), 2492 ) 2493 if isinstance(dim, Dim): 2494 dim = (dim,) # type: ignore[assignment] 2495 dims = utils.reduction_dims(a.shape, dim) # type: ignore[arg-type] 2496 nelem = 1 if a.ndim == 0 else reduce(operator.mul, (a.shape[i] for i in dims), 1) 2497 result = true_divide(result, nelem) 2498 result_dtype = a.dtype if dtype is None else dtype 2499 result = _maybe_convert_to_dtype(result, result_dtype) # type: ignore[method-assign] 2500 if out is not None: 2501 assert isinstance(out, TensorLike) 2502 out = _maybe_resize_out(out, result.shape) 2503 return _safe_copy_out(copy_from=result, copy_to=out) # type: ignore[arg-type] 2504 return result 2505 2506 2507@register_decomposition(aten.std_mean) 2508@out_wrapper("out0", "out1") 2509def std_mean( 2510 a: TensorLikeType, 2511 dim: Optional[DimsType] = None, 2512 *, 2513 unbiased: Optional[bool] = None, 2514 keepdim: bool = False, 2515 correction: Optional[NumberType] = None, 2516): 2517 dim, unbiased = _dim_var_dispatch(dim, unbiased) 2518 correction = utils.set_correction(unbiased, correction) 2519 opmath_dtype, dtype = utils.reduction_dtypes( 2520 a, REDUCTION_OUTPUT_TYPE_KIND.COMPLEX_TO_FLOAT 2521 ) 2522 original_dtype = a.dtype 2523 a = _maybe_convert_to_dtype(a, opmath_dtype) 2524 a_var, a_mean = torch.var_mean(a, dim, correction=correction, keepdim=keepdim) 2525 a_std = torch.sqrt(a_var) 2526 assert dtype is not None 2527 return ( 2528 _maybe_convert_to_dtype(a_std, dtype), 2529 _maybe_convert_to_dtype(a_mean, original_dtype), 2530 ) 2531 2532 2533@register_decomposition(aten.var_mean) 2534@out_wrapper("out0", "out1") 2535def var_mean( 2536 a: TensorLikeType, 2537 dim: Optional[DimsType] = None, 2538 unbiased: Optional[bool] = None, 2539 keepdim: bool = False, 2540 *, 2541 correction: Optional[NumberType] = None, 2542): 2543 dim, unbiased = _dim_var_dispatch(dim, unbiased) 2544 v = var(a, dim, unbiased, keepdim, correction=correction) 2545 m = mean(a, dim, keepdim) 2546 return v, m 2547 2548 2549@register_decomposition(aten.addr) 2550@out_wrapper() 2551@elementwise_type_promotion_wrapper( 2552 type_promoting_args=("self", "vec1", "vec2"), 2553 type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT, 2554) 2555def addr( 2556 self: TensorLikeType, 2557 vec1: TensorLikeType, 2558 vec2: TensorLikeType, 2559 *, 2560 beta: NumberType = 1, 2561 alpha: NumberType = 1, 2562) -> TensorLikeType: 2563 torch._check( 2564 vec1.ndim == 1, 2565 lambda: f"addr: Expected 1-D argument vec1, but got {vec1.ndim}-D", 2566 ) 2567 torch._check( 2568 vec2.ndim == 1, 2569 lambda: f"addr: Expected 1-D argument vec2, but got {vec2.ndim}-D", 2570 ) 2571 for arg, arg_name in ((alpha, "alpha"), (beta, "beta")): 2572 if isinstance(arg, bool): 2573 torch._check( 2574 utils.is_boolean_dtype(self.dtype) 2575 and utils.is_boolean_dtype(vec1.dtype) 2576 and utils.is_boolean_dtype(vec2.dtype), 2577 lambda: f"Boolean {arg_name} only supported for Boolean results.", 2578 ) 2579 self = self.expand(vec1.shape[0], vec2.shape[0]) 2580 if utils.is_boolean_dtype(self.dtype): 2581 # Integers are accepted for booleans 2582 torch._check( 2583 is_weakly_lesser_type(type(beta), int), 2584 lambda: f"expected bool/int beta but got {type(beta)}", 2585 ) 2586 torch._check( 2587 is_weakly_lesser_type(type(alpha), int), 2588 lambda: f"expected bool/int alpha but got {type(beta)}", 2589 ) 2590 if not beta: 2591 return torch.outer(vec1, vec2) if alpha else torch.full_like(self, False) 2592 else: 2593 return torch.logical_or( 2594 self, 2595 torch.outer(vec1, vec2) if alpha else torch.full_like(self, False), 2596 ) 2597 else: 2598 torch._check( 2599 is_weakly_lesser_type(type(beta), dtype_to_type(self.dtype)), 2600 lambda: f"cannot safely convert {type(beta)} to {self.dtype}", 2601 ) 2602 torch._check( 2603 is_weakly_lesser_type(type(alpha), dtype_to_type(self.dtype)), 2604 lambda: f"cannot safely convert {type(alpha)} to {self.dtype}", 2605 ) 2606 if beta == 0: 2607 # This means NaNs from self are dropped if beta is zero 2608 return alpha * torch.outer(vec1, vec2) 2609 else: 2610 return beta * self + alpha * torch.outer(vec1, vec2) 2611 2612 2613# CompositeImplicitAutograd - don't register decomp 2614def atleast_1d( 2615 arg: Union[TensorLikeType, Sequence[TensorLikeType]], *args: TensorLikeType 2616) -> Union[TensorLikeType, Tuple[TensorLikeType, ...]]: 2617 """Reference implementation of :func:`torch.atleast_1d`.""" 2618 if not args and isinstance(arg, collections.abc.Sequence): 2619 args_ = arg 2620 else: 2621 assert not isinstance(arg, collections.abc.Sequence) 2622 args_ = (arg,) + args 2623 res = tuple(a if a.ndim >= 1 else unsqueeze(a, 0) for a in args_) 2624 return res if len(res) > 1 else res[0] 2625 2626 2627# Helper function with assert to avoid MyPy error 2628# of incompatible type passed to unsqueeze 2629def _unsqueeze_atleast( 2630 at_least_fn: Callable, dim: int, arg: TensorLikeType 2631) -> TensorLikeType: 2632 arg_ = at_least_fn(arg) 2633 assert isinstance(arg_, TensorLike) 2634 return unsqueeze(arg_, dim) 2635 2636 2637# CompositeImplicitAutograd - don't register decomp 2638def atleast_2d( 2639 arg: Union[TensorLikeType, Sequence[TensorLikeType]], *args: TensorLikeType 2640) -> Union[TensorLikeType, Tuple[TensorLikeType, ...]]: 2641 """Reference implementation of :func:`torch.atleast_2d`.""" 2642 if not args and isinstance(arg, collections.abc.Sequence): 2643 args_ = arg 2644 else: 2645 assert not isinstance(arg, collections.abc.Sequence) 2646 args_ = (arg,) + args 2647 unsqueeze_atleast_1d = partial(_unsqueeze_atleast, atleast_1d, 0) 2648 res = tuple(a if a.ndim >= 2 else unsqueeze_atleast_1d(a) for a in args_) 2649 return res if len(res) > 1 else res[0] 2650 2651 2652# CompositeImplicitAutograd - don't register decomp 2653def atleast_3d( 2654 arg: Union[TensorLikeType, Sequence[TensorLikeType]], *args: TensorLikeType 2655) -> Union[TensorLikeType, Tuple[TensorLikeType, ...]]: 2656 """Reference implementation of :func:`torch.atleast_3d`.""" 2657 if not args and isinstance(arg, collections.abc.Sequence): 2658 args_ = arg 2659 else: 2660 assert not isinstance(arg, collections.abc.Sequence) 2661 args_ = (arg,) + args 2662 unsqueeze_atleast_2d = partial(_unsqueeze_atleast, atleast_2d, -1) 2663 res = tuple(a if a.ndim >= 3 else unsqueeze_atleast_2d(a) for a in args_) 2664 return res if len(res) > 1 else res[0] 2665 2666 2667def as_strided( 2668 a: TensorLikeType, 2669 size: ShapeType, 2670 stride: StrideType, 2671 storage_offset: Optional[int] = None, 2672) -> TensorLikeType: 2673 storage_offset_int = ( 2674 storage_offset if storage_offset is not None else a.storage_offset() 2675 ) 2676 return prims.as_strided(a, size, stride, storage_offset_int) 2677 2678 2679@register_decomposition(aten.as_strided_scatter) 2680@out_wrapper() 2681def as_strided_scatter( 2682 input: TensorLikeType, 2683 src: TensorLikeType, 2684 size: ShapeType, 2685 stride: StrideType, 2686 storage_offset: Optional[int] = None, 2687) -> TensorLikeType: 2688 storage_offset_int = 0 if storage_offset is None else storage_offset 2689 return prims.as_strided_scatter(input, src, size, stride, storage_offset_int) 2690 2691 2692def broadcast_shapes(*shapes) -> ShapeType: 2693 return torch.Size(_broadcast_shapes(*shapes)) 2694 2695 2696@aten.broadcast_tensors.default.py_impl(DispatchKey.CompositeImplicitAutograd) 2697@aten.broadcast_tensors.default.py_impl(DispatchKey.Meta) 2698def broadcast_tensors(*tensors) -> List[TensorLikeType]: 2699 if len(tensors) == 1 and not isinstance(tensors[0], Tensor): 2700 tensors = tensors[0] 2701 return list(_maybe_broadcast(*tensors, preserve_cpu_scalar_tensors=False)) 2702 2703 2704# CompositeImplicitAutograd - don't register decomp 2705def broadcast_to(a: TensorLikeType, size: ShapeType) -> TensorLikeType: 2706 start = len(size) - len(a.shape) 2707 dims = tuple(range(start, len(a.shape) + start)) 2708 return prims.broadcast_in_dim(a, size, dims) 2709 2710 2711@register_decomposition(aten.cat) 2712@out_wrapper() 2713@elementwise_type_promotion_wrapper( 2714 type_promoting_args=("tensors",), 2715 type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.NO_OPMATH, 2716) 2717def cat(tensors: TensorSequenceType, dim: int = 0) -> TensorLikeType: 2718 def cat_compute_output_memory_format(inputs): 2719 format = None 2720 for t in inputs: 2721 f = utils.suggest_memory_format(t) 2722 if f == torch.contiguous_format: 2723 return f 2724 if format is not None and format != f: 2725 return torch.contiguous_format 2726 format = f 2727 assert format is not None 2728 return format 2729 2730 if len(tensors) == 0: 2731 msg = "cat expects at least one tensor, but received zero!" 2732 raise ValueError(msg) 2733 2734 for tensor in tensors: 2735 assert isinstance(tensor, TensorLike) 2736 2737 utils.check_same_device(*tensors, allow_cpu_scalar_tensors=False) 2738 2739 from torch.fx.experimental.symbolic_shapes import guard_size_oblivious 2740 2741 # This is a bit tricky. Naively, you would expect to just pick one 2742 # arbitrary tensor and check that all tensors match this tensor. However, 2743 # there is legacy behavior which says that if you have a 1-D empty tensor 2744 # (0,), this is permissible. So you can't assume that all the tensors 2745 # have same dimensionality, and you can't assume that the first tensor is 2746 # the correct stencil. 2747 # 2748 # We'll implement this in a few passes. First, we will try to infer the 2749 # ndim of the cat output. If this ndim != 1, then we know that all ndim = 2750 # 1 inputs must be empty, or are errors. If this ndim == 1, then life 2751 # is easy (the legacy special case coincides with regular handling). 2752 # 2753 # NB: The regular implementation of cat just filters out empty inputs, 2754 # but we do it slightly different here for better handling for unbacked 2755 # SymInts 2756 2757 example = None 2758 for i, t in enumerate(tensors): 2759 if example is None: 2760 if t.ndim != 1: 2761 example = t 2762 else: 2763 if t.ndim != 1: 2764 torch._check( 2765 t.ndim == example.ndim, 2766 lambda: "Number of dimensions of tensors must match. " 2767 f"Expected {example.ndim}-D tensors, but got {t.ndim}-D for " 2768 f"tensor number {i} in the list", 2769 ) 2770 2771 if example is None: 2772 # example is None if everything is 1-D. If so, just arbitrarily pick 2773 # the first one 2774 example = tensors[0] 2775 2776 shape = example.shape 2777 filtered = [] 2778 for tensor_idx, tensor in enumerate(tensors): 2779 if len(shape) != len(tensor.shape): 2780 assert tensor.ndim == 1 # we've already checked this above 2781 # Don't suggest the legacy behavior in the error message 2782 torch._check( 2783 # NB: it is not enough to simply assert that tensor.shape[0] == 0; 2784 # this MUST be true even under guard size oblivious. 2785 # Effectively, we must actually know that the shape is zero, 2786 # passing an unbacked SymInt which we will defer a runtime 2787 # assert on won't cut it. This is a policy decision (size 2788 # oblivious semantics say that u0 tensors never are inferred 2789 # to be zero size, even if they must be that for the cat to go 2790 # through), and is load bearing for our Inductor lowerings 2791 # (which assume that size oblivious tests are OK to determine 2792 # if a shape is permissibly zero.) 2793 guard_size_oblivious(tensor.shape[0] == 0), 2794 lambda: f"Number of dimensions of tensors must match. " 2795 f"Expected {example.ndim}-D tensors, but got 1-D for " 2796 f"tensor number {tensor_idx} in the list", 2797 ) 2798 else: 2799 # Remove inputs that are 1-D, zero size 2800 if tensor.ndim == 1 and guard_size_oblivious(tensor.shape[0] == 0): 2801 continue 2802 # Don't bother checking size match, prims.cat will handle it 2803 filtered.append(tensor) 2804 2805 memory_format = cat_compute_output_memory_format(tensors) 2806 2807 if len(filtered) == 0: 2808 t = tensors[0] 2809 2810 # TODO: fix this to work with meta tensors 2811 try: 2812 # BUG? This looks like it wants to call builtins.any() but is 2813 # actually calling .any() (in this file). Changing to builtins.any() 2814 # causes tests to fail: 2815 # PYTORCH_OPINFO_SAMPLE_INPUT_INDEX=4 python test/test_ops.py -k \ 2816 # TestFakeTensorCUDA.test_fake_crossref_backward_amp_cat_cuda_float32 2817 requires_grad = bool(any(x.requires_grad for x in tensors)) # type: ignore[arg-type] 2818 except Exception: 2819 requires_grad = False # type: ignore[assignment] 2820 2821 return empty( 2822 (0,), 2823 dtype=t.dtype, 2824 device=t.device, 2825 requires_grad=requires_grad, 2826 memory_format=memory_format, 2827 ) 2828 2829 dim = utils.canonicalize_dim(filtered[0].ndim, dim) 2830 utils.validate_idx(filtered[0].ndim, dim) 2831 2832 return prims.cat(filtered, dim).clone(memory_format=memory_format) 2833 2834 2835# CompositeImplicitAutograd - don't register decomp 2836@out_wrapper() 2837def column_stack(tensors: TensorSequenceType) -> TensorLikeType: 2838 aligned_tensors = tuple( 2839 x if x.ndim > 1 else x.reshape((x.numel(), 1)) for x in tensors 2840 ) 2841 return cat(aligned_tensors, 1) 2842 2843 2844def conj(input: TensorLikeType) -> TensorLikeType: 2845 if not utils.is_complex_dtype(input.dtype): 2846 return input 2847 if input.is_sparse: 2848 return torch.conj_physical(input) 2849 return prims.conj(input) 2850 2851 2852# This replicates at::constant_pad_nd, defined in ATen/native/PadNd.cpp 2853@register_decomposition(aten.constant_pad_nd) 2854@out_wrapper() 2855def constant_pad_nd( 2856 input: TensorLikeType, pad: List[int], value: NumberType = 0 2857) -> TensorLikeType: 2858 torch._check( 2859 len(pad) % 2 == 0, 2860 lambda: f"Length of pad must be even but instead it equals {len(pad)}", 2861 ) 2862 2863 input_sizes = input.shape 2864 l_inp = len(input_sizes) 2865 2866 l_pad = len(pad) // 2 2867 l_diff = l_inp - l_pad 2868 2869 torch._check( 2870 l_inp >= l_pad, 2871 lambda: "Length of pad should be no more than twice the number of " 2872 f"dimensions of the input. Pad length is {len(pad)} while the input has " 2873 f"{l_inp} dimensions.", 2874 ) 2875 2876 c_input = input 2877 for i in range(l_diff, l_inp): 2878 pad_idx = 2 * (l_inp - i - 1) 2879 if pad[pad_idx] < 0: 2880 c_input = c_input.narrow(i, -pad[pad_idx], c_input.shape[i] + pad[pad_idx]) 2881 2882 if pad[pad_idx + 1] < 0: 2883 c_input = c_input.narrow(i, 0, c_input.shape[i] + pad[pad_idx + 1]) 2884 2885 # If all the pads are negative we can return the result. 2886 # Avoid early exiting if all pads = 0 to prevent specialization on export. 2887 # During export, raw if statements are specialized on the input, meaning 2888 # that we lose a branch depending on the example input used to export. 2889 # Here, this is either the case where all pads = 0, or the case where at 2890 # least one pad > 0 and the rest are >= 0. 2891 # Avoiding the early exit when all pads = 0 ensures we can export 2892 # constant_pad_nd for cases when all pads >= 0. 2893 # Note: if any pads are negative, this code specializes due to the if statements above. 2894 if builtins.all(p < 0 for p in pad): 2895 return c_input.clone() 2896 2897 new_shape = list(input_sizes[:l_diff]) 2898 2899 for i in range(l_pad): 2900 pad_idx = len(pad) - ((i + 1) * 2) 2901 new_dim = input_sizes[l_diff + i] + pad[pad_idx] + pad[pad_idx + 1] 2902 torch._check( 2903 new_dim > 0, 2904 lambda: f"The input size {input_sizes[l_diff + i]}, plus negative padding " 2905 f"{pad[pad_idx]} and {pad[pad_idx + 1]} resulted in a negative output size, " 2906 f"which is invalid. Check dimension {l_diff + i} of your input.", 2907 ) 2908 new_shape.append(new_dim) 2909 2910 memory_format = utils.suggest_memory_format(input) 2911 output = torch.empty( 2912 new_shape, 2913 dtype=input.dtype, 2914 device=input.device, 2915 requires_grad=input.requires_grad, 2916 memory_format=memory_format, 2917 ) 2918 2919 if value == 0 and input.dtype == torch.bool: 2920 value = False 2921 # torch.fill isn't typed to allow complex values 2922 output = torch.fill(output, value) # type: ignore[arg-type] 2923 2924 c_output = output 2925 for i in range(l_diff, l_inp): 2926 pad_idx = 2 * (l_inp - i - 1) 2927 if pad[pad_idx] >= 0: 2928 c_output = c_output.narrow( 2929 i, pad[pad_idx], c_output.shape[i] - pad[pad_idx] 2930 ) 2931 if pad[pad_idx + 1] >= 0: 2932 c_output = c_output.narrow(i, 0, c_output.shape[i] - pad[pad_idx + 1]) 2933 2934 prims.copy_to(c_output, c_input) 2935 return output 2936 2937 2938def contiguous( 2939 a: Tensor, *, memory_format: torch.memory_format = torch.contiguous_format 2940) -> Tensor: 2941 torch._check( 2942 memory_format != torch.preserve_format, 2943 lambda: "preserve memory format is unsupported by the contiguous operator", 2944 ) 2945 2946 if utils.is_contiguous_for_memory_format(a, memory_format=memory_format): 2947 return a 2948 2949 return torch.clone(a, memory_format=memory_format) 2950 2951 2952@out_wrapper() 2953def dstack(tensors: TensorSequenceType) -> TensorLikeType: 2954 torch._check(len(tensors) > 0, lambda: "dstack expects a non-empty TensorList") 2955 aligned_tensors = atleast_3d(*tensors) 2956 return cat(aligned_tensors, 2) 2957 2958 2959@register_decomposition(aten.expand) 2960def expand(a: Tensor, *shape) -> Tensor: 2961 from torch.fx.experimental.symbolic_shapes import guard_size_oblivious 2962 2963 # NOTE: cannot use utils.extract_shape_from_varargs here 2964 # because that also validates the shape, but the shape 2965 # given to expand may be "invalid" 2966 if len(shape) == 1 and isinstance(shape[0], Sequence): 2967 shape = tuple(shape[0]) 2968 2969 torch._check( 2970 len(shape) >= len(a.shape), 2971 lambda: "expand: the requested shape has too few dimensions!", 2972 ) 2973 2974 offset = len(shape) - len(a.shape) 2975 shape_ = list(shape) 2976 for idx, x in enumerate(a.shape): 2977 offset_idx = idx + offset 2978 requested_length = shape[offset_idx] 2979 torch._check( 2980 guard_size_oblivious(requested_length == x) 2981 or guard_size_oblivious(x == 1) 2982 or requested_length == -1, 2983 lambda: f"expand: attempting to expand a dimension of length {x}!", 2984 ) 2985 2986 shape_[offset_idx] = requested_length if requested_length != -1 else x 2987 2988 # At this point shape must be valid 2989 utils.validate_shape(shape_) 2990 2991 return prims.broadcast_in_dim( 2992 a, shape_, tuple(range(offset, len(a.shape) + offset)) 2993 ) 2994 2995 2996# CompositeImplicitAutograd - don't register decomp 2997def expand_as(a: Tensor, b: Tensor) -> Tensor: 2998 return a.expand(b.shape) 2999 3000 3001def chunk(a: TensorLikeType, chunks: int, dim: int = 0) -> Tuple[TensorLikeType, ...]: 3002 if chunks <= 0: 3003 msg = f"Expected at least one chunk, but got {chunks}!" 3004 raise ValueError(msg) 3005 3006 dim = utils.canonicalize_dim(a.ndim, dim) 3007 length = a.shape[dim] 3008 chunk_size = math.ceil(length / chunks) 3009 full_chunks = math.floor(length / chunk_size) 3010 tail_chunk_size = length % chunk_size 3011 3012 result = [] 3013 for i in range(full_chunks): 3014 result.append(narrow(a, dim, i * chunk_size, chunk_size)) 3015 3016 if tail_chunk_size != 0: 3017 result.append(narrow(a, dim, full_chunks * chunk_size, tail_chunk_size)) 3018 3019 return tuple(result) 3020 3021 3022# Note: flatten, unlike other shape operators, returns the input tensor on a no-op (unless 3023# a 0D tensor is flattened, in which case it's returned in 1D) 3024# CompositeImplicitAutograd - don't register decomp 3025def flatten(a: TensorLikeType, start_dim: int = 0, end_dim: int = -1) -> TensorLikeType: 3026 start_dim = utils.canonicalize_dim(a.ndim, start_dim) 3027 end_dim = utils.canonicalize_dim(a.ndim, end_dim) 3028 3029 # Short-circuits on no-op 3030 if start_dim == end_dim and a.ndim != 0: 3031 return a 3032 3033 # Tries to take a view 3034 # TODO: we could look at directing collapse_view to skip its meta function here (unsafe_collapse_view) 3035 new_shape, new_strides = prims._collapse_view_helper(a, start_dim, end_dim) 3036 if new_shape is not None: 3037 return prims.collapse_view(a, start_dim, end_dim) 3038 3039 # Makes a copy if it can't make a view 3040 return prims.collapse(a, start_dim, end_dim) 3041 3042 3043@register_decomposition(aten.flip) 3044@out_wrapper() 3045def flip(a: TensorLikeType, dims: DimsSequenceType) -> TensorLikeType: 3046 if not isinstance(dims, tuple) and not isinstance(dims, list): 3047 raise ValueError("dims has to be a sequence of ints") 3048 dims = utils.canonicalize_dims(a.ndim, dims) # type: ignore[assignment] 3049 utils.validate_no_repeating_dims(dims) 3050 return prims.rev(a, dims) 3051 3052 3053# CompositeImplicitAutograd - don't register decomp 3054def fliplr(a: TensorLikeType) -> TensorLikeType: 3055 if a.ndim < 2: 3056 raise RuntimeError("Input must be >= 2-d.") 3057 3058 return flip(a, (1,)) 3059 3060 3061# CompositeImplicitAutograd - don't register decomp 3062def flipud(a: TensorLikeType) -> TensorLikeType: 3063 if a.ndim < 1: 3064 raise RuntimeError("Input must be >= 1-d.") 3065 3066 return flip(a, (0,)) 3067 3068 3069# CompositeImplicitAutograd - don't register decomp 3070def narrow( 3071 a: TensorLikeType, dim: int, start: Union[int, TensorLikeType], length: int 3072) -> TensorLikeType: 3073 # Supports Tensor overload that was added for XLA: 3074 # https://github.com/pytorch/pytorch/issues/31558 3075 if isinstance(start, TensorLike): 3076 torch._check( 3077 start.dim() == 0 and utils.is_integer_dtype(start.dtype), 3078 lambda: "start must be an 0-dim integral Tensor.", 3079 ) 3080 start = start.item() # type: ignore[assignment] 3081 torch._check(a.dim() > 0, lambda: "narrow() cannot be applied to a 0-dim tensor.") 3082 torch._check(length >= 0, lambda: "narrow(): length must be non-negative.") 3083 dim = utils.canonicalize_dim(a.ndim, dim) 3084 dim_length = a.size(dim) 3085 torch._check_with( 3086 IndexError, 3087 -dim_length <= start and start <= dim_length, # type: ignore[arg-type] 3088 lambda: f"start out of range (expected to be in range of [{-dim_length}, {dim_length}], but got {start})", 3089 ) 3090 if start < 0: 3091 start = start + dim_length 3092 torch._check( 3093 start <= dim_length - length, # type: ignore[arg-type] 3094 lambda: f"start ({start}) + length ({length}) exceeds dimension size ({dim_length}).", 3095 ) 3096 return prims.slice_in_dim(a, start, start + length, axis=dim) 3097 3098 3099def _normalize( 3100 a: Tensor, norm_dims: DimsType, eps: float 3101) -> Tuple[Tensor, Tensor, Tensor]: 3102 """Computes mean and 1/std of a tensor along norm_dims. 3103 3104 Used as a helper function for normalization layers. 3105 3106 Args: 3107 a (Tensor): input tensor 3108 norm_dims (DimsType): dimensions to normalize over 3109 eps (float): epsilon for numerical stability 3110 3111 Returns: 3112 out (Tensor): normalized tensor. 3113 mean (Tensor): mean of the tensor along norm_dims. 3114 rstd (Tensor): 1/std of the tensor along norm_dims. 3115 """ 3116 norm_dims = utils.canonicalize_dims(a.ndim, norm_dims) 3117 computation_dtype = utils.get_computation_dtype(a.dtype) 3118 a_acc = _maybe_convert_to_dtype(a, computation_dtype) 3119 assert isinstance(a_acc, TensorLike) # to avoid mypy error for var_mean 3120 biased_var, mean = torch.var_mean( 3121 a_acc, dim=norm_dims, unbiased=False, keepdim=True 3122 ) 3123 rstd = torch.rsqrt(biased_var + eps) 3124 out = (a - mean) * rstd 3125 return out, mean, rstd 3126 3127 3128# add all specified dimensions 3129def _unsqueeze_multiple(x: TensorLikeType, dimensions: List[int]) -> TensorLikeType: 3130 for dim in sorted(dimensions): 3131 x = torch.unsqueeze(x, dim) 3132 return x 3133 3134 3135@register_decomposition(aten.native_group_norm.default) 3136def native_group_norm( 3137 input: Tensor, 3138 weight: Optional[Tensor], 3139 bias: Optional[Tensor], 3140 batch_size: int, 3141 num_channels: int, 3142 flattened_inner_size: int, 3143 num_groups: int, 3144 eps: float, 3145) -> Tuple[Tensor, Tensor, Tensor]: 3146 torch._check( 3147 input.ndim >= 2, 3148 lambda: f"Expected at least 2 dimensions for input tensor but received {input.ndim}", 3149 ) 3150 torch._check( 3151 num_channels % num_groups == 0, 3152 lambda: "Expected number of channels in input to be divisible by num_groups, " 3153 + f"but got input of shape {input.shape} and num_groups = {num_groups}", 3154 ) 3155 3156 # num_channels / num_groups and flattened inner dimension are the reduction axes 3157 reduction_dims = [2, 3] 3158 input_reshaped = torch.reshape( 3159 input, 3160 [batch_size, num_groups, num_channels // num_groups, flattened_inner_size], 3161 ) 3162 out, mean, rstd = _normalize(input_reshaped, reduction_dims, eps) 3163 out = out.view(input.shape) 3164 3165 broadcast_dims = [0] + list(range(2, input.ndim)) 3166 unsqueeze_bias = None 3167 if bias is not None: 3168 unsqueeze_bias = _unsqueeze_multiple(bias, broadcast_dims) 3169 unsqueeze_weight = None 3170 if weight is not None: 3171 unsqueeze_weight = _unsqueeze_multiple(weight, broadcast_dims) 3172 3173 if unsqueeze_weight is not None: 3174 out = out * unsqueeze_weight 3175 if unsqueeze_bias is not None: 3176 out = out + unsqueeze_bias 3177 3178 out = _maybe_convert_to_dtype(out, input.dtype) # type: ignore[assignment] 3179 mean = _maybe_convert_to_dtype(mean, input.dtype) # type: ignore[assignment] 3180 rstd = _maybe_convert_to_dtype(rstd, input.dtype) # type: ignore[assignment] 3181 3182 # remove broadcast dimensions from mean and rstd 3183 mean = torch.squeeze(mean, reduction_dims) 3184 rstd = torch.squeeze(rstd, reduction_dims) 3185 return (out, mean, rstd) 3186 3187 3188@register_decomposition(aten.native_layer_norm) 3189@out_wrapper("out0", "out1", "out2") 3190def native_layer_norm( 3191 input: Tensor, 3192 normalized_shape: ShapeType, 3193 weight: Optional[Tensor], 3194 bias: Optional[Tensor], 3195 eps: float, 3196) -> Tuple[Tensor, Tensor, Tensor]: 3197 normalized_ndim = len(normalized_shape) 3198 torch._check( 3199 normalized_ndim >= 1, 3200 lambda: "Expected normalized_shape to be at least 1-dimensional, i.e., " 3201 + "containing at least one element, but got normalized_shape = " 3202 + str(normalized_shape), 3203 ) 3204 # torch.Size([1, 2, 3]) == [1, 2, 3] evaluates to False 3205 # while torch.Size([1, 2, 3]) == (1, 2, 3) is True 3206 # therefore we use tuple(normalized_shape) 3207 torch._check( 3208 weight is None or weight.shape == tuple(normalized_shape), 3209 lambda: "Expected weight to be of same shape as normalized_shape, but got " 3210 + "weight of shape " 3211 + str(weight.shape) # type: ignore[union-attr] 3212 + " and normalized_shape = " 3213 + str(normalized_shape), 3214 ) 3215 torch._check( 3216 bias is None or bias.shape == tuple(normalized_shape), 3217 lambda: "Expected bias to be of same shape as normalized_shape, but got " 3218 + "bias of shape " 3219 + str(bias.shape) # type: ignore[union-attr] 3220 + " and normalized_shape = " 3221 + str(normalized_shape), 3222 ) 3223 torch._check( 3224 input.ndim >= normalized_ndim 3225 and input.shape[(input.ndim - normalized_ndim) :] == tuple(normalized_shape), 3226 lambda: "Given normalized_shape=" 3227 + str(normalized_shape) 3228 + ", expected input with shape " 3229 + str(normalized_shape) 3230 + ", but got input of size " 3231 + str(input.shape), 3232 ) 3233 3234 input = input.contiguous() 3235 if weight is not None: 3236 weight = weight.contiguous() 3237 if bias is not None: 3238 bias = bias.contiguous() 3239 3240 axis = input.ndim - normalized_ndim 3241 reduction_dims = list(range(axis, input.ndim)) 3242 out, mean, rstd = _normalize(input, reduction_dims, eps) 3243 3244 if weight is None and bias is not None: 3245 out = out + bias 3246 elif weight is not None and bias is None: 3247 out = out * weight 3248 elif weight is not None and bias is not None: 3249 out = out * weight + bias 3250 3251 out = _maybe_convert_to_dtype(out, input.dtype) # type: ignore[assignment] 3252 if input.device.type == "cpu": 3253 mean = _maybe_convert_to_dtype(mean, input.dtype) # type: ignore[assignment] 3254 rstd = _maybe_convert_to_dtype(rstd, input.dtype) # type: ignore[assignment] 3255 return (out, mean, rstd) 3256 3257 3258# TODO: Adding this as a meta function causes functorch tests to fail when compiled with debug mode. 3259# test/test_eager_transforms.py::TestFunctionalizeCPU::test_functionalize_fx_transpose_simple_cpu 3260@register_decomposition(aten.permute) 3261def permute(a: TensorLikeType, *dims) -> TensorLikeType: 3262 _permutation = utils.canonicalize_dims( 3263 a.ndim, utils.extract_dims_from_varargs(dims) 3264 ) 3265 return prims.transpose(a, _permutation) 3266 3267 3268@register_decomposition(aten.renorm) 3269@out_wrapper() 3270def renorm( 3271 input: TensorLikeType, p: RealNumberType, dim: int, maxnorm: RealNumberType 3272) -> TensorLikeType: 3273 torch._check(not isinstance(p, complex), lambda: "renorm: p must be real-valued") 3274 torch._check(p > 0, lambda: "renorm: non-positive norm not supported") 3275 torch._check( 3276 not isinstance(maxnorm, complex), lambda: "renorm: maxnorm must be real-valued" 3277 ) 3278 torch._check( 3279 maxnorm >= 0, lambda: f"renorm: expected maxnorm to be >= 0 but got {maxnorm}" 3280 ) 3281 ndim = input.ndim 3282 torch._check( 3283 ndim > 1, 3284 lambda: f"renorm: input needs at least 2 dimensions, got {ndim} dimensions", 3285 ) 3286 3287 dim = utils.canonicalize_dim(ndim, dim) 3288 reduce_dims = list(range(ndim)) 3289 del reduce_dims[dim] 3290 3291 # For half and bfloat16, calculate norm in float precision then cast 3292 # normalization factor to half 3293 acc_type = utils.get_computation_dtype(input.dtype) 3294 if acc_type != input.dtype: 3295 norm = torch.linalg.vector_norm( 3296 input, p, reduce_dims, keepdim=True, dtype=acc_type 3297 ) 3298 else: 3299 norm = torch.linalg.vector_norm(input, p, reduce_dims, keepdim=True) 3300 3301 eps = 1e-7 3302 norm_factor = torch.where(norm > maxnorm, maxnorm / (norm + eps), 1.0) 3303 if acc_type != input.dtype: 3304 norm_factor = prims.convert_element_type(norm_factor, input.dtype) 3305 return (input * norm_factor).contiguous() 3306 3307 3308# CompositeImplicitAutograd - don't register decomp 3309@aten.stft.center.py_impl(DispatchKey.CompositeImplicitAutograd) 3310def stft( 3311 input: Tensor, 3312 n_fft: int, 3313 hop_length: Optional[int] = None, 3314 win_length: Optional[int] = None, 3315 window: Optional[Tensor] = None, 3316 center: bool = True, 3317 pad_mode: str = "reflect", 3318 normalized: bool = False, 3319 onesided: Optional[bool] = None, 3320 return_complex: Optional[bool] = None, 3321) -> Tensor: 3322 torch._check( 3323 window is None or window.device == input.device, 3324 lambda: ( 3325 f"stft input and window must be on the same device but got self on {input.device}" 3326 + f" and window on {window.device}" # type: ignore[union-attr] 3327 ), 3328 ) 3329 3330 hop_length_ = hop_length if hop_length is not None else n_fft // 4 3331 win_length_ = win_length if win_length is not None else n_fft 3332 3333 if return_complex is None: 3334 return_complex_ = input.is_complex() or ( 3335 window is not None and utils.is_complex_dtype(window.dtype) 3336 ) 3337 torch._check( 3338 return_complex_, 3339 ( 3340 "stft requires the return_complex parameter be given for real inputs, " 3341 + "and will further require that return_complex=True in a future PyTorch release." 3342 ), 3343 ) 3344 else: 3345 return_complex_ = return_complex 3346 3347 torch._check( 3348 utils.is_float_dtype(input.dtype) or utils.is_complex_dtype(input.dtype), 3349 lambda: "stft expected a tensor of floating point or complex values", 3350 ) 3351 torch._check(1 <= input.ndim <= 2, lambda: "stft expected a 1D or 2D tensor") 3352 3353 original_ndim = input.ndim 3354 if original_ndim == 1: 3355 input = input.unsqueeze(0) 3356 3357 if center: 3358 extra_dims = 3 - input.ndim 3359 pad_amount = n_fft // 2 3360 extended_shape = [*itertools.repeat(1, extra_dims), *input.shape] 3361 input = aten.pad(input.view(extended_shape), [pad_amount, pad_amount], pad_mode) 3362 input = input.view(input.size()[extra_dims:]) 3363 3364 batch = input.size(0) 3365 length = input.size(1) 3366 torch._check( 3367 0 < n_fft <= length, 3368 lambda: f"stft expected 0 < n_fft <= {length}, but got n_fft={n_fft}", 3369 ) 3370 torch._check( 3371 hop_length_ > 0, 3372 lambda: f"stft expected hop_length > 0 but got hop_length={hop_length_}", 3373 ) 3374 torch._check( 3375 0 < win_length_ <= n_fft, 3376 lambda: f"stft expected 0 < win_length <= n_fft but got win_length={win_length_}", 3377 ) 3378 torch._check( 3379 window is None or window.shape == (win_length_,), 3380 lambda: ( 3381 f"expected a 1D window tensor of size equal to win_length={win_length_}, " 3382 + f"but got window with size {window.shape}" # type: ignore[union-attr] 3383 ), 3384 ) 3385 3386 if win_length_ < n_fft: 3387 if window is None: 3388 window = torch.ones(win_length_, dtype=input.dtype, device=input.device) 3389 left = (n_fft - win_length_) // 2 3390 window = aten.constant_pad_nd(window, [left, n_fft - win_length_ - left]) 3391 3392 input = input.unfold(dimension=-1, size=n_fft, step=hop_length_) 3393 if window is not None: 3394 input = input * window 3395 3396 complex_fft = utils.is_complex_dtype(input.dtype) 3397 onesided = onesided if onesided is not None else not complex_fft 3398 norm = "ortho" if normalized else None 3399 if onesided: 3400 torch._check( 3401 not complex_fft, 3402 lambda: "Cannot have onesided output if window or input is complex", 3403 ) 3404 out = torch.fft.rfft(input, dim=-1, norm=norm) 3405 else: 3406 out = torch.fft.fft(input, dim=-1, norm=norm) 3407 3408 out.transpose_(1, 2) 3409 3410 if original_ndim == 1: 3411 out = out.squeeze_(0) 3412 3413 return out if return_complex_ else torch.view_as_real(out) 3414 3415 3416# CompositeImplicitAutograd - don't register decomp 3417@aten.istft.default.py_impl(DispatchKey.CompositeImplicitAutograd) 3418def istft( 3419 input: Tensor, 3420 n_fft: int, 3421 hop_length: Optional[int] = None, 3422 win_length: Optional[int] = None, 3423 window: Optional[Tensor] = None, 3424 center: bool = True, 3425 normalized: bool = False, 3426 onesided: Optional[bool] = None, 3427 length: Optional[int] = None, 3428 return_complex=False, 3429) -> Tensor: 3430 torch._check( 3431 window is None or window.device == input.device, 3432 lambda: ( 3433 f"istft input and window must be on the same device but got self on {input.device}" 3434 + f" and window on {window.device}" # type: ignore[union-attr] 3435 ), 3436 ) 3437 3438 hop_length_ = hop_length if hop_length is not None else n_fft // 4 3439 win_length_ = win_length if win_length is not None else n_fft 3440 3441 torch._check( 3442 utils.is_complex_dtype(input.dtype), 3443 lambda: ( 3444 "istft input and window must be on the same device but got self on " 3445 + f"{input.device} and window on {window.device}" # type: ignore[union-attr] 3446 ), 3447 ) 3448 n_frames = input.size(-1) 3449 fft_size = input.size(-2) 3450 3451 expected_output_signal_len = n_fft + hop_length_ * (n_frames - 1) 3452 torch._check(input.numel() > 0, lambda: "istft input tensor cannot be empty") 3453 torch._check( 3454 2 <= input.ndim <= 3, 3455 lambda: f"istft expected a tensor with 2 or 3 dimensions, but got {input.ndim}", 3456 ) 3457 onesided_ = onesided if onesided is not None else fft_size != n_fft 3458 3459 if onesided_: 3460 torch._check( 3461 n_fft // 2 + 1 == fft_size, 3462 lambda: ( 3463 "istft expected the frequency dimension (3rd to the last) of the input tensor " 3464 + "to match n_fft / 2 + 1 when onesided=True, but got {fft_size}" 3465 ), 3466 ) 3467 else: 3468 torch._check( 3469 n_fft == fft_size, 3470 lambda: ( 3471 "istft expected the frequency dimension (3rd to the last) of the input tensor " 3472 + "to match n_fft when onesided=False, but got {fft_size}", 3473 ), 3474 ) 3475 3476 torch._check( 3477 0 < hop_length_ <= win_length_, 3478 lambda: "istft expected 0 < hop_length <= win_length", 3479 ) 3480 torch._check( 3481 0 < win_length_ <= n_fft, lambda: "istft expected 0 < win_length <= n_fft" 3482 ) 3483 torch._check( 3484 window is None or window.shape == (win_length_,), 3485 lambda: "Invalid window shape. window has to be 1D and length of `win_length`", 3486 ) 3487 3488 if window is None: 3489 real_dtype = utils.corresponding_real_dtype(input.dtype) 3490 window_ = torch.ones(win_length_, dtype=real_dtype, device=input.device) 3491 else: 3492 window_ = window 3493 3494 if win_length_ != n_fft: 3495 left = (n_fft - win_length_) // 2 3496 window_ = aten.constant_pad_nd(window_, (left, n_fft - win_length_ - left), 0) 3497 3498 original_ndim = input.ndim 3499 if input.ndim == 2: 3500 input = input.unsqueeze(0) 3501 3502 input = input.transpose(1, 2) 3503 norm = "ortho" if normalized else None 3504 if return_complex: 3505 torch._check( 3506 not onesided_, 3507 lambda: "cannot have onesided output if window or input is complex", 3508 ) 3509 input = torch.fft.ifft(input, dim=-1, norm=norm) 3510 else: 3511 torch._check( 3512 window is None or not utils.is_complex_dtype(window.dtype), 3513 lambda: "Complex windows are incompatible with return_complex=False", 3514 ) 3515 if not onesided_: 3516 input = input.narrow(dim=-1, start=0, length=n_fft // 2 + 1) 3517 input = torch.fft.irfft(input, dim=-1, norm=norm) 3518 3519 assert input.size(2) == n_fft 3520 3521 y_tmp = input * window_.view([1, 1, n_fft]) 3522 y = aten.unfold_backward( 3523 y_tmp, 3524 input_sizes=(y_tmp.size(0), expected_output_signal_len), 3525 dim=1, 3526 size=n_fft, 3527 step=hop_length_, 3528 ) 3529 window_envelop = aten.unfold_backward( 3530 window_.pow(2).expand((1, n_frames, n_fft)), 3531 input_sizes=(y_tmp.size(0), expected_output_signal_len), 3532 dim=1, 3533 size=n_fft, 3534 step=hop_length_, 3535 ) 3536 3537 assert expected_output_signal_len == y.size(1) 3538 assert expected_output_signal_len == window_envelop.size(1) 3539 3540 start = n_fft // 2 if center else 0 3541 if length is not None: 3542 end = start + length 3543 elif center: 3544 end = expected_output_signal_len - n_fft // 2 3545 else: 3546 end = expected_output_signal_len 3547 3548 length = max(0, end - start) 3549 y = y.narrow(dim=1, start=start, length=length) 3550 window_envelop = window_envelop.narrow(dim=1, start=start, length=length) 3551 3552 y = y / window_envelop 3553 if original_ndim == 2: 3554 y = y.squeeze(0) 3555 3556 if end > expected_output_signal_len: 3557 warnings.warn( 3558 "The length of signal is shorter than the length parameter. Result is being " 3559 + "padded with zeros in the tail. Please check your center and hop_length settings" 3560 ) 3561 y = aten.constant_pad_nd(y, (0, end - expected_output_signal_len), 0) 3562 return y 3563 3564 3565# Get the new shape and stride after applying unfold to an input tensor 3566def _get_unfold_shape_stride( 3567 a_shape: ShapeType, a_stride: StrideType, dimension: int, size: int, step: int 3568): 3569 a_ndim = len(a_shape) 3570 dim = utils.canonicalize_dim(a_ndim, dimension, wrap_scalar=True) 3571 max_size = 1 if a_ndim == 0 else a_shape[dim] 3572 last_stride = 1 if a_ndim == 0 else a_stride[dim] 3573 3574 torch._check( 3575 size <= max_size, 3576 lambda: f"Maximum size for tensor at dimension {dim} is {max_size} but size is {size}", 3577 ) 3578 3579 torch._check( 3580 step > 0, 3581 lambda: f"Step is {step} but must be > 0", 3582 ) 3583 3584 shape = list(a_shape) 3585 strides = list(a_stride) 3586 shape.append(size) 3587 strides.append(last_stride) 3588 if dim < a_ndim: 3589 shape[dim] = (shape[dim] - size) // step + 1 3590 strides[dim] *= step 3591 return shape, strides 3592 3593 3594@register_decomposition(aten.repeat) 3595@out_wrapper() 3596def repeat(a: Tensor, *repeat_shape) -> Tensor: 3597 repeat_shape = utils.extract_shape_from_varargs(repeat_shape, validate=False) 3598 torch._check( 3599 len(repeat_shape) >= len(a.shape), 3600 lambda: "repeat: Number of dimensions of repeat dims can not be smaller than number of dimensions of tensor", 3601 ) 3602 3603 if len(repeat_shape) == 0: 3604 return torch.clone(a) 3605 3606 num_new_dimensions = len(repeat_shape) - a.ndim 3607 padded_shape = [1] * num_new_dimensions 3608 for dim_size in a.shape: 3609 padded_shape.append(dim_size) 3610 3611 target_shape = tuple( 3612 padded_size * repeat_size 3613 for padded_size, repeat_size in zip(padded_shape, repeat_shape) 3614 ) 3615 3616 # return an empty tensor if one of the repeat_shape dimensions is zero 3617 if 0 in repeat_shape: 3618 return torch.empty( 3619 target_shape, 3620 dtype=a.dtype, 3621 device=a.device, 3622 requires_grad=a.requires_grad, 3623 memory_format=utils.suggest_memory_format(a), 3624 ) 3625 3626 urtensor_shape = target_shape 3627 urtensor_stride = utils.make_contiguous_strides_for(target_shape) 3628 for dim, dim_size in enumerate(padded_shape): 3629 # repeat each dimension by using unfold_copy operation 3630 urtensor_shape, urtensor_stride = _get_unfold_shape_stride( 3631 urtensor_shape, urtensor_stride, dim, dim_size, max(dim_size, 1) 3632 ) 3633 3634 # derive permute order by sorting urtensor strides 3635 enumerated_stride = list(enumerate(urtensor_stride)) 3636 enumerated_stride.sort(key=operator.itemgetter(1), reverse=True) 3637 permute_order, sorted_stride = zip(*enumerated_stride) 3638 3639 # add new and expand dimensions according to urtensor 3640 repeat_xtensor = a.expand(urtensor_shape) 3641 3642 # clone tensor to concretize expanded dimensions 3643 cloned_result = torch.clone(repeat_xtensor) 3644 3645 # transpose axis so strides are in sorted order 3646 permuted_result = cloned_result.permute(permute_order) 3647 3648 # reshape to get contiguous tensor with correct target shape 3649 return permuted_result.reshape(target_shape) 3650 3651 3652def _reshape_view_helper(a: TensorLikeType, *shape, allow_copy: bool) -> TensorLikeType: 3653 from torch.fx.experimental.symbolic_shapes import guard_size_oblivious, sym_eq 3654 3655 # Creates a valid shape 3656 shape = utils.extract_shape_from_varargs(shape, validate=False) 3657 # Reshape may be given a shape with a -1 length 3658 # This indicates that the dimension's length should be inferred 3659 shape = utils.infer_size(shape, a.numel()) 3660 3661 # Special-cases tensors with no elements 3662 if guard_size_oblivious(a.numel() == 0): 3663 return as_strided(a, shape, utils.make_contiguous_strides_for(shape)) 3664 3665 # Special-cases reshaping zero dim tensors 3666 if a.ndim == 0: 3667 _a = a 3668 for length in shape: 3669 assert length == 1 3670 _a = unsqueeze(_a, -1) 3671 if _a is a: 3672 return prims.view_of(a) 3673 else: 3674 return _a 3675 3676 # Special-cases reshaping to zero dim tensors 3677 if len(shape) == 0: 3678 _a = a 3679 for length in a.shape: 3680 assert length == 1 3681 _a = squeeze(_a, -1) 3682 if _a is a: 3683 return prims.view_of(a) 3684 else: 3685 return _a 3686 3687 if a.is_contiguous(): 3688 # Special-cases for nd_to_1d 3689 if len(shape) == 1 and a.ndim > 1: 3690 return torch.as_strided(a, [a.numel()], [1]) 3691 # Special-cases for 1d_to_2d 3692 if len(shape) == 2 and a.ndim == 1: 3693 dim0 = shape[0] 3694 dim1 = shape[1] 3695 return torch.as_strided(a, [dim0, dim1], [dim1, 1]) 3696 3697 # Handles general case: a 1+D tensor reshaped into a distinct 1+D shape 3698 3699 # NOTE [Reshape Algorithm] 3700 # This algorithm works by attempting to greedily construct the desired dimensions in 3701 # the output shape, left to right. It does this by, conceptually, accumulating 3702 # dimensions of the original tensor, also left to right, until the dimension 3703 # can be constructed using prims.split_dim. 3704 # The algorithm also has special handling for tail squeezes/unsqueezes, like 3705 # if a reshape from (5, 5) to (5, 5, 1) or vice versa. 3706 # 3707 # This algorithm does not flatten the original tensor and then split dims as appropriate 3708 # because that would create copies more often than this algorithm. flatten is the only 3709 # operation below which can create a view or a copy, and while it prefers creating 3710 # views it may sometimes create a copy if the tensor's strides do not permit a view. 3711 # As a result, this algorithm tries to minimize flattening. 3712 # 3713 # Note that a better version of this algorithm may exist. Regions which could be 3714 # flattened without creating a copy can be identified in advance, and that might 3715 # allow fewer flatten calls or faster short-circuiting to make a copy. 3716 idx = 0 3717 a_ = a 3718 for length in shape: 3719 # Handles tail unsqueezes 3720 if idx >= a_.ndim: 3721 assert length == 1 3722 last_dim = a_.ndim - 1 3723 # NOTE: using split_dim instead of unsqueeze may seem silly here, 3724 # but it's necessary to get the strides correct 3725 a_ = prims.split_dim(a_, last_dim, a_.shape[last_dim]) 3726 idx = idx + 1 3727 continue 3728 3729 # Skips dimensions that are already the correct length 3730 if guard_size_oblivious(length == a_.shape[idx]): 3731 idx = idx + 1 3732 continue 3733 3734 # Gathers enough original dimensions such that this new dimension can be created 3735 # Note that this accumulation will terminate because we've verified a and the shape 3736 # specify the same number of elements above 3737 accum = a_.shape[idx] 3738 end = idx 3739 while guard_size_oblivious(accum % length != 0): 3740 end = end + 1 3741 accum = accum * a_.shape[end] 3742 if end != idx: 3743 # NOTE: in this case multiple dimensions must be flatten to create the desired dimension 3744 # This flattening is why reshape sometimes creates a copy -- because flattening 3745 # may return a view of a copy 3746 3747 # Checks if collapse can be a view and short-circuits to copying reshape if it can't 3748 new_shape, new_strides = prims._collapse_view_helper(a_, idx, end) 3749 if new_shape is None: 3750 if allow_copy: 3751 return prims.reshape(a, shape) 3752 3753 msg = f"Cannot view a tensor with shape {a.shape} and strides {a.stride()} as a tensor with shape {shape}!" 3754 raise ValueError(msg) 3755 3756 a_ = flatten(a_, idx, end) 3757 3758 # Splits the (possibly flattened) dimension to create the desired dim length 3759 if guard_size_oblivious(accum != length): 3760 a_ = prims.split_dim(a_, idx, length) 3761 3762 idx = idx + 1 3763 3764 # Squeezes tail 3765 while idx < a_.ndim: 3766 torch._check( 3767 a_.shape[idx] == 1, 3768 lambda: f"a.size({idx}) expected to be 1 but got {a_.shape[idx]}", 3769 ) 3770 a_ = squeeze(a_, idx) 3771 3772 if a_ is a: 3773 return prims.view_of(a) 3774 else: 3775 return a_ 3776 3777 3778# CompositeImplicitAutograd - don't register decomp 3779# NOTE: shape is a vararg because Tensor.reshape can be called with as 3780# Tensor.reshape(a, b, c) or Tensor.reshape((a, b, c)) Function call 3781# torch.reshape doesn't support unpacked shapes 3782def reshape(a: TensorLikeType, *shape: ShapeType) -> TensorLikeType: 3783 return _reshape_view_helper(a, *shape, allow_copy=True) 3784 3785 3786# CompositeImplicitAutograd - don't register decomp 3787def reshape_as(self: TensorLikeType, other: TensorLikeType) -> TensorLikeType: 3788 return self.reshape(other.size()) 3789 3790 3791@register_decomposition(aten.roll) 3792@out_wrapper() 3793def roll(a: TensorLikeType, shifts: DimsType, dims: DimsType = ()) -> TensorLikeType: 3794 """Reference implementation of :func:`torch.roll`.""" 3795 dims = utils.canonicalize_dims(a.ndim, dims) 3796 # ATen specifies int[1] type for shifts and dims which expands integers to tuples of length 1 3797 if not isinstance(shifts, Iterable): 3798 shifts = (shifts,) 3799 if not isinstance(dims, Iterable): 3800 dims = (dims,) 3801 3802 # Avoid modulo by zero 3803 if a.numel() == 0: 3804 # Keeping this as ref for now as FakeTensor runs into some issues with complex tensors 3805 return a.clone() 3806 3807 if a.dim() == 0 and len(dims) > 0: 3808 raise IndexError( 3809 f"Dimension specified as {dims[0]} but tensor has no dimensions" 3810 ) 3811 3812 len_shifts = len(shifts) 3813 len_dims = len(dims) 3814 if len_shifts != 1 or len_dims != 1: 3815 if len_shifts == 0: 3816 raise RuntimeError("`shifts` required") 3817 # Takes care of the case when dims is not specified (default) 3818 # By default, the tensor is flattened before shifting, after which the original shape is restored 3819 if len_dims == 0 and len_shifts == 1: 3820 return torch.roll(torch.flatten(a), shifts, 0).view(a.shape) 3821 if len_shifts != len_dims: 3822 raise RuntimeError( 3823 f"shifts and dimensions must align. shifts: {len_shifts}, dims: {len_dims}" 3824 ) 3825 assert len_dims > 1 3826 tail_shifts = shifts[1:] 3827 tail_dims = dims[1:] 3828 first_dim_rolled = torch.roll(a, (shifts[0],), dims[0]) 3829 return torch.roll(first_dim_rolled, tail_shifts, tail_dims) 3830 3831 # This path is taken when only one dimension is rolled 3832 # For example to get `first_dim_rolled` above 3833 dim = dims[0] 3834 size = a.shape[dim] 3835 start = (size - shifts[0]) % size 3836 idx = torch.arange(size, device=a.device) 3837 return a.index_select(dim, torch.fmod(start + idx, size)) 3838 3839 3840@register_decomposition(aten.rot90) 3841@out_wrapper() 3842def rot90( 3843 a: TensorLikeType, k: int = 1, dims: DimsSequenceType = (0, 1) 3844) -> TensorLikeType: 3845 """Reference implementation of :func:`torch.rot90`.""" 3846 if len(dims) != 2: 3847 raise RuntimeError( 3848 f"expected total rotation dims == 2, but got dims = {len(dims)}" 3849 ) 3850 if a.ndim < 2: 3851 raise RuntimeError(f"expected total dims >= 2, but got total dims = {a.ndim}") 3852 3853 # Do this after the initial checks to be compatible with the behavior in 3854 # core. 3855 dims = utils.canonicalize_dims(a.ndim, dims) 3856 3857 if dims[0] == dims[1]: 3858 raise RuntimeError( 3859 f"expected rotation dims to be different, but got dim0 = {dims[0]} and dim1 = {dims[1]}" 3860 ) 3861 k = k % 4 # Rotation direction is from the second towards the first axis for k < 0 3862 if k == 1: 3863 return torch.transpose(torch.flip(a, (dims[1],)), dims[0], dims[1]) 3864 elif k == 2: 3865 return torch.flip(a, dims) 3866 elif k == 3: 3867 return torch.transpose(torch.flip(a, (dims[0],)), dims[0], dims[1]) 3868 else: 3869 return a.clone(memory_format=torch.contiguous_format) 3870 3871 3872def _check_stack_inputs(tensors: TensorSequenceType) -> None: 3873 entry_shape = tensors[0].shape 3874 for i in range(1, len(tensors)): 3875 assert tensors[i].shape == entry_shape, ( 3876 f"stack expects each tensor to be equal size, but got {entry_shape} at entry 0 " 3877 f"and {tensors[i].shape} at entry {i}" 3878 ) 3879 3880 3881@register_decomposition(aten.stack) 3882@out_wrapper() 3883def stack(tensors: TensorSequenceType, dim: int = 0) -> TensorLikeType: 3884 assert len(tensors) > 0, "stack expects a non-empty TensorList" 3885 wrapped_dim = utils.canonicalize_dim(tensors[0].ndim + 1, dim) 3886 # Refs need sparse support to check other condition 3887 if wrapped_dim < tensors[0].ndim: # and not tensors[0].is_sparse: 3888 _check_stack_inputs(tensors) 3889 result_sizes = list(tensors[0].shape) 3890 result_sizes.insert(wrapped_dim, len(tensors)) 3891 out = torch.cat(tensors, wrapped_dim) 3892 return out.view(result_sizes) 3893 3894 # If dim == tensors[0].ndim, view cannot efficiently handle it 3895 return torch.cat([t.unsqueeze(wrapped_dim) for t in tensors], dim) 3896 3897 3898# CompositeImplicitAutograd - don't register decomp 3899@out_wrapper() 3900def softmax( 3901 a: TensorLikeType, 3902 dim: int, 3903 dtype: Optional[torch.dtype] = None, 3904) -> TensorLikeType: 3905 result_dtype = dtype or a.dtype 3906 computation_dtype = utils.get_computation_dtype(result_dtype) 3907 a_ = _maybe_convert_to_dtype(a, computation_dtype) 3908 if a.numel() == 0: 3909 a_exp = exp(a_) 3910 else: 3911 a_max = amax(a_, dim, keepdim=True) 3912 a_exp = exp(a_ - a_max) 3913 return _maybe_convert_to_dtype( 3914 true_divide(a_exp, sum(a_exp, dim, keepdim=True)), result_dtype 3915 ) # type: ignore[return-value] 3916 3917 3918# CompositeImplicitAutograd - don't register decomp 3919@out_wrapper() 3920def hstack(tensors: TensorSequenceType) -> TensorLikeType: 3921 torch._check(len(tensors) > 0, lambda: "hstack expects a non-empty TensorList") 3922 aligned_tensors = atleast_1d(*tensors) 3923 if aligned_tensors[0].ndim == 1: 3924 return cat(aligned_tensors, 0) 3925 return cat(aligned_tensors, 1) 3926 3927 3928# CompositeImplicitAutograd - don't register decomp 3929@out_wrapper() 3930def vstack(tensors: TensorSequenceType) -> TensorLikeType: 3931 torch._check(len(tensors) > 0, lambda: "vstack expects a non-empty TensorList") 3932 aligned_tensors = atleast_2d(*tensors) 3933 return cat(aligned_tensors, 0) 3934 3935 3936# CompositeImplicitAutograd - don't register decomp 3937def unflatten(a: TensorLikeType, dim: int, sizes: ShapeType) -> TensorLikeType: 3938 dim = utils.canonicalize_dim(a.ndim, dim) 3939 torch._check(len(sizes) != 0, lambda: "unflatten: sizes must be non-empty") 3940 return a.view(tuple(a.shape[:dim]) + tuple(sizes) + tuple(a.shape[dim + 1 :])) 3941 3942 3943@register_decomposition(aten.unbind) 3944def unbind(t: TensorLikeType, dim: int = 0) -> TensorSequenceType: 3945 from torch.fx.experimental.symbolic_shapes import guard_size_oblivious 3946 3947 dim = utils.canonicalize_dim(t.ndim, dim) 3948 torch._check_index( 3949 len(t.shape) > 0, 3950 lambda: "Dimension specified as 0 but tensor has no dimensions", 3951 ) 3952 if guard_size_oblivious(t.shape[dim] == 0): 3953 return () 3954 else: 3955 return tuple( 3956 torch.squeeze(s, dim) for s in torch.tensor_split(t, t.shape[dim], dim) 3957 ) 3958 3959 3960@out_wrapper() 3961def index_copy(x: TensorLike, dim: int, index: TensorLike, tensor: TensorLike): 3962 return x.clone(memory_format=torch.contiguous_format).index_copy_( 3963 dim, index, tensor 3964 ) 3965 3966 3967def index_copy_(x: TensorLike, dim: int, index: TensorLike, tensor: TensorLike): 3968 dim = utils.canonicalize_dims(x.ndim, dim) 3969 torch._check( 3970 index.ndim <= 1, 3971 lambda: f"Index should have dimension 1 or 0 (got {index.ndim})", 3972 ) 3973 # Treat scalars as elements of \R^1 3974 y = x.unsqueeze(0) if x.ndim == 0 else x 3975 idx = (slice(None),) * dim + (index,) 3976 y[idx] = tensor 3977 return x 3978 3979 3980@register_decomposition(aten.index_fill) 3981@out_wrapper() 3982def index_fill( 3983 x: TensorLike, dim: int, index: TensorLike, value: Union[NumberType, TensorLike] 3984): 3985 return _index_fill(x, dim, index, value, inplace=False) 3986 3987 3988@register_decomposition(aten.index_fill_) 3989def index_fill_( 3990 x: TensorLike, dim: int, index: TensorLike, value: Union[NumberType, TensorLike] 3991): 3992 return _index_fill(x, dim, index, value, inplace=True) 3993 3994 3995def _index_fill( 3996 x: TensorLike, 3997 dim: int, 3998 index: TensorLike, 3999 value: Union[NumberType, TensorLike], 4000 *, 4001 inplace: bool, 4002): 4003 torch._check( 4004 index.ndim <= 1, 4005 lambda: f"Index should have dimension 1 or 0 (got {index.ndim})", 4006 ) 4007 if isinstance(value, TensorLike): 4008 torch._check( 4009 value.ndim == 0, 4010 lambda: "Only supports 0-dimensional value tensor. " # type: ignore[union-attr] 4011 f"Got a tensor with {value.ndim} dimensions.", 4012 ) # type: ignore[arg-type] 4013 else: 4014 value = torch.scalar_tensor( 4015 value, dtype=x.dtype, layout=x.layout, device=x.device # type: ignore[arg-type] 4016 ) 4017 4018 # index_copy has some unnecessary preconditions when x is a scalar. We do this to work through them 4019 zero_dim = x.ndim == 0 4020 y = x.unsqueeze(0) if zero_dim else x 4021 # index_copy does not broadcast on value so we have to do it manually 4022 shape = list(y.shape) 4023 shape[dim] = index.numel() 4024 value = value.expand(shape) 4025 index_copy = Tensor.index_copy_ if inplace else torch.index_copy 4026 out = index_copy(y, dim, index, value) # type: ignore[operator] 4027 if inplace: 4028 return x 4029 else: 4030 if zero_dim: 4031 # The clone is necessary so that it returns a fresh tensor rather than a view 4032 out = out.squeeze(0).clone() 4033 # index_fill preserves the strides. index_copy always returns contiguous tensors 4034 if out.stride() != x.stride(): 4035 new_out = torch.empty_like(x) 4036 new_out.copy_(out) 4037 out = new_out 4038 return out 4039 4040 4041@out_wrapper() 4042def index_add( 4043 x: TensorLike, 4044 dim: int, 4045 index: TensorLike, 4046 tensor: TensorLike, 4047 *, 4048 alpha: NumberType = 1, 4049): 4050 # index_add always returns a new contiguous tensor 4051 return x.clone(memory_format=torch.contiguous_format).index_add_( 4052 dim, index, tensor, alpha=alpha # type: ignore[arg-type] 4053 ) 4054 4055 4056@register_decomposition(aten.index_select) 4057@out_wrapper() 4058def index_select(x: TensorLike, dim: int, index: TensorLike): 4059 dim = utils.canonicalize_dims(x.ndim, dim) 4060 torch._check( 4061 index.ndim <= 1, 4062 lambda: f"Index should have dimension 1 or 0 (got {index.ndim})", 4063 ) 4064 if index.ndim == 0: 4065 index = index.unsqueeze(0) 4066 if x.ndim == 0: 4067 # Treat scalars as elements of \R^1 4068 # We cannot use x[idx] here as it accesses item() (??), hence this awkward construction 4069 return torch.empty_like(x).index_copy(0, index, x.expand_as(index)) 4070 4071 idx = (slice(None),) * dim + (index,) 4072 return x[idx] 4073 4074 4075@register_decomposition(aten.squeeze.dims) 4076def squeeze(a: TensorLikeType, dim: Optional[DimsType] = None) -> TensorLikeType: 4077 from torch.fx.experimental.symbolic_shapes import guard_size_oblivious 4078 4079 if dim is None: 4080 dims = tuple(idx for idx, size in enumerate(a.shape) if size == 1) 4081 return prims.squeeze(a, dims) if dims else prims.view_of(a) 4082 4083 ndim = a.ndim 4084 dim = utils.canonicalize_dims(ndim, dim) 4085 dims = (dim,) if isinstance(dim, Dim) else dim 4086 # Short-circuits if the tensor has no dimensions 4087 if ndim == 0: 4088 assert len(dims) == 0 or dims == (0,) 4089 return prims.view_of(a) 4090 4091 # Note: squeeze does not modify tensors when the given dim is not a dimension of length 1 4092 dims = tuple(d for d in dims if guard_size_oblivious(a.shape[d] == 1)) 4093 if len(dims) == 0: 4094 return prims.view_of(a) 4095 if len(dims) == 1: 4096 return prims.squeeze(a, dims) 4097 dims_list = list(dims) 4098 dims_list = sorted(dims_list, reverse=True) 4099 for i in dims_list: 4100 a = squeeze(a, i) 4101 return a 4102 4103 4104# Note: does not work with TensorMetas because of data-dependent control-flow 4105# CompositeImplicitAutograd - don't register decomp 4106def tensor_split( 4107 a: TensorLikeType, 4108 indices_or_sections: Union[Tensor, DimsType], 4109 dim: int = 0, 4110) -> Tuple[TensorLikeType, ...]: 4111 _dim = utils.canonicalize_dim(a.ndim, dim) 4112 if a.ndim == 0: 4113 msg = "tensor_split: received a rank zero tensor, but expected a tensor of rank one or greater!" 4114 raise ValueError(msg) 4115 4116 # If indices_or_sections is a tensor, it must be a CPU Long tensor 4117 if isinstance(indices_or_sections, TensorLike): 4118 if not indices_or_sections.device.type == "cpu": 4119 msg = ( 4120 f"tensor_split: if indices_or_sections is a tensor it must be on the CPU, " 4121 f"but received one on {indices_or_sections.device}" 4122 ) 4123 raise ValueError(msg) 4124 if indices_or_sections.dtype != torch.long: 4125 msg = "tensor_split: if indices_or_sections is a tensor it must have long dtype, " 4126 f" but received one with dtype {indices_or_sections.dtype}" 4127 raise ValueError(msg) 4128 4129 # Case 0 -- indices_or_sections is an integer or a scalar tensor n and a is split along dim into n parts of equal-ish length 4130 if isinstance(indices_or_sections, IntLike) or ( 4131 isinstance(indices_or_sections, TensorLike) and indices_or_sections.ndim == 0 4132 ): 4133 sections: int = ( 4134 indices_or_sections # type: ignore[assignment] 4135 if isinstance(indices_or_sections, Number) 4136 else indices_or_sections.item() 4137 ) 4138 4139 if sections <= 0: 4140 msg = f"tensor_split: number of sections must be greater than 0, but was {sections}" 4141 raise ValueError(msg) 4142 4143 splits = [] 4144 dim_size = a.shape[_dim] 4145 min_split_size = math.floor(dim_size / sections) 4146 num_splits_one_extra = dim_size % sections 4147 start_idx = 0 4148 for split_idx in range(sections): 4149 split_size = ( 4150 min_split_size + 1 4151 if (split_idx < num_splits_one_extra) 4152 else min_split_size 4153 ) 4154 s = prims.slice_in_dim(a, start_idx, start_idx + split_size, axis=_dim) 4155 splits.append(s) 4156 start_idx = start_idx + split_size 4157 4158 return tuple(splits) 4159 # Case 1 -- indices_or_sections is a sequence of integers or a 1D tensor describing the splits 4160 else: 4161 indices = indices_or_sections 4162 if isinstance(indices_or_sections, TensorLike): 4163 if indices_or_sections.ndim != 1: 4164 msg = "tensor_split: non-scalar indices_or_sections tensors must have only one dimension, " 4165 f"but received a tensor with {indices_or_sections.ndim} dimensions" 4166 raise ValueError(msg) 4167 4168 indices = indices_or_sections.tolist() 4169 4170 splits = [] 4171 start_idx = 0 4172 for x in indices: 4173 splits.append(prims.slice_in_dim(a, start_idx, x, axis=_dim)) 4174 start_idx = x 4175 splits.append(prims.slice_in_dim(a, start_idx, a.shape[_dim], axis=_dim)) 4176 return tuple(splits) 4177 4178 4179# CompositeImplicitAutograd - don't register decomp 4180def hsplit( 4181 a: TensorLikeType, indices_or_sections: DimsType 4182) -> Tuple[TensorLikeType, ...]: 4183 torch._check( 4184 a.ndim >= 1, 4185 lambda: ( 4186 "torch.hsplit requires a tensor with at least 1 dimension, but got a tensor with " 4187 + str(a.ndim) 4188 + " dimensions!" 4189 ), 4190 ) 4191 dim = 0 if a.ndim == 1 else 1 4192 if isinstance(indices_or_sections, IntLike): 4193 split_size = indices_or_sections 4194 torch._check( 4195 (split_size != 0 and a.shape[dim] % split_size == 0), 4196 lambda: ( 4197 "torch.hsplit attempted to split along dimension " 4198 + str(dim) 4199 + ", but the size of the dimension " 4200 + str(a.shape[dim]) 4201 + " is not divisible by the split_size " 4202 + str(split_size) 4203 + "!" 4204 ), 4205 ) 4206 return tensor_split(a, split_size, dim) 4207 4208 torch._check_type( 4209 isinstance(indices_or_sections, (list, tuple)), 4210 lambda: ( 4211 "hsplit(): received an invalid combination of arguments. " 4212 "Expected indices_or_sections to be of type int, list of ints or tuple of ints " 4213 f"but got type {type(indices_or_sections)}" 4214 ), 4215 ) 4216 4217 split_sizes = indices_or_sections 4218 return tensor_split(a, split_sizes, dim) 4219 4220 4221# CompositeImplicitAutograd - don't register decomp 4222def vsplit( 4223 a: TensorLikeType, indices_or_sections: DimsType 4224) -> Tuple[TensorLikeType, ...]: 4225 torch._check( 4226 a.ndim >= 2, 4227 lambda: ( 4228 "torch.vsplit requires a tensor with at least 2 dimension, but got a tensor with " 4229 + str(a.ndim) 4230 + " dimensions!" 4231 ), 4232 ) 4233 if isinstance(indices_or_sections, IntLike): 4234 split_size = indices_or_sections 4235 torch._check( 4236 (split_size != 0 and a.shape[0] % split_size == 0), 4237 lambda: ( 4238 f"torch.vsplit attempted to split along dimension 0" 4239 f", but the size of the dimension " 4240 f"{a.shape[0]}" 4241 f" is not divisible by the split_size " 4242 f"{split_size}" 4243 f"!" 4244 ), 4245 ) 4246 return tensor_split(a, split_size, 0) 4247 4248 torch._check_type( 4249 isinstance(indices_or_sections, (list, tuple)), 4250 lambda: ( 4251 "vsplit(): received an invalid combination of arguments. " 4252 "Expected indices_or_sections to be of type int, list of ints or tuple of ints " 4253 f"but got type {type(indices_or_sections)}" 4254 ), 4255 ) 4256 4257 split_sizes = indices_or_sections 4258 return tensor_split(a, split_sizes, 0) 4259 4260 4261@register_decomposition(aten.diag.out) 4262@out_wrapper() 4263def diag( 4264 self: TensorLikeType, 4265 offset: int = 0, 4266) -> TensorLikeType: 4267 ndim = self.dim() 4268 torch._check( 4269 ndim in (1, 2), lambda: f"diag(): Supports 1D or 2D tensors. Got {ndim}D" 4270 ) 4271 if ndim == 1: 4272 return torch.diag_embed(self, offset) 4273 else: 4274 return torch.diagonal_copy(self, offset) 4275 4276 4277@register_decomposition(aten.diagonal_scatter) 4278@out_wrapper() 4279def diagonal_scatter( 4280 input: TensorLikeType, 4281 src: TensorLikeType, 4282 offset: int = 0, 4283 dim1: int = 0, 4284 dim2: int = 1, 4285) -> TensorLikeType: 4286 out = utils.clone_preserve_strides(input) 4287 diag = out.diagonal(offset, dim1, dim2) 4288 torch._check( 4289 diag.shape == src.shape, 4290 lambda: "expected src to have a size equal to the diagonal of the input." 4291 f"Got {src.shape} for a diagonal of shape {diag.shape}", 4292 ) 4293 copy_to(diag, src) 4294 return out 4295 4296 4297@register_decomposition(aten.diagonal) 4298def diagonal( 4299 self: TensorLikeType, 4300 offset: int = 0, 4301 dim1: int = 0, 4302 dim2: int = 1, 4303) -> TensorLikeType: 4304 """ 4305 Reference implementation of torch.diagonal 4306 """ 4307 num_dims = self.dim() 4308 dim1 = utils.canonicalize_dim(idx=dim1, rank=num_dims) 4309 dim2 = utils.canonicalize_dim(idx=dim2, rank=num_dims) 4310 4311 torch._check( 4312 dim1 != dim2, lambda: f"diagonal dimensions cannot be identical {dim1}, {dim2}" 4313 ) 4314 4315 storage_offset = self.storage_offset() 4316 4317 if offset >= 0: 4318 diag_size = max(min(self.size()[dim1], self.size()[dim2] - offset), 0) 4319 else: 4320 diag_size = max(min(self.size()[dim1] + offset, self.size()[dim2]), 0) 4321 4322 if diag_size > 0: 4323 if offset >= 0: 4324 storage_offset += offset * self.stride()[dim2] 4325 else: 4326 storage_offset -= offset * self.stride()[dim1] 4327 4328 sizes = [s for i, s in enumerate(self.size()) if i not in (dim1, dim2)] 4329 sizes.append(diag_size) 4330 4331 strides = [s for i, s in enumerate(self.stride()) if i not in (dim1, dim2)] 4332 strides.append(self.stride()[dim1] + self.stride()[dim2]) 4333 4334 result = self.as_strided(size=sizes, stride=strides, storage_offset=storage_offset) 4335 4336 return result 4337 4338 4339@register_decomposition(aten.diag_embed) 4340@out_wrapper() 4341def diag_embed( 4342 t: TensorLikeType, 4343 offset: int = 0, 4344 dim1: int = -2, 4345 dim2: int = -1, 4346) -> TensorLikeType: 4347 """ 4348 Reference implementation of torch.diag_embed 4349 """ 4350 # convert from negative dims 4351 rank = t.ndim + 1 4352 dim1 = utils.canonicalize_dim(rank=rank, idx=dim1) 4353 dim2 = utils.canonicalize_dim(rank=rank, idx=dim2) 4354 4355 # as per the docs, exchanging dims is equivalent to changing the sign of 4356 # offset 4357 if dim1 > dim2: 4358 dim1, dim2 = dim2, dim1 4359 offset = -offset 4360 4361 torch._check( 4362 dim1 != dim2, lambda: f"diagonal dimensions cannot be identical {dim1}, {dim2}" 4363 ) 4364 4365 # as per the docs, the size of last dim is placed at dim1 and dim2 4366 last_dim = t.size(-1) 4367 4368 if offset != 0: 4369 # add padding to match the new size 4370 t_shape = list(t.shape) 4371 t_shape[-1] = builtins.abs(offset) 4372 z = torch.zeros(t_shape, dtype=t.dtype, device=t.device, requires_grad=False) 4373 pair = (z, t) if offset > 0 else (t, z) 4374 t = torch.cat(pair, dim=-1) 4375 # make sure the diagonal always has the same size 4376 last_dim += builtins.abs(offset) 4377 4378 # preserve original data, but place 1 at dim1 and move last dim to dim2 4379 t = t.unsqueeze(dim1).movedim(-1, dim2) 4380 4381 # generate ranges shifting indices based on offset 4382 a_range = torch.arange(last_dim, device=t.device, dtype=torch.int64) 4383 b_range = torch.arange( 4384 offset, last_dim + offset, device=t.device, dtype=torch.int64 4385 ) 4386 4387 # broadcast 4388 cond = a_range == b_range.unsqueeze(-1) 4389 cond_shape = [last_dim if i in (dim1, dim2) else 1 for i in range(len(t.shape))] 4390 cond = cond.reshape(cond_shape) 4391 4392 # aten.diag_embed always returns a new contiguous tensor 4393 # contiguous() is needed to correctly model the output stride 4394 return utils.mask_tensor(cond, t).contiguous() 4395 4396 4397@register_decomposition(aten.block_diag) 4398@out_wrapper() 4399def _block_diag_iterable(tensors: List[TensorLikeType]) -> TensorLikeType: 4400 """ 4401 Reference implementation of torch.block_diag 4402 """ 4403 tensors_2d = [ 4404 tensor.view(1, -1) if tensor.dim() <= 1 else tensor for tensor in tensors 4405 ] 4406 4407 ncols = builtins.sum(tensor.shape[1] for tensor in tensors_2d) 4408 device = tensors_2d[0].device 4409 4410 result = [] 4411 4412 col_start = 0 4413 for i, tensor in enumerate(tensors_2d): 4414 torch._check( 4415 tensor.dim() == 2, 4416 lambda: "Input tensors must have 2 or fewer dimensions. " 4417 f"Input {i} has {tensor.dim()} dimensions", 4418 ) 4419 torch._check( 4420 tensor.device == device, 4421 lambda: "Input tensors must all be on the same device. " 4422 f"Input 0 is on device {device} and input {i} is on device {tensor.device}.", 4423 ) 4424 row, col = tensor.shape 4425 left = torch.zeros((row, col_start), device=device, dtype=tensor.dtype) 4426 right = torch.zeros( 4427 (row, ncols - col_start - col), device=device, dtype=tensor.dtype 4428 ) 4429 result += [torch.cat((left, tensor, right), dim=1)] 4430 col_start += col 4431 4432 return torch.cat(result, dim=0) 4433 4434 4435def block_diag(*tensors: List[TensorLikeType]) -> TensorLikeType: 4436 """ 4437 This is used as an input to PythonRefInfo. `torch.block_diag` 4438 expects arguments splatted, but `aten.block_diag` expects only 4439 one argument that is a list of Tensors. 4440 """ 4441 return _block_diag_iterable(tensors) # type: ignore[arg-type] 4442 4443 4444# CompositeImplicitAutograd - don't register decomp 4445def dsplit(a: TensorLikeType, sections: DimsType) -> TensorSequenceType: 4446 if a.ndim < 3: 4447 raise RuntimeError( 4448 f"torch.dsplit requires a tensor with at least 3 dimension, but got a tensor with {a.ndim} dimensions!" 4449 ) 4450 if isinstance(sections, IntLike) and (sections == 0 or a.shape[2] % sections != 0): 4451 raise RuntimeError( 4452 "torch.dsplit attempted to split along dimension 2, " 4453 + f"but the size of the dimension {a.shape[2]} is not divisible by the split_size {sections}!" 4454 ) 4455 return tensor_split(a, sections, 2) 4456 4457 4458@register_decomposition(aten.t.default) 4459def t(a: TensorLikeType): 4460 # TODO: Add sparse support 4461 # if a.is_sparse: 4462 # sparse_dim = a.sparse_dim() 4463 # dense_dim = a.dense_dim() 4464 # if not (sparse_dim <= 2 and dense_dim == 0): 4465 # raise RuntimeError( 4466 # f"t() expects a tensor with <= 2 sparse and 0 dense dimensions, but got {sparse_dim} sparse and" 4467 # f"{dense_dim} dense dimensions" 4468 # ) 4469 if a.ndim > 2: 4470 raise RuntimeError( 4471 f"t() expects a tensor with <= 2 dimensions, but self is {a.ndim}D" 4472 ) 4473 return torch.transpose(a, 0, 0 if a.ndim < 2 else 1) 4474 4475 4476# CompositeImplicitAutograd - don't register decomp 4477def T(a: TensorLikeType) -> TensorLikeType: 4478 # n != 2 && n != 0 is deprecated in regular PyTorch. 4479 torch._check( 4480 a.ndim in (0, 2), 4481 lambda: ( 4482 "The use of `x.T` on tensors of dimension other than 0 or 2 " 4483 "to reverse their shape is not supported." 4484 ), 4485 ) 4486 return a.t() 4487 4488 4489@register_decomposition(aten.alias) 4490def alias(a: TensorLikeType) -> TensorLikeType: 4491 return prims.view_of(a) 4492 4493 4494@register_decomposition(aten.transpose) 4495def transpose(a: TensorLikeType, dim0: int, dim1: int) -> TensorLikeType: 4496 _dim0, _dim1 = utils.canonicalize_dims(a.ndim, (dim0, dim1)) # type: ignore[misc] 4497 4498 if a.ndim <= 1 or dim0 == dim1: 4499 return aten.alias.default(a) 4500 4501 _permutation = list(range(0, a.ndim)) 4502 _permutation[_dim0] = _dim1 4503 _permutation[_dim1] = _dim0 4504 return torch.permute(a, _permutation) 4505 4506 4507# Aliases for transpose 4508swap_axes = transpose 4509 4510 4511@register_decomposition(aten.unfold) 4512def unfold( 4513 self: TensorLikeType, dimension: int, size: int, step: int 4514) -> TensorLikeType: 4515 shape, strides = _get_unfold_shape_stride( 4516 self.shape, self.stride(), dimension, size, step 4517 ) 4518 return self.as_strided(shape, strides) 4519 4520 4521@register_decomposition(aten.unfold_copy) 4522@out_wrapper() 4523def unfold_copy(self: TensorLikeType, dimension: int, size: int, step: int): 4524 return self.unfold(dimension, size, step).clone( 4525 memory_format=torch.contiguous_format 4526 ) 4527 4528 4529def _cumsumprod_common( 4530 func, 4531 init, 4532 a: TensorLikeType, 4533 dim: int, 4534 *, 4535 dtype: Optional[torch.dtype] = None, 4536 out: Optional[Tensor] = None, 4537) -> TensorLikeType: 4538 # We implement all the kwargs of a reduction. ATen just handles dtype 4539 # nb. This decomposition may not be as efficient as a backend-specific implementation 4540 ndim = a.ndim 4541 dim = utils.canonicalize_dim(ndim, dim) 4542 if ndim == 0: 4543 return func(a.unsqueeze(0), dim=0, dtype=dtype, out=out) 4544 a = a.unsqueeze(dim + 1) 4545 rg = torch.arange(a.shape[dim], device=a.device) 4546 mask = rg.unsqueeze(1) <= rg 4547 for _ in range(ndim - dim - 1): 4548 mask = mask.unsqueeze(-1) 4549 masked_a = torch.where(mask, a, init) 4550 return func(masked_a, dim=dim, dtype=dtype, out=out) 4551 4552 4553@register_decomposition(aten.cumsum) 4554def cumsum( 4555 a: TensorLikeType, 4556 dim: int, 4557 *, 4558 dtype: Optional[torch.dtype] = None, 4559 out: Optional[Tensor] = None, 4560) -> TensorLikeType: 4561 return _cumsumprod_common(func=sum, init=0, a=a, dim=dim, dtype=dtype, out=out) 4562 4563 4564@register_decomposition(aten.cumprod) 4565def cumprod( 4566 a: TensorLikeType, 4567 dim: int, 4568 *, 4569 dtype: Optional[torch.dtype] = None, 4570 out: Optional[Tensor] = None, 4571) -> TensorLikeType: 4572 return _cumsumprod_common(func=prod, init=1, a=a, dim=dim, dtype=dtype, out=out) 4573 4574 4575# Note: although squeeze is documented as having the out= kwarg it doesn't 4576@register_decomposition(aten.unsqueeze) 4577def unsqueeze(a: TensorLikeType, dim: int) -> TensorLikeType: 4578 # Note that unsqueeze canonicalizes with rank + 1 because it allows 4579 # a new innermost dimension to be specified 4580 ndim = a.ndim + 1 4581 dim = utils.canonicalize_dim(ndim, dim) 4582 return prims.expand_dims(a, (dim,), ndim=ndim) 4583 4584 4585# NOTE: shape is a vararg because Tensor.reshape can be called with as 4586# Tensor.view(a, b, c) or Tensor.view((a, b, c)) Function call torch.view 4587# doesn't support unpacked shapes 4588# TODO: Turn this into a decomposition (currently fails on reshape meta tests) 4589@register_decomposition(aten.view.default) 4590def view(a: TensorLikeType, *shape: ShapeType) -> TensorLikeType: 4591 return _reshape_view_helper(a, *shape, allow_copy=False) 4592 4593 4594# CompositeImplicitAutograd - don't register decomp 4595def view_as(self: TensorLikeType, other: TensorLikeType) -> TensorLikeType: 4596 return self.view(other.size()) 4597 4598 4599# CompositeImplicitAutograd - don't register decomp 4600def ravel(a: TensorLikeType) -> TensorLikeType: 4601 return reshape(a, (-1,)) 4602 4603 4604# CompositeImplicitAutograd - don't register decomp 4605# missing ref impl. for aten.gather 4606@out_wrapper() 4607def take_along_dim( 4608 a: torch.Tensor, indices: torch.Tensor, dim: Optional[int] = None 4609) -> torch.Tensor: 4610 torch._check( 4611 a.ndim == indices.ndim, 4612 lambda: ( 4613 "torch.take_along_dim(): input and indices should have the same " 4614 f"number of dimensions, but got {a.ndim} dimensions for input, and " 4615 f"{indices.ndim} dimensions for indices" 4616 ), 4617 ) 4618 4619 torch._check( 4620 utils.is_integer_dtype(indices.dtype), 4621 lambda: ( 4622 "torch.take_along_dim(): dtype of indices should be int but got " 4623 f"{indices.dtype} instead" 4624 ), 4625 ) 4626 4627 if dim is None: 4628 return torch.gather(a.view(-1), 0, indices.view(-1)) 4629 else: 4630 self_sizes = list(a.shape) 4631 self_sizes[dim] = indices.size(dim) 4632 broadcast_shape = utils.infer_size_shapes(self_sizes, indices.size()) 4633 indices_broadcast = broadcast_to(indices, broadcast_shape) 4634 4635 indices_sizes = list(indices.shape) 4636 indices_sizes[dim] = a.size(dim) 4637 broadcast_shape = utils.infer_size_shapes(indices_sizes, a.size()) 4638 self_broadcast = broadcast_to(a, broadcast_shape) 4639 4640 return torch.gather(self_broadcast, dim, indices_broadcast) 4641 4642 4643@out_wrapper() 4644def empty( 4645 *shape, 4646 dtype: Optional[torch.dtype] = None, 4647 layout: torch.layout = torch.strided, 4648 device: Optional[DeviceLikeType] = None, 4649 requires_grad: bool = False, 4650 pin_memory: bool = False, 4651 memory_format: torch.memory_format = torch.contiguous_format, 4652) -> TensorLikeType: 4653 torch._check( 4654 memory_format != torch.preserve_format, 4655 lambda: "torch.empty: the Preserve memory format is not supported", 4656 ) 4657 4658 shape = utils.extract_shape_from_varargs(shape) 4659 4660 if memory_format == torch.contiguous_format: 4661 strides = utils.make_contiguous_strides_for(shape) 4662 elif memory_format == torch.channels_last_3d: 4663 strides = utils.make_channels_last_3d_strides_for(shape) 4664 else: # memory_format == torch.channels_last 4665 torch._check( 4666 memory_format == torch.channels_last, 4667 lambda: f"torch.empty: received an unknown memory format {memory_format}!", 4668 ) 4669 strides = utils.make_channels_last_2d_strides_for(shape) 4670 4671 return torch.empty_strided( 4672 shape, 4673 strides, 4674 dtype=dtype, 4675 layout=layout, 4676 device=device, 4677 pin_memory=pin_memory, 4678 requires_grad=requires_grad, 4679 ) 4680 4681 4682@out_wrapper() 4683def empty_permuted( 4684 shape, 4685 physical_layout, 4686 dtype: Optional[torch.dtype] = None, 4687 layout: torch.layout = torch.strided, 4688 device: Optional[DeviceLikeType] = None, 4689 requires_grad: bool = False, 4690 pin_memory: bool = False, 4691) -> TensorLikeType: 4692 return prims.empty_permuted( 4693 shape, 4694 physical_layout, 4695 dtype=dtype, 4696 device=device, 4697 requires_grad=requires_grad, 4698 ) 4699 4700 4701@register_decomposition(aten.new_empty) 4702@out_wrapper() 4703def new_empty( 4704 a: TensorLikeType, 4705 size: ShapeType, 4706 *, 4707 dtype: Optional[torch.dtype] = None, 4708 layout: Optional[torch.layout] = None, 4709 device: Optional[DeviceLikeType] = None, 4710 pin_memory: bool = False, 4711) -> TensorLikeType: 4712 dtype = a.dtype if dtype is None else dtype 4713 layout = a.layout if layout is None else layout 4714 device = a.device if device is None else device 4715 4716 return torch.empty( 4717 size, 4718 dtype=dtype, 4719 device=device, 4720 pin_memory=pin_memory, 4721 layout=layout, 4722 ) 4723 4724 4725@register_decomposition(aten.new_empty_strided) 4726@out_wrapper() 4727def new_empty_strided( 4728 a: TensorLikeType, 4729 size: ShapeType, 4730 stride: StrideType, 4731 *, 4732 dtype: Optional[torch.dtype] = None, 4733 layout: Optional[torch.layout] = None, 4734 device: Optional[DeviceLikeType] = None, 4735 pin_memory: bool = False, 4736) -> TensorLikeType: 4737 """ 4738 Reference implementation of torch.Tensor.new_empty_strided 4739 """ 4740 4741 dtype = a.dtype if dtype is None else dtype 4742 layout = a.layout if layout is None else layout 4743 device = a.device if device is None else device 4744 4745 return torch.empty_strided( 4746 size, 4747 stride, 4748 dtype=dtype, 4749 device=device, 4750 pin_memory=pin_memory, 4751 layout=layout, 4752 ) 4753 4754 4755@register_decomposition(aten.zeros.default) 4756@out_wrapper() 4757def zeros( 4758 *size, 4759 dtype: Optional[torch.dtype] = None, 4760 layout: torch.layout = torch.strided, 4761 device: Optional[DeviceLikeType] = None, 4762 pin_memory: bool = False, 4763 requires_grad: bool = False, 4764) -> TensorLikeType: 4765 size = utils.extract_shape_from_varargs(size) 4766 4767 if dtype is None: 4768 dtype = torch.get_default_dtype() 4769 4770 return torch.full( 4771 size, 4772 False if dtype == torch.bool else 0, 4773 dtype=dtype, 4774 layout=layout, 4775 device=device, 4776 pin_memory=pin_memory, 4777 requires_grad=requires_grad, 4778 ) 4779 4780 4781@register_decomposition(aten.new_zeros) 4782@out_wrapper() 4783def new_zeros( 4784 a: TensorLikeType, 4785 size: ShapeType, 4786 *, 4787 dtype: Optional[torch.dtype] = None, 4788 layout: Optional[torch.layout] = None, 4789 device: Optional[DeviceLikeType] = None, 4790 pin_memory: bool = False, 4791 requires_grad: bool = False, 4792) -> TensorLikeType: 4793 dtype = a.dtype if dtype is None else dtype 4794 layout = a.layout if layout is None else layout 4795 device = a.device if device is None else device 4796 4797 return torch.full( 4798 size, 4799 False if (dtype or a.dtype) == torch.bool else 0, 4800 dtype=dtype, 4801 layout=layout, 4802 device=device, 4803 pin_memory=pin_memory, 4804 requires_grad=requires_grad, 4805 ) 4806 4807 4808@register_decomposition(aten.ones.default) 4809@out_wrapper() 4810def ones( 4811 *size, 4812 dtype: Optional[torch.dtype] = None, 4813 layout: torch.layout = torch.strided, 4814 device: Optional[DeviceLikeType] = None, 4815 pin_memory: bool = False, 4816 requires_grad: bool = False, 4817) -> TensorLikeType: 4818 size = utils.extract_shape_from_varargs(size) 4819 4820 if dtype is None: 4821 dtype = torch.get_default_dtype() 4822 4823 return torch.full( 4824 size, 4825 True if dtype == torch.bool else 1, 4826 dtype=dtype, 4827 layout=layout, 4828 device=device, 4829 pin_memory=pin_memory, 4830 requires_grad=requires_grad, 4831 ) 4832 4833 4834@register_decomposition(aten.new_ones) 4835@out_wrapper() 4836def new_ones( 4837 a: TensorLikeType, 4838 size: ShapeType, 4839 *, 4840 dtype: Optional[torch.dtype] = None, 4841 layout: Optional[torch.layout] = None, 4842 device: Optional[DeviceLikeType] = None, 4843 pin_memory: bool = False, 4844 requires_grad: bool = False, 4845) -> TensorLikeType: 4846 dtype = a.dtype if dtype is None else dtype 4847 layout = a.layout if layout is None else layout 4848 device = a.device if device is None else device 4849 4850 return torch.full( 4851 size, 4852 True if (dtype or a.dtype) == torch.bool else 1, 4853 dtype=dtype, 4854 layout=layout, 4855 device=device, 4856 pin_memory=pin_memory, 4857 requires_grad=requires_grad, 4858 ) 4859 4860 4861@register_decomposition(aten.new_full) 4862@out_wrapper() 4863def new_full( 4864 a: TensorLikeType, 4865 size: ShapeType, 4866 fill_value: NumberType, 4867 *, 4868 dtype: Optional[torch.dtype] = None, 4869 layout: Optional[torch.layout] = None, 4870 device: Optional[DeviceLikeType] = None, 4871 pin_memory: bool = False, 4872) -> TensorLikeType: 4873 dtype = a.dtype if dtype is None else dtype 4874 layout = a.layout if layout is None else layout 4875 device = a.device if device is None else device 4876 4877 return torch.full( 4878 size, 4879 fill_value, 4880 dtype=dtype, 4881 layout=layout, 4882 device=device, 4883 pin_memory=pin_memory, 4884 ) 4885 4886 4887@register_decomposition(aten.empty_like) 4888@out_wrapper() 4889def empty_like( 4890 a: TensorLikeType, 4891 *, 4892 dtype: Optional[torch.dtype] = None, 4893 device: Optional[DeviceLikeType] = None, 4894 layout: Optional[torch.layout] = None, 4895 pin_memory: bool = False, 4896 requires_grad: bool = False, 4897 memory_format: torch.memory_format = torch.preserve_format, 4898) -> TensorLikeType: 4899 dtype = a.dtype if dtype is None else dtype 4900 layout = a.layout if layout is None else layout 4901 device = a.device if device is None else device 4902 4903 if memory_format != torch.preserve_format: 4904 return torch.empty( 4905 a.shape, 4906 dtype=dtype, 4907 layout=layout, 4908 device=device, 4909 requires_grad=requires_grad, 4910 pin_memory=pin_memory, 4911 memory_format=memory_format, 4912 ) 4913 4914 # memory_format == torch.preserve_format 4915 logical_to_physical_perm = ( 4916 utils.compute_elementwise_output_logical_to_physical_perm(a) 4917 ) 4918 # identity perm is [2, 1, 0] 4919 return torch.empty_permuted( 4920 a.shape, 4921 logical_to_physical_perm, 4922 dtype=dtype, 4923 layout=layout, 4924 device=device, 4925 pin_memory=pin_memory, 4926 requires_grad=requires_grad, 4927 ) 4928 4929 4930@register_decomposition([aten.arange.start_step, aten.arange.start_out]) 4931@out_wrapper() 4932def arange( 4933 start: NumberType = 0, 4934 end: Optional[NumberType] = None, 4935 step: NumberType = 1, 4936 *, 4937 dtype: Optional[torch.dtype] = None, 4938 layout: torch.layout = torch.strided, 4939 device: Optional[DeviceLikeType] = None, 4940 pin_memory: bool = False, 4941 requires_grad: bool = False, 4942) -> TensorLikeType: 4943 utils.check_layout(layout) 4944 utils.check_pin_memory(pin_memory) 4945 device = torch.device(utils.device_or_default(device)) 4946 4947 assert not isinstance(start, complex) 4948 assert not isinstance(end, complex) 4949 assert not isinstance(step, complex) 4950 4951 # Case: torch.arange(5) 4952 if end is None: 4953 end = start 4954 start = 0 4955 torch._check(step != 0, lambda: "step must be nonzero") 4956 if step > 0: 4957 torch._check( 4958 end >= start, 4959 lambda: "upper bound and lower bound inconsistent with step sign", 4960 ) 4961 elif step < 0: 4962 torch._check( 4963 end <= start, 4964 lambda: "upper bound and lower bound inconsistent with step sign", 4965 ) 4966 4967 def is_finite(x): 4968 return not isinstance(x, FloatWithoutSymFloat) or math.isfinite(x) 4969 4970 torch._check( 4971 is_finite(start) and is_finite(end), 4972 lambda: f"unsupported range: {start} -> {end}", 4973 ) 4974 torch._check( 4975 is_finite(step), 4976 lambda: f"step must be finite but got {step}", 4977 ) 4978 4979 args = (start, end, step) 4980 integer_args = builtins.all(isinstance(arg, IntLike) for arg in args) 4981 4982 if dtype is None: 4983 dtype = torch.int64 if integer_args else torch.get_default_dtype() 4984 4985 is_integer = utils.is_integer_dtype(dtype) 4986 if is_integer or integer_args: 4987 xstart = sym_int(start) 4988 xend = sym_int(end) 4989 xstep = sym_int(step) 4990 4991 # For int64 we truncate arguments to int before calculating length, but 4992 # other integral dtypes we don't. Weird... but needed to match ATen shapes. 4993 if dtype == torch.int64 or integer_args: 4994 # Uses floordiv to avoid ceil in inductor. 4995 sgn = bool(xstep > 0) - bool(xstep < 0) # type: ignore[possibly-undefined] 4996 length = (xend - xstart + xstep - sgn) // xstep # type: ignore[possibly-undefined] 4997 else: 4998 length = math.ceil((end - start) / step) 4999 5000 if is_integer: 5001 return prims.iota( 5002 length, 5003 start=xstart, # type: ignore[possibly-undefined] 5004 step=xstep, # type: ignore[possibly-undefined] 5005 dtype=dtype, 5006 device=device, 5007 requires_grad=requires_grad, 5008 ) 5009 5010 index = prims.iota( 5011 length, 5012 start=0, 5013 step=1, 5014 dtype=torch.int64, 5015 device=device, 5016 requires_grad=False, 5017 ) 5018 5019 computation_dtype = ( 5020 torch.long if integer_args else utils.get_acc_type(dtype, device) 5021 ) 5022 index = _maybe_convert_to_dtype(index, computation_dtype) 5023 result = start + step * index 5024 result = _maybe_convert_to_dtype(result, dtype) 5025 5026 if requires_grad: 5027 result.requires_grad_(True) 5028 return result 5029 5030 5031@register_decomposition(aten.lerp) 5032@out_wrapper() 5033@elementwise_type_promotion_wrapper( 5034 type_promoting_args=("start", "end", "weight"), 5035 type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT, 5036) 5037def lerp(start: Tensor, end: Tensor, weight: Union[Tensor, NumberType]): 5038 inputs = [start, end] 5039 if isinstance(weight, Number): 5040 weight = start.new_full((), weight) # type: ignore[arg-type] 5041 else: 5042 inputs.append(weight) 5043 assert isinstance(weight, Tensor) # mypy 5044 # We implement it this way for numerical stability. We assume (in the stability optimisation) 5045 # that 0 <= weight <= 1. We take the abs to deal with complex numbers 5046 # We want to perform operations near zero, which is where floating points are most precise 5047 # thus, we perform the following optimisation: 5048 # If weight.abs() >= 0.5: 5049 # return (1 - weight) * (start - end) + end 5050 mask = weight.abs() >= 0.5 5051 coeff = torch.where(mask, weight - 1, weight) 5052 base = torch.where(mask, end, start) 5053 output = coeff * (end - start) + base 5054 # make sure the decomposition output's stride is same as non-decomposition path. 5055 stride = utils.compute_elementwise_output_strides(*_maybe_broadcast(*inputs)) 5056 if output.stride() != stride: 5057 output = prims.copy_strided(output, stride) 5058 5059 return handle_noncontiguous_outputs(inputs, output) 5060 5061 5062@register_decomposition(aten.linspace) 5063@out_wrapper() 5064def linspace( 5065 start: Union[NumberType, TensorLikeType], 5066 end: Union[NumberType, TensorLikeType], 5067 steps: NumberType, 5068 *, 5069 dtype: Optional[torch.dtype] = None, 5070 device: Optional[DeviceLikeType] = None, 5071 layout: torch.layout = torch.strided, 5072 pin_memory: bool = False, 5073 requires_grad: bool = False, 5074) -> TensorLikeType: 5075 if isinstance(start, TensorLikeType): 5076 torch._check( 5077 start.dim() == 0, 5078 lambda: "linspace only supports 0-dimensional start and end tensors", 5079 ) 5080 start = _maybe_convert_to_dtype(start, torch.float64) 5081 if isinstance(end, TensorLikeType): 5082 torch._check( 5083 end.dim() == 0, 5084 lambda: "linspace only supports 0-dimensional start and end tensors", 5085 ) 5086 end = _maybe_convert_to_dtype(end, torch.float64) 5087 5088 if builtins.any(isinstance(arg, complex) for arg in (start, end, steps)): 5089 default_complex_dtype = utils.corresponding_complex_dtype( 5090 torch.get_default_dtype() 5091 ) 5092 if dtype is None: 5093 dtype = default_complex_dtype 5094 else: 5095 torch._check( 5096 utils.is_complex_dtype(dtype), 5097 lambda: f"linspace(): inferred dtype {default_complex_dtype} can't be safely cast to passed dtype {dtype}", 5098 ) 5099 else: 5100 dtype = dtype or torch.get_default_dtype() 5101 assert isinstance(dtype, torch.dtype) 5102 5103 # steps does not participate in the computation of the dtype 5104 torch._check_type( 5105 isinstance(steps, IntLike), 5106 lambda: f"received an invalid combination of arguments - got \ 5107({type(start).__name__}, {type(end).__name__}, {type(steps).__name__})", 5108 ) 5109 assert isinstance(steps, IntLike) # for mypy 5110 torch._check(steps >= 0, lambda: "number of steps must be non-negative") 5111 5112 factory_kwargs = { 5113 "layout": layout, 5114 "device": device, 5115 "pin_memory": pin_memory, 5116 "requires_grad": requires_grad, 5117 } 5118 if steps == 0: 5119 return torch.full((0,), 0, dtype=dtype, **factory_kwargs) # type: ignore[arg-type] 5120 if steps == 1: 5121 if isinstance(start, TensorLikeType): 5122 return torch.empty((steps,), dtype=dtype, **factory_kwargs).copy_(start) # type: ignore[arg-type] 5123 else: 5124 return torch.full((steps,), start, dtype=dtype, **factory_kwargs) # type: ignore[arg-type] 5125 5126 # Perform in arange in int because some backends like ATen or Triton do not support all the dtypes 5127 rg = torch.arange(0, steps, **factory_kwargs) # type: ignore[arg-type] 5128 5129 # Small types need to be computed in higher precision as this is, at heart, an associative scan 5130 dtype_red = ( 5131 torch.int64 5132 if (utils.is_boolean_dtype(dtype) or utils.is_integer_dtype(dtype)) 5133 else dtype 5134 ) 5135 computation_dtype, _ = utils.reduction_dtypes( 5136 rg, REDUCTION_OUTPUT_TYPE_KIND.SAME, dtype_red 5137 ) 5138 cast_rg = partial(_maybe_convert_to_dtype, dtype=computation_dtype) 5139 5140 # We implement torch.lerp without performing rg / (steps - 1) explicitly 5141 # With this we get out[0] == start, out[-1] == end 5142 step = (end - start) / (steps - 1) 5143 out = torch.where( 5144 rg < steps / 2, 5145 start + step * cast_rg(rg), # type: ignore[arg-type,operator] 5146 end - step * cast_rg((steps - 1) - rg), # type: ignore[arg-type,operator] 5147 ) 5148 return _maybe_convert_to_dtype(out, dtype) # type: ignore[return-value] 5149 5150 5151@register_decomposition(aten.logspace) 5152@out_wrapper() 5153def logspace( 5154 start: Union[NumberType, TensorLikeType], 5155 end: Union[NumberType, TensorLikeType], 5156 steps: NumberType, 5157 base: NumberType = 10, 5158 *, 5159 dtype: Optional[torch.dtype] = None, 5160 device: Optional[DeviceLikeType] = None, 5161 layout: torch.layout = torch.strided, 5162 pin_memory: bool = False, 5163 requires_grad: bool = False, 5164) -> TensorLikeType: 5165 if dtype is None: 5166 dtype = torch.get_default_dtype() 5167 5168 # NB: NumPy doesn't have this cast 5169 if prims.utils.is_integer_dtype(dtype): 5170 if isinstance(start, FloatLike): 5171 start = sym_int(start) 5172 elif isinstance(start, TensorLikeType): 5173 torch._check( 5174 start.dim() == 0, 5175 lambda: "logspace only supports 0-dimensional start and end tensors", 5176 ) 5177 start = _maybe_convert_to_dtype(start, dtype) 5178 if isinstance(end, FloatLike): 5179 end = sym_int(end) 5180 elif isinstance(end, TensorLikeType): 5181 torch._check( 5182 end.dim() == 0, 5183 lambda: "logspace only supports 0-dimensional start and end tensors", 5184 ) 5185 end = _maybe_convert_to_dtype(end, dtype) 5186 5187 if builtins.any(isinstance(arg, complex) for arg in (start, end, steps)): 5188 default_complex_dtype = utils.corresponding_complex_dtype( 5189 torch.get_default_dtype() 5190 ) 5191 dtype = default_complex_dtype 5192 _dtype = None # torch.linspace will update the correct dtype 5193 else: 5194 _dtype = torch.float64 5195 5196 assert not isinstance(base, complex) # for mypy 5197 if base < 0: 5198 raise NotImplementedError 5199 ret = torch.linspace( # type: ignore[misc] 5200 start, # type: ignore[arg-type] 5201 end, # type: ignore[arg-type] 5202 steps, # type: ignore[arg-type] 5203 dtype=_dtype, 5204 layout=layout, 5205 device=device, 5206 pin_memory=pin_memory, 5207 requires_grad=requires_grad, 5208 ) 5209 return _maybe_convert_to_dtype(torch.pow(base, ret), dtype) # type: ignore[arg-type,return-value] 5210 5211 5212@overload 5213def meshgrid(tensors: Sequence[TensorLikeType], indexing: str): 5214 pass 5215 5216 5217@overload 5218def meshgrid(*tensors: TensorLikeType, indexing: str): 5219 pass 5220 5221 5222@register_decomposition(aten.meshgrid) # type: ignore[misc] 5223def meshgrid( 5224 *tensors: Union[TensorLikeType, List[TensorLikeType], Tuple[TensorLikeType]], 5225 indexing: str, 5226) -> List[TensorLikeType]: 5227 # This ref simultaneously handles two overloads (see stubs above) 5228 # The `indexing` argument is currently optional for torch.meshgrid, but we 5229 # plan to make the argument required: https://github.com/pytorch/pytorch/issues/50276 5230 if isinstance(tensors[0], (list, tuple)): 5231 assert len(tensors) == 1 5232 tensors = tuple(tensors[0]) 5233 5234 torch._check( 5235 builtins.all(isinstance(a, TensorLike) for a in tensors), 5236 lambda: "meshgrid expects its inputs to be tensors", 5237 ) 5238 5239 torch._check(len(tensors) > 0, lambda: "meshgrid expects a non-empty TensorList") 5240 5241 for i in range(len(tensors) - 1): 5242 torch._check( 5243 tensors[i].dtype == tensors[i + 1].dtype, # type: ignore[union-attr] 5244 lambda: "meshgrid expects all tensors to have the same dtype", 5245 ) 5246 torch._check( 5247 tensors[i].device == tensors[i + 1].device, # type: ignore[union-attr] 5248 lambda: "meshgrid expects all tensors to have the same device", 5249 ) 5250 5251 swap_first_and_second_tensors = False 5252 if indexing == "xy": 5253 swap_first_and_second_tensors = len(tensors) >= 2 5254 if swap_first_and_second_tensors: 5255 tensors = (tensors[1], tensors[0], *tensors[2:]) 5256 else: 5257 torch._check( 5258 indexing == "ij", 5259 lambda: ( 5260 'torch.meshgrid: indexing must be one of "xy" or "ij", ' 5261 f"but received: {indexing}" 5262 ), 5263 ) 5264 5265 result_shape: List[int] = [] 5266 for t in tensors: 5267 assert isinstance(t, TensorLike) # mypy 5268 torch._check( 5269 t.ndim == 0 or t.ndim == 1, 5270 lambda: f"torch.meshgrid: Expected 0D or 1D tensor in the tensor list but got: {t}", 5271 ) 5272 result_shape.append(t.numel()) 5273 5274 grids: List[TensorLikeType] = [] 5275 for i, t in enumerate(tensors): 5276 assert isinstance(t, TensorLike) # mypy 5277 if t.ndim == 0: 5278 t = t.view((1,)) 5279 grids.append(prims.broadcast_in_dim(t, result_shape, (i,))) 5280 5281 if swap_first_and_second_tensors: 5282 # Swap outputs if we originally swapped at the beginning 5283 grids[0], grids[1] = grids[1], grids[0] 5284 5285 return grids 5286 5287 5288# CompositeImplicitAutograd - don't register decomp 5289def movedim( 5290 input: TensorLikeType, 5291 source: Union[int, DimsSequenceType], 5292 destination: Union[int, DimsSequenceType], 5293) -> TensorLikeType: 5294 """ 5295 Reference implementation of torch.movedim 5296 """ 5297 if type(source) is int: 5298 source = (source,) 5299 if type(destination) is int: 5300 destination = (destination,) 5301 5302 # Converts to list to produce a compatible error message with core PyTorch, 5303 # which prints sequences in square brackets. 5304 torch._check( 5305 len(source) == len(destination), # type: ignore[arg-type] 5306 lambda: ( 5307 "movedim: Invalid source or destination dims: source " # type: ignore[arg-type] 5308 f"({list(source)} dims) should contain the same number " # type: ignore[arg-type] 5309 f"of dims as destination ({list(destination)} dims)" # type: ignore[arg-type] 5310 ), 5311 ) 5312 5313 rank = input.ndim 5314 ss = tuple(utils.canonicalize_dims(rank=rank, indices=source)) # type: ignore[arg-type] 5315 ds = tuple(utils.canonicalize_dims(rank=rank, indices=destination)) # type: ignore[arg-type] 5316 5317 sss = set(ss) 5318 dss = set(ds) 5319 5320 # See above on why this converts to list in error messages. 5321 torch._check( 5322 len(ss) == len(sss), 5323 lambda: f"movedim: repeated dim in `source` ({list(source)})", # type: ignore[arg-type] 5324 ) 5325 torch._check( 5326 len(ds) == len(dss), 5327 lambda: f"movedim: repeated dim in `destination` ({list(destination)})", # type: ignore[arg-type] 5328 ) 5329 5330 m = dict(zip(ds, ss)) 5331 dims = [] 5332 si = 0 # source index 5333 for di in range(rank): 5334 # check if the destination index is in the mapping 5335 s = m.get(di) 5336 if s is not None: 5337 # insert source index if found 5338 dims.append(s) 5339 else: 5340 # insert source index sequentially, skipping indices from the mapping 5341 while si in sss: 5342 si += 1 5343 dims.append(si) 5344 si += 1 5345 5346 result = torch.permute(input, tuple(dims)) 5347 5348 return result 5349 5350 5351# NOTE: for convenience, shape can be a tuple of ints or a tuple containing a tuple of ints 5352@register_decomposition(aten.empty_strided) 5353@out_wrapper() 5354def empty_strided( 5355 shape: Union[ShapeType, Tuple[ShapeType]], 5356 strides: StrideType, 5357 *, 5358 dtype: Optional[torch.dtype] = None, 5359 device: Optional[DeviceLikeType] = None, 5360 layout: torch.layout = torch.strided, 5361 requires_grad: bool = False, 5362 pin_memory: bool = False, 5363) -> TensorLikeType: 5364 # Layout == strided, pin_memory is False 5365 utils.check_layout(layout) 5366 utils.check_pin_memory(pin_memory) 5367 5368 shape = utils.extract_shape_from_varargs(shape) 5369 dtype = torch.get_default_dtype() if dtype is None else dtype 5370 device = torch.device("cpu") if device is None else device 5371 5372 return prims.empty_strided( 5373 shape, 5374 strides, 5375 dtype=dtype, 5376 device=device, 5377 requires_grad=requires_grad, 5378 ) 5379 5380 5381@register_decomposition(aten.eye) 5382@out_wrapper() 5383def eye( 5384 n: int, 5385 m: Optional[int] = None, 5386 *, 5387 dtype: Optional[torch.dtype] = None, 5388 layout: torch.layout = torch.strided, 5389 device: Optional[DeviceLikeType] = None, 5390 pin_memory: bool = False, 5391 requires_grad: bool = False, # TODO: unused 5392) -> TensorLikeType: 5393 """ 5394 Reference implementation of torch.eye 5395 """ 5396 if m is None: 5397 m = n 5398 5399 torch._check(n >= 0, lambda: f"n must be greater or equal to 0, got {n}") 5400 torch._check(m >= 0, lambda: f"m must be greater or equal to 0, got {m}") 5401 5402 range_n = torch.arange(n, dtype=torch.int64, device=device, requires_grad=False) 5403 range_m = torch.arange(m, dtype=torch.int64, device=device, requires_grad=False) 5404 5405 cond = range_n.unsqueeze(-1) == range_m 5406 if dtype is torch.bool: 5407 return cond 5408 else: 5409 one = torch.ones( 5410 (1,), 5411 dtype=dtype, 5412 layout=layout, 5413 device=device, 5414 pin_memory=pin_memory, 5415 requires_grad=False, 5416 ) 5417 return torch.where(cond, one, 0) 5418 # TODO: Use requires_grad. All refs taking the requires_grad kwarg must 5419 # return a leaf tensor. 5420 # result.requires_grad_(requires_grad) 5421 5422 5423@register_decomposition([aten.full.default, aten.full.out]) 5424@out_wrapper() 5425def full( 5426 shape: ShapeType, 5427 fill_value: NumberType, 5428 *, 5429 dtype: Optional[torch.dtype] = None, 5430 layout: torch.layout = torch.strided, 5431 device: Optional[DeviceLikeType] = None, 5432 pin_memory: bool = False, 5433 requires_grad: bool = False, 5434) -> TensorLikeType: 5435 utils.check_layout(layout) 5436 utils.check_pin_memory(pin_memory) 5437 5438 dtype = dtype if dtype is not None else utils.type_to_dtype(type(fill_value)) 5439 device = device if device is not None else torch.device("cpu") 5440 5441 e = empty( 5442 shape, 5443 dtype=dtype, 5444 layout=layout, 5445 device=device, 5446 pin_memory=pin_memory, 5447 requires_grad=requires_grad, 5448 ) 5449 return torch.fill(e, fill_value) # type: ignore[arg-type] 5450 5451 5452def full_like( 5453 a: TensorLikeType, 5454 fill_value: NumberType, 5455 *, 5456 dtype: Optional[torch.dtype] = None, 5457 layout: Optional[torch.layout] = None, 5458 device: Optional[DeviceLikeType] = None, 5459 pin_memory: bool = False, 5460 requires_grad: bool = False, 5461 memory_format: torch.memory_format = torch.preserve_format, 5462) -> TensorLikeType: 5463 e = torch.empty_like( 5464 a, 5465 dtype=dtype, 5466 layout=layout, 5467 device=device, 5468 pin_memory=pin_memory, 5469 requires_grad=requires_grad, 5470 memory_format=memory_format, 5471 ) 5472 return fill(e, fill_value) 5473 5474 5475@register_decomposition(aten.zeros_like) 5476@out_wrapper() 5477def zeros_like( 5478 a: TensorLikeType, 5479 *, 5480 dtype: Optional[torch.dtype] = None, 5481 layout: Optional[torch.layout] = None, 5482 device: Optional[DeviceLikeType] = None, 5483 pin_memory: bool = False, 5484 requires_grad: bool = False, 5485 memory_format: torch.memory_format = torch.preserve_format, 5486) -> TensorLikeType: 5487 return torch.full_like( 5488 a, 5489 False if (dtype or a.dtype) == torch.bool else 0, 5490 dtype=dtype, 5491 layout=layout, 5492 device=device, 5493 pin_memory=pin_memory, 5494 requires_grad=requires_grad, 5495 memory_format=memory_format, 5496 ) 5497 5498 5499@register_decomposition(aten.ones_like) 5500@out_wrapper() 5501def ones_like( 5502 a: TensorLikeType, 5503 *, 5504 dtype: Optional[torch.dtype] = None, 5505 layout: Optional[torch.layout] = None, 5506 device: Optional[DeviceLikeType] = None, 5507 pin_memory: bool = False, 5508 requires_grad: bool = False, 5509 memory_format: torch.memory_format = torch.preserve_format, 5510) -> TensorLikeType: 5511 return torch.full_like( 5512 a, 5513 True if (dtype or a.dtype) == torch.bool else 1, 5514 dtype=dtype, 5515 layout=layout, 5516 device=device, 5517 pin_memory=pin_memory, 5518 requires_grad=requires_grad, 5519 memory_format=memory_format, 5520 ) 5521 5522 5523@register_decomposition(aten.randn.default) 5524@out_wrapper() 5525def randn( 5526 *shape, 5527 dtype: Optional[torch.dtype] = None, 5528 device: Optional[DeviceLikeType] = None, 5529 layout: Optional[torch.layout] = None, 5530 requires_grad: bool = False, 5531 pin_memory: bool = False, 5532) -> TensorLikeType: 5533 utils.check_pin_memory(pin_memory) 5534 5535 shape_ = utils.extract_shape_from_varargs(shape) 5536 5537 dtype = utils.dtype_or_default(dtype) 5538 device = utils.device_or_default(device) 5539 5540 return prims.normal( 5541 shape_, 5542 mean=0.0, 5543 std=1.0, 5544 dtype=dtype, 5545 device=device, 5546 requires_grad=requires_grad, 5547 ) 5548 5549 5550def scalar_tensor( 5551 a: NumberType, 5552 *, 5553 dtype: Optional[torch.dtype] = None, 5554 layout: torch.layout = torch.strided, 5555 device: Optional[DeviceLikeType] = None, 5556 pin_memory: bool = False, 5557) -> TensorLikeType: 5558 utils.check_layout(layout) 5559 utils.check_pin_memory(pin_memory) 5560 dtype = dtype if dtype is not None else utils.type_to_dtype(type(a)) 5561 device = device if device is not None else torch.device("cpu") 5562 return prims.scalar_tensor(a, dtype=dtype, device=device) 5563 5564 5565# 5566# Randomness References 5567# 5568 5569 5570def _uniform_helper( 5571 shape: ShapeType, 5572 low: Union[bool, int, float] = 0.0, 5573 high: Union[bool, int, float] = 1.0, 5574 *, 5575 dtype: torch.dtype, 5576 device: DeviceLikeType, 5577) -> TensorLikeType: 5578 utils.validate_shape(shape) 5579 5580 assert isinstance(low, Number) 5581 assert isinstance(high, Number) 5582 low = sym_float(low) 5583 high = sym_float(high) 5584 5585 assert isinstance(dtype, torch.dtype) 5586 device = utils.canonicalize_device(device) 5587 5588 return prims._uniform_helper(shape, low=low, high=high, dtype=dtype, device=device) 5589 5590 5591@register_decomposition(aten.masked_fill) 5592@out_wrapper() 5593def masked_fill(a: TensorLikeType, mask: TensorLikeType, value: TensorOrNumberLikeType): 5594 python_type = utils.dtype_to_type(a.dtype) 5595 if isinstance(value, Number): 5596 value_type = type(value) 5597 else: 5598 # NOTE: Could not use value = item(value) as it resulted in 5599 # RuntimeError: Cannot cast FakeTensor(cpu) to number 5600 value_ndim = value.ndim 5601 torch._check( 5602 value_ndim == 0, 5603 lambda: f"only supports a 0-dimensional value tensor, but got tensor with {value_ndim} dimension", 5604 ) 5605 # `masked_fill` allows cpu scalar to be moved to cuda, xpu and hpu but not otherwise. 5606 is_cpu_scalar = ( 5607 a.device.type 5608 in ["cuda", "xpu", torch._C._get_privateuse1_backend_name(), "hpu"] 5609 and value.device.type == "cpu" 5610 ) 5611 torch._check( 5612 is_cpu_scalar or value.device == a.device, 5613 lambda: "Expected `value` to be on same device as `a`", 5614 ) 5615 value_type = utils.dtype_to_type(value.dtype) 5616 5617 if value_type is complex: 5618 # only downcasting from complex to lower type is not allowed. 5619 # We allow casting `value` to lower type for other case 5620 # Eg. float -> int. 5621 # Ref: https://github.com/pytorch/pytorch/issues/79195 5622 torch._check( 5623 utils.is_weakly_lesser_type(value_type, python_type), 5624 lambda: f"could not convert to type {python_type} without overflow", 5625 ) 5626 5627 # Since `where` allows type-promotion, 5628 # cast value to correct type before passing to `where` 5629 value = _maybe_convert_to_dtype(value, a.dtype) 5630 r = torch.where(mask, value, a) # type: ignore[arg-type] 5631 5632 # aten.mask_fill always return a new contiguous tensor 5633 # contiguous() is needed to correctly model the output stride 5634 return r.contiguous() 5635 5636 5637@register_decomposition(aten.masked_fill_) 5638def masked_fill_( 5639 a: TensorLikeType, mask: TensorLikeType, value: TensorOrNumberLikeType 5640) -> TensorLikeType: 5641 b = torch.masked_fill(a, mask, value) # type: ignore[arg-type] 5642 a.copy_(b) 5643 return a 5644 5645 5646# CompositeImplicitAutograd - don't register decomp 5647def allclose( 5648 a: TensorLikeType, 5649 b: TensorLikeType, 5650 rtol: float = 1e-05, 5651 atol: float = 1e-08, 5652 equal_nan: bool = False, 5653) -> bool: 5654 """ 5655 Reference implementation of torch.allclose 5656 """ 5657 _check_close_args(name="torch.allclose", a=a, b=b, rtol=rtol, atol=atol) 5658 5659 return bool( 5660 torch.all(torch.isclose(a, b, rtol=rtol, atol=atol, equal_nan=equal_nan)).item() 5661 ) 5662 5663 5664def equal(a: TensorLikeType, b: TensorLikeType) -> bool: 5665 utils.check_same_device(a, b, allow_cpu_scalar_tensors=False) 5666 utils.check_same_dtype(a, b) 5667 5668 # Shape check 5669 if a.ndim != b.ndim: 5670 return False 5671 5672 for x, y in zip(a.shape, b.shape): 5673 if x != y: 5674 return False 5675 5676 # Short-circuits if there are no elements to validate 5677 if a.numel() == 0: 5678 return True 5679 5680 return item(all(eq(a, b))) # type: ignore[return-value] 5681 5682 5683@register_decomposition(aten.norm) 5684@out_wrapper(exact_dtype=True) 5685def norm( 5686 input: TensorLikeType, 5687 p: Optional[Union[float, str]] = "fro", 5688 dim: Optional[DimsType] = None, 5689 keepdim: bool = False, 5690 *, 5691 dtype: Optional[torch.dtype] = None, 5692) -> TensorLikeType: 5693 # In these cases we compute the "Frobenius norm" 5694 if ( 5695 p == "fro" and (dim is None or isinstance(dim, Dim) or len(dim) <= 2) 5696 ) or p is None: 5697 p = 2 5698 if isinstance(dim, Dim): 5699 dim = [dim] 5700 if isinstance(p, str): 5701 # Here we either call the nuclear norm, or we call matrix_norm with some arguments 5702 # that will throw an error 5703 if dim is None: 5704 dim = tuple(range(input.ndim)) 5705 return torch.linalg.matrix_norm(input, p, dim, keepdim, dtype=dtype) 5706 else: 5707 return torch.linalg.vector_norm(input, p, dim, keepdim, dtype=dtype) 5708 5709 5710@register_decomposition(aten.trace) 5711@out_wrapper() 5712def trace(self: TensorLikeType) -> TensorLikeType: 5713 torch._check( 5714 self.ndim == 2, lambda: "expected a matrix, but got tensor with dim {self.ndim}" 5715 ) 5716 return torch.sum(torch.diag(self, 0)) 5717 5718 5719def _make_r_binary_op(base_op): 5720 def rop( 5721 a: Union[TensorLikeType, NumberType], 5722 b: Union[TensorLikeType, NumberType], 5723 ) -> TensorLikeType: 5724 return base_op(b, a) 5725 5726 return rop 5727 5728 5729rtruediv = _make_r_binary_op(true_divide) 5730rfloordiv = _make_r_binary_op(floor_divide) 5731rpow = _make_r_binary_op(pow) 5732 5733 5734@register_decomposition(aten.triu) 5735@out_wrapper() 5736def triu(a: TensorLikeType, diagonal: int = 0) -> TensorLikeType: 5737 torch._check( 5738 a.ndim >= 2, lambda: "triu: input tensor must have at least 2 dimensions" 5739 ) 5740 h, w = a.shape[-2:] 5741 mask = ( 5742 torch.arange(w, device=a.device).unsqueeze(-2) 5743 - torch.arange(h, device=a.device).unsqueeze(-1) 5744 ) >= diagonal 5745 5746 # aten.triu always returns a new contiguous tensor 5747 # contiguous() is needed to correctly model the output stride 5748 return utils.mask_tensor(mask, a).contiguous() 5749 5750 5751@register_decomposition(aten.tril) 5752@out_wrapper() 5753def tril(a: TensorLikeType, diagonal: int = 0) -> TensorLikeType: 5754 torch._check( 5755 a.ndim >= 2, lambda: "tril: input tensor must have at least 2 dimensions" 5756 ) 5757 h, w = a.shape[-2:] 5758 mask = ( 5759 torch.arange(w, device=a.device).unsqueeze(-2) 5760 - torch.arange(h, device=a.device).unsqueeze(-1) 5761 ) <= diagonal 5762 5763 # aten.tril always returns a new contiguous tensor 5764 # contiguous() is needed to correctly model the output stride 5765 return utils.mask_tensor(mask, a).contiguous() 5766 5767 5768# This is based on get_tril_size in aten/src/ATen/native/TensorFactories.h 5769# The components of the matrix that belong to the lower triangle with offset 5770# form a pentagon that can be broken down into a top trapezoid and a bottom 5771# rectangle. For the implementation of tril_indices, we need the sizes of 5772# both of these, as well as the length of the top side of the trapezoid. 5773def _get_tril_sizes(row: int, col: int, offset: int) -> Tuple[int, int, int]: 5774 if row == 0 or col == 0: 5775 return 0, 0, 0 5776 5777 m_first_row = min(col, 1 + offset) if offset > 0 else int(row + offset > 0) 5778 m_last_row = max(0, min(col, row + offset)) 5779 n_row_all = max(0, min(row, row + offset)) 5780 n_row_trapezoid = m_last_row - m_first_row + 1 5781 5782 # Number of elements in top trapezoid 5783 trapezoid_size = (m_first_row + m_last_row) * n_row_trapezoid // 2 5784 # Number of elements in bottom rectangle 5785 diff_row = n_row_all - n_row_trapezoid 5786 rectangle_size = max(0, diff_row * col) 5787 5788 return trapezoid_size, rectangle_size, m_first_row 5789 5790 5791def _trilu_checks( 5792 name: str, 5793 row: int, 5794 col: int, 5795 dtype: torch.dtype, 5796 layout: torch.layout, 5797 pin_memory: bool, 5798): 5799 torch._check(row >= 0, lambda: f"row must be non-negative, got {row}") 5800 torch._check(col >= 0, lambda: f"col must be non-negative, got {col}") 5801 torch._check( 5802 dtype in (torch.int32, torch.int64), 5803 lambda: f"\"{name}\" not implemented for '{dtype}'", 5804 ) 5805 5806 5807# This is based on tril_indices_cuda in aten/src/ATen/native/cuda/TensorFactories.cu 5808@register_decomposition(aten.tril_indices) 5809@out_wrapper() 5810def tril_indices( 5811 row: int, 5812 col: int, 5813 offset: int = 0, 5814 *, 5815 dtype: torch.dtype = torch.long, 5816 layout: torch.layout = torch.strided, 5817 device: DeviceLikeType = "cpu", 5818 pin_memory: bool = False, 5819) -> TensorLikeType: 5820 _trilu_checks("tril_indices", row, col, dtype, layout, pin_memory) 5821 5822 trapezoid_size, rectangle_size, m_first_row = _get_tril_sizes(row, col, offset) 5823 row_offset = max(0, -offset) 5824 5825 arange_kw = partial( 5826 torch.arange, layout=layout, device=device, pin_memory=pin_memory 5827 ) 5828 5829 # first we do the indices for top trapezoid 5830 xs1 = arange_kw(0, trapezoid_size, dtype=torch.float64) 5831 b = m_first_row - 0.5 5832 row_inds1 = torch.floor(-b + torch.sqrt(b * b + 2 * xs1)) 5833 col_inds1 = torch.floor(xs1 - (2 * m_first_row - 1 + row_inds1) * row_inds1 * 0.5) 5834 row_inds1 = _maybe_convert_to_dtype(row_inds1 + row_offset, dtype) 5835 col_inds1 = _maybe_convert_to_dtype(col_inds1, dtype) 5836 5837 # then bottom rectangle 5838 xs2 = arange_kw(0, rectangle_size, dtype=dtype) 5839 row_inds2 = xs2 // col + (col - m_first_row + 1 + row_offset) 5840 col_inds2 = xs2 % col 5841 5842 return torch.stack( 5843 (torch.cat((row_inds1, row_inds2)), torch.cat((col_inds1, col_inds2))) 5844 ) 5845 5846 5847# Similar to _get_tril_sizes above, but here there is a top trapezoid and 5848# a bottom rectangle instead. Note that you can't reduce this to 5849# _get_tril_sizes(col, row, -offset) because that would correspond to 5850# decomposing into a left trapezoid and right rectangle. 5851def _get_triu_sizes(row: int, col: int, offset: int) -> Tuple[int, int, int]: 5852 if row == 0 or col == 0: 5853 return 0, 0, 0 5854 5855 m_first_row = max(0, col - offset) if offset > 0 else col 5856 5857 # Number of elements in top rectangle 5858 rectangle_size = max(0, min(row, -offset) * col) 5859 5860 # Number of elements in bottom trapezoid 5861 trapezoid_size_tril, rectangle_size_tril, _ = _get_tril_sizes(row, col, offset - 1) 5862 triu_size = row * col - (trapezoid_size_tril + rectangle_size_tril) 5863 trapezoid_size = triu_size - rectangle_size 5864 5865 return trapezoid_size, rectangle_size, m_first_row 5866 5867 5868@register_decomposition(aten.triu_indices) 5869@out_wrapper() 5870def triu_indices( 5871 row: int, 5872 col: int, 5873 offset: int = 0, 5874 *, 5875 dtype: torch.dtype = torch.long, 5876 layout: torch.layout = torch.strided, 5877 device: DeviceLikeType = "cpu", 5878 pin_memory: bool = False, 5879) -> TensorLikeType: 5880 _trilu_checks("triu_indices", row, col, dtype, layout, pin_memory) 5881 5882 trapezoid_size, rectangle_size, m_first_row = _get_triu_sizes(row, col, offset) 5883 col_offset = max(0, offset) 5884 5885 arange_kw = partial( 5886 torch.arange, layout=layout, device=device, pin_memory=pin_memory 5887 ) 5888 5889 # indices for top rectangle 5890 xs2 = arange_kw(0, rectangle_size, dtype=dtype) 5891 row_inds2 = xs2 // col 5892 col_inds2 = xs2 % col 5893 5894 # bottom trapezoid 5895 xs1 = arange_kw(0, trapezoid_size, dtype=torch.float64) 5896 b = -0.5 - m_first_row 5897 row_inds1 = torch.floor(-b - torch.sqrt(b * b - 2 * xs1)) 5898 col_inds1 = torch.floor(xs1 - ((2 * m_first_row - 1 - row_inds1) * row_inds1) * 0.5) 5899 row_inds1 = _maybe_convert_to_dtype(row_inds1, dtype) 5900 col_inds1 = _maybe_convert_to_dtype(col_inds1, dtype) 5901 5902 if col: 5903 row_inds1 = row_inds1 + (rectangle_size // col) 5904 col_inds1 = col_inds1 + col_offset 5905 5906 return torch.stack( 5907 (torch.cat((row_inds2, row_inds1)), torch.cat((col_inds2, col_inds1))) 5908 ) 5909 5910 5911@register_decomposition(aten.bucketize) 5912@out_wrapper(exact_dtype=True) 5913def bucketize( 5914 a: TensorOrNumberLikeType, 5915 boundaries: TensorLikeType, 5916 *, 5917 out_int32: bool = False, 5918 right: bool = False, 5919): 5920 torch._check( 5921 boundaries.dim() == 1, 5922 lambda: f"boundaries tensor must be 1 dimension but got dim({boundaries.dim()})", 5923 ) 5924 5925 a = a if isinstance(a, torch.Tensor) else torch.tensor(a) 5926 out_dtype = torch.int32 if out_int32 else torch.int64 5927 n_boundaries = boundaries.shape[-1] 5928 if n_boundaries == 0: 5929 return torch.zeros_like(a) 5930 # We are trying to find the bucket (defined by pairs of consecutive elements of `boundaries`) 5931 # each element of `a` belongs to. We use binary search to achieve logarithimic complexity, 5932 # but each step of the search is done "in parallel" over all elements of `a` 5933 # can't use int32 as indexes, so we have to do all computations with int64 and convert at the end 5934 start = torch.zeros(a.shape, device=a.device, dtype=torch.int64) 5935 end = start + n_boundaries 5936 # Max depth of the binary search 5937 # Since we can't break out of the loop at different points for different elements of a, 5938 # we just do the max amount of iterations that binary search requires and add condition 5939 # tensor (cond_update below) to stop updating once the search terminates 5940 5941 # For first iteration through loop we can skip some checks, we have separate implementation 5942 mid = start + (end - start) // 2 5943 mid_val = boundaries[mid] 5944 if right: 5945 cond_mid = mid_val > a 5946 else: 5947 cond_mid = mid_val >= a 5948 start = torch.where(cond_mid, start, mid + 1) 5949 5950 if n_boundaries > 1: 5951 cond_update = torch.ones_like(a, dtype=torch.bool) 5952 niters = int(math.log2(n_boundaries)) 5953 for _ in range(niters): 5954 end = torch.where(cond_mid & cond_update, mid, end) 5955 cond_update = start < end 5956 # start might end up pointing to 1 past the end, we guard against that 5957 mid = torch.where(cond_update, start + (end - start) // 2, 0) 5958 mid_val = boundaries[mid] 5959 # If right is true, the buckets are closed on the *left* 5960 # (i.e., we are doing the equivalent of std::upper_bound in C++) 5961 # Otherwise they are closed on the right (std::lower_bound) 5962 if right: 5963 cond_mid = mid_val > a 5964 else: 5965 cond_mid = mid_val >= a 5966 start = torch.where((~cond_mid) & cond_update, mid + 1, start) 5967 5968 return start.to(dtype=out_dtype) 5969 5970 5971@register_decomposition(aten.cauchy) 5972@out_wrapper() 5973@elementwise_type_promotion_wrapper( 5974 type_promoting_args=("self",), 5975 type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT, 5976) 5977def cauchy(self, median=0, sigma=1, generator=None): 5978 assert generator is None 5979 torch._check( 5980 not utils.is_complex_dtype(self.dtype) 5981 and not utils.is_integer_dtype(self.dtype) 5982 and not utils.is_boolean_dtype(self.dtype), 5983 lambda: f"Cauchy distribution is a continuous probability distribution. \ 5984 dtype must be a floating point but you specified {self.dtype}", 5985 ) 5986 torch._check( 5987 sigma > 0.0, 5988 lambda: f"cauchy_ expects sigma > 0.0, but found sigma={sigma}", 5989 ) 5990 return median + sigma * torch.tan(math.pi * (torch.rand_like(self) - 0.5)) 5991 5992 5993@register_decomposition(aten.exponential) 5994@out_wrapper() 5995@elementwise_type_promotion_wrapper( 5996 type_promoting_args=("self",), 5997 type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT, 5998) 5999def exponential(self, rate=1, generator=None): 6000 assert generator is None 6001 torch._check( 6002 not utils.is_complex_dtype(self.dtype) 6003 and not utils.is_integer_dtype(self.dtype) 6004 and not utils.is_boolean_dtype(self.dtype), 6005 lambda: f"Exponential distribution is a continuous probability distribution. \ 6006 dtype must be a floating point but you specified {self.dtype}", 6007 ) 6008 torch._check( 6009 rate > 0.0, 6010 lambda: f"exponential_ expects lambda > 0.0, but found lambda={rate}", 6011 ) 6012 6013 uniform_val = torch.rand_like(self) 6014 6015 # copying numerics of transformation::exponential see comment: 6016 # curand_uniform has (0,1] bounds. log(1) is 0 and exponential excludes 0. 6017 # we need log to be not 0, and not underflow when converted to half 6018 # fast __logf approximation can underflow, so set log to -epsilon/2 for 1 or close to 1 args 6019 epsilon = torch.finfo(uniform_val.dtype).eps / 2 6020 condition = uniform_val >= 1.0 - epsilon 6021 log_uniform = torch.where(condition, -epsilon, torch.log(uniform_val)) 6022 6023 return -1 / rate * log_uniform 6024 6025 6026@register_decomposition(aten.geometric) 6027@out_wrapper() 6028@elementwise_type_promotion_wrapper( 6029 type_promoting_args=("self",), 6030 type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT, 6031) 6032def geometric(self, p, generator=None): 6033 assert generator is None 6034 # TODO: fix inductor rand_like for integer, bool dtypes 6035 torch._check( 6036 not utils.is_complex_dtype(self.dtype) 6037 and not utils.is_boolean_dtype(self.dtype), 6038 lambda: f"geometric not implemented for {self.dtype}", 6039 ) 6040 torch._check( 6041 0 < p and p < 1, 6042 lambda: f"geometric_ expects p to be in (0, 1), but got p={p}", 6043 ) 6044 return torch.floor(torch.log1p(-torch.rand_like(self)) / math.log1p(-p)) + 1 6045 6046 6047@register_decomposition(aten.log_normal) 6048@out_wrapper() 6049@elementwise_type_promotion_wrapper( 6050 type_promoting_args=("self",), 6051 type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT, 6052) 6053def log_normal(self, mean=1, std=2, generator=None): 6054 assert generator is None 6055 torch._check( 6056 not utils.is_complex_dtype(self.dtype) 6057 and not utils.is_integer_dtype(self.dtype) 6058 and not utils.is_boolean_dtype(self.dtype), 6059 lambda: f"log_normal not implemented for {self.dtype}", 6060 ) 6061 torch._check( 6062 0 < std, 6063 lambda: f"log_normal_ expects std > 0.0, but found std={std}", 6064 ) 6065 return torch.exp(std * torch.randn_like(self) + mean) 6066 6067 6068# TODO: add support for functionalization aten.normal_functional 6069# NOTE: the device and dtype will be ignored when shape is None 6070@register_decomposition(aten.normal) 6071@out_wrapper() 6072@elementwise_type_promotion_wrapper( 6073 type_promoting_args=( 6074 "mean", 6075 "std", 6076 ), 6077 type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT, 6078) 6079def normal( 6080 mean=0, 6081 std=1, 6082 size=None, 6083 *, 6084 generator=None, 6085 dtype=None, 6086 layout=None, 6087 device=None, 6088 pin_memory=None, 6089): 6090 assert layout is None or layout == torch.strided 6091 6092 if not isinstance(std, TensorLike): 6093 torch._check( 6094 std >= 0, lambda: f"normal expects std >= 0.0, but found std {std}" 6095 ) 6096 6097 if size is None: 6098 tensors = tuple(t for t in (mean, std) if isinstance(t, TensorLike)) 6099 torch._check( 6100 len(tensors) > 0, 6101 lambda: "normal expects that either mean or std is a tensor, or size is defined", 6102 ) 6103 torch._check( 6104 layout is None and pin_memory is None, 6105 lambda: "Cannot pass layout, or pin_memory without size", 6106 ) 6107 6108 size = _broadcast_shapes(*(t.shape for t in tensors)) 6109 dtype = tensors[0].dtype 6110 device = tensors[0].device 6111 else: 6112 torch._check( 6113 not isinstance(mean, TensorLike) and not isinstance(std, TensorLike), 6114 lambda: "normal expects mean and std to be scalars when size is defined", 6115 ) 6116 dtype = torch.get_default_dtype() if dtype is None else dtype 6117 device = torch.device("cpu") if device is None else device 6118 6119 normal_samples = prims.normal( 6120 size, 6121 mean=0.0, 6122 std=1.0, 6123 dtype=dtype, 6124 device=device, 6125 requires_grad=False, 6126 generator=generator, 6127 ) 6128 return std * normal_samples + mean 6129 6130 6131@register_decomposition(aten.normal_) 6132def normal_(self, mean=0, std=1, *, generator=None): 6133 return normal(mean, std, self.shape, out=self, generator=generator) 6134 6135 6136@_make_elementwise_unary_reference(ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT) 6137def rad2deg(self: TensorLikeType): 6138 torch._check( 6139 not utils.is_complex_dtype(self.dtype), 6140 lambda: "rad2deg is not supported for complex tensors.", 6141 ) 6142 M_180_PI = 57.295779513082320876798154814105170332405472466564 6143 return self * M_180_PI 6144 6145 6146@_make_elementwise_unary_reference(ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT) 6147def deg2rad(self: TensorLikeType): 6148 torch._check( 6149 not utils.is_complex_dtype(self.dtype), 6150 lambda: "deg2rad is not supported for complex tensors.", 6151 ) 6152 M_PI_180 = 0.017453292519943295769236907684886127134428718885417 6153 return self * M_PI_180 6154 6155 6156@register_decomposition(aten.count_nonzero) 6157@out_wrapper() 6158def count_nonzero(self, dim: Optional[DimsType] = None): 6159 return (self != 0).sum(dim) 6160 6161 6162def _dot_check(self, other): 6163 torch._check( 6164 self.dim() == 1 and other.dim() == 1, 6165 lambda: f"1D tensors expected, but got {self.dim()}D and {other.dim()}D tensors", 6166 ) 6167 6168 def numel_error(): 6169 return ( 6170 f"inconsistent tensor size, expected tensor [{self.numel()}] and src [{other.numel()}] to have the" 6171 f"same number of elements, but got {self.numel()} and {other.numel()} elements respectively" 6172 ) 6173 6174 torch._check(self.numel() == other.numel(), numel_error) 6175 6176 6177@register_decomposition(aten.dot) 6178@out_wrapper() 6179@elementwise_type_promotion_wrapper( 6180 type_promoting_args=("self", "other"), 6181 type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT, 6182) 6183def dot(self, other): 6184 if self.is_complex(): 6185 if self.is_conj(): 6186 if other.is_conj(): 6187 return torch.dot(self.conj(), other.conj()).conj() 6188 else: 6189 return torch.vdot(self.conj(), other) 6190 elif other.is_conj(): 6191 return torch.vdot(other.conj(), self) 6192 6193 _dot_check(self, other) 6194 return (self * other).sum() 6195 6196 6197@register_decomposition(aten.vdot) 6198@out_wrapper() 6199@elementwise_type_promotion_wrapper( 6200 type_promoting_args=("self", "other"), 6201 type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT, 6202) 6203def vdot(self, other): 6204 if not self.is_complex(): 6205 return torch.dot(self, other) 6206 6207 if self.is_conj(): 6208 if other.is_conj(): 6209 return torch.vdot(other.conj(), self.conj()) 6210 else: 6211 return torch.dot(self.conj(), other) 6212 elif other.is_conj(): 6213 return torch.dot(self, other.conj()).conj() 6214 6215 _dot_check(self, other) 6216 # The decomposition fails if you do self.conj()... not sure why 6217 return (self.conj_physical() * other).sum() 6218 6219 6220@register_decomposition(aten.select_scatter) 6221@out_wrapper() 6222def select_scatter(x: TensorLikeType, src: TensorLikeType, dim: int, index: int): 6223 dim = utils.canonicalize_dim(x.ndim, dim) 6224 mask_shape = [1] * x.ndim 6225 mask_shape[dim] = -1 6226 if index < 0: 6227 index = index + x.shape[dim] 6228 mask = torch.arange(x.shape[dim], device=x.device).view(mask_shape) == index 6229 src = torch.unsqueeze(src, dim).expand(x.shape) 6230 return torch.where(mask, src, x) 6231 6232 6233# inplace 6234abs_ = _make_inplace(abs) 6235acos_ = _make_inplace(acos) 6236acosh_ = _make_inplace(acosh) 6237add_ = _make_inplace(add) 6238addcmul_ = _make_inplace(addcmul) 6239addcdiv_ = _make_inplace(addcdiv) 6240asin_ = _make_inplace(asin) 6241asinh_ = _make_inplace(asinh) 6242atan_ = _make_inplace(atan) 6243atanh_ = _make_inplace(atanh) 6244atan2_ = _make_inplace(atan2) 6245bitwise_and_ = _make_inplace(bitwise_and) 6246bitwise_left_shift_ = _make_inplace(bitwise_left_shift) 6247bitwise_not_ = _make_inplace(bitwise_not) 6248bitwise_or_ = _make_inplace(bitwise_or) 6249bitwise_right_shift_ = _make_inplace(bitwise_right_shift) 6250bitwise_xor_ = _make_inplace(bitwise_xor) 6251ceil_ = _make_inplace(ceil) 6252clamp_ = _make_inplace(clamp) 6253clamp_min_ = _make_inplace(clamp_min) 6254clamp_max_ = _make_inplace(clamp_max) 6255conj_physical_ = _make_inplace(conj_physical) 6256copysign_ = _make_inplace(copysign) 6257cos_ = _make_inplace(cos) 6258cosh_ = _make_inplace(cosh) 6259cumsum_ = _make_inplace(cumsum) 6260cumprod_ = _make_inplace(cumprod) 6261deg2rad_ = _make_inplace(deg2rad) 6262digamma_ = _make_inplace(digamma) 6263div_ = _make_inplace(div) 6264eq_ = _make_inplace(eq) 6265erf_ = _make_inplace(erf) 6266erfc_ = _make_inplace(erfc) 6267erfinv_ = _make_inplace(erfinv) 6268exp_ = _make_inplace(exp) 6269exp2_ = _make_inplace(exp2) 6270expm1_ = _make_inplace(expm1) 6271float_power_ = _make_inplace(float_power) 6272floor_ = _make_inplace(floor) 6273floor_divide_ = _make_inplace(floor_divide) 6274fmod_ = _make_inplace(fmod) 6275frac_ = _make_inplace(frac) 6276gcd_ = _make_inplace(gcd) 6277ge_ = _make_inplace(ge) 6278gt_ = _make_inplace(gt) 6279heaviside_ = _make_inplace(heaviside) 6280hypot_ = _make_inplace(hypot) 6281igamma_ = _make_inplace(igamma) 6282igammac_ = _make_inplace(igammac) 6283i0_ = _make_inplace(i0) 6284lcm_ = _make_inplace(lcm) 6285le_ = _make_inplace(le) 6286lerp_ = _make_inplace(lerp) 6287lgamma_ = _make_inplace(lgamma) 6288log10_ = _make_inplace(log10) 6289log1p_ = _make_inplace(log1p) 6290log2_ = _make_inplace(log2) 6291log_ = _make_inplace(log) 6292logical_and_ = _make_inplace(logical_and) 6293logical_not_ = _make_inplace(logical_not) 6294logical_or_ = _make_inplace(logical_or) 6295logical_xor_ = _make_inplace(logical_xor) 6296lt_ = _make_inplace(lt) 6297mul_ = _make_inplace(mul) 6298mvlgamma_ = _make_inplace(mvlgamma) 6299nan_to_num_ = _make_inplace(nan_to_num) 6300ne_ = _make_inplace(ne) 6301neg_ = _make_inplace(neg) 6302nextafter_ = _make_inplace(nextafter) 6303pow_ = _make_inplace(pow) 6304rad2deg_ = _make_inplace(rad2deg) 6305reciprocal_ = _make_inplace(reciprocal) 6306remainder_ = _make_inplace(remainder) 6307rsqrt_ = _make_inplace(rsqrt) 6308sgn_ = _make_inplace(sgn) 6309sigmoid_ = _make_inplace(sigmoid) 6310sign_ = _make_inplace(sign) 6311sin_ = _make_inplace(sin) 6312sinc_ = _make_inplace(sinc) 6313sinh_ = _make_inplace(sinh) 6314sqrt_ = _make_inplace(sqrt) 6315square_ = _make_inplace(square) 6316sub_ = _make_inplace(sub) 6317tan_ = _make_inplace(tan) 6318tanh_ = _make_inplace(tanh) 6319tril_ = _make_inplace(tril) 6320triu_ = _make_inplace(triu) 6321true_divide_ = _make_inplace(true_divide) 6322trunc_ = _make_inplace(trunc) 6323xlogy_ = _make_inplace(xlogy) 6324cauchy_ = _make_inplace(cauchy) 6325exponential_ = _make_inplace(exponential) 6326geometric_ = _make_inplace(geometric) 6327log_normal_ = _make_inplace(log_normal) 6328zero_ = _make_inplace(zero) 6329 6330alias_copy = _make_copy_from_view(aten.alias) 6331as_strided_copy = _make_copy_from_view(aten.as_strided) 6332diagonal_copy = _make_copy_from_view(aten.diagonal) 6333expand_copy = _make_copy_from_view(aten.expand) 6334# TODO: This must return a sparse tensor if the input is sparse, but refs have 6335# no sparse support. See narrow_copy_sparse in core. 6336narrow_copy = _make_copy_from_view(aten.narrow) 6337t_copy = _make_copy_from_view(aten.t) 6338unsqueeze_copy = _make_copy_from_view(aten.unsqueeze) 6339view_copy = _make_copy_from_view(aten.view) 6340 6341 6342# xref: isStorage in torch/csrc/DynamicTypes.cpp 6343def _isStorage(obj): 6344 return isinstance(obj, (torch.TypedStorage, torch.UntypedStorage)) 6345 6346 6347# xref: compute_sizes in torch/csrc/utils/tensor_new.cpp 6348def _compute_sizes(seq, scalar_type): 6349 MAX_DIMS = 128 6350 is_storage = _isStorage(seq) 6351 sizes = [] 6352 # TODO: this is inaccurate, we actually test PySequence_Check 6353 while isinstance(seq, (list, tuple)): 6354 length = len(seq) 6355 if is_storage: 6356 length //= scalar_type.itemsize 6357 sizes.append(length) 6358 if len(sizes) > MAX_DIMS: 6359 raise ValueError(f"too many dimensions '{type(seq).__name__}'") 6360 if length == 0: 6361 break 6362 try: 6363 handle = seq[0] 6364 except Exception: 6365 raise ValueError( # noqa: B904 6366 f"could not determine the shape of object type '{type(seq).__name__}'" 6367 ) 6368 seq = handle 6369 6370 return sizes 6371 6372 6373# xref: infer_scalar_type in torch/csrc/utils/tensor_new.cpp 6374def _infer_scalar_type(obj): 6375 if isinstance(obj, FloatLike): 6376 return torch.get_default_dtype() 6377 if isinstance(obj, IntLike) and not isinstance(obj, bool): # careful! 6378 return torch.int64 6379 if isinstance(obj, BoolLike): 6380 return torch.bool 6381 if isinstance(obj, complex): 6382 default_dtype = torch.get_default_dtype() 6383 if default_dtype is torch.float: 6384 return torch.cfloat 6385 elif default_dtype is torch.double: 6386 return torch.cdouble 6387 elif default_dtype is torch.half: 6388 return torch.chalf 6389 else: 6390 raise RuntimeError("invalid default scalar type for complex") 6391 if isinstance(obj, torch.Tensor): 6392 return obj.dtype 6393 if isinstance(obj, str): 6394 raise TypeError(f"new(): invalid data type '{type(obj).__name__}'") 6395 # TODO: this is inaccurate, we actually test PySequence_Check 6396 if isinstance(obj, (list, tuple)): 6397 scalarType = None 6398 length = len(obj) 6399 # match NumPy semantics, except use default tensor type instead of 6400 # double. 6401 if length == 0: 6402 return torch.get_default_dtype() 6403 for i in range(length): 6404 cur_item = obj[i] 6405 # TODO: test this 6406 """ 6407 if cur_item is obj: 6408 raise TypeError("new(): self-referential lists are incompatible") 6409 """ 6410 item_scalarType = _infer_scalar_type(cur_item) # recurse! 6411 if scalarType is not None: 6412 scalarType = torch.promote_types(scalarType, item_scalarType) 6413 else: 6414 scalarType = item_scalarType 6415 if scalarType is torch.cdouble: 6416 # this won't change (unless we hit undefined, but that will 6417 # fail later) 6418 return scalarType 6419 return scalarType 6420 raise RuntimeError(f"Could not infer dtype of {type(obj).__name__}") 6421 6422 6423# Analogous to recursive_store 6424# xref: recursive_store in torch/csrc/utils/tensor_new.cpp 6425def _recursive_build( 6426 scalarType: torch.dtype, obj: Union[TensorOrNumberLikeType, TensorSequenceType] 6427): 6428 if isinstance(obj, Tensor) and obj.numel() == 1: 6429 return obj.detach().to(dtype=scalarType, device="cpu", copy=True).view(()) 6430 elif isinstance(obj, Tensor): 6431 # It is invalid to call ".tensor([...])" with a non-scalar tensor in eager mode 6432 # >>> torch.tensor([torch.randn(2)]) 6433 # ValueError: only one element tensors can be converted to Python scalars 6434 # 6435 # But it is possible with a NumPy array 6436 # >>> torch.tensor([np.random.uniform(size=(2,))]).shape 6437 # torch.Size([1, 2]) 6438 return obj.detach().to(dtype=scalarType, device="cpu", copy=True) 6439 elif isinstance(obj, Number): 6440 return torch.scalar_tensor(obj, dtype=scalarType) 6441 6442 # seq can be a list of tensors 6443 seq = obj 6444 return torch.stack([_recursive_build(scalarType, item) for item in seq]) 6445 6446 6447# xref: internal_new_from_data in torch/csrc/utils/tensor_new.cpp 6448def _internal_new_from_data( 6449 options, 6450 scalar_type, 6451 device_opt, 6452 data, 6453 copy_variables, 6454 copy_numpy, 6455 type_inference, 6456 pin_memory=False, 6457): 6458 if isinstance(data, torch.Tensor): 6459 torch._check( 6460 not pin_memory, lambda: "Can't pin tensor constructed from a variable" 6461 ) 6462 var = data 6463 if copy_variables: 6464 var = var.detach() 6465 inferred_scalar_type = var.dtype if type_inference else scalar_type 6466 device = device_opt if device_opt is not None else var.device 6467 return var.to( 6468 device=device, 6469 dtype=inferred_scalar_type, 6470 non_blocking=False, 6471 copy=copy_variables, 6472 ) 6473 6474 # TODO 6475 if hasattr(data, "__cuda_array_interface__"): 6476 return NotImplemented 6477 6478 # TODO: test for numpy input with PyArray_Check 6479 6480 device = device_opt if device_opt is not None else options["device"] 6481 inferred_scalar_type = _infer_scalar_type(data) if type_inference else scalar_type 6482 6483 # NB: Don't need to avoid tracing, as we aren't going to do any manual 6484 # pointer filling tricks 6485 if _isStorage(data): 6486 return NotImplemented 6487 else: 6488 if torch.device(device).type == "meta": 6489 return NotImplemented 6490 6491 # In the C implementation, we would directly start poking the memory 6492 # of a freshly allocated CPU tensor. Here, we're going to do an 6493 # alternate, heinously slow implementation: turn each individual 6494 # scalar into a tensor, and then repeatedly cat them together 6495 tensor = _recursive_build(inferred_scalar_type, data) 6496 6497 tensor = tensor.to(device, inferred_scalar_type, non_blocking=False, copy=False) 6498 6499 # NB: lift_fresh is not needed, because we built the tensor from scalars 6500 # guaranteeing a fresh tensor in this case 6501 return tensor 6502 6503 6504# xref: tensor_ctor in torch/csrc/utils/tensor_new.cpp 6505def tensor(data, *, dtype=None, device=None, pin_memory=False, requires_grad=False): 6506 # TODO (or not): support names kwarg 6507 if isinstance(data, torch.Tensor): 6508 warnings.warn( 6509 "To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() " 6510 "or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor)" 6511 ) 6512 type_inference = dtype is None 6513 new_tensor = _internal_new_from_data( 6514 # device="cpu" because that's what you get with torch.tensor(2) no 6515 # device by default 6516 {"device": "cpu"}, # TODO: use torch.get_default_tensor_type 6517 dtype if dtype is not None else torch.get_default_dtype(), 6518 device, 6519 data, 6520 copy_variables=True, 6521 copy_numpy=True, 6522 type_inference=type_inference, 6523 pin_memory=pin_memory, 6524 ) 6525 new_tensor.detach_() 6526 if requires_grad: 6527 new_tensor.requires_grad_(requires_grad) 6528 return new_tensor 6529 6530 6531# Views 6532# We can't model these as above, as the pattern of doing `op(a, out=a)` does not work for a view function 6533# given that it does not reshape the input (it just copies the result into it) 6534 6535# squeeze_ = _make_inplace(squeeze) 6536# t_ = _make_inplace(t) 6537# transpose_ = _make_inplace(transpose) 6538# unsqueeze_ = _make_inplace(unsqueeze) 6539 6540 6541import torch._refs._conversions 6542import torch._refs.fft 6543import torch._refs.linalg 6544import torch._refs.nn.functional 6545import torch._refs.special 6546