1# mypy: allow-untyped-decorators 2# mypy: allow-untyped-defs 3import math 4from enum import Enum 5from typing import List, Optional, Sequence, Tuple, Union 6 7import torch 8import torch._prims_common as utils 9from torch import SymBool, SymFloat, Tensor 10from torch._decomp import ( 11 _add_op_to_registry, 12 _convert_out_params, 13 global_decomposition_table, 14 meta_table, 15) 16from torch._ops import OpOverload 17from torch._prims import _prim_elementwise_meta, ELEMENTWISE_PRIM_TYPE_PROMOTION_KIND 18from torch._prims_common import ( 19 corresponding_complex_dtype, 20 corresponding_real_dtype, 21 elementwise_dtypes, 22 ELEMENTWISE_TYPE_PROMOTION_KIND, 23 IntLike, 24 make_contiguous_strides_for, 25 Number, 26 TensorLike, 27) 28from torch._prims_common.wrappers import ( 29 _maybe_convert_to_dtype, 30 _maybe_resize_out, 31 _resize_output_check, 32 _safe_copy_out, 33 out_wrapper, 34) 35from torch._refs import _broadcast_shapes, _maybe_broadcast 36from torch.utils import _pytree as pytree 37 38 39aten = torch.ops.aten 40 41_meta_lib_dont_use_me_use_register_meta = torch.library.Library("aten", "IMPL", "Meta") 42 43 44def register_meta(op): 45 def wrapper(fn): 46 fn = _convert_out_params(fn) 47 48 def register(op): 49 _add_op_to_registry(meta_table, op, fn) 50 51 pytree.tree_map_(register, op) 52 return fn 53 54 return wrapper 55 56 57def elementwise_meta( 58 *args, 59 type_promotion: ELEMENTWISE_TYPE_PROMOTION_KIND, 60): 61 # Perform type promotion, as this is expected from prim_metafunction 62 _, result_dtype = utils.elementwise_dtypes( 63 *args, 64 type_promotion_kind=type_promotion, 65 ) 66 args = [_maybe_convert_to_dtype(x, result_dtype) for x in args] 67 68 # Broadcast 69 args = _maybe_broadcast(*args) 70 71 # Perform prim checks 72 return _prim_elementwise_meta( 73 *args, type_promotion=ELEMENTWISE_PRIM_TYPE_PROMOTION_KIND.DEFAULT 74 ) 75 76 77def toRealValueType(dtype): 78 from_complex = { 79 torch.complex32: torch.half, 80 torch.cfloat: torch.float, 81 torch.cdouble: torch.double, 82 } 83 return from_complex.get(dtype, dtype) 84 85 86def check_inplace_broadcast(self_shape, *args_shape): 87 broadcasted_shape = tuple(_broadcast_shapes(self_shape, *args_shape)) 88 torch._check( 89 broadcasted_shape == self_shape, 90 lambda: f"output with shape {self_shape} doesn't match the broadcast shape {broadcasted_shape}", 91 ) 92 93 94@register_meta([aten.linspace, aten.logspace]) 95@out_wrapper() 96def meta_linspace_logspace( 97 start, 98 end, 99 steps, 100 base=None, 101 dtype=None, 102 device=None, 103 layout=torch.strided, 104 pin_memory=False, 105 requires_grad=False, 106): 107 if isinstance(start, torch.Tensor): 108 torch._check( 109 start.dim() == 0, 110 lambda: "linspace only supports 0-dimensional start and end tensors", 111 ) 112 if isinstance(end, torch.Tensor): 113 torch._check( 114 end.dim() == 0, 115 lambda: "linspace only supports 0-dimensional start and end tensors", 116 ) 117 118 if any(isinstance(arg, complex) for arg in (start, end, steps)): 119 default_complex_dtype = utils.corresponding_complex_dtype( 120 torch.get_default_dtype() 121 ) 122 if dtype is None: 123 dtype = default_complex_dtype 124 else: 125 torch._check( 126 utils.is_complex_dtype(dtype), 127 lambda: f"linspace(): inferred dtype {default_complex_dtype} can't be safely cast to passed dtype {dtype}", 128 ) 129 else: 130 dtype = dtype or torch.get_default_dtype() 131 assert isinstance(dtype, torch.dtype) 132 133 # steps does not participate in the computation of the dtype 134 torch._check_type( 135 isinstance(steps, IntLike), 136 lambda: f"received an invalid combination of arguments - got \ 137({type(start).__name__}, {type(end).__name__}, {type(steps).__name__})", 138 ) 139 assert isinstance(steps, IntLike) # for mypy 140 torch._check(steps >= 0, lambda: "number of steps must be non-negative") 141 142 return torch.empty( 143 (steps,), # type: ignore[arg-type] 144 dtype=dtype, 145 layout=layout, 146 device="meta", 147 pin_memory=pin_memory, 148 requires_grad=requires_grad, 149 ) 150 151 152@register_meta([aten.take.default, aten.take.out]) 153@out_wrapper() 154def meta_take(self, index): 155 # Type and device checks 156 torch._check( 157 index.dtype == torch.long, 158 lambda: f"take(): Expected a long tensor for index, but got {index.dtype}", 159 ) 160 # Index checks 161 torch._check_index( 162 not (self.numel() == 0 and index.numel() != 0), 163 lambda: "take(): tried to take from an empty tensor", 164 ) 165 return self.new_empty(index.shape) 166 167 168@register_meta([aten.linalg_cross.default, aten.linalg_cross.out]) 169@out_wrapper() 170def linalg_cross(self, other, *, dim=-1): 171 x_d = self.ndim 172 y_d = other.ndim 173 torch._check( 174 x_d == y_d, 175 lambda: "linalg.cross: inputs must have the same number of dimensions.", 176 ) 177 torch._check( 178 self.size(dim) == 3 and other.size(dim) == 3, 179 lambda: ( 180 f"linalg.cross: inputs dimension {dim} must have length 3. " 181 f"Got {self.size(dim)} and {other.size(dim)}" 182 ), 183 ) 184 out_shape = _broadcast_shapes(self.shape, other.shape) 185 return self.new_empty(out_shape) 186 187 188@register_meta(aten.linalg_matrix_exp) 189@out_wrapper() 190def linalg_matrix_exp(self): 191 squareCheckInputs(self, "linalg.matrix_exp") 192 checkFloatingOrComplex(self, "linalg.matrix_exp") 193 return torch.empty_like(self, memory_format=torch.contiguous_format) 194 195 196@register_meta( 197 [aten.cummax.default, aten.cummax.out, aten.cummin.default, aten.cummin.out] 198) 199@out_wrapper("values", "indices") 200def cummaxmin(self, dim): 201 values = torch.empty(self.shape, device=self.device, dtype=self.dtype) 202 indices = torch.empty(self.shape, device=self.device, dtype=torch.int64) 203 if self.numel() != 0 and self.ndim != 0: 204 # Checks that dim is within bounds 205 maybe_wrap_dim(dim, self.ndim) 206 return values, indices 207 208 209@register_meta([aten.logcumsumexp.default, aten.logcumsumexp.out]) 210@out_wrapper() 211def logcumsumexp(self, dim): 212 # Checks that dim is within bounds 213 maybe_wrap_dim(dim, self.ndim) 214 return torch.empty_like(self).contiguous() 215 216 217# Stride-related code from _exec_fft in aten/src/ATen/native/cuda/SpectralOps.cpp 218def _exec_fft(out, self, out_sizes, dim, forward): 219 ndim = self.ndim 220 signal_ndim = len(dim) 221 batch_dims = ndim - signal_ndim 222 223 # Permute dimensions so batch dimensions come first, and in stride order 224 dim_permute = list(range(ndim)) 225 226 is_transformed_dim = [False for _ in range(ndim)] 227 for d in dim: 228 is_transformed_dim[d] = True 229 230 # std::partition 231 left, right = [], [] 232 for d in dim_permute: 233 if not is_transformed_dim[d]: 234 left.append(d) 235 else: 236 right.append(d) 237 dim_permute = left + right 238 batch_end = len(left) 239 240 self_strides = self.stride() 241 tmp = dim_permute[:batch_end] 242 tmp.sort(key=lambda x: self_strides[x], reverse=True) 243 dim_permute = tmp + dim_permute[batch_end:] 244 input = self.permute(dim_permute) 245 246 # Collapse batch dimensions into a single dimension 247 batched_sizes = [-1] + list(input.shape[batch_dims:]) 248 input = input.reshape(batched_sizes) 249 250 batch_size = input.size(0) 251 batched_sizes[0] = batch_size 252 batched_out_sizes = batched_sizes 253 for i in range(len(dim)): 254 batched_out_sizes[i + 1] = out_sizes[dim[i]] 255 out = out.reshape(batched_out_sizes) 256 257 # Reshaping to original batch shape and inverting the dimension permutation 258 out_strides = [0 for _ in range(ndim)] 259 batch_numel = 1 260 i = batch_dims - 1 261 while i >= 0: 262 out_strides[dim_permute[i]] = batch_numel * out.stride(0) 263 batch_numel *= out_sizes[dim_permute[i]] 264 i -= 1 265 for i in range(batch_dims, ndim): 266 out_strides[dim_permute[i]] = out.stride(1 + (i - batch_dims)) 267 return out.as_strided(out_sizes, out_strides, out.storage_offset()) 268 269 270# See _fft_c2c_cufft in aten/src/ATen/native/cuda/SpectralOps.cpp 271# and _fft_c2c_mkl in aten/src/ATen/native/mkl/SpectralOps.cpp 272@register_meta([aten._fft_c2c.default, aten._fft_c2c.out]) 273@out_wrapper() 274def meta_fft_c2c(self, dim, normalization, forward): 275 assert self.dtype.is_complex 276 277 out_sizes = self.shape 278 output = self.new_empty(out_sizes) 279 280 if not dim: 281 return output 282 283 sorted_dims = dim[:] 284 self_strides = self.stride() 285 sorted_dims.sort(key=lambda x: self_strides[x], reverse=True) 286 output = _exec_fft(output, self, out_sizes, sorted_dims, forward) 287 288 return output 289 290 291@register_meta([aten._fft_r2c.default, aten._fft_r2c.out]) 292@out_wrapper() 293def meta_fft_r2c(self, dim, normalization, onesided): 294 assert self.dtype.is_floating_point 295 output_sizes = list(self.size()) 296 297 if onesided: 298 last_dim = dim[-1] 299 last_dim_halfsize = (output_sizes[last_dim] // 2) + 1 300 output_sizes[last_dim] = last_dim_halfsize 301 302 return self.new_empty( 303 output_sizes, dtype=utils.corresponding_complex_dtype(self.dtype) 304 ) 305 306 307@register_meta(aten.randperm.generator_out) 308def meta_randperm(n, *, generator=None, out): 309 return _maybe_resize_out(out, torch.Size([n])) 310 311 312@register_meta(aten.randperm.default) 313def meta_randperm_default( 314 n, 315 *, 316 dtype=torch.long, 317 layout=None, 318 device=None, 319 pin_memory=None, 320): 321 return torch.empty( 322 n, dtype=dtype, layout=layout, device=device, pin_memory=pin_memory 323 ) 324 325 326@register_meta([aten.randint.default, aten.randint.out]) 327@out_wrapper() 328def meta_randint( 329 high, 330 size, 331 *, 332 dtype=torch.long, 333 layout=None, 334 device=None, 335 pin_memory=None, 336): 337 return torch.empty( 338 size, dtype=dtype, layout=layout, device=device, pin_memory=pin_memory 339 ) 340 341 342@register_meta([aten.randint.low, aten.randint.low_out]) 343@out_wrapper() 344def meta_randint_low( 345 low, 346 high, 347 size, 348 *, 349 dtype=torch.long, 350 layout=None, 351 device=None, 352 pin_memory=None, 353): 354 return torch.empty( 355 size, dtype=dtype, layout=layout, device=device, pin_memory=pin_memory 356 ) 357 358 359@register_meta([aten.rand.default, aten.rand.out]) 360@out_wrapper() 361def meta_rand_default(size, *, dtype=None, layout=None, device=None, pin_memory=None): 362 return torch.empty( 363 size, dtype=dtype, layout=layout, device=device, pin_memory=pin_memory 364 ) 365 366 367@register_meta([aten._fft_c2r.default, aten._fft_c2r.out]) 368@out_wrapper() 369def meta_fft_c2r(self, dim, normalization, lastdim): 370 assert self.dtype.is_complex 371 output_sizes = list(self.size()) 372 output_sizes[dim[-1]] = lastdim 373 return self.new_empty(output_sizes, dtype=toRealValueType(self.dtype)) 374 375 376@register_meta(aten.copy_.default) 377def meta_copy_(self, src, non_blocking=False): 378 # This code simulates the original decomp from inductor, 379 # which runs most of the meta checks that we care about. 380 # In theory, we should make this more robust by carefully 381 # auditing our C++ copy_() kernel and copying the checks here. 382 from torch.fx.experimental.symbolic_shapes import free_unbacked_symbols 383 384 # TODO: Ideally, we'd insert a deferred runtime assert here, but if we are 385 # calling an actual copy_, you'll get that automatically 386 # https://github.com/pytorch/pytorch/issues/122477 387 if ( 388 not free_unbacked_symbols(self) and torch._debug_has_internal_overlap(self) == 1 389 ): # 1 == MemOverlap::Yes 390 raise RuntimeError( 391 "more than one element of the written-to tensor refers to a single memory location" 392 ) 393 394 if isinstance(src, Tensor): 395 intermediate = src.to(self, non_blocking) 396 if self.size() != intermediate.size(): 397 aten.expand_copy.default(intermediate, self.size()) 398 return self 399 400 401def inferUnsqueezeGeometry(tensor, dim): 402 result_sizes = list(tensor.size()) 403 result_strides = list(tensor.stride()) 404 new_stride = 1 if dim >= tensor.dim() else result_sizes[dim] * result_strides[dim] 405 result_sizes.insert(dim, 1) 406 result_strides.insert(dim, new_stride) 407 return result_sizes, result_strides 408 409 410@register_meta(aten.unsqueeze_.default) 411def meta_unsqueeze_(self, dim): 412 dim = maybe_wrap_dim(dim, self.dim() + 1) 413 g_sizes, g_strides = inferUnsqueezeGeometry(self, dim) 414 self.as_strided_(g_sizes, g_strides) 415 return self 416 417 418@register_meta(aten._sparse_semi_structured_linear) 419def meta_sparse_structured_linear( 420 input: Tensor, 421 weight: Tensor, 422 _meta: Tensor, 423 bias: Optional[Tensor] = None, 424 _activation_opt: Optional[str] = None, 425 out_dtype: Optional[torch.dtype] = None, 426): 427 output_sizes = list(input.shape) 428 if bias is not None: 429 assert weight.size(0) == bias.size(0), "output size mismatch" 430 assert weight.size(1) == input.size(-1) / 2 431 output_sizes[-1] = weight.size(0) 432 433 # see: https://github.com/pytorch/pytorch/pull/114477#issuecomment-1830121375 434 # We assume that we have already squashed the inputs into a 2-D tensor 435 # Then, as the output is transposed, we need to propagate the transposed 436 # stride information to the output tensor 437 assert len(input.shape) == 2, "we can only handle the squashed input case" 438 transposed_strides = (1, input.size(0)) 439 440 if out_dtype is not None: 441 assert ( 442 input.dtype == torch.int8 and out_dtype == torch.int32 443 ), "out_dtype is only supported for i8i8->i32 linear operator" 444 output = input.new_empty( 445 output_sizes, 446 dtype=input.dtype if out_dtype is None else out_dtype, 447 ).as_strided(output_sizes, transposed_strides) 448 449 return output 450 451 452@register_meta(aten._sparse_semi_structured_mm) 453def meta_sparse_structured_mm( 454 mat1: Tensor, 455 mat1_meta: Tensor, 456 mat2: Tensor, 457 out_dtype: Optional[torch.dtype] = None, 458): 459 assert len(mat1.shape) == 2 460 assert len(mat1_meta.shape) == 2 461 assert len(mat2.shape) == 2 462 assert mat1.size(1) == mat2.size(0) / 2 463 output_sizes = [mat1.size(0), mat2.size(1)] 464 465 if out_dtype is not None: 466 assert ( 467 mat2.dtype == torch.int8 and out_dtype == torch.int32 468 ), "out_dtype is only supported for i8i8->i32 linear operator" 469 output = mat2.new_empty( 470 output_sizes, 471 dtype=mat2.dtype if out_dtype is None else out_dtype, 472 ) 473 474 return output 475 476 477@register_meta(aten._sparse_semi_structured_addmm) 478def meta_sparse_structured_addmm( 479 input: Tensor, 480 mat1: Tensor, 481 mat1_meta: Tensor, 482 mat2: Tensor, 483 *, 484 alpha=1, 485 beta=1, 486 out_dtype: Optional[torch.dtype] = None, 487): 488 assert ( 489 len(input.shape) == 1 490 ), "only input broadcasted to columns of mat1 * mat2 product is supported" 491 assert len(mat1.shape) == 2 492 assert len(mat1_meta.shape) == 2 493 assert len(mat2.shape) == 2 494 assert input.size(0) == mat1.size( 495 0 496 ), "only input broadcasted to columns of mat1 * mat2 product is supported" 497 assert mat1.size(1) == mat2.size(0) / 2 498 output_sizes = [mat1.size(0), mat2.size(1)] 499 500 if out_dtype is not None: 501 assert ( 502 mat2.dtype == torch.int8 and out_dtype == torch.int32 503 ), "out_dtype is only supported for i8i8->i32 linear operator" 504 output = mat2.new_empty( 505 output_sizes, 506 dtype=mat2.dtype if out_dtype is None else out_dtype, 507 ) 508 509 return output 510 511 512@register_meta(aten._cslt_sparse_mm) 513def meta__cslt_sparse_mm( 514 compressed_A: torch.Tensor, 515 dense_B: torch.Tensor, 516 bias: Optional[Tensor] = None, 517 alpha: Optional[Tensor] = None, 518 out_dtype: Optional[torch.dtype] = None, 519 transpose_result: bool = False, 520): 521 assert dense_B.dtype in { 522 torch.float32, 523 torch.float16, 524 torch.bfloat16, 525 torch.int8, 526 }, "_cslt_sparse_mm only supports fp16, bf16, and int8" 527 assert compressed_A.dtype == dense_B.dtype, "inputs must have the same dtype" 528 assert len(dense_B.shape) == 2, "_cslt_sparse_mm only supports 2d inputs" 529 530 is_int8_input_type = compressed_A.dtype == torch.int8 531 compression_factor = 10 if is_int8_input_type else 9 532 k = dense_B.size(0) 533 n = dense_B.size(1) 534 m = (compressed_A.numel() * 16) // (compression_factor * k) 535 if bias is not None: 536 assert m == bias.size(0) 537 538 if out_dtype is not None: 539 assert is_int8_input_type and out_dtype in { 540 torch.float16, 541 torch.bfloat16, 542 torch.int32, 543 }, "out_dtype is only supported for i8i8->fp16, bf16, or i32 matmul" 544 output_shape = (n, m) if transpose_result else (m, n) 545 result = dense_B.new_empty(output_shape, dtype=out_dtype) 546 return result 547 548 549@register_meta(aten.index_reduce.default) 550def meta_index_reduce( 551 self: Tensor, 552 dim: int, 553 index: Tensor, 554 source: torch.Tensor, 555 reduce: str, 556 *, 557 include_self: bool = True, 558) -> Tensor: 559 return torch.empty_like(self, memory_format=torch.contiguous_format) 560 561 562@register_meta(aten.index_reduce_.default) 563def meta_index_reduce_( 564 self: Tensor, 565 dim: int, 566 index: Tensor, 567 source: torch.Tensor, 568 reduce: str, 569 *, 570 include_self: bool = True, 571) -> Tensor: 572 return self 573 574 575# Implementations below are taken from https://github.com/albanD/subclass_zoo/blob/main/python_meta_tensor.py 576@out_wrapper() 577@register_meta(aten.index_select.default) 578def meta_index_select(self, dim, index): 579 result_size = list(self.size()) 580 if self.dim() > 0: 581 result_size[dim] = index.numel() 582 return self.new_empty(result_size) 583 584 585@register_meta(aten.segment_reduce.default) 586def meta_segment_reduce( 587 data: Tensor, 588 reduce: str, 589 *, 590 lengths: Optional[Tensor] = None, 591 indices: Optional[Tensor] = None, 592 offsets: Optional[Tensor] = None, 593 axis: int = 0, 594 unsafe: bool = False, 595 initial=None, 596) -> Tensor: 597 if indices is not None: 598 raise NotImplementedError( 599 "segment_reduce(): indices based reduction is not supported yet." 600 ) 601 602 def segment_reduce_lengths_tensor(lengths_shape): 603 return torch.empty( 604 lengths_shape + data.shape[axis + 1 :], 605 dtype=data.dtype, 606 device="meta", 607 memory_format=torch.contiguous_format, 608 ) 609 610 if lengths is not None: 611 return segment_reduce_lengths_tensor(lengths.shape) 612 # FIXME should probably check that lengths and offset aren't both set, but 613 # the ATen implementation neglects this too 614 if offsets is not None: 615 # lengths == torch.diff(offsets) 616 lengths_shape = offsets.shape[:-1] + (offsets.shape[-1] - 1,) 617 return segment_reduce_lengths_tensor(lengths_shape) 618 raise RuntimeError("segment_reduce(): Either lengths or offsets must be defined.") 619 620 621@register_meta([aten.max.default, aten.max.unary_out]) 622@out_wrapper() 623def meta_max(self): 624 return self.new_empty(()) 625 626 627@register_meta(aten.max.dim) 628def meta_max_dim(self, dim, keepdim=False): 629 dim = utils.reduction_dims(self.shape, (dim,)) 630 output_shape = _compute_reduction_shape(self, dim, keepdim) 631 return ( 632 self.new_empty(output_shape), 633 self.new_empty(output_shape, dtype=torch.long), 634 ) 635 636 637@register_meta([aten.min.default, aten.min.unary_out]) 638@out_wrapper() 639def meta_min(self): 640 return self.new_empty(()) 641 642 643@register_meta(aten.min.dim) 644def meta_min_dim(self, dim, keepdim=False): 645 dim = utils.reduction_dims(self.shape, (dim,)) 646 output_shape = _compute_reduction_shape(self, dim, keepdim) 647 return ( 648 self.new_empty(output_shape), 649 self.new_empty(output_shape, dtype=torch.long), 650 ) 651 652 653@register_meta(aten.angle.default) 654def meta_angle(self): 655 if self.is_complex(): 656 result_dtype = corresponding_real_dtype(self.dtype) 657 else: 658 _, result_dtype = elementwise_dtypes( 659 self, 660 type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT, 661 ) 662 return torch.empty_like(self, dtype=result_dtype) 663 664 665@register_meta(aten.angle.out) 666def meta_angle_out(self, out): 667 torch._resize_output_(out, self.size(), self.device) 668 return out.copy_(torch.angle(self)) 669 670 671@register_meta(aten._assert_async.default) 672def assert_async(val): 673 return 674 675 676@register_meta(aten._assert_async.msg) 677def assert_async_meta(val, assert_msg): 678 return 679 680 681@register_meta(aten._print.default) 682def print_meta(s): 683 return 684 685 686@register_meta(aten._make_dep_token.default) 687def make_dep_token( 688 *, 689 dtype=None, 690 layout=None, 691 device=None, 692 pin_memory=None, 693 memory_format=None, 694): 695 return torch.empty(0, device="meta") 696 697 698@register_meta(aten.sym_constrain_range.default) 699def sym_constrain_range(size, min=None, max=None): 700 # Avoid importing sympy at a module level 701 from torch.fx.experimental.symbolic_shapes import constrain_range 702 703 if isinstance(size, (SymFloat, SymBool)): 704 raise ValueError("Constraining SymFloat or Symbool is nyi") 705 constrain_range(size, min=min, max=max) 706 707 708@register_meta(aten._functional_sym_constrain_range.default) 709def functional_sym_constrain_range(size, min=None, max=None, dep_token=None): 710 aten.sym_constrain_range(size, min=min, max=max) 711 return dep_token 712 713 714@register_meta(aten.sym_constrain_range_for_size.default) 715def sym_constrain_range_for_size(size, min=None, max=None): 716 # Avoid importing sympy at a module level 717 from torch.fx.experimental.symbolic_shapes import _constrain_range_for_size 718 719 if isinstance(size, (SymFloat, SymBool)): 720 raise ValueError("Constraining SymFloat or Symbool is nyi") 721 _constrain_range_for_size(size, min=min, max=max) 722 723 724@register_meta(aten._functional_sym_constrain_range_for_size.default) 725def functional_sym_constrain_range_for_size(size, min, max, dep_token): 726 aten.sym_constrain_range_for_size(size, min=min, max=max) 727 return dep_token 728 729 730@register_meta(aten._functional_assert_async.msg) 731def functional_assert_async_meta(val, assert_msg, dep_token): 732 return dep_token 733 734 735# From aten/src/ATen/native/LinearAlgebraUtils.h 736def squareCheckInputs(self: Tensor, f_name: str): 737 assert ( 738 self.dim() >= 2 739 ), f"{f_name}: The input tensor must have at least 2 dimensions." 740 assert ( 741 self.size(-1) == self.size(-2) 742 ), f"{f_name}: A must be batches of square matrices, but they are {self.size(-2)} by {self.size(-1)} matrices" 743 744 745# Validates input shapes and devices 746# for linear solve methods (solve, cholesky_solve, lu_solve, triangular_solve) 747# From aten/src/ATen/native/LinearAlgebraUtils.h 748def linearSolveCheckInputs(self: Tensor, A: Tensor, name: str): 749 torch._check( 750 self.device == A.device, 751 lambda: ( 752 f"Expected b and A to be on the same device, but found b on " 753 f"{self.device} and A on {A.device} instead." 754 ), 755 ) 756 757 torch._check( 758 self.dtype == A.dtype, 759 lambda: ( 760 f"Expected b and A to have the same dtype, but found b of type " 761 f"{self.dtype} and A of type {A.dtype} instead." 762 ), 763 ) 764 765 torch._check( 766 A.size(-1) == A.size(-2), 767 lambda: ( 768 f"A must be batches of square matrices, " 769 f"but they are {A.size(-2)} by {A.size(-1)} matrices" 770 ), 771 ) 772 773 torch._check( 774 A.size(-1) == self.size(-2), 775 lambda: ( 776 f"Incompatible matrix sizes for {name}: each A " 777 f"matrix is {A.size(-1)} by {A.size(-1)}" 778 f" but each b matrix is {self.size(-2)} by {self.size(-1)}" 779 ), 780 ) 781 782 783# From aten/src/ATen/native/LinearAlgebraUtils.h 784def checkFloatingOrComplex( 785 t: Tensor, 786 f_name: str, 787 allow_low_precision_dtypes: bool = True, 788): 789 dtype = t.dtype 790 torch._check( 791 t.is_floating_point() or t.is_complex(), 792 lambda: f"{f_name}: Expected a floating point or complex tensor as input. Got {dtype}", 793 ) 794 if not allow_low_precision_dtypes: 795 torch._check( 796 dtype in (torch.float, torch.double, torch.cfloat, torch.cdouble), 797 lambda: f"{f_name}: Low precision dtypes not supported. Got {dtype}", 798 ) 799 800 801# From aten/src/ATen/native/LinearAlgebraUtils.h 802def checkIsMatrix(A: Tensor, f_name: str, arg_name: str = "A"): 803 torch._check( 804 A.dim() >= 2, 805 lambda: f"{f_name}: The input tensor {arg_name} must have at least 2 dimensions.", 806 ) 807 808 809def checkInputsSolver(A: Tensor, B: Tensor, left: bool, f_name: str): 810 squareCheckInputs(A, f_name) 811 checkIsMatrix(B, f_name) 812 torch._check( 813 A.size(-2) == B.size(-2) if left else A.size(-1) == B.size(-1), 814 lambda: ( 815 f"{f_name}: Incompatible shapes of A and B for the equation " 816 f"{'AX = B' if left else 'XA = B'}" 817 f" ({A.size(-2)}x{A.size(-1)} and {B.size(-2)}x{B.size(-1)})" 818 ), 819 ) 820 821 822def checkSameDevice( 823 fn_name: str, 824 result: Tensor, 825 input: Tensor, 826 result_name: str = "result", 827): 828 torch._check( 829 result.device == input.device, 830 lambda: ( 831 f"{fn_name}: Expected {result_name} and input tensors to be on the same device, but got " 832 f"{result_name} on {result.device} and input on {input.device}" 833 ), 834 ) 835 836 837def checkUplo(UPLO: str): 838 UPLO_uppercase = UPLO.upper() 839 torch._check( 840 len(UPLO) == 1 and (UPLO_uppercase == "U" or UPLO_uppercase == "L"), 841 lambda: f"Expected UPLO argument to be 'L' or 'U', but got {UPLO}", 842 ) 843 844 845@register_meta([aten._linalg_eigh.default, aten._linalg_eigh.eigenvalues]) 846@out_wrapper("eigenvalues", "eigenvectors") 847def meta__linalg_eigh(A: Tensor, UPLO: str = "L", compute_v: bool = True): 848 squareCheckInputs(A, "linalg.eigh") 849 checkUplo(UPLO) 850 851 shape = list(A.shape) 852 if compute_v: 853 vecs = A.new_empty(shape) 854 vecs.as_strided_(shape, make_contiguous_strides_for(shape, row_major=False)) 855 else: 856 vecs = A.new_empty([0]) 857 858 shape.pop() 859 vals = A.new_empty(shape, dtype=toRealValueType(A.dtype)) 860 861 return vals, vecs 862 863 864@register_meta([aten._linalg_eigvals.default, aten.linalg_eigvals.out]) 865@out_wrapper() 866def meta__linalg_eigvals(input: Tensor) -> Tensor: 867 squareCheckInputs(input, "linalg.eigvals") 868 complex_dtype = ( 869 input.dtype 870 if utils.is_complex_dtype(input.dtype) 871 else utils.corresponding_complex_dtype(input.dtype) 872 ) 873 return input.new_empty(input.shape[:-1], dtype=complex_dtype) 874 875 876@register_meta([aten.linalg_eig]) 877@out_wrapper("eigenvalues", "eigenvectors") 878def meta_linalg_eig(input: Tensor): 879 squareCheckInputs(input, "linalg.eig") 880 complex_dtype = ( 881 input.dtype 882 if utils.is_complex_dtype(input.dtype) 883 else utils.corresponding_complex_dtype(input.dtype) 884 ) 885 values = input.new_empty(input.shape[:-1], dtype=complex_dtype) 886 vectors = input.new_empty(input.shape, dtype=complex_dtype) 887 return values, vectors 888 889 890def cloneBatchedColumnMajor(src: Tensor) -> Tensor: 891 return src.mT.clone(memory_format=torch.contiguous_format).transpose(-2, -1) 892 893 894@register_meta(aten._cholesky_solve_helper) 895@out_wrapper() 896def _cholesky_solve_helper(self: Tensor, A: Tensor, upper: bool) -> Tensor: 897 return cloneBatchedColumnMajor(self) 898 899 900@register_meta(aten.cholesky_solve) 901@out_wrapper() 902def cholesky_solve(self: Tensor, A: Tensor, upper: bool = False) -> Tensor: 903 torch._check( 904 self.ndim >= 2, 905 lambda: f"b should have at least 2 dimensions, but has {self.ndim} dimensions instead", 906 ) 907 torch._check( 908 A.ndim >= 2, 909 lambda: f"u should have at least 2 dimensions, but has {A.ndim} dimensions instead", 910 ) 911 self_broadcasted, A_broadcasted = _linalg_broadcast_batch_dims_name( 912 self, A, "cholesky_solve" 913 ) 914 return _cholesky_solve_helper(self_broadcasted, A_broadcasted, upper) 915 916 917@register_meta(aten.cholesky) 918@out_wrapper() 919def cholesky(self: Tensor, upper: bool = False) -> Tensor: 920 if self.numel() == 0: 921 return torch.empty_like(self, memory_format=torch.legacy_contiguous_format) 922 squareCheckInputs(self, "cholesky") 923 return cloneBatchedColumnMajor(self) 924 925 926@register_meta(aten.cholesky_inverse) 927@out_wrapper() 928def cholesky_inverse(self: Tensor, upper: bool = False) -> Tensor: 929 squareCheckInputs(self, "cholesky_inverse") 930 return cloneBatchedColumnMajor(self) 931 932 933# From aten/src/ATen/native/BatchLinearAlgebra.cpp 934@register_meta(aten.linalg_cholesky_ex.default) 935def linalg_cholesky_ex(A: Tensor, upper: bool = False, check_errors: bool = False): 936 squareCheckInputs(A, "linalg.cholesky") 937 checkFloatingOrComplex(A, "linalg.cholesky") 938 939 A_shape = A.shape 940 ndim = len(A_shape) 941 942 # L 943 L_strides = make_contiguous_strides_for(A_shape, False) 944 L = A.new_empty(A_shape) 945 L.as_strided_(A_shape, L_strides) 946 947 # infos 948 infos = A.new_empty(A_shape[0 : ndim - 2], dtype=torch.int32) 949 return L, infos 950 951 952@register_meta( 953 [aten.linalg_householder_product.default, aten.linalg_householder_product.out] 954) 955@out_wrapper() 956def linalg_householder_product(input: Tensor, tau: Tensor) -> Tensor: 957 torch._check( 958 input.ndim >= 2, 959 lambda: "torch.linalg.householder_product: input must have at least 2 dimensions.", 960 ) 961 torch._check( 962 input.size(-2) >= input.size(-1), 963 lambda: "torch.linalg.householder_product: input.shape[-2] must be greater than or equal to input.shape[-1]", 964 ) 965 torch._check( 966 input.size(-1) >= tau.size(-1), 967 lambda: "torch.linalg.householder_product: input.shape[-1] must be greater than or equal to tau.shape[-1]", 968 ) 969 970 torch._check( 971 input.ndim - tau.ndim == 1, 972 lambda: ( 973 f"torch.linalg.householder_product: Expected tau to have one dimension less than input, " 974 f"but got tau.ndim equal to {tau.ndim} and input.ndim is equal to {input.ndim}" 975 ), 976 ) 977 if input.ndim > 2: 978 expected_batch_tau_shape = input.shape[:-2] 979 actual_batch_tau_shape = tau.shape[:-1] 980 torch._check( 981 actual_batch_tau_shape == expected_batch_tau_shape, 982 lambda: ( 983 f"torch.linalg.householder_product: Expected batch dimensions of tau to be " 984 f"equal to input.shape[:-2], but got {actual_batch_tau_shape}" 985 ), 986 ) 987 988 torch._check( 989 tau.dtype == input.dtype, 990 lambda: ( 991 f"torch.linalg.householder_product: tau dtype {tau.dtype}" 992 f" does not match input dtype {input.dtype}" 993 ), 994 ) 995 checkSameDevice("torch.linalg.householder_product", tau, input, "tau") 996 997 return torch.empty_strided( 998 size=input.shape, 999 stride=make_contiguous_strides_for(input.shape, row_major=False), 1000 dtype=input.dtype, 1001 device=input.device, 1002 ) 1003 1004 1005# From aten/src/ATen/native/BatchLinearAlgebra.cpp 1006@register_meta(aten.linalg_inv_ex.default) 1007def linalg_inv_ex_meta(A: Tensor, check_errors: bool = False): 1008 squareCheckInputs(A, "linalg.inv_ex") 1009 checkFloatingOrComplex(A, "linalg.inv_ex", allow_low_precision_dtypes=False) 1010 1011 L = A.new_empty(A.shape) 1012 L.as_strided_(A.shape, make_contiguous_strides_for(A.shape, row_major=False)) 1013 1014 infos = A.new_empty(A.shape[:-2], dtype=torch.int32) 1015 return L, infos 1016 1017 1018@register_meta([aten.linalg_ldl_factor_ex.default, aten.linalg_ldl_factor_ex.out]) 1019@out_wrapper("LD", "pivots", "info") 1020def linalg_ldl_factor_ex_meta( 1021 self: Tensor, 1022 *, 1023 hermitian: bool = False, 1024 check_errors: bool = False, 1025) -> Tuple[Tensor, Tensor, Tensor]: 1026 squareCheckInputs(self, "torch.linalg.ldl_factor_ex") 1027 checkFloatingOrComplex(self, "torch.linalg.ldl_factor_ex") 1028 LD = torch.empty_strided( 1029 size=self.shape, 1030 stride=make_contiguous_strides_for(self.shape, row_major=False), 1031 dtype=self.dtype, 1032 device=self.device, 1033 ) 1034 pivots = self.new_empty(self.shape[:-1], dtype=torch.int) 1035 info = self.new_empty(self.shape[:-2], dtype=torch.int) 1036 return LD, pivots, info 1037 1038 1039@register_meta([aten.linalg_ldl_solve.default, aten.linalg_ldl_solve.out]) 1040@out_wrapper() 1041def linalg_ldl_solve_meta( 1042 LD: Tensor, 1043 pivots: Tensor, 1044 B: Tensor, 1045 *, 1046 hermitian: bool = False, 1047) -> Tensor: 1048 squareCheckInputs(LD, "torch.linalg.ldl_solve") 1049 checkFloatingOrComplex(LD, "torch.linalg.ldl_solve") 1050 linearSolveCheckInputs(B, LD, "torch.linalg.ldl_solve") 1051 torch._check( 1052 B.ndim >= 2, 1053 lambda: ( 1054 f"torch.linalg.ldl_solve: Expected B to have at least 2 dimensions, " 1055 f"but it has {B.ndim} dimensions instead" 1056 ), 1057 ) 1058 expected_pivots_shape = LD.shape[:-1] 1059 torch._check( 1060 expected_pivots_shape == pivots.shape, 1061 lambda: ( 1062 f"torch.linalg.ldl_solve: Expected LD.shape[:-1] and pivots.shape to be the same, " 1063 f"but got pivots with shape {pivots.shape} instead" 1064 ), 1065 ) 1066 torch._check( 1067 utils.is_integer_dtype(pivots.dtype), 1068 lambda: f"torch.linalg.ldl_solve: Expected pivots to be integers. Got {pivots.dtype}", 1069 ) 1070 torch._check( 1071 LD.dtype == B.dtype, 1072 lambda: f"torch.linalg.ldl_solve: LD dtype {LD.dtype} does not match b dtype {B.dtype}", 1073 ) 1074 B_broadcast_size, _ = _linalg_broadcast_batch_dims(B, LD) 1075 return torch.empty_strided( 1076 size=B_broadcast_size, 1077 stride=make_contiguous_strides_for(B_broadcast_size, row_major=False), 1078 dtype=B.dtype, 1079 device=B.device, 1080 ) 1081 1082 1083@register_meta([aten.linalg_lu.default, aten.linalg_lu.out]) 1084@out_wrapper("P", "L", "U") 1085def linalg_lu_meta(A: Tensor, *, pivot: bool = True) -> Tuple[Tensor, Tensor, Tensor]: 1086 torch._check( 1087 A.ndim >= 2, 1088 lambda: f"linalg.lu: Expected tensor with 2 or more dimensions. Got size: {A.shape} instead", 1089 ) 1090 1091 sizes = list(A.shape) 1092 m = sizes[-2] 1093 n = sizes[-1] 1094 k = min(m, n) 1095 1096 sizes[-1] = m 1097 if pivot: 1098 P = A.new_empty(sizes) 1099 else: 1100 P = A.new_empty([0]) 1101 1102 sizes[-1] = k 1103 L = A.new_empty(sizes) 1104 1105 sizes[-2] = k 1106 sizes[-1] = n 1107 U = A.new_empty(sizes) 1108 return P, L, U 1109 1110 1111@register_meta([aten.linalg_lu_factor_ex.default, aten.linalg_lu_factor_ex.out]) 1112@out_wrapper("LU", "pivots", "info") 1113def linalg_lu_factor_ex_meta( 1114 A: Tensor, 1115 *, 1116 pivot: bool = True, 1117 check_errors: bool = False, 1118) -> Tuple[Tensor, Tensor, Tensor]: 1119 torch._check( 1120 A.ndim >= 2, 1121 lambda: f"torch.lu_factor: Expected tensor with 2 or more dimensions. Got size: {A.shape} instead", 1122 ) 1123 1124 sizes = list(A.shape) 1125 m = sizes[-2] 1126 n = sizes[-1] 1127 1128 LU = torch.empty_strided( 1129 size=sizes, 1130 stride=make_contiguous_strides_for(sizes, row_major=False), 1131 dtype=A.dtype, 1132 device=A.device, 1133 ) 1134 1135 # Sets sizes to the size of pivots 1136 sizes.pop() 1137 sizes[-1] = min(m, n) 1138 pivots = A.new_empty(sizes, dtype=torch.int) 1139 1140 # Sets sizes to the size of info 1141 sizes.pop() 1142 info = A.new_empty(sizes, dtype=torch.int) 1143 1144 return LU, pivots, info 1145 1146 1147@register_meta([aten.linalg_lu_solve.default, aten.linalg_lu_solve.out]) 1148@out_wrapper() 1149def linalg_lu_solve_meta( 1150 LU: Tensor, 1151 pivots: Tensor, 1152 B: Tensor, 1153 *, 1154 left: bool = True, 1155 adjoint: bool = False, 1156) -> Tensor: 1157 # dtype 1158 checkFloatingOrComplex(LU, "torch.linalg.lu_solve") 1159 torch._check( 1160 LU.dtype == B.dtype, 1161 lambda: ( 1162 f"linalg.lu_solve: Expected LU and B to have the same dtype, " 1163 f"but found LU of type {LU.dtype} and B of type {B.dtype} instead" 1164 ), 1165 ) 1166 torch._check( 1167 pivots.dtype == torch.int, 1168 lambda: "linalg.lu_solve: pivots should be a Tensor of scalar type torch.int32", 1169 ) 1170 1171 # matrix shapes 1172 squareCheckInputs(LU, "torch.linalg.lu_solve") 1173 checkInputsSolver(LU, B, left, "linalg.lu_solve") 1174 torch._check( 1175 LU.size(-1) == pivots.size(-1), 1176 lambda: "linalg.lu_solve: Number of pivots per batch should be same as the dimension of the matrix", 1177 ) 1178 1179 # batches 1180 torch._check( 1181 LU.shape[:-1] == pivots.shape, 1182 lambda: ( 1183 f"linalg.lu_solve: Expected LU.shape[:-1] and pivots.shape to be the same, " 1184 f"but got pivots with shape {pivots.shape} instead" 1185 ), 1186 ) 1187 1188 B_broadcast_size, _ = _linalg_broadcast_batch_dims(B, LU) 1189 1190 result = torch.empty_strided( 1191 size=B_broadcast_size, 1192 stride=make_contiguous_strides_for(B_broadcast_size, row_major=not left), 1193 dtype=B.dtype, 1194 device=B.device, 1195 ) 1196 1197 if result.numel() != 0 and not left: 1198 if result.is_complex(): 1199 result = result.conj() 1200 1201 return result 1202 1203 1204@register_meta(aten.lu_unpack) 1205@out_wrapper("P", "L", "U") 1206def lu_unpack_meta( 1207 LU: Tensor, 1208 pivots: Tensor, 1209 unpack_data: bool = True, 1210 unpack_pivots: bool = True, 1211) -> Tuple[Tensor, Tensor, Tensor]: 1212 torch._check( 1213 LU.ndim >= 2, 1214 lambda: f"torch.lu_unpack: Expected tensor with 2 or more dimensions. Got size: {LU.shape} instead", 1215 ) 1216 if unpack_pivots: 1217 torch._check( 1218 pivots.dtype == torch.int32, 1219 lambda: ( 1220 "torch.lu_unpack: LU_pivots is expected to be a contiguous tensor of torch.int32 dtype.\n" 1221 "Note: this function is intended to be used with the output produced by torch.linalg.lu_factor" 1222 ), 1223 ) 1224 sizes = list(LU.shape) 1225 m = sizes[-2] 1226 n = sizes[-1] 1227 k = min(m, n) 1228 sizes[-1] = m 1229 if unpack_pivots: 1230 P = LU.new_empty(sizes) 1231 else: 1232 P = LU.new_empty([0]) 1233 if unpack_data: 1234 sizes[-1] = k 1235 L = LU.new_empty(sizes) 1236 sizes[-2] = k 1237 sizes[-1] = n 1238 U = LU.new_empty(sizes) 1239 else: 1240 L = LU.new_empty([0]) 1241 U = LU.new_empty([0]) 1242 return P, L, U 1243 1244 1245# parse the "mode" param in linalg_qr: return a tuple of bools (compute_q, reduced) 1246def _parse_qr_mode(mode: str) -> Tuple[bool, bool]: 1247 if mode == "reduced": 1248 compute_q = True 1249 reduced = True 1250 elif mode == "complete": 1251 compute_q = True 1252 reduced = False 1253 elif mode == "r": 1254 compute_q = False 1255 reduced = True # this is actually irrelevant in this mode 1256 else: 1257 torch._check( 1258 False, 1259 lambda: ( 1260 f"qr received unrecognized mode '{mode}' " 1261 f"but expected one of 'reduced' (default), 'r', or 'complete'" 1262 ), 1263 ) 1264 return compute_q, reduced # type: ignore[possibly-undefined] 1265 1266 1267@register_meta([aten.linalg_qr.default, aten.linalg_qr.out]) 1268@out_wrapper("Q", "R") 1269def linalg_qr_meta(A: Tensor, mode: str = "reduced") -> Tuple[Tensor, Tensor]: 1270 checkIsMatrix(A, "linalg.qr") 1271 checkFloatingOrComplex(A, "linalg.qr") 1272 1273 compute_q, reduced_mode = _parse_qr_mode(mode) 1274 1275 m = A.shape[-2] 1276 n = A.shape[-1] 1277 k = min(m, n) 1278 1279 if compute_q: 1280 Q_shape = list(A.shape) 1281 Q_shape[-1] = k if reduced_mode else m 1282 Q = A.new_empty(Q_shape) 1283 Q.as_strided_(Q_shape, make_contiguous_strides_for(Q_shape, row_major=False)) 1284 else: 1285 Q = A.new_empty([0]) 1286 1287 # For readability 1288 R_shape = list(A.shape) 1289 R_shape[-2] = k if reduced_mode or not compute_q else m 1290 R = A.new_empty(R_shape) 1291 R.as_strided_(R_shape, make_contiguous_strides_for(R_shape, row_major=False)) 1292 return Q, R 1293 1294 1295@register_meta([aten._linalg_slogdet.default, aten._linalg_slogdet.sign]) 1296@out_wrapper("sign", "logabsdet", "LU", "pivots") 1297def _linalg_slogdet(A: Tensor) -> Tuple[Tensor, Tensor, Tensor, Tensor]: 1298 squareCheckInputs(A, "linalg.slogdet") 1299 checkFloatingOrComplex(A, "linalg.slogdet", False) 1300 shape = A.shape 1301 sign = A.new_empty(shape[:-2]) 1302 logabsdet = A.new_empty(shape[:-2], dtype=toRealValueType(A.dtype)) 1303 LU = torch.empty_strided( 1304 size=shape, 1305 stride=make_contiguous_strides_for(shape, False), 1306 dtype=A.dtype, 1307 device=A.device, 1308 ) 1309 pivots = A.new_empty(shape[:-1], dtype=torch.int32) 1310 return sign, logabsdet, LU, pivots 1311 1312 1313# From aten/src/ATen/native/BatchLinearAlgebra.cpp 1314# NOTE: matching defaults in aten/src/ATen/native/native_functions.yaml 1315@register_meta(aten._linalg_svd.default) 1316def _linalg_svd_meta( 1317 A: Tensor, 1318 full_matrices: bool = False, 1319 compute_uv: bool = True, 1320 driver: Optional[str] = None, 1321): 1322 checkIsMatrix(A, "linalg.svd") 1323 checkFloatingOrComplex(A, "linalg.svd") 1324 1325 batch_dims = list(A.shape[:-2]) 1326 m = A.shape[-2] 1327 n = A.shape[-1] 1328 k = min(m, n) 1329 1330 if compute_uv: 1331 U_shape = batch_dims + [m, m if full_matrices else k] 1332 U = A.new_empty(U_shape) 1333 U.as_strided_(U_shape, make_contiguous_strides_for(U_shape, row_major=False)) 1334 1335 V_shape = batch_dims + [n if full_matrices else k, n] 1336 V = A.new_empty(V_shape) 1337 # NB: This checks for CUDA since there is no way to check for cuSolver. 1338 # Also, this might not work correctly on CPU when fake_device is not 1339 # available as device_hint just defaults to CUDA in that case. See 1340 # _linalg_svd meta in core. 1341 is_cuda = device_hint(A) == "cuda" 1342 V.as_strided_(V_shape, make_contiguous_strides_for(V_shape, row_major=is_cuda)) 1343 else: 1344 # doesn't matter 1345 U = A.new_empty([0]) 1346 V = A.new_empty([0]) 1347 1348 # S is always real, even when A is complex. 1349 S = A.new_empty(batch_dims + [k], dtype=toRealValueType(A.dtype)) 1350 return U, S, V 1351 1352 1353def _linalg_broadcast_batch_dims( 1354 arg1: Tensor, 1355 arg2: Tensor, 1356) -> Tuple[List[int], List[int]]: 1357 # broadcast the batch dimensions of arg1 and arg2. 1358 arg1_batch_sizes = arg1.shape[:-2] 1359 arg2_batch_sizes = arg2.shape[:-2] 1360 expand_batch_portion = _broadcast_shapes(arg1_batch_sizes, arg2_batch_sizes) 1361 1362 arg1_expand_size = list(expand_batch_portion) 1363 arg1_expand_size += [arg1.size(-2), arg1.size(-1)] 1364 1365 arg2_expand_size = list(expand_batch_portion) 1366 arg2_expand_size += [arg2.size(-2), arg2.size(-1)] 1367 return arg1_expand_size, arg2_expand_size 1368 1369 1370def _linalg_broadcast_batch_dims_name( 1371 arg1: Tensor, 1372 arg2: Tensor, 1373 name: Optional[str], 1374) -> Tuple[Tensor, Tensor]: 1375 # If there's no name we assume we don't want to check the errors 1376 if name: 1377 linearSolveCheckInputs(arg1, arg2, name) 1378 1379 arg1_expand_size, arg2_expand_size = _linalg_broadcast_batch_dims(arg1, arg2) 1380 1381 arg1_broadcasted = ( 1382 arg1 if arg1_expand_size == arg1.shape else arg1.expand(arg1_expand_size) 1383 ) 1384 arg2_broadcasted = ( 1385 arg2 if arg2_expand_size == arg2.shape else arg2.expand(arg2_expand_size) 1386 ) 1387 return arg1_broadcasted, arg2_broadcasted 1388 1389 1390def linalg_solve_is_vector_rhs(input: Tensor, other: Tensor) -> bool: 1391 expected_batched_rhs_shape = input.shape[:-1] 1392 vector_case = other.ndim == 1 or ( 1393 input.ndim - 1 == other.ndim and other.shape == expected_batched_rhs_shape 1394 ) 1395 return vector_case 1396 1397 1398@register_meta(aten._linalg_solve_ex) 1399def _linalg_solve_ex( 1400 A: Tensor, 1401 B: Tensor, 1402 *, 1403 left: bool = True, 1404 check_errors: bool = False, 1405 result: Optional[Tensor] = None, 1406 LU: Optional[Tensor] = None, 1407 pivots: Optional[Tensor] = None, 1408 info: Optional[Tensor] = None, 1409) -> Tuple[Tensor, Tensor, Tensor, Tensor]: 1410 checkFloatingOrComplex(A, "linalg.solve") 1411 torch._check( 1412 A.dtype == B.dtype, 1413 lambda: ( 1414 f"linalg.solve: Expected A and B to have the same dtype, but found A of type " 1415 f"{A.dtype} and B of type {B.dtype} instead" 1416 ), 1417 ) 1418 vector_case = linalg_solve_is_vector_rhs(A, B) 1419 B_ = B.unsqueeze(-1) if vector_case else B 1420 checkInputsSolver(A, B_, left, "linalg.solve") 1421 B_broad_shape, _ = _linalg_broadcast_batch_dims(B_, A) 1422 torch._check( 1423 left or not vector_case, 1424 lambda: ( 1425 "linalg.solve: Vector broadcasting of the left hand side is not supported for left=False. " 1426 "In this case linalg.solve is equivalent to B / A.squeeze(-1)" 1427 ), 1428 ) 1429 result_shape = B_broad_shape[:-1] if vector_case else B_broad_shape 1430 result_ = torch.empty_strided( 1431 size=result_shape, 1432 stride=make_contiguous_strides_for(result_shape, not left), 1433 dtype=B.dtype, 1434 device=B.device, 1435 ) 1436 shape = A.shape 1437 ndim = A.ndim 1438 LU_ = torch.empty_strided( 1439 size=shape, 1440 stride=make_contiguous_strides_for(shape, False), 1441 dtype=A.dtype, 1442 device=A.device, 1443 ) 1444 pivots_ = A.new_empty(shape[:-1], dtype=torch.int32) 1445 info_ = A.new_empty(shape[:-2], dtype=torch.int32) 1446 out = (result, LU, pivots, info) 1447 res = (result_, LU_, pivots_, info_) 1448 if all(x is not None for x in out): 1449 for r, o in zip(res, out): 1450 # resize and copy operations are done in-place 1451 _maybe_resize_out(o, r.shape) # type: ignore[arg-type] 1452 # strides are not copied in out_wrapper 1453 o.as_strided_(r.shape, r.stride()) # type: ignore[union-attr] 1454 _safe_copy_out(copy_from=r, copy_to=o, exact_dtype=False) # type: ignore[arg-type] 1455 return res 1456 1457 1458@register_meta([aten.linalg_solve_triangular.default, aten.linalg_solve_triangular.out]) 1459def linalg_solve_triangular_meta( 1460 A: Tensor, 1461 B: Tensor, 1462 *, 1463 upper: bool, 1464 left: bool = True, 1465 unitriangular: bool = False, 1466 out: Optional[Tensor] = None, 1467) -> Tensor: 1468 if out is None: 1469 out = A.new_empty([0]) 1470 assert isinstance(out, TensorLike) 1471 checkInputsSolver(A, B, left, "linalg.solve_triangular") 1472 B_, A_ = _linalg_broadcast_batch_dims_name(B, A, None) 1473 avoid_copy_A = A_.transpose(-2, -1).is_contiguous() and A_.is_conj() 1474 if avoid_copy_A: 1475 out = _maybe_resize_out(out, B_.shape) 1476 else: 1477 # reimplementation of resize_output with result F-contig 1478 if _resize_output_check(out, B_.shape): 1479 out.resize_(B_.transpose(-2, -1).shape) 1480 out.transpose_(-2, -1) 1481 return out # type: ignore[return-value] 1482 1483 1484@register_meta(aten.triangular_solve) 1485@out_wrapper("solution", "cloned_coefficient") 1486def triangular_solve_meta( 1487 self: Tensor, 1488 A: Tensor, 1489 upper: bool = True, 1490 transpose: bool = False, 1491 unitriangular: bool = False, 1492) -> Tuple[Tensor, Tensor]: 1493 torch._check( 1494 self.ndim >= 2, 1495 lambda: ( 1496 f"torch.triangular_solve: Expected b to have at least 2 dimensions, " 1497 f"but it has {self.ndim} dimensions instead" 1498 ), 1499 ) 1500 torch._check( 1501 A.ndim >= 2, 1502 lambda: ( 1503 f"torch.triangular_solve: Expected A to have at least 2 dimensions, " 1504 f"but it has {A.ndim} dimensions instead" 1505 ), 1506 ) 1507 1508 linearSolveCheckInputs(self, A, "triangular_solve") 1509 1510 if A.layout == torch.strided: 1511 self_broadcast_size, A_broadcast_size = _linalg_broadcast_batch_dims(self, A) 1512 solution = torch.empty_strided( 1513 size=self_broadcast_size, 1514 stride=make_contiguous_strides_for(self_broadcast_size, row_major=False), 1515 dtype=self.dtype, 1516 device=self.device, 1517 ) 1518 cloned_coefficient = torch.empty_strided( 1519 size=A_broadcast_size, 1520 stride=make_contiguous_strides_for(A_broadcast_size, row_major=False), 1521 dtype=A.dtype, 1522 device=A.device, 1523 ) 1524 elif A.layout == torch.sparse_csr or A.layout == torch.sparse_bsr: 1525 solution = torch.empty_like(self) 1526 cloned_coefficient = self.new_empty([0]) 1527 else: 1528 torch._check(False, lambda: "triangular_solve: Got an unexpected layout.") 1529 return solution, cloned_coefficient # type: ignore[possibly-undefined] 1530 1531 1532# From aten/src/ATen/native/LinearAlgebra.cpp 1533@register_meta(aten._linalg_det.default) 1534def _linalg_det_meta(A): 1535 squareCheckInputs(A, "linalg.det") 1536 checkFloatingOrComplex(A, "linalg.det") 1537 1538 det = A.new_empty(A.shape[:-2]) 1539 1540 LU = A.new_empty(A.shape) 1541 LU.as_strided_(A.shape, make_contiguous_strides_for(A.shape, row_major=False)) 1542 1543 pivots = A.new_empty(A.shape[:-1], dtype=torch.int32) 1544 return det, LU, pivots 1545 1546 1547@register_meta(aten.ormqr) 1548@out_wrapper() 1549def ormqr( 1550 input: Tensor, 1551 tau: Tensor, 1552 other: Tensor, 1553 left: bool = True, 1554 transpose: bool = False, 1555) -> Tensor: 1556 torch._check( 1557 input.ndim >= 2, lambda: "torch.ormqr: input must have at least 2 dimensions." 1558 ) 1559 torch._check( 1560 other.ndim >= 2, lambda: "torch.ormqr: other must have at least 2 dimensions." 1561 ) 1562 1563 left_size_condition = -2 if left else -1 1564 torch._check( 1565 other.shape[left_size_condition] >= tau.shape[-1], 1566 lambda: f"torch.ormqr: other.shape[{left_size_condition}] must be greater than or equal to tau.shape[-1]", 1567 ) 1568 torch._check( 1569 other.shape[left_size_condition] == input.shape[-2], 1570 lambda: f"torch.ormqr: other.shape[{left_size_condition}] must be equal to input.shape[-2]", 1571 ) 1572 1573 torch._check( 1574 tau.shape[-1] <= input.shape[-1], 1575 lambda: "torch.ormqr: tau.shape[-1] must be less than or equal to input.shape[-1]", 1576 ) 1577 1578 torch._check( 1579 input.ndim - tau.ndim == 1, 1580 lambda: ( 1581 f"torch.ormqr: Expected tau to have one dimension less than input, " 1582 f"but got tau.ndim equal to {tau.ndim} and input.ndim is equal to {input.ndim}" 1583 ), 1584 ) 1585 torch._check( 1586 input.ndim == other.ndim, 1587 lambda: ( 1588 f"torch.ormqr: Expected other to have the same number of dimensions as input, " 1589 f"but got other.ndim equal to {other.ndim} and input.ndim is equal to {input.ndim}" 1590 ), 1591 ) 1592 1593 if input.ndim > 2: 1594 expected_batch_shape = input.shape[:-2] 1595 actual_batch_tau_shape = tau.shape[:-1] 1596 torch._check( 1597 actual_batch_tau_shape == expected_batch_shape, 1598 lambda: ( 1599 f"torch.ormqr: Expected batch dimensions of tau to be " 1600 f"equal to input.shape[:-2], but got {actual_batch_tau_shape}" 1601 ), 1602 ) 1603 1604 actual_batch_other_shape = other.shape[:-2] 1605 torch._check( 1606 actual_batch_other_shape == expected_batch_shape, 1607 lambda: ( 1608 f"torch.ormqr: Expected batch dimensions of other to be " 1609 f"equal to input.shape[:-2], but got {actual_batch_other_shape}" 1610 ), 1611 ) 1612 1613 torch._check( 1614 tau.dtype == input.dtype, 1615 lambda: ( 1616 f"torch.ormqr: Expected input and tau to have the same dtype, " 1617 f"but input has dtype {input.dtype} and tau has dtype {tau.dtype}" 1618 ), 1619 ) 1620 torch._check( 1621 other.dtype == input.dtype, 1622 lambda: ( 1623 f"torch.ormqr: Expected input and other to have the same dtype, " 1624 f"but input has dtype {input.dtype} and other has dtype {other.dtype}" 1625 ), 1626 ) 1627 1628 checkSameDevice("torch.ormqr", tau, input, "tau") 1629 checkSameDevice("torch.ormqr", other, input, "other") 1630 1631 return torch.empty_strided( 1632 size=other.shape, 1633 stride=make_contiguous_strides_for(other.shape, row_major=False), 1634 dtype=other.dtype, 1635 device=other.device, 1636 ) 1637 1638 1639def _padding_check_valid_input(input, padding, *, dim): 1640 torch._check( 1641 len(padding) == 2 * dim, 1642 lambda: f"padding size is expected to be {2 * dim}, but got: {len(padding)}", 1643 ) 1644 1645 input_dim = input.ndim 1646 1647 is_batch_mode = input_dim == (dim + 2) 1648 1649 valid_batch_mode = is_batch_mode 1650 valid_non_batch_mode = not is_batch_mode 1651 1652 if is_batch_mode: 1653 # allow batch size of 0-dim. 1654 for d in range(1, input_dim): 1655 valid_batch_mode = valid_batch_mode and input.size(d) != 0 1656 else: 1657 for d in range(0, input_dim): 1658 valid_non_batch_mode = valid_non_batch_mode and input.size(d) != 0 1659 1660 # allow empty batch size but not other dimensions. 1661 torch._check( 1662 valid_batch_mode or valid_non_batch_mode, 1663 lambda: ( 1664 f"Expected {dim + 1}D or {dim + 2}D (batch mode) tensor with possibly 0 batch size " 1665 f"and other non-zero dimensions for input, but got: {input.shape}" 1666 ), 1667 ) 1668 1669 1670def _pad1d_common(input, padding, *, is_reflection): 1671 dim_plane = 0 1672 dim_w = 1 1673 nbatch = 1 1674 1675 if input.ndim == 3: 1676 nbatch = input.size(0) 1677 dim_w += 1 1678 dim_plane += 1 1679 1680 _padding_check_valid_input(input, padding, dim=1) 1681 1682 pad_l, pad_r = padding 1683 1684 nplane = input.size(dim_plane) 1685 input_w = input.size(dim_w) 1686 output_w = input_w + pad_l + pad_r 1687 1688 if is_reflection: 1689 torch._check( 1690 pad_l < input_w and pad_r < input_w, 1691 lambda: ( 1692 f"Argument #4: Padding size should be less than the corresponding input dimension, " 1693 f"but got: padding ({pad_l}, {pad_r}) at dimension {dim_w} of input {input.shape}" 1694 ), 1695 ) 1696 1697 torch._check( 1698 output_w >= 1, 1699 lambda: f"input (W: {input_w}) is too small. Calculated output W: {output_w}", 1700 ) 1701 1702 if input.ndim == 2: 1703 return input.new_empty((nplane, output_w)) 1704 else: 1705 return input.new_empty((nbatch, nplane, output_w)) 1706 1707 1708@register_meta(aten.reflection_pad1d) 1709@out_wrapper() 1710def meta_reflection_pad1d(input, padding): 1711 return _pad1d_common(input, padding, is_reflection=True) 1712 1713 1714@register_meta(aten.replication_pad1d) 1715@out_wrapper() 1716def meta_replication_pad1d(input, padding): 1717 return _pad1d_common(input, padding, is_reflection=False) 1718 1719 1720def _pad1d_backward_common(grad_output, input, padding, *, is_reflection): 1721 dim_w = 1 1722 if not is_reflection: 1723 torch._check(len(padding) == 2, lambda: "padding size is expected to be 2") 1724 1725 if input.ndim == 3: 1726 dim_w += 1 1727 1728 pad_l, pad_r = padding 1729 1730 input_w = input.size(dim_w) 1731 output_w = input_w + pad_l + pad_r 1732 1733 if is_reflection: 1734 torch._check( 1735 pad_l < input_w and pad_r < input_w, 1736 lambda: ( 1737 f"Argument #4: Padding size should be less than the corresponding input dimension, " 1738 f"but got: padding ({pad_l}, {pad_r}) at dimension {dim_w} of input {input.shape}" 1739 ), 1740 ) 1741 1742 torch._check( 1743 output_w == grad_output.size(dim_w), 1744 lambda: f"grad_output width unexpected. Expected: {output_w}, Got: {grad_output.size(dim_w)}", 1745 ) 1746 1747 return input.new_empty(input.shape) 1748 1749 1750@register_meta(aten.reflection_pad1d_backward) 1751@out_wrapper("grad_input") 1752def meta_reflection_pad1d_backward(grad_output, input, padding): 1753 return _pad1d_backward_common(grad_output, input, padding, is_reflection=True) 1754 1755 1756@register_meta(aten.replication_pad1d_backward) 1757@out_wrapper("grad_input") 1758def meta_replication_pad1d_backward(grad_output, input, padding): 1759 return _pad1d_backward_common(grad_output, input, padding, is_reflection=False) 1760 1761 1762def _pad2d_common(input, padding, *, is_reflection): 1763 dim_w = 2 1764 dim_h = 1 1765 dim_slices = 0 1766 nbatch = 1 1767 1768 _padding_check_valid_input(input, padding, dim=2) 1769 1770 ndim = input.ndim 1771 if ndim == 4: 1772 nbatch = input.size(0) 1773 dim_w += 1 1774 dim_h += 1 1775 dim_slices += 1 1776 1777 pad_l, pad_r, pad_t, pad_b = padding 1778 1779 nplane = input.size(dim_slices) 1780 input_h = input.size(dim_h) 1781 input_w = input.size(dim_w) 1782 output_h = input_h + pad_t + pad_b 1783 output_w = input_w + pad_l + pad_r 1784 1785 if is_reflection: 1786 torch._check( 1787 pad_l < input_w and pad_r < input_w, 1788 lambda: ( 1789 f"Argument #4: Padding size should be less than the corresponding input dimension, " 1790 f"but got: padding ({pad_l}, {pad_r}) at dimension {dim_w} of input {input.shape}" 1791 ), 1792 ) 1793 torch._check( 1794 pad_t < input_h and pad_b < input_h, 1795 lambda: ( 1796 f"Argument #6: Padding size should be less than the corresponding input dimension, " 1797 f"but got: padding ({pad_t}, {pad_b}) at dimension {dim_h} of input {input.shape}" 1798 ), 1799 ) 1800 1801 torch._check( 1802 output_w >= 1 or output_h >= 1, 1803 lambda: ( 1804 f"input (H: {input_h} W: {input_w}) is too small. " 1805 f"Calculated output H: {output_h} W: {output_w}" 1806 ), 1807 ) 1808 1809 if input.ndim == 3: 1810 return input.new_empty((nplane, output_h, output_w)) 1811 else: 1812 return input.new_empty((nbatch, nplane, output_h, output_w)) 1813 1814 1815@register_meta(aten.reflection_pad2d) 1816@out_wrapper() 1817def meta_reflection_pad2d(input, padding): 1818 return _pad2d_common(input, padding, is_reflection=True) 1819 1820 1821@register_meta(aten.replication_pad2d) 1822@out_wrapper() 1823def meta_replication_pad2d(input, padding): 1824 return _pad2d_common(input, padding, is_reflection=False) 1825 1826 1827@register_meta( 1828 [ 1829 aten.reflection_pad2d_backward.default, 1830 aten.reflection_pad2d_backward.grad_input, 1831 aten.replication_pad2d_backward.default, 1832 aten.replication_pad2d_backward.grad_input, 1833 ] 1834) 1835@out_wrapper("grad_input") 1836def meta_pad2d_backward(grad_output, self, padding): 1837 dim_w = 2 1838 dim_h = 1 1839 dim_plane = 0 1840 nbatch = 1 1841 1842 self_shape = self.shape 1843 if self.dim() == 4: 1844 nbatch = self_shape[0] 1845 dim_w += 1 1846 dim_h += 1 1847 dim_plane += 1 1848 1849 pad_l, pad_r, pad_t, pad_b = padding 1850 1851 nplane = self_shape[dim_plane] 1852 input_h = self_shape[dim_h] 1853 input_w = self_shape[dim_w] 1854 output_h = input_h + pad_t + pad_b 1855 output_w = input_w + pad_l + pad_r 1856 1857 torch._check( 1858 output_w == grad_output.size(dim_w), 1859 lambda: f"grad_output width unexpected. Expected: {output_w}, Got: {grad_output.size(dim_w)}", 1860 ) 1861 torch._check( 1862 output_h == grad_output.size(dim_h), 1863 lambda: f"grad_output height unexpected. Expected: {output_h}, Got: {grad_output.size(dim_h)}", 1864 ) 1865 return self.new_empty(self.shape) 1866 1867 1868def _pad3d_common(input, padding, *, is_reflection): 1869 dim_w = 3 1870 dim_h = 2 1871 dim_d = 1 1872 dim_plane = 0 1873 1874 _padding_check_valid_input(input, padding, dim=3) 1875 1876 batch_mode = input.ndim == 5 1877 if batch_mode: 1878 nbatch = input.size(0) 1879 dim_w += 1 1880 dim_h += 1 1881 dim_d += 1 1882 dim_plane += 1 1883 1884 pad_l, pad_r, pad_t, pad_b, pad_f, pad_bk = padding 1885 1886 nplane = input.size(dim_plane) 1887 input_d = input.size(dim_d) 1888 input_h = input.size(dim_h) 1889 input_w = input.size(dim_w) 1890 output_d = input_d + pad_f + pad_bk 1891 output_h = input_h + pad_t + pad_b 1892 output_w = input_w + pad_l + pad_r 1893 1894 if is_reflection: 1895 torch._check( 1896 pad_l < input_w and pad_r < input_w, 1897 lambda: ( 1898 f"Argument #4: Padding size should be less than the corresponding input dimension, " 1899 f"but got: padding ({pad_l}, {pad_r}) at dimension {dim_w} of input {input.shape}" 1900 ), 1901 ) 1902 torch._check( 1903 pad_t < input_h and pad_b < input_h, 1904 lambda: ( 1905 f"Argument #6: Padding size should be less than the corresponding input dimension, " 1906 f"but got: padding ({pad_t}, {pad_b}) at dimension {dim_h} of input {input.shape}" 1907 ), 1908 ) 1909 torch._check( 1910 pad_f < input_d and pad_bk < input_d, 1911 lambda: ( 1912 f"Argument #8: Padding size should be less than the corresponding input dimension, " 1913 f"but got: padding ({pad_f}, {pad_bk}) at dimension {dim_d} of input {input.shape}" 1914 ), 1915 ) 1916 1917 torch._check( 1918 output_w >= 1 or output_h >= 1 or output_d >= 1, 1919 lambda: ( 1920 f"input (D: {input_d} H: {input_h} W: {input_w}) is too small. " 1921 f"Calculated output D: {output_d} H: {output_h} W: {output_w}" 1922 ), 1923 ) 1924 1925 if batch_mode: 1926 return input.new_empty((nbatch, nplane, output_d, output_h, output_w)) # type: ignore[possibly-undefined] 1927 else: 1928 return input.new_empty((nplane, output_d, output_h, output_w)) 1929 1930 1931@register_meta(aten.reflection_pad3d) 1932@out_wrapper() 1933def meta_reflection_pad3d(input, padding): 1934 return _pad3d_common(input, padding, is_reflection=True) 1935 1936 1937@register_meta(aten.replication_pad3d) 1938@out_wrapper() 1939def meta_replication_pad3d(input, padding): 1940 return _pad3d_common(input, padding, is_reflection=False) 1941 1942 1943@register_meta( 1944 [ 1945 aten.reflection_pad3d_backward.default, 1946 aten.reflection_pad3d_backward.grad_input, 1947 aten.replication_pad3d_backward.default, 1948 aten.replication_pad3d_backward.grad_input, 1949 ] 1950) 1951@out_wrapper("grad_input") 1952def meta_pad3d_backward(grad_output, input, padding): 1953 torch._check(len(padding) == 6, lambda: "padding size is expected to be 6") 1954 assert input.ndim > 3 1955 assert grad_output.ndim == input.ndim 1956 1957 dim_w = 3 1958 dim_h = 2 1959 dim_d = 1 1960 1961 if input.ndim == 5: 1962 dim_w += 1 1963 dim_h += 1 1964 dim_d += 1 1965 1966 pad_l, pad_r, pad_t, pad_b, pad_f, pad_bk = padding 1967 1968 input_d = input.size(dim_d) 1969 input_h = input.size(dim_h) 1970 input_w = input.size(dim_w) 1971 output_d = input_d + pad_f + pad_bk 1972 output_h = input_h + pad_t + pad_b 1973 output_w = input_w + pad_l + pad_r 1974 1975 torch._check( 1976 output_w == grad_output.size(dim_w), 1977 lambda: f"grad_output width unexpected. Expected: {output_w}, Got: {grad_output.size(dim_w)}", 1978 ) 1979 torch._check( 1980 output_h == grad_output.size(dim_h), 1981 lambda: f"grad_output height unexpected. Expected: {output_h}, Got: {grad_output.size(dim_h)}", 1982 ) 1983 torch._check( 1984 output_d == grad_output.size(dim_d), 1985 lambda: f"grad_output depth unexpected. Expected: {output_d}, Got: {grad_output.size(dim_d)}", 1986 ) 1987 1988 return input.new_empty(input.shape) 1989 1990 1991@register_meta(aten._pdist_forward) 1992@out_wrapper() 1993def meta__pdist_forward(self: Tensor, p: float = 2) -> Tensor: 1994 torch._check( 1995 self.is_contiguous(), lambda: "_pdist_forward requires contiguous input" 1996 ) 1997 n = self.size(0) 1998 if n <= 1: 1999 return self.new_empty([0]).to(memory_format=torch.legacy_contiguous_format) # type: ignore[call-overload] 2000 else: 2001 return self.new_empty((n * (n - 1) // 2,)).to( 2002 memory_format=torch.legacy_contiguous_format 2003 ) # type: ignore[call-overload] 2004 2005 2006@register_meta(aten._pdist_backward) 2007@out_wrapper() 2008def meta__pdist_backward(grad: Tensor, self: Tensor, p: float, pdist: Tensor) -> Tensor: 2009 torch._check( 2010 self.is_contiguous(), lambda: "_pdist_backward requires self to be contiguous" 2011 ) 2012 torch._check( 2013 pdist.is_contiguous(), lambda: "_pdist_backward requires pdist to be contiguous" 2014 ) 2015 return torch.empty_like(self, memory_format=torch.legacy_contiguous_format) 2016 2017 2018@register_meta([aten.baddbmm.default, aten.baddbmm.out]) 2019@out_wrapper() 2020def meta_baddbmm(self, batch1, batch2, *, beta=1, alpha=1): 2021 dim1 = batch1.size(0) 2022 dim2 = batch1.size(1) 2023 dim3 = batch2.size(2) 2024 self = self.expand((dim1, dim2, dim3)) 2025 torch._check(batch1.dim() == 3, lambda: "batch1 must be a 3D tensor") 2026 torch._check(batch2.dim() == 3, lambda: "batch2 must be a 3D tensor") 2027 torch._check( 2028 self.dtype == batch1.dtype == batch2.dtype, 2029 lambda: f"Input dtypes must be the same, got: input: {self.dtype}, batch1: {batch1.dtype}, batch2: {batch2.dtype}", 2030 ) 2031 batch1_sizes = batch1.shape 2032 batch2_sizes = batch2.shape 2033 bs = batch1_sizes[0] 2034 contraction_size = batch1_sizes[2] 2035 torch._check( 2036 batch2_sizes[0] == bs and batch2_sizes[1] == contraction_size, 2037 lambda: ( 2038 f"Expected size for first two dimensions of batch2 tensor to be: " 2039 f"[{bs}, {contraction_size}] but got: [{batch2_sizes[0]}, {batch2_sizes[1]}]." 2040 ), 2041 ) 2042 return self.new_empty(self.size()) 2043 2044 2045@register_meta([aten.bernoulli.default, aten.bernoulli.out]) 2046@out_wrapper() 2047def meta_bernoulli(self, *, generator=None): 2048 # https://github.com/pytorch/pytorch/issues/88612 2049 return torch.empty_like(self).contiguous() 2050 2051 2052@register_meta(aten.bernoulli_.float) 2053def meta_bernoulli_(self, p=0.5, generator=None): 2054 return self 2055 2056 2057@register_meta(aten.bernoulli.p) 2058def meta_bernoulli_p(self, p=0.5, generator=None): 2059 # https://github.com/pytorch/pytorch/issues/88612 2060 return torch.empty_like(self).contiguous() 2061 2062 2063@register_meta([aten.poisson.default, aten.poisson.out]) 2064@out_wrapper() 2065def meta_poisson(self, generator=None): 2066 return torch.empty_like(self) 2067 2068 2069@register_meta(aten._fused_moving_avg_obs_fq_helper.default) 2070def meta__fused_moving_avg_obs_fq_helper( 2071 self, 2072 observer_on, 2073 fake_quant_on, 2074 running_min, 2075 running_max, 2076 scale, 2077 zero_point, 2078 averaging_const, 2079 quant_min, 2080 quant_max, 2081 ch_axis, 2082 per_row_fake_quant=False, 2083 symmetric_quant=False, 2084): 2085 torch._check( 2086 ch_axis < self.dim(), 2087 lambda: "Error in fused_moving_avg_obs_fake_quant_cpu: ch_axis must be < self.dim()", 2088 ) 2089 mask = torch.empty_like(self, dtype=torch.bool) 2090 return (torch.empty_like(self), mask) 2091 2092 2093@register_meta(aten.mm) 2094@out_wrapper() 2095def meta_mm(a, b): 2096 torch._check(a.dim() == 2, lambda: "a must be 2D") 2097 torch._check(b.dim() == 2, lambda: "b must be 2D") 2098 N, M1 = a.shape 2099 M2, P = b.shape 2100 torch._check( 2101 M1 == M2, 2102 lambda: f"a and b must have same reduction dim, but got [{N}, {M1}] X [{M2}, {P}].", 2103 ) 2104 return a.new_empty(N, P) 2105 2106 2107def _compute_reduction_shape(self, dims, keepdim): 2108 if keepdim: 2109 return tuple(self.shape[i] if i not in dims else 1 for i in range(self.ndim)) 2110 2111 return utils.compute_reduction_output_shape(self.shape, dims) 2112 2113 2114# FakeTensors (meta tensors with a device) will report device as meta 2115# when running meta kernels. Here, access the "fake device" of FakeTensor if it 2116# exists so meta kernels which have diverge per device will be more 2117# accurate when run with FakeTensors 2118def device_hint(tensor) -> "str": 2119 if isinstance(tensor, torch._subclasses.FakeTensor): 2120 return tensor.fake_device.type 2121 else: 2122 return "cuda" # default to cuda 2123 2124 2125def calc_conv_nd_return_shape( 2126 input_tensor: torch.Tensor, 2127 weight: torch.Tensor, 2128 stride: Union[List[int], int], 2129 padding: Union[List[int], int], 2130 dilation: Union[List[int], int], 2131 is_transposed: bool, 2132 groups: int, 2133 output_padding: Optional[Union[List[int], int]] = None, 2134): 2135 def _formula(ln: int, p: int, d: int, k: int, s: int) -> int: 2136 """ 2137 Formula to apply to calculate the length of some dimension of the output 2138 2139 See: https://pytorch.org/docs/stable/generated/torch.nn.Conv2d.html 2140 2141 Args: 2142 ln: length of the dimension 2143 p: padding in that dim 2144 d: dilation in that dim 2145 k: kernel size in that dim 2146 s: stride in that dim 2147 Returns: 2148 The output length 2149 """ 2150 return (ln + 2 * p - d * (k - 1) - 1) // s + 1 2151 2152 def _formula_transposed(ln: int, p: int, d: int, k: int, s: int, op: int) -> int: 2153 """ 2154 Formula to apply to calculate the length of some dimension of the output 2155 if transposed convolution is used. 2156 See: https://pytorch.org/docs/stable/generated/torch.nn.ConvTranspose2d.html 2157 2158 Args: 2159 ln: length of the dimension 2160 p: padding in that dim 2161 d: dilation in that dim 2162 k: kernel size in that dim 2163 s: stride in that dim 2164 op: output padding in that dim 2165 2166 Returns: 2167 The output length 2168 """ 2169 return (ln - 1) * s - 2 * p + d * (k - 1) + op + 1 2170 2171 kernel_size = weight.shape[2:] 2172 dims = input_tensor.shape[2:] 2173 if is_transposed: 2174 out_channels = groups * weight.shape[1] 2175 else: 2176 out_channels = weight.shape[0] 2177 if weight.shape[1] * groups != input_tensor.shape[1]: 2178 raise RuntimeError("Invalid channel dimensions") 2179 2180 ret_shape = [input_tensor.shape[0], out_channels] 2181 if isinstance(stride, IntLike): 2182 stride = [stride] * len(dims) 2183 elif len(stride) == 1: 2184 stride = [stride[0]] * len(dims) 2185 2186 if isinstance(padding, IntLike): 2187 padding = [padding] * len(dims) 2188 elif len(padding) == 1: 2189 padding = [padding[0]] * len(dims) 2190 2191 if isinstance(dilation, IntLike): 2192 dilation = [dilation] * len(dims) 2193 elif len(dilation) == 1: 2194 dilation = [dilation[0]] * len(dims) 2195 2196 output_padding_list: Optional[List[int]] = None 2197 if output_padding: 2198 if isinstance(output_padding, IntLike): 2199 output_padding_list = [output_padding] * len(dims) 2200 elif len(output_padding) == 1: 2201 output_padding_list = [output_padding[0]] * len(dims) 2202 else: 2203 output_padding_list = output_padding 2204 2205 for i in range(len(dims)): 2206 # If output_padding is present, we are dealing with a transposed convolution 2207 if output_padding_list: 2208 ret_shape.append( 2209 _formula_transposed( 2210 dims[i], 2211 padding[i], 2212 dilation[i], 2213 kernel_size[i], 2214 stride[i], 2215 output_padding_list[i], 2216 ) 2217 ) 2218 else: 2219 ret_shape.append( 2220 _formula(dims[i], padding[i], dilation[i], kernel_size[i], stride[i]) 2221 ) 2222 2223 return ret_shape 2224 2225 2226def is_channels_last(ten): 2227 return torch._prims_common.suggest_memory_format(ten) == torch.channels_last 2228 2229 2230@register_meta(aten.convolution.default) 2231def meta_conv( 2232 input_tensor: torch.Tensor, 2233 weight: torch.Tensor, 2234 bias: torch.Tensor, 2235 stride: List[int], 2236 padding: List[int], 2237 dilation: List[int], 2238 is_transposed: bool, 2239 output_padding: List[int], 2240 groups: int, 2241): 2242 def pick_memory_format(): 2243 if device_hint(input_tensor) == "cuda": 2244 if is_channels_last(input_tensor) or is_channels_last(weight): 2245 return torch.channels_last 2246 else: 2247 if is_channels_last(input_tensor): 2248 return torch.channels_last 2249 if input_tensor.is_contiguous(memory_format=torch.contiguous_format): 2250 return torch.contiguous_format 2251 elif input_tensor.is_contiguous(memory_format=torch.preserve_format): 2252 return torch.preserve_format 2253 2254 shape_out = calc_conv_nd_return_shape( 2255 input_tensor, 2256 weight, 2257 stride, 2258 padding, 2259 dilation, 2260 is_transposed, 2261 groups, 2262 output_padding if is_transposed else None, 2263 ) 2264 2265 input_channels_dim = 1 2266 output_channels_dim = 1 2267 if input_tensor.size(input_channels_dim) == 0: 2268 shape_out[output_channels_dim] = 0 2269 2270 out = input_tensor.new_empty(shape_out) 2271 out = out.to(memory_format=pick_memory_format()) # type: ignore[call-overload] 2272 return out 2273 2274 2275if torch._C._has_mkldnn: 2276 _meta_lib_dont_use_me_use_register_meta_for_mkldnn = torch.library.Library( 2277 "mkldnn", "IMPL", "Meta" 2278 ) 2279 2280 @register_meta(torch.ops.mkldnn._convolution_pointwise.default) 2281 def meta_mkldnn_convolution_default( 2282 input_tensor, 2283 weight, 2284 bias, 2285 padding, 2286 stride, 2287 dilation, 2288 groups, 2289 attr, 2290 scalars, 2291 algorithm, 2292 ): 2293 shape_out = calc_conv_nd_return_shape( 2294 input_tensor, weight, stride, padding, dilation, False, groups, [] 2295 ) 2296 out = input_tensor.new_empty(shape_out) 2297 out_memory_format = torch.channels_last 2298 if input_tensor.dim() == 5: 2299 out_memory_format = torch.channels_last_3d 2300 out = out.to(memory_format=out_memory_format) # type: ignore[call-overload] 2301 return out 2302 2303 @register_meta(torch.ops.mkldnn._linear_pointwise.default) 2304 def meta_linear_pointwise_default( 2305 input_tensor, weight, bias, attr, scalars, algorithm 2306 ): 2307 return input_tensor.new_empty((*input_tensor.shape[:-1], weight.shape[0])) 2308 2309 if torch._C.has_mkl: 2310 _meta_lib_dont_use_me_use_register_meta_for_mkl = torch.library.Library( 2311 "mkl", "IMPL", "Meta" 2312 ) 2313 2314 @register_meta(torch.ops.mkl._mkl_linear) 2315 def meta_mkl_linear(input_tensor, packed_weight, orig_weight, bias, batch_size): 2316 return input_tensor.new_empty( 2317 (*input_tensor.shape[:-1], orig_weight.shape[0]) 2318 ) 2319 2320 _meta_lib_dont_use_me_use_register_meta_for_onednn = torch.library.Library( 2321 "onednn", "IMPL", "Meta" 2322 ) 2323 2324 @register_meta(torch.ops.onednn.qconv2d_pointwise.default) 2325 def meta_qconv2d_pointwise( 2326 x, 2327 x_scale, 2328 x_zp, 2329 w, # prepacked_weight 2330 w_scale, 2331 w_zp, 2332 bias, 2333 stride, 2334 padding, 2335 dilation, 2336 groups, 2337 output_scale, 2338 output_zero_point, 2339 output_dtype, 2340 attr, 2341 scalars, 2342 algorithm, 2343 ): 2344 shape_out = calc_conv_nd_return_shape( 2345 x, 2346 w, 2347 stride, 2348 padding, 2349 dilation, 2350 False, 2351 groups, 2352 None, 2353 ) 2354 assert output_dtype in [torch.float32, torch.bfloat16] 2355 out = x.new_empty(shape_out, dtype=output_dtype) 2356 out = out.to(memory_format=torch.channels_last) 2357 return out 2358 2359 @register_meta(torch.ops.onednn.qlinear_pointwise.default) 2360 @register_meta(torch.ops.onednn.qlinear_pointwise.tensor) 2361 def meta_qlinear_pointwise( 2362 x, 2363 x_scale, 2364 x_zp, 2365 w, 2366 w_scale, 2367 w_zp, 2368 bias, 2369 output_scale, 2370 output_zero_point, 2371 output_dtype, 2372 post_op_name, 2373 post_op_args, 2374 post_op_algorithm, 2375 ): 2376 output_shape = list(x.shape) 2377 # The weight has been transposed during the qlinear weight prepack process. 2378 output_shape[-1] = w.shape[1] 2379 assert output_dtype in [torch.float32, torch.bfloat16] 2380 out = x.new_empty(output_shape, dtype=output_dtype) 2381 return out 2382 2383 _meta_lib_dont_use_me_use_register_meta_for_quantized = torch.library.Library( 2384 "quantized", "IMPL", "Meta" 2385 ) 2386 2387 @register_meta(torch.ops.quantized.max_pool2d) 2388 def meta_quantized_max_pool2d( 2389 input, 2390 kernel_size, 2391 stride=(), 2392 padding=(0,), 2393 dilation=(1,), 2394 ceil_mode=False, 2395 ): 2396 ( 2397 nInputPlane, 2398 outputHeight, 2399 outputWidth, 2400 ) = max_pool2d_checks_and_compute_shape( 2401 input, kernel_size, stride, padding, dilation, ceil_mode 2402 ) 2403 nbatch = input.size(-4) if input.dim() == 4 else 1 2404 memory_format = torch.channels_last 2405 if input.dim() == 3: 2406 size = [nInputPlane, outputHeight, outputWidth] 2407 else: 2408 size = [nbatch, nInputPlane, outputHeight, outputWidth] 2409 return torch.empty( 2410 size, 2411 dtype=input.dtype, 2412 device=input.device, 2413 memory_format=memory_format, 2414 ) 2415 2416 2417# from check_dim_size() in aten/src/ATen/TensorUtils.cpp. 2418def check_dim_size(tensor, dim, dim_size, size): 2419 torch._check( 2420 tensor.dim() == dim and tensor.shape[dim_size] == size, 2421 lambda: f"Expected a tensor of dimension {dim} and tensor.size[{dim_size}] == {size}, " 2422 + f"but got : dimension {tensor.dim()} and tensor.size[{dim_size}] = {tensor.shape[dim_size]}", 2423 ) 2424 2425 2426@register_meta(aten.avg_pool2d.default) 2427def meta_avg_pool2d( 2428 input, 2429 kernel_size, 2430 stride=(), 2431 padding=(0,), 2432 ceil_mode=False, 2433 count_include_pad=True, 2434 divisor_override=None, 2435): 2436 def unpack(name, val): 2437 torch._check( 2438 len(val) in [1, 2], 2439 lambda: f"avg_pool2d: {name} must either be a single int, or a tuple of two ints", 2440 ) 2441 H = val[0] 2442 W = H if len(val) == 1 else val[1] 2443 return H, W 2444 2445 kH, kW = unpack("kernel_size", kernel_size) 2446 torch._check( 2447 len(stride) in [0, 1, 2], 2448 lambda: "avg_pool2d: stride must either be omitted, a single int, or a tuple of two ints", 2449 ) 2450 if len(stride) == 0: 2451 dH, dW = kH, kW 2452 elif len(stride) == 1: 2453 dH, dW = stride[0], stride[0] 2454 else: 2455 dH, dW = unpack("stride", stride) 2456 2457 padH, padW = unpack("padding", padding) 2458 2459 torch._check( 2460 divisor_override is None or divisor_override != 0, 2461 lambda: "divisor must be not zero", 2462 ) 2463 2464 nbatch = input.size(-4) if input.dim() == 4 else 1 2465 nInputPlane = input.size(-3) 2466 inputHeight = input.size(-2) 2467 inputWidth = input.size(-1) 2468 2469 outputHeight = pooling_output_shape(inputHeight, kH, padH, dH, 1, ceil_mode) 2470 outputWidth = pooling_output_shape(inputWidth, kW, padW, dW, 1, ceil_mode) 2471 2472 memory_format = utils.suggest_memory_format(input) 2473 pool2d_shape_check( 2474 input, 2475 kH, 2476 kW, 2477 dH, 2478 dW, 2479 padH, 2480 padW, 2481 1, 2482 1, 2483 nInputPlane, 2484 inputHeight, 2485 inputWidth, 2486 outputHeight, 2487 outputWidth, 2488 memory_format, 2489 ) 2490 2491 if input.dim() == 3: 2492 size = [nInputPlane, outputHeight, outputWidth] 2493 else: 2494 size = [nbatch, nInputPlane, outputHeight, outputWidth] 2495 return torch.empty( 2496 size, 2497 dtype=input.dtype, 2498 device=input.device, 2499 memory_format=memory_format, 2500 ) 2501 2502 2503# from avg_pool2d_backward_shape_check() in aten/src/ATen/native/Pool.h. 2504def avg_pool2d_backward_shape_check( 2505 input, 2506 gradOutput, 2507 nbatch, 2508 kH, 2509 kW, 2510 dH, 2511 dW, 2512 padH, 2513 padW, 2514 nInputPlane, 2515 inputHeight, 2516 inputWidth, 2517 outputHeight, 2518 outputWidth, 2519 mem_format, 2520): 2521 pool2d_shape_check( 2522 input, 2523 kH, 2524 kW, 2525 dH, 2526 dW, 2527 padH, 2528 padW, 2529 1, 2530 1, 2531 nInputPlane, 2532 inputHeight, 2533 inputWidth, 2534 outputHeight, 2535 outputWidth, 2536 mem_format, 2537 ) 2538 2539 ndim = input.dim() 2540 nOutputPlane = nInputPlane 2541 2542 check_dim_size(gradOutput, ndim, ndim - 3, nOutputPlane) 2543 check_dim_size(gradOutput, ndim, ndim - 2, outputHeight) 2544 check_dim_size(gradOutput, ndim, ndim - 1, outputWidth) 2545 2546 2547# Don't override the C++ registration. 2548@register_meta(aten.avg_pool2d_backward.default) 2549def meta_avg_pool2d_backward( 2550 gradOutput_, 2551 input, 2552 kernel_size, 2553 stride, 2554 padding, 2555 ceil_mode, 2556 count_include_pad, 2557 divisor_override, 2558): 2559 # From aten/src/ATen/native/AveragePool2d.cpp structured kernel meta func. 2560 torch._check( 2561 len(kernel_size) == 1 or len(kernel_size) == 2, 2562 lambda: "avg_pool2d: kernel_size must either be a single int, or a tuple of two ints", 2563 ) 2564 kH = kernel_size[0] 2565 kW = kH if len(kernel_size) == 1 else kernel_size[1] 2566 torch._check( 2567 len(stride) == 0 or len(stride) == 1 or len(stride) == 2, 2568 lambda: "avg_pool2d: stride must either be omitted, a single int, or a tuple of two ints", 2569 ) 2570 dH = kH if len(stride) == 0 else stride[0] 2571 dW = kW if len(stride) == 0 else dH if len(stride) == 1 else stride[1] 2572 torch._check( 2573 len(padding) == 1 or len(padding) == 2, 2574 lambda: "avg_pool2d: padding must either be a single int, or a tuple of two ints", 2575 ) 2576 padH = padding[0] 2577 padW = padH if len(padding) == 1 else padding[1] 2578 2579 torch._check( 2580 divisor_override is None or divisor_override != 0, 2581 lambda: "divisor must be not zero", 2582 ) 2583 2584 input_size = input.shape 2585 nbatch = input_size[-4] if input.dim() == 4 else 1 2586 nInputPlane = input_size[-3] 2587 inputHeight = input_size[-2] 2588 inputWidth = input_size[-1] 2589 2590 outputHeight = pooling_output_shape(inputHeight, kH, padH, dH, 1, ceil_mode) 2591 outputWidth = pooling_output_shape(inputWidth, kW, padW, dW, 1, ceil_mode) 2592 2593 mem_format = utils.suggest_memory_format(input) 2594 2595 avg_pool2d_backward_shape_check( 2596 input, 2597 gradOutput_, 2598 nbatch, 2599 kH, 2600 kW, 2601 dH, 2602 dW, 2603 padH, 2604 padW, 2605 nInputPlane, 2606 inputHeight, 2607 inputWidth, 2608 outputHeight, 2609 outputWidth, 2610 mem_format, 2611 ) 2612 2613 return torch.empty( 2614 input_size, 2615 dtype=input.dtype, 2616 device=input.device, 2617 memory_format=mem_format, 2618 ) 2619 2620 2621@register_meta(aten.avg_pool3d) 2622@out_wrapper() 2623def meta_avg_pool3d( 2624 input, 2625 kernel_size, 2626 stride=(), 2627 padding=(0,), 2628 ceil_mode=False, 2629 count_include_pad=True, 2630 divisor_override=None, 2631): 2632 torch._check( 2633 len(kernel_size) in (1, 3), 2634 lambda: "avg_pool3d: kernel_size must be a single int, or a tuple of three ints", 2635 ) 2636 kT = kernel_size[0] 2637 kH = kT if len(kernel_size) == 1 else kernel_size[1] 2638 kW = kT if len(kernel_size) == 1 else kernel_size[2] 2639 2640 torch._check( 2641 not stride or len(stride) in (1, 3), 2642 lambda: "avg_pool3d: stride must be omitted, a single int, or a tuple of three ints", 2643 ) 2644 dT = kT if not stride else stride[0] 2645 dH = kH if not stride else (dT if len(stride) == 1 else stride[1]) 2646 dW = kW if not stride else (dT if len(stride) == 1 else stride[2]) 2647 2648 torch._check( 2649 len(padding) in (1, 3), 2650 lambda: "avg_pool3d: padding must be a single int, or a tuple of three ints", 2651 ) 2652 padT = padding[0] 2653 padH = padT if len(padding) == 1 else padding[1] 2654 padW = padT if len(padding) == 1 else padding[2] 2655 2656 torch._check( 2657 input.ndim in (4, 5), 2658 lambda: "non-empty 4D or 5D (batch mode) tensor expected for input", 2659 ) 2660 2661 torch._check( 2662 not divisor_override or divisor_override != 0, 2663 lambda: "divisor must be not zero", 2664 ) 2665 2666 nbatch = input.size(0) 2667 nslices = input.size(-4) 2668 itime = input.size(-3) 2669 iheight = input.size(-2) 2670 iwidth = input.size(-1) 2671 2672 otime = pooling_output_shape(itime, kT, padT, dT, 1, ceil_mode) 2673 oheight = pooling_output_shape(iheight, kH, padH, dH, 1, ceil_mode) 2674 owidth = pooling_output_shape(iwidth, kW, padW, dW, 1, ceil_mode) 2675 2676 pool3d_shape_check( 2677 input, 2678 nslices, 2679 kT, 2680 kH, 2681 kW, 2682 dT, 2683 dH, 2684 dW, 2685 padT, 2686 padH, 2687 padW, 2688 1, 2689 1, 2690 1, 2691 itime, 2692 iheight, 2693 iwidth, 2694 otime, 2695 oheight, 2696 owidth, 2697 "avg_pool3d()", 2698 check_input_size=True, 2699 ) 2700 2701 if input.ndim == 4: 2702 return input.new_empty((nslices, otime, oheight, owidth)) 2703 else: 2704 return input.new_empty((nbatch, nslices, otime, oheight, owidth)) 2705 2706 2707@register_meta(aten.avg_pool3d_backward) 2708@out_wrapper("grad_input") 2709def meta_avg_pool3d_backward( 2710 grad_output, 2711 input, 2712 kernel_size, 2713 stride, 2714 padding, 2715 ceil_mode, 2716 count_include_pad, 2717 divisor_override, 2718): 2719 torch._check( 2720 len(kernel_size) in (1, 3), 2721 lambda: "avg_pool3d: kernel_size must be a single int, or a tuple of three ints", 2722 ) 2723 kT = kernel_size[0] 2724 kH = kT if len(kernel_size) == 1 else kernel_size[1] 2725 kW = kT if len(kernel_size) == 1 else kernel_size[2] 2726 2727 torch._check( 2728 not stride or len(stride) in (1, 3), 2729 lambda: "avg_pool3d: stride must be omitted, a single int, or a tuple of three ints", 2730 ) 2731 dT = kT if not stride else stride[0] 2732 dH = kH if not stride else (dT if len(stride) == 1 else stride[1]) 2733 dW = kW if not stride else (dT if len(stride) == 1 else stride[2]) 2734 2735 torch._check( 2736 len(padding) in (1, 3), 2737 lambda: "avg_pool3d: padding must be a single int, or a tuple of three ints", 2738 ) 2739 padT = padding[0] 2740 padH = padT if len(padding) == 1 else padding[1] 2741 padW = padT if len(padding) == 1 else padding[2] 2742 2743 torch._check( 2744 input.ndim in (4, 5), 2745 lambda: "non-empty 4D or 5D (batch mode) tensor expected for input", 2746 ) 2747 2748 torch._check( 2749 not divisor_override or divisor_override != 0, 2750 lambda: "divisor must be not zero", 2751 ) 2752 2753 nslices = input.size(-4) 2754 itime = input.size(-3) 2755 iheight = input.size(-2) 2756 iwidth = input.size(-1) 2757 2758 otime_for_shape_check = pooling_output_shape(itime, kT, padT, dT, 1, ceil_mode) 2759 oheight_for_shape_check = pooling_output_shape(iheight, kH, padH, dH, 1, ceil_mode) 2760 owidth_for_shape_check = pooling_output_shape(iwidth, kW, padW, dW, 1, ceil_mode) 2761 2762 avg_pool3d_backward_shape_check( 2763 input, 2764 grad_output, 2765 nslices, 2766 kT, 2767 kH, 2768 kW, 2769 dT, 2770 dH, 2771 dW, 2772 padT, 2773 padH, 2774 padW, 2775 itime, 2776 iheight, 2777 iwidth, 2778 otime_for_shape_check, 2779 oheight_for_shape_check, 2780 owidth_for_shape_check, 2781 "avg_pool3d_backward()", 2782 ) 2783 2784 return input.new_empty(input.shape) 2785 2786 2787@register_meta(aten._adaptive_avg_pool2d.default) 2788def meta_adaptive_avg_pool2d(self, output_size): 2789 torch._check( 2790 self.ndim == 3 or self.ndim == 4, 2791 lambda: f"Expected 3D or 4D tensor, but got {self.shape}", 2792 ) 2793 output_shape = self.shape[:-2] + tuple(output_size) 2794 memory_format = utils.suggest_memory_format(self) 2795 # need to set memory_format to preserve the memory format of the input 2796 # channel last input should have channel last output 2797 return torch.empty( 2798 output_shape, 2799 dtype=self.dtype, 2800 device=self.device, 2801 memory_format=memory_format, 2802 ) 2803 2804 2805@register_meta(aten._adaptive_avg_pool3d.default) 2806def meta_adaptive_avg_pool3d(self, output_size): 2807 torch._check( 2808 self.ndim == 4 or self.ndim == 5, 2809 lambda: f"Expected 4D or 5D tensor, but got {self.shape}", 2810 ) 2811 return self.new_empty(self.shape[:-3] + tuple(output_size)) 2812 2813 2814@register_meta(aten._adaptive_avg_pool2d_backward.default) 2815def meta__adaptive_avg_pool2d_backward(grad_out, self): 2816 ndim = grad_out.ndim 2817 for i in range(1, ndim): 2818 torch._check( 2819 grad_out.size(i) > 0, 2820 lambda: f"adaptive_avg_pool2d_backward(): Expected grad_output to have non-zero \ 2821 size for non-batch dimensions, {grad_out.shape} with dimension {i} being empty", 2822 ) 2823 torch._check( 2824 ndim == 3 or ndim == 4, 2825 lambda: f"adaptive_avg_pool2d_backward(): Expected 3D or 4D tensor, but got {self.shape}", 2826 ) 2827 torch._check( 2828 self.dtype == grad_out.dtype, 2829 lambda: f"expected dtype {self.dtype} for `grad_output` but got dtype {grad_out.dtype}", 2830 ) 2831 memory_format = torch.contiguous_format 2832 if is_channels_last(self): 2833 memory_format = torch.channels_last 2834 return self.new_empty(self.shape).to(memory_format=memory_format) 2835 2836 2837@register_meta(aten._adaptive_avg_pool3d_backward) 2838@out_wrapper("grad_input") 2839def meta__adaptive_avg_pool3d_backward(grad_output, self): 2840 _adaptive_pool_empty_output_check(grad_output, "adaptive_avg_pool3d_backward") 2841 return torch.empty_like(self, memory_format=torch.legacy_contiguous_format) 2842 2843 2844def _adaptive_pool_empty_output_check(grad_output: Tensor, arg_name: str): 2845 ndim = grad_output.ndim 2846 for i in range(1, ndim): 2847 torch._check( 2848 grad_output.size(i) > 0, 2849 lambda: ( 2850 f"{arg_name}(): Expected grad_output to have non-zero size for non-batch dimensions, " 2851 f"but grad_output has sizes {grad_output.shape} with dimension {i} being empty" 2852 ), 2853 ) 2854 2855 2856@register_meta(aten.adaptive_max_pool2d) 2857@out_wrapper("out", "indices") 2858def meta_adaptive_max_pool2d(input, output_size): 2859 ndim = input.ndim 2860 torch._check( 2861 ndim in (3, 4), 2862 lambda: f"adaptive_max_pool2d(): Expected 3D or 4D tensor, but got: {input.shape}", 2863 ) 2864 for i in range(1, ndim): 2865 torch._check( 2866 input.size(i) > 0, 2867 lambda: ( 2868 f"adaptive_max_pool2d(): Expected input to have non-zero size for non-batch dimensions, " 2869 f"but input has sizes {input.shape} with dimension {i} being empty" 2870 ), 2871 ) 2872 2873 torch._check( 2874 len(output_size) == 2, 2875 lambda: "adaptive_max_pool2d(): internal error: output_size.size() must be 2", 2876 ) 2877 2878 dimH = 1 2879 sizeB = 1 2880 sizeD = 0 2881 2882 if input.ndim == 4: 2883 sizeB = input.size(0) 2884 dimH += 1 2885 2886 sizeD = input.size(dimH - 1) 2887 osizeH, osizeW = output_size 2888 2889 if input.ndim == 3: 2890 out_shape = (sizeD, osizeH, osizeW) 2891 out = input.new_empty(out_shape) 2892 indices = input.new_empty(out_shape, dtype=torch.int64) 2893 return out, indices 2894 else: 2895 out_shape = (sizeB, sizeD, osizeH, osizeW) # type: ignore[assignment] 2896 memory_format = utils.suggest_memory_format(input) 2897 out = input.new_empty(out_shape).to(memory_format=memory_format) 2898 indices = input.new_empty(out_shape, dtype=torch.int64).to( 2899 memory_format=memory_format 2900 ) 2901 return out, indices 2902 2903 2904@register_meta(aten.adaptive_max_pool2d_backward) 2905@out_wrapper("grad_input") 2906def meta_adaptive_max_pool2d_backward(grad_output, input, indices): 2907 ndim = grad_output.ndim 2908 torch._check( 2909 ndim in (3, 4), 2910 lambda: f"adaptive_max_pooling2d_backward(): Expected 3D or 4D grad_output, but got: {grad_output.shape}", 2911 ) 2912 2913 _adaptive_pool_empty_output_check(grad_output, "adaptive_max_pool2d_backward") 2914 2915 torch._check( 2916 input.dtype == grad_output.dtype, 2917 lambda: f"expected dtype {input.dtype} for `grad_output` but got dtype {grad_output.dtype}", 2918 ) 2919 2920 memory_format = utils.suggest_memory_format(input) 2921 return input.new_empty(input.shape).to(memory_format=memory_format) 2922 2923 2924@register_meta(aten.adaptive_max_pool3d) 2925@out_wrapper("out", "indices") 2926def meta_adaptive_max_pool3d(input, output_size): 2927 ndim = input.ndim 2928 torch._check( 2929 ndim in (4, 5), 2930 lambda: f"adaptive_max_pool3d(): Expected 4D or 5D tensor, but got: {input.shape}", 2931 ) 2932 for i in range(1, ndim): 2933 torch._check( 2934 input.size(i) > 0, 2935 lambda: ( 2936 f"adaptive_max_pool3d(): Expected input to have non-zero size for non-batch dimensions, " 2937 f"but input has sizes {input.shape} with dimension {i} being empty" 2938 ), 2939 ) 2940 2941 torch._check( 2942 len(output_size) == 3, 2943 lambda: "adaptive_max_pool3d(): internal error: output_size.size() must be 3", 2944 ) 2945 2946 dimD = 0 2947 sizeB = 1 2948 sizeD = 0 2949 2950 if ndim == 5: 2951 sizeB = input.size(0) 2952 dimD += 1 2953 2954 sizeD = input.size(dimD) 2955 osizeT, osizeH, osizeW = output_size 2956 2957 if ndim == 4: 2958 out_shape = (sizeD, osizeT, osizeH, osizeW) 2959 else: 2960 out_shape = (sizeB, sizeD, osizeT, osizeH, osizeW) # type: ignore[assignment] 2961 2962 out = input.new_empty(out_shape) 2963 indices = input.new_empty(out_shape, dtype=torch.int64) 2964 2965 return out, indices 2966 2967 2968@register_meta(aten.adaptive_max_pool3d_backward) 2969@out_wrapper("grad_input") 2970def meta_adaptive_max_pool3d_backward(grad_output, input, indices): 2971 _adaptive_pool_empty_output_check(grad_output, "adaptive_max_pool3d_backward") 2972 return input.new_empty(input.shape) 2973 2974 2975@register_meta(aten.repeat_interleave.Tensor) 2976def meta_repeat_interleave_Tensor(repeats, output_size=None): 2977 if output_size is None: 2978 raise RuntimeError("cannot repeat_interleave a meta tensor without output_size") 2979 return repeats.new_empty(output_size) 2980 2981 2982@register_meta([aten.complex.default, aten.complex.out]) 2983@out_wrapper() 2984def meta_complex(real, imag): 2985 assert real.dtype.is_floating_point 2986 assert imag.dtype.is_floating_point 2987 out_shape = _broadcast_shapes(real.shape, imag.shape) 2988 return real.new_empty(out_shape, dtype=corresponding_complex_dtype(real.dtype)) 2989 2990 2991@register_meta([aten.nonzero_static.default, aten.nonzero_static.out]) 2992@out_wrapper() 2993def nonzero_static(self, *, size: int, fill_value: int = -1): 2994 return self.new_empty((size, self.dim()), dtype=torch.long) 2995 2996 2997@register_meta([aten.index.Tensor, aten._unsafe_index.Tensor]) 2998def meta_index_Tensor(self, indices): 2999 torch._check(bool(indices), lambda: "at least one index must be provided") 3000 # aten::index is the internal advanced indexing implementation 3001 # checkIndexTensorTypes and expandTensors 3002 result: List[Optional[Tensor]] = [] 3003 for i, index in enumerate(indices): 3004 if index is not None: 3005 torch._check( 3006 index.dtype in [torch.long, torch.int, torch.int8, torch.bool], 3007 lambda: "tensors used as indices must be long, int, byte or bool tensors", 3008 ) 3009 if index.dtype in [torch.int8, torch.bool]: 3010 nonzero = index.nonzero() 3011 k = len(result) 3012 torch._check_index( 3013 k + index.ndim <= self.ndim, 3014 lambda: f"too many indices for tensor of dimension {self.ndim}", 3015 ) 3016 for j in range(index.ndim): 3017 torch._check_index( 3018 index.shape[j] == self.shape[k + j], 3019 lambda: f"The shape of the mask {index.shape} at index {i} " 3020 f"does not match the shape of the indexed tensor {self.shape} at index {k + j}", 3021 ) 3022 result.append(nonzero.select(1, j)) 3023 else: 3024 result.append(index) 3025 else: 3026 result.append(index) 3027 indices = result 3028 torch._check( 3029 len(indices) <= self.ndim, 3030 lambda: f"too many indices for tensor of dimension {self.ndim} (got {len(indices)})", 3031 ) 3032 # expand_outplace 3033 import torch._refs as refs # avoid import cycle in mypy 3034 3035 indices = list(refs._maybe_broadcast(*indices)) 3036 # add missing null tensors 3037 while len(indices) < self.ndim: 3038 indices.append(None) 3039 3040 # hasContiguousSubspace 3041 # true if all non-null tensors are adjacent 3042 # See: 3043 # https://numpy.org/doc/stable/user/basics.indexing.html#combining-advanced-and-basic-indexing 3044 # https://stackoverflow.com/questions/53841497/why-does-numpy-mixed-basic-advanced-indexing-depend-on-slice-adjacency 3045 state = 0 3046 has_contiguous_subspace = False 3047 for index in indices: 3048 if state == 0: 3049 if index is not None: 3050 state = 1 3051 elif state == 1: 3052 if index is None: 3053 state = 2 3054 else: 3055 if index is not None: 3056 break 3057 else: 3058 has_contiguous_subspace = True 3059 3060 # transposeToFront 3061 # This is the logic that causes the newly inserted dimensions to show up 3062 # at the beginning of the tensor, if they're not contiguous 3063 if not has_contiguous_subspace: 3064 dims = [] 3065 transposed_indices = [] 3066 for i, index in enumerate(indices): 3067 if index is not None: 3068 dims.append(i) 3069 transposed_indices.append(index) 3070 for i, index in enumerate(indices): 3071 if index is None: 3072 dims.append(i) 3073 transposed_indices.append(index) 3074 self = self.permute(dims) 3075 indices = transposed_indices 3076 3077 # AdvancedIndex::AdvancedIndex 3078 # Now we can assume the indices have contiguous subspace 3079 # This is simplified from AdvancedIndex which goes to more effort 3080 # to put the input and indices in a form so that TensorIterator can 3081 # take them. If we write a ref for this, probably that logic should 3082 # get implemented 3083 before_shape: List[int] = [] 3084 after_shape: List[int] = [] 3085 replacement_shape: List[int] = [] 3086 for dim, index in enumerate(indices): 3087 if index is None: 3088 if replacement_shape: 3089 after_shape.append(self.shape[dim]) 3090 else: 3091 before_shape.append(self.shape[dim]) 3092 else: 3093 replacement_shape = list(index.shape) 3094 return self.new_empty(before_shape + replacement_shape + after_shape) 3095 3096 3097@register_meta([aten.convolution_backward.default]) 3098def meta_convolution_backward( 3099 grad_output_, 3100 input_, 3101 weight_, 3102 bias_sizes_opt, 3103 stride, 3104 padding, 3105 dilation, 3106 transposed, 3107 output_padding, 3108 groups, 3109 output_mask, 3110): 3111 # High level logic taken from slow_conv3d_backward_cpu which should 3112 # be representative of all convolution_backward impls 3113 backend_grad_input = None 3114 backend_grad_weight = None 3115 backend_grad_bias = None 3116 3117 if output_mask[0]: 3118 backend_grad_input = grad_output_.new_empty(input_.size()) 3119 if output_mask[1]: 3120 backend_grad_weight = grad_output_.new_empty(weight_.size()) 3121 if output_mask[2]: 3122 backend_grad_bias = grad_output_.new_empty(bias_sizes_opt) 3123 3124 return (backend_grad_input, backend_grad_weight, backend_grad_bias) 3125 3126 3127@register_meta([aten.addbmm.default, aten.addbmm.out]) 3128@out_wrapper() 3129def meta_addbmm(self, batch1, batch2, *, beta=1, alpha=1): 3130 dim1 = batch1.size(1) 3131 dim2 = batch2.size(2) 3132 self = self.expand((dim1, dim2)) 3133 torch._check(batch1.dim() == 3, lambda: "batch1 must be a 3D tensor") 3134 torch._check(batch2.dim() == 3, lambda: "batch2 must be a 3D tensor") 3135 torch._check( 3136 batch1.size(0) == batch2.size(0), 3137 lambda: f"batch1 and batch2 must have same number of batches, got {batch1.size(0)} and {batch2.size(0)}", 3138 ) 3139 torch._check( 3140 batch1.size(2) == batch2.size(1), 3141 lambda: ( 3142 f"Incompatible matrix sizes for bmm ({batch1.size(1)}x{batch1.size(2)} " 3143 f"and {batch2.size(1)}x{batch2.size(2)})" 3144 ), 3145 ) 3146 torch._check( 3147 self.size(0) == dim1 and self.size(1) == dim2, 3148 lambda: "self tensor does not match matmul output shape", 3149 ) 3150 return self.new_empty(self.size()) 3151 3152 3153@register_meta([aten._fused_adam_.default, aten._fused_adamw_.default]) 3154def meta__fused_adam_( 3155 self, 3156 grads, 3157 exp_avgs, 3158 exp_avg_sqs, 3159 max_exp_avg_sqs, 3160 state_steps, 3161 *, 3162 lr, 3163 beta1, 3164 beta2, 3165 weight_decay, 3166 eps, 3167 amsgrad, 3168 maximize, 3169 grad_scale=None, 3170 found_inf=None, 3171): 3172 for l in [self, grads, exp_avgs, exp_avg_sqs, max_exp_avg_sqs, state_steps]: 3173 torch._check( 3174 isinstance(l, List), 3175 lambda: f"exponent must be a tensor list but got {type(l)}", 3176 ) 3177 3178 3179@register_meta([aten._fused_adam.default]) 3180def meta__fused_adam( 3181 self, 3182 grads, 3183 exp_avgs, 3184 exp_avg_sqs, 3185 max_exp_avg_sqs, 3186 state_steps, 3187 *, 3188 lr, 3189 beta1, 3190 beta2, 3191 weight_decay, 3192 eps, 3193 amsgrad, 3194 maximize, 3195 grad_scale=None, 3196 found_inf=None, 3197): 3198 for l in [self, grads, exp_avgs, exp_avg_sqs, max_exp_avg_sqs, state_steps]: 3199 torch._check( 3200 isinstance(l, List), 3201 lambda: f"exponent must be a tensor list but got {type(l)}", 3202 ) 3203 3204 def empty_like_list(tensor_list): 3205 return [torch.empty_like(t) for t in tensor_list] 3206 3207 return ( 3208 empty_like_list(self), 3209 empty_like_list(grads), 3210 empty_like_list(exp_avgs), 3211 empty_like_list(exp_avg_sqs), 3212 empty_like_list(max_exp_avg_sqs), 3213 ) 3214 3215 3216@register_meta([aten._int_mm]) 3217@out_wrapper() 3218def meta__int_mm(a, b): 3219 torch._check(a.dim() == 2, lambda: "a must be a 2D tensor") 3220 torch._check(b.dim() == 2, lambda: "b must be a 2D tensor") 3221 torch._check( 3222 a.dtype is torch.int8, 3223 lambda: f"expected self to be int8, got {a.dtype}", 3224 ) 3225 torch._check( 3226 b.dtype is torch.int8, 3227 lambda: f"expected mat2 to be int8, got {b.dtype}", 3228 ) 3229 torch._check( 3230 a.size(1) == b.size(0), 3231 lambda: ( 3232 f"Incompatible matrix sizes for _int_mm ({a.size(0)}x{a.size(1)} " 3233 f"and {b.size(0)}x{b.size(1)})" 3234 ), 3235 ) 3236 return a.new_empty((a.size(0), b.size(1)), dtype=torch.int32) 3237 3238 3239@register_meta([aten._convert_weight_to_int4pack]) 3240def meta__convert_weight_to_int4pack(w, inner_k_tiles): 3241 torch._check(w.dim() == 2, lambda: "w must be a 2D tensor") 3242 torch._check( 3243 w.dtype is torch.uint8, 3244 lambda: f"expected w to be uint8, got {w.dtype}", 3245 ) 3246 n = w.size(0) 3247 k = w.size(1) * 2 # w is [n][k / 2] uint8 3248 return w.new_empty( 3249 ( 3250 n // 8, 3251 k // (inner_k_tiles * 16), 3252 32, 3253 inner_k_tiles // 2, 3254 ), 3255 dtype=torch.int32, 3256 ) 3257 3258 3259@register_meta([aten._weight_int4pack_mm]) 3260def meta__weight_int4pack_mm(x, w, q_group_size, q_scale_and_zeros): 3261 torch._check(x.dim() == 2, lambda: "x must be a 2D tensor") 3262 torch._check(w.dim() == 4, lambda: "w must be a 4D tensor") 3263 torch._check( 3264 x.dtype in [torch.float32, torch.float16, torch.bfloat16], 3265 lambda: f"expected x to be f32/f16/bf16, got {x.dtype}", 3266 ) 3267 torch._check( 3268 w.dtype is torch.int32, 3269 lambda: f"expected w to be int32, got {w.dtype}", 3270 ) 3271 return x.new_empty(x.size(0), w.size(0) * 8, dtype=x.dtype) 3272 3273 3274@register_meta([aten._weight_int8pack_mm]) 3275def meta__weight_int8pack_mm(x, w, q_scales): 3276 torch._check(x.dim() == 2, lambda: "x must be a 2D tensor") 3277 torch._check( 3278 x.dtype in [torch.float32, torch.float16, torch.bfloat16], 3279 lambda: f"expected x to be f32/f16/bf16, got {x.dtype}", 3280 ) 3281 torch._check(w.dim() == 2, lambda: "w must be a 2D tensor") 3282 torch._check( 3283 w.dtype is torch.int8, 3284 lambda: f"expected w to be int8, got {w.dtype}", 3285 ) 3286 return x.new_empty(x.size(0), w.size(0), dtype=x.dtype) 3287 3288 3289@register_meta(aten._cdist_forward.default) 3290def meta_cdist_forward(x1, x2, p, compute_mode): 3291 torch._check( 3292 x1.dim() >= 2, 3293 lambda: f"cdist only supports at least 2D tensors, X1 got: {x1.dim()}D", 3294 ) 3295 torch._check( 3296 x2.dim() >= 2, 3297 lambda: f"cdist only supports at least 2D tensors, X2 got: {x2.dim()}D", 3298 ) 3299 torch._check( 3300 x1.size(-1) == x2.size(-1), 3301 lambda: f"X1 and X2 must have the same number of columns. X1: {x1.size(-1)} X2: {x2.size(-1)}", 3302 ) 3303 torch._check( 3304 utils.is_float_dtype(x1.dtype), 3305 lambda: "cdist only supports floating-point dtypes, X1 got: {x1.dtype}", 3306 ) 3307 torch._check( 3308 utils.is_float_dtype(x2.dtype), 3309 lambda: "cdist only supports floating-point dtypes, X2 got: {x2.dtype}", 3310 ) 3311 torch._check(p >= 0, lambda: "cdist only supports non-negative p values") 3312 torch._check( 3313 compute_mode in (None, 1, 2), 3314 lambda: f"possible modes: None, 1, 2, but was: {compute_mode}", 3315 ) 3316 r1 = x1.size(-2) 3317 r2 = x2.size(-2) 3318 batch_tensor1 = x1.shape[:-2] 3319 batch_tensor2 = x2.shape[:-2] 3320 output_shape = list(torch.broadcast_shapes(batch_tensor1, batch_tensor2)) 3321 output_shape.extend([r1, r2]) 3322 return x1.new_empty(output_shape) 3323 3324 3325@register_meta(aten._cdist_backward) 3326@out_wrapper() 3327def meta_cdist_backward(grad, x1, x2, p, cdist): 3328 c1 = x1.shape[-1] 3329 r1 = x1.shape[-2] 3330 r2 = x2.shape[-2] 3331 batch_tensor1 = x1.shape[:-2] 3332 batch_tensor2 = x2.shape[:-2] 3333 expand_batch_portion = list(torch.broadcast_shapes(batch_tensor1, batch_tensor2)) 3334 tensor1_expand_size = expand_batch_portion.copy() 3335 tensor1_expand_size.extend([r1, c1]) 3336 batch_product = math.prod(expand_batch_portion) 3337 if r1 == 0 or r2 == 0 or c1 == 0 or batch_product == 0: 3338 return torch.zeros_like(x1) 3339 if tensor1_expand_size != list(x1.shape): 3340 x1 = x1.expand(tensor1_expand_size) 3341 return torch.empty_like(x1, memory_format=torch.contiguous_format) 3342 3343 3344# NB: This meta function accepts non-meta arguments! When this behavior 3345# was originally introduced this was accidental, but it is now load bearing 3346# as people are using this so that they can conveniently test code involving 3347# embeddings (feeding CPU tensor inputs with meta device EmbeddingBag module) 3348@register_meta(aten._embedding_bag.default) 3349def meta_embedding_bag( 3350 weight, 3351 indices, 3352 offsets, 3353 scale_grad_by_freq=False, 3354 mode=0, 3355 sparse=False, 3356 per_sample_weights=None, 3357 include_last_offset=False, 3358 padding_idx=-1, 3359): 3360 torch._check( 3361 indices.dtype in (torch.long, torch.int), 3362 lambda: f"expected indices to be long or int, got {indices.dtype}", 3363 ) 3364 torch._check( 3365 offsets.dtype in (torch.long, torch.int), 3366 lambda: f"expected offsets to be long or int, got {offsets.dtype}", 3367 ) 3368 torch._check( 3369 utils.is_float_dtype(weight.dtype), 3370 lambda: f"expected weight to be floating point type, got {weight.dtype}", 3371 ) 3372 3373 num_bags = offsets.size(0) 3374 if include_last_offset: 3375 torch._check( 3376 num_bags >= 1, 3377 lambda: "include_last_offset: numBags should be at least 1", 3378 ) 3379 num_bags -= 1 3380 3381 output = weight.new_empty(num_bags, weight.size(1)) 3382 MODE_SUM, MODE_MEAN, MODE_MAX = range(3) 3383 3384 if per_sample_weights is not None: 3385 torch._check( 3386 mode == MODE_SUM, 3387 lambda: "embedding_bag: per_sample_weights only supported with mode='sum'", 3388 ) 3389 torch._check( 3390 per_sample_weights.dtype == weight.dtype, 3391 lambda: f"expected weight ({weight.dtype}) and per_sample_weights ({per_sample_weights.dtype}) to have same dtype", 3392 ) 3393 torch._check( 3394 per_sample_weights.ndim == 1, 3395 lambda: f"expected per_sample_weights to be 1D tensor, got {per_sample_weights.ndim}D", 3396 ) 3397 torch._check( 3398 per_sample_weights.numel() == indices.numel(), 3399 lambda: ( 3400 f"expected per_sample_weights.numel() ({per_sample_weights.numel()} " 3401 f"to be the same as indices.numel() ({indices.numel()})" 3402 ), 3403 ) 3404 3405 def is_fast_path_index_select_scale(src, scale, output, padding_idx): 3406 return ( 3407 is_fast_path_index_select(src, output, padding_idx) and scale.stride(0) == 1 3408 ) 3409 3410 def is_fast_path_index_select(src, output, padding_idx): 3411 return ( 3412 (src.dtype == torch.float or src.dtype == torch.half) 3413 and src.stride(1) == 1 3414 and output.stride(1) == 1 3415 and padding_idx < 0 3416 ) 3417 3418 def is_fast_path(src, scale, output, padding_idx): 3419 if scale is not None: 3420 return is_fast_path_index_select_scale(src, scale, output, padding_idx) 3421 else: 3422 return is_fast_path_index_select(src, output, padding_idx) 3423 3424 if device_hint(offsets) != "cpu": 3425 offset2bag = indices.new_empty(indices.size(0)) 3426 bag_size = indices.new_empty(offsets.size()) 3427 if mode == MODE_MAX: 3428 max_indices = indices.new_empty(num_bags, weight.size(1)) 3429 else: 3430 max_indices = indices.new_empty(0) 3431 else: 3432 fast_path_sum = is_fast_path(weight, per_sample_weights, output, padding_idx) 3433 if mode in (MODE_MEAN, MODE_MAX) or not fast_path_sum: 3434 offset2bag = offsets.new_empty(indices.size(0)) 3435 else: 3436 offset2bag = offsets.new_empty(0) 3437 bag_size = offsets.new_empty(num_bags) 3438 # This part of the logic comes from make_max_indices_out in EmbeddingBag.cpp 3439 numBags = offsets.shape[0] 3440 if mode == MODE_MAX: 3441 if include_last_offset: 3442 torch._check( 3443 numBags >= 1, 3444 lambda: "include_last_offset: numBags should be at least 1", 3445 ) 3446 numBags -= 1 3447 max_indices = offsets.new_empty(numBags, weight.shape[1]) 3448 else: 3449 max_indices = offsets.new_empty(bag_size.size()) 3450 return output, offset2bag, bag_size, max_indices 3451 3452 3453@register_meta(aten._embedding_bag_forward_only.default) 3454def meta_embedding_bag_forward_only(weight, indices, offsets, *args): 3455 output, offset2bag, bag_size, max_indices = meta_embedding_bag( 3456 weight, indices, offsets, *args 3457 ) 3458 if device_hint(offsets) == "cpu": 3459 bag_size = offsets.new_empty(offsets.size()) 3460 return output, offset2bag, bag_size, max_indices 3461 3462 3463def _get_reduction_dtype(input, dtype, promote_int_to_long=True): 3464 # if specified, dtype takes precedence 3465 if dtype: 3466 return dtype 3467 3468 if input.dtype.is_floating_point or input.dtype.is_complex: 3469 return input.dtype 3470 elif promote_int_to_long: 3471 return torch.long 3472 3473 return input.dtype 3474 3475 3476@register_meta([aten.nansum.default, aten.nansum.out]) 3477@out_wrapper() 3478def meta_nansum(input, dims=None, keepdim=False, *, dtype=None): 3479 output_dtype = _get_reduction_dtype(input, dtype, promote_int_to_long=True) 3480 dims = utils.reduction_dims(input.shape, dims) 3481 output_shape = _compute_reduction_shape(input, dims, keepdim) 3482 return input.new_empty(output_shape, dtype=output_dtype) 3483 3484 3485@register_meta([aten.median.default, aten.nanmedian.default]) 3486def meta_median(input): 3487 output_shape = utils.compute_reduction_output_shape( 3488 input.shape, tuple(range(input.dim())) 3489 ) 3490 return input.new_empty(output_shape) 3491 3492 3493@register_meta( 3494 [ 3495 aten.median.dim, 3496 aten.median.dim_values, 3497 aten.nanmedian.dim, 3498 aten.nanmedian.dim_values, 3499 aten.mode.default, 3500 aten.mode.values, 3501 ] 3502) 3503@out_wrapper("values", "indices") 3504def meta_median_mode_dim(input, dim=-1, keepdim=False): 3505 if device_hint(input) == "cuda": 3506 utils.alert_not_deterministic("median CUDA with indices output") 3507 dim = utils.reduction_dims(input.shape, (dim,)) 3508 output_shape = _compute_reduction_shape(input, dim, keepdim) 3509 return ( 3510 input.new_empty(output_shape), 3511 input.new_empty(output_shape, dtype=torch.long), 3512 ) 3513 3514 3515@register_meta(aten.logical_not_.default) 3516def meta_logical_not_(self): 3517 return self 3518 3519 3520@register_meta(aten.repeat.default) 3521def meta_repeat(self, repeats): 3522 torch._check( 3523 len(repeats) >= self.dim(), 3524 lambda: "Number of dimensions of repeat dims can not be smaller than number of dimensions of tensor", 3525 ) 3526 # Add new leading dimensions to the tensor if the 3527 # number of target dimensions is larger than the 3528 # number of source dimensions. 3529 num_new_dimensions = len(repeats) - self.dim() 3530 padded_size = (1,) * num_new_dimensions + tuple(self.shape) 3531 target_size = [padded_size[i] * repeats[i] for i in range(len(repeats))] 3532 return self.new_empty(target_size) 3533 3534 3535@register_meta(aten.zero_.default) 3536def meta_zero_(self): 3537 return self 3538 3539 3540@register_meta( 3541 [ 3542 aten.mul_.Scalar, 3543 aten.div_.Scalar, 3544 aten.mul_.Tensor, 3545 aten.div_.Tensor, 3546 aten.logical_and_.default, 3547 aten.logical_or_.default, 3548 aten.logical_xor_.default, 3549 ], 3550) 3551def meta_binop_inplace(self, other): 3552 if isinstance(other, torch.Tensor): 3553 check_inplace_broadcast(self.shape, other.shape) 3554 return self 3555 3556 3557@register_meta( 3558 [ 3559 aten.add_.Scalar, 3560 aten.sub_.Scalar, 3561 aten.add_.Tensor, 3562 aten.sub_.Tensor, 3563 ], 3564) 3565def meta_binop_inplace_alpha(self, other, alpha=1): 3566 if isinstance(other, torch.Tensor): 3567 check_inplace_broadcast(self.shape, other.shape) 3568 return self 3569 3570 3571@register_meta([aten.round.default, aten.round.decimals]) 3572def meta_round(self, **kwargs): 3573 return elementwise_meta( 3574 self, type_promotion=ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT 3575 ) 3576 3577 3578def shift_dtype_check(fn_name, self, val): 3579 torch._check( 3580 utils.is_integer_dtype(self.dtype), 3581 lambda: f"{fn_name}: Expected input tensor to have an integral dtype. Got {self.dtype}", 3582 ) 3583 if isinstance(val, torch.Tensor): 3584 torch._check( 3585 utils.is_integer_dtype(val.dtype), 3586 lambda: f"{fn_name}: Expected shift value to have an integral dtype. Got {val.dtype}", 3587 ) 3588 else: 3589 torch._check( 3590 isinstance(val, IntLike), 3591 lambda: f"{fn_name}: Expected shift value to be an int. Got {val}", 3592 ) 3593 3594 3595@register_meta([aten.__rshift__.Tensor, aten.__rshift__.Scalar]) 3596def meta_rshifts(self, other): 3597 shift_dtype_check("rshift", self, other) 3598 return elementwise_meta( 3599 self, other, type_promotion=ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT 3600 ) 3601 3602 3603@register_meta([aten.__lshift__.Tensor, aten.__lshift__.Scalar]) 3604def meta_lshifts(self, other): 3605 shift_dtype_check("lshift", self, other) 3606 return elementwise_meta( 3607 self, other, type_promotion=ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT 3608 ) 3609 3610 3611@register_meta(aten.zero.default) 3612def meta_zero(self): 3613 return self.new_empty(self.shape) 3614 3615 3616@register_meta([aten.fill_.Tensor, aten.fill_.Scalar]) 3617def meta_fill_(self, val): 3618 return self 3619 3620 3621@register_meta([aten.fill.Tensor, aten.fill.Scalar]) 3622def meta_fill(self, val): 3623 return torch.empty_like(self) 3624 3625 3626@register_meta(aten.relu_.default) 3627def meta_relu_(self): 3628 return self 3629 3630 3631@register_meta([aten.index_put.default, aten._unsafe_index_put.default]) 3632def meta_index_put(self, indices, values, accumulate=False): 3633 return torch.empty_like(self) 3634 3635 3636@register_meta(aten.masked_fill_.Scalar) 3637def meta_masked_fill_(self, mask, value): 3638 check_inplace_broadcast(self.shape, mask.shape) 3639 return self 3640 3641 3642@register_meta(aten._masked_scale.default) 3643def meta__masked_scale(self, mask, scale): 3644 masked_scale = self.new_empty(self.size()).to( 3645 memory_format=utils.suggest_memory_format(self) 3646 ) 3647 return masked_scale 3648 3649 3650@register_meta(aten.masked_scatter_) 3651def meta_masked_scatter_(self, mask, source): 3652 torch._check( 3653 mask.dtype in (torch.bool, torch.uint8), lambda: "Mask must be bool or uint8" 3654 ) 3655 torch._check( 3656 self.dtype == source.dtype, 3657 lambda: "masked_scatter: expected self and source to have same " 3658 "dtypes but got {self.dtype} and {source.dtype}", 3659 ) 3660 return self 3661 3662 3663@register_meta(aten.masked_scatter) 3664@out_wrapper() 3665def meta_masked_scatter(self, mask, source): 3666 self, mask = _maybe_broadcast(self, mask) 3667 output = torch.empty_like(self, memory_format=torch.contiguous_format) 3668 return meta_masked_scatter_(output, mask, source) 3669 3670 3671@register_meta(aten.masked_scatter_backward) 3672def meta_masked_scatter_backward(self, mask, sizes): 3673 return self.new_empty(sizes) 3674 3675 3676@register_meta(aten.index_put_.default) 3677def meta_index_put_(self, indices, values, accumulate=False): 3678 return self 3679 3680 3681@register_meta(aten.alias.default) 3682def meta_alias(self): 3683 return self.view(self.shape) 3684 3685 3686def common_meta_baddbmm_bmm(batch1, batch2, is_bmm, self_baddbmm=None): 3687 torch._check(batch1.dim() == 3, lambda: "batch1 must be a 3D tensor") 3688 torch._check(batch2.dim() == 3, lambda: "batch2 must be a 3D tensor") 3689 3690 batch1_sizes = batch1.size() 3691 batch2_sizes = batch2.size() 3692 3693 bs = batch1_sizes[0] 3694 contraction_size = batch1_sizes[2] 3695 res_rows = batch1_sizes[1] 3696 res_cols = batch2_sizes[2] 3697 output_size = (bs, res_rows, res_cols) 3698 3699 torch._check( 3700 batch2_sizes[0] == bs and batch2_sizes[1] == contraction_size, 3701 lambda: f"Expected size for first two dimensions of batch2 tensor to be: [{bs}" 3702 f", {contraction_size}] but got: [{batch2_sizes[0]}, {batch2_sizes[1]}].", 3703 ) 3704 3705 # TODO: handle out 3706 3707 output = batch2.new_empty(output_size) 3708 3709 if not is_bmm and self_baddbmm is not None: 3710 torch._check(self_baddbmm.dim() == 3, lambda: "self must be a 3D tensor") 3711 torch._check( 3712 self_baddbmm.size() == output_size, 3713 lambda: f"Expected an input tensor shape with shape {output_size} but got shape: {self_baddbmm.size()}", 3714 ) 3715 3716 return output 3717 3718 3719@register_meta(aten.bmm.default) 3720def meta_bmm(self, mat2): 3721 return common_meta_baddbmm_bmm(self, mat2, True) 3722 3723 3724def div_rtn(x, y): 3725 q = x // y 3726 r = x % y 3727 # WARNING: explicit bool conversion here is necessary; 3728 # would be fixed by SymBool 3729 if r != 0 and (bool(r < 0) != bool(y < 0)): 3730 q -= 1 3731 return q 3732 3733 3734def pooling_output_shape_pad_lr( 3735 inputSize, 3736 kernelSize, 3737 pad_l, 3738 pad_r, 3739 stride, 3740 dilation, 3741 ceil_mode, 3742): 3743 outputSize = ( 3744 div_rtn( 3745 inputSize 3746 + pad_l 3747 + pad_r 3748 - dilation * (kernelSize - 1) 3749 - 1 3750 + (stride - 1 if ceil_mode else 0), 3751 stride, 3752 ) 3753 + 1 3754 ) 3755 if ceil_mode: 3756 if (outputSize - 1) * stride >= inputSize + pad_l: 3757 outputSize -= 1 3758 return outputSize 3759 3760 3761def pooling_output_shape(inputSize, kernelSize, pad, stride, dilation, ceil_mode): 3762 torch._check(stride != 0, lambda: "stride should not be zero") 3763 torch._check(pad >= 0, lambda: f"pad must be non-negative, but got pad: {pad}") 3764 torch._check( 3765 pad <= ((kernelSize - 1) * dilation + 1) // 2, 3766 lambda: ( 3767 f"pad should be at most half of effective kernel size, but got pad={pad}, " 3768 f"kernel_size={kernelSize} and dilation={dilation}" 3769 ), 3770 ) 3771 return pooling_output_shape_pad_lr( 3772 inputSize, kernelSize, pad, pad, stride, dilation, ceil_mode 3773 ) 3774 3775 3776def pool2d_shape_check( 3777 input, 3778 kH, 3779 kW, 3780 dH, 3781 dW, 3782 padH, 3783 padW, 3784 dilationH, 3785 dilationW, 3786 nInputPlane, 3787 inputHeight, 3788 inputWidth, 3789 outputHeight, 3790 outputWidth, 3791 memory_format, 3792): 3793 ndim = input.dim() 3794 nOutputPlane = nInputPlane 3795 3796 torch._check( 3797 kW > 0 and kH > 0, 3798 lambda: "kernel size should be greater than zero, but got kH: {kH}, kW: {kW}", 3799 ) 3800 torch._check( 3801 dW > 0 and dH > 0, 3802 lambda: "stride should be greater than zero, but got dH: {dH}, dW: {dW}", 3803 ) 3804 torch._check( 3805 dilationH > 0 and dilationW > 0, 3806 lambda: "dilation should be greater than zero, but got dilationH: {dilationH}, dilationW: {dilationW}", 3807 ) 3808 3809 valid_dims = input.size(1) != 0 and input.size(2) != 0 3810 3811 if memory_format == torch.channels_last: 3812 torch._check( 3813 ndim == 4 and valid_dims and input.size(3) != 0, 3814 lambda: "Expected 4D (batch mode) tensor expected for input with channels_last layout" 3815 " with optional 0 dim batch size for input, but got: {input.size()}", 3816 ) 3817 else: 3818 torch._check( 3819 (ndim == 3 and input.size(0) != 0 and valid_dims) 3820 or (ndim == 4 and valid_dims and input.size(3) != 0), 3821 lambda: f"Expected 3D or 4D (batch mode) tensor with optional 0 dim batch size for input, but got: {input.size()}", 3822 ) 3823 3824 torch._check( 3825 kW // 2 >= padW and kH // 2 >= padH, 3826 lambda: "pad should be smaller than or equal to half of kernel size, but got " 3827 f"padW = {padW}, padH = {padH}, kW = {kW}, kH = {kH}", 3828 ) 3829 3830 torch._check( 3831 outputWidth >= 1 and outputHeight >= 1, 3832 lambda: f"Given input size: ({nInputPlane}x{inputHeight}x{inputWidth}). " 3833 f"Calculated output size: ({nOutputPlane}x{outputHeight}x{outputWidth}). " 3834 "Output size is too small", 3835 ) 3836 3837 3838def pool3d_shape_check( 3839 input: Tensor, 3840 nslices: int, 3841 kT: int, 3842 kH: int, 3843 kW: int, 3844 dT: int, 3845 dH: int, 3846 dW: int, 3847 pT: int, 3848 pH: int, 3849 pW: int, 3850 dilationT: int, 3851 dilationH: int, 3852 dilationW: int, 3853 itime: int, 3854 iheight: int, 3855 iwidth: int, 3856 otime: int, 3857 oheight: int, 3858 owidth: int, 3859 fn_name: str, 3860 check_input_size: bool = False, 3861): 3862 ndim = input.ndim 3863 3864 torch._check( 3865 kT > 0 and kW > 0 and kH > 0, 3866 lambda: ( 3867 f"kernel size should be greater than zero, but got " 3868 f"kT: {kT}, kH: {kH}, kW: {kW}" 3869 ), 3870 ) 3871 torch._check( 3872 dT > 0 and dW > 0 and dH > 0, 3873 lambda: ( 3874 f"stride should be greater than zero, but got " 3875 f"dT: {dT}, dH: {dH}, dW: {dW}" 3876 ), 3877 ) 3878 torch._check( 3879 dilationT > 0 and dilationW > 0 and dilationH > 0, 3880 lambda: ( 3881 f"dilation should be greater than zero, but got " 3882 f"dilationT: {dilationT}, dilationH: {dilationH}, dilationW: {dilationW}" 3883 ), 3884 ) 3885 3886 torch._check( 3887 ndim in (4, 5), 3888 lambda: f"{fn_name}: Expected 4D or 5D tensor for input, but got: {input.shape}", 3889 ) 3890 3891 for i in range(ndim): 3892 if ndim == 5 and i == 0: 3893 # size of batch-dim can be 0. 3894 continue 3895 torch._check( 3896 input.size(i) > 0, 3897 lambda: ( 3898 f"{fn_name}: Expected input's non-batch dimensions to have positive length," 3899 f" but input has a shape of {input.shape}" 3900 f" and non-batch dimension {input.size(i)} has length zero!" 3901 ), 3902 ) 3903 3904 if check_input_size: # AveragePool3d 3905 torch._check( 3906 itime >= kT and iheight >= kH and iwidth >= kW, 3907 lambda: ( 3908 f"input image (T: {itime} H: {iheight} W: {iwidth}) smaller than " 3909 f"kernel size (kT: {kT} kH: {kH} kW: {kW})" 3910 ), 3911 ) 3912 3913 torch._check( 3914 kT / 2 >= pT and kW / 2 >= pW and kH / 2 >= pH, 3915 lambda: ( 3916 f"pad should be smaller than or equal to half of kernel size, but got " 3917 f"kT: {kT} kW: {kW} kH: {kH} padT: {pT} padW: {pW} padH: {pH}" 3918 ), 3919 ) 3920 3921 torch._check( 3922 otime >= 1 and owidth >= 1 and oheight >= 1, 3923 lambda: ( 3924 f"Given input size: ({nslices}x{itime}x{iheight}x{iwidth}). " 3925 f"Calculated output size: ({nslices}x{otime}x{oheight}x{owidth}). " 3926 f"Output size is too small" 3927 ), 3928 ) 3929 3930 3931def max_pool3d_backward_shape_check( 3932 input, 3933 grad_output, 3934 indices, 3935 nslices, 3936 kT, 3937 kH, 3938 kW, 3939 dT, 3940 dH, 3941 dW, 3942 pT, 3943 pH, 3944 pW, 3945 dilationT, 3946 dilationH, 3947 dilationW, 3948 itime, 3949 iheight, 3950 iwidth, 3951 otime, 3952 oheight, 3953 owidth, 3954 fn_name, 3955): 3956 ndim = input.ndim 3957 3958 pool3d_shape_check( 3959 input, 3960 nslices, 3961 kT, 3962 kH, 3963 kW, 3964 dT, 3965 dH, 3966 dW, 3967 pT, 3968 pH, 3969 pW, 3970 dilationT, 3971 dilationH, 3972 dilationW, 3973 itime, 3974 iheight, 3975 iwidth, 3976 otime, 3977 oheight, 3978 owidth, 3979 fn_name, 3980 ) 3981 3982 check_dim_size(grad_output, ndim, ndim - 4, nslices) 3983 check_dim_size(grad_output, ndim, ndim - 3, otime) 3984 check_dim_size(grad_output, ndim, ndim - 2, oheight) 3985 check_dim_size(grad_output, ndim, ndim - 1, owidth) 3986 3987 check_dim_size(indices, ndim, ndim - 4, nslices) 3988 check_dim_size(indices, ndim, ndim - 3, otime) 3989 check_dim_size(indices, ndim, ndim - 2, oheight) 3990 check_dim_size(indices, ndim, ndim - 1, owidth) 3991 3992 3993def avg_pool3d_backward_shape_check( 3994 input: Tensor, 3995 grad_output: Tensor, 3996 nslices: int, 3997 kT: int, 3998 kH: int, 3999 kW: int, 4000 dT: int, 4001 dH: int, 4002 dW: int, 4003 pT: int, 4004 pH: int, 4005 pW: int, 4006 itime: int, 4007 iheight: int, 4008 iwidth: int, 4009 otime: int, 4010 oheight: int, 4011 owidth: int, 4012 fn_name: str, 4013): 4014 ndim = input.ndim 4015 4016 pool3d_shape_check( 4017 input, 4018 nslices, 4019 kT, 4020 kH, 4021 kW, 4022 dT, 4023 dH, 4024 dW, 4025 pT, 4026 pH, 4027 pW, 4028 1, 4029 1, 4030 1, 4031 itime, 4032 iheight, 4033 iwidth, 4034 otime, 4035 oheight, 4036 owidth, 4037 fn_name, 4038 True, 4039 ) 4040 4041 check_dim_size(grad_output, ndim, ndim - 4, nslices) 4042 check_dim_size(grad_output, ndim, ndim - 3, otime) 4043 check_dim_size(grad_output, ndim, ndim - 2, oheight) 4044 check_dim_size(grad_output, ndim, ndim - 1, owidth) 4045 4046 4047def max_pool2d_checks_and_compute_shape( 4048 input, 4049 kernel_size, 4050 stride, 4051 padding, 4052 dilation, 4053 ceil_mode, 4054): 4055 # Reference: aten/src/ATen/native/DilatedMaxPool2d.cpp 4056 def unpack(name, val): 4057 torch._check( 4058 len(val) in [1, 2], 4059 lambda: f"max_pool2d: {name} must either be a single int, or a tuple of two ints", 4060 ) 4061 H = val[0] 4062 W = H if len(val) == 1 else val[1] 4063 return H, W 4064 4065 kH, kW = unpack("kernel_size", kernel_size) 4066 4067 torch._check( 4068 len(stride) in [0, 1, 2], 4069 lambda: "max_pool2d: stride must either be omitted, a single int, or a tuple of two ints", 4070 ) 4071 if len(stride) == 0: 4072 dH, dW = kH, kW 4073 else: 4074 dH, dW = unpack("stride", stride) 4075 4076 padH, padW = unpack("padding", padding) 4077 dilationH, dilationW = unpack("dilation", dilation) 4078 nInputPlane = input.size(-3) 4079 inputHeight = input.size(-2) 4080 inputWidth = input.size(-1) 4081 4082 memory_format = utils.suggest_memory_format(input) 4083 if memory_format == torch.channels_last: 4084 torch._check( 4085 input.dim() == 4, 4086 lambda: "non-empty 4D (batch mode) tensor expected for input with channels_last layout", 4087 ) 4088 elif memory_format == torch.contiguous_format: 4089 torch._check( 4090 input.dim() in [3, 4], 4091 lambda: "non-empty 3D or 4D (batch mode) tensor expected for input", 4092 ) 4093 else: 4094 torch._check( 4095 False, 4096 lambda: "Unsupport memory format. Supports only ChannelsLast, Contiguous", 4097 ) 4098 4099 outputHeight = pooling_output_shape(inputHeight, kH, padH, dH, dilationH, ceil_mode) 4100 outputWidth = pooling_output_shape(inputWidth, kW, padW, dW, dilationW, ceil_mode) 4101 4102 pool2d_shape_check( 4103 input, 4104 kH, 4105 kW, 4106 dH, 4107 dW, 4108 padH, 4109 padW, 4110 dilationH, 4111 dilationW, 4112 nInputPlane, 4113 inputHeight, 4114 inputWidth, 4115 outputHeight, 4116 outputWidth, 4117 memory_format, 4118 ) 4119 4120 return nInputPlane, outputHeight, outputWidth 4121 4122 4123@register_meta(aten.max_pool2d_with_indices_backward.default) 4124def meta_max_pool2d_with_indices_backward( 4125 grad_output, 4126 self, 4127 kernel_size, 4128 stride, 4129 padding, 4130 dilation, 4131 ceil_mode, 4132 indices, 4133): 4134 ( 4135 nInputPlane, 4136 outputHeight, 4137 outputWidth, 4138 ) = max_pool2d_checks_and_compute_shape( 4139 self, kernel_size, stride, padding, dilation, ceil_mode 4140 ) 4141 4142 torch._check( 4143 self.dtype == grad_output.dtype, 4144 lambda: f"Expected dtype {self.dtype} for `gradOutput` but got dtype {grad_output.dtype}", 4145 ) 4146 4147 nOutputPlane = nInputPlane 4148 ndim = self.ndim 4149 4150 def _check_dim_size(t): 4151 check_dim_size(t, ndim, ndim - 3, nOutputPlane) 4152 check_dim_size(t, ndim, ndim - 2, outputHeight) 4153 check_dim_size(t, ndim, ndim - 1, outputWidth) 4154 4155 _check_dim_size(grad_output) 4156 _check_dim_size(indices) 4157 4158 memory_format = utils.suggest_memory_format(self) 4159 return torch.empty( 4160 self.shape, 4161 dtype=self.dtype, 4162 device=self.device, 4163 memory_format=memory_format, 4164 ) 4165 4166 4167@register_meta(aten.max_pool2d_with_indices.default) 4168def meta_max_pool2d_with_indices( 4169 input, 4170 kernel_size, 4171 stride=(), 4172 padding=(0,), 4173 dilation=(1,), 4174 ceil_mode=False, 4175): 4176 ( 4177 nInputPlane, 4178 outputHeight, 4179 outputWidth, 4180 ) = max_pool2d_checks_and_compute_shape( 4181 input, kernel_size, stride, padding, dilation, ceil_mode 4182 ) 4183 4184 nbatch = input.size(-4) if input.dim() == 4 else 1 4185 memory_format = utils.suggest_memory_format(input) 4186 if input.dim() == 3: 4187 size = [nInputPlane, outputHeight, outputWidth] 4188 else: 4189 size = [nbatch, nInputPlane, outputHeight, outputWidth] 4190 return ( 4191 torch.empty( 4192 size, 4193 dtype=input.dtype, 4194 device=input.device, 4195 memory_format=memory_format, 4196 ), 4197 torch.empty( 4198 size, 4199 dtype=torch.int64, 4200 device=input.device, 4201 memory_format=memory_format, 4202 ), 4203 ) 4204 4205 4206@register_meta(aten.fractional_max_pool2d.default) 4207def meta_fractional_max_pool2d(self, kernel_size, output_size, random_samples): 4208 torch._check( 4209 self.ndim in (3, 4), 4210 lambda: f"fractional_max_pool2d: Expected 3D or 4D tensor, but got: {self.ndim}", 4211 ) 4212 ndim = self.ndim 4213 4214 for d in range(ndim - 3, ndim): 4215 torch._check( 4216 self.size(d) > 0, 4217 f"fractional_max_pool2d: Expected input to have non-zero " 4218 f" size for non-batch dimenions, but got {self.size()} with dimension {d} empty", 4219 ) 4220 4221 # the check and message are out of sync, but this matches the structured meta 4222 torch._check( 4223 len(kernel_size) == 2, 4224 lambda: "fractional_max_pool2d: kernel_size must" 4225 "either be a single int or tuple of Ints", 4226 ) 4227 torch._check( 4228 len(output_size) == 2, 4229 lambda: "fractional_max_pool2d: output_size must " 4230 "either be a single int or tuple of Ints", 4231 ) 4232 4233 input_channels = self.size(-3) 4234 input_height = self.size(-2) 4235 input_width = self.size(-1) 4236 if ndim == 4: 4237 input_batch = self.size(0) 4238 else: 4239 input_batch = 1 4240 4241 torch._check( 4242 self.dtype == random_samples.dtype, 4243 lambda: "Expect _random_samples to have the same dtype as input", 4244 ) 4245 torch._check( 4246 random_samples.ndim == 3, 4247 lambda: f"Expect _random samples to have 3 dimensions got, {random_samples.ndim}", 4248 ) 4249 4250 n = random_samples.size(0) 4251 c = random_samples.size(1) 4252 d = random_samples.size(2) 4253 torch._check( 4254 n >= input_batch, 4255 "Expect _random_samples.size(0) no less then input batch size.", 4256 ) 4257 torch._check( 4258 c == input_channels, 4259 lambda: "Expect _random_samples.size(1) equals to input channel size.", 4260 ) 4261 torch._check(d == 2, lambda: f"Expect _random_samples.size(2) equals to 2 got {d}.") 4262 4263 torch._check( 4264 output_size[0] + kernel_size[0] - 1 <= input_height, 4265 lambda: f"fractional_max_pool2d: kernel height {kernel_size[0]} is too large relative to input height {input_height}", 4266 ) 4267 torch._check( 4268 output_size[1] + kernel_size[1] - 1 <= input_width, 4269 lambda: f"fractional_max_pool2d: kernel width {kernel_size[1]} is too large relative to input width {input_width}", 4270 ) 4271 4272 if self.dim() == 4: 4273 size = [input_batch, input_channels, output_size[0], output_size[1]] 4274 else: 4275 size = [input_channels, output_size[0], output_size[1]] 4276 4277 return ( 4278 torch.empty( 4279 size, 4280 dtype=self.dtype, 4281 device=self.device, 4282 ), 4283 torch.empty( 4284 size, 4285 dtype=torch.int64, 4286 device=self.device, 4287 ), 4288 ) 4289 4290 4291@register_meta(aten.max_unpool2d) 4292@out_wrapper() 4293def meta_max_unpool2d(self, indices, output_size): 4294 utils.alert_not_deterministic("max_unpooling2d_forward_out") 4295 4296 torch._check( 4297 indices.dtype == torch.int64, 4298 lambda: f"elements in indices should be type int64 but got: {indices.dtype}", 4299 ) 4300 torch._check( 4301 len(output_size) == 2, 4302 lambda: ( 4303 f"There should be exactly two elements (height, width) in output_size, " 4304 f"but got {len(output_size)} elements." 4305 ), 4306 ) 4307 4308 oheight, owidth = output_size 4309 4310 torch._check( 4311 self.ndim in (3, 4), 4312 lambda: ( 4313 f"Input to max_unpooling2d should be a 3d or 4d Tensor, " 4314 f"but got a tensor with {self.ndim} dimensions." 4315 ), 4316 ) 4317 torch._check( 4318 self.shape == indices.shape, 4319 lambda: ( 4320 f"Expected shape of indices to be same as that of the input tensor ({self.shape}) " 4321 f"but got indices tensor with shape: {indices.shape}" 4322 ), 4323 ) 4324 4325 for i in range(1, self.ndim): 4326 torch._check( 4327 self.size(i) > 0, 4328 lambda: ( 4329 f"max_unpooling2d(): " 4330 f"Expected input to have non-zero size for non-batch dimensions, " 4331 f"but got {self.shape} with dimension {i} being empty." 4332 ), 4333 ) 4334 4335 self = self.contiguous() 4336 4337 if self.ndim == 3: 4338 nchannels = self.size(0) 4339 result = self.new_empty((nchannels, oheight, owidth)) 4340 else: 4341 nbatch = self.size(0) 4342 nchannels = self.size(1) 4343 result = self.new_empty((nbatch, nchannels, oheight, owidth)) 4344 4345 return result 4346 4347 4348def _max_unpooling3d_shape_check(input, indices, output_size, stride, padding, fn_name): 4349 torch._check( 4350 indices.dtype == torch.int64, lambda: "elements in indices should be type int64" 4351 ) 4352 torch._check( 4353 input.ndim in (4, 5), 4354 lambda: f"Input to max_unpooling3d should be a 4d or 5d Tensor, but got a tensor with {input.ndim} dimensions.", 4355 ) 4356 torch._check( 4357 len(output_size) == 3, 4358 lambda: ( 4359 f"There should be exactly three elements (depth, height, width) in output_size, " 4360 f"but got {len(output_size)} elements." 4361 ), 4362 ) 4363 torch._check( 4364 len(stride) == 3, 4365 lambda: f"There should be exactly three elements (depth, height, width) in stride, but got: {len(stride)} elements.", 4366 ) 4367 torch._check( 4368 len(padding) == 3, 4369 lambda: f"There should be exactly three elements (depth, height, width) in padding, but got: {len(padding)} elements.", 4370 ) 4371 torch._check( 4372 input.shape == indices.shape, 4373 lambda: ( 4374 f"Expected shape of indices to be same as that of the input tensor ({input.shape}) " 4375 f"but got indices tensor with shape: {indices.shape}" 4376 ), 4377 ) 4378 4379 for i in range(1, input.ndim): 4380 torch._check( 4381 input.size(i) > 0, 4382 lambda: ( 4383 f"{fn_name}: " 4384 f"Expected input to have non-zero size for non-batch dimensions, " 4385 f"but got {input.shape} with dimension {i} being empty." 4386 ), 4387 ) 4388 4389 torch._check( 4390 stride[0] > 0 and stride[1] > 0 and stride[2] > 0, 4391 lambda: f"strides should be greater than zero, but got stride: {stride}", 4392 ) 4393 4394 4395@register_meta(aten.max_unpool3d) 4396@out_wrapper() 4397def meta_max_unpool3d(self, indices, output_size, stride, padding): 4398 utils.alert_not_deterministic("max_unpooling3d_forward_out") 4399 4400 _max_unpooling3d_shape_check( 4401 self, indices, output_size, stride, padding, "max_unpooling3d()" 4402 ) 4403 4404 self = self.contiguous() 4405 4406 odepth, oheight, owidth = output_size 4407 4408 if self.ndim == 4: 4409 nchannels = self.size(0) 4410 result = self.new_empty((nchannels, odepth, oheight, owidth)) 4411 else: 4412 nbatch = self.size(0) 4413 nchannels = self.size(1) 4414 result = self.new_empty((nbatch, nchannels, odepth, oheight, owidth)) 4415 4416 return result 4417 4418 4419@register_meta(aten.max_pool3d_with_indices) 4420@out_wrapper("out", "indices") 4421def meta_max_pool3d_with_indices( 4422 input, 4423 kernel_size, 4424 stride=(), 4425 padding=(0,), 4426 dilation=(1,), 4427 ceil_mode=False, 4428): 4429 torch._check( 4430 len(kernel_size) in (1, 3), 4431 lambda: "max_pool3d: kernel_size must either be a single int, or a tuple of three ints", 4432 ) 4433 kT = kernel_size[0] 4434 kH = kT if len(kernel_size) == 1 else kernel_size[1] 4435 kW = kT if len(kernel_size) == 1 else kernel_size[2] 4436 4437 torch._check( 4438 not stride or len(stride) in (1, 3), 4439 lambda: "max_pool3d: stride must either be omitted, a single int, or a tuple of three ints", 4440 ) 4441 dT = kT if not stride else stride[0] 4442 dH = kH if not stride else (dT if len(stride) == 1 else stride[1]) 4443 dW = kW if not stride else (dT if len(stride) == 1 else stride[2]) 4444 4445 torch._check( 4446 len(padding) in (1, 3), 4447 lambda: "max_pool3d: padding must either be a single int, or a tuple of three ints", 4448 ) 4449 pT = padding[0] 4450 pH = pT if len(padding) == 1 else padding[1] 4451 pW = pT if len(padding) == 1 else padding[2] 4452 4453 torch._check( 4454 len(dilation) in (1, 3), 4455 lambda: "max_pool3d: dilation must be either a single int, or a tuple of three ints", 4456 ) 4457 dilationT = dilation[0] 4458 dilationH = dilationT if len(dilation) == 1 else dilation[1] 4459 dilationW = dilationT if len(dilation) == 1 else dilation[2] 4460 4461 torch._check( 4462 input.ndim in (4, 5), 4463 lambda: "non-empty 4D or 5D (batch mode) tensor expected for input", 4464 ) 4465 4466 nbatch = input.size(-5) if input.ndim == 5 else 1 4467 nslices = input.size(-4) 4468 itime = input.size(-3) 4469 iheight = input.size(-2) 4470 iwidth = input.size(-1) 4471 4472 otime = pooling_output_shape(itime, kT, pT, dT, dilationT, ceil_mode) 4473 oheight = pooling_output_shape(iheight, kH, pH, dH, dilationH, ceil_mode) 4474 owidth = pooling_output_shape(iwidth, kW, pW, dW, dilationW, ceil_mode) 4475 4476 pool3d_shape_check( 4477 input, 4478 nslices, 4479 kT, 4480 kH, 4481 kW, 4482 dT, 4483 dH, 4484 dW, 4485 pT, 4486 pH, 4487 pW, 4488 dilationT, 4489 dilationH, 4490 dilationW, 4491 itime, 4492 iheight, 4493 iwidth, 4494 otime, 4495 oheight, 4496 owidth, 4497 "max_pool3d_with_indices()", 4498 ) 4499 4500 channels_last = ( 4501 input.ndim == 5 and utils.suggest_memory_format(input) == torch.channels_last_3d 4502 ) 4503 if input.ndim == 4: 4504 input_channels_last_check = input.unsqueeze(0) 4505 channels_last = ( 4506 not input_channels_last_check.is_contiguous() 4507 ) and input_channels_last_check.is_contiguous( 4508 memory_format=torch.channels_last_3d 4509 ) 4510 out_shape = (nslices, otime, oheight, owidth) 4511 else: 4512 out_shape = (nbatch, nslices, otime, oheight, owidth) # type: ignore[assignment] 4513 4514 out = input.new_empty(out_shape) 4515 indices = input.new_empty(out_shape, dtype=torch.int64) 4516 4517 if channels_last: 4518 out = out.to(memory_format=torch.channels_last_3d) 4519 indices = indices.to(memory_format=torch.channels_last_3d) 4520 4521 return out, indices 4522 4523 4524@register_meta(aten.max_pool3d_with_indices_backward) 4525@out_wrapper("grad_input") 4526def meta_max_pool3d_with_indices_backward( 4527 grad_output, 4528 input, 4529 kernel_size, 4530 stride, 4531 padding, 4532 dilation, 4533 ceil_mode, 4534 indices, 4535): 4536 torch._check( 4537 len(kernel_size) in (1, 3), 4538 lambda: "max_pool3d: kernel_size must either be a single int, or a tuple of three ints", 4539 ) 4540 kT = kernel_size[0] 4541 kH = kT if len(kernel_size) == 1 else kernel_size[1] 4542 kW = kT if len(kernel_size) == 1 else kernel_size[2] 4543 4544 torch._check( 4545 not stride or len(stride) in (1, 3), 4546 lambda: "max_pool3d: stride must either be omitted, a single int, or a tuple of three ints", 4547 ) 4548 dT = kT if not stride else stride[0] 4549 dH = kH if not stride else (dT if len(stride) == 1 else stride[1]) 4550 dW = kW if not stride else (dT if len(stride) == 1 else stride[2]) 4551 4552 torch._check( 4553 len(padding) in (1, 3), 4554 lambda: "max_pool3d: padding must either be a single int, or a tuple of three ints", 4555 ) 4556 pT = padding[0] 4557 pH = pT if len(padding) == 1 else padding[1] 4558 pW = pT if len(padding) == 1 else padding[2] 4559 4560 torch._check( 4561 len(dilation) in (1, 3), 4562 lambda: "max_pool3d: dilation must be either a single int, or a tuple of three ints", 4563 ) 4564 dilationT = dilation[0] 4565 dilationH = dilationT if len(dilation) == 1 else dilation[1] 4566 dilationW = dilationT if len(dilation) == 1 else dilation[2] 4567 4568 torch._check( 4569 input.ndim in (4, 5), 4570 lambda: "non-empty 4D or 5D (batch mode) tensor expected for input", 4571 ) 4572 4573 nslices = input.size(-4) 4574 itime = input.size(-3) 4575 iheight = input.size(-2) 4576 iwidth = input.size(-1) 4577 4578 otime = grad_output.size(-3) 4579 oheight = grad_output.size(-2) 4580 owidth = grad_output.size(-1) 4581 4582 max_pool3d_backward_shape_check( 4583 input, 4584 grad_output, 4585 indices, 4586 nslices, 4587 kT, 4588 kH, 4589 kW, 4590 dT, 4591 dH, 4592 dW, 4593 pT, 4594 pH, 4595 pW, 4596 dilationT, 4597 dilationH, 4598 dilationW, 4599 itime, 4600 iheight, 4601 iwidth, 4602 otime, 4603 oheight, 4604 owidth, 4605 "max_pool3d_with_indices_backward()", 4606 ) 4607 4608 channels_last = ( 4609 input.ndim == 5 and utils.suggest_memory_format(input) == torch.channels_last_3d 4610 ) 4611 if input.ndim == 4: 4612 input_channels_last_check = input.unsqueeze(0) 4613 channels_last = ( 4614 not input_channels_last_check.is_contiguous() 4615 ) and input_channels_last_check.is_contiguous( 4616 memory_format=torch.channels_last_3d 4617 ) 4618 4619 grad_input = input.new_empty(input.shape) 4620 4621 if channels_last: 4622 grad_input = grad_input.to(memory_format=torch.channels_last_3d) 4623 4624 return grad_input 4625 4626 4627def check_grid_sampler_common(input: Tensor, grid: Tensor): 4628 torch._check( 4629 input.device == grid.device, 4630 lambda: ( 4631 f"grid_sampler(): expected input and grid to be on same device, but input " 4632 f"is on {input.device} and grid is on {grid.device}" 4633 ), 4634 ) 4635 torch._check( 4636 input.layout == torch.strided and grid.layout == torch.strided, 4637 lambda: ( 4638 f"grid_sampler(): expected input and grid to have torch.strided layout, but " 4639 f"input has {input.layout} and grid has {grid.layout}" 4640 ), 4641 ) 4642 torch._check( 4643 input.shape[0] == grid.shape[0], 4644 lambda: ( 4645 f"grid_sampler(): expected grid and input to have same batch size, but got " 4646 f"input with sizes {input.shape} and grid with sizes {grid.shape}" 4647 ), 4648 ) 4649 torch._check( 4650 grid.shape[-1] == input.ndim - 2, 4651 lambda: ( 4652 f"grid_sampler(): expected grid to have size {input.ndim - 2} in last " 4653 f"dimension, but got grid with sizes {grid.shape}" 4654 ), 4655 ) 4656 4657 for i in range(2, input.ndim): 4658 torch._check( 4659 input.shape[i] > 0, 4660 lambda: ( 4661 f"grid_sampler(): expected input to have non-empty spatial dimensions, " 4662 f"but input has sizes {input.shape} with dimension {i} being empty" 4663 ), 4664 ) 4665 4666 4667class GridSamplerInterpolation(Enum): 4668 BILINEAR = 0 4669 NEAREST = 1 4670 BICUBIC = 2 4671 4672 4673def check_grid_sampler_3d(input: Tensor, grid: Tensor, interpolation_mode: int): 4674 torch._check( 4675 input.ndim == 5 and input.ndim == grid.ndim, 4676 lambda: ( 4677 f"grid_sampler(): expected 5D input and grid with same number of " 4678 f"dimensions, but got input with sizes {input.shape}" 4679 f" and grid with sizes {grid.shape}" 4680 ), 4681 ) 4682 torch._check( 4683 not ( 4684 input.ndim == 5 4685 and interpolation_mode == GridSamplerInterpolation.BICUBIC.value 4686 ), 4687 lambda: "grid_sampler(): bicubic interpolation only supports 4D input", 4688 ) 4689 4690 4691@register_meta(aten.grid_sampler_2d_backward.default) 4692def grid_sampler_2d_backward_meta( 4693 grad_output, 4694 input, 4695 grid, 4696 interpolation_mode, 4697 padding_mode, 4698 align_corners, 4699 output_mask, 4700): 4701 input_requires_grad = output_mask[0] 4702 if input_requires_grad: 4703 grad_input = torch.zeros_like(input, memory_format=torch.contiguous_format) 4704 else: 4705 grad_input = None 4706 grad_grid = torch.empty_like(grid, memory_format=torch.contiguous_format) 4707 return (grad_input, grad_grid) 4708 4709 4710@register_meta(aten.grid_sampler_3d) 4711@out_wrapper() 4712def grid_sampler_3d( 4713 input, 4714 grid, 4715 interpolation_mode, 4716 padding_mode, 4717 align_corners, 4718): 4719 check_grid_sampler_common(input, grid) 4720 check_grid_sampler_3d(input, grid, interpolation_mode) 4721 N = input.shape[0] 4722 C = input.shape[1] 4723 out_D = grid.shape[1] 4724 out_H = grid.shape[2] 4725 out_W = grid.shape[3] 4726 return input.new_empty((N, C, out_D, out_H, out_W)) 4727 4728 4729@register_meta(aten.grid_sampler_3d_backward) 4730@out_wrapper("grad_input", "grad_grid") 4731def grid_sampler_3d_backward( 4732 grad_output, 4733 input, 4734 grid, 4735 interpolation_mode, 4736 padding_mode, 4737 align_corners, 4738 output_mask, 4739): 4740 check_grid_sampler_common(input, grid) 4741 check_grid_sampler_3d(input, grid, interpolation_mode) 4742 input_requires_grad = output_mask[0] 4743 if input_requires_grad: 4744 grad_input = torch.zeros_like( 4745 input, memory_format=torch.legacy_contiguous_format 4746 ) 4747 else: 4748 grad_input = None 4749 grad_grid = torch.empty_like(grid, memory_format=torch.legacy_contiguous_format) 4750 return grad_input, grad_grid 4751 4752 4753@register_meta([aten.full.default]) 4754def full(size, fill_value, *args, **kwargs): 4755 dtype = kwargs.get("dtype", None) 4756 if not dtype: 4757 dtype = utils.get_dtype(fill_value) 4758 kwargs["dtype"] = dtype 4759 return torch.empty(size, *args, **kwargs) 4760 4761 4762# zeros_like is special cased to work for sparse 4763@register_meta(aten.zeros_like.default) 4764def zeros_like( 4765 self, 4766 dtype=None, 4767 layout=None, 4768 device=None, 4769 pin_memory=None, 4770 memory_format=None, 4771): 4772 if layout == torch.sparse_coo: 4773 torch._check( 4774 memory_format is None, 4775 lambda: "memory format option is only supported by strided tensors", 4776 ) 4777 4778 res = torch.empty( 4779 0, 4780 dtype=self.dtype if dtype is None else dtype, 4781 layout=layout, 4782 device=self.device if device is None else device, 4783 pin_memory=pin_memory, 4784 ) 4785 4786 if self.is_sparse: 4787 res.sparse_resize_and_clear_( 4788 self.size(), self.sparse_dim(), self.dense_dim() 4789 ) 4790 else: 4791 res.sparse_resize_and_clear_(self.size(), self.dim(), 0) 4792 4793 res._coalesced_(True) 4794 return res 4795 res = aten.empty_like.default( 4796 self, 4797 dtype=dtype, 4798 layout=layout, 4799 device=device, 4800 pin_memory=pin_memory, 4801 memory_format=memory_format, 4802 ) 4803 # device can be not "meta" 4804 res.fill_(0) 4805 return res 4806 4807 4808@register_meta(aten.select.int) 4809def meta_select(self, dim, index): 4810 ndim = self.dim() 4811 torch._check_index( 4812 ndim != 0, 4813 lambda: "select() cannot be applied to a 0-dim tensor.", 4814 ) 4815 4816 dim = dim if dim >= 0 else dim + ndim 4817 size = self.size(dim) 4818 4819 torch._check_index( 4820 not (-index > size or index >= size), 4821 lambda: f"select(): index {index} out of range for tensor of size " 4822 f"{self.size()} at dimension {dim}", 4823 ) 4824 4825 index = index if index >= 0 else index + size 4826 4827 new_size = list(self.size()) 4828 new_stride = list(self.stride()) 4829 4830 new_storage_offset = self.storage_offset() + index * new_stride[dim] 4831 del new_size[dim] 4832 del new_stride[dim] 4833 4834 return self.as_strided(new_size, new_stride, new_storage_offset) 4835 4836 4837@register_meta(aten.select_scatter.default) 4838def meta_select_scatter(self, src, dim, index): 4839 return utils.clone_preserve_strides(self) 4840 4841 4842@register_meta(aten.slice_scatter.default) 4843def meta_slice_scatter(self, src, dim=0, start=None, end=None, step=1): 4844 return utils.clone_preserve_strides(self) 4845 4846 4847# TODO: Deduplicate this with canonicalize_dim 4848def maybe_wrap_dim(dim: int, dim_post_expr: int, wrap_scalar: bool = True): 4849 if dim_post_expr <= 0: 4850 assert wrap_scalar 4851 dim_post_expr = 1 4852 min = -dim_post_expr 4853 max = dim_post_expr - 1 4854 assert not (dim < min or dim > max), f"dim {dim} out of bounds ({min}, {max})" 4855 if dim < 0: 4856 dim += dim_post_expr 4857 return dim 4858 4859 4860def ensure_nonempty_size(t, dim): 4861 return 1 if t.dim() == 0 else t.shape[dim] 4862 4863 4864# From aten/src/ATen/native/ScatterGatherChecks.h 4865def gather_shape_check(self, dim, index): 4866 self_dims = max(self.dim(), 1) 4867 index_dims = max(index.dim(), 1) 4868 torch._check( 4869 self_dims == index_dims, 4870 lambda: "Index tensor must have the same number of dimensions as input tensor", 4871 ) 4872 for i in range(self_dims): 4873 if i != dim: 4874 torch._check( 4875 ensure_nonempty_size(index, i) <= ensure_nonempty_size(self, i), 4876 lambda: f"Size does not match at dimension {i} expected index {index.shape}" 4877 + f" to be smaller than self {self.shape} apart from dimension {dim}", 4878 ) 4879 4880 4881@register_meta(aten.gather.default) 4882def meta_gather(self, dim, index, sparse_grad=False): 4883 from torch.fx.experimental.symbolic_shapes import guard_size_oblivious 4884 4885 wrapped_dim = maybe_wrap_dim(dim, self.dim()) 4886 is_index_empty = guard_size_oblivious(index.numel() == 0) 4887 if not is_index_empty: 4888 torch._check( 4889 index.dtype == torch.long, 4890 lambda: f"gather(): Expected dtype int64 for index, but got {index.dtype}", 4891 ) 4892 gather_shape_check(self, wrapped_dim, index) 4893 return self.new_empty(index.shape) 4894 4895 4896# From aten/src/ATen/native/TensorAdvancedIndexing.cpp 4897def get_operator_enum(reduce_, use_new_options=False): 4898 if use_new_options: 4899 if reduce_ == "sum": 4900 return "REDUCE_ADD" 4901 elif reduce_ == "prod": 4902 return "REDUCE_MULTIPLY" 4903 elif reduce_ == "mean": 4904 return "REDUCE_MEAN" 4905 elif reduce_ == "amax": 4906 return "REDUCE_MAXIMUM" 4907 elif reduce_ == "amin": 4908 return "REDUCE_MINIMUM" 4909 torch._check( 4910 False, 4911 lambda: "reduce argument must be either sum, prod, mean, amax or amin.", 4912 ) 4913 return 4914 else: 4915 if reduce_ == "add": 4916 return "REDUCE_ADD" 4917 elif reduce_ == "multiply": 4918 return "REDUCE_MULTIPLY" 4919 torch._check(False, lambda: "reduce argument must be either add or multiply.") 4920 return 4921 4922 4923# From aten/src/ATen/native/ScatterGatherChecks.h 4924def scatter_gather_dtype_check(method_name, self, index, src_opt=None): 4925 from torch.fx.experimental.symbolic_shapes import guard_size_oblivious 4926 4927 if guard_size_oblivious(index.numel() != 0): 4928 torch._check( 4929 index.dtype == torch.long, 4930 lambda: f"{method_name}(): Expected dtype int64 for index", 4931 ) 4932 4933 if src_opt is not None: 4934 torch._check( 4935 self.dtype == src_opt.dtype, 4936 lambda: f"{method_name}(): Expected self.dtype to be equal to src.dtype", 4937 ) 4938 4939 4940def ensure_nonempty_dim(dim): 4941 return max(dim, 1) 4942 4943 4944# From aten/src/ATen/native/ScatterGatherChecks.h 4945def scatter_shape_check(self, dim, index, src_opt=None): 4946 from torch.fx.experimental.symbolic_shapes import guard_size_oblivious 4947 4948 if guard_size_oblivious(index.numel() == 0): 4949 return 4950 torch._check( 4951 ensure_nonempty_dim(self.dim()) == ensure_nonempty_dim(index.dim()), 4952 lambda: "Index tensor must have the same number of dimensions as self tensor", 4953 ) 4954 4955 is_wrong_shape = False 4956 self_dims = ensure_nonempty_dim(self.dim()) 4957 4958 # Check: index.size(d) <= self.size(d) for all d != dim 4959 for d in range(self_dims): 4960 index_d_size = ensure_nonempty_size(index, d) 4961 if d == dim: 4962 continue 4963 if index_d_size > ensure_nonempty_size(self, d): 4964 is_wrong_shape = True 4965 break 4966 4967 # Check: index.size(d) <= src.size(d) for all d if src is Tensor 4968 if not is_wrong_shape and src_opt is not None: 4969 for d in range(self_dims): 4970 index_d_size = ensure_nonempty_size(index, d) 4971 if index_d_size > ensure_nonempty_size(src_opt, d): 4972 is_wrong_shape = True 4973 break 4974 4975 if src_opt is not None: 4976 torch._check( 4977 ensure_nonempty_dim(self.dim()) == ensure_nonempty_dim(index.dim()), 4978 lambda: "Index tensor must have the same number of dimensions as self tensor", 4979 ) 4980 torch._check( 4981 not is_wrong_shape, 4982 lambda: f"Expected index {index.shape} to be smaller than self {self.shape}" 4983 + f" apart from dimension {dim} and to be smaller than src {src_opt.shape}", 4984 ) 4985 else: 4986 torch._check( 4987 not is_wrong_shape, 4988 lambda: f"Expected index {index.shape} to be smaller than self {self.shape}" 4989 + f" apart from dimension {dim}", 4990 ) 4991 4992 4993# From aten/src/ATen/native/TensorAdvancedIndexing.cpp 4994def scatter_meta_impl(self, dim, index, src=None, reduce_=None, use_new_options=False): 4995 wrapped_dim = maybe_wrap_dim(dim, self.dim()) 4996 scatter_gather_dtype_check("scatter", self, index, src) 4997 scatter_shape_check(self, wrapped_dim, index, src) 4998 if reduce_ is not None: 4999 # Check if we have a valid reduce operator. 5000 get_operator_enum(reduce_, use_new_options) 5001 5002 5003@register_meta(aten.scatter_add.default) 5004def meta_scatter_add(self, dim, index, src): 5005 scatter_meta_impl(self, dim, index, src, "add") 5006 return self.new_empty(self.shape) 5007 5008 5009@register_meta(aten.scatter_add_) 5010def meta_scatter_add_(self, dim, index, src): 5011 scatter_meta_impl(self, dim, index, src, "add") 5012 return self 5013 5014 5015@register_meta( 5016 [ 5017 aten.scatter.src, 5018 aten.scatter.value, 5019 aten.scatter.reduce, 5020 aten.scatter.value_reduce, 5021 ] 5022) 5023@out_wrapper() 5024def meta_scatter(self, dim, index, src_or_value, reduce=None): 5025 src = src_or_value if isinstance(src_or_value, torch.Tensor) else None 5026 scatter_meta_impl(self, dim, index, src, reduce) 5027 return self.new_empty(self.shape) 5028 5029 5030@register_meta( 5031 [ 5032 aten.scatter_.src, 5033 aten.scatter_.value, 5034 aten.scatter_.reduce, 5035 aten.scatter_.value_reduce, 5036 ] 5037) 5038def meta_scatter_(self, dim, index, src_or_value, reduce=None): 5039 src = src_or_value if isinstance(src_or_value, torch.Tensor) else None 5040 scatter_meta_impl(self, dim, index, src, reduce) 5041 return self 5042 5043 5044@register_meta([aten._scaled_dot_product_flash_attention]) 5045def meta__scaled_dot_product_flash_attention( 5046 query: Tensor, 5047 key: Tensor, 5048 value: Tensor, 5049 dropout_p: float = 0.0, 5050 is_causal: bool = False, 5051 return_debug_mask: bool = False, 5052 scale: Optional[float] = None, 5053): 5054 batch_size = query.size(0) 5055 num_heads = query.size(1) 5056 max_seqlen_batch_q = query.size(2) 5057 head_dim = query.size(3) 5058 max_seqlen_batch_k = key.size(2) 5059 5060 query_t = query.transpose(1, 2) 5061 attention = torch.empty_like(query_t).transpose(1, 2) 5062 logsumexp = torch.empty( 5063 (batch_size, num_heads, max_seqlen_batch_q), 5064 dtype=torch.float, 5065 device=query.device, 5066 ) 5067 5068 if return_debug_mask: 5069 blocksize_c = 128 if head_dim > 64 else 256 5070 max_seqlen_k = math.ceil(max_seqlen_batch_q / blocksize_c) 5071 if max_seqlen_batch_k <= 128: 5072 max_seqlen_k = 128 5073 elif max_seqlen_batch_k <= 256: 5074 max_seqlen_k = 256 5075 debug_mask = torch.empty( 5076 (batch_size, num_heads, max_seqlen_batch_q, max_seqlen_k), 5077 dtype=query.dtype, 5078 device=query.device, 5079 ) 5080 else: 5081 debug_mask = torch.empty(0, dtype=query.dtype, device=query.device) 5082 5083 # Note [Seed and Offset]: device for seed and offset below depends on whether we are 5084 # capturing or not, but at the time of tracing we don't know if we 5085 # are going to use cudagraphs or not, so we return meta tensors here 5086 # it's possible we'll need to have some special handling in inductor for sdpa 5087 5088 return ( 5089 attention, 5090 logsumexp, 5091 None, 5092 None, 5093 max_seqlen_batch_q, 5094 max_seqlen_batch_k, 5095 torch.empty((), dtype=torch.long, device="meta"), 5096 torch.empty((), dtype=torch.long, device="meta"), 5097 debug_mask, 5098 ) 5099 5100 5101@register_meta([aten._scaled_dot_product_cudnn_attention]) 5102def meta__scaled_dot_product_cudnn_attention( 5103 query: Tensor, 5104 key: Tensor, 5105 value: Tensor, 5106 attn_bias: Optional[Tensor], 5107 compute_log_sumexp: bool, 5108 dropout_p: float = 0.0, 5109 is_causal: bool = False, 5110 return_debug_mask: bool = False, 5111 scale: Optional[float] = None, 5112): 5113 B = query.size(0) 5114 H = query.size(1) 5115 S_Q = query.size(2) 5116 S_KV = key.size(2) 5117 D_QK = query.size(-1) 5118 D_V = value.size(-1) 5119 5120 res = torch.empty((B, H, S_Q, D_V), dtype=query.dtype, device=query.device) 5121 logsum_exp = torch.empty( 5122 (B, H, S_Q), 5123 dtype=torch.float, 5124 device=query.device, 5125 ) 5126 5127 # See Note [Seed and Offset] 5128 seed = torch.empty((), dtype=torch.long, device="meta") 5129 offset = torch.empty((), dtype=torch.long, device="meta") 5130 5131 return ( 5132 res, 5133 logsum_exp, 5134 None, 5135 None, 5136 S_Q, 5137 S_KV, 5138 seed, 5139 offset, 5140 None, 5141 ) 5142 5143 5144@register_meta( 5145 [ 5146 aten._scaled_dot_product_flash_attention_backward, 5147 ] 5148) 5149def meta__scaled_dot_product_flash_backward( 5150 grad_out: Tensor, 5151 query: Tensor, 5152 key: Tensor, 5153 value: Tensor, 5154 out: Tensor, 5155 logsumexp: Tensor, 5156 cum_seq_q: Tensor, 5157 cum_seq_k: Tensor, 5158 max_q: int, 5159 max_k: int, 5160 dropout_p: float, 5161 is_causal: bool, 5162 philox_seed: Tensor, 5163 philox_offset: Tensor, 5164 scale: Optional[float] = None, 5165): 5166 grad_q = torch.empty_like(query.transpose(1, 2)).transpose(1, 2) 5167 grad_k = torch.empty_like(key.transpose(1, 2)).transpose(1, 2) 5168 grad_v = torch.empty_like(value.transpose(1, 2)).transpose(1, 2) 5169 return grad_q, grad_k, grad_v 5170 5171 5172@register_meta( 5173 [ 5174 aten._scaled_dot_product_flash_attention_for_cpu, 5175 ] 5176) 5177def meta__scaled_dot_product_flash_attention_for_cpu( 5178 query: Tensor, 5179 key: Tensor, 5180 value: Tensor, 5181 dropout_p: float = 0.0, 5182 is_causal: bool = False, 5183 attn_mask: Optional[Tensor] = None, 5184 scale: Optional[float] = None, 5185): 5186 batch_size = query.size(0) 5187 num_heads = query.size(1) 5188 max_seqlen_batch_q = query.size(2) 5189 head_dim = query.size(3) 5190 5191 attention = torch.empty_like(query) 5192 logsumexp = torch.empty( 5193 ( 5194 batch_size, 5195 max_seqlen_batch_q, 5196 num_heads, 5197 ), 5198 dtype=torch.float, 5199 device=query.device, 5200 ).transpose(1, 2) 5201 return ( 5202 attention, 5203 logsumexp, 5204 ) 5205 5206 5207@register_meta( 5208 [ 5209 aten._scaled_dot_product_flash_attention_for_cpu_backward, 5210 ] 5211) 5212def meta__scaled_dot_product_flash_attention_for_cpu_backward( 5213 grad_out: Tensor, 5214 query: Tensor, 5215 key: Tensor, 5216 value: Tensor, 5217 out: Tensor, 5218 logsumexp: Tensor, 5219 dropout_p: float, 5220 is_causal: bool, 5221 attn_mask: Optional[Tensor] = None, 5222 scale: Optional[float] = None, 5223): 5224 # cpus's grad layout is different from cuda's, 5225 # i.e. (batch_size, seq_len,num_heads, head_dim) 5226 batch_size = query.size(0) 5227 num_heads = query.size(1) 5228 head_dim = query.size(3) 5229 len_q = query.size(2) 5230 len_k = key.size(2) 5231 5232 grad_q = torch.empty_permuted( 5233 (batch_size, num_heads, len_q, head_dim), 5234 (0, 2, 1, 3), 5235 dtype=query.dtype, 5236 device=query.device, 5237 ) 5238 grad_k = torch.empty_permuted( 5239 (batch_size, num_heads, len_k, head_dim), 5240 (0, 2, 1, 3), 5241 dtype=key.dtype, 5242 device=key.device, 5243 ) 5244 grad_v = torch.empty_permuted( 5245 (batch_size, num_heads, len_k, head_dim), 5246 (0, 2, 1, 3), 5247 dtype=value.dtype, 5248 device=value.device, 5249 ) 5250 5251 return grad_q, grad_k, grad_v 5252 5253 5254@register_meta([aten._scaled_dot_product_efficient_attention]) 5255def meta__scaled_dot_product_efficient_attention( 5256 query: Tensor, 5257 key: Tensor, 5258 value: Tensor, 5259 attn_bias: Optional[Tensor], 5260 compute_log_sumexp: bool, 5261 dropout_p=0.0, 5262 is_causal: bool = False, 5263 scale: Optional[float] = None, 5264): 5265 query = query.transpose(1, 2) 5266 key = key.transpose(1, 2) 5267 value = value.transpose(1, 2) 5268 5269 B = query.size(0) 5270 M = query.size(1) 5271 N = key.size(1) 5272 num_heads = query.size(-2) 5273 K = query.size(-1) 5274 Kv = value.size(-1) 5275 5276 res = torch.empty(B, M, num_heads, Kv, dtype=query.dtype, device=query.device) 5277 5278 logsumexp_dim = math.ceil(M / 32) * 32 if compute_log_sumexp else 0 5279 logsum_exp = torch.empty( 5280 (B, num_heads, logsumexp_dim), 5281 dtype=torch.float, 5282 device=query.device, 5283 ) 5284 5285 res = res.transpose(1, 2) 5286 5287 # See Note [Seed and Offset]: 5288 seed = torch.empty((), dtype=torch.long, device="meta") 5289 offset = torch.empty((), dtype=torch.long, device="meta") 5290 5291 return res, logsum_exp, seed, offset 5292 5293 5294@register_meta( 5295 [ 5296 aten._scaled_dot_product_efficient_attention_backward, 5297 ] 5298) 5299def meta__scaled_dot_product_efficient_backward( 5300 grad_out: Tensor, 5301 query: Tensor, 5302 key: Tensor, 5303 value: Tensor, 5304 attn_bias: Optional[Tensor], 5305 out: Tensor, 5306 logsumexp: Tensor, 5307 philox_seed: Tensor, 5308 philox_offset: Tensor, 5309 dropout_p: float, 5310 grad_input_mask: List[bool], 5311 is_causal: bool = False, 5312 scale: Optional[float] = None, 5313): 5314 batch_size = query.size(0) 5315 num_heads = query.size(1) 5316 max_q = query.size(2) 5317 head_dim = query.size(3) 5318 head_dim_v = value.size(3) 5319 5320 max_k = key.size(2) 5321 5322 grad_q = torch.empty_permuted( 5323 (batch_size, num_heads, max_q, head_dim), 5324 (0, 2, 1, 3), 5325 dtype=query.dtype, 5326 device=query.device, 5327 ) 5328 grad_k = torch.empty_permuted( 5329 (batch_size, num_heads, max_k, head_dim), 5330 (0, 2, 1, 3), 5331 dtype=key.dtype, 5332 device=key.device, 5333 ) 5334 grad_v = torch.empty_permuted( 5335 (batch_size, num_heads, max_k, head_dim_v), 5336 (0, 2, 1, 3), 5337 dtype=value.dtype, 5338 device=value.device, 5339 ) 5340 grad_bias = None 5341 if attn_bias is not None and grad_input_mask[3]: 5342 lastDim = attn_bias.size(-1) 5343 lastDimAligned = lastDim if lastDim % 16 == 0 else lastDim + 16 - lastDim % 16 5344 new_sizes = list(attn_bias.size()) 5345 new_sizes[-1] = lastDimAligned 5346 grad_bias = torch.empty( 5347 new_sizes, dtype=attn_bias.dtype, device=attn_bias.device 5348 ) 5349 grad_bias = grad_bias[..., :lastDim] 5350 5351 return grad_q, grad_k, grad_v, grad_bias 5352 5353 5354@register_meta( 5355 [ 5356 aten._scaled_dot_product_cudnn_attention_backward, 5357 ] 5358) 5359def meta__scaled_dot_product_cudnn_backward( 5360 grad_out: Tensor, 5361 query: Tensor, 5362 key: Tensor, 5363 value: Tensor, 5364 out: Tensor, 5365 logsumexp: Tensor, 5366 philox_seed: Tensor, 5367 philox_offset: Tensor, 5368 attn_bias: Tensor, 5369 cum_seq_q: Tensor, 5370 cum_seq_k: Tensor, 5371 max_q: int, 5372 max_k: int, 5373 dropout_p: float, 5374 is_causal: bool, 5375 scale: Optional[float] = None, 5376): 5377 grad_q = torch.empty_like(query) 5378 grad_k = torch.empty_like(key) 5379 grad_v = torch.empty_like(value) 5380 return grad_q, grad_k, grad_v 5381 5382 5383@register_meta( 5384 [ 5385 aten._flash_attention_forward, 5386 ] 5387) 5388def meta__flash_attention_forward( 5389 query: Tensor, 5390 key: Tensor, 5391 value: Tensor, 5392 cum_seq_q: Optional[Tensor], 5393 cum_seq_k: Optional[Tensor], 5394 max_q: int, 5395 max_k: int, 5396 dropout_p: float, 5397 is_causal: bool, 5398 return_debug_mask: bool, 5399 scale: Optional[float] = None, 5400 window_size_left: Optional[int] = None, 5401 window_size_right: Optional[int] = None, 5402 seqused_k: Optional[Tensor] = None, 5403 alibi_slopes: Optional[Tensor] = None, 5404): 5405 # NB: there are two underlying paths: 5406 # 1. normal dense path; expect 4D inputs of shape (batch_size, seqlen, num_heads, head_dim) 5407 # 2. varseqlen path; expect 3D inputs of shape (total, num_heads, head_dim) where total 5408 # includes all batch item sequences. cum_seq_q / cum_seq_k contain offsets into total 5409 batch_size = query.size(0) if cum_seq_q is None else cum_seq_q.numel() - 1 5410 max_seqlen_batch_q = query.size(1) if cum_seq_q is None else max_q 5411 max_seqlen_batch_k = key.size(1) if cum_seq_k is None else max_k 5412 num_heads = query.size(-2) 5413 head_dim = query.size(-1) 5414 5415 # Cuda Path 5416 attention = torch.empty_like(query) 5417 logsumexp = torch.empty( 5418 (batch_size, num_heads, max_seqlen_batch_q), 5419 dtype=torch.float, 5420 device=query.device, 5421 ) 5422 5423 if return_debug_mask: 5424 blocksize_c = 128 if head_dim > 64 else 256 5425 max_seqlen_k = math.ceil(max_seqlen_batch_q / blocksize_c) 5426 if max_seqlen_batch_k <= 128: 5427 max_seqlen_k = 128 5428 elif max_seqlen_batch_k <= 256: 5429 max_seqlen_k = 256 5430 debug_mask = torch.empty( 5431 (batch_size, num_heads, max_seqlen_batch_q, max_seqlen_k), 5432 dtype=query.dtype, 5433 device=query.device, 5434 ) 5435 else: 5436 debug_mask = torch.empty(0, dtype=query.dtype, device=query.device) 5437 5438 # See Note [Seed and Offset]: 5439 return ( 5440 attention, 5441 logsumexp, 5442 torch.empty((), dtype=torch.long, device="meta"), 5443 torch.empty((), dtype=torch.long, device="meta"), 5444 debug_mask, 5445 ) 5446 5447 5448@register_meta( 5449 [ 5450 aten._flash_attention_backward, 5451 ] 5452) 5453def meta__flash_attention_backward( 5454 grad_out: Tensor, 5455 query: Tensor, 5456 key: Tensor, 5457 value: Tensor, 5458 out: Tensor, 5459 logsumexp: Tensor, 5460 cum_seq_q: Tensor, 5461 cum_seq_k: Tensor, 5462 max_q: int, 5463 max_k: int, 5464 dropout_p: float, 5465 is_causal: bool, 5466 philox_seed: Tensor, 5467 philox_offset: Tensor, 5468 scale: Optional[float] = None, 5469 window_size_left: Optional[int] = None, 5470 window_size_right: Optional[int] = None, 5471): 5472 grad_query = torch.empty_like(query) 5473 grad_key = torch.empty_like(key) 5474 grad_value = torch.empty_like(value) 5475 5476 return grad_query, grad_key, grad_value 5477 5478 5479@register_meta( 5480 [ 5481 aten._efficient_attention_forward, 5482 ] 5483) 5484def meta__efficient_attention_forward( 5485 query: Tensor, 5486 key: Tensor, 5487 value: Tensor, 5488 bias: Optional[Tensor], 5489 cu_seqlens_q: Optional[Tensor], 5490 cu_seqlens_k: Optional[Tensor], 5491 max_seqlen_q: Optional[int], 5492 max_seqlen_k: Optional[int], 5493 dropout_p: float, 5494 custom_mask_type: int, 5495 compute_log_sumexp: bool = False, 5496 scale: Optional[float] = None, 5497 causal_diagonal: Optional[Tensor] = None, 5498 seqlen_k: Optional[Tensor] = None, 5499 window_size: Optional[int] = None, 5500): 5501 B = query.size(0) 5502 M = query.size(1) 5503 N = key.size(1) 5504 num_heads = query.size(-2) 5505 K = query.size(-1) 5506 Kv = value.size(-1) 5507 5508 res = torch.empty(B, M, num_heads, Kv, dtype=query.dtype, device=query.device) 5509 5510 logsumexp_batch_dim = cu_seqlens_q.size(0) - 1 if (cu_seqlens_q is not None) else B 5511 actual_max_seqlen_q = M 5512 if cu_seqlens_q is not None: 5513 assert max_seqlen_q is not None 5514 actual_max_seqlen_q = max_seqlen_q 5515 actual_max_seqlen_k = max_seqlen_k if max_seqlen_k is not None else N 5516 logsumexp_dim = ( 5517 math.ceil(actual_max_seqlen_q / 32) * 32 if compute_log_sumexp else 0 5518 ) 5519 logsum_exp = torch.empty( 5520 (logsumexp_batch_dim, num_heads, logsumexp_dim), 5521 dtype=torch.float, 5522 device=query.device, 5523 ) 5524 5525 # See Note [Seed and Offset]: 5526 seed = torch.empty((), dtype=torch.long, device="meta") 5527 offset = torch.empty((), dtype=torch.long, device="meta") 5528 5529 return res, logsum_exp, seed, offset, actual_max_seqlen_q, actual_max_seqlen_k 5530 5531 5532@register_meta( 5533 [ 5534 aten._efficient_attention_backward, 5535 ] 5536) 5537def meta__efficient_attention_backward( 5538 grad_out: Tensor, 5539 query: Tensor, 5540 key: Tensor, 5541 value: Tensor, 5542 bias: Optional[Tensor], 5543 cu_seqlens_q: Optional[Tensor], 5544 cu_seqlens_k: Optional[Tensor], 5545 max_seqlen_q: torch.SymInt, 5546 max_seqlen_k: torch.SymInt, 5547 logsumexp: Tensor, 5548 dropout_p: float, 5549 philox_seed: Tensor, 5550 philox_offset: Tensor, 5551 custom_mask_type: int, 5552 bias_requires_grad: bool, 5553 scale: Optional[float] = None, 5554 num_splits_key: Optional[int] = None, 5555 shared_storage_dqdkdv: bool = False, 5556): 5557 if shared_storage_dqdkdv: 5558 torch._check( 5559 query.shape[1] == key.shape[1], 5560 lambda: "seqlen must match for `shared_storage_dqdkdv", 5561 ) 5562 torch._check( 5563 query.shape[3] == key.shape[3], 5564 lambda: "embedding dim must match for `shared_storage_dqdkdv", 5565 ) 5566 chunk = torch.empty( 5567 (*query.shape[0:-2], 3, query.shape[-2], query.shape[-1]), 5568 dtype=query.dtype, 5569 device=query.device, 5570 ) 5571 grad_query = chunk.select(-3, 0) 5572 grad_key = chunk.select(-3, 1) 5573 grad_value = chunk.select(-3, 2) 5574 else: 5575 grad_query = torch.empty_like(query) 5576 grad_key = torch.empty_like(key) 5577 grad_value = torch.empty_like(value) 5578 5579 if bias is not None: 5580 lastDim = bias.size(-1) 5581 lastDimAligned = lastDim if lastDim % 16 == 0 else lastDim + 16 - lastDim % 16 5582 new_sizes = list(bias.size()) 5583 new_sizes[-1] = lastDimAligned 5584 grad_bias = torch.empty(new_sizes, dtype=bias.dtype, device=bias.device) 5585 grad_bias = grad_bias[..., :lastDim] 5586 else: 5587 grad_bias = torch.empty((), device=query.device) 5588 5589 return grad_query, grad_key, grad_value, grad_bias 5590 5591 5592@register_meta([aten._scaled_mm.default]) 5593def meta_scaled_mm( 5594 self: torch.Tensor, 5595 mat2: torch.Tensor, 5596 scale_a: torch.Tensor, 5597 scale_b: torch.Tensor, 5598 bias: Optional[torch.Tensor] = None, 5599 scale_result: Optional[torch.Tensor] = None, 5600 out_dtype: Optional[torch.dtype] = None, 5601 use_fast_accum: bool = False, 5602): 5603 def is_row_major(stride): 5604 return stride[0] > stride[1] and stride[1] == 1 5605 5606 def is_col_major(stride): 5607 return stride[0] == 1 and stride[1] > 1 5608 5609 def is_fp8_type(dtype): 5610 return dtype in ( 5611 torch.float8_e4m3fn, 5612 torch.float8_e5m2, 5613 torch.float8_e4m3fnuz, 5614 torch.float8_e5m2fnuz, 5615 ) 5616 5617 torch._check( 5618 self.dim() == 2 and mat2.dim() == 2, 5619 lambda: f"Inputs must be 2D but got self.dim()={self.dim()} and mat2.dim()={mat2.dim()}", 5620 ) 5621 torch._check( 5622 is_row_major(self.stride()), 5623 lambda: "self must be row_major", 5624 ) 5625 torch._check( 5626 is_col_major(mat2.stride()), 5627 lambda: "mat2 must be col_major", 5628 ) 5629 torch._check( 5630 self.size(1) % 16 == 0, 5631 lambda: f"Expected self.size(1) to be divisible by 16, but got self.size(1)={self.size(1)}", 5632 ) 5633 torch._check( 5634 mat2.size(0) % 16 == 0 and mat2.size(1) % 16 == 0, 5635 lambda: f"Expected both dimensions of mat2 to be divisble by 16 but got {mat2.shape}", 5636 ) 5637 torch._check( 5638 is_fp8_type(self.dtype) and is_fp8_type(mat2.dtype), 5639 lambda: f"Expected both inputs to be fp8 types but got self.dtype={self.dtype} and mat2.dtype={mat2.dtype}", 5640 ) 5641 5642 # determine scaling type and check input dimensions (refer to Blas.cpp op) 5643 torch._check( 5644 scale_a.dtype == torch.float32 and scale_b.dtype == torch.float32, 5645 lambda: "Both scale_a and scale_b must be float (fp32) tensors.", 5646 ) 5647 m, k = self.shape 5648 n = mat2.size(1) 5649 if scale_a.numel() == 1 and scale_b.numel() == 1: 5650 # tensorwise scaling 5651 pass 5652 else: 5653 # for non-tensorwise scaling, enforce 2D input tensors 5654 torch._check( 5655 scale_a.dim() == 2 and scale_b.dim() == 2, 5656 lambda: f"For non-tensorwise scaling, scale tensors must be 2D, but got {scale_a.dim()=} and {scale_b.dim()=}", 5657 ) 5658 5659 if ( 5660 scale_a.size(0) == m 5661 and scale_a.size(1) == 1 5662 and scale_b.size(0) == 1 5663 and scale_b.size(1) == n 5664 ): 5665 # rowwise scaling 5666 torch._check( 5667 scale_a.is_contiguous() and scale_b.is_contiguous(), 5668 lambda: "Both scale_a and scale_b must be contiguous for rowwise scaling.", 5669 ) 5670 else: 5671 # does not match any valid scaling type 5672 torch._check( 5673 False, 5674 lambda: ( 5675 "Invalid scaling configuration. " 5676 "For tensorwise scaling, both scales should be scalar. " 5677 f"For rowwise scaling, scale_a should be ({m}, 1), scale_b should be (1, {n}). " 5678 f"Got scale_a.size()=({scale_a.size(0)}, {scale_a.size(1)}) " 5679 f"and scale_b.size()=({scale_b.size(0)}, {scale_b.size(1)})" 5680 ), 5681 ) 5682 5683 _out_dtype = out_dtype if out_dtype is not None else self.dtype 5684 return torch.empty(self.size(0), mat2.size(1), dtype=_out_dtype, device=self.device) 5685 5686 5687@register_meta([aten.scatter_reduce.two, aten.scatter_reduce.two_out]) 5688@out_wrapper() 5689def meta_scatter_reduce_two(self, dim, index, src, reduce, include_self=True): 5690 scatter_meta_impl(self, dim, index, src, reduce, use_new_options=True) 5691 return self.new_empty(self.shape) 5692 5693 5694@register_meta(aten.scatter_reduce_.two) 5695def meta_scatter_reduce__two(self, dim, index, src, reduce, include_self=True): 5696 scatter_meta_impl(self, dim, index, src, reduce, use_new_options=True) 5697 return self 5698 5699 5700@register_meta([aten.multinomial.default, aten.multinomial.out]) 5701@out_wrapper() 5702def meta_multinomial(input, num_samples, replacement=False, *, generator=None): 5703 torch._check( 5704 0 < input.dim() <= 2, 5705 lambda: f"The probabilty distributions dimensions must be 1 or 2, but got {input.dim()}", 5706 ) 5707 if input.dim() == 1: 5708 return torch.empty(num_samples, dtype=torch.long, device=input.device) 5709 return torch.empty( 5710 input.size(0), num_samples, dtype=torch.long, device=input.device 5711 ) 5712 5713 5714def multiply_integers(vs): 5715 r = 1 5716 for v in vs: 5717 r *= v 5718 return r 5719 5720 5721def upsample_common_check(input_size, output_size, num_spatial_dims): 5722 torch._check( 5723 len(output_size) == num_spatial_dims, 5724 lambda: f"It is expected output_size equals to {num_spatial_dims}, but got size {len(output_size)}", 5725 ) 5726 expected_input_dims = num_spatial_dims + 2 # N, C, ... 5727 torch._check( 5728 len(input_size) == expected_input_dims, 5729 lambda: f"It is expected input_size equals to {expected_input_dims}, but got size {len(input_size)}", 5730 ) 5731 5732 torch._check( 5733 all(s > 0 for s in input_size[2:]) and all(s > 0 for s in output_size), 5734 lambda: f"Input and output sizes should be greater than 0, but got " 5735 f"input size {input_size} and output size {output_size}", 5736 ) 5737 5738 nbatch, channels = input_size[:2] 5739 return (nbatch, channels, *output_size) 5740 5741 5742@register_meta( 5743 [aten.upsample_nearest1d.default, aten._upsample_nearest_exact1d.default] 5744) 5745def upsample_nearest1d(input, output_size, scales=None): 5746 torch._check( 5747 input.numel() != 0 or multiply_integers(input.size()[1:]), 5748 lambda: f"Non-empty 3D data tensor expected but got a tensor with sizes {input.size()}", 5749 ) 5750 full_output_size = upsample_common_check( 5751 input.size(), output_size, num_spatial_dims=1 5752 ) 5753 return input.new_empty(full_output_size).to( 5754 memory_format=utils.suggest_memory_format(input) 5755 ) 5756 5757 5758@register_meta( 5759 [aten.upsample_nearest2d.default, aten._upsample_nearest_exact2d.default] 5760) 5761def upsample_nearest2d(input, output_size, scales_h=None, scales_w=None): 5762 torch._check( 5763 input.numel() != 0 or multiply_integers(input.size()[1:]), 5764 lambda: f"Non-empty 4D data tensor expected but got a tensor with sizes {input.size()}", 5765 ) 5766 full_output_size = upsample_common_check( 5767 input.size(), output_size, num_spatial_dims=2 5768 ) 5769 output = input.new_empty(full_output_size) 5770 5771 # convert output to correct memory format, if necessary 5772 memory_format = utils.suggest_memory_format(input) 5773 5774 # following "heuristic: only use channels_last path when it's faster than the contiguous path" 5775 _, n_channels, _, _ = input.shape 5776 if input.device.type == "cuda" and n_channels < 4: 5777 memory_format = torch.contiguous_format 5778 5779 output = output.contiguous(memory_format=memory_format) 5780 5781 return output 5782 5783 5784@register_meta( 5785 [ 5786 aten.upsample_nearest2d_backward.default, 5787 aten._upsample_nearest_exact2d_backward.default, 5788 ] 5789) 5790def upsample_nearest2d_backward( 5791 grad_output: Tensor, 5792 output_size: Sequence[Union[int, torch.SymInt]], 5793 input_size: Sequence[Union[int, torch.SymInt]], 5794 scales_h: Optional[float] = None, 5795 scales_w: Optional[float] = None, 5796): 5797 full_output_size = upsample_common_check( 5798 input_size, output_size, num_spatial_dims=2 5799 ) 5800 torch._check( 5801 grad_output.ndim == 4, 5802 lambda: f"Expected grad_output to be a tensor of dimension 4 but got: dimension {grad_output.ndim}", 5803 ) 5804 for i in range(4): 5805 torch._check( 5806 grad_output.size(i) == full_output_size[i], 5807 lambda: ( 5808 f"Expected grad_output to have the same shape as output;" 5809 f" output.size({i}) = {full_output_size[i]}" 5810 f" but got grad_output.size({i}) = {grad_output.size(i)}" 5811 ), 5812 ) 5813 5814 return grad_output.new_empty(input_size).to( 5815 memory_format=utils.suggest_memory_format(grad_output) 5816 ) # type: ignore[call-overload] 5817 5818 5819@register_meta( 5820 [aten.upsample_nearest3d.default, aten._upsample_nearest_exact3d.default] 5821) 5822def upsample_nearest3d(input, output_size, scales_d=None, scales_h=None, scales_w=None): 5823 torch._check( 5824 input.numel() != 0 or multiply_integers(input.size()[1:]), 5825 lambda: f"Non-empty 5D data tensor expected but got a tensor with sizes {input.size()}", 5826 ) 5827 full_output_size = upsample_common_check( 5828 input.size(), output_size, num_spatial_dims=3 5829 ) 5830 return input.new_empty(full_output_size).to( 5831 memory_format=utils.suggest_memory_format(input) 5832 ) 5833 5834 5835@register_meta( 5836 [ 5837 aten.sort.default, 5838 aten.sort.stable, 5839 aten.sort.values, 5840 aten.sort.values_stable, 5841 ] 5842) 5843def meta_sort(self, stable=None, dim=-1, descending=False, values=None, indices=None): 5844 v, i = torch.empty_like(self), torch.empty_like(self, dtype=torch.int64) 5845 if values is not None and indices is not None: 5846 assert isinstance(values, TensorLike) 5847 assert isinstance(indices, TensorLike) 5848 # Makes sure values and indices have the same strides. For cases where 5849 # these have different shapes, like (5, 10, 5) and (0) in msort. 5850 out_shape = v.shape 5851 out_stride = v.stride() 5852 values = _maybe_resize_out(values, out_shape) 5853 indices = _maybe_resize_out(indices, out_shape) 5854 values.as_strided_(out_shape, out_stride) 5855 indices.as_strided_(out_shape, out_stride) 5856 _safe_copy_out(copy_from=v, copy_to=values) # type: ignore[arg-type] 5857 _safe_copy_out(copy_from=i, copy_to=indices) # type: ignore[arg-type] 5858 return values, indices 5859 return v, i 5860 5861 5862def rnn_cell_checkSizes( 5863 input_gates, 5864 hidden_gates, 5865 input_bias, 5866 hidden_bias, 5867 factor, 5868 prev_hidden, 5869): 5870 torch._check(input_gates.ndim == 2, lambda: f"{input_gates.ndim} != 2") 5871 torch._check( 5872 input_gates.shape == hidden_gates.shape, 5873 lambda: f"{input_gates.shape} != {hidden_gates.shape}", 5874 ) 5875 gates_size = input_gates.size(1) 5876 if input_bias is not None: 5877 torch._check(input_bias.ndim == 1, lambda: f"{input_bias.ndim} != 1") 5878 torch._check( 5879 input_bias.numel() == gates_size, 5880 lambda: f"{input_bias.numel()} != {gates_size}", 5881 ) 5882 torch._check( 5883 input_bias.shape == hidden_bias.shape, 5884 lambda: f"{input_bias.shape} != {hidden_bias.shape}", 5885 ) 5886 torch._check(prev_hidden.ndim == 2, lambda: f"{prev_hidden.ndim} != 2") 5887 expected_prev_hidden_numel = input_gates.size(0) * gates_size // factor 5888 torch._check( 5889 prev_hidden.numel() == expected_prev_hidden_numel, 5890 lambda: f"{prev_hidden.numel()} != {input_gates.size(0)} * {gates_size} // {factor} (aka {expected_prev_hidden_numel})", 5891 ) 5892 torch._check( 5893 all( 5894 x.device == input_gates.device 5895 for x in [hidden_gates, input_bias, hidden_bias, prev_hidden] 5896 ), 5897 lambda: "expected all inputs to be same device", 5898 ) 5899 5900 5901@register_meta(aten._thnn_fused_lstm_cell.default) 5902def _thnn_fused_lstm_cell_meta( 5903 input_gates, 5904 hidden_gates, 5905 cx, 5906 input_bias=None, 5907 hidden_bias=None, 5908): 5909 rnn_cell_checkSizes(input_gates, hidden_gates, input_bias, hidden_bias, 4, cx) 5910 workspace = torch.empty_like(input_gates, memory_format=torch.contiguous_format) 5911 hy = torch.empty_like(cx, memory_format=torch.contiguous_format) 5912 cy = torch.empty_like(cx, memory_format=torch.contiguous_format) 5913 return (hy, cy, workspace) 5914 5915 5916@register_meta(aten._cudnn_rnn.default) 5917def _cudnn_rnn( 5918 input, 5919 weight, 5920 weight_stride0, 5921 weight_buf, 5922 hx, 5923 cx, 5924 mode, 5925 hidden_size, 5926 proj_size, 5927 num_layers, 5928 batch_first, 5929 dropout, 5930 train, 5931 bidirectional, 5932 batch_sizes, 5933 dropout_state, 5934): 5935 is_input_packed = len(batch_sizes) != 0 5936 if is_input_packed: 5937 seq_length = len(batch_sizes) 5938 mini_batch = batch_sizes[0] 5939 batch_sizes_sum = input.shape[0] 5940 else: 5941 seq_length = input.shape[1] if batch_first else input.shape[0] 5942 mini_batch = input.shape[0] if batch_first else input.shape[1] 5943 batch_sizes_sum = -1 5944 5945 num_directions = 2 if bidirectional else 1 5946 out_size = proj_size if proj_size != 0 else hidden_size 5947 if is_input_packed: 5948 out_shape = [batch_sizes_sum, out_size * num_directions] 5949 else: 5950 out_shape = ( 5951 [mini_batch, seq_length, out_size * num_directions] 5952 if batch_first 5953 else [seq_length, mini_batch, out_size * num_directions] 5954 ) 5955 output = input.new_empty(out_shape) 5956 5957 cell_shape = [num_layers * num_directions, mini_batch, hidden_size] 5958 if cx is None: 5959 cy = torch.empty(0, device=input.device) 5960 else: 5961 cy = cx.new_empty(cell_shape) 5962 5963 hy = hx.new_empty([num_layers * num_directions, mini_batch, out_size]) 5964 5965 # TODO: Query cudnnGetRNNTrainingReserveSize (expose to python) 5966 reserve_shape = 0 if train else 0 5967 reserve = input.new_empty(reserve_shape, dtype=torch.uint8) 5968 5969 return output, hy, cy, reserve, weight_buf 5970 5971 5972@register_meta(aten.mkldnn_rnn_layer.default) 5973def mkldnn_rnn_layer( 5974 input, 5975 w0, 5976 w1, 5977 w2, 5978 w3, 5979 hx_, 5980 cx_, 5981 reverse, 5982 batch_sizes, 5983 mode, 5984 hidden_size, 5985 num_layers, 5986 has_biases, 5987 bidirectional, 5988 batch_first, 5989 train, 5990): 5991 seq_length = input.shape[1] if batch_first else input.shape[0] 5992 mini_batch = input.shape[0] if batch_first else input.shape[1] 5993 output_chanels = hidden_size 5994 out_shape = ( 5995 [mini_batch, seq_length, output_chanels] 5996 if batch_first 5997 else [seq_length, mini_batch, output_chanels] 5998 ) 5999 output = input.new_empty(out_shape) 6000 if hx_ is None: 6001 hy = torch.empty(0, device=input.device) 6002 else: 6003 hy = hx_.new_empty(hx_.shape) 6004 if cx_ is None: 6005 cy = torch.empty(0, device=input.device) 6006 else: 6007 cy = cx_.new_empty(cx_.shape) 6008 workspace = torch.empty(0, device=input.device, dtype=torch.uint8) 6009 return output, hy, cy, workspace 6010 6011 6012def zero_numel_check_dims(self, dim, fn_name): 6013 if self.ndim == 0: 6014 torch._check_index( 6015 dim == 0 or dim == -1, 6016 lambda: f"{fn_name}: Expected reduction dim -1 or 0 for scalar but got {dim}", 6017 ) 6018 else: 6019 torch._check_index( 6020 self.size(dim) != 0, 6021 lambda: f"{fn_name}: Expected reduction dim {dim} to have non-zero size.", 6022 ) 6023 6024 6025# From aten/src/ATen/native/ReduceOps.cpp 6026def check_argmax_argmin(name, self, dim): 6027 if dim is not None: 6028 dim = maybe_wrap_dim(dim, self.dim()) 6029 zero_numel_check_dims(self, dim, name) 6030 else: 6031 torch._check( 6032 self.numel() != 0, 6033 lambda: f"{name}: Expected reduction dim to be specified for input.numel() == 0.", 6034 ) 6035 6036 6037@register_meta([aten.argmax.default, aten.argmin.default]) 6038def argmax_argmin_meta(self, dim=None, keepdim=False): 6039 check_argmax_argmin("argmax", self, dim) 6040 dims = utils.reduction_dims(self.shape, (dim,) if dim is not None else None) 6041 shape = _compute_reduction_shape(self, dims, keepdim) 6042 return self.new_empty(shape, dtype=torch.int64) 6043 6044 6045@register_meta(aten.scalar_tensor.default) 6046def scalar_tensor(s, dtype=None, layout=None, device=None, pin_memory=None): 6047 return torch.empty( 6048 (), dtype=dtype, layout=layout, device=device, pin_memory=pin_memory 6049 ) 6050 6051 6052@register_meta(aten.topk.default) 6053def topk_meta(self, k, dim=-1, largest=True, sorted=True): 6054 # From aten/src/ATen/native/Sorting.cpp 6055 dim = maybe_wrap_dim(dim, self.dim(), wrap_scalar=True) 6056 sliceSize = 1 if self.dim() == 0 else self.size(dim) 6057 torch._check(k >= 0 and k <= sliceSize, lambda: "k not in range for dimension") 6058 6059 topKSize = list(self.shape) 6060 if len(topKSize) > 0: 6061 topKSize[dim] = k 6062 return self.new_empty(topKSize), self.new_empty(topKSize, dtype=torch.int64) 6063 6064 6065@register_meta([aten.kthvalue.default, aten.kthvalue.values]) 6066@out_wrapper("values", "indices") 6067def kthvalue_meta(self, k, dim=-1, keepdim=False): 6068 dim = maybe_wrap_dim(dim, self.dim(), wrap_scalar=True) 6069 dimSize = self.size(dim) if self.dim() > 0 else 1 6070 torch._check( 6071 k >= 1 and k <= dimSize, 6072 lambda: f"kthvalue(): selected number k out of range for dimension {dim}", 6073 ) 6074 6075 shape = list(self.shape[:dim] + self.shape[dim + 1 :]) 6076 if keepdim and self.dim() > 0: 6077 shape.insert(dim, 1) 6078 return self.new_empty(shape), self.new_empty(shape, dtype=torch.int64) 6079 6080 6081legacy_contiguous_memory_format = torch.contiguous_format 6082 6083 6084# From aten/src/ATen/native/cuda/RNN.cu 6085def checkLSTMBackwardSizes(grad_hy, grad_cy, cx, cy, workspace): 6086 defined_grad = grad_hy if grad_hy is not None else grad_cy 6087 torch._check(defined_grad.dim() == 2, lambda: "") 6088 exp_size = defined_grad.size() 6089 if grad_hy is not None: 6090 torch._check(grad_hy.size() == exp_size, lambda: "") 6091 if grad_cy is not None: 6092 torch._check(grad_cy.size() == exp_size, lambda: "") 6093 torch._check(cx.size() == exp_size, lambda: "") 6094 torch._check(cy.size() == exp_size, lambda: "") 6095 torch._check(workspace.dim() == 2, lambda: "") 6096 torch._check(workspace.numel() == exp_size[0] * exp_size[1] * 4, lambda: "") 6097 6098 6099# From aten/src/ATen/native/cuda/RNN.cu 6100@register_meta(aten._thnn_fused_lstm_cell_backward_impl.default) 6101def _thnn_fused_lstm_cell_backward_impl(grad_hy, grad_cy, cx, cy, workspace, has_bias): 6102 if grad_hy is None and grad_cy is None: 6103 return None, None, None 6104 checkLSTMBackwardSizes(grad_hy, grad_cy, cx, cy, workspace) 6105 grad_gates = torch.empty_like( 6106 workspace, memory_format=legacy_contiguous_memory_format 6107 ) 6108 grad_cx = torch.empty_like(cx, memory_format=legacy_contiguous_memory_format) 6109 grad_bias = grad_gates.sum(0, keepdim=False) if has_bias else None 6110 return grad_gates, grad_cx, grad_bias 6111 6112 6113# From aten/src/ATen/native/mps/operations/Linear.mm 6114@register_meta(aten.linear_backward.default) 6115def linear_backward(input_, grad_output_, weight_, output_mask): 6116 grad_input = None 6117 grad_weight = None 6118 grad_bias = None 6119 if output_mask[0]: 6120 grad_input = grad_output_.new_empty(input_.size()) 6121 if output_mask[1] or output_mask[2]: 6122 grad_weight = grad_output_.new_empty((grad_output_.size(-1), input_.size(-1))) 6123 grad_bias = grad_output_.new_empty(grad_output_.size(-1)) 6124 return (grad_input, grad_weight, grad_bias) 6125 6126 6127@register_meta(aten.pixel_shuffle.default) 6128def meta_pixel_shuffle(self, upscale_factor): 6129 assert ( 6130 len(self.shape) > 2 and self.shape[-3] % (upscale_factor * upscale_factor) == 0 6131 ), f"Invalid input shape for pixel_shuffle: {self.shape} with upscale_factor = {upscale_factor}" 6132 6133 def is_channels_last(ten): 6134 return torch._prims_common.suggest_memory_format(ten) == torch.channels_last 6135 6136 def pick_memory_format(): 6137 if is_channels_last(self): 6138 if device_hint(self) == "cuda": 6139 return torch.contiguous_format 6140 else: 6141 return torch.channels_last 6142 elif self.is_contiguous(memory_format=torch.contiguous_format): 6143 return torch.contiguous_format 6144 elif self.is_contiguous(memory_format=torch.preserve_format): 6145 return torch.preserve_format 6146 6147 C = self.shape[-3] // (upscale_factor * upscale_factor) 6148 Hr = self.shape[-2] * upscale_factor 6149 Wr = self.shape[-1] * upscale_factor 6150 out_shape = (*self.shape[:-3], C, Hr, Wr) 6151 6152 out = self.new_empty(out_shape) 6153 out = out.to(memory_format=pick_memory_format()) # type: ignore[call-overload] 6154 return out 6155 6156 6157@register_meta(aten.mkldnn_rnn_layer_backward.default) 6158def mkldnn_rnn_layer_backward( 6159 input, 6160 weight0, 6161 weight1, 6162 weight2, 6163 weight3, 6164 hx_, 6165 cx_tmp, 6166 output, 6167 hy_, 6168 cy_, 6169 grad_output_r_opt, 6170 grad_hy_r_opt, 6171 grad_cy_r_opt, 6172 reverse, 6173 mode, 6174 hidden_size, 6175 num_layers, 6176 has_biases, 6177 train, 6178 bidirectional, 6179 batch_sizes, 6180 batch_first, 6181 workspace, 6182): 6183 diff_x = input.new_empty(input.shape) 6184 diff_hx = hx_.new_empty(hx_.shape) 6185 diff_cx = cx_tmp.new_empty(cx_tmp.shape) 6186 diff_w1 = weight0.new_empty(weight0.shape) 6187 diff_w2 = weight1.new_empty(weight1.shape) 6188 diff_b = weight2.new_empty(weight2.shape) 6189 return diff_x, diff_w1, diff_w2, diff_b, diff_b, diff_hx, diff_cx 6190 6191 6192@register_meta([aten.bucketize.Tensor, aten.bucketize.Tensor_out]) 6193@out_wrapper() 6194def meta_bucketize(self, boundaries, *, out_int32=False, right=False): 6195 return torch.empty_like( 6196 self, dtype=torch.int32 if out_int32 else torch.int64 6197 ).contiguous() 6198 6199 6200@register_meta([aten.histc]) 6201@out_wrapper() 6202def meta_histc(input, bins=100, min=0, max=0): 6203 fn_name = "histc()" 6204 if device_hint(input) == "cpu": 6205 torch._check( 6206 input.is_floating_point(), 6207 lambda: f"\"histogram_cpu\" not implemented for '{input.dtype}'", 6208 ) 6209 torch._check( 6210 isinstance(bins, IntLike), 6211 lambda: f"{fn_name}: argument 'bins' must be int, not {type(bins)}", 6212 ) 6213 torch._check(bins > 0, lambda: f"{fn_name}: bins must be > 0, but got {bins}") 6214 torch._check( 6215 isinstance(min, Number), 6216 lambda: f"{fn_name}: argument 'min' must be Number, not {type(min)}", 6217 ) 6218 torch._check( 6219 isinstance(max, Number), 6220 lambda: f"{fn_name}: argument 'max' must be Number, not {type(max)}", 6221 ) 6222 torch._check(max >= min, lambda: "{fn_name}: max must be larger than min") 6223 return torch.empty(bins, device=input.device, dtype=input.dtype) 6224 6225 6226@register_meta( 6227 [aten._upsample_bilinear2d_aa.default, aten._upsample_bicubic2d_aa.default] 6228) 6229def meta_upsample_bimode2d_aa( 6230 input, 6231 output_size, 6232 align_corners, 6233 scales_h=None, 6234 scales_w=None, 6235): 6236 full_output_size = upsample_common_check( 6237 input.size(), output_size, num_spatial_dims=2 6238 ) 6239 torch._check( 6240 input.numel() != 0 or all(size > 0 for size in input.size()[1:]), 6241 lambda: f"Non-empty 4D data tensor expected but got a tensor with sizes {input.size()}", 6242 ) 6243 return input.new_empty(full_output_size).to( 6244 memory_format=utils.suggest_memory_format(input) 6245 ) 6246 6247 6248# From aten/src/ATen/native/cuda/AmpKernels.cu 6249@register_meta(aten._amp_foreach_non_finite_check_and_unscale_.default) 6250def _amp_foreach_non_finite_check_and_unscale_(self, found_inf, inv_scale): 6251 torch._check( 6252 found_inf.numel() == 1, lambda: "found_inf must be a 1-element tensor." 6253 ) 6254 torch._check( 6255 inv_scale.numel() == 1, lambda: "inv_scale must be a 1-element tensor." 6256 ) 6257 torch._check( 6258 found_inf.dtype.is_floating_point, 6259 lambda: "found_inf must be a float tensor.", 6260 ) 6261 torch._check( 6262 inv_scale.dtype.is_floating_point, 6263 lambda: "inv_scale must be a float tensor.", 6264 ) 6265 6266 6267# From aten/src/ATen/native/UnaryOps.cpp 6268@register_meta([aten.nan_to_num.default, aten.nan_to_num.out]) 6269@out_wrapper() 6270def nan_to_num(self, nan=None, posinf=None, neginf=None): 6271 result_size = list(self.size()) 6272 return self.new_empty(result_size) 6273 6274 6275@register_meta(torch.ops.aten.transpose_) 6276def transpose_(self, dim0, dim1): 6277 assert ( 6278 self.layout 6279 not in { 6280 torch.sparse_csr, 6281 torch.sparse_csc, 6282 torch.sparse_bsr, 6283 torch.sparse_bsc, 6284 } 6285 ), f"torch.transpose_: in-place transposition is not supported for {self.layout} layout" 6286 6287 ndims = self.ndim 6288 6289 dim0 = maybe_wrap_dim(dim0, ndims) 6290 dim1 = maybe_wrap_dim(dim1, ndims) 6291 6292 if dim0 == dim1: 6293 return self 6294 6295 size = list(self.size()) 6296 stride = list(self.stride()) 6297 6298 stride[dim0], stride[dim1] = stride[dim1], stride[dim0] 6299 size[dim0], size[dim1] = size[dim1], size[dim0] 6300 6301 self.as_strided_(size, stride) 6302 return self 6303 6304 6305@register_meta(torch.ops.aten.t_) 6306def t_(self): 6307 ndims = self.ndim 6308 6309 if self.is_sparse: 6310 sparse_dim = self.sparse_dim() 6311 dense_dim = self.dense_dim() 6312 assert ( 6313 sparse_dim <= 2 and dense_dim == 0 6314 ), f"t_ expects a tensor with <= 2 sparse and 0 dense dimensions, but got {sparse_dim} sparse and {dense_dim} dense dimensions" # noqa: B950 6315 else: 6316 assert ( 6317 self.dim() <= 2 6318 ), f"t_ expects a tensor with <= 2 dimensions, but self is {ndims}D" 6319 6320 return transpose_(self, 0, 0 if ndims < 2 else 1) 6321 6322 6323@register_meta(aten.searchsorted) 6324@out_wrapper() 6325def meta_searchsorted( 6326 sorted_sequence, 6327 self, 6328 *, 6329 out_int32=False, 6330 right=False, 6331 side=None, 6332 sorter=None, 6333): 6334 dtype = torch.int32 if out_int32 else torch.int64 6335 if isinstance(self, torch.Tensor): 6336 return torch.empty_like(self, dtype=dtype).contiguous() 6337 else: # Scalar 6338 return torch.empty((), dtype=dtype, device=sorted_sequence.device) 6339 6340 6341def _check_for_unsupported_isin_dtype(dtype): 6342 torch._check( 6343 dtype not in [torch.bool, torch.bfloat16, torch.complex128, torch.complex64], 6344 lambda: f"Unsupported input type encountered for isin(): {dtype}", 6345 ) 6346 6347 6348@register_meta(aten._embedding_bag_backward) 6349def meta_embedding_bag_backward( 6350 grad, 6351 indices, 6352 offsets, 6353 offset2bag, 6354 bag_size, 6355 maximum_indices, 6356 num_weights, 6357 scale_grad_by_freq, 6358 mode, 6359 sparse, 6360 per_sample_weights, 6361 padding_idx=-1, 6362): 6363 if sparse: 6364 return aten._embedding_bag_sparse_backward( 6365 grad, 6366 indices, 6367 offsets, 6368 offset2bag, 6369 bag_size, 6370 num_weights, 6371 scale_grad_by_freq, 6372 mode, 6373 per_sample_weights, 6374 padding_idx, 6375 ) 6376 else: 6377 return meta_embedding_bag_dense_backward( 6378 grad, 6379 indices, 6380 offset2bag, 6381 bag_size, 6382 maximum_indices, 6383 num_weights, 6384 scale_grad_by_freq, 6385 mode, 6386 per_sample_weights, 6387 padding_idx, 6388 ) 6389 6390 6391@register_meta(aten._embedding_bag_dense_backward) 6392def meta_embedding_bag_dense_backward( 6393 grad, 6394 indices, 6395 offset2bag, 6396 bag_size, 6397 maximum_indices, 6398 num_weights, 6399 scale_grad_by_freq, 6400 mode, 6401 per_sample_weights, 6402 padding_idx=-1, 6403): 6404 torch._check( 6405 grad.dtype in [torch.float16, torch.bfloat16, torch.float32, torch.float64], 6406 lambda: f"Unsupported input type encountered: {grad.dtype}", 6407 ) 6408 MODE_SUM, MODE_MEAN, MODE_MAX = range(3) 6409 if mode == MODE_MAX: 6410 torch._check(maximum_indices is not None) 6411 index_grad_weight = grad.new_empty((num_weights, grad.size(1))) 6412 return index_grad_weight 6413 6414 6415@register_meta(aten._embedding_bag_per_sample_weights_backward) 6416def meta_embedding_bag_per_sample_weights_backward( 6417 grad, 6418 weight, 6419 indices, 6420 offsets, 6421 offset2bag, 6422 mode, 6423 padding_idx=-1, 6424): 6425 MODE_SUM, MODE_MEAN, MODE_MAX = range(3) 6426 embedding_features = grad.size(1) 6427 torch._check( 6428 mode == MODE_SUM, 6429 "embedding_bag_backward: per_sample_weights only supported for mode='sum'", 6430 ) 6431 torch._check(grad.dim() == 2) 6432 torch._check(indices.dim() == 1) 6433 num_samples = indices.size(0) 6434 torch._check(weight.dim() == 2) 6435 torch._check(weight.size(1) == embedding_features) 6436 output = grad.new_empty((num_samples,)) 6437 return output 6438 6439 6440@register_meta(aten.isin) 6441@out_wrapper() 6442def meta_isin(elements, test_elements, *, assume_unique=False, invert=False): 6443 torch._check( 6444 isinstance(elements, Tensor) or isinstance(test_elements, Tensor), 6445 lambda: "At least one of elements and test_elements must be a Tensor.", 6446 ) 6447 if not isinstance(elements, Tensor): 6448 elements = torch.tensor(elements, device=test_elements.device) 6449 6450 if not isinstance(test_elements, Tensor): 6451 test_elements = torch.tensor(test_elements, device=elements.device) 6452 6453 _check_for_unsupported_isin_dtype(elements.dtype) 6454 _check_for_unsupported_isin_dtype(test_elements.dtype) 6455 return torch.empty_like(elements, dtype=torch.bool) 6456 6457 6458@register_meta(aten.polygamma) 6459@out_wrapper() 6460def meta_polygamma(n: int, self: Tensor) -> Tensor: 6461 torch._check(n >= 0, lambda: "polygamma(n, x) does not support negative n.") 6462 _, result_dtype = elementwise_dtypes( 6463 self, 6464 type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT, 6465 ) 6466 return torch.empty_like(self, dtype=result_dtype) 6467 6468 6469@register_meta(aten._local_scalar_dense) 6470def meta_local_scalar_dense(self: Tensor): 6471 raise RuntimeError("Tensor.item() cannot be called on meta tensors") 6472 6473 6474@register_meta(aten._jagged_to_padded_dense_forward.default) 6475def meta__jagged_to_padded_dense_forward( 6476 values: Tensor, 6477 offsets: List[Tensor], 6478 max_lengths: List[int], 6479 padding_value: float = 0.0, 6480): 6481 # only one jagged dim is supported for now 6482 assert len(offsets) == 1 6483 assert len(max_lengths) == 1 6484 6485 B = offsets[0].shape[0] - 1 6486 S = max_lengths[0] 6487 output_shape = (B, S, *values.shape[1:]) 6488 return values.new_empty(output_shape) 6489 6490 6491@register_meta(aten._padded_dense_to_jagged_forward.default) 6492def meta__padded_dense_to_jagged_forward( 6493 padded: Tensor, 6494 offsets: List[Tensor], 6495 total_L: Optional[int] = None, 6496): 6497 # only one jagged dim is supported for now 6498 assert len(offsets) == 1 6499 6500 if not total_L: 6501 assert isinstance(padded, torch._subclasses.FakeTensor) 6502 shape_env = padded.fake_mode.shape_env 6503 assert shape_env is not None 6504 total_L = shape_env.create_unbacked_symint() 6505 torch.fx.experimental.symbolic_shapes._constrain_range_for_size( 6506 total_L, min=0, max=None 6507 ) 6508 6509 output_shape = (total_L, *padded.shape[2:]) 6510 return padded.new_empty(output_shape) 6511 6512 6513def _create_unary_float_meta_func(func): 6514 @register_meta(func) 6515 @out_wrapper() 6516 def _f(x): 6517 return elementwise_meta( 6518 x, type_promotion=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT 6519 ) 6520 6521 return _f 6522 6523 6524def _create_binary_float_meta_func(func): 6525 @register_meta(func) 6526 @out_wrapper() 6527 def _f(x, y): 6528 return elementwise_meta( 6529 x, y, type_promotion=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT 6530 ) 6531 6532 return _f 6533 6534 6535_create_unary_float_meta_func(aten.special_airy_ai) 6536_create_unary_float_meta_func(aten.special_bessel_y0) 6537_create_unary_float_meta_func(aten.special_bessel_y1) 6538_create_unary_float_meta_func(aten.special_modified_bessel_i0) 6539_create_unary_float_meta_func(aten.special_modified_bessel_i1) 6540_create_unary_float_meta_func(aten.special_modified_bessel_k0) 6541_create_unary_float_meta_func(aten.special_modified_bessel_k1) 6542_create_unary_float_meta_func(aten.special_scaled_modified_bessel_k0) 6543_create_unary_float_meta_func(aten.special_scaled_modified_bessel_k1) 6544 6545 6546_create_binary_float_meta_func(aten.special_chebyshev_polynomial_t) 6547_create_binary_float_meta_func(aten.special_chebyshev_polynomial_u) 6548_create_binary_float_meta_func(aten.special_chebyshev_polynomial_v) 6549_create_binary_float_meta_func(aten.special_chebyshev_polynomial_w) 6550_create_binary_float_meta_func(aten.special_shifted_chebyshev_polynomial_t) 6551_create_binary_float_meta_func(aten.special_shifted_chebyshev_polynomial_u) 6552_create_binary_float_meta_func(aten.special_shifted_chebyshev_polynomial_v) 6553_create_binary_float_meta_func(aten.special_shifted_chebyshev_polynomial_w) 6554_create_binary_float_meta_func(aten.special_hermite_polynomial_h) 6555_create_binary_float_meta_func(aten.special_hermite_polynomial_he) 6556_create_binary_float_meta_func(aten.special_laguerre_polynomial_l) 6557_create_binary_float_meta_func(aten.special_legendre_polynomial_p) 6558 6559 6560# We must also trigger meta registrations from PrimTorch ref 6561# decompositions 6562import torch._refs 6563import torch._refs.nn.functional 6564import torch._refs.special 6565 6566 6567def activate_meta(): 6568 activate_meta_table = {} 6569 6570 # For a given op, we pick the most specific decomp function from 6571 # global_decomp_table in the precedence order of meta > post_autograd > pre_autograd 6572 for type in ["meta", "post_autograd", "pre_autograd"]: 6573 registry = global_decomposition_table[type] 6574 6575 for opo in registry: 6576 if opo not in activate_meta_table: 6577 activate_meta_table[opo] = registry[opo] 6578 6579 for op_overload, fn in activate_meta_table.items(): 6580 # Don't register meta for HigherOrderOp's decomp. 6581 # We can reconsider this in the future, but in general, 6582 # the way you do a meta for a HigherOrderOp is different from 6583 # OpOverload. 6584 if isinstance(op_overload, torch._ops.HigherOrderOperator): 6585 continue 6586 assert isinstance(op_overload, OpOverload) 6587 6588 op_overload.py_impl(torch._C.DispatchKey.Meta)(fn) 6589 6590 if torch._C._dispatch_has_kernel_for_dispatch_key( 6591 op_overload.name(), "CompositeImplicitAutograd" 6592 ): 6593 # Internally, we shouldn't be registering meta kernels for any operators that 6594 # have CompositeImplicitAutograd kernels. 6595 # Instead, we should be letting those decompositions run, and writing meta kernels 6596 # only for the base operators. 6597 if op_overload in global_decomposition_table["meta"]: 6598 raise RuntimeError( 6599 f"{op_overload} is a CompositeImplicitAutograd op, we shouldn't " 6600 "register meta function for it. Instead, we should let the decomposition run and write " 6601 "meta kernels for the base operators." 6602 ) 6603 elif op_overload.is_view: 6604 # Attempting to register a python meta kernel for a view operator. 6605 # We shouldn't do this, because the output will report as not having aliased storages. 6606 # All view ops have meta kernels in C++ today, so we should use those instead. 6607 pass 6608 elif ( 6609 op_overload.name() 6610 in { 6611 "aten::empty_strided", # causing infinite recursion, test_meta.py 6612 "aten::clone", # causing infinite recursion 6613 "aten::_to_copy", # causing infinite recursion, test_serialization.py -k test_tensor_subclass_getstate_overwrite # noqa: B950 6614 "aten::copy_", # Exception not raised, test_torch.py -k test_storage_meta_errors_cpu_int64 # noqa: B950 6615 "aten::constant_pad_nd", # requires_grad mismatch, test_ops.py -k test_fake_crossref_backward_amp_istft_cuda_float32 # noqa: B950 6616 "aten::rot90", # requires_grad mismatch! test_ops.py -k test_fake_crossref_backward_amp_rot90_cuda_float32 # noqa: B950 6617 "aten::as_strided_scatter", # requires_grad mismatch, test_ops.py -k test_fake_crossref_backward_no_amp_as_strided_scatter_cuda_float32 # noqa: B950 6618 } 6619 ): 6620 pass 6621 else: 6622 if "mkldnn::" in op_overload.name(): 6623 _meta_lib_dont_use_me_use_register_meta_for_mkldnn.impl(op_overload, fn) 6624 elif "mkl::" in op_overload.name(): 6625 _meta_lib_dont_use_me_use_register_meta_for_mkl.impl(op_overload, fn) 6626 elif "onednn::" in op_overload.name(): 6627 _meta_lib_dont_use_me_use_register_meta_for_onednn.impl(op_overload, fn) 6628 elif "quantized::" in op_overload.name(): 6629 _meta_lib_dont_use_me_use_register_meta_for_quantized.impl( 6630 op_overload, fn 6631 ) 6632 else: 6633 _meta_lib_dont_use_me_use_register_meta.impl(op_overload, fn) 6634 6635 6636activate_meta() 6637