1# mypy: allow-untyped-defs 2from __future__ import annotations 3 4import operator 5import warnings 6from contextlib import nullcontext 7from enum import Enum 8from functools import reduce 9from typing import ( 10 Any, 11 Callable, 12 cast, 13 List, 14 NamedTuple, 15 Optional, 16 overload, 17 Sequence, 18 Tuple, 19 Type, 20 TYPE_CHECKING, 21 Union, 22) 23from typing_extensions import deprecated, TypeAlias 24 25 26if TYPE_CHECKING: 27 # Import the following modules during type checking to enable code intelligence features, 28 # such as auto-completion in tools like pylance, even when these modules are not explicitly 29 # imported in user code. 30 31 import sympy 32 33import torch 34from torch import sym_float, sym_int, sym_max 35 36 37ShapeType: TypeAlias = Union[torch.Size, List[int], Tuple[int, ...]] 38StrideType: TypeAlias = Union[List[int], Tuple[int, ...]] 39DimsType: TypeAlias = Union[int, List[int], Tuple[int, ...]] 40DimsSequenceType: TypeAlias = Union[List[int], Tuple[int, ...]] 41# TODO: Type[torch.SymInt], Type[torch.SymFloat] 42NumberTypeType: TypeAlias = Union[Type[bool], Type[int], Type[float], Type[complex]] 43# TODO: This needs a lot more type annotations 44# NumberType = Union[bool, int, float, complex, torch.SymInt, torch.SymFloat] 45NumberType: TypeAlias = Union[bool, int, float, complex] 46RealNumberType: TypeAlias = Union[bool, int, float] 47 48Number = (bool, int, float, complex, torch.SymInt, torch.SymFloat, torch.SymBool) 49# I don't call it Integral because numbers.Integral includes bool, but IntLike 50# does not 51Dim = int 52IntLike = (int, torch.SymInt) 53FloatLike = (float, torch.SymFloat) 54BoolLike = (bool, torch.SymBool) 55IntWithoutSymInt = int 56FloatWithoutSymFloat = float 57DeviceLikeType: TypeAlias = Union[str, torch.device, int] 58Tensor = torch.Tensor 59 60 61torch_function_passthrough = { 62 torch.device, 63 torch.sym_not, 64 torch.sym_float, 65 torch.sym_int, 66 torch.sym_max, 67 torch.sym_min, 68 torch._sym_sqrt, # type: ignore[attr-defined] 69 torch.sym_ite, 70 torch.Tensor.dim, 71 torch.Tensor.ndim.__get__, # type: ignore[attr-defined] 72 torch.Tensor.numel, 73 torch.Tensor.size, 74 torch.Tensor.storage_offset, 75 torch.Tensor.stride, 76 torch.Tensor.dtype.__get__, # type: ignore[attr-defined] 77 torch.Tensor.is_sparse.__get__, # type: ignore[attr-defined] 78 torch.Tensor.shape.__get__, # type: ignore[attr-defined] 79 torch.Tensor.device.__get__, # type: ignore[attr-defined] 80 torch.Tensor.requires_grad.__get__, # type: ignore[attr-defined] 81 torch.Tensor.layout.__get__, # type: ignore[attr-defined] 82 torch.Tensor.is_contiguous, 83 # For TorchRefsMode only 84 torch.Tensor.__format__, 85 torch.Tensor.__repr__, 86 torch.Tensor.requires_grad.__get__, # type: ignore[attr-defined] 87 torch.Tensor.__getitem__, 88} 89 90 91TensorLikeType = torch.Tensor 92TensorLike = torch.Tensor 93TensorSequenceType: TypeAlias = Union[List[TensorLikeType], Tuple[TensorLikeType, ...]] 94TensorOrNumberLikeType: TypeAlias = Union[TensorLikeType, NumberType] 95 96CustomOutParamAnnotation = "__custom_out_param__" 97 98 99def same_shape(a: ShapeType, b: ShapeType, *, allow_rhs_unbacked=False) -> bool: 100 from torch.fx.experimental.symbolic_shapes import guard_size_oblivious 101 102 if len(a) != len(b): 103 return False 104 105 for x, y in zip(a, b): 106 if allow_rhs_unbacked: 107 # TODO: We should check that the symbols are consistent 108 # with each other 109 if isinstance(y, torch.SymInt): 110 continue 111 # NB: Naively, you would not expect to have to do an oblivious guard 112 # here because there is seemingly no broadcasting here, but in fact we 113 # use this in some situations to determine if we need to do an expand 114 # on the tensor because they don't line up, so you can definitely end 115 # up trying to prove u0 != 1 in this situation. See 116 # python test/test_proxy_tensor.py -k test_cumsum_unbacked 117 if guard_size_oblivious(x != y): 118 return False 119 120 return True 121 122 123def _maybe_get_pytype(t): 124 if t is torch.SymFloat: 125 return float 126 elif t is torch.SymInt: 127 return int 128 elif t is torch.SymBool: 129 return bool 130 else: 131 return t 132 133 134# TODO: look at using torch.testing.assert_close instead with an option 135# to just compare metadata 136def compare_tensor_meta( 137 a: TensorLikeType, 138 b: TensorLikeType, 139 check_strides=False, 140 *, 141 allow_rhs_unbacked=False, 142 check_conj=True, 143): 144 """ 145 Checks that two tensor likes have the same shape, 146 dtype and device. 147 148 In the future this will validate additional metadata, like 149 strides. 150 """ 151 assert isinstance(a, TensorLike) 152 assert isinstance(b, TensorLike) 153 154 if not same_shape(a.shape, b.shape, allow_rhs_unbacked=allow_rhs_unbacked): 155 msg = f"Shapes {a.shape} and {b.shape} are not equal!" 156 raise AssertionError(msg) 157 158 if a.dtype != b.dtype: 159 msg = f"Dtypes {a.dtype} and {b.dtype} are not equal!" 160 raise AssertionError(msg) 161 162 if a.device != b.device: 163 # Handles special cuda:0 vs cuda case 164 # TODO: we should review why this happens and see about fixing it 165 if (str(a.device) == "cuda:0" or str(a.device) == "cuda") and ( 166 str(b.device) == "cuda:0" or str(b.device) == "cuda" 167 ): 168 pass 169 else: 170 msg = f"Devices {a.device} and {b.device} are not equal!" 171 raise AssertionError(msg) 172 173 # Stride checking is currently disabled, see https://github.com/pytorch/pytorch/issues/78050 174 if check_strides: 175 same_strides, idx = check_significant_strides(a, b) 176 if not same_strides: 177 msg = f"Stride mismatch! Strides are {a.stride()} and {b.stride()} (mismatched at {idx})!" 178 raise RuntimeError(msg) 179 180 if a.storage_offset() != b.storage_offset(): 181 msg = f"Storage offset mismatch! Storage offsets are {a.storage_offset()} and {b.storage_offset()}!" 182 raise RuntimeError(msg) 183 184 if check_conj: 185 if a.is_conj() != b.is_conj(): 186 raise RuntimeError( 187 f"Conj mismatch! is_conj is set to {a.is_conj()} and {b.is_conj()}" 188 ) 189 190 if a.is_neg() != b.is_neg(): 191 raise RuntimeError( 192 f"Neg mismatch! is_neg is set to {a.is_neg()} and {b.is_neg()}" 193 ) 194 195 196def _check_strides_helper( 197 a: TensorLikeType, b: TensorLikeType, *, only_cuda=True, significant_only=True 198) -> Tuple[bool, Optional[int]]: 199 # NOTE: only on CUDA because CPU elementwise strides are incorrect in PyTorch 200 # See https://github.com/pytorch/pytorch/issues/77553 201 # Only compares strides that are "meaningful" -- strides for dimensions with length > 1 202 # and for tensors with more than one element 203 if ( 204 not only_cuda or a.device.type == "cuda" or b.device.type == "cuda" 205 ) and a.numel() > 0: 206 for idx in range(a.ndim): 207 check = not significant_only or a.shape[idx] > 1 208 if a.stride()[idx] != b.stride()[idx] and check: 209 return False, idx 210 211 return True, None 212 213 214def check_significant_strides( 215 a: TensorLikeType, b: TensorLikeType, *, only_cuda=True 216) -> Tuple[bool, Optional[int]]: 217 return _check_strides_helper(a, b, only_cuda=only_cuda, significant_only=True) 218 219 220def check_all_strides( 221 a: TensorLikeType, b: TensorLikeType, *, only_cuda=True 222) -> Tuple[bool, Optional[int]]: 223 return _check_strides_helper(a, b, only_cuda=only_cuda, significant_only=False) 224 225 226# This function is equivalent to compute_contiguous() from TensorImpl.cpp 227def is_contiguous(a: TensorLikeType) -> bool: 228 """ 229 Tests whether a tensor is contiguous or not. 230 231 Tensors are contiguous when they have no elements, 232 one element, or when they have "nested" strides. 233 """ 234 from torch.fx.experimental.symbolic_shapes import guard_size_oblivious 235 236 if guard_size_oblivious(a.numel() < 2): 237 return True 238 239 expected_stride = 1 240 for x, y in reversed(tuple(zip(a.shape, a.stride()))): 241 # Skips checking strides when a dimension has length 1 242 if guard_size_oblivious(x == 1): 243 continue 244 245 if guard_size_oblivious(y != expected_stride): 246 return False 247 expected_stride = expected_stride * x 248 249 return True 250 251 252# This function is equivalent to compute_channels_last_contiguous_2d() in TensorImpl.cpp 253def is_channels_last_contiguous_2d(a: Tensor) -> bool: 254 # NHWC or not channels last 2D contiguous 255 if a.ndim != 4: 256 return False 257 258 from torch.fx.experimental.symbolic_shapes import guard_size_oblivious 259 260 expected_stride = 1 261 for idx in (1, 3, 2, 0): 262 length = a.shape[idx] 263 if guard_size_oblivious(length == 1): 264 continue 265 266 stride = a.stride()[idx] 267 if guard_size_oblivious(stride != expected_stride): 268 return False 269 270 expected_stride *= length 271 272 return True 273 274 275def is_channels_last_contiguous_3d(a: Tensor) -> bool: 276 # NDHWC or not channels last 3D contiguous 277 if a.ndim != 5: 278 return False 279 280 expected_stride = 1 281 for idx in (1, 4, 3, 2, 0): 282 length = a.shape[idx] 283 if length == 1: 284 continue 285 286 stride = a.stride()[idx] 287 if stride != expected_stride: 288 return False 289 290 expected_stride *= length 291 292 return True 293 294 295_memory_formats = { 296 torch.contiguous_format, 297 torch.preserve_format, 298 torch.channels_last, 299 torch.channels_last_3d, 300} 301 302 303def validate_memory_format(memory_format: torch.memory_format): 304 torch._check( 305 memory_format in _memory_formats, 306 lambda: f"Received unknown memory format {memory_format}!", 307 ) 308 309 310def is_contiguous_for_memory_format( # type: ignore[return] 311 a: Tensor, *, memory_format: torch.memory_format 312) -> bool: 313 validate_memory_format(memory_format) 314 315 if memory_format == torch.contiguous_format: 316 return is_contiguous(a) 317 if memory_format == torch.channels_last: 318 return is_channels_last_contiguous_2d(a) 319 if memory_format == torch.channels_last_3d: 320 return is_channels_last_contiguous_3d(a) 321 322 torch._check( 323 False, 324 lambda: f"is_contiguous received unsupported memory format {memory_format}", 325 ) 326 327 328# NOTE: that tensors with no elements and channels last is ??? 329def is_channels_last_contiguous(a: Tensor) -> bool: 330 """ 331 True when a tensor is channels-last contiguous. 332 333 This requires that: 334 335 - the tensor is conceptually either 4 (NHWC) or 5 (NDHWC) dimensions 336 - if we name the tensor's dimensions NCHW or NCDHW, then the strides are such that the 337 stride of the 'C' dimension (Cs) is 1 and the strides corresponding to 338 each dimension (Xs) can be ordered Cs <= Ws <= Hs <= (Ds) <= Ns and are 339 "nested" -- so Ws = Cs * Cl, where Cl is the length of the 'C' dimension, 340 for example. 341 """ 342 return is_channels_last_contiguous_2d(a) or is_channels_last_contiguous_3d(a) 343 344 345def is_non_overlapping_and_dense(a: Tensor) -> bool: 346 """ 347 True when a tensor is non-overlapping and dense. 348 349 A tensor is non-overlapping and dense when there exists a permutation of 350 its dimensions that is contiguous. 351 """ 352 353 from torch.fx.experimental.symbolic_shapes import guard_size_oblivious 354 355 if a.is_sparse: 356 return False 357 358 # Short-circuits if the tensor is already contiguous or channels-last contiguous 359 if is_contiguous(a) or is_channels_last_contiguous(a): 360 return True 361 362 # The following is equivalent to compute_non_overlapping_and_dense in TensorImpl.cpp 363 364 # Short-circuits for tensors of rank one, which are 365 # non-overlapping and "dense" if their stride is one 366 if a.ndim == 1: 367 return a.stride()[0] == 1 368 369 # Checks that there exists a permutation of the strides s.t. the tensor would be contiguous 370 # Sorts (length, stride) pairs by stride 371 # 372 # This sort is done in a size-oblivious way, which helps if we do a 373 # comparison like 2048*u0 > u0; we just want this to return True 374 # (and not worry about what if u0 is zero). 375 class K(NamedTuple): 376 size: int 377 stride: int 378 379 def __lt__(self, other): 380 return guard_size_oblivious(self.stride < other.stride) 381 382 def __gt__(self, other): 383 return guard_size_oblivious(self.stride > other.stride) 384 385 def __le__(self, other): 386 return guard_size_oblivious(self.stride <= other.stride) 387 388 def __ge__(self, other): 389 return guard_size_oblivious(self.stride >= other.stride) 390 391 def __eq__(self, other): 392 return guard_size_oblivious(self.stride == other.stride) 393 394 lengths_and_strides = sorted(map(K, a.shape, a.stride())) 395 396 expected_stride = 1 397 for length, stride in lengths_and_strides: 398 if guard_size_oblivious(length == 1): 399 continue 400 401 if stride != expected_stride: 402 return False 403 404 expected_stride *= length 405 406 return True 407 408 409# NOTE: Based on the implementation in TensorIterator.cpp, but note that 410# the note [Computing output strides] is incorrect, because it 411# says that strides will be preserved even if they are not 412# "non overlapping and dense", but this is incorrect. The 413# output of elementwise operations are always given 414# non overlapping and dense strides. 415# This is also INCORRECT because it does not model TensorIterator's 416# short-circuit, which can cause different strides. 417def compute_elementwise_output_logical_to_physical_perm( 418 *tensors, _skip_checks=False 419) -> List[int]: 420 from torch.fx.experimental.symbolic_shapes import guard_size_oblivious 421 422 if not _skip_checks and len(tensors) == 0: 423 msg = "Can't compute elementwise output strides for zero tensors!" 424 raise ValueError(msg) 425 426 if not _skip_checks: 427 check_same_shape(*tensors, allow_cpu_scalar_tensors=True) 428 429 # Filters the tensors to actual tensors 430 if not _skip_checks: 431 tensors = tuple( 432 a 433 for a in tensors 434 if isinstance(a, TensorLike) and not is_cpu_scalar_tensor(a) 435 ) 436 437 # Short-circuits for CPU scalar case 438 if len(tensors) == 0: 439 return [] 440 441 # Short-circuits for shapes with zero or one dimensions 442 # TODO: are these necessary? 443 ndim = tensors[0].ndim 444 if ndim == 0: 445 return [] 446 if ndim == 1: 447 return [0] 448 449 # Short-circuits if contiguous or channels last, following the fake fast path. 450 # This reduces the number of guards we end up making 451 is_contiguous = True 452 is_channels_last = True 453 for t in tensors: 454 is_contiguous = is_contiguous and t.is_contiguous( 455 memory_format=torch.contiguous_format 456 ) 457 is_channels_last = is_channels_last and t.is_contiguous( 458 memory_format=torch.channels_last 459 ) 460 461 if is_contiguous and not is_channels_last: 462 return list(range(ndim)) 463 464 if is_channels_last and not is_contiguous: 465 return [0, *list(range(2, ndim)), 1] 466 467 shape = tensors[0].shape 468 469 def should_swap(idx_a, idx_b): 470 for tensor in tensors: 471 stride_a = tensor.stride()[idx_a] 472 stride_b = tensor.stride()[idx_b] 473 474 if guard_size_oblivious(stride_a == 0) or guard_size_oblivious( 475 stride_b == 0 476 ): 477 continue 478 479 if guard_size_oblivious(stride_a < stride_b): 480 return -1 481 482 if guard_size_oblivious(stride_a > stride_b): 483 return 1 484 485 # stride_a == stride_b 486 if guard_size_oblivious(shape[idx_a] > shape[idx_b]): 487 return 1 488 489 # Note: this case is hit if all strides are zero, 490 # or all strides are equal and all dimensions have the same length 491 return 0 492 493 # The "sort" order for the permutation is back-to-front, but 494 # the natural order for permutations is front-to-back. Do the 495 # sorting back-to-front and then reverse it on output. 496 # 497 # also, note this returns the logical to physical shape permutation 498 perm = list(reversed(range(ndim))) 499 500 # insertion sort with support for ambiguous comparisons 501 for i in range(1, ndim): 502 dim1 = i 503 for dim0 in reversed(range(i)): 504 comparison = should_swap(perm[dim0], perm[dim1]) 505 if comparison > 0: 506 perm[dim0], perm[dim1] = perm[dim1], perm[dim0] 507 dim1 = dim0 508 elif comparison < 0: 509 break 510 511 return list(reversed(perm)) 512 513 514def compute_elementwise_output_strides(*tensors) -> Tuple[int, ...]: 515 """ 516 Computes the output strides for elementwise operations. 517 """ 518 if len(tensors) == 0: 519 msg = "Can't compute elementwise output strides for zero tensors!" 520 raise ValueError(msg) 521 522 check_same_shape(*tensors, allow_cpu_scalar_tensors=True) 523 524 # Filters the tensors to actual tensors 525 tensors = tuple( 526 a for a in tensors if isinstance(a, TensorLike) and not is_cpu_scalar_tensor(a) 527 ) 528 529 # Short-circuits for CPU scalar case 530 if len(tensors) == 0: 531 return () 532 533 ndim = tensors[0].ndim 534 shape = tensors[0].shape 535 536 if ndim == 0: 537 return () 538 if ndim == 1: 539 return (1,) 540 541 logical_to_physical_perm = compute_elementwise_output_logical_to_physical_perm( 542 *tensors, _skip_checks=True 543 ) 544 permuted_shape = apply_perm(shape, logical_to_physical_perm) # to physical 545 546 new_strides = make_contiguous_strides_for(permuted_shape) 547 permuted_strides = apply_perm( 548 new_strides, invert_perm(logical_to_physical_perm) 549 ) # to logical 550 551 return tuple(permuted_strides) 552 553 554# Identity permutation is [0, 1, 2] 555def apply_perm(inp, perm): 556 ndim = len(inp) 557 permuted_inp = [-1] * ndim 558 for idx, x in enumerate(perm): 559 permuted_inp[idx] = inp[x] 560 return permuted_inp 561 562 563def invert_perm(perm): 564 ndim = len(perm) 565 new_perm = [-1] * ndim 566 for idx, x in enumerate(perm): 567 new_perm[x] = idx 568 return new_perm 569 570 571# 572# Common helper functions 573# 574 575 576def validate_dim_length(length: int): 577 """ 578 Validates that an object represents a valid 579 dimension length. 580 """ 581 582 if isinstance(length, (int, torch.SymInt)): 583 torch._check_is_size(length) 584 else: 585 # sometimes called with sympy expression by inductor 586 assert length >= 0 587 588 589def validate_shape(shape: ShapeType): 590 """ 591 Validates that a sequence represents a valid shape. 592 """ 593 594 assert isinstance(shape, Sequence), type(shape) 595 for l in shape: 596 validate_dim_length(l) 597 598 599def validate_strides(strides: StrideType): 600 """ 601 Verifies the object specifies valid strides. 602 """ 603 604 assert isinstance(strides, Sequence) 605 for stride in strides: 606 assert stride >= 0 607 608 609def validate_idx(rank: int, idx: int): 610 """ 611 Validates that idx is a valid index for the given shape. 612 Assumes the index is already canonicalized. 613 """ 614 615 assert isinstance(idx, Dim) 616 assert isinstance(rank, Dim) 617 618 assert idx >= 0 and idx < rank or idx == 0 619 620 621def validate_dimension_indices(rank: int, indices: DimsSequenceType): 622 for idx in indices: 623 validate_idx(rank, idx) 624 625 626def validate_exclusive_idx(rank: int, ex_idx: int): 627 """ 628 Validates that ex_idx is a valid exclusive index 629 for the given shape. 630 """ 631 632 assert isinstance(ex_idx, Dim) 633 assert isinstance(rank, Dim) 634 assert ex_idx > 0 and ex_idx <= rank 635 636 637# "Wraps" a dim (up to one time) for the given rank, allowing dims to be 638# specified using negative indices. If `wrap_scalar` is true then scalar 639# tensors of rank 0 will allow dimensions in the range [-1, 0]. Otherwise, 640# idx should be in the range [-rank, rank-1]. 641def canonicalize_dim(rank: int, idx: int, wrap_scalar: bool = True) -> int: 642 if rank < 0: 643 msg = f"Rank cannot be negative but got {rank}" 644 raise IndexError(msg) 645 646 if rank == 0: 647 if not wrap_scalar: 648 msg = f"Dimension specified as {idx} but tensor has no dimensions" 649 raise IndexError(msg) 650 rank = 1 651 652 if idx >= 0 and idx < rank: 653 return idx 654 655 if idx < 0: 656 _idx = idx + rank 657 else: 658 _idx = idx 659 660 if _idx < 0 or _idx >= rank: 661 # Same error message as in aten/src/ATen/WrapDimUtils.h:49 662 msg = f"Dimension out of range (expected to be in range of [{-rank}, {rank - 1}], but got {idx})" 663 raise IndexError(msg) 664 665 return _idx 666 667 668# Takes a dimension or sequence of dimensions and "wraps" them, 669# mapping negative offsets to positive ones 670@overload 671def canonicalize_dims( 672 rank: int, indices: Sequence[int], wrap_scalar: bool = True 673) -> Tuple[int, ...]: 674 pass 675 676 677@overload 678def canonicalize_dims(rank: int, indices: int, wrap_scalar: bool = True) -> int: 679 pass 680 681 682def canonicalize_dims(rank, indices, wrap_scalar=True): 683 if isinstance(indices, Dim): 684 return canonicalize_dim(rank, indices, wrap_scalar) 685 686 return tuple(canonicalize_dim(rank, x, wrap_scalar) for x in indices) 687 688 689def is_valid_permutation(rank: int, perm: DimsSequenceType) -> bool: 690 """ 691 Validates that perm is a permutation of length rank. 692 """ 693 694 return isinstance(perm, Sequence) and sorted(perm) == list(range(rank)) 695 696 697def is_same_shape(a: Sequence, b: Sequence) -> bool: 698 """ 699 Compares two shapes a and b, returning True if they are the same 700 (their ranks and corresponding lengths match) and False otherwise. 701 """ 702 703 return tuple(a) == tuple(b) 704 705 706def is_cpu_scalar_tensor(a: Any) -> bool: 707 return isinstance(a, TensorLike) and a.ndim == 0 and a.device.type == "cpu" 708 709 710def check_same_device(*args, allow_cpu_scalar_tensors): 711 """ 712 Checks that all Tensors in args have the same device. 713 714 Raises a RuntimeError when: 715 - args contains an object whose type is not Tensor or Number 716 - two Tensor objects in args have different devices, unless one is a CPU scalar tensor and allow_cpu_scalar_tensors is True 717 """ 718 # Short-circuits if all (one or fewer) arguments are trivially on the same device 719 if len(args) <= 1: 720 return 721 722 # Note: cannot initialize device to the first arg's device (it may not have one) 723 device = None 724 for arg in args: 725 if isinstance(arg, Number): 726 continue 727 elif isinstance(arg, TensorLike): 728 if allow_cpu_scalar_tensors and is_cpu_scalar_tensor(arg): 729 continue 730 731 if device is None: 732 device = arg.device 733 734 if device != arg.device: 735 msg = ( 736 "Tensor on device " 737 + str(arg.device) 738 + " is not on the expected device " 739 + str(device) 740 + "!" 741 ) 742 raise RuntimeError(msg) 743 else: 744 msg = ( 745 "Unexpected type when checking for same device, " + str(type(arg)) + "!" 746 ) 747 raise RuntimeError(msg) 748 749 750def canonicalize_device(device: DeviceLikeType) -> torch.device: 751 if isinstance(device, torch.device): 752 return device 753 754 assert isinstance(device, str) 755 return torch.device(device) 756 757 758# Asserts if any of the following are true: 759# - a non-scalar or non-Tensor is given 760# - the shape of any tensors is distinct 761def check_same_shape(*args, allow_cpu_scalar_tensors: bool): 762 """ 763 Checks that all Tensors in args have the same shape. 764 765 Raises a RuntimeError when: 766 - args contains an object whose type is not Tensor or Number 767 - two Tensor objects in args have different devices 768 """ 769 shape = None 770 771 for arg in args: 772 if isinstance(arg, Number): 773 continue 774 elif isinstance(arg, TensorLike): 775 if allow_cpu_scalar_tensors and is_cpu_scalar_tensor(arg): 776 continue 777 778 if shape is None: 779 shape = arg.shape 780 781 if not is_same_shape(shape, arg.shape): 782 msg = f"Shape {arg.shape} is not the expected shape {shape}!" 783 raise RuntimeError(msg) 784 else: 785 msg = ( 786 "Unexpected type when checking for same shape, " + str(type(arg)) + "!" 787 ) 788 raise RuntimeError(msg) 789 790 791# Acquires a common shape, if it exists, from one or more tensor arguments, 792# filtering number arguments 793def extract_shape(*args, allow_cpu_scalar_tensors: bool) -> Optional[ShapeType]: 794 shape = None 795 scalar_shape = None 796 797 for arg in args: 798 if isinstance(arg, Number): 799 continue 800 elif isinstance(arg, TensorLike): 801 if allow_cpu_scalar_tensors and is_cpu_scalar_tensor(arg): 802 scalar_shape = arg.shape 803 continue 804 805 if shape is None: 806 shape = arg.shape 807 808 if not is_same_shape(shape, arg.shape): 809 return None 810 else: 811 return None 812 813 return shape if shape is not None else scalar_shape 814 815 816# Extracts dimensions that might be passed either as a list/tuple or as varargs. 817# A typical case is Tensor.permute . 818def extract_dims_from_varargs( 819 dims: Union[DimsSequenceType, Tuple[DimsSequenceType, ...]] 820) -> DimsSequenceType: 821 if dims and isinstance(dims[0], Sequence): 822 assert len(dims) == 1 823 dims = cast(Tuple[DimsSequenceType], dims) 824 return dims[0] 825 else: 826 return cast(DimsSequenceType, dims) 827 828 829def extract_shape_from_varargs( 830 shape: Union[ShapeType, Tuple[ShapeType]], 831 validate=True, 832) -> Tuple[int, ...]: 833 """ 834 Returns a shape from varargs. 835 836 In PyTorch, operations that accept shapes often accept them as varargs, like 837 foo(*shape). However a user can pass the shape as a sequence of integers, 838 like this: 839 840 foo(1, 2, 3) 841 842 or as a sequence of integers 843 844 foo((1, 2, 3)) 845 846 In the first case shape will be a tuple of integers, and in the second case it's a tuple 847 containing a tuple of integers. This validates those inputs and canonicalizes them 848 to a tuple of integers. 849 """ 850 851 # Handles tuple unwrapping 852 if len(shape) == 1 and isinstance(shape[0], Sequence): 853 shape = shape[0] 854 855 if validate: 856 validate_shape(shape) # type: ignore[arg-type] 857 return shape # type: ignore[return-value] 858 859 860def infer_size_shapes(a: ShapeType, b: ShapeType) -> Tuple[int, ...]: 861 ndim = max(len(a), len(b)) 862 expandedSizes = [0] * ndim 863 864 for i in range(ndim - 1, -1, -1): 865 offset = ndim - 1 - i 866 dimA = len(a) - 1 - offset 867 dimB = len(b) - 1 - offset 868 sizeA = a[dimA] if dimA >= 0 else 1 869 sizeB = b[dimB] if dimB >= 0 else 1 870 871 torch._check( 872 (sizeA == sizeB) or (sizeA == 1) or (sizeB == 1), 873 lambda: ( 874 f"The size of tensor a ({sizeA}) must match the size of " 875 f"tensor b ({sizeB}) at non-jagged dimension {i}" 876 ), 877 ) 878 879 # 1s map to the other size (even 0) 880 expandedSizes[i] = sizeB if sizeA == 1 else sizeA 881 882 return tuple(expandedSizes) 883 884 885def infer_size(shape: ShapeType, numel: int) -> Tuple[int, ...]: 886 """ 887 Infers the size of a dim with size -1, if it exists. 888 Also checks that new shape is compatible with the number of elements. 889 """ 890 dim = None 891 newsize = 1 892 for i, d in enumerate(shape): 893 if d == -1: 894 torch._check(dim is None, lambda: "only one dimension can be inferred") 895 dim = i 896 elif d >= 0: 897 newsize *= d 898 else: 899 torch._check(False, lambda: f"invalid shape dimension {d}") 900 if dim is None: 901 torch._check( 902 numel == newsize, 903 lambda: f"shape '{list(shape)}' is invalid for input of size {numel}", 904 ) 905 else: 906 from torch.fx.experimental.symbolic_shapes import definitely_true 907 908 torch._check( 909 newsize != 0, 910 lambda: ( 911 f"cannot reshape tensor of 0 elements into shape {list(shape)} because the " 912 f"unspecified dimension size -1 can be any value and is ambiguous" 913 if definitely_true(numel == 0) 914 else f"shape '{list(shape)}' is invalid for input of size {numel}" 915 ), 916 ) 917 torch._check( 918 numel % newsize == 0, 919 lambda: f"shape '{list(shape)}' is invalid for input of size {numel}", 920 ) 921 # Convert to list to produce a compatible error message with core 922 # PyTorch, which prints sequences in square brackets. 923 shape = list(shape) 924 shape[dim] = numel // newsize 925 # NB: This is pretty important when you have unbacked SymInts. 926 # Suppose you have (i0, 12) resizing into (2, -1, 12). The old 927 # range for i0 is typically [2, inf], which means if you divide 928 # by two the new range should be [1, inf]. But this is bad news 929 # if you have an unbacked SymInt: we need to reapply the unsound 930 # assumption that the size is >= 2. 931 torch._check_is_size(shape[dim]) 932 return tuple(shape) 933 934 935_integer_dtypes = ( 936 torch.uint8, 937 torch.uint16, 938 torch.uint32, 939 torch.uint64, 940 torch.int8, 941 torch.int16, 942 torch.int32, 943 torch.int64, 944) 945_low_precision_dtypes = (torch.float16, torch.bfloat16, torch.complex32) 946_complex_dtypes = (torch.complex32, torch.complex64, torch.complex128) 947 948 949def is_boolean_dtype(dtype: torch.dtype) -> bool: 950 assert isinstance(dtype, torch.dtype) 951 return dtype is torch.bool 952 953 954def is_integer_dtype(dtype: torch.dtype) -> bool: 955 assert isinstance(dtype, torch.dtype) 956 return dtype in _integer_dtypes 957 958 959def is_low_precision_dtype(dtype: torch.dtype) -> bool: 960 assert isinstance(dtype, torch.dtype) 961 return dtype in _low_precision_dtypes 962 963 964def is_float_dtype(dtype: torch.dtype) -> bool: 965 assert isinstance(dtype, torch.dtype) 966 return dtype.is_floating_point 967 968 969def is_complex_dtype(dtype: torch.dtype) -> bool: 970 assert isinstance(dtype, torch.dtype) 971 return dtype in _complex_dtypes 972 973 974def is_grad_dtype(dtype: torch.dtype) -> bool: 975 """ 976 Checks if the dtype can require a gradient. 977 """ 978 return dtype.is_floating_point or is_complex_dtype(dtype) 979 980 981_complex_to_real_dtype_map = { 982 torch.complex128: torch.float64, 983 torch.complex64: torch.float32, 984 torch.complex32: torch.float16, 985} 986 987_real_to_complex_dtype_map = { 988 torch.float16: torch.complex32, 989 torch.bfloat16: torch.complex64, 990 torch.float32: torch.complex64, 991 torch.float64: torch.complex128, 992} 993 994 995def corresponding_real_dtype(dtype: torch.dtype) -> torch.dtype: 996 return _complex_to_real_dtype_map[dtype] 997 998 999def corresponding_complex_dtype(dtype: torch.dtype) -> torch.dtype: 1000 return _real_to_complex_dtype_map[dtype] 1001 1002 1003def dtype_to_type(dtype: torch.dtype) -> type: 1004 """ 1005 Computes the corresponding Python type (AKA "type kind") for the 1006 given dtype. 1007 """ 1008 assert isinstance(dtype, torch.dtype) 1009 1010 if dtype is torch.bool: 1011 return bool 1012 if dtype in _integer_dtypes: 1013 return int 1014 if dtype.is_floating_point: 1015 return float 1016 if dtype in _complex_dtypes: 1017 return complex 1018 1019 raise ValueError("Invalid dtype!") 1020 1021 1022def dtype_to_type_ctor(dtype: torch.dtype) -> Callable[[NumberType], NumberType]: 1023 """ 1024 Computes the corresponding Python type constructor for the 1025 given dtype. 1026 """ 1027 assert isinstance(dtype, torch.dtype) 1028 1029 if dtype is torch.bool: 1030 return lambda x: bool(x) 1031 if dtype in _integer_dtypes: 1032 return sym_int 1033 if dtype.is_floating_point: 1034 return sym_float 1035 if dtype in _complex_dtypes: 1036 # TODO: type error here is real, replace with sym_complex 1037 return lambda x: complex(x) # type: ignore[arg-type] 1038 1039 raise ValueError("Invalid dtype!") 1040 1041 1042def type_to_dtype(typ: type) -> torch.dtype: 1043 """ 1044 Computes the corresponding dtype for a Number type. 1045 """ 1046 1047 assert isinstance(typ, type) 1048 1049 if typ in (bool, torch.SymBool): 1050 return torch.bool 1051 if typ in (int, torch.SymInt): 1052 return torch.long 1053 if typ in (float, torch.SymFloat): 1054 return torch.get_default_dtype() 1055 # TODO: sym_complex_float? 1056 if typ is complex: 1057 return corresponding_complex_dtype(torch.get_default_dtype()) 1058 1059 raise ValueError(f"Invalid type {typ}!") 1060 1061 1062def get_dtype(x: Union[torch.Tensor, NumberType]): 1063 if isinstance(x, torch.Tensor): 1064 return x.dtype 1065 else: 1066 return type_to_dtype(type(x)) 1067 1068 1069_ordered_types = (bool, int, float, complex) 1070 1071 1072def check_fp_or_complex( 1073 dtype: torch.dtype, fn_name: str, allow_low_precision_dtypes: bool = True 1074): 1075 """ 1076 Checks whether the input is floating point or complex. 1077 If allow_low_precision_dtypes is True, it allows having float16, bfloat16, and complex32 1078 """ 1079 torch._check( 1080 is_float_dtype(dtype) or is_complex_dtype(dtype), 1081 lambda: f"{fn_name}: Expected a floating point or complex tensor as input. Got {dtype}", 1082 ) 1083 torch._check( 1084 allow_low_precision_dtypes or not is_low_precision_dtype(dtype), 1085 lambda: f"{fn_name}: Half precision dtypes not supported. Got {dtype}", 1086 ) 1087 1088 1089def check_is_matrix(A: TensorLikeType, f_name: str, arg_name: str = "A"): 1090 torch._check( 1091 len(A.shape) >= 2, 1092 lambda: f"{f_name}: The input tensor {arg_name} must have at least 2 dimensions.", 1093 ) 1094 1095 1096def get_higher_type(a: type, b: type) -> type: 1097 """ 1098 Returns the higher of the two given Number types. 1099 1100 The types are ordered bool -> int -> float -> complex. 1101 """ 1102 a, b = _maybe_get_pytype(a), _maybe_get_pytype(b) 1103 # Type checking 1104 if a not in _ordered_types or b not in _ordered_types: 1105 raise RuntimeError(f"Expected builtin numeric types, found {a}, {b}") 1106 1107 if a is b: 1108 return a 1109 1110 for typ in _ordered_types: 1111 if a is typ: 1112 return b 1113 if b is typ: 1114 return a 1115 1116 raise ValueError("Unknown Python scalar type!") 1117 1118 1119# Returns the higher of two torch datatypes a and b or, if the two 1120# are not ordered relative to each other, the next 1121# higher datatype 1122def get_higher_dtype( 1123 a: Optional[Union[torch.dtype, TensorLikeType, NumberType]], 1124 b: Optional[Union[torch.dtype, TensorLikeType, NumberType]], 1125) -> Optional[torch.dtype]: 1126 """ 1127 Computes the "lowest" datatype that is weakly 1128 "higher" than both a and b. 1129 """ 1130 1131 # Type checking 1132 assert a is None or isinstance(a, (torch.dtype, TensorLike, Number)) 1133 assert b is None or isinstance(b, (torch.dtype, TensorLike, Number)) 1134 1135 def _extract_dtype( 1136 x: Optional[Union[torch.dtype, TensorLikeType, NumberType]] 1137 ) -> Optional[torch.dtype]: 1138 if x is None: 1139 return None 1140 if isinstance(x, torch.dtype): 1141 return x 1142 if isinstance(x, TensorLike): 1143 return x.dtype 1144 if isinstance(x, Number): 1145 return type_to_dtype(type(x)) 1146 1147 raise RuntimeError("Unexpected type given to _extract_dtype!") 1148 1149 a, b = _extract_dtype(a), _extract_dtype(b) 1150 1151 if a is b: 1152 return a 1153 1154 if a is None: 1155 return b 1156 1157 if b is None: 1158 return a 1159 1160 ordered_datatypes = ( 1161 (torch.bool,), 1162 (torch.uint8, torch.int8), 1163 (torch.int16,), 1164 (torch.int32,), 1165 (torch.int64,), 1166 (torch.float16, torch.bfloat16), 1167 (torch.float32,), 1168 (torch.float64,), 1169 (torch.complex32,), 1170 (torch.complex64,), 1171 (torch.complex128,), 1172 ) 1173 1174 for idx, dtypes in enumerate(ordered_datatypes): 1175 if a in dtypes and b in dtypes: 1176 return ordered_datatypes[idx + 1][0] 1177 if a in dtypes: 1178 return b 1179 if b in dtypes: 1180 return a 1181 1182 raise RuntimeError("Unexpected termination!") 1183 1184 1185def check_pin_memory(pin_memory: bool): 1186 torch._check_not_implemented( 1187 not pin_memory, lambda: "PrimTorch does not support pinned memory" 1188 ) 1189 1190 1191def check_layout(layout: torch.layout): 1192 torch._check_not_implemented( 1193 layout == torch.strided, lambda: f"PrimTorch doesn't support layout={layout}" 1194 ) 1195 1196 1197# TODO: maybe unify with can_cast_to? 1198def is_weakly_lesser_type(a: type, b: type) -> bool: 1199 """ 1200 Compares two types, a and b, returning True if a is weakly "less" than b. 1201 1202 The comparison is determined by the following type ordering: bool, int, float, complex. 1203 """ 1204 1205 a, b = _maybe_get_pytype(a), _maybe_get_pytype(b) 1206 1207 if a not in _ordered_types or b not in _ordered_types: 1208 raise RuntimeError(f"Expected builtin numeric types, found {a}, {b}") 1209 1210 for typ in _ordered_types: 1211 if a == typ: 1212 return True 1213 if b == typ: 1214 return False 1215 1216 raise RuntimeError("Unexpected termination!") 1217 1218 1219def can_safe_cast_to(*, cast_to: torch.dtype, cast_from: torch.dtype) -> bool: 1220 for fn in (is_complex_dtype, is_float_dtype, is_integer_dtype, is_boolean_dtype): 1221 if fn(cast_to): 1222 return True 1223 if fn(cast_from): 1224 return False 1225 1226 raise ValueError(f"Received unknown dtypes {cast_to}, {cast_from}!") 1227 1228 1229def check_same_dtype(*args): 1230 """ 1231 Checks that all Tensors in args have the same device and that all Numbers have the 1232 same corresponding Python type. 1233 1234 Raises a RuntimeError when: 1235 - args contains an object whose type is not Tensor or Number 1236 - two Tensors objects in args have different dtypes 1237 - two Number objects in args have different types 1238 - there are Tensors and Numbers in args, and one of those Tensors corresponding 1239 Python types is different from the type of one of those Numbers 1240 """ 1241 full_dtype = None 1242 scalar_type = None 1243 1244 for arg in args: 1245 if isinstance(arg, Number): 1246 # Scalar type checking is disabled (and may be removed in the future) 1247 continue 1248 # if scalar_type is None: 1249 # scalar_type = type(arg) 1250 1251 # if scalar_type is not type(arg): 1252 # msg = ( 1253 # "Scalar of type " 1254 # + str(type(arg)) 1255 # + " is not the expected type of " 1256 # + str(scalar_type) 1257 # + "!" 1258 # ) 1259 # raise RuntimeError(msg) 1260 elif isinstance(arg, TensorLike): 1261 if full_dtype is None: 1262 full_dtype = arg.dtype 1263 if scalar_type is None: 1264 scalar_type = dtype_to_type(arg.dtype) 1265 1266 if full_dtype is not arg.dtype: 1267 msg = ( 1268 "Tensor with dtype " 1269 + str(arg.dtype) 1270 + " is not the expected dtype of " 1271 + str(full_dtype) 1272 + "!" 1273 ) 1274 raise RuntimeError(msg) 1275 1276 arg_type = dtype_to_type(arg.dtype) 1277 if arg_type is not scalar_type: 1278 msg = ( 1279 "Tensor with corresponding Python type " 1280 + str(arg_type) 1281 + " is not the expected type of " 1282 + str(scalar_type) 1283 + "!" 1284 ) 1285 raise RuntimeError(msg) 1286 else: 1287 msg = ( 1288 "Unexpected type when checking for same dtype, " + str(type(arg)) + "!" 1289 ) 1290 raise RuntimeError(msg) 1291 1292 1293# Maps datatypes to their computation types for elementwise operations 1294_computation_dtype_map = { 1295 torch.bfloat16: torch.float32, 1296 torch.float16: torch.float32, 1297 torch.complex32: torch.complex64, 1298} 1299 1300 1301def get_computation_dtype(dtype: torch.dtype) -> torch.dtype: 1302 return _computation_dtype_map.get(dtype, dtype) 1303 1304 1305_cpu_acc_type_map = { 1306 torch.bfloat16: torch.float64, 1307 torch.float16: torch.float64, 1308 torch.float32: torch.float64, 1309 torch.complex32: torch.complex128, 1310 torch.complex64: torch.complex128, 1311} 1312 1313 1314def get_acc_type(dtype: torch.dtype, device: torch.device) -> torch.dtype: 1315 # Equivalent to at::toAccumulateType, prefer computation_dtype where possible 1316 if device.type == "cpu": 1317 return _cpu_acc_type_map.get(dtype, dtype) 1318 else: 1319 return get_computation_dtype(dtype) 1320 1321 1322class ELEMENTWISE_TYPE_PROMOTION_KIND(Enum): 1323 DEFAULT = (0,) 1324 NO_OPMATH = (1,) 1325 INT_TO_FLOAT = (2,) 1326 ALWAYS_BOOL = (3,) 1327 COMPLEX_TO_FLOAT = (4,) 1328 BOOL_TO_LONG = (5,) 1329 1330 1331class REDUCTION_OUTPUT_TYPE_KIND(Enum): 1332 SAME = (0,) 1333 COMPLEX_TO_FLOAT = (1,) # for complex types outputs corresponding real type 1334 KEEP_PROMOTED_TYPE = (2,) # keep output in opmath type, needed for mean 1335 ALWAYS_BOOL = (3,) 1336 1337 1338# Describes the return type of the primitive: 1339# 1340# - NEW, a new tensor is created 1341# - VIEW, a view of an input tensor is returned 1342# - INPLACE, one or more input tensors is modified 1343# 1344# these descriptors are mututally exclusive and exhaustive. 1345class RETURN_TYPE(Enum): 1346 NEW = (0,) 1347 VIEW = (1,) 1348 INPLACE = (2,) 1349 NONE = (3,) 1350 1351 1352# TODO: when NumberType contains the sym types, can simplify this 1353def number_type( 1354 x: Union[NumberType, torch.SymInt, torch.SymFloat, torch.SymBool] 1355) -> Type: 1356 if isinstance(x, torch.SymInt): 1357 return int 1358 elif isinstance(x, torch.SymFloat): 1359 return float 1360 elif isinstance(x, torch.SymBool): 1361 return bool 1362 else: 1363 return type(x) 1364 1365 1366def expr_type(x: sympy.Basic) -> Type: 1367 import sympy 1368 1369 if x.kind is sympy.core.kind.BooleanKind: 1370 return bool 1371 elif x.is_integer: # type: ignore[attr-defined] 1372 return int 1373 else: 1374 # NB: Not strictly correct, but we don't support SymPy complex or bool. 1375 return float 1376 1377 1378# TODO: document type promotion kinds 1379def elementwise_dtypes( 1380 *_args, 1381 type_promotion_kind: ELEMENTWISE_TYPE_PROMOTION_KIND, 1382) -> Tuple[torch.dtype, torch.dtype]: 1383 """ 1384 Computes the computation and result dtypes for elementwise type promotion 1385 on the given arguments and with the given elementwise type promotion kind. 1386 1387 Note that not all inputs to an elementwise operation necessarily participate in type promotion. 1388 For example, the "alpha" parameter of torch.add does not participate in type promotion, 1389 although it may be cast to the Python type corresponding to the computation dtype that 1390 the type promotion algorithm determines. 1391 1392 Default elementwise type promotion, which all other type promotion kinds tweak (see below), 1393 first decides which of four ordered types to use: 1394 1395 bool -> integer -> floating point -> complex 1396 1397 The selected type is the "lowest" type in the above list such that all number arguments 1398 have a weakly "lower" type and all tensor arguments have a weakly lower corresponding 1399 type for their dtype. 1400 1401 Once the type is determined, the particular result dtype is found. The dtypes are 1402 partially ordered as follows: 1403 1404 bool -> uint8, int8 -> int16 -> int32 -> int64 -> 1405 float16, bfloat16 -> float32 -> float64 -> complex32 -> complex64 -> complex128 1406 1407 The result dtype is selected by: 1408 - if no tensor's dtype has the same corresponding type as the one selected, 1409 then the result dtype is the (default) dtype corresponding to the selected type 1410 (for example, 1.5 + an integer tensor has a result dtype of the default floating point dtype) 1411 - if the result type is complex then the dtype is: 1412 - the default complex dtype if there are no floating point or complex tensors 1413 - if there are floating point or complex tensors with one or more dimensions, then 1414 the complex dtype corresponding to the highest corresponding complex dtype among those tensors 1415 (for example, double + cfloat -> cdouble) 1416 - if there are only floating point or complex tensors with zero dimensions, then 1417 the complex dtype corresponding to the highest corresponding complex dtype among those tensors 1418 - if the first two cases do not apply, the result dtype is the highest dtype among 1419 all tensors with one or more dimensions of the output type, and if there are no such 1420 tensors then it's the highest dtype among all tensors with zero dimensions of the output type 1421 (for example, long + half -> half, even if the half tensor has zero dimensions) 1422 1423 The "corresponding complex dtypes" are: 1424 float16 -> complex32 1425 bfloat16 -> complex64 1426 float32 -> complex64 1427 float64 -> complex128 1428 complex32 -> complex32 1429 complex64 -> complex64 1430 complex128 -> complex128 1431 1432 The DEFAULT type promotion kind computes per above, and then uses the result dtype to pick a computation 1433 dtype by mapping low precision floating point and complex dtypes as follows: 1434 1435 float16 -> float32 1436 bfloat16 -> float32 1437 complex32 -> complex64 1438 1439 This is referred to as "op math", and the NO_OPMATH type promotion kind disables this mapping, making the 1440 computation dtype the same as the result dtype when it's selected. NO_OPMATH is appropriate for kernels 1441 which perform no mathematical operations on their tensors (see below for examples). 1442 1443 The INT_TO_FLOAT type promotion kind maps boolean and integer result dtypes to the default floating point dtype, 1444 and computation dtypes to the appropriate op math dtype. 1445 1446 The COMPLEX_TO_FLOAT type promotion kind maps complex result dtypes to the corresponding float dtype, following this 1447 mapping: 1448 1449 complex32 -> float16 1450 complex64 -> float32 1451 complex128 -> float64 1452 1453 Note that COMPLEX_TO_FLOAT derives the computation dtype as the DEFAULT setting does. 1454 1455 The BOOL_TO_LONG type promotion kind maps boolean computation and result dtypes to long. 1456 1457 The ALWAYS_BOOL type promotion kind always sets the result dtype to bool. 1458 1459 Example operators for each type promotion option: 1460 DEFAULT : add 1461 NO_OPMATH : where, nextafter, cat 1462 INT_TO_FLOAT : sin 1463 COMPLEX_TO_FLOAT : abs 1464 BOOL_TO_LONG : pow 1465 ALWAYS_BOOL : eq 1466 1467 """ 1468 1469 args = tuple(x for x in _args if x is not None) 1470 1471 highest_type: type = bool 1472 1473 # Import sympy locally, as importing it eagerly at a module level is too slow 1474 # See https://dev-discuss.pytorch.org/t/delving-into-what-happens-when-you-import-torch/1589 1475 import sympy 1476 1477 for x in args: 1478 if not isinstance(x, (Number, TensorLike, sympy.Basic)): 1479 msg = f"Unexpected type {str(type(x))} when computing elementwise type promotion!" 1480 raise ValueError(msg) 1481 1482 if isinstance(x, Number): 1483 highest_type = get_higher_type(highest_type, number_type(x)) 1484 elif isinstance(x, sympy.Basic): 1485 highest_type = get_higher_type(highest_type, expr_type(x)) 1486 else: 1487 # x is a TensorLike 1488 highest_type = get_higher_type(highest_type, dtype_to_type(x.dtype)) 1489 1490 result_dtype = None 1491 1492 def _find_highest_dtype_filtered( 1493 args, filter, *, float_as_complex=False 1494 ) -> Optional[torch.dtype]: 1495 zero_dim_tensor_dtype = None 1496 one_plus_dim_tensor_dtype = None 1497 for x in args: 1498 if isinstance(x, TensorLike) and filter(x.dtype): 1499 _dtype = x.dtype 1500 if float_as_complex and is_float_dtype(_dtype): 1501 _dtype = corresponding_complex_dtype(_dtype) 1502 if x.ndim == 0: 1503 zero_dim_tensor_dtype = get_higher_dtype( 1504 zero_dim_tensor_dtype, _dtype 1505 ) 1506 else: 1507 # x.ndim > 0 1508 one_plus_dim_tensor_dtype = get_higher_dtype( 1509 one_plus_dim_tensor_dtype, _dtype 1510 ) 1511 1512 # Prefers dtype of tensors with one or more dimensions 1513 if one_plus_dim_tensor_dtype is not None: 1514 return one_plus_dim_tensor_dtype 1515 1516 return zero_dim_tensor_dtype 1517 1518 if highest_type is float: 1519 result_dtype = _find_highest_dtype_filtered(args, is_float_dtype) 1520 result_dtype = ( 1521 torch.get_default_dtype() if result_dtype is None else result_dtype 1522 ) 1523 elif highest_type is complex: 1524 result_dtype = _find_highest_dtype_filtered( 1525 args, 1526 lambda x: is_float_dtype(x) or is_complex_dtype(x), 1527 float_as_complex=True, 1528 ) 1529 if result_dtype is None: 1530 result_dtype = corresponding_complex_dtype(torch.get_default_dtype()) 1531 elif highest_type is int: 1532 result_dtype = _find_highest_dtype_filtered(args, is_integer_dtype) 1533 result_dtype = torch.long if result_dtype is None else result_dtype 1534 else: 1535 # highest_type is bool 1536 result_dtype = torch.bool 1537 1538 if type_promotion_kind is ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT: 1539 return get_computation_dtype(result_dtype), result_dtype 1540 elif type_promotion_kind is ELEMENTWISE_TYPE_PROMOTION_KIND.NO_OPMATH: 1541 return result_dtype, result_dtype 1542 elif type_promotion_kind is ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT: 1543 if is_integer_dtype(result_dtype) or is_boolean_dtype(result_dtype): 1544 result_dtype = torch.get_default_dtype() 1545 return get_computation_dtype(result_dtype), result_dtype 1546 elif type_promotion_kind is ELEMENTWISE_TYPE_PROMOTION_KIND.COMPLEX_TO_FLOAT: 1547 # NOTE: computation can still occur in a complex dtype 1548 computation_dtype = get_computation_dtype(result_dtype) 1549 if is_complex_dtype(result_dtype): 1550 result_dtype = corresponding_real_dtype(result_dtype) 1551 return computation_dtype, result_dtype 1552 elif type_promotion_kind is ELEMENTWISE_TYPE_PROMOTION_KIND.BOOL_TO_LONG: 1553 if is_boolean_dtype(result_dtype): 1554 return torch.long, torch.long 1555 return get_computation_dtype(result_dtype), result_dtype 1556 elif type_promotion_kind is ELEMENTWISE_TYPE_PROMOTION_KIND.ALWAYS_BOOL: 1557 return get_computation_dtype(result_dtype), torch.bool 1558 else: 1559 raise ValueError(f"Unknown type promotion kind {str(type_promotion_kind)}") 1560 1561 1562def reduction_dtypes( 1563 arg, 1564 output_dtype_kind: REDUCTION_OUTPUT_TYPE_KIND, 1565 dtype: Optional[torch.dtype] = None, 1566) -> Tuple[torch.dtype, Optional[torch.dtype]]: 1567 # even though some reductions, like amin or amax, don't strictly require type promotion, 1568 # all the math ops (including comparisons) are still defined only for a computation type, 1569 # so promotion will still happen. We are doing it explicitly here 1570 inp_dtype = dtype if dtype is not None else arg.dtype 1571 computation_dtype = get_computation_dtype(inp_dtype) 1572 if ( 1573 output_dtype_kind == REDUCTION_OUTPUT_TYPE_KIND.SAME 1574 or output_dtype_kind == REDUCTION_OUTPUT_TYPE_KIND.COMPLEX_TO_FLOAT 1575 ): 1576 result_dtype = dtype if dtype else arg.dtype 1577 if ( 1578 output_dtype_kind == REDUCTION_OUTPUT_TYPE_KIND.COMPLEX_TO_FLOAT 1579 and is_complex_dtype(result_dtype) 1580 ): 1581 result_dtype = corresponding_real_dtype(result_dtype) 1582 elif output_dtype_kind == REDUCTION_OUTPUT_TYPE_KIND.KEEP_PROMOTED_TYPE: 1583 result_dtype = None 1584 else: # ALWAYS_BOOL 1585 result_dtype = torch.bool 1586 return computation_dtype, result_dtype 1587 1588 1589# This function's logic is borrowed from the following functions defined in C++: 1590# batched_matrix_contiguous_strides and contiguous_strides 1591def make_contiguous_strides_for( 1592 shape: ShapeType, row_major: bool = True 1593) -> Tuple[int, ...]: 1594 """ 1595 Returns the strides of a contiguous tensor if row_major 1596 If row_major=True, it returns the strides of a contiguous batch of Fortran-contiguous matrices 1597 This is often used when calling external libraries like BLAS/LAPACK/cuSolver... 1598 """ 1599 # contiguous_strides from c10/util/strides.h 1600 validate_shape(shape) 1601 if not shape: 1602 return () 1603 1604 from torch.fx.experimental.symbolic_shapes import is_nested_int 1605 1606 multiplier = 1 1607 strides = [] 1608 for l in reversed(shape): 1609 strides.append(multiplier) 1610 multiplier *= l if is_nested_int(l) else sym_max(l, 1) 1611 1612 result = tuple(reversed(strides)) 1613 1614 # batched_matrix_contiguous_strides from aten/src/ATen/native/LinearAlgebraUtils.h 1615 if row_major: 1616 return result 1617 else: 1618 if len(shape) < 2: 1619 return result 1620 return result[:-2] + (1, max(shape[-2], 1)) 1621 1622 1623def make_channels_last_1d_strides_for(shape: ShapeType) -> Tuple[int, ...]: 1624 torch._check( 1625 len(shape) == 3, 1626 lambda: "Only tensors of rank 3 can use the channels_last_1d memory format", 1627 ) 1628 1629 multiplier = 1 1630 strides = [0] * 3 1631 for idx in (1, -1, 0): 1632 # NOTE: intentionally divergence from make_contiguous_strides_for 1633 # This is consistent with eager 1634 strides[idx] = multiplier 1635 multiplier *= shape[idx] 1636 1637 return tuple(strides) 1638 1639 1640def make_channels_last_2d_strides_for(shape: ShapeType) -> Tuple[int, ...]: 1641 # TODO: maybe inform the user of channels_last_3d if rank of the tensor is 5? 1642 torch._check( 1643 len(shape) == 4, 1644 lambda: "Only tensors of rank 4 can use the channels_last memory format", 1645 ) 1646 1647 multiplier = 1 1648 strides = [0] * 4 1649 for idx in (1, -1, -2, 0): 1650 # NOTE: intentionally divergence from make_contiguous_strides_for 1651 # This is consistent with eager 1652 strides[idx] = multiplier 1653 multiplier *= shape[idx] 1654 1655 return tuple(strides) 1656 1657 1658def make_channels_last_3d_strides_for(shape: ShapeType) -> Tuple[int, ...]: 1659 torch._check( 1660 len(shape) == 5, 1661 lambda: "Only tensors of rank 5 can use the channels_last_3d memory format", 1662 ) 1663 1664 multiplier = 1 1665 strides = [0] * 5 1666 for idx in (1, -1, -2, -3, 0): 1667 # NOTE: intentionally divergence from make_contiguous_strides_for 1668 # This is consistent with eager 1669 strides[idx] = multiplier 1670 multiplier *= shape[idx] 1671 1672 return tuple(strides) 1673 1674 1675def make_channels_last_strides_for(shape: ShapeType) -> Tuple[int, ...]: 1676 ndim = len(shape) if isinstance(shape, Sequence) else 1 1677 if ndim == 3: 1678 return make_channels_last_1d_strides_for(shape) 1679 elif ndim == 4: 1680 return make_channels_last_2d_strides_for(shape) 1681 elif ndim == 5: 1682 return make_channels_last_3d_strides_for(shape) 1683 else: 1684 raise RuntimeError( 1685 f"no channels last format strides exist in {ndim} dimensions" 1686 ) 1687 1688 1689def compute_reduction_output_shape( 1690 shape: ShapeType, dimensions: Sequence 1691) -> Tuple[int, ...]: 1692 for idx in dimensions: 1693 validate_idx(len(shape), idx) 1694 1695 new_shape = [] 1696 for idx in range(len(shape)): 1697 if idx in dimensions: 1698 continue 1699 1700 new_shape.append(shape[idx]) 1701 1702 return tuple(new_shape) 1703 1704 1705def validate_no_repeating_dims(dims: Sequence): 1706 if len(dims) != len(set(dims)): 1707 raise RuntimeError("duplicate value in the list of dims") 1708 1709 1710def reduction_dims(shape: ShapeType, dims: Optional[Sequence]) -> Tuple[int, ...]: 1711 if dims is None: 1712 return tuple(range(len(shape))) 1713 dims = tuple(canonicalize_dim(len(shape), idx) for idx in dims) 1714 validate_no_repeating_dims(dims) 1715 return dims 1716 1717 1718def set_correction( 1719 unbiased: Optional[bool] = None, 1720 correction: Optional[NumberType] = None, 1721) -> float: 1722 if correction is not None and unbiased is not None: 1723 raise RuntimeError("cannot specify both correction and unbiased arguments") 1724 elif correction is None and unbiased is None: 1725 correction = 1.0 1726 elif correction is None and unbiased is not None: 1727 correction = 0.0 if unbiased is False else 1.0 1728 # NB: we don't actually support symint here, but it's harmless to accept 1729 if not isinstance(correction, (IntLike, FloatLike)): 1730 raise ValueError("correction argument should be integer or float") 1731 if correction < 0: 1732 raise ValueError("correction argument should be non-negative") 1733 return sym_float(correction) 1734 1735 1736def compute_required_storage_length( 1737 shape: ShapeType, strides: StrideType, storage_offset: int 1738) -> int: 1739 """Computes the minimum storage size to hold the given tensor geometry. 1740 1741 Example 1742 ======= 1743 1744 This is the size of a newly allocated tensor's storage, in units of elements 1745 1746 >>> t = torch.empty((10, 20)) 1747 >>> compute_required_storage_length(t.shape, t.stride(), t.storage_offset()) 1748 200 1749 1750 >>> # xdoctest: +SKIP(failing) 1751 >>> t2 = torch.empty_strided((1, 2, 3), (5, 7, 11)) 1752 >>> size = compute_required_storage_length(t2.shape, t2.stride(), t2.storage_offset()) 1753 >>> size == t.storage().size() 1754 True 1755 1756 A valid tensor may have a larger storage size, but never smaller 1757 1758 >>> slice = torch.empty(100)[20:40] 1759 >>> slice.storage().size() 1760 100 1761 1762 >>> compute_required_storage_length(slice.shape, slice.stride(), slice.storage_offset()) 1763 40 1764 1765 """ 1766 from torch.fx.experimental.symbolic_shapes import guard_size_oblivious 1767 1768 # Short-circuits if the shape has no elements 1769 if guard_size_oblivious(reduce(operator.mul, shape, 1) == 0): 1770 return 0 1771 1772 max_offset = sum((x - 1) * y for x, y in zip(shape, strides)) 1773 # +1 to account for the first element which offsets are taken from 1774 return 1 + storage_offset + max_offset 1775 1776 1777def check_in_bounds_for_storage( 1778 a: torch.TypedStorage, shape: ShapeType, strides: StrideType, storage_offset: int 1779): 1780 """ 1781 Determines if the given shape, strides, and offset are valid for the given storage. 1782 """ 1783 1784 required_length = compute_required_storage_length(shape, strides, storage_offset) 1785 if a.size() < required_length: 1786 msg = ( 1787 f"Can't view a storage of size {a.size()} with an offset of {storage_offset}, " 1788 f"shape of {str(shape)}, and strides of {str(strides)}, " 1789 f"which requires a storage of size {required_length}" 1790 ) 1791 raise ValueError(msg) 1792 1793 1794# NOTE: This function should ideally be removed, but some Meta internal models 1795# packaged with `torch.package` are using it, so it will have to be removed 1796# at some point in the future when those models no longer use this function. 1797@deprecated( 1798 "`torch._prims_common.check` is deprecated and will be removed in the future. " 1799 "Please use `torch._check*` functions instead.", 1800 category=FutureWarning, 1801) 1802def check( 1803 b: bool, s: Callable[[], str], exc_type: Type[Exception] = RuntimeError 1804) -> None: 1805 """ 1806 Helper function for raising an error_type (default: RuntimeError) if a boolean condition fails. 1807 Error message is a callable producing a string (to avoid wasting time 1808 string formatting in non-error case, and also to make it easier for torchdynamo 1809 to trace.) 1810 1811 .. note:: This function is planned for removal in the future. Please use 1812 `torch._check*` functions instead. 1813 """ 1814 torch._check_with(exc_type, b, s) 1815 1816 1817# This combines is_channels_last_strides_2d and is_channels_last_strides_3d in 1818# c10/core/MemoryFormat.h into one function 1819def are_strides_like_channels_last( 1820 shape: Sequence[int], strides: Sequence[int] 1821) -> bool: 1822 from torch.fx.experimental.symbolic_shapes import guard_size_oblivious 1823 1824 ndim = len(shape) 1825 1826 if ndim == 4: 1827 # Check for channels_last_2d 1828 dim_order = [1, 3, 2, 0] 1829 elif ndim == 5: 1830 # Check for channels_last_3d 1831 dim_order = [1, 4, 3, 2, 0] 1832 else: 1833 return False 1834 1835 if guard_size_oblivious(strides[1] == 0): 1836 return False 1837 1838 min = 0 1839 for d in dim_order: 1840 if guard_size_oblivious(shape[d] == 0): 1841 return False 1842 if guard_size_oblivious(strides[d] < min): 1843 return False 1844 if d == 0 and min == strides[1]: 1845 return False 1846 min = strides[d] 1847 if guard_size_oblivious(strides[d] > 1): 1848 min *= shape[d] 1849 return True 1850 1851 1852def suggest_memory_format(x: TensorLikeType) -> torch.memory_format: 1853 if x.layout != torch.strided: 1854 return torch.contiguous_format 1855 1856 if are_strides_like_channels_last(x.shape, x.stride()): 1857 return torch.channels_last if x.ndim == 4 else torch.channels_last_3d 1858 1859 return torch.contiguous_format 1860 1861 1862def prod(xs: Sequence[NumberType]) -> NumberType: 1863 """Product of elements in input sequence. Returns 1 for empty sequence""" 1864 return reduce(operator.mul, xs, 1) 1865 1866 1867def is_expandable_to(shape: ShapeType, desired: ShapeType) -> bool: 1868 """Checks if a shape can be expanded to another shape. 1869 This is equivalent to checking if the two shapes are broadcastable. 1870 """ 1871 # This is a Python implementation of 1872 # aten/src/ATen/ExpandUtils.h:is_expandable_to 1873 if len(shape) > len(desired): 1874 return False 1875 for i in range(len(shape)): 1876 if shape[-i - 1] != desired[-i - 1] and shape[-i - 1] != 1: 1877 return False 1878 return True 1879 1880 1881def mask_tensor(mask: TensorLikeType, t: TensorLikeType): 1882 """ 1883 Similar to torch.where(mask, t, 0) but if t is boolean, 1884 result is also boolean and not promoted to int. 1885 """ 1886 # torch.where(mask, t, False) is equivalent 1887 # but feels hacky and might break in the future 1888 if t.dtype is torch.bool: 1889 return mask.logical_and(t) 1890 else: 1891 return torch.where(mask, t, 0) 1892 1893 1894def get_aten_op(fn: Callable, name: str): 1895 """ 1896 Given the __module__ of reference and its name, it returns 1897 (our best guess of) the ATen name of the associated operation 1898 1899 Note: In ATen, the __name__ of a function within a module often 1900 starts by the module name. E.g. linalg_eigh, or special_zeta 1901 """ 1902 module = fn.__module__ 1903 prefix = "torch._refs" 1904 assert module.startswith(prefix) 1905 module = module[len(prefix) :] 1906 # We want to go from .special / .nn.functional 1907 # to special and special_ / nn_functional_ 1908 if module: 1909 module = module[1:] 1910 module = module.replace(".", "_") 1911 module = module + "_" 1912 return getattr(torch._ops.ops.aten, f"{module}{name}") 1913 1914 1915def dtype_or_default(dtype: Optional[torch.dtype]) -> torch.dtype: 1916 return dtype if dtype is not None else torch.get_default_dtype() 1917 1918 1919def device_or_default(device: Optional[DeviceLikeType]) -> DeviceLikeType: 1920 return device if device is not None else torch.device("cpu") 1921 1922 1923def layout_or_default(layout: Optional[torch.layout]) -> torch.layout: 1924 return layout if layout is not None else torch.strided 1925 1926 1927def clone_preserve_strides(x): 1928 needed_size = compute_required_storage_length( 1929 x.size(), x.stride(), x.storage_offset() 1930 ) 1931 # Our eager implementations for *_scatter ops are all primitives w.r.t autograd, 1932 # so these as_strided() calls are not seen by autograd. 1933 # We need to mimic this behavior in our ref/prim implementations. 1934 # TODO: a better way to handle this would be with a new op, "_unsafe_as_strided" 1935 # We should revisit this when we add a compositional as_strided op, 1936 # and also as part of https://github.com/pytorch/pytorch/issues/90507 1937 try: 1938 old = torch._C._dispatch_tls_is_dispatch_key_excluded( 1939 torch._C.DispatchKey.ADInplaceOrView 1940 ) 1941 torch._C._dispatch_tls_set_dispatch_key_excluded( 1942 torch._C.DispatchKey.ADInplaceOrView, True 1943 ) 1944 buffer = torch.as_strided(x, (needed_size,), (1,), 0).clone() 1945 return torch.as_strided(buffer, x.size(), x.stride(), x.storage_offset()) 1946 finally: 1947 torch._C._dispatch_tls_set_dispatch_key_excluded( 1948 torch._C.DispatchKey.ADInplaceOrView, old 1949 ) 1950 1951 1952def alert_not_deterministic(caller: str): 1953 if torch.are_deterministic_algorithms_enabled(): 1954 if torch.is_deterministic_algorithms_warn_only_enabled(): 1955 warnings.warn( 1956 f"{caller} does not have a deterministic implementation, but you set " 1957 f"'torch.use_deterministic_algorithms(True, warn_only=True)'. " 1958 f"You can file an issue at https://github.com/pytorch/pytorch/issues " 1959 f"to help us prioritize adding deterministic support for this operation." 1960 ) 1961 else: 1962 torch._check( 1963 False, 1964 lambda: ( 1965 f"{caller} does not have a deterministic implementation, but you set " 1966 f"'torch.use_deterministic_algorithms(True)'. You can turn off " 1967 f"determinism just for this operation, or you can use the " 1968 f"'warn_only=True' option, if that's acceptable for your application. " 1969 f"You can also file an issue at https://github.com/pytorch/pytorch/issues " 1970 f"to help us prioritize adding deterministic support for this operation." 1971 ), 1972 ) 1973 1974 1975class CUDARngStateHelper: 1976 @staticmethod 1977 def get_torch_state_as_tuple(fake_mode=nullcontext()): 1978 if not torch.cuda.is_available(): 1979 raise RuntimeError("CUDA not available") 1980 1981 with fake_mode: 1982 seed = torch.tensor(torch.cuda.initial_seed()) 1983 offset = torch.tensor(torch.cuda._get_rng_state_offset()) 1984 return seed, offset 1985 1986 @staticmethod 1987 def set_torch_state_tensor(seed, offset): 1988 # Rng state is [64-bit seed, 64-bit offset] 1989 seed_portion = seed.reshape([1]).view(torch.uint8) 1990 offset_portion = offset.reshape([1]).view(torch.uint8) 1991 new_state = torch.cat([seed_portion, offset_portion]) 1992 torch.cuda.set_rng_state(new_state) 1993 1994 @staticmethod 1995 def set_new_offset(relative_offset): 1996 torch.cuda._set_rng_state_offset(relative_offset.item()) 1997